mirror of
https://github.com/Ed1s0nZ/CyberStrikeAI.git
synced 2026-05-18 22:08:13 +02:00
Compare commits
6 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 8a2177ffab | |||
| 3a7bbfbb88 | |||
| 7c01641de9 | |||
| 1c1086eea4 | |||
| 8f4f40f894 | |||
| 7f16ba706a |
+1
-1
@@ -10,7 +10,7 @@
|
||||
# ============================================
|
||||
|
||||
# 前端显示的版本号(可选,不填则显示默认版本)
|
||||
version: "v1.3.18"
|
||||
version: "v1.3.19"
|
||||
|
||||
# 服务器配置
|
||||
server:
|
||||
|
||||
+38
-6
@@ -2,7 +2,7 @@
|
||||
|
||||
[English](robot_en.md)
|
||||
|
||||
本文档说明如何通过**钉钉**、**飞书**与 CyberStrikeAI 对话(长连接模式),在手机端即可使用,无需在服务器上打开网页。按下面步骤操作可避免常见弯路。
|
||||
本文档说明如何通过**钉钉**、**飞书**与 **企业微信** 与 CyberStrikeAI 对话(长连接 / 回调模式),在手机端即可使用,无需在服务器上打开网页。按下面步骤操作可避免常见弯路。
|
||||
|
||||
---
|
||||
|
||||
@@ -19,12 +19,13 @@
|
||||
|
||||
---
|
||||
|
||||
## 二、支持的平台(长连接)
|
||||
## 二、支持的平台(长连接 / 回调)
|
||||
|
||||
| 平台 | 说明 |
|
||||
|------|------|
|
||||
| 钉钉 | 使用 Stream 长连接,程序主动连接钉钉接收消息 |
|
||||
| 飞书 | 使用长连接,程序主动连接飞书接收消息 |
|
||||
| 平台 | 说明 |
|
||||
|----------|------|
|
||||
| 钉钉 | 使用 Stream 长连接,程序主动连接钉钉接收消息 |
|
||||
| 飞书 | 使用长连接,程序主动连接飞书接收消息 |
|
||||
| 企业微信 | 使用 HTTP 回调接收消息,被动回包 + 主动调用企业微信发送消息 API |
|
||||
|
||||
下面第三节会按平台写清:在开放平台要做什么、要复制哪些字段、填到 CyberStrikeAI 的哪一栏。
|
||||
|
||||
@@ -101,6 +102,37 @@
|
||||
|
||||
---
|
||||
|
||||
### 3.3 企业微信 (WeCom)
|
||||
|
||||
> 企业微信目前采用「HTTP 回调 + 主动发送消息 API」的方式工作:
|
||||
> - 用户发消息 → 企业微信以加密 XML **回调到你的服务器**(本程序的 `/api/robot/wecom`);
|
||||
> - CyberStrikeAI 解密并调用 AI → 使用企业微信的 `message/send` 接口**主动发消息给用户**。
|
||||
|
||||
**配置概览:**
|
||||
|
||||
- 在企业微信管理后台创建或选择一个**自建应用**。
|
||||
- 在该应用的「接收消息」处配置回调 URL、Token、EncodingAESKey。
|
||||
- 在 CyberStrikeAI 的 `config.yaml` 中填入:
|
||||
- `robots.wecom.corp_id`:企业 ID(CorpID)
|
||||
- `robots.wecom.agent_id`:应用的 AgentId
|
||||
- `robots.wecom.token`:消息回调使用的 Token
|
||||
- `robots.wecom.encoding_aes_key`:消息回调使用的 EncodingAESKey
|
||||
- `robots.wecom.secret`:该应用的 Secret(用于调用企业微信主动发送消息接口)
|
||||
|
||||
> **重要:IP 白名单(errcode 60020)**
|
||||
> CyberStrikeAI 使用 `https://qyapi.weixin.qq.com/cgi-bin/message/send` 主动发送 AI 回复。
|
||||
> 若企业微信日志或本程序日志中出现 `errcode 60020 not allow to access from your ip`:
|
||||
>
|
||||
> - 说明你的服务器出口 IP **没有加入企业微信的 IP 白名单**;
|
||||
> - 请在企业微信管理后台中找到该自建应用的**「安全设置 / IP 白名单」**(具体入口可能因版本略有不同),将运行 CyberStrikeAI 的服务器公网 IP(如 `110.xxx.xxx.xxx`)加入白名单;
|
||||
> - 保存后等待生效,再次发送消息测试。
|
||||
>
|
||||
> 如果 IP 未加入白名单,企业微信会拒绝主动发送消息,表现为:
|
||||
> - 回调接口 `/api/robot/wecom` 能正常收到并处理消息;
|
||||
> - 但手机端**始终收不到 AI 回复**,日志中有 `not allow to access from your ip` 提示。
|
||||
|
||||
---
|
||||
|
||||
## 四、机器人命令
|
||||
|
||||
在钉钉/飞书中向机器人发送以下**文本命令**(仅支持文本):
|
||||
|
||||
+36
-6
@@ -2,7 +2,7 @@
|
||||
|
||||
[中文](robot.md)
|
||||
|
||||
This document explains how to chat with CyberStrikeAI from **DingTalk** and **Lark (Feishu)** using long-lived connections—no need to open a browser on the server. Following the steps below helps avoid common mistakes.
|
||||
This document explains how to chat with CyberStrikeAI from **DingTalk**, **Lark (Feishu)**, and **WeCom (Enterprise WeChat)** using long-lived connections or HTTP callbacks—no need to open a browser on the server. Following the steps below helps avoid common mistakes.
|
||||
|
||||
---
|
||||
|
||||
@@ -19,12 +19,13 @@ Settings are written to the `robots` section of `config.yaml`; you can also edit
|
||||
|
||||
---
|
||||
|
||||
## 2. Supported platforms (long-lived connection)
|
||||
## 2. Supported platforms (long-lived / callback)
|
||||
|
||||
| Platform | Description |
|
||||
|----------|-------------|
|
||||
| DingTalk | Stream long-lived connection; the app connects to DingTalk to receive messages |
|
||||
| Lark (Feishu) | Long-lived connection; the app connects to Lark to receive messages |
|
||||
| Platform | Description |
|
||||
|----------------|-------------|
|
||||
| DingTalk | Stream long-lived connection; the app connects to DingTalk to receive messages |
|
||||
| Lark (Feishu) | Long-lived connection; the app connects to Lark to receive messages |
|
||||
| WeCom (Qiye WX)| HTTP callback to receive messages; CyberStrikeAI replies via WeCom’s message sending API |
|
||||
|
||||
Section 3 below describes, per platform, what to do in the developer console and which fields to copy into CyberStrikeAI.
|
||||
|
||||
@@ -100,6 +101,35 @@ If you only have a **custom bot** Webhook URL (`oapi.dingtalk.com/robot/send?acc
|
||||
|
||||
---
|
||||
|
||||
### 3.3 WeCom (Enterprise WeChat)
|
||||
|
||||
> WeCom uses a **“HTTP callback + active message send API”** model:
|
||||
> - User sends a message → WeCom sends an **encrypted XML callback** to your server (CyberStrikeAI’s `/api/robot/wecom`).
|
||||
> - CyberStrikeAI decrypts it, calls the AI, then uses WeCom’s `message/send` API to **actively push the reply** to the user.
|
||||
|
||||
**Configuration overview:**
|
||||
|
||||
- In the WeCom admin console, create or select a **custom app** (自建应用).
|
||||
- In that app’s settings, configure the message **callback URL**, **Token**, and **EncodingAESKey**.
|
||||
- In CyberStrikeAI’s `config.yaml`, fill in:
|
||||
- `robots.wecom.corp_id`: your CorpID (企业 ID)
|
||||
- `robots.wecom.agent_id`: the app’s AgentId
|
||||
- `robots.wecom.token`: the Token used for message callbacks
|
||||
- `robots.wecom.encoding_aes_key`: the EncodingAESKey used for callbacks
|
||||
- `robots.wecom.secret`: the app’s Secret (used when calling WeCom APIs to send messages)
|
||||
|
||||
> **Important: IP allowlist (errcode 60020)**
|
||||
> CyberStrikeAI calls `https://qyapi.weixin.qq.com/cgi-bin/message/send` to actively send AI replies.
|
||||
> If logs show `errcode 60020 not allow to access from your ip`:
|
||||
>
|
||||
> - Your server’s outbound IP is **not in WeCom’s IP allowlist**.
|
||||
> - In the WeCom admin console, open the custom app’s **Security / IP allowlist** settings (name may vary slightly), and add the public IP of the machine running CyberStrikeAI (e.g. `110.xxx.xxx.xxx`).
|
||||
> - Save and wait for it to take effect, then test again.
|
||||
>
|
||||
> If the IP is not whitelisted, WeCom will reject active message sending. You will see that `/api/robot/wecom` receives and processes callbacks, but users **never see AI replies**, and logs contain `not allow to access from your ip`.
|
||||
|
||||
---
|
||||
|
||||
## 4. Bot commands
|
||||
|
||||
Send these **text commands** to the bot in DingTalk or Lark (text only):
|
||||
|
||||
@@ -345,8 +345,29 @@ func (mc *MemoryCompressor) adjustRecentStartForToolCalls(msgs []ChatMessage, re
|
||||
adjusted--
|
||||
}
|
||||
|
||||
// Ensure at least one user message is included in recent messages to avoid Qwen model error
|
||||
// Qwen models require a user message in the message array, otherwise they return:
|
||||
// "No user query found in messages"
|
||||
hasUserMessage := false
|
||||
for i := adjusted; i < len(msgs); i++ {
|
||||
if strings.EqualFold(msgs[i].Role, "user") {
|
||||
hasUserMessage = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// If no user message in recent messages, adjust backwards to include one
|
||||
if !hasUserMessage {
|
||||
for adjusted > 0 {
|
||||
adjusted--
|
||||
if strings.EqualFold(msgs[adjusted].Role, "user") {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if adjusted != recentStart {
|
||||
mc.logger.Debug("adjusted recent window to keep tool call context",
|
||||
mc.logger.Debug("adjusted recent window to keep tool call context and user message",
|
||||
zap.Int("original_recent_start", recentStart),
|
||||
zap.Int("adjusted_recent_start", adjusted),
|
||||
)
|
||||
|
||||
@@ -582,9 +582,18 @@ func Default() *Config {
|
||||
},
|
||||
Retrieval: RetrievalConfig{
|
||||
TopK: 5,
|
||||
SimilarityThreshold: 0.7,
|
||||
SimilarityThreshold: 0.65, // 降低阈值到 0.65,减少漏检
|
||||
HybridWeight: 0.7,
|
||||
},
|
||||
Indexing: IndexingConfig{
|
||||
ChunkSize: 768, // 增加到 768,更好的上下文保持
|
||||
ChunkOverlap: 50,
|
||||
MaxChunksPerItem: 20, // 限制单个知识项最多 20 个块,避免消耗过多配额
|
||||
MaxRPM: 100, // 默认 100 RPM,避免 429 错误
|
||||
RateLimitDelayMs: 600, // 600ms 间隔,对应 100 RPM
|
||||
MaxRetries: 3,
|
||||
RetryDelayMs: 1000,
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
+381
-77
@@ -1,11 +1,15 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/aes"
|
||||
"crypto/cipher"
|
||||
"crypto/rand"
|
||||
"crypto/sha1"
|
||||
"encoding/base64"
|
||||
"encoding/binary"
|
||||
"encoding/json"
|
||||
"encoding/xml"
|
||||
"errors"
|
||||
"fmt"
|
||||
@@ -141,56 +145,9 @@ func (h *RobotHandler) HandleMessage(platform, userID, text string) (reply strin
|
||||
return "请输入内容或发送「帮助」/ help 查看命令。"
|
||||
}
|
||||
|
||||
// 命令分发(支持中英文)
|
||||
switch {
|
||||
case text == robotCmdHelp || text == "help" || text == "?" || text == "?":
|
||||
return h.cmdHelp()
|
||||
case text == robotCmdList || text == robotCmdListAlt || text == "list":
|
||||
return h.cmdList()
|
||||
case strings.HasPrefix(text, robotCmdSwitch+" ") || strings.HasPrefix(text, robotCmdContinue+" ") || strings.HasPrefix(text, "switch ") || strings.HasPrefix(text, "continue "):
|
||||
var id string
|
||||
switch {
|
||||
case strings.HasPrefix(text, robotCmdSwitch+" "):
|
||||
id = strings.TrimSpace(text[len(robotCmdSwitch)+1:])
|
||||
case strings.HasPrefix(text, robotCmdContinue+" "):
|
||||
id = strings.TrimSpace(text[len(robotCmdContinue)+1:])
|
||||
case strings.HasPrefix(text, "switch "):
|
||||
id = strings.TrimSpace(text[7:])
|
||||
default:
|
||||
id = strings.TrimSpace(text[9:])
|
||||
}
|
||||
return h.cmdSwitch(platform, userID, id)
|
||||
case text == robotCmdNew || text == "new":
|
||||
return h.cmdNew(platform, userID)
|
||||
case text == robotCmdClear || text == "clear":
|
||||
return h.cmdClear(platform, userID)
|
||||
case text == robotCmdCurrent || text == "current":
|
||||
return h.cmdCurrent(platform, userID)
|
||||
case text == robotCmdStop || text == "stop":
|
||||
return h.cmdStop(platform, userID)
|
||||
case text == robotCmdRoles || text == robotCmdRolesList || text == "roles":
|
||||
return h.cmdRoles()
|
||||
case strings.HasPrefix(text, robotCmdRoles+" ") || strings.HasPrefix(text, robotCmdSwitchRole+" ") || strings.HasPrefix(text, "role "):
|
||||
var roleName string
|
||||
switch {
|
||||
case strings.HasPrefix(text, robotCmdRoles+" "):
|
||||
roleName = strings.TrimSpace(text[len(robotCmdRoles)+1:])
|
||||
case strings.HasPrefix(text, robotCmdSwitchRole+" "):
|
||||
roleName = strings.TrimSpace(text[len(robotCmdSwitchRole)+1:])
|
||||
default:
|
||||
roleName = strings.TrimSpace(text[5:])
|
||||
}
|
||||
return h.cmdSwitchRole(platform, userID, roleName)
|
||||
case strings.HasPrefix(text, robotCmdDelete+" ") || strings.HasPrefix(text, "delete "):
|
||||
var convID string
|
||||
if strings.HasPrefix(text, robotCmdDelete+" ") {
|
||||
convID = strings.TrimSpace(text[len(robotCmdDelete)+1:])
|
||||
} else {
|
||||
convID = strings.TrimSpace(text[7:])
|
||||
}
|
||||
return h.cmdDelete(platform, userID, convID)
|
||||
case text == robotCmdVersion || text == "version":
|
||||
return h.cmdVersion()
|
||||
// 先尝试作为命令处理(支持中英文)
|
||||
if cmdReply, ok := h.handleRobotCommand(platform, userID, text); ok {
|
||||
return cmdReply
|
||||
}
|
||||
|
||||
// 普通消息:走 Agent
|
||||
@@ -404,6 +361,62 @@ func (h *RobotHandler) cmdVersion() string {
|
||||
return "CyberStrikeAI " + v
|
||||
}
|
||||
|
||||
// handleRobotCommand 处理机器人内置命令;若匹配到命令返回 (回复内容, true),否则返回 ("", false)
|
||||
func (h *RobotHandler) handleRobotCommand(platform, userID, text string) (string, bool) {
|
||||
switch {
|
||||
case text == robotCmdHelp || text == "help" || text == "?" || text == "?":
|
||||
return h.cmdHelp(), true
|
||||
case text == robotCmdList || text == robotCmdListAlt || text == "list":
|
||||
return h.cmdList(), true
|
||||
case strings.HasPrefix(text, robotCmdSwitch+" ") || strings.HasPrefix(text, robotCmdContinue+" ") || strings.HasPrefix(text, "switch ") || strings.HasPrefix(text, "continue "):
|
||||
var id string
|
||||
switch {
|
||||
case strings.HasPrefix(text, robotCmdSwitch+" "):
|
||||
id = strings.TrimSpace(text[len(robotCmdSwitch)+1:])
|
||||
case strings.HasPrefix(text, robotCmdContinue+" "):
|
||||
id = strings.TrimSpace(text[len(robotCmdContinue)+1:])
|
||||
case strings.HasPrefix(text, "switch "):
|
||||
id = strings.TrimSpace(text[7:])
|
||||
default:
|
||||
id = strings.TrimSpace(text[9:])
|
||||
}
|
||||
return h.cmdSwitch(platform, userID, id), true
|
||||
case text == robotCmdNew || text == "new":
|
||||
return h.cmdNew(platform, userID), true
|
||||
case text == robotCmdClear || text == "clear":
|
||||
return h.cmdClear(platform, userID), true
|
||||
case text == robotCmdCurrent || text == "current":
|
||||
return h.cmdCurrent(platform, userID), true
|
||||
case text == robotCmdStop || text == "stop":
|
||||
return h.cmdStop(platform, userID), true
|
||||
case text == robotCmdRoles || text == robotCmdRolesList || text == "roles":
|
||||
return h.cmdRoles(), true
|
||||
case strings.HasPrefix(text, robotCmdRoles+" ") || strings.HasPrefix(text, robotCmdSwitchRole+" ") || strings.HasPrefix(text, "role "):
|
||||
var roleName string
|
||||
switch {
|
||||
case strings.HasPrefix(text, robotCmdRoles+" "):
|
||||
roleName = strings.TrimSpace(text[len(robotCmdRoles)+1:])
|
||||
case strings.HasPrefix(text, robotCmdSwitchRole+" "):
|
||||
roleName = strings.TrimSpace(text[len(robotCmdSwitchRole)+1:])
|
||||
default:
|
||||
roleName = strings.TrimSpace(text[5:])
|
||||
}
|
||||
return h.cmdSwitchRole(platform, userID, roleName), true
|
||||
case strings.HasPrefix(text, robotCmdDelete+" ") || strings.HasPrefix(text, "delete "):
|
||||
var convID string
|
||||
if strings.HasPrefix(text, robotCmdDelete+" ") {
|
||||
convID = strings.TrimSpace(text[len(robotCmdDelete)+1:])
|
||||
} else {
|
||||
convID = strings.TrimSpace(text[7:])
|
||||
}
|
||||
return h.cmdDelete(platform, userID, convID), true
|
||||
case text == robotCmdVersion || text == "version":
|
||||
return h.cmdVersion(), true
|
||||
default:
|
||||
return "", false
|
||||
}
|
||||
}
|
||||
|
||||
// —————— 企业微信 ——————
|
||||
|
||||
// wecomXML 企业微信回调 XML(明文模式下的简化结构;加密模式需先解密再解析)
|
||||
@@ -418,14 +431,14 @@ type wecomXML struct {
|
||||
Encrypt string `xml:"Encrypt"` // 加密模式下消息在此
|
||||
}
|
||||
|
||||
// wecomReplyXML 被动回复 XML
|
||||
// wecomReplyXML 被动回复 XML(仅用于兼容,当前使用手动构造 XML)
|
||||
type wecomReplyXML struct {
|
||||
XMLName xml.Name `xml:"xml"`
|
||||
ToUserName string `xml:"ToUserName"`
|
||||
FromUserName string `xml:"FromUserName"`
|
||||
CreateTime int64 `xml:"CreateTime"`
|
||||
MsgType string `xml:"MsgType"`
|
||||
Content string `xml:"Content"`
|
||||
FromUserName string `xml:"FromUserName"`
|
||||
CreateTime int64 `xml:"CreateTime"`
|
||||
MsgType string `xml:"MsgType"`
|
||||
Content string `xml:"Content"`
|
||||
}
|
||||
|
||||
// HandleWecomGET 企业微信 URL 校验(GET)
|
||||
@@ -434,15 +447,51 @@ func (h *RobotHandler) HandleWecomGET(c *gin.Context) {
|
||||
c.String(http.StatusNotFound, "")
|
||||
return
|
||||
}
|
||||
// Gin 的 Query() 会自动 URL 解码,拿到的就是正确的 base64 字符串
|
||||
echostr := c.Query("echostr")
|
||||
msgSignature := c.Query("msg_signature")
|
||||
timestamp := c.Query("timestamp")
|
||||
nonce := c.Query("nonce")
|
||||
|
||||
// 验证签名:将 token、timestamp、nonce、echostr 四个参数排序后拼接计算 SHA1
|
||||
signature := h.signWecomRequest(h.config.Robots.Wecom.Token, timestamp, nonce, echostr)
|
||||
if signature != msgSignature {
|
||||
h.logger.Warn("企业微信 URL 验证签名失败", zap.String("expected", msgSignature), zap.String("got", signature))
|
||||
c.String(http.StatusBadRequest, "invalid signature")
|
||||
return
|
||||
}
|
||||
|
||||
if echostr == "" {
|
||||
c.String(http.StatusBadRequest, "missing echostr")
|
||||
return
|
||||
}
|
||||
// 明文模式时企业微信可能直接传 echostr,先直接返回以通过校验
|
||||
|
||||
// 如果配置了 EncodingAESKey,说明是加密模式,需要解密 echostr
|
||||
if h.config.Robots.Wecom.EncodingAESKey != "" {
|
||||
decrypted, err := wecomDecrypt(h.config.Robots.Wecom.EncodingAESKey, echostr)
|
||||
if err != nil {
|
||||
h.logger.Warn("企业微信 echostr 解密失败", zap.Error(err))
|
||||
c.String(http.StatusBadRequest, "decrypt failed")
|
||||
return
|
||||
}
|
||||
c.String(http.StatusOK, string(decrypted))
|
||||
return
|
||||
}
|
||||
|
||||
// 明文模式直接返回 echostr
|
||||
c.String(http.StatusOK, echostr)
|
||||
}
|
||||
|
||||
// signWecomRequest 生成企业微信请求签名
|
||||
// 企业微信签名算法:将 token、timestamp、nonce、echostr 四个值排序后拼接成字符串,再计算 SHA1
|
||||
func (h *RobotHandler) signWecomRequest(token, timestamp, nonce, echostr string) string {
|
||||
strs := []string{token, timestamp, nonce, echostr}
|
||||
sort.Strings(strs)
|
||||
s := strings.Join(strs, "")
|
||||
hash := sha1.Sum([]byte(s))
|
||||
return fmt.Sprintf("%x", hash)
|
||||
}
|
||||
|
||||
// wecomDecrypt 企业微信消息解密(AES-256-CBC,PKCS7,明文格式:16字节随机+4字节长度+消息+corpID)
|
||||
func wecomDecrypt(encodingAESKey, encryptedB64 string) ([]byte, error) {
|
||||
key, err := base64.StdEncoding.DecodeString(encodingAESKey + "=")
|
||||
@@ -484,54 +533,228 @@ func wecomDecrypt(encodingAESKey, encryptedB64 string) ([]byte, error) {
|
||||
return plain[20 : 20+msgLen], nil
|
||||
}
|
||||
|
||||
// wecomEncrypt 企业微信消息加密(AES-256-CBC,PKCS7,明文格式:16字节随机+4字节长度+消息+corpID)
|
||||
func wecomEncrypt(encodingAESKey, message, corpID string) (string, error) {
|
||||
key, err := base64.StdEncoding.DecodeString(encodingAESKey + "=")
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if len(key) != 32 {
|
||||
return "", fmt.Errorf("encoding_aes_key 解码后应为 32 字节")
|
||||
}
|
||||
// 构造明文:16 字节随机 + 4 字节长度 (大端) + 消息 + corpID
|
||||
random := make([]byte, 16)
|
||||
if _, err := rand.Read(random); err != nil {
|
||||
// 降级方案:使用时间戳生成随机数
|
||||
for i := range random {
|
||||
random[i] = byte(time.Now().UnixNano() % 256)
|
||||
}
|
||||
}
|
||||
msgLen := len(message)
|
||||
msgBytes := []byte(message)
|
||||
corpBytes := []byte(corpID)
|
||||
plain := make([]byte, 16+4+msgLen+len(corpBytes))
|
||||
copy(plain[:16], random)
|
||||
binary.BigEndian.PutUint32(plain[16:20], uint32(msgLen))
|
||||
copy(plain[20:20+msgLen], msgBytes)
|
||||
copy(plain[20+msgLen:], corpBytes)
|
||||
// PKCS7 填充
|
||||
padding := aes.BlockSize - len(plain)%aes.BlockSize
|
||||
pad := bytes.Repeat([]byte{byte(padding)}, padding)
|
||||
plain = append(plain, pad...)
|
||||
// AES-256-CBC 加密
|
||||
block, err := aes.NewCipher(key)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
iv := key[:16]
|
||||
ciphertext := make([]byte, len(plain))
|
||||
mode := cipher.NewCBCEncrypter(block, iv)
|
||||
mode.CryptBlocks(ciphertext, plain)
|
||||
return base64.StdEncoding.EncodeToString(ciphertext), nil
|
||||
}
|
||||
|
||||
// HandleWecomPOST 企业微信消息回调(POST),支持明文与加密模式
|
||||
func (h *RobotHandler) HandleWecomPOST(c *gin.Context) {
|
||||
if !h.config.Robots.Wecom.Enabled {
|
||||
h.logger.Debug("企业微信机器人未启用,跳过请求")
|
||||
c.String(http.StatusOK, "")
|
||||
return
|
||||
}
|
||||
bodyRaw, _ := io.ReadAll(c.Request.Body)
|
||||
// 从 URL 获取签名参数(加密模式回复时需要用到)
|
||||
timestamp := c.Query("timestamp")
|
||||
nonce := c.Query("nonce")
|
||||
msgSignature := c.Query("msg_signature")
|
||||
|
||||
// 先读取请求体,后续解析/签名验证都会用到
|
||||
bodyRaw, err := io.ReadAll(c.Request.Body)
|
||||
if err != nil {
|
||||
h.logger.Warn("企业微信 POST 读取请求体失败", zap.Error(err))
|
||||
c.String(http.StatusOK, "")
|
||||
return
|
||||
}
|
||||
h.logger.Debug("企业微信 POST 收到请求", zap.String("body", string(bodyRaw)))
|
||||
|
||||
// 验证请求签名防止伪造。企业微信签名算法同 URL 验证,使用 token、timestamp、nonce、 Encrypt 四个字段
|
||||
if msgSignature != "" {
|
||||
var tmp wecomXML
|
||||
if err := xml.Unmarshal(bodyRaw, &tmp); err == nil {
|
||||
expected := h.signWecomRequest(h.config.Robots.Wecom.Token, timestamp, nonce, tmp.Encrypt)
|
||||
if expected != msgSignature {
|
||||
h.logger.Warn("企业微信 POST 签名验证失败", zap.String("expected", expected), zap.String("got", msgSignature))
|
||||
c.String(http.StatusOK, "")
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var body wecomXML
|
||||
if err := xml.Unmarshal(bodyRaw, &body); err != nil {
|
||||
h.logger.Debug("企业微信 POST 解析 XML 失败", zap.Error(err))
|
||||
h.logger.Warn("企业微信 POST 解析 XML 失败", zap.Error(err))
|
||||
c.String(http.StatusOK, "")
|
||||
return
|
||||
}
|
||||
h.logger.Debug("企业微信 XML 解析成功", zap.String("ToUserName", body.ToUserName), zap.String("FromUserName", body.FromUserName), zap.String("MsgType", body.MsgType), zap.String("Content", body.Content), zap.String("Encrypt", body.Encrypt))
|
||||
|
||||
// 保存企业 ID(用于明文模式回复)
|
||||
enterpriseID := body.ToUserName
|
||||
|
||||
// 加密模式:先解密再解析内层 XML
|
||||
if body.Encrypt != "" && h.config.Robots.Wecom.EncodingAESKey != "" {
|
||||
h.logger.Debug("企业微信进入加密模式解密流程")
|
||||
decrypted, err := wecomDecrypt(h.config.Robots.Wecom.EncodingAESKey, body.Encrypt)
|
||||
if err != nil {
|
||||
h.logger.Warn("企业微信消息解密失败", zap.Error(err))
|
||||
c.String(http.StatusOK, "")
|
||||
return
|
||||
}
|
||||
h.logger.Debug("企业微信解密成功", zap.String("decrypted", string(decrypted)))
|
||||
if err := xml.Unmarshal(decrypted, &body); err != nil {
|
||||
h.logger.Warn("企业微信解密后 XML 解析失败", zap.Error(err))
|
||||
c.String(http.StatusOK, "")
|
||||
return
|
||||
}
|
||||
h.logger.Debug("企业微信内层 XML 解析成功", zap.String("FromUserName", body.FromUserName), zap.String("Content", body.Content))
|
||||
}
|
||||
if body.MsgType != "text" {
|
||||
c.XML(http.StatusOK, wecomReplyXML{
|
||||
ToUserName: body.FromUserName,
|
||||
FromUserName: body.ToUserName,
|
||||
CreateTime: time.Now().Unix(),
|
||||
MsgType: "text",
|
||||
Content: "暂仅支持文本消息,请发送文字。",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
userID := body.FromUserName
|
||||
text := strings.TrimSpace(body.Content)
|
||||
reply := h.HandleMessage("wecom", userID, text)
|
||||
// 加密模式需加密回复(此处简化为明文回复;若企业要求加密需再实现加密)
|
||||
c.XML(http.StatusOK, wecomReplyXML{
|
||||
ToUserName: body.FromUserName,
|
||||
FromUserName: body.ToUserName,
|
||||
CreateTime: time.Now().Unix(),
|
||||
MsgType: "text",
|
||||
Content: reply,
|
||||
})
|
||||
|
||||
// 限制回复内容长度(企业微信限制 2048 字节)
|
||||
maxReplyLen := 2000
|
||||
limitReply := func(s string) string {
|
||||
if len(s) > maxReplyLen {
|
||||
return s[:maxReplyLen] + "\n\n(内容过长,已截断)"
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
if body.MsgType != "text" {
|
||||
h.logger.Debug("企业微信收到非文本消息", zap.String("MsgType", body.MsgType))
|
||||
h.sendWecomReply(c, userID, enterpriseID, limitReply("暂仅支持文本消息,请发送文字。"), timestamp, nonce)
|
||||
return
|
||||
}
|
||||
|
||||
// 文本消息:先判断是否为内置命令(如 帮助/列表/新对话 等),这类命令处理很快,可以直接走被动回复,避免依赖主动发送 API。
|
||||
if cmdReply, ok := h.handleRobotCommand("wecom", userID, text); ok {
|
||||
h.logger.Debug("企业微信收到命令消息,走被动回复", zap.String("userID", userID), zap.String("text", text))
|
||||
h.sendWecomReply(c, userID, enterpriseID, limitReply(cmdReply), timestamp, nonce)
|
||||
return
|
||||
}
|
||||
|
||||
h.logger.Debug("企业微信开始处理消息(异步 AI)", zap.String("userID", userID), zap.String("text", text))
|
||||
|
||||
// 企业微信被动回复有 5 秒超时限制,而 AI 调用通常超过该时长。
|
||||
// 这里采用推荐做法:立即返回 success(或空串),然后通过主动发送接口推送完整回复。
|
||||
c.String(http.StatusOK, "success")
|
||||
|
||||
// 异步处理消息并通过企业微信主动消息接口发送结果
|
||||
go func() {
|
||||
reply := h.HandleMessage("wecom", userID, text)
|
||||
reply = limitReply(reply)
|
||||
h.logger.Debug("企业微信消息处理完成", zap.String("userID", userID), zap.String("reply", reply))
|
||||
// 调用企业微信 API 主动发送消息
|
||||
h.sendWecomMessageViaAPI(userID, enterpriseID, reply)
|
||||
}()
|
||||
}
|
||||
|
||||
// sendWecomReply 发送企业微信回复(加密模式自动加密)
|
||||
// 参数:toUser=用户 ID, fromUser=企业 ID(明文模式)/CorpID(加密模式), content=回复内容,timestamp/nonce=请求参数
|
||||
func (h *RobotHandler) sendWecomReply(c *gin.Context, toUser, fromUser, content, timestamp, nonce string) {
|
||||
// 加密模式:判断 EncodingAESKey 是否配置
|
||||
if h.config.Robots.Wecom.EncodingAESKey != "" {
|
||||
// 加密模式使用 CorpID 进行加密
|
||||
corpID := h.config.Robots.Wecom.CorpID
|
||||
if corpID == "" {
|
||||
h.logger.Warn("企业微信加密模式缺少 CorpID 配置")
|
||||
c.String(http.StatusOK, "")
|
||||
return
|
||||
}
|
||||
|
||||
// 构造完整的明文 XML 回复(格式严格按企业微信文档要求)
|
||||
plainResp := fmt.Sprintf(`<xml>
|
||||
<ToUserName><![CDATA[%s]]></ToUserName>
|
||||
<FromUserName><![CDATA[%s]]></FromUserName>
|
||||
<CreateTime>%d</CreateTime>
|
||||
<MsgType><![CDATA[text]]></MsgType>
|
||||
<Content><![CDATA[%s]]></Content>
|
||||
</xml>`, toUser, fromUser, time.Now().Unix(), content)
|
||||
|
||||
encrypted, err := wecomEncrypt(h.config.Robots.Wecom.EncodingAESKey, plainResp, corpID)
|
||||
if err != nil {
|
||||
h.logger.Warn("企业微信回复加密失败", zap.Error(err))
|
||||
c.String(http.StatusOK, "")
|
||||
return
|
||||
}
|
||||
// 使用请求中的 timestamp/nonce 生成签名(企业微信要求回复时使用与请求相同的 timestamp 和 nonce)
|
||||
msgSignature := h.signWecomRequest(h.config.Robots.Wecom.Token, timestamp, nonce, encrypted)
|
||||
|
||||
h.logger.Debug("企业微信发送加密回复",
|
||||
zap.String("Encrypt", encrypted[:50]+"..."),
|
||||
zap.String("MsgSignature", msgSignature),
|
||||
zap.String("TimeStamp", timestamp),
|
||||
zap.String("Nonce", nonce))
|
||||
|
||||
// 加密模式仅返回 4 个核心字段(企业微信官方要求)
|
||||
xmlResp := fmt.Sprintf(`<xml><Encrypt><![CDATA[%s]]></Encrypt><MsgSignature><![CDATA[%s]]></MsgSignature><TimeStamp><![CDATA[%s]]></TimeStamp><Nonce><![CDATA[%s]]></Nonce></xml>`, encrypted, msgSignature, timestamp, nonce)
|
||||
// also log the final response body so we can cross-check with the
|
||||
// network traffic or developer console
|
||||
h.logger.Debug("企业微信加密回复包", zap.String("xml", xmlResp))
|
||||
// for additional confidence, decrypt the payload ourselves and log it
|
||||
if dec, err2 := wecomDecrypt(h.config.Robots.Wecom.EncodingAESKey, encrypted); err2 == nil {
|
||||
h.logger.Debug("企业微信加密回复解密检查", zap.String("plain", string(dec)))
|
||||
} else {
|
||||
h.logger.Warn("企业微信加密回复解密检查失败", zap.Error(err2))
|
||||
}
|
||||
|
||||
// 使用 c.Writer.Write 直接写入响应,避免 c.String 的转义问题
|
||||
c.Writer.WriteHeader(http.StatusOK)
|
||||
// use text/xml as that's what WeCom examples show
|
||||
c.Writer.Header().Set("Content-Type", "text/xml; charset=utf-8")
|
||||
_, _ = c.Writer.Write([]byte(xmlResp))
|
||||
h.logger.Debug("企业微信加密回复已发送")
|
||||
return
|
||||
}
|
||||
|
||||
// 明文模式
|
||||
h.logger.Debug("企业微信发送明文回复", zap.String("ToUserName", toUser), zap.String("FromUserName", fromUser), zap.String("Content", content[:50]+"..."))
|
||||
|
||||
// 手动构造 XML 响应(使用 CDATA 包裹所有字段,并包含 AgentID)
|
||||
xmlResp := fmt.Sprintf(`<xml>
|
||||
<ToUserName><![CDATA[%s]]></ToUserName>
|
||||
<FromUserName><![CDATA[%s]]></FromUserName>
|
||||
<CreateTime>%d</CreateTime>
|
||||
<MsgType><![CDATA[text]]></MsgType>
|
||||
<Content><![CDATA[%s]]></Content>
|
||||
</xml>`, toUser, fromUser, time.Now().Unix(), content)
|
||||
|
||||
// log the exact plaintext response for debugging
|
||||
h.logger.Debug("企业微信明文回复包", zap.String("xml", xmlResp))
|
||||
|
||||
// use text/xml as recommended by WeCom docs
|
||||
c.Header("Content-Type", "text/xml; charset=utf-8")
|
||||
c.String(http.StatusOK, xmlResp)
|
||||
h.logger.Debug("企业微信明文回复已发送")
|
||||
}
|
||||
|
||||
// —————— 测试接口(需登录,用于验证机器人逻辑,无需钉钉/飞书客户端) ——————
|
||||
@@ -562,6 +785,87 @@ func (h *RobotHandler) HandleRobotTest(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{"reply": reply})
|
||||
}
|
||||
|
||||
// sendWecomMessageViaAPI 通过企业微信 API 主动发送消息(用于异步处理后的结果发送)
|
||||
func (h *RobotHandler) sendWecomMessageViaAPI(toUser, toParty, content string) {
|
||||
if !h.config.Robots.Wecom.Enabled {
|
||||
return
|
||||
}
|
||||
|
||||
secret := h.config.Robots.Wecom.Secret
|
||||
corpID := h.config.Robots.Wecom.CorpID
|
||||
agentID := h.config.Robots.Wecom.AgentID
|
||||
|
||||
if secret == "" || corpID == "" {
|
||||
h.logger.Warn("企业微信主动 API 缺少 secret 或 corpID 配置")
|
||||
return
|
||||
}
|
||||
|
||||
// 第 1 步:获取 access_token
|
||||
tokenURL := fmt.Sprintf("https://qyapi.weixin.qq.com/cgi-bin/gettoken?corpid=%s&corpsecret=%s", corpID, secret)
|
||||
resp, err := http.Get(tokenURL)
|
||||
if err != nil {
|
||||
h.logger.Warn("企业微信获取 token 失败", zap.Error(err))
|
||||
return
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
var tokenResp struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
ErrCode int `json:"errcode"`
|
||||
ErrMsg string `json:"errmsg"`
|
||||
}
|
||||
if err := json.NewDecoder(resp.Body).Decode(&tokenResp); err != nil {
|
||||
h.logger.Warn("企业微信 token 响应解析失败", zap.Error(err))
|
||||
return
|
||||
}
|
||||
if tokenResp.ErrCode != 0 {
|
||||
h.logger.Warn("企业微信 token 获取错误", zap.String("errmsg", tokenResp.ErrMsg), zap.Int("errcode", tokenResp.ErrCode))
|
||||
return
|
||||
}
|
||||
|
||||
// 第 2 步:构造发送消息请求
|
||||
msgReq := map[string]interface{}{
|
||||
"touser": toUser,
|
||||
"msgtype": "text",
|
||||
"agentid": agentID,
|
||||
"text": map[string]interface{}{
|
||||
"content": content,
|
||||
},
|
||||
}
|
||||
|
||||
msgBody, err := json.Marshal(msgReq)
|
||||
if err != nil {
|
||||
h.logger.Warn("企业微信消息序列化失败", zap.Error(err))
|
||||
return
|
||||
}
|
||||
|
||||
// 第 3 步:发送消息
|
||||
sendURL := fmt.Sprintf("https://qyapi.weixin.qq.com/cgi-bin/message/send?access_token=%s", tokenResp.AccessToken)
|
||||
msgResp, err := http.Post(sendURL, "application/json", bytes.NewReader(msgBody))
|
||||
if err != nil {
|
||||
h.logger.Warn("企业微信主动发送消息失败", zap.Error(err))
|
||||
return
|
||||
}
|
||||
defer msgResp.Body.Close()
|
||||
|
||||
var sendResp struct {
|
||||
ErrCode int `json:"errcode"`
|
||||
ErrMsg string `json:"errmsg"`
|
||||
InvalidUser string `json:"invaliduser"`
|
||||
MsgID string `json:"msgid"`
|
||||
}
|
||||
if err := json.NewDecoder(msgResp.Body).Decode(&sendResp); err != nil {
|
||||
h.logger.Warn("企业微信发送响应解析失败", zap.Error(err))
|
||||
return
|
||||
}
|
||||
|
||||
if sendResp.ErrCode == 0 {
|
||||
h.logger.Debug("企业微信主动发送消息成功", zap.String("msgid", sendResp.MsgID))
|
||||
} else {
|
||||
h.logger.Warn("企业微信主动发送消息失败", zap.String("errmsg", sendResp.ErrMsg), zap.Int("errcode", sendResp.ErrCode), zap.String("invaliduser", sendResp.InvalidUser))
|
||||
}
|
||||
}
|
||||
|
||||
// —————— 钉钉 ——————
|
||||
|
||||
// HandleDingtalkPOST 钉钉事件回调(流式接入等);当前为占位,返回 200
|
||||
|
||||
@@ -544,9 +544,21 @@ func (idx *Indexer) IndexItem(ctx context.Context, itemID string) error {
|
||||
idx.mu.Unlock()
|
||||
}
|
||||
|
||||
// 如果连续失败 2 个块,立即停止处理该知识项(降低阈值,更快停止)
|
||||
// 如果连续失败 5 个块,立即停止处理该知识项
|
||||
// 这样可以避免继续浪费 API 调用,同时也能更快地检测到配置问题
|
||||
if itemErrorCount >= 2 {
|
||||
// 对于大文档(超过 10 个块),允许失败比例不超过 50%
|
||||
maxConsecutiveFailures := 5
|
||||
if len(chunks) > 10 && itemErrorCount > len(chunks)/2 {
|
||||
idx.logger.Error("知识项向量化失败比例过高,停止处理",
|
||||
zap.String("itemId", itemID),
|
||||
zap.Int("totalChunks", len(chunks)),
|
||||
zap.Int("failedChunks", itemErrorCount),
|
||||
zap.Int("firstErrorChunkIndex", firstErrorChunkIndex),
|
||||
zap.Error(firstError),
|
||||
)
|
||||
return fmt.Errorf("知识项向量化失败比例过高 (%d/%d个块失败): %v", itemErrorCount, len(chunks), firstError)
|
||||
}
|
||||
if itemErrorCount >= maxConsecutiveFailures {
|
||||
idx.logger.Error("知识项连续向量化失败,停止处理",
|
||||
zap.String("itemId", itemID),
|
||||
zap.Int("totalChunks", len(chunks)),
|
||||
@@ -649,7 +661,7 @@ func (idx *Indexer) RebuildIndex(ctx context.Context) error {
|
||||
|
||||
failedCount := 0
|
||||
consecutiveFailures := 0
|
||||
maxConsecutiveFailures := 2 // 连续失败 2 次后立即停止(降低阈值,更快停止)
|
||||
maxConsecutiveFailures := 5 // 连续失败 5 次后立即停止(允许偶尔的临时错误)
|
||||
firstFailureItemID := ""
|
||||
var firstFailureError error
|
||||
|
||||
|
||||
@@ -657,7 +657,7 @@ func (m *Manager) UpdateItem(id, category, title, content string) (*KnowledgeIte
|
||||
|
||||
// 删除旧目录(如果为空)
|
||||
oldDir := filepath.Dir(item.FilePath)
|
||||
if entries, err := os.ReadDir(oldDir); err == nil && len(entries) == 0 {
|
||||
if isEmpty, _ := isEmptyDir(oldDir); isEmpty {
|
||||
// 只有当目录不是知识库根目录时才删除(避免删除根目录)
|
||||
if oldDir != m.basePath {
|
||||
if err := os.Remove(oldDir); err != nil {
|
||||
@@ -712,7 +712,7 @@ func (m *Manager) DeleteItem(id string) error {
|
||||
|
||||
// 删除空目录(如果为空)
|
||||
dir := filepath.Dir(filePath)
|
||||
if entries, err := os.ReadDir(dir); err == nil && len(entries) == 0 {
|
||||
if isEmpty, _ := isEmptyDir(dir); isEmpty {
|
||||
// 只有当目录不是知识库根目录时才删除(避免删除根目录)
|
||||
if dir != m.basePath {
|
||||
if err := os.Remove(dir); err != nil {
|
||||
@@ -724,6 +724,21 @@ func (m *Manager) DeleteItem(id string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// isEmptyDir 检查目录是否为空(忽略隐藏文件和 . 开头的文件)
|
||||
func isEmptyDir(dir string) (bool, error) {
|
||||
entries, err := os.ReadDir(dir)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
for _, entry := range entries {
|
||||
// 忽略隐藏文件(以 . 开头)
|
||||
if !strings.HasPrefix(entry.Name(), ".") {
|
||||
return false, nil
|
||||
}
|
||||
}
|
||||
return true, nil
|
||||
}
|
||||
|
||||
// LogRetrieval 记录检索日志
|
||||
func (m *Manager) LogRetrieval(conversationID, messageID, query, riskType string, retrievedItems []string) error {
|
||||
id := uuid.New().String()
|
||||
|
||||
@@ -69,8 +69,8 @@ func cosineSimilarity(a, b []float32) float64 {
|
||||
return dotProduct / (math.Sqrt(normA) * math.Sqrt(normB))
|
||||
}
|
||||
|
||||
// bm25Score 计算BM25分数(改进版,更接近标准BM25)
|
||||
// 注意:这是单文档版本的BM25,缺少全局IDF,但比之前的简化版本更准确
|
||||
// bm25Score 计算 BM25 分数(带缓存的改进版本)
|
||||
// 注意:由于缺少全局文档统计,使用简化 IDF 计算
|
||||
func (r *Retriever) bm25Score(query, text string) float64 {
|
||||
queryTerms := strings.Fields(strings.ToLower(query))
|
||||
if len(queryTerms) == 0 {
|
||||
@@ -83,44 +83,56 @@ func (r *Retriever) bm25Score(query, text string) float64 {
|
||||
return 0.0
|
||||
}
|
||||
|
||||
// BM25参数
|
||||
k1 := 1.5 // 词频饱和度参数
|
||||
b := 0.75 // 长度归一化参数
|
||||
avgDocLength := 100.0 // 估算的平均文档长度(用于归一化)
|
||||
// BM25 参数(标准值)
|
||||
k1 := 1.2 // 词频饱和度参数(标准范围 1.2-2.0)
|
||||
b := 0.75 // 长度归一化参数(标准值)
|
||||
avgDocLength := 150.0 // 估算的平均文档长度(基于典型知识块大小)
|
||||
docLength := float64(len(textTerms))
|
||||
|
||||
score := 0.0
|
||||
for _, term := range queryTerms {
|
||||
// 计算词频(TF)
|
||||
termFreq := 0
|
||||
for _, textTerm := range textTerms {
|
||||
if textTerm == term {
|
||||
termFreq++
|
||||
}
|
||||
}
|
||||
|
||||
if termFreq > 0 {
|
||||
// BM25公式的核心部分
|
||||
// TF部分:termFreq / (termFreq + k1 * (1 - b + b * (docLength / avgDocLength)))
|
||||
tf := float64(termFreq)
|
||||
lengthNorm := 1 - b + b*(docLength/avgDocLength)
|
||||
tfScore := tf / (tf + k1*lengthNorm)
|
||||
|
||||
// 简化IDF:使用词长度作为权重(短词通常更重要)
|
||||
// 实际BM25需要全局文档统计,这里用简化版本
|
||||
idfWeight := 1.0
|
||||
if len(term) > 2 {
|
||||
// 长词稍微降低权重(但实际BM25中,罕见词IDF更高)
|
||||
idfWeight = 1.0 + math.Log(1.0+float64(len(term))/10.0)
|
||||
}
|
||||
|
||||
score += tfScore * idfWeight
|
||||
}
|
||||
// 计算词频映射
|
||||
textTermFreq := make(map[string]int, len(textTerms))
|
||||
for _, term := range textTerms {
|
||||
textTermFreq[term]++
|
||||
}
|
||||
|
||||
// 归一化到0-1范围
|
||||
score := 0.0
|
||||
matchedQueryTerms := 0
|
||||
|
||||
for _, term := range queryTerms {
|
||||
termFreq, exists := textTermFreq[term]
|
||||
if !exists || termFreq == 0 {
|
||||
continue
|
||||
}
|
||||
matchedQueryTerms++
|
||||
|
||||
// BM25 TF 计算公式
|
||||
tf := float64(termFreq)
|
||||
lengthNorm := 1 - b + b*(docLength/avgDocLength)
|
||||
tfScore := tf / (tf + k1*lengthNorm)
|
||||
|
||||
// 改进的 IDF 计算:使用词长度和出现频率估算
|
||||
// 短词(2-3 字符)通常更重要,长词 IDF 略低
|
||||
idfWeight := 1.0
|
||||
termLen := len(term)
|
||||
if termLen <= 2 {
|
||||
// 极短词(如 go, js)给予更高权重
|
||||
idfWeight = 1.2 + math.Log(1.0+float64(termFreq)/20.0)
|
||||
} else if termLen <= 4 {
|
||||
// 短词(4 字符)标准权重
|
||||
idfWeight = 1.0 + math.Log(1.0+float64(termFreq)/15.0)
|
||||
} else {
|
||||
// 长词稍微降低权重
|
||||
idfWeight = 0.9 + math.Log(1.0+float64(termFreq)/10.0)
|
||||
}
|
||||
|
||||
score += tfScore * idfWeight
|
||||
}
|
||||
|
||||
// 归一化:考虑匹配的查询词比例
|
||||
if len(queryTerms) > 0 {
|
||||
score = score / float64(len(queryTerms))
|
||||
// 使用匹配比例作为额外因子
|
||||
matchRatio := float64(matchedQueryTerms) / float64(len(queryTerms))
|
||||
score = (score / float64(len(queryTerms))) * (1 + matchRatio) / 2
|
||||
}
|
||||
|
||||
return math.Min(score, 1.0)
|
||||
@@ -173,7 +185,7 @@ func (r *Retriever) Search(ctx context.Context, req *SearchRequest) ([]*Retrieva
|
||||
SELECT e.id, e.item_id, e.chunk_index, e.chunk_text, e.embedding, i.category, i.title
|
||||
FROM knowledge_embeddings e
|
||||
JOIN knowledge_base_items i ON e.item_id = i.id
|
||||
WHERE i.category = ? COLLATE NOCASE
|
||||
WHERE TRIM(i.category) = TRIM(?) COLLATE NOCASE
|
||||
`, req.RiskType)
|
||||
} else {
|
||||
rows, err = r.db.Query(`
|
||||
@@ -357,7 +369,10 @@ func (r *Retriever) Search(ctx context.Context, req *SearchRequest) ([]*Retrieva
|
||||
zap.Float64("threshold", threshold),
|
||||
zap.Float64("maxSimilarity", maxSimilarity),
|
||||
)
|
||||
} else if len(filteredCandidates) > topK {
|
||||
}
|
||||
|
||||
// 统一在最终返回前严格限制 Top-K 数量
|
||||
if len(filteredCandidates) > topK {
|
||||
// 如果过滤后结果太多,只取Top-K
|
||||
filteredCandidates = filteredCandidates[:topK]
|
||||
}
|
||||
|
||||
+24
-42
@@ -5,6 +5,14 @@ import (
|
||||
"time"
|
||||
)
|
||||
|
||||
// formatTime 格式化时间为 RFC3339 格式,零时间返回空字符串
|
||||
func formatTime(t time.Time) string {
|
||||
if t.IsZero() {
|
||||
return ""
|
||||
}
|
||||
return t.Format(time.RFC3339)
|
||||
}
|
||||
|
||||
// KnowledgeItem 知识库项
|
||||
type KnowledgeItem struct {
|
||||
ID string `json:"id"`
|
||||
@@ -22,12 +30,12 @@ type KnowledgeItemSummary struct {
|
||||
Category string `json:"category"`
|
||||
Title string `json:"title"`
|
||||
FilePath string `json:"filePath"`
|
||||
Content string `json:"content,omitempty"` // 可选:内容预览(如果提供,通常只包含前150字符)
|
||||
Content string `json:"content,omitempty"` // 可选:内容预览(如果提供,通常只包含前 150 字符)
|
||||
CreatedAt time.Time `json:"createdAt"`
|
||||
UpdatedAt time.Time `json:"updatedAt"`
|
||||
}
|
||||
|
||||
// MarshalJSON 自定义JSON序列化,确保时间格式正确
|
||||
// MarshalJSON 自定义 JSON 序列化,确保时间格式正确
|
||||
func (k *KnowledgeItemSummary) MarshalJSON() ([]byte, error) {
|
||||
type Alias KnowledgeItemSummary
|
||||
aux := &struct {
|
||||
@@ -37,25 +45,12 @@ func (k *KnowledgeItemSummary) MarshalJSON() ([]byte, error) {
|
||||
}{
|
||||
Alias: (*Alias)(k),
|
||||
}
|
||||
|
||||
// 格式化创建时间
|
||||
if k.CreatedAt.IsZero() {
|
||||
aux.CreatedAt = ""
|
||||
} else {
|
||||
aux.CreatedAt = k.CreatedAt.Format(time.RFC3339)
|
||||
}
|
||||
|
||||
// 格式化更新时间
|
||||
if k.UpdatedAt.IsZero() {
|
||||
aux.UpdatedAt = ""
|
||||
} else {
|
||||
aux.UpdatedAt = k.UpdatedAt.Format(time.RFC3339)
|
||||
}
|
||||
|
||||
aux.CreatedAt = formatTime(k.CreatedAt)
|
||||
aux.UpdatedAt = formatTime(k.UpdatedAt)
|
||||
return json.Marshal(aux)
|
||||
}
|
||||
|
||||
// MarshalJSON 自定义JSON序列化,确保时间格式正确
|
||||
// MarshalJSON 自定义 JSON 序列化,确保时间格式正确
|
||||
func (k *KnowledgeItem) MarshalJSON() ([]byte, error) {
|
||||
type Alias KnowledgeItem
|
||||
aux := &struct {
|
||||
@@ -65,21 +60,8 @@ func (k *KnowledgeItem) MarshalJSON() ([]byte, error) {
|
||||
}{
|
||||
Alias: (*Alias)(k),
|
||||
}
|
||||
|
||||
// 格式化创建时间
|
||||
if k.CreatedAt.IsZero() {
|
||||
aux.CreatedAt = ""
|
||||
} else {
|
||||
aux.CreatedAt = k.CreatedAt.Format(time.RFC3339)
|
||||
}
|
||||
|
||||
// 格式化更新时间
|
||||
if k.UpdatedAt.IsZero() {
|
||||
aux.UpdatedAt = ""
|
||||
} else {
|
||||
aux.UpdatedAt = k.UpdatedAt.Format(time.RFC3339)
|
||||
}
|
||||
|
||||
aux.CreatedAt = formatTime(k.CreatedAt)
|
||||
aux.UpdatedAt = formatTime(k.UpdatedAt)
|
||||
return json.Marshal(aux)
|
||||
}
|
||||
|
||||
@@ -89,7 +71,7 @@ type KnowledgeChunk struct {
|
||||
ItemID string `json:"itemId"`
|
||||
ChunkIndex int `json:"chunkIndex"`
|
||||
ChunkText string `json:"chunkText"`
|
||||
Embedding []float32 `json:"-"` // 向量嵌入,不序列化到JSON
|
||||
Embedding []float32 `json:"-"` // 向量嵌入,不序列化到 JSON
|
||||
CreatedAt time.Time `json:"createdAt"`
|
||||
}
|
||||
|
||||
@@ -108,11 +90,11 @@ type RetrievalLog struct {
|
||||
MessageID string `json:"messageId,omitempty"`
|
||||
Query string `json:"query"`
|
||||
RiskType string `json:"riskType,omitempty"`
|
||||
RetrievedItems []string `json:"retrievedItems"` // 检索到的知识项ID列表
|
||||
RetrievedItems []string `json:"retrievedItems"` // 检索到的知识项 ID 列表
|
||||
CreatedAt time.Time `json:"createdAt"`
|
||||
}
|
||||
|
||||
// MarshalJSON 自定义JSON序列化,确保时间格式正确
|
||||
// MarshalJSON 自定义 JSON 序列化,确保时间格式正确
|
||||
func (r *RetrievalLog) MarshalJSON() ([]byte, error) {
|
||||
type Alias RetrievalLog
|
||||
return json.Marshal(&struct {
|
||||
@@ -120,21 +102,21 @@ func (r *RetrievalLog) MarshalJSON() ([]byte, error) {
|
||||
CreatedAt string `json:"createdAt"`
|
||||
}{
|
||||
Alias: (*Alias)(r),
|
||||
CreatedAt: r.CreatedAt.Format(time.RFC3339),
|
||||
CreatedAt: formatTime(r.CreatedAt),
|
||||
})
|
||||
}
|
||||
|
||||
// CategoryWithItems 分类及其下的知识项(用于按分类分页)
|
||||
type CategoryWithItems struct {
|
||||
Category string `json:"category"` // 分类名称
|
||||
ItemCount int `json:"itemCount"` // 该分类下的知识项总数
|
||||
Items []*KnowledgeItemSummary `json:"items"` // 该分类下的知识项列表
|
||||
Category string `json:"category"` // 分类名称
|
||||
ItemCount int `json:"itemCount"` // 该分类下的知识项总数
|
||||
Items []*KnowledgeItemSummary `json:"items"` // 该分类下的知识项列表
|
||||
}
|
||||
|
||||
// SearchRequest 搜索请求
|
||||
type SearchRequest struct {
|
||||
Query string `json:"query"`
|
||||
RiskType string `json:"riskType,omitempty"` // 可选:指定风险类型
|
||||
TopK int `json:"topK,omitempty"` // 返回Top-K结果,默认5
|
||||
Threshold float64 `json:"threshold,omitempty"` // 相似度阈值,默认0.7
|
||||
TopK int `json:"topK,omitempty"` // 返回 Top-K 结果,默认 5
|
||||
Threshold float64 `json:"threshold,omitempty"` // 相似度阈值,默认 0.7
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user