mirror of
https://github.com/Ed1s0nZ/CyberStrikeAI.git
synced 2026-06-06 22:33:54 +02:00
Compare commits
149 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| ba77e1837e | |||
| eacad60fd6 | |||
| 70bf5c93bf | |||
| 08bd278d8c | |||
| 22746d64a3 | |||
| 199392a5d5 | |||
| aafb4cb584 | |||
| 96e3dd397c | |||
| ec0f17145b | |||
| ed53da0999 | |||
| dc440fc511 | |||
| 009ae59033 | |||
| f348b3245a | |||
| 0018c5219c | |||
| 01a3e3677a | |||
| a12ecdb46f | |||
| 9f59230d74 | |||
| 085c6a1c72 | |||
| 7b3860971f | |||
| f6f7b7b237 | |||
| d5cf4b3b16 | |||
| 3e58d8355b | |||
| eb01ade63b | |||
| d1dc15fa44 | |||
| 73a39ef868 | |||
| a022baef03 | |||
| 59312d428e | |||
| 951d14ef14 | |||
| 0eb22da6e9 | |||
| 5fd9ef0514 | |||
| 9a4f3c7d35 | |||
| ead2ce3ecc | |||
| 8733f3a2d2 | |||
| 8642f3ba31 | |||
| 6a262a7367 | |||
| eb9192ddb3 | |||
| 5587e75628 | |||
| 74bbb453e2 | |||
| 66842f6206 | |||
| dc1779275d | |||
| 10dff937b1 | |||
| d4e1fe3bbe | |||
| 179976ae57 | |||
| 1c758bb98c | |||
| 17c4f38ee3 | |||
| cd7e57d121 | |||
| 0f2c3f65cc | |||
| 7779666e27 | |||
| c74bd4403b | |||
| 04d23ddb43 | |||
| 0874e84393 | |||
| 57f57f30b1 | |||
| f37d613a0c | |||
| 87d0ff9154 | |||
| b3418f39b8 | |||
| f9e1ca0e2d | |||
| 2c45879669 | |||
| 1cdcfa2c2d | |||
| eab5b73846 | |||
| d961ba1ec7 | |||
| 1ba5e57ec6 | |||
| 1216d25f96 | |||
| fde693408e | |||
| 352a81a869 | |||
| b2562b1010 | |||
| 0d8ba51087 | |||
| 0b847fcea3 | |||
| bf2f49fe62 | |||
| 75e64b1a86 | |||
| 2167735022 | |||
| 4ee292cc1f | |||
| 961205940f | |||
| ffe797bd06 | |||
| b6c864547e | |||
| da369c2edc | |||
| 54dc31a616 | |||
| 9e0b985221 | |||
| eb47077082 | |||
| f9a482857d | |||
| 679a68b12f | |||
| 840a26c7ef | |||
| 030e69c02d | |||
| d9683cdb44 | |||
| 60a063dd7d | |||
| 5f0c1805a7 | |||
| cb7e66001b | |||
| 4ea838f1d7 | |||
| 573648fc4b | |||
| f0e090abea | |||
| 549dcf518c | |||
| c74e20c54a | |||
| c94a9fd9e9 | |||
| ce9749a8ef | |||
| 145da12017 | |||
| 5111f4c311 | |||
| 8f6384a083 | |||
| 762f778e1e | |||
| 4a11ba8f14 | |||
| 86090af4df | |||
| 2dea6e36bd | |||
| 38ce695708 | |||
| 41fe90faa3 | |||
| 9f54bdb1bf | |||
| 08e727aa41 | |||
| 176c17d630 | |||
| 62710f6619 | |||
| e4dbb96b3e | |||
| 832532213a | |||
| eb04ac0c3a | |||
| 1946508325 | |||
| 89d1c5124f | |||
| 1e7a3299a5 | |||
| cae3a77331 | |||
| 2e1e57ce27 | |||
| 45b6ed2847 | |||
| 88eadf13a4 | |||
| dca5666b18 | |||
| e5d52cdf85 | |||
| 65e48826ff | |||
| 0cff507272 | |||
| 30afd71c05 | |||
| d2b6a154de | |||
| 278d5aa25c | |||
| 215f5a4a93 | |||
| 44185d748d | |||
| fe47f1f058 | |||
| 99ce183f41 | |||
| 2ed1947f36 | |||
| 97f3e8c179 | |||
| 38b0c31b87 | |||
| cb839da4d1 | |||
| 5ed730f17c | |||
| 30b1e5f820 | |||
| 8e5c70703e | |||
| 3cc3b25a7b | |||
| 44cf63fa52 | |||
| 12057c065b | |||
| c4e0b9735c | |||
| 218e9b9880 | |||
| 82d840966e | |||
| c62ff3bde9 | |||
| df2506b651 | |||
| efe9172f85 | |||
| b788bc6dab | |||
| 9134f2bbcb | |||
| d76cf2a162 | |||
| 2f96feb98f | |||
| a374c3950c | |||
| a93e3455fa |
@@ -174,9 +174,11 @@ The `run.sh` script will automatically:
|
|||||||
- ✅ Build the project
|
- ✅ Build the project
|
||||||
- ✅ Start the server
|
- ✅ Start the server
|
||||||
|
|
||||||
|
**Networking defaults:** `run.sh` starts the server with **`--https`** and the repo **`config.yaml`** (local self-signed TLS; better for many concurrent streams). Use **`./run.sh --http`** for plain HTTP. In production, set **`server.tls_cert_path`** / **`server.tls_key_path`** in **`config.yaml`** (see comments there). For manual runs, add **`--https`** or **`CYBERSTRIKE_HTTPS=1`**; if **`-config`** is wrong, the binary prints a short usage hint on stderr.
|
||||||
|
|
||||||
**First-Time Configuration:**
|
**First-Time Configuration:**
|
||||||
1. **Configure OpenAI-compatible API** (required before first use)
|
1. **Configure OpenAI-compatible API** (required before first use)
|
||||||
- Open http://localhost:8080 after launch
|
- After launch, open **`https://127.0.0.1:8080/`** (or **`https://localhost:8080/`**; replace **8080** with `server.port` in `config.yaml`) and accept the self-signed certificate warning once. If you used `./run.sh --http`, use **`http://`** instead.
|
||||||
- Go to `Settings` → Fill in your API credentials:
|
- Go to `Settings` → Fill in your API credentials:
|
||||||
```yaml
|
```yaml
|
||||||
openai:
|
openai:
|
||||||
@@ -197,21 +199,23 @@ The `run.sh` script will automatically:
|
|||||||
|
|
||||||
**Alternative Launch Methods:**
|
**Alternative Launch Methods:**
|
||||||
```bash
|
```bash
|
||||||
# Direct Go run (requires manual setup)
|
# Direct Go run (set up env yourself); add --https to match run.sh defaults
|
||||||
go run cmd/server/main.go
|
go run cmd/server/main.go --https
|
||||||
|
|
||||||
# Manual build
|
# Manual build
|
||||||
go build -o cyberstrike-ai cmd/server/main.go
|
go build -o cyberstrike-ai cmd/server/main.go
|
||||||
./cyberstrike-ai
|
./cyberstrike-ai --https
|
||||||
```
|
```
|
||||||
|
|
||||||
|
If server logs show `client sent an HTTP request to an HTTPS server`, a client is still using **`http://`** on a TLS-only port—switch the URL to **`https://`**.
|
||||||
|
|
||||||
**Note:** The Python virtual environment (`venv/`) is automatically created and managed by `run.sh`. Tools that require Python (like `api-fuzzer`, `http-framework-test`, etc.) will automatically use this environment.
|
**Note:** The Python virtual environment (`venv/`) is automatically created and managed by `run.sh`. Tools that require Python (like `api-fuzzer`, `http-framework-test`, etc.) will automatically use this environment.
|
||||||
|
|
||||||
### Version Update (No Breaking Changes)
|
### Version Update (No Breaking Changes)
|
||||||
|
|
||||||
**CyberStrikeAI one-click upgrade (recommended):**
|
**CyberStrikeAI one-click upgrade (recommended):**
|
||||||
1. (First time) enable the script: `chmod +x upgrade.sh`
|
1. (First time) enable the script: `chmod +x upgrade.sh`
|
||||||
2. Upgrade with: `./upgrade.sh` (optional flags: `--tag vX.Y.Z`, `--no-venv`, `--preserve-custom`, `--yes`)
|
2. Upgrade with: `./upgrade.sh` (optional flags: `--tag vX.Y.Z`, `--no-venv`, `--yes`). Local `tools/`, `roles/`, and `skills/` are always preserved.
|
||||||
3. The script will back up your `config.yaml` and `data/`, upgrade the code from GitHub Release, update `config.yaml`'s `version`, then restart the server.
|
3. The script will back up your `config.yaml` and `data/`, upgrade the code from GitHub Release, update `config.yaml`'s `version`, then restart the server.
|
||||||
|
|
||||||
Recommended one-liner:
|
Recommended one-liner:
|
||||||
|
|||||||
+9
-5
@@ -173,9 +173,11 @@ chmod +x run.sh && ./run.sh
|
|||||||
- ✅ 编译构建项目
|
- ✅ 编译构建项目
|
||||||
- ✅ 启动服务器
|
- ✅ 启动服务器
|
||||||
|
|
||||||
|
**网络默认:** `run.sh` 会以 **`--https`** 并传入项目根 **`config.yaml`** 启动(本机自签证书,多路流式场景更稳)。只要明文 HTTP 用 **`./run.sh --http`**。生产环境在 **`config.yaml`** 的 **`server.tls_cert_path` / `server.tls_key_path`** 配正式证书(见文件内注释)。手动启动可加 **`--https`** 或环境变量 **`CYBERSTRIKE_HTTPS=1`**;`-config` 写错时程序会在终端提示正确写法。
|
||||||
|
|
||||||
**首次配置:**
|
**首次配置:**
|
||||||
1. **配置 AI 模型 API**(首次使用前必填)
|
1. **配置 AI 模型 API**(首次使用前必填)
|
||||||
- 启动后访问 http://localhost:8080
|
- 启动后在浏览器打开 **`https://127.0.0.1:8080/`**(或 **`https://localhost:8080/`**;端口以 `config.yaml` 中 **`server.port`** 为准,默认 8080),并按提示信任自签证书。若使用 **`./run.sh --http`**,则改用 **`http://`** 访问。
|
||||||
- 进入 `设置` → 填写 API 配置信息:
|
- 进入 `设置` → 填写 API 配置信息:
|
||||||
```yaml
|
```yaml
|
||||||
openai:
|
openai:
|
||||||
@@ -196,20 +198,22 @@ chmod +x run.sh && ./run.sh
|
|||||||
|
|
||||||
**其他启动方式:**
|
**其他启动方式:**
|
||||||
```bash
|
```bash
|
||||||
# 直接运行(需手动配置环境)
|
# 直接运行(需自行配环境);与 run.sh 默认一致可加 --https
|
||||||
go run cmd/server/main.go
|
go run cmd/server/main.go --https
|
||||||
|
|
||||||
# 手动编译
|
# 手动编译
|
||||||
go build -o cyberstrike-ai cmd/server/main.go
|
go build -o cyberstrike-ai cmd/server/main.go
|
||||||
./cyberstrike-ai
|
./cyberstrike-ai --https
|
||||||
```
|
```
|
||||||
|
|
||||||
|
若日志出现 `client sent an HTTP request to an HTTPS server`,说明仍有客户端用 **`http://`** 访问只提供 HTTPS 的端口,请改为 **`https://`**。
|
||||||
|
|
||||||
**说明:** Python 虚拟环境(`venv/`)由 `run.sh` 自动创建和管理。需要 Python 的工具(如 `api-fuzzer`、`http-framework-test` 等)会自动使用该环境。
|
**说明:** Python 虚拟环境(`venv/`)由 `run.sh` 自动创建和管理。需要 Python 的工具(如 `api-fuzzer`、`http-framework-test` 等)会自动使用该环境。
|
||||||
|
|
||||||
### CyberStrikeAI 版本更新(无兼容性问题)
|
### CyberStrikeAI 版本更新(无兼容性问题)
|
||||||
|
|
||||||
1. (首次使用)启用脚本:`chmod +x upgrade.sh`
|
1. (首次使用)启用脚本:`chmod +x upgrade.sh`
|
||||||
2. 一键升级:`./upgrade.sh`(可选参数:`--tag vX.Y.Z`、`--no-venv`、`--preserve-custom`、`--yes`)
|
2. 一键升级:`./upgrade.sh`(可选参数:`--tag vX.Y.Z`、`--no-venv`、`--yes`)。本地的 `tools/`、`roles/`、`skills/` 会始终保留不被覆盖。
|
||||||
3. 脚本会备份你的 `config.yaml` 和 `data/`,从 GitHub Release 升级代码,更新 `config.yaml` 的 `version` 字段后重启服务。
|
3. 脚本会备份你的 `config.yaml` 和 `data/`,从 GitHub Release 升级代码,更新 `config.yaml` 的 `version` 字段后重启服务。
|
||||||
|
|
||||||
推荐的一键指令:
|
推荐的一键指令:
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ import (
|
|||||||
"cyberstrike-ai/internal/logger"
|
"cyberstrike-ai/internal/logger"
|
||||||
"cyberstrike-ai/internal/mcp"
|
"cyberstrike-ai/internal/mcp"
|
||||||
"cyberstrike-ai/internal/security"
|
"cyberstrike-ai/internal/security"
|
||||||
|
"cyberstrike-ai/internal/storage"
|
||||||
"flag"
|
"flag"
|
||||||
"fmt"
|
"fmt"
|
||||||
"os"
|
"os"
|
||||||
@@ -32,6 +33,23 @@ func main() {
|
|||||||
// 创建安全工具执行器
|
// 创建安全工具执行器
|
||||||
executor := security.NewExecutor(&cfg.Security, mcpServer, log.Logger)
|
executor := security.NewExecutor(&cfg.Security, mcpServer, log.Logger)
|
||||||
|
|
||||||
|
// 初始化结果存储(与 internal/app/app.go 同样的逻辑)。
|
||||||
|
// stdio 模式下原本不初始化,导致 'exec' 等查询型工具报"结果存储未初始化"。
|
||||||
|
resultStorageDir := "tmp"
|
||||||
|
if cfg.Agent.ResultStorageDir != "" {
|
||||||
|
resultStorageDir = cfg.Agent.ResultStorageDir
|
||||||
|
}
|
||||||
|
if err := os.MkdirAll(resultStorageDir, 0755); err != nil {
|
||||||
|
fmt.Fprintf(os.Stderr, "创建结果存储目录失败: %v\n", err)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
resultStorage, err := storage.NewFileResultStorage(resultStorageDir, log.Logger)
|
||||||
|
if err != nil {
|
||||||
|
fmt.Fprintf(os.Stderr, "初始化结果存储失败: %v\n", err)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
executor.SetResultStorage(resultStorage)
|
||||||
|
|
||||||
// 注册工具
|
// 注册工具
|
||||||
executor.RegisterTools(mcpServer)
|
executor.RegisterTools(mcpServer)
|
||||||
|
|
||||||
|
|||||||
+43
-3
@@ -9,22 +9,62 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"os"
|
"os"
|
||||||
"os/signal"
|
"os/signal"
|
||||||
|
"strings"
|
||||||
"syscall"
|
"syscall"
|
||||||
)
|
)
|
||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
var configPath = flag.String("config", "config.yaml", "配置文件路径")
|
var configPath = flag.String("config", "config.yaml", "配置文件路径")
|
||||||
|
var httpsBootstrap = flag.Bool("https", false, "启用主站 HTTPS:未配置 tls_cert_path/tls_key_path 时使用内存自签证书(本地测试);与 run.sh 默认行为一致")
|
||||||
flag.Parse()
|
flag.Parse()
|
||||||
|
|
||||||
|
// 环境变量兼容(便于 systemd/docker 等不传参场景)
|
||||||
|
if !*httpsBootstrap {
|
||||||
|
v := strings.TrimSpace(os.Getenv("CYBERSTRIKE_HTTPS"))
|
||||||
|
if v == "1" || strings.EqualFold(v, "true") || strings.EqualFold(v, "yes") {
|
||||||
|
*httpsBootstrap = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// 加载配置
|
// 加载配置
|
||||||
cfg, err := config.Load(*configPath)
|
cp := strings.TrimSpace(*configPath)
|
||||||
|
if cp == "" {
|
||||||
|
cp = "config.yaml"
|
||||||
|
}
|
||||||
|
if strings.HasPrefix(cp, "-") {
|
||||||
|
fmt.Fprintf(os.Stderr, "无效的 -config 路径 %q。\n若同时需要 HTTPS,请写成: ./cyberstrike-ai --https -config config.yaml(-config 后必须是 yaml 文件路径)。\n", cp)
|
||||||
|
os.Exit(2)
|
||||||
|
}
|
||||||
|
cfg, err := config.Load(cp)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
fmt.Printf("加载配置失败: %v\n", err)
|
fmt.Printf("加载配置失败: %v\n", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if *httpsBootstrap {
|
||||||
|
config.ApplyDevHTTPSBootstrap(cfg)
|
||||||
|
}
|
||||||
|
|
||||||
|
port := cfg.Server.Port
|
||||||
|
if port <= 0 {
|
||||||
|
port = 8080
|
||||||
|
}
|
||||||
|
scheme := "http"
|
||||||
|
if config.MainWebUIUsesHTTPS(&cfg.Server) {
|
||||||
|
scheme = "https"
|
||||||
|
}
|
||||||
|
fmt.Println()
|
||||||
|
fmt.Printf("→ Web 界面: %s://127.0.0.1:%d/\n", scheme, port)
|
||||||
|
if scheme == "https" && cfg.Server.TLSAutoSelfSign {
|
||||||
|
fmt.Println(" (内存自签证书:浏览器首次需确认「继续访问」)")
|
||||||
|
}
|
||||||
|
if scheme == "https" && config.ServerHTTPRedirectEnabled(&cfg.Server) {
|
||||||
|
fmt.Printf(" (http://127.0.0.1:%d/ 将自动跳转到 HTTPS)\n", port)
|
||||||
|
}
|
||||||
|
fmt.Println()
|
||||||
|
|
||||||
// MCP 启用且 auth_header_value 为空时,自动生成随机密钥并写回配置
|
// MCP 启用且 auth_header_value 为空时,自动生成随机密钥并写回配置
|
||||||
if err := config.EnsureMCPAuth(*configPath, cfg); err != nil {
|
if err := config.EnsureMCPAuth(cp, cfg); err != nil {
|
||||||
fmt.Printf("MCP 鉴权配置失败: %v\n", err)
|
fmt.Printf("MCP 鉴权配置失败: %v\n", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -44,7 +84,7 @@ func main() {
|
|||||||
signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM)
|
signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM)
|
||||||
|
|
||||||
// 创建应用
|
// 创建应用
|
||||||
application, err := app.New(cfg, log)
|
application, err := app.New(cfg, log, cp)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatal("应用初始化失败", "error", err)
|
log.Fatal("应用初始化失败", "error", err)
|
||||||
}
|
}
|
||||||
|
|||||||
+45
-4
@@ -10,11 +10,22 @@
|
|||||||
# ============================================
|
# ============================================
|
||||||
|
|
||||||
# 前端显示的版本号(可选,不填则显示默认版本)
|
# 前端显示的版本号(可选,不填则显示默认版本)
|
||||||
version: "v1.6.5"
|
version: "v1.6.18"
|
||||||
# 服务器配置
|
# 服务器配置
|
||||||
server:
|
server:
|
||||||
host: 0.0.0.0 # 监听地址,0.0.0.0 表示监听所有网络接口
|
host: 0.0.0.0 # 监听地址,0.0.0.0 表示监听所有网络接口
|
||||||
port: 8080 # HTTP 服务端口,可通过浏览器访问 http://localhost:8080
|
port: 8080 # 服务端口;未启用 TLS 时为 http://localhost:8080
|
||||||
|
# --- 可选:HTTPS + HTTP/2(缓解浏览器对同源 HTTP/1.1 的并发连接数限制,多路 Deep 流式更稳)---
|
||||||
|
# 启用 TLS 的条件(满足其一即可):tls_enabled: true,或 tls_auto_self_sign: true,或同时配置了 tls_cert_path + tls_key_path。
|
||||||
|
# 启用后请用 https://127.0.0.1:<本端口>/ 访问;若仍用 http:// 访问同端口,将自动 308 跳转到 HTTPS(可用 tls_http_redirect: false 关闭)。
|
||||||
|
tls_enabled: true
|
||||||
|
# 启用 HTTPS 时,明文 HTTP 是否自动跳转到 HTTPS(默认 true;同端口嗅探 TLS/HTTP 后分流)
|
||||||
|
# tls_http_redirect: true
|
||||||
|
# 方式 A(推荐生产):PEM 证书与私钥路径
|
||||||
|
# tls_cert_path: /path/to/fullchain.pem
|
||||||
|
# tls_key_path: /path/to/privkey.pem
|
||||||
|
# 方式 B(仅本地/测试):无证书文件时内存自签(浏览器会提示不受信任;SAN 含 localhost / 127.0.0.1)
|
||||||
|
tls_auto_self_sign: true
|
||||||
# 认证配置
|
# 认证配置
|
||||||
auth:
|
auth:
|
||||||
password: # Web 登录密码,请修改为强密码
|
password: # Web 登录密码,请修改为强密码
|
||||||
@@ -41,6 +52,13 @@ openai:
|
|||||||
api_key: sk-xxxxxxx # API 密钥(必填)
|
api_key: sk-xxxxxxx # API 密钥(必填)
|
||||||
model: qwen3-max # 模型名称(必填)
|
model: qwen3-max # 模型名称(必填)
|
||||||
max_total_tokens: 120000 # LLM 相关上下文的最大 Token 数限制(内存压缩和攻击链构建会共用此配置)
|
max_total_tokens: 120000 # LLM 相关上下文的最大 Token 数限制(内存压缩和攻击链构建会共用此配置)
|
||||||
|
# Eino 路径模型推理:DeepSeek/OpenAI 为 thinking / reasoning_effort 等;provider 为 claude 时合并为 Anthropic 顶层 thinking(extended thinking),mode: off 关闭
|
||||||
|
reasoning:
|
||||||
|
mode: on # auto | on | off;off 时不附加任何推理扩展字段
|
||||||
|
effort: max # low | medium | high | max;空表示不指定(openai_compat 下 auto 且无强度时不发请求扩展)
|
||||||
|
allow_client_reasoning: true # false 时忽略对话请求体 reasoning,仅以下方为准
|
||||||
|
profile: openai_compat # auto | deepseek_compat | openai_compat | output_config_effort
|
||||||
|
# extra_request_fields: {} # 可选:管理员自定义根级 JSON 片段(高级)
|
||||||
# ============================================
|
# ============================================
|
||||||
# 信息收集(FOFA)配置(可选)
|
# 信息收集(FOFA)配置(可选)
|
||||||
# ============================================
|
# ============================================
|
||||||
@@ -53,10 +71,10 @@ fofa:
|
|||||||
# Agent 配置
|
# Agent 配置
|
||||||
# 达到最大迭代次数时,AI 会自动总结测试结果
|
# 达到最大迭代次数时,AI 会自动总结测试结果
|
||||||
agent:
|
agent:
|
||||||
max_iterations: 120 # 最大迭代次数,AI 代理最多执行多少轮工具调用
|
max_iterations: 1200 # 最大迭代次数,AI 代理最多执行多少轮工具调用
|
||||||
large_result_threshold: 102400 # 大结果阈值(字节),默认50KB,超过此大小会自动保存到存储
|
large_result_threshold: 102400 # 大结果阈值(字节),默认50KB,超过此大小会自动保存到存储
|
||||||
result_storage_dir: tmp # 结果存储目录,大结果会保存在此目录下
|
result_storage_dir: tmp # 结果存储目录,大结果会保存在此目录下
|
||||||
tool_timeout_minutes: 30 # 单次工具执行最大时长(分钟),超时自动终止;0 表示不限制(不推荐,易出现长时间挂起)
|
tool_timeout_minutes: 60 # 单次工具执行最大时长(分钟),超时自动终止;0 表示不限制(不推荐,易出现长时间挂起)
|
||||||
# system_prompt_path: prompts/single-react.md # 可选:单代理系统提示文件(相对本配置文件所在目录);非空且可读时替换内置提示
|
# system_prompt_path: prompts/single-react.md # 可选:单代理系统提示文件(相对本配置文件所在目录);非空且可读时替换内置提示
|
||||||
# 人机协同(HITL)全局白名单:此处列出的工具始终免审批,与对话页「白名单工具(免审批,逗号分隔)」合并为并集;侧栏「应用」可合并写入本列表并立即生效。
|
# 人机协同(HITL)全局白名单:此处列出的工具始终免审批,与对话页「白名单工具(免审批,逗号分隔)」合并为并集;侧栏「应用」可合并写入本列表并立即生效。
|
||||||
hitl:
|
hitl:
|
||||||
@@ -110,6 +128,21 @@ multi_agent:
|
|||||||
deep_output_key: "" # 非空:将最终助手输出写入 adk session 的键名(Deep 与 Supervisor 主代理);空表示不写入
|
deep_output_key: "" # 非空:将最终助手输出写入 adk session 的键名(Deep 与 Supervisor 主代理);空表示不写入
|
||||||
deep_model_retry_max_retries: 0 # >0:ChatModel 调用失败时的框架级最大重试次数(Deep 与 Supervisor 主);0:不重试
|
deep_model_retry_max_retries: 0 # >0:ChatModel 调用失败时的框架级最大重试次数(Deep 与 Supervisor 主);0:不重试
|
||||||
task_tool_description_prefix: "" # 非空:仅 Deep 的 task 工具使用自定义描述前缀,运行时会拼接子代理名称;空则走 Eino 默认生成逻辑
|
task_tool_description_prefix: "" # 非空:仅 Deep 的 task 工具使用自定义描述前缀,运行时会拼接子代理名称;空则走 Eino 默认生成逻辑
|
||||||
|
# Eino callbacks + OpenTelemetry:框架级 span(与 Zap 对齐);默认不向终端用户 UI 推 eino_trace_*(见 sse_trace_to_client)
|
||||||
|
eino_callbacks:
|
||||||
|
enabled: true
|
||||||
|
# log_only=仅 Zap+OTel(推荐默认)| sse/full=才启用流式回调副本关闭等(full 含 stream hooks)
|
||||||
|
mode: log_only
|
||||||
|
sse_trace_to_client: false # true:且 mode 为 sse/full 时,向前端时间线推送 eino_trace_*(排障/内网演示用)
|
||||||
|
max_input_summary_runes: 400
|
||||||
|
max_output_summary_runes: 400
|
||||||
|
zap_verbose: false # true:Debug 附带 input/output 摘要
|
||||||
|
otel:
|
||||||
|
enabled: true
|
||||||
|
service_name: cyberstrike-ai
|
||||||
|
exporter: stdout # none | stdout(开发/本机)| otlphttp(生产接 Collector)
|
||||||
|
otlp_endpoint: localhost:4318 # otlphttp 时使用,host:port,路径固定 /v1/traces
|
||||||
|
sample_ratio: 1.0 # 0~1,ParentBased+TraceIDRatio
|
||||||
# 数据库配置
|
# 数据库配置
|
||||||
database:
|
database:
|
||||||
path: data/conversations.db # SQLite 数据库文件路径,用于存储对话历史和消息
|
path: data/conversations.db # SQLite 数据库文件路径,用于存储对话历史和消息
|
||||||
@@ -202,6 +235,14 @@ knowledge:
|
|||||||
# 用于在手机端通过企业微信/钉钉/飞书与 CyberStrikeAI 对话,无需部署在服务器上也可使用
|
# 用于在手机端通过企业微信/钉钉/飞书与 CyberStrikeAI 对话,无需部署在服务器上也可使用
|
||||||
# 在系统设置 -> 机器人设置 中可配置
|
# 在系统设置 -> 机器人设置 中可配置
|
||||||
robots:
|
robots:
|
||||||
|
wechat: # 微信 iLink(个人微信 ClawBot,扫码绑定)
|
||||||
|
enabled: false
|
||||||
|
bot_token: ""
|
||||||
|
ilink_bot_id: ""
|
||||||
|
ilink_user_id: ""
|
||||||
|
base_url: https://ilinkai.weixin.qq.com
|
||||||
|
bot_type: "3"
|
||||||
|
bot_agent: CyberStrikeAI/1.0
|
||||||
wecom: # 企业微信
|
wecom: # 企业微信
|
||||||
enabled: false
|
enabled: false
|
||||||
token: ""
|
token: ""
|
||||||
|
|||||||
@@ -27,7 +27,14 @@ require (
|
|||||||
github.com/open-dingtalk/dingtalk-stream-sdk-go v0.9.1
|
github.com/open-dingtalk/dingtalk-stream-sdk-go v0.9.1
|
||||||
github.com/pkoukk/tiktoken-go v0.1.8
|
github.com/pkoukk/tiktoken-go v0.1.8
|
||||||
github.com/robfig/cron/v3 v3.0.1
|
github.com/robfig/cron/v3 v3.0.1
|
||||||
|
github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e
|
||||||
|
go.opentelemetry.io/otel v1.34.0
|
||||||
|
go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp v1.34.0
|
||||||
|
go.opentelemetry.io/otel/exporters/stdout/stdouttrace v1.34.0
|
||||||
|
go.opentelemetry.io/otel/sdk v1.34.0
|
||||||
|
go.opentelemetry.io/otel/trace v1.34.0
|
||||||
go.uber.org/zap v1.26.0
|
go.uber.org/zap v1.26.0
|
||||||
|
golang.org/x/net v0.35.0
|
||||||
golang.org/x/text v0.26.0
|
golang.org/x/text v0.26.0
|
||||||
golang.org/x/time v0.14.0
|
golang.org/x/time v0.14.0
|
||||||
gopkg.in/yaml.v3 v3.0.1
|
gopkg.in/yaml.v3 v3.0.1
|
||||||
@@ -39,6 +46,7 @@ require (
|
|||||||
github.com/buger/jsonparser v1.1.1 // indirect
|
github.com/buger/jsonparser v1.1.1 // indirect
|
||||||
github.com/bytedance/gopkg v0.1.3 // indirect
|
github.com/bytedance/gopkg v0.1.3 // indirect
|
||||||
github.com/bytedance/sonic/loader v0.5.0 // indirect
|
github.com/bytedance/sonic/loader v0.5.0 // indirect
|
||||||
|
github.com/cenkalti/backoff/v4 v4.3.0 // indirect
|
||||||
github.com/cloudwego/base64x v0.1.6 // indirect
|
github.com/cloudwego/base64x v0.1.6 // indirect
|
||||||
github.com/cloudwego/eino-ext/libs/acl/openai v0.1.17 // indirect
|
github.com/cloudwego/eino-ext/libs/acl/openai v0.1.17 // indirect
|
||||||
github.com/dlclark/regexp2 v1.10.0 // indirect
|
github.com/dlclark/regexp2 v1.10.0 // indirect
|
||||||
@@ -46,6 +54,8 @@ require (
|
|||||||
github.com/evanphx/json-patch v0.5.2 // indirect
|
github.com/evanphx/json-patch v0.5.2 // indirect
|
||||||
github.com/gabriel-vasile/mimetype v1.4.2 // indirect
|
github.com/gabriel-vasile/mimetype v1.4.2 // indirect
|
||||||
github.com/gin-contrib/sse v0.1.0 // indirect
|
github.com/gin-contrib/sse v0.1.0 // indirect
|
||||||
|
github.com/go-logr/logr v1.4.2 // indirect
|
||||||
|
github.com/go-logr/stdr v1.2.2 // indirect
|
||||||
github.com/go-playground/locales v0.14.1 // indirect
|
github.com/go-playground/locales v0.14.1 // indirect
|
||||||
github.com/go-playground/universal-translator v0.18.1 // indirect
|
github.com/go-playground/universal-translator v0.18.1 // indirect
|
||||||
github.com/go-playground/validator/v10 v10.14.0 // indirect
|
github.com/go-playground/validator/v10 v10.14.0 // indirect
|
||||||
@@ -53,6 +63,7 @@ require (
|
|||||||
github.com/gogo/protobuf v1.3.2 // indirect
|
github.com/gogo/protobuf v1.3.2 // indirect
|
||||||
github.com/google/jsonschema-go v0.3.0 // indirect
|
github.com/google/jsonschema-go v0.3.0 // indirect
|
||||||
github.com/goph/emperror v0.17.2 // indirect
|
github.com/goph/emperror v0.17.2 // indirect
|
||||||
|
github.com/grpc-ecosystem/grpc-gateway/v2 v2.25.1 // indirect
|
||||||
github.com/json-iterator/go v1.1.12 // indirect
|
github.com/json-iterator/go v1.1.12 // indirect
|
||||||
github.com/klauspost/cpuid/v2 v2.2.10 // indirect
|
github.com/klauspost/cpuid/v2 v2.2.10 // indirect
|
||||||
github.com/leodido/go-urn v1.2.4 // indirect
|
github.com/leodido/go-urn v1.2.4 // indirect
|
||||||
@@ -71,14 +82,20 @@ require (
|
|||||||
github.com/wk8/go-ordered-map/v2 v2.1.8 // indirect
|
github.com/wk8/go-ordered-map/v2 v2.1.8 // indirect
|
||||||
github.com/yargevad/filepathx v1.0.0 // indirect
|
github.com/yargevad/filepathx v1.0.0 // indirect
|
||||||
github.com/yosida95/uritemplate/v3 v3.0.2 // indirect
|
github.com/yosida95/uritemplate/v3 v3.0.2 // indirect
|
||||||
|
go.opentelemetry.io/auto/sdk v1.1.0 // indirect
|
||||||
|
go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.34.0 // indirect
|
||||||
|
go.opentelemetry.io/otel/metric v1.34.0 // indirect
|
||||||
|
go.opentelemetry.io/proto/otlp v1.5.0 // indirect
|
||||||
go.uber.org/multierr v1.11.0 // indirect
|
go.uber.org/multierr v1.11.0 // indirect
|
||||||
golang.org/x/arch v0.15.0 // indirect
|
golang.org/x/arch v0.15.0 // indirect
|
||||||
golang.org/x/crypto v0.39.0 // indirect
|
golang.org/x/crypto v0.39.0 // indirect
|
||||||
golang.org/x/exp v0.0.0-20250305212735-054e65f0b394 // indirect
|
golang.org/x/exp v0.0.0-20250305212735-054e65f0b394 // indirect
|
||||||
golang.org/x/net v0.24.0 // indirect
|
|
||||||
golang.org/x/oauth2 v0.30.0 // indirect
|
golang.org/x/oauth2 v0.30.0 // indirect
|
||||||
golang.org/x/sys v0.33.0 // indirect
|
golang.org/x/sys v0.33.0 // indirect
|
||||||
google.golang.org/protobuf v1.30.0 // indirect
|
google.golang.org/genproto/googleapis/api v0.0.0-20250115164207-1a7da9e5054f // indirect
|
||||||
|
google.golang.org/genproto/googleapis/rpc v0.0.0-20250115164207-1a7da9e5054f // indirect
|
||||||
|
google.golang.org/grpc v1.69.4 // indirect
|
||||||
|
google.golang.org/protobuf v1.36.3 // indirect
|
||||||
)
|
)
|
||||||
|
|
||||||
// 修复钉钉 Stream SDK 在长连接断开(熄屏/网络中断)后 "panic: send on closed channel" 问题
|
// 修复钉钉 Stream SDK 在长连接断开(熄屏/网络中断)后 "panic: send on closed channel" 问题
|
||||||
|
|||||||
@@ -17,6 +17,8 @@ github.com/bytedance/sonic v1.15.0 h1:/PXeWFaR5ElNcVE84U0dOHjiMHQOwNIx3K4ymzh/uS
|
|||||||
github.com/bytedance/sonic v1.15.0/go.mod h1:tFkWrPz0/CUCLEF4ri4UkHekCIcdnkqXw9VduqpJh0k=
|
github.com/bytedance/sonic v1.15.0/go.mod h1:tFkWrPz0/CUCLEF4ri4UkHekCIcdnkqXw9VduqpJh0k=
|
||||||
github.com/bytedance/sonic/loader v0.5.0 h1:gXH3KVnatgY7loH5/TkeVyXPfESoqSBSBEiDd5VjlgE=
|
github.com/bytedance/sonic/loader v0.5.0 h1:gXH3KVnatgY7loH5/TkeVyXPfESoqSBSBEiDd5VjlgE=
|
||||||
github.com/bytedance/sonic/loader v0.5.0/go.mod h1:AR4NYCk5DdzZizZ5djGqQ92eEhCCcdf5x77udYiSJRo=
|
github.com/bytedance/sonic/loader v0.5.0/go.mod h1:AR4NYCk5DdzZizZ5djGqQ92eEhCCcdf5x77udYiSJRo=
|
||||||
|
github.com/cenkalti/backoff/v4 v4.3.0 h1:MyRJ/UdXutAwSAT+s3wNd7MfTIcy71VQueUuFK343L8=
|
||||||
|
github.com/cenkalti/backoff/v4 v4.3.0/go.mod h1:Y3VNntkOUPxTVeUxJ/G5vcM//AlwfmyYozVcomhLiZE=
|
||||||
github.com/certifi/gocertifi v0.0.0-20190105021004-abcd57078448/go.mod h1:GJKEexRPVJrBSOjoqN5VNOIKJ5Q3RViH6eu3puDRwx4=
|
github.com/certifi/gocertifi v0.0.0-20190105021004-abcd57078448/go.mod h1:GJKEexRPVJrBSOjoqN5VNOIKJ5Q3RViH6eu3puDRwx4=
|
||||||
github.com/cloudwego/base64x v0.1.6 h1:t11wG9AECkCDk5fMSoxmufanudBtJ+/HemLstXDLI2M=
|
github.com/cloudwego/base64x v0.1.6 h1:t11wG9AECkCDk5fMSoxmufanudBtJ+/HemLstXDLI2M=
|
||||||
github.com/cloudwego/base64x v0.1.6/go.mod h1:OFcloc187FXDaYHvrNIjxSe8ncn0OOM8gEHfghB2IPU=
|
github.com/cloudwego/base64x v0.1.6/go.mod h1:OFcloc187FXDaYHvrNIjxSe8ncn0OOM8gEHfghB2IPU=
|
||||||
@@ -59,6 +61,11 @@ github.com/gin-gonic/gin v1.9.1 h1:4idEAncQnU5cB7BeOkPtxjfCSye0AAm1R0RVIqJ+Jmg=
|
|||||||
github.com/gin-gonic/gin v1.9.1/go.mod h1:hPrL7YrpYKXt5YId3A/Tnip5kqbEAP+KLuI3SUcPTeU=
|
github.com/gin-gonic/gin v1.9.1/go.mod h1:hPrL7YrpYKXt5YId3A/Tnip5kqbEAP+KLuI3SUcPTeU=
|
||||||
github.com/go-check/check v0.0.0-20180628173108-788fd7840127 h1:0gkP6mzaMqkmpcJYCFOLkIBwI7xFExG03bbkOkCvUPI=
|
github.com/go-check/check v0.0.0-20180628173108-788fd7840127 h1:0gkP6mzaMqkmpcJYCFOLkIBwI7xFExG03bbkOkCvUPI=
|
||||||
github.com/go-check/check v0.0.0-20180628173108-788fd7840127/go.mod h1:9ES+weclKsC9YodN5RgxqK/VD9HM9JsCSh7rNhMZE98=
|
github.com/go-check/check v0.0.0-20180628173108-788fd7840127/go.mod h1:9ES+weclKsC9YodN5RgxqK/VD9HM9JsCSh7rNhMZE98=
|
||||||
|
github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A=
|
||||||
|
github.com/go-logr/logr v1.4.2 h1:6pFjapn8bFcIbiKo3XT4j/BhANplGihG6tvd+8rYgrY=
|
||||||
|
github.com/go-logr/logr v1.4.2/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY=
|
||||||
|
github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag=
|
||||||
|
github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE=
|
||||||
github.com/go-playground/assert/v2 v2.2.0 h1:JvknZsQTYeFEAhQwI4qEt9cyV5ONwRHC+lYKSsYSR8s=
|
github.com/go-playground/assert/v2 v2.2.0 h1:JvknZsQTYeFEAhQwI4qEt9cyV5ONwRHC+lYKSsYSR8s=
|
||||||
github.com/go-playground/assert/v2 v2.2.0/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4=
|
github.com/go-playground/assert/v2 v2.2.0/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4=
|
||||||
github.com/go-playground/locales v0.14.1 h1:EWaQ/wswjilfKLTECiXz7Rh+3BjFhfDFKv/oXslEjJA=
|
github.com/go-playground/locales v0.14.1 h1:EWaQ/wswjilfKLTECiXz7Rh+3BjFhfDFKv/oXslEjJA=
|
||||||
@@ -75,8 +82,8 @@ github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69
|
|||||||
github.com/golang-jwt/jwt/v5 v5.2.2 h1:Rl4B7itRWVtYIHFrSNd7vhTiz9UpLdi6gZhZ3wEeDy8=
|
github.com/golang-jwt/jwt/v5 v5.2.2 h1:Rl4B7itRWVtYIHFrSNd7vhTiz9UpLdi6gZhZ3wEeDy8=
|
||||||
github.com/golang-jwt/jwt/v5 v5.2.2/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk=
|
github.com/golang-jwt/jwt/v5 v5.2.2/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk=
|
||||||
github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
|
github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
|
||||||
github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk=
|
github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek=
|
||||||
github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
|
github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps=
|
||||||
github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
|
github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
|
||||||
github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU=
|
github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU=
|
||||||
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
|
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
|
||||||
@@ -90,6 +97,8 @@ github.com/gopherjs/gopherjs v1.17.2 h1:fQnZVsXk8uxXIStYb0N4bGk7jeyTalG/wsZjQ25d
|
|||||||
github.com/gopherjs/gopherjs v1.17.2/go.mod h1:pRRIvn/QzFLrKfvEz3qUuEhtE/zLCWfreZ6J5gM2i+k=
|
github.com/gopherjs/gopherjs v1.17.2/go.mod h1:pRRIvn/QzFLrKfvEz3qUuEhtE/zLCWfreZ6J5gM2i+k=
|
||||||
github.com/gorilla/websocket v1.5.0 h1:PPwGk2jz7EePpoHN/+ClbZu8SPxiqlu12wZP/3sWmnc=
|
github.com/gorilla/websocket v1.5.0 h1:PPwGk2jz7EePpoHN/+ClbZu8SPxiqlu12wZP/3sWmnc=
|
||||||
github.com/gorilla/websocket v1.5.0/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
|
github.com/gorilla/websocket v1.5.0/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
|
||||||
|
github.com/grpc-ecosystem/grpc-gateway/v2 v2.25.1 h1:VNqngBF40hVlDloBruUehVYC3ArSgIyScOAyMRqBxRg=
|
||||||
|
github.com/grpc-ecosystem/grpc-gateway/v2 v2.25.1/go.mod h1:RBRO7fro65R6tjKzYgLAFo0t1QEXY1Dp+i/bvpRiqiQ=
|
||||||
github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU=
|
github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU=
|
||||||
github.com/jessevdk/go-flags v1.4.0/go.mod h1:4FA24M0QyGHXBuZZK/XkWh8h0e1EYbRYJSGM75WSRxI=
|
github.com/jessevdk/go-flags v1.4.0/go.mod h1:4FA24M0QyGHXBuZZK/XkWh8h0e1EYbRYJSGM75WSRxI=
|
||||||
github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM=
|
github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM=
|
||||||
@@ -154,6 +163,8 @@ github.com/rollbar/rollbar-go v1.0.2/go.mod h1:AcFs5f0I+c71bpHlXNNDbOWJiKwjFDtIS
|
|||||||
github.com/sirupsen/logrus v1.2.0/go.mod h1:LxeOpSwHxABJmUn/MG1IvRgCAasNZTLOkJPxbbu5VWo=
|
github.com/sirupsen/logrus v1.2.0/go.mod h1:LxeOpSwHxABJmUn/MG1IvRgCAasNZTLOkJPxbbu5VWo=
|
||||||
github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ=
|
github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ=
|
||||||
github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ=
|
github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ=
|
||||||
|
github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e h1:MRM5ITcdelLK2j1vwZ3Je0FKVCfqOLp5zO6trqMLYs0=
|
||||||
|
github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e/go.mod h1:XV66xRDqSt+GTGFMVlhk3ULuV0y9ZmzeVGR4mloJI3M=
|
||||||
github.com/slongfield/pyfmt v0.0.0-20220222012616-ea85ff4c361f h1:Z2cODYsUxQPofhpYRMQVwWz4yUVpHF+vPi+eUdruUYI=
|
github.com/slongfield/pyfmt v0.0.0-20220222012616-ea85ff4c361f h1:Z2cODYsUxQPofhpYRMQVwWz4yUVpHF+vPi+eUdruUYI=
|
||||||
github.com/slongfield/pyfmt v0.0.0-20220222012616-ea85ff4c361f/go.mod h1:JqzWyvTuI2X4+9wOHmKSQCYxybB/8j6Ko43qVmXDuZg=
|
github.com/slongfield/pyfmt v0.0.0-20220222012616-ea85ff4c361f/go.mod h1:JqzWyvTuI2X4+9wOHmKSQCYxybB/8j6Ko43qVmXDuZg=
|
||||||
github.com/smarty/assertions v1.16.0 h1:EvHNkdRA4QHMrn75NZSoUQ/mAUXAYWfatfB01yTCzfY=
|
github.com/smarty/assertions v1.16.0 h1:EvHNkdRA4QHMrn75NZSoUQ/mAUXAYWfatfB01yTCzfY=
|
||||||
@@ -191,6 +202,26 @@ github.com/yosida95/uritemplate/v3 v3.0.2 h1:Ed3Oyj9yrmi9087+NczuL5BwkIc4wvTb5zI
|
|||||||
github.com/yosida95/uritemplate/v3 v3.0.2/go.mod h1:ILOh0sOhIJR3+L/8afwt/kE++YT040gmv5BQTMR2HP4=
|
github.com/yosida95/uritemplate/v3 v3.0.2/go.mod h1:ILOh0sOhIJR3+L/8afwt/kE++YT040gmv5BQTMR2HP4=
|
||||||
github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
|
github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
|
||||||
github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
|
github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
|
||||||
|
go.opentelemetry.io/auto/sdk v1.1.0 h1:cH53jehLUN6UFLY71z+NDOiNJqDdPRaXzTel0sJySYA=
|
||||||
|
go.opentelemetry.io/auto/sdk v1.1.0/go.mod h1:3wSPjt5PWp2RhlCcmmOial7AvC4DQqZb7a7wCow3W8A=
|
||||||
|
go.opentelemetry.io/otel v1.34.0 h1:zRLXxLCgL1WyKsPVrgbSdMN4c0FMkDAskSTQP+0hdUY=
|
||||||
|
go.opentelemetry.io/otel v1.34.0/go.mod h1:OWFPOQ+h4G8xpyjgqo4SxJYdDQ/qmRH+wivy7zzx9oI=
|
||||||
|
go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.34.0 h1:OeNbIYk/2C15ckl7glBlOBp5+WlYsOElzTNmiPW/x60=
|
||||||
|
go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.34.0/go.mod h1:7Bept48yIeqxP2OZ9/AqIpYS94h2or0aB4FypJTc8ZM=
|
||||||
|
go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp v1.34.0 h1:BEj3SPM81McUZHYjRS5pEgNgnmzGJ5tRpU5krWnV8Bs=
|
||||||
|
go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp v1.34.0/go.mod h1:9cKLGBDzI/F3NoHLQGm4ZrYdIHsvGt6ej6hUowxY0J4=
|
||||||
|
go.opentelemetry.io/otel/exporters/stdout/stdouttrace v1.34.0 h1:jBpDk4HAUsrnVO1FsfCfCOTEc/MkInJmvfCHYLFiT80=
|
||||||
|
go.opentelemetry.io/otel/exporters/stdout/stdouttrace v1.34.0/go.mod h1:H9LUIM1daaeZaz91vZcfeM0fejXPmgCYE8ZhzqfJuiU=
|
||||||
|
go.opentelemetry.io/otel/metric v1.34.0 h1:+eTR3U0MyfWjRDhmFMxe2SsW64QrZ84AOhvqS7Y+PoQ=
|
||||||
|
go.opentelemetry.io/otel/metric v1.34.0/go.mod h1:CEDrp0fy2D0MvkXE+dPV7cMi8tWZwX3dmaIhwPOaqHE=
|
||||||
|
go.opentelemetry.io/otel/sdk v1.34.0 h1:95zS4k/2GOy069d321O8jWgYsW3MzVV+KuSPKp7Wr1A=
|
||||||
|
go.opentelemetry.io/otel/sdk v1.34.0/go.mod h1:0e/pNiaMAqaykJGKbi+tSjWfNNHMTxoC9qANsCzbyxU=
|
||||||
|
go.opentelemetry.io/otel/sdk/metric v1.31.0 h1:i9hxxLJF/9kkvfHppyLL55aW7iIJz4JjxTeYusH7zMc=
|
||||||
|
go.opentelemetry.io/otel/sdk/metric v1.31.0/go.mod h1:CRInTMVvNhUKgSAMbKyTMxqOBC0zgyxzW55lZzX43Y8=
|
||||||
|
go.opentelemetry.io/otel/trace v1.34.0 h1:+ouXS2V8Rd4hp4580a8q23bg0azF2nI8cqLYnC8mh/k=
|
||||||
|
go.opentelemetry.io/otel/trace v1.34.0/go.mod h1:Svm7lSjQD7kG7KJ/MUHPVXSDGz2OX4h0M2jHBhmSfRE=
|
||||||
|
go.opentelemetry.io/proto/otlp v1.5.0 h1:xJvq7gMzB31/d406fB8U5CBdyQGw4P399D1aQWU/3i4=
|
||||||
|
go.opentelemetry.io/proto/otlp v1.5.0/go.mod h1:keN8WnHxOy8PG0rQZjJJ5A2ebUoafqWp0eVQ4yIXvJ4=
|
||||||
go.uber.org/goleak v1.2.0 h1:xqgm/S+aQvhWFTtR0XK3Jvg7z8kGV8P4X14IzwN3Eqk=
|
go.uber.org/goleak v1.2.0 h1:xqgm/S+aQvhWFTtR0XK3Jvg7z8kGV8P4X14IzwN3Eqk=
|
||||||
go.uber.org/goleak v1.2.0/go.mod h1:XJYK+MuIchqpmGmUSAzotztawfKvYLUIgg7guXrwVUo=
|
go.uber.org/goleak v1.2.0/go.mod h1:XJYK+MuIchqpmGmUSAzotztawfKvYLUIgg7guXrwVUo=
|
||||||
go.uber.org/mock v0.4.0 h1:VcM4ZOtdbR4f6VXfiOpwpVJDL6lCReaZ6mw31wqh7KU=
|
go.uber.org/mock v0.4.0 h1:VcM4ZOtdbR4f6VXfiOpwpVJDL6lCReaZ6mw31wqh7KU=
|
||||||
@@ -216,8 +247,8 @@ golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn
|
|||||||
golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
|
golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
|
||||||
golang.org/x/net v0.0.0-20200226121028-0de0cce0169b/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
|
golang.org/x/net v0.0.0-20200226121028-0de0cce0169b/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
|
||||||
golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU=
|
golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU=
|
||||||
golang.org/x/net v0.24.0 h1:1PcaxkF854Fu3+lvBIx5SYn9wRlBzzcnHZSiaFFAb0w=
|
golang.org/x/net v0.35.0 h1:T5GQRQb2y08kTAByq9L4/bz8cipCdA8FbRTXewonqY8=
|
||||||
golang.org/x/net v0.24.0/go.mod h1:2Q7sJY5mzlzWjKtYUEXSlBWCdyaioyXzRB2RtU8KVE8=
|
golang.org/x/net v0.35.0/go.mod h1:EglIi67kWsHKlRzzVMUD93VMSWGFOMSZgxFjparz1Qk=
|
||||||
golang.org/x/oauth2 v0.30.0 h1:dnDm7JmhM45NNpd8FDDeLhK6FwqbOf4MLCM9zb1BOHI=
|
golang.org/x/oauth2 v0.30.0 h1:dnDm7JmhM45NNpd8FDDeLhK6FwqbOf4MLCM9zb1BOHI=
|
||||||
golang.org/x/oauth2 v0.30.0/go.mod h1:B++QgG3ZKulg6sRPGD/mqlHQs5rB3Ml9erfeDY7xKlU=
|
golang.org/x/oauth2 v0.30.0/go.mod h1:B++QgG3ZKulg6sRPGD/mqlHQs5rB3Ml9erfeDY7xKlU=
|
||||||
golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||||
@@ -251,9 +282,14 @@ golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8T
|
|||||||
golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||||
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||||
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||||
google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw=
|
google.golang.org/genproto/googleapis/api v0.0.0-20250115164207-1a7da9e5054f h1:gap6+3Gk41EItBuyi4XX/bp4oqJ3UwuIMl25yGinuAA=
|
||||||
google.golang.org/protobuf v1.30.0 h1:kPPoIgf3TsEvrm0PFe15JQ+570QVxYzEvvHqChK+cng=
|
google.golang.org/genproto/googleapis/api v0.0.0-20250115164207-1a7da9e5054f/go.mod h1:Ic02D47M+zbarjYYUlK57y316f2MoN0gjAwI3f2S95o=
|
||||||
google.golang.org/protobuf v1.30.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I=
|
google.golang.org/genproto/googleapis/rpc v0.0.0-20250115164207-1a7da9e5054f h1:OxYkA3wjPsZyBylwymxSHa7ViiW1Sml4ToBrncvFehI=
|
||||||
|
google.golang.org/genproto/googleapis/rpc v0.0.0-20250115164207-1a7da9e5054f/go.mod h1:+2Yz8+CLJbIfL9z73EW45avw8Lmge3xVElCP9zEKi50=
|
||||||
|
google.golang.org/grpc v1.69.4 h1:MF5TftSMkd8GLw/m0KM6V8CMOCY6NZ1NQDPGFgbTt4A=
|
||||||
|
google.golang.org/grpc v1.69.4/go.mod h1:vyjdE6jLBI76dgpDojsFGNaHlxdjXN9ghpnd2o7JGZ4=
|
||||||
|
google.golang.org/protobuf v1.36.3 h1:82DV7MYdb8anAVi3qge1wSnMDrnKK7ebr+I0hHRN1BU=
|
||||||
|
google.golang.org/protobuf v1.36.3/go.mod h1:9fA7Ob0pmnwhb644+1+CVWFRbNajQ6iRojtC/QF5bRE=
|
||||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||||
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
|
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
|
||||||
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q=
|
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q=
|
||||||
|
|||||||
+54
-10
@@ -193,6 +193,10 @@ type ChatMessage struct {
|
|||||||
Content string `json:"content,omitempty"`
|
Content string `json:"content,omitempty"`
|
||||||
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
|
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
|
||||||
ToolCallID string `json:"tool_call_id,omitempty"`
|
ToolCallID string `json:"tool_call_id,omitempty"`
|
||||||
|
// ToolName 仅 tool 角色:从 Eino/轨迹 JSON 的 name 或 tool_name 恢复,供续跑构造 ToolMessage。
|
||||||
|
ToolName string `json:"tool_name,omitempty"`
|
||||||
|
// ReasoningContent 对应 OpenAI/DeepSeek 的 reasoning_content;思考模式 + 工具调用后续跑须回传(见 DeepSeek 文档)。
|
||||||
|
ReasoningContent string `json:"reasoning_content,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// MarshalJSON 自定义JSON序列化,将tool_calls中的arguments转换为JSON字符串
|
// MarshalJSON 自定义JSON序列化,将tool_calls中的arguments转换为JSON字符串
|
||||||
@@ -206,11 +210,17 @@ func (cm ChatMessage) MarshalJSON() ([]byte, error) {
|
|||||||
if cm.Content != "" {
|
if cm.Content != "" {
|
||||||
aux["content"] = cm.Content
|
aux["content"] = cm.Content
|
||||||
}
|
}
|
||||||
|
if cm.ReasoningContent != "" {
|
||||||
|
aux["reasoning_content"] = cm.ReasoningContent
|
||||||
|
}
|
||||||
|
|
||||||
// 添加tool_call_id(如果存在)
|
// 添加tool_call_id(如果存在)
|
||||||
if cm.ToolCallID != "" {
|
if cm.ToolCallID != "" {
|
||||||
aux["tool_call_id"] = cm.ToolCallID
|
aux["tool_call_id"] = cm.ToolCallID
|
||||||
}
|
}
|
||||||
|
if cm.ToolName != "" {
|
||||||
|
aux["tool_name"] = cm.ToolName
|
||||||
|
}
|
||||||
|
|
||||||
// 转换tool_calls,将arguments转换为JSON字符串
|
// 转换tool_calls,将arguments转换为JSON字符串
|
||||||
if len(cm.ToolCalls) > 0 {
|
if len(cm.ToolCalls) > 0 {
|
||||||
@@ -438,6 +448,7 @@ func (a *Agent) AgentLoopWithProgress(ctx context.Context, userInput string, his
|
|||||||
Content: msg.Content,
|
Content: msg.Content,
|
||||||
ToolCalls: msg.ToolCalls,
|
ToolCalls: msg.ToolCalls,
|
||||||
ToolCallID: msg.ToolCallID,
|
ToolCallID: msg.ToolCallID,
|
||||||
|
ToolName: msg.ToolName,
|
||||||
})
|
})
|
||||||
addedCount++
|
addedCount++
|
||||||
contentPreview := msg.Content
|
contentPreview := msg.Content
|
||||||
@@ -587,11 +598,17 @@ func (a *Agent) AgentLoopWithProgress(ctx context.Context, userInput string, his
|
|||||||
thinkingStreamSeq++
|
thinkingStreamSeq++
|
||||||
thinkingStreamId := fmt.Sprintf("thinking-stream-%s-%d-%d", conversationID, i+1, thinkingStreamSeq)
|
thinkingStreamId := fmt.Sprintf("thinking-stream-%s-%d-%d", conversationID, i+1, thinkingStreamSeq)
|
||||||
thinkingStreamStarted := false
|
thinkingStreamStarted := false
|
||||||
|
var thinkingWire string
|
||||||
|
|
||||||
response, err := a.callOpenAIStreamWithToolCalls(ctx, messages, tools, func(delta string) error {
|
response, err := a.callOpenAIStreamWithToolCalls(ctx, messages, tools, func(delta string) error {
|
||||||
if delta == "" {
|
if delta == "" {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
var deltaOut string
|
||||||
|
thinkingWire, deltaOut = openai.NormalizeStreamingDelta(thinkingWire, delta)
|
||||||
|
if deltaOut == "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
if !thinkingStreamStarted {
|
if !thinkingStreamStarted {
|
||||||
thinkingStreamStarted = true
|
thinkingStreamStarted = true
|
||||||
sendProgress("thinking_stream_start", " ", map[string]interface{}{
|
sendProgress("thinking_stream_start", " ", map[string]interface{}{
|
||||||
@@ -600,10 +617,10 @@ func (a *Agent) AgentLoopWithProgress(ctx context.Context, userInput string, his
|
|||||||
"toolStream": false,
|
"toolStream": false,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
sendProgress("thinking_stream_delta", delta, map[string]interface{}{
|
sendProgress("thinking_stream_delta", deltaOut, openai.WithSSEAccumulated(map[string]interface{}{
|
||||||
"streamId": thinkingStreamId,
|
"streamId": thinkingStreamId,
|
||||||
"iteration": i + 1,
|
"iteration": i + 1,
|
||||||
})
|
}, thinkingWire))
|
||||||
return nil
|
return nil
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -657,8 +674,8 @@ func (a *Agent) AgentLoopWithProgress(ctx context.Context, userInput string, his
|
|||||||
|
|
||||||
// 检查是否有工具调用
|
// 检查是否有工具调用
|
||||||
if len(choice.Message.ToolCalls) > 0 {
|
if len(choice.Message.ToolCalls) > 0 {
|
||||||
// 思考内容:如果本轮启用了思考流式增量(thinking_stream_*),前端会去重;
|
// ReAct 助手正文流式增量(thinking_stream_*)在 UI 上归为「思考」;若与 streamId 重复则前端会去重。
|
||||||
// 同时也需要在该“思考阶段结束”时补一条可落库的 thinking(用于刷新后持久化展示)。
|
// 该条 thinking 用于刷新后持久化展示(与流式聚合一致)。
|
||||||
if choice.Message.Content != "" {
|
if choice.Message.Content != "" {
|
||||||
sendProgress("thinking", choice.Message.Content, map[string]interface{}{
|
sendProgress("thinking", choice.Message.Content, map[string]interface{}{
|
||||||
"iteration": i + 1,
|
"iteration": i + 1,
|
||||||
@@ -816,10 +833,16 @@ func (a *Agent) AgentLoopWithProgress(ctx context.Context, userInput string, his
|
|||||||
"mcpExecutionIds": result.MCPExecutionIDs,
|
"mcpExecutionIds": result.MCPExecutionIDs,
|
||||||
"messageGeneratedBy": "summary",
|
"messageGeneratedBy": "summary",
|
||||||
})
|
})
|
||||||
|
var summaryWire string
|
||||||
streamText, _ := a.callOpenAIStreamText(ctx, messages, []Tool{}, func(delta string) error {
|
streamText, _ := a.callOpenAIStreamText(ctx, messages, []Tool{}, func(delta string) error {
|
||||||
sendProgress("response_delta", delta, map[string]interface{}{
|
var deltaOut string
|
||||||
|
summaryWire, deltaOut = openai.NormalizeStreamingDelta(summaryWire, delta)
|
||||||
|
if deltaOut == "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
sendProgress("response_delta", deltaOut, openai.WithSSEAccumulated(map[string]interface{}{
|
||||||
"conversationId": conversationID,
|
"conversationId": conversationID,
|
||||||
})
|
}, summaryWire))
|
||||||
return nil
|
return nil
|
||||||
})
|
})
|
||||||
if strings.TrimSpace(streamText) != "" {
|
if strings.TrimSpace(streamText) != "" {
|
||||||
@@ -863,10 +886,16 @@ func (a *Agent) AgentLoopWithProgress(ctx context.Context, userInput string, his
|
|||||||
"mcpExecutionIds": result.MCPExecutionIDs,
|
"mcpExecutionIds": result.MCPExecutionIDs,
|
||||||
"messageGeneratedBy": "summary",
|
"messageGeneratedBy": "summary",
|
||||||
})
|
})
|
||||||
|
var summaryWire string
|
||||||
streamText, _ := a.callOpenAIStreamText(ctx, messages, []Tool{}, func(delta string) error {
|
streamText, _ := a.callOpenAIStreamText(ctx, messages, []Tool{}, func(delta string) error {
|
||||||
sendProgress("response_delta", delta, map[string]interface{}{
|
var deltaOut string
|
||||||
|
summaryWire, deltaOut = openai.NormalizeStreamingDelta(summaryWire, delta)
|
||||||
|
if deltaOut == "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
sendProgress("response_delta", deltaOut, openai.WithSSEAccumulated(map[string]interface{}{
|
||||||
"conversationId": conversationID,
|
"conversationId": conversationID,
|
||||||
})
|
}, summaryWire))
|
||||||
return nil
|
return nil
|
||||||
})
|
})
|
||||||
if strings.TrimSpace(streamText) != "" {
|
if strings.TrimSpace(streamText) != "" {
|
||||||
@@ -910,10 +939,16 @@ func (a *Agent) AgentLoopWithProgress(ctx context.Context, userInput string, his
|
|||||||
"mcpExecutionIds": result.MCPExecutionIDs,
|
"mcpExecutionIds": result.MCPExecutionIDs,
|
||||||
"messageGeneratedBy": "max_iter_summary",
|
"messageGeneratedBy": "max_iter_summary",
|
||||||
})
|
})
|
||||||
|
var summaryWire string
|
||||||
streamText, _ := a.callOpenAIStreamText(ctx, messages, []Tool{}, func(delta string) error {
|
streamText, _ := a.callOpenAIStreamText(ctx, messages, []Tool{}, func(delta string) error {
|
||||||
sendProgress("response_delta", delta, map[string]interface{}{
|
var deltaOut string
|
||||||
|
summaryWire, deltaOut = openai.NormalizeStreamingDelta(summaryWire, delta)
|
||||||
|
if deltaOut == "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
sendProgress("response_delta", deltaOut, openai.WithSSEAccumulated(map[string]interface{}{
|
||||||
"conversationId": conversationID,
|
"conversationId": conversationID,
|
||||||
})
|
}, summaryWire))
|
||||||
return nil
|
return nil
|
||||||
})
|
})
|
||||||
if strings.TrimSpace(streamText) != "" {
|
if strings.TrimSpace(streamText) != "" {
|
||||||
@@ -1909,6 +1944,15 @@ func (a *Agent) ExecuteMCPToolForConversation(ctx context.Context, conversationI
|
|||||||
return a.executeToolViaMCP(ctx, toolName, args)
|
return a.executeToolViaMCP(ctx, toolName, args)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// RecordLocalToolExecution 将非 CallTool 路径完成的工具调用写入 MCP 监控库(与 CallTool 落库一致),返回 executionId。
|
||||||
|
// 用于 Eino filesystem execute 等场景,使助手气泡「渗透测试详情」与常规 MCP 一致可点进监控。
|
||||||
|
func (a *Agent) RecordLocalToolExecution(toolName string, args map[string]interface{}, resultText string, invokeErr error) string {
|
||||||
|
if a == nil || a.mcpServer == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return a.mcpServer.RecordCompletedToolInvocation(toolName, args, resultText, invokeErr)
|
||||||
|
}
|
||||||
|
|
||||||
// CancelMCPToolExecutionWithNote 取消一次进行中的 MCP 工具(先内部后外部),与监控页「终止工具」一致;note 非空时合并进返回给模型的文本。
|
// CancelMCPToolExecutionWithNote 取消一次进行中的 MCP 工具(先内部后外部),与监控页「终止工具」一致;note 非空时合并进返回给模型的文本。
|
||||||
func (a *Agent) CancelMCPToolExecutionWithNote(executionID, note string) bool {
|
func (a *Agent) CancelMCPToolExecutionWithNote(executionID, note string) bool {
|
||||||
executionID = strings.TrimSpace(executionID)
|
executionID = strings.TrimSpace(executionID)
|
||||||
|
|||||||
@@ -0,0 +1,167 @@
|
|||||||
|
package agent
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ParseTraceMessages 解析落库的 last_react_input(OpenAI 风格 messages JSON 数组)。
|
||||||
|
func ParseTraceMessages(traceInputJSON string) ([]ChatMessage, error) {
|
||||||
|
traceInputJSON = strings.TrimSpace(traceInputJSON)
|
||||||
|
if traceInputJSON == "" {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
var raw []map[string]interface{}
|
||||||
|
if err := json.Unmarshal([]byte(traceInputJSON), &raw); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
out := make([]ChatMessage, 0, len(raw))
|
||||||
|
for _, msgMap := range raw {
|
||||||
|
msg := ChatMessage{}
|
||||||
|
role, _ := msgMap["role"].(string)
|
||||||
|
if role == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
msg.Role = role
|
||||||
|
if content, ok := msgMap["content"].(string); ok {
|
||||||
|
msg.Content = content
|
||||||
|
}
|
||||||
|
if rc, ok := msgMap["reasoning_content"].(string); ok && strings.TrimSpace(rc) != "" {
|
||||||
|
msg.ReasoningContent = rc
|
||||||
|
}
|
||||||
|
if toolCallsRaw, ok := msgMap["tool_calls"]; ok && toolCallsRaw != nil {
|
||||||
|
if toolCallsArray, ok := toolCallsRaw.([]interface{}); ok {
|
||||||
|
for _, tcRaw := range toolCallsArray {
|
||||||
|
tcMap, ok := tcRaw.(map[string]interface{})
|
||||||
|
if !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
toolCall := ToolCall{}
|
||||||
|
if id, ok := tcMap["id"].(string); ok {
|
||||||
|
toolCall.ID = id
|
||||||
|
}
|
||||||
|
if toolType, ok := tcMap["type"].(string); ok {
|
||||||
|
toolCall.Type = toolType
|
||||||
|
}
|
||||||
|
if funcMap, ok := tcMap["function"].(map[string]interface{}); ok {
|
||||||
|
toolCall.Function = FunctionCall{}
|
||||||
|
if name, ok := funcMap["name"].(string); ok {
|
||||||
|
toolCall.Function.Name = name
|
||||||
|
}
|
||||||
|
if argsRaw, ok := funcMap["arguments"]; ok {
|
||||||
|
if argsStr, ok := argsRaw.(string); ok {
|
||||||
|
var argsMap map[string]interface{}
|
||||||
|
if err := json.Unmarshal([]byte(argsStr), &argsMap); err == nil {
|
||||||
|
toolCall.Function.Arguments = argsMap
|
||||||
|
}
|
||||||
|
} else if argsMap, ok := argsRaw.(map[string]interface{}); ok {
|
||||||
|
toolCall.Function.Arguments = argsMap
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if toolCall.ID != "" {
|
||||||
|
msg.ToolCalls = append(msg.ToolCalls, toolCall)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if toolCallID, ok := msgMap["tool_call_id"].(string); ok {
|
||||||
|
msg.ToolCallID = toolCallID
|
||||||
|
}
|
||||||
|
if tn, ok := msgMap["tool_name"].(string); ok && strings.TrimSpace(tn) != "" {
|
||||||
|
msg.ToolName = strings.TrimSpace(tn)
|
||||||
|
} else if tn, ok := msgMap["name"].(string); ok && strings.TrimSpace(tn) != "" && strings.EqualFold(msg.Role, "tool") {
|
||||||
|
msg.ToolName = strings.TrimSpace(tn)
|
||||||
|
}
|
||||||
|
out = append(out, msg)
|
||||||
|
}
|
||||||
|
return out, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ExtractLastUserTurnMessages 仅保留最后一次 user 提问起的消息(不含更早的用户轮次;跳过 system)。
|
||||||
|
// 与「继续对话」续跑所用轨迹范围一致:当前任务轮次,而非整段多轮对话历史。
|
||||||
|
func ExtractLastUserTurnMessages(msgs []ChatMessage) []ChatMessage {
|
||||||
|
if len(msgs) == 0 {
|
||||||
|
return msgs
|
||||||
|
}
|
||||||
|
lastUser := -1
|
||||||
|
for i, m := range msgs {
|
||||||
|
if strings.EqualFold(m.Role, "user") {
|
||||||
|
lastUser = i
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if lastUser < 0 {
|
||||||
|
return msgs
|
||||||
|
}
|
||||||
|
trimmed := msgs[lastUser:]
|
||||||
|
out := make([]ChatMessage, 0, len(trimmed))
|
||||||
|
for _, m := range trimmed {
|
||||||
|
if strings.EqualFold(m.Role, "system") {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
out = append(out, m)
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
// ExtractLastUserTurnTraceJSON 在 JSON 轨迹上裁剪为最后一次 user 起的片段(供落库格式直接处理)。
|
||||||
|
func ExtractLastUserTurnTraceJSON(traceInputJSON string) string {
|
||||||
|
traceInputJSON = strings.TrimSpace(traceInputJSON)
|
||||||
|
if traceInputJSON == "" {
|
||||||
|
return traceInputJSON
|
||||||
|
}
|
||||||
|
var arr []map[string]interface{}
|
||||||
|
if err := json.Unmarshal([]byte(traceInputJSON), &arr); err != nil {
|
||||||
|
return traceInputJSON
|
||||||
|
}
|
||||||
|
lastUser := -1
|
||||||
|
for i, m := range arr {
|
||||||
|
if r, _ := m["role"].(string); strings.EqualFold(r, "user") {
|
||||||
|
lastUser = i
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if lastUser <= 0 {
|
||||||
|
return traceInputJSON
|
||||||
|
}
|
||||||
|
trimmed := arr[lastUser:]
|
||||||
|
b, err := json.Marshal(trimmed)
|
||||||
|
if err != nil {
|
||||||
|
return traceInputJSON
|
||||||
|
}
|
||||||
|
return string(b)
|
||||||
|
}
|
||||||
|
|
||||||
|
// MergeAssistantTraceOutput 将 last_react_output 合并进轨迹最后一条 assistant(与 loadHistoryFromAgentTrace 一致)。
|
||||||
|
func MergeAssistantTraceOutput(msgs []ChatMessage, assistantOut string) []ChatMessage {
|
||||||
|
assistantOut = strings.TrimSpace(assistantOut)
|
||||||
|
if assistantOut == "" || len(msgs) == 0 {
|
||||||
|
return msgs
|
||||||
|
}
|
||||||
|
out := append([]ChatMessage(nil), msgs...)
|
||||||
|
last := &out[len(out)-1]
|
||||||
|
if strings.EqualFold(last.Role, "assistant") && len(last.ToolCalls) == 0 {
|
||||||
|
last.Content = assistantOut
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
out = append(out, ChatMessage{
|
||||||
|
Role: "assistant",
|
||||||
|
Content: assistantOut,
|
||||||
|
})
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
// MessagesToTraceJSON 将消息带序列化为 JSON(跳过 system)。
|
||||||
|
func MessagesToTraceJSON(msgs []ChatMessage) (string, error) {
|
||||||
|
filtered := make([]ChatMessage, 0, len(msgs))
|
||||||
|
for _, m := range msgs {
|
||||||
|
if strings.EqualFold(m.Role, "system") {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
filtered = append(filtered, m)
|
||||||
|
}
|
||||||
|
b, err := json.Marshal(filtered)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
return string(b), nil
|
||||||
|
}
|
||||||
@@ -0,0 +1,57 @@
|
|||||||
|
package agent
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestExtractLastUserTurnTraceJSON(t *testing.T) {
|
||||||
|
raw := []map[string]interface{}{
|
||||||
|
{"role": "user", "content": "old question"},
|
||||||
|
{"role": "assistant", "content": "old answer"},
|
||||||
|
{"role": "user", "content": "new target 1.1.1.1"},
|
||||||
|
{"role": "assistant", "tool_calls": []interface{}{map[string]interface{}{
|
||||||
|
"id": "c1", "type": "function",
|
||||||
|
"function": map[string]interface{}{"name": "nmap", "arguments": "{}"},
|
||||||
|
}}},
|
||||||
|
{"role": "tool", "tool_call_id": "c1", "content": "open ports"},
|
||||||
|
}
|
||||||
|
b, _ := json.Marshal(raw)
|
||||||
|
out := ExtractLastUserTurnTraceJSON(string(b))
|
||||||
|
var trimmed []map[string]interface{}
|
||||||
|
if err := json.Unmarshal([]byte(out), &trimmed); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if len(trimmed) != 3 {
|
||||||
|
t.Fatalf("expected 3 messages, got %d", len(trimmed))
|
||||||
|
}
|
||||||
|
if trimmed[0]["content"] != "new target 1.1.1.1" {
|
||||||
|
t.Fatalf("unexpected first message: %v", trimmed[0])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestExtractLastUserTurnMessagesSkipsSystem(t *testing.T) {
|
||||||
|
msgs := []ChatMessage{
|
||||||
|
{Role: "system", Content: "sys"},
|
||||||
|
{Role: "user", Content: "q"},
|
||||||
|
{Role: "assistant", Content: "a"},
|
||||||
|
}
|
||||||
|
out := ExtractLastUserTurnMessages(msgs)
|
||||||
|
if len(out) != 2 {
|
||||||
|
t.Fatalf("expected 2, got %d", len(out))
|
||||||
|
}
|
||||||
|
if out[0].Role != "user" {
|
||||||
|
t.Fatal("expected user first")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMergeAssistantTraceOutput(t *testing.T) {
|
||||||
|
msgs := []ChatMessage{
|
||||||
|
{Role: "user", Content: "q"},
|
||||||
|
{Role: "assistant", Content: "draft"},
|
||||||
|
}
|
||||||
|
out := MergeAssistantTraceOutput(msgs, "final summary")
|
||||||
|
if out[len(out)-1].Content != "final summary" {
|
||||||
|
t.Fatalf("expected merged output, got %q", out[len(out)-1].Content)
|
||||||
|
}
|
||||||
|
}
|
||||||
+100
-11
@@ -3,8 +3,10 @@ package app
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"crypto/subtle"
|
"crypto/subtle"
|
||||||
|
"crypto/tls"
|
||||||
"database/sql"
|
"database/sql"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
@@ -16,6 +18,7 @@ import (
|
|||||||
"cyberstrike-ai/internal/c2"
|
"cyberstrike-ai/internal/c2"
|
||||||
"cyberstrike-ai/internal/config"
|
"cyberstrike-ai/internal/config"
|
||||||
"cyberstrike-ai/internal/database"
|
"cyberstrike-ai/internal/database"
|
||||||
|
"cyberstrike-ai/internal/einoobserve"
|
||||||
"cyberstrike-ai/internal/handler"
|
"cyberstrike-ai/internal/handler"
|
||||||
"cyberstrike-ai/internal/knowledge"
|
"cyberstrike-ai/internal/knowledge"
|
||||||
"cyberstrike-ai/internal/logger"
|
"cyberstrike-ai/internal/logger"
|
||||||
@@ -29,6 +32,7 @@ import (
|
|||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
"go.uber.org/zap"
|
"go.uber.org/zap"
|
||||||
|
"golang.org/x/net/http2"
|
||||||
)
|
)
|
||||||
|
|
||||||
// App 应用
|
// App 应用
|
||||||
@@ -52,6 +56,7 @@ type App struct {
|
|||||||
robotMu sync.Mutex // 保护钉钉/飞书长连接的 cancel
|
robotMu sync.Mutex // 保护钉钉/飞书长连接的 cancel
|
||||||
dingCancel context.CancelFunc // 钉钉 Stream 取消函数,用于配置变更时重启
|
dingCancel context.CancelFunc // 钉钉 Stream 取消函数,用于配置变更时重启
|
||||||
larkCancel context.CancelFunc // 飞书长连接取消函数,用于配置变更时重启
|
larkCancel context.CancelFunc // 飞书长连接取消函数,用于配置变更时重启
|
||||||
|
wechatCancel context.CancelFunc // 微信 iLink 长轮询取消函数
|
||||||
c2Manager *c2.Manager // C2 管理器(未启用 C2 时为 nil)
|
c2Manager *c2.Manager // C2 管理器(未启用 C2 时为 nil)
|
||||||
c2Watchdog *c2.SessionWatchdog // C2 会话看门狗
|
c2Watchdog *c2.SessionWatchdog // C2 会话看门狗
|
||||||
c2WatchdogCancel context.CancelFunc // 看门狗取消函数
|
c2WatchdogCancel context.CancelFunc // 看门狗取消函数
|
||||||
@@ -59,7 +64,7 @@ type App struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// New 创建新应用
|
// New 创建新应用
|
||||||
func New(cfg *config.Config, log *logger.Logger) (*App, error) {
|
func New(cfg *config.Config, log *logger.Logger, configPath string) (*App, error) {
|
||||||
gin.SetMode(gin.ReleaseMode)
|
gin.SetMode(gin.ReleaseMode)
|
||||||
router := gin.Default()
|
router := gin.Default()
|
||||||
|
|
||||||
@@ -90,6 +95,7 @@ func New(cfg *config.Config, log *logger.Logger) (*App, error) {
|
|||||||
|
|
||||||
// 创建MCP服务器(带数据库持久化)
|
// 创建MCP服务器(带数据库持久化)
|
||||||
mcpServer := mcp.NewServerWithStorage(log.Logger, db)
|
mcpServer := mcp.NewServerWithStorage(log.Logger, db)
|
||||||
|
mcpServer.ConfigureHTTPToolCallTimeoutFromAgentMinutes(cfg.Agent.ToolTimeoutMinutes)
|
||||||
|
|
||||||
// 创建安全工具执行器
|
// 创建安全工具执行器
|
||||||
executor := security.NewExecutor(&cfg.Security, mcpServer, log.Logger)
|
executor := security.NewExecutor(&cfg.Security, mcpServer, log.Logger)
|
||||||
@@ -290,10 +296,10 @@ func New(cfg *config.Config, log *logger.Logger) (*App, error) {
|
|||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
|
|
||||||
// 获取配置文件路径
|
// 配置文件路径必须由入口传入(与 flag -config 一致)。勿再用 os.Args[1],否则 ./cyberstrike-ai --https 会把 --https 当成路径。
|
||||||
configPath := "config.yaml"
|
configPath = strings.TrimSpace(configPath)
|
||||||
if len(os.Args) > 1 {
|
if configPath == "" {
|
||||||
configPath = os.Args[1]
|
configPath = "config.yaml"
|
||||||
}
|
}
|
||||||
|
|
||||||
skillsDir := skillpackage.SkillsRootFromConfig(cfg.SkillsDir, configPath)
|
skillsDir := skillpackage.SkillsRootFromConfig(cfg.SkillsDir, configPath)
|
||||||
@@ -444,9 +450,11 @@ func New(cfg *config.Config, log *logger.Logger) (*App, error) {
|
|||||||
configHandler.SetRetrieverUpdater(knowledgeRetriever)
|
configHandler.SetRetrieverUpdater(knowledgeRetriever)
|
||||||
}
|
}
|
||||||
|
|
||||||
// 设置机器人连接重启器,前端应用配置后无需重启服务即可使钉钉/飞书新配置生效
|
// 设置机器人连接重启器,前端应用配置后无需重启服务即可使钉钉/飞书/微信新配置生效
|
||||||
configHandler.SetRobotRestarter(app)
|
configHandler.SetRobotRestarter(app)
|
||||||
|
|
||||||
|
wechatRobotHandler := handler.NewWechatRobotHandler(cfg, configHandler, log.Logger)
|
||||||
|
|
||||||
configHandler.SetC2Runtime(app)
|
configHandler.SetC2Runtime(app)
|
||||||
configHandler.SetC2ToolRegistrar(func() error {
|
configHandler.SetC2ToolRegistrar(func() error {
|
||||||
if app.config.C2.EnabledEffective() && app.c2Manager != nil {
|
if app.config.C2.EnabledEffective() && app.c2Manager != nil {
|
||||||
@@ -464,6 +472,7 @@ func New(cfg *config.Config, log *logger.Logger) (*App, error) {
|
|||||||
notificationHandler,
|
notificationHandler,
|
||||||
conversationHandler,
|
conversationHandler,
|
||||||
robotHandler,
|
robotHandler,
|
||||||
|
wechatRobotHandler,
|
||||||
groupHandler,
|
groupHandler,
|
||||||
configHandler,
|
configHandler,
|
||||||
externalMCPHandler,
|
externalMCPHandler,
|
||||||
@@ -528,18 +537,49 @@ func (a *App) RunWithContext(ctx context.Context) error {
|
|||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
|
|
||||||
// 启动主服务器
|
// 启动主服务器(可选 HTTPS + HTTP/2,见 config server.tls_*)
|
||||||
addr := fmt.Sprintf("%s:%d", a.config.Server.Host, a.config.Server.Port)
|
addr := fmt.Sprintf("%s:%d", a.config.Server.Host, a.config.Server.Port)
|
||||||
a.logger.Info("启动HTTP服务器", zap.String("address", addr))
|
tlsMode, tlsConf, certFile, keyFile, tlsErr := prepareMainServerTLS(&a.config.Server)
|
||||||
|
if tlsErr != nil {
|
||||||
|
return tlsErr
|
||||||
|
}
|
||||||
|
|
||||||
srv := &http.Server{Addr: addr, Handler: a.router}
|
srv := &http.Server{Addr: addr, Handler: a.router}
|
||||||
|
var mainMux *mainServerMux
|
||||||
|
httpRedirect := config.ServerHTTPRedirectEnabled(&a.config.Server)
|
||||||
|
if tlsMode != mainTLSOff {
|
||||||
|
srv.TLSConfig = tlsConf
|
||||||
|
if err := http2.ConfigureServer(srv, &http2.Server{}); err != nil {
|
||||||
|
return fmt.Errorf("主服务 HTTP/2 配置失败: %w", err)
|
||||||
|
}
|
||||||
|
switch tlsMode {
|
||||||
|
case mainTLSFromFiles:
|
||||||
|
a.logger.Info("启动 HTTPS 主服务(已启用 HTTP/2 协商)",
|
||||||
|
zap.String("address", addr),
|
||||||
|
zap.String("cert", certFile),
|
||||||
|
)
|
||||||
|
case mainTLSInMemorySelfSigned:
|
||||||
|
a.logger.Info("启动 HTTPS 主服务(内存自签证书,仅测试;已启用 HTTP/2 协商)",
|
||||||
|
zap.String("address", addr),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
if httpRedirect {
|
||||||
|
a.logger.Info("已启用 HTTP→HTTPS 自动跳转(同端口嗅探分流)", zap.String("address", addr))
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
a.logger.Info("启动 HTTP 主服务", zap.String("address", addr))
|
||||||
|
}
|
||||||
|
|
||||||
// 监听 context 取消,优雅关闭 HTTP 服务器
|
// 监听 context 取消,优雅关闭 HTTP 服务器
|
||||||
go func() {
|
go func() {
|
||||||
<-ctx.Done()
|
<-ctx.Done()
|
||||||
shutdownCtx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
shutdownCtx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
if err := srv.Shutdown(shutdownCtx); err != nil {
|
if mainMux != nil {
|
||||||
|
if err := mainMux.Shutdown(shutdownCtx); err != nil {
|
||||||
|
a.logger.Error("HTTP/HTTPS 分流服务器关闭失败", zap.Error(err))
|
||||||
|
}
|
||||||
|
} else if err := srv.Shutdown(shutdownCtx); err != nil {
|
||||||
a.logger.Error("HTTP服务器关闭失败", zap.Error(err))
|
a.logger.Error("HTTP服务器关闭失败", zap.Error(err))
|
||||||
}
|
}
|
||||||
if mcpServer != nil {
|
if mcpServer != nil {
|
||||||
@@ -549,7 +589,36 @@ func (a *App) RunWithContext(ctx context.Context) error {
|
|||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
if err := srv.ListenAndServe(); err != nil && err != http.ErrServerClosed {
|
var err error
|
||||||
|
switch {
|
||||||
|
case tlsMode != mainTLSOff && httpRedirect:
|
||||||
|
var tlsConfReady *tls.Config
|
||||||
|
tlsConfReady, err = ensureMainTLSConfigCerts(tlsMode, tlsConf, certFile, keyFile)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("加载 TLS 证书: %w", err)
|
||||||
|
}
|
||||||
|
srv.TLSConfig = tlsConfReady
|
||||||
|
var ln net.Listener
|
||||||
|
ln, err = net.Listen("tcp", addr)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
mainMux = newMainServerMux(ln, srv, portFromListenAddr(addr), a.logger.Logger)
|
||||||
|
err = mainMux.Serve()
|
||||||
|
case tlsMode == mainTLSOff:
|
||||||
|
err = srv.ListenAndServe()
|
||||||
|
case tlsMode == mainTLSFromFiles:
|
||||||
|
err = srv.ListenAndServeTLS(certFile, keyFile)
|
||||||
|
case tlsMode == mainTLSInMemorySelfSigned:
|
||||||
|
var ln net.Listener
|
||||||
|
ln, err = tls.Listen("tcp", addr, srv.TLSConfig)
|
||||||
|
if err == nil {
|
||||||
|
err = srv.Serve(ln)
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
err = srv.ListenAndServe()
|
||||||
|
}
|
||||||
|
if err != nil && err != http.ErrServerClosed {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
@@ -557,6 +626,10 @@ func (a *App) RunWithContext(ctx context.Context) error {
|
|||||||
|
|
||||||
// Shutdown 关闭应用
|
// Shutdown 关闭应用
|
||||||
func (a *App) Shutdown() {
|
func (a *App) Shutdown() {
|
||||||
|
shutdownCtx, shutdownCancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||||
|
_ = einoobserve.ShutdownOtel(shutdownCtx)
|
||||||
|
shutdownCancel()
|
||||||
|
|
||||||
// 停止钉钉/飞书长连接
|
// 停止钉钉/飞书长连接
|
||||||
a.robotMu.Lock()
|
a.robotMu.Lock()
|
||||||
if a.dingCancel != nil {
|
if a.dingCancel != nil {
|
||||||
@@ -606,9 +679,14 @@ func (a *App) startRobotConnections() {
|
|||||||
a.dingCancel = cancel
|
a.dingCancel = cancel
|
||||||
go robot.StartDing(ctx, cfg.Robots, a.robotHandler, a.logger.Logger)
|
go robot.StartDing(ctx, cfg.Robots, a.robotHandler, a.logger.Logger)
|
||||||
}
|
}
|
||||||
|
if cfg.Robots.Wechat.Enabled && cfg.Robots.Wechat.BotToken != "" {
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
a.wechatCancel = cancel
|
||||||
|
go robot.StartWechat(ctx, cfg.Robots, a.robotHandler, cfg.Version, a.logger.Logger)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// RestartRobotConnections 重启钉钉/飞书长连接,使前端应用配置后立即生效(实现 handler.RobotRestarter)
|
// RestartRobotConnections 重启钉钉/飞书/微信长连接,使前端应用配置后立即生效(实现 handler.RobotRestarter)
|
||||||
func (a *App) RestartRobotConnections() {
|
func (a *App) RestartRobotConnections() {
|
||||||
a.robotMu.Lock()
|
a.robotMu.Lock()
|
||||||
if a.dingCancel != nil {
|
if a.dingCancel != nil {
|
||||||
@@ -619,6 +697,10 @@ func (a *App) RestartRobotConnections() {
|
|||||||
a.larkCancel()
|
a.larkCancel()
|
||||||
a.larkCancel = nil
|
a.larkCancel = nil
|
||||||
}
|
}
|
||||||
|
if a.wechatCancel != nil {
|
||||||
|
a.wechatCancel()
|
||||||
|
a.wechatCancel = nil
|
||||||
|
}
|
||||||
a.robotMu.Unlock()
|
a.robotMu.Unlock()
|
||||||
// 给旧 goroutine 一点时间退出
|
// 给旧 goroutine 一点时间退出
|
||||||
time.Sleep(200 * time.Millisecond)
|
time.Sleep(200 * time.Millisecond)
|
||||||
@@ -634,6 +716,7 @@ func setupRoutes(
|
|||||||
notificationHandler *handler.NotificationHandler,
|
notificationHandler *handler.NotificationHandler,
|
||||||
conversationHandler *handler.ConversationHandler,
|
conversationHandler *handler.ConversationHandler,
|
||||||
robotHandler *handler.RobotHandler,
|
robotHandler *handler.RobotHandler,
|
||||||
|
wechatRobotHandler *handler.WechatRobotHandler,
|
||||||
groupHandler *handler.GroupHandler,
|
groupHandler *handler.GroupHandler,
|
||||||
configHandler *handler.ConfigHandler,
|
configHandler *handler.ConfigHandler,
|
||||||
externalMCPHandler *handler.ExternalMCPHandler,
|
externalMCPHandler *handler.ExternalMCPHandler,
|
||||||
@@ -682,6 +765,12 @@ func setupRoutes(
|
|||||||
// 机器人测试(需登录):POST /api/robot/test,body: {"platform":"dingtalk","user_id":"test","text":"帮助"},用于验证机器人逻辑
|
// 机器人测试(需登录):POST /api/robot/test,body: {"platform":"dingtalk","user_id":"test","text":"帮助"},用于验证机器人逻辑
|
||||||
protected.POST("/robot/test", robotHandler.HandleRobotTest)
|
protected.POST("/robot/test", robotHandler.HandleRobotTest)
|
||||||
|
|
||||||
|
// 微信 iLink 扫码绑定(需登录)
|
||||||
|
protected.POST("/robot/wechat/qrcode", wechatRobotHandler.HandleWechatQRCode)
|
||||||
|
protected.GET("/robot/wechat/qrcode/status", wechatRobotHandler.HandleWechatQRCodeStatus)
|
||||||
|
protected.POST("/robot/wechat/qrcode/verify", wechatRobotHandler.HandleWechatVerifyCode)
|
||||||
|
protected.GET("/robot/wechat/status", wechatRobotHandler.HandleWechatStatus)
|
||||||
|
|
||||||
// Agent Loop
|
// Agent Loop
|
||||||
protected.POST("/agent-loop", agentHandler.AgentLoop)
|
protected.POST("/agent-loop", agentHandler.AgentLoop)
|
||||||
// Agent Loop 流式输出
|
// Agent Loop 流式输出
|
||||||
|
|||||||
@@ -0,0 +1,196 @@
|
|||||||
|
package app
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bufio"
|
||||||
|
"context"
|
||||||
|
"crypto/tls"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
"strconv"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"go.uber.org/zap"
|
||||||
|
)
|
||||||
|
|
||||||
|
// peekedConn 在已预读首字节后仍将连接交给 net/http 或 crypto/tls。
|
||||||
|
type peekedConn struct {
|
||||||
|
net.Conn
|
||||||
|
r *bufio.Reader
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *peekedConn) Read(p []byte) (int, error) {
|
||||||
|
return c.r.Read(p)
|
||||||
|
}
|
||||||
|
|
||||||
|
// oneConnListener 供 http.Server.Serve 处理单条 TCP 连接(含 keep-alive)。
|
||||||
|
type oneConnListener struct {
|
||||||
|
conn net.Conn
|
||||||
|
addr net.Addr
|
||||||
|
once sync.Once
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *oneConnListener) Accept() (net.Conn, error) {
|
||||||
|
var c net.Conn
|
||||||
|
l.once.Do(func() {
|
||||||
|
c = l.conn
|
||||||
|
l.conn = nil
|
||||||
|
})
|
||||||
|
if c == nil {
|
||||||
|
return nil, net.ErrClosed
|
||||||
|
}
|
||||||
|
return c, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *oneConnListener) Close() error { return nil }
|
||||||
|
func (l *oneConnListener) Addr() net.Addr { return l.addr }
|
||||||
|
|
||||||
|
func isTLSHandshakeRecord(b byte) bool {
|
||||||
|
return b == 0x16
|
||||||
|
}
|
||||||
|
|
||||||
|
func newHTTPToHTTPSRedirectHandler(httpsPort int) http.Handler {
|
||||||
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
host := r.Host
|
||||||
|
if h, _, err := net.SplitHostPort(host); err == nil {
|
||||||
|
host = h
|
||||||
|
}
|
||||||
|
var target string
|
||||||
|
if httpsPort == 443 {
|
||||||
|
target = fmt.Sprintf("https://%s%s", host, r.URL.RequestURI())
|
||||||
|
} else {
|
||||||
|
target = fmt.Sprintf("https://%s:%d%s", host, httpsPort, r.URL.RequestURI())
|
||||||
|
}
|
||||||
|
http.Redirect(w, r, target, http.StatusPermanentRedirect)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func portFromListenAddr(addr string) int {
|
||||||
|
_, portStr, err := net.SplitHostPort(addr)
|
||||||
|
if err != nil {
|
||||||
|
return 443
|
||||||
|
}
|
||||||
|
p, err := strconv.Atoi(portStr)
|
||||||
|
if err != nil || p <= 0 {
|
||||||
|
return 443
|
||||||
|
}
|
||||||
|
return p
|
||||||
|
}
|
||||||
|
|
||||||
|
func ensureMainTLSConfigCerts(mode mainTLSMode, tlsConf *tls.Config, certFile, keyFile string) (*tls.Config, error) {
|
||||||
|
if mode != mainTLSFromFiles {
|
||||||
|
return tlsConf, nil
|
||||||
|
}
|
||||||
|
if tlsConf == nil {
|
||||||
|
tlsConf = &tls.Config{MinVersion: tls.VersionTLS12}
|
||||||
|
}
|
||||||
|
if len(tlsConf.Certificates) > 0 {
|
||||||
|
return tlsConf, nil
|
||||||
|
}
|
||||||
|
cert, err := tls.LoadX509KeyPair(certFile, keyFile)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
tlsConf.Certificates = []tls.Certificate{cert}
|
||||||
|
return tlsConf, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type mainServerMux struct {
|
||||||
|
ln net.Listener
|
||||||
|
httpsSrv *http.Server
|
||||||
|
redirectSrv *http.Server
|
||||||
|
logger *zap.Logger
|
||||||
|
}
|
||||||
|
|
||||||
|
func newMainServerMux(ln net.Listener, httpsSrv *http.Server, httpsPort int, logger *zap.Logger) *mainServerMux {
|
||||||
|
return &mainServerMux{
|
||||||
|
ln: ln,
|
||||||
|
httpsSrv: httpsSrv,
|
||||||
|
redirectSrv: &http.Server{Handler: newHTTPToHTTPSRedirectHandler(httpsPort), ReadHeaderTimeout: 10 * time.Second},
|
||||||
|
logger: logger,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mainServerMux) Serve() error {
|
||||||
|
for {
|
||||||
|
conn, err := m.ln.Accept()
|
||||||
|
if err != nil {
|
||||||
|
if errors.Is(err, net.ErrClosed) {
|
||||||
|
return http.ErrServerClosed
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
go m.handleConn(conn)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mainServerMux) handleConn(raw net.Conn) {
|
||||||
|
if err := raw.SetReadDeadline(time.Now().Add(10 * time.Second)); err != nil {
|
||||||
|
_ = raw.Close()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
br := bufio.NewReader(raw)
|
||||||
|
b, err := br.Peek(1)
|
||||||
|
if err != nil {
|
||||||
|
_ = raw.Close()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
_ = raw.SetReadDeadline(time.Time{})
|
||||||
|
|
||||||
|
pc := &peekedConn{Conn: raw, r: br}
|
||||||
|
ocl := &oneConnListener{conn: pc, addr: raw.LocalAddr()}
|
||||||
|
|
||||||
|
if isTLSHandshakeRecord(b[0]) {
|
||||||
|
m.serveHTTPS(pc, raw.LocalAddr())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if err := m.redirectSrv.Serve(ocl); err != nil && !errors.Is(err, net.ErrClosed) && !errors.Is(err, http.ErrServerClosed) {
|
||||||
|
m.logger.Debug("HTTP 重定向连接处理结束", zap.Error(err))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// serveHTTPS 在已嗅探为 TLS 的连接上完成握手,再按 ALPN 走 HTTP/2 或 HTTP/1.1。
|
||||||
|
// 不能对同一 http.Server 并发调用 Serve(TLSConfig!=nil),否则握手/ALPN 会异常(浏览器 ERR_SSL_PROTOCOL_ERROR)。
|
||||||
|
func (m *mainServerMux) serveHTTPS(pc *peekedConn, localAddr net.Addr) {
|
||||||
|
tlsConn := tls.Server(pc, m.httpsSrv.TLSConfig)
|
||||||
|
handCtx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
if err := tlsConn.HandshakeContext(handCtx); err != nil {
|
||||||
|
m.logger.Debug("TLS 握手失败", zap.Error(err))
|
||||||
|
_ = pc.Close()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
srv := m.httpsSrv
|
||||||
|
if srv.TLSNextProto != nil {
|
||||||
|
proto := tlsConn.ConnectionState().NegotiatedProtocol
|
||||||
|
if fn := srv.TLSNextProto[proto]; fn != nil {
|
||||||
|
fn(srv, tlsConn, srv.Handler)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
plain := *srv
|
||||||
|
plain.TLSConfig = nil
|
||||||
|
ocl := &oneConnListener{conn: tlsConn, addr: localAddr}
|
||||||
|
if err := plain.Serve(ocl); err != nil && !errors.Is(err, net.ErrClosed) && !errors.Is(err, http.ErrServerClosed) {
|
||||||
|
m.logger.Debug("HTTPS 连接处理结束", zap.Error(err))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mainServerMux) Shutdown(ctx context.Context) error {
|
||||||
|
_ = m.ln.Close()
|
||||||
|
var err1, err2 error
|
||||||
|
if m.httpsSrv != nil {
|
||||||
|
err1 = m.httpsSrv.Shutdown(ctx)
|
||||||
|
}
|
||||||
|
if m.redirectSrv != nil {
|
||||||
|
err2 = m.redirectSrv.Shutdown(ctx)
|
||||||
|
}
|
||||||
|
if err1 != nil {
|
||||||
|
return err1
|
||||||
|
}
|
||||||
|
return err2
|
||||||
|
}
|
||||||
@@ -0,0 +1,150 @@
|
|||||||
|
package app
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/tls"
|
||||||
|
"io"
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"strconv"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"cyberstrike-ai/internal/config"
|
||||||
|
|
||||||
|
"golang.org/x/net/http2"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestNewHTTPToHTTPSRedirectHandler(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
httpsPort int
|
||||||
|
host string
|
||||||
|
uri string
|
||||||
|
wantTarget string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "non standard port",
|
||||||
|
httpsPort: 8080,
|
||||||
|
host: "127.0.0.1:8080",
|
||||||
|
uri: "/login?next=/",
|
||||||
|
wantTarget: "https://127.0.0.1:8080/login?next=/",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "standard port",
|
||||||
|
httpsPort: 443,
|
||||||
|
host: "example.com:80",
|
||||||
|
uri: "/",
|
||||||
|
wantTarget: "https://example.com/",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
tt := tt
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
h := newHTTPToHTTPSRedirectHandler(tt.httpsPort)
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "http://"+tt.host+tt.uri, nil)
|
||||||
|
req.Host = tt.host
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
h.ServeHTTP(rec, req)
|
||||||
|
if rec.Code != http.StatusPermanentRedirect {
|
||||||
|
t.Fatalf("status = %d, want %d", rec.Code, http.StatusPermanentRedirect)
|
||||||
|
}
|
||||||
|
if got := rec.Header().Get("Location"); got != tt.wantTarget {
|
||||||
|
t.Fatalf("Location = %q, want %q", got, tt.wantTarget)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestIsTLSHandshakeRecord(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
if !isTLSHandshakeRecord(0x16) {
|
||||||
|
t.Fatal("expected TLS handshake record")
|
||||||
|
}
|
||||||
|
if isTLSHandshakeRecord('G') {
|
||||||
|
t.Fatal("GET should not be TLS")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestServerHTTPRedirectEnabled(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
disabled := false
|
||||||
|
enabled := true
|
||||||
|
if config.ServerHTTPRedirectEnabled(nil) {
|
||||||
|
t.Fatal("nil config should disable redirect")
|
||||||
|
}
|
||||||
|
if !config.ServerHTTPRedirectEnabled(&config.ServerConfig{TLSEnabled: true}) {
|
||||||
|
t.Fatal("HTTPS without explicit flag should enable redirect")
|
||||||
|
}
|
||||||
|
if config.ServerHTTPRedirectEnabled(&config.ServerConfig{TLSEnabled: true, TLSHTTPRedirect: &disabled}) {
|
||||||
|
t.Fatal("explicit false should disable redirect")
|
||||||
|
}
|
||||||
|
if !config.ServerHTTPRedirectEnabled(&config.ServerConfig{TLSEnabled: true, TLSHTTPRedirect: &enabled}) {
|
||||||
|
t.Fatal("explicit true should enable redirect")
|
||||||
|
}
|
||||||
|
if config.ServerHTTPRedirectEnabled(&config.ServerConfig{}) {
|
||||||
|
t.Fatal("plain HTTP should not redirect")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMainServerMuxHTTPRedirectAndHTTPS(t *testing.T) {
|
||||||
|
cert, err := generateMainServerSelfSignedCert()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("generate cert: %v", err)
|
||||||
|
}
|
||||||
|
handler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||||
|
_, _ = io.WriteString(w, "ok")
|
||||||
|
})
|
||||||
|
srv := &http.Server{Handler: handler, TLSConfig: &tls.Config{
|
||||||
|
MinVersion: tls.VersionTLS12,
|
||||||
|
Certificates: []tls.Certificate{cert},
|
||||||
|
}}
|
||||||
|
if err := http2.ConfigureServer(srv, &http2.Server{}); err != nil {
|
||||||
|
t.Fatalf("configure http2: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
ln, err := net.Listen("tcp", "127.0.0.1:0")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("listen: %v", err)
|
||||||
|
}
|
||||||
|
defer ln.Close()
|
||||||
|
|
||||||
|
mux := newMainServerMux(ln, srv, portFromListenAddr(ln.Addr().String()), nil)
|
||||||
|
go func() { _ = mux.Serve() }()
|
||||||
|
|
||||||
|
client := &http.Client{
|
||||||
|
Transport: &http.Transport{
|
||||||
|
TLSClientConfig: &tls.Config{InsecureSkipVerify: true, MinVersion: tls.VersionTLS12},
|
||||||
|
},
|
||||||
|
CheckRedirect: func(_ *http.Request, _ []*http.Request) error {
|
||||||
|
return http.ErrUseLastResponse
|
||||||
|
},
|
||||||
|
}
|
||||||
|
addr := ln.Addr().String()
|
||||||
|
|
||||||
|
httpResp, err := client.Get("http://" + addr + "/")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("http get: %v", err)
|
||||||
|
}
|
||||||
|
_ = httpResp.Body.Close()
|
||||||
|
if httpResp.StatusCode != http.StatusPermanentRedirect {
|
||||||
|
t.Fatalf("http status = %d, want %d", httpResp.StatusCode, http.StatusPermanentRedirect)
|
||||||
|
}
|
||||||
|
if got := httpResp.Header.Get("Location"); got != "https://127.0.0.1:"+strconv.Itoa(portFromListenAddr(addr))+"/" {
|
||||||
|
t.Fatalf("Location = %q", got)
|
||||||
|
}
|
||||||
|
|
||||||
|
httpsResp, err := client.Get("https://" + addr + "/")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("https get: %v", err)
|
||||||
|
}
|
||||||
|
defer httpsResp.Body.Close()
|
||||||
|
if httpsResp.StatusCode != http.StatusOK {
|
||||||
|
t.Fatalf("https status = %d, want %d", httpsResp.StatusCode, http.StatusOK)
|
||||||
|
}
|
||||||
|
body, _ := io.ReadAll(httpsResp.Body)
|
||||||
|
if string(body) != "ok" {
|
||||||
|
t.Fatalf("body = %q, want ok", body)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,86 @@
|
|||||||
|
package app
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/ecdsa"
|
||||||
|
"crypto/elliptic"
|
||||||
|
"crypto/rand"
|
||||||
|
"crypto/tls"
|
||||||
|
"crypto/x509"
|
||||||
|
"crypto/x509/pkix"
|
||||||
|
"encoding/pem"
|
||||||
|
"fmt"
|
||||||
|
"math/big"
|
||||||
|
"net"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"cyberstrike-ai/internal/config"
|
||||||
|
)
|
||||||
|
|
||||||
|
// mainTLSMode 主 Web 服务 TLS 启动方式。
|
||||||
|
type mainTLSMode int
|
||||||
|
|
||||||
|
const (
|
||||||
|
mainTLSOff mainTLSMode = iota
|
||||||
|
mainTLSFromFiles
|
||||||
|
mainTLSInMemorySelfSigned
|
||||||
|
)
|
||||||
|
|
||||||
|
// prepareMainServerTLS 根据 server 配置决定主站是否启用 HTTPS(及 HTTP/2 协商)。
|
||||||
|
// fromFiles:使用 tls_cert_path + tls_key_path,由 http.Server.ListenAndServeTLS 加载 PEM。
|
||||||
|
// inMemory:tls_auto_self_sign 生成的自签证书,仅用于本地/测试。
|
||||||
|
func prepareMainServerTLS(cfg *config.ServerConfig) (mode mainTLSMode, tlsConf *tls.Config, certFile, keyFile string, err error) {
|
||||||
|
if cfg == nil || !config.MainWebUIUsesHTTPS(cfg) {
|
||||||
|
return mainTLSOff, nil, "", "", nil
|
||||||
|
}
|
||||||
|
certFile = strings.TrimSpace(cfg.TLSCertPath)
|
||||||
|
keyFile = strings.TrimSpace(cfg.TLSKeyPath)
|
||||||
|
if certFile != "" && keyFile != "" {
|
||||||
|
// 证书由 ListenAndServeTLS 从文件加载;此处仅提供最小 TLS 配置供 http2.ConfigureServer 合并 ALPN。
|
||||||
|
return mainTLSFromFiles, &tls.Config{MinVersion: tls.VersionTLS12}, certFile, keyFile, nil
|
||||||
|
}
|
||||||
|
if cfg.TLSAutoSelfSign {
|
||||||
|
cert, genErr := generateMainServerSelfSignedCert()
|
||||||
|
if genErr != nil {
|
||||||
|
return mainTLSOff, nil, "", "", fmt.Errorf("生成自签 TLS 证书: %w", genErr)
|
||||||
|
}
|
||||||
|
tlsConf = &tls.Config{
|
||||||
|
MinVersion: tls.VersionTLS12,
|
||||||
|
Certificates: []tls.Certificate{cert},
|
||||||
|
}
|
||||||
|
return mainTLSInMemorySelfSigned, tlsConf, "", "", nil
|
||||||
|
}
|
||||||
|
return mainTLSOff, nil, "", "", fmt.Errorf("server: 已启用 TLS(tls_enabled / tls_auto_self_sign / 证书路径),请设置 tls_cert_path 与 tls_key_path,或将 tls_auto_self_sign 设为 true(仅测试环境)")
|
||||||
|
}
|
||||||
|
|
||||||
|
func generateMainServerSelfSignedCert() (tls.Certificate, error) {
|
||||||
|
priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||||
|
if err != nil {
|
||||||
|
return tls.Certificate{}, err
|
||||||
|
}
|
||||||
|
serial, err := rand.Int(rand.Reader, big.NewInt(1<<62))
|
||||||
|
if err != nil {
|
||||||
|
return tls.Certificate{}, err
|
||||||
|
}
|
||||||
|
tmpl := &x509.Certificate{
|
||||||
|
SerialNumber: serial,
|
||||||
|
Subject: pkix.Name{CommonName: "CyberStrikeAI"},
|
||||||
|
NotBefore: time.Now().Add(-1 * time.Hour),
|
||||||
|
NotAfter: time.Now().Add(365 * 24 * time.Hour),
|
||||||
|
KeyUsage: x509.KeyUsageDigitalSignature,
|
||||||
|
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
|
||||||
|
IPAddresses: []net.IP{net.ParseIP("127.0.0.1"), net.ParseIP("::1")},
|
||||||
|
DNSNames: []string{"localhost"},
|
||||||
|
}
|
||||||
|
der, err := x509.CreateCertificate(rand.Reader, tmpl, tmpl, &priv.PublicKey, priv)
|
||||||
|
if err != nil {
|
||||||
|
return tls.Certificate{}, err
|
||||||
|
}
|
||||||
|
keyDER, err := x509.MarshalECPrivateKey(priv)
|
||||||
|
if err != nil {
|
||||||
|
return tls.Certificate{}, err
|
||||||
|
}
|
||||||
|
certPEM := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: der})
|
||||||
|
keyPEM := pem.EncodeToMemory(&pem.Block{Type: "EC PRIVATE KEY", Bytes: keyDER})
|
||||||
|
return tls.X509KeyPair(certPEM, keyPEM)
|
||||||
|
}
|
||||||
@@ -82,7 +82,7 @@ func NewBuilder(db *database.DB, openAIConfig *config.OpenAIConfig, logger *zap.
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// BuildChainFromConversation 从对话构建攻击链(简化版本:用户输入+最后一轮ReAct输入+大模型输出)
|
// BuildChainFromConversation 从对话构建攻击链(单次 LLM 调用;输入为当前任务轮次的 last_react 轨迹,与继续对话续跑范围一致)。
|
||||||
func (b *Builder) BuildChainFromConversation(ctx context.Context, conversationID string) (*Chain, error) {
|
func (b *Builder) BuildChainFromConversation(ctx context.Context, conversationID string) (*Chain, error) {
|
||||||
b.logger.Info("开始构建攻击链(简化版本)", zap.String("conversationId", conversationID))
|
b.logger.Info("开始构建攻击链(简化版本)", zap.String("conversationId", conversationID))
|
||||||
|
|
||||||
@@ -157,33 +157,34 @@ func (b *Builder) BuildChainFromConversation(ctx context.Context, conversationID
|
|||||||
var reactInputFinal string
|
var reactInputFinal string
|
||||||
var dataSource string // 记录数据来源
|
var dataSource string // 记录数据来源
|
||||||
|
|
||||||
// 如果成功获取到保存的ReAct数据,直接使用
|
// 优先使用落库的代理轨迹(与继续对话 loadHistoryFromAgentTrace 同源),并裁剪为「当前任务轮次」
|
||||||
if reactInputJSON != "" && modelOutput != "" {
|
if reactInputJSON != "" {
|
||||||
// 计算 ReAct 输入的哈希值,用于追踪
|
trimmedJSON := agent.ExtractLastUserTurnTraceJSON(reactInputJSON)
|
||||||
hash := sha256.Sum256([]byte(reactInputJSON))
|
hash := sha256.Sum256([]byte(trimmedJSON))
|
||||||
reactInputHash := hex.EncodeToString(hash[:])[:16] // 使用前16字符作为短标识
|
reactInputHash := hex.EncodeToString(hash[:])[:16]
|
||||||
|
|
||||||
// 统计消息数量
|
|
||||||
var messageCount int
|
var messageCount int
|
||||||
var tempMessages []interface{}
|
if msgs, parseErr := agent.ParseTraceMessages(trimmedJSON); parseErr == nil {
|
||||||
if json.Unmarshal([]byte(reactInputJSON), &tempMessages) == nil {
|
messageCount = len(msgs)
|
||||||
messageCount = len(tempMessages)
|
msgs = agent.MergeAssistantTraceOutput(msgs, modelOutput)
|
||||||
|
reactInputFinal = b.formatAgentTraceFromChatMessages(msgs)
|
||||||
|
} else {
|
||||||
|
b.logger.Warn("解析代理轨迹失败,回退原始 JSON 格式化", zap.Error(parseErr))
|
||||||
|
reactInputFinal = b.formatAgentTraceInputFromJSON(trimmedJSON)
|
||||||
|
if strings.TrimSpace(modelOutput) != "" {
|
||||||
|
reactInputFinal += "\n\n## 助手结论(last_react_output)\n\n" + modelOutput
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
dataSource = "database_last_agent_trace"
|
dataSource = "last_user_turn_agent_trace"
|
||||||
b.logger.Info("使用保存的ReAct数据构建攻击链",
|
b.logger.Info("使用当前任务轮次代理轨迹构建攻击链(与续跑上下文范围一致)",
|
||||||
zap.String("conversationId", conversationID),
|
zap.String("conversationId", conversationID),
|
||||||
zap.String("dataSource", dataSource),
|
zap.String("dataSource", dataSource),
|
||||||
zap.Int("reactInputSize", len(reactInputJSON)),
|
zap.Int("traceInputSizeBeforeTrim", len(reactInputJSON)),
|
||||||
|
zap.Int("traceInputSizeAfterTrim", len(trimmedJSON)),
|
||||||
zap.Int("messageCount", messageCount),
|
zap.Int("messageCount", messageCount),
|
||||||
zap.String("reactInputHash", reactInputHash),
|
zap.String("reactInputHash", reactInputHash),
|
||||||
zap.Int("modelOutputSize", len(modelOutput)))
|
zap.Int("modelOutputSize", len(modelOutput)))
|
||||||
|
|
||||||
// 从保存的ReAct输入(JSON格式)中提取用户输入
|
|
||||||
// userInput = b.extractUserInputFromReActInput(reactInputJSON)
|
|
||||||
|
|
||||||
// 将JSON格式的messages转换为可读格式
|
|
||||||
reactInputFinal = b.formatAgentTraceInputFromJSON(reactInputJSON)
|
|
||||||
} else {
|
} else {
|
||||||
// 2. 如果没有保存的ReAct数据,从对话消息构建
|
// 2. 如果没有保存的ReAct数据,从对话消息构建
|
||||||
dataSource = "messages_table"
|
dataSource = "messages_table"
|
||||||
@@ -243,8 +244,15 @@ func (b *Builder) BuildChainFromConversation(ctx context.Context, conversationID
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// 3. 构建简化的prompt,一次性传递给大模型
|
// 3. 按 token 预算压缩输入,再构建 prompt(避免超出模型上下文)
|
||||||
prompt := b.buildSimplePrompt(reactInputFinal, modelOutput)
|
reactInputFinal, modelOutput, _ = b.fitAttackChainPayload(reactInputFinal, modelOutput)
|
||||||
|
|
||||||
|
// 4. 构建 prompt 并单次调用大模型(助手结论已并入轨迹时不再重复传入)
|
||||||
|
promptAssistantOut := modelOutput
|
||||||
|
if reactInputJSON != "" {
|
||||||
|
promptAssistantOut = ""
|
||||||
|
}
|
||||||
|
prompt := b.buildSimplePrompt(reactInputFinal, promptAssistantOut)
|
||||||
// fmt.Println(prompt)
|
// fmt.Println(prompt)
|
||||||
// 6. 调用AI生成攻击链(一次性,不做任何处理)
|
// 6. 调用AI生成攻击链(一次性,不做任何处理)
|
||||||
chainJSON, err := b.callAIForChainGeneration(ctx, prompt)
|
chainJSON, err := b.callAIForChainGeneration(ctx, prompt)
|
||||||
@@ -301,7 +309,7 @@ func (b *Builder) formatProcessDetailsForAttackChain(details []database.ProcessD
|
|||||||
// 目标:以主 agent(编排器)视角输出整轮迭代
|
// 目标:以主 agent(编排器)视角输出整轮迭代
|
||||||
// - 保留:编排器工具调用/结果、对子代理的 task 调度、子代理最终回复(不含推理)
|
// - 保留:编排器工具调用/结果、对子代理的 task 调度、子代理最终回复(不含推理)
|
||||||
// - 丢弃:thinking/planning/progress 等噪声、子代理的工具细节与推理过程
|
// - 丢弃:thinking/planning/progress 等噪声、子代理的工具细节与推理过程
|
||||||
if d.EventType == "progress" || d.EventType == "thinking" || d.EventType == "planning" {
|
if d.EventType == "progress" || d.EventType == "thinking" || d.EventType == "reasoning_chain" || d.EventType == "planning" {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -366,10 +374,17 @@ func (b *Builder) formatProcessDetailsForAttackChain(details []database.ProcessD
|
|||||||
return strings.TrimSpace(sb.String())
|
return strings.TrimSpace(sb.String())
|
||||||
}
|
}
|
||||||
|
|
||||||
// buildAgentTraceInput 构建最后一轮ReAct的输入(历史消息+当前用户输入)
|
// buildAgentTraceInput 构建最后一轮 ReAct 的输入(从最后一条 user 消息起,不含更早轮次)。
|
||||||
func (b *Builder) buildAgentTraceInput(messages []database.Message) string {
|
func (b *Builder) buildAgentTraceInput(messages []database.Message) string {
|
||||||
|
start := 0
|
||||||
|
for i := len(messages) - 1; i >= 0; i-- {
|
||||||
|
if strings.EqualFold(messages[i].Role, "user") {
|
||||||
|
start = i
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
var builder strings.Builder
|
var builder strings.Builder
|
||||||
for _, msg := range messages {
|
for _, msg := range messages[start:] {
|
||||||
builder.WriteString(fmt.Sprintf("[%s]: %s\n\n", msg.Role, msg.Content))
|
builder.WriteString(fmt.Sprintf("[%s]: %s\n\n", msg.Role, msg.Content))
|
||||||
}
|
}
|
||||||
return builder.String()
|
return builder.String()
|
||||||
@@ -396,67 +411,66 @@ func (b *Builder) buildAgentTraceInput(messages []database.Message) string {
|
|||||||
// return ""
|
// return ""
|
||||||
// }
|
// }
|
||||||
|
|
||||||
// formatAgentTraceInputFromJSON 将JSON格式的messages数组转换为可读的字符串格式
|
// formatAgentTraceInputFromJSON 将 JSON 轨迹转为可读文本(会先按当前任务轮次裁剪)。
|
||||||
func (b *Builder) formatAgentTraceInputFromJSON(reactInputJSON string) string {
|
func (b *Builder) formatAgentTraceInputFromJSON(reactInputJSON string) string {
|
||||||
var messages []map[string]interface{}
|
trimmed := agent.ExtractLastUserTurnTraceJSON(reactInputJSON)
|
||||||
if err := json.Unmarshal([]byte(reactInputJSON), &messages); err != nil {
|
msgs, err := agent.ParseTraceMessages(trimmed)
|
||||||
|
if err != nil {
|
||||||
b.logger.Warn("解析ReAct输入JSON失败", zap.Error(err))
|
b.logger.Warn("解析ReAct输入JSON失败", zap.Error(err))
|
||||||
return reactInputJSON // 如果解析失败,返回原始JSON
|
return trimmed
|
||||||
}
|
}
|
||||||
|
return b.formatAgentTraceFromChatMessages(msgs)
|
||||||
|
}
|
||||||
|
|
||||||
|
// formatAgentTraceFromChatMessages 将代理消息带格式化为攻击链分析输入(与续跑轨迹字段一致)。
|
||||||
|
func (b *Builder) formatAgentTraceFromChatMessages(msgs []agent.ChatMessage) string {
|
||||||
var builder strings.Builder
|
var builder strings.Builder
|
||||||
for _, msg := range messages {
|
for _, msg := range msgs {
|
||||||
role, _ := msg["role"].(string)
|
role := msg.Role
|
||||||
content, _ := msg["content"].(string)
|
content := msg.Content
|
||||||
|
|
||||||
// 处理assistant消息:提取tool_calls信息
|
if strings.EqualFold(role, "assistant") && len(msg.ToolCalls) > 0 {
|
||||||
if role == "assistant" {
|
if content != "" {
|
||||||
if toolCalls, ok := msg["tool_calls"].([]interface{}); ok && len(toolCalls) > 0 {
|
builder.WriteString(fmt.Sprintf("[%s]: %s\n", role, content))
|
||||||
// 如果有文本内容,先显示
|
}
|
||||||
if content != "" {
|
builder.WriteString(fmt.Sprintf("[%s] 工具调用 (%d个):\n", role, len(msg.ToolCalls)))
|
||||||
builder.WriteString(fmt.Sprintf("[%s]: %s\n", role, content))
|
for i, tc := range msg.ToolCalls {
|
||||||
}
|
args := ""
|
||||||
// 详细显示每个工具调用
|
if tc.Function.Arguments != nil {
|
||||||
builder.WriteString(fmt.Sprintf("[%s] 工具调用 (%d个):\n", role, len(toolCalls)))
|
if b, err := json.Marshal(tc.Function.Arguments); err == nil {
|
||||||
for i, toolCall := range toolCalls {
|
args = string(b)
|
||||||
if tc, ok := toolCall.(map[string]interface{}); ok {
|
|
||||||
toolCallID, _ := tc["id"].(string)
|
|
||||||
if funcData, ok := tc["function"].(map[string]interface{}); ok {
|
|
||||||
toolName, _ := funcData["name"].(string)
|
|
||||||
arguments, _ := funcData["arguments"].(string)
|
|
||||||
builder.WriteString(fmt.Sprintf(" [工具调用 %d]\n", i+1))
|
|
||||||
builder.WriteString(fmt.Sprintf(" ID: %s\n", toolCallID))
|
|
||||||
builder.WriteString(fmt.Sprintf(" 工具名称: %s\n", toolName))
|
|
||||||
builder.WriteString(fmt.Sprintf(" 参数: %s\n", arguments))
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
builder.WriteString("\n")
|
builder.WriteString(fmt.Sprintf(" [工具调用 %d]\n", i+1))
|
||||||
continue
|
builder.WriteString(fmt.Sprintf(" ID: %s\n", tc.ID))
|
||||||
|
builder.WriteString(fmt.Sprintf(" 工具名称: %s\n", tc.Function.Name))
|
||||||
|
builder.WriteString(fmt.Sprintf(" 参数: %s\n", args))
|
||||||
}
|
}
|
||||||
|
builder.WriteString("\n")
|
||||||
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
// 处理tool消息:显示tool_call_id和完整内容
|
if strings.EqualFold(role, "tool") {
|
||||||
if role == "tool" {
|
if msg.ToolCallID != "" {
|
||||||
toolCallID, _ := msg["tool_call_id"].(string)
|
builder.WriteString(fmt.Sprintf("[%s] (tool_call_id: %s):\n%s\n\n", role, msg.ToolCallID, content))
|
||||||
if toolCallID != "" {
|
|
||||||
builder.WriteString(fmt.Sprintf("[%s] (tool_call_id: %s):\n%s\n\n", role, toolCallID, content))
|
|
||||||
} else {
|
} else {
|
||||||
builder.WriteString(fmt.Sprintf("[%s]: %s\n\n", role, content))
|
builder.WriteString(fmt.Sprintf("[%s]: %s\n\n", role, content))
|
||||||
}
|
}
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
// 其他消息类型(system, user等)正常显示
|
|
||||||
builder.WriteString(fmt.Sprintf("[%s]: %s\n\n", role, content))
|
builder.WriteString(fmt.Sprintf("[%s]: %s\n\n", role, content))
|
||||||
}
|
}
|
||||||
|
|
||||||
return builder.String()
|
return builder.String()
|
||||||
}
|
}
|
||||||
|
|
||||||
// buildSimplePrompt 构建简化的prompt
|
// buildSimplePrompt 构建简化的prompt
|
||||||
func (b *Builder) buildSimplePrompt(reactInput, modelOutput string) string {
|
func (b *Builder) buildSimplePrompt(reactInput, modelOutput string) string {
|
||||||
return fmt.Sprintf(`你是专业的安全测试分析师和攻击链构建专家。你的任务是根据对话记录和工具执行结果,构建一个逻辑清晰、有教育意义的攻击链图,完整展现渗透测试的思维过程和执行路径。
|
return fmt.Sprintf(`你是专业的安全测试分析师和攻击链构建专家。你的任务是根据**当前任务轮次**的对话记录和工具执行结果,一次性输出攻击链 JSON(不要分多轮追问)。
|
||||||
|
|
||||||
|
## 输入范围(与「继续对话」续跑一致)
|
||||||
|
- 下方「ReAct 轨迹」仅包含**最后一次用户提问之后**的消息与工具结果(last_react 当前任务轮次),不含更早的用户提问轮次。
|
||||||
|
- 「助手结论」为同轮任务的最终输出摘要(last_react_output);节点须与轨迹中的实际工具执行一致,严禁编造。
|
||||||
|
|
||||||
## 核心目标
|
## 核心目标
|
||||||
|
|
||||||
@@ -618,12 +632,9 @@ func (b *Builder) buildSimplePrompt(reactInput, modelOutput string) string {
|
|||||||
5. **漏洞确认**:如何确认漏洞存在?(action→vulnerability)
|
5. **漏洞确认**:如何确认漏洞存在?(action→vulnerability)
|
||||||
6. **攻击路径**:完整的攻击路径是什么?(从target到vulnerability的路径)
|
6. **攻击路径**:完整的攻击路径是什么?(从target到vulnerability的路径)
|
||||||
|
|
||||||
## 最后一轮ReAct输入
|
## 当前任务 ReAct 轨迹(含工具执行;助手结论见轨迹末尾 assistant)
|
||||||
|
|
||||||
%s
|
%s
|
||||||
|
|
||||||
## 大模型输出
|
|
||||||
|
|
||||||
%s
|
%s
|
||||||
|
|
||||||
## 输出格式
|
## 输出格式
|
||||||
@@ -752,7 +763,15 @@ func (b *Builder) buildSimplePrompt(reactInput, modelOutput string) string {
|
|||||||
9. **不要过度精简**:如果实际执行步骤较多,可以适当增加节点数量(最多20个),确保不遗漏关键步骤。
|
9. **不要过度精简**:如果实际执行步骤较多,可以适当增加节点数量(最多20个),确保不遗漏关键步骤。
|
||||||
10. **输出前验证**:在输出JSON前,必须验证所有边都满足source < target的条件,确保DAG结构正确。
|
10. **输出前验证**:在输出JSON前,必须验证所有边都满足source < target的条件,确保DAG结构正确。
|
||||||
|
|
||||||
现在开始分析并构建攻击链:`, reactInput, modelOutput)
|
现在开始分析并构建攻击链:`, reactInput, assistantOutSection(modelOutput))
|
||||||
|
}
|
||||||
|
|
||||||
|
func assistantOutSection(modelOutput string) string {
|
||||||
|
modelOutput = strings.TrimSpace(modelOutput)
|
||||||
|
if modelOutput == "" {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return "\n## 助手结论(补充)\n\n" + modelOutput + "\n"
|
||||||
}
|
}
|
||||||
|
|
||||||
// saveChain 保存攻击链到数据库
|
// saveChain 保存攻击链到数据库
|
||||||
@@ -812,7 +831,7 @@ func (b *Builder) callAIForChainGeneration(ctx context.Context, prompt string) (
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
"temperature": 0.3,
|
"temperature": 0.3,
|
||||||
"max_completion_tokens": 80000,
|
"max_completion_tokens": attackChainMaxCompletionTokens(b.maxTokens),
|
||||||
}
|
}
|
||||||
|
|
||||||
var apiResponse struct {
|
var apiResponse struct {
|
||||||
|
|||||||
@@ -0,0 +1,248 @@
|
|||||||
|
package attackchain
|
||||||
|
|
||||||
|
import (
|
||||||
|
"strings"
|
||||||
|
"unicode/utf8"
|
||||||
|
|
||||||
|
"go.uber.org/zap"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
attackChainTruncationMarker = "\n\n...[攻击链输入已截断 / attack chain input truncated]...\n\n"
|
||||||
|
attackChainSystemReserve = 256
|
||||||
|
attackChainSafetyReserve = 2048
|
||||||
|
)
|
||||||
|
|
||||||
|
// attackChainMaxCompletionTokens 为攻击链 JSON 输出预留的 completion token 上限。
|
||||||
|
func attackChainMaxCompletionTokens(maxTotal int) int {
|
||||||
|
const capTokens = 16384
|
||||||
|
if maxTotal <= 0 {
|
||||||
|
return 8192
|
||||||
|
}
|
||||||
|
v := maxTotal / 8
|
||||||
|
if v < 4096 {
|
||||||
|
v = 4096
|
||||||
|
}
|
||||||
|
if v > capTokens {
|
||||||
|
v = capTokens
|
||||||
|
}
|
||||||
|
return v
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *Builder) modelName() string {
|
||||||
|
if b.openAIConfig != nil && b.openAIConfig.Model != "" {
|
||||||
|
return b.openAIConfig.Model
|
||||||
|
}
|
||||||
|
return "gpt-4"
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *Builder) countTokens(text string) int {
|
||||||
|
if text == "" {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
n, err := b.tokenCounter.Count(b.modelName(), text)
|
||||||
|
if err != nil {
|
||||||
|
return utf8.RuneCountInString(text) / 4
|
||||||
|
}
|
||||||
|
return n
|
||||||
|
}
|
||||||
|
|
||||||
|
// attackChainPayloadTokenBudget 计算 reactInput + modelOutput 可用的 token 预算。
|
||||||
|
func (b *Builder) attackChainPayloadTokenBudget() int {
|
||||||
|
maxTotal := b.maxTokens
|
||||||
|
if maxTotal <= 0 {
|
||||||
|
maxTotal = 100000
|
||||||
|
}
|
||||||
|
templateTok := b.countTokens(b.buildSimplePrompt("", ""))
|
||||||
|
completion := attackChainMaxCompletionTokens(maxTotal)
|
||||||
|
reserve := templateTok + attackChainSystemReserve + completion + attackChainSafetyReserve
|
||||||
|
budget := maxTotal - reserve
|
||||||
|
minBudget := maxTotal * 35 / 100
|
||||||
|
if budget < minBudget {
|
||||||
|
budget = minBudget
|
||||||
|
}
|
||||||
|
if budget < 4096 {
|
||||||
|
budget = 4096
|
||||||
|
}
|
||||||
|
return budget
|
||||||
|
}
|
||||||
|
|
||||||
|
// fitAttackChainPayload 在构建最终 prompt 前压缩 ReAct 轨迹与模型输出,避免超出模型上下文。
|
||||||
|
func (b *Builder) fitAttackChainPayload(reactInput, modelOutput string) (string, string, bool) {
|
||||||
|
budget := b.attackChainPayloadTokenBudget()
|
||||||
|
modelBudget := budget * 15 / 100
|
||||||
|
if modelBudget < 512 {
|
||||||
|
modelBudget = 512
|
||||||
|
}
|
||||||
|
reactBudget := budget - modelBudget
|
||||||
|
|
||||||
|
origReactTok := b.countTokens(reactInput)
|
||||||
|
origModelTok := b.countTokens(modelOutput)
|
||||||
|
truncated := false
|
||||||
|
|
||||||
|
outModel := modelOutput
|
||||||
|
if origModelTok > modelBudget {
|
||||||
|
outModel = truncateTextByTokens(b, modelOutput, modelBudget)
|
||||||
|
truncated = true
|
||||||
|
}
|
||||||
|
|
||||||
|
outReact := reactInput
|
||||||
|
perToolLimits := []int{12000, 6000, 3000, 1500, 800}
|
||||||
|
for _, lim := range perToolLimits {
|
||||||
|
compact := compactFormattedToolBodies(outReact, lim)
|
||||||
|
if compact != outReact {
|
||||||
|
outReact = compact
|
||||||
|
truncated = true
|
||||||
|
}
|
||||||
|
if b.countTokens(outReact) <= reactBudget {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if b.countTokens(outReact) > reactBudget {
|
||||||
|
outReact = truncateTextByTokens(b, outReact, reactBudget)
|
||||||
|
truncated = true
|
||||||
|
}
|
||||||
|
|
||||||
|
if truncated {
|
||||||
|
b.logger.Info("攻击链输入已按 token 预算截断",
|
||||||
|
zap.Int("maxTotalTokens", b.maxTokens),
|
||||||
|
zap.Int("payloadBudget", budget),
|
||||||
|
zap.Int("reactBudget", reactBudget),
|
||||||
|
zap.Int("modelBudget", modelBudget),
|
||||||
|
zap.Int("reactInputTokensBefore", origReactTok),
|
||||||
|
zap.Int("reactInputTokensAfter", b.countTokens(outReact)),
|
||||||
|
zap.Int("modelOutputTokensBefore", origModelTok),
|
||||||
|
zap.Int("modelOutputTokensAfter", b.countTokens(outModel)),
|
||||||
|
zap.Int("maxCompletionTokens", attackChainMaxCompletionTokens(b.maxTokens)),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
return outReact, outModel, truncated
|
||||||
|
}
|
||||||
|
|
||||||
|
// compactFormattedToolBodies 缩短格式化 trace 中 [tool] 消息的正文,保留工具头与调用 ID。
|
||||||
|
func compactFormattedToolBodies(s string, maxRunesPerBody int) string {
|
||||||
|
if maxRunesPerBody <= 0 || s == "" {
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
const marker = "[tool]"
|
||||||
|
var out strings.Builder
|
||||||
|
remaining := s
|
||||||
|
changed := false
|
||||||
|
for {
|
||||||
|
idx := strings.Index(remaining, marker)
|
||||||
|
if idx < 0 {
|
||||||
|
out.WriteString(remaining)
|
||||||
|
break
|
||||||
|
}
|
||||||
|
out.WriteString(remaining[:idx])
|
||||||
|
remaining = remaining[idx:]
|
||||||
|
nl := strings.IndexByte(remaining, '\n')
|
||||||
|
if nl < 0 {
|
||||||
|
out.WriteString(remaining)
|
||||||
|
break
|
||||||
|
}
|
||||||
|
header := remaining[:nl+1]
|
||||||
|
remaining = remaining[nl+1:]
|
||||||
|
bodyEnd := strings.Index(remaining, "\n\n[")
|
||||||
|
var body, rest string
|
||||||
|
if bodyEnd < 0 {
|
||||||
|
body = remaining
|
||||||
|
rest = ""
|
||||||
|
} else {
|
||||||
|
body = remaining[:bodyEnd]
|
||||||
|
rest = remaining[bodyEnd:]
|
||||||
|
}
|
||||||
|
if runeLen(body) > maxRunesPerBody {
|
||||||
|
body = truncateRunesWithNotice(body, maxRunesPerBody)
|
||||||
|
changed = true
|
||||||
|
}
|
||||||
|
out.WriteString(header)
|
||||||
|
out.WriteString(body)
|
||||||
|
remaining = rest
|
||||||
|
if rest == "" {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !changed {
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
return out.String()
|
||||||
|
}
|
||||||
|
|
||||||
|
func truncateTextByTokens(b *Builder, text string, maxTokens int) string {
|
||||||
|
if maxTokens <= 0 || text == "" {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
if b.countTokens(text) <= maxTokens {
|
||||||
|
return text
|
||||||
|
}
|
||||||
|
markerTok := b.countTokens(attackChainTruncationMarker)
|
||||||
|
usable := maxTokens - markerTok
|
||||||
|
if usable < 256 {
|
||||||
|
usable = maxTokens / 2
|
||||||
|
}
|
||||||
|
headBudget := usable * 60 / 100
|
||||||
|
tailBudget := usable - headBudget
|
||||||
|
head := takeTokensFromStart(b, text, headBudget)
|
||||||
|
tail := takeTokensFromEnd(b, text, tailBudget)
|
||||||
|
return head + attackChainTruncationMarker + tail
|
||||||
|
}
|
||||||
|
|
||||||
|
func takeTokensFromStart(b *Builder, text string, maxTokens int) string {
|
||||||
|
rs := []rune(text)
|
||||||
|
if len(rs) == 0 || maxTokens <= 0 {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
lo, hi := 0, len(rs)
|
||||||
|
for lo < hi {
|
||||||
|
mid := (lo + hi + 1) / 2
|
||||||
|
if b.countTokens(string(rs[:mid])) <= maxTokens {
|
||||||
|
lo = mid
|
||||||
|
} else {
|
||||||
|
hi = mid - 1
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return string(rs[:lo])
|
||||||
|
}
|
||||||
|
|
||||||
|
func takeTokensFromEnd(b *Builder, text string, maxTokens int) string {
|
||||||
|
rs := []rune(text)
|
||||||
|
if len(rs) == 0 || maxTokens <= 0 {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
lo, hi := 0, len(rs)
|
||||||
|
for lo < hi {
|
||||||
|
mid := (lo + hi) / 2
|
||||||
|
if b.countTokens(string(rs[mid:])) <= maxTokens {
|
||||||
|
hi = mid
|
||||||
|
} else {
|
||||||
|
lo = mid + 1
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return string(rs[lo:])
|
||||||
|
}
|
||||||
|
|
||||||
|
func truncateRunesWithNotice(s string, maxRunes int) string {
|
||||||
|
rs := []rune(s)
|
||||||
|
if len(rs) <= maxRunes {
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
const notice = "\n...[工具输出已截断 / tool output truncated]...\n"
|
||||||
|
noticeRunes := []rune(notice)
|
||||||
|
keep := maxRunes - len(noticeRunes)
|
||||||
|
if keep < 200 {
|
||||||
|
keep = maxRunes * 2 / 3
|
||||||
|
}
|
||||||
|
if keep < 1 {
|
||||||
|
return notice
|
||||||
|
}
|
||||||
|
head := keep * 70 / 100
|
||||||
|
tail := keep - head
|
||||||
|
return string(rs[:head]) + notice + string(rs[len(rs)-tail:])
|
||||||
|
}
|
||||||
|
|
||||||
|
func runeLen(s string) int {
|
||||||
|
return len([]rune(s))
|
||||||
|
}
|
||||||
@@ -0,0 +1,63 @@
|
|||||||
|
package attackchain
|
||||||
|
|
||||||
|
import (
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"cyberstrike-ai/internal/agent"
|
||||||
|
"cyberstrike-ai/internal/config"
|
||||||
|
|
||||||
|
"go.uber.org/zap"
|
||||||
|
)
|
||||||
|
|
||||||
|
func testBuilder(maxTotal int) *Builder {
|
||||||
|
return &Builder{
|
||||||
|
logger: zap.NewNop(),
|
||||||
|
openAIConfig: &config.OpenAIConfig{Model: "gpt-4"},
|
||||||
|
tokenCounter: agent.NewTikTokenCounter(),
|
||||||
|
maxTokens: maxTotal,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCompactFormattedToolBodies(t *testing.T) {
|
||||||
|
long := strings.Repeat("x", 20000)
|
||||||
|
in := "[user]: hi\n\n[tool] (tool_call_id: abc):\n" + long + "\n\n[assistant]: done\n"
|
||||||
|
out := compactFormattedToolBodies(in, 500)
|
||||||
|
if strings.Contains(out, strings.Repeat("x", 10000)) {
|
||||||
|
t.Fatal("expected tool body to be truncated")
|
||||||
|
}
|
||||||
|
if !strings.Contains(out, "[user]: hi") {
|
||||||
|
t.Fatal("expected user header preserved")
|
||||||
|
}
|
||||||
|
if !strings.Contains(out, "[assistant]: done") {
|
||||||
|
t.Fatal("expected assistant header preserved")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFitAttackChainPayloadWithinBudget(t *testing.T) {
|
||||||
|
b := testBuilder(32000)
|
||||||
|
react := strings.Repeat("scan ", 50000)
|
||||||
|
model := strings.Repeat("result ", 10000)
|
||||||
|
r, m, truncated := b.fitAttackChainPayload(react, model)
|
||||||
|
if !truncated {
|
||||||
|
t.Fatal("expected truncation for large payload")
|
||||||
|
}
|
||||||
|
prompt := b.buildSimplePrompt(r, m)
|
||||||
|
total := b.countTokens(prompt) + attackChainMaxCompletionTokens(b.maxTokens) + attackChainSystemReserve
|
||||||
|
if total > b.maxTokens+attackChainSafetyReserve {
|
||||||
|
t.Fatalf("prompt still too large: estimated %d > max %d", total, b.maxTokens)
|
||||||
|
}
|
||||||
|
_ = m
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAttackChainMaxCompletionTokens(t *testing.T) {
|
||||||
|
if got := attackChainMaxCompletionTokens(120000); got != 15000 && got != 16384 {
|
||||||
|
// 120000/8 = 15000
|
||||||
|
if got < 4096 || got > 16384 {
|
||||||
|
t.Fatalf("unexpected completion cap: %d", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if got := attackChainMaxCompletionTokens(0); got != 8192 {
|
||||||
|
t.Fatalf("expected default 8192, got %d", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -239,13 +239,15 @@ func (m *Manager) StartListener(id string) (*database.C2Listener, error) {
|
|||||||
}
|
}
|
||||||
cfg.ApplyDefaults()
|
cfg.ApplyDefaults()
|
||||||
|
|
||||||
// 通过工厂创建具体实现
|
// 通过工厂创建具体实现。必须使用 rec 的副本:HTTP handler 在返回 JSON 前会清空
|
||||||
|
// rec.ImplantToken / EncryptionKey 做脱敏,若 listener 实现持有同一指针会导致 beacon 鉴权永久失败。
|
||||||
|
listenerRec := *rec
|
||||||
factory := m.registry.Get(rec.Type)
|
factory := m.registry.Get(rec.Type)
|
||||||
if factory == nil {
|
if factory == nil {
|
||||||
return nil, ErrUnsupportedType
|
return nil, ErrUnsupportedType
|
||||||
}
|
}
|
||||||
inst, err := factory(ListenerCreationCtx{
|
inst, err := factory(ListenerCreationCtx{
|
||||||
Listener: rec,
|
Listener: &listenerRec,
|
||||||
Config: cfg,
|
Config: cfg,
|
||||||
Manager: m,
|
Manager: m,
|
||||||
Logger: m.logger.With(zap.String("listener_id", rec.ID), zap.String("type", rec.Type)),
|
Logger: m.logger.With(zap.String("listener_id", rec.ID), zap.String("type", rec.Type)),
|
||||||
|
|||||||
@@ -0,0 +1,74 @@
|
|||||||
|
package c2
|
||||||
|
|
||||||
|
import (
|
||||||
|
"io"
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
"path/filepath"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"cyberstrike-ai/internal/database"
|
||||||
|
|
||||||
|
"go.uber.org/zap"
|
||||||
|
)
|
||||||
|
|
||||||
|
// 回归:StartListener 返回的 rec 被 handler 脱敏清空 ImplantToken 后,运行中的 HTTP listener 仍能鉴权。
|
||||||
|
func TestStartListener_ImplantTokenSurvivesHandlerRedaction(t *testing.T) {
|
||||||
|
tmp := t.TempDir()
|
||||||
|
db, err := database.NewDB(filepath.Join(tmp, "c2.sqlite"), zap.NewNop())
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
t.Cleanup(func() { _ = db.Close() })
|
||||||
|
|
||||||
|
lnPick, err := net.Listen("tcp", "127.0.0.1:0")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
port := lnPick.Addr().(*net.TCPAddr).Port
|
||||||
|
_ = lnPick.Close()
|
||||||
|
|
||||||
|
mgr := NewManager(db, zap.NewNop(), tmp)
|
||||||
|
mgr.Registry().Register(string(ListenerTypeHTTPBeacon), NewHTTPBeaconListener)
|
||||||
|
rec, err := mgr.CreateListener(CreateListenerInput{
|
||||||
|
Name: "t",
|
||||||
|
Type: string(ListenerTypeHTTPBeacon),
|
||||||
|
BindHost: "127.0.0.1",
|
||||||
|
BindPort: port,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
token := rec.ImplantToken
|
||||||
|
|
||||||
|
rec, err = mgr.StartListener(rec.ID)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
// 模拟 internal/handler/c2.go StartListener 在 JSON 响应前的脱敏
|
||||||
|
rec.ImplantToken = ""
|
||||||
|
rec.EncryptionKey = ""
|
||||||
|
|
||||||
|
time.Sleep(50 * time.Millisecond)
|
||||||
|
|
||||||
|
body := `{"hostname":"n","username":"u","os":"Linux","arch":"amd64","internal_ip":"10.0.0.1","pid":42}`
|
||||||
|
req, _ := http.NewRequest(http.MethodPost, "http://127.0.0.1:"+strconv.Itoa(port)+"/check_in", strings.NewReader(body))
|
||||||
|
req.Header.Set("X-Implant-Token", token)
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
resp, err := http.DefaultClient.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
b, _ := io.ReadAll(resp.Body)
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
t.Fatalf("status=%d body=%s", resp.StatusCode, b)
|
||||||
|
}
|
||||||
|
if !strings.Contains(string(b), "session_id") {
|
||||||
|
t.Fatalf("expected session_id in body: %s", b)
|
||||||
|
}
|
||||||
|
_ = mgr.StopListener(rec.ID)
|
||||||
|
}
|
||||||
+191
-5
@@ -63,6 +63,126 @@ type MultiAgentConfig struct {
|
|||||||
EinoSkills MultiAgentEinoSkillsConfig `yaml:"eino_skills,omitempty" json:"eino_skills,omitempty"`
|
EinoSkills MultiAgentEinoSkillsConfig `yaml:"eino_skills,omitempty" json:"eino_skills,omitempty"`
|
||||||
// EinoMiddleware wires optional ADK middleware (patchtoolcalls, toolsearch, plantask, reduction) and Deep extras.
|
// EinoMiddleware wires optional ADK middleware (patchtoolcalls, toolsearch, plantask, reduction) and Deep extras.
|
||||||
EinoMiddleware MultiAgentEinoMiddlewareConfig `yaml:"eino_middleware,omitempty" json:"eino_middleware,omitempty"`
|
EinoMiddleware MultiAgentEinoMiddlewareConfig `yaml:"eino_middleware,omitempty" json:"eino_middleware,omitempty"`
|
||||||
|
// EinoCallbacks attaches CloudWeGo eino callbacks.InitCallbacks on ADK Runner context (structured logs + optional SSE trace).
|
||||||
|
EinoCallbacks MultiAgentEinoCallbacksConfig `yaml:"eino_callbacks,omitempty" json:"eino_callbacks,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// MultiAgentEinoCallbacksConfig enables Eino unified callbacks on each ADK agent run (deep / plan_execute / supervisor / eino_single).
|
||||||
|
// Modes: log_only (zap + optional OTel; no SSE to browser), sse (adds client SSE eino_trace_* when sse_trace_to_client), full (sse rules + stream callback copies closed).
|
||||||
|
type MultiAgentEinoCallbacksConfig struct {
|
||||||
|
Enabled bool `yaml:"enabled" json:"enabled"`
|
||||||
|
Mode string `yaml:"mode,omitempty" json:"mode,omitempty"` // log_only | sse | full; empty with enabled=true defaults to log_only
|
||||||
|
// SseTraceToClient when true emits eino_trace_* SSE for UI (use only for admin/debug; nil/false recommended in production).
|
||||||
|
SseTraceToClient *bool `yaml:"sse_trace_to_client,omitempty" json:"sse_trace_to_client,omitempty"`
|
||||||
|
// Otel configures OpenTelemetry trace export (independent of mode; exporter none disables export even if enabled).
|
||||||
|
Otel MultiAgentEinoCallbacksOtelConfig `yaml:"otel,omitempty" json:"otel,omitempty"`
|
||||||
|
// MaxInputSummaryRunes / MaxOutputSummaryRunes cap text placed in SSE payloads and debug logs (not full payloads).
|
||||||
|
MaxInputSummaryRunes int `yaml:"max_input_summary_runes,omitempty" json:"max_input_summary_runes,omitempty"`
|
||||||
|
MaxOutputSummaryRunes int `yaml:"max_output_summary_runes,omitempty" json:"max_output_summary_runes,omitempty"`
|
||||||
|
// ZapVerbose when true logs input/output summaries at zap.Debug on start/end; false uses Info with short fields only.
|
||||||
|
ZapVerbose bool `yaml:"zap_verbose,omitempty" json:"zap_verbose,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// MultiAgentEinoCallbacksOtelConfig OpenTelemetry for Eino callback spans (W3C trace in collector / stdout).
|
||||||
|
type MultiAgentEinoCallbacksOtelConfig struct {
|
||||||
|
Enabled bool `yaml:"enabled" json:"enabled"`
|
||||||
|
ServiceName string `yaml:"service_name,omitempty" json:"service_name,omitempty"`
|
||||||
|
Exporter string `yaml:"exporter,omitempty" json:"exporter,omitempty"` // none | stdout | otlphttp
|
||||||
|
OTLPEndpoint string `yaml:"otlp_endpoint,omitempty" json:"otlp_endpoint,omitempty"` // host:port, e.g. localhost:4318 (path /v1/traces)
|
||||||
|
SampleRatio float64 `yaml:"sample_ratio,omitempty" json:"sample_ratio,omitempty"` // 0–1, default 1.0
|
||||||
|
}
|
||||||
|
|
||||||
|
// EinoCallbacksModeEffective returns off | log_only | sse | full.
|
||||||
|
func (c MultiAgentEinoCallbacksConfig) EinoCallbacksModeEffective() string {
|
||||||
|
if !c.Enabled {
|
||||||
|
return "off"
|
||||||
|
}
|
||||||
|
m := strings.TrimSpace(strings.ToLower(c.Mode))
|
||||||
|
switch m {
|
||||||
|
case "log_only":
|
||||||
|
return "log_only"
|
||||||
|
case "sse":
|
||||||
|
return "sse"
|
||||||
|
case "full":
|
||||||
|
return "full"
|
||||||
|
case "":
|
||||||
|
return "log_only"
|
||||||
|
default:
|
||||||
|
return "log_only"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// SseTraceToClientEffective is false unless explicitly set true (best practice: do not expose framework traces to end users by default).
|
||||||
|
func (c MultiAgentEinoCallbacksConfig) SseTraceToClientEffective() bool {
|
||||||
|
if c.SseTraceToClient == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return *c.SseTraceToClient
|
||||||
|
}
|
||||||
|
|
||||||
|
// ShouldEmitEinoTraceSSE is true when client-visible trace events should be sent over progress/SSE.
|
||||||
|
func (c MultiAgentEinoCallbacksConfig) ShouldEmitEinoTraceSSE(mode string) bool {
|
||||||
|
if !c.SseTraceToClientEffective() {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return mode == "sse" || mode == "full"
|
||||||
|
}
|
||||||
|
|
||||||
|
// OtelExporterEffective returns none | stdout | otlphttp.
|
||||||
|
func (c MultiAgentEinoCallbacksOtelConfig) OtelExporterEffective() string {
|
||||||
|
e := strings.TrimSpace(strings.ToLower(c.Exporter))
|
||||||
|
switch e {
|
||||||
|
case "none", "stdout", "otlphttp":
|
||||||
|
return e
|
||||||
|
case "":
|
||||||
|
if c.Enabled {
|
||||||
|
return "stdout"
|
||||||
|
}
|
||||||
|
return "none"
|
||||||
|
default:
|
||||||
|
return "none"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// OtelTracingActive is true when spans should be started (enabled + non-none exporter).
|
||||||
|
func (c MultiAgentEinoCallbacksConfig) OtelTracingActive() bool {
|
||||||
|
if !c.Otel.Enabled {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return c.Otel.OtelExporterEffective() != "none"
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c MultiAgentEinoCallbacksOtelConfig) ServiceNameEffective() string {
|
||||||
|
s := strings.TrimSpace(c.ServiceName)
|
||||||
|
if s != "" {
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
return "cyberstrike-ai"
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c MultiAgentEinoCallbacksOtelConfig) SampleRatioEffective() float64 {
|
||||||
|
r := c.SampleRatio
|
||||||
|
if r <= 0 {
|
||||||
|
return 1.0
|
||||||
|
}
|
||||||
|
if r > 1 {
|
||||||
|
return 1.0
|
||||||
|
}
|
||||||
|
return r
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c MultiAgentEinoCallbacksConfig) EinoCallbacksMaxInputSummaryRunes() int {
|
||||||
|
if c.MaxInputSummaryRunes > 0 {
|
||||||
|
return c.MaxInputSummaryRunes
|
||||||
|
}
|
||||||
|
return 400
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c MultiAgentEinoCallbacksConfig) EinoCallbacksMaxOutputSummaryRunes() int {
|
||||||
|
if c.MaxOutputSummaryRunes > 0 {
|
||||||
|
return c.MaxOutputSummaryRunes
|
||||||
|
}
|
||||||
|
return 400
|
||||||
}
|
}
|
||||||
|
|
||||||
// MultiAgentEinoMiddlewareConfig optional Eino ADK middleware and Deep / supervisor tuning.
|
// MultiAgentEinoMiddlewareConfig optional Eino ADK middleware and Deep / supervisor tuning.
|
||||||
@@ -90,7 +210,8 @@ type MultiAgentEinoMiddlewareConfig struct {
|
|||||||
SummarizationTriggerRatio float64 `yaml:"summarization_trigger_ratio,omitempty" json:"summarization_trigger_ratio,omitempty"`
|
SummarizationTriggerRatio float64 `yaml:"summarization_trigger_ratio,omitempty" json:"summarization_trigger_ratio,omitempty"`
|
||||||
// SummarizationEmitInternalEvents controls middleware internal event emission (default true).
|
// SummarizationEmitInternalEvents controls middleware internal event emission (default true).
|
||||||
SummarizationEmitInternalEvents *bool `yaml:"summarization_emit_internal_events,omitempty" json:"summarization_emit_internal_events,omitempty"`
|
SummarizationEmitInternalEvents *bool `yaml:"summarization_emit_internal_events,omitempty" json:"summarization_emit_internal_events,omitempty"`
|
||||||
// HistoryInputBudgetRatio caps pre-agent history tokens as max_total_tokens * ratio (default 0.35).
|
// HistoryInputBudgetRatio 已不影响 Eino:从 last_react 轨迹转 ADK 消息时**不再**按 token 比例裁剪(完整注入)。
|
||||||
|
// 字段仍保留,便于旧版 config 不报错;新部署可省略。
|
||||||
HistoryInputBudgetRatio float64 `yaml:"history_input_budget_ratio,omitempty" json:"history_input_budget_ratio,omitempty"`
|
HistoryInputBudgetRatio float64 `yaml:"history_input_budget_ratio,omitempty" json:"history_input_budget_ratio,omitempty"`
|
||||||
// PlanExecuteUserInputBudgetRatio caps planner/replanner/executor userInput prompt budget ratio (default 0.35).
|
// PlanExecuteUserInputBudgetRatio caps planner/replanner/executor userInput prompt budget ratio (default 0.35).
|
||||||
PlanExecuteUserInputBudgetRatio float64 `yaml:"plan_execute_user_input_budget_ratio,omitempty" json:"plan_execute_user_input_budget_ratio,omitempty"`
|
PlanExecuteUserInputBudgetRatio float64 `yaml:"plan_execute_user_input_budget_ratio,omitempty" json:"plan_execute_user_input_budget_ratio,omitempty"`
|
||||||
@@ -270,17 +391,31 @@ type MultiAgentAPIUpdate struct {
|
|||||||
RobotUseMultiAgent bool `json:"robot_use_multi_agent"`
|
RobotUseMultiAgent bool `json:"robot_use_multi_agent"`
|
||||||
BatchUseMultiAgent bool `json:"batch_use_multi_agent"`
|
BatchUseMultiAgent bool `json:"batch_use_multi_agent"`
|
||||||
PlanExecuteLoopMaxIterations *int `json:"plan_execute_loop_max_iterations,omitempty"`
|
PlanExecuteLoopMaxIterations *int `json:"plan_execute_loop_max_iterations,omitempty"`
|
||||||
ToolSearchAlwaysVisibleTools []string `json:"tool_search_always_visible_tools,omitempty"`
|
// 指针区分「JSON 未传该字段」与「传空数组要清空」;省略时不应覆盖 YAML 中的常驻工具白名单。
|
||||||
|
ToolSearchAlwaysVisibleTools *[]string `json:"tool_search_always_visible_tools,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// RobotsConfig 机器人配置(企业微信、钉钉、飞书等)
|
// RobotsConfig 机器人配置(企业微信、钉钉、飞书、微信 iLink 等)
|
||||||
type RobotsConfig struct {
|
type RobotsConfig struct {
|
||||||
Session RobotSessionConfig `yaml:"session,omitempty" json:"session,omitempty"` // 机器人会话隔离策略
|
Session RobotSessionConfig `yaml:"session,omitempty" json:"session,omitempty"` // 机器人会话隔离策略
|
||||||
|
Wechat RobotWechatConfig `yaml:"wechat,omitempty" json:"wechat,omitempty"` // 微信(iLink 扫码绑定)
|
||||||
Wecom RobotWecomConfig `yaml:"wecom,omitempty" json:"wecom,omitempty"` // 企业微信
|
Wecom RobotWecomConfig `yaml:"wecom,omitempty" json:"wecom,omitempty"` // 企业微信
|
||||||
Dingtalk RobotDingtalkConfig `yaml:"dingtalk,omitempty" json:"dingtalk,omitempty"` // 钉钉
|
Dingtalk RobotDingtalkConfig `yaml:"dingtalk,omitempty" json:"dingtalk,omitempty"` // 钉钉
|
||||||
Lark RobotLarkConfig `yaml:"lark,omitempty" json:"lark,omitempty"` // 飞书
|
Lark RobotLarkConfig `yaml:"lark,omitempty" json:"lark,omitempty"` // 飞书
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// RobotWechatConfig 微信 iLink 机器人配置(个人微信 ClawBot / iLink 协议)
|
||||||
|
type RobotWechatConfig struct {
|
||||||
|
Enabled bool `yaml:"enabled" json:"enabled"`
|
||||||
|
BotToken string `yaml:"bot_token,omitempty" json:"bot_token,omitempty"`
|
||||||
|
ILinkBotID string `yaml:"ilink_bot_id,omitempty" json:"ilink_bot_id,omitempty"`
|
||||||
|
ILinkUserID string `yaml:"ilink_user_id,omitempty" json:"ilink_user_id,omitempty"`
|
||||||
|
BaseURL string `yaml:"base_url,omitempty" json:"base_url,omitempty"` // 默认 https://ilinkai.weixin.qq.com
|
||||||
|
BotType string `yaml:"bot_type,omitempty" json:"bot_type,omitempty"` // get_bot_qrcode 参数,默认 3
|
||||||
|
BotAgent string `yaml:"bot_agent,omitempty" json:"bot_agent,omitempty"` // base_info.bot_agent
|
||||||
|
GetUpdatesBuf string `yaml:"get_updates_buf,omitempty" json:"get_updates_buf,omitempty"` // 长轮询游标(运行时)
|
||||||
|
}
|
||||||
|
|
||||||
// RobotSessionConfig 机器人会话隔离策略
|
// RobotSessionConfig 机器人会话隔离策略
|
||||||
type RobotSessionConfig struct {
|
type RobotSessionConfig struct {
|
||||||
StrictUserIdentity *bool `yaml:"strict_user_identity,omitempty" json:"strict_user_identity,omitempty"` // true 时只允许真实用户标识,不允许会话/群 ID 兜底
|
StrictUserIdentity *bool `yaml:"strict_user_identity,omitempty" json:"strict_user_identity,omitempty"` // true 时只允许真实用户标识,不允许会话/群 ID 兜底
|
||||||
@@ -322,8 +457,17 @@ type RobotLarkConfig struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type ServerConfig struct {
|
type ServerConfig struct {
|
||||||
Host string `yaml:"host"`
|
Host string `yaml:"host" json:"host"`
|
||||||
Port int `yaml:"port"`
|
Port int `yaml:"port" json:"port"`
|
||||||
|
// TLSEnabled 为 true 时主 Web UI 使用 HTTPS;现代浏览器在同源下会协商 HTTP/2,缓解 HTTP/1.1 每源并发连接数限制。
|
||||||
|
TLSEnabled bool `yaml:"tls_enabled,omitempty" json:"tls_enabled,omitempty"`
|
||||||
|
// TLSCertPath / TLSKeyPath 非空时从 PEM 文件加载证书(生产环境推荐)。
|
||||||
|
TLSCertPath string `yaml:"tls_cert_path,omitempty" json:"tls_cert_path,omitempty"`
|
||||||
|
TLSKeyPath string `yaml:"tls_key_path,omitempty" json:"tls_key_path,omitempty"`
|
||||||
|
// TLSAutoSelfSign 为 true 且未配置有效证书路径时,启动时生成内存自签证书(仅本地/测试;浏览器会提示不受信任)。
|
||||||
|
TLSAutoSelfSign bool `yaml:"tls_auto_self_sign,omitempty" json:"tls_auto_self_sign,omitempty"`
|
||||||
|
// TLSHTTPRedirect 为 false 时禁用 HTTP→HTTPS 跳转;省略或为 true 且已启用 HTTPS 时,明文 HTTP 访问将 308 跳转到 HTTPS(同端口嗅探分流)。
|
||||||
|
TLSHTTPRedirect *bool `yaml:"tls_http_redirect,omitempty" json:"tls_http_redirect,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type LogConfig struct {
|
type LogConfig struct {
|
||||||
@@ -345,6 +489,48 @@ type OpenAIConfig struct {
|
|||||||
BaseURL string `yaml:"base_url" json:"base_url"`
|
BaseURL string `yaml:"base_url" json:"base_url"`
|
||||||
Model string `yaml:"model" json:"model"`
|
Model string `yaml:"model" json:"model"`
|
||||||
MaxTotalTokens int `yaml:"max_total_tokens,omitempty" json:"max_total_tokens,omitempty"`
|
MaxTotalTokens int `yaml:"max_total_tokens,omitempty" json:"max_total_tokens,omitempty"`
|
||||||
|
// Reasoning 控制 Eino ChatModel 的 thinking / reasoning_effort / output_config 等(仅 Eino 路径生效;原生 ReAct 忽略)。
|
||||||
|
Reasoning OpenAIReasoningConfig `yaml:"reasoning,omitempty" json:"reasoning,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// OpenAIReasoningConfig 全局默认与网关 profile(对话页可通过 ChatRequest.reasoning 覆盖,受 AllowClientReasoning 约束)。
|
||||||
|
type OpenAIReasoningConfig struct {
|
||||||
|
// Mode: auto(默认)| on | off | default(与 auto 相同)。off 时不向模型附加推理扩展字段。
|
||||||
|
Mode string `yaml:"mode,omitempty" json:"mode,omitempty"`
|
||||||
|
// Effort: low | medium | high | max;空表示不单独指定强度(各 profile 行为见 internal/reasoning)。
|
||||||
|
Effort string `yaml:"effort,omitempty" json:"effort,omitempty"`
|
||||||
|
// AllowClientReasoning 为 false 时忽略请求体 reasoning;nil 或未设置等同于 true。
|
||||||
|
AllowClientReasoning *bool `yaml:"allow_client_reasoning,omitempty" json:"allow_client_reasoning,omitempty"`
|
||||||
|
// Profile: auto | deepseek_compat | openai_compat | output_config_effort
|
||||||
|
Profile string `yaml:"profile,omitempty" json:"profile,omitempty"`
|
||||||
|
// ExtraRequestFields 合并进 Chat Completions 根 JSON(管理员用;与自动字段同名时后者覆盖)。
|
||||||
|
ExtraRequestFields map[string]interface{} `yaml:"extra_request_fields,omitempty" json:"extra_request_fields,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// ModeEffective returns auto when empty or default.
|
||||||
|
func (c OpenAIReasoningConfig) ModeEffective() string {
|
||||||
|
m := strings.ToLower(strings.TrimSpace(c.Mode))
|
||||||
|
if m == "" || m == "default" {
|
||||||
|
return "auto"
|
||||||
|
}
|
||||||
|
return m
|
||||||
|
}
|
||||||
|
|
||||||
|
// ProfileEffective returns auto when empty.
|
||||||
|
func (c OpenAIReasoningConfig) ProfileEffective() string {
|
||||||
|
p := strings.ToLower(strings.TrimSpace(c.Profile))
|
||||||
|
if p == "" {
|
||||||
|
return "auto"
|
||||||
|
}
|
||||||
|
return p
|
||||||
|
}
|
||||||
|
|
||||||
|
// AllowClientReasoningEffective true when client may send ChatRequest.reasoning.
|
||||||
|
func (c OpenAIReasoningConfig) AllowClientReasoningEffective() bool {
|
||||||
|
if c.AllowClientReasoning == nil {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return *c.AllowClientReasoning
|
||||||
}
|
}
|
||||||
|
|
||||||
type FofaConfig struct {
|
type FofaConfig struct {
|
||||||
|
|||||||
@@ -0,0 +1,46 @@
|
|||||||
|
package config
|
||||||
|
|
||||||
|
import "strings"
|
||||||
|
|
||||||
|
// MainWebUIUsesHTTPS 判断主 Web UI 是否以 HTTPS 监听(与 internal/app.prepareMainServerTLS 前置条件一致)。
|
||||||
|
func MainWebUIUsesHTTPS(s *ServerConfig) bool {
|
||||||
|
if s == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if s.TLSEnabled {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
if s.TLSAutoSelfSign {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
cert := strings.TrimSpace(s.TLSCertPath)
|
||||||
|
key := strings.TrimSpace(s.TLSKeyPath)
|
||||||
|
return cert != "" && key != ""
|
||||||
|
}
|
||||||
|
|
||||||
|
// ServerHTTPRedirectEnabled 是否在主站启用 HTTPS 时把明文 HTTP 请求重定向到 HTTPS(默认开启)。
|
||||||
|
func ServerHTTPRedirectEnabled(s *ServerConfig) bool {
|
||||||
|
if s == nil || !MainWebUIUsesHTTPS(s) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if s.TLSHTTPRedirect == nil {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return *s.TLSHTTPRedirect
|
||||||
|
}
|
||||||
|
|
||||||
|
// ApplyDevHTTPSBootstrap 供 --https / 一键脚本使用:强制开启主站 TLS。
|
||||||
|
// 若已配置 tls_cert_path 与 tls_key_path 则仅用 PEM,不开启自签;否则启用 tls_auto_self_sign(内存证书,仅本地测试)。
|
||||||
|
func ApplyDevHTTPSBootstrap(cfg *Config) {
|
||||||
|
if cfg == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
cfg.Server.TLSEnabled = true
|
||||||
|
cert := strings.TrimSpace(cfg.Server.TLSCertPath)
|
||||||
|
key := strings.TrimSpace(cfg.Server.TLSKeyPath)
|
||||||
|
if cert != "" && key != "" {
|
||||||
|
cfg.Server.TLSAutoSelfSign = false
|
||||||
|
return
|
||||||
|
}
|
||||||
|
cfg.Server.TLSAutoSelfSign = true
|
||||||
|
}
|
||||||
@@ -25,14 +25,15 @@ type Conversation struct {
|
|||||||
|
|
||||||
// Message 消息
|
// Message 消息
|
||||||
type Message struct {
|
type Message struct {
|
||||||
ID string `json:"id"`
|
ID string `json:"id"`
|
||||||
ConversationID string `json:"conversationId"`
|
ConversationID string `json:"conversationId"`
|
||||||
Role string `json:"role"`
|
Role string `json:"role"`
|
||||||
Content string `json:"content"`
|
Content string `json:"content"`
|
||||||
MCPExecutionIDs []string `json:"mcpExecutionIds,omitempty"`
|
ReasoningContent string `json:"reasoningContent,omitempty"`
|
||||||
ProcessDetails []map[string]interface{} `json:"processDetails,omitempty"`
|
MCPExecutionIDs []string `json:"mcpExecutionIds,omitempty"`
|
||||||
CreatedAt time.Time `json:"createdAt"`
|
ProcessDetails []map[string]interface{} `json:"processDetails,omitempty"`
|
||||||
UpdatedAt time.Time `json:"updatedAt"`
|
CreatedAt time.Time `json:"createdAt"`
|
||||||
|
UpdatedAt time.Time `json:"updatedAt"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// CreateConversation 创建新对话
|
// CreateConversation 创建新对话
|
||||||
@@ -116,6 +117,7 @@ func (db *DB) GetConversationByWebshellConnectionID(connectionID string) (*Conve
|
|||||||
}
|
}
|
||||||
for i := range conv.Messages {
|
for i := range conv.Messages {
|
||||||
if details, ok := processDetailsMap[conv.Messages[i].ID]; ok {
|
if details, ok := processDetailsMap[conv.Messages[i].ID]; ok {
|
||||||
|
details = DedupeConsecutiveProcessDetails(details)
|
||||||
detailsJSON := make([]map[string]interface{}, len(details))
|
detailsJSON := make([]map[string]interface{}, len(details))
|
||||||
for j, detail := range details {
|
for j, detail := range details {
|
||||||
var data interface{}
|
var data interface{}
|
||||||
@@ -234,6 +236,7 @@ func (db *DB) GetConversation(id string) (*Conversation, error) {
|
|||||||
// 将过程详情附加到对应的消息上
|
// 将过程详情附加到对应的消息上
|
||||||
for i := range conv.Messages {
|
for i := range conv.Messages {
|
||||||
if details, ok := processDetailsMap[conv.Messages[i].ID]; ok {
|
if details, ok := processDetailsMap[conv.Messages[i].ID]; ok {
|
||||||
|
details = DedupeConsecutiveProcessDetails(details)
|
||||||
// 将ProcessDetail转换为JSON格式,以便前端使用
|
// 将ProcessDetail转换为JSON格式,以便前端使用
|
||||||
detailsJSON := make([]map[string]interface{}, len(details))
|
detailsJSON := make([]map[string]interface{}, len(details))
|
||||||
for j, detail := range details {
|
for j, detail := range details {
|
||||||
@@ -498,8 +501,8 @@ func (db *DB) AddMessage(conversationID, role, content string, mcpExecutionIDs [
|
|||||||
}
|
}
|
||||||
|
|
||||||
_, err := db.Exec(
|
_, err := db.Exec(
|
||||||
"INSERT INTO messages (id, conversation_id, role, content, mcp_execution_ids, created_at, updated_at) VALUES (?, ?, ?, ?, ?, ?, ?)",
|
"INSERT INTO messages (id, conversation_id, role, content, reasoning_content, mcp_execution_ids, created_at, updated_at) VALUES (?, ?, ?, ?, ?, ?, ?, ?)",
|
||||||
id, conversationID, role, content, mcpIDsJSON, now, now,
|
id, conversationID, role, content, "", mcpIDsJSON, now, now,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("添加消息失败: %w", err)
|
return nil, fmt.Errorf("添加消息失败: %w", err)
|
||||||
@@ -523,10 +526,30 @@ func (db *DB) AddMessage(conversationID, role, content string, mcpExecutionIDs [
|
|||||||
return message, nil
|
return message, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// UpdateAssistantMessageFinalize 更新助手消息终态(正文、MCP id、思考链聚合文本,供无轨迹回退时回放)。
|
||||||
|
func (db *DB) UpdateAssistantMessageFinalize(messageID, content string, mcpExecutionIDs []string, reasoningContent string) error {
|
||||||
|
var mcpIDsJSON string
|
||||||
|
if len(mcpExecutionIDs) > 0 {
|
||||||
|
jsonData, err := json.Marshal(mcpExecutionIDs)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("序列化MCP执行ID失败: %w", err)
|
||||||
|
}
|
||||||
|
mcpIDsJSON = string(jsonData)
|
||||||
|
}
|
||||||
|
_, err := db.Exec(
|
||||||
|
"UPDATE messages SET content = ?, mcp_execution_ids = ?, reasoning_content = ?, updated_at = ? WHERE id = ?",
|
||||||
|
content, mcpIDsJSON, strings.TrimSpace(reasoningContent), time.Now(), messageID,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("更新助手消息失败: %w", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// GetMessages 获取对话的所有消息
|
// GetMessages 获取对话的所有消息
|
||||||
func (db *DB) GetMessages(conversationID string) ([]Message, error) {
|
func (db *DB) GetMessages(conversationID string) ([]Message, error) {
|
||||||
rows, err := db.Query(
|
rows, err := db.Query(
|
||||||
"SELECT id, conversation_id, role, content, mcp_execution_ids, created_at, updated_at FROM messages WHERE conversation_id = ? ORDER BY created_at ASC",
|
"SELECT id, conversation_id, role, content, reasoning_content, mcp_execution_ids, created_at, updated_at FROM messages WHERE conversation_id = ? ORDER BY created_at ASC",
|
||||||
conversationID,
|
conversationID,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -537,13 +560,17 @@ func (db *DB) GetMessages(conversationID string) ([]Message, error) {
|
|||||||
var messages []Message
|
var messages []Message
|
||||||
for rows.Next() {
|
for rows.Next() {
|
||||||
var msg Message
|
var msg Message
|
||||||
|
var reasoning sql.NullString
|
||||||
var mcpIDsJSON sql.NullString
|
var mcpIDsJSON sql.NullString
|
||||||
var createdAt string
|
var createdAt string
|
||||||
var updatedAt sql.NullString
|
var updatedAt sql.NullString
|
||||||
|
|
||||||
if err := rows.Scan(&msg.ID, &msg.ConversationID, &msg.Role, &msg.Content, &mcpIDsJSON, &createdAt, &updatedAt); err != nil {
|
if err := rows.Scan(&msg.ID, &msg.ConversationID, &msg.Role, &msg.Content, &reasoning, &mcpIDsJSON, &createdAt, &updatedAt); err != nil {
|
||||||
return nil, fmt.Errorf("扫描消息失败: %w", err)
|
return nil, fmt.Errorf("扫描消息失败: %w", err)
|
||||||
}
|
}
|
||||||
|
if reasoning.Valid {
|
||||||
|
msg.ReasoningContent = reasoning.String
|
||||||
|
}
|
||||||
|
|
||||||
// 尝试多种时间格式解析
|
// 尝试多种时间格式解析
|
||||||
var err error
|
var err error
|
||||||
@@ -683,7 +710,7 @@ type ProcessDetail struct {
|
|||||||
ID string `json:"id"`
|
ID string `json:"id"`
|
||||||
MessageID string `json:"messageId"`
|
MessageID string `json:"messageId"`
|
||||||
ConversationID string `json:"conversationId"`
|
ConversationID string `json:"conversationId"`
|
||||||
EventType string `json:"eventType"` // iteration, thinking, tool_calls_detected, tool_call, tool_result, progress, error
|
EventType string `json:"eventType"` // iteration, thinking, reasoning_chain, tool_calls_detected, tool_call, tool_result, progress, error
|
||||||
Message string `json:"message"`
|
Message string `json:"message"`
|
||||||
Data string `json:"data"` // JSON格式的数据
|
Data string `json:"data"` // JSON格式的数据
|
||||||
CreatedAt time.Time `json:"createdAt"`
|
CreatedAt time.Time `json:"createdAt"`
|
||||||
|
|||||||
@@ -594,6 +594,25 @@ func (db *DB) migrateMessagesTable() error {
|
|||||||
|
|
||||||
// 回填已有数据:让 updated_at 至少等于 created_at,避免前端出现空/当前时间回退。
|
// 回填已有数据:让 updated_at 至少等于 created_at,避免前端出现空/当前时间回退。
|
||||||
_, _ = db.Exec("UPDATE messages SET updated_at = created_at WHERE updated_at IS NULL OR updated_at = ''")
|
_, _ = db.Exec("UPDATE messages SET updated_at = created_at WHERE updated_at IS NULL OR updated_at = ''")
|
||||||
|
|
||||||
|
// reasoning_content:DeepSeek 思考模式 + 工具调用续跑;与 last_react_input 互补,供消息表回退路径回放
|
||||||
|
var rcColCount int
|
||||||
|
errRC := db.QueryRow("SELECT COUNT(*) FROM pragma_table_info('messages') WHERE name='reasoning_content'").Scan(&rcColCount)
|
||||||
|
if errRC != nil {
|
||||||
|
if _, addErr := db.Exec("ALTER TABLE messages ADD COLUMN reasoning_content TEXT"); addErr != nil {
|
||||||
|
errMsg := strings.ToLower(addErr.Error())
|
||||||
|
if !strings.Contains(errMsg, "duplicate column") && !strings.Contains(errMsg, "already exists") {
|
||||||
|
return fmt.Errorf("添加 messages.reasoning_content 字段失败: %w", addErr)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else if rcColCount == 0 {
|
||||||
|
if _, err := db.Exec("ALTER TABLE messages ADD COLUMN reasoning_content TEXT"); err != nil {
|
||||||
|
errMsg := strings.ToLower(err.Error())
|
||||||
|
if !strings.Contains(errMsg, "duplicate column") && !strings.Contains(errMsg, "already exists") {
|
||||||
|
return fmt.Errorf("添加 messages.reasoning_content 字段失败: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -0,0 +1,28 @@
|
|||||||
|
package database
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
// DedupeConsecutiveProcessDetails 去掉相邻且语义相同的过程详情(使用 DB 中 data 列原始 JSON 作指纹,避免 map 序列化键序不稳定)。
|
||||||
|
func DedupeConsecutiveProcessDetails(rows []ProcessDetail) []ProcessDetail {
|
||||||
|
if len(rows) < 2 {
|
||||||
|
return rows
|
||||||
|
}
|
||||||
|
out := make([]ProcessDetail, 0, len(rows))
|
||||||
|
var lastKey string
|
||||||
|
for _, d := range rows {
|
||||||
|
key := processDetailRowKey(d)
|
||||||
|
if len(out) > 0 && key != "" && key == lastKey {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
out = append(out, d)
|
||||||
|
lastKey = key
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
func processDetailRowKey(d ProcessDetail) string {
|
||||||
|
return fmt.Sprintf("%s\x00%s\x00%s", d.EventType, strings.TrimSpace(d.Message), d.Data)
|
||||||
|
}
|
||||||
@@ -23,12 +23,16 @@ type ExecutionRecorder func(executionID string)
|
|||||||
const ToolErrorPrefix = "__CYBERSTRIKE_AI_TOOL_ERROR__\n"
|
const ToolErrorPrefix = "__CYBERSTRIKE_AI_TOOL_ERROR__\n"
|
||||||
|
|
||||||
// ToolsFromDefinitions 将单 Agent 使用的 OpenAI 风格工具定义转为 Eino InvokableTool,执行时走 Agent 的 MCP 路径。
|
// ToolsFromDefinitions 将单 Agent 使用的 OpenAI 风格工具定义转为 Eino InvokableTool,执行时走 Agent 的 MCP 路径。
|
||||||
|
// invokeNotify 可选:与 runEinoADKAgentLoop 共享,在 InvokableRun 返回时触发 UI 与 pending 清理(与 ADK Tool 事件去重)。
|
||||||
|
// einoAgentName 为该套工具所属 ChatModelAgent 的 Name(主代理或子代理 id),用于 SSE 上的 einoAgent 字段。
|
||||||
func ToolsFromDefinitions(
|
func ToolsFromDefinitions(
|
||||||
ag *agent.Agent,
|
ag *agent.Agent,
|
||||||
holder *ConversationHolder,
|
holder *ConversationHolder,
|
||||||
defs []agent.Tool,
|
defs []agent.Tool,
|
||||||
rec ExecutionRecorder,
|
rec ExecutionRecorder,
|
||||||
toolOutputChunk func(toolName, toolCallID, chunk string),
|
toolOutputChunk func(toolName, toolCallID, chunk string),
|
||||||
|
invokeNotify *ToolInvokeNotifyHolder,
|
||||||
|
einoAgentName string,
|
||||||
) ([]tool.BaseTool, error) {
|
) ([]tool.BaseTool, error) {
|
||||||
out := make([]tool.BaseTool, 0, len(defs))
|
out := make([]tool.BaseTool, 0, len(defs))
|
||||||
for _, d := range defs {
|
for _, d := range defs {
|
||||||
@@ -40,12 +44,14 @@ func ToolsFromDefinitions(
|
|||||||
return nil, fmt.Errorf("tool %q: %w", d.Function.Name, err)
|
return nil, fmt.Errorf("tool %q: %w", d.Function.Name, err)
|
||||||
}
|
}
|
||||||
out = append(out, &mcpBridgeTool{
|
out = append(out, &mcpBridgeTool{
|
||||||
info: info,
|
info: info,
|
||||||
name: d.Function.Name,
|
name: d.Function.Name,
|
||||||
agent: ag,
|
agent: ag,
|
||||||
holder: holder,
|
holder: holder,
|
||||||
record: rec,
|
record: rec,
|
||||||
chunk: toolOutputChunk,
|
chunk: toolOutputChunk,
|
||||||
|
invokeNotify: invokeNotify,
|
||||||
|
einoAgentName: strings.TrimSpace(einoAgentName),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
return out, nil
|
return out, nil
|
||||||
@@ -77,12 +83,14 @@ func toolInfoFromDefinition(d agent.Tool) (*schema.ToolInfo, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type mcpBridgeTool struct {
|
type mcpBridgeTool struct {
|
||||||
info *schema.ToolInfo
|
info *schema.ToolInfo
|
||||||
name string
|
name string
|
||||||
agent *agent.Agent
|
agent *agent.Agent
|
||||||
holder *ConversationHolder
|
holder *ConversationHolder
|
||||||
record ExecutionRecorder
|
record ExecutionRecorder
|
||||||
chunk func(toolName, toolCallID, chunk string)
|
chunk func(toolName, toolCallID, chunk string)
|
||||||
|
invokeNotify *ToolInvokeNotifyHolder
|
||||||
|
einoAgentName string
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *mcpBridgeTool) Info(ctx context.Context) (*schema.ToolInfo, error) {
|
func (m *mcpBridgeTool) Info(ctx context.Context) (*schema.ToolInfo, error) {
|
||||||
@@ -90,8 +98,27 @@ func (m *mcpBridgeTool) Info(ctx context.Context) (*schema.ToolInfo, error) {
|
|||||||
return m.info, nil
|
return m.info, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *mcpBridgeTool) InvokableRun(ctx context.Context, argumentsInJSON string, opts ...tool.Option) (string, error) {
|
func (m *mcpBridgeTool) InvokableRun(ctx context.Context, argumentsInJSON string, opts ...tool.Option) (out string, err error) {
|
||||||
_ = opts
|
_ = opts
|
||||||
|
toolCallID := compose.GetToolCallID(ctx)
|
||||||
|
defer func() {
|
||||||
|
if m.invokeNotify == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
tid := strings.TrimSpace(toolCallID)
|
||||||
|
if tid == "" {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
success := err == nil && !strings.HasPrefix(out, ToolErrorPrefix)
|
||||||
|
body := out
|
||||||
|
if err != nil {
|
||||||
|
success = false
|
||||||
|
} else if strings.HasPrefix(out, ToolErrorPrefix) {
|
||||||
|
success = false
|
||||||
|
body = strings.TrimPrefix(out, ToolErrorPrefix)
|
||||||
|
}
|
||||||
|
m.invokeNotify.Fire(tid, m.name, m.einoAgentName, success, body, err)
|
||||||
|
}()
|
||||||
return runMCPToolInvocation(ctx, m.agent, m.holder, m.name, argumentsInJSON, m.record, m.chunk)
|
return runMCPToolInvocation(ctx, m.agent, m.holder, m.name, argumentsInJSON, m.record, m.chunk)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -0,0 +1,39 @@
|
|||||||
|
package einomcp
|
||||||
|
|
||||||
|
import "sync"
|
||||||
|
|
||||||
|
// ToolInvokeNotifyHolder 由 Eino run loop 在迭代开始前 Set 回调;MCP 桥在每次 InvokableRun 结束时 Fire,
|
||||||
|
// 用于在 ADK 未透出 schema.Tool 事件时仍推送 tool_result、清 pending,避免 UI 卡在「执行中」或迭代末 force-close。
|
||||||
|
type ToolInvokeNotifyHolder struct {
|
||||||
|
mu sync.RWMutex
|
||||||
|
fn func(toolCallID, toolName, einoAgent string, success bool, content string, invokeErr error)
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewToolInvokeNotifyHolder 创建可在 ToolsFromDefinitions 与 run loop 之间共享的 holder。
|
||||||
|
func NewToolInvokeNotifyHolder() *ToolInvokeNotifyHolder {
|
||||||
|
return &ToolInvokeNotifyHolder{}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set 由 runEinoADKAgentLoop 在开始消费 iter 之前调用;可多次覆盖(通常仅一次)。
|
||||||
|
func (h *ToolInvokeNotifyHolder) Set(fn func(toolCallID, toolName, einoAgent string, success bool, content string, invokeErr error)) {
|
||||||
|
if h == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
h.mu.Lock()
|
||||||
|
defer h.mu.Unlock()
|
||||||
|
h.fn = fn
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fire 由 mcpBridgeTool 在工具调用返回时调用;若尚未 Set 或 toolCallID 为空则忽略。
|
||||||
|
func (h *ToolInvokeNotifyHolder) Fire(toolCallID, toolName, einoAgent string, success bool, content string, invokeErr error) {
|
||||||
|
if h == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
h.mu.RLock()
|
||||||
|
fn := h.fn
|
||||||
|
h.mu.RUnlock()
|
||||||
|
if fn == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
fn(toolCallID, toolName, einoAgent, success, content, invokeErr)
|
||||||
|
}
|
||||||
@@ -0,0 +1,435 @@
|
|||||||
|
// Package einoobserve attaches CloudWeGo Eino [callbacks.Handler] to ADK Runner contexts for
|
||||||
|
// structured logging and optional SSE trace events (eino_trace_*).
|
||||||
|
package einoobserve
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"cyberstrike-ai/internal/config"
|
||||||
|
|
||||||
|
"github.com/cloudwego/eino/adk"
|
||||||
|
"github.com/cloudwego/eino/callbacks"
|
||||||
|
"github.com/cloudwego/eino/components"
|
||||||
|
"github.com/cloudwego/eino/components/model"
|
||||||
|
"github.com/cloudwego/eino/components/tool"
|
||||||
|
"github.com/cloudwego/eino/schema"
|
||||||
|
"github.com/google/uuid"
|
||||||
|
"go.opentelemetry.io/otel"
|
||||||
|
"go.opentelemetry.io/otel/attribute"
|
||||||
|
"go.opentelemetry.io/otel/codes"
|
||||||
|
"go.opentelemetry.io/otel/trace"
|
||||||
|
"go.uber.org/zap"
|
||||||
|
)
|
||||||
|
|
||||||
|
type ctxSpanKey struct{}
|
||||||
|
|
||||||
|
type ctxOtelSpanKey struct{}
|
||||||
|
|
||||||
|
// Params for attaching per-run callback instrumentation.
|
||||||
|
type Params struct {
|
||||||
|
Logger *zap.Logger
|
||||||
|
Progress func(eventType, message string, data interface{})
|
||||||
|
ConversationID string
|
||||||
|
OrchMode string
|
||||||
|
OrchestratorName string
|
||||||
|
}
|
||||||
|
|
||||||
|
// AttachAgentRunCallbacks returns ctx wrapped with callbacks.InitCallbacks when enabled.
|
||||||
|
// Safe to call with nil cfg or disabled cfg (returns ctx unchanged).
|
||||||
|
func AttachAgentRunCallbacks(ctx context.Context, cfg *config.MultiAgentEinoCallbacksConfig, p Params) context.Context {
|
||||||
|
if ctx == nil {
|
||||||
|
return ctx
|
||||||
|
}
|
||||||
|
if cfg == nil || !cfg.Enabled {
|
||||||
|
return ctx
|
||||||
|
}
|
||||||
|
mode := cfg.EinoCallbacksModeEffective()
|
||||||
|
if mode == "off" {
|
||||||
|
return ctx
|
||||||
|
}
|
||||||
|
runID := uuid.New().String()
|
||||||
|
if p.Progress != nil && cfg.ShouldEmitEinoTraceSSE(mode) {
|
||||||
|
p.Progress("eino_trace_run", "Eino callbacks session", map[string]interface{}{
|
||||||
|
"runId": runID,
|
||||||
|
"conversationId": strings.TrimSpace(p.ConversationID),
|
||||||
|
"orchestration": strings.TrimSpace(p.OrchMode),
|
||||||
|
"orchestratorName": strings.TrimSpace(p.OrchestratorName),
|
||||||
|
"observeMode": mode,
|
||||||
|
"source": "eino_callbacks",
|
||||||
|
})
|
||||||
|
}
|
||||||
|
h := &runHandler{
|
||||||
|
cfg: *cfg,
|
||||||
|
mode: mode,
|
||||||
|
params: p,
|
||||||
|
runID: runID,
|
||||||
|
}
|
||||||
|
b := callbacks.NewHandlerBuilder().
|
||||||
|
OnStartFn(h.onStart).
|
||||||
|
OnEndFn(h.onEnd).
|
||||||
|
OnErrorFn(h.onError)
|
||||||
|
if mode == "full" {
|
||||||
|
b = b.OnStartWithStreamInputFn(h.onStartStreamIn).OnEndWithStreamOutputFn(h.onEndStreamOut)
|
||||||
|
}
|
||||||
|
ri := &callbacks.RunInfo{
|
||||||
|
Name: "CyberStrikeADKRun",
|
||||||
|
Type: strings.TrimSpace(p.OrchMode),
|
||||||
|
Component: components.Component("AgentSession"),
|
||||||
|
}
|
||||||
|
return callbacks.InitCallbacks(ctx, ri, b.Build())
|
||||||
|
}
|
||||||
|
|
||||||
|
type runHandler struct {
|
||||||
|
cfg config.MultiAgentEinoCallbacksConfig
|
||||||
|
mode string
|
||||||
|
params Params
|
||||||
|
runID string
|
||||||
|
|
||||||
|
mu sync.Mutex
|
||||||
|
spanStack []string
|
||||||
|
seq atomic.Uint64
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *runHandler) genSpanID() string {
|
||||||
|
return fmt.Sprintf("%s-%d", h.runID, h.seq.Add(1))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *runHandler) popSpan() (id string) {
|
||||||
|
h.mu.Lock()
|
||||||
|
defer h.mu.Unlock()
|
||||||
|
if len(h.spanStack) == 0 {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
id = h.spanStack[len(h.spanStack)-1]
|
||||||
|
h.spanStack = h.spanStack[:len(h.spanStack)-1]
|
||||||
|
return id
|
||||||
|
}
|
||||||
|
|
||||||
|
// popMatching removes the given id from the stack top if it matches; otherwise pops until empty or match (rare ordering mismatch).
|
||||||
|
func (h *runHandler) popMatching(want string) string {
|
||||||
|
h.mu.Lock()
|
||||||
|
defer h.mu.Unlock()
|
||||||
|
if want == "" {
|
||||||
|
if len(h.spanStack) == 0 {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
id := h.spanStack[len(h.spanStack)-1]
|
||||||
|
h.spanStack = h.spanStack[:len(h.spanStack)-1]
|
||||||
|
return id
|
||||||
|
}
|
||||||
|
for len(h.spanStack) > 0 {
|
||||||
|
top := h.spanStack[len(h.spanStack)-1]
|
||||||
|
h.spanStack = h.spanStack[:len(h.spanStack)-1]
|
||||||
|
if top == want {
|
||||||
|
return top
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return want
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *runHandler) onStart(ctx context.Context, info *callbacks.RunInfo, input callbacks.CallbackInput) context.Context {
|
||||||
|
var parentID string
|
||||||
|
h.mu.Lock()
|
||||||
|
if len(h.spanStack) > 0 {
|
||||||
|
parentID = h.spanStack[len(h.spanStack)-1]
|
||||||
|
}
|
||||||
|
spanID := h.genSpanID()
|
||||||
|
h.spanStack = append(h.spanStack, spanID)
|
||||||
|
h.mu.Unlock()
|
||||||
|
|
||||||
|
inSum := summarizeCallbackInput(input, h.cfg.EinoCallbacksMaxInputSummaryRunes())
|
||||||
|
if h.cfg.OtelTracingActive() {
|
||||||
|
tracer := otel.Tracer("cyberstrike/eino")
|
||||||
|
spanName := callbackSpanName(info)
|
||||||
|
var sp trace.Span
|
||||||
|
ctx, sp = tracer.Start(ctx, spanName,
|
||||||
|
trace.WithSpanKind(trace.SpanKindInternal),
|
||||||
|
trace.WithAttributes(
|
||||||
|
attribute.String("eino.component", string(info.Component)),
|
||||||
|
attribute.String("eino.name", info.Name),
|
||||||
|
attribute.String("eino.type", info.Type),
|
||||||
|
attribute.String("cyberstrike.run_id", h.runID),
|
||||||
|
attribute.String("cyberstrike.conversation_id", strings.TrimSpace(h.params.ConversationID)),
|
||||||
|
attribute.String("cyberstrike.orchestration", strings.TrimSpace(h.params.OrchMode)),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
if inSum != "" {
|
||||||
|
sp.SetAttributes(attribute.String("eino.input.summary", truncateForAttr(inSum, 256)))
|
||||||
|
}
|
||||||
|
ctx = context.WithValue(ctx, ctxOtelSpanKey{}, sp)
|
||||||
|
}
|
||||||
|
if h.params.Logger != nil {
|
||||||
|
fields := []zap.Field{
|
||||||
|
zap.String("runId", h.runID),
|
||||||
|
zap.String("spanId", spanID),
|
||||||
|
zap.String("parentSpanId", parentID),
|
||||||
|
zap.String("component", string(info.Component)),
|
||||||
|
zap.String("name", info.Name),
|
||||||
|
zap.String("type", info.Type),
|
||||||
|
zap.String("phase", "start"),
|
||||||
|
}
|
||||||
|
if sp, ok := ctx.Value(ctxOtelSpanKey{}).(trace.Span); ok && sp != nil {
|
||||||
|
if sc := sp.SpanContext(); sc.IsValid() {
|
||||||
|
fields = append(fields,
|
||||||
|
zap.String("trace_id", sc.TraceID().String()),
|
||||||
|
zap.String("otel_span_id", sc.SpanID().String()),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if h.cfg.ZapVerbose {
|
||||||
|
h.params.Logger.Debug("eino_callback", append(fields, zap.String("inputSummary", inSum))...)
|
||||||
|
} else {
|
||||||
|
h.params.Logger.Info("eino_callback", fields...)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if h.params.Progress != nil && h.cfg.ShouldEmitEinoTraceSSE(h.mode) {
|
||||||
|
h.params.Progress("eino_trace_start", "", map[string]interface{}{
|
||||||
|
"runId": h.runID,
|
||||||
|
"spanId": spanID,
|
||||||
|
"parentSpanId": parentID,
|
||||||
|
"conversationId": strings.TrimSpace(h.params.ConversationID),
|
||||||
|
"orchestration": strings.TrimSpace(h.params.OrchMode),
|
||||||
|
"component": string(info.Component),
|
||||||
|
"name": info.Name,
|
||||||
|
"type": info.Type,
|
||||||
|
"ts": time.Now().UTC().Format(time.RFC3339Nano),
|
||||||
|
"inputSummary": inSum,
|
||||||
|
"source": "eino_callbacks",
|
||||||
|
})
|
||||||
|
}
|
||||||
|
ctx = context.WithValue(ctx, ctxSpanKey{}, spanID)
|
||||||
|
return ctx
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *runHandler) onEnd(ctx context.Context, info *callbacks.RunInfo, output callbacks.CallbackOutput) context.Context {
|
||||||
|
spanID, _ := ctx.Value(ctxSpanKey{}).(string)
|
||||||
|
if spanID == "" {
|
||||||
|
spanID = h.popSpan()
|
||||||
|
} else {
|
||||||
|
spanID = h.popMatching(spanID)
|
||||||
|
}
|
||||||
|
outSum := summarizeCallbackOutput(output, h.cfg.EinoCallbacksMaxOutputSummaryRunes())
|
||||||
|
if sp, ok := ctx.Value(ctxOtelSpanKey{}).(trace.Span); ok && sp != nil {
|
||||||
|
if outSum != "" {
|
||||||
|
sp.SetAttributes(attribute.String("eino.output.summary", truncateForAttr(outSum, 256)))
|
||||||
|
}
|
||||||
|
sp.SetStatus(codes.Ok, "")
|
||||||
|
sp.End()
|
||||||
|
}
|
||||||
|
if h.params.Logger != nil {
|
||||||
|
fields := []zap.Field{
|
||||||
|
zap.String("runId", h.runID),
|
||||||
|
zap.String("spanId", spanID),
|
||||||
|
zap.String("component", string(info.Component)),
|
||||||
|
zap.String("name", info.Name),
|
||||||
|
zap.String("type", info.Type),
|
||||||
|
zap.String("phase", "end"),
|
||||||
|
}
|
||||||
|
if h.cfg.ZapVerbose {
|
||||||
|
h.params.Logger.Debug("eino_callback", append(fields, zap.String("outputSummary", outSum))...)
|
||||||
|
} else {
|
||||||
|
h.params.Logger.Info("eino_callback", fields...)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if h.params.Progress != nil && h.cfg.ShouldEmitEinoTraceSSE(h.mode) {
|
||||||
|
h.params.Progress("eino_trace_end", "", map[string]interface{}{
|
||||||
|
"runId": h.runID,
|
||||||
|
"spanId": spanID,
|
||||||
|
"conversationId": strings.TrimSpace(h.params.ConversationID),
|
||||||
|
"orchestration": strings.TrimSpace(h.params.OrchMode),
|
||||||
|
"component": string(info.Component),
|
||||||
|
"name": info.Name,
|
||||||
|
"type": info.Type,
|
||||||
|
"ts": time.Now().UTC().Format(time.RFC3339Nano),
|
||||||
|
"outputSummary": outSum,
|
||||||
|
"source": "eino_callbacks",
|
||||||
|
})
|
||||||
|
}
|
||||||
|
return ctx
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *runHandler) onError(ctx context.Context, info *callbacks.RunInfo, err error) context.Context {
|
||||||
|
spanID, _ := ctx.Value(ctxSpanKey{}).(string)
|
||||||
|
if spanID == "" {
|
||||||
|
spanID = h.popSpan()
|
||||||
|
} else {
|
||||||
|
spanID = h.popMatching(spanID)
|
||||||
|
}
|
||||||
|
msg := ""
|
||||||
|
if err != nil {
|
||||||
|
msg = truncateRunes(err.Error(), h.cfg.EinoCallbacksMaxOutputSummaryRunes())
|
||||||
|
}
|
||||||
|
if sp, ok := ctx.Value(ctxOtelSpanKey{}).(trace.Span); ok && sp != nil {
|
||||||
|
if err != nil {
|
||||||
|
sp.RecordError(err)
|
||||||
|
}
|
||||||
|
sp.SetStatus(codes.Error, msg)
|
||||||
|
sp.End()
|
||||||
|
}
|
||||||
|
if h.params.Logger != nil {
|
||||||
|
h.params.Logger.Warn("eino_callback_error",
|
||||||
|
zap.String("runId", h.runID),
|
||||||
|
zap.String("spanId", spanID),
|
||||||
|
zap.String("component", string(info.Component)),
|
||||||
|
zap.String("name", info.Name),
|
||||||
|
zap.String("type", info.Type),
|
||||||
|
zap.Error(err),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
if h.params.Progress != nil && h.cfg.ShouldEmitEinoTraceSSE(h.mode) {
|
||||||
|
h.params.Progress("eino_trace_error", msg, map[string]interface{}{
|
||||||
|
"runId": h.runID,
|
||||||
|
"spanId": spanID,
|
||||||
|
"conversationId": strings.TrimSpace(h.params.ConversationID),
|
||||||
|
"orchestration": strings.TrimSpace(h.params.OrchMode),
|
||||||
|
"component": string(info.Component),
|
||||||
|
"name": info.Name,
|
||||||
|
"type": info.Type,
|
||||||
|
"ts": time.Now().UTC().Format(time.RFC3339Nano),
|
||||||
|
"error": msg,
|
||||||
|
"source": "eino_callbacks",
|
||||||
|
})
|
||||||
|
}
|
||||||
|
return ctx
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *runHandler) onStartStreamIn(ctx context.Context, info *callbacks.RunInfo, input *schema.StreamReader[callbacks.CallbackInput]) context.Context {
|
||||||
|
if input != nil {
|
||||||
|
input.Close()
|
||||||
|
}
|
||||||
|
if h.params.Logger != nil {
|
||||||
|
h.params.Logger.Debug("eino_callback_stream_in",
|
||||||
|
zap.String("runId", h.runID),
|
||||||
|
zap.String("component", string(info.Component)),
|
||||||
|
zap.String("name", info.Name),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
return ctx
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *runHandler) onEndStreamOut(ctx context.Context, info *callbacks.RunInfo, output *schema.StreamReader[callbacks.CallbackOutput]) context.Context {
|
||||||
|
if output != nil {
|
||||||
|
output.Close()
|
||||||
|
}
|
||||||
|
if h.params.Logger != nil {
|
||||||
|
h.params.Logger.Debug("eino_callback_stream_out",
|
||||||
|
zap.String("runId", h.runID),
|
||||||
|
zap.String("component", string(info.Component)),
|
||||||
|
zap.String("name", info.Name),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
return ctx
|
||||||
|
}
|
||||||
|
|
||||||
|
func callbackSpanName(info *callbacks.RunInfo) string {
|
||||||
|
if info == nil {
|
||||||
|
return "eino.callback"
|
||||||
|
}
|
||||||
|
comp := strings.TrimSpace(string(info.Component))
|
||||||
|
name := strings.TrimSpace(info.Name)
|
||||||
|
typ := strings.TrimSpace(info.Type)
|
||||||
|
if name != "" && comp != "" {
|
||||||
|
return comp + "/" + name
|
||||||
|
}
|
||||||
|
if typ != "" && comp != "" {
|
||||||
|
return comp + "[" + typ + "]"
|
||||||
|
}
|
||||||
|
if comp != "" {
|
||||||
|
return comp
|
||||||
|
}
|
||||||
|
return "eino.callback"
|
||||||
|
}
|
||||||
|
|
||||||
|
func truncateForAttr(s string, maxRunes int) string {
|
||||||
|
return truncateRunes(s, maxRunes)
|
||||||
|
}
|
||||||
|
|
||||||
|
func summarizeCallbackInput(in callbacks.CallbackInput, maxRunes int) string {
|
||||||
|
if in == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
if ai := adk.ConvAgentCallbackInput(in); ai != nil {
|
||||||
|
parts := []string{"agent"}
|
||||||
|
if ai.Input != nil {
|
||||||
|
parts = append(parts, fmt.Sprintf("messages=%d", len(ai.Input.Messages)))
|
||||||
|
}
|
||||||
|
if ai.ResumeInfo != nil {
|
||||||
|
parts = append(parts, "resume=true")
|
||||||
|
}
|
||||||
|
return strings.Join(parts, " ")
|
||||||
|
}
|
||||||
|
if mi := model.ConvCallbackInput(in); mi != nil {
|
||||||
|
return fmt.Sprintf("chatModel messages=%d tools=%d", len(mi.Messages), len(mi.Tools))
|
||||||
|
}
|
||||||
|
if ti := tool.ConvCallbackInput(in); ti != nil {
|
||||||
|
raw := ti.ArgumentsInJSON
|
||||||
|
return "tool args=" + truncateRunes(raw, maxRunes)
|
||||||
|
}
|
||||||
|
b, err := json.Marshal(in)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Sprintf("%T", in)
|
||||||
|
}
|
||||||
|
return truncateRunes(string(b), maxRunes)
|
||||||
|
}
|
||||||
|
|
||||||
|
func summarizeCallbackOutput(out callbacks.CallbackOutput, maxRunes int) string {
|
||||||
|
if out == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
if ao := adk.ConvAgentCallbackOutput(out); ao != nil {
|
||||||
|
return "agent_events=stream"
|
||||||
|
}
|
||||||
|
if mo := model.ConvCallbackOutput(out); mo != nil && mo.Message != nil {
|
||||||
|
s := ""
|
||||||
|
if mo.Message.Content != "" {
|
||||||
|
s = mo.Message.Content
|
||||||
|
}
|
||||||
|
if mo.TokenUsage != nil {
|
||||||
|
return fmt.Sprintf("tokens total=%d completion=%d prompt=%d text=%s",
|
||||||
|
mo.TokenUsage.TotalTokens, mo.TokenUsage.CompletionTokens, mo.TokenUsage.PromptTokens,
|
||||||
|
truncateRunes(s, minInt(120, maxRunes)))
|
||||||
|
}
|
||||||
|
return "assistant len=" + itoa(len(s))
|
||||||
|
}
|
||||||
|
if to := tool.ConvCallbackOutput(out); to != nil {
|
||||||
|
if to.Response != "" {
|
||||||
|
return truncateRunes(to.Response, maxRunes)
|
||||||
|
}
|
||||||
|
if to.ToolOutput != nil {
|
||||||
|
return "tool_result multimodal"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
b, err := json.Marshal(out)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Sprintf("%T", out)
|
||||||
|
}
|
||||||
|
return truncateRunes(string(b), maxRunes)
|
||||||
|
}
|
||||||
|
|
||||||
|
func minInt(a, b int) int {
|
||||||
|
if a < b {
|
||||||
|
return a
|
||||||
|
}
|
||||||
|
return b
|
||||||
|
}
|
||||||
|
|
||||||
|
func itoa(n int) string {
|
||||||
|
return fmt.Sprintf("%d", n)
|
||||||
|
}
|
||||||
|
|
||||||
|
func truncateRunes(s string, maxRunes int) string {
|
||||||
|
if maxRunes <= 0 {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
r := []rune(s)
|
||||||
|
if len(r) <= maxRunes {
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
return string(r[:maxRunes]) + "…"
|
||||||
|
}
|
||||||
@@ -0,0 +1,26 @@
|
|||||||
|
package einoobserve
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"cyberstrike-ai/internal/config"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestAttachAgentRunCallbacks_Disabled(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
cfg := &config.MultiAgentEinoCallbacksConfig{Enabled: false}
|
||||||
|
out := AttachAgentRunCallbacks(ctx, cfg, Params{})
|
||||||
|
if out != ctx {
|
||||||
|
t.Fatalf("expected same ctx when disabled")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTruncateRunes(t *testing.T) {
|
||||||
|
if got := truncateRunes("abc", 10); got != "abc" {
|
||||||
|
t.Fatalf("got %q", got)
|
||||||
|
}
|
||||||
|
if got := truncateRunes("abcdefghij", 4); got != "abcd…" {
|
||||||
|
t.Fatalf("got %q", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,111 @@
|
|||||||
|
package einoobserve
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
"cyberstrike-ai/internal/config"
|
||||||
|
|
||||||
|
"go.opentelemetry.io/otel"
|
||||||
|
"go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp"
|
||||||
|
"go.opentelemetry.io/otel/exporters/stdout/stdouttrace"
|
||||||
|
"go.opentelemetry.io/otel/sdk/resource"
|
||||||
|
sdktrace "go.opentelemetry.io/otel/sdk/trace"
|
||||||
|
semconv "go.opentelemetry.io/otel/semconv/v1.26.0"
|
||||||
|
"go.uber.org/zap"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
otelMu sync.Mutex
|
||||||
|
otelShutdown func(context.Context) error
|
||||||
|
otelInitialized bool
|
||||||
|
)
|
||||||
|
|
||||||
|
// InitOtelFromConfig installs the global OpenTelemetry TracerProvider when
|
||||||
|
// eino_callbacks.otel is enabled and exporter is not none. Safe to call multiple times.
|
||||||
|
func InitOtelFromConfig(cfg *config.MultiAgentEinoCallbacksConfig, log *zap.Logger) (shutdown func(context.Context) error, err error) {
|
||||||
|
shutdown = func(context.Context) error { return nil }
|
||||||
|
if cfg == nil || !cfg.OtelTracingActive() {
|
||||||
|
return shutdown, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
otelMu.Lock()
|
||||||
|
defer otelMu.Unlock()
|
||||||
|
if otelInitialized {
|
||||||
|
if otelShutdown != nil {
|
||||||
|
return otelShutdown, nil
|
||||||
|
}
|
||||||
|
return shutdown, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
oc := cfg.Otel
|
||||||
|
expKind := oc.OtelExporterEffective()
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
var exporter sdktrace.SpanExporter
|
||||||
|
switch expKind {
|
||||||
|
case "stdout":
|
||||||
|
exporter, err = stdouttrace.New()
|
||||||
|
if err != nil {
|
||||||
|
return shutdown, fmt.Errorf("eino otel stdout exporter: %w", err)
|
||||||
|
}
|
||||||
|
case "otlphttp":
|
||||||
|
ep := strings.TrimSpace(oc.OTLPEndpoint)
|
||||||
|
if ep == "" {
|
||||||
|
ep = "localhost:4318"
|
||||||
|
}
|
||||||
|
exporter, err = otlptracehttp.New(ctx,
|
||||||
|
otlptracehttp.WithEndpoint(ep),
|
||||||
|
otlptracehttp.WithURLPath("/v1/traces"),
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return shutdown, fmt.Errorf("eino otel otlphttp exporter: %w", err)
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
return shutdown, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
res, err := resource.New(ctx,
|
||||||
|
resource.WithAttributes(
|
||||||
|
semconv.ServiceName(oc.ServiceNameEffective()),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return shutdown, fmt.Errorf("eino otel resource: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
sampler := sdktrace.ParentBased(sdktrace.TraceIDRatioBased(oc.SampleRatioEffective()))
|
||||||
|
tp := sdktrace.NewTracerProvider(
|
||||||
|
sdktrace.WithBatcher(exporter),
|
||||||
|
sdktrace.WithResource(res),
|
||||||
|
sdktrace.WithSampler(sampler),
|
||||||
|
)
|
||||||
|
otel.SetTracerProvider(tp)
|
||||||
|
|
||||||
|
otelShutdown = tp.Shutdown
|
||||||
|
otelInitialized = true
|
||||||
|
if log != nil {
|
||||||
|
log.Info("eino otel: tracer provider initialized",
|
||||||
|
zap.String("exporter", expKind),
|
||||||
|
zap.String("service", oc.ServiceNameEffective()),
|
||||||
|
zap.Float64("sample_ratio", oc.SampleRatioEffective()),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
return otelShutdown, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ShutdownOtel flushes and shuts down the global TracerProvider if it was installed.
|
||||||
|
func ShutdownOtel(ctx context.Context) error {
|
||||||
|
otelMu.Lock()
|
||||||
|
fn := otelShutdown
|
||||||
|
otelShutdown = nil
|
||||||
|
inited := otelInitialized
|
||||||
|
otelInitialized = false
|
||||||
|
otelMu.Unlock()
|
||||||
|
if !inited || fn == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return fn(ctx)
|
||||||
|
}
|
||||||
+144
-97
@@ -19,9 +19,11 @@ import (
|
|||||||
"cyberstrike-ai/internal/agent"
|
"cyberstrike-ai/internal/agent"
|
||||||
"cyberstrike-ai/internal/config"
|
"cyberstrike-ai/internal/config"
|
||||||
"cyberstrike-ai/internal/database"
|
"cyberstrike-ai/internal/database"
|
||||||
|
"cyberstrike-ai/internal/reasoning"
|
||||||
"cyberstrike-ai/internal/mcp"
|
"cyberstrike-ai/internal/mcp"
|
||||||
"cyberstrike-ai/internal/mcp/builtin"
|
"cyberstrike-ai/internal/mcp/builtin"
|
||||||
"cyberstrike-ai/internal/multiagent"
|
"cyberstrike-ai/internal/multiagent"
|
||||||
|
"cyberstrike-ai/internal/openai"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/robfig/cron/v3"
|
"github.com/robfig/cron/v3"
|
||||||
@@ -201,6 +203,14 @@ type ChatAttachment struct {
|
|||||||
ServerPath string `json:"serverPath,omitempty"` // 已保存在 chat_uploads 下的绝对路径(由 POST /api/chat-uploads 返回)
|
ServerPath string `json:"serverPath,omitempty"` // 已保存在 chat_uploads 下的绝对路径(由 POST /api/chat-uploads 返回)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ChatReasoningRequest 对话页「模型推理」意图(仅 Eino 路径消费;原生 agent-loop 忽略)。
|
||||||
|
type ChatReasoningRequest struct {
|
||||||
|
// Mode: default(跟随系统)| off | on | auto
|
||||||
|
Mode string `json:"mode,omitempty"`
|
||||||
|
// Effort: low | medium | high | max;空表示不指定(由系统默认与各 profile 决定)。
|
||||||
|
Effort string `json:"effort,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
// ChatRequest 聊天请求
|
// ChatRequest 聊天请求
|
||||||
type ChatRequest struct {
|
type ChatRequest struct {
|
||||||
Message string `json:"message" binding:"required"`
|
Message string `json:"message" binding:"required"`
|
||||||
@@ -209,10 +219,18 @@ type ChatRequest struct {
|
|||||||
Attachments []ChatAttachment `json:"attachments,omitempty"`
|
Attachments []ChatAttachment `json:"attachments,omitempty"`
|
||||||
WebShellConnectionID string `json:"webshellConnectionId,omitempty"` // WebShell 管理 - AI 助手:当前选中的连接 ID,仅使用 webshell_* 工具
|
WebShellConnectionID string `json:"webshellConnectionId,omitempty"` // WebShell 管理 - AI 助手:当前选中的连接 ID,仅使用 webshell_* 工具
|
||||||
Hitl *HITLRequest `json:"hitl,omitempty"`
|
Hitl *HITLRequest `json:"hitl,omitempty"`
|
||||||
|
Reasoning *ChatReasoningRequest `json:"reasoning,omitempty"`
|
||||||
// Orchestration 仅对 /api/multi-agent、/api/multi-agent/stream:deep | plan_execute | supervisor;空则等同 deep。机器人/批量等无请求体时由服务端默认 deep。/api/eino-agent* 不使用此字段。
|
// Orchestration 仅对 /api/multi-agent、/api/multi-agent/stream:deep | plan_execute | supervisor;空则等同 deep。机器人/批量等无请求体时由服务端默认 deep。/api/eino-agent* 不使用此字段。
|
||||||
Orchestration string `json:"orchestration,omitempty"`
|
Orchestration string `json:"orchestration,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func chatReasoningToClientIntent(r *ChatReasoningRequest) *reasoning.ClientIntent {
|
||||||
|
if r == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return &reasoning.ClientIntent{Mode: r.Mode, Effort: r.Effort}
|
||||||
|
}
|
||||||
|
|
||||||
type HITLRequest struct {
|
type HITLRequest struct {
|
||||||
Enabled bool `json:"enabled"`
|
Enabled bool `json:"enabled"`
|
||||||
Mode string `json:"mode,omitempty"`
|
Mode string `json:"mode,omitempty"`
|
||||||
@@ -567,14 +585,7 @@ func (h *AgentHandler) AgentLoop(c *gin.Context) {
|
|||||||
h.logger.Warn("获取历史消息失败", zap.Error(err))
|
h.logger.Warn("获取历史消息失败", zap.Error(err))
|
||||||
agentHistoryMessages = []agent.ChatMessage{}
|
agentHistoryMessages = []agent.ChatMessage{}
|
||||||
} else {
|
} else {
|
||||||
// 将数据库消息转换为Agent消息格式
|
agentHistoryMessages = dbMessagesToAgentChatMessages(historyMessages)
|
||||||
agentHistoryMessages = make([]agent.ChatMessage, 0, len(historyMessages))
|
|
||||||
for _, msg := range historyMessages {
|
|
||||||
agentHistoryMessages = append(agentHistoryMessages, agent.ChatMessage{
|
|
||||||
Role: msg.Role,
|
|
||||||
Content: msg.Content,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
h.logger.Info("从消息表加载历史消息", zap.Int("count", len(agentHistoryMessages)))
|
h.logger.Info("从消息表加载历史消息", zap.Int("count", len(agentHistoryMessages)))
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
@@ -775,6 +786,7 @@ func (h *AgentHandler) ProcessMessageForRobot(ctx context.Context, conversationI
|
|||||||
progressCallback,
|
progressCallback,
|
||||||
h.agentsMarkdownDir,
|
h.agentsMarkdownDir,
|
||||||
"deep",
|
"deep",
|
||||||
|
nil,
|
||||||
)
|
)
|
||||||
if errMA != nil {
|
if errMA != nil {
|
||||||
if shouldPersistEinoAgentTraceAfterRunError(ctx) {
|
if shouldPersistEinoAgentTraceAfterRunError(ctx) {
|
||||||
@@ -788,17 +800,8 @@ func (h *AgentHandler) ProcessMessageForRobot(ctx context.Context, conversationI
|
|||||||
return "", conversationID, errMA
|
return "", conversationID, errMA
|
||||||
}
|
}
|
||||||
if assistantMessageID != "" {
|
if assistantMessageID != "" {
|
||||||
mcpIDsJSON := ""
|
if errU := h.db.UpdateAssistantMessageFinalize(assistantMessageID, resultMA.Response, resultMA.MCPExecutionIDs, multiagent.AggregatedReasoningFromTraceJSON(resultMA.LastAgentTraceInput)); errU != nil {
|
||||||
if len(resultMA.MCPExecutionIDs) > 0 {
|
h.logger.Warn("机器人:更新助手消息失败", zap.Error(errU))
|
||||||
jsonData, _ := json.Marshal(resultMA.MCPExecutionIDs)
|
|
||||||
mcpIDsJSON = string(jsonData)
|
|
||||||
}
|
|
||||||
_, err = h.db.Exec(
|
|
||||||
"UPDATE messages SET content = ?, mcp_execution_ids = ?, updated_at = ? WHERE id = ?",
|
|
||||||
resultMA.Response, mcpIDsJSON, time.Now(), assistantMessageID,
|
|
||||||
)
|
|
||||||
if err != nil {
|
|
||||||
h.logger.Warn("机器人:更新助手消息失败", zap.Error(err))
|
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
if _, err = h.db.AddMessage(conversationID, "assistant", resultMA.Response, resultMA.MCPExecutionIDs); err != nil {
|
if _, err = h.db.AddMessage(conversationID, "assistant", resultMA.Response, resultMA.MCPExecutionIDs); err != nil {
|
||||||
@@ -823,17 +826,8 @@ func (h *AgentHandler) ProcessMessageForRobot(ctx context.Context, conversationI
|
|||||||
|
|
||||||
// 更新助手消息内容与 MCP 执行 ID(与 stream 一致)
|
// 更新助手消息内容与 MCP 执行 ID(与 stream 一致)
|
||||||
if assistantMessageID != "" {
|
if assistantMessageID != "" {
|
||||||
mcpIDsJSON := ""
|
if errU := h.db.UpdateAssistantMessageFinalize(assistantMessageID, result.Response, result.MCPExecutionIDs, multiagent.AggregatedReasoningFromTraceJSON(result.LastAgentTraceInput)); errU != nil {
|
||||||
if len(result.MCPExecutionIDs) > 0 {
|
h.logger.Warn("机器人:更新助手消息失败", zap.Error(errU))
|
||||||
jsonData, _ := json.Marshal(result.MCPExecutionIDs)
|
|
||||||
mcpIDsJSON = string(jsonData)
|
|
||||||
}
|
|
||||||
_, err = h.db.Exec(
|
|
||||||
"UPDATE messages SET content = ?, mcp_execution_ids = ?, updated_at = ? WHERE id = ?",
|
|
||||||
result.Response, mcpIDsJSON, time.Now(), assistantMessageID,
|
|
||||||
)
|
|
||||||
if err != nil {
|
|
||||||
h.logger.Warn("机器人:更新助手消息失败", zap.Error(err))
|
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
if _, err = h.db.AddMessage(conversationID, "assistant", result.Response, result.MCPExecutionIDs); err != nil {
|
if _, err = h.db.AddMessage(conversationID, "assistant", result.Response, result.MCPExecutionIDs); err != nil {
|
||||||
@@ -891,10 +885,12 @@ func (h *AgentHandler) createProgressCallback(runCtx context.Context, cancelRun
|
|||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
|
|
||||||
// thinking_stream_*:不逐条落库,按 streamId 聚合,在后续关键事件前补一条可持久化的 thinking
|
// thinking_stream_*(ReAct 等助手正文流)与 reasoning_chain_stream_*(Eino ReasoningContent):
|
||||||
|
// 不逐条落库,按 streamId 聚合,flush 时分别落 thinking / reasoning_chain。
|
||||||
type thinkingBuf struct {
|
type thinkingBuf struct {
|
||||||
b strings.Builder
|
b strings.Builder
|
||||||
meta map[string]interface{}
|
meta map[string]interface{}
|
||||||
|
persistAs string // "thinking" | "reasoning_chain"
|
||||||
}
|
}
|
||||||
thinkingStreams := make(map[string]*thinkingBuf) // streamId -> buf
|
thinkingStreams := make(map[string]*thinkingBuf) // streamId -> buf
|
||||||
flushedThinking := make(map[string]bool) // streamId -> flushed
|
flushedThinking := make(map[string]bool) // streamId -> flushed
|
||||||
@@ -948,8 +944,12 @@ func (h *AgentHandler) createProgressCallback(runCtx context.Context, cancelRun
|
|||||||
}
|
}
|
||||||
data[k] = v
|
data[k] = v
|
||||||
}
|
}
|
||||||
if err := h.db.AddProcessDetail(assistantMessageID, conversationID, "thinking", content, data); err != nil {
|
persist := tb.persistAs
|
||||||
h.logger.Warn("保存过程详情失败", zap.Error(err), zap.String("eventType", "thinking"))
|
if persist != "reasoning_chain" {
|
||||||
|
persist = "thinking"
|
||||||
|
}
|
||||||
|
if err := h.db.AddProcessDetail(assistantMessageID, conversationID, persist, content, data); err != nil {
|
||||||
|
h.logger.Warn("保存过程详情失败", zap.Error(err), zap.String("eventType", persist))
|
||||||
}
|
}
|
||||||
flushedThinking[sid] = true
|
flushedThinking[sid] = true
|
||||||
}
|
}
|
||||||
@@ -1159,7 +1159,16 @@ func (h *AgentHandler) createProgressCallback(runCtx context.Context, cancelRun
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
if eventType == "response_delta" {
|
if eventType == "response_delta" {
|
||||||
respPlan.b.WriteString(message)
|
if dataMap, ok := data.(map[string]interface{}); ok {
|
||||||
|
if acc, okAcc := dataMap[openai.SSEAccumulatedKey].(string); okAcc {
|
||||||
|
respPlan.b.Reset()
|
||||||
|
respPlan.b.WriteString(acc)
|
||||||
|
} else {
|
||||||
|
respPlan.b.WriteString(message)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
respPlan.b.WriteString(message)
|
||||||
|
}
|
||||||
if dataMap, ok := data.(map[string]interface{}); ok && respPlan.meta == nil {
|
if dataMap, ok := data.(map[string]interface{}); ok && respPlan.meta == nil {
|
||||||
respPlan.meta = make(map[string]interface{}, len(dataMap))
|
respPlan.meta = make(map[string]interface{}, len(dataMap))
|
||||||
for k, v := range dataMap {
|
for k, v := range dataMap {
|
||||||
@@ -1177,14 +1186,20 @@ func (h *AgentHandler) createProgressCallback(runCtx context.Context, cancelRun
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// 聚合 thinking_stream_*(ReasoningContent),不逐条落库
|
// 聚合 thinking_stream_* / reasoning_chain_stream_*,不逐条落库
|
||||||
if eventType == "thinking_stream_start" {
|
if eventType == "thinking_stream_start" || eventType == "reasoning_chain_stream_start" {
|
||||||
|
persistAs := "thinking"
|
||||||
|
if eventType == "reasoning_chain_stream_start" {
|
||||||
|
persistAs = "reasoning_chain"
|
||||||
|
}
|
||||||
if dataMap, ok := data.(map[string]interface{}); ok {
|
if dataMap, ok := data.(map[string]interface{}); ok {
|
||||||
if sid, ok2 := dataMap["streamId"].(string); ok2 && sid != "" {
|
if sid, ok2 := dataMap["streamId"].(string); ok2 && sid != "" {
|
||||||
tb := thinkingStreams[sid]
|
tb := thinkingStreams[sid]
|
||||||
if tb == nil {
|
if tb == nil {
|
||||||
tb = &thinkingBuf{meta: map[string]interface{}{}}
|
tb = &thinkingBuf{meta: map[string]interface{}{}, persistAs: persistAs}
|
||||||
thinkingStreams[sid] = tb
|
thinkingStreams[sid] = tb
|
||||||
|
} else {
|
||||||
|
tb.persistAs = persistAs
|
||||||
}
|
}
|
||||||
// 记录元信息(source/einoAgent/einoRole/iteration 等)
|
// 记录元信息(source/einoAgent/einoRole/iteration 等)
|
||||||
for k, v := range dataMap {
|
for k, v := range dataMap {
|
||||||
@@ -1194,16 +1209,26 @@ func (h *AgentHandler) createProgressCallback(runCtx context.Context, cancelRun
|
|||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if eventType == "thinking_stream_delta" {
|
if eventType == "thinking_stream_delta" || eventType == "reasoning_chain_stream_delta" {
|
||||||
|
persistAs := "thinking"
|
||||||
|
if eventType == "reasoning_chain_stream_delta" {
|
||||||
|
persistAs = "reasoning_chain"
|
||||||
|
}
|
||||||
if dataMap, ok := data.(map[string]interface{}); ok {
|
if dataMap, ok := data.(map[string]interface{}); ok {
|
||||||
if sid, ok2 := dataMap["streamId"].(string); ok2 && sid != "" {
|
if sid, ok2 := dataMap["streamId"].(string); ok2 && sid != "" {
|
||||||
tb := thinkingStreams[sid]
|
tb := thinkingStreams[sid]
|
||||||
if tb == nil {
|
if tb == nil {
|
||||||
tb = &thinkingBuf{meta: map[string]interface{}{}}
|
tb = &thinkingBuf{meta: map[string]interface{}{}, persistAs: persistAs}
|
||||||
thinkingStreams[sid] = tb
|
thinkingStreams[sid] = tb
|
||||||
|
} else if tb.persistAs == "" {
|
||||||
|
tb.persistAs = persistAs
|
||||||
|
}
|
||||||
|
if acc, okAcc := dataMap[openai.SSEAccumulatedKey].(string); okAcc {
|
||||||
|
tb.b.Reset()
|
||||||
|
tb.b.WriteString(acc)
|
||||||
|
} else {
|
||||||
|
tb.b.WriteString(message)
|
||||||
}
|
}
|
||||||
// delta 片段直接拼接;message 本身就是 reasoning content
|
|
||||||
tb.b.WriteString(message)
|
|
||||||
// 有时 delta 先到 start 未到,补充元信息
|
// 有时 delta 先到 start 未到,补充元信息
|
||||||
for k, v := range dataMap {
|
for k, v := range dataMap {
|
||||||
tb.meta[k] = v
|
tb.meta[k] = v
|
||||||
@@ -1213,10 +1238,9 @@ func (h *AgentHandler) createProgressCallback(runCtx context.Context, cancelRun
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// 当 Agent 同时发送 thinking_stream_* 和 thinking(带同一 streamId)时,
|
// 当 Agent 同时发送 *_stream_* 与同名 streamId 的 thinking/reasoning_chain 时,
|
||||||
// thinking_stream_* 已经会在 flushThinkingStreams() 聚合落库;
|
// 流式聚合已会在 flushThinkingStreams() 落库;此处跳过逐条重复。
|
||||||
// 这里跳过同 streamId 的 thinking,避免 processDetails 双份展示。
|
if eventType == "thinking" || eventType == "reasoning_chain" {
|
||||||
if eventType == "thinking" {
|
|
||||||
if dataMap, ok := data.(map[string]interface{}); ok {
|
if dataMap, ok := data.(map[string]interface{}); ok {
|
||||||
if sid, ok2 := dataMap["streamId"].(string); ok2 && sid != "" {
|
if sid, ok2 := dataMap["streamId"].(string); ok2 && sid != "" {
|
||||||
if tb, exists := thinkingStreams[sid]; exists && tb != nil {
|
if tb, exists := thinkingStreams[sid]; exists && tb != nil {
|
||||||
@@ -1239,13 +1263,17 @@ func (h *AgentHandler) createProgressCallback(runCtx context.Context, cancelRun
|
|||||||
eventType != "response_start" &&
|
eventType != "response_start" &&
|
||||||
eventType != "response_delta" &&
|
eventType != "response_delta" &&
|
||||||
eventType != "tool_result_delta" &&
|
eventType != "tool_result_delta" &&
|
||||||
|
eventType != "eino_trace_run" &&
|
||||||
|
eventType != "eino_trace_start" &&
|
||||||
|
eventType != "eino_trace_end" &&
|
||||||
|
eventType != "eino_trace_error" &&
|
||||||
eventType != "eino_agent_reply_stream_start" &&
|
eventType != "eino_agent_reply_stream_start" &&
|
||||||
eventType != "eino_agent_reply_stream_delta" &&
|
eventType != "eino_agent_reply_stream_delta" &&
|
||||||
eventType != "eino_agent_reply_stream_end" {
|
eventType != "eino_agent_reply_stream_end" {
|
||||||
if eventType == "tool_result" {
|
if eventType == "tool_result" {
|
||||||
discardPlanningIfEchoesToolResult(&respPlan, data)
|
discardPlanningIfEchoesToolResult(&respPlan, data)
|
||||||
}
|
}
|
||||||
// 在关键过程事件落库前,先把「规划中」与 thinking_stream 落库
|
// 在关键过程事件落库前,先把「规划中」与聚合中的 thinking / reasoning_chain 流落库
|
||||||
flushResponsePlan()
|
flushResponsePlan()
|
||||||
flushThinkingStreams()
|
flushThinkingStreams()
|
||||||
if err := h.db.AddProcessDetail(assistantMessageID, conversationID, eventType, message, data); err != nil {
|
if err := h.db.AddProcessDetail(assistantMessageID, conversationID, eventType, message, data); err != nil {
|
||||||
@@ -1427,14 +1455,7 @@ func (h *AgentHandler) AgentLoopStream(c *gin.Context) {
|
|||||||
h.logger.Warn("获取历史消息失败", zap.Error(err))
|
h.logger.Warn("获取历史消息失败", zap.Error(err))
|
||||||
agentHistoryMessages = []agent.ChatMessage{}
|
agentHistoryMessages = []agent.ChatMessage{}
|
||||||
} else {
|
} else {
|
||||||
// 将数据库消息转换为Agent消息格式
|
agentHistoryMessages = dbMessagesToAgentChatMessages(historyMessages)
|
||||||
agentHistoryMessages = make([]agent.ChatMessage, 0, len(historyMessages))
|
|
||||||
for _, msg := range historyMessages {
|
|
||||||
agentHistoryMessages = append(agentHistoryMessages, agent.ChatMessage{
|
|
||||||
Role: msg.Role,
|
|
||||||
Content: msg.Content,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
h.logger.Info("从消息表加载历史消息", zap.Int("count", len(agentHistoryMessages)))
|
h.logger.Info("从消息表加载历史消息", zap.Int("count", len(agentHistoryMessages)))
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
@@ -1727,20 +1748,8 @@ func (h *AgentHandler) AgentLoopStream(c *gin.Context) {
|
|||||||
|
|
||||||
// 更新助手消息内容
|
// 更新助手消息内容
|
||||||
if assistantMsg != nil {
|
if assistantMsg != nil {
|
||||||
_, err = h.db.Exec(
|
if errU := h.db.UpdateAssistantMessageFinalize(assistantMessageID, result.Response, result.MCPExecutionIDs, multiagent.AggregatedReasoningFromTraceJSON(result.LastAgentTraceInput)); errU != nil {
|
||||||
"UPDATE messages SET content = ?, mcp_execution_ids = ?, updated_at = ? WHERE id = ?",
|
h.logger.Error("更新助手消息失败", zap.Error(errU))
|
||||||
result.Response,
|
|
||||||
func() string {
|
|
||||||
if len(result.MCPExecutionIDs) > 0 {
|
|
||||||
jsonData, _ := json.Marshal(result.MCPExecutionIDs)
|
|
||||||
return string(jsonData)
|
|
||||||
}
|
|
||||||
return ""
|
|
||||||
}(),
|
|
||||||
time.Now(), assistantMessageID,
|
|
||||||
)
|
|
||||||
if err != nil {
|
|
||||||
h.logger.Error("更新助手消息失败", zap.Error(err))
|
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
// 如果之前创建失败,现在创建
|
// 如果之前创建失败,现在创建
|
||||||
@@ -1789,27 +1798,51 @@ func (h *AgentHandler) CancelAgentLoop(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
execID := h.tasks.ActiveMCPExecutionID(req.ConversationID)
|
execID := h.tasks.ActiveMCPExecutionID(req.ConversationID)
|
||||||
if execID == "" {
|
|
||||||
c.JSON(http.StatusBadRequest, gin.H{"error": "当前没有正在执行的 MCP 工具(例如模型尚在推理、尚未发起工具调用)。请等待工具开始执行后再试,或使用「彻底停止」结束整轮任务。"})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
note := strings.TrimSpace(req.Reason)
|
note := strings.TrimSpace(req.Reason)
|
||||||
if !h.agent.CancelMCPToolExecutionWithNote(execID, note) {
|
if execID != "" {
|
||||||
c.JSON(http.StatusNotFound, gin.H{"error": "未找到进行中的工具执行或该调用已结束"})
|
if !h.agent.CancelMCPToolExecutionWithNote(execID, note) {
|
||||||
|
c.JSON(http.StatusNotFound, gin.H{"error": "未找到进行中的工具执行或该调用已结束"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
h.logger.Info("对话页仅终止当前 MCP 工具",
|
||||||
|
zap.String("conversationId", req.ConversationID),
|
||||||
|
zap.String("executionId", execID),
|
||||||
|
zap.Bool("hasNote", note != ""),
|
||||||
|
)
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"status": "tool_abort_requested",
|
||||||
|
"conversationId": req.ConversationID,
|
||||||
|
"executionId": execID,
|
||||||
|
"message": "已请求终止当前工具调用;工具返回后本轮推理将继续(与 MCP 监控页终止一致)。",
|
||||||
|
"continueAfter": true,
|
||||||
|
"interruptWithNote": note != "",
|
||||||
|
"continueWithoutTool": false,
|
||||||
|
})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
h.logger.Info("对话页仅终止当前 MCP 工具",
|
// 无进行中的 MCP 工具(模型纯推理/流式输出阶段):取消当前上下文并由 Eino 流式处理器合并用户补充后自动续跑。
|
||||||
|
h.tasks.SetInterruptContinueNote(req.ConversationID, note)
|
||||||
|
ok, err := h.tasks.CancelTask(req.ConversationID, multiagent.ErrInterruptContinue)
|
||||||
|
if err != nil {
|
||||||
|
h.logger.Error("中断并继续(无工具)失败", zap.Error(err))
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if !ok {
|
||||||
|
c.JSON(http.StatusNotFound, gin.H{"error": "未找到正在执行的任务"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
h.logger.Info("对话页中断并继续(无 MCP 工具,将自动续跑)",
|
||||||
zap.String("conversationId", req.ConversationID),
|
zap.String("conversationId", req.ConversationID),
|
||||||
zap.String("executionId", execID),
|
|
||||||
zap.Bool("hasNote", note != ""),
|
zap.Bool("hasNote", note != ""),
|
||||||
)
|
)
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
"status": "tool_abort_requested",
|
"status": "interrupt_continue_scheduled",
|
||||||
"conversationId": req.ConversationID,
|
"conversationId": req.ConversationID,
|
||||||
"executionId": execID,
|
"message": "已请求暂停当前推理;用户补充将合并到上下文并自动继续执行(无需整轮停止)。",
|
||||||
"message": "已请求终止当前工具调用;工具返回后本轮推理将继续(与 MCP 监控页终止一致)。",
|
"continueAfter": true,
|
||||||
"continueAfter": true,
|
"interruptWithNote": note != "",
|
||||||
"interruptWithNote": note != "",
|
"continueWithoutTool": true,
|
||||||
})
|
})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -2640,12 +2673,12 @@ func (h *AgentHandler) executeBatchQueue(queueID string) {
|
|||||||
var runErr error
|
var runErr error
|
||||||
switch {
|
switch {
|
||||||
case useBatchMulti:
|
case useBatchMulti:
|
||||||
resultMA, runErr = multiagent.RunDeepAgent(taskCtx, h.config, &h.config.MultiAgent, h.agent, h.logger, conversationID, finalMessage, []agent.ChatMessage{}, roleTools, progressCallback, h.agentsMarkdownDir, batchOrch)
|
resultMA, runErr = multiagent.RunDeepAgent(taskCtx, h.config, &h.config.MultiAgent, h.agent, h.logger, conversationID, finalMessage, []agent.ChatMessage{}, roleTools, progressCallback, h.agentsMarkdownDir, batchOrch, nil)
|
||||||
case useEinoSingle:
|
case useEinoSingle:
|
||||||
if h.config == nil {
|
if h.config == nil {
|
||||||
runErr = fmt.Errorf("服务器配置未加载")
|
runErr = fmt.Errorf("服务器配置未加载")
|
||||||
} else {
|
} else {
|
||||||
resultMA, runErr = multiagent.RunEinoSingleChatModelAgent(taskCtx, h.config, &h.config.MultiAgent, h.agent, h.logger, conversationID, finalMessage, []agent.ChatMessage{}, roleTools, progressCallback)
|
resultMA, runErr = multiagent.RunEinoSingleChatModelAgent(taskCtx, h.config, &h.config.MultiAgent, h.agent, h.logger, conversationID, finalMessage, []agent.ChatMessage{}, roleTools, progressCallback, nil)
|
||||||
}
|
}
|
||||||
default:
|
default:
|
||||||
result, runErr = h.agent.AgentLoopWithProgress(taskCtx, finalMessage, []agent.ChatMessage{}, conversationID, progressCallback, roleTools)
|
result, runErr = h.agent.AgentLoopWithProgress(taskCtx, finalMessage, []agent.ChatMessage{}, conversationID, progressCallback, roleTools)
|
||||||
@@ -2744,17 +2777,7 @@ func (h *AgentHandler) executeBatchQueue(queueID string) {
|
|||||||
|
|
||||||
// 更新助手消息内容
|
// 更新助手消息内容
|
||||||
if assistantMessageID != "" {
|
if assistantMessageID != "" {
|
||||||
mcpIDsJSON := ""
|
if updateErr := h.db.UpdateAssistantMessageFinalize(assistantMessageID, resText, mcpIDs, multiagent.AggregatedReasoningFromTraceJSON(lastIn)); updateErr != nil {
|
||||||
if len(mcpIDs) > 0 {
|
|
||||||
jsonData, _ := json.Marshal(mcpIDs)
|
|
||||||
mcpIDsJSON = string(jsonData)
|
|
||||||
}
|
|
||||||
if _, updateErr := h.db.Exec(
|
|
||||||
"UPDATE messages SET content = ?, mcp_execution_ids = ?, updated_at = ? WHERE id = ?",
|
|
||||||
resText,
|
|
||||||
mcpIDsJSON,
|
|
||||||
time.Now(), assistantMessageID,
|
|
||||||
); updateErr != nil {
|
|
||||||
h.logger.Warn("更新助手消息失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.Error(updateErr))
|
h.logger.Warn("更新助手消息失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.Error(updateErr))
|
||||||
// 如果更新失败,尝试创建新消息
|
// 如果更新失败,尝试创建新消息
|
||||||
_, err = h.db.AddMessage(conversationID, "assistant", resText, mcpIDs)
|
_, err = h.db.AddMessage(conversationID, "assistant", resText, mcpIDs)
|
||||||
@@ -2846,6 +2869,10 @@ func (h *AgentHandler) loadHistoryFromAgentTrace(conversationID string) ([]agent
|
|||||||
if content, ok := msgMap["content"].(string); ok {
|
if content, ok := msgMap["content"].(string); ok {
|
||||||
msg.Content = content
|
msg.Content = content
|
||||||
}
|
}
|
||||||
|
// DeepSeek 思考模式:含工具调用的 assistant 须在后续请求中回传 reasoning_content
|
||||||
|
if rc, ok := msgMap["reasoning_content"].(string); ok && strings.TrimSpace(rc) != "" {
|
||||||
|
msg.ReasoningContent = rc
|
||||||
|
}
|
||||||
|
|
||||||
// 解析tool_calls(如果存在)
|
// 解析tool_calls(如果存在)
|
||||||
if toolCallsRaw, ok := msgMap["tool_calls"]; ok && toolCallsRaw != nil {
|
if toolCallsRaw, ok := msgMap["tool_calls"]; ok && toolCallsRaw != nil {
|
||||||
@@ -2901,6 +2928,11 @@ func (h *AgentHandler) loadHistoryFromAgentTrace(conversationID string) ([]agent
|
|||||||
if toolCallID, ok := msgMap["tool_call_id"].(string); ok {
|
if toolCallID, ok := msgMap["tool_call_id"].(string); ok {
|
||||||
msg.ToolCallID = toolCallID
|
msg.ToolCallID = toolCallID
|
||||||
}
|
}
|
||||||
|
if tn, ok := msgMap["tool_name"].(string); ok && strings.TrimSpace(tn) != "" {
|
||||||
|
msg.ToolName = strings.TrimSpace(tn)
|
||||||
|
} else if tn, ok := msgMap["name"].(string); ok && strings.TrimSpace(tn) != "" && strings.EqualFold(msg.Role, "tool") {
|
||||||
|
msg.ToolName = strings.TrimSpace(tn)
|
||||||
|
}
|
||||||
|
|
||||||
agentMessages = append(agentMessages, msg)
|
agentMessages = append(agentMessages, msg)
|
||||||
}
|
}
|
||||||
@@ -2946,3 +2978,18 @@ func (h *AgentHandler) loadHistoryFromAgentTrace(conversationID string) ([]agent
|
|||||||
)
|
)
|
||||||
return agentMessages, nil
|
return agentMessages, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// dbMessagesToAgentChatMessages maps DB rows to agent ChatMessage for history fallback
|
||||||
|
// (includes reasoning_content for DeepSeek thinking + tool replay).
|
||||||
|
func dbMessagesToAgentChatMessages(msgs []database.Message) []agent.ChatMessage {
|
||||||
|
out := make([]agent.ChatMessage, 0, len(msgs))
|
||||||
|
for i := range msgs {
|
||||||
|
m := msgs[i]
|
||||||
|
out = append(out, agent.ChatMessage{
|
||||||
|
Role: m.Role,
|
||||||
|
Content: m.Content,
|
||||||
|
ReasoningContent: m.ReasoningContent,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|||||||
+111
-15
@@ -206,6 +206,25 @@ func (h *ConfigHandler) SetRobotRestarter(restarter RobotRestarter) {
|
|||||||
h.robotRestarter = restarter
|
h.robotRestarter = restarter
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ApplyWechatRobotBinding 微信 iLink 扫码绑定成功后写入配置并重启机器人连接
|
||||||
|
func (h *ConfigHandler) ApplyWechatRobotBinding(wc config.RobotWechatConfig) error {
|
||||||
|
h.mu.Lock()
|
||||||
|
wc.Enabled = true
|
||||||
|
h.config.Robots.Wechat = wc
|
||||||
|
h.mu.Unlock()
|
||||||
|
if err := h.saveConfig(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if h.robotRestarter != nil {
|
||||||
|
h.robotRestarter.RestartRobotConnections()
|
||||||
|
}
|
||||||
|
h.logger.Info("微信机器人绑定已保存",
|
||||||
|
zap.String("ilink_bot_id", wc.ILinkBotID),
|
||||||
|
zap.Bool("enabled", wc.Enabled),
|
||||||
|
)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// GetConfigResponse 获取配置响应
|
// GetConfigResponse 获取配置响应
|
||||||
type GetConfigResponse struct {
|
type GetConfigResponse struct {
|
||||||
OpenAI config.OpenAIConfig `json:"openai"`
|
OpenAI config.OpenAIConfig `json:"openai"`
|
||||||
@@ -609,15 +628,46 @@ func (h *ConfigHandler) GetTools(c *gin.Context) {
|
|||||||
|
|
||||||
// UpdateConfigRequest 更新配置请求
|
// UpdateConfigRequest 更新配置请求
|
||||||
type UpdateConfigRequest struct {
|
type UpdateConfigRequest struct {
|
||||||
OpenAI *config.OpenAIConfig `json:"openai,omitempty"`
|
OpenAI *config.OpenAIConfig `json:"openai,omitempty"`
|
||||||
FOFA *config.FofaConfig `json:"fofa,omitempty"`
|
FOFA *config.FofaConfig `json:"fofa,omitempty"`
|
||||||
MCP *config.MCPConfig `json:"mcp,omitempty"`
|
MCP *config.MCPConfig `json:"mcp,omitempty"`
|
||||||
Tools []ToolEnableStatus `json:"tools,omitempty"`
|
Tools []ToolEnableStatus `json:"tools,omitempty"`
|
||||||
Agent *config.AgentConfig `json:"agent,omitempty"`
|
Agent *AgentConfigUpdate `json:"agent,omitempty"`
|
||||||
Knowledge *config.KnowledgeConfig `json:"knowledge,omitempty"`
|
Knowledge *config.KnowledgeConfig `json:"knowledge,omitempty"`
|
||||||
Robots *config.RobotsConfig `json:"robots,omitempty"`
|
Robots *config.RobotsConfig `json:"robots,omitempty"`
|
||||||
MultiAgent *config.MultiAgentAPIUpdate `json:"multi_agent,omitempty"`
|
MultiAgent *config.MultiAgentAPIUpdate `json:"multi_agent,omitempty"`
|
||||||
C2 *config.C2APIUpdate `json:"c2,omitempty"`
|
C2 *config.C2APIUpdate `json:"c2,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// AgentConfigUpdate 用于 PATCH /api/config 的 agent 段:仅 JSON 中出现的字段(指针非 nil)覆盖内存配置。
|
||||||
|
// 避免旧版「整包替换 *AgentConfig」时,未传的整型字段被反序列化为 0 误覆盖(例如 tool_timeout_minutes 变成 0)。
|
||||||
|
type AgentConfigUpdate struct {
|
||||||
|
MaxIterations *int `json:"max_iterations,omitempty"`
|
||||||
|
LargeResultThreshold *int `json:"large_result_threshold,omitempty"`
|
||||||
|
ResultStorageDir *string `json:"result_storage_dir,omitempty"`
|
||||||
|
ToolTimeoutMinutes *int `json:"tool_timeout_minutes,omitempty"`
|
||||||
|
SystemPromptPath *string `json:"system_prompt_path,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func applyAgentConfigUpdate(dst *config.AgentConfig, src *AgentConfigUpdate) {
|
||||||
|
if dst == nil || src == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if src.MaxIterations != nil {
|
||||||
|
dst.MaxIterations = *src.MaxIterations
|
||||||
|
}
|
||||||
|
if src.LargeResultThreshold != nil {
|
||||||
|
dst.LargeResultThreshold = *src.LargeResultThreshold
|
||||||
|
}
|
||||||
|
if src.ResultStorageDir != nil {
|
||||||
|
dst.ResultStorageDir = *src.ResultStorageDir
|
||||||
|
}
|
||||||
|
if src.ToolTimeoutMinutes != nil {
|
||||||
|
dst.ToolTimeoutMinutes = *src.ToolTimeoutMinutes
|
||||||
|
}
|
||||||
|
if src.SystemPromptPath != nil {
|
||||||
|
dst.SystemPromptPath = *src.SystemPromptPath
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// ToolEnableStatus 工具启用状态
|
// ToolEnableStatus 工具启用状态
|
||||||
@@ -664,12 +714,19 @@ func (h *ConfigHandler) UpdateConfig(c *gin.Context) {
|
|||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
// 更新Agent配置
|
// 更新Agent配置(按字段合并,避免部分 JSON 把未出现的字段写成 0)
|
||||||
if req.Agent != nil {
|
if req.Agent != nil {
|
||||||
h.config.Agent = *req.Agent
|
applyAgentConfigUpdate(&h.config.Agent, req.Agent)
|
||||||
h.logger.Info("更新Agent配置",
|
h.logger.Info("更新Agent配置",
|
||||||
zap.Int("max_iterations", h.config.Agent.MaxIterations),
|
zap.Int("max_iterations", h.config.Agent.MaxIterations),
|
||||||
|
zap.Int("tool_timeout_minutes", h.config.Agent.ToolTimeoutMinutes),
|
||||||
)
|
)
|
||||||
|
if h.agent != nil && req.Agent.MaxIterations != nil {
|
||||||
|
h.agent.UpdateMaxIterations(h.config.Agent.MaxIterations)
|
||||||
|
}
|
||||||
|
if h.mcpServer != nil {
|
||||||
|
h.mcpServer.ConfigureHTTPToolCallTimeoutFromAgentMinutes(h.config.Agent.ToolTimeoutMinutes)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// 更新Knowledge配置
|
// 更新Knowledge配置
|
||||||
@@ -697,6 +754,7 @@ func (h *ConfigHandler) UpdateConfig(c *gin.Context) {
|
|||||||
if req.Robots != nil {
|
if req.Robots != nil {
|
||||||
h.config.Robots = *req.Robots
|
h.config.Robots = *req.Robots
|
||||||
h.logger.Info("更新机器人配置",
|
h.logger.Info("更新机器人配置",
|
||||||
|
zap.Bool("wechat_enabled", h.config.Robots.Wechat.Enabled),
|
||||||
zap.Bool("wecom_enabled", h.config.Robots.Wecom.Enabled),
|
zap.Bool("wecom_enabled", h.config.Robots.Wecom.Enabled),
|
||||||
zap.Bool("dingtalk_enabled", h.config.Robots.Dingtalk.Enabled),
|
zap.Bool("dingtalk_enabled", h.config.Robots.Dingtalk.Enabled),
|
||||||
zap.Bool("lark_enabled", h.config.Robots.Lark.Enabled),
|
zap.Bool("lark_enabled", h.config.Robots.Lark.Enabled),
|
||||||
@@ -717,7 +775,9 @@ func (h *ConfigHandler) UpdateConfig(c *gin.Context) {
|
|||||||
if req.MultiAgent.PlanExecuteLoopMaxIterations != nil {
|
if req.MultiAgent.PlanExecuteLoopMaxIterations != nil {
|
||||||
h.config.MultiAgent.PlanExecuteLoopMaxIterations = *req.MultiAgent.PlanExecuteLoopMaxIterations
|
h.config.MultiAgent.PlanExecuteLoopMaxIterations = *req.MultiAgent.PlanExecuteLoopMaxIterations
|
||||||
}
|
}
|
||||||
h.config.MultiAgent.EinoMiddleware.ToolSearchAlwaysVisibleTools = dedupeToolNameList(req.MultiAgent.ToolSearchAlwaysVisibleTools)
|
if req.MultiAgent.ToolSearchAlwaysVisibleTools != nil {
|
||||||
|
h.config.MultiAgent.EinoMiddleware.ToolSearchAlwaysVisibleTools = dedupeToolNameList(*req.MultiAgent.ToolSearchAlwaysVisibleTools)
|
||||||
|
}
|
||||||
h.logger.Info("更新多代理配置",
|
h.logger.Info("更新多代理配置",
|
||||||
zap.Bool("enabled", h.config.MultiAgent.Enabled),
|
zap.Bool("enabled", h.config.MultiAgent.Enabled),
|
||||||
zap.Bool("robot_use_multi_agent", h.config.MultiAgent.RobotUseMultiAgent),
|
zap.Bool("robot_use_multi_agent", h.config.MultiAgent.RobotUseMultiAgent),
|
||||||
@@ -1116,6 +1176,9 @@ func (h *ConfigHandler) ApplyConfig(c *gin.Context) {
|
|||||||
h.agent.UpdateToolDescriptionMode(h.config.Security.ToolDescriptionMode)
|
h.agent.UpdateToolDescriptionMode(h.config.Security.ToolDescriptionMode)
|
||||||
h.logger.Info("Agent配置已更新")
|
h.logger.Info("Agent配置已更新")
|
||||||
}
|
}
|
||||||
|
if h.mcpServer != nil {
|
||||||
|
h.mcpServer.ConfigureHTTPToolCallTimeoutFromAgentMinutes(h.config.Agent.ToolTimeoutMinutes)
|
||||||
|
}
|
||||||
|
|
||||||
// 更新AttackChainHandler的OpenAI配置
|
// 更新AttackChainHandler的OpenAI配置
|
||||||
if h.attackChainHandler != nil {
|
if h.attackChainHandler != nil {
|
||||||
@@ -1181,7 +1244,7 @@ func (h *ConfigHandler) saveConfig() error {
|
|||||||
return fmt.Errorf("解析配置文件失败: %w", err)
|
return fmt.Errorf("解析配置文件失败: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
updateAgentConfig(root, h.config.Agent.MaxIterations)
|
updateAgentConfig(root, h.config.Agent)
|
||||||
updateMCPConfig(root, h.config.MCP)
|
updateMCPConfig(root, h.config.MCP)
|
||||||
updateOpenAIConfig(root, h.config.OpenAI)
|
updateOpenAIConfig(root, h.config.OpenAI)
|
||||||
updateFOFAConfig(root, h.config.FOFA)
|
updateFOFAConfig(root, h.config.FOFA)
|
||||||
@@ -1286,10 +1349,14 @@ func writeYAMLDocument(path string, doc *yaml.Node) error {
|
|||||||
return os.WriteFile(path, buf.Bytes(), 0644)
|
return os.WriteFile(path, buf.Bytes(), 0644)
|
||||||
}
|
}
|
||||||
|
|
||||||
func updateAgentConfig(doc *yaml.Node, maxIterations int) {
|
func updateAgentConfig(doc *yaml.Node, agent config.AgentConfig) {
|
||||||
root := doc.Content[0]
|
root := doc.Content[0]
|
||||||
agentNode := ensureMap(root, "agent")
|
agentNode := ensureMap(root, "agent")
|
||||||
setIntInMap(agentNode, "max_iterations", maxIterations)
|
setIntInMap(agentNode, "max_iterations", agent.MaxIterations)
|
||||||
|
setIntInMap(agentNode, "tool_timeout_minutes", agent.ToolTimeoutMinutes)
|
||||||
|
setIntInMap(agentNode, "large_result_threshold", agent.LargeResultThreshold)
|
||||||
|
setStringInMap(agentNode, "result_storage_dir", agent.ResultStorageDir)
|
||||||
|
setStringInMap(agentNode, "system_prompt_path", agent.SystemPromptPath)
|
||||||
}
|
}
|
||||||
|
|
||||||
func updateMCPConfig(doc *yaml.Node, cfg config.MCPConfig) {
|
func updateMCPConfig(doc *yaml.Node, cfg config.MCPConfig) {
|
||||||
@@ -1312,6 +1379,19 @@ func updateOpenAIConfig(doc *yaml.Node, cfg config.OpenAIConfig) {
|
|||||||
if cfg.MaxTotalTokens > 0 {
|
if cfg.MaxTotalTokens > 0 {
|
||||||
setIntInMap(openaiNode, "max_total_tokens", cfg.MaxTotalTokens)
|
setIntInMap(openaiNode, "max_total_tokens", cfg.MaxTotalTokens)
|
||||||
}
|
}
|
||||||
|
rn := ensureMap(openaiNode, "reasoning")
|
||||||
|
if strings.TrimSpace(cfg.Reasoning.Mode) != "" {
|
||||||
|
setStringInMap(rn, "mode", cfg.Reasoning.Mode)
|
||||||
|
}
|
||||||
|
if strings.TrimSpace(cfg.Reasoning.Effort) != "" {
|
||||||
|
setStringInMap(rn, "effort", cfg.Reasoning.Effort)
|
||||||
|
}
|
||||||
|
if cfg.Reasoning.AllowClientReasoning != nil {
|
||||||
|
setBoolInMap(rn, "allow_client_reasoning", *cfg.Reasoning.AllowClientReasoning)
|
||||||
|
}
|
||||||
|
if strings.TrimSpace(cfg.Reasoning.Profile) != "" {
|
||||||
|
setStringInMap(rn, "profile", cfg.Reasoning.Profile)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func updateFOFAConfig(doc *yaml.Node, cfg config.FofaConfig) {
|
func updateFOFAConfig(doc *yaml.Node, cfg config.FofaConfig) {
|
||||||
@@ -1416,6 +1496,20 @@ func updateRobotsConfig(doc *yaml.Node, cfg config.RobotsConfig) {
|
|||||||
root := doc.Content[0]
|
root := doc.Content[0]
|
||||||
robotsNode := ensureMap(root, "robots")
|
robotsNode := ensureMap(root, "robots")
|
||||||
|
|
||||||
|
if cfg.Session.StrictUserIdentity != nil {
|
||||||
|
sessionNode := ensureMap(robotsNode, "session")
|
||||||
|
setBoolInMap(sessionNode, "strict_user_identity", *cfg.Session.StrictUserIdentity)
|
||||||
|
}
|
||||||
|
|
||||||
|
wechatNode := ensureMap(robotsNode, "wechat")
|
||||||
|
setBoolInMap(wechatNode, "enabled", cfg.Wechat.Enabled)
|
||||||
|
setStringInMap(wechatNode, "bot_token", cfg.Wechat.BotToken)
|
||||||
|
setStringInMap(wechatNode, "ilink_bot_id", cfg.Wechat.ILinkBotID)
|
||||||
|
setStringInMap(wechatNode, "ilink_user_id", cfg.Wechat.ILinkUserID)
|
||||||
|
setStringInMap(wechatNode, "base_url", cfg.Wechat.BaseURL)
|
||||||
|
setStringInMap(wechatNode, "bot_type", cfg.Wechat.BotType)
|
||||||
|
setStringInMap(wechatNode, "bot_agent", cfg.Wechat.BotAgent)
|
||||||
|
|
||||||
wecomNode := ensureMap(robotsNode, "wecom")
|
wecomNode := ensureMap(robotsNode, "wecom")
|
||||||
setBoolInMap(wecomNode, "enabled", cfg.Wecom.Enabled)
|
setBoolInMap(wecomNode, "enabled", cfg.Wecom.Enabled)
|
||||||
setStringInMap(wecomNode, "token", cfg.Wecom.Token)
|
setStringInMap(wecomNode, "token", cfg.Wecom.Token)
|
||||||
@@ -1428,12 +1522,14 @@ func updateRobotsConfig(doc *yaml.Node, cfg config.RobotsConfig) {
|
|||||||
setBoolInMap(dingtalkNode, "enabled", cfg.Dingtalk.Enabled)
|
setBoolInMap(dingtalkNode, "enabled", cfg.Dingtalk.Enabled)
|
||||||
setStringInMap(dingtalkNode, "client_id", cfg.Dingtalk.ClientID)
|
setStringInMap(dingtalkNode, "client_id", cfg.Dingtalk.ClientID)
|
||||||
setStringInMap(dingtalkNode, "client_secret", cfg.Dingtalk.ClientSecret)
|
setStringInMap(dingtalkNode, "client_secret", cfg.Dingtalk.ClientSecret)
|
||||||
|
setBoolInMap(dingtalkNode, "allow_conversation_id_fallback", cfg.Dingtalk.AllowConversationIDFallback)
|
||||||
|
|
||||||
larkNode := ensureMap(robotsNode, "lark")
|
larkNode := ensureMap(robotsNode, "lark")
|
||||||
setBoolInMap(larkNode, "enabled", cfg.Lark.Enabled)
|
setBoolInMap(larkNode, "enabled", cfg.Lark.Enabled)
|
||||||
setStringInMap(larkNode, "app_id", cfg.Lark.AppID)
|
setStringInMap(larkNode, "app_id", cfg.Lark.AppID)
|
||||||
setStringInMap(larkNode, "app_secret", cfg.Lark.AppSecret)
|
setStringInMap(larkNode, "app_secret", cfg.Lark.AppSecret)
|
||||||
setStringInMap(larkNode, "verify_token", cfg.Lark.VerifyToken)
|
setStringInMap(larkNode, "verify_token", cfg.Lark.VerifyToken)
|
||||||
|
setBoolInMap(larkNode, "allow_chat_id_fallback", cfg.Lark.AllowChatIDFallback)
|
||||||
}
|
}
|
||||||
|
|
||||||
func updateMultiAgentConfig(doc *yaml.Node, cfg config.MultiAgentConfig) {
|
func updateMultiAgentConfig(doc *yaml.Node, cfg config.MultiAgentConfig) {
|
||||||
|
|||||||
@@ -117,6 +117,8 @@ func (h *ConversationHandler) GetMessageProcessDetails(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
details = database.DedupeConsecutiveProcessDetails(details)
|
||||||
|
|
||||||
// 转换为前端期望的 JSON 结构(与 GetConversation 中 processDetails 结构一致)
|
// 转换为前端期望的 JSON 结构(与 GetConversation 中 processDetails 结构一致)
|
||||||
out := make([]map[string]interface{}, 0, len(details))
|
out := make([]map[string]interface{}, 0, len(details))
|
||||||
for _, d := range details {
|
for _, d := range details {
|
||||||
|
|||||||
@@ -46,7 +46,7 @@ func (h *AgentHandler) EinoSingleAgentLoopStream(c *gin.Context) {
|
|||||||
sendEvent := func(eventType, message string, data interface{}) {
|
sendEvent := func(eventType, message string, data interface{}) {
|
||||||
if eventType == "error" && baseCtx != nil {
|
if eventType == "error" && baseCtx != nil {
|
||||||
cause := context.Cause(baseCtx)
|
cause := context.Cause(baseCtx)
|
||||||
if errors.Is(cause, ErrTaskCancelled) {
|
if errors.Is(cause, ErrTaskCancelled) || errors.Is(cause, multiagent.ErrInterruptContinue) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -175,29 +175,69 @@ func (h *AgentHandler) EinoSingleAgentLoopStream(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
taskOwned = true
|
taskOwned = true
|
||||||
|
|
||||||
progressCallback := h.createProgressCallback(taskCtx, cancelWithCause, conversationID, assistantMessageID, sendEvent)
|
var cumulativeMCPExecutionIDs []string
|
||||||
taskCtx = mcp.WithMCPConversationID(taskCtx, conversationID)
|
|
||||||
taskCtx = mcp.WithToolRunRegistry(taskCtx, h.tasks)
|
|
||||||
taskCtx = multiagent.WithHITLToolInterceptor(taskCtx, func(ctx context.Context, toolName, arguments string) (string, error) {
|
|
||||||
return h.interceptHITLForEinoTool(ctx, cancelWithCause, conversationID, assistantMessageID, sendEvent, toolName, arguments)
|
|
||||||
})
|
|
||||||
|
|
||||||
result, runErr = multiagent.RunEinoSingleChatModelAgent(
|
for {
|
||||||
taskCtx,
|
progressCallback := h.createProgressCallback(taskCtx, cancelWithCause, conversationID, assistantMessageID, sendEvent)
|
||||||
h.config,
|
taskCtxLoop := mcp.WithMCPConversationID(taskCtx, conversationID)
|
||||||
&h.config.MultiAgent,
|
taskCtxLoop = mcp.WithToolRunRegistry(taskCtxLoop, h.tasks)
|
||||||
h.agent,
|
taskCtxLoop = multiagent.WithHITLToolInterceptor(taskCtxLoop, func(ctx context.Context, toolName, arguments string) (string, error) {
|
||||||
h.logger,
|
return h.interceptHITLForEinoTool(ctx, cancelWithCause, conversationID, assistantMessageID, sendEvent, toolName, arguments)
|
||||||
conversationID,
|
})
|
||||||
curFinalMessage,
|
|
||||||
curHistory,
|
result, runErr = multiagent.RunEinoSingleChatModelAgent(
|
||||||
roleTools,
|
taskCtxLoop,
|
||||||
progressCallback,
|
h.config,
|
||||||
)
|
&h.config.MultiAgent,
|
||||||
timeoutCancel()
|
h.agent,
|
||||||
|
h.logger,
|
||||||
|
conversationID,
|
||||||
|
curFinalMessage,
|
||||||
|
curHistory,
|
||||||
|
roleTools,
|
||||||
|
progressCallback,
|
||||||
|
chatReasoningToClientIntent(req.Reasoning),
|
||||||
|
)
|
||||||
|
timeoutCancel()
|
||||||
|
|
||||||
|
if result != nil && len(result.MCPExecutionIDs) > 0 {
|
||||||
|
cumulativeMCPExecutionIDs = mergeMCPExecutionIDLists(cumulativeMCPExecutionIDs, result.MCPExecutionIDs)
|
||||||
|
}
|
||||||
|
|
||||||
|
if runErr == nil {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
if runErr != nil {
|
|
||||||
cause := context.Cause(baseCtx)
|
cause := context.Cause(baseCtx)
|
||||||
|
if errors.Is(cause, multiagent.ErrInterruptContinue) {
|
||||||
|
if shouldPersistEinoAgentTraceAfterRunError(baseCtx) {
|
||||||
|
h.persistEinoAgentTraceForResume(conversationID, result)
|
||||||
|
}
|
||||||
|
note := h.tasks.TakeInterruptContinueNote(conversationID)
|
||||||
|
icSummary := interruptContinueTimelineSummary(note)
|
||||||
|
progressCallback("user_interrupt_continue", icSummary, map[string]interface{}{
|
||||||
|
"conversationId": conversationID,
|
||||||
|
"rawReason": strings.TrimSpace(note),
|
||||||
|
"emptyReason": strings.TrimSpace(note) == "",
|
||||||
|
"kind": "no_active_mcp_tool",
|
||||||
|
})
|
||||||
|
inject := formatInterruptContinueUserMessage(note)
|
||||||
|
// 不写入 messages 表为 user 气泡:避免主对话流出现大段模板;说明已由 user_interrupt_continue 记入助手 process_details(迭代详情)。
|
||||||
|
if hist, err := h.loadHistoryFromAgentTrace(conversationID); err == nil && len(hist) > 0 {
|
||||||
|
curHistory = hist
|
||||||
|
}
|
||||||
|
curFinalMessage = inject
|
||||||
|
sendEvent("progress", "已合并用户补充与最新轨迹,正在继续推理…", map[string]interface{}{
|
||||||
|
"conversationId": conversationID,
|
||||||
|
"source": "interrupt_continue",
|
||||||
|
})
|
||||||
|
h.tasks.UpdateTaskStatus(conversationID, "running")
|
||||||
|
baseCtx, cancelWithCause = context.WithCancelCause(context.Background())
|
||||||
|
h.tasks.BindTaskCancel(conversationID, cancelWithCause)
|
||||||
|
taskCtx, timeoutCancel = context.WithTimeout(baseCtx, 600*time.Minute)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
if shouldPersistEinoAgentTraceAfterRunError(baseCtx) {
|
if shouldPersistEinoAgentTraceAfterRunError(baseCtx) {
|
||||||
h.persistEinoAgentTraceForResume(conversationID, result)
|
h.persistEinoAgentTraceForResume(conversationID, result)
|
||||||
}
|
}
|
||||||
@@ -258,18 +298,7 @@ func (h *AgentHandler) EinoSingleAgentLoopStream(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if assistantMessageID != "" {
|
if assistantMessageID != "" {
|
||||||
mcpIDsJSON := ""
|
_ = h.db.UpdateAssistantMessageFinalize(assistantMessageID, result.Response, cumulativeMCPExecutionIDs, multiagent.AggregatedReasoningFromTraceJSON(result.LastAgentTraceInput))
|
||||||
if len(result.MCPExecutionIDs) > 0 {
|
|
||||||
jsonData, _ := json.Marshal(result.MCPExecutionIDs)
|
|
||||||
mcpIDsJSON = string(jsonData)
|
|
||||||
}
|
|
||||||
_, _ = h.db.Exec(
|
|
||||||
"UPDATE messages SET content = ?, mcp_execution_ids = ?, updated_at = ? WHERE id = ?",
|
|
||||||
result.Response,
|
|
||||||
mcpIDsJSON,
|
|
||||||
time.Now(),
|
|
||||||
assistantMessageID,
|
|
||||||
)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if result.LastAgentTraceInput != "" || result.LastAgentTraceOutput != "" {
|
if result.LastAgentTraceInput != "" || result.LastAgentTraceOutput != "" {
|
||||||
@@ -279,7 +308,7 @@ func (h *AgentHandler) EinoSingleAgentLoopStream(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
sendEvent("response", result.Response, map[string]interface{}{
|
sendEvent("response", result.Response, map[string]interface{}{
|
||||||
"mcpExecutionIds": result.MCPExecutionIDs,
|
"mcpExecutionIds": cumulativeMCPExecutionIDs,
|
||||||
"conversationId": conversationID,
|
"conversationId": conversationID,
|
||||||
"messageId": assistantMessageID,
|
"messageId": assistantMessageID,
|
||||||
"agentMode": "eino_single",
|
"agentMode": "eino_single",
|
||||||
@@ -337,6 +366,7 @@ func (h *AgentHandler) EinoSingleAgentLoop(c *gin.Context) {
|
|||||||
prep.History,
|
prep.History,
|
||||||
prep.RoleTools,
|
prep.RoleTools,
|
||||||
progressCallback,
|
progressCallback,
|
||||||
|
chatReasoningToClientIntent(req.Reasoning),
|
||||||
)
|
)
|
||||||
if runErr != nil {
|
if runErr != nil {
|
||||||
if shouldPersistEinoAgentTraceAfterRunError(baseCtx) {
|
if shouldPersistEinoAgentTraceAfterRunError(baseCtx) {
|
||||||
@@ -347,18 +377,7 @@ func (h *AgentHandler) EinoSingleAgentLoop(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if prep.AssistantMessageID != "" {
|
if prep.AssistantMessageID != "" {
|
||||||
mcpIDsJSON := ""
|
_ = h.db.UpdateAssistantMessageFinalize(prep.AssistantMessageID, result.Response, result.MCPExecutionIDs, multiagent.AggregatedReasoningFromTraceJSON(result.LastAgentTraceInput))
|
||||||
if len(result.MCPExecutionIDs) > 0 {
|
|
||||||
jsonData, _ := json.Marshal(result.MCPExecutionIDs)
|
|
||||||
mcpIDsJSON = string(jsonData)
|
|
||||||
}
|
|
||||||
_, _ = h.db.Exec(
|
|
||||||
"UPDATE messages SET content = ?, mcp_execution_ids = ?, updated_at = ? WHERE id = ?",
|
|
||||||
result.Response,
|
|
||||||
mcpIDsJSON,
|
|
||||||
time.Now(),
|
|
||||||
prep.AssistantMessageID,
|
|
||||||
)
|
|
||||||
}
|
}
|
||||||
if result.LastAgentTraceInput != "" || result.LastAgentTraceOutput != "" {
|
if result.LastAgentTraceInput != "" || result.LastAgentTraceOutput != "" {
|
||||||
_ = h.db.SaveAgentTrace(prep.ConversationID, result.LastAgentTraceInput, result.LastAgentTraceOutput)
|
_ = h.db.SaveAgentTrace(prep.ConversationID, result.LastAgentTraceInput, result.LastAgentTraceOutput)
|
||||||
|
|||||||
+114
-48
@@ -63,7 +63,7 @@ func (h *AgentHandler) MultiAgentLoopStream(c *gin.Context) {
|
|||||||
// 为避免 UI 看到“取消错误 + cancelled 文案”两条回复,这里直接丢弃取消对应的 error。
|
// 为避免 UI 看到“取消错误 + cancelled 文案”两条回复,这里直接丢弃取消对应的 error。
|
||||||
if eventType == "error" && baseCtx != nil {
|
if eventType == "error" && baseCtx != nil {
|
||||||
cause := context.Cause(baseCtx)
|
cause := context.Cause(baseCtx)
|
||||||
if errors.Is(cause, ErrTaskCancelled) {
|
if errors.Is(cause, ErrTaskCancelled) || errors.Is(cause, multiagent.ErrInterruptContinue) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -184,31 +184,72 @@ func (h *AgentHandler) MultiAgentLoopStream(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
taskOwned = true
|
taskOwned = true
|
||||||
|
|
||||||
progressCallback := h.createProgressCallback(taskCtx, cancelWithCause, conversationID, assistantMessageID, sendEvent)
|
// 同一 HTTP 流内多段 Run(如中断并继续)合并 MCP execution id,供最终 response / 库表与工具芯片展示完整列表
|
||||||
taskCtx = mcp.WithMCPConversationID(taskCtx, conversationID)
|
var cumulativeMCPExecutionIDs []string
|
||||||
taskCtx = mcp.WithToolRunRegistry(taskCtx, h.tasks)
|
|
||||||
taskCtx = multiagent.WithHITLToolInterceptor(taskCtx, func(ctx context.Context, toolName, arguments string) (string, error) {
|
|
||||||
return h.interceptHITLForEinoTool(ctx, cancelWithCause, conversationID, assistantMessageID, sendEvent, toolName, arguments)
|
|
||||||
})
|
|
||||||
|
|
||||||
result, runErr = multiagent.RunDeepAgent(
|
for {
|
||||||
taskCtx,
|
progressCallback := h.createProgressCallback(taskCtx, cancelWithCause, conversationID, assistantMessageID, sendEvent)
|
||||||
h.config,
|
taskCtxLoop := mcp.WithMCPConversationID(taskCtx, conversationID)
|
||||||
&h.config.MultiAgent,
|
taskCtxLoop = mcp.WithToolRunRegistry(taskCtxLoop, h.tasks)
|
||||||
h.agent,
|
taskCtxLoop = multiagent.WithHITLToolInterceptor(taskCtxLoop, func(ctx context.Context, toolName, arguments string) (string, error) {
|
||||||
h.logger,
|
return h.interceptHITLForEinoTool(ctx, cancelWithCause, conversationID, assistantMessageID, sendEvent, toolName, arguments)
|
||||||
conversationID,
|
})
|
||||||
curFinalMessage,
|
|
||||||
curHistory,
|
result, runErr = multiagent.RunDeepAgent(
|
||||||
roleTools,
|
taskCtxLoop,
|
||||||
progressCallback,
|
h.config,
|
||||||
h.agentsMarkdownDir,
|
&h.config.MultiAgent,
|
||||||
orch,
|
h.agent,
|
||||||
)
|
h.logger,
|
||||||
timeoutCancel()
|
conversationID,
|
||||||
|
curFinalMessage,
|
||||||
|
curHistory,
|
||||||
|
roleTools,
|
||||||
|
progressCallback,
|
||||||
|
h.agentsMarkdownDir,
|
||||||
|
orch,
|
||||||
|
chatReasoningToClientIntent(req.Reasoning),
|
||||||
|
)
|
||||||
|
timeoutCancel()
|
||||||
|
|
||||||
|
if result != nil && len(result.MCPExecutionIDs) > 0 {
|
||||||
|
cumulativeMCPExecutionIDs = mergeMCPExecutionIDLists(cumulativeMCPExecutionIDs, result.MCPExecutionIDs)
|
||||||
|
}
|
||||||
|
|
||||||
|
if runErr == nil {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
if runErr != nil {
|
|
||||||
cause := context.Cause(baseCtx)
|
cause := context.Cause(baseCtx)
|
||||||
|
if errors.Is(cause, multiagent.ErrInterruptContinue) {
|
||||||
|
if shouldPersistEinoAgentTraceAfterRunError(baseCtx) {
|
||||||
|
h.persistEinoAgentTraceForResume(conversationID, result)
|
||||||
|
}
|
||||||
|
note := h.tasks.TakeInterruptContinueNote(conversationID)
|
||||||
|
icSummary := interruptContinueTimelineSummary(note)
|
||||||
|
progressCallback("user_interrupt_continue", icSummary, map[string]interface{}{
|
||||||
|
"conversationId": conversationID,
|
||||||
|
"rawReason": strings.TrimSpace(note),
|
||||||
|
"emptyReason": strings.TrimSpace(note) == "",
|
||||||
|
"kind": "no_active_mcp_tool",
|
||||||
|
})
|
||||||
|
inject := formatInterruptContinueUserMessage(note)
|
||||||
|
// 不写入 messages 表为 user 气泡:避免主对话流出现大段模板;说明已由 user_interrupt_continue 记入助手 process_details(迭代详情)。
|
||||||
|
if hist, err := h.loadHistoryFromAgentTrace(conversationID); err == nil && len(hist) > 0 {
|
||||||
|
curHistory = hist
|
||||||
|
}
|
||||||
|
curFinalMessage = inject
|
||||||
|
sendEvent("progress", "已合并用户补充与最新轨迹,正在继续推理…", map[string]interface{}{
|
||||||
|
"conversationId": conversationID,
|
||||||
|
"source": "interrupt_continue",
|
||||||
|
})
|
||||||
|
h.tasks.UpdateTaskStatus(conversationID, "running")
|
||||||
|
baseCtx, cancelWithCause = context.WithCancelCause(context.Background())
|
||||||
|
h.tasks.BindTaskCancel(conversationID, cancelWithCause)
|
||||||
|
taskCtx, timeoutCancel = context.WithTimeout(baseCtx, 600*time.Minute)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
if shouldPersistEinoAgentTraceAfterRunError(baseCtx) {
|
if shouldPersistEinoAgentTraceAfterRunError(baseCtx) {
|
||||||
h.persistEinoAgentTraceForResume(conversationID, result)
|
h.persistEinoAgentTraceForResume(conversationID, result)
|
||||||
}
|
}
|
||||||
@@ -269,18 +310,7 @@ func (h *AgentHandler) MultiAgentLoopStream(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if assistantMessageID != "" {
|
if assistantMessageID != "" {
|
||||||
mcpIDsJSON := ""
|
_ = h.db.UpdateAssistantMessageFinalize(assistantMessageID, result.Response, cumulativeMCPExecutionIDs, multiagent.AggregatedReasoningFromTraceJSON(result.LastAgentTraceInput))
|
||||||
if len(result.MCPExecutionIDs) > 0 {
|
|
||||||
jsonData, _ := json.Marshal(result.MCPExecutionIDs)
|
|
||||||
mcpIDsJSON = string(jsonData)
|
|
||||||
}
|
|
||||||
_, _ = h.db.Exec(
|
|
||||||
"UPDATE messages SET content = ?, mcp_execution_ids = ?, updated_at = ? WHERE id = ?",
|
|
||||||
result.Response,
|
|
||||||
mcpIDsJSON,
|
|
||||||
time.Now(),
|
|
||||||
assistantMessageID,
|
|
||||||
)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if result.LastAgentTraceInput != "" || result.LastAgentTraceOutput != "" {
|
if result.LastAgentTraceInput != "" || result.LastAgentTraceOutput != "" {
|
||||||
@@ -294,7 +324,7 @@ func (h *AgentHandler) MultiAgentLoopStream(c *gin.Context) {
|
|||||||
effectiveOrch = config.NormalizeMultiAgentOrchestration(o)
|
effectiveOrch = config.NormalizeMultiAgentOrchestration(o)
|
||||||
}
|
}
|
||||||
sendEvent("response", result.Response, map[string]interface{}{
|
sendEvent("response", result.Response, map[string]interface{}{
|
||||||
"mcpExecutionIds": result.MCPExecutionIDs,
|
"mcpExecutionIds": cumulativeMCPExecutionIDs,
|
||||||
"conversationId": conversationID,
|
"conversationId": conversationID,
|
||||||
"messageId": assistantMessageID,
|
"messageId": assistantMessageID,
|
||||||
"agentMode": "eino_" + effectiveOrch,
|
"agentMode": "eino_" + effectiveOrch,
|
||||||
@@ -350,6 +380,7 @@ func (h *AgentHandler) MultiAgentLoop(c *gin.Context) {
|
|||||||
progressCallback,
|
progressCallback,
|
||||||
h.agentsMarkdownDir,
|
h.agentsMarkdownDir,
|
||||||
strings.TrimSpace(req.Orchestration),
|
strings.TrimSpace(req.Orchestration),
|
||||||
|
chatReasoningToClientIntent(req.Reasoning),
|
||||||
)
|
)
|
||||||
if runErr != nil {
|
if runErr != nil {
|
||||||
if shouldPersistEinoAgentTraceAfterRunError(baseCtx) {
|
if shouldPersistEinoAgentTraceAfterRunError(baseCtx) {
|
||||||
@@ -365,18 +396,7 @@ func (h *AgentHandler) MultiAgentLoop(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if prep.AssistantMessageID != "" {
|
if prep.AssistantMessageID != "" {
|
||||||
mcpIDsJSON := ""
|
_ = h.db.UpdateAssistantMessageFinalize(prep.AssistantMessageID, result.Response, result.MCPExecutionIDs, multiagent.AggregatedReasoningFromTraceJSON(result.LastAgentTraceInput))
|
||||||
if len(result.MCPExecutionIDs) > 0 {
|
|
||||||
jsonData, _ := json.Marshal(result.MCPExecutionIDs)
|
|
||||||
mcpIDsJSON = string(jsonData)
|
|
||||||
}
|
|
||||||
_, _ = h.db.Exec(
|
|
||||||
"UPDATE messages SET content = ?, mcp_execution_ids = ?, updated_at = ? WHERE id = ?",
|
|
||||||
result.Response,
|
|
||||||
mcpIDsJSON,
|
|
||||||
time.Now(),
|
|
||||||
prep.AssistantMessageID,
|
|
||||||
)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if result.LastAgentTraceInput != "" || result.LastAgentTraceOutput != "" {
|
if result.LastAgentTraceInput != "" || result.LastAgentTraceOutput != "" {
|
||||||
@@ -406,6 +426,52 @@ func (h *AgentHandler) persistEinoAgentTraceForResume(conversationID string, res
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// mergeMCPExecutionIDLists 去重合并多段 Run 的 MCP execution id(顺序:先 dst 后 more)。
|
||||||
|
func mergeMCPExecutionIDLists(dst []string, more []string) []string {
|
||||||
|
seen := make(map[string]struct{}, len(dst)+len(more))
|
||||||
|
out := make([]string, 0, len(dst)+len(more))
|
||||||
|
add := func(ids []string) {
|
||||||
|
for _, id := range ids {
|
||||||
|
id = strings.TrimSpace(id)
|
||||||
|
if id == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if _, ok := seen[id]; ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
seen[id] = struct{}{}
|
||||||
|
out = append(out, id)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
add(dst)
|
||||||
|
add(more)
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
// interruptContinueTimelineSummary 时间线 / process_details 中展示的简短正文(完整模板已写入另一条用户消息)。
|
||||||
|
func interruptContinueTimelineSummary(note string) string {
|
||||||
|
note = strings.TrimSpace(note)
|
||||||
|
if note == "" {
|
||||||
|
return "用户选择「中断并继续」,未填写说明;已按默认渗透补充模板合并上下文并续跑。"
|
||||||
|
}
|
||||||
|
return "用户中断说明(原文):\n\n" + note
|
||||||
|
}
|
||||||
|
|
||||||
|
// formatInterruptContinueUserMessage 将「中断并继续」弹窗中的说明格式化为新一轮 user 消息(渗透场景下强调路径补充与端口复扫)。
|
||||||
|
func formatInterruptContinueUserMessage(note string) string {
|
||||||
|
var b strings.Builder
|
||||||
|
b.WriteString("【用户补充 / 中断后继续】\n")
|
||||||
|
if s := strings.TrimSpace(note); s != "" {
|
||||||
|
b.WriteString(s)
|
||||||
|
b.WriteString("\n\n")
|
||||||
|
}
|
||||||
|
b.WriteString("【请在本轮落实】\n")
|
||||||
|
b.WriteString("- 将用户提供的接口路径、参数、业务变化纳入后续测试与推理。\n")
|
||||||
|
b.WriteString("- 若资产或目标信息有更新,请对目标重新执行端口/服务探测,再基于新结果规划下一步。\n")
|
||||||
|
b.WriteString("- 在已有轨迹基础上推进,避免无意义重复已完成的步骤。\n")
|
||||||
|
return strings.TrimSpace(b.String())
|
||||||
|
}
|
||||||
|
|
||||||
func multiAgentHTTPErrorStatus(err error) (int, string) {
|
func multiAgentHTTPErrorStatus(err error) (int, string) {
|
||||||
msg := err.Error()
|
msg := err.Error()
|
||||||
switch {
|
switch {
|
||||||
|
|||||||
@@ -55,13 +55,7 @@ func (h *AgentHandler) prepareMultiAgentSession(req *ChatRequest) (*multiAgentPr
|
|||||||
if getErr != nil {
|
if getErr != nil {
|
||||||
agentHistoryMessages = []agent.ChatMessage{}
|
agentHistoryMessages = []agent.ChatMessage{}
|
||||||
} else {
|
} else {
|
||||||
agentHistoryMessages = make([]agent.ChatMessage, 0, len(historyMessages))
|
agentHistoryMessages = dbMessagesToAgentChatMessages(historyMessages)
|
||||||
for _, msg := range historyMessages {
|
|
||||||
agentHistoryMessages = append(agentHistoryMessages, agent.ChatMessage{
|
|
||||||
Role: msg.Role,
|
|
||||||
Content: msg.Content,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -6,6 +6,8 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"cyberstrike-ai/internal/multiagent"
|
||||||
)
|
)
|
||||||
|
|
||||||
// ErrTaskCancelled 用户取消任务的错误
|
// ErrTaskCancelled 用户取消任务的错误
|
||||||
@@ -32,6 +34,9 @@ type AgentTask struct {
|
|||||||
// ActiveMCPExecutionID 当前正在执行的 MCP 工具 executionId(仅内存,供「中断并继续」= 仅掐当前工具)
|
// ActiveMCPExecutionID 当前正在执行的 MCP 工具 executionId(仅内存,供「中断并继续」= 仅掐当前工具)
|
||||||
ActiveMCPExecutionID string `json:"-"`
|
ActiveMCPExecutionID string `json:"-"`
|
||||||
|
|
||||||
|
// InterruptContinueNote 无 MCP 时「中断并继续」由用户在弹窗中填写的补充说明(Cancel 前写入,续跑轮次读取后清空)
|
||||||
|
InterruptContinueNote string `json:"-"`
|
||||||
|
|
||||||
cancel func(error)
|
cancel func(error)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -65,6 +70,50 @@ func (m *AgentTaskManager) UnregisterRunningTool(conversationID, executionID str
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetInterruptContinueNote 在发起 ErrInterruptContinue 取消前写入用户补充说明(仅内存)。
|
||||||
|
func (m *AgentTaskManager) SetInterruptContinueNote(conversationID, note string) {
|
||||||
|
conversationID = strings.TrimSpace(conversationID)
|
||||||
|
if conversationID == "" {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
m.mu.Lock()
|
||||||
|
defer m.mu.Unlock()
|
||||||
|
if t, ok := m.tasks[conversationID]; ok && t != nil {
|
||||||
|
t.InterruptContinueNote = note
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TakeInterruptContinueNote 读取并清空补充说明(续跑开始时调用一次)。
|
||||||
|
func (m *AgentTaskManager) TakeInterruptContinueNote(conversationID string) string {
|
||||||
|
conversationID = strings.TrimSpace(conversationID)
|
||||||
|
if conversationID == "" {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
m.mu.Lock()
|
||||||
|
defer m.mu.Unlock()
|
||||||
|
if t, ok := m.tasks[conversationID]; ok && t != nil {
|
||||||
|
n := t.InterruptContinueNote
|
||||||
|
t.InterruptContinueNote = ""
|
||||||
|
return n
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
// BindTaskCancel 在同一运行任务内替换与 context 绑定的 cancel 函数(用于中断后继续时换新 baseCtx)。
|
||||||
|
func (m *AgentTaskManager) BindTaskCancel(conversationID string, cancel context.CancelCauseFunc) {
|
||||||
|
conversationID = strings.TrimSpace(conversationID)
|
||||||
|
if conversationID == "" || cancel == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
m.mu.Lock()
|
||||||
|
defer m.mu.Unlock()
|
||||||
|
if t, ok := m.tasks[conversationID]; ok && t != nil {
|
||||||
|
t.cancel = func(err error) {
|
||||||
|
cancel(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// ActiveMCPExecutionID 返回当前会话进行中的工具 executionId,无则空串。
|
// ActiveMCPExecutionID 返回当前会话进行中的工具 executionId,无则空串。
|
||||||
func (m *AgentTaskManager) ActiveMCPExecutionID(conversationID string) string {
|
func (m *AgentTaskManager) ActiveMCPExecutionID(conversationID string) string {
|
||||||
conversationID = strings.TrimSpace(conversationID)
|
conversationID = strings.TrimSpace(conversationID)
|
||||||
@@ -210,8 +259,16 @@ func (m *AgentTaskManager) CancelTask(conversationID string, cause error) (bool,
|
|||||||
return true, nil
|
return true, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
task.Status = "cancelling"
|
// ErrInterruptContinue:仅掐断当前推理步骤,随后由处理器续跑,不进入长时间「取消中」态。
|
||||||
task.CancellingAt = time.Now()
|
if cause != nil && errors.Is(cause, multiagent.ErrInterruptContinue) {
|
||||||
|
task.Status = "running"
|
||||||
|
} else {
|
||||||
|
task.Status = "cancelling"
|
||||||
|
task.CancellingAt = time.Now()
|
||||||
|
}
|
||||||
|
if cause != nil && errors.Is(cause, ErrTaskCancelled) {
|
||||||
|
task.InterruptContinueNote = ""
|
||||||
|
}
|
||||||
cancel := task.cancel
|
cancel := task.cancel
|
||||||
m.mu.Unlock()
|
m.mu.Unlock()
|
||||||
|
|
||||||
|
|||||||
@@ -0,0 +1,293 @@
|
|||||||
|
package handler
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"cyberstrike-ai/internal/config"
|
||||||
|
"cyberstrike-ai/internal/robot/ilink"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/google/uuid"
|
||||||
|
"go.uber.org/zap"
|
||||||
|
)
|
||||||
|
|
||||||
|
const wechatLoginTTL = 5 * time.Minute
|
||||||
|
|
||||||
|
// WechatConfigSaver 绑定成功后写入配置并重启机器人连接
|
||||||
|
type WechatConfigSaver interface {
|
||||||
|
ApplyWechatRobotBinding(cfg config.RobotWechatConfig) error
|
||||||
|
}
|
||||||
|
|
||||||
|
type wechatLoginSession struct {
|
||||||
|
QRCode string
|
||||||
|
QRCodeImgURL string
|
||||||
|
PendingVerify string
|
||||||
|
CurrentBaseURL string
|
||||||
|
StartedAt time.Time
|
||||||
|
}
|
||||||
|
|
||||||
|
// WechatRobotHandler 微信 iLink 机器人(扫码绑定 + 配置)
|
||||||
|
type WechatRobotHandler struct {
|
||||||
|
config *config.Config
|
||||||
|
configSaver WechatConfigSaver
|
||||||
|
logger *zap.Logger
|
||||||
|
mu sync.Mutex
|
||||||
|
logins map[string]*wechatLoginSession
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewWechatRobotHandler 创建微信机器人处理器
|
||||||
|
func NewWechatRobotHandler(cfg *config.Config, saver WechatConfigSaver, logger *zap.Logger) *WechatRobotHandler {
|
||||||
|
return &WechatRobotHandler{
|
||||||
|
config: cfg,
|
||||||
|
configSaver: saver,
|
||||||
|
logger: logger,
|
||||||
|
logins: make(map[string]*wechatLoginSession),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *WechatRobotHandler) purgeExpiredLogins() {
|
||||||
|
now := time.Now()
|
||||||
|
for k, v := range h.logins {
|
||||||
|
if now.Sub(v.StartedAt) > wechatLoginTTL {
|
||||||
|
delete(h.logins, k)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *WechatRobotHandler) ilinkClient(baseURL string) *ilink.Client {
|
||||||
|
ver := h.config.Version
|
||||||
|
if ver == "" {
|
||||||
|
ver = "1.0.0"
|
||||||
|
}
|
||||||
|
ver = strings.TrimPrefix(strings.TrimSpace(ver), "v")
|
||||||
|
ver = strings.TrimPrefix(ver, "V")
|
||||||
|
wc := h.config.Robots.Wechat
|
||||||
|
return ilink.NewClient(baseURL, wc.BotToken, wc.BotAgent, ilink.BuildClientVersion(ver))
|
||||||
|
}
|
||||||
|
|
||||||
|
// HandleWechatQRCode POST /api/robot/wechat/qrcode — 生成绑定二维码
|
||||||
|
func (h *WechatRobotHandler) HandleWechatQRCode(c *gin.Context) {
|
||||||
|
h.mu.Lock()
|
||||||
|
h.purgeExpiredLogins()
|
||||||
|
h.mu.Unlock()
|
||||||
|
|
||||||
|
var req struct {
|
||||||
|
BotType string `json:"bot_type"`
|
||||||
|
}
|
||||||
|
_ = c.ShouldBindJSON(&req)
|
||||||
|
|
||||||
|
botType := req.BotType
|
||||||
|
if botType == "" {
|
||||||
|
botType = h.config.Robots.Wechat.BotType
|
||||||
|
}
|
||||||
|
if botType == "" {
|
||||||
|
botType = ilink.DefaultBotType
|
||||||
|
}
|
||||||
|
baseURL := h.config.Robots.Wechat.BaseURL
|
||||||
|
if baseURL == "" {
|
||||||
|
baseURL = ilink.DefaultBaseURL
|
||||||
|
}
|
||||||
|
|
||||||
|
var localTokens []string
|
||||||
|
if t := h.config.Robots.Wechat.BotToken; t != "" {
|
||||||
|
localTokens = []string{t}
|
||||||
|
}
|
||||||
|
|
||||||
|
client := h.ilinkClient(baseURL)
|
||||||
|
ctx, cancel := context.WithTimeout(c.Request.Context(), 30*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
qr, err := client.GetBotQRCode(ctx, botType, localTokens)
|
||||||
|
if err != nil {
|
||||||
|
h.logger.Warn("获取微信二维码失败", zap.Error(err))
|
||||||
|
c.JSON(http.StatusBadGateway, gin.H{"error": "获取二维码失败: " + err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if qr.QRCode == "" || qr.QRCodeImgContent == "" {
|
||||||
|
c.JSON(http.StatusBadGateway, gin.H{"error": "微信服务器未返回有效二维码"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
sessionKey := uuid.New().String()
|
||||||
|
h.mu.Lock()
|
||||||
|
h.logins[sessionKey] = &wechatLoginSession{
|
||||||
|
QRCode: qr.QRCode,
|
||||||
|
QRCodeImgURL: qr.QRCodeImgContent,
|
||||||
|
CurrentBaseURL: baseURL,
|
||||||
|
StartedAt: time.Now(),
|
||||||
|
}
|
||||||
|
h.mu.Unlock()
|
||||||
|
|
||||||
|
resp := gin.H{
|
||||||
|
"session_key": sessionKey,
|
||||||
|
"qrcode": qr.QRCode,
|
||||||
|
"qrcode_open_url": qr.QRCodeImgContent,
|
||||||
|
"message": "请使用微信扫描二维码并确认绑定",
|
||||||
|
}
|
||||||
|
if dataURL, err := ilink.QRCodeDataURL(qr.QRCodeImgContent, 256); err != nil {
|
||||||
|
h.logger.Warn("生成二维码图片失败", zap.Error(err))
|
||||||
|
} else {
|
||||||
|
resp["qrcode_image_data_url"] = dataURL
|
||||||
|
}
|
||||||
|
|
||||||
|
c.JSON(http.StatusOK, resp)
|
||||||
|
}
|
||||||
|
|
||||||
|
// HandleWechatQRCodeStatus GET /api/robot/wechat/qrcode/status — 轮询扫码状态
|
||||||
|
func (h *WechatRobotHandler) HandleWechatQRCodeStatus(c *gin.Context) {
|
||||||
|
sessionKey := c.Query("session_key")
|
||||||
|
verifyCode := c.Query("verify_code")
|
||||||
|
if sessionKey == "" {
|
||||||
|
c.JSON(http.StatusBadRequest, gin.H{"error": "缺少 session_key"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
h.mu.Lock()
|
||||||
|
sess, ok := h.logins[sessionKey]
|
||||||
|
h.mu.Unlock()
|
||||||
|
if !ok {
|
||||||
|
c.JSON(http.StatusNotFound, gin.H{"error": "登录会话不存在或已过期,请重新生成二维码"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if time.Since(sess.StartedAt) > wechatLoginTTL {
|
||||||
|
h.mu.Lock()
|
||||||
|
delete(h.logins, sessionKey)
|
||||||
|
h.mu.Unlock()
|
||||||
|
c.JSON(http.StatusGone, gin.H{"error": "二维码已过期,请重新生成"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
baseURL := sess.CurrentBaseURL
|
||||||
|
if baseURL == "" {
|
||||||
|
baseURL = ilink.DefaultBaseURL
|
||||||
|
}
|
||||||
|
vc := verifyCode
|
||||||
|
if vc == "" {
|
||||||
|
vc = sess.PendingVerify
|
||||||
|
}
|
||||||
|
|
||||||
|
client := h.ilinkClient(baseURL)
|
||||||
|
ctx, cancel := context.WithTimeout(c.Request.Context(), 40*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
st, err := client.GetQRCodeStatus(ctx, sess.QRCode, vc)
|
||||||
|
if err != nil {
|
||||||
|
h.logger.Warn("轮询微信二维码状态失败", zap.Error(err))
|
||||||
|
c.JSON(http.StatusBadGateway, gin.H{"error": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
switch st.Status {
|
||||||
|
case "wait", "scaned":
|
||||||
|
c.JSON(http.StatusOK, gin.H{"status": st.Status})
|
||||||
|
return
|
||||||
|
case "need_verifycode":
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"status": st.Status,
|
||||||
|
"message": "请在手机微信查看配对数字,并在下方输入",
|
||||||
|
})
|
||||||
|
return
|
||||||
|
case "scaned_but_redirect":
|
||||||
|
if st.RedirectHost != "" {
|
||||||
|
h.mu.Lock()
|
||||||
|
if s, ok := h.logins[sessionKey]; ok {
|
||||||
|
s.CurrentBaseURL = "https://" + st.RedirectHost
|
||||||
|
}
|
||||||
|
h.mu.Unlock()
|
||||||
|
}
|
||||||
|
c.JSON(http.StatusOK, gin.H{"status": st.Status})
|
||||||
|
return
|
||||||
|
case "binded_redirect":
|
||||||
|
h.mu.Lock()
|
||||||
|
delete(h.logins, sessionKey)
|
||||||
|
h.mu.Unlock()
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"status": st.Status,
|
||||||
|
"already_connected": true,
|
||||||
|
"message": "该微信已绑定过,无需重复绑定",
|
||||||
|
})
|
||||||
|
return
|
||||||
|
case "confirmed":
|
||||||
|
if st.BotToken == "" || st.ILinkBotID == "" {
|
||||||
|
c.JSON(http.StatusBadGateway, gin.H{"error": "绑定确认成功但缺少 bot_token"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
saveBase := st.BaseURL
|
||||||
|
if saveBase == "" {
|
||||||
|
saveBase = baseURL
|
||||||
|
}
|
||||||
|
wc := h.config.Robots.Wechat
|
||||||
|
wc.Enabled = true
|
||||||
|
wc.BotToken = st.BotToken
|
||||||
|
wc.ILinkBotID = st.ILinkBotID
|
||||||
|
wc.ILinkUserID = st.ILinkUserID
|
||||||
|
wc.BaseURL = saveBase
|
||||||
|
if wc.BotType == "" {
|
||||||
|
wc.BotType = ilink.DefaultBotType
|
||||||
|
}
|
||||||
|
if wc.BotAgent == "" {
|
||||||
|
wc.BotAgent = ilink.DefaultBotAgent
|
||||||
|
}
|
||||||
|
if h.configSaver != nil {
|
||||||
|
if err := h.configSaver.ApplyWechatRobotBinding(wc); err != nil {
|
||||||
|
h.logger.Warn("保存微信机器人配置失败", zap.Error(err))
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{"error": "保存配置失败: " + err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
h.config.Robots.Wechat = wc
|
||||||
|
}
|
||||||
|
h.mu.Lock()
|
||||||
|
delete(h.logins, sessionKey)
|
||||||
|
h.mu.Unlock()
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"status": "confirmed",
|
||||||
|
"message": "绑定成功,微信机器人已启用",
|
||||||
|
"ilink_bot_id": st.ILinkBotID,
|
||||||
|
"ilink_user_id": st.ILinkUserID,
|
||||||
|
})
|
||||||
|
return
|
||||||
|
default:
|
||||||
|
c.JSON(http.StatusOK, gin.H{"status": st.Status})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// HandleWechatVerifyCode POST /api/robot/wechat/qrcode/verify — 提交手机配对数字
|
||||||
|
func (h *WechatRobotHandler) HandleWechatVerifyCode(c *gin.Context) {
|
||||||
|
var req struct {
|
||||||
|
SessionKey string `json:"session_key"`
|
||||||
|
VerifyCode string `json:"verify_code"`
|
||||||
|
}
|
||||||
|
if err := c.ShouldBindJSON(&req); err != nil || req.SessionKey == "" || req.VerifyCode == "" {
|
||||||
|
c.JSON(http.StatusBadRequest, gin.H{"error": "需要 session_key 与 verify_code"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
h.mu.Lock()
|
||||||
|
sess, ok := h.logins[req.SessionKey]
|
||||||
|
if ok {
|
||||||
|
sess.PendingVerify = req.VerifyCode
|
||||||
|
}
|
||||||
|
h.mu.Unlock()
|
||||||
|
if !ok {
|
||||||
|
c.JSON(http.StatusNotFound, gin.H{"error": "登录会话不存在或已过期"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
c.JSON(http.StatusOK, gin.H{"message": "已提交配对码,请继续等待绑定"})
|
||||||
|
}
|
||||||
|
|
||||||
|
// HandleWechatStatus GET /api/robot/wechat/status — 当前绑定状态(供前端展示)
|
||||||
|
func (h *WechatRobotHandler) HandleWechatStatus(c *gin.Context) {
|
||||||
|
wc := h.config.Robots.Wechat
|
||||||
|
bound := wc.BotToken != "" && wc.ILinkBotID != ""
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"enabled": wc.Enabled,
|
||||||
|
"bound": bound,
|
||||||
|
"ilink_bot_id": wc.ILinkBotID,
|
||||||
|
"ilink_user_id": wc.ILinkUserID,
|
||||||
|
"base_url": wc.BaseURL,
|
||||||
|
})
|
||||||
|
}
|
||||||
+81
-1
@@ -44,6 +44,10 @@ type Server struct {
|
|||||||
runningCancels map[string]context.CancelFunc
|
runningCancels map[string]context.CancelFunc
|
||||||
runningCancelsMu sync.Mutex
|
runningCancelsMu sync.Mutex
|
||||||
abortUserNotes map[string]string // 监控页终止时附带的用户说明,与 executionID 对应
|
abortUserNotes map[string]string // 监控页终止时附带的用户说明,与 executionID 对应
|
||||||
|
// httpToolTimeoutMinutes 同步 agent.tool_timeout_minutes,用于 POST /api/mcp 的 tools/call(不经 Agent 包装的路径)。
|
||||||
|
// nil 表示未配置,沿用默认 30 分钟;指向 0 表示不限制;>0 为分钟数。
|
||||||
|
httpToolTimeoutMinutes *int
|
||||||
|
httpToolTimeoutMu sync.RWMutex
|
||||||
}
|
}
|
||||||
|
|
||||||
type sseClient struct {
|
type sseClient struct {
|
||||||
@@ -90,6 +94,39 @@ func NewServerWithStorage(logger *zap.Logger, storage MonitorStorage) *Server {
|
|||||||
return s
|
return s
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ConfigureHTTPToolCallTimeoutFromAgentMinutes 将 agent.tool_timeout_minutes 同步到经 HTTP POST /api/mcp 触发的 tools/call。
|
||||||
|
// minutes<=0 表示不设置硬性截止时间(与配置「0 不限制」一致);minutes>0 为该次调用的最长等待时间。
|
||||||
|
// 未调用前对 tools/call 使用默认 30 分钟(与历史硬编码一致)。
|
||||||
|
func (s *Server) ConfigureHTTPToolCallTimeoutFromAgentMinutes(minutes int) {
|
||||||
|
if s == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
v := minutes
|
||||||
|
if v < 0 {
|
||||||
|
v = 0
|
||||||
|
}
|
||||||
|
s.httpToolTimeoutMu.Lock()
|
||||||
|
defer s.httpToolTimeoutMu.Unlock()
|
||||||
|
s.httpToolTimeoutMinutes = &v
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) effectiveHTTPToolCallDeadline() (context.Context, context.CancelFunc) {
|
||||||
|
const defaultDur = 30 * time.Minute
|
||||||
|
if s == nil {
|
||||||
|
return context.WithTimeout(context.Background(), defaultDur)
|
||||||
|
}
|
||||||
|
s.httpToolTimeoutMu.RLock()
|
||||||
|
mPtr := s.httpToolTimeoutMinutes
|
||||||
|
s.httpToolTimeoutMu.RUnlock()
|
||||||
|
if mPtr == nil {
|
||||||
|
return context.WithTimeout(context.Background(), defaultDur)
|
||||||
|
}
|
||||||
|
if *mPtr <= 0 {
|
||||||
|
return context.WithCancel(context.Background())
|
||||||
|
}
|
||||||
|
return context.WithTimeout(context.Background(), time.Duration(*mPtr)*time.Minute)
|
||||||
|
}
|
||||||
|
|
||||||
// RegisterTool 注册工具
|
// RegisterTool 注册工具
|
||||||
func (s *Server) RegisterTool(tool Tool, handler ToolHandler) {
|
func (s *Server) RegisterTool(tool Tool, handler ToolHandler) {
|
||||||
s.mu.Lock()
|
s.mu.Lock()
|
||||||
@@ -457,7 +494,7 @@ func (s *Server) handleCallTool(msg *Message) *Message {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
baseCtx, timeoutCancel := context.WithTimeout(context.Background(), 30*time.Minute)
|
baseCtx, timeoutCancel := s.effectiveHTTPToolCallDeadline()
|
||||||
defer timeoutCancel()
|
defer timeoutCancel()
|
||||||
execCtx, runCancel := context.WithCancel(baseCtx)
|
execCtx, runCancel := context.WithCancel(baseCtx)
|
||||||
s.registerRunningCancel(executionID, runCancel)
|
s.registerRunningCancel(executionID, runCancel)
|
||||||
@@ -883,6 +920,49 @@ func (s *Server) CallTool(ctx context.Context, toolName string, args map[string]
|
|||||||
return finalResult, executionID, nil
|
return finalResult, executionID, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// RecordCompletedToolInvocation 将已在其它路径完成的工具调用写入监控存储(格式与 CallTool 结束后一致),
|
||||||
|
// 用于 Eino ADK filesystem execute 等未经过 CallTool 的场景;返回 executionId 供助手消息 mcpExecutionIds 关联。
|
||||||
|
func (s *Server) RecordCompletedToolInvocation(toolName string, args map[string]interface{}, resultText string, invokeErr error) string {
|
||||||
|
if s == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
if args == nil {
|
||||||
|
args = map[string]interface{}{}
|
||||||
|
}
|
||||||
|
executionID := uuid.New().String()
|
||||||
|
now := time.Now()
|
||||||
|
failed := invokeErr != nil
|
||||||
|
exec := &ToolExecution{
|
||||||
|
ID: executionID,
|
||||||
|
ToolName: toolName,
|
||||||
|
Arguments: args,
|
||||||
|
StartTime: now,
|
||||||
|
EndTime: &now,
|
||||||
|
Duration: 0,
|
||||||
|
}
|
||||||
|
if failed {
|
||||||
|
exec.Status = "failed"
|
||||||
|
exec.Error = invokeErr.Error()
|
||||||
|
if strings.TrimSpace(resultText) != "" {
|
||||||
|
exec.Result = &ToolResult{Content: []Content{{Type: "text", Text: resultText}}}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
exec.Status = "completed"
|
||||||
|
text := resultText
|
||||||
|
if strings.TrimSpace(text) == "" {
|
||||||
|
text = "(无输出)"
|
||||||
|
}
|
||||||
|
exec.Result = &ToolResult{Content: []Content{{Type: "text", Text: text}}}
|
||||||
|
}
|
||||||
|
if s.storage != nil {
|
||||||
|
if err := s.storage.SaveToolExecution(exec); err != nil {
|
||||||
|
s.logger.Warn("RecordCompletedToolInvocation 保存失败", zap.Error(err))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
s.updateStats(toolName, failed)
|
||||||
|
return executionID
|
||||||
|
}
|
||||||
|
|
||||||
// cleanupOldExecutions 清理旧的执行记录,防止内存无限增长
|
// cleanupOldExecutions 清理旧的执行记录,防止内存无限增长
|
||||||
func (s *Server) cleanupOldExecutions() {
|
func (s *Server) cleanupOldExecutions() {
|
||||||
if len(s.executions) <= s.maxExecutionsInMemory {
|
if len(s.executions) <= s.maxExecutionsInMemory {
|
||||||
|
|||||||
@@ -11,8 +11,13 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
|
"unicode/utf8"
|
||||||
|
|
||||||
|
"cyberstrike-ai/internal/agent"
|
||||||
|
"cyberstrike-ai/internal/config"
|
||||||
"cyberstrike-ai/internal/einomcp"
|
"cyberstrike-ai/internal/einomcp"
|
||||||
|
"cyberstrike-ai/internal/einoobserve"
|
||||||
|
"cyberstrike-ai/internal/openai"
|
||||||
|
|
||||||
"github.com/cloudwego/eino/adk"
|
"github.com/cloudwego/eino/adk"
|
||||||
"github.com/cloudwego/eino/schema"
|
"github.com/cloudwego/eino/schema"
|
||||||
@@ -20,7 +25,9 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
// normalizeStreamingDelta 将可能是“累计片段”的 chunk 归一化为“纯增量”。
|
// normalizeStreamingDelta 将可能是“累计片段”的 chunk 归一化为“纯增量”。
|
||||||
// 一些模型/桥接层在流式过程中会重复发送已输出前缀,前端若直接 buffer+=chunk 会出现“结巴”重复。
|
// 一些模型/桥接层在流式过程中会重复发送已输出前缀,前端若直接 buffer+=chunk 会出现重复文本。
|
||||||
|
//
|
||||||
|
// 注意:与 internal/openai.normalizeStreamingDelta 保持一致。
|
||||||
func normalizeStreamingDelta(current, incoming string) (next, delta string) {
|
func normalizeStreamingDelta(current, incoming string) (next, delta string) {
|
||||||
if incoming == "" {
|
if incoming == "" {
|
||||||
return current, ""
|
return current, ""
|
||||||
@@ -28,31 +35,22 @@ func normalizeStreamingDelta(current, incoming string) (next, delta string) {
|
|||||||
if current == "" {
|
if current == "" {
|
||||||
return incoming, incoming
|
return incoming, incoming
|
||||||
}
|
}
|
||||||
if incoming == current {
|
if strings.HasPrefix(incoming, current) && len(incoming) > len(current) {
|
||||||
return current, ""
|
|
||||||
}
|
|
||||||
// incoming 是累计全文(包含 current 前缀)
|
|
||||||
if strings.HasPrefix(incoming, current) {
|
|
||||||
return incoming, incoming[len(current):]
|
return incoming, incoming[len(current):]
|
||||||
}
|
}
|
||||||
// incoming 完全是已输出尾部重发
|
if incoming == current && utf8.RuneCountInString(current) > 1 {
|
||||||
if strings.HasSuffix(current, incoming) {
|
|
||||||
return current, ""
|
return current, ""
|
||||||
}
|
}
|
||||||
|
|
||||||
// 处理边界重叠:current 后缀与 incoming 前缀重叠,只追加非重叠部分。
|
|
||||||
max := len(current)
|
|
||||||
if len(incoming) < max {
|
|
||||||
max = len(incoming)
|
|
||||||
}
|
|
||||||
for overlap := max; overlap > 0; overlap-- {
|
|
||||||
if current[len(current)-overlap:] == incoming[:overlap] {
|
|
||||||
return current + incoming[overlap:], incoming[overlap:]
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return current + incoming, incoming
|
return current + incoming, incoming
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func isInterruptContinue(ctx context.Context) bool {
|
||||||
|
if ctx == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return errors.Is(context.Cause(ctx), ErrInterruptContinue)
|
||||||
|
}
|
||||||
|
|
||||||
func isEinoIterationLimitError(err error) bool {
|
func isEinoIterationLimitError(err error) bool {
|
||||||
if err == nil {
|
if err == nil {
|
||||||
return false
|
return false
|
||||||
@@ -83,10 +81,25 @@ type einoADKRunLoopArgs struct {
|
|||||||
McpIDsMu *sync.Mutex
|
McpIDsMu *sync.Mutex
|
||||||
McpIDs *[]string
|
McpIDs *[]string
|
||||||
|
|
||||||
|
// FilesystemMonitorAgent / FilesystemMonitorRecord 非 nil 时,将 Eino ADK filesystem 中间件工具(ls/read_file/write_file/edit_file/glob/grep)
|
||||||
|
// 在完成时写入 MCP 监控;execute 仍由 eino_execute_monitor 记录,此处跳过。
|
||||||
|
FilesystemMonitorAgent *agent.Agent
|
||||||
|
FilesystemMonitorRecord einomcp.ExecutionRecorder
|
||||||
|
|
||||||
|
// ToolInvokeNotify 与 einomcp.ToolsFromDefinitions 共享:run loop 在迭代前 Set,MCP 桥 Fire 以补全 tool_result。
|
||||||
|
ToolInvokeNotify *einomcp.ToolInvokeNotifyHolder
|
||||||
|
|
||||||
DA adk.Agent
|
DA adk.Agent
|
||||||
|
|
||||||
// EmptyResponseMessage 当未捕获到助手正文时的占位(多代理与单代理文案不同)。
|
// EmptyResponseMessage 当未捕获到助手正文时的占位(多代理与单代理文案不同)。
|
||||||
EmptyResponseMessage string
|
EmptyResponseMessage string
|
||||||
|
|
||||||
|
// ModelFacingTrace 可选:由各 ChatModelAgent Handlers 链末尾中间件写入「即将送入模型」的消息快照;
|
||||||
|
// 非空时优先用于 LastAgentTraceInput 序列化,使续跑与 summarization/reduction 后的上下文一致。
|
||||||
|
ModelFacingTrace *modelFacingTraceHolder
|
||||||
|
|
||||||
|
// EinoCallbacks 可选:为 ADK Runner 注入 eino [callbacks] 全链路观测(见 internal/einoobserve)。
|
||||||
|
EinoCallbacks *config.MultiAgentEinoCallbacksConfig
|
||||||
}
|
}
|
||||||
|
|
||||||
func runEinoADKAgentLoop(ctx context.Context, args *einoADKRunLoopArgs, baseMsgs []adk.Message) (*RunResult, error) {
|
func runEinoADKAgentLoop(ctx context.Context, args *einoADKRunLoopArgs, baseMsgs []adk.Message) (*RunResult, error) {
|
||||||
@@ -164,6 +177,8 @@ func runEinoADKAgentLoop(ctx context.Context, args *einoADKRunLoopArgs, baseMsgs
|
|||||||
var einoMainRound int
|
var einoMainRound int
|
||||||
var einoLastAgent string
|
var einoLastAgent string
|
||||||
subAgentToolStep := make(map[string]int)
|
subAgentToolStep := make(map[string]int)
|
||||||
|
// mainAgentToolStep:主代理每次工具调用批次递增,供 UI 显示「第 N 轮」(单代理无子代理切换时原先会一直停在第 1 轮)。
|
||||||
|
mainAgentToolStep := make(map[string]int)
|
||||||
pendingByID := make(map[string]toolCallPendingInfo)
|
pendingByID := make(map[string]toolCallPendingInfo)
|
||||||
pendingQueueByAgent := make(map[string][]string)
|
pendingQueueByAgent := make(map[string][]string)
|
||||||
markPending := func(tc toolCallPendingInfo) {
|
markPending := func(tc toolCallPendingInfo) {
|
||||||
@@ -224,6 +239,82 @@ func runEinoADKAgentLoop(ctx context.Context, args *einoADKRunLoopArgs, baseMsgs
|
|||||||
pendingQueueByAgent = make(map[string][]string)
|
pendingQueueByAgent = make(map[string][]string)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 最近一次成功的 Eino filesystem execute 的标准输出(trim):用于抑制模型紧接着复述同一字符串时的重复「助手输出」时间线。
|
||||||
|
var executeStdoutDupMu sync.Mutex
|
||||||
|
var pendingExecuteStdoutDup string
|
||||||
|
recordPendingExecuteStdoutDup := func(toolName, stdout string, isErr bool) {
|
||||||
|
if isErr || !strings.EqualFold(strings.TrimSpace(toolName), "execute") {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
t := strings.TrimSpace(stdout)
|
||||||
|
if t == "" {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
executeStdoutDupMu.Lock()
|
||||||
|
pendingExecuteStdoutDup = t
|
||||||
|
executeStdoutDupMu.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
var toolResultSent sync.Map // toolCallID -> struct{};与 ADK Tool 消息去重,避免 bridge 与事件流各推一次
|
||||||
|
if args.ToolInvokeNotify != nil {
|
||||||
|
args.ToolInvokeNotify.Set(func(toolCallID, toolName, einoAgent string, success bool, content string, invokeErr error) {
|
||||||
|
tid := strings.TrimSpace(toolCallID)
|
||||||
|
removePendingByID(tid)
|
||||||
|
if tid == "" || progress == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if _, loaded := toolResultSent.LoadOrStore(tid, struct{}{}); loaded {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
isErr := !success || invokeErr != nil
|
||||||
|
body := content
|
||||||
|
if invokeErr != nil {
|
||||||
|
// 保留已流式累计的 stdout(如 execute 超时前的一半输出),避免 tool_result 只剩错误串、模型与 UI 丢失上下文
|
||||||
|
tail := friendlyEinoExecuteInvokeTail(invokeErr)
|
||||||
|
// execute 流式包装可能已把超时句写入 content(供 ADK tool 与流式 delta);勿重复拼接
|
||||||
|
if tail != "" && strings.Contains(content, tail) {
|
||||||
|
body = content
|
||||||
|
} else if strings.TrimSpace(content) != "" {
|
||||||
|
body = strings.TrimRight(content, "\n") + "\n\n" + tail
|
||||||
|
} else {
|
||||||
|
body = tail
|
||||||
|
}
|
||||||
|
isErr = true
|
||||||
|
}
|
||||||
|
recordPendingExecuteStdoutDup(toolName, body, isErr)
|
||||||
|
preview := body
|
||||||
|
if len(preview) > 200 {
|
||||||
|
preview = preview[:200] + "..."
|
||||||
|
}
|
||||||
|
agentTag := strings.TrimSpace(einoAgent)
|
||||||
|
if agentTag == "" {
|
||||||
|
agentTag = orchestratorName
|
||||||
|
}
|
||||||
|
progress("tool_result", fmt.Sprintf("工具结果 (%s)", toolName), map[string]interface{}{
|
||||||
|
"toolName": toolName,
|
||||||
|
"success": !isErr,
|
||||||
|
"isError": isErr,
|
||||||
|
"result": body,
|
||||||
|
"resultPreview": preview,
|
||||||
|
"toolCallId": tid,
|
||||||
|
"conversationId": conversationID,
|
||||||
|
"einoAgent": agentTag,
|
||||||
|
"einoRole": einoRoleTag(agentTag),
|
||||||
|
"source": "eino",
|
||||||
|
})
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
if args.EinoCallbacks != nil {
|
||||||
|
ctx = einoobserve.AttachAgentRunCallbacks(ctx, args.EinoCallbacks, einoobserve.Params{
|
||||||
|
Logger: logger,
|
||||||
|
Progress: progress,
|
||||||
|
ConversationID: conversationID,
|
||||||
|
OrchMode: orchMode,
|
||||||
|
OrchestratorName: orchestratorName,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
runnerCfg := adk.RunnerConfig{
|
runnerCfg := adk.RunnerConfig{
|
||||||
Agent: da,
|
Agent: da,
|
||||||
EnableStreaming: true,
|
EnableStreaming: true,
|
||||||
@@ -352,7 +443,8 @@ func runEinoADKAgentLoop(ctx context.Context, args *einoADKRunLoopArgs, baseMsgs
|
|||||||
}
|
}
|
||||||
ids := snapshotMCPIDs()
|
ids := snapshotMCPIDs()
|
||||||
return buildEinoRunResultFromAccumulated(
|
return buildEinoRunResultFromAccumulated(
|
||||||
orchMode, runAccumulatedMsgs, lastAssistant, lastPlanExecuteExecutor, emptyHint, ids, true,
|
orchMode, runAccumulatedMsgs, persistTraceSource(args, runAccumulatedMsgs),
|
||||||
|
lastAssistant, lastPlanExecuteExecutor, emptyHint, ids, true,
|
||||||
), runErr
|
), runErr
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -362,10 +454,18 @@ func runEinoADKAgentLoop(ctx context.Context, args *einoADKRunLoopArgs, baseMsgs
|
|||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
flushAllPendingAsFailed(ctx.Err())
|
flushAllPendingAsFailed(ctx.Err())
|
||||||
if progress != nil {
|
if progress != nil {
|
||||||
progress("error", "Request cancelled / 请求已取消", map[string]interface{}{
|
if isInterruptContinue(ctx) {
|
||||||
"conversationId": conversationID,
|
progress("progress", "已暂停当前输出,正在合并用户补充并继续…", map[string]interface{}{
|
||||||
"source": "eino",
|
"conversationId": conversationID,
|
||||||
})
|
"source": "eino",
|
||||||
|
"kind": "interrupt_continue",
|
||||||
|
})
|
||||||
|
} else {
|
||||||
|
progress("error", "Request cancelled / 请求已取消", map[string]interface{}{
|
||||||
|
"conversationId": conversationID,
|
||||||
|
"source": "eino",
|
||||||
|
})
|
||||||
|
}
|
||||||
}
|
}
|
||||||
return takePartial(ctx.Err())
|
return takePartial(ctx.Err())
|
||||||
default:
|
default:
|
||||||
@@ -379,10 +479,18 @@ func runEinoADKAgentLoop(ctx context.Context, args *einoADKRunLoopArgs, baseMsgs
|
|||||||
if ctxErr := ctx.Err(); ctxErr != nil {
|
if ctxErr := ctx.Err(); ctxErr != nil {
|
||||||
flushAllPendingAsFailed(ctxErr)
|
flushAllPendingAsFailed(ctxErr)
|
||||||
if progress != nil {
|
if progress != nil {
|
||||||
progress("error", ctxErr.Error(), map[string]interface{}{
|
if isInterruptContinue(ctx) {
|
||||||
"conversationId": conversationID,
|
progress("progress", "已暂停当前输出,正在合并用户补充并继续…", map[string]interface{}{
|
||||||
"source": "eino",
|
"conversationId": conversationID,
|
||||||
})
|
"source": "eino",
|
||||||
|
"kind": "interrupt_continue",
|
||||||
|
})
|
||||||
|
} else {
|
||||||
|
progress("error", ctxErr.Error(), map[string]interface{}{
|
||||||
|
"conversationId": conversationID,
|
||||||
|
"source": "eino",
|
||||||
|
})
|
||||||
|
}
|
||||||
}
|
}
|
||||||
return takePartial(ctxErr)
|
return takePartial(ctxErr)
|
||||||
}
|
}
|
||||||
@@ -423,8 +531,10 @@ func runEinoADKAgentLoop(ctx context.Context, args *einoADKRunLoopArgs, baseMsgs
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
if streamsMainAssistant(ev.AgentName) {
|
if streamsMainAssistant(ev.AgentName) {
|
||||||
|
mainIterKey := einoMainIterationKey(iterEinoAgent, orchestratorName)
|
||||||
if einoMainRound == 0 {
|
if einoMainRound == 0 {
|
||||||
einoMainRound = 1
|
einoMainRound = 1
|
||||||
|
mainAgentToolStep[mainIterKey] = 1
|
||||||
progress("iteration", "", map[string]interface{}{
|
progress("iteration", "", map[string]interface{}{
|
||||||
"iteration": 1,
|
"iteration": 1,
|
||||||
"einoScope": "main",
|
"einoScope": "main",
|
||||||
@@ -434,17 +544,26 @@ func runEinoADKAgentLoop(ctx context.Context, args *einoADKRunLoopArgs, baseMsgs
|
|||||||
"conversationId": conversationID,
|
"conversationId": conversationID,
|
||||||
"source": "eino",
|
"source": "eino",
|
||||||
})
|
})
|
||||||
} else if einoLastAgent != "" && !streamsMainAssistant(einoLastAgent) {
|
} else if einoLastAgent != "" {
|
||||||
einoMainRound++
|
needBump := false
|
||||||
progress("iteration", "", map[string]interface{}{
|
if !streamsMainAssistant(einoLastAgent) {
|
||||||
"iteration": einoMainRound,
|
needBump = true // 子代理 → 主代理
|
||||||
"einoScope": "main",
|
} else if einoLastAgent != ev.AgentName {
|
||||||
"einoRole": "orchestrator",
|
needBump = true // plan_execute:planner ↔ executor 等主代理切换
|
||||||
"einoAgent": iterEinoAgent,
|
}
|
||||||
"orchestration": orchMode,
|
if needBump {
|
||||||
"conversationId": conversationID,
|
einoMainRound++
|
||||||
"source": "eino",
|
mainAgentToolStep[mainIterKey] = einoMainRound
|
||||||
})
|
progress("iteration", "", map[string]interface{}{
|
||||||
|
"iteration": einoMainRound,
|
||||||
|
"einoScope": "main",
|
||||||
|
"einoRole": "orchestrator",
|
||||||
|
"einoAgent": iterEinoAgent,
|
||||||
|
"orchestration": orchMode,
|
||||||
|
"conversationId": conversationID,
|
||||||
|
"source": "eino",
|
||||||
|
})
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
einoLastAgent = ev.AgentName
|
einoLastAgent = ev.AgentName
|
||||||
@@ -467,98 +586,193 @@ func runEinoADKAgentLoop(ctx context.Context, args *einoADKRunLoopArgs, baseMsgs
|
|||||||
var subAssistantBuf string
|
var subAssistantBuf string
|
||||||
var subReplyStreamID string
|
var subReplyStreamID string
|
||||||
var mainAssistantBuf string
|
var mainAssistantBuf string
|
||||||
|
// 已通过 response_delta 推到前端的正文(与 monitor.js normalizeStreamingDeltaJs 累积一致)
|
||||||
|
var mainAssistWireAccum string
|
||||||
|
var mainAssistDupTarget string // 非空表示本段主助手流需缓冲至 EOF,与 execute 输出比对去重
|
||||||
var reasoningBuf string
|
var reasoningBuf string
|
||||||
|
var prevReasoningDisplay string // UI 用:剥离 Claude 内部 signature 尾缀后的累计展示
|
||||||
var streamRecvErr error
|
var streamRecvErr error
|
||||||
|
type streamMsg struct {
|
||||||
|
chunk *schema.Message
|
||||||
|
err error
|
||||||
|
}
|
||||||
|
recvCh := make(chan streamMsg, 8)
|
||||||
|
go func() {
|
||||||
|
defer close(recvCh)
|
||||||
|
for {
|
||||||
|
ch, rerr := mv.MessageStream.Recv()
|
||||||
|
recvCh <- streamMsg{chunk: ch, err: rerr}
|
||||||
|
if rerr != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
streamRecvLoop:
|
||||||
for {
|
for {
|
||||||
chunk, rerr := mv.MessageStream.Recv()
|
select {
|
||||||
if rerr != nil {
|
case <-ctx.Done():
|
||||||
if errors.Is(rerr, io.EOF) {
|
streamRecvErr = ctx.Err()
|
||||||
break
|
break streamRecvLoop
|
||||||
|
case sm, ok := <-recvCh:
|
||||||
|
if !ok {
|
||||||
|
break streamRecvLoop
|
||||||
}
|
}
|
||||||
if logger != nil {
|
chunk, rerr := sm.chunk, sm.err
|
||||||
logger.Warn("eino stream recv error, flushing incomplete stream",
|
if rerr != nil {
|
||||||
zap.Error(rerr),
|
if errors.Is(rerr, io.EOF) {
|
||||||
zap.String("agent", ev.AgentName),
|
break streamRecvLoop
|
||||||
zap.Int("toolFragments", len(toolStreamFragments)))
|
|
||||||
}
|
|
||||||
streamRecvErr = rerr
|
|
||||||
break
|
|
||||||
}
|
|
||||||
if chunk == nil {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if progress != nil && strings.TrimSpace(chunk.ReasoningContent) != "" {
|
|
||||||
var reasoningDelta string
|
|
||||||
reasoningBuf, reasoningDelta = normalizeStreamingDelta(reasoningBuf, chunk.ReasoningContent)
|
|
||||||
if reasoningDelta != "" {
|
|
||||||
if reasoningStreamID == "" {
|
|
||||||
reasoningStreamID = fmt.Sprintf("eino-reasoning-%s-%d", conversationID, atomic.AddInt64(&reasoningStreamSeq, 1))
|
|
||||||
progress("thinking_stream_start", " ", map[string]interface{}{
|
|
||||||
"streamId": reasoningStreamID,
|
|
||||||
"source": "eino",
|
|
||||||
"einoAgent": ev.AgentName,
|
|
||||||
"einoRole": einoRoleTag(ev.AgentName),
|
|
||||||
"orchestration": orchMode,
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
progress("thinking_stream_delta", reasoningDelta, map[string]interface{}{
|
if logger != nil {
|
||||||
"streamId": reasoningStreamID,
|
logger.Warn("eino stream recv error, flushing incomplete stream",
|
||||||
})
|
zap.Error(rerr),
|
||||||
|
zap.String("agent", ev.AgentName),
|
||||||
|
zap.Int("toolFragments", len(toolStreamFragments)))
|
||||||
|
}
|
||||||
|
streamRecvErr = rerr
|
||||||
|
break streamRecvLoop
|
||||||
}
|
}
|
||||||
}
|
if chunk == nil {
|
||||||
if chunk.Content != "" {
|
continue
|
||||||
if progress != nil && streamsMainAssistant(ev.AgentName) {
|
}
|
||||||
var contentDelta string
|
if progress != nil && strings.TrimSpace(chunk.ReasoningContent) != "" {
|
||||||
mainAssistantBuf, contentDelta = normalizeStreamingDelta(mainAssistantBuf, chunk.Content)
|
var reasoningDelta string
|
||||||
if contentDelta != "" {
|
reasoningBuf, reasoningDelta = normalizeStreamingDelta(reasoningBuf, chunk.ReasoningContent)
|
||||||
if !streamHeaderSent {
|
if reasoningDelta != "" {
|
||||||
progress("response_start", "", map[string]interface{}{
|
fullDisplay := openai.DisplayReasoningContent(reasoningBuf)
|
||||||
"conversationId": conversationID,
|
var displayDelta string
|
||||||
"mcpExecutionIds": snapshotMCPIDs(),
|
if strings.HasPrefix(fullDisplay, prevReasoningDisplay) {
|
||||||
"messageGeneratedBy": "eino:" + ev.AgentName,
|
displayDelta = fullDisplay[len(prevReasoningDisplay):]
|
||||||
"einoRole": "orchestrator",
|
} else {
|
||||||
"einoAgent": ev.AgentName,
|
displayDelta = fullDisplay
|
||||||
"orchestration": orchMode,
|
|
||||||
})
|
|
||||||
streamHeaderSent = true
|
|
||||||
}
|
}
|
||||||
progress("response_delta", contentDelta, map[string]interface{}{
|
prevReasoningDisplay = fullDisplay
|
||||||
"conversationId": conversationID,
|
if displayDelta != "" {
|
||||||
"mcpExecutionIds": snapshotMCPIDs(),
|
if reasoningStreamID == "" {
|
||||||
"einoRole": "orchestrator",
|
reasoningStreamID = fmt.Sprintf("eino-reasoning-%s-%d", conversationID, atomic.AddInt64(&reasoningStreamSeq, 1))
|
||||||
"einoAgent": ev.AgentName,
|
progress("reasoning_chain_stream_start", " ", map[string]interface{}{
|
||||||
"orchestration": orchMode,
|
"streamId": reasoningStreamID,
|
||||||
})
|
"source": "eino",
|
||||||
}
|
"einoAgent": ev.AgentName,
|
||||||
} else if !streamsMainAssistant(ev.AgentName) {
|
"einoRole": einoRoleTag(ev.AgentName),
|
||||||
var subDelta string
|
"orchestration": orchMode,
|
||||||
subAssistantBuf, subDelta = normalizeStreamingDelta(subAssistantBuf, chunk.Content)
|
|
||||||
if subDelta != "" {
|
|
||||||
if progress != nil {
|
|
||||||
if subReplyStreamID == "" {
|
|
||||||
subReplyStreamID = fmt.Sprintf("eino-sub-reply-%s-%d", conversationID, atomic.AddInt64(&einoSubReplyStreamSeq, 1))
|
|
||||||
progress("eino_agent_reply_stream_start", "", map[string]interface{}{
|
|
||||||
"streamId": subReplyStreamID,
|
|
||||||
"einoAgent": ev.AgentName,
|
|
||||||
"einoRole": "sub",
|
|
||||||
"conversationId": conversationID,
|
|
||||||
"source": "eino",
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
progress("eino_agent_reply_stream_delta", subDelta, map[string]interface{}{
|
progress("reasoning_chain_stream_delta", displayDelta, openai.WithSSEAccumulated(map[string]interface{}{
|
||||||
"streamId": subReplyStreamID,
|
"streamId": reasoningStreamID,
|
||||||
"conversationId": conversationID,
|
}, fullDisplay))
|
||||||
})
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
if chunk.Content != "" {
|
||||||
if len(chunk.ToolCalls) > 0 {
|
if progress != nil && streamsMainAssistant(ev.AgentName) {
|
||||||
toolStreamFragments = append(toolStreamFragments, chunk.ToolCalls...)
|
var contentDelta string
|
||||||
|
mainAssistantBuf, contentDelta = normalizeStreamingDelta(mainAssistantBuf, chunk.Content)
|
||||||
|
if contentDelta != "" {
|
||||||
|
if mainAssistDupTarget == "" {
|
||||||
|
executeStdoutDupMu.Lock()
|
||||||
|
if pendingExecuteStdoutDup != "" {
|
||||||
|
mainAssistDupTarget = pendingExecuteStdoutDup
|
||||||
|
}
|
||||||
|
executeStdoutDupMu.Unlock()
|
||||||
|
}
|
||||||
|
if mainAssistDupTarget != "" {
|
||||||
|
// 已展示过 tool_result,缓冲全文;EOF 后与 execute 输出相同则不再发助手流
|
||||||
|
} else {
|
||||||
|
if !streamHeaderSent {
|
||||||
|
progress("response_start", "", map[string]interface{}{
|
||||||
|
"conversationId": conversationID,
|
||||||
|
"mcpExecutionIds": snapshotMCPIDs(),
|
||||||
|
"messageGeneratedBy": "eino:" + ev.AgentName,
|
||||||
|
"einoRole": "orchestrator",
|
||||||
|
"einoAgent": ev.AgentName,
|
||||||
|
"orchestration": orchMode,
|
||||||
|
})
|
||||||
|
streamHeaderSent = true
|
||||||
|
}
|
||||||
|
progress("response_delta", contentDelta, openai.WithSSEAccumulated(map[string]interface{}{
|
||||||
|
"conversationId": conversationID,
|
||||||
|
"mcpExecutionIds": snapshotMCPIDs(),
|
||||||
|
"einoRole": "orchestrator",
|
||||||
|
"einoAgent": ev.AgentName,
|
||||||
|
"orchestration": orchMode,
|
||||||
|
}, mainAssistantBuf))
|
||||||
|
mainAssistWireAccum, _ = normalizeStreamingDelta(mainAssistWireAccum, contentDelta)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else if !streamsMainAssistant(ev.AgentName) {
|
||||||
|
var subDelta string
|
||||||
|
subAssistantBuf, subDelta = normalizeStreamingDelta(subAssistantBuf, chunk.Content)
|
||||||
|
if subDelta != "" {
|
||||||
|
if progress != nil {
|
||||||
|
if subReplyStreamID == "" {
|
||||||
|
subReplyStreamID = fmt.Sprintf("eino-sub-reply-%s-%d", conversationID, atomic.AddInt64(&einoSubReplyStreamSeq, 1))
|
||||||
|
progress("eino_agent_reply_stream_start", "", map[string]interface{}{
|
||||||
|
"streamId": subReplyStreamID,
|
||||||
|
"einoAgent": ev.AgentName,
|
||||||
|
"einoRole": "sub",
|
||||||
|
"conversationId": conversationID,
|
||||||
|
"source": "eino",
|
||||||
|
})
|
||||||
|
}
|
||||||
|
progress("eino_agent_reply_stream_delta", subDelta, openai.WithSSEAccumulated(map[string]interface{}{
|
||||||
|
"streamId": subReplyStreamID,
|
||||||
|
"conversationId": conversationID,
|
||||||
|
}, subAssistantBuf))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if len(chunk.ToolCalls) > 0 {
|
||||||
|
toolStreamFragments = append(toolStreamFragments, chunk.ToolCalls...)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if streamsMainAssistant(ev.AgentName) {
|
if streamsMainAssistant(ev.AgentName) {
|
||||||
if s := strings.TrimSpace(mainAssistantBuf); s != "" {
|
s := strings.TrimSpace(mainAssistantBuf)
|
||||||
|
if mainAssistDupTarget != "" {
|
||||||
|
executeStdoutDupMu.Lock()
|
||||||
|
pendingExecuteStdoutDup = ""
|
||||||
|
executeStdoutDupMu.Unlock()
|
||||||
|
if s != "" && s == mainAssistDupTarget {
|
||||||
|
// 与刚展示的 execute 结果完全一致:不再发助手流式事件,仍写入轨迹与最终回复字段
|
||||||
|
lastAssistant = s
|
||||||
|
runAccumulatedMsgs = append(runAccumulatedMsgs, schema.AssistantMessage(s, nil))
|
||||||
|
if orchMode == "plan_execute" && strings.EqualFold(strings.TrimSpace(ev.AgentName), "executor") {
|
||||||
|
lastPlanExecuteExecutor = UnwrapPlanExecuteUserText(s)
|
||||||
|
}
|
||||||
|
} else if s != "" {
|
||||||
|
if progress != nil {
|
||||||
|
// 仅用 TrimSpace 与 execute 比对;推到 UI 的必须是 mainAssistantBuf,
|
||||||
|
// 否则尾部空白/换行与已流式前缀不一致时,前端 normalize 会走拼接路径造成叠字。
|
||||||
|
_, eofTail := normalizeStreamingDelta(mainAssistWireAccum, mainAssistantBuf)
|
||||||
|
if eofTail != "" {
|
||||||
|
if !streamHeaderSent {
|
||||||
|
progress("response_start", "", map[string]interface{}{
|
||||||
|
"conversationId": conversationID,
|
||||||
|
"mcpExecutionIds": snapshotMCPIDs(),
|
||||||
|
"messageGeneratedBy": "eino:" + ev.AgentName,
|
||||||
|
"einoRole": "orchestrator",
|
||||||
|
"einoAgent": ev.AgentName,
|
||||||
|
"orchestration": orchMode,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
progress("response_delta", eofTail, openai.WithSSEAccumulated(map[string]interface{}{
|
||||||
|
"conversationId": conversationID,
|
||||||
|
"mcpExecutionIds": snapshotMCPIDs(),
|
||||||
|
"einoRole": "orchestrator",
|
||||||
|
"einoAgent": ev.AgentName,
|
||||||
|
"orchestration": orchMode,
|
||||||
|
}, mainAssistantBuf))
|
||||||
|
mainAssistWireAccum, _ = normalizeStreamingDelta(mainAssistWireAccum, eofTail)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
lastAssistant = s
|
||||||
|
runAccumulatedMsgs = append(runAccumulatedMsgs, schema.AssistantMessage(s, nil))
|
||||||
|
if orchMode == "plan_execute" && strings.EqualFold(strings.TrimSpace(ev.AgentName), "executor") {
|
||||||
|
lastPlanExecuteExecutor = UnwrapPlanExecuteUserText(s)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else if s != "" {
|
||||||
lastAssistant = s
|
lastAssistant = s
|
||||||
runAccumulatedMsgs = append(runAccumulatedMsgs, schema.AssistantMessage(s, nil))
|
runAccumulatedMsgs = append(runAccumulatedMsgs, schema.AssistantMessage(s, nil))
|
||||||
if orchMode == "plan_execute" && strings.EqualFold(strings.TrimSpace(ev.AgentName), "executor") {
|
if orchMode == "plan_execute" && strings.EqualFold(strings.TrimSpace(ev.AgentName), "executor") {
|
||||||
@@ -588,10 +802,17 @@ func runEinoADKAgentLoop(ctx context.Context, args *einoADKRunLoopArgs, baseMsgs
|
|||||||
}
|
}
|
||||||
var lastToolChunk *schema.Message
|
var lastToolChunk *schema.Message
|
||||||
if merged := mergeStreamingToolCallFragments(toolStreamFragments); len(merged) > 0 {
|
if merged := mergeStreamingToolCallFragments(toolStreamFragments); len(merged) > 0 {
|
||||||
lastToolChunk = &schema.Message{ToolCalls: merged}
|
lastToolChunk = mergeMessageToolCalls(&schema.Message{ToolCalls: merged})
|
||||||
|
}
|
||||||
|
tryEmitToolCallsOnce(lastToolChunk, ev.AgentName, orchestratorName, conversationID, orchMode, progress, toolEmitSeen, subAgentToolStep, mainAgentToolStep, markPending)
|
||||||
|
// 流式路径此前只把 tool_calls 推给进度 UI,未写入 runAccumulatedMsgs;落库后 loadHistory→RepairOrphan 会删掉全部 tool 结果,表现为「续跑/下轮失忆」。
|
||||||
|
if lastToolChunk != nil && len(lastToolChunk.ToolCalls) > 0 {
|
||||||
|
runAccumulatedMsgs = append(runAccumulatedMsgs, schema.AssistantMessage("", lastToolChunk.ToolCalls))
|
||||||
}
|
}
|
||||||
tryEmitToolCallsOnce(lastToolChunk, ev.AgentName, orchestratorName, conversationID, progress, toolEmitSeen, subAgentToolStep, markPending)
|
|
||||||
if streamRecvErr != nil {
|
if streamRecvErr != nil {
|
||||||
|
if isInterruptContinue(ctx) {
|
||||||
|
return takePartial(streamRecvErr)
|
||||||
|
}
|
||||||
if progress != nil {
|
if progress != nil {
|
||||||
progress("eino_stream_error", streamRecvErr.Error(), map[string]interface{}{
|
progress("eino_stream_error", streamRecvErr.Error(), map[string]interface{}{
|
||||||
"conversationId": conversationID,
|
"conversationId": conversationID,
|
||||||
@@ -612,11 +833,11 @@ func runEinoADKAgentLoop(ctx context.Context, args *einoADKRunLoopArgs, baseMsgs
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
runAccumulatedMsgs = append(runAccumulatedMsgs, msg)
|
runAccumulatedMsgs = append(runAccumulatedMsgs, msg)
|
||||||
tryEmitToolCallsOnce(mergeMessageToolCalls(msg), ev.AgentName, orchestratorName, conversationID, progress, toolEmitSeen, subAgentToolStep, markPending)
|
tryEmitToolCallsOnce(mergeMessageToolCalls(msg), ev.AgentName, orchestratorName, conversationID, orchMode, progress, toolEmitSeen, subAgentToolStep, mainAgentToolStep, markPending)
|
||||||
|
|
||||||
if mv.Role == schema.Assistant {
|
if mv.Role == schema.Assistant {
|
||||||
if progress != nil && strings.TrimSpace(msg.ReasoningContent) != "" {
|
if progress != nil && strings.TrimSpace(msg.ReasoningContent) != "" {
|
||||||
progress("thinking", strings.TrimSpace(msg.ReasoningContent), map[string]interface{}{
|
progress("reasoning_chain", openai.DisplayReasoningContent(strings.TrimSpace(msg.ReasoningContent)), map[string]interface{}{
|
||||||
"conversationId": conversationID,
|
"conversationId": conversationID,
|
||||||
"source": "eino",
|
"source": "eino",
|
||||||
"einoAgent": ev.AgentName,
|
"einoAgent": ev.AgentName,
|
||||||
@@ -627,26 +848,42 @@ func runEinoADKAgentLoop(ctx context.Context, args *einoADKRunLoopArgs, baseMsgs
|
|||||||
body := strings.TrimSpace(msg.Content)
|
body := strings.TrimSpace(msg.Content)
|
||||||
if body != "" {
|
if body != "" {
|
||||||
if streamsMainAssistant(ev.AgentName) {
|
if streamsMainAssistant(ev.AgentName) {
|
||||||
if progress != nil {
|
executeStdoutDupMu.Lock()
|
||||||
progress("response_start", "", map[string]interface{}{
|
dup := pendingExecuteStdoutDup
|
||||||
"conversationId": conversationID,
|
if dup != "" && body == dup {
|
||||||
"mcpExecutionIds": snapshotMCPIDs(),
|
pendingExecuteStdoutDup = ""
|
||||||
"messageGeneratedBy": "eino:" + ev.AgentName,
|
executeStdoutDupMu.Unlock()
|
||||||
"einoRole": "orchestrator",
|
lastAssistant = body
|
||||||
"einoAgent": ev.AgentName,
|
if orchMode == "plan_execute" && strings.EqualFold(strings.TrimSpace(ev.AgentName), "executor") {
|
||||||
"orchestration": orchMode,
|
lastPlanExecuteExecutor = UnwrapPlanExecuteUserText(body)
|
||||||
})
|
}
|
||||||
progress("response_delta", body, map[string]interface{}{
|
// 非流式:与 execute 输出相同则跳过助手通道展示(msg 已在上方写入 runAccumulatedMsgs)
|
||||||
"conversationId": conversationID,
|
} else {
|
||||||
"mcpExecutionIds": snapshotMCPIDs(),
|
if dup != "" {
|
||||||
"einoRole": "orchestrator",
|
pendingExecuteStdoutDup = ""
|
||||||
"einoAgent": ev.AgentName,
|
}
|
||||||
"orchestration": orchMode,
|
executeStdoutDupMu.Unlock()
|
||||||
})
|
if progress != nil {
|
||||||
}
|
progress("response_start", "", map[string]interface{}{
|
||||||
lastAssistant = body
|
"conversationId": conversationID,
|
||||||
if orchMode == "plan_execute" && strings.EqualFold(strings.TrimSpace(ev.AgentName), "executor") {
|
"mcpExecutionIds": snapshotMCPIDs(),
|
||||||
lastPlanExecuteExecutor = UnwrapPlanExecuteUserText(body)
|
"messageGeneratedBy": "eino:" + ev.AgentName,
|
||||||
|
"einoRole": "orchestrator",
|
||||||
|
"einoAgent": ev.AgentName,
|
||||||
|
"orchestration": orchMode,
|
||||||
|
})
|
||||||
|
progress("response_delta", body, openai.WithSSEAccumulated(map[string]interface{}{
|
||||||
|
"conversationId": conversationID,
|
||||||
|
"mcpExecutionIds": snapshotMCPIDs(),
|
||||||
|
"einoRole": "orchestrator",
|
||||||
|
"einoAgent": ev.AgentName,
|
||||||
|
"orchestration": orchMode,
|
||||||
|
}, body))
|
||||||
|
}
|
||||||
|
lastAssistant = body
|
||||||
|
if orchMode == "plan_execute" && strings.EqualFold(strings.TrimSpace(ev.AgentName), "executor") {
|
||||||
|
lastPlanExecuteExecutor = UnwrapPlanExecuteUserText(body)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
} else if progress != nil {
|
} else if progress != nil {
|
||||||
progress("eino_agent_reply", body, map[string]interface{}{
|
progress("eino_agent_reply", body, map[string]interface{}{
|
||||||
@@ -702,12 +939,19 @@ func runEinoADKAgentLoop(ctx context.Context, args *einoADKRunLoopArgs, baseMsgs
|
|||||||
break
|
break
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else {
|
|
||||||
removePendingByID(toolCallID)
|
|
||||||
}
|
}
|
||||||
if toolCallID != "" {
|
if toolCallID != "" {
|
||||||
|
removePendingByID(toolCallID)
|
||||||
|
if _, loaded := toolResultSent.LoadOrStore(toolCallID, struct{}{}); loaded {
|
||||||
|
// ToolInvokeNotify 可能已推过 tool_result(如 execute 流式包装里 Fire 仅携带截断后的 stdout),
|
||||||
|
// 此处仍应用 ADK Tool 消息中的完整内容刷新去重基准,避免模型复述全文时与截断串比对失败而重复展示「助手输出」。
|
||||||
|
recordPendingExecuteStdoutDup(toolName, content, isErr)
|
||||||
|
continue
|
||||||
|
}
|
||||||
data["toolCallId"] = toolCallID
|
data["toolCallId"] = toolCallID
|
||||||
}
|
}
|
||||||
|
recordPendingExecuteStdoutDup(toolName, content, isErr)
|
||||||
|
recordEinoADKFilesystemToolMonitor(args.FilesystemMonitorAgent, args.FilesystemMonitorRecord, toolName, toolCallID, runAccumulatedMsgs, content, isErr)
|
||||||
progress("tool_result", fmt.Sprintf("工具结果 (%s)", toolName), data)
|
progress("tool_result", fmt.Sprintf("工具结果 (%s)", toolName), data)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -717,26 +961,52 @@ func runEinoADKAgentLoop(ctx context.Context, args *einoADKRunLoopArgs, baseMsgs
|
|||||||
mcpIDsMu.Unlock()
|
mcpIDsMu.Unlock()
|
||||||
|
|
||||||
out := buildEinoRunResultFromAccumulated(
|
out := buildEinoRunResultFromAccumulated(
|
||||||
orchMode, runAccumulatedMsgs, lastAssistant, lastPlanExecuteExecutor, emptyHint, ids, false,
|
orchMode, runAccumulatedMsgs, persistTraceSource(args, runAccumulatedMsgs),
|
||||||
|
lastAssistant, lastPlanExecuteExecutor, emptyHint, ids, false,
|
||||||
)
|
)
|
||||||
return out, nil
|
return out, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func persistTraceSource(args *einoADKRunLoopArgs, fallback []adk.Message) []adk.Message {
|
||||||
|
if args != nil && args.ModelFacingTrace != nil {
|
||||||
|
if snap := args.ModelFacingTrace.Snapshot(); len(snap) > 0 {
|
||||||
|
return snap
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return fallback
|
||||||
|
}
|
||||||
|
|
||||||
func einoPartialRunLastOutputHint() string {
|
func einoPartialRunLastOutputHint() string {
|
||||||
return "[执行未正常结束(用户停止、超时或异常)。续跑时请基于上文已产生的工具与结果继续,勿重复已完成步骤。]\n" +
|
return "[执行未正常结束(用户停止、超时或异常)。续跑时请基于上文已产生的工具与结果继续,勿重复已完成步骤。]\n" +
|
||||||
"[Run ended abnormally; continue from the trace above without repeating completed steps.]"
|
"[Run ended abnormally; continue from the trace above without repeating completed steps.]"
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// friendlyEinoExecuteInvokeTail 将 Eino execute 等非 MCP 路径的结尾错误转成简短提示;其它情况保留原 error 文本。
|
||||||
|
func friendlyEinoExecuteInvokeTail(invokeErr error) string {
|
||||||
|
if invokeErr == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
if errors.Is(invokeErr, context.DeadlineExceeded) {
|
||||||
|
return einoExecuteTimeoutUserHint()
|
||||||
|
}
|
||||||
|
return "[执行未正常结束] " + invokeErr.Error()
|
||||||
|
}
|
||||||
|
|
||||||
func buildEinoRunResultFromAccumulated(
|
func buildEinoRunResultFromAccumulated(
|
||||||
orchMode string,
|
orchMode string,
|
||||||
runAccumulatedMsgs []adk.Message,
|
runAccumulatedMsgs []adk.Message,
|
||||||
|
persistMsgs []adk.Message,
|
||||||
lastAssistant string,
|
lastAssistant string,
|
||||||
lastPlanExecuteExecutor string,
|
lastPlanExecuteExecutor string,
|
||||||
emptyHint string,
|
emptyHint string,
|
||||||
mcpIDs []string,
|
mcpIDs []string,
|
||||||
partial bool,
|
partial bool,
|
||||||
) *RunResult {
|
) *RunResult {
|
||||||
histJSON, _ := json.Marshal(runAccumulatedMsgs)
|
traceForJSON := persistMsgs
|
||||||
|
if len(traceForJSON) == 0 {
|
||||||
|
traceForJSON = runAccumulatedMsgs
|
||||||
|
}
|
||||||
|
histJSON, _ := json.Marshal(traceForJSON)
|
||||||
cleaned := strings.TrimSpace(lastAssistant)
|
cleaned := strings.TrimSpace(lastAssistant)
|
||||||
if orchMode == "plan_execute" {
|
if orchMode == "plan_execute" {
|
||||||
if e := strings.TrimSpace(lastPlanExecuteExecutor); e != "" {
|
if e := strings.TrimSpace(lastPlanExecuteExecutor); e != "" {
|
||||||
@@ -745,6 +1015,11 @@ func buildEinoRunResultFromAccumulated(
|
|||||||
cleaned = UnwrapPlanExecuteUserText(cleaned)
|
cleaned = UnwrapPlanExecuteUserText(cleaned)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if cleaned == "" {
|
||||||
|
if fb := strings.TrimSpace(einoExtractFallbackAssistantFromMsgs(runAccumulatedMsgs)); fb != "" {
|
||||||
|
cleaned = fb
|
||||||
|
}
|
||||||
|
}
|
||||||
cleaned = dedupeRepeatedParagraphs(cleaned, 80)
|
cleaned = dedupeRepeatedParagraphs(cleaned, 80)
|
||||||
cleaned = dedupeParagraphsByLineFingerprint(cleaned, 100)
|
cleaned = dedupeParagraphsByLineFingerprint(cleaned, 100)
|
||||||
// 防止超长响应导致 JSON 序列化慢或 OOM(多代理拼接大量工具输出时可能触发)。
|
// 防止超长响应导致 JSON 序列化慢或 OOM(多代理拼接大量工具输出时可能触发)。
|
||||||
@@ -771,6 +1046,79 @@ func buildEinoRunResultFromAccumulated(
|
|||||||
return out
|
return out
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// einoExtractFallbackAssistantFromMsgs 在「主通道未产出助手正文」时,从 Eino ADK 轨迹中回填用户可见回复。
|
||||||
|
// 典型场景:监督者仅调用 exit(final_result 落在 Tool 消息中),或工具结果已写入历史但 lastAssistant 未更新。
|
||||||
|
//
|
||||||
|
// 优先级:最后一次 exit 工具输出 → 最后一条含 exit 的助手 tool_calls 参数中的 final_result。
|
||||||
|
func einoExtractFallbackAssistantFromMsgs(msgs []adk.Message) string {
|
||||||
|
for i := len(msgs) - 1; i >= 0; i-- {
|
||||||
|
m := msgs[i]
|
||||||
|
if m == nil || m.Role != schema.Tool {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if !strings.EqualFold(strings.TrimSpace(m.ToolName), adk.ToolInfoExit.Name) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
content := strings.TrimSpace(m.Content)
|
||||||
|
if content == "" || strings.HasPrefix(content, einomcp.ToolErrorPrefix) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
return content
|
||||||
|
}
|
||||||
|
for i := len(msgs) - 1; i >= 0; i-- {
|
||||||
|
m := msgs[i]
|
||||||
|
if m == nil || m.Role != schema.Assistant {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if s := einoExtractExitFinalFromAssistantToolCalls(m); s != "" {
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
func einoExtractExitFinalFromAssistantToolCalls(msg *schema.Message) string {
|
||||||
|
if msg == nil || len(msg.ToolCalls) == 0 {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
for i := len(msg.ToolCalls) - 1; i >= 0; i-- {
|
||||||
|
tc := msg.ToolCalls[i]
|
||||||
|
if !strings.EqualFold(strings.TrimSpace(tc.Function.Name), adk.ToolInfoExit.Name) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if s := einoParseExitFinalResultArguments(tc.Function.Arguments); s != "" {
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
func einoParseExitFinalResultArguments(arguments string) string {
|
||||||
|
arguments = strings.TrimSpace(arguments)
|
||||||
|
if arguments == "" {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
var wrap struct {
|
||||||
|
FinalResult json.RawMessage `json:"final_result"`
|
||||||
|
}
|
||||||
|
if err := json.Unmarshal([]byte(arguments), &wrap); err != nil || len(wrap.FinalResult) == 0 {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
var s string
|
||||||
|
if err := json.Unmarshal(wrap.FinalResult, &s); err == nil {
|
||||||
|
return strings.TrimSpace(s)
|
||||||
|
}
|
||||||
|
var anyVal interface{}
|
||||||
|
if err := json.Unmarshal(wrap.FinalResult, &anyVal); err != nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
b, err := json.Marshal(anyVal)
|
||||||
|
if err != nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return strings.TrimSpace(string(b))
|
||||||
|
}
|
||||||
|
|
||||||
func buildEinoCheckpointID(orchMode string) string {
|
func buildEinoCheckpointID(orchMode string) string {
|
||||||
mode := sanitizeEinoPathSegment(strings.TrimSpace(orchMode))
|
mode := sanitizeEinoPathSegment(strings.TrimSpace(orchMode))
|
||||||
if mode == "" {
|
if mode == "" {
|
||||||
|
|||||||
@@ -0,0 +1,31 @@
|
|||||||
|
package multiagent
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"cyberstrike-ai/internal/agent"
|
||||||
|
"cyberstrike-ai/internal/einomcp"
|
||||||
|
)
|
||||||
|
|
||||||
|
// newEinoExecuteMonitorCallback 在 Eino filesystem execute 结束时写入 MCP 监控库并 recorder(executionId),
|
||||||
|
// 与 CallTool 路径一致,供助手消息展示「渗透测试详情」芯片。
|
||||||
|
func newEinoExecuteMonitorCallback(ag *agent.Agent, recorder einomcp.ExecutionRecorder) func(command, stdout string, success bool, invokeErr error) {
|
||||||
|
return func(command, stdout string, success bool, invokeErr error) {
|
||||||
|
if ag == nil || recorder == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
var err error
|
||||||
|
if !success {
|
||||||
|
if invokeErr != nil {
|
||||||
|
err = invokeErr
|
||||||
|
} else {
|
||||||
|
err = fmt.Errorf("execute failed")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
args := map[string]interface{}{"command": command}
|
||||||
|
id := ag.RecordLocalToolExecution("execute", args, stdout, err)
|
||||||
|
if id != "" {
|
||||||
|
recorder(id)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -2,20 +2,58 @@ package multiagent
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"cyberstrike-ai/internal/einomcp"
|
||||||
"cyberstrike-ai/internal/security"
|
"cyberstrike-ai/internal/security"
|
||||||
|
|
||||||
"github.com/cloudwego/eino/adk/filesystem"
|
"github.com/cloudwego/eino/adk/filesystem"
|
||||||
|
"github.com/cloudwego/eino/compose"
|
||||||
"github.com/cloudwego/eino/schema"
|
"github.com/cloudwego/eino/schema"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// prependPythonUnbufferedEnv 为 /bin/sh -c 注入 PYTHONUNBUFFERED=1。
|
||||||
|
// eino-ext local 对流式 stdout 使用 bufio 按「行」推送;python3 写管道时默认块缓冲,print 长期留在用户态缓冲,
|
||||||
|
// 管道里收不到换行,表现为长时间无输出直至超时或退出。若命令里已出现 PYTHONUNBUFFERED 则不再覆盖。
|
||||||
|
func prependPythonUnbufferedEnv(shellCommand string) string {
|
||||||
|
if strings.TrimSpace(shellCommand) == "" {
|
||||||
|
return shellCommand
|
||||||
|
}
|
||||||
|
if strings.Contains(strings.ToUpper(shellCommand), "PYTHONUNBUFFERED") {
|
||||||
|
return shellCommand
|
||||||
|
}
|
||||||
|
return "export PYTHONUNBUFFERED=1\n" + shellCommand
|
||||||
|
}
|
||||||
|
|
||||||
|
// einoExecuteTimeoutUserHint 与写入 ADK 工具消息(模型可见)及 SSE tool_result 尾标一致。
|
||||||
|
func einoExecuteTimeoutUserHint() string {
|
||||||
|
return "已超时终止 · Timed out"
|
||||||
|
}
|
||||||
|
|
||||||
// einoStreamingShellWrap 包装 Eino filesystem 使用的 StreamingShell(cloudwego eino-ext local.Local)。
|
// einoStreamingShellWrap 包装 Eino filesystem 使用的 StreamingShell(cloudwego eino-ext local.Local)。
|
||||||
// 官方 execute 工具默认走 ExecuteStreaming 且不设 RunInBackendGround;末尾带 & 时子进程仍与管道相连,
|
// 官方 execute 工具默认走 ExecuteStreaming 且不设 RunInBackendGround;末尾带 & 时子进程仍与管道相连,
|
||||||
// streamStdout 按行读取会在无换行输出时长时间阻塞(与 MCP 工具 exec 的独立实现不同)。
|
// streamStdout 按行读取会在无换行输出时长时间阻塞(与 MCP 工具 exec 的独立实现不同)。
|
||||||
// 对「完全后台」命令自动开启 RunInBackendGround,与 local.runCmdInBackground 行为对齐。
|
// 对「完全后台」命令自动开启 RunInBackendGround,与 local.runCmdInBackground 行为对齐。
|
||||||
|
//
|
||||||
|
// 使用 Pipe 将内层流转发给调用方:在 inner EOF 后、关闭 Pipe 前同步调用 ToolInvokeNotify.Fire,
|
||||||
|
// 保证 run loop 在模型开始下一轮输出前已记录 execute 结果(用于 UI 与「重复助手复述」去重)。
|
||||||
|
//
|
||||||
|
// 若 inner 在校验阶段直接返回 error(未建立 reader),不会进入下方 goroutine,也必须 Fire;
|
||||||
|
// 否则 pending tool_call 要等整轮 run 结束才被 force-close,与已展示的助手/工具软错误文案不同步。
|
||||||
type einoStreamingShellWrap struct {
|
type einoStreamingShellWrap struct {
|
||||||
inner filesystem.StreamingShell
|
inner filesystem.StreamingShell
|
||||||
|
invokeNotify *einomcp.ToolInvokeNotifyHolder
|
||||||
|
einoAgentName string
|
||||||
|
// outputChunk 可选;非 nil 时在收到内层 ExecuteResponse 片段时推送,与 MCP 工具的 tool_result_delta 一致(需有效 toolCallId)。
|
||||||
|
outputChunk func(toolName, toolCallID, chunk string)
|
||||||
|
// toolTimeoutMinutes 与 agent.tool_timeout_minutes 对齐;>0 时对单次 execute 套用 context 超时(与 MCP 工具经 executeToolViaMCP 行为一致)。0 表示仅依赖上层 ctx(如整任务 10h 上限)。
|
||||||
|
toolTimeoutMinutes int
|
||||||
|
// recordMonitor 在 execute 流结束后写入 tool_executions 并 recorder(executionId),使「渗透测试详情」与常规 MCP 一致。
|
||||||
|
recordMonitor func(command, stdout string, success bool, invokeErr error)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (w *einoStreamingShellWrap) ExecuteStreaming(ctx context.Context, input *filesystem.ExecuteRequest) (*schema.StreamReader[*filesystem.ExecuteResponse], error) {
|
func (w *einoStreamingShellWrap) ExecuteStreaming(ctx context.Context, input *filesystem.ExecuteRequest) (*schema.StreamReader[*filesystem.ExecuteResponse], error) {
|
||||||
@@ -26,8 +64,123 @@ func (w *einoStreamingShellWrap) ExecuteStreaming(ctx context.Context, input *fi
|
|||||||
return w.inner.ExecuteStreaming(ctx, nil)
|
return w.inner.ExecuteStreaming(ctx, nil)
|
||||||
}
|
}
|
||||||
req := *input
|
req := *input
|
||||||
|
userCmd := strings.TrimSpace(req.Command)
|
||||||
if security.IsBackgroundShellCommand(req.Command) && !req.RunInBackendGround {
|
if security.IsBackgroundShellCommand(req.Command) && !req.RunInBackendGround {
|
||||||
req.RunInBackendGround = true
|
req.RunInBackendGround = true
|
||||||
}
|
}
|
||||||
return w.inner.ExecuteStreaming(ctx, &req)
|
req.Command = prependPythonUnbufferedEnv(req.Command)
|
||||||
|
tid := strings.TrimSpace(compose.GetToolCallID(ctx))
|
||||||
|
agentTag := strings.TrimSpace(w.einoAgentName)
|
||||||
|
|
||||||
|
execCtx := ctx
|
||||||
|
var execCancel context.CancelFunc
|
||||||
|
if w.toolTimeoutMinutes > 0 {
|
||||||
|
execCtx, execCancel = context.WithTimeout(ctx, time.Duration(w.toolTimeoutMinutes)*time.Minute)
|
||||||
|
}
|
||||||
|
|
||||||
|
sr, err := w.inner.ExecuteStreaming(execCtx, &req)
|
||||||
|
if err != nil {
|
||||||
|
if execCancel != nil {
|
||||||
|
execCancel()
|
||||||
|
}
|
||||||
|
if w.recordMonitor != nil {
|
||||||
|
w.recordMonitor(userCmd, "", false, err)
|
||||||
|
}
|
||||||
|
if w.invokeNotify != nil && tid != "" {
|
||||||
|
w.invokeNotify.Fire(tid, "execute", agentTag, false, "", err)
|
||||||
|
}
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if sr == nil || w.invokeNotify == nil || tid == "" {
|
||||||
|
if execCancel != nil {
|
||||||
|
execCancel()
|
||||||
|
}
|
||||||
|
return sr, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
outR, outW := schema.Pipe[*filesystem.ExecuteResponse](32)
|
||||||
|
|
||||||
|
go func(inner *schema.StreamReader[*filesystem.ExecuteResponse], command string, cancel context.CancelFunc, tctx context.Context) {
|
||||||
|
defer inner.Close()
|
||||||
|
if cancel != nil {
|
||||||
|
defer cancel()
|
||||||
|
}
|
||||||
|
|
||||||
|
var sb strings.Builder
|
||||||
|
const maxCapture = 16 * 1024
|
||||||
|
success := true
|
||||||
|
var invokeErr error
|
||||||
|
exitCode := 0
|
||||||
|
hasExitCode := false
|
||||||
|
|
||||||
|
for {
|
||||||
|
resp, rerr := inner.Recv()
|
||||||
|
if errors.Is(rerr, io.EOF) {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
if rerr != nil {
|
||||||
|
success = false
|
||||||
|
invokeErr = rerr
|
||||||
|
_ = outW.Send(nil, rerr)
|
||||||
|
break
|
||||||
|
}
|
||||||
|
if resp != nil {
|
||||||
|
if resp.ExitCode != nil {
|
||||||
|
hasExitCode = true
|
||||||
|
exitCode = *resp.ExitCode
|
||||||
|
}
|
||||||
|
var appended string
|
||||||
|
if remain := maxCapture - sb.Len(); remain > 0 {
|
||||||
|
out := resp.Output
|
||||||
|
if len(out) > remain {
|
||||||
|
out = out[:remain]
|
||||||
|
}
|
||||||
|
sb.WriteString(out)
|
||||||
|
appended = out
|
||||||
|
}
|
||||||
|
// 仅推送写入 sb 的片段,与末尾 Fire/recordMonitor 的截断累计一致,避免最终 tool_result 短于已展示增量。
|
||||||
|
if w.outputChunk != nil && strings.TrimSpace(appended) != "" {
|
||||||
|
w.outputChunk("execute", tid, appended)
|
||||||
|
}
|
||||||
|
if outW.Send(resp, nil) {
|
||||||
|
success = false
|
||||||
|
invokeErr = fmt.Errorf("execute stream closed by consumer")
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if success && hasExitCode && exitCode != 0 {
|
||||||
|
success = false
|
||||||
|
invokeErr = fmt.Errorf("execute exited with code %d", exitCode)
|
||||||
|
}
|
||||||
|
// WithTimeout 触发后,子进程常被信号结束,local 侧多报 exit -1 / canceled,错误链里不一定带 DeadlineExceeded。
|
||||||
|
// 用执行所用 ctx 归一化,便于 UI 展示「超时」而非含糊的 -1。
|
||||||
|
if tctx != nil && errors.Is(tctx.Err(), context.DeadlineExceeded) {
|
||||||
|
success = false
|
||||||
|
invokeErr = context.DeadlineExceeded
|
||||||
|
}
|
||||||
|
// ADK 从本 Pipe 拼出 tool 消息正文;仅 Notify 尾标不会进入模型上下文。超时句写入流,与 UI 一致。
|
||||||
|
if invokeErr != nil && errors.Is(invokeErr, context.DeadlineExceeded) {
|
||||||
|
hint := "\n\n" + einoExecuteTimeoutUserHint() + "\n"
|
||||||
|
_ = outW.Send(&filesystem.ExecuteResponse{Output: hint}, nil)
|
||||||
|
if w.outputChunk != nil && tid != "" {
|
||||||
|
w.outputChunk("execute", tid, hint)
|
||||||
|
}
|
||||||
|
if remain := maxCapture - sb.Len(); remain > 0 {
|
||||||
|
h := hint
|
||||||
|
if len(h) > remain {
|
||||||
|
h = h[:remain]
|
||||||
|
}
|
||||||
|
sb.WriteString(h)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if w.recordMonitor != nil {
|
||||||
|
w.recordMonitor(command, sb.String(), success, invokeErr)
|
||||||
|
}
|
||||||
|
w.invokeNotify.Fire(tid, "execute", agentTag, success, sb.String(), invokeErr)
|
||||||
|
outW.Close()
|
||||||
|
}(sr, userCmd, execCancel, execCtx)
|
||||||
|
|
||||||
|
return outR, nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -0,0 +1,62 @@
|
|||||||
|
package multiagent
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/cloudwego/eino/schema"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestEinoExtractFallbackAssistantFromMsgs_exitToolMessage(t *testing.T) {
|
||||||
|
u := schema.UserMessage("hi")
|
||||||
|
tm := schema.ToolMessage("answer for user", "call-exit-1")
|
||||||
|
tm.ToolName = "exit"
|
||||||
|
if got := einoExtractFallbackAssistantFromMsgs([]*schema.Message{u, tm}); got != "answer for user" {
|
||||||
|
t.Fatalf("got %q", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEinoExtractFallbackAssistantFromMsgs_lastExitWins(t *testing.T) {
|
||||||
|
msgs := []*schema.Message{
|
||||||
|
schema.UserMessage("hi"),
|
||||||
|
toolExitMsg("first", "c1"),
|
||||||
|
toolExitMsg("second", "c2"),
|
||||||
|
}
|
||||||
|
if got := einoExtractFallbackAssistantFromMsgs(msgs); got != "second" {
|
||||||
|
t.Fatalf("got %q", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEinoExtractFallbackAssistantFromMsgs_fromAssistantToolCalls(t *testing.T) {
|
||||||
|
m := schema.AssistantMessage("", []schema.ToolCall{{
|
||||||
|
ID: "x",
|
||||||
|
Type: "function",
|
||||||
|
Function: schema.FunctionCall{
|
||||||
|
Name: "exit",
|
||||||
|
Arguments: `{"final_result":"from args"}`,
|
||||||
|
},
|
||||||
|
}})
|
||||||
|
if got := einoExtractFallbackAssistantFromMsgs([]*schema.Message{m}); got != "from args" {
|
||||||
|
t.Fatalf("got %q", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEinoExtractFallbackAssistantFromMsgs_prefersToolOverEarlierAssistant(t *testing.T) {
|
||||||
|
asst := schema.AssistantMessage("", []schema.ToolCall{{
|
||||||
|
ID: "x",
|
||||||
|
Type: "function",
|
||||||
|
Function: schema.FunctionCall{
|
||||||
|
Name: "exit",
|
||||||
|
Arguments: `{"final_result":"from args"}`,
|
||||||
|
},
|
||||||
|
}})
|
||||||
|
tool := toolExitMsg("from tool", "c1")
|
||||||
|
if got := einoExtractFallbackAssistantFromMsgs([]*schema.Message{asst, tool}); got != "from tool" {
|
||||||
|
t.Fatalf("got %q", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func toolExitMsg(content, callID string) *schema.Message {
|
||||||
|
m := schema.ToolMessage(content, callID)
|
||||||
|
m.ToolName = "exit"
|
||||||
|
return m
|
||||||
|
}
|
||||||
@@ -0,0 +1,101 @@
|
|||||||
|
package multiagent
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"cyberstrike-ai/internal/agent"
|
||||||
|
"cyberstrike-ai/internal/einomcp"
|
||||||
|
|
||||||
|
"github.com/cloudwego/eino/adk"
|
||||||
|
"github.com/cloudwego/eino/schema"
|
||||||
|
)
|
||||||
|
|
||||||
|
// einoADKFilesystemToolNames 与 cloudwego/eino/adk/middlewares/filesystem 默认 ToolName* 一致。
|
||||||
|
// execute 已由 eino_execute_monitor 落库,此处不包含。
|
||||||
|
var einoADKFilesystemToolNames = map[string]struct{}{
|
||||||
|
"ls": {},
|
||||||
|
"read_file": {},
|
||||||
|
"write_file": {},
|
||||||
|
"edit_file": {},
|
||||||
|
"glob": {},
|
||||||
|
"grep": {},
|
||||||
|
}
|
||||||
|
|
||||||
|
func isBuiltinEinoADKFilesystemToolName(name string) bool {
|
||||||
|
n := strings.ToLower(strings.TrimSpace(name))
|
||||||
|
_, ok := einoADKFilesystemToolNames[n]
|
||||||
|
return ok
|
||||||
|
}
|
||||||
|
|
||||||
|
func toolCallArgsFromAccumulated(msgs []adk.Message, toolCallID, expectToolName string) map[string]interface{} {
|
||||||
|
tid := strings.TrimSpace(toolCallID)
|
||||||
|
expect := strings.TrimSpace(expectToolName)
|
||||||
|
for i := len(msgs) - 1; i >= 0; i-- {
|
||||||
|
m := msgs[i]
|
||||||
|
if m == nil || m.Role != schema.Assistant || len(m.ToolCalls) == 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
for j := len(m.ToolCalls) - 1; j >= 0; j-- {
|
||||||
|
tc := m.ToolCalls[j]
|
||||||
|
if tid != "" && strings.TrimSpace(tc.ID) != tid {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
fn := strings.TrimSpace(tc.Function.Name)
|
||||||
|
if expect != "" && !strings.EqualFold(fn, expect) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
raw := strings.TrimSpace(tc.Function.Arguments)
|
||||||
|
if raw == "" {
|
||||||
|
return map[string]interface{}{}
|
||||||
|
}
|
||||||
|
var args map[string]interface{}
|
||||||
|
if err := json.Unmarshal([]byte(raw), &args); err != nil {
|
||||||
|
return map[string]interface{}{"arguments_raw": raw}
|
||||||
|
}
|
||||||
|
if args == nil {
|
||||||
|
return map[string]interface{}{}
|
||||||
|
}
|
||||||
|
return args
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return map[string]interface{}{}
|
||||||
|
}
|
||||||
|
|
||||||
|
// recordEinoADKFilesystemToolMonitor 将 Eino ADK filesystem 中间件工具结果写入 MCP 监控(与 execute / MCP 桥芯片一致)。
|
||||||
|
func recordEinoADKFilesystemToolMonitor(
|
||||||
|
ag *agent.Agent,
|
||||||
|
rec einomcp.ExecutionRecorder,
|
||||||
|
toolName string,
|
||||||
|
toolCallID string,
|
||||||
|
msgs []adk.Message,
|
||||||
|
resultText string,
|
||||||
|
isErr bool,
|
||||||
|
) {
|
||||||
|
if ag == nil || rec == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
name := strings.TrimSpace(toolName)
|
||||||
|
if name == "" || strings.EqualFold(name, "execute") {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if !isBuiltinEinoADKFilesystemToolName(name) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
args := toolCallArgsFromAccumulated(msgs, toolCallID, name)
|
||||||
|
storedName := "eino_fs::" + strings.ToLower(name)
|
||||||
|
var invErr error
|
||||||
|
if isErr {
|
||||||
|
t := strings.TrimSpace(resultText)
|
||||||
|
if t == "" {
|
||||||
|
invErr = errors.New("tool error")
|
||||||
|
} else {
|
||||||
|
invErr = errors.New(t)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
id := ag.RecordLocalToolExecution(storedName, args, resultText, invErr)
|
||||||
|
if id != "" {
|
||||||
|
rec(id)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -161,6 +161,8 @@ func buildReductionMiddleware(ctx context.Context, mw config.MultiAgentEinoMiddl
|
|||||||
}
|
}
|
||||||
|
|
||||||
// prependEinoMiddlewares returns handlers to prepend (outermost first) and optionally replaces tools when tool_search is used.
|
// prependEinoMiddlewares returns handlers to prepend (outermost first) and optionally replaces tools when tool_search is used.
|
||||||
|
// toolSearchActive is true when the toolsearch middleware was mounted (dynamic tools split off); callers should pass this to
|
||||||
|
// injectToolNamesOnlyInstruction — tool_search is not part of the pre-middleware tools list, so name-scanning alone cannot detect it.
|
||||||
func prependEinoMiddlewares(
|
func prependEinoMiddlewares(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
mw *config.MultiAgentEinoMiddlewareConfig,
|
mw *config.MultiAgentEinoMiddlewareConfig,
|
||||||
@@ -170,16 +172,16 @@ func prependEinoMiddlewares(
|
|||||||
skillsRoot string,
|
skillsRoot string,
|
||||||
conversationID string,
|
conversationID string,
|
||||||
logger *zap.Logger,
|
logger *zap.Logger,
|
||||||
) (outTools []tool.BaseTool, extraHandlers []adk.ChatModelAgentMiddleware, err error) {
|
) (outTools []tool.BaseTool, extraHandlers []adk.ChatModelAgentMiddleware, toolSearchActive bool, err error) {
|
||||||
if mw == nil {
|
if mw == nil {
|
||||||
return tools, nil, nil
|
return tools, nil, false, nil
|
||||||
}
|
}
|
||||||
outTools = tools
|
outTools = tools
|
||||||
|
|
||||||
if mw.PatchToolCallsEffective() {
|
if mw.PatchToolCallsEffective() {
|
||||||
patchMW, perr := patchtoolcalls.New(ctx, &patchtoolcalls.Config{})
|
patchMW, perr := patchtoolcalls.New(ctx, &patchtoolcalls.Config{})
|
||||||
if perr != nil {
|
if perr != nil {
|
||||||
return nil, nil, fmt.Errorf("patchtoolcalls: %w", perr)
|
return nil, nil, false, fmt.Errorf("patchtoolcalls: %w", perr)
|
||||||
}
|
}
|
||||||
extraHandlers = append(extraHandlers, patchMW)
|
extraHandlers = append(extraHandlers, patchMW)
|
||||||
}
|
}
|
||||||
@@ -190,7 +192,7 @@ func prependEinoMiddlewares(
|
|||||||
} else {
|
} else {
|
||||||
redMW, rerr := buildReductionMiddleware(ctx, *mw, conversationID, einoLoc, logger)
|
redMW, rerr := buildReductionMiddleware(ctx, *mw, conversationID, einoLoc, logger)
|
||||||
if rerr != nil {
|
if rerr != nil {
|
||||||
return nil, nil, rerr
|
return nil, nil, false, rerr
|
||||||
}
|
}
|
||||||
extraHandlers = append(extraHandlers, redMW)
|
extraHandlers = append(extraHandlers, redMW)
|
||||||
}
|
}
|
||||||
@@ -209,10 +211,11 @@ func prependEinoMiddlewares(
|
|||||||
if split && len(dynamic) > 0 {
|
if split && len(dynamic) > 0 {
|
||||||
ts, terr := toolsearch.New(ctx, &toolsearch.Config{DynamicTools: dynamic})
|
ts, terr := toolsearch.New(ctx, &toolsearch.Config{DynamicTools: dynamic})
|
||||||
if terr != nil {
|
if terr != nil {
|
||||||
return nil, nil, fmt.Errorf("toolsearch: %w", terr)
|
return nil, nil, false, fmt.Errorf("toolsearch: %w", terr)
|
||||||
}
|
}
|
||||||
extraHandlers = append(extraHandlers, ts)
|
extraHandlers = append(extraHandlers, ts)
|
||||||
outTools = static
|
outTools = static
|
||||||
|
toolSearchActive = true
|
||||||
if logger != nil {
|
if logger != nil {
|
||||||
logger.Info("eino middleware: tool_search enabled",
|
logger.Info("eino middleware: tool_search enabled",
|
||||||
zap.Int("static_tools", len(static)),
|
zap.Int("static_tools", len(static)),
|
||||||
@@ -233,12 +236,12 @@ func prependEinoMiddlewares(
|
|||||||
}
|
}
|
||||||
baseDir := filepath.Join(skillsRoot, rel, sanitizeEinoPathSegment(conversationID))
|
baseDir := filepath.Join(skillsRoot, rel, sanitizeEinoPathSegment(conversationID))
|
||||||
if mk := os.MkdirAll(baseDir, 0o755); mk != nil {
|
if mk := os.MkdirAll(baseDir, 0o755); mk != nil {
|
||||||
return nil, nil, fmt.Errorf("plantask mkdir: %w", mk)
|
return nil, nil, toolSearchActive, fmt.Errorf("plantask mkdir: %w", mk)
|
||||||
}
|
}
|
||||||
ptBE := &localPlantaskBackend{Local: einoLoc}
|
ptBE := &localPlantaskBackend{Local: einoLoc}
|
||||||
pt, perr := plantask.New(ctx, &plantask.Config{Backend: ptBE, BaseDir: baseDir})
|
pt, perr := plantask.New(ctx, &plantask.Config{Backend: ptBE, BaseDir: baseDir})
|
||||||
if perr != nil {
|
if perr != nil {
|
||||||
return nil, nil, fmt.Errorf("plantask: %w", perr)
|
return nil, nil, toolSearchActive, fmt.Errorf("plantask: %w", perr)
|
||||||
}
|
}
|
||||||
extraHandlers = append(extraHandlers, pt)
|
extraHandlers = append(extraHandlers, pt)
|
||||||
if logger != nil {
|
if logger != nil {
|
||||||
@@ -247,7 +250,7 @@ func prependEinoMiddlewares(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return outTools, extraHandlers, nil
|
return outTools, extraHandlers, toolSearchActive, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func deepExtrasFromConfig(ma *config.MultiAgentConfig) (outputKey string, retry *adk.ModelRetryConfig, taskDesc func(context.Context, []adk.Agent) (string, error)) {
|
func deepExtrasFromConfig(ma *config.MultiAgentConfig) (outputKey string, retry *adk.ModelRetryConfig, taskDesc func(context.Context, []adk.Agent) (string, error)) {
|
||||||
|
|||||||
@@ -0,0 +1,84 @@
|
|||||||
|
package multiagent
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
"github.com/cloudwego/eino/adk"
|
||||||
|
)
|
||||||
|
|
||||||
|
// modelFacingTraceHolder 保存「即将送入 ChatModel」的消息快照(已走 summarization / reduction / orphan 修剪等),
|
||||||
|
// 用于 last_react_input 落库,使续跑与「上下文压缩后」的模型视角一致,而非仅依赖事件流 append 的 runAccumulatedMsgs。
|
||||||
|
type modelFacingTraceHolder struct {
|
||||||
|
mu sync.Mutex
|
||||||
|
// msgs 为深拷贝后的切片,避免框架后续原地修改污染快照
|
||||||
|
msgs []adk.Message
|
||||||
|
}
|
||||||
|
|
||||||
|
func newModelFacingTraceHolder() *modelFacingTraceHolder {
|
||||||
|
return &modelFacingTraceHolder{}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Snapshot 返回当前快照的再一次深拷贝(供序列化落库,避免与 holder 互斥长期持锁)。
|
||||||
|
func (h *modelFacingTraceHolder) Snapshot() []adk.Message {
|
||||||
|
if h == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
h.mu.Lock()
|
||||||
|
defer h.mu.Unlock()
|
||||||
|
return cloneADKMessagesForTrace(h.msgs)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *modelFacingTraceHolder) storeFromState(state *adk.ChatModelAgentState) {
|
||||||
|
if h == nil || state == nil || len(state.Messages) == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
cloned := cloneADKMessagesForTrace(state.Messages)
|
||||||
|
if len(cloned) == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
h.mu.Lock()
|
||||||
|
h.msgs = cloned
|
||||||
|
h.mu.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
func cloneADKMessagesForTrace(msgs []adk.Message) []adk.Message {
|
||||||
|
if len(msgs) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
b, err := json.Marshal(msgs)
|
||||||
|
if err != nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
var out []adk.Message
|
||||||
|
if err := json.Unmarshal(b, &out); err != nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
// modelFacingTraceMiddleware 必须在 Handlers 链中处于 **BeforeModel 最后**(telemetry 之后),
|
||||||
|
// 此时 state.Messages 即为本次 LLM 调用的最终入参。
|
||||||
|
type modelFacingTraceMiddleware struct {
|
||||||
|
adk.BaseChatModelAgentMiddleware
|
||||||
|
holder *modelFacingTraceHolder
|
||||||
|
}
|
||||||
|
|
||||||
|
func newModelFacingTraceMiddleware(holder *modelFacingTraceHolder) adk.ChatModelAgentMiddleware {
|
||||||
|
if holder == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return &modelFacingTraceMiddleware{holder: holder}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *modelFacingTraceMiddleware) BeforeModelRewriteState(
|
||||||
|
ctx context.Context,
|
||||||
|
state *adk.ChatModelAgentState,
|
||||||
|
mc *adk.ModelContext,
|
||||||
|
) (context.Context, *adk.ChatModelAgentState, error) {
|
||||||
|
if m.holder != nil && state != nil {
|
||||||
|
m.holder.storeFromState(state)
|
||||||
|
}
|
||||||
|
return ctx, state, nil
|
||||||
|
}
|
||||||
@@ -41,6 +41,8 @@ type PlanExecuteRootArgs struct {
|
|||||||
FilesystemMiddleware adk.ChatModelAgentMiddleware
|
FilesystemMiddleware adk.ChatModelAgentMiddleware
|
||||||
// PlannerReplannerRewriteHandlers applies BeforeModelRewriteState pipeline for planner/replanner input.
|
// PlannerReplannerRewriteHandlers applies BeforeModelRewriteState pipeline for planner/replanner input.
|
||||||
PlannerReplannerRewriteHandlers []adk.ChatModelAgentMiddleware
|
PlannerReplannerRewriteHandlers []adk.ChatModelAgentMiddleware
|
||||||
|
// ModelFacingTrace 可选:由 Executor Handlers 链末尾写入,供 last_react 与 summarization 后上下文对齐。
|
||||||
|
ModelFacingTrace *modelFacingTraceHolder
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewPlanExecuteRoot 返回 plan → execute → replan 预置编排根节点(与 Deep / Supervisor 并列)。
|
// NewPlanExecuteRoot 返回 plan → execute → replan 预置编排根节点(与 Deep / Supervisor 并列)。
|
||||||
@@ -101,6 +103,11 @@ func NewPlanExecuteRoot(ctx context.Context, a *PlanExecuteRootArgs) (adk.Resuma
|
|||||||
if teleMw := newEinoModelInputTelemetryMiddleware(a.Logger, a.ModelName, a.ConversationID, "plan_execute_executor"); teleMw != nil {
|
if teleMw := newEinoModelInputTelemetryMiddleware(a.Logger, a.ModelName, a.ConversationID, "plan_execute_executor"); teleMw != nil {
|
||||||
execHandlers = append(execHandlers, teleMw)
|
execHandlers = append(execHandlers, teleMw)
|
||||||
}
|
}
|
||||||
|
if a.ModelFacingTrace != nil {
|
||||||
|
if capMw := newModelFacingTraceMiddleware(a.ModelFacingTrace); capMw != nil {
|
||||||
|
execHandlers = append(execHandlers, capMw)
|
||||||
|
}
|
||||||
|
}
|
||||||
executor, err := newPlanExecuteExecutor(ctx, &planexecute.ExecutorConfig{
|
executor, err := newPlanExecuteExecutor(ctx, &planexecute.ExecutorConfig{
|
||||||
Model: a.ExecModel,
|
Model: a.ExecModel,
|
||||||
ToolsConfig: a.ToolsCfg,
|
ToolsConfig: a.ToolsCfg,
|
||||||
|
|||||||
@@ -13,6 +13,7 @@ import (
|
|||||||
"cyberstrike-ai/internal/config"
|
"cyberstrike-ai/internal/config"
|
||||||
"cyberstrike-ai/internal/einomcp"
|
"cyberstrike-ai/internal/einomcp"
|
||||||
"cyberstrike-ai/internal/openai"
|
"cyberstrike-ai/internal/openai"
|
||||||
|
"cyberstrike-ai/internal/reasoning"
|
||||||
|
|
||||||
einoopenai "github.com/cloudwego/eino-ext/components/model/openai"
|
einoopenai "github.com/cloudwego/eino-ext/components/model/openai"
|
||||||
"github.com/cloudwego/eino/adk"
|
"github.com/cloudwego/eino/adk"
|
||||||
@@ -37,6 +38,7 @@ func RunEinoSingleChatModelAgent(
|
|||||||
history []agent.ChatMessage,
|
history []agent.ChatMessage,
|
||||||
roleTools []string,
|
roleTools []string,
|
||||||
progress func(eventType, message string, data interface{}),
|
progress func(eventType, message string, data interface{}),
|
||||||
|
reasoningClient *reasoning.ClientIntent,
|
||||||
) (*RunResult, error) {
|
) (*RunResult, error) {
|
||||||
if appCfg == nil || ag == nil {
|
if appCfg == nil || ag == nil {
|
||||||
return nil, fmt.Errorf("eino single: 配置或 Agent 为空")
|
return nil, fmt.Errorf("eino single: 配置或 Agent 为空")
|
||||||
@@ -86,13 +88,15 @@ func RunEinoSingleChatModelAgent(
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
toolInvokeNotify := einomcp.NewToolInvokeNotifyHolder()
|
||||||
|
einoExecMonitor := newEinoExecuteMonitorCallback(ag, recorder)
|
||||||
mainDefs := ag.ToolsForRole(roleTools)
|
mainDefs := ag.ToolsForRole(roleTools)
|
||||||
mainTools, err := einomcp.ToolsFromDefinitions(ag, holder, mainDefs, recorder, toolOutputChunk)
|
mainTools, err := einomcp.ToolsFromDefinitions(ag, holder, mainDefs, recorder, toolOutputChunk, toolInvokeNotify, einoSingleAgentName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
mainToolsForCfg, mainOrchestratorPre, err := prependEinoMiddlewares(ctx, &ma.EinoMiddleware, einoMWMain, mainTools, einoLoc, skillsRoot, conversationID, logger)
|
mainToolsForCfg, mainOrchestratorPre, singleToolSearchActive, err := prependEinoMiddlewares(ctx, &ma.EinoMiddleware, einoMWMain, mainTools, einoLoc, skillsRoot, conversationID, logger)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("eino single eino 中间件: %w", err)
|
return nil, fmt.Errorf("eino single eino 中间件: %w", err)
|
||||||
}
|
}
|
||||||
@@ -119,6 +123,7 @@ func RunEinoSingleChatModelAgent(
|
|||||||
Model: appCfg.OpenAI.Model,
|
Model: appCfg.OpenAI.Model,
|
||||||
HTTPClient: httpClient,
|
HTTPClient: httpClient,
|
||||||
}
|
}
|
||||||
|
reasoning.ApplyToEinoChatModelConfig(baseModelCfg, &appCfg.OpenAI, reasoningClient)
|
||||||
|
|
||||||
mainModel, err := einoopenai.NewChatModel(ctx, baseModelCfg)
|
mainModel, err := einoopenai.NewChatModel(ctx, baseModelCfg)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -130,13 +135,15 @@ func RunEinoSingleChatModelAgent(
|
|||||||
return nil, fmt.Errorf("eino single summarization: %w", err)
|
return nil, fmt.Errorf("eino single summarization: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
handlers := make([]adk.ChatModelAgentMiddleware, 0, 4)
|
modelFacingTrace := newModelFacingTraceHolder()
|
||||||
|
|
||||||
|
handlers := make([]adk.ChatModelAgentMiddleware, 0, 8)
|
||||||
if len(mainOrchestratorPre) > 0 {
|
if len(mainOrchestratorPre) > 0 {
|
||||||
handlers = append(handlers, mainOrchestratorPre...)
|
handlers = append(handlers, mainOrchestratorPre...)
|
||||||
}
|
}
|
||||||
if einoSkillMW != nil {
|
if einoSkillMW != nil {
|
||||||
if einoFSTools && einoLoc != nil {
|
if einoFSTools && einoLoc != nil {
|
||||||
fsMw, fsErr := subAgentFilesystemMiddleware(ctx, einoLoc)
|
fsMw, fsErr := subAgentFilesystemMiddleware(ctx, einoLoc, toolInvokeNotify, einoSingleAgentName, einoExecMonitor, agentToolTimeoutMinutes(appCfg), toolOutputChunk)
|
||||||
if fsErr != nil {
|
if fsErr != nil {
|
||||||
return nil, fmt.Errorf("eino single filesystem 中间件: %w", fsErr)
|
return nil, fmt.Errorf("eino single filesystem 中间件: %w", fsErr)
|
||||||
}
|
}
|
||||||
@@ -148,6 +155,9 @@ func RunEinoSingleChatModelAgent(
|
|||||||
if teleMw := newEinoModelInputTelemetryMiddleware(logger, appCfg.OpenAI.Model, conversationID, "eino_single"); teleMw != nil {
|
if teleMw := newEinoModelInputTelemetryMiddleware(logger, appCfg.OpenAI.Model, conversationID, "eino_single"); teleMw != nil {
|
||||||
handlers = append(handlers, teleMw)
|
handlers = append(handlers, teleMw)
|
||||||
}
|
}
|
||||||
|
if capMw := newModelFacingTraceMiddleware(modelFacingTrace); capMw != nil {
|
||||||
|
handlers = append(handlers, capMw)
|
||||||
|
}
|
||||||
|
|
||||||
maxIter := ma.MaxIteration
|
maxIter := ma.MaxIteration
|
||||||
if maxIter <= 0 {
|
if maxIter <= 0 {
|
||||||
@@ -162,28 +172,21 @@ func RunEinoSingleChatModelAgent(
|
|||||||
Tools: mainToolsForCfg,
|
Tools: mainToolsForCfg,
|
||||||
UnknownToolsHandler: einomcp.UnknownToolReminderHandler(),
|
UnknownToolsHandler: einomcp.UnknownToolReminderHandler(),
|
||||||
ToolCallMiddlewares: []compose.ToolMiddleware{
|
ToolCallMiddlewares: []compose.ToolMiddleware{
|
||||||
{Invokable: hitlToolCallMiddleware()},
|
hitlToolCallMiddleware(),
|
||||||
{Invokable: softRecoveryToolCallMiddleware()},
|
softRecoveryToolMiddleware(),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
EmitInternalEvents: true,
|
EmitInternalEvents: true,
|
||||||
}
|
}
|
||||||
ins := injectToolNamesOnlyInstruction(ctx, ag.EinoSingleAgentSystemInstruction(), mainTools)
|
ins := injectToolNamesOnlyInstruction(ctx, ag.EinoSingleAgentSystemInstruction(), mainTools, singleToolSearchActive)
|
||||||
if logger != nil {
|
if logger != nil {
|
||||||
names := collectToolNames(ctx, mainTools)
|
names := collectToolNames(ctx, mainTools)
|
||||||
mountedNames := collectToolNames(ctx, mainToolsForCfg)
|
mountedNames := collectToolNames(ctx, mainToolsForCfg)
|
||||||
hasToolSearch := false
|
|
||||||
for _, n := range names {
|
|
||||||
if strings.EqualFold(strings.TrimSpace(n), "tool_search") {
|
|
||||||
hasToolSearch = true
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
logger.Info("eino tool-name injection",
|
logger.Info("eino tool-name injection",
|
||||||
zap.String("scope", "eino_single"),
|
zap.String("scope", "eino_single"),
|
||||||
zap.Int("tool_names", len(names)),
|
zap.Int("tool_names", len(names)),
|
||||||
zap.Int("mounted_tool_names", len(mountedNames)),
|
zap.Int("mounted_tool_names", len(mountedNames)),
|
||||||
zap.Bool("has_tool_search", hasToolSearch),
|
zap.Bool("tool_search_middleware", singleToolSearchActive),
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -221,18 +224,23 @@ func RunEinoSingleChatModelAgent(
|
|||||||
}
|
}
|
||||||
|
|
||||||
return runEinoADKAgentLoop(ctx, &einoADKRunLoopArgs{
|
return runEinoADKAgentLoop(ctx, &einoADKRunLoopArgs{
|
||||||
OrchMode: "eino_single",
|
OrchMode: "eino_single",
|
||||||
OrchestratorName: einoSingleAgentName,
|
OrchestratorName: einoSingleAgentName,
|
||||||
ConversationID: conversationID,
|
ConversationID: conversationID,
|
||||||
Progress: progress,
|
Progress: progress,
|
||||||
Logger: logger,
|
Logger: logger,
|
||||||
SnapshotMCPIDs: snapshotMCPIDs,
|
SnapshotMCPIDs: snapshotMCPIDs,
|
||||||
StreamsMainAssistant: streamsMainAssistant,
|
StreamsMainAssistant: streamsMainAssistant,
|
||||||
EinoRoleTag: einoRoleTag,
|
EinoRoleTag: einoRoleTag,
|
||||||
CheckpointDir: ma.EinoMiddleware.CheckpointDir,
|
CheckpointDir: ma.EinoMiddleware.CheckpointDir,
|
||||||
McpIDsMu: &mcpIDsMu,
|
McpIDsMu: &mcpIDsMu,
|
||||||
McpIDs: &mcpIDs,
|
McpIDs: &mcpIDs,
|
||||||
DA: chatAgent,
|
FilesystemMonitorAgent: ag,
|
||||||
|
FilesystemMonitorRecord: recorder,
|
||||||
|
ToolInvokeNotify: toolInvokeNotify,
|
||||||
|
DA: chatAgent,
|
||||||
|
ModelFacingTrace: modelFacingTrace,
|
||||||
|
EinoCallbacks: &ma.EinoCallbacks,
|
||||||
EmptyResponseMessage: "(Eino ADK single-agent session completed but no assistant text was captured. Check process details or logs.) " +
|
EmptyResponseMessage: "(Eino ADK single-agent session completed but no assistant text was captured. Check process details or logs.) " +
|
||||||
"(Eino ADK 单代理会话已完成,但未捕获到助手文本输出。请查看过程详情或日志。)",
|
"(Eino ADK 单代理会话已完成,但未捕获到助手文本输出。请查看过程详情或日志。)",
|
||||||
}, baseMsgs)
|
}, baseMsgs)
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"cyberstrike-ai/internal/config"
|
"cyberstrike-ai/internal/config"
|
||||||
|
"cyberstrike-ai/internal/einomcp"
|
||||||
|
|
||||||
localbk "github.com/cloudwego/eino-ext/adk/backend/local"
|
localbk "github.com/cloudwego/eino-ext/adk/backend/local"
|
||||||
"github.com/cloudwego/eino/adk"
|
"github.com/cloudwego/eino/adk"
|
||||||
@@ -75,12 +76,35 @@ func prepareEinoSkills(
|
|||||||
// subAgentFilesystemMiddleware returns filesystem middleware for a sub-agent when Deep itself
|
// subAgentFilesystemMiddleware returns filesystem middleware for a sub-agent when Deep itself
|
||||||
// does not set Backend (fsTools false on orchestrator) but we still want tools on subs — not used;
|
// does not set Backend (fsTools false on orchestrator) but we still want tools on subs — not used;
|
||||||
// when orchestrator has Backend, builtin FS is only on outer agent; subs need explicit FS for parity.
|
// when orchestrator has Backend, builtin FS is only on outer agent; subs need explicit FS for parity.
|
||||||
func subAgentFilesystemMiddleware(ctx context.Context, loc *localbk.Local) (adk.ChatModelAgentMiddleware, error) {
|
func subAgentFilesystemMiddleware(
|
||||||
|
ctx context.Context,
|
||||||
|
loc *localbk.Local,
|
||||||
|
invokeNotify *einomcp.ToolInvokeNotifyHolder,
|
||||||
|
einoAgentName string,
|
||||||
|
recordMonitor func(command, stdout string, success bool, invokeErr error),
|
||||||
|
toolTimeoutMinutes int,
|
||||||
|
outputChunk func(toolName, toolCallID, chunk string),
|
||||||
|
) (adk.ChatModelAgentMiddleware, error) {
|
||||||
if loc == nil {
|
if loc == nil {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
return filesystem.New(ctx, &filesystem.MiddlewareConfig{
|
return filesystem.New(ctx, &filesystem.MiddlewareConfig{
|
||||||
Backend: loc,
|
Backend: loc,
|
||||||
StreamingShell: &einoStreamingShellWrap{inner: loc},
|
StreamingShell: &einoStreamingShellWrap{
|
||||||
|
inner: loc,
|
||||||
|
invokeNotify: invokeNotify,
|
||||||
|
einoAgentName: strings.TrimSpace(einoAgentName),
|
||||||
|
outputChunk: outputChunk,
|
||||||
|
recordMonitor: recordMonitor,
|
||||||
|
toolTimeoutMinutes: toolTimeoutMinutes,
|
||||||
|
},
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// agentToolTimeoutMinutes 返回 agent.tool_timeout_minutes(与 executeToolViaMCP 一致);cfg 为 nil 时 0。
|
||||||
|
func agentToolTimeoutMinutes(cfg *config.Config) int {
|
||||||
|
if cfg == nil {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
return cfg.Agent.ToolTimeoutMinutes
|
||||||
|
}
|
||||||
|
|||||||
@@ -214,7 +214,7 @@ func summarizeFinalizeWithRecentAssistantToolTrail(
|
|||||||
selectedCount++
|
selectedCount++
|
||||||
}
|
}
|
||||||
|
|
||||||
// 还原时间顺序
|
// 还原时间顺序。round 内为原始 *schema.Message 指针,保留 ReasoningContent(DeepSeek 工具续跑所必需)。
|
||||||
selectedMsgs := make([]adk.Message, 0, 8)
|
selectedMsgs := make([]adk.Message, 0, 8)
|
||||||
for i := len(selectedRoundsReverse) - 1; i >= 0; i-- {
|
for i := len(selectedRoundsReverse) - 1; i >= 0; i-- {
|
||||||
selectedMsgs = append(selectedMsgs, selectedRoundsReverse[i].messages...)
|
selectedMsgs = append(selectedMsgs, selectedRoundsReverse[i].messages...)
|
||||||
|
|||||||
@@ -9,34 +9,43 @@ import (
|
|||||||
|
|
||||||
// injectToolNamesOnlyInstruction prepends a compact tool-name-only section into
|
// injectToolNamesOnlyInstruction prepends a compact tool-name-only section into
|
||||||
// the system instruction so the model can reference current callable names.
|
// the system instruction so the model can reference current callable names.
|
||||||
func injectToolNamesOnlyInstruction(ctx context.Context, instruction string, tools []tool.BaseTool) string {
|
// toolSearchMiddlewareActive must be true when prependEinoMiddlewares mounted toolsearch (dynamic tools); do not infer this
|
||||||
|
// by scanning tool names — tool_search is injected by middleware and is usually absent from the pre-split tools list.
|
||||||
|
func injectToolNamesOnlyInstruction(ctx context.Context, instruction string, tools []tool.BaseTool, toolSearchMiddlewareActive bool) string {
|
||||||
names := collectToolNames(ctx, tools)
|
names := collectToolNames(ctx, tools)
|
||||||
if len(names) == 0 {
|
if len(names) == 0 {
|
||||||
return strings.TrimSpace(instruction)
|
return strings.TrimSpace(instruction)
|
||||||
}
|
}
|
||||||
hasToolSearch := false
|
hasToolSearch := toolSearchMiddlewareActive
|
||||||
for _, n := range names {
|
if !hasToolSearch {
|
||||||
if strings.EqualFold(strings.TrimSpace(n), "tool_search") {
|
for _, n := range names {
|
||||||
hasToolSearch = true
|
if strings.EqualFold(strings.TrimSpace(n), "tool_search") {
|
||||||
break
|
hasToolSearch = true
|
||||||
|
break
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
var sb strings.Builder
|
var sb strings.Builder
|
||||||
sb.WriteString("以下是当前会话中可调用的工具名称列表(仅名称,无参数定义):\n")
|
sb.WriteString("以下是当前会话绑定的工具名称索引(仅名称,无参数 JSON Schema)。\n")
|
||||||
|
sb.WriteString("说明:若启用了 tool_search,则列表里可能含「非常驻」工具——它们不一定出现在当前轮次下发给模型的工具定义中;在未看到该工具的完整 schema 前,禁止凭名称臆测参数。\n")
|
||||||
for _, name := range names {
|
for _, name := range names {
|
||||||
sb.WriteString("- ")
|
sb.WriteString("- ")
|
||||||
sb.WriteString(name)
|
sb.WriteString(name)
|
||||||
sb.WriteByte('\n')
|
sb.WriteByte('\n')
|
||||||
}
|
}
|
||||||
sb.WriteString("\n使用规则:\n")
|
sb.WriteString("\n使用规则:\n")
|
||||||
sb.WriteString("1) 上述仅为名称列表,不包含参数定义。\n")
|
sb.WriteString("1) 上表仅为名称索引,不含参数定义。禁止猜测参数名、类型、枚举取值或是否必填。\n")
|
||||||
if hasToolSearch {
|
if hasToolSearch {
|
||||||
sb.WriteString("2) 在调用具体工具前,应先使用 tool_search 查看工具详情与参数要求,再发起调用。\n")
|
sb.WriteString("【强制 / 最高优先级】本会话已启用 tool_search(动态工具池)。凡名称索引里出现、但你在「当前请求所附 tools 定义」中看不到其完整参数 schema 的工具,一律必须先调用 tool_search;为省 token 或赶进度而跳过 tool_search、直接调用业务工具,属于明确禁止的错误流程。\n")
|
||||||
|
sb.WriteString("2) 默认策略:只要对目标工具的参数定义有任何不确定,就先 tool_search;宁可多一次 tool_search,也不要在未见 schema 时盲调业务工具。\n")
|
||||||
|
sb.WriteString("3) 调用顺序:先 tool_search(唯一必填参数 regex_pattern:按工具名匹配的正则,如子串 nuclei 或 ^exact_tool_name$)→ 在后续轮次确认目标工具已出现在 tools 列表且已阅读其 schema → 再发起对该工具的真实调用。\n")
|
||||||
|
sb.WriteString("4) tool_search 的返回仅为匹配到的工具名列表;schema 在解锁后的下一轮才会下发。禁止在 schema 未出现时编造 JSON 参数。\n")
|
||||||
|
sb.WriteString("5) 不要臆造不存在的工具名。\n\n")
|
||||||
} else {
|
} else {
|
||||||
sb.WriteString("2) 调用具体工具前,请先确认该工具的参数要求;不确定时先澄清再调用。\n")
|
sb.WriteString("2) 调用具体工具前,请先确认该工具的参数要求(以当前请求中的工具定义为准);不确定时先澄清再调用。\n")
|
||||||
|
sb.WriteString("3) 不要臆造不存在的工具名。\n\n")
|
||||||
}
|
}
|
||||||
sb.WriteString("3) 不要臆造不存在的工具名。\n\n")
|
|
||||||
if s := strings.TrimSpace(instruction); s != "" {
|
if s := strings.TrimSpace(instruction); s != "" {
|
||||||
sb.WriteString(s)
|
sb.WriteString(s)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ import (
|
|||||||
|
|
||||||
"github.com/cloudwego/eino/adk"
|
"github.com/cloudwego/eino/adk"
|
||||||
"github.com/cloudwego/eino/compose"
|
"github.com/cloudwego/eino/compose"
|
||||||
|
"github.com/cloudwego/eino/schema"
|
||||||
)
|
)
|
||||||
|
|
||||||
type hitlInterceptorKey struct{}
|
type hitlInterceptorKey struct{}
|
||||||
@@ -41,7 +42,31 @@ func WithHITLToolInterceptor(ctx context.Context, fn HITLToolInterceptor) contex
|
|||||||
return context.WithValue(ctx, hitlInterceptorKey{}, fn)
|
return context.WithValue(ctx, hitlInterceptorKey{}, fn)
|
||||||
}
|
}
|
||||||
|
|
||||||
func hitlToolCallMiddleware() compose.InvokableToolMiddleware {
|
// hitlToolCallMiddleware 同时注册 Invokable 与 Streamable。
|
||||||
|
// Eino filesystem 的 execute 为流式工具(StreamableTool),仅挂 Invokable 时人机协同不会拦截,会直接执行。
|
||||||
|
func hitlToolCallMiddleware() compose.ToolMiddleware {
|
||||||
|
return compose.ToolMiddleware{
|
||||||
|
Invokable: hitlInvokableToolCallMiddleware(),
|
||||||
|
Streamable: hitlStreamableToolCallMiddleware(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func hitlClearReturnDirectlyIfTransfer(ctx context.Context, toolName string) {
|
||||||
|
if !strings.EqualFold(strings.TrimSpace(toolName), adk.TransferToAgentToolName) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
_ = compose.ProcessState[*adk.State](ctx, func(_ context.Context, st *adk.State) error {
|
||||||
|
if st == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
st.ReturnDirectlyToolCallID = ""
|
||||||
|
st.HasReturnDirectly = false
|
||||||
|
st.ReturnDirectlyEvent = nil
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func hitlInvokableToolCallMiddleware() compose.InvokableToolMiddleware {
|
||||||
return func(next compose.InvokableToolEndpoint) compose.InvokableToolEndpoint {
|
return func(next compose.InvokableToolEndpoint) compose.InvokableToolEndpoint {
|
||||||
return func(ctx context.Context, input *compose.ToolInput) (*compose.ToolOutput, error) {
|
return func(ctx context.Context, input *compose.ToolInput) (*compose.ToolOutput, error) {
|
||||||
if input != nil {
|
if input != nil {
|
||||||
@@ -55,17 +80,7 @@ func hitlToolCallMiddleware() compose.InvokableToolMiddleware {
|
|||||||
// transfer_to_agent 在 Eino 中标记为 returnDirectly:工具成功后 ReAct 子图会直接 END,
|
// transfer_to_agent 在 Eino 中标记为 returnDirectly:工具成功后 ReAct 子图会直接 END,
|
||||||
// 并依赖真实工具内的 SendToolGenAction 触发移交。HITL 拒绝时不会执行真实工具,
|
// 并依赖真实工具内的 SendToolGenAction 触发移交。HITL 拒绝时不会执行真实工具,
|
||||||
// 若仍走 returnDirectly 分支,监督者会在无 Transfer 动作的情况下结束,模型不再迭代。
|
// 若仍走 returnDirectly 分支,监督者会在无 Transfer 动作的情况下结束,模型不再迭代。
|
||||||
if strings.EqualFold(strings.TrimSpace(input.Name), adk.TransferToAgentToolName) {
|
hitlClearReturnDirectlyIfTransfer(ctx, input.Name)
|
||||||
_ = compose.ProcessState[*adk.State](ctx, func(_ context.Context, st *adk.State) error {
|
|
||||||
if st == nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
st.ReturnDirectlyToolCallID = ""
|
|
||||||
st.HasReturnDirectly = false
|
|
||||||
st.ReturnDirectlyEvent = nil
|
|
||||||
return nil
|
|
||||||
})
|
|
||||||
}
|
|
||||||
return &compose.ToolOutput{Result: msg}, nil
|
return &compose.ToolOutput{Result: msg}, nil
|
||||||
}
|
}
|
||||||
return nil, err
|
return nil, err
|
||||||
@@ -79,3 +94,30 @@ func hitlToolCallMiddleware() compose.InvokableToolMiddleware {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func hitlStreamableToolCallMiddleware() compose.StreamableToolMiddleware {
|
||||||
|
return func(next compose.StreamableToolEndpoint) compose.StreamableToolEndpoint {
|
||||||
|
return func(ctx context.Context, input *compose.ToolInput) (*compose.StreamToolOutput, error) {
|
||||||
|
if input != nil {
|
||||||
|
if fn, ok := ctx.Value(hitlInterceptorKey{}).(HITLToolInterceptor); ok && fn != nil {
|
||||||
|
edited, err := fn(ctx, input.Name, input.Arguments)
|
||||||
|
if err != nil {
|
||||||
|
if IsHumanRejectError(err) {
|
||||||
|
msg := fmt.Sprintf("[HITL Reject] Tool '%s' was rejected by human reviewer. Reason: %s\nPlease adjust parameters/plan and continue without this call.",
|
||||||
|
input.Name, strings.TrimSpace(err.Error()))
|
||||||
|
hitlClearReturnDirectlyIfTransfer(ctx, input.Name)
|
||||||
|
return &compose.StreamToolOutput{
|
||||||
|
Result: schema.StreamReaderFromArray([]string{msg}),
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if edited != "" {
|
||||||
|
input.Arguments = edited
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return next(ctx, input)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -0,0 +1,7 @@
|
|||||||
|
package multiagent
|
||||||
|
|
||||||
|
import "errors"
|
||||||
|
|
||||||
|
// ErrInterruptContinue 作为 context.CancelCause 使用:用户选择「中断并继续」且当前无进行中的 MCP 工具时,
|
||||||
|
// 取消当前推理/流式输出,并在同一会话任务内携带用户补充说明自动续跑下一轮(类似 Hermes 式人机回合)。
|
||||||
|
var ErrInterruptContinue = errors.New("agent interrupt: continue with user-supplied context")
|
||||||
@@ -0,0 +1,22 @@
|
|||||||
|
package multiagent
|
||||||
|
|
||||||
|
import (
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Eino execute 去重分支 EOF flush 须以 mainAssistantBuf 为基准计算 tail,
|
||||||
|
// 若误用 TrimSpace(mainAssistantBuf),会与已推前缀在空白处失配,normalize 走拼接路径叠字。
|
||||||
|
func TestNormalizeStreamingDelta_eofTailUsesRawBufNotTrim(t *testing.T) {
|
||||||
|
wireAccum := "phrase "
|
||||||
|
rawFull := "phrase \n"
|
||||||
|
_, tail := normalizeStreamingDelta(wireAccum, rawFull)
|
||||||
|
if want := "\n"; tail != want {
|
||||||
|
t.Fatalf("tail=%q want %q", tail, want)
|
||||||
|
}
|
||||||
|
|
||||||
|
nextWrong, badTail := normalizeStreamingDelta(wireAccum, strings.TrimSpace(rawFull))
|
||||||
|
if badTail != "phrase" || nextWrong != "phrase phrase" {
|
||||||
|
t.Fatalf("trimmed full vs wire prefix mismatch should concat-append; got next=%q badTail=%q", nextWrong, badTail)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,52 @@
|
|||||||
|
package multiagent
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
// AggregatedReasoningFromTraceJSON concatenates non-empty assistant `reasoning_content`
|
||||||
|
// fields from last_react-style JSON (slice of message objects) in document order.
|
||||||
|
// Used to persist on the single assistant bubble row for audit and for GetMessages fallback
|
||||||
|
// when the full trace JSON is unavailable. For strict per-message replay, prefer last_react_input.
|
||||||
|
func AggregatedReasoningFromTraceJSON(traceJSON string) string {
|
||||||
|
traceJSON = strings.TrimSpace(traceJSON)
|
||||||
|
if traceJSON == "" {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
var arr []map[string]interface{}
|
||||||
|
if err := json.Unmarshal([]byte(traceJSON), &arr); err != nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
var b strings.Builder
|
||||||
|
for _, m := range arr {
|
||||||
|
role, _ := m["role"].(string)
|
||||||
|
if !strings.EqualFold(strings.TrimSpace(role), "assistant") {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
rc := reasoningContentFromMessageMap(m)
|
||||||
|
if rc == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if b.Len() > 0 {
|
||||||
|
b.WriteByte('\n')
|
||||||
|
}
|
||||||
|
b.WriteString(rc)
|
||||||
|
}
|
||||||
|
return b.String()
|
||||||
|
}
|
||||||
|
|
||||||
|
func reasoningContentFromMessageMap(m map[string]interface{}) string {
|
||||||
|
if m == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
switch v := m["reasoning_content"].(type) {
|
||||||
|
case string:
|
||||||
|
return strings.TrimSpace(v)
|
||||||
|
case nil:
|
||||||
|
return ""
|
||||||
|
default:
|
||||||
|
return strings.TrimSpace(fmt.Sprint(v))
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,20 @@
|
|||||||
|
package multiagent
|
||||||
|
|
||||||
|
import "testing"
|
||||||
|
|
||||||
|
func TestAggregatedReasoningFromTraceJSON(t *testing.T) {
|
||||||
|
const j = `[
|
||||||
|
{"role":"user","content":"hi"},
|
||||||
|
{"role":"assistant","content":"c1","reasoning_content":"r1","tool_calls":[{"id":"1","type":"function","function":{"name":"f","arguments":"{}"}}]},
|
||||||
|
{"role":"tool","tool_call_id":"1","content":"out"},
|
||||||
|
{"role":"assistant","content":"c2","reasoning_content":"r2"}
|
||||||
|
]`
|
||||||
|
got := AggregatedReasoningFromTraceJSON(j)
|
||||||
|
want := "r1\nr2"
|
||||||
|
if got != want {
|
||||||
|
t.Fatalf("got %q want %q", got, want)
|
||||||
|
}
|
||||||
|
if AggregatedReasoningFromTraceJSON("") != "" || AggregatedReasoningFromTraceJSON("[]") != "" {
|
||||||
|
t.Fatal("empty expected")
|
||||||
|
}
|
||||||
|
}
|
||||||
+150
-110
@@ -17,6 +17,7 @@ import (
|
|||||||
"cyberstrike-ai/internal/config"
|
"cyberstrike-ai/internal/config"
|
||||||
"cyberstrike-ai/internal/einomcp"
|
"cyberstrike-ai/internal/einomcp"
|
||||||
"cyberstrike-ai/internal/openai"
|
"cyberstrike-ai/internal/openai"
|
||||||
|
"cyberstrike-ai/internal/reasoning"
|
||||||
|
|
||||||
einoopenai "github.com/cloudwego/eino-ext/components/model/openai"
|
einoopenai "github.com/cloudwego/eino-ext/components/model/openai"
|
||||||
"github.com/cloudwego/eino/adk"
|
"github.com/cloudwego/eino/adk"
|
||||||
@@ -48,6 +49,7 @@ type toolCallPendingInfo struct {
|
|||||||
|
|
||||||
// RunDeepAgent 使用 Eino 多代理预置编排执行一轮对话(deep / plan_execute / supervisor;流式事件通过 progress 回调输出)。
|
// RunDeepAgent 使用 Eino 多代理预置编排执行一轮对话(deep / plan_execute / supervisor;流式事件通过 progress 回调输出)。
|
||||||
// orchestrationOverride 非空时优先(如聊天/WebShell 请求体);否则用 multi_agent.orchestration(遗留 yaml);皆空则按 deep。
|
// orchestrationOverride 非空时优先(如聊天/WebShell 请求体);否则用 multi_agent.orchestration(遗留 yaml);皆空则按 deep。
|
||||||
|
// reasoningClient 来自 ChatRequest.reasoning;可为 nil(机器人/批量等走全局 openai.reasoning)。
|
||||||
func RunDeepAgent(
|
func RunDeepAgent(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
appCfg *config.Config,
|
appCfg *config.Config,
|
||||||
@@ -61,6 +63,7 @@ func RunDeepAgent(
|
|||||||
progress func(eventType, message string, data interface{}),
|
progress func(eventType, message string, data interface{}),
|
||||||
agentsMarkdownDir string,
|
agentsMarkdownDir string,
|
||||||
orchestrationOverride string,
|
orchestrationOverride string,
|
||||||
|
reasoningClient *reasoning.ClientIntent,
|
||||||
) (*RunResult, error) {
|
) (*RunResult, error) {
|
||||||
if appCfg == nil || ma == nil || ag == nil {
|
if appCfg == nil || ma == nil || ag == nil {
|
||||||
return nil, fmt.Errorf("multiagent: 配置或 Agent 为空")
|
return nil, fmt.Errorf("multiagent: 配置或 Agent 为空")
|
||||||
@@ -110,6 +113,7 @@ func RunDeepAgent(
|
|||||||
mcpIDs = append(mcpIDs, id)
|
mcpIDs = append(mcpIDs, id)
|
||||||
mcpIDsMu.Unlock()
|
mcpIDsMu.Unlock()
|
||||||
}
|
}
|
||||||
|
einoExecMonitor := newEinoExecuteMonitorCallback(ag, recorder)
|
||||||
|
|
||||||
// 与单代理流式一致:在 response_start / response_delta 的 data 中带当前 mcpExecutionIds,供主聊天绑定复制与展示。
|
// 与单代理流式一致:在 response_start / response_delta 的 data 中带当前 mcpExecutionIds,供主聊天绑定复制与展示。
|
||||||
snapshotMCPIDs := func() []string {
|
snapshotMCPIDs := func() []string {
|
||||||
@@ -120,6 +124,7 @@ func RunDeepAgent(
|
|||||||
return out
|
return out
|
||||||
}
|
}
|
||||||
|
|
||||||
|
toolInvokeNotify := einomcp.NewToolInvokeNotifyHolder()
|
||||||
mainDefs := ag.ToolsForRole(roleTools)
|
mainDefs := ag.ToolsForRole(roleTools)
|
||||||
toolOutputChunk := func(toolName, toolCallID, chunk string) {
|
toolOutputChunk := func(toolName, toolCallID, chunk string) {
|
||||||
// When toolCallId is missing, frontend ignores tool_result_delta.
|
// When toolCallId is missing, frontend ignores tool_result_delta.
|
||||||
@@ -137,16 +142,6 @@ func RunDeepAgent(
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
mainTools, err := einomcp.ToolsFromDefinitions(ag, holder, mainDefs, recorder, toolOutputChunk)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
mainToolsForCfg, mainOrchestratorPre, err := prependEinoMiddlewares(ctx, &ma.EinoMiddleware, einoMWMain, mainTools, einoLoc, skillsRoot, conversationID, logger)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
httpClient := &http.Client{
|
httpClient := &http.Client{
|
||||||
Timeout: 30 * time.Minute,
|
Timeout: 30 * time.Minute,
|
||||||
Transport: &http.Transport{
|
Transport: &http.Transport{
|
||||||
@@ -171,6 +166,7 @@ func RunDeepAgent(
|
|||||||
Model: appCfg.OpenAI.Model,
|
Model: appCfg.OpenAI.Model,
|
||||||
HTTPClient: httpClient,
|
HTTPClient: httpClient,
|
||||||
}
|
}
|
||||||
|
reasoning.ApplyToEinoChatModelConfig(baseModelCfg, &appCfg.OpenAI, reasoningClient)
|
||||||
|
|
||||||
deepMaxIter := ma.MaxIteration
|
deepMaxIter := ma.MaxIteration
|
||||||
if deepMaxIter <= 0 {
|
if deepMaxIter <= 0 {
|
||||||
@@ -222,12 +218,12 @@ func RunDeepAgent(
|
|||||||
}
|
}
|
||||||
|
|
||||||
subDefs := ag.ToolsForRole(roleTools)
|
subDefs := ag.ToolsForRole(roleTools)
|
||||||
subTools, err := einomcp.ToolsFromDefinitions(ag, holder, subDefs, recorder, toolOutputChunk)
|
subTools, err := einomcp.ToolsFromDefinitions(ag, holder, subDefs, recorder, toolOutputChunk, toolInvokeNotify, id)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("子代理 %q 工具: %w", id, err)
|
return nil, fmt.Errorf("子代理 %q 工具: %w", id, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
subToolsForCfg, subPre, err := prependEinoMiddlewares(ctx, &ma.EinoMiddleware, einoMWSub, subTools, einoLoc, skillsRoot, conversationID, logger)
|
subToolsForCfg, subPre, subToolSearchActive, err := prependEinoMiddlewares(ctx, &ma.EinoMiddleware, einoMWSub, subTools, einoLoc, skillsRoot, conversationID, logger)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("子代理 %q eino 中间件: %w", id, err)
|
return nil, fmt.Errorf("子代理 %q eino 中间件: %w", id, err)
|
||||||
}
|
}
|
||||||
@@ -248,7 +244,7 @@ func RunDeepAgent(
|
|||||||
}
|
}
|
||||||
if einoSkillMW != nil {
|
if einoSkillMW != nil {
|
||||||
if einoFSTools && einoLoc != nil {
|
if einoFSTools && einoLoc != nil {
|
||||||
subFs, fsErr := subAgentFilesystemMiddleware(ctx, einoLoc)
|
subFs, fsErr := subAgentFilesystemMiddleware(ctx, einoLoc, toolInvokeNotify, id, einoExecMonitor, agentToolTimeoutMinutes(appCfg), toolOutputChunk)
|
||||||
if fsErr != nil {
|
if fsErr != nil {
|
||||||
return nil, fmt.Errorf("子代理 %q filesystem 中间件: %w", id, fsErr)
|
return nil, fmt.Errorf("子代理 %q filesystem 中间件: %w", id, fsErr)
|
||||||
}
|
}
|
||||||
@@ -264,23 +260,16 @@ func RunDeepAgent(
|
|||||||
subHandlers = append(subHandlers, teleMw)
|
subHandlers = append(subHandlers, teleMw)
|
||||||
}
|
}
|
||||||
|
|
||||||
subInstrFinal := injectToolNamesOnlyInstruction(ctx, instr, subTools)
|
subInstrFinal := injectToolNamesOnlyInstruction(ctx, instr, subTools, subToolSearchActive)
|
||||||
if logger != nil {
|
if logger != nil {
|
||||||
subNames := collectToolNames(ctx, subTools)
|
subNames := collectToolNames(ctx, subTools)
|
||||||
mountedNames := collectToolNames(ctx, subToolsForCfg)
|
mountedNames := collectToolNames(ctx, subToolsForCfg)
|
||||||
hasToolSearch := false
|
|
||||||
for _, n := range subNames {
|
|
||||||
if strings.EqualFold(strings.TrimSpace(n), "tool_search") {
|
|
||||||
hasToolSearch = true
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
logger.Info("eino tool-name injection",
|
logger.Info("eino tool-name injection",
|
||||||
zap.String("scope", "sub_agent"),
|
zap.String("scope", "sub_agent"),
|
||||||
zap.String("agent", id),
|
zap.String("agent", id),
|
||||||
zap.Int("tool_names", len(subNames)),
|
zap.Int("tool_names", len(subNames)),
|
||||||
zap.Int("mounted_tool_names", len(mountedNames)),
|
zap.Int("mounted_tool_names", len(mountedNames)),
|
||||||
zap.Bool("has_tool_search", hasToolSearch),
|
zap.Bool("tool_search_middleware", subToolSearchActive),
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
sa, err := adk.NewChatModelAgent(ctx, &adk.ChatModelAgentConfig{
|
sa, err := adk.NewChatModelAgent(ctx, &adk.ChatModelAgentConfig{
|
||||||
@@ -293,8 +282,8 @@ func RunDeepAgent(
|
|||||||
Tools: subToolsForCfg,
|
Tools: subToolsForCfg,
|
||||||
UnknownToolsHandler: einomcp.UnknownToolReminderHandler(),
|
UnknownToolsHandler: einomcp.UnknownToolReminderHandler(),
|
||||||
ToolCallMiddlewares: []compose.ToolMiddleware{
|
ToolCallMiddlewares: []compose.ToolMiddleware{
|
||||||
{Invokable: hitlToolCallMiddleware()},
|
hitlToolCallMiddleware(),
|
||||||
{Invokable: softRecoveryToolCallMiddleware()},
|
softRecoveryToolMiddleware(),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
EmitInternalEvents: true,
|
EmitInternalEvents: true,
|
||||||
@@ -319,6 +308,8 @@ func RunDeepAgent(
|
|||||||
return nil, fmt.Errorf("多代理主 summarization 中间件: %w", err)
|
return nil, fmt.Errorf("多代理主 summarization 中间件: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
modelFacingTrace := newModelFacingTraceHolder()
|
||||||
|
|
||||||
// 与 deep.Config.Name / supervisor 主代理 Name 一致。
|
// 与 deep.Config.Name / supervisor 主代理 Name 一致。
|
||||||
orchestratorName := "cyberstrike-deep"
|
orchestratorName := "cyberstrike-deep"
|
||||||
orchDescription := "Coordinates specialist agents and MCP tools for authorized security testing."
|
orchDescription := "Coordinates specialist agents and MCP tools for authorized security testing."
|
||||||
@@ -338,23 +329,26 @@ func RunDeepAgent(
|
|||||||
orchDescription = d
|
orchDescription = d
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
orchInstruction = injectToolNamesOnlyInstruction(ctx, orchInstruction, mainTools)
|
|
||||||
|
mainTools, err := einomcp.ToolsFromDefinitions(ag, holder, mainDefs, recorder, toolOutputChunk, toolInvokeNotify, orchestratorName)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
mainToolsForCfg, mainOrchestratorPre, mainToolSearchActive, err := prependEinoMiddlewares(ctx, &ma.EinoMiddleware, einoMWMain, mainTools, einoLoc, skillsRoot, conversationID, logger)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
orchInstruction = injectToolNamesOnlyInstruction(ctx, orchInstruction, mainTools, mainToolSearchActive)
|
||||||
if logger != nil {
|
if logger != nil {
|
||||||
mainNames := collectToolNames(ctx, mainTools)
|
mainNames := collectToolNames(ctx, mainTools)
|
||||||
mountedNames := collectToolNames(ctx, mainToolsForCfg)
|
mountedNames := collectToolNames(ctx, mainToolsForCfg)
|
||||||
hasToolSearch := false
|
|
||||||
for _, n := range mainNames {
|
|
||||||
if strings.EqualFold(strings.TrimSpace(n), "tool_search") {
|
|
||||||
hasToolSearch = true
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
logger.Info("eino tool-name injection",
|
logger.Info("eino tool-name injection",
|
||||||
zap.String("scope", "orchestrator"),
|
zap.String("scope", "orchestrator"),
|
||||||
zap.String("orchestration", orchMode),
|
zap.String("orchestration", orchMode),
|
||||||
zap.Int("tool_names", len(mainNames)),
|
zap.Int("tool_names", len(mainNames)),
|
||||||
zap.Int("mounted_tool_names", len(mountedNames)),
|
zap.Int("mounted_tool_names", len(mountedNames)),
|
||||||
zap.Bool("has_tool_search", hasToolSearch),
|
zap.Bool("tool_search_middleware", mainToolSearchActive),
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -381,7 +375,14 @@ func RunDeepAgent(
|
|||||||
var deepShell filesystem.StreamingShell
|
var deepShell filesystem.StreamingShell
|
||||||
if einoLoc != nil && einoFSTools {
|
if einoLoc != nil && einoFSTools {
|
||||||
deepBackend = einoLoc
|
deepBackend = einoLoc
|
||||||
deepShell = einoLoc
|
deepShell = &einoStreamingShellWrap{
|
||||||
|
inner: einoLoc,
|
||||||
|
invokeNotify: toolInvokeNotify,
|
||||||
|
einoAgentName: orchestratorName,
|
||||||
|
outputChunk: toolOutputChunk,
|
||||||
|
recordMonitor: einoExecMonitor,
|
||||||
|
toolTimeoutMinutes: agentToolTimeoutMinutes(appCfg),
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// noNestedTaskMiddleware 必须在最外层(最先拦截),防止 skill 或其他中间件内部触发 task 调用绕过检测。
|
// noNestedTaskMiddleware 必须在最外层(最先拦截),防止 skill 或其他中间件内部触发 task 调用绕过检测。
|
||||||
@@ -400,6 +401,9 @@ func RunDeepAgent(
|
|||||||
if teleMw := newEinoModelInputTelemetryMiddleware(logger, appCfg.OpenAI.Model, conversationID, "deep_orchestrator"); teleMw != nil {
|
if teleMw := newEinoModelInputTelemetryMiddleware(logger, appCfg.OpenAI.Model, conversationID, "deep_orchestrator"); teleMw != nil {
|
||||||
deepHandlers = append(deepHandlers, teleMw)
|
deepHandlers = append(deepHandlers, teleMw)
|
||||||
}
|
}
|
||||||
|
if capMw := newModelFacingTraceMiddleware(modelFacingTrace); capMw != nil {
|
||||||
|
deepHandlers = append(deepHandlers, capMw)
|
||||||
|
}
|
||||||
|
|
||||||
supHandlers := []adk.ChatModelAgentMiddleware{}
|
supHandlers := []adk.ChatModelAgentMiddleware{}
|
||||||
if len(mainOrchestratorPre) > 0 {
|
if len(mainOrchestratorPre) > 0 {
|
||||||
@@ -413,14 +417,17 @@ func RunDeepAgent(
|
|||||||
if teleMw := newEinoModelInputTelemetryMiddleware(logger, appCfg.OpenAI.Model, conversationID, "supervisor_orchestrator"); teleMw != nil {
|
if teleMw := newEinoModelInputTelemetryMiddleware(logger, appCfg.OpenAI.Model, conversationID, "supervisor_orchestrator"); teleMw != nil {
|
||||||
supHandlers = append(supHandlers, teleMw)
|
supHandlers = append(supHandlers, teleMw)
|
||||||
}
|
}
|
||||||
|
if capMw := newModelFacingTraceMiddleware(modelFacingTrace); capMw != nil {
|
||||||
|
supHandlers = append(supHandlers, capMw)
|
||||||
|
}
|
||||||
|
|
||||||
mainToolsCfg := adk.ToolsConfig{
|
mainToolsCfg := adk.ToolsConfig{
|
||||||
ToolsNodeConfig: compose.ToolsNodeConfig{
|
ToolsNodeConfig: compose.ToolsNodeConfig{
|
||||||
Tools: mainToolsForCfg,
|
Tools: mainToolsForCfg,
|
||||||
UnknownToolsHandler: einomcp.UnknownToolReminderHandler(),
|
UnknownToolsHandler: einomcp.UnknownToolReminderHandler(),
|
||||||
ToolCallMiddlewares: []compose.ToolMiddleware{
|
ToolCallMiddlewares: []compose.ToolMiddleware{
|
||||||
{Invokable: hitlToolCallMiddleware()},
|
hitlToolCallMiddleware(),
|
||||||
{Invokable: softRecoveryToolCallMiddleware()},
|
softRecoveryToolMiddleware(),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
EmitInternalEvents: true,
|
EmitInternalEvents: true,
|
||||||
@@ -438,7 +445,7 @@ func RunDeepAgent(
|
|||||||
// 构建 filesystem 中间件(与 Deep sub-agent 一致)
|
// 构建 filesystem 中间件(与 Deep sub-agent 一致)
|
||||||
var peFsMw adk.ChatModelAgentMiddleware
|
var peFsMw adk.ChatModelAgentMiddleware
|
||||||
if einoSkillMW != nil && einoFSTools && einoLoc != nil {
|
if einoSkillMW != nil && einoFSTools && einoLoc != nil {
|
||||||
peFsMw, err = subAgentFilesystemMiddleware(ctx, einoLoc)
|
peFsMw, err = subAgentFilesystemMiddleware(ctx, einoLoc, toolInvokeNotify, "executor", einoExecMonitor, agentToolTimeoutMinutes(appCfg), toolOutputChunk)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("plan_execute filesystem 中间件: %w", err)
|
return nil, fmt.Errorf("plan_execute filesystem 中间件: %w", err)
|
||||||
}
|
}
|
||||||
@@ -458,6 +465,7 @@ func RunDeepAgent(
|
|||||||
ExecPreMiddlewares: mainOrchestratorPre,
|
ExecPreMiddlewares: mainOrchestratorPre,
|
||||||
SkillMiddleware: einoSkillMW,
|
SkillMiddleware: einoSkillMW,
|
||||||
FilesystemMiddleware: peFsMw,
|
FilesystemMiddleware: peFsMw,
|
||||||
|
ModelFacingTrace: modelFacingTrace,
|
||||||
PlannerReplannerRewriteHandlers: []adk.ChatModelAgentMiddleware{
|
PlannerReplannerRewriteHandlers: []adk.ChatModelAgentMiddleware{
|
||||||
mainSumMw,
|
mainSumMw,
|
||||||
// 孤儿 tool 消息兜底:必须挂在 summarization 之后、telemetry 之前。
|
// 孤儿 tool 消息兜底:必须挂在 summarization 之后、telemetry 之前。
|
||||||
@@ -549,95 +557,100 @@ func RunDeepAgent(
|
|||||||
}
|
}
|
||||||
|
|
||||||
return runEinoADKAgentLoop(ctx, &einoADKRunLoopArgs{
|
return runEinoADKAgentLoop(ctx, &einoADKRunLoopArgs{
|
||||||
OrchMode: orchMode,
|
OrchMode: orchMode,
|
||||||
OrchestratorName: orchestratorName,
|
OrchestratorName: orchestratorName,
|
||||||
ConversationID: conversationID,
|
ConversationID: conversationID,
|
||||||
Progress: progress,
|
Progress: progress,
|
||||||
Logger: logger,
|
Logger: logger,
|
||||||
SnapshotMCPIDs: snapshotMCPIDs,
|
SnapshotMCPIDs: snapshotMCPIDs,
|
||||||
StreamsMainAssistant: streamsMainAssistant,
|
StreamsMainAssistant: streamsMainAssistant,
|
||||||
EinoRoleTag: einoRoleTag,
|
EinoRoleTag: einoRoleTag,
|
||||||
CheckpointDir: ma.EinoMiddleware.CheckpointDir,
|
CheckpointDir: ma.EinoMiddleware.CheckpointDir,
|
||||||
McpIDsMu: &mcpIDsMu,
|
McpIDsMu: &mcpIDsMu,
|
||||||
McpIDs: &mcpIDs,
|
McpIDs: &mcpIDs,
|
||||||
DA: da,
|
FilesystemMonitorAgent: ag,
|
||||||
|
FilesystemMonitorRecord: recorder,
|
||||||
|
ToolInvokeNotify: toolInvokeNotify,
|
||||||
|
DA: da,
|
||||||
|
ModelFacingTrace: modelFacingTrace,
|
||||||
|
EinoCallbacks: &ma.EinoCallbacks,
|
||||||
EmptyResponseMessage: "(Eino multi-agent orchestration completed but no assistant text was captured. Check process details or logs.) " +
|
EmptyResponseMessage: "(Eino multi-agent orchestration completed but no assistant text was captured. Check process details or logs.) " +
|
||||||
"(Eino 多代理编排已完成,但未捕获到助手文本输出。请查看过程详情或日志。)",
|
"(Eino 多代理编排已完成,但未捕获到助手文本输出。请查看过程详情或日志。)",
|
||||||
}, baseMsgs)
|
}, baseMsgs)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func chatToolCallsToSchema(tcs []agent.ToolCall) []schema.ToolCall {
|
||||||
|
if len(tcs) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
out := make([]schema.ToolCall, 0, len(tcs))
|
||||||
|
for _, tc := range tcs {
|
||||||
|
if strings.TrimSpace(tc.ID) == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
argsStr := ""
|
||||||
|
if tc.Function.Arguments != nil {
|
||||||
|
b, err := json.Marshal(tc.Function.Arguments)
|
||||||
|
if err == nil {
|
||||||
|
argsStr = string(b)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
typ := tc.Type
|
||||||
|
if typ == "" {
|
||||||
|
typ = "function"
|
||||||
|
}
|
||||||
|
out = append(out, schema.ToolCall{
|
||||||
|
ID: tc.ID,
|
||||||
|
Type: typ,
|
||||||
|
Function: schema.FunctionCall{
|
||||||
|
Name: tc.Function.Name,
|
||||||
|
Arguments: argsStr,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
// historyToMessages 将轨迹恢复的 ChatMessage 转为 Eino ADK 消息:**不裁剪条数、不按 token 预算截断**,
|
||||||
|
// 并保留 user / assistant(含仅 tool_calls)/ tool,与库中 last_react 轨迹一致。
|
||||||
func historyToMessages(history []agent.ChatMessage, appCfg *config.Config, mwCfg *config.MultiAgentEinoMiddlewareConfig) []adk.Message {
|
func historyToMessages(history []agent.ChatMessage, appCfg *config.Config, mwCfg *config.MultiAgentEinoMiddlewareConfig) []adk.Message {
|
||||||
|
_ = appCfg
|
||||||
|
_ = mwCfg
|
||||||
if len(history) == 0 {
|
if len(history) == 0 {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
// Keep a bounded tail first; then enforce a token budget.
|
raw := make([]adk.Message, 0, len(history))
|
||||||
const maxHistoryMessages = 200
|
for _, h := range history {
|
||||||
start := 0
|
role := strings.ToLower(strings.TrimSpace(h.Role))
|
||||||
if len(history) > maxHistoryMessages {
|
switch role {
|
||||||
start = len(history) - maxHistoryMessages
|
|
||||||
}
|
|
||||||
raw := make([]adk.Message, 0, len(history[start:]))
|
|
||||||
for _, h := range history[start:] {
|
|
||||||
switch h.Role {
|
|
||||||
case "user":
|
case "user":
|
||||||
if strings.TrimSpace(h.Content) != "" {
|
if strings.TrimSpace(h.Content) != "" {
|
||||||
raw = append(raw, schema.UserMessage(h.Content))
|
raw = append(raw, schema.UserMessage(h.Content))
|
||||||
}
|
}
|
||||||
case "assistant":
|
case "assistant":
|
||||||
if strings.TrimSpace(h.Content) == "" && len(h.ToolCalls) > 0 {
|
toolSchema := chatToolCallsToSchema(h.ToolCalls)
|
||||||
|
hasRC := strings.TrimSpace(h.ReasoningContent) != ""
|
||||||
|
if len(toolSchema) > 0 || strings.TrimSpace(h.Content) != "" || hasRC {
|
||||||
|
am := schema.AssistantMessage(h.Content, toolSchema)
|
||||||
|
if hasRC {
|
||||||
|
am.ReasoningContent = strings.TrimSpace(h.ReasoningContent)
|
||||||
|
}
|
||||||
|
raw = append(raw, am)
|
||||||
|
}
|
||||||
|
case "tool":
|
||||||
|
if strings.TrimSpace(h.ToolCallID) == "" && strings.TrimSpace(h.Content) == "" {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
if strings.TrimSpace(h.Content) != "" {
|
var opts []schema.ToolMessageOption
|
||||||
raw = append(raw, schema.AssistantMessage(h.Content, nil))
|
if tn := strings.TrimSpace(h.ToolName); tn != "" {
|
||||||
|
opts = append(opts, schema.WithToolName(tn))
|
||||||
}
|
}
|
||||||
|
raw = append(raw, schema.ToolMessage(h.Content, h.ToolCallID, opts...))
|
||||||
default:
|
default:
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if len(raw) == 0 {
|
return raw
|
||||||
return raw
|
|
||||||
}
|
|
||||||
maxTotal := 120000
|
|
||||||
modelName := "gpt-4o"
|
|
||||||
if appCfg != nil {
|
|
||||||
if appCfg.OpenAI.MaxTotalTokens > 0 {
|
|
||||||
maxTotal = appCfg.OpenAI.MaxTotalTokens
|
|
||||||
}
|
|
||||||
if m := strings.TrimSpace(appCfg.OpenAI.Model); m != "" {
|
|
||||||
modelName = m
|
|
||||||
}
|
|
||||||
}
|
|
||||||
ratio := 0.35
|
|
||||||
if mwCfg != nil {
|
|
||||||
ratio = mwCfg.HistoryInputBudgetRatioEffective()
|
|
||||||
}
|
|
||||||
budget := int(float64(maxTotal) * ratio)
|
|
||||||
if budget < 4096 {
|
|
||||||
budget = 4096
|
|
||||||
}
|
|
||||||
tc := agent.NewTikTokenCounter()
|
|
||||||
outRev := make([]adk.Message, 0, len(raw))
|
|
||||||
used := 0
|
|
||||||
for i := len(raw) - 1; i >= 0; i-- {
|
|
||||||
msg := raw[i]
|
|
||||||
n, err := tc.Count(modelName, string(msg.Role)+"\n"+msg.Content)
|
|
||||||
if err != nil {
|
|
||||||
n = (len(msg.Content) + 3) / 4
|
|
||||||
}
|
|
||||||
if n <= 0 {
|
|
||||||
n = 1
|
|
||||||
}
|
|
||||||
if used+n > budget {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
used += n
|
|
||||||
outRev = append(outRev, msg)
|
|
||||||
}
|
|
||||||
out := make([]adk.Message, 0, len(outRev))
|
|
||||||
for i := len(outRev) - 1; i >= 0; i-- {
|
|
||||||
out = append(out, outRev[i])
|
|
||||||
}
|
|
||||||
return out
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// mergeStreamingToolCallFragments 将流式多帧的 ToolCall 按 index 合并 arguments(与 schema.concatToolCalls 行为一致)。
|
// mergeStreamingToolCallFragments 将流式多帧的 ToolCall 按 index 合并 arguments(与 schema.concatToolCalls 行为一致)。
|
||||||
@@ -724,12 +737,23 @@ func toolCallsRichSignature(msg *schema.Message) string {
|
|||||||
return base + "|" + strings.Join(parts, ";")
|
return base + "|" + strings.Join(parts, ";")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func einoMainIterationKey(agentName, orchestratorName string) string {
|
||||||
|
key := strings.TrimSpace(agentName)
|
||||||
|
if key == "" {
|
||||||
|
key = strings.TrimSpace(orchestratorName)
|
||||||
|
}
|
||||||
|
if key == "" {
|
||||||
|
return "_main"
|
||||||
|
}
|
||||||
|
return key
|
||||||
|
}
|
||||||
|
|
||||||
func tryEmitToolCallsOnce(
|
func tryEmitToolCallsOnce(
|
||||||
msg *schema.Message,
|
msg *schema.Message,
|
||||||
agentName, orchestratorName, conversationID string,
|
agentName, orchestratorName, conversationID, orchMode string,
|
||||||
progress func(string, string, interface{}),
|
progress func(string, string, interface{}),
|
||||||
seen map[string]struct{},
|
seen map[string]struct{},
|
||||||
subAgentToolStep map[string]int,
|
subAgentToolStep, mainAgentToolStep map[string]int,
|
||||||
markPending func(toolCallPendingInfo),
|
markPending func(toolCallPendingInfo),
|
||||||
) {
|
) {
|
||||||
if msg == nil || len(msg.ToolCalls) == 0 || progress == nil || seen == nil {
|
if msg == nil || len(msg.ToolCalls) == 0 || progress == nil || seen == nil {
|
||||||
@@ -743,14 +767,14 @@ func tryEmitToolCallsOnce(
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
seen[sig] = struct{}{}
|
seen[sig] = struct{}{}
|
||||||
emitToolCallsFromMessage(msg, agentName, orchestratorName, conversationID, progress, subAgentToolStep, markPending)
|
emitToolCallsFromMessage(msg, agentName, orchestratorName, conversationID, orchMode, progress, subAgentToolStep, mainAgentToolStep, markPending)
|
||||||
}
|
}
|
||||||
|
|
||||||
func emitToolCallsFromMessage(
|
func emitToolCallsFromMessage(
|
||||||
msg *schema.Message,
|
msg *schema.Message,
|
||||||
agentName, orchestratorName, conversationID string,
|
agentName, orchestratorName, conversationID, orchMode string,
|
||||||
progress func(string, string, interface{}),
|
progress func(string, string, interface{}),
|
||||||
subAgentToolStep map[string]int,
|
subAgentToolStep, mainAgentToolStep map[string]int,
|
||||||
markPending func(toolCallPendingInfo),
|
markPending func(toolCallPendingInfo),
|
||||||
) {
|
) {
|
||||||
if msg == nil || len(msg.ToolCalls) == 0 || progress == nil {
|
if msg == nil || len(msg.ToolCalls) == 0 || progress == nil {
|
||||||
@@ -771,6 +795,22 @@ func emitToolCallsFromMessage(
|
|||||||
"conversationId": conversationID,
|
"conversationId": conversationID,
|
||||||
"source": "eino",
|
"source": "eino",
|
||||||
})
|
})
|
||||||
|
} else if mainAgentToolStep != nil {
|
||||||
|
key := einoMainIterationKey(agentName, orchestratorName)
|
||||||
|
mainAgentToolStep[key]++
|
||||||
|
n := mainAgentToolStep[key]
|
||||||
|
// 第 1 轮已在主代理进入时发出;此后每次工具批次对应新一轮 ReAct(与子代理按工具计步一致)。
|
||||||
|
if n > 1 {
|
||||||
|
progress("iteration", "", map[string]interface{}{
|
||||||
|
"iteration": n,
|
||||||
|
"einoScope": "main",
|
||||||
|
"einoRole": "orchestrator",
|
||||||
|
"einoAgent": agentName,
|
||||||
|
"orchestration": orchMode,
|
||||||
|
"conversationId": conversationID,
|
||||||
|
"source": "eino",
|
||||||
|
})
|
||||||
|
}
|
||||||
}
|
}
|
||||||
role := "orchestrator"
|
role := "orchestrator"
|
||||||
if isSubToolRound {
|
if isSubToolRound {
|
||||||
|
|||||||
@@ -0,0 +1,22 @@
|
|||||||
|
package multiagent
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"cyberstrike-ai/internal/agent"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestHistoryToMessagesPreservesReasoningContent(t *testing.T) {
|
||||||
|
h := []agent.ChatMessage{
|
||||||
|
{Role: "user", Content: "u"},
|
||||||
|
{Role: "assistant", Content: "c", ReasoningContent: "r1", ToolCalls: []agent.ToolCall{{ID: "t1", Type: "function", Function: agent.FunctionCall{Name: "f", Arguments: map[string]interface{}{}}}}},
|
||||||
|
}
|
||||||
|
msgs := historyToMessages(h, nil, nil)
|
||||||
|
if len(msgs) != 2 {
|
||||||
|
t.Fatalf("len=%d", len(msgs))
|
||||||
|
}
|
||||||
|
am := msgs[1]
|
||||||
|
if am.ReasoningContent != "r1" || am.Content != "c" {
|
||||||
|
t.Fatalf("got reasoning=%q content=%q", am.ReasoningContent, am.Content)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -8,6 +8,7 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/cloudwego/eino/compose"
|
"github.com/cloudwego/eino/compose"
|
||||||
|
"github.com/cloudwego/eino/schema"
|
||||||
)
|
)
|
||||||
|
|
||||||
// softRecoveryToolCallMiddleware returns an InvokableToolMiddleware that catches
|
// softRecoveryToolCallMiddleware returns an InvokableToolMiddleware that catches
|
||||||
@@ -16,8 +17,9 @@ import (
|
|||||||
// returned to the LLM. This allows the model to self-correct within the same
|
// returned to the LLM. This allows the model to self-correct within the same
|
||||||
// iteration rather than crashing the entire graph and requiring a full replay.
|
// iteration rather than crashing the entire graph and requiring a full replay.
|
||||||
//
|
//
|
||||||
// Without this middleware, a JSON parse failure in any tool's InvokableRun propagates
|
// Without Invokable (+ Streamable where applicable) registration, a JSON parse failure
|
||||||
// as a hard error through the Eino ToolsNode → [NodeRunError] → ev.Err, which
|
// in InvokableRun / StreamableRun propagates as a hard error through the Eino ToolsNode
|
||||||
|
// → [NodeRunError] → ev.Err, which
|
||||||
// either triggers the full-replay retry loop (expensive) or terminates the run
|
// either triggers the full-replay retry loop (expensive) or terminates the run
|
||||||
// entirely once retries are exhausted. With it, the LLM simply sees an error message
|
// entirely once retries are exhausted. With it, the LLM simply sees an error message
|
||||||
// in the tool result and can adjust its next tool call accordingly.
|
// in the tool result and can adjust its next tool call accordingly.
|
||||||
@@ -39,6 +41,44 @@ func softRecoveryToolCallMiddleware() compose.InvokableToolMiddleware {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// softRecoveryStreamableToolCallMiddleware mirrors softRecoveryToolCallMiddleware for
|
||||||
|
// tools that implement StreamableTool only (e.g. Eino ADK filesystem execute).
|
||||||
|
// Eino applies Invokable vs Streamable middleware to disjoint code paths in ToolsNode;
|
||||||
|
// registering only Invokable leaves streaming tools uncovered — empty/malformed JSON
|
||||||
|
// then fails inside [LocalStreamFunc] before the inner endpoint runs.
|
||||||
|
func softRecoveryStreamableToolCallMiddleware() compose.StreamableToolMiddleware {
|
||||||
|
return func(next compose.StreamableToolEndpoint) compose.StreamableToolEndpoint {
|
||||||
|
return func(ctx context.Context, input *compose.ToolInput) (*compose.StreamToolOutput, error) {
|
||||||
|
out, err := next(ctx, input)
|
||||||
|
if err == nil {
|
||||||
|
return out, nil
|
||||||
|
}
|
||||||
|
if !isSoftRecoverableToolError(err) {
|
||||||
|
return out, err
|
||||||
|
}
|
||||||
|
toolName := ""
|
||||||
|
args := ""
|
||||||
|
if input != nil {
|
||||||
|
toolName = input.Name
|
||||||
|
args = input.Arguments
|
||||||
|
}
|
||||||
|
msg := buildSoftRecoveryMessage(toolName, args, err)
|
||||||
|
return &compose.StreamToolOutput{
|
||||||
|
Result: schema.StreamReaderFromArray([]string{msg}),
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// softRecoveryToolMiddleware returns a ToolMiddleware with both Invokable and Streamable
|
||||||
|
// soft recovery (same semantics as hitlToolCallMiddleware bundling).
|
||||||
|
func softRecoveryToolMiddleware() compose.ToolMiddleware {
|
||||||
|
return compose.ToolMiddleware{
|
||||||
|
Invokable: softRecoveryToolCallMiddleware(),
|
||||||
|
Streamable: softRecoveryStreamableToolCallMiddleware(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// isSoftRecoverableToolError determines whether a tool execution error should be
|
// isSoftRecoverableToolError determines whether a tool execution error should be
|
||||||
// silently converted to a tool-result message rather than crashing the graph.
|
// silently converted to a tool-result message rather than crashing the graph.
|
||||||
//
|
//
|
||||||
|
|||||||
@@ -4,6 +4,8 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
|
"io"
|
||||||
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/cloudwego/eino/compose"
|
"github.com/cloudwego/eino/compose"
|
||||||
@@ -108,6 +110,39 @@ func TestSoftRecoveryToolCallMiddleware_PassesThrough(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestSoftRecoveryStreamableToolCallMiddleware_LocalStreamFuncJSONError(t *testing.T) {
|
||||||
|
mw := softRecoveryStreamableToolCallMiddleware()
|
||||||
|
next := func(ctx context.Context, input *compose.ToolInput) (*compose.StreamToolOutput, error) {
|
||||||
|
return nil, errors.New(`[LocalStreamFunc] failed to unmarshal arguments in json, toolName=execute, err="Syntax error no sources available, the input json is empty`)
|
||||||
|
}
|
||||||
|
wrapped := mw(next)
|
||||||
|
out, err := wrapped(context.Background(), &compose.ToolInput{
|
||||||
|
Name: "execute",
|
||||||
|
Arguments: "",
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("expected nil error (soft recovery), got: %v", err)
|
||||||
|
}
|
||||||
|
if out == nil || out.Result == nil {
|
||||||
|
t.Fatal("expected stream result")
|
||||||
|
}
|
||||||
|
var sb strings.Builder
|
||||||
|
for {
|
||||||
|
chunk, rerr := out.Result.Recv()
|
||||||
|
if errors.Is(rerr, io.EOF) {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
if rerr != nil {
|
||||||
|
t.Fatalf("recv: %v", rerr)
|
||||||
|
}
|
||||||
|
sb.WriteString(chunk)
|
||||||
|
}
|
||||||
|
text := sb.String()
|
||||||
|
if !containsAll(text, "[Tool Error]", "execute", "JSON") {
|
||||||
|
t.Fatalf("recovery message missing expected content: %s", text)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestSoftRecoveryToolCallMiddleware_ConvertsJSONError(t *testing.T) {
|
func TestSoftRecoveryToolCallMiddleware_ConvertsJSONError(t *testing.T) {
|
||||||
mw := softRecoveryToolCallMiddleware()
|
mw := softRecoveryToolCallMiddleware()
|
||||||
next := func(ctx context.Context, input *compose.ToolInput) (*compose.ToolOutput, error) {
|
next := func(ctx context.Context, input *compose.ToolInput) (*compose.ToolOutput, error) {
|
||||||
|
|||||||
@@ -9,6 +9,9 @@ package openai
|
|||||||
// Stream: Claude SSE (event: content_block_delta / message_delta) → OpenAI SSE 格式
|
// Stream: Claude SSE (event: content_block_delta / message_delta) → OpenAI SSE 格式
|
||||||
// Auth: Bearer → x-api-key
|
// Auth: Bearer → x-api-key
|
||||||
// Tools: OpenAI tools[] → Claude tools[] (input_schema)
|
// Tools: OpenAI tools[] → Claude tools[] (input_schema)
|
||||||
|
//
|
||||||
|
// Extended thinking: 顶层 `thinking` 从 OpenAI 请求体透传;响应中 `thinking` block 映射为
|
||||||
|
// `reasoning_content`(可读前缀 + 内部 JSON 尾缀以保留 signature,供多轮工具续跑;UI 用 openai.DisplayReasoningContent 剥离)。
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bufio"
|
"bufio"
|
||||||
@@ -38,6 +41,7 @@ type claudeRequest struct {
|
|||||||
Messages []claudeMessage `json:"messages"`
|
Messages []claudeMessage `json:"messages"`
|
||||||
Tools []claudeTool `json:"tools,omitempty"`
|
Tools []claudeTool `json:"tools,omitempty"`
|
||||||
Stream bool `json:"stream,omitempty"`
|
Stream bool `json:"stream,omitempty"`
|
||||||
|
Thinking json.RawMessage `json:"thinking,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type claudeMessage struct {
|
type claudeMessage struct {
|
||||||
@@ -76,6 +80,10 @@ type claudeContentBlock struct {
|
|||||||
// text block
|
// text block
|
||||||
Text string `json:"text,omitempty"`
|
Text string `json:"text,omitempty"`
|
||||||
|
|
||||||
|
// thinking block (extended thinking)
|
||||||
|
Thinking string `json:"thinking,omitempty"`
|
||||||
|
Signature string `json:"signature,omitempty"`
|
||||||
|
|
||||||
// tool_use block (assistant 返回)
|
// tool_use block (assistant 返回)
|
||||||
ID string `json:"id,omitempty"`
|
ID string `json:"id,omitempty"`
|
||||||
Name string `json:"name,omitempty"`
|
Name string `json:"name,omitempty"`
|
||||||
@@ -176,7 +184,13 @@ func convertOpenAIToClaude(payload interface{}) (*claudeRequest, error) {
|
|||||||
|
|
||||||
// tool_calls (assistant 消息中包含工具调用)
|
// tool_calls (assistant 消息中包含工具调用)
|
||||||
if role == "assistant" {
|
if role == "assistant" {
|
||||||
|
rc, _ := mm["reasoning_content"].(string)
|
||||||
|
_, thinkingReplay := parseClaudeReasoningAssistantBlocks(rc)
|
||||||
|
|
||||||
var blocks []claudeContentBlock
|
var blocks []claudeContentBlock
|
||||||
|
for _, tb := range thinkingReplay {
|
||||||
|
blocks = append(blocks, tb)
|
||||||
|
}
|
||||||
if content != "" {
|
if content != "" {
|
||||||
blocks = append(blocks, claudeContentBlock{Type: "text", Text: content})
|
blocks = append(blocks, claudeContentBlock{Type: "text", Text: content})
|
||||||
}
|
}
|
||||||
@@ -290,6 +304,13 @@ func convertOpenAIToClaude(payload interface{}) (*claudeRequest, error) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Extended thinking (Anthropic top-level); merged from Eino ExtraFields / admin extras.
|
||||||
|
if th, ok := oai["thinking"]; ok && th != nil {
|
||||||
|
if raw, err := json.Marshal(th); err == nil && len(raw) > 0 && string(raw) != "null" {
|
||||||
|
req.Thinking = json.RawMessage(raw)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return req, nil
|
return req, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -318,9 +339,12 @@ func claudeToOpenAIResponseJSON(claudeBody []byte) ([]byte, error) {
|
|||||||
|
|
||||||
var textContent string
|
var textContent string
|
||||||
var toolCalls []interface{}
|
var toolCalls []interface{}
|
||||||
|
var thinkingBlocks []claudeContentBlock
|
||||||
|
|
||||||
for _, block := range cr.Content {
|
for _, block := range cr.Content {
|
||||||
switch block.Type {
|
switch block.Type {
|
||||||
|
case "thinking":
|
||||||
|
thinkingBlocks = append(thinkingBlocks, block)
|
||||||
case "text":
|
case "text":
|
||||||
textContent += block.Text
|
textContent += block.Text
|
||||||
case "tool_use":
|
case "tool_use":
|
||||||
@@ -344,6 +368,18 @@ func claudeToOpenAIResponseJSON(claudeBody []byte) ([]byte, error) {
|
|||||||
if len(toolCalls) > 0 {
|
if len(toolCalls) > 0 {
|
||||||
message["tool_calls"] = toolCalls
|
message["tool_calls"] = toolCalls
|
||||||
}
|
}
|
||||||
|
if len(thinkingBlocks) > 0 {
|
||||||
|
var parts []string
|
||||||
|
for _, tb := range thinkingBlocks {
|
||||||
|
if strings.TrimSpace(tb.Thinking) != "" {
|
||||||
|
parts = append(parts, tb.Thinking)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
rc := appendClaudeReasoningRoundTrip(strings.Join(parts, "\n\n"), thinkingBlocks)
|
||||||
|
if rc != "" {
|
||||||
|
message["reasoning_content"] = rc
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
choice := map[string]interface{}{
|
choice := map[string]interface{}{
|
||||||
"index": 0,
|
"index": 0,
|
||||||
@@ -901,8 +937,16 @@ func (rt *claudeRoundTripper) RoundTrip(req *http.Request) (*http.Response, erro
|
|||||||
|
|
||||||
reader := bufio.NewReader(resp.Body)
|
reader := bufio.NewReader(resp.Body)
|
||||||
blockToToolIndex := make(map[int]int)
|
blockToToolIndex := make(map[int]int)
|
||||||
|
blockIndexToType := make(map[int]string)
|
||||||
nextToolIndex := 0
|
nextToolIndex := 0
|
||||||
|
|
||||||
|
type thinkingAcc struct {
|
||||||
|
text strings.Builder
|
||||||
|
sig strings.Builder
|
||||||
|
}
|
||||||
|
thinkingByIndex := make(map[int]*thinkingAcc)
|
||||||
|
var finishedThinking []claudeContentBlock
|
||||||
|
|
||||||
for {
|
for {
|
||||||
line, readErr := reader.ReadString('\n')
|
line, readErr := reader.ReadString('\n')
|
||||||
if readErr != nil {
|
if readErr != nil {
|
||||||
@@ -947,6 +991,11 @@ func (rt *claudeRoundTripper) RoundTrip(req *http.Request) (*http.Response, erro
|
|||||||
blockIdx := int(blockIdxFlt)
|
blockIdx := int(blockIdxFlt)
|
||||||
cb, _ := event["content_block"].(map[string]interface{})
|
cb, _ := event["content_block"].(map[string]interface{})
|
||||||
bt, _ := cb["type"].(string)
|
bt, _ := cb["type"].(string)
|
||||||
|
blockIndexToType[blockIdx] = bt
|
||||||
|
|
||||||
|
if bt == "thinking" {
|
||||||
|
thinkingByIndex[blockIdx] = &thinkingAcc{}
|
||||||
|
}
|
||||||
|
|
||||||
if bt == "tool_use" {
|
if bt == "tool_use" {
|
||||||
id, _ := cb["id"].(string)
|
id, _ := cb["id"].(string)
|
||||||
@@ -986,7 +1035,35 @@ func (rt *claudeRoundTripper) RoundTrip(req *http.Request) (*http.Response, erro
|
|||||||
delta, _ := event["delta"].(map[string]interface{})
|
delta, _ := event["delta"].(map[string]interface{})
|
||||||
dt, _ := delta["type"].(string)
|
dt, _ := delta["type"].(string)
|
||||||
|
|
||||||
if dt == "text_delta" {
|
if dt == "thinking_delta" {
|
||||||
|
tPart, _ := delta["thinking"].(string)
|
||||||
|
if tPart != "" {
|
||||||
|
if acc := thinkingByIndex[blockIdx]; acc != nil {
|
||||||
|
acc.text.WriteString(tPart)
|
||||||
|
}
|
||||||
|
oaiChunk := map[string]interface{}{
|
||||||
|
"choices": []map[string]interface{}{
|
||||||
|
{
|
||||||
|
"delta": map[string]interface{}{
|
||||||
|
"reasoning_content": tPart,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
b, _ := json.Marshal(oaiChunk)
|
||||||
|
if !writeLine("data: " + string(b) + "\n\n") {
|
||||||
|
pw.Close()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else if dt == "signature_delta" {
|
||||||
|
sigPart, _ := delta["signature"].(string)
|
||||||
|
if sigPart != "" {
|
||||||
|
if acc := thinkingByIndex[blockIdx]; acc != nil {
|
||||||
|
acc.sig.WriteString(sigPart)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else if dt == "text_delta" {
|
||||||
text, _ := delta["text"].(string)
|
text, _ := delta["text"].(string)
|
||||||
oaiChunk := map[string]interface{}{
|
oaiChunk := map[string]interface{}{
|
||||||
"choices": []map[string]interface{}{
|
"choices": []map[string]interface{}{
|
||||||
@@ -1031,6 +1108,21 @@ func (rt *claudeRoundTripper) RoundTrip(req *http.Request) (*http.Response, erro
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
case "content_block_stop":
|
||||||
|
blockIdxFlt, _ := event["index"].(float64)
|
||||||
|
blockIdx := int(blockIdxFlt)
|
||||||
|
bt := blockIndexToType[blockIdx]
|
||||||
|
if bt == "thinking" {
|
||||||
|
if acc := thinkingByIndex[blockIdx]; acc != nil {
|
||||||
|
finishedThinking = append(finishedThinking, claudeContentBlock{
|
||||||
|
Type: "thinking",
|
||||||
|
Thinking: acc.text.String(),
|
||||||
|
Signature: acc.sig.String(),
|
||||||
|
})
|
||||||
|
delete(thinkingByIndex, blockIdx)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
case "message_delta":
|
case "message_delta":
|
||||||
d, _ := event["delta"].(map[string]interface{})
|
d, _ := event["delta"].(map[string]interface{})
|
||||||
if sr, ok := d["stop_reason"].(string); ok {
|
if sr, ok := d["stop_reason"].(string); ok {
|
||||||
@@ -1051,6 +1143,25 @@ func (rt *claudeRoundTripper) RoundTrip(req *http.Request) (*http.Response, erro
|
|||||||
}
|
}
|
||||||
|
|
||||||
case "message_stop":
|
case "message_stop":
|
||||||
|
if len(finishedThinking) > 0 {
|
||||||
|
suffix := appendClaudeReasoningRoundTrip("", finishedThinking)
|
||||||
|
if strings.TrimSpace(suffix) != "" {
|
||||||
|
oaiChunk := map[string]interface{}{
|
||||||
|
"choices": []map[string]interface{}{
|
||||||
|
{
|
||||||
|
"delta": map[string]interface{}{
|
||||||
|
"reasoning_content": suffix,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
b, _ := json.Marshal(oaiChunk)
|
||||||
|
if !writeLine("data: " + string(b) + "\n\n") {
|
||||||
|
pw.Close()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
writeLine("data: [DONE]\n\n")
|
writeLine("data: [DONE]\n\n")
|
||||||
pw.Close()
|
pw.Close()
|
||||||
return
|
return
|
||||||
|
|||||||
@@ -0,0 +1,81 @@
|
|||||||
|
package openai
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
// claudeReasoningRoundTripSep separates human-readable reasoning from a JSON payload of
|
||||||
|
// Anthropic thinking blocks (with signatures) for multi-turn extended thinking + tools.
|
||||||
|
// Not shown in UI (see DisplayReasoningContent).
|
||||||
|
const claudeReasoningRoundTripSep = "\n---CSAI_CLAUDE_THINKING_BLOCKS---\n"
|
||||||
|
|
||||||
|
// DisplayReasoningContent returns reasoning text suitable for the UI (strips internal
|
||||||
|
// Claude round-trip JSON suffix). Safe for DeepSeek/plain reasoning strings (no-op).
|
||||||
|
func DisplayReasoningContent(s string) string {
|
||||||
|
s = strings.TrimSpace(s)
|
||||||
|
if s == "" {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
i := strings.LastIndex(s, claudeReasoningRoundTripSep)
|
||||||
|
if i < 0 {
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
return strings.TrimSpace(s[:i])
|
||||||
|
}
|
||||||
|
|
||||||
|
func appendClaudeReasoningRoundTrip(display string, blocks []claudeContentBlock) string {
|
||||||
|
var payload []map[string]string
|
||||||
|
for _, b := range blocks {
|
||||||
|
if b.Type != "thinking" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
payload = append(payload, map[string]string{
|
||||||
|
"type": b.Type,
|
||||||
|
"thinking": b.Thinking,
|
||||||
|
"signature": b.Signature,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
if len(payload) == 0 {
|
||||||
|
return strings.TrimSpace(display)
|
||||||
|
}
|
||||||
|
js, err := json.Marshal(payload)
|
||||||
|
if err != nil {
|
||||||
|
return strings.TrimSpace(display)
|
||||||
|
}
|
||||||
|
d := strings.TrimSpace(display)
|
||||||
|
if d == "" {
|
||||||
|
return claudeReasoningRoundTripSep + string(js)
|
||||||
|
}
|
||||||
|
return d + claudeReasoningRoundTripSep + string(js)
|
||||||
|
}
|
||||||
|
|
||||||
|
// parseClaudeReasoningAssistantBlocks extracts Anthropic thinking blocks from an OpenAI-style
|
||||||
|
// reasoning_content string. When no suffix is present, blocks is nil (caller must not invent signatures).
|
||||||
|
func parseClaudeReasoningAssistantBlocks(reasoningContent string) (display string, blocks []claudeContentBlock) {
|
||||||
|
reasoningContent = strings.TrimSpace(reasoningContent)
|
||||||
|
if reasoningContent == "" {
|
||||||
|
return "", nil
|
||||||
|
}
|
||||||
|
idx := strings.LastIndex(reasoningContent, claudeReasoningRoundTripSep)
|
||||||
|
if idx < 0 {
|
||||||
|
return reasoningContent, nil
|
||||||
|
}
|
||||||
|
display = strings.TrimSpace(reasoningContent[:idx])
|
||||||
|
jsonPart := strings.TrimSpace(reasoningContent[idx+len(claudeReasoningRoundTripSep):])
|
||||||
|
var arr []struct {
|
||||||
|
Type string `json:"type"`
|
||||||
|
Thinking string `json:"thinking"`
|
||||||
|
Signature string `json:"signature"`
|
||||||
|
}
|
||||||
|
if err := json.Unmarshal([]byte(jsonPart), &arr); err != nil {
|
||||||
|
return reasoningContent, nil
|
||||||
|
}
|
||||||
|
for _, x := range arr {
|
||||||
|
if x.Type != "thinking" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
blocks = append(blocks, claudeContentBlock{Type: "thinking", Thinking: x.Thinking, Signature: x.Signature})
|
||||||
|
}
|
||||||
|
return display, blocks
|
||||||
|
}
|
||||||
@@ -0,0 +1,102 @@
|
|||||||
|
package openai
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestDisplayReasoningContent(t *testing.T) {
|
||||||
|
raw := "hello" + claudeReasoningRoundTripSep + `[{"type":"thinking","thinking":"x","signature":"sig"}]`
|
||||||
|
if d := DisplayReasoningContent(raw); d != "hello" {
|
||||||
|
t.Fatalf("got %q", d)
|
||||||
|
}
|
||||||
|
if DisplayReasoningContent("plain") != "plain" {
|
||||||
|
t.Fatal()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAppendParseClaudeReasoningRoundTrip(t *testing.T) {
|
||||||
|
blocks := []claudeContentBlock{
|
||||||
|
{Type: "thinking", Thinking: "a", Signature: "s1"},
|
||||||
|
{Type: "thinking", Thinking: "b", Signature: "s2"},
|
||||||
|
}
|
||||||
|
s := appendClaudeReasoningRoundTrip("sum", blocks)
|
||||||
|
if !strings.Contains(s, claudeReasoningRoundTripSep) {
|
||||||
|
t.Fatal("missing sep")
|
||||||
|
}
|
||||||
|
display, back := parseClaudeReasoningAssistantBlocks(s)
|
||||||
|
if display != "sum" || len(back) != 2 {
|
||||||
|
t.Fatalf("display=%q len=%d", display, len(back))
|
||||||
|
}
|
||||||
|
if back[0].Signature != "s1" || back[1].Thinking != "b" {
|
||||||
|
t.Fatalf("%+v", back)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConvertOpenAIToClaude_AssistantReasoningReplay(t *testing.T) {
|
||||||
|
rc := appendClaudeReasoningRoundTrip("vis", []claudeContentBlock{
|
||||||
|
{Type: "thinking", Thinking: "t1", Signature: "sig1"},
|
||||||
|
})
|
||||||
|
payload := map[string]interface{}{
|
||||||
|
"model": "claude-3-5-sonnet-latest",
|
||||||
|
"messages": []interface{}{
|
||||||
|
map[string]interface{}{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": "out",
|
||||||
|
"reasoning_content": rc,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
req, err := convertOpenAIToClaude(payload)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if len(req.Messages) != 1 {
|
||||||
|
t.Fatalf("messages=%d", len(req.Messages))
|
||||||
|
}
|
||||||
|
blocks := req.Messages[0].Content.Blocks
|
||||||
|
if len(blocks) < 2 {
|
||||||
|
t.Fatalf("blocks=%d", len(blocks))
|
||||||
|
}
|
||||||
|
if blocks[0].Type != "thinking" || blocks[0].Signature != "sig1" {
|
||||||
|
t.Fatalf("first block %+v", blocks[0])
|
||||||
|
}
|
||||||
|
foundText := false
|
||||||
|
for _, b := range blocks {
|
||||||
|
if b.Type == "text" && b.Text == "out" {
|
||||||
|
foundText = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !foundText {
|
||||||
|
t.Fatalf("blocks=%+v", blocks)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestClaudeToOpenAIResponseJSON_Thinking(t *testing.T) {
|
||||||
|
claudeBody := []byte(`{
|
||||||
|
"id":"msg_1","type":"message","role":"assistant","model":"x","stop_reason":"end_turn",
|
||||||
|
"content":[
|
||||||
|
{"type":"thinking","thinking":"step","signature":"sigx"},
|
||||||
|
{"type":"text","text":"hi"}
|
||||||
|
]
|
||||||
|
}`)
|
||||||
|
oai, err := claudeToOpenAIResponseJSON(claudeBody)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
var wrap map[string]interface{}
|
||||||
|
if err := json.Unmarshal(oai, &wrap); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
choices := wrap["choices"].([]interface{})
|
||||||
|
ch0 := choices[0].(map[string]interface{})
|
||||||
|
msg := ch0["message"].(map[string]interface{})
|
||||||
|
rc, _ := msg["reasoning_content"].(string)
|
||||||
|
if !strings.Contains(rc, "step") || !strings.Contains(rc, claudeReasoningRoundTripSep) {
|
||||||
|
t.Fatalf("reasoning_content=%q", rc)
|
||||||
|
}
|
||||||
|
if msg["content"] != "hi" {
|
||||||
|
t.Fatal()
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,56 @@
|
|||||||
|
package openai
|
||||||
|
|
||||||
|
import "testing"
|
||||||
|
|
||||||
|
func TestNormalizeStreamingDelta_RepeatedCharBoundary(t *testing.T) {
|
||||||
|
// 流式在重复数字边界分片:不得把 "43" 的首字符与 "194" 尾字符误合并。
|
||||||
|
cur, d := normalizeStreamingDelta("https://x:194", "43")
|
||||||
|
if want := "https://x:19443"; cur != want {
|
||||||
|
t.Fatalf("next: want %q got %q", want, cur)
|
||||||
|
}
|
||||||
|
if d != "43" {
|
||||||
|
t.Fatalf("delta: want %q got %q", "43", d)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNormalizeStreamingDelta_CumulativePrefix(t *testing.T) {
|
||||||
|
cur, d := normalizeStreamingDelta("今天", "今天天气")
|
||||||
|
if cur != "今天天气" || d != "天气" {
|
||||||
|
t.Fatalf("got cur=%q d=%q", cur, d)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNormalizeStreamingDelta_FullRetransmit(t *testing.T) {
|
||||||
|
cur, d := normalizeStreamingDelta("今天", "今天")
|
||||||
|
if d != "" || cur != "今天" {
|
||||||
|
t.Fatalf("got cur=%q d=%q", cur, d)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNormalizeStreamingDelta_SingleRuneRepeated(t *testing.T) {
|
||||||
|
cur, d := normalizeStreamingDelta("呀", "呀")
|
||||||
|
if want := "呀呀"; cur != want {
|
||||||
|
t.Fatalf("next: want %q got %q", want, cur)
|
||||||
|
}
|
||||||
|
if d != "呀" {
|
||||||
|
t.Fatalf("delta: want %q got %q", "呀", d)
|
||||||
|
}
|
||||||
|
cur, d = normalizeStreamingDelta("4", "4")
|
||||||
|
if want := "44"; cur != want {
|
||||||
|
t.Fatalf("next: want %q got %q", want, cur)
|
||||||
|
}
|
||||||
|
if d != "4" {
|
||||||
|
t.Fatalf("delta: want %q got %q", "4", d)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNormalizeStreamingDelta_CumulativeExtendsNumber(t *testing.T) {
|
||||||
|
// 已缓冲 "194" 后收到累计串 "19443"(注意 "1943" 并非 "19443" 的前缀,不能靠误写的中间态测 HasPrefix)。
|
||||||
|
cur, d := normalizeStreamingDelta("194", "19443")
|
||||||
|
if want := "19443"; cur != want {
|
||||||
|
t.Fatalf("next: want %q got %q", want, cur)
|
||||||
|
}
|
||||||
|
if d != "43" {
|
||||||
|
t.Fatalf("delta: want %q got %q", "43", d)
|
||||||
|
}
|
||||||
|
}
|
||||||
+12
-17
@@ -10,6 +10,7 @@ import (
|
|||||||
"net/http"
|
"net/http"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
"unicode/utf8"
|
||||||
|
|
||||||
"cyberstrike-ai/internal/config"
|
"cyberstrike-ai/internal/config"
|
||||||
|
|
||||||
@@ -34,7 +35,15 @@ func (e *APIError) Error() string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// normalizeStreamingDelta 将可能是“累计片段/重发片段”的内容归一化为“纯增量”。
|
// normalizeStreamingDelta 将可能是“累计片段/重发片段”的内容归一化为“纯增量”。
|
||||||
// 部分兼容网关会返回累计 content;若直接 append 会出现重复文本(结巴)。
|
// 部分兼容网关会返回累计 content;若直接 append 会出现重复文本。
|
||||||
|
//
|
||||||
|
// 注意:
|
||||||
|
// - 不做「任意后缀与前缀重叠」合并;流式可能在重复字符边界分片("194"+"43"→"19443")。
|
||||||
|
// - HasPrefix 仅在 incoming 严格长于 current 时视为累计全文,否则会把分片产生的第二个相同
|
||||||
|
// 单字/单码点(叠字、44、22 等)误判为「整段重复」而吞字。
|
||||||
|
// - incoming==current 仅当 current 长度 >1 个码点时才视为整包重发;单码点重复必须走拼接。
|
||||||
|
// - 不再使用「current 以 incoming 结尾则丢弃」:否则 "1943"+"43" 会误吞增量(19443 显示成 1943)。
|
||||||
|
// 若网关重复发送尾部片段,应重复送完整累计串,由 HasPrefix 分支去重。
|
||||||
func normalizeStreamingDelta(current, incoming string) (next, delta string) {
|
func normalizeStreamingDelta(current, incoming string) (next, delta string) {
|
||||||
if incoming == "" {
|
if incoming == "" {
|
||||||
return current, ""
|
return current, ""
|
||||||
@@ -42,26 +51,12 @@ func normalizeStreamingDelta(current, incoming string) (next, delta string) {
|
|||||||
if current == "" {
|
if current == "" {
|
||||||
return incoming, incoming
|
return incoming, incoming
|
||||||
}
|
}
|
||||||
if incoming == current {
|
if strings.HasPrefix(incoming, current) && len(incoming) > len(current) {
|
||||||
return current, ""
|
|
||||||
}
|
|
||||||
if strings.HasPrefix(incoming, current) {
|
|
||||||
return incoming, incoming[len(current):]
|
return incoming, incoming[len(current):]
|
||||||
}
|
}
|
||||||
if strings.HasSuffix(current, incoming) {
|
if incoming == current && utf8.RuneCountInString(current) > 1 {
|
||||||
return current, ""
|
return current, ""
|
||||||
}
|
}
|
||||||
|
|
||||||
// 边界重叠:current 后缀与 incoming 前缀重合,仅追加非重叠部分。
|
|
||||||
max := len(current)
|
|
||||||
if len(incoming) < max {
|
|
||||||
max = len(incoming)
|
|
||||||
}
|
|
||||||
for overlap := max; overlap > 0; overlap-- {
|
|
||||||
if current[len(current)-overlap:] == incoming[:overlap] {
|
|
||||||
return current + incoming[overlap:], incoming[overlap:]
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return current + incoming, incoming
|
return current + incoming, incoming
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -0,0 +1,20 @@
|
|||||||
|
package openai
|
||||||
|
|
||||||
|
// SSEAccumulatedKey 为 SSE progress 事件 data 中的服务端权威流式全文快照字段。
|
||||||
|
// 前端应优先用该字段更新 buffer,避免对 delta 二次 normalize 导致叠字。
|
||||||
|
const SSEAccumulatedKey = "accumulated"
|
||||||
|
|
||||||
|
// WithSSEAccumulated 在 progress data 中附带当前流式累计全文(权威快照)。
|
||||||
|
func WithSSEAccumulated(data map[string]interface{}, accumulated string) map[string]interface{} {
|
||||||
|
if data == nil {
|
||||||
|
data = make(map[string]interface{}, 1)
|
||||||
|
}
|
||||||
|
data[SSEAccumulatedKey] = accumulated
|
||||||
|
return data
|
||||||
|
}
|
||||||
|
|
||||||
|
// NormalizeStreamingDelta 将可能是“累计片段/重发片段”的内容归一化为“纯增量”。
|
||||||
|
// 与 unexported normalizeStreamingDelta 相同,供 agent / multiagent 等包在发 SSE 前累计正文。
|
||||||
|
func NormalizeStreamingDelta(current, incoming string) (next, delta string) {
|
||||||
|
return normalizeStreamingDelta(current, incoming)
|
||||||
|
}
|
||||||
@@ -0,0 +1,250 @@
|
|||||||
|
// Package reasoning maps user/config intent to CloudWeGo Eino OpenAI ChatModel fields
|
||||||
|
// (ReasoningEffort, ExtraFields such as thinking / reasoning_effort / output_config).
|
||||||
|
package reasoning
|
||||||
|
|
||||||
|
import (
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"cyberstrike-ai/internal/config"
|
||||||
|
|
||||||
|
einoopenai "github.com/cloudwego/eino-ext/components/model/openai"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ClientIntent is optional per-request override from ChatRequest.reasoning.
|
||||||
|
type ClientIntent struct {
|
||||||
|
Mode string
|
||||||
|
Effort string
|
||||||
|
}
|
||||||
|
|
||||||
|
type wireProfile int
|
||||||
|
|
||||||
|
const (
|
||||||
|
wireNone wireProfile = iota
|
||||||
|
wireClaude
|
||||||
|
wireDeepseek
|
||||||
|
wireOpenAI
|
||||||
|
wireOutputConfig
|
||||||
|
)
|
||||||
|
|
||||||
|
// ApplyToEinoChatModelConfig merges reasoning-related options into cfg.
|
||||||
|
// Precondition: cfg already has APIKey, BaseURL, Model, HTTPClient set.
|
||||||
|
func ApplyToEinoChatModelConfig(cfg *einoopenai.ChatModelConfig, oa *config.OpenAIConfig, client *ClientIntent) {
|
||||||
|
if cfg == nil || oa == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
sr := &oa.Reasoning
|
||||||
|
allowClient := sr.AllowClientReasoningEffective()
|
||||||
|
mode := effectiveMode(sr, client, allowClient)
|
||||||
|
|
||||||
|
// Claude (Anthropic): merge admin extras first; optional extended thinking maps to top-level `thinking`
|
||||||
|
// (see internal/openai convertOpenAIToClaude). DeepSeek/OpenAI-style fields are not sent.
|
||||||
|
if strings.EqualFold(strings.TrimSpace(oa.Provider), "claude") ||
|
||||||
|
strings.EqualFold(strings.TrimSpace(oa.Provider), "anthropic") {
|
||||||
|
if len(sr.ExtraRequestFields) > 0 {
|
||||||
|
if cfg.ExtraFields == nil {
|
||||||
|
cfg.ExtraFields = make(map[string]any)
|
||||||
|
}
|
||||||
|
for k, v := range sr.ExtraRequestFields {
|
||||||
|
cfg.ExtraFields[k] = v
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if mode == "off" {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
applyClaudeExtendedThinking(cfg, mode, effectiveEffort(sr, client, allowClient), oa.Model)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if mode == "off" {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
effort := effectiveEffort(sr, client, allowClient)
|
||||||
|
prof := resolveWireProfile(oa, sr)
|
||||||
|
|
||||||
|
// Admin-defined extra root fields (merged first; automatic keys may follow).
|
||||||
|
if len(sr.ExtraRequestFields) > 0 {
|
||||||
|
if cfg.ExtraFields == nil {
|
||||||
|
cfg.ExtraFields = make(map[string]any)
|
||||||
|
}
|
||||||
|
for k, v := range sr.ExtraRequestFields {
|
||||||
|
cfg.ExtraFields[k] = v
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
switch prof {
|
||||||
|
case wireClaude, wireNone:
|
||||||
|
return
|
||||||
|
case wireDeepseek:
|
||||||
|
applyDeepseek(cfg, mode, effort)
|
||||||
|
case wireOutputConfig:
|
||||||
|
applyOutputConfigEffort(cfg, mode, effort)
|
||||||
|
default: // wireOpenAI
|
||||||
|
applyOpenAICompat(cfg, mode, effort)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// applyClaudeExtendedThinking sets Anthropic Messages API `thinking` when absent from ExtraRequestFields.
|
||||||
|
// Uses adaptive + summarized display by default (per Anthropic guidance for Claude 4.x); Sonnet 3.7 uses enabled+budget.
|
||||||
|
func applyClaudeExtendedThinking(cfg *einoopenai.ChatModelConfig, mode, effort, model string) {
|
||||||
|
if cfg == nil || mode == "off" {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if cfg.ExtraFields == nil {
|
||||||
|
cfg.ExtraFields = make(map[string]any)
|
||||||
|
}
|
||||||
|
if _, exists := cfg.ExtraFields["thinking"]; exists {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
m := strings.ToLower(strings.TrimSpace(model))
|
||||||
|
thinking := map[string]any{
|
||||||
|
"type": "adaptive",
|
||||||
|
"display": "summarized",
|
||||||
|
}
|
||||||
|
// Sonnet 3.7: manual extended thinking is the documented path.
|
||||||
|
if strings.Contains(m, "claude-3-7-sonnet") || strings.Contains(m, "3-7-sonnet") || strings.Contains(m, "sonnet-3.7") {
|
||||||
|
thinking = map[string]any{
|
||||||
|
"type": "enabled",
|
||||||
|
"budget_tokens": 10000,
|
||||||
|
"display": "summarized",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Opus 4.7+: manual enabled+budget rejected — keep adaptive only.
|
||||||
|
if strings.Contains(m, "opus-4-7") || strings.Contains(m, "opus-4.7") {
|
||||||
|
thinking = map[string]any{
|
||||||
|
"type": "adaptive",
|
||||||
|
"display": "summarized",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
_ = effort // reserved: map to Anthropic effort / output_config when API stabilizes in one place
|
||||||
|
cfg.ExtraFields["thinking"] = thinking
|
||||||
|
}
|
||||||
|
|
||||||
|
func effectiveMode(sr *config.OpenAIReasoningConfig, client *ClientIntent, allowClient bool) string {
|
||||||
|
server := strings.ToLower(strings.TrimSpace(sr.ModeEffective()))
|
||||||
|
if server == "" || server == "default" {
|
||||||
|
server = "auto"
|
||||||
|
}
|
||||||
|
if !allowClient || client == nil {
|
||||||
|
return server
|
||||||
|
}
|
||||||
|
cm := strings.ToLower(strings.TrimSpace(client.Mode))
|
||||||
|
if cm == "" || cm == "default" {
|
||||||
|
return server
|
||||||
|
}
|
||||||
|
return cm
|
||||||
|
}
|
||||||
|
|
||||||
|
func effectiveEffort(sr *config.OpenAIReasoningConfig, client *ClientIntent, allowClient bool) string {
|
||||||
|
se := normalizeEffort(sr.Effort)
|
||||||
|
if !allowClient || client == nil {
|
||||||
|
return se
|
||||||
|
}
|
||||||
|
ce := normalizeEffort(client.Effort)
|
||||||
|
if ce != "" {
|
||||||
|
return ce
|
||||||
|
}
|
||||||
|
return se
|
||||||
|
}
|
||||||
|
|
||||||
|
func normalizeEffort(s string) string {
|
||||||
|
e := strings.ToLower(strings.TrimSpace(s))
|
||||||
|
switch e {
|
||||||
|
case "low", "medium", "high", "max":
|
||||||
|
return e
|
||||||
|
default:
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func resolveWireProfile(oa *config.OpenAIConfig, sr *config.OpenAIReasoningConfig) wireProfile {
|
||||||
|
if strings.EqualFold(strings.TrimSpace(oa.Provider), "claude") {
|
||||||
|
return wireClaude
|
||||||
|
}
|
||||||
|
p := strings.ToLower(strings.TrimSpace(sr.ProfileEffective()))
|
||||||
|
switch p {
|
||||||
|
case "output_config", "output_config_effort":
|
||||||
|
return wireOutputConfig
|
||||||
|
case "openai", "openai_compat":
|
||||||
|
return wireOpenAI
|
||||||
|
case "deepseek", "deepseek_compat":
|
||||||
|
return wireDeepseek
|
||||||
|
case "auto", "":
|
||||||
|
bu := strings.ToLower(oa.BaseURL)
|
||||||
|
mo := strings.ToLower(oa.Model)
|
||||||
|
if strings.Contains(bu, "deepseek") || strings.Contains(mo, "deepseek") {
|
||||||
|
return wireDeepseek
|
||||||
|
}
|
||||||
|
return wireOpenAI
|
||||||
|
default:
|
||||||
|
return wireOpenAI
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func applyDeepseek(cfg *einoopenai.ChatModelConfig, mode, effort string) {
|
||||||
|
// auto: enable thinking for DeepSeek line; on: same; auto without effort still opens thinking.
|
||||||
|
if mode == "off" {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if mode == "auto" || mode == "on" {
|
||||||
|
if cfg.ExtraFields == nil {
|
||||||
|
cfg.ExtraFields = make(map[string]any)
|
||||||
|
}
|
||||||
|
cfg.ExtraFields["thinking"] = map[string]any{"type": "enabled"}
|
||||||
|
}
|
||||||
|
if effort != "" {
|
||||||
|
if cfg.ExtraFields == nil {
|
||||||
|
cfg.ExtraFields = make(map[string]any)
|
||||||
|
}
|
||||||
|
cfg.ExtraFields["reasoning_effort"] = effortStringForAPI(effort)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func applyOpenAICompat(cfg *einoopenai.ChatModelConfig, mode, effort string) {
|
||||||
|
if mode == "auto" && effort == "" {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
e := effort
|
||||||
|
if mode == "on" && e == "" {
|
||||||
|
e = "medium"
|
||||||
|
}
|
||||||
|
if e == "" {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if e == "max" {
|
||||||
|
if cfg.ExtraFields == nil {
|
||||||
|
cfg.ExtraFields = make(map[string]any)
|
||||||
|
}
|
||||||
|
cfg.ExtraFields["reasoning_effort"] = "max"
|
||||||
|
return
|
||||||
|
}
|
||||||
|
switch e {
|
||||||
|
case "low":
|
||||||
|
cfg.ReasoningEffort = einoopenai.ReasoningEffortLevelLow
|
||||||
|
case "medium":
|
||||||
|
cfg.ReasoningEffort = einoopenai.ReasoningEffortLevelMedium
|
||||||
|
case "high":
|
||||||
|
cfg.ReasoningEffort = einoopenai.ReasoningEffortLevelHigh
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func applyOutputConfigEffort(cfg *einoopenai.ChatModelConfig, mode, effort string) {
|
||||||
|
if mode == "auto" && effort == "" {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
e := effort
|
||||||
|
if mode == "on" && e == "" {
|
||||||
|
e = "high"
|
||||||
|
}
|
||||||
|
if e == "" {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if cfg.ExtraFields == nil {
|
||||||
|
cfg.ExtraFields = make(map[string]any)
|
||||||
|
}
|
||||||
|
cfg.ExtraFields["output_config"] = map[string]any{"effort": effortStringForAPI(e)}
|
||||||
|
}
|
||||||
|
|
||||||
|
func effortStringForAPI(e string) string {
|
||||||
|
// Gateways expect lowercase strings; "max" kept as max.
|
||||||
|
return strings.ToLower(strings.TrimSpace(e))
|
||||||
|
}
|
||||||
@@ -0,0 +1,316 @@
|
|||||||
|
package ilink
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"crypto/rand"
|
||||||
|
"encoding/base64"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"net/url"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
DefaultBaseURL = "https://ilinkai.weixin.qq.com"
|
||||||
|
DefaultBotType = "3"
|
||||||
|
DefaultBotAgent = "CyberStrikeAI/1.0"
|
||||||
|
ILinkAppID = "bot"
|
||||||
|
QRLongPollTimeout = 35 * time.Second
|
||||||
|
APIDefaultTimeout = 15 * time.Second
|
||||||
|
GetUpdatesTimeout = 35 * time.Second
|
||||||
|
)
|
||||||
|
|
||||||
|
// Client 微信 iLink Bot HTTP 客户端(与 @tencent-weixin/openclaw-weixin 协议兼容)
|
||||||
|
type Client struct {
|
||||||
|
BaseURL string
|
||||||
|
BotToken string
|
||||||
|
BotAgent string
|
||||||
|
ClientVersion uint32
|
||||||
|
HTTP *http.Client
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewClient(baseURL, botToken, botAgent string, clientVersion uint32) *Client {
|
||||||
|
base := strings.TrimSpace(baseURL)
|
||||||
|
if base == "" {
|
||||||
|
base = DefaultBaseURL
|
||||||
|
}
|
||||||
|
agent := strings.TrimSpace(botAgent)
|
||||||
|
if agent == "" {
|
||||||
|
agent = DefaultBotAgent
|
||||||
|
}
|
||||||
|
return &Client{
|
||||||
|
BaseURL: strings.TrimRight(base, "/"),
|
||||||
|
BotToken: strings.TrimSpace(botToken),
|
||||||
|
BotAgent: sanitizeBotAgent(agent),
|
||||||
|
ClientVersion: clientVersion,
|
||||||
|
HTTP: &http.Client{Timeout: 0},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// BuildClientVersion 将 semver 编码为 iLink-App-ClientVersion(0x00MMNNPP)
|
||||||
|
func BuildClientVersion(version string) uint32 {
|
||||||
|
parts := strings.Split(version, ".")
|
||||||
|
parse := func(i int) int {
|
||||||
|
if i >= len(parts) {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
n, _ := strconv.Atoi(strings.TrimSpace(parts[i]))
|
||||||
|
if n < 0 {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
return n
|
||||||
|
}
|
||||||
|
major := parse(0) & 0xff
|
||||||
|
minor := parse(1) & 0xff
|
||||||
|
patch := parse(2) & 0xff
|
||||||
|
return uint32((major << 16) | (minor << 8) | patch)
|
||||||
|
}
|
||||||
|
|
||||||
|
type baseInfo struct {
|
||||||
|
ChannelVersion string `json:"channel_version"`
|
||||||
|
BotAgent string `json:"bot_agent"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Client) buildBaseInfo() baseInfo {
|
||||||
|
return baseInfo{
|
||||||
|
ChannelVersion: "1.0.0",
|
||||||
|
BotAgent: c.BotAgent,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func randomWechatUIN() string {
|
||||||
|
var b [4]byte
|
||||||
|
_, _ = rand.Read(b[:])
|
||||||
|
u := uint32(b[0])<<24 | uint32(b[1])<<16 | uint32(b[2])<<8 | uint32(b[3])
|
||||||
|
return base64.StdEncoding.EncodeToString([]byte(strconv.FormatUint(uint64(u), 10)))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Client) commonHeaders() http.Header {
|
||||||
|
h := http.Header{}
|
||||||
|
h.Set("iLink-App-Id", ILinkAppID)
|
||||||
|
h.Set("iLink-App-ClientVersion", strconv.FormatUint(uint64(c.ClientVersion), 10))
|
||||||
|
return h
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Client) authHeaders() http.Header {
|
||||||
|
h := c.commonHeaders()
|
||||||
|
h.Set("Content-Type", "application/json")
|
||||||
|
h.Set("AuthorizationType", "ilink_bot_token")
|
||||||
|
h.Set("X-WECHAT-UIN", randomWechatUIN())
|
||||||
|
if c.BotToken != "" {
|
||||||
|
h.Set("Authorization", "Bearer "+c.BotToken)
|
||||||
|
}
|
||||||
|
return h
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Client) endpointURL(path string) (string, error) {
|
||||||
|
u, err := url.Parse(c.BaseURL + "/")
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
ref, err := url.Parse(path)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
return u.ResolveReference(ref).String(), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Client) doRequest(ctx context.Context, method, path string, body []byte, headers http.Header, timeout time.Duration) ([]byte, error) {
|
||||||
|
reqURL, err := c.endpointURL(path)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
var bodyReader io.Reader
|
||||||
|
if len(body) > 0 {
|
||||||
|
bodyReader = bytes.NewReader(body)
|
||||||
|
}
|
||||||
|
req, err := http.NewRequestWithContext(ctx, method, reqURL, bodyReader)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
for k, vs := range headers {
|
||||||
|
for _, v := range vs {
|
||||||
|
req.Header.Add(k, v)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
client := c.HTTP
|
||||||
|
if client == nil {
|
||||||
|
client = http.DefaultClient
|
||||||
|
}
|
||||||
|
if timeout > 0 {
|
||||||
|
ctx2, cancel := context.WithTimeout(ctx, timeout)
|
||||||
|
defer cancel()
|
||||||
|
req = req.WithContext(ctx2)
|
||||||
|
}
|
||||||
|
resp, err := client.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
raw, err := io.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||||
|
return nil, fmt.Errorf("ilink %s %s: %d %s", method, path, resp.StatusCode, string(raw))
|
||||||
|
}
|
||||||
|
return raw, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// QRCodeResponse 获取二维码响应
|
||||||
|
type QRCodeResponse struct {
|
||||||
|
QRCode string `json:"qrcode"`
|
||||||
|
QRCodeImgContent string `json:"qrcode_img_content"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetBotQRCode 获取绑定二维码
|
||||||
|
func (c *Client) GetBotQRCode(ctx context.Context, botType string, localTokenList []string) (*QRCodeResponse, error) {
|
||||||
|
if strings.TrimSpace(botType) == "" {
|
||||||
|
botType = DefaultBotType
|
||||||
|
}
|
||||||
|
body, _ := json.Marshal(map[string]interface{}{
|
||||||
|
"local_token_list": localTokenList,
|
||||||
|
})
|
||||||
|
path := "ilink/bot/get_bot_qrcode?bot_type=" + url.QueryEscape(botType)
|
||||||
|
raw, err := c.doRequest(ctx, http.MethodPost, path, body, c.authHeaders(), APIDefaultTimeout)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
var out QRCodeResponse
|
||||||
|
if err := json.Unmarshal(raw, &out); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &out, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// QRStatusResponse 二维码状态轮询响应
|
||||||
|
type QRStatusResponse struct {
|
||||||
|
Status string `json:"status"`
|
||||||
|
BotToken string `json:"bot_token"`
|
||||||
|
ILinkBotID string `json:"ilink_bot_id"`
|
||||||
|
ILinkUserID string `json:"ilink_user_id"`
|
||||||
|
BaseURL string `json:"baseurl"`
|
||||||
|
RedirectHost string `json:"redirect_host"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetQRCodeStatus 长轮询二维码扫码状态
|
||||||
|
func (c *Client) GetQRCodeStatus(ctx context.Context, qrcode, verifyCode string) (*QRStatusResponse, error) {
|
||||||
|
path := "ilink/bot/get_qrcode_status?qrcode=" + url.QueryEscape(qrcode)
|
||||||
|
if verifyCode != "" {
|
||||||
|
path += "&verify_code=" + url.QueryEscape(verifyCode)
|
||||||
|
}
|
||||||
|
raw, err := c.doRequest(ctx, http.MethodGet, path, nil, c.commonHeaders(), QRLongPollTimeout)
|
||||||
|
if err != nil {
|
||||||
|
if ctx.Err() != nil {
|
||||||
|
return &QRStatusResponse{Status: "wait"}, nil
|
||||||
|
}
|
||||||
|
return &QRStatusResponse{Status: "wait"}, nil
|
||||||
|
}
|
||||||
|
var out QRStatusResponse
|
||||||
|
if err := json.Unmarshal(raw, &out); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &out, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// MessageItem 消息内容项
|
||||||
|
type MessageItem struct {
|
||||||
|
Type int `json:"type"`
|
||||||
|
TextItem *struct {
|
||||||
|
Text string `json:"text"`
|
||||||
|
} `json:"text_item,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// WeixinMessage 入站消息
|
||||||
|
type WeixinMessage struct {
|
||||||
|
FromUserID string `json:"from_user_id"`
|
||||||
|
MessageType int `json:"message_type"`
|
||||||
|
MessageState int `json:"message_state"`
|
||||||
|
ItemList []MessageItem `json:"item_list"`
|
||||||
|
ContextToken string `json:"context_token"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetUpdatesResponse 长轮询消息响应
|
||||||
|
type GetUpdatesResponse struct {
|
||||||
|
Ret int `json:"ret"`
|
||||||
|
ErrCode int `json:"errcode"`
|
||||||
|
ErrMsg string `json:"errmsg"`
|
||||||
|
Msgs []WeixinMessage `json:"msgs"`
|
||||||
|
GetUpdatesBuf string `json:"get_updates_buf"`
|
||||||
|
LongPollingTimeoutMs int `json:"longpolling_timeout_ms"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetUpdates 长轮询获取新消息
|
||||||
|
func (c *Client) GetUpdates(ctx context.Context, getUpdatesBuf string) (*GetUpdatesResponse, error) {
|
||||||
|
body, _ := json.Marshal(map[string]interface{}{
|
||||||
|
"get_updates_buf": getUpdatesBuf,
|
||||||
|
"base_info": c.buildBaseInfo(),
|
||||||
|
})
|
||||||
|
raw, err := c.doRequest(ctx, http.MethodPost, "ilink/bot/getupdates", body, c.authHeaders(), GetUpdatesTimeout)
|
||||||
|
if err != nil {
|
||||||
|
if ctx.Err() != nil {
|
||||||
|
return &GetUpdatesResponse{Ret: 0, GetUpdatesBuf: getUpdatesBuf}, nil
|
||||||
|
}
|
||||||
|
return &GetUpdatesResponse{Ret: 0, GetUpdatesBuf: getUpdatesBuf}, nil
|
||||||
|
}
|
||||||
|
var out GetUpdatesResponse
|
||||||
|
if err := json.Unmarshal(raw, &out); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &out, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// SendTextMessage 发送文本回复
|
||||||
|
func (c *Client) SendTextMessage(ctx context.Context, toUserID, contextToken, text, clientID string) error {
|
||||||
|
if clientID == "" {
|
||||||
|
clientID = randomClientID()
|
||||||
|
}
|
||||||
|
payload := map[string]interface{}{
|
||||||
|
"msg": map[string]interface{}{
|
||||||
|
"to_user_id": toUserID,
|
||||||
|
"client_id": clientID,
|
||||||
|
"message_type": 2,
|
||||||
|
"message_state": 2,
|
||||||
|
"context_token": contextToken,
|
||||||
|
"item_list": []map[string]interface{}{
|
||||||
|
{"type": 1, "text_item": map[string]string{"text": text}},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"base_info": c.buildBaseInfo(),
|
||||||
|
}
|
||||||
|
body, _ := json.Marshal(payload)
|
||||||
|
_, err := c.doRequest(ctx, http.MethodPost, "ilink/bot/sendmessage", body, c.authHeaders(), APIDefaultTimeout)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func randomClientID() string {
|
||||||
|
var b [8]byte
|
||||||
|
_, _ = rand.Read(b[:])
|
||||||
|
return fmt.Sprintf("%x", b)
|
||||||
|
}
|
||||||
|
|
||||||
|
func sanitizeBotAgent(raw string) string {
|
||||||
|
raw = strings.TrimSpace(raw)
|
||||||
|
if raw == "" {
|
||||||
|
return DefaultBotAgent
|
||||||
|
}
|
||||||
|
if len(raw) > 256 {
|
||||||
|
return raw[:256]
|
||||||
|
}
|
||||||
|
return raw
|
||||||
|
}
|
||||||
|
|
||||||
|
// ExtractText 从消息中提取首条文本
|
||||||
|
func ExtractText(msg WeixinMessage) string {
|
||||||
|
for _, item := range msg.ItemList {
|
||||||
|
if item.Type == 1 && item.TextItem != nil {
|
||||||
|
return strings.TrimSpace(item.TextItem.Text)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
@@ -0,0 +1,26 @@
|
|||||||
|
package ilink
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/base64"
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/skip2/go-qrcode"
|
||||||
|
)
|
||||||
|
|
||||||
|
// QRCodeDataURL 将扫码内容(一般为 liteapp 链接)编码为 PNG data URL,供 Web 端展示。
|
||||||
|
// qrcode_img_content 不是图片直链,不能用作 <img src>。
|
||||||
|
func QRCodeDataURL(content string, size int) (string, error) {
|
||||||
|
content = strings.TrimSpace(content)
|
||||||
|
if content == "" {
|
||||||
|
return "", fmt.Errorf("empty qr content")
|
||||||
|
}
|
||||||
|
if size <= 0 {
|
||||||
|
size = 256
|
||||||
|
}
|
||||||
|
png, err := qrcode.Encode(content, qrcode.Medium, size)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
return "data:image/png;base64," + base64.StdEncoding.EncodeToString(png), nil
|
||||||
|
}
|
||||||
@@ -0,0 +1,96 @@
|
|||||||
|
package robot
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"cyberstrike-ai/internal/config"
|
||||||
|
"cyberstrike-ai/internal/robot/ilink"
|
||||||
|
|
||||||
|
"go.uber.org/zap"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
wechatReconnectInitial = 5 * time.Second
|
||||||
|
wechatReconnectMax = 60 * time.Second
|
||||||
|
wechatPlatform = "wechat"
|
||||||
|
)
|
||||||
|
|
||||||
|
// StartWechat 启动微信 iLink 长轮询(无需公网回调),收到消息后调用 handler 并回复。
|
||||||
|
func StartWechat(ctx context.Context, robotsCfg config.RobotsConfig, h MessageHandler, appVersion string, logger *zap.Logger) {
|
||||||
|
cfg := robotsCfg.Wechat
|
||||||
|
if !cfg.Enabled || cfg.BotToken == "" {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
go runWechatLoop(ctx, cfg, h, appVersion, logger)
|
||||||
|
}
|
||||||
|
|
||||||
|
func runWechatLoop(ctx context.Context, cfg config.RobotWechatConfig, h MessageHandler, appVersion string, logger *zap.Logger) {
|
||||||
|
backoff := wechatReconnectInitial
|
||||||
|
for {
|
||||||
|
err := runWechatPoll(ctx, cfg, h, appVersion, logger)
|
||||||
|
if ctx.Err() != nil {
|
||||||
|
logger.Info("微信 iLink 长轮询已按配置关闭")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
logger.Warn("微信 iLink 长轮询异常,将自动重连", zap.Error(err), zap.Duration("retry_after", backoff))
|
||||||
|
}
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return
|
||||||
|
case <-time.After(backoff):
|
||||||
|
if backoff < wechatReconnectMax {
|
||||||
|
backoff *= 2
|
||||||
|
if backoff > wechatReconnectMax {
|
||||||
|
backoff = wechatReconnectMax
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func runWechatPoll(ctx context.Context, cfg config.RobotWechatConfig, h MessageHandler, appVersion string, logger *zap.Logger) error {
|
||||||
|
client := ilink.NewClient(cfg.BaseURL, cfg.BotToken, cfg.BotAgent, ilink.BuildClientVersion(appVersion))
|
||||||
|
buf := cfg.GetUpdatesBuf
|
||||||
|
logger.Info("微信 iLink 长轮询已启动", zap.String("ilink_bot_id", cfg.ILinkBotID))
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return ctx.Err()
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
resp, err := client.GetUpdates(ctx, buf)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if resp.ErrCode != 0 && resp.Ret != 0 {
|
||||||
|
logger.Warn("微信 getUpdates 返回错误", zap.Int("errcode", resp.ErrCode), zap.String("errmsg", resp.ErrMsg))
|
||||||
|
}
|
||||||
|
if resp.GetUpdatesBuf != "" {
|
||||||
|
buf = resp.GetUpdatesBuf
|
||||||
|
}
|
||||||
|
for _, msg := range resp.Msgs {
|
||||||
|
if msg.MessageType != 1 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
text := ilink.ExtractText(msg)
|
||||||
|
if text == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
userID := strings.TrimSpace(msg.FromUserID)
|
||||||
|
if userID == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
logger.Info("微信收到消息", zap.String("from", userID), zap.String("content", text))
|
||||||
|
reply := h.HandleMessage(wechatPlatform, userID, text)
|
||||||
|
if strings.TrimSpace(reply) == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if err := client.SendTextMessage(ctx, userID, msg.ContextToken, reply, ""); err != nil {
|
||||||
|
logger.Warn("微信发送回复失败", zap.String("to", userID), zap.Error(err))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -153,6 +153,7 @@ func (e *Executor) ExecuteTool(ctx context.Context, toolName string, args map[st
|
|||||||
// 执行命令
|
// 执行命令
|
||||||
cmd := exec.CommandContext(ctx, toolConfig.Command, cmdArgs...)
|
cmd := exec.CommandContext(ctx, toolConfig.Command, cmdArgs...)
|
||||||
applyDefaultTerminalEnv(cmd)
|
applyDefaultTerminalEnv(cmd)
|
||||||
|
_ = prepareShellCmdSession(cmd)
|
||||||
|
|
||||||
e.logger.Info("执行安全工具",
|
e.logger.Info("执行安全工具",
|
||||||
zap.String("tool", toolName),
|
zap.String("tool", toolName),
|
||||||
@@ -163,13 +164,14 @@ func (e *Executor) ExecuteTool(ctx context.Context, toolName string, args map[st
|
|||||||
var err error
|
var err error
|
||||||
// 如果上层提供了 stdout/stderr 增量回调,则边执行边读取并回调。
|
// 如果上层提供了 stdout/stderr 增量回调,则边执行边读取并回调。
|
||||||
if cb, ok := ctx.Value(ToolOutputCallbackCtxKey).(ToolOutputCallback); ok && cb != nil {
|
if cb, ok := ctx.Value(ToolOutputCallbackCtxKey).(ToolOutputCallback); ok && cb != nil {
|
||||||
output, err = streamCommandOutput(cmd, cb)
|
output, err = streamCommandOutput(ctx, cmd, cb)
|
||||||
if err != nil && shouldRetryWithPTY(output) {
|
if err != nil && shouldRetryWithPTY(output) {
|
||||||
e.logger.Info("检测到工具需要 TTY,使用 PTY 重试",
|
e.logger.Info("检测到工具需要 TTY,使用 PTY 重试",
|
||||||
zap.String("tool", toolName),
|
zap.String("tool", toolName),
|
||||||
)
|
)
|
||||||
cmd2 := exec.CommandContext(ctx, toolConfig.Command, cmdArgs...)
|
cmd2 := exec.CommandContext(ctx, toolConfig.Command, cmdArgs...)
|
||||||
applyDefaultTerminalEnv(cmd2)
|
applyDefaultTerminalEnv(cmd2)
|
||||||
|
_ = prepareShellCmdSession(cmd2)
|
||||||
output, err = runCommandWithPTY(ctx, cmd2, cb)
|
output, err = runCommandWithPTY(ctx, cmd2, cb)
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
@@ -182,6 +184,7 @@ func (e *Executor) ExecuteTool(ctx context.Context, toolName string, args map[st
|
|||||||
)
|
)
|
||||||
cmd2 := exec.CommandContext(ctx, toolConfig.Command, cmdArgs...)
|
cmd2 := exec.CommandContext(ctx, toolConfig.Command, cmdArgs...)
|
||||||
applyDefaultTerminalEnv(cmd2)
|
applyDefaultTerminalEnv(cmd2)
|
||||||
|
_ = prepareShellCmdSession(cmd2)
|
||||||
output, err = runCommandWithPTY(ctx, cmd2, nil)
|
output, err = runCommandWithPTY(ctx, cmd2, nil)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -837,6 +840,8 @@ func (e *Executor) executeSystemCommand(ctx context.Context, args map[string]int
|
|||||||
} else {
|
} else {
|
||||||
cmd = exec.CommandContext(ctx, shell, "-c", command)
|
cmd = exec.CommandContext(ctx, shell, "-c", command)
|
||||||
}
|
}
|
||||||
|
applyDefaultTerminalEnv(cmd)
|
||||||
|
_ = prepareShellCmdSession(cmd)
|
||||||
|
|
||||||
// 执行命令
|
// 执行命令
|
||||||
e.logger.Info("执行系统命令",
|
e.logger.Info("执行系统命令",
|
||||||
@@ -865,6 +870,8 @@ func (e *Executor) executeSystemCommand(ctx context.Context, args map[string]int
|
|||||||
} else {
|
} else {
|
||||||
pidCmd = exec.CommandContext(ctx, shell, "-c", pidCommand)
|
pidCmd = exec.CommandContext(ctx, shell, "-c", pidCommand)
|
||||||
}
|
}
|
||||||
|
applyDefaultTerminalEnv(pidCmd)
|
||||||
|
_ = prepareShellCmdSession(pidCmd)
|
||||||
|
|
||||||
// 获取stdout管道
|
// 获取stdout管道
|
||||||
stdout, err := pidCmd.StdoutPipe()
|
stdout, err := pidCmd.StdoutPipe()
|
||||||
@@ -976,7 +983,7 @@ func (e *Executor) executeSystemCommand(ctx context.Context, args map[string]int
|
|||||||
var err error
|
var err error
|
||||||
// 若上层提供工具输出增量回调,则边执行边流式读取。
|
// 若上层提供工具输出增量回调,则边执行边流式读取。
|
||||||
if cb, ok := ctx.Value(ToolOutputCallbackCtxKey).(ToolOutputCallback); ok && cb != nil {
|
if cb, ok := ctx.Value(ToolOutputCallbackCtxKey).(ToolOutputCallback); ok && cb != nil {
|
||||||
output, err = streamCommandOutput(cmd, cb)
|
output, err = streamCommandOutput(ctx, cmd, cb)
|
||||||
if err != nil && shouldRetryWithPTY(output) {
|
if err != nil && shouldRetryWithPTY(output) {
|
||||||
e.logger.Info("检测到系统命令需要 TTY,使用 PTY 重试")
|
e.logger.Info("检测到系统命令需要 TTY,使用 PTY 重试")
|
||||||
cmd2 := exec.CommandContext(ctx, shell, "-c", command)
|
cmd2 := exec.CommandContext(ctx, shell, "-c", command)
|
||||||
@@ -984,6 +991,7 @@ func (e *Executor) executeSystemCommand(ctx context.Context, args map[string]int
|
|||||||
cmd2.Dir = workDir
|
cmd2.Dir = workDir
|
||||||
}
|
}
|
||||||
applyDefaultTerminalEnv(cmd2)
|
applyDefaultTerminalEnv(cmd2)
|
||||||
|
_ = prepareShellCmdSession(cmd2)
|
||||||
output, err = runCommandWithPTY(ctx, cmd2, cb)
|
output, err = runCommandWithPTY(ctx, cmd2, cb)
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
@@ -997,6 +1005,7 @@ func (e *Executor) executeSystemCommand(ctx context.Context, args map[string]int
|
|||||||
cmd2.Dir = workDir
|
cmd2.Dir = workDir
|
||||||
}
|
}
|
||||||
applyDefaultTerminalEnv(cmd2)
|
applyDefaultTerminalEnv(cmd2)
|
||||||
|
_ = prepareShellCmdSession(cmd2)
|
||||||
output, err = runCommandWithPTY(ctx, cmd2, nil)
|
output, err = runCommandWithPTY(ctx, cmd2, nil)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -1034,8 +1043,11 @@ func (e *Executor) executeSystemCommand(ctx context.Context, args map[string]int
|
|||||||
}
|
}
|
||||||
|
|
||||||
// streamCommandOutput 以“边读边回调”的方式读取命令 stdout/stderr。
|
// streamCommandOutput 以“边读边回调”的方式读取命令 stdout/stderr。
|
||||||
// 保持输出内容完整拼接返回,并用 cb(chunk) 向上层持续推送。
|
// 使用定长块读取,避免按行读取在无换行输出时永久阻塞;ctx 取消时终止进程树。
|
||||||
func streamCommandOutput(cmd *exec.Cmd, cb ToolOutputCallback) (string, error) {
|
func streamCommandOutput(ctx context.Context, cmd *exec.Cmd, cb ToolOutputCallback) (string, error) {
|
||||||
|
if err := prepareShellCmdSession(cmd); err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
stdoutPipe, err := cmd.StdoutPipe()
|
stdoutPipe, err := cmd.StdoutPipe()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
@@ -1051,18 +1063,27 @@ func streamCommandOutput(cmd *exec.Cmd, cb ToolOutputCallback) (string, error) {
|
|||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
stopWatch := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
terminateCmdTree(cmd)
|
||||||
|
case <-stopWatch:
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
defer close(stopWatch)
|
||||||
|
|
||||||
chunks := make(chan string, 64)
|
chunks := make(chan string, 64)
|
||||||
var wg sync.WaitGroup
|
var wg sync.WaitGroup
|
||||||
readFn := func(r io.Reader) {
|
readFn := func(r io.Reader) {
|
||||||
defer wg.Done()
|
defer wg.Done()
|
||||||
br := bufio.NewReader(r)
|
buf := make([]byte, 8192)
|
||||||
for {
|
for {
|
||||||
s, readErr := br.ReadString('\n')
|
n, readErr := r.Read(buf)
|
||||||
if s != "" {
|
if n > 0 {
|
||||||
chunks <- s
|
chunks <- string(buf[:n])
|
||||||
}
|
}
|
||||||
if readErr != nil {
|
if readErr != nil {
|
||||||
// EOF 正常结束
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -1158,12 +1179,14 @@ func runCommandWithPTY(ctx context.Context, cmd *exec.Cmd, cb ToolOutputCallback
|
|||||||
if runtime.GOOS == "windows" {
|
if runtime.GOOS == "windows" {
|
||||||
// PTY 方案为类 Unix;Windows 走原逻辑
|
// PTY 方案为类 Unix;Windows 走原逻辑
|
||||||
if cb != nil {
|
if cb != nil {
|
||||||
return streamCommandOutput(cmd, cb)
|
return streamCommandOutput(ctx, cmd, cb)
|
||||||
}
|
}
|
||||||
|
_ = prepareShellCmdSession(cmd)
|
||||||
out, err := cmd.CombinedOutput()
|
out, err := cmd.CombinedOutput()
|
||||||
return string(out), err
|
return string(out), err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
_ = prepareShellCmdSession(cmd)
|
||||||
ptmx, err := pty.Start(cmd)
|
ptmx, err := pty.Start(cmd)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
@@ -1176,9 +1199,7 @@ func runCommandWithPTY(ctx context.Context, cmd *exec.Cmd, cb ToolOutputCallback
|
|||||||
select {
|
select {
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
_ = ptmx.Close() // 触发读退出
|
_ = ptmx.Close() // 触发读退出
|
||||||
if cmd.Process != nil {
|
terminateCmdTree(cmd)
|
||||||
_ = cmd.Process.Kill()
|
|
||||||
}
|
|
||||||
case <-done:
|
case <-done:
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|||||||
@@ -0,0 +1,31 @@
|
|||||||
|
//go:build !windows
|
||||||
|
|
||||||
|
package security
|
||||||
|
|
||||||
|
import (
|
||||||
|
"os/exec"
|
||||||
|
"syscall"
|
||||||
|
)
|
||||||
|
|
||||||
|
// prepareShellCmdSession 让 shell 子进程在独立会话中运行,便于超时/取消时整组 SIGKILL(含子进程)。
|
||||||
|
func prepareShellCmdSession(cmd *exec.Cmd) error {
|
||||||
|
if cmd == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if cmd.SysProcAttr == nil {
|
||||||
|
cmd.SysProcAttr = &syscall.SysProcAttr{}
|
||||||
|
}
|
||||||
|
cmd.SysProcAttr.Setsid = true
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// terminateCmdTree 尽力终止 cmd 及其进程组(Unix 下 Setsid 后 PGID == 首进程 PID)。
|
||||||
|
func terminateCmdTree(cmd *exec.Cmd) {
|
||||||
|
if cmd == nil || cmd.Process == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
pid := cmd.Process.Pid
|
||||||
|
if err := syscall.Kill(-pid, syscall.SIGKILL); err != nil {
|
||||||
|
_ = cmd.Process.Kill()
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,17 @@
|
|||||||
|
//go:build windows
|
||||||
|
|
||||||
|
package security
|
||||||
|
|
||||||
|
import "os/exec"
|
||||||
|
|
||||||
|
func prepareShellCmdSession(cmd *exec.Cmd) error {
|
||||||
|
_ = cmd
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func terminateCmdTree(cmd *exec.Cmd) {
|
||||||
|
if cmd == nil || cmd.Process == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
_ = cmd.Process.Kill()
|
||||||
|
}
|
||||||
@@ -4,7 +4,7 @@
|
|||||||
|
|
||||||
### What it does
|
### What it does
|
||||||
|
|
||||||
- Configure **Host / Port / Password** and choose **Single-Agent** or **Multi-Agent**
|
- Configure **Host / Port / HTTPS / Password** and choose an agent mode
|
||||||
- Click **Validate** to login (`POST /api/auth/login`) and verify token (`GET /api/auth/validate`)
|
- Click **Validate** to login (`POST /api/auth/login`) and verify token (`GET /api/auth/validate`)
|
||||||
- Right-click any HTTP message in Burp and send it to CyberStrikeAI for **streaming web pentest**
|
- Right-click any HTTP message in Burp and send it to CyberStrikeAI for **streaming web pentest**
|
||||||
- Keep a **test history sidebar** (searchable) so you can revisit previous runs
|
- Keep a **test history sidebar** (searchable) so you can revisit previous runs
|
||||||
@@ -63,6 +63,7 @@ If you already have Gradle available, you can still use `build.gradle` to build.
|
|||||||
|
|
||||||
### Notes
|
### Notes
|
||||||
|
|
||||||
- This extension connects to your CyberStrikeAI server (default is `http://127.0.0.1:8080`).
|
- Default connection is `https://127.0.0.1:8080` (**HTTPS** checked). Self-signed / local certs are trusted automatically (no import).
|
||||||
|
- Uncheck **HTTPS** only if your server runs plain HTTP.
|
||||||
- It uses **Bearer Token** authentication obtained from the configured password.
|
- It uses **Bearer Token** authentication obtained from the configured password.
|
||||||
|
|
||||||
|
|||||||
@@ -81,7 +81,8 @@ cd plugins/burp-suite/cyberstrikeai-burp-extension
|
|||||||
2) 填写:
|
2) 填写:
|
||||||
- **Host**:例如 `127.0.0.1`
|
- **Host**:例如 `127.0.0.1`
|
||||||
- **Port**:例如 `8080`
|
- **Port**:例如 `8080`
|
||||||
- **Password**:你的 CyberStrikeAI 登录密码(对应服务端 `config.yaml` 的 `auth.password`)
|
- **HTTPS**:默认勾选(对接 `config.yaml` 中 `tls_enabled` / 自签证书);插件会自动信任本地自签证书,无需导入
|
||||||
|
- **Password**:你的 CyberStrikeAI 登录密码(对应服务端 `auth.password`)
|
||||||
- **Agent mode**:选择 `Single Agent` 或 `Multi Agent`
|
- **Agent mode**:选择 `Single Agent` 或 `Multi Agent`
|
||||||
3) 点击 **Validate**
|
3) 点击 **Validate**
|
||||||
- 成功:状态显示 `OK (token saved)`
|
- 成功:状态显示 `OK (token saved)`
|
||||||
@@ -94,8 +95,9 @@ cd plugins/burp-suite/cyberstrikeai-burp-extension
|
|||||||
|
|
||||||
- **Validate 失败 / 401**
|
- **Validate 失败 / 401**
|
||||||
- 确认密码是否正确(服务端 `auth.password`)
|
- 确认密码是否正确(服务端 `auth.password`)
|
||||||
- 确认 IP/端口是否能访问(例如浏览器能打开 `http://IP:PORT/`)
|
- 确认 IP/端口是否能访问(例如浏览器能打开 `https://IP:PORT/`)
|
||||||
- 若服务器启用了反向代理/HTTPS,需要把插件里 baseUrl 改成对应协议与端口(当前插件默认使用 `http://`)
|
- 服务端启用 TLS 时勾选 **HTTPS**(默认已勾选);自签证书无需手动导入
|
||||||
|
- 若仍为纯 HTTP 部署,取消勾选 **HTTPS**
|
||||||
|
|
||||||
- **选择 Multi Agent 后提示“多代理未启用”**
|
- **选择 Multi Agent 后提示“多代理未启用”**
|
||||||
- 服务端需要开启:`config.yaml` 中 `multi_agent.enabled: true`
|
- 服务端需要开启:`config.yaml` 中 `multi_agent.enabled: true`
|
||||||
|
|||||||
BIN
Binary file not shown.
+52
-11
@@ -73,15 +73,34 @@ public class BurpExtender implements IBurpExtender, IContextMenuFactory {
|
|||||||
public void onEvent(String type, String message, String rawJson) {
|
public void onEvent(String type, String message, String rawJson) {
|
||||||
if (type == null) type = "";
|
if (type == null) type = "";
|
||||||
switch (type) {
|
switch (type) {
|
||||||
|
case "response_start":
|
||||||
|
tab.appendProgressToRun(runId, "\n\n[主回复]\n");
|
||||||
|
break;
|
||||||
case "response_delta":
|
case "response_delta":
|
||||||
case "eino_agent_reply_stream_delta":
|
if (message != null && !message.isEmpty()) {
|
||||||
tab.appendFinalToRun(runId, message);
|
tab.appendFinalToRun(runId, message);
|
||||||
|
}
|
||||||
break;
|
break;
|
||||||
case "response":
|
case "response":
|
||||||
tab.appendFinalToRun(runId, "\n\n--- Final Response ---\n");
|
|
||||||
tab.appendFinalToRun(runId, message);
|
tab.appendFinalToRun(runId, message);
|
||||||
tab.setFinalResponse(runId, message);
|
tab.setFinalResponse(runId, message);
|
||||||
break;
|
break;
|
||||||
|
case "eino_agent_reply_stream_start":
|
||||||
|
tab.appendProgressToRun(runId, "\n\n[子代理回复]\n");
|
||||||
|
break;
|
||||||
|
case "eino_agent_reply_stream_delta":
|
||||||
|
if (message != null && !message.isEmpty()) {
|
||||||
|
tab.appendProgressToRun(runId, message);
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
case "eino_agent_reply_stream_end":
|
||||||
|
tab.appendProgressToRun(runId, "\n");
|
||||||
|
break;
|
||||||
|
case "eino_agent_reply":
|
||||||
|
if (message != null && !message.isEmpty()) {
|
||||||
|
tab.appendProgressToRun(runId, "\n\n[子代理回复]\n" + message + "\n");
|
||||||
|
}
|
||||||
|
break;
|
||||||
case "progress":
|
case "progress":
|
||||||
tab.appendProgressToRun(runId, "\n[progress] " + message + "\n");
|
tab.appendProgressToRun(runId, "\n[progress] " + message + "\n");
|
||||||
tab.setRunStatus(runId, "running");
|
tab.setRunStatus(runId, "running");
|
||||||
@@ -94,21 +113,40 @@ public class BurpExtender implements IBurpExtender, IContextMenuFactory {
|
|||||||
tab.appendProgressToRun(runId, "\n[error] " + message + "\n");
|
tab.appendProgressToRun(runId, "\n[error] " + message + "\n");
|
||||||
tab.setRunStatus(runId, "error");
|
tab.setRunStatus(runId, "error");
|
||||||
break;
|
break;
|
||||||
|
case "reasoning_chain_stream_start":
|
||||||
|
tab.appendProgressToRun(runId, "\n\n[推理过程]\n");
|
||||||
|
break;
|
||||||
|
case "reasoning_chain_stream_delta":
|
||||||
|
if (message != null && !message.isEmpty()) {
|
||||||
|
tab.appendProgressToRun(runId, message);
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
case "reasoning_chain_stream_end":
|
||||||
|
tab.appendProgressToRun(runId, "\n");
|
||||||
|
break;
|
||||||
|
case "reasoning_chain":
|
||||||
|
if (message != null && !message.isEmpty()) {
|
||||||
|
String streamId = rawJson != null ? SimpleJson.extractStringField(rawJson, "streamId") : "";
|
||||||
|
if (streamId == null || streamId.isEmpty()) {
|
||||||
|
tab.appendProgressToRun(runId, "\n\n[推理过程]\n" + message + "\n");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
break;
|
||||||
case "thinking_stream_start":
|
case "thinking_stream_start":
|
||||||
if (tab.isShowDebugEvents()) {
|
if (tab.isShowDebugEvents()) {
|
||||||
tab.resetThinkingStream(runId);
|
tab.resetThinkingStream(runId);
|
||||||
}
|
}
|
||||||
break;
|
break;
|
||||||
case "thinking_stream_delta":
|
case "thinking_stream_delta":
|
||||||
|
if (tab.isShowDebugEvents() && message != null && !message.isEmpty()) {
|
||||||
|
tab.appendProgressToRun(runId, message);
|
||||||
|
}
|
||||||
|
break;
|
||||||
case "tool_call":
|
case "tool_call":
|
||||||
case "tool_result":
|
case "tool_result":
|
||||||
case "tool_result_delta":
|
case "tool_result_delta":
|
||||||
if (tab.isShowDebugEvents() && message != null && !message.isEmpty()) {
|
if (tab.isShowDebugEvents() && message != null && !message.isEmpty()) {
|
||||||
if ("thinking_stream_delta".equals(type)) {
|
tab.appendProgressToRun(runId, "\n[" + type + "] " + message + "\n");
|
||||||
tab.appendThinkingDelta(runId, message);
|
|
||||||
} else {
|
|
||||||
tab.appendProgressToRun(runId, "\n[" + type + "] " + message + "\n");
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
break;
|
break;
|
||||||
case "conversation":
|
case "conversation":
|
||||||
@@ -125,7 +163,9 @@ public class BurpExtender implements IBurpExtender, IContextMenuFactory {
|
|||||||
case "done":
|
case "done":
|
||||||
break;
|
break;
|
||||||
default:
|
default:
|
||||||
if (tab.isShowDebugEvents() && message != null && !message.isEmpty()) {
|
if (tab.isShowDebugEvents() && message != null && !message.isEmpty()
|
||||||
|
&& !type.endsWith("_stream_delta") && !type.endsWith("_stream_start")
|
||||||
|
&& !type.endsWith("_stream_end")) {
|
||||||
tab.appendProgressToRun(runId, "\n[" + type + "] " + message + "\n");
|
tab.appendProgressToRun(runId, "\n[" + type + "] " + message + "\n");
|
||||||
}
|
}
|
||||||
break;
|
break;
|
||||||
@@ -134,8 +174,9 @@ public class BurpExtender implements IBurpExtender, IContextMenuFactory {
|
|||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void onError(String message, Exception e) {
|
public void onError(String message, Exception e) {
|
||||||
tab.appendProgressToRun(runId, "\n[error] " + message + "\n");
|
boolean cancelled = message != null && message.toLowerCase().contains("cancel");
|
||||||
tab.setRunStatus(runId, "error");
|
tab.appendProgressToRun(runId, cancelled ? "\n[info] " + message + "\n" : "\n[error] " + message + "\n");
|
||||||
|
tab.setRunStatus(runId, cancelled ? "cancelled" : "error");
|
||||||
callbacks.printError("CyberStrikeAI stream error: " + message);
|
callbacks.printError("CyberStrikeAI stream error: " + message);
|
||||||
if (e != null) {
|
if (e != null) {
|
||||||
callbacks.printError(e.toString());
|
callbacks.printError(e.toString());
|
||||||
|
|||||||
+127
-11
@@ -2,17 +2,29 @@ package burp;
|
|||||||
|
|
||||||
import java.io.BufferedReader;
|
import java.io.BufferedReader;
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
|
import java.io.InterruptedIOException;
|
||||||
import java.io.InputStream;
|
import java.io.InputStream;
|
||||||
import java.io.InputStreamReader;
|
import java.io.InputStreamReader;
|
||||||
import java.io.OutputStream;
|
import java.io.OutputStream;
|
||||||
import java.net.HttpURLConnection;
|
import java.net.HttpURLConnection;
|
||||||
|
import java.net.SocketTimeoutException;
|
||||||
import java.net.URL;
|
import java.net.URL;
|
||||||
import java.nio.charset.StandardCharsets;
|
import java.nio.charset.StandardCharsets;
|
||||||
import java.util.HashMap;
|
import java.util.HashMap;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
|
import java.util.concurrent.atomic.AtomicReference;
|
||||||
|
|
||||||
final class CyberStrikeAIClient {
|
final class CyberStrikeAIClient {
|
||||||
|
|
||||||
|
private static final int AUTH_CONNECT_TIMEOUT_MS = 4_000;
|
||||||
|
private static final int AUTH_READ_TIMEOUT_MS = 5_000;
|
||||||
|
/** login + validate 整段上限,避免两次读超时叠加拖到半分钟 */
|
||||||
|
private static final int AUTH_OVERALL_TIMEOUT_MS = 10_000;
|
||||||
|
private static final int DEFAULT_READ_TIMEOUT_MS = 15_000;
|
||||||
|
|
||||||
|
private final AtomicReference<HttpURLConnection> activeConnection = new AtomicReference<>();
|
||||||
|
private final AtomicReference<Thread> activeThread = new AtomicReference<>();
|
||||||
|
|
||||||
static final class Config {
|
static final class Config {
|
||||||
final String baseUrl; // e.g. http://127.0.0.1:8080
|
final String baseUrl; // e.g. http://127.0.0.1:8080
|
||||||
final String password;
|
final String password;
|
||||||
@@ -49,15 +61,97 @@ final class CyberStrikeAIClient {
|
|||||||
void onDone();
|
void onDone();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
boolean hasActiveRequest() {
|
||||||
|
return activeConnection.get() != null;
|
||||||
|
}
|
||||||
|
|
||||||
|
void cancelActiveRequest() {
|
||||||
|
HttpURLConnection conn = activeConnection.getAndSet(null);
|
||||||
|
if (conn != null) {
|
||||||
|
try {
|
||||||
|
conn.disconnect();
|
||||||
|
} catch (Exception ignored) {
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Thread t = activeThread.getAndSet(null);
|
||||||
|
if (t != null) {
|
||||||
|
t.interrupt();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
String loginAndValidate(Config cfg) throws IOException {
|
String loginAndValidate(Config cfg) throws IOException {
|
||||||
String token = login(cfg.baseUrl, cfg.password);
|
Thread worker = Thread.currentThread();
|
||||||
validate(cfg.baseUrl, token);
|
java.util.Timer deadline = new java.util.Timer("CyberStrikeAI-AuthDeadline", true);
|
||||||
return token;
|
deadline.schedule(new java.util.TimerTask() {
|
||||||
|
@Override
|
||||||
|
public void run() {
|
||||||
|
worker.interrupt();
|
||||||
|
cancelActiveRequest();
|
||||||
|
}
|
||||||
|
}, AUTH_OVERALL_TIMEOUT_MS);
|
||||||
|
try {
|
||||||
|
String token = login(cfg.baseUrl, cfg.password);
|
||||||
|
if (Thread.interrupted()) {
|
||||||
|
throw timeoutIOException();
|
||||||
|
}
|
||||||
|
validate(cfg.baseUrl, token);
|
||||||
|
if (Thread.interrupted()) {
|
||||||
|
throw timeoutIOException();
|
||||||
|
}
|
||||||
|
return token;
|
||||||
|
} catch (SocketTimeoutException e) {
|
||||||
|
throw timeoutIOException();
|
||||||
|
} finally {
|
||||||
|
deadline.cancel();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private static IOException timeoutIOException() {
|
||||||
|
return new IOException("Connection timed out (~" + (AUTH_OVERALL_TIMEOUT_MS / 1000)
|
||||||
|
+ "s). Check host/port and HTTPS checkbox.");
|
||||||
|
}
|
||||||
|
|
||||||
|
private void trackConnection(HttpURLConnection conn) {
|
||||||
|
activeThread.set(Thread.currentThread());
|
||||||
|
activeConnection.set(conn);
|
||||||
|
}
|
||||||
|
|
||||||
|
private void releaseConnection(HttpURLConnection conn) {
|
||||||
|
if (activeConnection.compareAndSet(conn, null)) {
|
||||||
|
activeThread.set(null);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private static boolean isCancelled(Throwable e) {
|
||||||
|
if (e == null) {
|
||||||
|
return Thread.currentThread().isInterrupted();
|
||||||
|
}
|
||||||
|
if (Thread.currentThread().isInterrupted()) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
if (e instanceof InterruptedIOException) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
if (e instanceof SocketTimeoutException) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
Throwable cause = e.getCause();
|
||||||
|
if (cause != null && cause != e) {
|
||||||
|
return isCancelled(cause);
|
||||||
|
}
|
||||||
|
String msg = e.getMessage();
|
||||||
|
return msg != null && (
|
||||||
|
msg.toLowerCase().contains("cancel")
|
||||||
|
|| msg.toLowerCase().contains("abort")
|
||||||
|
|| msg.toLowerCase().contains("closed")
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
private String login(String baseUrl, String password) throws IOException {
|
private String login(String baseUrl, String password) throws IOException {
|
||||||
URL url = new URL(baseUrl + "/api/auth/login");
|
URL url = new URL(baseUrl + "/api/auth/login");
|
||||||
HttpURLConnection conn = (HttpURLConnection) url.openConnection();
|
HttpURLConnection conn = SslTrustAll.open(url, AUTH_CONNECT_TIMEOUT_MS, AUTH_READ_TIMEOUT_MS);
|
||||||
|
trackConnection(conn);
|
||||||
|
try {
|
||||||
conn.setRequestMethod("POST");
|
conn.setRequestMethod("POST");
|
||||||
conn.setDoOutput(true);
|
conn.setDoOutput(true);
|
||||||
conn.setRequestProperty("Content-Type", "application/json");
|
conn.setRequestProperty("Content-Type", "application/json");
|
||||||
@@ -92,11 +186,16 @@ final class CyberStrikeAIClient {
|
|||||||
throw new IOException("Login response missing token. Check backend address and credentials.");
|
throw new IOException("Login response missing token. Check backend address and credentials.");
|
||||||
}
|
}
|
||||||
return token;
|
return token;
|
||||||
|
} finally {
|
||||||
|
releaseConnection(conn);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private void validate(String baseUrl, String token) throws IOException {
|
private void validate(String baseUrl, String token) throws IOException {
|
||||||
URL url = new URL(baseUrl + "/api/auth/validate");
|
URL url = new URL(baseUrl + "/api/auth/validate");
|
||||||
HttpURLConnection conn = (HttpURLConnection) url.openConnection();
|
HttpURLConnection conn = SslTrustAll.open(url, AUTH_CONNECT_TIMEOUT_MS, AUTH_READ_TIMEOUT_MS);
|
||||||
|
trackConnection(conn);
|
||||||
|
try {
|
||||||
conn.setRequestMethod("GET");
|
conn.setRequestMethod("GET");
|
||||||
conn.setRequestProperty("Authorization", "Bearer " + token);
|
conn.setRequestProperty("Authorization", "Bearer " + token);
|
||||||
int code = conn.getResponseCode();
|
int code = conn.getResponseCode();
|
||||||
@@ -104,6 +203,9 @@ final class CyberStrikeAIClient {
|
|||||||
if (code < 200 || code >= 300) {
|
if (code < 200 || code >= 300) {
|
||||||
throw new IOException("Validate failed (" + code + "): " + resp);
|
throw new IOException("Validate failed (" + code + "): " + resp);
|
||||||
}
|
}
|
||||||
|
} finally {
|
||||||
|
releaseConnection(conn);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void streamTest(Config cfg, String token, String message, StreamListener listener) {
|
void streamTest(Config cfg, String token, String message, StreamListener listener) {
|
||||||
@@ -117,11 +219,12 @@ final class CyberStrikeAIClient {
|
|||||||
payload.put("orchestration", cfg.agentMode.orchestration);
|
payload.put("orchestration", cfg.agentMode.orchestration);
|
||||||
}
|
}
|
||||||
|
|
||||||
new Thread(() -> {
|
Thread worker = new Thread(() -> {
|
||||||
HttpURLConnection conn = null;
|
HttpURLConnection conn = null;
|
||||||
try {
|
try {
|
||||||
URL url = new URL(urlStr);
|
URL url = new URL(urlStr);
|
||||||
conn = (HttpURLConnection) url.openConnection();
|
conn = SslTrustAll.open(url, AUTH_CONNECT_TIMEOUT_MS, 0);
|
||||||
|
trackConnection(conn);
|
||||||
conn.setRequestMethod("POST");
|
conn.setRequestMethod("POST");
|
||||||
conn.setDoOutput(true);
|
conn.setDoOutput(true);
|
||||||
conn.setRequestProperty("Content-Type", "application/json");
|
conn.setRequestProperty("Content-Type", "application/json");
|
||||||
@@ -142,6 +245,9 @@ final class CyberStrikeAIClient {
|
|||||||
try (BufferedReader br = new BufferedReader(new InputStreamReader(is, StandardCharsets.UTF_8))) {
|
try (BufferedReader br = new BufferedReader(new InputStreamReader(is, StandardCharsets.UTF_8))) {
|
||||||
String line;
|
String line;
|
||||||
while ((line = br.readLine()) != null) {
|
while ((line = br.readLine()) != null) {
|
||||||
|
if (Thread.currentThread().isInterrupted()) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
// SSE format: "data: {json}"
|
// SSE format: "data: {json}"
|
||||||
if (line.startsWith("data:")) {
|
if (line.startsWith("data:")) {
|
||||||
String json = line.substring("data:".length()).trim();
|
String json = line.substring("data:".length()).trim();
|
||||||
@@ -156,15 +262,25 @@ final class CyberStrikeAIClient {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
listener.onDone();
|
if (Thread.currentThread().isInterrupted()) {
|
||||||
|
listener.onError("Cancelled.", null);
|
||||||
|
} else {
|
||||||
|
listener.onDone();
|
||||||
|
}
|
||||||
} catch (Exception e) {
|
} catch (Exception e) {
|
||||||
listener.onError(e.getMessage(), e);
|
if (isCancelled(e)) {
|
||||||
|
listener.onError("Cancelled.", e);
|
||||||
|
} else {
|
||||||
|
listener.onError(e.getMessage(), e);
|
||||||
|
}
|
||||||
} finally {
|
} finally {
|
||||||
if (conn != null) {
|
if (conn != null) {
|
||||||
|
releaseConnection(conn);
|
||||||
conn.disconnect();
|
conn.disconnect();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}, "CyberStrikeAI-Stream").start();
|
}, "CyberStrikeAI-Stream");
|
||||||
|
worker.start();
|
||||||
}
|
}
|
||||||
|
|
||||||
void cancelByConversationId(String baseUrl, String token, String conversationId) throws IOException {
|
void cancelByConversationId(String baseUrl, String token, String conversationId) throws IOException {
|
||||||
@@ -172,7 +288,7 @@ final class CyberStrikeAIClient {
|
|||||||
throw new IOException("Missing conversationId.");
|
throw new IOException("Missing conversationId.");
|
||||||
}
|
}
|
||||||
URL url = new URL(baseUrl + "/api/agent-loop/cancel");
|
URL url = new URL(baseUrl + "/api/agent-loop/cancel");
|
||||||
HttpURLConnection conn = (HttpURLConnection) url.openConnection();
|
HttpURLConnection conn = SslTrustAll.open(url, AUTH_CONNECT_TIMEOUT_MS, AUTH_READ_TIMEOUT_MS);
|
||||||
conn.setRequestMethod("POST");
|
conn.setRequestMethod("POST");
|
||||||
conn.setDoOutput(true);
|
conn.setDoOutput(true);
|
||||||
conn.setRequestProperty("Content-Type", "application/json");
|
conn.setRequestProperty("Content-Type", "application/json");
|
||||||
|
|||||||
+130
-34
@@ -14,6 +14,7 @@ final class CyberStrikeAITab implements ITab {
|
|||||||
|
|
||||||
private final JTextField hostField = new JTextField("127.0.0.1");
|
private final JTextField hostField = new JTextField("127.0.0.1");
|
||||||
private final JTextField portField = new JTextField("8080");
|
private final JTextField portField = new JTextField("8080");
|
||||||
|
private final JCheckBox useHttpsBox = new JCheckBox("HTTPS", true);
|
||||||
private final JPasswordField passwordField = new JPasswordField();
|
private final JPasswordField passwordField = new JPasswordField();
|
||||||
private final JComboBox<String> agentModeBox = new JComboBox<>(new String[]{
|
private final JComboBox<String> agentModeBox = new JComboBox<>(new String[]{
|
||||||
"Native ReAct", "Eino Single (ADK)", "Deep (DeepAgent)", "Plan-Execute", "Supervisor"
|
"Native ReAct", "Eino Single (ADK)", "Deep (DeepAgent)", "Plan-Execute", "Supervisor"
|
||||||
@@ -29,6 +30,10 @@ final class CyberStrikeAITab implements ITab {
|
|||||||
|
|
||||||
private final JTextArea progressArea = new JTextArea();
|
private final JTextArea progressArea = new JTextArea();
|
||||||
private final JTextArea finalRawArea = new JTextArea(); // raw final stream / final response
|
private final JTextArea finalRawArea = new JTextArea(); // raw final stream / final response
|
||||||
|
private JScrollPane progressScrollPane;
|
||||||
|
private JScrollPane finalRawScrollPane;
|
||||||
|
/** 距底部在此像素内视为「跟随滚动」,否则用户上拉阅读时不抢滚动条 */
|
||||||
|
private static final int SCROLL_FOLLOW_THRESHOLD_PX = 48;
|
||||||
private final JEditorPane markdownPane = new JEditorPane("text/html", "");
|
private final JEditorPane markdownPane = new JEditorPane("text/html", "");
|
||||||
private final CardLayout outputCardsLayout = new CardLayout();
|
private final CardLayout outputCardsLayout = new CardLayout();
|
||||||
private final JPanel outputCards = new JPanel(outputCardsLayout);
|
private final JPanel outputCards = new JPanel(outputCardsLayout);
|
||||||
@@ -41,6 +46,7 @@ final class CyberStrikeAITab implements ITab {
|
|||||||
|
|
||||||
private final CyberStrikeAIClient client = new CyberStrikeAIClient();
|
private final CyberStrikeAIClient client = new CyberStrikeAIClient();
|
||||||
private final AtomicReference<String> tokenRef = new AtomicReference<>("");
|
private final AtomicReference<String> tokenRef = new AtomicReference<>("");
|
||||||
|
private final AtomicReference<Thread> validateThreadRef = new AtomicReference<>();
|
||||||
|
|
||||||
private final DefaultListModel<TestRun> testListModel = new DefaultListModel<>();
|
private final DefaultListModel<TestRun> testListModel = new DefaultListModel<>();
|
||||||
private final JList<TestRun> testList = new JList<>(testListModel);
|
private final JList<TestRun> testList = new JList<>(testListModel);
|
||||||
@@ -107,6 +113,8 @@ final class CyberStrikeAITab implements ITab {
|
|||||||
row1.add(hostField);
|
row1.add(hostField);
|
||||||
row1.add(new JLabel("Port"));
|
row1.add(new JLabel("Port"));
|
||||||
row1.add(portField);
|
row1.add(portField);
|
||||||
|
useHttpsBox.setToolTipText("Use https:// for CyberStrikeAI (self-signed certs are trusted automatically)");
|
||||||
|
row1.add(useHttpsBox);
|
||||||
row1.add(new JLabel("Password"));
|
row1.add(new JLabel("Password"));
|
||||||
row1.add(passwordField);
|
row1.add(passwordField);
|
||||||
row1.add(validateButton);
|
row1.add(validateButton);
|
||||||
@@ -186,15 +194,22 @@ final class CyberStrikeAITab implements ITab {
|
|||||||
configureTextArea(requestArea, false);
|
configureTextArea(requestArea, false);
|
||||||
configureTextArea(responseArea, false);
|
configureTextArea(responseArea, false);
|
||||||
|
|
||||||
outputCards.add(new JScrollPane(finalRawArea), "raw");
|
finalRawScrollPane = new JScrollPane(finalRawArea);
|
||||||
|
finalRawScrollPane.setHorizontalScrollBarPolicy(ScrollPaneConstants.HORIZONTAL_SCROLLBAR_NEVER);
|
||||||
|
finalRawScrollPane.getVerticalScrollBar().setUnitIncrement(16);
|
||||||
|
outputCards.add(finalRawScrollPane, "raw");
|
||||||
outputCards.add(new JScrollPane(markdownPane), "md");
|
outputCards.add(new JScrollPane(markdownPane), "md");
|
||||||
|
|
||||||
outputRoot.add(buildOutputHeader(), BorderLayout.NORTH);
|
outputRoot.add(buildOutputHeader(), BorderLayout.NORTH);
|
||||||
outputRoot.add(buildOutputBody(), BorderLayout.CENTER);
|
outputRoot.add(buildOutputBody(), BorderLayout.CENTER);
|
||||||
|
|
||||||
rightTabs.addTab("Output", outputRoot);
|
rightTabs.addTab("Output", outputRoot);
|
||||||
rightTabs.addTab("Request", new JScrollPane(requestArea));
|
JScrollPane requestScroll = new JScrollPane(requestArea);
|
||||||
rightTabs.addTab("Response", new JScrollPane(responseArea));
|
requestScroll.setHorizontalScrollBarPolicy(ScrollPaneConstants.HORIZONTAL_SCROLLBAR_NEVER);
|
||||||
|
rightTabs.addTab("Request", requestScroll);
|
||||||
|
JScrollPane responseScroll = new JScrollPane(responseArea);
|
||||||
|
responseScroll.setHorizontalScrollBarPolicy(ScrollPaneConstants.HORIZONTAL_SCROLLBAR_NEVER);
|
||||||
|
rightTabs.addTab("Response", responseScroll);
|
||||||
return rightTabs;
|
return rightTabs;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -210,12 +225,13 @@ final class CyberStrikeAITab implements ITab {
|
|||||||
}
|
}
|
||||||
|
|
||||||
private JComponent buildOutputBody() {
|
private JComponent buildOutputBody() {
|
||||||
JScrollPane progressScroll = new JScrollPane(progressArea);
|
progressScrollPane = new JScrollPane(progressArea);
|
||||||
progressScroll.setBorder(BorderFactory.createTitledBorder("Progress"));
|
progressScrollPane.setBorder(BorderFactory.createTitledBorder("Progress"));
|
||||||
progressScroll.getVerticalScrollBar().setUnitIncrement(16);
|
progressScrollPane.setHorizontalScrollBarPolicy(ScrollPaneConstants.HORIZONTAL_SCROLLBAR_NEVER);
|
||||||
|
progressScrollPane.getVerticalScrollBar().setUnitIncrement(16);
|
||||||
|
|
||||||
JPanel empty = new JPanel();
|
JPanel empty = new JPanel();
|
||||||
progressContainer.add(progressScroll, "show");
|
progressContainer.add(progressScrollPane, "show");
|
||||||
progressContainer.add(empty, "hide");
|
progressContainer.add(empty, "hide");
|
||||||
((CardLayout) progressContainer.getLayout()).show(progressContainer, "show");
|
((CardLayout) progressContainer.getLayout()).show(progressContainer, "show");
|
||||||
|
|
||||||
@@ -259,10 +275,27 @@ final class CyberStrikeAITab implements ITab {
|
|||||||
return split;
|
return split;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private static boolean isScrollNearBottom(JScrollPane scrollPane) {
|
||||||
|
if (scrollPane == null) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
JScrollBar bar = scrollPane.getVerticalScrollBar();
|
||||||
|
int max = Math.max(0, bar.getMaximum() - bar.getVisibleAmount());
|
||||||
|
return bar.getValue() >= max - SCROLL_FOLLOW_THRESHOLD_PX;
|
||||||
|
}
|
||||||
|
|
||||||
|
private static void scrollPaneToBottom(JScrollPane scrollPane) {
|
||||||
|
if (scrollPane == null) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
JScrollBar bar = scrollPane.getVerticalScrollBar();
|
||||||
|
bar.setValue(bar.getMaximum());
|
||||||
|
}
|
||||||
|
|
||||||
private static void configureTextArea(JTextArea area, boolean monospaced) {
|
private static void configureTextArea(JTextArea area, boolean monospaced) {
|
||||||
area.setEditable(false);
|
area.setEditable(false);
|
||||||
area.setLineWrap(false);
|
area.setLineWrap(true);
|
||||||
area.setWrapStyleWord(false);
|
area.setWrapStyleWord(true);
|
||||||
if (monospaced) {
|
if (monospaced) {
|
||||||
area.setFont(new Font(Font.MONOSPACED, Font.PLAIN, 12));
|
area.setFont(new Font(Font.MONOSPACED, Font.PLAIN, 12));
|
||||||
} else {
|
} else {
|
||||||
@@ -381,24 +414,44 @@ final class CyberStrikeAITab implements ITab {
|
|||||||
|
|
||||||
private void wireActions() {
|
private void wireActions() {
|
||||||
validateButton.addActionListener(e -> {
|
validateButton.addActionListener(e -> {
|
||||||
validateButton.setEnabled(false);
|
if ("Cancel".equals(validateButton.getText())) {
|
||||||
|
cancelValidateInProgress();
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
validateButton.setText("Cancel");
|
||||||
|
validateButton.setEnabled(true);
|
||||||
|
stopButton.setEnabled(true);
|
||||||
statusLabel.setText("Validating...");
|
statusLabel.setText("Validating...");
|
||||||
log("Validating connection...");
|
log("Validating connection... (max ~10s; click Cancel or Stop to abort)");
|
||||||
new Thread(() -> {
|
Thread worker = new Thread(() -> {
|
||||||
try {
|
try {
|
||||||
CyberStrikeAIClient.Config cfg = currentConfig();
|
CyberStrikeAIClient.Config cfg = currentConfig();
|
||||||
String token = client.loginAndValidate(cfg);
|
String token = client.loginAndValidate(cfg);
|
||||||
|
if (Thread.currentThread().isInterrupted()) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
tokenRef.set(token);
|
tokenRef.set(token);
|
||||||
SwingUtilities.invokeLater(() -> statusLabel.setText("OK (token saved)"));
|
SwingUtilities.invokeLater(() -> statusLabel.setText("OK (token saved)"));
|
||||||
log("Validation OK.");
|
log("Validation OK.");
|
||||||
} catch (Exception ex) {
|
} catch (Exception ex) {
|
||||||
tokenRef.set("");
|
tokenRef.set("");
|
||||||
SwingUtilities.invokeLater(() -> statusLabel.setText("Failed: " + ex.getMessage()));
|
if (Thread.currentThread().isInterrupted()) {
|
||||||
log("Validation failed: " + ex.getMessage());
|
SwingUtilities.invokeLater(() -> statusLabel.setText("Cancelled"));
|
||||||
|
log("Validation cancelled.");
|
||||||
|
} else {
|
||||||
|
SwingUtilities.invokeLater(() -> statusLabel.setText("Failed: " + ex.getMessage()));
|
||||||
|
log("Validation failed: " + ex.getMessage());
|
||||||
|
}
|
||||||
} finally {
|
} finally {
|
||||||
SwingUtilities.invokeLater(() -> validateButton.setEnabled(true));
|
validateThreadRef.set(null);
|
||||||
|
SwingUtilities.invokeLater(() -> {
|
||||||
|
validateButton.setText("Validate");
|
||||||
|
validateButton.setEnabled(true);
|
||||||
|
});
|
||||||
}
|
}
|
||||||
}, "CyberStrikeAI-Validate").start();
|
}, "CyberStrikeAI-Validate");
|
||||||
|
validateThreadRef.set(worker);
|
||||||
|
worker.start();
|
||||||
});
|
});
|
||||||
|
|
||||||
clearButton.addActionListener(e -> {
|
clearButton.addActionListener(e -> {
|
||||||
@@ -435,10 +488,23 @@ final class CyberStrikeAITab implements ITab {
|
|||||||
});
|
});
|
||||||
|
|
||||||
stopButton.addActionListener(e -> {
|
stopButton.addActionListener(e -> {
|
||||||
|
if ("Cancel".equals(validateButton.getText())) {
|
||||||
|
cancelValidateInProgress();
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
String runId = selectedRunId;
|
String runId = selectedRunId;
|
||||||
|
if (runId != null && client.hasActiveRequest()) {
|
||||||
|
client.cancelActiveRequest();
|
||||||
|
appendProgressToRun(runId, "\n[info] Stream stopped.\n");
|
||||||
|
setRunStatus(runId, "cancelled");
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
if (runId == null) return;
|
if (runId == null) return;
|
||||||
TestRun run = runs.get(runId);
|
TestRun run = runs.get(runId);
|
||||||
if (run == null) return;
|
if (run == null) return;
|
||||||
|
|
||||||
String token = getToken();
|
String token = getToken();
|
||||||
if (token == null || token.trim().isEmpty()) {
|
if (token == null || token.trim().isEmpty()) {
|
||||||
appendProgressToRun(runId, "\n[error] Not validated.\n");
|
appendProgressToRun(runId, "\n[error] Not validated.\n");
|
||||||
@@ -483,7 +549,8 @@ final class CyberStrikeAITab implements ITab {
|
|||||||
String host = hostField.getText().trim();
|
String host = hostField.getText().trim();
|
||||||
String port = portField.getText().trim();
|
String port = portField.getText().trim();
|
||||||
String password = new String(passwordField.getPassword());
|
String password = new String(passwordField.getPassword());
|
||||||
String baseUrl = "http://" + host + ":" + port;
|
String scheme = useHttpsBox.isSelected() ? "https" : "http";
|
||||||
|
String baseUrl = scheme + "://" + host + ":" + port;
|
||||||
int idx = agentModeBox.getSelectedIndex();
|
int idx = agentModeBox.getSelectedIndex();
|
||||||
CyberStrikeAIClient.AgentMode mode = (idx >= 0 && idx < AGENT_MODES.length)
|
CyberStrikeAIClient.AgentMode mode = (idx >= 0 && idx < AGENT_MODES.length)
|
||||||
? AGENT_MODES[idx]
|
? AGENT_MODES[idx]
|
||||||
@@ -567,10 +634,31 @@ final class CyberStrikeAITab implements ITab {
|
|||||||
run.progressBuffer.append(s);
|
run.progressBuffer.append(s);
|
||||||
}
|
}
|
||||||
if (runId.equals(selectedRunId)) {
|
if (runId.equals(selectedRunId)) {
|
||||||
SwingUtilities.invokeLater(() -> {
|
SwingUtilities.invokeLater(() -> appendProgressUi(s, false));
|
||||||
progressArea.append(s);
|
}
|
||||||
progressArea.setCaretPosition(progressArea.getDocument().getLength());
|
}
|
||||||
});
|
|
||||||
|
private void appendProgressUi(String s, boolean forceFollow) {
|
||||||
|
JScrollBar bar = progressScrollPane != null ? progressScrollPane.getVerticalScrollBar() : null;
|
||||||
|
int scrollBefore = bar != null ? bar.getValue() : 0;
|
||||||
|
boolean follow = forceFollow || isScrollNearBottom(progressScrollPane);
|
||||||
|
progressArea.append(s);
|
||||||
|
if (follow) {
|
||||||
|
scrollPaneToBottom(progressScrollPane);
|
||||||
|
} else if (bar != null) {
|
||||||
|
bar.setValue(scrollBefore);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private void appendFinalUi(String s, boolean forceFollow) {
|
||||||
|
JScrollBar bar = finalRawScrollPane != null ? finalRawScrollPane.getVerticalScrollBar() : null;
|
||||||
|
int scrollBefore = bar != null ? bar.getValue() : 0;
|
||||||
|
boolean follow = forceFollow || isScrollNearBottom(finalRawScrollPane);
|
||||||
|
finalRawArea.append(s);
|
||||||
|
if (follow) {
|
||||||
|
scrollPaneToBottom(finalRawScrollPane);
|
||||||
|
} else if (bar != null) {
|
||||||
|
bar.setValue(scrollBefore);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -620,10 +708,7 @@ final class CyberStrikeAITab implements ITab {
|
|||||||
run.finalBuffer.append(s);
|
run.finalBuffer.append(s);
|
||||||
}
|
}
|
||||||
if (runId.equals(selectedRunId)) {
|
if (runId.equals(selectedRunId)) {
|
||||||
SwingUtilities.invokeLater(() -> {
|
SwingUtilities.invokeLater(() -> appendFinalUi(s, false));
|
||||||
finalRawArea.append(s);
|
|
||||||
finalRawArea.setCaretPosition(finalRawArea.getDocument().getLength());
|
|
||||||
});
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -656,9 +741,9 @@ final class CyberStrikeAITab implements ITab {
|
|||||||
}
|
}
|
||||||
SwingUtilities.invokeLater(() -> {
|
SwingUtilities.invokeLater(() -> {
|
||||||
progressArea.setText(progress);
|
progressArea.setText(progress);
|
||||||
progressArea.setCaretPosition(progressArea.getDocument().getLength());
|
scrollPaneToBottom(progressScrollPane);
|
||||||
finalRawArea.setText(fin);
|
finalRawArea.setText(fin);
|
||||||
finalRawArea.setCaretPosition(finalRawArea.getDocument().getLength());
|
scrollPaneToBottom(finalRawScrollPane);
|
||||||
requestArea.setText(run.requestRaw == null ? "" : run.requestRaw);
|
requestArea.setText(run.requestRaw == null ? "" : run.requestRaw);
|
||||||
responseArea.setText(run.responseRaw == null ? "" : run.responseRaw);
|
responseArea.setText(run.responseRaw == null ? "" : run.responseRaw);
|
||||||
refreshOutputView();
|
refreshOutputView();
|
||||||
@@ -682,25 +767,36 @@ final class CyberStrikeAITab implements ITab {
|
|||||||
|
|
||||||
void clearAndShowStreamHeader(String title) {
|
void clearAndShowStreamHeader(String title) {
|
||||||
SwingUtilities.invokeLater(() -> {
|
SwingUtilities.invokeLater(() -> {
|
||||||
progressArea.setText("");
|
progressArea.setText("[*] " + title + "\n\n");
|
||||||
finalRawArea.setText(title + "\n\n");
|
finalRawArea.setText("");
|
||||||
|
markdownPane.setText("");
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
// Legacy helpers kept for Validate logging
|
// Legacy helpers kept for Validate logging
|
||||||
void appendStreamLine(String s) {
|
void appendStreamLine(String s) {
|
||||||
if (s == null) return;
|
if (s == null) return;
|
||||||
SwingUtilities.invokeLater(() -> {
|
SwingUtilities.invokeLater(() -> appendProgressUi(s + "\n", false));
|
||||||
progressArea.append(s);
|
|
||||||
progressArea.append("\n");
|
|
||||||
progressArea.setCaretPosition(progressArea.getDocument().getLength());
|
|
||||||
});
|
|
||||||
}
|
}
|
||||||
|
|
||||||
private void log(String s) {
|
private void log(String s) {
|
||||||
appendStreamLine("[*] " + s);
|
appendStreamLine("[*] " + s);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private void cancelValidateInProgress() {
|
||||||
|
client.cancelActiveRequest();
|
||||||
|
Thread t = validateThreadRef.getAndSet(null);
|
||||||
|
if (t != null) {
|
||||||
|
t.interrupt();
|
||||||
|
}
|
||||||
|
SwingUtilities.invokeLater(() -> {
|
||||||
|
statusLabel.setText("Cancelled");
|
||||||
|
validateButton.setText("Validate");
|
||||||
|
validateButton.setEnabled(true);
|
||||||
|
});
|
||||||
|
log("Validation cancelled.");
|
||||||
|
}
|
||||||
|
|
||||||
private void applyFilter() {
|
private void applyFilter() {
|
||||||
String q = searchField.getText();
|
String q = searchField.getText();
|
||||||
if (q == null) q = "";
|
if (q == null) q = "";
|
||||||
|
|||||||
@@ -0,0 +1,149 @@
|
|||||||
|
package burp;
|
||||||
|
|
||||||
|
import javax.net.ssl.HostnameVerifier;
|
||||||
|
import javax.net.ssl.HttpsURLConnection;
|
||||||
|
import javax.net.ssl.SSLSocketFactory;
|
||||||
|
import javax.net.ssl.SSLContext;
|
||||||
|
import javax.net.ssl.TrustManager;
|
||||||
|
import javax.net.ssl.X509TrustManager;
|
||||||
|
import java.io.IOException;
|
||||||
|
import java.net.HttpURLConnection;
|
||||||
|
import java.net.InetSocketAddress;
|
||||||
|
import java.net.Socket;
|
||||||
|
import java.net.URL;
|
||||||
|
import java.security.cert.X509Certificate;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Opens HTTPS connections without validating server certificates (self-signed / local dev).
|
||||||
|
* Applied per-connection only; does not change JVM-wide defaults for other Burp components.
|
||||||
|
*/
|
||||||
|
final class SslTrustAll {
|
||||||
|
|
||||||
|
private static volatile SSLSocketFactory socketFactory;
|
||||||
|
private static final HostnameVerifier TRUST_ALL_HOSTS = (hostname, session) -> true;
|
||||||
|
|
||||||
|
private SslTrustAll() {
|
||||||
|
}
|
||||||
|
|
||||||
|
static HttpURLConnection open(URL url) throws IOException {
|
||||||
|
return open(url, 5_000, 30_000);
|
||||||
|
}
|
||||||
|
|
||||||
|
static HttpURLConnection open(URL url, int connectTimeoutMs, int readTimeoutMs) throws IOException {
|
||||||
|
HttpURLConnection conn = (HttpURLConnection) url.openConnection();
|
||||||
|
conn.setConnectTimeout(connectTimeoutMs);
|
||||||
|
conn.setReadTimeout(readTimeoutMs);
|
||||||
|
if (conn instanceof HttpsURLConnection) {
|
||||||
|
HttpsURLConnection https = (HttpsURLConnection) conn;
|
||||||
|
https.setSSLSocketFactory(new TimeoutSslSocketFactory(socketFactory(), connectTimeoutMs, readTimeoutMs));
|
||||||
|
https.setHostnameVerifier(TRUST_ALL_HOSTS);
|
||||||
|
}
|
||||||
|
return conn;
|
||||||
|
}
|
||||||
|
|
||||||
|
private static SSLSocketFactory socketFactory() {
|
||||||
|
SSLSocketFactory sf = socketFactory;
|
||||||
|
if (sf != null) {
|
||||||
|
return sf;
|
||||||
|
}
|
||||||
|
synchronized (SslTrustAll.class) {
|
||||||
|
sf = socketFactory;
|
||||||
|
if (sf != null) {
|
||||||
|
return sf;
|
||||||
|
}
|
||||||
|
try {
|
||||||
|
TrustManager[] trustAll = new TrustManager[]{
|
||||||
|
new X509TrustManager() {
|
||||||
|
@Override
|
||||||
|
public X509Certificate[] getAcceptedIssuers() {
|
||||||
|
return new X509Certificate[0];
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void checkClientTrusted(X509Certificate[] chain, String authType) {
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void checkServerTrusted(X509Certificate[] chain, String authType) {
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
SSLContext ctx = SSLContext.getInstance("TLS");
|
||||||
|
ctx.init(null, trustAll, new java.security.SecureRandom());
|
||||||
|
sf = ctx.getSocketFactory();
|
||||||
|
socketFactory = sf;
|
||||||
|
return sf;
|
||||||
|
} catch (Exception e) {
|
||||||
|
throw new RuntimeException("Failed to initialize trust-all TLS", e);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Ensures TCP connect + socket read respect timeouts (plain HttpURLConnection SSL can hang longer). */
|
||||||
|
private static final class TimeoutSslSocketFactory extends SSLSocketFactory {
|
||||||
|
private final SSLSocketFactory delegate;
|
||||||
|
private final int connectTimeoutMs;
|
||||||
|
private final int readTimeoutMs;
|
||||||
|
|
||||||
|
TimeoutSslSocketFactory(SSLSocketFactory delegate, int connectTimeoutMs, int readTimeoutMs) {
|
||||||
|
this.delegate = delegate;
|
||||||
|
this.connectTimeoutMs = connectTimeoutMs;
|
||||||
|
this.readTimeoutMs = readTimeoutMs;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String[] getDefaultCipherSuites() {
|
||||||
|
return delegate.getDefaultCipherSuites();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String[] getSupportedCipherSuites() {
|
||||||
|
return delegate.getSupportedCipherSuites();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Socket createSocket() throws IOException {
|
||||||
|
return tune(delegate.createSocket());
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Socket createSocket(Socket s, String host, int port, boolean autoClose) throws IOException {
|
||||||
|
return tune(delegate.createSocket(s, host, port, autoClose));
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Socket createSocket(String host, int port) throws IOException {
|
||||||
|
Socket plain = new Socket();
|
||||||
|
plain.connect(new InetSocketAddress(host, port), connectTimeoutMs);
|
||||||
|
return tune(delegate.createSocket(plain, host, port, true));
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Socket createSocket(String host, int port, java.net.InetAddress localHost, int localPort) throws IOException {
|
||||||
|
Socket plain = new Socket();
|
||||||
|
plain.bind(new InetSocketAddress(localHost, localPort));
|
||||||
|
plain.connect(new InetSocketAddress(host, port), connectTimeoutMs);
|
||||||
|
return tune(delegate.createSocket(plain, host, port, true));
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Socket createSocket(java.net.InetAddress host, int port) throws IOException {
|
||||||
|
Socket plain = new Socket();
|
||||||
|
plain.connect(new InetSocketAddress(host, port), connectTimeoutMs);
|
||||||
|
return tune(delegate.createSocket(plain, host.getHostName(), port, true));
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Socket createSocket(java.net.InetAddress address, int port, java.net.InetAddress localAddress, int localPort) throws IOException {
|
||||||
|
Socket plain = new Socket();
|
||||||
|
plain.bind(new InetSocketAddress(localAddress, localPort));
|
||||||
|
plain.connect(new InetSocketAddress(address, port), connectTimeoutMs);
|
||||||
|
return tune(delegate.createSocket(plain, address.getHostName(), port, true));
|
||||||
|
}
|
||||||
|
|
||||||
|
private Socket tune(Socket socket) throws IOException {
|
||||||
|
socket.setSoTimeout(readTimeoutMs);
|
||||||
|
return socket;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
BIN
Binary file not shown.
BIN
Binary file not shown.
BIN
Binary file not shown.
BIN
Binary file not shown.
BIN
Binary file not shown.
BIN
Binary file not shown.
BIN
Binary file not shown.
BIN
Binary file not shown.
BIN
Binary file not shown.
BIN
Binary file not shown.
BIN
Binary file not shown.
BIN
Binary file not shown.
BIN
Binary file not shown.
Binary file not shown.
BIN
Binary file not shown.
+4
@@ -1,12 +1,16 @@
|
|||||||
|
burp/SslTrustAll.class
|
||||||
|
burp/SslTrustAll$TimeoutSslSocketFactory.class
|
||||||
burp/CyberStrikeAIClient$StreamListener.class
|
burp/CyberStrikeAIClient$StreamListener.class
|
||||||
burp/CyberStrikeAIClient$Config.class
|
burp/CyberStrikeAIClient$Config.class
|
||||||
burp/CyberStrikeAIClient$AgentMode.class
|
burp/CyberStrikeAIClient$AgentMode.class
|
||||||
burp/MarkdownRenderer.class
|
burp/MarkdownRenderer.class
|
||||||
burp/SimpleJson.class
|
burp/SimpleJson.class
|
||||||
burp/CyberStrikeAIClient.class
|
burp/CyberStrikeAIClient.class
|
||||||
|
burp/CyberStrikeAIClient$1.class
|
||||||
burp/CyberStrikeAITab$DotIcon.class
|
burp/CyberStrikeAITab$DotIcon.class
|
||||||
burp/CyberStrikeAITab.class
|
burp/CyberStrikeAITab.class
|
||||||
burp/CyberStrikeAITab$1.class
|
burp/CyberStrikeAITab$1.class
|
||||||
|
burp/SslTrustAll$1.class
|
||||||
burp/BurpExtender$1.class
|
burp/BurpExtender$1.class
|
||||||
burp/BurpExtender.class
|
burp/BurpExtender.class
|
||||||
burp/CyberStrikeAITab$TestRun.class
|
burp/CyberStrikeAITab$TestRun.class
|
||||||
|
|||||||
+1
@@ -4,3 +4,4 @@
|
|||||||
/Users/temp/Downloads/CyberStrikeAI-main/plugins/burp-suite/cyberstrikeai-burp-extension/src/main/java/burp/HttpMessageFormatter.java
|
/Users/temp/Downloads/CyberStrikeAI-main/plugins/burp-suite/cyberstrikeai-burp-extension/src/main/java/burp/HttpMessageFormatter.java
|
||||||
/Users/temp/Downloads/CyberStrikeAI-main/plugins/burp-suite/cyberstrikeai-burp-extension/src/main/java/burp/MarkdownRenderer.java
|
/Users/temp/Downloads/CyberStrikeAI-main/plugins/burp-suite/cyberstrikeai-burp-extension/src/main/java/burp/MarkdownRenderer.java
|
||||||
/Users/temp/Downloads/CyberStrikeAI-main/plugins/burp-suite/cyberstrikeai-burp-extension/src/main/java/burp/SimpleJson.java
|
/Users/temp/Downloads/CyberStrikeAI-main/plugins/burp-suite/cyberstrikeai-burp-extension/src/main/java/burp/SimpleJson.java
|
||||||
|
/Users/temp/Downloads/CyberStrikeAI-main/plugins/burp-suite/cyberstrikeai-burp-extension/src/main/java/burp/SslTrustAll.java
|
||||||
|
|||||||
@@ -64,6 +64,7 @@ show_progress() {
|
|||||||
echo ""
|
echo ""
|
||||||
echo "=========================================="
|
echo "=========================================="
|
||||||
echo " CyberStrikeAI 一键部署启动脚本"
|
echo " CyberStrikeAI 一键部署启动脚本"
|
||||||
|
echo " (默认 HTTPS 自签证书;纯 HTTP 请用: $0 --http)"
|
||||||
echo "=========================================="
|
echo "=========================================="
|
||||||
echo ""
|
echo ""
|
||||||
|
|
||||||
@@ -353,7 +354,18 @@ need_rebuild() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
# 主流程
|
# 主流程
|
||||||
|
# 默认启动主站 HTTPS(--https 传给二进制);传 --http 则走明文 HTTP。
|
||||||
main() {
|
main() {
|
||||||
|
USE_HTTPS=1
|
||||||
|
FORWARD_ARGS=()
|
||||||
|
for arg in "$@"; do
|
||||||
|
if [ "$arg" = "--http" ]; then
|
||||||
|
USE_HTTPS=0
|
||||||
|
continue
|
||||||
|
fi
|
||||||
|
FORWARD_ARGS+=("$arg")
|
||||||
|
done
|
||||||
|
|
||||||
# 环境检查
|
# 环境检查
|
||||||
info "检查运行环境..."
|
info "检查运行环境..."
|
||||||
check_python
|
check_python
|
||||||
@@ -377,13 +389,30 @@ main() {
|
|||||||
# 启动服务器
|
# 启动服务器
|
||||||
success "所有准备工作完成!"
|
success "所有准备工作完成!"
|
||||||
echo ""
|
echo ""
|
||||||
info "启动 CyberStrikeAI 服务器..."
|
if [ "$USE_HTTPS" -eq 1 ]; then
|
||||||
|
info "启动 CyberStrikeAI 服务器(HTTPS + HTTP/2,自签证书)..."
|
||||||
|
note "纯 HTTP 启动请使用: $0 --http"
|
||||||
|
else
|
||||||
|
info "启动 CyberStrikeAI 服务器(HTTP)..."
|
||||||
|
fi
|
||||||
echo "=========================================="
|
echo "=========================================="
|
||||||
echo ""
|
echo ""
|
||||||
|
|
||||||
# 运行服务器
|
# 始终传入项目根目录下的 config.yaml,避免 cwd 不在项目根时找不到配置;额外参数仍可追加(如再次 -config 覆盖,以 Go flag 后写为准)。
|
||||||
exec "./$BINARY_NAME"
|
if [ "$USE_HTTPS" -eq 1 ]; then
|
||||||
|
if [ "${#FORWARD_ARGS[@]}" -gt 0 ]; then
|
||||||
|
exec "./$BINARY_NAME" -config "$CONFIG_FILE" --https "${FORWARD_ARGS[@]}"
|
||||||
|
else
|
||||||
|
exec "./$BINARY_NAME" -config "$CONFIG_FILE" --https
|
||||||
|
fi
|
||||||
|
else
|
||||||
|
if [ "${#FORWARD_ARGS[@]}" -gt 0 ]; then
|
||||||
|
exec "./$BINARY_NAME" -config "$CONFIG_FILE" "${FORWARD_ARGS[@]}"
|
||||||
|
else
|
||||||
|
exec "./$BINARY_NAME" -config "$CONFIG_FILE"
|
||||||
|
fi
|
||||||
|
fi
|
||||||
}
|
}
|
||||||
|
|
||||||
# 执行主流程
|
# 执行主流程(支持参数,如: ./run.sh --http)
|
||||||
main
|
main "$@"
|
||||||
|
|||||||
@@ -440,6 +440,230 @@ args:
|
|||||||
print("Body: <empty>")
|
print("Body: <empty>")
|
||||||
|
|
||||||
|
|
||||||
|
def compile_response_filter(pattern: str, ignore_case: bool):
|
||||||
|
flags = 0
|
||||||
|
if ignore_case:
|
||||||
|
flags |= re.IGNORECASE
|
||||||
|
try:
|
||||||
|
return re.compile(pattern, flags)
|
||||||
|
except re.error as exc:
|
||||||
|
print(f"Invalid response_filter regex: {exc}", file=sys.stderr)
|
||||||
|
sys.exit(2)
|
||||||
|
|
||||||
|
|
||||||
|
def truncate_utf8(text: str, max_bytes: int) -> Tuple[str, bool]:
|
||||||
|
if max_bytes <= 0 or not text:
|
||||||
|
return text, False
|
||||||
|
encoded = text.encode("utf-8", errors="replace")
|
||||||
|
if len(encoded) <= max_bytes:
|
||||||
|
return text, False
|
||||||
|
truncated = encoded[:max_bytes].decode("utf-8", errors="ignore")
|
||||||
|
return truncated, True
|
||||||
|
|
||||||
|
|
||||||
|
def cap_line_entries(entries: List[Tuple[int, str]], max_lines: int) -> Tuple[List[Tuple[int, str]], bool]:
|
||||||
|
if max_lines <= 0 or len(entries) <= max_lines:
|
||||||
|
return entries, False
|
||||||
|
return entries[:max_lines], True
|
||||||
|
|
||||||
|
|
||||||
|
def expand_line_context(line_numbers: List[int], total_lines: int, context: int) -> List[int]:
|
||||||
|
if context <= 0:
|
||||||
|
return sorted(set(line_numbers))
|
||||||
|
included = set()
|
||||||
|
for num in line_numbers:
|
||||||
|
start = max(1, num - context)
|
||||||
|
end = min(total_lines, num + context)
|
||||||
|
for i in range(start, end + 1):
|
||||||
|
included.add(i)
|
||||||
|
return sorted(included)
|
||||||
|
|
||||||
|
|
||||||
|
def format_line_entries(lines: List[str], indices: List[int], ellipsis_gaps: bool = True) -> str:
|
||||||
|
if not indices:
|
||||||
|
return ""
|
||||||
|
chunks = []
|
||||||
|
prev = None
|
||||||
|
for num in indices:
|
||||||
|
if ellipsis_gaps and prev is not None and num > prev + 1:
|
||||||
|
chunks.append(" ...")
|
||||||
|
chunks.append(f" L{num}: {lines[num - 1]}")
|
||||||
|
prev = num
|
||||||
|
return "\n".join(chunks)
|
||||||
|
|
||||||
|
|
||||||
|
def filter_body_by_lines(
|
||||||
|
lines: List[str],
|
||||||
|
compiled: "re.Pattern",
|
||||||
|
invert: bool,
|
||||||
|
context_lines: int,
|
||||||
|
max_lines: int,
|
||||||
|
) -> Tuple[str, Dict[str, object]]:
|
||||||
|
matched_nums = []
|
||||||
|
for idx, line in enumerate(lines, start=1):
|
||||||
|
hit = compiled.search(line) is not None
|
||||||
|
if invert:
|
||||||
|
hit = not hit
|
||||||
|
if hit:
|
||||||
|
matched_nums.append(idx)
|
||||||
|
total = len(lines)
|
||||||
|
meta = {
|
||||||
|
"mode": "line",
|
||||||
|
"total_lines": total,
|
||||||
|
"matched_lines": len(matched_nums),
|
||||||
|
"invert": invert,
|
||||||
|
"truncated": False,
|
||||||
|
"byte_truncated": False,
|
||||||
|
}
|
||||||
|
if not matched_nums:
|
||||||
|
return "", meta
|
||||||
|
display_nums = expand_line_context(matched_nums, total, context_lines)
|
||||||
|
entries = [(n, lines[n - 1]) for n in display_nums]
|
||||||
|
entries, line_capped = cap_line_entries(entries, max_lines)
|
||||||
|
meta["truncated"] = line_capped
|
||||||
|
meta["display_lines"] = len(entries)
|
||||||
|
return format_line_entries(lines, [n for n, _ in entries], ellipsis_gaps=context_lines > 0), meta
|
||||||
|
|
||||||
|
|
||||||
|
def filter_body_multiline(
|
||||||
|
text: str,
|
||||||
|
compiled: "re.Pattern",
|
||||||
|
invert: bool,
|
||||||
|
max_lines: int,
|
||||||
|
dotall: bool,
|
||||||
|
) -> Tuple[str, Dict[str, object]]:
|
||||||
|
flags = compiled.flags
|
||||||
|
if dotall:
|
||||||
|
pattern = re.compile(compiled.pattern, flags | re.DOTALL | re.MULTILINE)
|
||||||
|
else:
|
||||||
|
pattern = re.compile(compiled.pattern, flags | re.MULTILINE)
|
||||||
|
matches = list(pattern.finditer(text))
|
||||||
|
if invert:
|
||||||
|
if matches:
|
||||||
|
return "", {"mode": "multiline" if not dotall else "full", "total_lines": text.count("\n") + (1 if text else 0), "matched_lines": 0, "invert": True, "truncated": False, "byte_truncated": False}
|
||||||
|
output = text
|
||||||
|
meta = {"mode": "multiline" if not dotall else "full", "matched_lines": 1, "invert": True, "truncated": False, "byte_truncated": False}
|
||||||
|
lines = text.splitlines()
|
||||||
|
if max_lines > 0 and len(lines) > max_lines:
|
||||||
|
output = "\n".join(lines[:max_lines])
|
||||||
|
meta["truncated"] = True
|
||||||
|
meta["total_lines"] = len(lines)
|
||||||
|
meta["display_lines"] = min(len(lines), max_lines) if max_lines > 0 else len(lines)
|
||||||
|
return output, meta
|
||||||
|
chunks = []
|
||||||
|
for match in matches:
|
||||||
|
snippet = match.group(0)
|
||||||
|
if "\n" in snippet:
|
||||||
|
snippet = snippet.replace("\n", "\\n")
|
||||||
|
start_line = text.count("\n", 0, match.start()) + 1
|
||||||
|
chunks.append((start_line, f" @{start_line}: {snippet}"))
|
||||||
|
entries, line_capped = cap_line_entries(chunks, max_lines if max_lines > 0 else len(chunks))
|
||||||
|
meta = {
|
||||||
|
"mode": "multiline" if not dotall else "full",
|
||||||
|
"total_lines": text.count("\n") + (1 if text else 0),
|
||||||
|
"matched_lines": len(matches),
|
||||||
|
"invert": False,
|
||||||
|
"truncated": line_capped,
|
||||||
|
"byte_truncated": False,
|
||||||
|
"display_lines": len(entries),
|
||||||
|
}
|
||||||
|
return "\n".join(line for _, line in entries), meta
|
||||||
|
|
||||||
|
|
||||||
|
def apply_body_limits_plain(text: str, max_lines: int) -> Tuple[str, Dict[str, object]]:
|
||||||
|
lines = text.splitlines()
|
||||||
|
meta = {
|
||||||
|
"mode": "plain",
|
||||||
|
"total_lines": len(lines),
|
||||||
|
"matched_lines": len(lines),
|
||||||
|
"invert": False,
|
||||||
|
"truncated": False,
|
||||||
|
"byte_truncated": False,
|
||||||
|
"display_lines": len(lines),
|
||||||
|
}
|
||||||
|
output = text
|
||||||
|
if max_lines > 0 and len(lines) > max_lines:
|
||||||
|
output = "\n".join(lines[:max_lines])
|
||||||
|
meta["truncated"] = True
|
||||||
|
meta["display_lines"] = max_lines
|
||||||
|
return output, meta
|
||||||
|
|
||||||
|
|
||||||
|
def format_response_body_output(
|
||||||
|
decoded_body: str,
|
||||||
|
filter_pattern: str,
|
||||||
|
filter_mode: str,
|
||||||
|
filter_invert: bool,
|
||||||
|
filter_ignore_case: bool,
|
||||||
|
max_lines: int,
|
||||||
|
max_bytes: int,
|
||||||
|
preview_lines: int,
|
||||||
|
context_lines: int,
|
||||||
|
compiled_filter=None,
|
||||||
|
) -> Tuple[str, Dict[str, object]]:
|
||||||
|
text = decoded_body.rstrip("\r\n")
|
||||||
|
if not text:
|
||||||
|
return "", {"mode": "empty", "total_lines": 0, "matched_lines": 0, "invert": filter_invert, "truncated": False, "byte_truncated": False, "display_lines": 0}
|
||||||
|
|
||||||
|
lines = text.splitlines()
|
||||||
|
mode = (filter_mode or "line").strip().lower()
|
||||||
|
if mode not in {"line", "multiline", "full"}:
|
||||||
|
mode = "line"
|
||||||
|
|
||||||
|
if filter_pattern:
|
||||||
|
compiled = compiled_filter or compile_response_filter(filter_pattern, filter_ignore_case)
|
||||||
|
if mode == "line":
|
||||||
|
output, meta = filter_body_by_lines(lines, compiled, filter_invert, context_lines, max_lines)
|
||||||
|
else:
|
||||||
|
output, meta = filter_body_multiline(text, compiled, filter_invert, max_lines, dotall=(mode == "full"))
|
||||||
|
meta["filter_pattern"] = filter_pattern
|
||||||
|
if not output and not filter_invert:
|
||||||
|
preview = min(max(preview_lines, 0), len(lines))
|
||||||
|
if preview > 0:
|
||||||
|
preview_text = format_line_entries(lines, list(range(1, preview + 1)), ellipsis_gaps=False)
|
||||||
|
preview_text, byte_truncated = truncate_utf8(preview_text, max_bytes)
|
||||||
|
return preview_text, {
|
||||||
|
**meta,
|
||||||
|
"preview": True,
|
||||||
|
"matched_lines": 0,
|
||||||
|
"display_lines": preview,
|
||||||
|
"byte_truncated": byte_truncated,
|
||||||
|
}
|
||||||
|
return "", {**meta, "preview": False, "matched_lines": 0, "display_lines": 0}
|
||||||
|
else:
|
||||||
|
output, meta = apply_body_limits_plain(text, max_lines)
|
||||||
|
|
||||||
|
output, byte_truncated = truncate_utf8(output, max_bytes)
|
||||||
|
if byte_truncated:
|
||||||
|
meta["byte_truncated"] = True
|
||||||
|
return output, meta
|
||||||
|
|
||||||
|
|
||||||
|
def print_response_body_summary(meta: Dict[str, object]):
|
||||||
|
mode = meta.get("mode")
|
||||||
|
if mode == "empty":
|
||||||
|
return
|
||||||
|
parts = [f"mode={mode}"]
|
||||||
|
if meta.get("filter_pattern"):
|
||||||
|
parts.append(f"pattern={meta['filter_pattern']!r}")
|
||||||
|
if meta.get("invert"):
|
||||||
|
parts.append("invert=true")
|
||||||
|
total = meta.get("total_lines")
|
||||||
|
matched = meta.get("matched_lines")
|
||||||
|
displayed = meta.get("display_lines")
|
||||||
|
if total is not None and matched is not None:
|
||||||
|
parts.append(f"matched {matched}/{total} lines")
|
||||||
|
if displayed is not None:
|
||||||
|
parts.append(f"showing {displayed}")
|
||||||
|
if meta.get("preview"):
|
||||||
|
parts.append("preview on zero match")
|
||||||
|
if meta.get("truncated"):
|
||||||
|
parts.append("line cap applied")
|
||||||
|
if meta.get("byte_truncated"):
|
||||||
|
parts.append("byte cap applied")
|
||||||
|
print(f"[body] {' | '.join(parts)}")
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
parser = argparse.ArgumentParser(description="Pure Python HTTP testing helper powered by httpx")
|
parser = argparse.ArgumentParser(description="Pure Python HTTP testing helper powered by httpx")
|
||||||
parser.add_argument("--url", required=True)
|
parser.add_argument("--url", required=True)
|
||||||
@@ -466,6 +690,16 @@ args:
|
|||||||
parser.add_argument("--debug", dest="debug", action="store_true")
|
parser.add_argument("--debug", dest="debug", action="store_true")
|
||||||
parser.add_argument("--response-encoding", dest="response_encoding", default="")
|
parser.add_argument("--response-encoding", dest="response_encoding", default="")
|
||||||
parser.add_argument("--download", dest="download", default="")
|
parser.add_argument("--download", dest="download", default="")
|
||||||
|
parser.add_argument("--response-filter", dest="response_filter", default="")
|
||||||
|
parser.add_argument("--response-filter-mode", dest="response_filter_mode", default="line")
|
||||||
|
parser.add_argument("--response-filter-invert", dest="response_filter_invert", action="store_true")
|
||||||
|
parser.add_argument("--no-response-filter-invert", dest="response_filter_invert", action="store_false")
|
||||||
|
parser.add_argument("--response-filter-ignore-case", dest="response_filter_ignore_case", action="store_true")
|
||||||
|
parser.add_argument("--no-response-filter-ignore-case", dest="response_filter_ignore_case", action="store_false")
|
||||||
|
parser.add_argument("--response-max-lines", dest="response_max_lines", type=int, default=0)
|
||||||
|
parser.add_argument("--response-max-bytes", dest="response_max_bytes", type=int, default=0)
|
||||||
|
parser.add_argument("--response-preview-lines", dest="response_preview_lines", type=int, default=5)
|
||||||
|
parser.add_argument("--response-context-lines", dest="response_context_lines", type=int, default=0)
|
||||||
parser.set_defaults(
|
parser.set_defaults(
|
||||||
include_headers=False,
|
include_headers=False,
|
||||||
auto_encode_url=False,
|
auto_encode_url=False,
|
||||||
@@ -475,9 +709,22 @@ args:
|
|||||||
show_command=False,
|
show_command=False,
|
||||||
show_summary=False,
|
show_summary=False,
|
||||||
debug=False,
|
debug=False,
|
||||||
|
response_filter_invert=False,
|
||||||
|
response_filter_ignore_case=False,
|
||||||
)
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
response_filter = (args.response_filter or "").strip()
|
||||||
|
response_max_lines = max(0, args.response_max_lines or 0)
|
||||||
|
response_max_bytes = max(0, args.response_max_bytes or 0)
|
||||||
|
response_preview_lines = max(0, args.response_preview_lines if args.response_preview_lines is not None else 5)
|
||||||
|
response_context_lines = max(0, args.response_context_lines or 0)
|
||||||
|
compiled_response_filter = None
|
||||||
|
if response_filter:
|
||||||
|
compiled_response_filter = compile_response_filter(
|
||||||
|
response_filter, args.response_filter_ignore_case
|
||||||
|
)
|
||||||
|
|
||||||
repeat = max(1, args.repeat)
|
repeat = max(1, args.repeat)
|
||||||
try:
|
try:
|
||||||
delay_between = float(args.delay or "0")
|
delay_between = float(args.delay or "0")
|
||||||
@@ -648,9 +895,37 @@ args:
|
|||||||
for key, value in response.headers.items():
|
for key, value in response.headers.items():
|
||||||
print(f"{key}: {value}")
|
print(f"{key}: {value}")
|
||||||
print("")
|
print("")
|
||||||
output_body = decoded_body.rstrip()
|
output_body, body_output_meta = format_response_body_output(
|
||||||
|
decoded_body,
|
||||||
|
response_filter,
|
||||||
|
args.response_filter_mode,
|
||||||
|
args.response_filter_invert,
|
||||||
|
args.response_filter_ignore_case,
|
||||||
|
response_max_lines,
|
||||||
|
response_max_bytes,
|
||||||
|
response_preview_lines,
|
||||||
|
response_context_lines,
|
||||||
|
compiled_filter=compiled_response_filter,
|
||||||
|
)
|
||||||
|
has_filter_or_cap = bool(
|
||||||
|
response_filter or response_max_lines > 0 or response_max_bytes > 0
|
||||||
|
)
|
||||||
|
if has_filter_or_cap and body_output_meta.get("mode") != "empty":
|
||||||
|
print_response_body_summary(body_output_meta)
|
||||||
|
if body_output_meta.get("preview") and not body_output_meta.get("matched_lines"):
|
||||||
|
print("[body] no regex match; showing preview:")
|
||||||
if output_body:
|
if output_body:
|
||||||
print(output_body)
|
print(output_body)
|
||||||
|
if body_output_meta.get("truncated") or body_output_meta.get("byte_truncated"):
|
||||||
|
omitted = (body_output_meta.get("total_lines") or 0) - (
|
||||||
|
body_output_meta.get("display_lines") or 0
|
||||||
|
)
|
||||||
|
if omitted > 0:
|
||||||
|
print(f"[body] ... {omitted} more line(s) omitted (use --download for full body)")
|
||||||
|
elif body_output_meta.get("mode") == "empty":
|
||||||
|
print("[no body]")
|
||||||
|
elif response_filter and not body_output_meta.get("preview"):
|
||||||
|
print("[body] no lines matched filter")
|
||||||
else:
|
else:
|
||||||
print("[no body]")
|
print("[no body]")
|
||||||
|
|
||||||
@@ -729,6 +1004,13 @@ description: |
|
|||||||
- 连接探针:在无代理场景下额外进行 DNS/TCP/TLS 探测,粗粒度复刻 curl -w 指标
|
- 连接探针:在无代理场景下额外进行 DNS/TCP/TLS 探测,粗粒度复刻 curl -w 指标
|
||||||
- 可重复观测:repeat/delay + TTFB/total/speed_download 统计,便于盲注/时序测试
|
- 可重复观测:repeat/delay + TTFB/total/speed_download 统计,便于盲注/时序测试
|
||||||
- 扩展开关:additional_args 解析 http2/cert/verify/trust_env/max_redirects 等 httpx 选项
|
- 扩展开关:additional_args 解析 http2/cert/verify/trust_env/max_redirects 等 httpx 选项
|
||||||
|
- 响应体瘦身:response_filter 按行/块正则提取,配合 max_lines/max_bytes 限制 stdout,降低 Agent token 消耗
|
||||||
|
|
||||||
|
**响应过滤最佳实践:**
|
||||||
|
- 大页面/HTML:用 `response_filter` 抓 error|exception|password|token|uid 等关键字行
|
||||||
|
- 无 filter 时:设 `response_max_lines=80` 或 `response_max_bytes=8192` 防止整页灌入上下文
|
||||||
|
- 0 命中:自动预览前 `response_preview_lines` 行,避免误判「空响应」
|
||||||
|
- 完整留存:大 body 用 `download` 落盘,stdout 只保留摘要行
|
||||||
parameters:
|
parameters:
|
||||||
- name: "url"
|
- name: "url"
|
||||||
type: "string"
|
type: "string"
|
||||||
@@ -836,6 +1118,56 @@ parameters:
|
|||||||
description: "强制响应解码使用的编码(如GBK),覆盖自动探测"
|
description: "强制响应解码使用的编码(如GBK),覆盖自动探测"
|
||||||
required: false
|
required: false
|
||||||
flag: "--response-encoding"
|
flag: "--response-encoding"
|
||||||
|
- name: "response_filter"
|
||||||
|
type: "string"
|
||||||
|
description: |
|
||||||
|
响应体正则过滤(仅影响 stdout,不影响 --download 与指标)。
|
||||||
|
默认 line 模式按行匹配;示例:'(error|exception|SQL|password|token|uid)'。
|
||||||
|
与 response_max_lines/response_max_bytes 配合可显著减少 token 消耗。
|
||||||
|
required: false
|
||||||
|
flag: "--response-filter"
|
||||||
|
- name: "response_filter_mode"
|
||||||
|
type: "string"
|
||||||
|
description: "过滤模式:line(按行,默认)、multiline(跨行块)、full(整段 DOTALL 匹配)"
|
||||||
|
required: false
|
||||||
|
default: "line"
|
||||||
|
flag: "--response-filter-mode"
|
||||||
|
- name: "response_filter_invert"
|
||||||
|
type: "bool"
|
||||||
|
description: "反向过滤:输出不匹配 regex 的行(用于剔除 HTML 噪音)"
|
||||||
|
required: false
|
||||||
|
default: false
|
||||||
|
flag: "--response-filter-invert"
|
||||||
|
- name: "response_filter_ignore_case"
|
||||||
|
type: "bool"
|
||||||
|
description: "正则忽略大小写"
|
||||||
|
required: false
|
||||||
|
default: false
|
||||||
|
flag: "--response-filter-ignore-case"
|
||||||
|
- name: "response_max_lines"
|
||||||
|
type: "int"
|
||||||
|
description: "stdout 最多输出行数(0=不限制);有 filter 时限制命中行数,无 filter 时截断全文"
|
||||||
|
required: false
|
||||||
|
default: 0
|
||||||
|
flag: "--response-max-lines"
|
||||||
|
- name: "response_max_bytes"
|
||||||
|
type: "int"
|
||||||
|
description: "stdout 响应体 UTF-8 字节上限(0=不限制),超出部分截断"
|
||||||
|
required: false
|
||||||
|
default: 0
|
||||||
|
flag: "--response-max-bytes"
|
||||||
|
- name: "response_preview_lines"
|
||||||
|
type: "int"
|
||||||
|
description: "filter 零命中时预览的前 N 行(默认 5,0=不预览)"
|
||||||
|
required: false
|
||||||
|
default: 5
|
||||||
|
flag: "--response-preview-lines"
|
||||||
|
- name: "response_context_lines"
|
||||||
|
type: "int"
|
||||||
|
description: "line 模式下命中行上下各保留 N 行上下文(类似 grep -C)"
|
||||||
|
required: false
|
||||||
|
default: 0
|
||||||
|
flag: "--response-context-lines"
|
||||||
- name: "action"
|
- name: "action"
|
||||||
type: "string"
|
type: "string"
|
||||||
description: "保留字段:标识调用意图(request, spider等),脚本内部不使用"
|
description: "保留字段:标识调用意图(request, spider等),脚本内部不使用"
|
||||||
|
|||||||
+8
-23
@@ -8,11 +8,8 @@ set -euo pipefail
|
|||||||
# - data/
|
# - data/
|
||||||
# - venv/ (disabled with --no-venv)
|
# - venv/ (disabled with --no-venv)
|
||||||
# - tools/ (user extensions; never overwritten by upgrade)
|
# - tools/ (user extensions; never overwritten by upgrade)
|
||||||
#
|
|
||||||
# Optional preserves (may overwrite upstream updates):
|
|
||||||
# - roles/
|
# - roles/
|
||||||
# - skills/
|
# - skills/
|
||||||
# Enable with --preserve-custom
|
|
||||||
|
|
||||||
ROOT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
|
ROOT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
|
||||||
cd "$ROOT_DIR"
|
cd "$ROOT_DIR"
|
||||||
@@ -28,7 +25,6 @@ BACKUP_BASE_DIR="$ROOT_DIR/.upgrade-backup"
|
|||||||
GITHUB_REPO="Ed1s0nZ/CyberStrikeAI"
|
GITHUB_REPO="Ed1s0nZ/CyberStrikeAI"
|
||||||
|
|
||||||
TAG=""
|
TAG=""
|
||||||
PRESERVE_CUSTOM=0
|
|
||||||
PRESERVE_VENV=1
|
PRESERVE_VENV=1
|
||||||
STOP_SERVICE=1
|
STOP_SERVICE=1
|
||||||
FORCE_STOP=0
|
FORCE_STOP=0
|
||||||
@@ -37,14 +33,12 @@ YES=0
|
|||||||
usage() {
|
usage() {
|
||||||
cat <<EOF
|
cat <<EOF
|
||||||
Usage:
|
Usage:
|
||||||
./upgrade.sh [--tag vX.Y.Z] [--preserve-custom] [--no-venv] [--no-stop]
|
./upgrade.sh [--tag vX.Y.Z] [--no-venv] [--no-stop]
|
||||||
[--force-stop] [--yes]
|
[--force-stop] [--yes]
|
||||||
|
|
||||||
Options:
|
Options:
|
||||||
--tag <tag> Specify GitHub Release tag (e.g. v1.3.28).
|
--tag <tag> Specify GitHub Release tag (e.g. v1.3.28).
|
||||||
If omitted, the script uses the latest release.
|
If omitted, the script uses the latest release.
|
||||||
--preserve-custom Preserve roles/skills (may overwrite upstream files).
|
|
||||||
tools/ is always preserved. Use with caution.
|
|
||||||
--no-venv Do not preserve venv/ (Python deps will be re-installed).
|
--no-venv Do not preserve venv/ (Python deps will be re-installed).
|
||||||
--no-stop Do not try to stop the running service.
|
--no-stop Do not try to stop the running service.
|
||||||
--force-stop If no process matching current directory is found, also stop
|
--force-stop If no process matching current directory is found, also stop
|
||||||
@@ -52,7 +46,7 @@ Options:
|
|||||||
--yes Do not ask for confirmation.
|
--yes Do not ask for confirmation.
|
||||||
|
|
||||||
Description:
|
Description:
|
||||||
The script backs up config.yaml/data/tools/ (and optionally venv/roles/skills) to
|
The script backs up config.yaml/data/tools/roles/skills/ (and optionally venv/) to
|
||||||
.upgrade-backup/
|
.upgrade-backup/
|
||||||
EOF
|
EOF
|
||||||
}
|
}
|
||||||
@@ -177,11 +171,7 @@ confirm_or_exit() {
|
|||||||
info " - Preserve venv/: no (will remove old venv and re-install deps)"
|
info " - Preserve venv/: no (will remove old venv and re-install deps)"
|
||||||
fi
|
fi
|
||||||
info " - Preserve tools/: yes (always)"
|
info " - Preserve tools/: yes (always)"
|
||||||
if [[ "$PRESERVE_CUSTOM" -eq 1 ]]; then
|
info " - Preserve roles/skills: yes (always)"
|
||||||
info " - Preserve roles/skills: yes (may overwrite upstream updates)"
|
|
||||||
else
|
|
||||||
info " - Preserve roles/skills: no (will use upstream versions)"
|
|
||||||
fi
|
|
||||||
info " - Stop service: ${STOP_SERVICE}"
|
info " - Stop service: ${STOP_SERVICE}"
|
||||||
echo ""
|
echo ""
|
||||||
read -r -p "Continue? (y/N) " ans
|
read -r -p "Continue? (y/N) " ans
|
||||||
@@ -299,11 +289,8 @@ sync_code() {
|
|||||||
|
|
||||||
# User tool extensions: never replace or delete during upgrade.
|
# User tool extensions: never replace or delete during upgrade.
|
||||||
rsync_excludes+=( "--exclude=tools/" )
|
rsync_excludes+=( "--exclude=tools/" )
|
||||||
|
rsync_excludes+=( "--exclude=roles/" )
|
||||||
if [[ "$PRESERVE_CUSTOM" -eq 1 ]]; then
|
rsync_excludes+=( "--exclude=skills/" )
|
||||||
rsync_excludes+=( "--exclude=roles/" )
|
|
||||||
rsync_excludes+=( "--exclude=skills/" )
|
|
||||||
fi
|
|
||||||
|
|
||||||
# Ensure this upgrade script itself is not deleted.
|
# Ensure this upgrade script itself is not deleted.
|
||||||
rsync_excludes+=( "--exclude=upgrade.sh" )
|
rsync_excludes+=( "--exclude=upgrade.sh" )
|
||||||
@@ -324,10 +311,6 @@ main() {
|
|||||||
TAG="${2:-}"
|
TAG="${2:-}"
|
||||||
shift 2
|
shift 2
|
||||||
;;
|
;;
|
||||||
--preserve-custom)
|
|
||||||
PRESERVE_CUSTOM=1
|
|
||||||
shift 1
|
|
||||||
;;
|
|
||||||
--no-venv)
|
--no-venv)
|
||||||
PRESERVE_VENV=0
|
PRESERVE_VENV=0
|
||||||
shift 1
|
shift 1
|
||||||
@@ -384,8 +367,10 @@ main() {
|
|||||||
if [[ -d "$ROOT_DIR/tools" ]]; then
|
if [[ -d "$ROOT_DIR/tools" ]]; then
|
||||||
backup_dir_tgz "tools" "$ROOT_DIR/tools"
|
backup_dir_tgz "tools" "$ROOT_DIR/tools"
|
||||||
fi
|
fi
|
||||||
if [[ "$PRESERVE_CUSTOM" -eq 1 ]]; then
|
if [[ -d "$ROOT_DIR/roles" ]]; then
|
||||||
backup_dir_tgz "roles" "$ROOT_DIR/roles"
|
backup_dir_tgz "roles" "$ROOT_DIR/roles"
|
||||||
|
fi
|
||||||
|
if [[ -d "$ROOT_DIR/skills" ]]; then
|
||||||
backup_dir_tgz "skills" "$ROOT_DIR/skills"
|
backup_dir_tgz "skills" "$ROOT_DIR/skills"
|
||||||
fi
|
fi
|
||||||
|
|
||||||
|
|||||||
+34
-3
@@ -260,8 +260,14 @@
|
|||||||
gap: 12px;
|
gap: 12px;
|
||||||
justify-content: center;
|
justify-content: center;
|
||||||
flex-wrap: wrap;
|
flex-wrap: wrap;
|
||||||
|
max-width: 420px;
|
||||||
|
margin-inline: auto;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
.c2-actions > button {
|
||||||
|
flex: 1;
|
||||||
|
min-width: min(100%, 160px);
|
||||||
|
}
|
||||||
/* ============================================================================
|
/* ============================================================================
|
||||||
Listener Cards
|
Listener Cards
|
||||||
============================================================================ */
|
============================================================================ */
|
||||||
@@ -851,10 +857,35 @@
|
|||||||
background: var(--c2-surface);
|
background: var(--c2-surface);
|
||||||
border-radius: var(--c2-radius);
|
border-radius: var(--c2-radius);
|
||||||
border: 1px solid var(--c2-border);
|
border: 1px solid var(--c2-border);
|
||||||
overflow: hidden;
|
overflow-x: auto;
|
||||||
|
overflow-y: visible;
|
||||||
}
|
}
|
||||||
|
|
||||||
.c2-task-table { width: 100%; border-collapse: collapse; }
|
/* 操作列:仅占按钮宽度,避免 100% 表格把余白摊到最右列 */
|
||||||
|
.c2-task-table th.c2-task-table-col-actions,
|
||||||
|
.c2-task-table td.c2-task-table-col-actions {
|
||||||
|
width: 1%;
|
||||||
|
white-space: nowrap;
|
||||||
|
text-align: right;
|
||||||
|
vertical-align: middle;
|
||||||
|
}
|
||||||
|
|
||||||
|
.c2-task-table-actions {
|
||||||
|
display: inline-flex;
|
||||||
|
align-items: center;
|
||||||
|
justify-content: flex-end;
|
||||||
|
gap: 6px;
|
||||||
|
flex-wrap: nowrap;
|
||||||
|
}
|
||||||
|
|
||||||
|
.c2-task-table-actions .btn-small,
|
||||||
|
.c2-task-table-actions .btn-sm {
|
||||||
|
min-height: 30px;
|
||||||
|
min-width: 52px;
|
||||||
|
justify-content: center;
|
||||||
|
}
|
||||||
|
|
||||||
|
.c2-task-table { width: 100%; border-collapse: collapse; table-layout: auto; }
|
||||||
|
|
||||||
.c2-task-table th {
|
.c2-task-table th {
|
||||||
text-align: left;
|
text-align: left;
|
||||||
@@ -1255,7 +1286,7 @@
|
|||||||
display: flex;
|
display: flex;
|
||||||
align-items: center;
|
align-items: center;
|
||||||
justify-content: center;
|
justify-content: center;
|
||||||
z-index: 1000;
|
z-index: 10050;
|
||||||
padding: 24px;
|
padding: 24px;
|
||||||
animation: c2-fade-in 0.15s ease-out;
|
animation: c2-fade-in 0.15s ease-out;
|
||||||
}
|
}
|
||||||
|
|||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user