mirror of
https://github.com/Ed1s0nZ/CyberStrikeAI.git
synced 2026-06-06 06:13:58 +02:00
Compare commits
193 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 97834c162e | |||
| 9276f2f144 | |||
| a454cada6a | |||
| 99b53d4fbc | |||
| a43a9deaea | |||
| ce88da84c9 | |||
| 15855c7073 | |||
| 43eb3e546b | |||
| 2d52c9b6ac | |||
| d5401b8b4c | |||
| 5fd4393a2e | |||
| a049f6b5c2 | |||
| acba8e5a39 | |||
| f826b91362 | |||
| 98c2de2a60 | |||
| 1c4d4b305b | |||
| f210ac9a03 | |||
| 6685076dfb | |||
| 7f322653f6 | |||
| 66ac2f1357 | |||
| c446e22d0c | |||
| 0358d3a67d | |||
| 9b82f265fd | |||
| 3d9cae58e4 | |||
| 1f1eadee5e | |||
| 0569255189 | |||
| 8ccf90d067 | |||
| b3be89f47d | |||
| b9bf8f62d4 | |||
| 05ca0c1480 | |||
| 47a4f3fc5b | |||
| a3b378ae9e | |||
| a904d26e78 | |||
| 7ba7476c4f | |||
| ae25a243ac | |||
| 23bd6288ff | |||
| fef21d3a24 | |||
| 933bba4517 | |||
| e1d65437cc | |||
| 9325aed1eb | |||
| dee2b3ab42 | |||
| a69bc93fa1 | |||
| b1a620bfce | |||
| 61b164eec2 | |||
| ba77e1837e | |||
| eacad60fd6 | |||
| 70bf5c93bf | |||
| 08bd278d8c | |||
| 22746d64a3 | |||
| 199392a5d5 | |||
| aafb4cb584 | |||
| 96e3dd397c | |||
| ec0f17145b | |||
| ed53da0999 | |||
| dc440fc511 | |||
| 009ae59033 | |||
| f348b3245a | |||
| 0018c5219c | |||
| 01a3e3677a | |||
| a12ecdb46f | |||
| 9f59230d74 | |||
| 085c6a1c72 | |||
| 7b3860971f | |||
| f6f7b7b237 | |||
| d5cf4b3b16 | |||
| 3e58d8355b | |||
| eb01ade63b | |||
| d1dc15fa44 | |||
| 73a39ef868 | |||
| a022baef03 | |||
| 59312d428e | |||
| 951d14ef14 | |||
| 0eb22da6e9 | |||
| 5fd9ef0514 | |||
| 9a4f3c7d35 | |||
| ead2ce3ecc | |||
| 8733f3a2d2 | |||
| 8642f3ba31 | |||
| 6a262a7367 | |||
| eb9192ddb3 | |||
| 5587e75628 | |||
| 74bbb453e2 | |||
| 66842f6206 | |||
| dc1779275d | |||
| 10dff937b1 | |||
| d4e1fe3bbe | |||
| 179976ae57 | |||
| 1c758bb98c | |||
| 17c4f38ee3 | |||
| cd7e57d121 | |||
| 0f2c3f65cc | |||
| 7779666e27 | |||
| c74bd4403b | |||
| 04d23ddb43 | |||
| 0874e84393 | |||
| 57f57f30b1 | |||
| f37d613a0c | |||
| 87d0ff9154 | |||
| b3418f39b8 | |||
| f9e1ca0e2d | |||
| 2c45879669 | |||
| 1cdcfa2c2d | |||
| eab5b73846 | |||
| d961ba1ec7 | |||
| 1ba5e57ec6 | |||
| 1216d25f96 | |||
| fde693408e | |||
| 352a81a869 | |||
| b2562b1010 | |||
| 0d8ba51087 | |||
| 0b847fcea3 | |||
| bf2f49fe62 | |||
| 75e64b1a86 | |||
| 2167735022 | |||
| 4ee292cc1f | |||
| 961205940f | |||
| ffe797bd06 | |||
| b6c864547e | |||
| da369c2edc | |||
| 54dc31a616 | |||
| 9e0b985221 | |||
| eb47077082 | |||
| f9a482857d | |||
| 679a68b12f | |||
| 840a26c7ef | |||
| 030e69c02d | |||
| d9683cdb44 | |||
| 60a063dd7d | |||
| 5f0c1805a7 | |||
| cb7e66001b | |||
| 4ea838f1d7 | |||
| 573648fc4b | |||
| f0e090abea | |||
| 549dcf518c | |||
| c74e20c54a | |||
| c94a9fd9e9 | |||
| ce9749a8ef | |||
| 145da12017 | |||
| 5111f4c311 | |||
| 8f6384a083 | |||
| 762f778e1e | |||
| 4a11ba8f14 | |||
| 86090af4df | |||
| 2dea6e36bd | |||
| 38ce695708 | |||
| 41fe90faa3 | |||
| 9f54bdb1bf | |||
| 08e727aa41 | |||
| 176c17d630 | |||
| 62710f6619 | |||
| e4dbb96b3e | |||
| 832532213a | |||
| eb04ac0c3a | |||
| 1946508325 | |||
| 89d1c5124f | |||
| 1e7a3299a5 | |||
| cae3a77331 | |||
| 2e1e57ce27 | |||
| 45b6ed2847 | |||
| 88eadf13a4 | |||
| dca5666b18 | |||
| e5d52cdf85 | |||
| 65e48826ff | |||
| 0cff507272 | |||
| 30afd71c05 | |||
| d2b6a154de | |||
| 278d5aa25c | |||
| 215f5a4a93 | |||
| 44185d748d | |||
| fe47f1f058 | |||
| 99ce183f41 | |||
| 2ed1947f36 | |||
| 97f3e8c179 | |||
| 38b0c31b87 | |||
| cb839da4d1 | |||
| 5ed730f17c | |||
| 30b1e5f820 | |||
| 8e5c70703e | |||
| 3cc3b25a7b | |||
| 44cf63fa52 | |||
| 12057c065b | |||
| c4e0b9735c | |||
| 218e9b9880 | |||
| 82d840966e | |||
| c62ff3bde9 | |||
| df2506b651 | |||
| efe9172f85 | |||
| b788bc6dab | |||
| 9134f2bbcb | |||
| d76cf2a162 | |||
| 2f96feb98f | |||
| a374c3950c | |||
| a93e3455fa |
@@ -174,9 +174,11 @@ The `run.sh` script will automatically:
|
||||
- ✅ Build the project
|
||||
- ✅ Start the server
|
||||
|
||||
**Networking defaults:** `run.sh` starts the server with **`--https`** and the repo **`config.yaml`** (local self-signed TLS; better for many concurrent streams). Use **`./run.sh --http`** for plain HTTP. In production, set **`server.tls_cert_path`** / **`server.tls_key_path`** in **`config.yaml`** (see comments there). For manual runs, add **`--https`** or **`CYBERSTRIKE_HTTPS=1`**; if **`-config`** is wrong, the binary prints a short usage hint on stderr.
|
||||
|
||||
**First-Time Configuration:**
|
||||
1. **Configure OpenAI-compatible API** (required before first use)
|
||||
- Open http://localhost:8080 after launch
|
||||
- After launch, open **`https://127.0.0.1:8080/`** (or **`https://localhost:8080/`**; replace **8080** with `server.port` in `config.yaml`) and accept the self-signed certificate warning once. If you used `./run.sh --http`, use **`http://`** instead.
|
||||
- Go to `Settings` → Fill in your API credentials:
|
||||
```yaml
|
||||
openai:
|
||||
@@ -197,21 +199,23 @@ The `run.sh` script will automatically:
|
||||
|
||||
**Alternative Launch Methods:**
|
||||
```bash
|
||||
# Direct Go run (requires manual setup)
|
||||
go run cmd/server/main.go
|
||||
# Direct Go run (set up env yourself); add --https to match run.sh defaults
|
||||
go run cmd/server/main.go --https
|
||||
|
||||
# Manual build
|
||||
go build -o cyberstrike-ai cmd/server/main.go
|
||||
./cyberstrike-ai
|
||||
./cyberstrike-ai --https
|
||||
```
|
||||
|
||||
If server logs show `client sent an HTTP request to an HTTPS server`, a client is still using **`http://`** on a TLS-only port—switch the URL to **`https://`**.
|
||||
|
||||
**Note:** The Python virtual environment (`venv/`) is automatically created and managed by `run.sh`. Tools that require Python (like `api-fuzzer`, `http-framework-test`, etc.) will automatically use this environment.
|
||||
|
||||
### Version Update (No Breaking Changes)
|
||||
|
||||
**CyberStrikeAI one-click upgrade (recommended):**
|
||||
1. (First time) enable the script: `chmod +x upgrade.sh`
|
||||
2. Upgrade with: `./upgrade.sh` (optional flags: `--tag vX.Y.Z`, `--no-venv`, `--preserve-custom`, `--yes`)
|
||||
2. Upgrade with: `./upgrade.sh` (optional flags: `--tag vX.Y.Z`, `--no-venv`, `--yes`). Local `tools/`, `roles/`, and `skills/` are always preserved.
|
||||
3. The script will back up your `config.yaml` and `data/`, upgrade the code from GitHub Release, update `config.yaml`'s `version`, then restart the server.
|
||||
|
||||
Recommended one-liner:
|
||||
@@ -281,7 +285,7 @@ Requirements / tips:
|
||||
- **Supervisor orchestrator**: fixed name **`orchestrator-supervisor.md`** (plus optional `orchestrator_instruction_supervisor`); requires at least one sub-agent.
|
||||
- **Sub-agents** (for **deep** / **supervisor**): other `*.md` files (YAML front matter + body). Not used as **`task`** targets if marked orchestrator-only.
|
||||
- **Management** – Web UI: **Agents → Agent management**; API `/api/multi-agent/markdown-agents`.
|
||||
- **Config** – `multi_agent` in `config.yaml`: `enabled`, `default_mode`, `robot_use_multi_agent`, `batch_use_multi_agent`, `max_iteration`, `plan_execute_loop_max_iterations`, per-mode orchestrator instruction fields, optional YAML `sub_agents` merged with disk (`id` clash → Markdown wins), **`eino_skills`**, **`eino_middleware`** (optional ADK middleware and Deep/Supervisor tuning).
|
||||
- **Config** – `multi_agent` in `config.yaml`: `enabled`, `robot_default_agent_mode`, `batch_use_multi_agent`, `max_iteration`, `plan_execute_loop_max_iterations`, per-mode orchestrator instruction fields, optional YAML `sub_agents` merged with disk (`id` clash → Markdown wins), **`eino_skills`**, **`eino_middleware`** (optional ADK middleware and Deep/Supervisor tuning).
|
||||
- **Details** – **[docs/MULTI_AGENT_EINO.md](docs/MULTI_AGENT_EINO.md)** (streaming, robots, batch, middleware caveats).
|
||||
|
||||
### Skills System (Agent Skills + Eino)
|
||||
@@ -532,7 +536,7 @@ agents_dir: "agents" # Multi-agent Markdown definitions (orchestrator + sub-age
|
||||
multi_agent:
|
||||
enabled: false
|
||||
default_mode: "single" # single | multi (UI default when multi-agent is enabled)
|
||||
robot_use_multi_agent: false
|
||||
robot_default_agent_mode: react
|
||||
batch_use_multi_agent: false
|
||||
orchestrator_instruction: "" # Deep; used when orchestrator.md body is empty
|
||||
# orchestrator_instruction_plan_execute / orchestrator_instruction_supervisor optional
|
||||
|
||||
+11
-7
@@ -173,9 +173,11 @@ chmod +x run.sh && ./run.sh
|
||||
- ✅ 编译构建项目
|
||||
- ✅ 启动服务器
|
||||
|
||||
**网络默认:** `run.sh` 会以 **`--https`** 并传入项目根 **`config.yaml`** 启动(本机自签证书,多路流式场景更稳)。只要明文 HTTP 用 **`./run.sh --http`**。生产环境在 **`config.yaml`** 的 **`server.tls_cert_path` / `server.tls_key_path`** 配正式证书(见文件内注释)。手动启动可加 **`--https`** 或环境变量 **`CYBERSTRIKE_HTTPS=1`**;`-config` 写错时程序会在终端提示正确写法。
|
||||
|
||||
**首次配置:**
|
||||
1. **配置 AI 模型 API**(首次使用前必填)
|
||||
- 启动后访问 http://localhost:8080
|
||||
- 启动后在浏览器打开 **`https://127.0.0.1:8080/`**(或 **`https://localhost:8080/`**;端口以 `config.yaml` 中 **`server.port`** 为准,默认 8080),并按提示信任自签证书。若使用 **`./run.sh --http`**,则改用 **`http://`** 访问。
|
||||
- 进入 `设置` → 填写 API 配置信息:
|
||||
```yaml
|
||||
openai:
|
||||
@@ -196,20 +198,22 @@ chmod +x run.sh && ./run.sh
|
||||
|
||||
**其他启动方式:**
|
||||
```bash
|
||||
# 直接运行(需手动配置环境)
|
||||
go run cmd/server/main.go
|
||||
# 直接运行(需自行配环境);与 run.sh 默认一致可加 --https
|
||||
go run cmd/server/main.go --https
|
||||
|
||||
# 手动编译
|
||||
go build -o cyberstrike-ai cmd/server/main.go
|
||||
./cyberstrike-ai
|
||||
./cyberstrike-ai --https
|
||||
```
|
||||
|
||||
若日志出现 `client sent an HTTP request to an HTTPS server`,说明仍有客户端用 **`http://`** 访问只提供 HTTPS 的端口,请改为 **`https://`**。
|
||||
|
||||
**说明:** Python 虚拟环境(`venv/`)由 `run.sh` 自动创建和管理。需要 Python 的工具(如 `api-fuzzer`、`http-framework-test` 等)会自动使用该环境。
|
||||
|
||||
### CyberStrikeAI 版本更新(无兼容性问题)
|
||||
|
||||
1. (首次使用)启用脚本:`chmod +x upgrade.sh`
|
||||
2. 一键升级:`./upgrade.sh`(可选参数:`--tag vX.Y.Z`、`--no-venv`、`--preserve-custom`、`--yes`)
|
||||
2. 一键升级:`./upgrade.sh`(可选参数:`--tag vX.Y.Z`、`--no-venv`、`--yes`)。本地的 `tools/`、`roles/`、`skills/` 会始终保留不被覆盖。
|
||||
3. 脚本会备份你的 `config.yaml` 和 `data/`,从 GitHub Release 升级代码,更新 `config.yaml` 的 `version` 字段后重启服务。
|
||||
|
||||
推荐的一键指令:
|
||||
@@ -279,7 +283,7 @@ go build -o cyberstrike-ai cmd/server/main.go
|
||||
- **Supervisor 主代理**:固定 **`orchestrator-supervisor.md`**(另可配 `orchestrator_instruction_supervisor`);至少需一名子代理。
|
||||
- **子代理**(**deep** / **supervisor**):其余 `*.md`;标成 orchestrator 的不会进入 `task` 列表。
|
||||
- **界面管理**:**Agents → Agent 管理**;API `/api/multi-agent/markdown-agents`。
|
||||
- **配置项**:`multi_agent`:`enabled`、`default_mode`、`robot_use_multi_agent`、`batch_use_multi_agent`、`max_iteration`、`plan_execute_loop_max_iterations`、各模式 orchestrator 指令字段、可选 YAML `sub_agents` 与目录合并(同 `id` → Markdown 优先)、**`eino_skills`**、**`eino_middleware`**。
|
||||
- **配置项**:`multi_agent`:`enabled`、`robot_default_agent_mode`、`batch_use_multi_agent`、`max_iteration`、`plan_execute_loop_max_iterations`、各模式 orchestrator 指令字段、可选 YAML `sub_agents` 与目录合并(同 `id` → Markdown 优先)、**`eino_skills`**、**`eino_middleware`**。
|
||||
- **更多细节**:[docs/MULTI_AGENT_EINO.md](docs/MULTI_AGENT_EINO.md)(流式、机器人、批量、中间件差异)。
|
||||
|
||||
### Skills 技能系统(Agent Skills + Eino)
|
||||
@@ -530,7 +534,7 @@ agents_dir: "agents" # 多代理 Markdown(主代理 orchestrator.md + 子代
|
||||
multi_agent:
|
||||
enabled: false
|
||||
default_mode: "single" # single | multi(开启多代理时的界面默认模式)
|
||||
robot_use_multi_agent: false
|
||||
robot_default_agent_mode: react
|
||||
batch_use_multi_agent: false
|
||||
orchestrator_instruction: "" # Deep;orchestrator.md 正文为空时使用
|
||||
# orchestrator_instruction_plan_execute / orchestrator_instruction_supervisor 可选
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"cyberstrike-ai/internal/logger"
|
||||
"cyberstrike-ai/internal/mcp"
|
||||
"cyberstrike-ai/internal/security"
|
||||
"cyberstrike-ai/internal/storage"
|
||||
"flag"
|
||||
"fmt"
|
||||
"os"
|
||||
@@ -32,6 +33,23 @@ func main() {
|
||||
// 创建安全工具执行器
|
||||
executor := security.NewExecutor(&cfg.Security, mcpServer, log.Logger)
|
||||
|
||||
// 初始化结果存储(与 internal/app/app.go 同样的逻辑)。
|
||||
// stdio 模式下原本不初始化,导致 'exec' 等查询型工具报"结果存储未初始化"。
|
||||
resultStorageDir := "tmp"
|
||||
if cfg.Agent.ResultStorageDir != "" {
|
||||
resultStorageDir = cfg.Agent.ResultStorageDir
|
||||
}
|
||||
if err := os.MkdirAll(resultStorageDir, 0755); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "创建结果存储目录失败: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
resultStorage, err := storage.NewFileResultStorage(resultStorageDir, log.Logger)
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "初始化结果存储失败: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
executor.SetResultStorage(resultStorage)
|
||||
|
||||
// 注册工具
|
||||
executor.RegisterTools(mcpServer)
|
||||
|
||||
|
||||
+43
-3
@@ -9,22 +9,62 @@ import (
|
||||
"fmt"
|
||||
"os"
|
||||
"os/signal"
|
||||
"strings"
|
||||
"syscall"
|
||||
)
|
||||
|
||||
func main() {
|
||||
var configPath = flag.String("config", "config.yaml", "配置文件路径")
|
||||
var httpsBootstrap = flag.Bool("https", false, "启用主站 HTTPS:未配置 tls_cert_path/tls_key_path 时使用内存自签证书(本地测试);与 run.sh 默认行为一致")
|
||||
flag.Parse()
|
||||
|
||||
// 环境变量兼容(便于 systemd/docker 等不传参场景)
|
||||
if !*httpsBootstrap {
|
||||
v := strings.TrimSpace(os.Getenv("CYBERSTRIKE_HTTPS"))
|
||||
if v == "1" || strings.EqualFold(v, "true") || strings.EqualFold(v, "yes") {
|
||||
*httpsBootstrap = true
|
||||
}
|
||||
}
|
||||
|
||||
// 加载配置
|
||||
cfg, err := config.Load(*configPath)
|
||||
cp := strings.TrimSpace(*configPath)
|
||||
if cp == "" {
|
||||
cp = "config.yaml"
|
||||
}
|
||||
if strings.HasPrefix(cp, "-") {
|
||||
fmt.Fprintf(os.Stderr, "无效的 -config 路径 %q。\n若同时需要 HTTPS,请写成: ./cyberstrike-ai --https -config config.yaml(-config 后必须是 yaml 文件路径)。\n", cp)
|
||||
os.Exit(2)
|
||||
}
|
||||
cfg, err := config.Load(cp)
|
||||
if err != nil {
|
||||
fmt.Printf("加载配置失败: %v\n", err)
|
||||
return
|
||||
}
|
||||
|
||||
if *httpsBootstrap {
|
||||
config.ApplyDevHTTPSBootstrap(cfg)
|
||||
}
|
||||
|
||||
port := cfg.Server.Port
|
||||
if port <= 0 {
|
||||
port = 8080
|
||||
}
|
||||
scheme := "http"
|
||||
if config.MainWebUIUsesHTTPS(&cfg.Server) {
|
||||
scheme = "https"
|
||||
}
|
||||
fmt.Println()
|
||||
fmt.Printf("→ Web 界面: %s://127.0.0.1:%d/\n", scheme, port)
|
||||
if scheme == "https" && cfg.Server.TLSAutoSelfSign {
|
||||
fmt.Println(" (内存自签证书:浏览器首次需确认「继续访问」)")
|
||||
}
|
||||
if scheme == "https" && config.ServerHTTPRedirectEnabled(&cfg.Server) {
|
||||
fmt.Printf(" (http://127.0.0.1:%d/ 将自动跳转到 HTTPS)\n", port)
|
||||
}
|
||||
fmt.Println()
|
||||
|
||||
// MCP 启用且 auth_header_value 为空时,自动生成随机密钥并写回配置
|
||||
if err := config.EnsureMCPAuth(*configPath, cfg); err != nil {
|
||||
if err := config.EnsureMCPAuth(cp, cfg); err != nil {
|
||||
fmt.Printf("MCP 鉴权配置失败: %v\n", err)
|
||||
return
|
||||
}
|
||||
@@ -44,7 +84,7 @@ func main() {
|
||||
signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM)
|
||||
|
||||
// 创建应用
|
||||
application, err := app.New(cfg, log)
|
||||
application, err := app.New(cfg, log, cp)
|
||||
if err != nil {
|
||||
log.Fatal("应用初始化失败", "error", err)
|
||||
}
|
||||
|
||||
+59
-6
@@ -10,11 +10,22 @@
|
||||
# ============================================
|
||||
|
||||
# 前端显示的版本号(可选,不填则显示默认版本)
|
||||
version: "v1.6.5"
|
||||
version: "v1.6.22"
|
||||
# 服务器配置
|
||||
server:
|
||||
host: 0.0.0.0 # 监听地址,0.0.0.0 表示监听所有网络接口
|
||||
port: 8080 # HTTP 服务端口,可通过浏览器访问 http://localhost:8080
|
||||
port: 8080 # 服务端口;未启用 TLS 时为 http://localhost:8080
|
||||
# --- 可选:HTTPS + HTTP/2(缓解浏览器对同源 HTTP/1.1 的并发连接数限制,多路 Deep 流式更稳)---
|
||||
# 启用 TLS 的条件(满足其一即可):tls_enabled: true,或 tls_auto_self_sign: true,或同时配置了 tls_cert_path + tls_key_path。
|
||||
# 启用后请用 https://127.0.0.1:<本端口>/ 访问;若仍用 http:// 访问同端口,将自动 308 跳转到 HTTPS(可用 tls_http_redirect: false 关闭)。
|
||||
tls_enabled: true
|
||||
# 启用 HTTPS 时,明文 HTTP 是否自动跳转到 HTTPS(默认 true;同端口嗅探 TLS/HTTP 后分流)
|
||||
# tls_http_redirect: true
|
||||
# 方式 A(推荐生产):PEM 证书与私钥路径
|
||||
# tls_cert_path: /path/to/fullchain.pem
|
||||
# tls_key_path: /path/to/privkey.pem
|
||||
# 方式 B(仅本地/测试):无证书文件时内存自签(浏览器会提示不受信任;SAN 含 localhost / 127.0.0.1)
|
||||
tls_auto_self_sign: true
|
||||
# 认证配置
|
||||
auth:
|
||||
password: # Web 登录密码,请修改为强密码
|
||||
@@ -23,6 +34,12 @@ auth:
|
||||
log:
|
||||
level: info # 日志级别: debug(调试), info(信息), warn(警告), error(错误)
|
||||
output: stdout # 日志输出位置: stdout(标准输出), stderr(标准错误), 或文件路径
|
||||
# 平台操作审计(系统设置 -> 日志审计;不记录对话正文与每次工具调用)
|
||||
audit:
|
||||
enabled: true
|
||||
retention_days: 15 # 0 表示不自动清理
|
||||
max_detail_bytes: 8192
|
||||
auth_failure_cooldown_seconds: 60 # 同一 IP 登录/改密失败审计最短间隔(秒);未配置时默认 60;-1 关闭节流
|
||||
# ============================================
|
||||
# 对话相关配置
|
||||
# ============================================
|
||||
@@ -41,6 +58,13 @@ openai:
|
||||
api_key: sk-xxxxxxx # API 密钥(必填)
|
||||
model: qwen3-max # 模型名称(必填)
|
||||
max_total_tokens: 120000 # LLM 相关上下文的最大 Token 数限制(内存压缩和攻击链构建会共用此配置)
|
||||
# Eino 路径模型推理:DeepSeek/OpenAI 为 thinking / reasoning_effort 等;provider 为 claude 时合并为 Anthropic 顶层 thinking(extended thinking),mode: off 关闭
|
||||
reasoning:
|
||||
mode: on # auto | on | off;off 时不附加任何推理扩展字段
|
||||
effort: high # low | medium | high | max | xhigh(最高档:OpenAI 常用 xhigh,部分网关用 max,原样下发);空表示不指定
|
||||
allow_client_reasoning: true # false 时忽略对话请求体 reasoning,仅以下方为准
|
||||
profile: openai_compat # auto | deepseek_compat | openai_compat | output_config_effort
|
||||
# extra_request_fields: {} # 可选:管理员自定义根级 JSON 片段(高级)
|
||||
# ============================================
|
||||
# 信息收集(FOFA)配置(可选)
|
||||
# ============================================
|
||||
@@ -53,21 +77,23 @@ fofa:
|
||||
# Agent 配置
|
||||
# 达到最大迭代次数时,AI 会自动总结测试结果
|
||||
agent:
|
||||
max_iterations: 120 # 最大迭代次数,AI 代理最多执行多少轮工具调用
|
||||
max_iterations: 1200 # 最大迭代次数,AI 代理最多执行多少轮工具调用
|
||||
large_result_threshold: 102400 # 大结果阈值(字节),默认50KB,超过此大小会自动保存到存储
|
||||
result_storage_dir: tmp # 结果存储目录,大结果会保存在此目录下
|
||||
tool_timeout_minutes: 30 # 单次工具执行最大时长(分钟),超时自动终止;0 表示不限制(不推荐,易出现长时间挂起)
|
||||
tool_timeout_minutes: 60 # 单次工具执行最大时长(分钟),超时自动终止;0 表示不限制(不推荐,易出现长时间挂起)
|
||||
# system_prompt_path: prompts/single-react.md # 可选:单代理系统提示文件(相对本配置文件所在目录);非空且可读时替换内置提示
|
||||
|
||||
system_prompt_path: ""
|
||||
# 人机协同(HITL)全局白名单:此处列出的工具始终免审批,与对话页「白名单工具(免审批,逗号分隔)」合并为并集;侧栏「应用」可合并写入本列表并立即生效。
|
||||
hitl:
|
||||
# 按你环境里的真实工具名增删(与侧栏一致、小写不敏感);不需要全局免审批可改为 []
|
||||
tool_whitelist: [read_file, list_dir, glob, grep]
|
||||
# 多代理(CloudWeGo Eino DeepAgent,与上方单 Agent /api/agent-loop 并存)
|
||||
# 依赖在 go.mod 中拉取;若下载失败可设置: go env -w GOPROXY=https://goproxy.cn,direct
|
||||
# 启用后需重启服务才会注册 /api/multi-agent 与 /api/multi-agent/stream;Deep / Plan-Execute / Supervisor 由对话页与 WebShell 所选模式在请求体中传入;机器人/批量无请求体时固定按 deep
|
||||
# 启用后需重启服务才会注册 /api/multi-agent 与 /api/multi-agent/stream;Deep / Plan-Execute / Supervisor 由对话页与 WebShell 所选模式在请求体中传入;机器人按 robot_default_agent_mode
|
||||
multi_agent:
|
||||
enabled: true
|
||||
robot_use_multi_agent: true # true 时企业微信/钉钉/飞书机器人也走 Eino 多代理(成本更高)
|
||||
robot_default_agent_mode: eino_single # 企微/钉钉/飞书机器人默认对话模式:react | eino_single | deep | plan_execute | supervisor
|
||||
batch_use_multi_agent: false # true 时「批量任务」队列中每个子任务也走 Eino 多代理(成本更高)
|
||||
max_iteration: 0 # 主代理 / plan_execute 执行器最大轮次,0 表示沿用 agent.max_iterations
|
||||
# plan_execute 专用:execute↔replan 外层循环上限,0 表示 Eino 默认 10。当前实现下 Executor 会挂载 patch/reduction/tool_search 等前置中间件。
|
||||
@@ -107,9 +133,26 @@ multi_agent:
|
||||
plan_execute_max_step_result_runes: 4000 # plan_execute 每步结果最大字符数(超出截断)
|
||||
plan_execute_keep_last_steps: 8 # plan_execute 仅保留最近 N 步正文,早期步骤折叠为标题
|
||||
checkpoint_dir: "" # 非空:为 adk.NewRunner 启用按会话子目录的文件型 CheckPointStore,便于中断恢复持久化;Resume 的 HTTP/前端流程需另行对接
|
||||
run_retry_max_attempts: 0 # >0:429/5xx/网络抖动时 ADK 运行循环指数退避续跑次数;0=默认 10
|
||||
run_retry_max_backoff_sec: 0 # 单次退避上限秒数;0=默认 30
|
||||
deep_output_key: "" # 非空:将最终助手输出写入 adk session 的键名(Deep 与 Supervisor 主代理);空表示不写入
|
||||
deep_model_retry_max_retries: 0 # >0:ChatModel 调用失败时的框架级最大重试次数(Deep 与 Supervisor 主);0:不重试
|
||||
task_tool_description_prefix: "" # 非空:仅 Deep 的 task 工具使用自定义描述前缀,运行时会拼接子代理名称;空则走 Eino 默认生成逻辑
|
||||
# Eino callbacks + OpenTelemetry:框架级 span(与 Zap 对齐);默认不向终端用户 UI 推 eino_trace_*(见 sse_trace_to_client)
|
||||
eino_callbacks:
|
||||
enabled: true
|
||||
# log_only=仅 Zap+OTel(推荐默认)| sse/full=才启用流式回调副本关闭等(full 含 stream hooks)
|
||||
mode: log_only
|
||||
sse_trace_to_client: false # true:且 mode 为 sse/full 时,向前端时间线推送 eino_trace_*(排障/内网演示用)
|
||||
max_input_summary_runes: 400
|
||||
max_output_summary_runes: 400
|
||||
zap_verbose: false # true:Debug 附带 input/output 摘要
|
||||
otel:
|
||||
enabled: true
|
||||
service_name: cyberstrike-ai
|
||||
exporter: stdout # none | stdout(开发/本机)| otlphttp(生产接 Collector)
|
||||
otlp_endpoint: localhost:4318 # otlphttp 时使用,host:port,路径固定 /v1/traces
|
||||
sample_ratio: 1.0 # 0~1,ParentBased+TraceIDRatio
|
||||
# 数据库配置
|
||||
database:
|
||||
path: data/conversations.db # SQLite 数据库文件路径,用于存储对话历史和消息
|
||||
@@ -202,6 +245,14 @@ knowledge:
|
||||
# 用于在手机端通过企业微信/钉钉/飞书与 CyberStrikeAI 对话,无需部署在服务器上也可使用
|
||||
# 在系统设置 -> 机器人设置 中可配置
|
||||
robots:
|
||||
wechat: # 微信 iLink(个人微信 ClawBot,扫码绑定)
|
||||
enabled: false
|
||||
bot_token: ""
|
||||
ilink_bot_id: ""
|
||||
ilink_user_id: ""
|
||||
base_url: https://ilinkai.weixin.qq.com
|
||||
bot_type: "3"
|
||||
bot_agent: CyberStrikeAI/1.0
|
||||
wecom: # 企业微信
|
||||
enabled: false
|
||||
token: ""
|
||||
@@ -213,11 +264,13 @@ robots:
|
||||
enabled: false
|
||||
client_id: ""
|
||||
client_secret: ""
|
||||
allow_conversation_id_fallback: false
|
||||
lark: # 飞书
|
||||
enabled: false
|
||||
app_id: ""
|
||||
app_secret: ""
|
||||
verify_token: ""
|
||||
allow_chat_id_fallback: false
|
||||
# ============================================
|
||||
# Skills 相关配置
|
||||
# ============================================
|
||||
|
||||
@@ -27,7 +27,14 @@ require (
|
||||
github.com/open-dingtalk/dingtalk-stream-sdk-go v0.9.1
|
||||
github.com/pkoukk/tiktoken-go v0.1.8
|
||||
github.com/robfig/cron/v3 v3.0.1
|
||||
github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e
|
||||
go.opentelemetry.io/otel v1.34.0
|
||||
go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp v1.34.0
|
||||
go.opentelemetry.io/otel/exporters/stdout/stdouttrace v1.34.0
|
||||
go.opentelemetry.io/otel/sdk v1.34.0
|
||||
go.opentelemetry.io/otel/trace v1.34.0
|
||||
go.uber.org/zap v1.26.0
|
||||
golang.org/x/net v0.35.0
|
||||
golang.org/x/text v0.26.0
|
||||
golang.org/x/time v0.14.0
|
||||
gopkg.in/yaml.v3 v3.0.1
|
||||
@@ -39,6 +46,7 @@ require (
|
||||
github.com/buger/jsonparser v1.1.1 // indirect
|
||||
github.com/bytedance/gopkg v0.1.3 // indirect
|
||||
github.com/bytedance/sonic/loader v0.5.0 // indirect
|
||||
github.com/cenkalti/backoff/v4 v4.3.0 // indirect
|
||||
github.com/cloudwego/base64x v0.1.6 // indirect
|
||||
github.com/cloudwego/eino-ext/libs/acl/openai v0.1.17 // indirect
|
||||
github.com/dlclark/regexp2 v1.10.0 // indirect
|
||||
@@ -46,6 +54,8 @@ require (
|
||||
github.com/evanphx/json-patch v0.5.2 // indirect
|
||||
github.com/gabriel-vasile/mimetype v1.4.2 // indirect
|
||||
github.com/gin-contrib/sse v0.1.0 // indirect
|
||||
github.com/go-logr/logr v1.4.2 // indirect
|
||||
github.com/go-logr/stdr v1.2.2 // indirect
|
||||
github.com/go-playground/locales v0.14.1 // indirect
|
||||
github.com/go-playground/universal-translator v0.18.1 // indirect
|
||||
github.com/go-playground/validator/v10 v10.14.0 // indirect
|
||||
@@ -53,6 +63,7 @@ require (
|
||||
github.com/gogo/protobuf v1.3.2 // indirect
|
||||
github.com/google/jsonschema-go v0.3.0 // indirect
|
||||
github.com/goph/emperror v0.17.2 // indirect
|
||||
github.com/grpc-ecosystem/grpc-gateway/v2 v2.25.1 // indirect
|
||||
github.com/json-iterator/go v1.1.12 // indirect
|
||||
github.com/klauspost/cpuid/v2 v2.2.10 // indirect
|
||||
github.com/leodido/go-urn v1.2.4 // indirect
|
||||
@@ -71,14 +82,20 @@ require (
|
||||
github.com/wk8/go-ordered-map/v2 v2.1.8 // indirect
|
||||
github.com/yargevad/filepathx v1.0.0 // indirect
|
||||
github.com/yosida95/uritemplate/v3 v3.0.2 // indirect
|
||||
go.opentelemetry.io/auto/sdk v1.1.0 // indirect
|
||||
go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.34.0 // indirect
|
||||
go.opentelemetry.io/otel/metric v1.34.0 // indirect
|
||||
go.opentelemetry.io/proto/otlp v1.5.0 // indirect
|
||||
go.uber.org/multierr v1.11.0 // indirect
|
||||
golang.org/x/arch v0.15.0 // indirect
|
||||
golang.org/x/crypto v0.39.0 // indirect
|
||||
golang.org/x/exp v0.0.0-20250305212735-054e65f0b394 // indirect
|
||||
golang.org/x/net v0.24.0 // indirect
|
||||
golang.org/x/oauth2 v0.30.0 // indirect
|
||||
golang.org/x/sys v0.33.0 // indirect
|
||||
google.golang.org/protobuf v1.30.0 // indirect
|
||||
google.golang.org/genproto/googleapis/api v0.0.0-20250115164207-1a7da9e5054f // indirect
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20250115164207-1a7da9e5054f // indirect
|
||||
google.golang.org/grpc v1.69.4 // indirect
|
||||
google.golang.org/protobuf v1.36.3 // indirect
|
||||
)
|
||||
|
||||
// 修复钉钉 Stream SDK 在长连接断开(熄屏/网络中断)后 "panic: send on closed channel" 问题
|
||||
|
||||
@@ -17,6 +17,8 @@ github.com/bytedance/sonic v1.15.0 h1:/PXeWFaR5ElNcVE84U0dOHjiMHQOwNIx3K4ymzh/uS
|
||||
github.com/bytedance/sonic v1.15.0/go.mod h1:tFkWrPz0/CUCLEF4ri4UkHekCIcdnkqXw9VduqpJh0k=
|
||||
github.com/bytedance/sonic/loader v0.5.0 h1:gXH3KVnatgY7loH5/TkeVyXPfESoqSBSBEiDd5VjlgE=
|
||||
github.com/bytedance/sonic/loader v0.5.0/go.mod h1:AR4NYCk5DdzZizZ5djGqQ92eEhCCcdf5x77udYiSJRo=
|
||||
github.com/cenkalti/backoff/v4 v4.3.0 h1:MyRJ/UdXutAwSAT+s3wNd7MfTIcy71VQueUuFK343L8=
|
||||
github.com/cenkalti/backoff/v4 v4.3.0/go.mod h1:Y3VNntkOUPxTVeUxJ/G5vcM//AlwfmyYozVcomhLiZE=
|
||||
github.com/certifi/gocertifi v0.0.0-20190105021004-abcd57078448/go.mod h1:GJKEexRPVJrBSOjoqN5VNOIKJ5Q3RViH6eu3puDRwx4=
|
||||
github.com/cloudwego/base64x v0.1.6 h1:t11wG9AECkCDk5fMSoxmufanudBtJ+/HemLstXDLI2M=
|
||||
github.com/cloudwego/base64x v0.1.6/go.mod h1:OFcloc187FXDaYHvrNIjxSe8ncn0OOM8gEHfghB2IPU=
|
||||
@@ -59,6 +61,11 @@ github.com/gin-gonic/gin v1.9.1 h1:4idEAncQnU5cB7BeOkPtxjfCSye0AAm1R0RVIqJ+Jmg=
|
||||
github.com/gin-gonic/gin v1.9.1/go.mod h1:hPrL7YrpYKXt5YId3A/Tnip5kqbEAP+KLuI3SUcPTeU=
|
||||
github.com/go-check/check v0.0.0-20180628173108-788fd7840127 h1:0gkP6mzaMqkmpcJYCFOLkIBwI7xFExG03bbkOkCvUPI=
|
||||
github.com/go-check/check v0.0.0-20180628173108-788fd7840127/go.mod h1:9ES+weclKsC9YodN5RgxqK/VD9HM9JsCSh7rNhMZE98=
|
||||
github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A=
|
||||
github.com/go-logr/logr v1.4.2 h1:6pFjapn8bFcIbiKo3XT4j/BhANplGihG6tvd+8rYgrY=
|
||||
github.com/go-logr/logr v1.4.2/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY=
|
||||
github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag=
|
||||
github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE=
|
||||
github.com/go-playground/assert/v2 v2.2.0 h1:JvknZsQTYeFEAhQwI4qEt9cyV5ONwRHC+lYKSsYSR8s=
|
||||
github.com/go-playground/assert/v2 v2.2.0/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4=
|
||||
github.com/go-playground/locales v0.14.1 h1:EWaQ/wswjilfKLTECiXz7Rh+3BjFhfDFKv/oXslEjJA=
|
||||
@@ -75,8 +82,8 @@ github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69
|
||||
github.com/golang-jwt/jwt/v5 v5.2.2 h1:Rl4B7itRWVtYIHFrSNd7vhTiz9UpLdi6gZhZ3wEeDy8=
|
||||
github.com/golang-jwt/jwt/v5 v5.2.2/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk=
|
||||
github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
|
||||
github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk=
|
||||
github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
|
||||
github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek=
|
||||
github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps=
|
||||
github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
|
||||
github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU=
|
||||
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
|
||||
@@ -90,6 +97,8 @@ github.com/gopherjs/gopherjs v1.17.2 h1:fQnZVsXk8uxXIStYb0N4bGk7jeyTalG/wsZjQ25d
|
||||
github.com/gopherjs/gopherjs v1.17.2/go.mod h1:pRRIvn/QzFLrKfvEz3qUuEhtE/zLCWfreZ6J5gM2i+k=
|
||||
github.com/gorilla/websocket v1.5.0 h1:PPwGk2jz7EePpoHN/+ClbZu8SPxiqlu12wZP/3sWmnc=
|
||||
github.com/gorilla/websocket v1.5.0/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
|
||||
github.com/grpc-ecosystem/grpc-gateway/v2 v2.25.1 h1:VNqngBF40hVlDloBruUehVYC3ArSgIyScOAyMRqBxRg=
|
||||
github.com/grpc-ecosystem/grpc-gateway/v2 v2.25.1/go.mod h1:RBRO7fro65R6tjKzYgLAFo0t1QEXY1Dp+i/bvpRiqiQ=
|
||||
github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU=
|
||||
github.com/jessevdk/go-flags v1.4.0/go.mod h1:4FA24M0QyGHXBuZZK/XkWh8h0e1EYbRYJSGM75WSRxI=
|
||||
github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM=
|
||||
@@ -154,6 +163,8 @@ github.com/rollbar/rollbar-go v1.0.2/go.mod h1:AcFs5f0I+c71bpHlXNNDbOWJiKwjFDtIS
|
||||
github.com/sirupsen/logrus v1.2.0/go.mod h1:LxeOpSwHxABJmUn/MG1IvRgCAasNZTLOkJPxbbu5VWo=
|
||||
github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ=
|
||||
github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ=
|
||||
github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e h1:MRM5ITcdelLK2j1vwZ3Je0FKVCfqOLp5zO6trqMLYs0=
|
||||
github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e/go.mod h1:XV66xRDqSt+GTGFMVlhk3ULuV0y9ZmzeVGR4mloJI3M=
|
||||
github.com/slongfield/pyfmt v0.0.0-20220222012616-ea85ff4c361f h1:Z2cODYsUxQPofhpYRMQVwWz4yUVpHF+vPi+eUdruUYI=
|
||||
github.com/slongfield/pyfmt v0.0.0-20220222012616-ea85ff4c361f/go.mod h1:JqzWyvTuI2X4+9wOHmKSQCYxybB/8j6Ko43qVmXDuZg=
|
||||
github.com/smarty/assertions v1.16.0 h1:EvHNkdRA4QHMrn75NZSoUQ/mAUXAYWfatfB01yTCzfY=
|
||||
@@ -191,6 +202,26 @@ github.com/yosida95/uritemplate/v3 v3.0.2 h1:Ed3Oyj9yrmi9087+NczuL5BwkIc4wvTb5zI
|
||||
github.com/yosida95/uritemplate/v3 v3.0.2/go.mod h1:ILOh0sOhIJR3+L/8afwt/kE++YT040gmv5BQTMR2HP4=
|
||||
github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
|
||||
github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
|
||||
go.opentelemetry.io/auto/sdk v1.1.0 h1:cH53jehLUN6UFLY71z+NDOiNJqDdPRaXzTel0sJySYA=
|
||||
go.opentelemetry.io/auto/sdk v1.1.0/go.mod h1:3wSPjt5PWp2RhlCcmmOial7AvC4DQqZb7a7wCow3W8A=
|
||||
go.opentelemetry.io/otel v1.34.0 h1:zRLXxLCgL1WyKsPVrgbSdMN4c0FMkDAskSTQP+0hdUY=
|
||||
go.opentelemetry.io/otel v1.34.0/go.mod h1:OWFPOQ+h4G8xpyjgqo4SxJYdDQ/qmRH+wivy7zzx9oI=
|
||||
go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.34.0 h1:OeNbIYk/2C15ckl7glBlOBp5+WlYsOElzTNmiPW/x60=
|
||||
go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.34.0/go.mod h1:7Bept48yIeqxP2OZ9/AqIpYS94h2or0aB4FypJTc8ZM=
|
||||
go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp v1.34.0 h1:BEj3SPM81McUZHYjRS5pEgNgnmzGJ5tRpU5krWnV8Bs=
|
||||
go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp v1.34.0/go.mod h1:9cKLGBDzI/F3NoHLQGm4ZrYdIHsvGt6ej6hUowxY0J4=
|
||||
go.opentelemetry.io/otel/exporters/stdout/stdouttrace v1.34.0 h1:jBpDk4HAUsrnVO1FsfCfCOTEc/MkInJmvfCHYLFiT80=
|
||||
go.opentelemetry.io/otel/exporters/stdout/stdouttrace v1.34.0/go.mod h1:H9LUIM1daaeZaz91vZcfeM0fejXPmgCYE8ZhzqfJuiU=
|
||||
go.opentelemetry.io/otel/metric v1.34.0 h1:+eTR3U0MyfWjRDhmFMxe2SsW64QrZ84AOhvqS7Y+PoQ=
|
||||
go.opentelemetry.io/otel/metric v1.34.0/go.mod h1:CEDrp0fy2D0MvkXE+dPV7cMi8tWZwX3dmaIhwPOaqHE=
|
||||
go.opentelemetry.io/otel/sdk v1.34.0 h1:95zS4k/2GOy069d321O8jWgYsW3MzVV+KuSPKp7Wr1A=
|
||||
go.opentelemetry.io/otel/sdk v1.34.0/go.mod h1:0e/pNiaMAqaykJGKbi+tSjWfNNHMTxoC9qANsCzbyxU=
|
||||
go.opentelemetry.io/otel/sdk/metric v1.31.0 h1:i9hxxLJF/9kkvfHppyLL55aW7iIJz4JjxTeYusH7zMc=
|
||||
go.opentelemetry.io/otel/sdk/metric v1.31.0/go.mod h1:CRInTMVvNhUKgSAMbKyTMxqOBC0zgyxzW55lZzX43Y8=
|
||||
go.opentelemetry.io/otel/trace v1.34.0 h1:+ouXS2V8Rd4hp4580a8q23bg0azF2nI8cqLYnC8mh/k=
|
||||
go.opentelemetry.io/otel/trace v1.34.0/go.mod h1:Svm7lSjQD7kG7KJ/MUHPVXSDGz2OX4h0M2jHBhmSfRE=
|
||||
go.opentelemetry.io/proto/otlp v1.5.0 h1:xJvq7gMzB31/d406fB8U5CBdyQGw4P399D1aQWU/3i4=
|
||||
go.opentelemetry.io/proto/otlp v1.5.0/go.mod h1:keN8WnHxOy8PG0rQZjJJ5A2ebUoafqWp0eVQ4yIXvJ4=
|
||||
go.uber.org/goleak v1.2.0 h1:xqgm/S+aQvhWFTtR0XK3Jvg7z8kGV8P4X14IzwN3Eqk=
|
||||
go.uber.org/goleak v1.2.0/go.mod h1:XJYK+MuIchqpmGmUSAzotztawfKvYLUIgg7guXrwVUo=
|
||||
go.uber.org/mock v0.4.0 h1:VcM4ZOtdbR4f6VXfiOpwpVJDL6lCReaZ6mw31wqh7KU=
|
||||
@@ -216,8 +247,8 @@ golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn
|
||||
golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
|
||||
golang.org/x/net v0.0.0-20200226121028-0de0cce0169b/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
|
||||
golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU=
|
||||
golang.org/x/net v0.24.0 h1:1PcaxkF854Fu3+lvBIx5SYn9wRlBzzcnHZSiaFFAb0w=
|
||||
golang.org/x/net v0.24.0/go.mod h1:2Q7sJY5mzlzWjKtYUEXSlBWCdyaioyXzRB2RtU8KVE8=
|
||||
golang.org/x/net v0.35.0 h1:T5GQRQb2y08kTAByq9L4/bz8cipCdA8FbRTXewonqY8=
|
||||
golang.org/x/net v0.35.0/go.mod h1:EglIi67kWsHKlRzzVMUD93VMSWGFOMSZgxFjparz1Qk=
|
||||
golang.org/x/oauth2 v0.30.0 h1:dnDm7JmhM45NNpd8FDDeLhK6FwqbOf4MLCM9zb1BOHI=
|
||||
golang.org/x/oauth2 v0.30.0/go.mod h1:B++QgG3ZKulg6sRPGD/mqlHQs5rB3Ml9erfeDY7xKlU=
|
||||
golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
@@ -251,9 +282,14 @@ golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8T
|
||||
golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw=
|
||||
google.golang.org/protobuf v1.30.0 h1:kPPoIgf3TsEvrm0PFe15JQ+570QVxYzEvvHqChK+cng=
|
||||
google.golang.org/protobuf v1.30.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I=
|
||||
google.golang.org/genproto/googleapis/api v0.0.0-20250115164207-1a7da9e5054f h1:gap6+3Gk41EItBuyi4XX/bp4oqJ3UwuIMl25yGinuAA=
|
||||
google.golang.org/genproto/googleapis/api v0.0.0-20250115164207-1a7da9e5054f/go.mod h1:Ic02D47M+zbarjYYUlK57y316f2MoN0gjAwI3f2S95o=
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20250115164207-1a7da9e5054f h1:OxYkA3wjPsZyBylwymxSHa7ViiW1Sml4ToBrncvFehI=
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20250115164207-1a7da9e5054f/go.mod h1:+2Yz8+CLJbIfL9z73EW45avw8Lmge3xVElCP9zEKi50=
|
||||
google.golang.org/grpc v1.69.4 h1:MF5TftSMkd8GLw/m0KM6V8CMOCY6NZ1NQDPGFgbTt4A=
|
||||
google.golang.org/grpc v1.69.4/go.mod h1:vyjdE6jLBI76dgpDojsFGNaHlxdjXN9ghpnd2o7JGZ4=
|
||||
google.golang.org/protobuf v1.36.3 h1:82DV7MYdb8anAVi3qge1wSnMDrnKK7ebr+I0hHRN1BU=
|
||||
google.golang.org/protobuf v1.36.3/go.mod h1:9fA7Ob0pmnwhb644+1+CVWFRbNajQ6iRojtC/QF5bRE=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
|
||||
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q=
|
||||
|
||||
+54
-10
@@ -193,6 +193,10 @@ type ChatMessage struct {
|
||||
Content string `json:"content,omitempty"`
|
||||
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
|
||||
ToolCallID string `json:"tool_call_id,omitempty"`
|
||||
// ToolName 仅 tool 角色:从 Eino/轨迹 JSON 的 name 或 tool_name 恢复,供续跑构造 ToolMessage。
|
||||
ToolName string `json:"tool_name,omitempty"`
|
||||
// ReasoningContent 对应 OpenAI/DeepSeek 的 reasoning_content;思考模式 + 工具调用后续跑须回传(见 DeepSeek 文档)。
|
||||
ReasoningContent string `json:"reasoning_content,omitempty"`
|
||||
}
|
||||
|
||||
// MarshalJSON 自定义JSON序列化,将tool_calls中的arguments转换为JSON字符串
|
||||
@@ -206,11 +210,17 @@ func (cm ChatMessage) MarshalJSON() ([]byte, error) {
|
||||
if cm.Content != "" {
|
||||
aux["content"] = cm.Content
|
||||
}
|
||||
if cm.ReasoningContent != "" {
|
||||
aux["reasoning_content"] = cm.ReasoningContent
|
||||
}
|
||||
|
||||
// 添加tool_call_id(如果存在)
|
||||
if cm.ToolCallID != "" {
|
||||
aux["tool_call_id"] = cm.ToolCallID
|
||||
}
|
||||
if cm.ToolName != "" {
|
||||
aux["tool_name"] = cm.ToolName
|
||||
}
|
||||
|
||||
// 转换tool_calls,将arguments转换为JSON字符串
|
||||
if len(cm.ToolCalls) > 0 {
|
||||
@@ -438,6 +448,7 @@ func (a *Agent) AgentLoopWithProgress(ctx context.Context, userInput string, his
|
||||
Content: msg.Content,
|
||||
ToolCalls: msg.ToolCalls,
|
||||
ToolCallID: msg.ToolCallID,
|
||||
ToolName: msg.ToolName,
|
||||
})
|
||||
addedCount++
|
||||
contentPreview := msg.Content
|
||||
@@ -587,11 +598,17 @@ func (a *Agent) AgentLoopWithProgress(ctx context.Context, userInput string, his
|
||||
thinkingStreamSeq++
|
||||
thinkingStreamId := fmt.Sprintf("thinking-stream-%s-%d-%d", conversationID, i+1, thinkingStreamSeq)
|
||||
thinkingStreamStarted := false
|
||||
var thinkingWire string
|
||||
|
||||
response, err := a.callOpenAIStreamWithToolCalls(ctx, messages, tools, func(delta string) error {
|
||||
if delta == "" {
|
||||
return nil
|
||||
}
|
||||
var deltaOut string
|
||||
thinkingWire, deltaOut = openai.NormalizeStreamingDelta(thinkingWire, delta)
|
||||
if deltaOut == "" {
|
||||
return nil
|
||||
}
|
||||
if !thinkingStreamStarted {
|
||||
thinkingStreamStarted = true
|
||||
sendProgress("thinking_stream_start", " ", map[string]interface{}{
|
||||
@@ -600,10 +617,10 @@ func (a *Agent) AgentLoopWithProgress(ctx context.Context, userInput string, his
|
||||
"toolStream": false,
|
||||
})
|
||||
}
|
||||
sendProgress("thinking_stream_delta", delta, map[string]interface{}{
|
||||
sendProgress("thinking_stream_delta", deltaOut, openai.WithSSEAccumulated(map[string]interface{}{
|
||||
"streamId": thinkingStreamId,
|
||||
"iteration": i + 1,
|
||||
})
|
||||
}, thinkingWire))
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
@@ -657,8 +674,8 @@ func (a *Agent) AgentLoopWithProgress(ctx context.Context, userInput string, his
|
||||
|
||||
// 检查是否有工具调用
|
||||
if len(choice.Message.ToolCalls) > 0 {
|
||||
// 思考内容:如果本轮启用了思考流式增量(thinking_stream_*),前端会去重;
|
||||
// 同时也需要在该“思考阶段结束”时补一条可落库的 thinking(用于刷新后持久化展示)。
|
||||
// ReAct 助手正文流式增量(thinking_stream_*)在 UI 上归为「思考」;若与 streamId 重复则前端会去重。
|
||||
// 该条 thinking 用于刷新后持久化展示(与流式聚合一致)。
|
||||
if choice.Message.Content != "" {
|
||||
sendProgress("thinking", choice.Message.Content, map[string]interface{}{
|
||||
"iteration": i + 1,
|
||||
@@ -816,10 +833,16 @@ func (a *Agent) AgentLoopWithProgress(ctx context.Context, userInput string, his
|
||||
"mcpExecutionIds": result.MCPExecutionIDs,
|
||||
"messageGeneratedBy": "summary",
|
||||
})
|
||||
var summaryWire string
|
||||
streamText, _ := a.callOpenAIStreamText(ctx, messages, []Tool{}, func(delta string) error {
|
||||
sendProgress("response_delta", delta, map[string]interface{}{
|
||||
var deltaOut string
|
||||
summaryWire, deltaOut = openai.NormalizeStreamingDelta(summaryWire, delta)
|
||||
if deltaOut == "" {
|
||||
return nil
|
||||
}
|
||||
sendProgress("response_delta", deltaOut, openai.WithSSEAccumulated(map[string]interface{}{
|
||||
"conversationId": conversationID,
|
||||
})
|
||||
}, summaryWire))
|
||||
return nil
|
||||
})
|
||||
if strings.TrimSpace(streamText) != "" {
|
||||
@@ -863,10 +886,16 @@ func (a *Agent) AgentLoopWithProgress(ctx context.Context, userInput string, his
|
||||
"mcpExecutionIds": result.MCPExecutionIDs,
|
||||
"messageGeneratedBy": "summary",
|
||||
})
|
||||
var summaryWire string
|
||||
streamText, _ := a.callOpenAIStreamText(ctx, messages, []Tool{}, func(delta string) error {
|
||||
sendProgress("response_delta", delta, map[string]interface{}{
|
||||
var deltaOut string
|
||||
summaryWire, deltaOut = openai.NormalizeStreamingDelta(summaryWire, delta)
|
||||
if deltaOut == "" {
|
||||
return nil
|
||||
}
|
||||
sendProgress("response_delta", deltaOut, openai.WithSSEAccumulated(map[string]interface{}{
|
||||
"conversationId": conversationID,
|
||||
})
|
||||
}, summaryWire))
|
||||
return nil
|
||||
})
|
||||
if strings.TrimSpace(streamText) != "" {
|
||||
@@ -910,10 +939,16 @@ func (a *Agent) AgentLoopWithProgress(ctx context.Context, userInput string, his
|
||||
"mcpExecutionIds": result.MCPExecutionIDs,
|
||||
"messageGeneratedBy": "max_iter_summary",
|
||||
})
|
||||
var summaryWire string
|
||||
streamText, _ := a.callOpenAIStreamText(ctx, messages, []Tool{}, func(delta string) error {
|
||||
sendProgress("response_delta", delta, map[string]interface{}{
|
||||
var deltaOut string
|
||||
summaryWire, deltaOut = openai.NormalizeStreamingDelta(summaryWire, delta)
|
||||
if deltaOut == "" {
|
||||
return nil
|
||||
}
|
||||
sendProgress("response_delta", deltaOut, openai.WithSSEAccumulated(map[string]interface{}{
|
||||
"conversationId": conversationID,
|
||||
})
|
||||
}, summaryWire))
|
||||
return nil
|
||||
})
|
||||
if strings.TrimSpace(streamText) != "" {
|
||||
@@ -1909,6 +1944,15 @@ func (a *Agent) ExecuteMCPToolForConversation(ctx context.Context, conversationI
|
||||
return a.executeToolViaMCP(ctx, toolName, args)
|
||||
}
|
||||
|
||||
// RecordLocalToolExecution 将非 CallTool 路径完成的工具调用写入 MCP 监控库(与 CallTool 落库一致),返回 executionId。
|
||||
// 用于 Eino filesystem execute 等场景,使助手气泡「渗透测试详情」与常规 MCP 一致可点进监控。
|
||||
func (a *Agent) RecordLocalToolExecution(toolName string, args map[string]interface{}, resultText string, invokeErr error) string {
|
||||
if a == nil || a.mcpServer == nil {
|
||||
return ""
|
||||
}
|
||||
return a.mcpServer.RecordCompletedToolInvocation(toolName, args, resultText, invokeErr)
|
||||
}
|
||||
|
||||
// CancelMCPToolExecutionWithNote 取消一次进行中的 MCP 工具(先内部后外部),与监控页「终止工具」一致;note 非空时合并进返回给模型的文本。
|
||||
func (a *Agent) CancelMCPToolExecutionWithNote(executionID, note string) bool {
|
||||
executionID = strings.TrimSpace(executionID)
|
||||
|
||||
@@ -0,0 +1,167 @@
|
||||
package agent
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// ParseTraceMessages 解析落库的 last_react_input(OpenAI 风格 messages JSON 数组)。
|
||||
func ParseTraceMessages(traceInputJSON string) ([]ChatMessage, error) {
|
||||
traceInputJSON = strings.TrimSpace(traceInputJSON)
|
||||
if traceInputJSON == "" {
|
||||
return nil, nil
|
||||
}
|
||||
var raw []map[string]interface{}
|
||||
if err := json.Unmarshal([]byte(traceInputJSON), &raw); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
out := make([]ChatMessage, 0, len(raw))
|
||||
for _, msgMap := range raw {
|
||||
msg := ChatMessage{}
|
||||
role, _ := msgMap["role"].(string)
|
||||
if role == "" {
|
||||
continue
|
||||
}
|
||||
msg.Role = role
|
||||
if content, ok := msgMap["content"].(string); ok {
|
||||
msg.Content = content
|
||||
}
|
||||
if rc, ok := msgMap["reasoning_content"].(string); ok && strings.TrimSpace(rc) != "" {
|
||||
msg.ReasoningContent = rc
|
||||
}
|
||||
if toolCallsRaw, ok := msgMap["tool_calls"]; ok && toolCallsRaw != nil {
|
||||
if toolCallsArray, ok := toolCallsRaw.([]interface{}); ok {
|
||||
for _, tcRaw := range toolCallsArray {
|
||||
tcMap, ok := tcRaw.(map[string]interface{})
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
toolCall := ToolCall{}
|
||||
if id, ok := tcMap["id"].(string); ok {
|
||||
toolCall.ID = id
|
||||
}
|
||||
if toolType, ok := tcMap["type"].(string); ok {
|
||||
toolCall.Type = toolType
|
||||
}
|
||||
if funcMap, ok := tcMap["function"].(map[string]interface{}); ok {
|
||||
toolCall.Function = FunctionCall{}
|
||||
if name, ok := funcMap["name"].(string); ok {
|
||||
toolCall.Function.Name = name
|
||||
}
|
||||
if argsRaw, ok := funcMap["arguments"]; ok {
|
||||
if argsStr, ok := argsRaw.(string); ok {
|
||||
var argsMap map[string]interface{}
|
||||
if err := json.Unmarshal([]byte(argsStr), &argsMap); err == nil {
|
||||
toolCall.Function.Arguments = argsMap
|
||||
}
|
||||
} else if argsMap, ok := argsRaw.(map[string]interface{}); ok {
|
||||
toolCall.Function.Arguments = argsMap
|
||||
}
|
||||
}
|
||||
}
|
||||
if toolCall.ID != "" {
|
||||
msg.ToolCalls = append(msg.ToolCalls, toolCall)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if toolCallID, ok := msgMap["tool_call_id"].(string); ok {
|
||||
msg.ToolCallID = toolCallID
|
||||
}
|
||||
if tn, ok := msgMap["tool_name"].(string); ok && strings.TrimSpace(tn) != "" {
|
||||
msg.ToolName = strings.TrimSpace(tn)
|
||||
} else if tn, ok := msgMap["name"].(string); ok && strings.TrimSpace(tn) != "" && strings.EqualFold(msg.Role, "tool") {
|
||||
msg.ToolName = strings.TrimSpace(tn)
|
||||
}
|
||||
out = append(out, msg)
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
// ExtractLastUserTurnMessages 仅保留最后一次 user 提问起的消息(不含更早的用户轮次;跳过 system)。
|
||||
// 与「继续对话」续跑所用轨迹范围一致:当前任务轮次,而非整段多轮对话历史。
|
||||
func ExtractLastUserTurnMessages(msgs []ChatMessage) []ChatMessage {
|
||||
if len(msgs) == 0 {
|
||||
return msgs
|
||||
}
|
||||
lastUser := -1
|
||||
for i, m := range msgs {
|
||||
if strings.EqualFold(m.Role, "user") {
|
||||
lastUser = i
|
||||
}
|
||||
}
|
||||
if lastUser < 0 {
|
||||
return msgs
|
||||
}
|
||||
trimmed := msgs[lastUser:]
|
||||
out := make([]ChatMessage, 0, len(trimmed))
|
||||
for _, m := range trimmed {
|
||||
if strings.EqualFold(m.Role, "system") {
|
||||
continue
|
||||
}
|
||||
out = append(out, m)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// ExtractLastUserTurnTraceJSON 在 JSON 轨迹上裁剪为最后一次 user 起的片段(供落库格式直接处理)。
|
||||
func ExtractLastUserTurnTraceJSON(traceInputJSON string) string {
|
||||
traceInputJSON = strings.TrimSpace(traceInputJSON)
|
||||
if traceInputJSON == "" {
|
||||
return traceInputJSON
|
||||
}
|
||||
var arr []map[string]interface{}
|
||||
if err := json.Unmarshal([]byte(traceInputJSON), &arr); err != nil {
|
||||
return traceInputJSON
|
||||
}
|
||||
lastUser := -1
|
||||
for i, m := range arr {
|
||||
if r, _ := m["role"].(string); strings.EqualFold(r, "user") {
|
||||
lastUser = i
|
||||
}
|
||||
}
|
||||
if lastUser <= 0 {
|
||||
return traceInputJSON
|
||||
}
|
||||
trimmed := arr[lastUser:]
|
||||
b, err := json.Marshal(trimmed)
|
||||
if err != nil {
|
||||
return traceInputJSON
|
||||
}
|
||||
return string(b)
|
||||
}
|
||||
|
||||
// MergeAssistantTraceOutput 将 last_react_output 合并进轨迹最后一条 assistant(与 loadHistoryFromAgentTrace 一致)。
|
||||
func MergeAssistantTraceOutput(msgs []ChatMessage, assistantOut string) []ChatMessage {
|
||||
assistantOut = strings.TrimSpace(assistantOut)
|
||||
if assistantOut == "" || len(msgs) == 0 {
|
||||
return msgs
|
||||
}
|
||||
out := append([]ChatMessage(nil), msgs...)
|
||||
last := &out[len(out)-1]
|
||||
if strings.EqualFold(last.Role, "assistant") && len(last.ToolCalls) == 0 {
|
||||
last.Content = assistantOut
|
||||
return out
|
||||
}
|
||||
out = append(out, ChatMessage{
|
||||
Role: "assistant",
|
||||
Content: assistantOut,
|
||||
})
|
||||
return out
|
||||
}
|
||||
|
||||
// MessagesToTraceJSON 将消息带序列化为 JSON(跳过 system)。
|
||||
func MessagesToTraceJSON(msgs []ChatMessage) (string, error) {
|
||||
filtered := make([]ChatMessage, 0, len(msgs))
|
||||
for _, m := range msgs {
|
||||
if strings.EqualFold(m.Role, "system") {
|
||||
continue
|
||||
}
|
||||
filtered = append(filtered, m)
|
||||
}
|
||||
b, err := json.Marshal(filtered)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return string(b), nil
|
||||
}
|
||||
@@ -0,0 +1,57 @@
|
||||
package agent
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestExtractLastUserTurnTraceJSON(t *testing.T) {
|
||||
raw := []map[string]interface{}{
|
||||
{"role": "user", "content": "old question"},
|
||||
{"role": "assistant", "content": "old answer"},
|
||||
{"role": "user", "content": "new target 1.1.1.1"},
|
||||
{"role": "assistant", "tool_calls": []interface{}{map[string]interface{}{
|
||||
"id": "c1", "type": "function",
|
||||
"function": map[string]interface{}{"name": "nmap", "arguments": "{}"},
|
||||
}}},
|
||||
{"role": "tool", "tool_call_id": "c1", "content": "open ports"},
|
||||
}
|
||||
b, _ := json.Marshal(raw)
|
||||
out := ExtractLastUserTurnTraceJSON(string(b))
|
||||
var trimmed []map[string]interface{}
|
||||
if err := json.Unmarshal([]byte(out), &trimmed); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(trimmed) != 3 {
|
||||
t.Fatalf("expected 3 messages, got %d", len(trimmed))
|
||||
}
|
||||
if trimmed[0]["content"] != "new target 1.1.1.1" {
|
||||
t.Fatalf("unexpected first message: %v", trimmed[0])
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractLastUserTurnMessagesSkipsSystem(t *testing.T) {
|
||||
msgs := []ChatMessage{
|
||||
{Role: "system", Content: "sys"},
|
||||
{Role: "user", Content: "q"},
|
||||
{Role: "assistant", Content: "a"},
|
||||
}
|
||||
out := ExtractLastUserTurnMessages(msgs)
|
||||
if len(out) != 2 {
|
||||
t.Fatalf("expected 2, got %d", len(out))
|
||||
}
|
||||
if out[0].Role != "user" {
|
||||
t.Fatal("expected user first")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMergeAssistantTraceOutput(t *testing.T) {
|
||||
msgs := []ChatMessage{
|
||||
{Role: "user", Content: "q"},
|
||||
{Role: "assistant", Content: "draft"},
|
||||
}
|
||||
out := MergeAssistantTraceOutput(msgs, "final summary")
|
||||
if out[len(out)-1].Content != "final summary" {
|
||||
t.Fatalf("expected merged output, got %q", out[len(out)-1].Content)
|
||||
}
|
||||
}
|
||||
+135
-11
@@ -3,8 +3,10 @@ package app
|
||||
import (
|
||||
"context"
|
||||
"crypto/subtle"
|
||||
"crypto/tls"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
@@ -13,9 +15,11 @@ import (
|
||||
"time"
|
||||
|
||||
"cyberstrike-ai/internal/agent"
|
||||
"cyberstrike-ai/internal/audit"
|
||||
"cyberstrike-ai/internal/c2"
|
||||
"cyberstrike-ai/internal/config"
|
||||
"cyberstrike-ai/internal/database"
|
||||
"cyberstrike-ai/internal/einoobserve"
|
||||
"cyberstrike-ai/internal/handler"
|
||||
"cyberstrike-ai/internal/knowledge"
|
||||
"cyberstrike-ai/internal/logger"
|
||||
@@ -29,6 +33,7 @@ import (
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/uuid"
|
||||
"go.uber.org/zap"
|
||||
"golang.org/x/net/http2"
|
||||
)
|
||||
|
||||
// App 应用
|
||||
@@ -52,14 +57,16 @@ type App struct {
|
||||
robotMu sync.Mutex // 保护钉钉/飞书长连接的 cancel
|
||||
dingCancel context.CancelFunc // 钉钉 Stream 取消函数,用于配置变更时重启
|
||||
larkCancel context.CancelFunc // 飞书长连接取消函数,用于配置变更时重启
|
||||
wechatCancel context.CancelFunc // 微信 iLink 长轮询取消函数
|
||||
c2Manager *c2.Manager // C2 管理器(未启用 C2 时为 nil)
|
||||
c2Watchdog *c2.SessionWatchdog // C2 会话看门狗
|
||||
c2WatchdogCancel context.CancelFunc // 看门狗取消函数
|
||||
c2Handler *handler.C2Handler // C2 REST(与 Manager 生命周期同步)
|
||||
auditSvc *audit.Service
|
||||
}
|
||||
|
||||
// New 创建新应用
|
||||
func New(cfg *config.Config, log *logger.Logger) (*App, error) {
|
||||
func New(cfg *config.Config, log *logger.Logger, configPath string) (*App, error) {
|
||||
gin.SetMode(gin.ReleaseMode)
|
||||
router := gin.Default()
|
||||
|
||||
@@ -88,8 +95,14 @@ func New(cfg *config.Config, log *logger.Logger) (*App, error) {
|
||||
return nil, fmt.Errorf("初始化数据库失败: %w", err)
|
||||
}
|
||||
|
||||
auditSvc := audit.NewService(db, cfg, log.Logger)
|
||||
audit.RegisterConversationCreateHook(auditSvc)
|
||||
auditSvc.PurgeExpired()
|
||||
audit.StartRetentionLoop(auditSvc, log.Logger)
|
||||
|
||||
// 创建MCP服务器(带数据库持久化)
|
||||
mcpServer := mcp.NewServerWithStorage(log.Logger, db)
|
||||
mcpServer.ConfigureHTTPToolCallTimeoutFromAgentMinutes(cfg.Agent.ToolTimeoutMinutes)
|
||||
|
||||
// 创建安全工具执行器
|
||||
executor := security.NewExecutor(&cfg.Security, mcpServer, log.Logger)
|
||||
@@ -216,6 +229,7 @@ func New(cfg *config.Config, log *logger.Logger) (*App, error) {
|
||||
|
||||
// 创建知识库API处理器
|
||||
knowledgeHandler = handler.NewKnowledgeHandler(knowledgeManager, knowledgeRetriever, knowledgeIndexer, db, log.Logger)
|
||||
knowledgeHandler.SetAudit(auditSvc)
|
||||
log.Logger.Info("知识库模块初始化完成", zap.Bool("handler_created", knowledgeHandler != nil))
|
||||
|
||||
// 扫描知识库并建立索引(异步)
|
||||
@@ -290,10 +304,10 @@ func New(cfg *config.Config, log *logger.Logger) (*App, error) {
|
||||
}()
|
||||
}
|
||||
|
||||
// 获取配置文件路径
|
||||
configPath := "config.yaml"
|
||||
if len(os.Args) > 1 {
|
||||
configPath = os.Args[1]
|
||||
// 配置文件路径必须由入口传入(与 flag -config 一致)。勿再用 os.Args[1],否则 ./cyberstrike-ai --https 会把 --https 当成路径。
|
||||
configPath = strings.TrimSpace(configPath)
|
||||
if configPath == "" {
|
||||
configPath = "config.yaml"
|
||||
}
|
||||
|
||||
skillsDir := skillpackage.SkillsRootFromConfig(cfg.SkillsDir, configPath)
|
||||
@@ -312,31 +326,42 @@ func New(cfg *config.Config, log *logger.Logger) (*App, error) {
|
||||
log.Logger.Warn("创建 agents 目录失败", zap.String("path", agentsDir), zap.Error(err))
|
||||
}
|
||||
markdownAgentsHandler := handler.NewMarkdownAgentsHandler(agentsDir)
|
||||
markdownAgentsHandler.SetAudit(auditSvc)
|
||||
log.Logger.Info("多代理 Markdown 子 Agent 目录", zap.String("agentsDir", agentsDir))
|
||||
|
||||
// 创建处理器
|
||||
agentHandler := handler.NewAgentHandler(agent, db, cfg, log.Logger)
|
||||
agentHandler.SetAudit(auditSvc)
|
||||
agentHandler.SetAgentsMarkdownDir(agentsDir)
|
||||
// 如果知识库已启用,设置知识库管理器到AgentHandler以便记录检索日志
|
||||
if knowledgeManager != nil {
|
||||
agentHandler.SetKnowledgeManager(knowledgeManager)
|
||||
}
|
||||
monitorHandler := handler.NewMonitorHandler(mcpServer, executor, db, log.Logger)
|
||||
monitorHandler.SetAudit(auditSvc)
|
||||
monitorHandler.SetExternalMCPManager(externalMCPMgr) // 设置外部MCP管理器,以便获取外部MCP执行记录
|
||||
notificationHandler := handler.NewNotificationHandler(db, agentHandler, log.Logger)
|
||||
groupHandler := handler.NewGroupHandler(db, log.Logger)
|
||||
authHandler := handler.NewAuthHandler(authManager, cfg, configPath, log.Logger)
|
||||
authHandler.SetAudit(auditSvc)
|
||||
attackChainHandler := handler.NewAttackChainHandler(db, &cfg.OpenAI, log.Logger)
|
||||
vulnerabilityHandler := handler.NewVulnerabilityHandler(db, log.Logger)
|
||||
vulnerabilityHandler.SetAudit(auditSvc)
|
||||
webshellHandler := handler.NewWebShellHandler(log.Logger, db)
|
||||
webshellHandler.SetAudit(auditSvc)
|
||||
chatUploadsHandler := handler.NewChatUploadsHandler(log.Logger)
|
||||
chatUploadsHandler.SetAudit(auditSvc)
|
||||
registerWebshellTools(mcpServer, db, webshellHandler, log.Logger)
|
||||
registerWebshellManagementTools(mcpServer, db, webshellHandler, log.Logger)
|
||||
configHandler := handler.NewConfigHandler(configPath, cfg, mcpServer, executor, agent, attackChainHandler, externalMCPMgr, log.Logger)
|
||||
configHandler.SetAudit(auditSvc)
|
||||
agentHandler.SetHitlToolWhitelistSaver(configHandler)
|
||||
externalMCPHandler := handler.NewExternalMCPHandler(externalMCPMgr, cfg, configPath, log.Logger)
|
||||
externalMCPHandler.SetAudit(auditSvc)
|
||||
roleHandler := handler.NewRoleHandler(cfg, configPath, log.Logger)
|
||||
roleHandler.SetAudit(auditSvc)
|
||||
skillsHandler := handler.NewSkillsHandler(cfg, configPath, log.Logger)
|
||||
skillsHandler.SetAudit(auditSvc)
|
||||
fofaHandler := handler.NewFofaHandler(cfg, log.Logger)
|
||||
terminalHandler := handler.NewTerminalHandler(log.Logger)
|
||||
if db != nil {
|
||||
@@ -351,9 +376,12 @@ func New(cfg *config.Config, log *logger.Logger) (*App, error) {
|
||||
registerC2Tools(mcpServer, c2Manager, log.Logger, cfg.Server.Port)
|
||||
}
|
||||
c2Handler := handler.NewC2Handler(c2Manager, log.Logger)
|
||||
c2Handler.SetAudit(auditSvc)
|
||||
|
||||
// 创建OpenAPI处理器
|
||||
conversationHandler := handler.NewConversationHandler(db, log.Logger)
|
||||
conversationHandler.SetAudit(auditSvc)
|
||||
auditHandler := handler.NewAuditHandler(db, auditSvc, log.Logger)
|
||||
robotHandler := handler.NewRobotHandler(cfg, db, agentHandler, log.Logger)
|
||||
openAPIHandler := handler.NewOpenAPIHandler(db, log.Logger, resultStorage, conversationHandler, agentHandler)
|
||||
|
||||
@@ -379,6 +407,7 @@ func New(cfg *config.Config, log *logger.Logger) (*App, error) {
|
||||
c2Watchdog: c2Watchdog,
|
||||
c2WatchdogCancel: watchdogCancel,
|
||||
c2Handler: c2Handler,
|
||||
auditSvc: auditSvc,
|
||||
}
|
||||
// 飞书/钉钉长连接(无需公网),启用时在后台启动;后续前端应用配置时会通过 RestartRobotConnections 重启
|
||||
app.startRobotConnections()
|
||||
@@ -444,9 +473,11 @@ func New(cfg *config.Config, log *logger.Logger) (*App, error) {
|
||||
configHandler.SetRetrieverUpdater(knowledgeRetriever)
|
||||
}
|
||||
|
||||
// 设置机器人连接重启器,前端应用配置后无需重启服务即可使钉钉/飞书新配置生效
|
||||
// 设置机器人连接重启器,前端应用配置后无需重启服务即可使钉钉/飞书/微信新配置生效
|
||||
configHandler.SetRobotRestarter(app)
|
||||
|
||||
wechatRobotHandler := handler.NewWechatRobotHandler(cfg, configHandler, log.Logger)
|
||||
|
||||
configHandler.SetC2Runtime(app)
|
||||
configHandler.SetC2ToolRegistrar(func() error {
|
||||
if app.config.C2.EnabledEffective() && app.c2Manager != nil {
|
||||
@@ -464,6 +495,7 @@ func New(cfg *config.Config, log *logger.Logger) (*App, error) {
|
||||
notificationHandler,
|
||||
conversationHandler,
|
||||
robotHandler,
|
||||
wechatRobotHandler,
|
||||
groupHandler,
|
||||
configHandler,
|
||||
externalMCPHandler,
|
||||
@@ -478,6 +510,7 @@ func New(cfg *config.Config, log *logger.Logger) (*App, error) {
|
||||
fofaHandler,
|
||||
terminalHandler,
|
||||
app.c2Handler,
|
||||
auditHandler,
|
||||
mcpServer,
|
||||
authManager,
|
||||
openAPIHandler,
|
||||
@@ -528,18 +561,49 @@ func (a *App) RunWithContext(ctx context.Context) error {
|
||||
}()
|
||||
}
|
||||
|
||||
// 启动主服务器
|
||||
// 启动主服务器(可选 HTTPS + HTTP/2,见 config server.tls_*)
|
||||
addr := fmt.Sprintf("%s:%d", a.config.Server.Host, a.config.Server.Port)
|
||||
a.logger.Info("启动HTTP服务器", zap.String("address", addr))
|
||||
tlsMode, tlsConf, certFile, keyFile, tlsErr := prepareMainServerTLS(&a.config.Server)
|
||||
if tlsErr != nil {
|
||||
return tlsErr
|
||||
}
|
||||
|
||||
srv := &http.Server{Addr: addr, Handler: a.router}
|
||||
var mainMux *mainServerMux
|
||||
httpRedirect := config.ServerHTTPRedirectEnabled(&a.config.Server)
|
||||
if tlsMode != mainTLSOff {
|
||||
srv.TLSConfig = tlsConf
|
||||
if err := http2.ConfigureServer(srv, &http2.Server{}); err != nil {
|
||||
return fmt.Errorf("主服务 HTTP/2 配置失败: %w", err)
|
||||
}
|
||||
switch tlsMode {
|
||||
case mainTLSFromFiles:
|
||||
a.logger.Info("启动 HTTPS 主服务(已启用 HTTP/2 协商)",
|
||||
zap.String("address", addr),
|
||||
zap.String("cert", certFile),
|
||||
)
|
||||
case mainTLSInMemorySelfSigned:
|
||||
a.logger.Info("启动 HTTPS 主服务(内存自签证书,仅测试;已启用 HTTP/2 协商)",
|
||||
zap.String("address", addr),
|
||||
)
|
||||
}
|
||||
if httpRedirect {
|
||||
a.logger.Info("已启用 HTTP→HTTPS 自动跳转(同端口嗅探分流)", zap.String("address", addr))
|
||||
}
|
||||
} else {
|
||||
a.logger.Info("启动 HTTP 主服务", zap.String("address", addr))
|
||||
}
|
||||
|
||||
// 监听 context 取消,优雅关闭 HTTP 服务器
|
||||
go func() {
|
||||
<-ctx.Done()
|
||||
shutdownCtx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
if err := srv.Shutdown(shutdownCtx); err != nil {
|
||||
if mainMux != nil {
|
||||
if err := mainMux.Shutdown(shutdownCtx); err != nil {
|
||||
a.logger.Error("HTTP/HTTPS 分流服务器关闭失败", zap.Error(err))
|
||||
}
|
||||
} else if err := srv.Shutdown(shutdownCtx); err != nil {
|
||||
a.logger.Error("HTTP服务器关闭失败", zap.Error(err))
|
||||
}
|
||||
if mcpServer != nil {
|
||||
@@ -549,7 +613,36 @@ func (a *App) RunWithContext(ctx context.Context) error {
|
||||
}
|
||||
}()
|
||||
|
||||
if err := srv.ListenAndServe(); err != nil && err != http.ErrServerClosed {
|
||||
var err error
|
||||
switch {
|
||||
case tlsMode != mainTLSOff && httpRedirect:
|
||||
var tlsConfReady *tls.Config
|
||||
tlsConfReady, err = ensureMainTLSConfigCerts(tlsMode, tlsConf, certFile, keyFile)
|
||||
if err != nil {
|
||||
return fmt.Errorf("加载 TLS 证书: %w", err)
|
||||
}
|
||||
srv.TLSConfig = tlsConfReady
|
||||
var ln net.Listener
|
||||
ln, err = net.Listen("tcp", addr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
mainMux = newMainServerMux(ln, srv, portFromListenAddr(addr), a.logger.Logger)
|
||||
err = mainMux.Serve()
|
||||
case tlsMode == mainTLSOff:
|
||||
err = srv.ListenAndServe()
|
||||
case tlsMode == mainTLSFromFiles:
|
||||
err = srv.ListenAndServeTLS(certFile, keyFile)
|
||||
case tlsMode == mainTLSInMemorySelfSigned:
|
||||
var ln net.Listener
|
||||
ln, err = tls.Listen("tcp", addr, srv.TLSConfig)
|
||||
if err == nil {
|
||||
err = srv.Serve(ln)
|
||||
}
|
||||
default:
|
||||
err = srv.ListenAndServe()
|
||||
}
|
||||
if err != nil && err != http.ErrServerClosed {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
@@ -557,6 +650,10 @@ func (a *App) RunWithContext(ctx context.Context) error {
|
||||
|
||||
// Shutdown 关闭应用
|
||||
func (a *App) Shutdown() {
|
||||
shutdownCtx, shutdownCancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
_ = einoobserve.ShutdownOtel(shutdownCtx)
|
||||
shutdownCancel()
|
||||
|
||||
// 停止钉钉/飞书长连接
|
||||
a.robotMu.Lock()
|
||||
if a.dingCancel != nil {
|
||||
@@ -606,9 +703,14 @@ func (a *App) startRobotConnections() {
|
||||
a.dingCancel = cancel
|
||||
go robot.StartDing(ctx, cfg.Robots, a.robotHandler, a.logger.Logger)
|
||||
}
|
||||
if cfg.Robots.Wechat.Enabled && cfg.Robots.Wechat.BotToken != "" {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
a.wechatCancel = cancel
|
||||
go robot.StartWechat(ctx, cfg.Robots, a.robotHandler, cfg.Version, a.logger.Logger)
|
||||
}
|
||||
}
|
||||
|
||||
// RestartRobotConnections 重启钉钉/飞书长连接,使前端应用配置后立即生效(实现 handler.RobotRestarter)
|
||||
// RestartRobotConnections 重启钉钉/飞书/微信长连接,使前端应用配置后立即生效(实现 handler.RobotRestarter)
|
||||
func (a *App) RestartRobotConnections() {
|
||||
a.robotMu.Lock()
|
||||
if a.dingCancel != nil {
|
||||
@@ -619,6 +721,10 @@ func (a *App) RestartRobotConnections() {
|
||||
a.larkCancel()
|
||||
a.larkCancel = nil
|
||||
}
|
||||
if a.wechatCancel != nil {
|
||||
a.wechatCancel()
|
||||
a.wechatCancel = nil
|
||||
}
|
||||
a.robotMu.Unlock()
|
||||
// 给旧 goroutine 一点时间退出
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
@@ -634,6 +740,7 @@ func setupRoutes(
|
||||
notificationHandler *handler.NotificationHandler,
|
||||
conversationHandler *handler.ConversationHandler,
|
||||
robotHandler *handler.RobotHandler,
|
||||
wechatRobotHandler *handler.WechatRobotHandler,
|
||||
groupHandler *handler.GroupHandler,
|
||||
configHandler *handler.ConfigHandler,
|
||||
externalMCPHandler *handler.ExternalMCPHandler,
|
||||
@@ -648,6 +755,7 @@ func setupRoutes(
|
||||
fofaHandler *handler.FofaHandler,
|
||||
terminalHandler *handler.TerminalHandler,
|
||||
c2Handler *handler.C2Handler,
|
||||
auditHandler *handler.AuditHandler,
|
||||
mcpServer *mcp.Server,
|
||||
authManager *security.AuthManager,
|
||||
openAPIHandler *handler.OpenAPIHandler,
|
||||
@@ -682,6 +790,12 @@ func setupRoutes(
|
||||
// 机器人测试(需登录):POST /api/robot/test,body: {"platform":"dingtalk","user_id":"test","text":"帮助"},用于验证机器人逻辑
|
||||
protected.POST("/robot/test", robotHandler.HandleRobotTest)
|
||||
|
||||
// 微信 iLink 扫码绑定(需登录)
|
||||
protected.POST("/robot/wechat/qrcode", wechatRobotHandler.HandleWechatQRCode)
|
||||
protected.GET("/robot/wechat/qrcode/status", wechatRobotHandler.HandleWechatQRCodeStatus)
|
||||
protected.POST("/robot/wechat/qrcode/verify", wechatRobotHandler.HandleWechatVerifyCode)
|
||||
protected.GET("/robot/wechat/status", wechatRobotHandler.HandleWechatStatus)
|
||||
|
||||
// Agent Loop
|
||||
protected.POST("/agent-loop", agentHandler.AgentLoop)
|
||||
// Agent Loop 流式输出
|
||||
@@ -778,6 +892,13 @@ func setupRoutes(
|
||||
protected.POST("/terminal/run/stream", terminalHandler.RunCommandStream)
|
||||
protected.GET("/terminal/ws", terminalHandler.RunCommandWS)
|
||||
|
||||
// 平台审计日志
|
||||
protected.GET("/audit/meta", auditHandler.Meta)
|
||||
protected.GET("/audit/summary", auditHandler.Summary)
|
||||
protected.GET("/audit/logs", auditHandler.ListLogs)
|
||||
protected.GET("/audit/logs/export", auditHandler.ExportLogs)
|
||||
protected.GET("/audit/logs/:id", auditHandler.GetLog)
|
||||
|
||||
// 外部MCP管理
|
||||
protected.GET("/external-mcp", externalMCPHandler.GetExternalMCPs)
|
||||
protected.GET("/external-mcp/stats", externalMCPHandler.GetExternalMCPStats)
|
||||
@@ -1839,6 +1960,9 @@ func initializeKnowledge(
|
||||
|
||||
// 创建知识库API处理器
|
||||
knowledgeHandler := handler.NewKnowledgeHandler(knowledgeManager, knowledgeRetriever, knowledgeIndexer, db, logger)
|
||||
if app != nil && app.auditSvc != nil {
|
||||
knowledgeHandler.SetAudit(app.auditSvc)
|
||||
}
|
||||
logger.Info("知识库模块初始化完成", zap.Bool("handler_created", knowledgeHandler != nil))
|
||||
|
||||
// 设置知识库管理器到AgentHandler以便记录检索日志
|
||||
|
||||
@@ -0,0 +1,196 @@
|
||||
package app
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// peekedConn 在已预读首字节后仍将连接交给 net/http 或 crypto/tls。
|
||||
type peekedConn struct {
|
||||
net.Conn
|
||||
r *bufio.Reader
|
||||
}
|
||||
|
||||
func (c *peekedConn) Read(p []byte) (int, error) {
|
||||
return c.r.Read(p)
|
||||
}
|
||||
|
||||
// oneConnListener 供 http.Server.Serve 处理单条 TCP 连接(含 keep-alive)。
|
||||
type oneConnListener struct {
|
||||
conn net.Conn
|
||||
addr net.Addr
|
||||
once sync.Once
|
||||
}
|
||||
|
||||
func (l *oneConnListener) Accept() (net.Conn, error) {
|
||||
var c net.Conn
|
||||
l.once.Do(func() {
|
||||
c = l.conn
|
||||
l.conn = nil
|
||||
})
|
||||
if c == nil {
|
||||
return nil, net.ErrClosed
|
||||
}
|
||||
return c, nil
|
||||
}
|
||||
|
||||
func (l *oneConnListener) Close() error { return nil }
|
||||
func (l *oneConnListener) Addr() net.Addr { return l.addr }
|
||||
|
||||
func isTLSHandshakeRecord(b byte) bool {
|
||||
return b == 0x16
|
||||
}
|
||||
|
||||
func newHTTPToHTTPSRedirectHandler(httpsPort int) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
host := r.Host
|
||||
if h, _, err := net.SplitHostPort(host); err == nil {
|
||||
host = h
|
||||
}
|
||||
var target string
|
||||
if httpsPort == 443 {
|
||||
target = fmt.Sprintf("https://%s%s", host, r.URL.RequestURI())
|
||||
} else {
|
||||
target = fmt.Sprintf("https://%s:%d%s", host, httpsPort, r.URL.RequestURI())
|
||||
}
|
||||
http.Redirect(w, r, target, http.StatusPermanentRedirect)
|
||||
})
|
||||
}
|
||||
|
||||
func portFromListenAddr(addr string) int {
|
||||
_, portStr, err := net.SplitHostPort(addr)
|
||||
if err != nil {
|
||||
return 443
|
||||
}
|
||||
p, err := strconv.Atoi(portStr)
|
||||
if err != nil || p <= 0 {
|
||||
return 443
|
||||
}
|
||||
return p
|
||||
}
|
||||
|
||||
func ensureMainTLSConfigCerts(mode mainTLSMode, tlsConf *tls.Config, certFile, keyFile string) (*tls.Config, error) {
|
||||
if mode != mainTLSFromFiles {
|
||||
return tlsConf, nil
|
||||
}
|
||||
if tlsConf == nil {
|
||||
tlsConf = &tls.Config{MinVersion: tls.VersionTLS12}
|
||||
}
|
||||
if len(tlsConf.Certificates) > 0 {
|
||||
return tlsConf, nil
|
||||
}
|
||||
cert, err := tls.LoadX509KeyPair(certFile, keyFile)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
tlsConf.Certificates = []tls.Certificate{cert}
|
||||
return tlsConf, nil
|
||||
}
|
||||
|
||||
type mainServerMux struct {
|
||||
ln net.Listener
|
||||
httpsSrv *http.Server
|
||||
redirectSrv *http.Server
|
||||
logger *zap.Logger
|
||||
}
|
||||
|
||||
func newMainServerMux(ln net.Listener, httpsSrv *http.Server, httpsPort int, logger *zap.Logger) *mainServerMux {
|
||||
return &mainServerMux{
|
||||
ln: ln,
|
||||
httpsSrv: httpsSrv,
|
||||
redirectSrv: &http.Server{Handler: newHTTPToHTTPSRedirectHandler(httpsPort), ReadHeaderTimeout: 10 * time.Second},
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
func (m *mainServerMux) Serve() error {
|
||||
for {
|
||||
conn, err := m.ln.Accept()
|
||||
if err != nil {
|
||||
if errors.Is(err, net.ErrClosed) {
|
||||
return http.ErrServerClosed
|
||||
}
|
||||
return err
|
||||
}
|
||||
go m.handleConn(conn)
|
||||
}
|
||||
}
|
||||
|
||||
func (m *mainServerMux) handleConn(raw net.Conn) {
|
||||
if err := raw.SetReadDeadline(time.Now().Add(10 * time.Second)); err != nil {
|
||||
_ = raw.Close()
|
||||
return
|
||||
}
|
||||
br := bufio.NewReader(raw)
|
||||
b, err := br.Peek(1)
|
||||
if err != nil {
|
||||
_ = raw.Close()
|
||||
return
|
||||
}
|
||||
_ = raw.SetReadDeadline(time.Time{})
|
||||
|
||||
pc := &peekedConn{Conn: raw, r: br}
|
||||
ocl := &oneConnListener{conn: pc, addr: raw.LocalAddr()}
|
||||
|
||||
if isTLSHandshakeRecord(b[0]) {
|
||||
m.serveHTTPS(pc, raw.LocalAddr())
|
||||
return
|
||||
}
|
||||
if err := m.redirectSrv.Serve(ocl); err != nil && !errors.Is(err, net.ErrClosed) && !errors.Is(err, http.ErrServerClosed) {
|
||||
m.logger.Debug("HTTP 重定向连接处理结束", zap.Error(err))
|
||||
}
|
||||
}
|
||||
|
||||
// serveHTTPS 在已嗅探为 TLS 的连接上完成握手,再按 ALPN 走 HTTP/2 或 HTTP/1.1。
|
||||
// 不能对同一 http.Server 并发调用 Serve(TLSConfig!=nil),否则握手/ALPN 会异常(浏览器 ERR_SSL_PROTOCOL_ERROR)。
|
||||
func (m *mainServerMux) serveHTTPS(pc *peekedConn, localAddr net.Addr) {
|
||||
tlsConn := tls.Server(pc, m.httpsSrv.TLSConfig)
|
||||
handCtx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
|
||||
defer cancel()
|
||||
if err := tlsConn.HandshakeContext(handCtx); err != nil {
|
||||
m.logger.Debug("TLS 握手失败", zap.Error(err))
|
||||
_ = pc.Close()
|
||||
return
|
||||
}
|
||||
|
||||
srv := m.httpsSrv
|
||||
if srv.TLSNextProto != nil {
|
||||
proto := tlsConn.ConnectionState().NegotiatedProtocol
|
||||
if fn := srv.TLSNextProto[proto]; fn != nil {
|
||||
fn(srv, tlsConn, srv.Handler)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
plain := *srv
|
||||
plain.TLSConfig = nil
|
||||
ocl := &oneConnListener{conn: tlsConn, addr: localAddr}
|
||||
if err := plain.Serve(ocl); err != nil && !errors.Is(err, net.ErrClosed) && !errors.Is(err, http.ErrServerClosed) {
|
||||
m.logger.Debug("HTTPS 连接处理结束", zap.Error(err))
|
||||
}
|
||||
}
|
||||
|
||||
func (m *mainServerMux) Shutdown(ctx context.Context) error {
|
||||
_ = m.ln.Close()
|
||||
var err1, err2 error
|
||||
if m.httpsSrv != nil {
|
||||
err1 = m.httpsSrv.Shutdown(ctx)
|
||||
}
|
||||
if m.redirectSrv != nil {
|
||||
err2 = m.redirectSrv.Shutdown(ctx)
|
||||
}
|
||||
if err1 != nil {
|
||||
return err1
|
||||
}
|
||||
return err2
|
||||
}
|
||||
@@ -0,0 +1,150 @@
|
||||
package app
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strconv"
|
||||
"testing"
|
||||
|
||||
"cyberstrike-ai/internal/config"
|
||||
|
||||
"golang.org/x/net/http2"
|
||||
)
|
||||
|
||||
func TestNewHTTPToHTTPSRedirectHandler(t *testing.T) {
|
||||
t.Parallel()
|
||||
tests := []struct {
|
||||
name string
|
||||
httpsPort int
|
||||
host string
|
||||
uri string
|
||||
wantTarget string
|
||||
}{
|
||||
{
|
||||
name: "non standard port",
|
||||
httpsPort: 8080,
|
||||
host: "127.0.0.1:8080",
|
||||
uri: "/login?next=/",
|
||||
wantTarget: "https://127.0.0.1:8080/login?next=/",
|
||||
},
|
||||
{
|
||||
name: "standard port",
|
||||
httpsPort: 443,
|
||||
host: "example.com:80",
|
||||
uri: "/",
|
||||
wantTarget: "https://example.com/",
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
tt := tt
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
h := newHTTPToHTTPSRedirectHandler(tt.httpsPort)
|
||||
req := httptest.NewRequest(http.MethodGet, "http://"+tt.host+tt.uri, nil)
|
||||
req.Host = tt.host
|
||||
rec := httptest.NewRecorder()
|
||||
h.ServeHTTP(rec, req)
|
||||
if rec.Code != http.StatusPermanentRedirect {
|
||||
t.Fatalf("status = %d, want %d", rec.Code, http.StatusPermanentRedirect)
|
||||
}
|
||||
if got := rec.Header().Get("Location"); got != tt.wantTarget {
|
||||
t.Fatalf("Location = %q, want %q", got, tt.wantTarget)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsTLSHandshakeRecord(t *testing.T) {
|
||||
t.Parallel()
|
||||
if !isTLSHandshakeRecord(0x16) {
|
||||
t.Fatal("expected TLS handshake record")
|
||||
}
|
||||
if isTLSHandshakeRecord('G') {
|
||||
t.Fatal("GET should not be TLS")
|
||||
}
|
||||
}
|
||||
|
||||
func TestServerHTTPRedirectEnabled(t *testing.T) {
|
||||
t.Parallel()
|
||||
disabled := false
|
||||
enabled := true
|
||||
if config.ServerHTTPRedirectEnabled(nil) {
|
||||
t.Fatal("nil config should disable redirect")
|
||||
}
|
||||
if !config.ServerHTTPRedirectEnabled(&config.ServerConfig{TLSEnabled: true}) {
|
||||
t.Fatal("HTTPS without explicit flag should enable redirect")
|
||||
}
|
||||
if config.ServerHTTPRedirectEnabled(&config.ServerConfig{TLSEnabled: true, TLSHTTPRedirect: &disabled}) {
|
||||
t.Fatal("explicit false should disable redirect")
|
||||
}
|
||||
if !config.ServerHTTPRedirectEnabled(&config.ServerConfig{TLSEnabled: true, TLSHTTPRedirect: &enabled}) {
|
||||
t.Fatal("explicit true should enable redirect")
|
||||
}
|
||||
if config.ServerHTTPRedirectEnabled(&config.ServerConfig{}) {
|
||||
t.Fatal("plain HTTP should not redirect")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMainServerMuxHTTPRedirectAndHTTPS(t *testing.T) {
|
||||
cert, err := generateMainServerSelfSignedCert()
|
||||
if err != nil {
|
||||
t.Fatalf("generate cert: %v", err)
|
||||
}
|
||||
handler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
_, _ = io.WriteString(w, "ok")
|
||||
})
|
||||
srv := &http.Server{Handler: handler, TLSConfig: &tls.Config{
|
||||
MinVersion: tls.VersionTLS12,
|
||||
Certificates: []tls.Certificate{cert},
|
||||
}}
|
||||
if err := http2.ConfigureServer(srv, &http2.Server{}); err != nil {
|
||||
t.Fatalf("configure http2: %v", err)
|
||||
}
|
||||
|
||||
ln, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
t.Fatalf("listen: %v", err)
|
||||
}
|
||||
defer ln.Close()
|
||||
|
||||
mux := newMainServerMux(ln, srv, portFromListenAddr(ln.Addr().String()), nil)
|
||||
go func() { _ = mux.Serve() }()
|
||||
|
||||
client := &http.Client{
|
||||
Transport: &http.Transport{
|
||||
TLSClientConfig: &tls.Config{InsecureSkipVerify: true, MinVersion: tls.VersionTLS12},
|
||||
},
|
||||
CheckRedirect: func(_ *http.Request, _ []*http.Request) error {
|
||||
return http.ErrUseLastResponse
|
||||
},
|
||||
}
|
||||
addr := ln.Addr().String()
|
||||
|
||||
httpResp, err := client.Get("http://" + addr + "/")
|
||||
if err != nil {
|
||||
t.Fatalf("http get: %v", err)
|
||||
}
|
||||
_ = httpResp.Body.Close()
|
||||
if httpResp.StatusCode != http.StatusPermanentRedirect {
|
||||
t.Fatalf("http status = %d, want %d", httpResp.StatusCode, http.StatusPermanentRedirect)
|
||||
}
|
||||
if got := httpResp.Header.Get("Location"); got != "https://127.0.0.1:"+strconv.Itoa(portFromListenAddr(addr))+"/" {
|
||||
t.Fatalf("Location = %q", got)
|
||||
}
|
||||
|
||||
httpsResp, err := client.Get("https://" + addr + "/")
|
||||
if err != nil {
|
||||
t.Fatalf("https get: %v", err)
|
||||
}
|
||||
defer httpsResp.Body.Close()
|
||||
if httpsResp.StatusCode != http.StatusOK {
|
||||
t.Fatalf("https status = %d, want %d", httpsResp.StatusCode, http.StatusOK)
|
||||
}
|
||||
body, _ := io.ReadAll(httpsResp.Body)
|
||||
if string(body) != "ok" {
|
||||
t.Fatalf("body = %q, want ok", body)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,86 @@
|
||||
package app
|
||||
|
||||
import (
|
||||
"crypto/ecdsa"
|
||||
"crypto/elliptic"
|
||||
"crypto/rand"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"crypto/x509/pkix"
|
||||
"encoding/pem"
|
||||
"fmt"
|
||||
"math/big"
|
||||
"net"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"cyberstrike-ai/internal/config"
|
||||
)
|
||||
|
||||
// mainTLSMode 主 Web 服务 TLS 启动方式。
|
||||
type mainTLSMode int
|
||||
|
||||
const (
|
||||
mainTLSOff mainTLSMode = iota
|
||||
mainTLSFromFiles
|
||||
mainTLSInMemorySelfSigned
|
||||
)
|
||||
|
||||
// prepareMainServerTLS 根据 server 配置决定主站是否启用 HTTPS(及 HTTP/2 协商)。
|
||||
// fromFiles:使用 tls_cert_path + tls_key_path,由 http.Server.ListenAndServeTLS 加载 PEM。
|
||||
// inMemory:tls_auto_self_sign 生成的自签证书,仅用于本地/测试。
|
||||
func prepareMainServerTLS(cfg *config.ServerConfig) (mode mainTLSMode, tlsConf *tls.Config, certFile, keyFile string, err error) {
|
||||
if cfg == nil || !config.MainWebUIUsesHTTPS(cfg) {
|
||||
return mainTLSOff, nil, "", "", nil
|
||||
}
|
||||
certFile = strings.TrimSpace(cfg.TLSCertPath)
|
||||
keyFile = strings.TrimSpace(cfg.TLSKeyPath)
|
||||
if certFile != "" && keyFile != "" {
|
||||
// 证书由 ListenAndServeTLS 从文件加载;此处仅提供最小 TLS 配置供 http2.ConfigureServer 合并 ALPN。
|
||||
return mainTLSFromFiles, &tls.Config{MinVersion: tls.VersionTLS12}, certFile, keyFile, nil
|
||||
}
|
||||
if cfg.TLSAutoSelfSign {
|
||||
cert, genErr := generateMainServerSelfSignedCert()
|
||||
if genErr != nil {
|
||||
return mainTLSOff, nil, "", "", fmt.Errorf("生成自签 TLS 证书: %w", genErr)
|
||||
}
|
||||
tlsConf = &tls.Config{
|
||||
MinVersion: tls.VersionTLS12,
|
||||
Certificates: []tls.Certificate{cert},
|
||||
}
|
||||
return mainTLSInMemorySelfSigned, tlsConf, "", "", nil
|
||||
}
|
||||
return mainTLSOff, nil, "", "", fmt.Errorf("server: 已启用 TLS(tls_enabled / tls_auto_self_sign / 证书路径),请设置 tls_cert_path 与 tls_key_path,或将 tls_auto_self_sign 设为 true(仅测试环境)")
|
||||
}
|
||||
|
||||
func generateMainServerSelfSignedCert() (tls.Certificate, error) {
|
||||
priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||
if err != nil {
|
||||
return tls.Certificate{}, err
|
||||
}
|
||||
serial, err := rand.Int(rand.Reader, big.NewInt(1<<62))
|
||||
if err != nil {
|
||||
return tls.Certificate{}, err
|
||||
}
|
||||
tmpl := &x509.Certificate{
|
||||
SerialNumber: serial,
|
||||
Subject: pkix.Name{CommonName: "CyberStrikeAI"},
|
||||
NotBefore: time.Now().Add(-1 * time.Hour),
|
||||
NotAfter: time.Now().Add(365 * 24 * time.Hour),
|
||||
KeyUsage: x509.KeyUsageDigitalSignature,
|
||||
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
|
||||
IPAddresses: []net.IP{net.ParseIP("127.0.0.1"), net.ParseIP("::1")},
|
||||
DNSNames: []string{"localhost"},
|
||||
}
|
||||
der, err := x509.CreateCertificate(rand.Reader, tmpl, tmpl, &priv.PublicKey, priv)
|
||||
if err != nil {
|
||||
return tls.Certificate{}, err
|
||||
}
|
||||
keyDER, err := x509.MarshalECPrivateKey(priv)
|
||||
if err != nil {
|
||||
return tls.Certificate{}, err
|
||||
}
|
||||
certPEM := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: der})
|
||||
keyPEM := pem.EncodeToMemory(&pem.Block{Type: "EC PRIVATE KEY", Bytes: keyDER})
|
||||
return tls.X509KeyPair(certPEM, keyPEM)
|
||||
}
|
||||
@@ -82,7 +82,7 @@ func NewBuilder(db *database.DB, openAIConfig *config.OpenAIConfig, logger *zap.
|
||||
}
|
||||
}
|
||||
|
||||
// BuildChainFromConversation 从对话构建攻击链(简化版本:用户输入+最后一轮ReAct输入+大模型输出)
|
||||
// BuildChainFromConversation 从对话构建攻击链(单次 LLM 调用;输入为当前任务轮次的 last_react 轨迹,与继续对话续跑范围一致)。
|
||||
func (b *Builder) BuildChainFromConversation(ctx context.Context, conversationID string) (*Chain, error) {
|
||||
b.logger.Info("开始构建攻击链(简化版本)", zap.String("conversationId", conversationID))
|
||||
|
||||
@@ -157,33 +157,34 @@ func (b *Builder) BuildChainFromConversation(ctx context.Context, conversationID
|
||||
var reactInputFinal string
|
||||
var dataSource string // 记录数据来源
|
||||
|
||||
// 如果成功获取到保存的ReAct数据,直接使用
|
||||
if reactInputJSON != "" && modelOutput != "" {
|
||||
// 计算 ReAct 输入的哈希值,用于追踪
|
||||
hash := sha256.Sum256([]byte(reactInputJSON))
|
||||
reactInputHash := hex.EncodeToString(hash[:])[:16] // 使用前16字符作为短标识
|
||||
// 优先使用落库的代理轨迹(与继续对话 loadHistoryFromAgentTrace 同源),并裁剪为「当前任务轮次」
|
||||
if reactInputJSON != "" {
|
||||
trimmedJSON := agent.ExtractLastUserTurnTraceJSON(reactInputJSON)
|
||||
hash := sha256.Sum256([]byte(trimmedJSON))
|
||||
reactInputHash := hex.EncodeToString(hash[:])[:16]
|
||||
|
||||
// 统计消息数量
|
||||
var messageCount int
|
||||
var tempMessages []interface{}
|
||||
if json.Unmarshal([]byte(reactInputJSON), &tempMessages) == nil {
|
||||
messageCount = len(tempMessages)
|
||||
if msgs, parseErr := agent.ParseTraceMessages(trimmedJSON); parseErr == nil {
|
||||
messageCount = len(msgs)
|
||||
msgs = agent.MergeAssistantTraceOutput(msgs, modelOutput)
|
||||
reactInputFinal = b.formatAgentTraceFromChatMessages(msgs)
|
||||
} else {
|
||||
b.logger.Warn("解析代理轨迹失败,回退原始 JSON 格式化", zap.Error(parseErr))
|
||||
reactInputFinal = b.formatAgentTraceInputFromJSON(trimmedJSON)
|
||||
if strings.TrimSpace(modelOutput) != "" {
|
||||
reactInputFinal += "\n\n## 助手结论(last_react_output)\n\n" + modelOutput
|
||||
}
|
||||
}
|
||||
|
||||
dataSource = "database_last_agent_trace"
|
||||
b.logger.Info("使用保存的ReAct数据构建攻击链",
|
||||
dataSource = "last_user_turn_agent_trace"
|
||||
b.logger.Info("使用当前任务轮次代理轨迹构建攻击链(与续跑上下文范围一致)",
|
||||
zap.String("conversationId", conversationID),
|
||||
zap.String("dataSource", dataSource),
|
||||
zap.Int("reactInputSize", len(reactInputJSON)),
|
||||
zap.Int("traceInputSizeBeforeTrim", len(reactInputJSON)),
|
||||
zap.Int("traceInputSizeAfterTrim", len(trimmedJSON)),
|
||||
zap.Int("messageCount", messageCount),
|
||||
zap.String("reactInputHash", reactInputHash),
|
||||
zap.Int("modelOutputSize", len(modelOutput)))
|
||||
|
||||
// 从保存的ReAct输入(JSON格式)中提取用户输入
|
||||
// userInput = b.extractUserInputFromReActInput(reactInputJSON)
|
||||
|
||||
// 将JSON格式的messages转换为可读格式
|
||||
reactInputFinal = b.formatAgentTraceInputFromJSON(reactInputJSON)
|
||||
} else {
|
||||
// 2. 如果没有保存的ReAct数据,从对话消息构建
|
||||
dataSource = "messages_table"
|
||||
@@ -243,8 +244,15 @@ func (b *Builder) BuildChainFromConversation(ctx context.Context, conversationID
|
||||
}
|
||||
}
|
||||
|
||||
// 3. 构建简化的prompt,一次性传递给大模型
|
||||
prompt := b.buildSimplePrompt(reactInputFinal, modelOutput)
|
||||
// 3. 按 token 预算压缩输入,再构建 prompt(避免超出模型上下文)
|
||||
reactInputFinal, modelOutput, _ = b.fitAttackChainPayload(reactInputFinal, modelOutput)
|
||||
|
||||
// 4. 构建 prompt 并单次调用大模型(助手结论已并入轨迹时不再重复传入)
|
||||
promptAssistantOut := modelOutput
|
||||
if reactInputJSON != "" {
|
||||
promptAssistantOut = ""
|
||||
}
|
||||
prompt := b.buildSimplePrompt(reactInputFinal, promptAssistantOut)
|
||||
// fmt.Println(prompt)
|
||||
// 6. 调用AI生成攻击链(一次性,不做任何处理)
|
||||
chainJSON, err := b.callAIForChainGeneration(ctx, prompt)
|
||||
@@ -301,7 +309,7 @@ func (b *Builder) formatProcessDetailsForAttackChain(details []database.ProcessD
|
||||
// 目标:以主 agent(编排器)视角输出整轮迭代
|
||||
// - 保留:编排器工具调用/结果、对子代理的 task 调度、子代理最终回复(不含推理)
|
||||
// - 丢弃:thinking/planning/progress 等噪声、子代理的工具细节与推理过程
|
||||
if d.EventType == "progress" || d.EventType == "thinking" || d.EventType == "planning" {
|
||||
if d.EventType == "progress" || d.EventType == "thinking" || d.EventType == "reasoning_chain" || d.EventType == "planning" {
|
||||
continue
|
||||
}
|
||||
|
||||
@@ -366,10 +374,17 @@ func (b *Builder) formatProcessDetailsForAttackChain(details []database.ProcessD
|
||||
return strings.TrimSpace(sb.String())
|
||||
}
|
||||
|
||||
// buildAgentTraceInput 构建最后一轮ReAct的输入(历史消息+当前用户输入)
|
||||
// buildAgentTraceInput 构建最后一轮 ReAct 的输入(从最后一条 user 消息起,不含更早轮次)。
|
||||
func (b *Builder) buildAgentTraceInput(messages []database.Message) string {
|
||||
start := 0
|
||||
for i := len(messages) - 1; i >= 0; i-- {
|
||||
if strings.EqualFold(messages[i].Role, "user") {
|
||||
start = i
|
||||
break
|
||||
}
|
||||
}
|
||||
var builder strings.Builder
|
||||
for _, msg := range messages {
|
||||
for _, msg := range messages[start:] {
|
||||
builder.WriteString(fmt.Sprintf("[%s]: %s\n\n", msg.Role, msg.Content))
|
||||
}
|
||||
return builder.String()
|
||||
@@ -396,67 +411,66 @@ func (b *Builder) buildAgentTraceInput(messages []database.Message) string {
|
||||
// return ""
|
||||
// }
|
||||
|
||||
// formatAgentTraceInputFromJSON 将JSON格式的messages数组转换为可读的字符串格式
|
||||
// formatAgentTraceInputFromJSON 将 JSON 轨迹转为可读文本(会先按当前任务轮次裁剪)。
|
||||
func (b *Builder) formatAgentTraceInputFromJSON(reactInputJSON string) string {
|
||||
var messages []map[string]interface{}
|
||||
if err := json.Unmarshal([]byte(reactInputJSON), &messages); err != nil {
|
||||
trimmed := agent.ExtractLastUserTurnTraceJSON(reactInputJSON)
|
||||
msgs, err := agent.ParseTraceMessages(trimmed)
|
||||
if err != nil {
|
||||
b.logger.Warn("解析ReAct输入JSON失败", zap.Error(err))
|
||||
return reactInputJSON // 如果解析失败,返回原始JSON
|
||||
return trimmed
|
||||
}
|
||||
return b.formatAgentTraceFromChatMessages(msgs)
|
||||
}
|
||||
|
||||
// formatAgentTraceFromChatMessages 将代理消息带格式化为攻击链分析输入(与续跑轨迹字段一致)。
|
||||
func (b *Builder) formatAgentTraceFromChatMessages(msgs []agent.ChatMessage) string {
|
||||
var builder strings.Builder
|
||||
for _, msg := range messages {
|
||||
role, _ := msg["role"].(string)
|
||||
content, _ := msg["content"].(string)
|
||||
for _, msg := range msgs {
|
||||
role := msg.Role
|
||||
content := msg.Content
|
||||
|
||||
// 处理assistant消息:提取tool_calls信息
|
||||
if role == "assistant" {
|
||||
if toolCalls, ok := msg["tool_calls"].([]interface{}); ok && len(toolCalls) > 0 {
|
||||
// 如果有文本内容,先显示
|
||||
if content != "" {
|
||||
builder.WriteString(fmt.Sprintf("[%s]: %s\n", role, content))
|
||||
}
|
||||
// 详细显示每个工具调用
|
||||
builder.WriteString(fmt.Sprintf("[%s] 工具调用 (%d个):\n", role, len(toolCalls)))
|
||||
for i, toolCall := range toolCalls {
|
||||
if tc, ok := toolCall.(map[string]interface{}); ok {
|
||||
toolCallID, _ := tc["id"].(string)
|
||||
if funcData, ok := tc["function"].(map[string]interface{}); ok {
|
||||
toolName, _ := funcData["name"].(string)
|
||||
arguments, _ := funcData["arguments"].(string)
|
||||
builder.WriteString(fmt.Sprintf(" [工具调用 %d]\n", i+1))
|
||||
builder.WriteString(fmt.Sprintf(" ID: %s\n", toolCallID))
|
||||
builder.WriteString(fmt.Sprintf(" 工具名称: %s\n", toolName))
|
||||
builder.WriteString(fmt.Sprintf(" 参数: %s\n", arguments))
|
||||
}
|
||||
if strings.EqualFold(role, "assistant") && len(msg.ToolCalls) > 0 {
|
||||
if content != "" {
|
||||
builder.WriteString(fmt.Sprintf("[%s]: %s\n", role, content))
|
||||
}
|
||||
builder.WriteString(fmt.Sprintf("[%s] 工具调用 (%d个):\n", role, len(msg.ToolCalls)))
|
||||
for i, tc := range msg.ToolCalls {
|
||||
args := ""
|
||||
if tc.Function.Arguments != nil {
|
||||
if b, err := json.Marshal(tc.Function.Arguments); err == nil {
|
||||
args = string(b)
|
||||
}
|
||||
}
|
||||
builder.WriteString("\n")
|
||||
continue
|
||||
builder.WriteString(fmt.Sprintf(" [工具调用 %d]\n", i+1))
|
||||
builder.WriteString(fmt.Sprintf(" ID: %s\n", tc.ID))
|
||||
builder.WriteString(fmt.Sprintf(" 工具名称: %s\n", tc.Function.Name))
|
||||
builder.WriteString(fmt.Sprintf(" 参数: %s\n", args))
|
||||
}
|
||||
builder.WriteString("\n")
|
||||
continue
|
||||
}
|
||||
|
||||
// 处理tool消息:显示tool_call_id和完整内容
|
||||
if role == "tool" {
|
||||
toolCallID, _ := msg["tool_call_id"].(string)
|
||||
if toolCallID != "" {
|
||||
builder.WriteString(fmt.Sprintf("[%s] (tool_call_id: %s):\n%s\n\n", role, toolCallID, content))
|
||||
if strings.EqualFold(role, "tool") {
|
||||
if msg.ToolCallID != "" {
|
||||
builder.WriteString(fmt.Sprintf("[%s] (tool_call_id: %s):\n%s\n\n", role, msg.ToolCallID, content))
|
||||
} else {
|
||||
builder.WriteString(fmt.Sprintf("[%s]: %s\n\n", role, content))
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
// 其他消息类型(system, user等)正常显示
|
||||
builder.WriteString(fmt.Sprintf("[%s]: %s\n\n", role, content))
|
||||
}
|
||||
|
||||
return builder.String()
|
||||
}
|
||||
|
||||
// buildSimplePrompt 构建简化的prompt
|
||||
func (b *Builder) buildSimplePrompt(reactInput, modelOutput string) string {
|
||||
return fmt.Sprintf(`你是专业的安全测试分析师和攻击链构建专家。你的任务是根据对话记录和工具执行结果,构建一个逻辑清晰、有教育意义的攻击链图,完整展现渗透测试的思维过程和执行路径。
|
||||
return fmt.Sprintf(`你是专业的安全测试分析师和攻击链构建专家。你的任务是根据**当前任务轮次**的对话记录和工具执行结果,一次性输出攻击链 JSON(不要分多轮追问)。
|
||||
|
||||
## 输入范围(与「继续对话」续跑一致)
|
||||
- 下方「ReAct 轨迹」仅包含**最后一次用户提问之后**的消息与工具结果(last_react 当前任务轮次),不含更早的用户提问轮次。
|
||||
- 「助手结论」为同轮任务的最终输出摘要(last_react_output);节点须与轨迹中的实际工具执行一致,严禁编造。
|
||||
|
||||
## 核心目标
|
||||
|
||||
@@ -618,12 +632,9 @@ func (b *Builder) buildSimplePrompt(reactInput, modelOutput string) string {
|
||||
5. **漏洞确认**:如何确认漏洞存在?(action→vulnerability)
|
||||
6. **攻击路径**:完整的攻击路径是什么?(从target到vulnerability的路径)
|
||||
|
||||
## 最后一轮ReAct输入
|
||||
## 当前任务 ReAct 轨迹(含工具执行;助手结论见轨迹末尾 assistant)
|
||||
|
||||
%s
|
||||
|
||||
## 大模型输出
|
||||
|
||||
%s
|
||||
|
||||
## 输出格式
|
||||
@@ -752,7 +763,15 @@ func (b *Builder) buildSimplePrompt(reactInput, modelOutput string) string {
|
||||
9. **不要过度精简**:如果实际执行步骤较多,可以适当增加节点数量(最多20个),确保不遗漏关键步骤。
|
||||
10. **输出前验证**:在输出JSON前,必须验证所有边都满足source < target的条件,确保DAG结构正确。
|
||||
|
||||
现在开始分析并构建攻击链:`, reactInput, modelOutput)
|
||||
现在开始分析并构建攻击链:`, reactInput, assistantOutSection(modelOutput))
|
||||
}
|
||||
|
||||
func assistantOutSection(modelOutput string) string {
|
||||
modelOutput = strings.TrimSpace(modelOutput)
|
||||
if modelOutput == "" {
|
||||
return ""
|
||||
}
|
||||
return "\n## 助手结论(补充)\n\n" + modelOutput + "\n"
|
||||
}
|
||||
|
||||
// saveChain 保存攻击链到数据库
|
||||
@@ -812,7 +831,7 @@ func (b *Builder) callAIForChainGeneration(ctx context.Context, prompt string) (
|
||||
},
|
||||
},
|
||||
"temperature": 0.3,
|
||||
"max_completion_tokens": 80000,
|
||||
"max_completion_tokens": attackChainMaxCompletionTokens(b.maxTokens),
|
||||
}
|
||||
|
||||
var apiResponse struct {
|
||||
|
||||
@@ -0,0 +1,248 @@
|
||||
package attackchain
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"unicode/utf8"
|
||||
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
const (
|
||||
attackChainTruncationMarker = "\n\n...[攻击链输入已截断 / attack chain input truncated]...\n\n"
|
||||
attackChainSystemReserve = 256
|
||||
attackChainSafetyReserve = 2048
|
||||
)
|
||||
|
||||
// attackChainMaxCompletionTokens 为攻击链 JSON 输出预留的 completion token 上限。
|
||||
func attackChainMaxCompletionTokens(maxTotal int) int {
|
||||
const capTokens = 16384
|
||||
if maxTotal <= 0 {
|
||||
return 8192
|
||||
}
|
||||
v := maxTotal / 8
|
||||
if v < 4096 {
|
||||
v = 4096
|
||||
}
|
||||
if v > capTokens {
|
||||
v = capTokens
|
||||
}
|
||||
return v
|
||||
}
|
||||
|
||||
func (b *Builder) modelName() string {
|
||||
if b.openAIConfig != nil && b.openAIConfig.Model != "" {
|
||||
return b.openAIConfig.Model
|
||||
}
|
||||
return "gpt-4"
|
||||
}
|
||||
|
||||
func (b *Builder) countTokens(text string) int {
|
||||
if text == "" {
|
||||
return 0
|
||||
}
|
||||
n, err := b.tokenCounter.Count(b.modelName(), text)
|
||||
if err != nil {
|
||||
return utf8.RuneCountInString(text) / 4
|
||||
}
|
||||
return n
|
||||
}
|
||||
|
||||
// attackChainPayloadTokenBudget 计算 reactInput + modelOutput 可用的 token 预算。
|
||||
func (b *Builder) attackChainPayloadTokenBudget() int {
|
||||
maxTotal := b.maxTokens
|
||||
if maxTotal <= 0 {
|
||||
maxTotal = 100000
|
||||
}
|
||||
templateTok := b.countTokens(b.buildSimplePrompt("", ""))
|
||||
completion := attackChainMaxCompletionTokens(maxTotal)
|
||||
reserve := templateTok + attackChainSystemReserve + completion + attackChainSafetyReserve
|
||||
budget := maxTotal - reserve
|
||||
minBudget := maxTotal * 35 / 100
|
||||
if budget < minBudget {
|
||||
budget = minBudget
|
||||
}
|
||||
if budget < 4096 {
|
||||
budget = 4096
|
||||
}
|
||||
return budget
|
||||
}
|
||||
|
||||
// fitAttackChainPayload 在构建最终 prompt 前压缩 ReAct 轨迹与模型输出,避免超出模型上下文。
|
||||
func (b *Builder) fitAttackChainPayload(reactInput, modelOutput string) (string, string, bool) {
|
||||
budget := b.attackChainPayloadTokenBudget()
|
||||
modelBudget := budget * 15 / 100
|
||||
if modelBudget < 512 {
|
||||
modelBudget = 512
|
||||
}
|
||||
reactBudget := budget - modelBudget
|
||||
|
||||
origReactTok := b.countTokens(reactInput)
|
||||
origModelTok := b.countTokens(modelOutput)
|
||||
truncated := false
|
||||
|
||||
outModel := modelOutput
|
||||
if origModelTok > modelBudget {
|
||||
outModel = truncateTextByTokens(b, modelOutput, modelBudget)
|
||||
truncated = true
|
||||
}
|
||||
|
||||
outReact := reactInput
|
||||
perToolLimits := []int{12000, 6000, 3000, 1500, 800}
|
||||
for _, lim := range perToolLimits {
|
||||
compact := compactFormattedToolBodies(outReact, lim)
|
||||
if compact != outReact {
|
||||
outReact = compact
|
||||
truncated = true
|
||||
}
|
||||
if b.countTokens(outReact) <= reactBudget {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if b.countTokens(outReact) > reactBudget {
|
||||
outReact = truncateTextByTokens(b, outReact, reactBudget)
|
||||
truncated = true
|
||||
}
|
||||
|
||||
if truncated {
|
||||
b.logger.Info("攻击链输入已按 token 预算截断",
|
||||
zap.Int("maxTotalTokens", b.maxTokens),
|
||||
zap.Int("payloadBudget", budget),
|
||||
zap.Int("reactBudget", reactBudget),
|
||||
zap.Int("modelBudget", modelBudget),
|
||||
zap.Int("reactInputTokensBefore", origReactTok),
|
||||
zap.Int("reactInputTokensAfter", b.countTokens(outReact)),
|
||||
zap.Int("modelOutputTokensBefore", origModelTok),
|
||||
zap.Int("modelOutputTokensAfter", b.countTokens(outModel)),
|
||||
zap.Int("maxCompletionTokens", attackChainMaxCompletionTokens(b.maxTokens)),
|
||||
)
|
||||
}
|
||||
|
||||
return outReact, outModel, truncated
|
||||
}
|
||||
|
||||
// compactFormattedToolBodies 缩短格式化 trace 中 [tool] 消息的正文,保留工具头与调用 ID。
|
||||
func compactFormattedToolBodies(s string, maxRunesPerBody int) string {
|
||||
if maxRunesPerBody <= 0 || s == "" {
|
||||
return s
|
||||
}
|
||||
const marker = "[tool]"
|
||||
var out strings.Builder
|
||||
remaining := s
|
||||
changed := false
|
||||
for {
|
||||
idx := strings.Index(remaining, marker)
|
||||
if idx < 0 {
|
||||
out.WriteString(remaining)
|
||||
break
|
||||
}
|
||||
out.WriteString(remaining[:idx])
|
||||
remaining = remaining[idx:]
|
||||
nl := strings.IndexByte(remaining, '\n')
|
||||
if nl < 0 {
|
||||
out.WriteString(remaining)
|
||||
break
|
||||
}
|
||||
header := remaining[:nl+1]
|
||||
remaining = remaining[nl+1:]
|
||||
bodyEnd := strings.Index(remaining, "\n\n[")
|
||||
var body, rest string
|
||||
if bodyEnd < 0 {
|
||||
body = remaining
|
||||
rest = ""
|
||||
} else {
|
||||
body = remaining[:bodyEnd]
|
||||
rest = remaining[bodyEnd:]
|
||||
}
|
||||
if runeLen(body) > maxRunesPerBody {
|
||||
body = truncateRunesWithNotice(body, maxRunesPerBody)
|
||||
changed = true
|
||||
}
|
||||
out.WriteString(header)
|
||||
out.WriteString(body)
|
||||
remaining = rest
|
||||
if rest == "" {
|
||||
break
|
||||
}
|
||||
}
|
||||
if !changed {
|
||||
return s
|
||||
}
|
||||
return out.String()
|
||||
}
|
||||
|
||||
func truncateTextByTokens(b *Builder, text string, maxTokens int) string {
|
||||
if maxTokens <= 0 || text == "" {
|
||||
return ""
|
||||
}
|
||||
if b.countTokens(text) <= maxTokens {
|
||||
return text
|
||||
}
|
||||
markerTok := b.countTokens(attackChainTruncationMarker)
|
||||
usable := maxTokens - markerTok
|
||||
if usable < 256 {
|
||||
usable = maxTokens / 2
|
||||
}
|
||||
headBudget := usable * 60 / 100
|
||||
tailBudget := usable - headBudget
|
||||
head := takeTokensFromStart(b, text, headBudget)
|
||||
tail := takeTokensFromEnd(b, text, tailBudget)
|
||||
return head + attackChainTruncationMarker + tail
|
||||
}
|
||||
|
||||
func takeTokensFromStart(b *Builder, text string, maxTokens int) string {
|
||||
rs := []rune(text)
|
||||
if len(rs) == 0 || maxTokens <= 0 {
|
||||
return ""
|
||||
}
|
||||
lo, hi := 0, len(rs)
|
||||
for lo < hi {
|
||||
mid := (lo + hi + 1) / 2
|
||||
if b.countTokens(string(rs[:mid])) <= maxTokens {
|
||||
lo = mid
|
||||
} else {
|
||||
hi = mid - 1
|
||||
}
|
||||
}
|
||||
return string(rs[:lo])
|
||||
}
|
||||
|
||||
func takeTokensFromEnd(b *Builder, text string, maxTokens int) string {
|
||||
rs := []rune(text)
|
||||
if len(rs) == 0 || maxTokens <= 0 {
|
||||
return ""
|
||||
}
|
||||
lo, hi := 0, len(rs)
|
||||
for lo < hi {
|
||||
mid := (lo + hi) / 2
|
||||
if b.countTokens(string(rs[mid:])) <= maxTokens {
|
||||
hi = mid
|
||||
} else {
|
||||
lo = mid + 1
|
||||
}
|
||||
}
|
||||
return string(rs[lo:])
|
||||
}
|
||||
|
||||
func truncateRunesWithNotice(s string, maxRunes int) string {
|
||||
rs := []rune(s)
|
||||
if len(rs) <= maxRunes {
|
||||
return s
|
||||
}
|
||||
const notice = "\n...[工具输出已截断 / tool output truncated]...\n"
|
||||
noticeRunes := []rune(notice)
|
||||
keep := maxRunes - len(noticeRunes)
|
||||
if keep < 200 {
|
||||
keep = maxRunes * 2 / 3
|
||||
}
|
||||
if keep < 1 {
|
||||
return notice
|
||||
}
|
||||
head := keep * 70 / 100
|
||||
tail := keep - head
|
||||
return string(rs[:head]) + notice + string(rs[len(rs)-tail:])
|
||||
}
|
||||
|
||||
func runeLen(s string) int {
|
||||
return len([]rune(s))
|
||||
}
|
||||
@@ -0,0 +1,63 @@
|
||||
package attackchain
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"cyberstrike-ai/internal/agent"
|
||||
"cyberstrike-ai/internal/config"
|
||||
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
func testBuilder(maxTotal int) *Builder {
|
||||
return &Builder{
|
||||
logger: zap.NewNop(),
|
||||
openAIConfig: &config.OpenAIConfig{Model: "gpt-4"},
|
||||
tokenCounter: agent.NewTikTokenCounter(),
|
||||
maxTokens: maxTotal,
|
||||
}
|
||||
}
|
||||
|
||||
func TestCompactFormattedToolBodies(t *testing.T) {
|
||||
long := strings.Repeat("x", 20000)
|
||||
in := "[user]: hi\n\n[tool] (tool_call_id: abc):\n" + long + "\n\n[assistant]: done\n"
|
||||
out := compactFormattedToolBodies(in, 500)
|
||||
if strings.Contains(out, strings.Repeat("x", 10000)) {
|
||||
t.Fatal("expected tool body to be truncated")
|
||||
}
|
||||
if !strings.Contains(out, "[user]: hi") {
|
||||
t.Fatal("expected user header preserved")
|
||||
}
|
||||
if !strings.Contains(out, "[assistant]: done") {
|
||||
t.Fatal("expected assistant header preserved")
|
||||
}
|
||||
}
|
||||
|
||||
func TestFitAttackChainPayloadWithinBudget(t *testing.T) {
|
||||
b := testBuilder(32000)
|
||||
react := strings.Repeat("scan ", 50000)
|
||||
model := strings.Repeat("result ", 10000)
|
||||
r, m, truncated := b.fitAttackChainPayload(react, model)
|
||||
if !truncated {
|
||||
t.Fatal("expected truncation for large payload")
|
||||
}
|
||||
prompt := b.buildSimplePrompt(r, m)
|
||||
total := b.countTokens(prompt) + attackChainMaxCompletionTokens(b.maxTokens) + attackChainSystemReserve
|
||||
if total > b.maxTokens+attackChainSafetyReserve {
|
||||
t.Fatalf("prompt still too large: estimated %d > max %d", total, b.maxTokens)
|
||||
}
|
||||
_ = m
|
||||
}
|
||||
|
||||
func TestAttackChainMaxCompletionTokens(t *testing.T) {
|
||||
if got := attackChainMaxCompletionTokens(120000); got != 15000 && got != 16384 {
|
||||
// 120000/8 = 15000
|
||||
if got < 4096 || got > 16384 {
|
||||
t.Fatalf("unexpected completion cap: %d", got)
|
||||
}
|
||||
}
|
||||
if got := attackChainMaxCompletionTokens(0); got != 8192 {
|
||||
t.Fatalf("expected default 8192, got %d", got)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,55 @@
|
||||
package audit
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
"cyberstrike-ai/internal/database"
|
||||
"cyberstrike-ai/internal/security"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// RegisterConversationCreateHook records platform audit rows for every new conversation.
|
||||
func RegisterConversationCreateHook(s *Service) {
|
||||
if s == nil {
|
||||
return
|
||||
}
|
||||
database.SetConversationCreateHook(func(conv *database.Conversation, meta database.ConversationCreateMeta) {
|
||||
detail := map[string]interface{}{
|
||||
"title": conv.Title,
|
||||
"source": meta.Source,
|
||||
}
|
||||
if meta.WebShellConnectionID != "" {
|
||||
detail["webshell_connection_id"] = meta.WebShellConnectionID
|
||||
}
|
||||
s.Record(nil, Entry{
|
||||
Category: "conversation",
|
||||
Action: "create",
|
||||
Result: "success",
|
||||
Message: "创建对话",
|
||||
ResourceType: "conversation",
|
||||
ResourceID: conv.ID,
|
||||
Detail: detail,
|
||||
ClientIP: meta.ClientIP,
|
||||
SessionHint: meta.SessionHint,
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
// ConversationCreateMeta builds audit metadata for conversation creation.
|
||||
func ConversationCreateMeta(source string) database.ConversationCreateMeta {
|
||||
return database.ConversationCreateMeta{Source: strings.TrimSpace(source)}
|
||||
}
|
||||
|
||||
// ConversationCreateMetaFromGin includes client IP and session hint when available.
|
||||
func ConversationCreateMetaFromGin(c *gin.Context, source string) database.ConversationCreateMeta {
|
||||
m := ConversationCreateMeta(source)
|
||||
if c == nil {
|
||||
return m
|
||||
}
|
||||
m.ClientIP = c.ClientIP()
|
||||
if token := c.GetString(security.ContextAuthTokenKey); token != "" {
|
||||
m.SessionHint = sessionHint(token)
|
||||
}
|
||||
return m
|
||||
}
|
||||
@@ -0,0 +1,9 @@
|
||||
package audit
|
||||
|
||||
// RetentionDays returns configured retention; 0 means keep forever.
|
||||
func (s *Service) RetentionDays() int {
|
||||
if s == nil || s.cfg == nil {
|
||||
return 0
|
||||
}
|
||||
return s.cfg.Audit.RetentionDaysEffective()
|
||||
}
|
||||
@@ -0,0 +1,29 @@
|
||||
package audit
|
||||
|
||||
import "github.com/gin-gonic/gin"
|
||||
|
||||
// RecordAction writes a platform audit row with common defaults.
|
||||
func (s *Service) RecordAction(c *gin.Context, category, action, result, message, resourceType, resourceID string, detail map[string]interface{}) {
|
||||
if s == nil {
|
||||
return
|
||||
}
|
||||
s.Record(c, Entry{
|
||||
Category: category,
|
||||
Action: action,
|
||||
Result: result,
|
||||
Message: message,
|
||||
ResourceType: resourceType,
|
||||
ResourceID: resourceID,
|
||||
Detail: detail,
|
||||
})
|
||||
}
|
||||
|
||||
// RecordOK is a shorthand for successful operations.
|
||||
func (s *Service) RecordOK(c *gin.Context, category, action, message, resourceType, resourceID string, detail map[string]interface{}) {
|
||||
s.RecordAction(c, category, action, "success", message, resourceType, resourceID, detail)
|
||||
}
|
||||
|
||||
// RecordFail is a shorthand for failed operations.
|
||||
func (s *Service) RecordFail(c *gin.Context, category, action, message string, detail map[string]interface{}) {
|
||||
s.RecordAction(c, category, action, "failure", message, "", "", detail)
|
||||
}
|
||||
@@ -0,0 +1,86 @@
|
||||
package audit
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
"cyberstrike-ai/internal/database"
|
||||
)
|
||||
|
||||
var auditActionsResourceRemoved = map[string]bool{
|
||||
"delete": true,
|
||||
"item_delete": true,
|
||||
"connection_delete": true,
|
||||
"listener_delete": true,
|
||||
"session_delete": true,
|
||||
"task_delete": true,
|
||||
"execution_delete": true,
|
||||
"execution_delete_batch": true,
|
||||
"delete_queue": true,
|
||||
"delete_batch_task": true,
|
||||
"markdown_delete": true,
|
||||
}
|
||||
|
||||
// ApplyResourceAvailability sets log.ResourceAvailable when the linked resource can be checked.
|
||||
func ApplyResourceAvailability(db *database.DB, log *database.AuditLog) {
|
||||
if log == nil || strings.TrimSpace(log.ResourceID) == "" {
|
||||
return
|
||||
}
|
||||
if auditActionsResourceRemoved[log.Action] {
|
||||
f := false
|
||||
log.ResourceAvailable = &f
|
||||
return
|
||||
}
|
||||
if db == nil {
|
||||
return
|
||||
}
|
||||
available, known := resourceStillExists(db, log.ResourceType, log.ResourceID)
|
||||
if known {
|
||||
log.ResourceAvailable = &available
|
||||
}
|
||||
}
|
||||
|
||||
func resourceStillExists(db *database.DB, resourceType, resourceID string) (bool, bool) {
|
||||
resourceID = strings.TrimSpace(resourceID)
|
||||
if resourceID == "" {
|
||||
return false, false
|
||||
}
|
||||
t := strings.TrimSpace(resourceType)
|
||||
if t == "" {
|
||||
if len(resourceID) > 8 && !strings.HasPrefix(resourceID, "c2_") {
|
||||
t = "conversation"
|
||||
} else {
|
||||
return false, false
|
||||
}
|
||||
}
|
||||
switch t {
|
||||
case "conversation":
|
||||
ok, err := db.ConversationExists(resourceID)
|
||||
return ok, err == nil
|
||||
case "vulnerability":
|
||||
_, err := db.GetVulnerability(resourceID)
|
||||
if err != nil {
|
||||
return false, strings.Contains(err.Error(), "不存在")
|
||||
}
|
||||
return true, true
|
||||
case "batch_queue":
|
||||
_, err := db.GetBatchQueue(resourceID)
|
||||
return err == nil, true
|
||||
case "c2_listener":
|
||||
_, err := db.GetC2Listener(resourceID)
|
||||
return err == nil, true
|
||||
case "c2_session":
|
||||
_, err := db.GetC2Session(resourceID)
|
||||
return err == nil, true
|
||||
case "c2_task":
|
||||
_, err := db.GetC2Task(resourceID)
|
||||
return err == nil, true
|
||||
case "webshell_connection":
|
||||
c, err := db.GetWebshellConnection(resourceID)
|
||||
return err == nil && c != nil, true
|
||||
case "tool_execution":
|
||||
_, err := db.GetToolExecution(resourceID)
|
||||
return err == nil, true
|
||||
default:
|
||||
return false, false
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,27 @@
|
||||
package audit
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// auditRetentionPurgeInterval is how often PurgeExpired runs while the process is up (startup also purges once).
|
||||
const auditRetentionPurgeInterval = time.Hour
|
||||
|
||||
// StartRetentionLoop periodically purges expired audit rows.
|
||||
func StartRetentionLoop(s *Service, logger *zap.Logger) {
|
||||
if s == nil {
|
||||
return
|
||||
}
|
||||
go func() {
|
||||
ticker := time.NewTicker(auditRetentionPurgeInterval)
|
||||
defer ticker.Stop()
|
||||
for range ticker.C {
|
||||
s.PurgeExpired()
|
||||
if logger != nil {
|
||||
logger.Debug("audit retention tick completed")
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
@@ -0,0 +1,58 @@
|
||||
package audit
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"strings"
|
||||
)
|
||||
|
||||
var sensitiveKeySubstrings = []string{
|
||||
"password", "api_key", "apikey", "secret", "token", "authorization",
|
||||
"credential", "private_key", "access_key",
|
||||
}
|
||||
|
||||
// SanitizeDetail redacts sensitive keys and truncates serialized size.
|
||||
func SanitizeDetail(detail map[string]interface{}, maxBytes int) map[string]interface{} {
|
||||
if detail == nil {
|
||||
return nil
|
||||
}
|
||||
if maxBytes <= 0 {
|
||||
maxBytes = 8192
|
||||
}
|
||||
out := sanitizeValue("", detail)
|
||||
if m, ok := out.(map[string]interface{}); ok {
|
||||
b, _ := json.Marshal(m)
|
||||
if len(b) > maxBytes {
|
||||
return map[string]interface{}{
|
||||
"_truncated": true,
|
||||
"_preview": string(b[:maxBytes]),
|
||||
}
|
||||
}
|
||||
return m
|
||||
}
|
||||
return map[string]interface{}{"value": out}
|
||||
}
|
||||
|
||||
func sanitizeValue(key string, v interface{}) interface{} {
|
||||
kl := strings.ToLower(key)
|
||||
for _, sub := range sensitiveKeySubstrings {
|
||||
if strings.Contains(kl, sub) {
|
||||
return "***"
|
||||
}
|
||||
}
|
||||
switch t := v.(type) {
|
||||
case map[string]interface{}:
|
||||
m := make(map[string]interface{}, len(t))
|
||||
for k, val := range t {
|
||||
m[k] = sanitizeValue(k, val)
|
||||
}
|
||||
return m
|
||||
case []interface{}:
|
||||
arr := make([]interface{}, len(t))
|
||||
for i, val := range t {
|
||||
arr[i] = sanitizeValue(key, val)
|
||||
}
|
||||
return arr
|
||||
default:
|
||||
return v
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,172 @@
|
||||
package audit
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"cyberstrike-ai/internal/config"
|
||||
"cyberstrike-ai/internal/database"
|
||||
"cyberstrike-ai/internal/security"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/uuid"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// Service persists platform audit logs.
|
||||
type Service struct {
|
||||
db *database.DB
|
||||
cfg *config.Config
|
||||
logger *zap.Logger
|
||||
failThrottle *failureThrottle
|
||||
}
|
||||
|
||||
// NewService creates an audit service.
|
||||
func NewService(db *database.DB, cfg *config.Config, logger *zap.Logger) *Service {
|
||||
return &Service{
|
||||
db: db,
|
||||
cfg: cfg,
|
||||
logger: logger,
|
||||
failThrottle: newFailureThrottle(),
|
||||
}
|
||||
}
|
||||
|
||||
// Enabled reports whether audit persistence is on.
|
||||
func (s *Service) Enabled() bool {
|
||||
if s == nil || s.cfg == nil {
|
||||
return false
|
||||
}
|
||||
return s.cfg.Audit.EnabledEffective()
|
||||
}
|
||||
|
||||
// Record writes one audit row from a Gin request context.
|
||||
func (s *Service) Record(c *gin.Context, e Entry) {
|
||||
if s == nil || !s.Enabled() || s.db == nil {
|
||||
return
|
||||
}
|
||||
if strings.TrimSpace(e.Category) == "" || strings.TrimSpace(e.Action) == "" {
|
||||
return
|
||||
}
|
||||
if e.Result == "failure" && !s.allowFailureAudit(c, e) {
|
||||
return
|
||||
}
|
||||
if strings.TrimSpace(e.Result) == "" {
|
||||
e.Result = "success"
|
||||
}
|
||||
if strings.TrimSpace(e.Level) == "" {
|
||||
if e.Result == "failure" {
|
||||
e.Level = "warn"
|
||||
} else {
|
||||
e.Level = "info"
|
||||
}
|
||||
}
|
||||
if strings.TrimSpace(e.Actor) == "" {
|
||||
e.Actor = "admin"
|
||||
}
|
||||
maxDetail := s.cfg.Audit.MaxDetailBytesEffective()
|
||||
detail := SanitizeDetail(e.Detail, maxDetail)
|
||||
|
||||
sessionHintVal := e.SessionHint
|
||||
if sessionHintVal == "" && c != nil {
|
||||
if token := c.GetString(security.ContextAuthTokenKey); token != "" {
|
||||
sessionHintVal = sessionHint(token)
|
||||
}
|
||||
}
|
||||
clientIPVal := e.ClientIP
|
||||
if clientIPVal == "" {
|
||||
clientIPVal = clientIP(c)
|
||||
}
|
||||
|
||||
row := &database.AuditLog{
|
||||
ID: "audit_" + strings.ReplaceAll(uuid.New().String(), "-", ""),
|
||||
CreatedAt: time.Now(),
|
||||
Level: e.Level,
|
||||
Category: e.Category,
|
||||
Action: e.Action,
|
||||
Result: e.Result,
|
||||
Actor: e.Actor,
|
||||
SessionHint: sessionHintVal,
|
||||
ClientIP: clientIPVal,
|
||||
UserAgent: userAgent(c),
|
||||
ResourceType: e.ResourceType,
|
||||
ResourceID: e.ResourceID,
|
||||
Message: e.Message,
|
||||
Detail: detail,
|
||||
}
|
||||
if err := s.db.AppendAuditLog(row); err != nil && s.logger != nil {
|
||||
s.logger.Warn("写入审计日志失败",
|
||||
zap.String("action", e.Action),
|
||||
zap.Error(err),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
// RecordSystem writes an audit row without HTTP context (e.g. retention cleanup).
|
||||
func (s *Service) RecordSystem(e Entry) {
|
||||
s.Record(nil, e)
|
||||
}
|
||||
|
||||
// PurgeExpired deletes rows older than retention_days when configured.
|
||||
func (s *Service) PurgeExpired() {
|
||||
if s == nil || s.db == nil || s.cfg == nil {
|
||||
return
|
||||
}
|
||||
days := s.cfg.Audit.RetentionDaysEffective()
|
||||
if days <= 0 {
|
||||
return
|
||||
}
|
||||
cutoff := time.Now().AddDate(0, 0, -days)
|
||||
n, err := s.db.DeleteAuditLogsBefore(cutoff)
|
||||
if err != nil {
|
||||
if s.logger != nil {
|
||||
s.logger.Warn("清理过期审计日志失败", zap.Error(err))
|
||||
}
|
||||
return
|
||||
}
|
||||
if n > 0 && s.logger != nil {
|
||||
s.logger.Info("已清理过期审计日志", zap.Int64("deleted", n))
|
||||
}
|
||||
}
|
||||
|
||||
// HintFromToken returns a short stable hash prefix for a session token.
|
||||
func HintFromToken(token string) string {
|
||||
return sessionHint(token)
|
||||
}
|
||||
|
||||
func sessionHint(token string) string {
|
||||
token = strings.TrimSpace(token)
|
||||
if token == "" {
|
||||
return ""
|
||||
}
|
||||
sum := sha256.Sum256([]byte(token))
|
||||
return hex.EncodeToString(sum[:4])
|
||||
}
|
||||
|
||||
func (s *Service) allowFailureAudit(c *gin.Context, e Entry) bool {
|
||||
if !isAuthFailureThrottled(e.Category, e.Action) {
|
||||
return true
|
||||
}
|
||||
cooldown := time.Duration(s.cfg.Audit.AuthFailureCooldownEffective()) * time.Second
|
||||
key := authFailureThrottleKey(e.Category, e.Action, clientIP(c))
|
||||
return s.failThrottle.allow(key, cooldown)
|
||||
}
|
||||
|
||||
func clientIP(c *gin.Context) string {
|
||||
if c == nil {
|
||||
return ""
|
||||
}
|
||||
return c.ClientIP()
|
||||
}
|
||||
|
||||
func userAgent(c *gin.Context) string {
|
||||
if c == nil {
|
||||
return ""
|
||||
}
|
||||
ua := c.GetHeader("User-Agent")
|
||||
if len(ua) > 512 {
|
||||
return ua[:512]
|
||||
}
|
||||
return ua
|
||||
}
|
||||
@@ -0,0 +1,55 @@
|
||||
package audit
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// failureThrottle deduplicates high-frequency failure audit rows (e.g. wrong password).
|
||||
type failureThrottle struct {
|
||||
mu sync.Mutex
|
||||
last map[string]time.Time
|
||||
}
|
||||
|
||||
func newFailureThrottle() *failureThrottle {
|
||||
return &failureThrottle{last: make(map[string]time.Time)}
|
||||
}
|
||||
|
||||
// allow reports whether a row with the given key may be written now.
|
||||
func (t *failureThrottle) allow(key string, cooldown time.Duration) bool {
|
||||
if t == nil || cooldown <= 0 || key == "" {
|
||||
return true
|
||||
}
|
||||
now := time.Now()
|
||||
t.mu.Lock()
|
||||
defer t.mu.Unlock()
|
||||
if prev, ok := t.last[key]; ok && now.Sub(prev) < cooldown {
|
||||
return false
|
||||
}
|
||||
t.last[key] = now
|
||||
if len(t.last) > 4096 {
|
||||
for k, ts := range t.last {
|
||||
if now.Sub(ts) > cooldown*2 {
|
||||
delete(t.last, k)
|
||||
}
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// authFailureThrottleKey builds a per-IP key for auth failure deduplication.
|
||||
func authFailureThrottleKey(category, action, clientIP string) string {
|
||||
return category + ":" + action + ":" + clientIP
|
||||
}
|
||||
|
||||
func isAuthFailureThrottled(category, action string) bool {
|
||||
if category != "auth" {
|
||||
return false
|
||||
}
|
||||
switch action {
|
||||
case "login", "change_password":
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,16 @@
|
||||
package audit
|
||||
|
||||
// Entry describes one platform audit record (not chat/tool execution bodies).
|
||||
type Entry struct {
|
||||
Level string
|
||||
Category string
|
||||
Action string
|
||||
Result string // success | failure
|
||||
Actor string
|
||||
SessionHint string
|
||||
ResourceType string
|
||||
ResourceID string
|
||||
Message string
|
||||
Detail map[string]interface{}
|
||||
ClientIP string // optional when c is nil (robot, batch, DB hook)
|
||||
}
|
||||
@@ -239,13 +239,15 @@ func (m *Manager) StartListener(id string) (*database.C2Listener, error) {
|
||||
}
|
||||
cfg.ApplyDefaults()
|
||||
|
||||
// 通过工厂创建具体实现
|
||||
// 通过工厂创建具体实现。必须使用 rec 的副本:HTTP handler 在返回 JSON 前会清空
|
||||
// rec.ImplantToken / EncryptionKey 做脱敏,若 listener 实现持有同一指针会导致 beacon 鉴权永久失败。
|
||||
listenerRec := *rec
|
||||
factory := m.registry.Get(rec.Type)
|
||||
if factory == nil {
|
||||
return nil, ErrUnsupportedType
|
||||
}
|
||||
inst, err := factory(ListenerCreationCtx{
|
||||
Listener: rec,
|
||||
Listener: &listenerRec,
|
||||
Config: cfg,
|
||||
Manager: m,
|
||||
Logger: m.logger.With(zap.String("listener_id", rec.ID), zap.String("type", rec.Type)),
|
||||
|
||||
@@ -0,0 +1,74 @@
|
||||
package c2
|
||||
|
||||
import (
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"cyberstrike-ai/internal/database"
|
||||
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// 回归:StartListener 返回的 rec 被 handler 脱敏清空 ImplantToken 后,运行中的 HTTP listener 仍能鉴权。
|
||||
func TestStartListener_ImplantTokenSurvivesHandlerRedaction(t *testing.T) {
|
||||
tmp := t.TempDir()
|
||||
db, err := database.NewDB(filepath.Join(tmp, "c2.sqlite"), zap.NewNop())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Cleanup(func() { _ = db.Close() })
|
||||
|
||||
lnPick, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
port := lnPick.Addr().(*net.TCPAddr).Port
|
||||
_ = lnPick.Close()
|
||||
|
||||
mgr := NewManager(db, zap.NewNop(), tmp)
|
||||
mgr.Registry().Register(string(ListenerTypeHTTPBeacon), NewHTTPBeaconListener)
|
||||
rec, err := mgr.CreateListener(CreateListenerInput{
|
||||
Name: "t",
|
||||
Type: string(ListenerTypeHTTPBeacon),
|
||||
BindHost: "127.0.0.1",
|
||||
BindPort: port,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
token := rec.ImplantToken
|
||||
|
||||
rec, err = mgr.StartListener(rec.ID)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
// 模拟 internal/handler/c2.go StartListener 在 JSON 响应前的脱敏
|
||||
rec.ImplantToken = ""
|
||||
rec.EncryptionKey = ""
|
||||
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
|
||||
body := `{"hostname":"n","username":"u","os":"Linux","arch":"amd64","internal_ip":"10.0.0.1","pid":42}`
|
||||
req, _ := http.NewRequest(http.MethodPost, "http://127.0.0.1:"+strconv.Itoa(port)+"/check_in", strings.NewReader(body))
|
||||
req.Header.Set("X-Implant-Token", token)
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
b, _ := io.ReadAll(resp.Body)
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
t.Fatalf("status=%d body=%s", resp.StatusCode, b)
|
||||
}
|
||||
if !strings.Contains(string(b), "session_id") {
|
||||
t.Fatalf("expected session_id in body: %s", b)
|
||||
}
|
||||
_ = mgr.StopListener(rec.ID)
|
||||
}
|
||||
+273
-14
@@ -26,6 +26,7 @@ type Config struct {
|
||||
Security SecurityConfig `yaml:"security"`
|
||||
Database DatabaseConfig `yaml:"database"`
|
||||
Auth AuthConfig `yaml:"auth"`
|
||||
Audit AuditConfig `yaml:"audit,omitempty" json:"audit,omitempty"`
|
||||
ExternalMCP ExternalMCPConfig `yaml:"external_mcp,omitempty"`
|
||||
Knowledge KnowledgeConfig `yaml:"knowledge,omitempty"`
|
||||
C2 C2Config `yaml:"c2,omitempty" json:"c2,omitempty"` // 内置 C2 总开关;未配置时默认启用
|
||||
@@ -39,9 +40,9 @@ type Config struct {
|
||||
|
||||
// MultiAgentConfig 基于 CloudWeGo Eino adk/prebuilt 的多代理编排(deep | plan_execute | supervisor,与单 Agent /agent-loop 并存)。
|
||||
type MultiAgentConfig struct {
|
||||
Enabled bool `yaml:"enabled" json:"enabled"`
|
||||
RobotUseMultiAgent bool `yaml:"robot_use_multi_agent" json:"robot_use_multi_agent"` // 为 true 时钉钉/飞书/企微机器人走 Eino 多代理
|
||||
BatchUseMultiAgent bool `yaml:"batch_use_multi_agent" json:"batch_use_multi_agent"` // 为 true 时批量任务队列中每子任务走 Eino 多代理
|
||||
Enabled bool `yaml:"enabled" json:"enabled"`
|
||||
RobotDefaultAgentMode string `yaml:"robot_default_agent_mode,omitempty" json:"robot_default_agent_mode,omitempty"` // react | eino_single | deep | plan_execute | supervisor
|
||||
BatchUseMultiAgent bool `yaml:"batch_use_multi_agent" json:"batch_use_multi_agent"` // 为 true 时批量任务队列中每子任务走 Eino 多代理
|
||||
// Orchestration 已弃用:保留仅兼容旧版 config.yaml;编排由聊天/WebShell 请求体 orchestration 决定,未传时按 deep。
|
||||
Orchestration string `yaml:"orchestration,omitempty" json:"orchestration,omitempty"`
|
||||
MaxIteration int `yaml:"max_iteration" json:"max_iteration"` // 主代理 / 执行器最大推理轮次(Deep、Supervisor、plan_execute 的 Executor)
|
||||
@@ -63,6 +64,126 @@ type MultiAgentConfig struct {
|
||||
EinoSkills MultiAgentEinoSkillsConfig `yaml:"eino_skills,omitempty" json:"eino_skills,omitempty"`
|
||||
// EinoMiddleware wires optional ADK middleware (patchtoolcalls, toolsearch, plantask, reduction) and Deep extras.
|
||||
EinoMiddleware MultiAgentEinoMiddlewareConfig `yaml:"eino_middleware,omitempty" json:"eino_middleware,omitempty"`
|
||||
// EinoCallbacks attaches CloudWeGo eino callbacks.InitCallbacks on ADK Runner context (structured logs + optional SSE trace).
|
||||
EinoCallbacks MultiAgentEinoCallbacksConfig `yaml:"eino_callbacks,omitempty" json:"eino_callbacks,omitempty"`
|
||||
}
|
||||
|
||||
// MultiAgentEinoCallbacksConfig enables Eino unified callbacks on each ADK agent run (deep / plan_execute / supervisor / eino_single).
|
||||
// Modes: log_only (zap + optional OTel; no SSE to browser), sse (adds client SSE eino_trace_* when sse_trace_to_client), full (sse rules + stream callback copies closed).
|
||||
type MultiAgentEinoCallbacksConfig struct {
|
||||
Enabled bool `yaml:"enabled" json:"enabled"`
|
||||
Mode string `yaml:"mode,omitempty" json:"mode,omitempty"` // log_only | sse | full; empty with enabled=true defaults to log_only
|
||||
// SseTraceToClient when true emits eino_trace_* SSE for UI (use only for admin/debug; nil/false recommended in production).
|
||||
SseTraceToClient *bool `yaml:"sse_trace_to_client,omitempty" json:"sse_trace_to_client,omitempty"`
|
||||
// Otel configures OpenTelemetry trace export (independent of mode; exporter none disables export even if enabled).
|
||||
Otel MultiAgentEinoCallbacksOtelConfig `yaml:"otel,omitempty" json:"otel,omitempty"`
|
||||
// MaxInputSummaryRunes / MaxOutputSummaryRunes cap text placed in SSE payloads and debug logs (not full payloads).
|
||||
MaxInputSummaryRunes int `yaml:"max_input_summary_runes,omitempty" json:"max_input_summary_runes,omitempty"`
|
||||
MaxOutputSummaryRunes int `yaml:"max_output_summary_runes,omitempty" json:"max_output_summary_runes,omitempty"`
|
||||
// ZapVerbose when true logs input/output summaries at zap.Debug on start/end; false uses Info with short fields only.
|
||||
ZapVerbose bool `yaml:"zap_verbose,omitempty" json:"zap_verbose,omitempty"`
|
||||
}
|
||||
|
||||
// MultiAgentEinoCallbacksOtelConfig OpenTelemetry for Eino callback spans (W3C trace in collector / stdout).
|
||||
type MultiAgentEinoCallbacksOtelConfig struct {
|
||||
Enabled bool `yaml:"enabled" json:"enabled"`
|
||||
ServiceName string `yaml:"service_name,omitempty" json:"service_name,omitempty"`
|
||||
Exporter string `yaml:"exporter,omitempty" json:"exporter,omitempty"` // none | stdout | otlphttp
|
||||
OTLPEndpoint string `yaml:"otlp_endpoint,omitempty" json:"otlp_endpoint,omitempty"` // host:port, e.g. localhost:4318 (path /v1/traces)
|
||||
SampleRatio float64 `yaml:"sample_ratio,omitempty" json:"sample_ratio,omitempty"` // 0–1, default 1.0
|
||||
}
|
||||
|
||||
// EinoCallbacksModeEffective returns off | log_only | sse | full.
|
||||
func (c MultiAgentEinoCallbacksConfig) EinoCallbacksModeEffective() string {
|
||||
if !c.Enabled {
|
||||
return "off"
|
||||
}
|
||||
m := strings.TrimSpace(strings.ToLower(c.Mode))
|
||||
switch m {
|
||||
case "log_only":
|
||||
return "log_only"
|
||||
case "sse":
|
||||
return "sse"
|
||||
case "full":
|
||||
return "full"
|
||||
case "":
|
||||
return "log_only"
|
||||
default:
|
||||
return "log_only"
|
||||
}
|
||||
}
|
||||
|
||||
// SseTraceToClientEffective is false unless explicitly set true (best practice: do not expose framework traces to end users by default).
|
||||
func (c MultiAgentEinoCallbacksConfig) SseTraceToClientEffective() bool {
|
||||
if c.SseTraceToClient == nil {
|
||||
return false
|
||||
}
|
||||
return *c.SseTraceToClient
|
||||
}
|
||||
|
||||
// ShouldEmitEinoTraceSSE is true when client-visible trace events should be sent over progress/SSE.
|
||||
func (c MultiAgentEinoCallbacksConfig) ShouldEmitEinoTraceSSE(mode string) bool {
|
||||
if !c.SseTraceToClientEffective() {
|
||||
return false
|
||||
}
|
||||
return mode == "sse" || mode == "full"
|
||||
}
|
||||
|
||||
// OtelExporterEffective returns none | stdout | otlphttp.
|
||||
func (c MultiAgentEinoCallbacksOtelConfig) OtelExporterEffective() string {
|
||||
e := strings.TrimSpace(strings.ToLower(c.Exporter))
|
||||
switch e {
|
||||
case "none", "stdout", "otlphttp":
|
||||
return e
|
||||
case "":
|
||||
if c.Enabled {
|
||||
return "stdout"
|
||||
}
|
||||
return "none"
|
||||
default:
|
||||
return "none"
|
||||
}
|
||||
}
|
||||
|
||||
// OtelTracingActive is true when spans should be started (enabled + non-none exporter).
|
||||
func (c MultiAgentEinoCallbacksConfig) OtelTracingActive() bool {
|
||||
if !c.Otel.Enabled {
|
||||
return false
|
||||
}
|
||||
return c.Otel.OtelExporterEffective() != "none"
|
||||
}
|
||||
|
||||
func (c MultiAgentEinoCallbacksOtelConfig) ServiceNameEffective() string {
|
||||
s := strings.TrimSpace(c.ServiceName)
|
||||
if s != "" {
|
||||
return s
|
||||
}
|
||||
return "cyberstrike-ai"
|
||||
}
|
||||
|
||||
func (c MultiAgentEinoCallbacksOtelConfig) SampleRatioEffective() float64 {
|
||||
r := c.SampleRatio
|
||||
if r <= 0 {
|
||||
return 1.0
|
||||
}
|
||||
if r > 1 {
|
||||
return 1.0
|
||||
}
|
||||
return r
|
||||
}
|
||||
|
||||
func (c MultiAgentEinoCallbacksConfig) EinoCallbacksMaxInputSummaryRunes() int {
|
||||
if c.MaxInputSummaryRunes > 0 {
|
||||
return c.MaxInputSummaryRunes
|
||||
}
|
||||
return 400
|
||||
}
|
||||
|
||||
func (c MultiAgentEinoCallbacksConfig) EinoCallbacksMaxOutputSummaryRunes() int {
|
||||
if c.MaxOutputSummaryRunes > 0 {
|
||||
return c.MaxOutputSummaryRunes
|
||||
}
|
||||
return 400
|
||||
}
|
||||
|
||||
// MultiAgentEinoMiddlewareConfig optional Eino ADK middleware and Deep / supervisor tuning.
|
||||
@@ -90,7 +211,8 @@ type MultiAgentEinoMiddlewareConfig struct {
|
||||
SummarizationTriggerRatio float64 `yaml:"summarization_trigger_ratio,omitempty" json:"summarization_trigger_ratio,omitempty"`
|
||||
// SummarizationEmitInternalEvents controls middleware internal event emission (default true).
|
||||
SummarizationEmitInternalEvents *bool `yaml:"summarization_emit_internal_events,omitempty" json:"summarization_emit_internal_events,omitempty"`
|
||||
// HistoryInputBudgetRatio caps pre-agent history tokens as max_total_tokens * ratio (default 0.35).
|
||||
// HistoryInputBudgetRatio 已不影响 Eino:从 last_react 轨迹转 ADK 消息时**不再**按 token 比例裁剪(完整注入)。
|
||||
// 字段仍保留,便于旧版 config 不报错;新部署可省略。
|
||||
HistoryInputBudgetRatio float64 `yaml:"history_input_budget_ratio,omitempty" json:"history_input_budget_ratio,omitempty"`
|
||||
// PlanExecuteUserInputBudgetRatio caps planner/replanner/executor userInput prompt budget ratio (default 0.35).
|
||||
PlanExecuteUserInputBudgetRatio float64 `yaml:"plan_execute_user_input_budget_ratio,omitempty" json:"plan_execute_user_input_budget_ratio,omitempty"`
|
||||
@@ -106,6 +228,10 @@ type MultiAgentEinoMiddlewareConfig struct {
|
||||
DeepOutputKey string `yaml:"deep_output_key,omitempty" json:"deep_output_key,omitempty"`
|
||||
// DeepModelRetryMaxRetries > 0 enables deep.Config ModelRetryConfig (framework-level chat model retries).
|
||||
DeepModelRetryMaxRetries int `yaml:"deep_model_retry_max_retries,omitempty" json:"deep_model_retry_max_retries,omitempty"`
|
||||
// RunRetryMaxAttempts > 0:429/5xx/网络抖动时 handler 分段续跑次数;0=默认 10。
|
||||
RunRetryMaxAttempts int `yaml:"run_retry_max_attempts,omitempty" json:"run_retry_max_attempts,omitempty"`
|
||||
// RunRetryMaxBackoffSec 单次退避上限秒数;0=默认 30。
|
||||
RunRetryMaxBackoffSec int `yaml:"run_retry_max_backoff_sec,omitempty" json:"run_retry_max_backoff_sec,omitempty"`
|
||||
// TaskToolDescriptionPrefix when non-empty sets deep.Config TaskToolDescriptionGenerator (sub-agent names appended).
|
||||
TaskToolDescriptionPrefix string `yaml:"task_tool_description_prefix,omitempty" json:"task_tool_description_prefix,omitempty"`
|
||||
}
|
||||
@@ -241,9 +367,9 @@ type MultiAgentSubConfig struct {
|
||||
|
||||
// MultiAgentPublic 返回给前端的精简信息(不含子代理指令全文)。
|
||||
type MultiAgentPublic struct {
|
||||
Enabled bool `json:"enabled"`
|
||||
RobotUseMultiAgent bool `json:"robot_use_multi_agent"`
|
||||
BatchUseMultiAgent bool `json:"batch_use_multi_agent"`
|
||||
Enabled bool `json:"enabled"`
|
||||
RobotDefaultAgentMode string `json:"robot_default_agent_mode,omitempty"`
|
||||
BatchUseMultiAgent bool `json:"batch_use_multi_agent"`
|
||||
SubAgentCount int `json:"sub_agent_count"`
|
||||
Orchestration string `json:"orchestration,omitempty"`
|
||||
PlanExecuteLoopMaxIterations int `json:"plan_execute_loop_max_iterations"`
|
||||
@@ -251,6 +377,18 @@ type MultiAgentPublic struct {
|
||||
ToolSearchAlwaysVisibleEffectiveTools []string `json:"tool_search_always_visible_effective_tools,omitempty"`
|
||||
}
|
||||
|
||||
// NormalizeRobotAgentMode 解析机器人默认对话模式(react | eino_single | deep | plan_execute | supervisor);空值视为 react。
|
||||
func NormalizeRobotAgentMode(ma MultiAgentConfig) string {
|
||||
s := strings.TrimSpace(strings.ToLower(ma.RobotDefaultAgentMode))
|
||||
if s == "" || s == "single" || s == "react" {
|
||||
return "react"
|
||||
}
|
||||
if s == "eino_single" {
|
||||
return "eino_single"
|
||||
}
|
||||
return NormalizeMultiAgentOrchestration(s)
|
||||
}
|
||||
|
||||
// NormalizeMultiAgentOrchestration 返回 deep、plan_execute 或 supervisor。
|
||||
func NormalizeMultiAgentOrchestration(s string) string {
|
||||
v := strings.TrimSpace(strings.ToLower(s))
|
||||
@@ -266,21 +404,35 @@ func NormalizeMultiAgentOrchestration(s string) string {
|
||||
|
||||
// MultiAgentAPIUpdate 设置页/API 仅更新多代理标量字段;写入 YAML 时不覆盖 sub_agents 等块。
|
||||
type MultiAgentAPIUpdate struct {
|
||||
Enabled bool `json:"enabled"`
|
||||
RobotUseMultiAgent bool `json:"robot_use_multi_agent"`
|
||||
BatchUseMultiAgent bool `json:"batch_use_multi_agent"`
|
||||
Enabled bool `json:"enabled"`
|
||||
RobotDefaultAgentMode string `json:"robot_default_agent_mode,omitempty"`
|
||||
BatchUseMultiAgent bool `json:"batch_use_multi_agent"`
|
||||
PlanExecuteLoopMaxIterations *int `json:"plan_execute_loop_max_iterations,omitempty"`
|
||||
ToolSearchAlwaysVisibleTools []string `json:"tool_search_always_visible_tools,omitempty"`
|
||||
// 指针区分「JSON 未传该字段」与「传空数组要清空」;省略时不应覆盖 YAML 中的常驻工具白名单。
|
||||
ToolSearchAlwaysVisibleTools *[]string `json:"tool_search_always_visible_tools,omitempty"`
|
||||
}
|
||||
|
||||
// RobotsConfig 机器人配置(企业微信、钉钉、飞书等)
|
||||
// RobotsConfig 机器人配置(企业微信、钉钉、飞书、微信 iLink 等)
|
||||
type RobotsConfig struct {
|
||||
Session RobotSessionConfig `yaml:"session,omitempty" json:"session,omitempty"` // 机器人会话隔离策略
|
||||
Wechat RobotWechatConfig `yaml:"wechat,omitempty" json:"wechat,omitempty"` // 微信(iLink 扫码绑定)
|
||||
Wecom RobotWecomConfig `yaml:"wecom,omitempty" json:"wecom,omitempty"` // 企业微信
|
||||
Dingtalk RobotDingtalkConfig `yaml:"dingtalk,omitempty" json:"dingtalk,omitempty"` // 钉钉
|
||||
Lark RobotLarkConfig `yaml:"lark,omitempty" json:"lark,omitempty"` // 飞书
|
||||
}
|
||||
|
||||
// RobotWechatConfig 微信 iLink 机器人配置(个人微信 ClawBot / iLink 协议)
|
||||
type RobotWechatConfig struct {
|
||||
Enabled bool `yaml:"enabled" json:"enabled"`
|
||||
BotToken string `yaml:"bot_token,omitempty" json:"bot_token,omitempty"`
|
||||
ILinkBotID string `yaml:"ilink_bot_id,omitempty" json:"ilink_bot_id,omitempty"`
|
||||
ILinkUserID string `yaml:"ilink_user_id,omitempty" json:"ilink_user_id,omitempty"`
|
||||
BaseURL string `yaml:"base_url,omitempty" json:"base_url,omitempty"` // 默认 https://ilinkai.weixin.qq.com
|
||||
BotType string `yaml:"bot_type,omitempty" json:"bot_type,omitempty"` // get_bot_qrcode 参数,默认 3
|
||||
BotAgent string `yaml:"bot_agent,omitempty" json:"bot_agent,omitempty"` // base_info.bot_agent
|
||||
GetUpdatesBuf string `yaml:"get_updates_buf,omitempty" json:"get_updates_buf,omitempty"` // 长轮询游标(运行时)
|
||||
}
|
||||
|
||||
// RobotSessionConfig 机器人会话隔离策略
|
||||
type RobotSessionConfig struct {
|
||||
StrictUserIdentity *bool `yaml:"strict_user_identity,omitempty" json:"strict_user_identity,omitempty"` // true 时只允许真实用户标识,不允许会话/群 ID 兜底
|
||||
@@ -322,8 +474,17 @@ type RobotLarkConfig struct {
|
||||
}
|
||||
|
||||
type ServerConfig struct {
|
||||
Host string `yaml:"host"`
|
||||
Port int `yaml:"port"`
|
||||
Host string `yaml:"host" json:"host"`
|
||||
Port int `yaml:"port" json:"port"`
|
||||
// TLSEnabled 为 true 时主 Web UI 使用 HTTPS;现代浏览器在同源下会协商 HTTP/2,缓解 HTTP/1.1 每源并发连接数限制。
|
||||
TLSEnabled bool `yaml:"tls_enabled,omitempty" json:"tls_enabled,omitempty"`
|
||||
// TLSCertPath / TLSKeyPath 非空时从 PEM 文件加载证书(生产环境推荐)。
|
||||
TLSCertPath string `yaml:"tls_cert_path,omitempty" json:"tls_cert_path,omitempty"`
|
||||
TLSKeyPath string `yaml:"tls_key_path,omitempty" json:"tls_key_path,omitempty"`
|
||||
// TLSAutoSelfSign 为 true 且未配置有效证书路径时,启动时生成内存自签证书(仅本地/测试;浏览器会提示不受信任)。
|
||||
TLSAutoSelfSign bool `yaml:"tls_auto_self_sign,omitempty" json:"tls_auto_self_sign,omitempty"`
|
||||
// TLSHTTPRedirect 为 false 时禁用 HTTP→HTTPS 跳转;省略或为 true 且已启用 HTTPS 时,明文 HTTP 访问将 308 跳转到 HTTPS(同端口嗅探分流)。
|
||||
TLSHTTPRedirect *bool `yaml:"tls_http_redirect,omitempty" json:"tls_http_redirect,omitempty"`
|
||||
}
|
||||
|
||||
type LogConfig struct {
|
||||
@@ -345,6 +506,48 @@ type OpenAIConfig struct {
|
||||
BaseURL string `yaml:"base_url" json:"base_url"`
|
||||
Model string `yaml:"model" json:"model"`
|
||||
MaxTotalTokens int `yaml:"max_total_tokens,omitempty" json:"max_total_tokens,omitempty"`
|
||||
// Reasoning 控制 Eino ChatModel 的 thinking / reasoning_effort / output_config 等(仅 Eino 路径生效;原生 ReAct 忽略)。
|
||||
Reasoning OpenAIReasoningConfig `yaml:"reasoning,omitempty" json:"reasoning,omitempty"`
|
||||
}
|
||||
|
||||
// OpenAIReasoningConfig 全局默认与网关 profile(对话页可通过 ChatRequest.reasoning 覆盖,受 AllowClientReasoning 约束)。
|
||||
type OpenAIReasoningConfig struct {
|
||||
// Mode: auto(默认)| on | off | default(与 auto 相同)。off 时不向模型附加推理扩展字段。
|
||||
Mode string `yaml:"mode,omitempty" json:"mode,omitempty"`
|
||||
// Effort: low | medium | high | max | xhigh;max/xhigh 为不同网关最高档命名,原样下发、不互转。空表示不单独指定强度。
|
||||
Effort string `yaml:"effort,omitempty" json:"effort,omitempty"`
|
||||
// AllowClientReasoning 为 false 时忽略请求体 reasoning;nil 或未设置等同于 true。
|
||||
AllowClientReasoning *bool `yaml:"allow_client_reasoning,omitempty" json:"allow_client_reasoning,omitempty"`
|
||||
// Profile: auto | deepseek_compat | openai_compat | output_config_effort
|
||||
Profile string `yaml:"profile,omitempty" json:"profile,omitempty"`
|
||||
// ExtraRequestFields 合并进 Chat Completions 根 JSON(管理员用;与自动字段同名时后者覆盖)。
|
||||
ExtraRequestFields map[string]interface{} `yaml:"extra_request_fields,omitempty" json:"extra_request_fields,omitempty"`
|
||||
}
|
||||
|
||||
// ModeEffective returns auto when empty or default.
|
||||
func (c OpenAIReasoningConfig) ModeEffective() string {
|
||||
m := strings.ToLower(strings.TrimSpace(c.Mode))
|
||||
if m == "" || m == "default" {
|
||||
return "auto"
|
||||
}
|
||||
return m
|
||||
}
|
||||
|
||||
// ProfileEffective returns auto when empty.
|
||||
func (c OpenAIReasoningConfig) ProfileEffective() string {
|
||||
p := strings.ToLower(strings.TrimSpace(c.Profile))
|
||||
if p == "" {
|
||||
return "auto"
|
||||
}
|
||||
return p
|
||||
}
|
||||
|
||||
// AllowClientReasoningEffective true when client may send ChatRequest.reasoning.
|
||||
func (c OpenAIReasoningConfig) AllowClientReasoningEffective() bool {
|
||||
if c.AllowClientReasoning == nil {
|
||||
return true
|
||||
}
|
||||
return *c.AllowClientReasoning
|
||||
}
|
||||
|
||||
type FofaConfig struct {
|
||||
@@ -389,6 +592,51 @@ type AuthConfig struct {
|
||||
GeneratedPasswordPersistErr string `yaml:"-" json:"-"`
|
||||
}
|
||||
|
||||
// AuditConfig platform operation audit log settings (not chat/tool execution bodies).
|
||||
type AuditConfig struct {
|
||||
// Enabled nil or true enables persistence; explicit false disables.
|
||||
Enabled *bool `yaml:"enabled,omitempty" json:"enabled,omitempty"`
|
||||
RetentionDays int `yaml:"retention_days,omitempty" json:"retention_days,omitempty"`
|
||||
MaxDetailBytes int `yaml:"max_detail_bytes,omitempty" json:"max_detail_bytes,omitempty"`
|
||||
// AuthFailureCooldownSeconds: per-IP cooldown for auth login/change_password failure audit rows; -1 disables; 0 uses default 60.
|
||||
AuthFailureCooldownSeconds int `yaml:"auth_failure_cooldown_seconds,omitempty" json:"auth_failure_cooldown_seconds,omitempty"`
|
||||
}
|
||||
|
||||
// EnabledEffective returns true unless audit.enabled is explicitly false.
|
||||
func (a AuditConfig) EnabledEffective() bool {
|
||||
if a.Enabled == nil {
|
||||
return true
|
||||
}
|
||||
return *a.Enabled
|
||||
}
|
||||
|
||||
// RetentionDaysEffective returns retention; 0 means keep forever.
|
||||
func (a AuditConfig) RetentionDaysEffective() int {
|
||||
if a.RetentionDays < 0 {
|
||||
return 0
|
||||
}
|
||||
return a.RetentionDays
|
||||
}
|
||||
|
||||
// MaxDetailBytesEffective caps serialized detail JSON size.
|
||||
func (a AuditConfig) MaxDetailBytesEffective() int {
|
||||
if a.MaxDetailBytes <= 0 {
|
||||
return 8192
|
||||
}
|
||||
return a.MaxDetailBytes
|
||||
}
|
||||
|
||||
// AuthFailureCooldownEffective returns seconds between duplicate auth-failure audit rows per IP (default 60; -1 disables).
|
||||
func (a AuditConfig) AuthFailureCooldownEffective() int {
|
||||
if a.AuthFailureCooldownSeconds < 0 {
|
||||
return 0
|
||||
}
|
||||
if a.AuthFailureCooldownSeconds == 0 {
|
||||
return 60
|
||||
}
|
||||
return a.AuthFailureCooldownSeconds
|
||||
}
|
||||
|
||||
// ExternalMCPConfig 外部MCP配置
|
||||
type ExternalMCPConfig struct {
|
||||
Servers map[string]ExternalMCPServerConfig `yaml:"servers,omitempty" json:"servers,omitempty"`
|
||||
@@ -481,6 +729,9 @@ func Load(path string) (*Config, error) {
|
||||
if cfg.Auth.SessionDurationHours <= 0 {
|
||||
cfg.Auth.SessionDurationHours = 12
|
||||
}
|
||||
if cfg.Audit.MaxDetailBytes <= 0 {
|
||||
cfg.Audit.MaxDetailBytes = 8192
|
||||
}
|
||||
if strings.TrimSpace(cfg.Auth.Password) == "" {
|
||||
password, err := generateStrongPassword(24)
|
||||
if err != nil {
|
||||
@@ -984,6 +1235,14 @@ func Default() *Config {
|
||||
Auth: AuthConfig{
|
||||
SessionDurationHours: 12,
|
||||
},
|
||||
Audit: func() AuditConfig {
|
||||
on := true
|
||||
return AuditConfig{
|
||||
RetentionDays: 90,
|
||||
MaxDetailBytes: 8192,
|
||||
Enabled: &on,
|
||||
}
|
||||
}(),
|
||||
Robots: RobotsConfig{
|
||||
Session: RobotSessionConfig{
|
||||
StrictUserIdentity: &strictRobotIdentity,
|
||||
|
||||
@@ -0,0 +1,46 @@
|
||||
package config
|
||||
|
||||
import "strings"
|
||||
|
||||
// MainWebUIUsesHTTPS 判断主 Web UI 是否以 HTTPS 监听(与 internal/app.prepareMainServerTLS 前置条件一致)。
|
||||
func MainWebUIUsesHTTPS(s *ServerConfig) bool {
|
||||
if s == nil {
|
||||
return false
|
||||
}
|
||||
if s.TLSEnabled {
|
||||
return true
|
||||
}
|
||||
if s.TLSAutoSelfSign {
|
||||
return true
|
||||
}
|
||||
cert := strings.TrimSpace(s.TLSCertPath)
|
||||
key := strings.TrimSpace(s.TLSKeyPath)
|
||||
return cert != "" && key != ""
|
||||
}
|
||||
|
||||
// ServerHTTPRedirectEnabled 是否在主站启用 HTTPS 时把明文 HTTP 请求重定向到 HTTPS(默认开启)。
|
||||
func ServerHTTPRedirectEnabled(s *ServerConfig) bool {
|
||||
if s == nil || !MainWebUIUsesHTTPS(s) {
|
||||
return false
|
||||
}
|
||||
if s.TLSHTTPRedirect == nil {
|
||||
return true
|
||||
}
|
||||
return *s.TLSHTTPRedirect
|
||||
}
|
||||
|
||||
// ApplyDevHTTPSBootstrap 供 --https / 一键脚本使用:强制开启主站 TLS。
|
||||
// 若已配置 tls_cert_path 与 tls_key_path 则仅用 PEM,不开启自签;否则启用 tls_auto_self_sign(内存证书,仅本地测试)。
|
||||
func ApplyDevHTTPSBootstrap(cfg *Config) {
|
||||
if cfg == nil {
|
||||
return
|
||||
}
|
||||
cfg.Server.TLSEnabled = true
|
||||
cert := strings.TrimSpace(cfg.Server.TLSCertPath)
|
||||
key := strings.TrimSpace(cfg.Server.TLSKeyPath)
|
||||
if cert != "" && key != "" {
|
||||
cfg.Server.TLSAutoSelfSign = false
|
||||
return
|
||||
}
|
||||
cfg.Server.TLSAutoSelfSign = true
|
||||
}
|
||||
@@ -0,0 +1,210 @@
|
||||
package database
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// AuditLog platform operation audit record.
|
||||
type AuditLog struct {
|
||||
ID string `json:"id"`
|
||||
CreatedAt time.Time `json:"createdAt"`
|
||||
Level string `json:"level"`
|
||||
Category string `json:"category"`
|
||||
Action string `json:"action"`
|
||||
Result string `json:"result"`
|
||||
Actor string `json:"actor"`
|
||||
SessionHint string `json:"sessionHint,omitempty"`
|
||||
ClientIP string `json:"clientIp,omitempty"`
|
||||
UserAgent string `json:"userAgent,omitempty"`
|
||||
ResourceType string `json:"resourceType,omitempty"`
|
||||
ResourceID string `json:"resourceId,omitempty"`
|
||||
ResourceAvailable *bool `json:"resourceAvailable,omitempty"` // API-only: whether linked resource still exists
|
||||
Message string `json:"message"`
|
||||
Detail map[string]interface{} `json:"detail,omitempty"`
|
||||
}
|
||||
|
||||
// ListAuditLogsFilter query parameters.
|
||||
type ListAuditLogsFilter struct {
|
||||
Level string
|
||||
Category string
|
||||
Action string
|
||||
Result string
|
||||
Query string
|
||||
ResourceType string
|
||||
ResourceID string
|
||||
Since *time.Time
|
||||
Until *time.Time
|
||||
Limit int
|
||||
Offset int
|
||||
}
|
||||
|
||||
func buildAuditLogsWhere(filter ListAuditLogsFilter) (string, []interface{}) {
|
||||
conditions := []string{"1=1"}
|
||||
args := []interface{}{}
|
||||
if filter.Level != "" {
|
||||
conditions = append(conditions, "level = ?")
|
||||
args = append(args, filter.Level)
|
||||
}
|
||||
if filter.Category != "" {
|
||||
conditions = append(conditions, "category = ?")
|
||||
args = append(args, filter.Category)
|
||||
}
|
||||
if filter.Action != "" {
|
||||
conditions = append(conditions, "action = ?")
|
||||
args = append(args, filter.Action)
|
||||
}
|
||||
if filter.Result != "" {
|
||||
conditions = append(conditions, "result = ?")
|
||||
args = append(args, filter.Result)
|
||||
}
|
||||
if filter.ResourceType != "" {
|
||||
conditions = append(conditions, "resource_type = ?")
|
||||
args = append(args, filter.ResourceType)
|
||||
}
|
||||
if filter.ResourceID != "" {
|
||||
conditions = append(conditions, "resource_id = ?")
|
||||
args = append(args, filter.ResourceID)
|
||||
}
|
||||
if filter.Since != nil {
|
||||
conditions = append(conditions, "created_at >= ?")
|
||||
args = append(args, *filter.Since)
|
||||
}
|
||||
if filter.Until != nil {
|
||||
conditions = append(conditions, "created_at <= ?")
|
||||
args = append(args, *filter.Until)
|
||||
}
|
||||
if q := strings.TrimSpace(filter.Query); q != "" {
|
||||
like := "%" + q + "%"
|
||||
conditions = append(conditions, "(message LIKE ? OR resource_id LIKE ? OR action LIKE ? OR category LIKE ?)")
|
||||
args = append(args, like, like, like, like)
|
||||
}
|
||||
return strings.Join(conditions, " AND "), args
|
||||
}
|
||||
|
||||
// AppendAuditLog inserts one audit row.
|
||||
func (db *DB) AppendAuditLog(row *AuditLog) error {
|
||||
if row == nil {
|
||||
return errors.New("audit log is nil")
|
||||
}
|
||||
if strings.TrimSpace(row.ID) == "" {
|
||||
return errors.New("audit id is required")
|
||||
}
|
||||
if row.CreatedAt.IsZero() {
|
||||
row.CreatedAt = time.Now()
|
||||
}
|
||||
if strings.TrimSpace(row.Level) == "" {
|
||||
row.Level = "info"
|
||||
}
|
||||
detailJSON := ""
|
||||
if len(row.Detail) > 0 {
|
||||
if b, err := json.Marshal(row.Detail); err == nil {
|
||||
detailJSON = string(b)
|
||||
}
|
||||
}
|
||||
query := `
|
||||
INSERT INTO audit_logs (
|
||||
id, created_at, level, category, action, result, actor, session_hint,
|
||||
client_ip, user_agent, resource_type, resource_id, message, detail_json
|
||||
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
`
|
||||
_, err := db.Exec(query,
|
||||
row.ID, row.CreatedAt, row.Level, row.Category, row.Action, row.Result,
|
||||
row.Actor, row.SessionHint, row.ClientIP, row.UserAgent,
|
||||
row.ResourceType, row.ResourceID, row.Message, detailJSON,
|
||||
)
|
||||
return err
|
||||
}
|
||||
|
||||
// GetAuditLogByID returns one row.
|
||||
func (db *DB) GetAuditLogByID(id string) (*AuditLog, error) {
|
||||
id = strings.TrimSpace(id)
|
||||
if id == "" {
|
||||
return nil, errors.New("id is required")
|
||||
}
|
||||
query := `
|
||||
SELECT id, created_at, level, category, action, result, actor,
|
||||
COALESCE(session_hint, ''), COALESCE(client_ip, ''), COALESCE(user_agent, ''),
|
||||
COALESCE(resource_type, ''), COALESCE(resource_id, ''), message, COALESCE(detail_json, '')
|
||||
FROM audit_logs WHERE id = ?
|
||||
`
|
||||
var row AuditLog
|
||||
var detailJSON string
|
||||
err := db.QueryRow(query, id).Scan(
|
||||
&row.ID, &row.CreatedAt, &row.Level, &row.Category, &row.Action, &row.Result, &row.Actor,
|
||||
&row.SessionHint, &row.ClientIP, &row.UserAgent,
|
||||
&row.ResourceType, &row.ResourceID, &row.Message, &detailJSON,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if detailJSON != "" {
|
||||
_ = json.Unmarshal([]byte(detailJSON), &row.Detail)
|
||||
}
|
||||
return &row, nil
|
||||
}
|
||||
|
||||
// CountAuditLogs counts rows matching filter.
|
||||
func (db *DB) CountAuditLogs(filter ListAuditLogsFilter) (int64, error) {
|
||||
where, args := buildAuditLogsWhere(filter)
|
||||
query := `SELECT COUNT(*) FROM audit_logs WHERE ` + where
|
||||
var n int64
|
||||
err := db.QueryRow(query, args...).Scan(&n)
|
||||
return n, err
|
||||
}
|
||||
|
||||
// ListAuditLogs lists audit rows newest first.
|
||||
func (db *DB) ListAuditLogs(filter ListAuditLogsFilter) ([]*AuditLog, error) {
|
||||
where, args := buildAuditLogsWhere(filter)
|
||||
limit := filter.Limit
|
||||
if limit <= 0 || limit > 500 {
|
||||
limit = 50
|
||||
}
|
||||
offset := filter.Offset
|
||||
if offset < 0 {
|
||||
offset = 0
|
||||
}
|
||||
query := `
|
||||
SELECT id, created_at, level, category, action, result, actor,
|
||||
COALESCE(session_hint, ''), COALESCE(client_ip, ''), COALESCE(user_agent, ''),
|
||||
COALESCE(resource_type, ''), COALESCE(resource_id, ''), message, COALESCE(detail_json, '')
|
||||
FROM audit_logs
|
||||
WHERE ` + where + `
|
||||
ORDER BY created_at DESC
|
||||
LIMIT ? OFFSET ?
|
||||
`
|
||||
args = append(args, limit, offset)
|
||||
rows, err := db.Query(query, args...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
var list []*AuditLog
|
||||
for rows.Next() {
|
||||
var row AuditLog
|
||||
var detailJSON string
|
||||
if err := rows.Scan(
|
||||
&row.ID, &row.CreatedAt, &row.Level, &row.Category, &row.Action, &row.Result, &row.Actor,
|
||||
&row.SessionHint, &row.ClientIP, &row.UserAgent,
|
||||
&row.ResourceType, &row.ResourceID, &row.Message, &detailJSON,
|
||||
); err != nil {
|
||||
continue
|
||||
}
|
||||
if detailJSON != "" {
|
||||
_ = json.Unmarshal([]byte(detailJSON), &row.Detail)
|
||||
}
|
||||
list = append(list, &row)
|
||||
}
|
||||
return list, rows.Err()
|
||||
}
|
||||
|
||||
// DeleteAuditLogsBefore removes rows older than cutoff.
|
||||
func (db *DB) DeleteAuditLogsBefore(cutoff time.Time) (int64, error) {
|
||||
res, err := db.Exec(`DELETE FROM audit_logs WHERE created_at < ?`, cutoff)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return res.RowsAffected()
|
||||
}
|
||||
@@ -25,23 +25,24 @@ type Conversation struct {
|
||||
|
||||
// Message 消息
|
||||
type Message struct {
|
||||
ID string `json:"id"`
|
||||
ConversationID string `json:"conversationId"`
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content"`
|
||||
MCPExecutionIDs []string `json:"mcpExecutionIds,omitempty"`
|
||||
ProcessDetails []map[string]interface{} `json:"processDetails,omitempty"`
|
||||
CreatedAt time.Time `json:"createdAt"`
|
||||
UpdatedAt time.Time `json:"updatedAt"`
|
||||
ID string `json:"id"`
|
||||
ConversationID string `json:"conversationId"`
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content"`
|
||||
ReasoningContent string `json:"reasoningContent,omitempty"`
|
||||
MCPExecutionIDs []string `json:"mcpExecutionIds,omitempty"`
|
||||
ProcessDetails []map[string]interface{} `json:"processDetails,omitempty"`
|
||||
CreatedAt time.Time `json:"createdAt"`
|
||||
UpdatedAt time.Time `json:"updatedAt"`
|
||||
}
|
||||
|
||||
// CreateConversation 创建新对话
|
||||
func (db *DB) CreateConversation(title string) (*Conversation, error) {
|
||||
return db.CreateConversationWithWebshell("", title)
|
||||
func (db *DB) CreateConversation(title string, meta ConversationCreateMeta) (*Conversation, error) {
|
||||
return db.CreateConversationWithWebshell("", title, meta)
|
||||
}
|
||||
|
||||
// CreateConversationWithWebshell 创建新对话,可选绑定 WebShell 连接 ID(为空则普通对话)
|
||||
func (db *DB) CreateConversationWithWebshell(webshellConnectionID, title string) (*Conversation, error) {
|
||||
func (db *DB) CreateConversationWithWebshell(webshellConnectionID, title string, meta ConversationCreateMeta) (*Conversation, error) {
|
||||
id := uuid.New().String()
|
||||
now := time.Now()
|
||||
|
||||
@@ -61,12 +62,17 @@ func (db *DB) CreateConversationWithWebshell(webshellConnectionID, title string)
|
||||
return nil, fmt.Errorf("创建对话失败: %w", err)
|
||||
}
|
||||
|
||||
return &Conversation{
|
||||
conv := &Conversation{
|
||||
ID: id,
|
||||
Title: title,
|
||||
CreatedAt: now,
|
||||
UpdatedAt: now,
|
||||
}, nil
|
||||
}
|
||||
if webshellConnectionID != "" {
|
||||
meta.WebShellConnectionID = webshellConnectionID
|
||||
}
|
||||
notifyConversationCreated(conv, meta)
|
||||
return conv, nil
|
||||
}
|
||||
|
||||
// GetConversationByWebshellConnectionID 根据 WebShell 连接 ID 获取该连接下最近一条对话(用于 AI 助手持久化)
|
||||
@@ -116,6 +122,7 @@ func (db *DB) GetConversationByWebshellConnectionID(connectionID string) (*Conve
|
||||
}
|
||||
for i := range conv.Messages {
|
||||
if details, ok := processDetailsMap[conv.Messages[i].ID]; ok {
|
||||
details = DedupeConsecutiveProcessDetails(details)
|
||||
detailsJSON := make([]map[string]interface{}, len(details))
|
||||
for j, detail := range details {
|
||||
var data interface{}
|
||||
@@ -180,6 +187,23 @@ func (db *DB) ListConversationsByWebshellConnectionID(connectionID string) ([]We
|
||||
return list, rows.Err()
|
||||
}
|
||||
|
||||
// ConversationExists reports whether a conversation row exists (lightweight check for audit links).
|
||||
func (db *DB) ConversationExists(id string) (bool, error) {
|
||||
id = strings.TrimSpace(id)
|
||||
if id == "" {
|
||||
return false, nil
|
||||
}
|
||||
var one int
|
||||
err := db.QueryRow("SELECT 1 FROM conversations WHERE id = ? LIMIT 1", id).Scan(&one)
|
||||
if err == sql.ErrNoRows {
|
||||
return false, nil
|
||||
}
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return true, nil
|
||||
}
|
||||
|
||||
// GetConversation 获取对话
|
||||
func (db *DB) GetConversation(id string) (*Conversation, error) {
|
||||
var conv Conversation
|
||||
@@ -234,6 +258,7 @@ func (db *DB) GetConversation(id string) (*Conversation, error) {
|
||||
// 将过程详情附加到对应的消息上
|
||||
for i := range conv.Messages {
|
||||
if details, ok := processDetailsMap[conv.Messages[i].ID]; ok {
|
||||
details = DedupeConsecutiveProcessDetails(details)
|
||||
// 将ProcessDetail转换为JSON格式,以便前端使用
|
||||
detailsJSON := make([]map[string]interface{}, len(details))
|
||||
for j, detail := range details {
|
||||
@@ -498,8 +523,8 @@ func (db *DB) AddMessage(conversationID, role, content string, mcpExecutionIDs [
|
||||
}
|
||||
|
||||
_, err := db.Exec(
|
||||
"INSERT INTO messages (id, conversation_id, role, content, mcp_execution_ids, created_at, updated_at) VALUES (?, ?, ?, ?, ?, ?, ?)",
|
||||
id, conversationID, role, content, mcpIDsJSON, now, now,
|
||||
"INSERT INTO messages (id, conversation_id, role, content, reasoning_content, mcp_execution_ids, created_at, updated_at) VALUES (?, ?, ?, ?, ?, ?, ?, ?)",
|
||||
id, conversationID, role, content, "", mcpIDsJSON, now, now,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("添加消息失败: %w", err)
|
||||
@@ -523,10 +548,30 @@ func (db *DB) AddMessage(conversationID, role, content string, mcpExecutionIDs [
|
||||
return message, nil
|
||||
}
|
||||
|
||||
// UpdateAssistantMessageFinalize 更新助手消息终态(正文、MCP id、思考链聚合文本,供无轨迹回退时回放)。
|
||||
func (db *DB) UpdateAssistantMessageFinalize(messageID, content string, mcpExecutionIDs []string, reasoningContent string) error {
|
||||
var mcpIDsJSON string
|
||||
if len(mcpExecutionIDs) > 0 {
|
||||
jsonData, err := json.Marshal(mcpExecutionIDs)
|
||||
if err != nil {
|
||||
return fmt.Errorf("序列化MCP执行ID失败: %w", err)
|
||||
}
|
||||
mcpIDsJSON = string(jsonData)
|
||||
}
|
||||
_, err := db.Exec(
|
||||
"UPDATE messages SET content = ?, mcp_execution_ids = ?, reasoning_content = ?, updated_at = ? WHERE id = ?",
|
||||
content, mcpIDsJSON, strings.TrimSpace(reasoningContent), time.Now(), messageID,
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("更新助手消息失败: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetMessages 获取对话的所有消息
|
||||
func (db *DB) GetMessages(conversationID string) ([]Message, error) {
|
||||
rows, err := db.Query(
|
||||
"SELECT id, conversation_id, role, content, mcp_execution_ids, created_at, updated_at FROM messages WHERE conversation_id = ? ORDER BY created_at ASC",
|
||||
"SELECT id, conversation_id, role, content, reasoning_content, mcp_execution_ids, created_at, updated_at FROM messages WHERE conversation_id = ? ORDER BY created_at ASC",
|
||||
conversationID,
|
||||
)
|
||||
if err != nil {
|
||||
@@ -537,13 +582,17 @@ func (db *DB) GetMessages(conversationID string) ([]Message, error) {
|
||||
var messages []Message
|
||||
for rows.Next() {
|
||||
var msg Message
|
||||
var reasoning sql.NullString
|
||||
var mcpIDsJSON sql.NullString
|
||||
var createdAt string
|
||||
var updatedAt sql.NullString
|
||||
|
||||
if err := rows.Scan(&msg.ID, &msg.ConversationID, &msg.Role, &msg.Content, &mcpIDsJSON, &createdAt, &updatedAt); err != nil {
|
||||
if err := rows.Scan(&msg.ID, &msg.ConversationID, &msg.Role, &msg.Content, &reasoning, &mcpIDsJSON, &createdAt, &updatedAt); err != nil {
|
||||
return nil, fmt.Errorf("扫描消息失败: %w", err)
|
||||
}
|
||||
if reasoning.Valid {
|
||||
msg.ReasoningContent = reasoning.String
|
||||
}
|
||||
|
||||
// 尝试多种时间格式解析
|
||||
var err error
|
||||
@@ -683,7 +732,7 @@ type ProcessDetail struct {
|
||||
ID string `json:"id"`
|
||||
MessageID string `json:"messageId"`
|
||||
ConversationID string `json:"conversationId"`
|
||||
EventType string `json:"eventType"` // iteration, thinking, tool_calls_detected, tool_call, tool_result, progress, error
|
||||
EventType string `json:"eventType"` // iteration, thinking, reasoning_chain, tool_calls_detected, tool_call, tool_result, progress, error
|
||||
Message string `json:"message"`
|
||||
Data string `json:"data"` // JSON格式的数据
|
||||
CreatedAt time.Time `json:"createdAt"`
|
||||
|
||||
@@ -0,0 +1,29 @@
|
||||
package database
|
||||
|
||||
// ConversationCreateMeta describes how a conversation was created (for audit hooks).
|
||||
type ConversationCreateMeta struct {
|
||||
Source string
|
||||
WebShellConnectionID string
|
||||
ClientIP string
|
||||
SessionHint string
|
||||
}
|
||||
|
||||
// ConversationCreateHook is invoked after a conversation row is inserted.
|
||||
type ConversationCreateHook func(conv *Conversation, meta ConversationCreateMeta)
|
||||
|
||||
var conversationCreateHook ConversationCreateHook
|
||||
|
||||
// SetConversationCreateHook registers a global hook (e.g. platform audit).
|
||||
func SetConversationCreateHook(h ConversationCreateHook) {
|
||||
conversationCreateHook = h
|
||||
}
|
||||
|
||||
func notifyConversationCreated(conv *Conversation, meta ConversationCreateMeta) {
|
||||
if conversationCreateHook == nil || conv == nil {
|
||||
return
|
||||
}
|
||||
if meta.Source == "" {
|
||||
meta.Source = "unknown"
|
||||
}
|
||||
conversationCreateHook(conv, meta)
|
||||
}
|
||||
@@ -387,6 +387,24 @@ func (db *DB) initTables() error {
|
||||
created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP
|
||||
);`
|
||||
|
||||
createAuditLogsTable := `
|
||||
CREATE TABLE IF NOT EXISTS audit_logs (
|
||||
id TEXT PRIMARY KEY,
|
||||
created_at DATETIME NOT NULL,
|
||||
level TEXT NOT NULL DEFAULT 'info',
|
||||
category TEXT NOT NULL,
|
||||
action TEXT NOT NULL,
|
||||
result TEXT NOT NULL,
|
||||
actor TEXT NOT NULL DEFAULT 'admin',
|
||||
session_hint TEXT,
|
||||
client_ip TEXT,
|
||||
user_agent TEXT,
|
||||
resource_type TEXT,
|
||||
resource_id TEXT,
|
||||
message TEXT NOT NULL,
|
||||
detail_json TEXT
|
||||
);`
|
||||
|
||||
createC2ProfilesTable := `
|
||||
CREATE TABLE IF NOT EXISTS c2_profiles (
|
||||
id TEXT PRIMARY KEY,
|
||||
@@ -445,6 +463,10 @@ func (db *DB) initTables() error {
|
||||
CREATE INDEX IF NOT EXISTS idx_c2_events_created_at ON c2_events(created_at);
|
||||
CREATE INDEX IF NOT EXISTS idx_c2_events_category ON c2_events(category);
|
||||
CREATE INDEX IF NOT EXISTS idx_c2_events_session ON c2_events(session_id);
|
||||
CREATE INDEX IF NOT EXISTS idx_audit_logs_created_at ON audit_logs(created_at);
|
||||
CREATE INDEX IF NOT EXISTS idx_audit_logs_category ON audit_logs(category);
|
||||
CREATE INDEX IF NOT EXISTS idx_audit_logs_action ON audit_logs(action);
|
||||
CREATE INDEX IF NOT EXISTS idx_audit_logs_result ON audit_logs(result);
|
||||
`
|
||||
|
||||
if _, err := db.Exec(createConversationsTable); err != nil {
|
||||
@@ -514,6 +536,10 @@ func (db *DB) initTables() error {
|
||||
return fmt.Errorf("创建webshell_connection_states表失败: %w", err)
|
||||
}
|
||||
|
||||
if _, err := db.Exec(createAuditLogsTable); err != nil {
|
||||
return fmt.Errorf("创建audit_logs表失败: %w", err)
|
||||
}
|
||||
|
||||
for tableName, ddl := range map[string]string{
|
||||
"c2_listeners": createC2ListenersTable,
|
||||
"c2_sessions": createC2SessionsTable,
|
||||
@@ -594,6 +620,25 @@ func (db *DB) migrateMessagesTable() error {
|
||||
|
||||
// 回填已有数据:让 updated_at 至少等于 created_at,避免前端出现空/当前时间回退。
|
||||
_, _ = db.Exec("UPDATE messages SET updated_at = created_at WHERE updated_at IS NULL OR updated_at = ''")
|
||||
|
||||
// reasoning_content:DeepSeek 思考模式 + 工具调用续跑;与 last_react_input 互补,供消息表回退路径回放
|
||||
var rcColCount int
|
||||
errRC := db.QueryRow("SELECT COUNT(*) FROM pragma_table_info('messages') WHERE name='reasoning_content'").Scan(&rcColCount)
|
||||
if errRC != nil {
|
||||
if _, addErr := db.Exec("ALTER TABLE messages ADD COLUMN reasoning_content TEXT"); addErr != nil {
|
||||
errMsg := strings.ToLower(addErr.Error())
|
||||
if !strings.Contains(errMsg, "duplicate column") && !strings.Contains(errMsg, "already exists") {
|
||||
return fmt.Errorf("添加 messages.reasoning_content 字段失败: %w", addErr)
|
||||
}
|
||||
}
|
||||
} else if rcColCount == 0 {
|
||||
if _, err := db.Exec("ALTER TABLE messages ADD COLUMN reasoning_content TEXT"); err != nil {
|
||||
errMsg := strings.ToLower(err.Error())
|
||||
if !strings.Contains(errMsg, "duplicate column") && !strings.Contains(errMsg, "already exists") {
|
||||
return fmt.Errorf("添加 messages.reasoning_content 字段失败: %w", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
@@ -0,0 +1,28 @@
|
||||
package database
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// DedupeConsecutiveProcessDetails 去掉相邻且语义相同的过程详情(使用 DB 中 data 列原始 JSON 作指纹,避免 map 序列化键序不稳定)。
|
||||
func DedupeConsecutiveProcessDetails(rows []ProcessDetail) []ProcessDetail {
|
||||
if len(rows) < 2 {
|
||||
return rows
|
||||
}
|
||||
out := make([]ProcessDetail, 0, len(rows))
|
||||
var lastKey string
|
||||
for _, d := range rows {
|
||||
key := processDetailRowKey(d)
|
||||
if len(out) > 0 && key != "" && key == lastKey {
|
||||
continue
|
||||
}
|
||||
out = append(out, d)
|
||||
lastKey = key
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func processDetailRowKey(d ProcessDetail) string {
|
||||
return fmt.Sprintf("%s\x00%s\x00%s", d.EventType, strings.TrimSpace(d.Message), d.Data)
|
||||
}
|
||||
@@ -3,12 +3,84 @@ package database
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// VulnerabilityListFilter 列表/统计/导出共用的筛选条件
|
||||
type VulnerabilityListFilter struct {
|
||||
ID string
|
||||
Search string // 关键词模糊匹配(标题、描述、类型、目标等)
|
||||
ConversationID string
|
||||
Severity string
|
||||
Status string
|
||||
TaskID string
|
||||
ConversationTag string
|
||||
TaskTag string
|
||||
}
|
||||
|
||||
func escapeVulnerabilityLikePattern(s string) string {
|
||||
s = strings.ReplaceAll(s, `\`, `\\`)
|
||||
s = strings.ReplaceAll(s, `%`, `\%`)
|
||||
s = strings.ReplaceAll(s, `_`, `\_`)
|
||||
return "%" + s + "%"
|
||||
}
|
||||
|
||||
func (f VulnerabilityListFilter) appendWhere(query string, args []interface{}) (string, []interface{}) {
|
||||
if f.ID != "" {
|
||||
query += " AND id = ?"
|
||||
args = append(args, f.ID)
|
||||
}
|
||||
if f.ConversationID != "" {
|
||||
query += " AND conversation_id = ?"
|
||||
args = append(args, f.ConversationID)
|
||||
}
|
||||
if f.TaskID != "" {
|
||||
query += " AND EXISTS (SELECT 1 FROM batch_tasks bt WHERE bt.conversation_id = vulnerabilities.conversation_id AND (bt.id = ? OR bt.queue_id = ?))"
|
||||
args = append(args, f.TaskID, f.TaskID)
|
||||
}
|
||||
if f.ConversationTag != "" {
|
||||
query += " AND conversation_tag = ?"
|
||||
args = append(args, f.ConversationTag)
|
||||
}
|
||||
if f.TaskTag != "" {
|
||||
query += " AND task_tag = ?"
|
||||
args = append(args, f.TaskTag)
|
||||
}
|
||||
if f.Severity != "" {
|
||||
query += " AND severity = ?"
|
||||
args = append(args, f.Severity)
|
||||
}
|
||||
if f.Status != "" {
|
||||
query += " AND status = ?"
|
||||
args = append(args, f.Status)
|
||||
}
|
||||
search := strings.TrimSpace(f.Search)
|
||||
if search != "" {
|
||||
pattern := escapeVulnerabilityLikePattern(search)
|
||||
query += ` AND (
|
||||
LOWER(id) LIKE LOWER(?) OR
|
||||
LOWER(title) LIKE LOWER(?) OR
|
||||
LOWER(COALESCE(description, '')) LIKE LOWER(?) OR
|
||||
LOWER(COALESCE(vulnerability_type, '')) LIKE LOWER(?) OR
|
||||
LOWER(COALESCE(target, '')) LIKE LOWER(?) OR
|
||||
LOWER(COALESCE(proof, '')) LIKE LOWER(?) OR
|
||||
LOWER(COALESCE(impact, '')) LIKE LOWER(?) OR
|
||||
LOWER(COALESCE(recommendation, '')) LIKE LOWER(?) OR
|
||||
LOWER(COALESCE(conversation_id, '')) LIKE LOWER(?) OR
|
||||
LOWER(COALESCE(conversation_tag, '')) LIKE LOWER(?) OR
|
||||
LOWER(COALESCE(task_tag, '')) LIKE LOWER(?)
|
||||
)`
|
||||
for i := 0; i < 11; i++ {
|
||||
args = append(args, pattern)
|
||||
}
|
||||
}
|
||||
return query, args
|
||||
}
|
||||
|
||||
// Vulnerability 漏洞
|
||||
type Vulnerability struct {
|
||||
ID string `json:"id"`
|
||||
@@ -97,7 +169,7 @@ func (db *DB) GetVulnerability(id string) (*Vulnerability, error) {
|
||||
}
|
||||
|
||||
// ListVulnerabilities 列出漏洞
|
||||
func (db *DB) ListVulnerabilities(limit, offset int, id, conversationID, severity, status, taskID, conversationTag, taskTag string) ([]*Vulnerability, error) {
|
||||
func (db *DB) ListVulnerabilities(limit, offset int, filter VulnerabilityListFilter) ([]*Vulnerability, error) {
|
||||
query := `
|
||||
SELECT id, conversation_id, title, description, severity, status, conversation_tag, task_tag,
|
||||
vulnerability_type, target, proof, impact, recommendation,
|
||||
@@ -108,35 +180,7 @@ func (db *DB) ListVulnerabilities(limit, offset int, id, conversationID, severit
|
||||
WHERE 1=1
|
||||
`
|
||||
args := []interface{}{}
|
||||
|
||||
if id != "" {
|
||||
query += " AND id = ?"
|
||||
args = append(args, id)
|
||||
}
|
||||
if conversationID != "" {
|
||||
query += " AND conversation_id = ?"
|
||||
args = append(args, conversationID)
|
||||
}
|
||||
if taskID != "" {
|
||||
query += " AND EXISTS (SELECT 1 FROM batch_tasks bt WHERE bt.conversation_id = vulnerabilities.conversation_id AND (bt.id = ? OR bt.queue_id = ?))"
|
||||
args = append(args, taskID, taskID)
|
||||
}
|
||||
if conversationTag != "" {
|
||||
query += " AND conversation_tag = ?"
|
||||
args = append(args, conversationTag)
|
||||
}
|
||||
if taskTag != "" {
|
||||
query += " AND task_tag = ?"
|
||||
args = append(args, taskTag)
|
||||
}
|
||||
if severity != "" {
|
||||
query += " AND severity = ?"
|
||||
args = append(args, severity)
|
||||
}
|
||||
if status != "" {
|
||||
query += " AND status = ?"
|
||||
args = append(args, status)
|
||||
}
|
||||
query, args = filter.appendWhere(query, args)
|
||||
|
||||
query += " ORDER BY created_at DESC LIMIT ? OFFSET ?"
|
||||
args = append(args, limit, offset)
|
||||
@@ -168,38 +212,10 @@ func (db *DB) ListVulnerabilities(limit, offset int, id, conversationID, severit
|
||||
}
|
||||
|
||||
// CountVulnerabilities 统计漏洞总数(支持筛选条件)
|
||||
func (db *DB) CountVulnerabilities(id, conversationID, severity, status, taskID, conversationTag, taskTag string) (int, error) {
|
||||
func (db *DB) CountVulnerabilities(filter VulnerabilityListFilter) (int, error) {
|
||||
query := "SELECT COUNT(*) FROM vulnerabilities WHERE 1=1"
|
||||
args := []interface{}{}
|
||||
|
||||
if id != "" {
|
||||
query += " AND id = ?"
|
||||
args = append(args, id)
|
||||
}
|
||||
if conversationID != "" {
|
||||
query += " AND conversation_id = ?"
|
||||
args = append(args, conversationID)
|
||||
}
|
||||
if taskID != "" {
|
||||
query += " AND EXISTS (SELECT 1 FROM batch_tasks bt WHERE bt.conversation_id = vulnerabilities.conversation_id AND (bt.id = ? OR bt.queue_id = ?))"
|
||||
args = append(args, taskID, taskID)
|
||||
}
|
||||
if conversationTag != "" {
|
||||
query += " AND conversation_tag = ?"
|
||||
args = append(args, conversationTag)
|
||||
}
|
||||
if taskTag != "" {
|
||||
query += " AND task_tag = ?"
|
||||
args = append(args, taskTag)
|
||||
}
|
||||
if severity != "" {
|
||||
query += " AND severity = ?"
|
||||
args = append(args, severity)
|
||||
}
|
||||
if status != "" {
|
||||
query += " AND status = ?"
|
||||
args = append(args, status)
|
||||
}
|
||||
query, args = filter.appendWhere(query, args)
|
||||
|
||||
var count int
|
||||
err := db.QueryRow(query, args...).Scan(&count)
|
||||
@@ -245,19 +261,12 @@ func (db *DB) DeleteVulnerability(id string) error {
|
||||
}
|
||||
|
||||
// GetVulnerabilityStats 获取漏洞统计(筛选条件与 ListVulnerabilities / CountVulnerabilities 一致)
|
||||
func (db *DB) GetVulnerabilityStats(conversationID, taskID string) (map[string]interface{}, error) {
|
||||
func (db *DB) GetVulnerabilityStats(filter VulnerabilityListFilter) (map[string]interface{}, error) {
|
||||
stats := make(map[string]interface{})
|
||||
|
||||
where := "WHERE 1=1"
|
||||
args := []interface{}{}
|
||||
if conversationID != "" {
|
||||
where += " AND conversation_id = ?"
|
||||
args = append(args, conversationID)
|
||||
}
|
||||
if taskID != "" {
|
||||
where += " AND EXISTS (SELECT 1 FROM batch_tasks bt WHERE bt.conversation_id = vulnerabilities.conversation_id AND (bt.id = ? OR bt.queue_id = ?))"
|
||||
args = append(args, taskID, taskID)
|
||||
}
|
||||
where, args = filter.appendWhere(where, args)
|
||||
|
||||
// 总漏洞数
|
||||
var totalCount int
|
||||
|
||||
@@ -23,12 +23,16 @@ type ExecutionRecorder func(executionID string)
|
||||
const ToolErrorPrefix = "__CYBERSTRIKE_AI_TOOL_ERROR__\n"
|
||||
|
||||
// ToolsFromDefinitions 将单 Agent 使用的 OpenAI 风格工具定义转为 Eino InvokableTool,执行时走 Agent 的 MCP 路径。
|
||||
// invokeNotify 可选:与 runEinoADKAgentLoop 共享,在 InvokableRun 返回时触发 UI 与 pending 清理(与 ADK Tool 事件去重)。
|
||||
// einoAgentName 为该套工具所属 ChatModelAgent 的 Name(主代理或子代理 id),用于 SSE 上的 einoAgent 字段。
|
||||
func ToolsFromDefinitions(
|
||||
ag *agent.Agent,
|
||||
holder *ConversationHolder,
|
||||
defs []agent.Tool,
|
||||
rec ExecutionRecorder,
|
||||
toolOutputChunk func(toolName, toolCallID, chunk string),
|
||||
invokeNotify *ToolInvokeNotifyHolder,
|
||||
einoAgentName string,
|
||||
) ([]tool.BaseTool, error) {
|
||||
out := make([]tool.BaseTool, 0, len(defs))
|
||||
for _, d := range defs {
|
||||
@@ -40,12 +44,14 @@ func ToolsFromDefinitions(
|
||||
return nil, fmt.Errorf("tool %q: %w", d.Function.Name, err)
|
||||
}
|
||||
out = append(out, &mcpBridgeTool{
|
||||
info: info,
|
||||
name: d.Function.Name,
|
||||
agent: ag,
|
||||
holder: holder,
|
||||
record: rec,
|
||||
chunk: toolOutputChunk,
|
||||
info: info,
|
||||
name: d.Function.Name,
|
||||
agent: ag,
|
||||
holder: holder,
|
||||
record: rec,
|
||||
chunk: toolOutputChunk,
|
||||
invokeNotify: invokeNotify,
|
||||
einoAgentName: strings.TrimSpace(einoAgentName),
|
||||
})
|
||||
}
|
||||
return out, nil
|
||||
@@ -77,12 +83,14 @@ func toolInfoFromDefinition(d agent.Tool) (*schema.ToolInfo, error) {
|
||||
}
|
||||
|
||||
type mcpBridgeTool struct {
|
||||
info *schema.ToolInfo
|
||||
name string
|
||||
agent *agent.Agent
|
||||
holder *ConversationHolder
|
||||
record ExecutionRecorder
|
||||
chunk func(toolName, toolCallID, chunk string)
|
||||
info *schema.ToolInfo
|
||||
name string
|
||||
agent *agent.Agent
|
||||
holder *ConversationHolder
|
||||
record ExecutionRecorder
|
||||
chunk func(toolName, toolCallID, chunk string)
|
||||
invokeNotify *ToolInvokeNotifyHolder
|
||||
einoAgentName string
|
||||
}
|
||||
|
||||
func (m *mcpBridgeTool) Info(ctx context.Context) (*schema.ToolInfo, error) {
|
||||
@@ -90,8 +98,27 @@ func (m *mcpBridgeTool) Info(ctx context.Context) (*schema.ToolInfo, error) {
|
||||
return m.info, nil
|
||||
}
|
||||
|
||||
func (m *mcpBridgeTool) InvokableRun(ctx context.Context, argumentsInJSON string, opts ...tool.Option) (string, error) {
|
||||
func (m *mcpBridgeTool) InvokableRun(ctx context.Context, argumentsInJSON string, opts ...tool.Option) (out string, err error) {
|
||||
_ = opts
|
||||
toolCallID := compose.GetToolCallID(ctx)
|
||||
defer func() {
|
||||
if m.invokeNotify == nil {
|
||||
return
|
||||
}
|
||||
tid := strings.TrimSpace(toolCallID)
|
||||
if tid == "" {
|
||||
return
|
||||
}
|
||||
success := err == nil && !strings.HasPrefix(out, ToolErrorPrefix)
|
||||
body := out
|
||||
if err != nil {
|
||||
success = false
|
||||
} else if strings.HasPrefix(out, ToolErrorPrefix) {
|
||||
success = false
|
||||
body = strings.TrimPrefix(out, ToolErrorPrefix)
|
||||
}
|
||||
m.invokeNotify.Fire(tid, m.name, m.einoAgentName, success, body, err)
|
||||
}()
|
||||
return runMCPToolInvocation(ctx, m.agent, m.holder, m.name, argumentsInJSON, m.record, m.chunk)
|
||||
}
|
||||
|
||||
|
||||
@@ -0,0 +1,39 @@
|
||||
package einomcp
|
||||
|
||||
import "sync"
|
||||
|
||||
// ToolInvokeNotifyHolder 由 Eino run loop 在迭代开始前 Set 回调;MCP 桥在每次 InvokableRun 结束时 Fire,
|
||||
// 用于在 ADK 未透出 schema.Tool 事件时仍推送 tool_result、清 pending,避免 UI 卡在「执行中」或迭代末 force-close。
|
||||
type ToolInvokeNotifyHolder struct {
|
||||
mu sync.RWMutex
|
||||
fn func(toolCallID, toolName, einoAgent string, success bool, content string, invokeErr error)
|
||||
}
|
||||
|
||||
// NewToolInvokeNotifyHolder 创建可在 ToolsFromDefinitions 与 run loop 之间共享的 holder。
|
||||
func NewToolInvokeNotifyHolder() *ToolInvokeNotifyHolder {
|
||||
return &ToolInvokeNotifyHolder{}
|
||||
}
|
||||
|
||||
// Set 由 runEinoADKAgentLoop 在开始消费 iter 之前调用;可多次覆盖(通常仅一次)。
|
||||
func (h *ToolInvokeNotifyHolder) Set(fn func(toolCallID, toolName, einoAgent string, success bool, content string, invokeErr error)) {
|
||||
if h == nil {
|
||||
return
|
||||
}
|
||||
h.mu.Lock()
|
||||
defer h.mu.Unlock()
|
||||
h.fn = fn
|
||||
}
|
||||
|
||||
// Fire 由 mcpBridgeTool 在工具调用返回时调用;若尚未 Set 或 toolCallID 为空则忽略。
|
||||
func (h *ToolInvokeNotifyHolder) Fire(toolCallID, toolName, einoAgent string, success bool, content string, invokeErr error) {
|
||||
if h == nil {
|
||||
return
|
||||
}
|
||||
h.mu.RLock()
|
||||
fn := h.fn
|
||||
h.mu.RUnlock()
|
||||
if fn == nil {
|
||||
return
|
||||
}
|
||||
fn(toolCallID, toolName, einoAgent, success, content, invokeErr)
|
||||
}
|
||||
@@ -0,0 +1,435 @@
|
||||
// Package einoobserve attaches CloudWeGo Eino [callbacks.Handler] to ADK Runner contexts for
|
||||
// structured logging and optional SSE trace events (eino_trace_*).
|
||||
package einoobserve
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"cyberstrike-ai/internal/config"
|
||||
|
||||
"github.com/cloudwego/eino/adk"
|
||||
"github.com/cloudwego/eino/callbacks"
|
||||
"github.com/cloudwego/eino/components"
|
||||
"github.com/cloudwego/eino/components/model"
|
||||
"github.com/cloudwego/eino/components/tool"
|
||||
"github.com/cloudwego/eino/schema"
|
||||
"github.com/google/uuid"
|
||||
"go.opentelemetry.io/otel"
|
||||
"go.opentelemetry.io/otel/attribute"
|
||||
"go.opentelemetry.io/otel/codes"
|
||||
"go.opentelemetry.io/otel/trace"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
type ctxSpanKey struct{}
|
||||
|
||||
type ctxOtelSpanKey struct{}
|
||||
|
||||
// Params for attaching per-run callback instrumentation.
|
||||
type Params struct {
|
||||
Logger *zap.Logger
|
||||
Progress func(eventType, message string, data interface{})
|
||||
ConversationID string
|
||||
OrchMode string
|
||||
OrchestratorName string
|
||||
}
|
||||
|
||||
// AttachAgentRunCallbacks returns ctx wrapped with callbacks.InitCallbacks when enabled.
|
||||
// Safe to call with nil cfg or disabled cfg (returns ctx unchanged).
|
||||
func AttachAgentRunCallbacks(ctx context.Context, cfg *config.MultiAgentEinoCallbacksConfig, p Params) context.Context {
|
||||
if ctx == nil {
|
||||
return ctx
|
||||
}
|
||||
if cfg == nil || !cfg.Enabled {
|
||||
return ctx
|
||||
}
|
||||
mode := cfg.EinoCallbacksModeEffective()
|
||||
if mode == "off" {
|
||||
return ctx
|
||||
}
|
||||
runID := uuid.New().String()
|
||||
if p.Progress != nil && cfg.ShouldEmitEinoTraceSSE(mode) {
|
||||
p.Progress("eino_trace_run", "Eino callbacks session", map[string]interface{}{
|
||||
"runId": runID,
|
||||
"conversationId": strings.TrimSpace(p.ConversationID),
|
||||
"orchestration": strings.TrimSpace(p.OrchMode),
|
||||
"orchestratorName": strings.TrimSpace(p.OrchestratorName),
|
||||
"observeMode": mode,
|
||||
"source": "eino_callbacks",
|
||||
})
|
||||
}
|
||||
h := &runHandler{
|
||||
cfg: *cfg,
|
||||
mode: mode,
|
||||
params: p,
|
||||
runID: runID,
|
||||
}
|
||||
b := callbacks.NewHandlerBuilder().
|
||||
OnStartFn(h.onStart).
|
||||
OnEndFn(h.onEnd).
|
||||
OnErrorFn(h.onError)
|
||||
if mode == "full" {
|
||||
b = b.OnStartWithStreamInputFn(h.onStartStreamIn).OnEndWithStreamOutputFn(h.onEndStreamOut)
|
||||
}
|
||||
ri := &callbacks.RunInfo{
|
||||
Name: "CyberStrikeADKRun",
|
||||
Type: strings.TrimSpace(p.OrchMode),
|
||||
Component: components.Component("AgentSession"),
|
||||
}
|
||||
return callbacks.InitCallbacks(ctx, ri, b.Build())
|
||||
}
|
||||
|
||||
type runHandler struct {
|
||||
cfg config.MultiAgentEinoCallbacksConfig
|
||||
mode string
|
||||
params Params
|
||||
runID string
|
||||
|
||||
mu sync.Mutex
|
||||
spanStack []string
|
||||
seq atomic.Uint64
|
||||
}
|
||||
|
||||
func (h *runHandler) genSpanID() string {
|
||||
return fmt.Sprintf("%s-%d", h.runID, h.seq.Add(1))
|
||||
}
|
||||
|
||||
func (h *runHandler) popSpan() (id string) {
|
||||
h.mu.Lock()
|
||||
defer h.mu.Unlock()
|
||||
if len(h.spanStack) == 0 {
|
||||
return ""
|
||||
}
|
||||
id = h.spanStack[len(h.spanStack)-1]
|
||||
h.spanStack = h.spanStack[:len(h.spanStack)-1]
|
||||
return id
|
||||
}
|
||||
|
||||
// popMatching removes the given id from the stack top if it matches; otherwise pops until empty or match (rare ordering mismatch).
|
||||
func (h *runHandler) popMatching(want string) string {
|
||||
h.mu.Lock()
|
||||
defer h.mu.Unlock()
|
||||
if want == "" {
|
||||
if len(h.spanStack) == 0 {
|
||||
return ""
|
||||
}
|
||||
id := h.spanStack[len(h.spanStack)-1]
|
||||
h.spanStack = h.spanStack[:len(h.spanStack)-1]
|
||||
return id
|
||||
}
|
||||
for len(h.spanStack) > 0 {
|
||||
top := h.spanStack[len(h.spanStack)-1]
|
||||
h.spanStack = h.spanStack[:len(h.spanStack)-1]
|
||||
if top == want {
|
||||
return top
|
||||
}
|
||||
}
|
||||
return want
|
||||
}
|
||||
|
||||
func (h *runHandler) onStart(ctx context.Context, info *callbacks.RunInfo, input callbacks.CallbackInput) context.Context {
|
||||
var parentID string
|
||||
h.mu.Lock()
|
||||
if len(h.spanStack) > 0 {
|
||||
parentID = h.spanStack[len(h.spanStack)-1]
|
||||
}
|
||||
spanID := h.genSpanID()
|
||||
h.spanStack = append(h.spanStack, spanID)
|
||||
h.mu.Unlock()
|
||||
|
||||
inSum := summarizeCallbackInput(input, h.cfg.EinoCallbacksMaxInputSummaryRunes())
|
||||
if h.cfg.OtelTracingActive() {
|
||||
tracer := otel.Tracer("cyberstrike/eino")
|
||||
spanName := callbackSpanName(info)
|
||||
var sp trace.Span
|
||||
ctx, sp = tracer.Start(ctx, spanName,
|
||||
trace.WithSpanKind(trace.SpanKindInternal),
|
||||
trace.WithAttributes(
|
||||
attribute.String("eino.component", string(info.Component)),
|
||||
attribute.String("eino.name", info.Name),
|
||||
attribute.String("eino.type", info.Type),
|
||||
attribute.String("cyberstrike.run_id", h.runID),
|
||||
attribute.String("cyberstrike.conversation_id", strings.TrimSpace(h.params.ConversationID)),
|
||||
attribute.String("cyberstrike.orchestration", strings.TrimSpace(h.params.OrchMode)),
|
||||
),
|
||||
)
|
||||
if inSum != "" {
|
||||
sp.SetAttributes(attribute.String("eino.input.summary", truncateForAttr(inSum, 256)))
|
||||
}
|
||||
ctx = context.WithValue(ctx, ctxOtelSpanKey{}, sp)
|
||||
}
|
||||
if h.params.Logger != nil {
|
||||
fields := []zap.Field{
|
||||
zap.String("runId", h.runID),
|
||||
zap.String("spanId", spanID),
|
||||
zap.String("parentSpanId", parentID),
|
||||
zap.String("component", string(info.Component)),
|
||||
zap.String("name", info.Name),
|
||||
zap.String("type", info.Type),
|
||||
zap.String("phase", "start"),
|
||||
}
|
||||
if sp, ok := ctx.Value(ctxOtelSpanKey{}).(trace.Span); ok && sp != nil {
|
||||
if sc := sp.SpanContext(); sc.IsValid() {
|
||||
fields = append(fields,
|
||||
zap.String("trace_id", sc.TraceID().String()),
|
||||
zap.String("otel_span_id", sc.SpanID().String()),
|
||||
)
|
||||
}
|
||||
}
|
||||
if h.cfg.ZapVerbose {
|
||||
h.params.Logger.Debug("eino_callback", append(fields, zap.String("inputSummary", inSum))...)
|
||||
} else {
|
||||
h.params.Logger.Info("eino_callback", fields...)
|
||||
}
|
||||
}
|
||||
if h.params.Progress != nil && h.cfg.ShouldEmitEinoTraceSSE(h.mode) {
|
||||
h.params.Progress("eino_trace_start", "", map[string]interface{}{
|
||||
"runId": h.runID,
|
||||
"spanId": spanID,
|
||||
"parentSpanId": parentID,
|
||||
"conversationId": strings.TrimSpace(h.params.ConversationID),
|
||||
"orchestration": strings.TrimSpace(h.params.OrchMode),
|
||||
"component": string(info.Component),
|
||||
"name": info.Name,
|
||||
"type": info.Type,
|
||||
"ts": time.Now().UTC().Format(time.RFC3339Nano),
|
||||
"inputSummary": inSum,
|
||||
"source": "eino_callbacks",
|
||||
})
|
||||
}
|
||||
ctx = context.WithValue(ctx, ctxSpanKey{}, spanID)
|
||||
return ctx
|
||||
}
|
||||
|
||||
func (h *runHandler) onEnd(ctx context.Context, info *callbacks.RunInfo, output callbacks.CallbackOutput) context.Context {
|
||||
spanID, _ := ctx.Value(ctxSpanKey{}).(string)
|
||||
if spanID == "" {
|
||||
spanID = h.popSpan()
|
||||
} else {
|
||||
spanID = h.popMatching(spanID)
|
||||
}
|
||||
outSum := summarizeCallbackOutput(output, h.cfg.EinoCallbacksMaxOutputSummaryRunes())
|
||||
if sp, ok := ctx.Value(ctxOtelSpanKey{}).(trace.Span); ok && sp != nil {
|
||||
if outSum != "" {
|
||||
sp.SetAttributes(attribute.String("eino.output.summary", truncateForAttr(outSum, 256)))
|
||||
}
|
||||
sp.SetStatus(codes.Ok, "")
|
||||
sp.End()
|
||||
}
|
||||
if h.params.Logger != nil {
|
||||
fields := []zap.Field{
|
||||
zap.String("runId", h.runID),
|
||||
zap.String("spanId", spanID),
|
||||
zap.String("component", string(info.Component)),
|
||||
zap.String("name", info.Name),
|
||||
zap.String("type", info.Type),
|
||||
zap.String("phase", "end"),
|
||||
}
|
||||
if h.cfg.ZapVerbose {
|
||||
h.params.Logger.Debug("eino_callback", append(fields, zap.String("outputSummary", outSum))...)
|
||||
} else {
|
||||
h.params.Logger.Info("eino_callback", fields...)
|
||||
}
|
||||
}
|
||||
if h.params.Progress != nil && h.cfg.ShouldEmitEinoTraceSSE(h.mode) {
|
||||
h.params.Progress("eino_trace_end", "", map[string]interface{}{
|
||||
"runId": h.runID,
|
||||
"spanId": spanID,
|
||||
"conversationId": strings.TrimSpace(h.params.ConversationID),
|
||||
"orchestration": strings.TrimSpace(h.params.OrchMode),
|
||||
"component": string(info.Component),
|
||||
"name": info.Name,
|
||||
"type": info.Type,
|
||||
"ts": time.Now().UTC().Format(time.RFC3339Nano),
|
||||
"outputSummary": outSum,
|
||||
"source": "eino_callbacks",
|
||||
})
|
||||
}
|
||||
return ctx
|
||||
}
|
||||
|
||||
func (h *runHandler) onError(ctx context.Context, info *callbacks.RunInfo, err error) context.Context {
|
||||
spanID, _ := ctx.Value(ctxSpanKey{}).(string)
|
||||
if spanID == "" {
|
||||
spanID = h.popSpan()
|
||||
} else {
|
||||
spanID = h.popMatching(spanID)
|
||||
}
|
||||
msg := ""
|
||||
if err != nil {
|
||||
msg = truncateRunes(err.Error(), h.cfg.EinoCallbacksMaxOutputSummaryRunes())
|
||||
}
|
||||
if sp, ok := ctx.Value(ctxOtelSpanKey{}).(trace.Span); ok && sp != nil {
|
||||
if err != nil {
|
||||
sp.RecordError(err)
|
||||
}
|
||||
sp.SetStatus(codes.Error, msg)
|
||||
sp.End()
|
||||
}
|
||||
if h.params.Logger != nil {
|
||||
h.params.Logger.Warn("eino_callback_error",
|
||||
zap.String("runId", h.runID),
|
||||
zap.String("spanId", spanID),
|
||||
zap.String("component", string(info.Component)),
|
||||
zap.String("name", info.Name),
|
||||
zap.String("type", info.Type),
|
||||
zap.Error(err),
|
||||
)
|
||||
}
|
||||
if h.params.Progress != nil && h.cfg.ShouldEmitEinoTraceSSE(h.mode) {
|
||||
h.params.Progress("eino_trace_error", msg, map[string]interface{}{
|
||||
"runId": h.runID,
|
||||
"spanId": spanID,
|
||||
"conversationId": strings.TrimSpace(h.params.ConversationID),
|
||||
"orchestration": strings.TrimSpace(h.params.OrchMode),
|
||||
"component": string(info.Component),
|
||||
"name": info.Name,
|
||||
"type": info.Type,
|
||||
"ts": time.Now().UTC().Format(time.RFC3339Nano),
|
||||
"error": msg,
|
||||
"source": "eino_callbacks",
|
||||
})
|
||||
}
|
||||
return ctx
|
||||
}
|
||||
|
||||
func (h *runHandler) onStartStreamIn(ctx context.Context, info *callbacks.RunInfo, input *schema.StreamReader[callbacks.CallbackInput]) context.Context {
|
||||
if input != nil {
|
||||
input.Close()
|
||||
}
|
||||
if h.params.Logger != nil {
|
||||
h.params.Logger.Debug("eino_callback_stream_in",
|
||||
zap.String("runId", h.runID),
|
||||
zap.String("component", string(info.Component)),
|
||||
zap.String("name", info.Name),
|
||||
)
|
||||
}
|
||||
return ctx
|
||||
}
|
||||
|
||||
func (h *runHandler) onEndStreamOut(ctx context.Context, info *callbacks.RunInfo, output *schema.StreamReader[callbacks.CallbackOutput]) context.Context {
|
||||
if output != nil {
|
||||
output.Close()
|
||||
}
|
||||
if h.params.Logger != nil {
|
||||
h.params.Logger.Debug("eino_callback_stream_out",
|
||||
zap.String("runId", h.runID),
|
||||
zap.String("component", string(info.Component)),
|
||||
zap.String("name", info.Name),
|
||||
)
|
||||
}
|
||||
return ctx
|
||||
}
|
||||
|
||||
func callbackSpanName(info *callbacks.RunInfo) string {
|
||||
if info == nil {
|
||||
return "eino.callback"
|
||||
}
|
||||
comp := strings.TrimSpace(string(info.Component))
|
||||
name := strings.TrimSpace(info.Name)
|
||||
typ := strings.TrimSpace(info.Type)
|
||||
if name != "" && comp != "" {
|
||||
return comp + "/" + name
|
||||
}
|
||||
if typ != "" && comp != "" {
|
||||
return comp + "[" + typ + "]"
|
||||
}
|
||||
if comp != "" {
|
||||
return comp
|
||||
}
|
||||
return "eino.callback"
|
||||
}
|
||||
|
||||
func truncateForAttr(s string, maxRunes int) string {
|
||||
return truncateRunes(s, maxRunes)
|
||||
}
|
||||
|
||||
func summarizeCallbackInput(in callbacks.CallbackInput, maxRunes int) string {
|
||||
if in == nil {
|
||||
return ""
|
||||
}
|
||||
if ai := adk.ConvAgentCallbackInput(in); ai != nil {
|
||||
parts := []string{"agent"}
|
||||
if ai.Input != nil {
|
||||
parts = append(parts, fmt.Sprintf("messages=%d", len(ai.Input.Messages)))
|
||||
}
|
||||
if ai.ResumeInfo != nil {
|
||||
parts = append(parts, "resume=true")
|
||||
}
|
||||
return strings.Join(parts, " ")
|
||||
}
|
||||
if mi := model.ConvCallbackInput(in); mi != nil {
|
||||
return fmt.Sprintf("chatModel messages=%d tools=%d", len(mi.Messages), len(mi.Tools))
|
||||
}
|
||||
if ti := tool.ConvCallbackInput(in); ti != nil {
|
||||
raw := ti.ArgumentsInJSON
|
||||
return "tool args=" + truncateRunes(raw, maxRunes)
|
||||
}
|
||||
b, err := json.Marshal(in)
|
||||
if err != nil {
|
||||
return fmt.Sprintf("%T", in)
|
||||
}
|
||||
return truncateRunes(string(b), maxRunes)
|
||||
}
|
||||
|
||||
func summarizeCallbackOutput(out callbacks.CallbackOutput, maxRunes int) string {
|
||||
if out == nil {
|
||||
return ""
|
||||
}
|
||||
if ao := adk.ConvAgentCallbackOutput(out); ao != nil {
|
||||
return "agent_events=stream"
|
||||
}
|
||||
if mo := model.ConvCallbackOutput(out); mo != nil && mo.Message != nil {
|
||||
s := ""
|
||||
if mo.Message.Content != "" {
|
||||
s = mo.Message.Content
|
||||
}
|
||||
if mo.TokenUsage != nil {
|
||||
return fmt.Sprintf("tokens total=%d completion=%d prompt=%d text=%s",
|
||||
mo.TokenUsage.TotalTokens, mo.TokenUsage.CompletionTokens, mo.TokenUsage.PromptTokens,
|
||||
truncateRunes(s, minInt(120, maxRunes)))
|
||||
}
|
||||
return "assistant len=" + itoa(len(s))
|
||||
}
|
||||
if to := tool.ConvCallbackOutput(out); to != nil {
|
||||
if to.Response != "" {
|
||||
return truncateRunes(to.Response, maxRunes)
|
||||
}
|
||||
if to.ToolOutput != nil {
|
||||
return "tool_result multimodal"
|
||||
}
|
||||
}
|
||||
b, err := json.Marshal(out)
|
||||
if err != nil {
|
||||
return fmt.Sprintf("%T", out)
|
||||
}
|
||||
return truncateRunes(string(b), maxRunes)
|
||||
}
|
||||
|
||||
func minInt(a, b int) int {
|
||||
if a < b {
|
||||
return a
|
||||
}
|
||||
return b
|
||||
}
|
||||
|
||||
func itoa(n int) string {
|
||||
return fmt.Sprintf("%d", n)
|
||||
}
|
||||
|
||||
func truncateRunes(s string, maxRunes int) string {
|
||||
if maxRunes <= 0 {
|
||||
return ""
|
||||
}
|
||||
r := []rune(s)
|
||||
if len(r) <= maxRunes {
|
||||
return s
|
||||
}
|
||||
return string(r[:maxRunes]) + "…"
|
||||
}
|
||||
@@ -0,0 +1,26 @@
|
||||
package einoobserve
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"cyberstrike-ai/internal/config"
|
||||
)
|
||||
|
||||
func TestAttachAgentRunCallbacks_Disabled(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
cfg := &config.MultiAgentEinoCallbacksConfig{Enabled: false}
|
||||
out := AttachAgentRunCallbacks(ctx, cfg, Params{})
|
||||
if out != ctx {
|
||||
t.Fatalf("expected same ctx when disabled")
|
||||
}
|
||||
}
|
||||
|
||||
func TestTruncateRunes(t *testing.T) {
|
||||
if got := truncateRunes("abc", 10); got != "abc" {
|
||||
t.Fatalf("got %q", got)
|
||||
}
|
||||
if got := truncateRunes("abcdefghij", 4); got != "abcd…" {
|
||||
t.Fatalf("got %q", got)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,111 @@
|
||||
package einoobserve
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"cyberstrike-ai/internal/config"
|
||||
|
||||
"go.opentelemetry.io/otel"
|
||||
"go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp"
|
||||
"go.opentelemetry.io/otel/exporters/stdout/stdouttrace"
|
||||
"go.opentelemetry.io/otel/sdk/resource"
|
||||
sdktrace "go.opentelemetry.io/otel/sdk/trace"
|
||||
semconv "go.opentelemetry.io/otel/semconv/v1.26.0"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
var (
|
||||
otelMu sync.Mutex
|
||||
otelShutdown func(context.Context) error
|
||||
otelInitialized bool
|
||||
)
|
||||
|
||||
// InitOtelFromConfig installs the global OpenTelemetry TracerProvider when
|
||||
// eino_callbacks.otel is enabled and exporter is not none. Safe to call multiple times.
|
||||
func InitOtelFromConfig(cfg *config.MultiAgentEinoCallbacksConfig, log *zap.Logger) (shutdown func(context.Context) error, err error) {
|
||||
shutdown = func(context.Context) error { return nil }
|
||||
if cfg == nil || !cfg.OtelTracingActive() {
|
||||
return shutdown, nil
|
||||
}
|
||||
|
||||
otelMu.Lock()
|
||||
defer otelMu.Unlock()
|
||||
if otelInitialized {
|
||||
if otelShutdown != nil {
|
||||
return otelShutdown, nil
|
||||
}
|
||||
return shutdown, nil
|
||||
}
|
||||
|
||||
oc := cfg.Otel
|
||||
expKind := oc.OtelExporterEffective()
|
||||
ctx := context.Background()
|
||||
|
||||
var exporter sdktrace.SpanExporter
|
||||
switch expKind {
|
||||
case "stdout":
|
||||
exporter, err = stdouttrace.New()
|
||||
if err != nil {
|
||||
return shutdown, fmt.Errorf("eino otel stdout exporter: %w", err)
|
||||
}
|
||||
case "otlphttp":
|
||||
ep := strings.TrimSpace(oc.OTLPEndpoint)
|
||||
if ep == "" {
|
||||
ep = "localhost:4318"
|
||||
}
|
||||
exporter, err = otlptracehttp.New(ctx,
|
||||
otlptracehttp.WithEndpoint(ep),
|
||||
otlptracehttp.WithURLPath("/v1/traces"),
|
||||
)
|
||||
if err != nil {
|
||||
return shutdown, fmt.Errorf("eino otel otlphttp exporter: %w", err)
|
||||
}
|
||||
default:
|
||||
return shutdown, nil
|
||||
}
|
||||
|
||||
res, err := resource.New(ctx,
|
||||
resource.WithAttributes(
|
||||
semconv.ServiceName(oc.ServiceNameEffective()),
|
||||
),
|
||||
)
|
||||
if err != nil {
|
||||
return shutdown, fmt.Errorf("eino otel resource: %w", err)
|
||||
}
|
||||
|
||||
sampler := sdktrace.ParentBased(sdktrace.TraceIDRatioBased(oc.SampleRatioEffective()))
|
||||
tp := sdktrace.NewTracerProvider(
|
||||
sdktrace.WithBatcher(exporter),
|
||||
sdktrace.WithResource(res),
|
||||
sdktrace.WithSampler(sampler),
|
||||
)
|
||||
otel.SetTracerProvider(tp)
|
||||
|
||||
otelShutdown = tp.Shutdown
|
||||
otelInitialized = true
|
||||
if log != nil {
|
||||
log.Info("eino otel: tracer provider initialized",
|
||||
zap.String("exporter", expKind),
|
||||
zap.String("service", oc.ServiceNameEffective()),
|
||||
zap.Float64("sample_ratio", oc.SampleRatioEffective()),
|
||||
)
|
||||
}
|
||||
return otelShutdown, nil
|
||||
}
|
||||
|
||||
// ShutdownOtel flushes and shuts down the global TracerProvider if it was installed.
|
||||
func ShutdownOtel(ctx context.Context) error {
|
||||
otelMu.Lock()
|
||||
fn := otelShutdown
|
||||
otelShutdown = nil
|
||||
inited := otelInitialized
|
||||
otelInitialized = false
|
||||
otelMu.Unlock()
|
||||
if !inited || fn == nil {
|
||||
return nil
|
||||
}
|
||||
return fn(ctx)
|
||||
}
|
||||
+313
-138
@@ -17,11 +17,14 @@ import (
|
||||
"unicode/utf8"
|
||||
|
||||
"cyberstrike-ai/internal/agent"
|
||||
"cyberstrike-ai/internal/audit"
|
||||
"cyberstrike-ai/internal/config"
|
||||
"cyberstrike-ai/internal/database"
|
||||
"cyberstrike-ai/internal/reasoning"
|
||||
"cyberstrike-ai/internal/mcp"
|
||||
"cyberstrike-ai/internal/mcp/builtin"
|
||||
"cyberstrike-ai/internal/multiagent"
|
||||
"cyberstrike-ai/internal/openai"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/robfig/cron/v3"
|
||||
@@ -129,6 +132,12 @@ type AgentHandler struct {
|
||||
batchRunning map[string]struct{}
|
||||
// hitlWhitelistSaver 侧栏「应用」HITL 时将会话增量白名单合并写入 config.yaml(可选)
|
||||
hitlWhitelistSaver HitlToolWhitelistSaver
|
||||
audit *audit.Service
|
||||
}
|
||||
|
||||
// SetAudit wires platform audit logging.
|
||||
func (h *AgentHandler) SetAudit(s *audit.Service) {
|
||||
h.audit = s
|
||||
}
|
||||
|
||||
// HitlToolWhitelistSaver 合并 HITL 免审批工具到全局配置并落盘
|
||||
@@ -201,6 +210,14 @@ type ChatAttachment struct {
|
||||
ServerPath string `json:"serverPath,omitempty"` // 已保存在 chat_uploads 下的绝对路径(由 POST /api/chat-uploads 返回)
|
||||
}
|
||||
|
||||
// ChatReasoningRequest 对话页「模型推理」意图(仅 Eino 路径消费;原生 agent-loop 忽略)。
|
||||
type ChatReasoningRequest struct {
|
||||
// Mode: default(跟随系统)| off | on | auto
|
||||
Mode string `json:"mode,omitempty"`
|
||||
// Effort: low | medium | high | max | xhigh(原样下发;不同网关最高档命名不同)。空表示不指定。
|
||||
Effort string `json:"effort,omitempty"`
|
||||
}
|
||||
|
||||
// ChatRequest 聊天请求
|
||||
type ChatRequest struct {
|
||||
Message string `json:"message" binding:"required"`
|
||||
@@ -209,10 +226,18 @@ type ChatRequest struct {
|
||||
Attachments []ChatAttachment `json:"attachments,omitempty"`
|
||||
WebShellConnectionID string `json:"webshellConnectionId,omitempty"` // WebShell 管理 - AI 助手:当前选中的连接 ID,仅使用 webshell_* 工具
|
||||
Hitl *HITLRequest `json:"hitl,omitempty"`
|
||||
Reasoning *ChatReasoningRequest `json:"reasoning,omitempty"`
|
||||
// Orchestration 仅对 /api/multi-agent、/api/multi-agent/stream:deep | plan_execute | supervisor;空则等同 deep。机器人/批量等无请求体时由服务端默认 deep。/api/eino-agent* 不使用此字段。
|
||||
Orchestration string `json:"orchestration,omitempty"`
|
||||
}
|
||||
|
||||
func chatReasoningToClientIntent(r *ChatReasoningRequest) *reasoning.ClientIntent {
|
||||
if r == nil {
|
||||
return nil
|
||||
}
|
||||
return &reasoning.ClientIntent{Mode: r.Mode, Effort: r.Effort}
|
||||
}
|
||||
|
||||
type HITLRequest struct {
|
||||
Enabled bool `json:"enabled"`
|
||||
Mode string `json:"mode,omitempty"`
|
||||
@@ -535,7 +560,7 @@ func (h *AgentHandler) AgentLoop(c *gin.Context) {
|
||||
conversationID := req.ConversationID
|
||||
if conversationID == "" {
|
||||
title := safeTruncateString(req.Message, 50)
|
||||
conv, err := h.db.CreateConversation(title)
|
||||
conv, err := h.db.CreateConversation(title, audit.ConversationCreateMetaFromGin(c, "agent_loop"))
|
||||
if err != nil {
|
||||
h.logger.Error("创建对话失败", zap.Error(err))
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
@@ -567,14 +592,7 @@ func (h *AgentHandler) AgentLoop(c *gin.Context) {
|
||||
h.logger.Warn("获取历史消息失败", zap.Error(err))
|
||||
agentHistoryMessages = []agent.ChatMessage{}
|
||||
} else {
|
||||
// 将数据库消息转换为Agent消息格式
|
||||
agentHistoryMessages = make([]agent.ChatMessage, 0, len(historyMessages))
|
||||
for _, msg := range historyMessages {
|
||||
agentHistoryMessages = append(agentHistoryMessages, agent.ChatMessage{
|
||||
Role: msg.Role,
|
||||
Content: msg.Content,
|
||||
})
|
||||
}
|
||||
agentHistoryMessages = dbMessagesToAgentChatMessages(historyMessages)
|
||||
h.logger.Info("从消息表加载历史消息", zap.Int("count", len(agentHistoryMessages)))
|
||||
}
|
||||
} else {
|
||||
@@ -706,11 +724,43 @@ func (h *AgentHandler) AgentLoop(c *gin.Context) {
|
||||
})
|
||||
}
|
||||
|
||||
func (h *AgentHandler) finalizeRobotAgentError(ctx context.Context, assistantMessageID, conversationID string, resultMA *multiagent.RunResult, errMA error) (string, string, error) {
|
||||
if shouldPersistEinoAgentTraceAfterRunError(ctx) {
|
||||
h.persistEinoAgentTraceForResume(conversationID, resultMA)
|
||||
}
|
||||
errMsg := "执行失败: " + errMA.Error()
|
||||
if assistantMessageID != "" {
|
||||
_, _ = h.db.Exec("UPDATE messages SET content = ?, updated_at = ? WHERE id = ?", errMsg, time.Now(), assistantMessageID)
|
||||
_ = h.db.AddProcessDetail(assistantMessageID, conversationID, "error", errMsg, nil)
|
||||
}
|
||||
return "", conversationID, errMA
|
||||
}
|
||||
|
||||
func (h *AgentHandler) finalizeRobotAgentSuccess(assistantMessageID, conversationID string, resultMA *multiagent.RunResult) (string, string, error) {
|
||||
if assistantMessageID != "" {
|
||||
if errU := h.db.UpdateAssistantMessageFinalize(assistantMessageID, resultMA.Response, resultMA.MCPExecutionIDs, multiagent.AggregatedReasoningFromTraceJSON(resultMA.LastAgentTraceInput)); errU != nil {
|
||||
h.logger.Warn("机器人:更新助手消息失败", zap.Error(errU))
|
||||
}
|
||||
} else {
|
||||
if _, err := h.db.AddMessage(conversationID, "assistant", resultMA.Response, resultMA.MCPExecutionIDs); err != nil {
|
||||
h.logger.Warn("机器人:保存助手消息失败", zap.Error(err))
|
||||
}
|
||||
}
|
||||
if resultMA.LastAgentTraceInput != "" || resultMA.LastAgentTraceOutput != "" {
|
||||
_ = h.db.SaveAgentTrace(conversationID, resultMA.LastAgentTraceInput, resultMA.LastAgentTraceOutput)
|
||||
}
|
||||
return resultMA.Response, conversationID, nil
|
||||
}
|
||||
|
||||
// ProcessMessageForRobot 供机器人(企业微信/钉钉/飞书)调用:与 /api/agent-loop/stream 相同执行路径(含 progressCallback、过程详情),仅不发送 SSE,最后返回完整回复
|
||||
func (h *AgentHandler) ProcessMessageForRobot(ctx context.Context, conversationID, message, role string) (response string, convID string, err error) {
|
||||
func (h *AgentHandler) ProcessMessageForRobot(ctx context.Context, platform, conversationID, message, role string) (response string, convID string, err error) {
|
||||
if conversationID == "" {
|
||||
title := safeTruncateString(message, 50)
|
||||
conv, createErr := h.db.CreateConversation(title)
|
||||
src := "robot"
|
||||
if strings.TrimSpace(platform) != "" {
|
||||
src = "robot:" + strings.TrimSpace(platform)
|
||||
}
|
||||
conv, createErr := h.db.CreateConversation(title, audit.ConversationCreateMeta(src))
|
||||
if createErr != nil {
|
||||
return "", "", fmt.Errorf("创建对话失败: %w", createErr)
|
||||
}
|
||||
@@ -758,61 +808,92 @@ func (h *AgentHandler) ProcessMessageForRobot(ctx context.Context, conversationI
|
||||
if assistantMsg != nil {
|
||||
assistantMessageID = assistantMsg.ID
|
||||
}
|
||||
progressCallback := h.createProgressCallback(ctx, nil, conversationID, assistantMessageID, nil)
|
||||
|
||||
useRobotMulti := h.config != nil && h.config.MultiAgent.Enabled && h.config.MultiAgent.RobotUseMultiAgent
|
||||
if useRobotMulti {
|
||||
resultMA, errMA := multiagent.RunDeepAgent(
|
||||
ctx,
|
||||
h.config,
|
||||
&h.config.MultiAgent,
|
||||
h.agent,
|
||||
h.logger,
|
||||
conversationID,
|
||||
finalMessage,
|
||||
agentHistoryMessages,
|
||||
roleTools,
|
||||
progressCallback,
|
||||
h.agentsMarkdownDir,
|
||||
"deep",
|
||||
)
|
||||
if errMA != nil {
|
||||
if shouldPersistEinoAgentTraceAfterRunError(ctx) {
|
||||
h.persistEinoAgentTraceForResume(conversationID, resultMA)
|
||||
}
|
||||
errMsg := "执行失败: " + errMA.Error()
|
||||
if assistantMessageID != "" {
|
||||
_, _ = h.db.Exec("UPDATE messages SET content = ?, updated_at = ? WHERE id = ?", errMsg, time.Now(), assistantMessageID)
|
||||
_ = h.db.AddProcessDetail(assistantMessageID, conversationID, "error", errMsg, nil)
|
||||
}
|
||||
return "", conversationID, errMA
|
||||
// 注册运行中任务并向 taskEventBus 镜像进度事件,供 Web 端 task-events 补流(与 agent-loop/stream 一致)。
|
||||
taskCtx, cancelWithCause := context.WithCancelCause(ctx)
|
||||
defer cancelWithCause(nil)
|
||||
taskStatus := "completed"
|
||||
defer func() {
|
||||
h.tasks.FinishTask(conversationID, taskStatus)
|
||||
}()
|
||||
if _, err := h.tasks.StartTask(conversationID, message, cancelWithCause); err != nil {
|
||||
if errors.Is(err, ErrTaskAlreadyRunning) {
|
||||
return "", conversationID, fmt.Errorf("当前会话已有任务正在执行中,请稍后再试")
|
||||
}
|
||||
if assistantMessageID != "" {
|
||||
mcpIDsJSON := ""
|
||||
if len(resultMA.MCPExecutionIDs) > 0 {
|
||||
jsonData, _ := json.Marshal(resultMA.MCPExecutionIDs)
|
||||
mcpIDsJSON = string(jsonData)
|
||||
}
|
||||
_, err = h.db.Exec(
|
||||
"UPDATE messages SET content = ?, mcp_execution_ids = ?, updated_at = ? WHERE id = ?",
|
||||
resultMA.Response, mcpIDsJSON, time.Now(), assistantMessageID,
|
||||
return "", conversationID, fmt.Errorf("无法启动任务: %w", err)
|
||||
}
|
||||
progressCallback := h.createProgressCallback(taskCtx, cancelWithCause, conversationID, assistantMessageID, nil)
|
||||
|
||||
robotMode := "react"
|
||||
if h.config != nil {
|
||||
robotMode = config.NormalizeRobotAgentMode(h.config.MultiAgent)
|
||||
}
|
||||
switch robotMode {
|
||||
case "eino_single":
|
||||
curHist := agentHistoryMessages
|
||||
curMsg := finalMessage
|
||||
segmentUserMessage := finalMessage
|
||||
var resultMA *multiagent.RunResult
|
||||
var errMA error
|
||||
var transientRunAttempts int
|
||||
for {
|
||||
resultMA, errMA = multiagent.RunEinoSingleChatModelAgent(
|
||||
taskCtx, h.config, &h.config.MultiAgent, h.agent, h.logger,
|
||||
conversationID, curMsg, curHist, roleTools, progressCallback, nil,
|
||||
)
|
||||
if err != nil {
|
||||
h.logger.Warn("机器人:更新助手消息失败", zap.Error(err))
|
||||
if errMA == nil {
|
||||
// 成功后重置 transient 重试窗口,下一次分段从第 1 次重试开始。
|
||||
transientRunAttempts = 0
|
||||
break
|
||||
}
|
||||
} else {
|
||||
if _, err = h.db.AddMessage(conversationID, "assistant", resultMA.Response, resultMA.MCPExecutionIDs); err != nil {
|
||||
h.logger.Warn("机器人:保存助手消息失败", zap.Error(err))
|
||||
if handled, _ := h.handleEinoTransientRetryContinue(
|
||||
taskCtx, conversationID, resultMA, errMA, &transientRunAttempts,
|
||||
&curHist, &curMsg, segmentUserMessage, progressCallback, nil,
|
||||
); handled {
|
||||
continue
|
||||
}
|
||||
taskStatus = "failed"
|
||||
return h.finalizeRobotAgentError(taskCtx, assistantMessageID, conversationID, resultMA, errMA)
|
||||
}
|
||||
if resultMA.LastAgentTraceInput != "" || resultMA.LastAgentTraceOutput != "" {
|
||||
_ = h.db.SaveAgentTrace(conversationID, resultMA.LastAgentTraceInput, resultMA.LastAgentTraceOutput)
|
||||
return h.finalizeRobotAgentSuccess(assistantMessageID, conversationID, resultMA)
|
||||
case "deep", "plan_execute", "supervisor":
|
||||
if h.config == nil || !h.config.MultiAgent.Enabled {
|
||||
h.logger.Warn("机器人配置为多代理模式但未启用 multi_agent,回退原生 ReAct",
|
||||
zap.String("robot_mode", robotMode))
|
||||
break
|
||||
}
|
||||
return resultMA.Response, conversationID, nil
|
||||
curHist := agentHistoryMessages
|
||||
curMsg := finalMessage
|
||||
segmentUserMessage := finalMessage
|
||||
var resultMA *multiagent.RunResult
|
||||
var errMA error
|
||||
var transientRunAttempts int
|
||||
for {
|
||||
resultMA, errMA = multiagent.RunDeepAgent(
|
||||
taskCtx, h.config, &h.config.MultiAgent, h.agent, h.logger,
|
||||
conversationID, curMsg, curHist, roleTools, progressCallback,
|
||||
h.agentsMarkdownDir, robotMode, nil,
|
||||
)
|
||||
if errMA == nil {
|
||||
// 成功后重置 transient 重试窗口,下一次分段从第 1 次重试开始。
|
||||
transientRunAttempts = 0
|
||||
break
|
||||
}
|
||||
if handled, _ := h.handleEinoTransientRetryContinue(
|
||||
taskCtx, conversationID, resultMA, errMA, &transientRunAttempts,
|
||||
&curHist, &curMsg, segmentUserMessage, progressCallback, nil,
|
||||
); handled {
|
||||
continue
|
||||
}
|
||||
taskStatus = "failed"
|
||||
return h.finalizeRobotAgentError(taskCtx, assistantMessageID, conversationID, resultMA, errMA)
|
||||
}
|
||||
return h.finalizeRobotAgentSuccess(assistantMessageID, conversationID, resultMA)
|
||||
}
|
||||
|
||||
result, err := h.agent.AgentLoopWithProgress(ctx, finalMessage, agentHistoryMessages, conversationID, progressCallback, roleTools)
|
||||
result, err := h.agent.AgentLoopWithProgress(taskCtx, finalMessage, agentHistoryMessages, conversationID, progressCallback, roleTools)
|
||||
if err != nil {
|
||||
taskStatus = "failed"
|
||||
errMsg := "执行失败: " + err.Error()
|
||||
if assistantMessageID != "" {
|
||||
_, _ = h.db.Exec("UPDATE messages SET content = ?, updated_at = ? WHERE id = ?", errMsg, time.Now(), assistantMessageID)
|
||||
@@ -823,17 +904,8 @@ func (h *AgentHandler) ProcessMessageForRobot(ctx context.Context, conversationI
|
||||
|
||||
// 更新助手消息内容与 MCP 执行 ID(与 stream 一致)
|
||||
if assistantMessageID != "" {
|
||||
mcpIDsJSON := ""
|
||||
if len(result.MCPExecutionIDs) > 0 {
|
||||
jsonData, _ := json.Marshal(result.MCPExecutionIDs)
|
||||
mcpIDsJSON = string(jsonData)
|
||||
}
|
||||
_, err = h.db.Exec(
|
||||
"UPDATE messages SET content = ?, mcp_execution_ids = ?, updated_at = ? WHERE id = ?",
|
||||
result.Response, mcpIDsJSON, time.Now(), assistantMessageID,
|
||||
)
|
||||
if err != nil {
|
||||
h.logger.Warn("机器人:更新助手消息失败", zap.Error(err))
|
||||
if errU := h.db.UpdateAssistantMessageFinalize(assistantMessageID, result.Response, result.MCPExecutionIDs, multiagent.AggregatedReasoningFromTraceJSON(result.LastAgentTraceInput)); errU != nil {
|
||||
h.logger.Warn("机器人:更新助手消息失败", zap.Error(errU))
|
||||
}
|
||||
} else {
|
||||
if _, err = h.db.AddMessage(conversationID, "assistant", result.Response, result.MCPExecutionIDs); err != nil {
|
||||
@@ -853,6 +925,23 @@ type StreamEvent struct {
|
||||
Data interface{} `json:"data,omitempty"`
|
||||
}
|
||||
|
||||
// publishProgressToTaskEventBus 将进度事件镜像到 taskEventBus(机器人/无 HTTP SSE 客户端时供 Web task-events 订阅)。
|
||||
func (h *AgentHandler) publishProgressToTaskEventBus(conversationID, eventType, message string, data interface{}) {
|
||||
if h == nil || h.taskEventBus == nil || strings.TrimSpace(conversationID) == "" {
|
||||
return
|
||||
}
|
||||
event := StreamEvent{Type: eventType, Message: message, Data: data}
|
||||
eventJSON, err := json.Marshal(event)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
sseLine := make([]byte, 0, len(eventJSON)+8)
|
||||
sseLine = append(sseLine, []byte("data: ")...)
|
||||
sseLine = append(sseLine, eventJSON...)
|
||||
sseLine = append(sseLine, '\n', '\n')
|
||||
h.taskEventBus.Publish(conversationID, sseLine)
|
||||
}
|
||||
|
||||
// createProgressCallback 创建进度回调函数,用于保存processDetails
|
||||
// sendEventFunc: 可选的流式事件发送函数,如果为nil则不发送流式事件
|
||||
func (h *AgentHandler) createProgressCallback(runCtx context.Context, cancelRun context.CancelCauseFunc, conversationID, assistantMessageID string, sendEventFunc func(eventType, message string, data interface{})) agent.ProgressCallback {
|
||||
@@ -891,10 +980,12 @@ func (h *AgentHandler) createProgressCallback(runCtx context.Context, cancelRun
|
||||
return ""
|
||||
}
|
||||
|
||||
// thinking_stream_*:不逐条落库,按 streamId 聚合,在后续关键事件前补一条可持久化的 thinking
|
||||
// thinking_stream_*(ReAct 等助手正文流)与 reasoning_chain_stream_*(Eino ReasoningContent):
|
||||
// 不逐条落库,按 streamId 聚合,flush 时分别落 thinking / reasoning_chain。
|
||||
type thinkingBuf struct {
|
||||
b strings.Builder
|
||||
meta map[string]interface{}
|
||||
b strings.Builder
|
||||
meta map[string]interface{}
|
||||
persistAs string // "thinking" | "reasoning_chain"
|
||||
}
|
||||
thinkingStreams := make(map[string]*thinkingBuf) // streamId -> buf
|
||||
flushedThinking := make(map[string]bool) // streamId -> flushed
|
||||
@@ -948,17 +1039,23 @@ func (h *AgentHandler) createProgressCallback(runCtx context.Context, cancelRun
|
||||
}
|
||||
data[k] = v
|
||||
}
|
||||
if err := h.db.AddProcessDetail(assistantMessageID, conversationID, "thinking", content, data); err != nil {
|
||||
h.logger.Warn("保存过程详情失败", zap.Error(err), zap.String("eventType", "thinking"))
|
||||
persist := tb.persistAs
|
||||
if persist != "reasoning_chain" {
|
||||
persist = "thinking"
|
||||
}
|
||||
if err := h.db.AddProcessDetail(assistantMessageID, conversationID, persist, content, data); err != nil {
|
||||
h.logger.Warn("保存过程详情失败", zap.Error(err), zap.String("eventType", persist))
|
||||
}
|
||||
flushedThinking[sid] = true
|
||||
}
|
||||
}
|
||||
|
||||
return func(eventType, message string, data interface{}) {
|
||||
// 如果提供了sendEventFunc,发送流式事件
|
||||
// 流式:写 HTTP SSE;非流式(机器人等):镜像到 taskEventBus 供 Web 订阅
|
||||
if sendEventFunc != nil {
|
||||
sendEventFunc(eventType, message, data)
|
||||
} else {
|
||||
h.publishProgressToTaskEventBus(conversationID, eventType, message, data)
|
||||
}
|
||||
|
||||
// 保存tool_call事件中的参数
|
||||
@@ -1159,7 +1256,16 @@ func (h *AgentHandler) createProgressCallback(runCtx context.Context, cancelRun
|
||||
return
|
||||
}
|
||||
if eventType == "response_delta" {
|
||||
respPlan.b.WriteString(message)
|
||||
if dataMap, ok := data.(map[string]interface{}); ok {
|
||||
if acc, okAcc := dataMap[openai.SSEAccumulatedKey].(string); okAcc {
|
||||
respPlan.b.Reset()
|
||||
respPlan.b.WriteString(acc)
|
||||
} else {
|
||||
respPlan.b.WriteString(message)
|
||||
}
|
||||
} else {
|
||||
respPlan.b.WriteString(message)
|
||||
}
|
||||
if dataMap, ok := data.(map[string]interface{}); ok && respPlan.meta == nil {
|
||||
respPlan.meta = make(map[string]interface{}, len(dataMap))
|
||||
for k, v := range dataMap {
|
||||
@@ -1177,14 +1283,20 @@ func (h *AgentHandler) createProgressCallback(runCtx context.Context, cancelRun
|
||||
return
|
||||
}
|
||||
|
||||
// 聚合 thinking_stream_*(ReasoningContent),不逐条落库
|
||||
if eventType == "thinking_stream_start" {
|
||||
// 聚合 thinking_stream_* / reasoning_chain_stream_*,不逐条落库
|
||||
if eventType == "thinking_stream_start" || eventType == "reasoning_chain_stream_start" {
|
||||
persistAs := "thinking"
|
||||
if eventType == "reasoning_chain_stream_start" {
|
||||
persistAs = "reasoning_chain"
|
||||
}
|
||||
if dataMap, ok := data.(map[string]interface{}); ok {
|
||||
if sid, ok2 := dataMap["streamId"].(string); ok2 && sid != "" {
|
||||
tb := thinkingStreams[sid]
|
||||
if tb == nil {
|
||||
tb = &thinkingBuf{meta: map[string]interface{}{}}
|
||||
tb = &thinkingBuf{meta: map[string]interface{}{}, persistAs: persistAs}
|
||||
thinkingStreams[sid] = tb
|
||||
} else {
|
||||
tb.persistAs = persistAs
|
||||
}
|
||||
// 记录元信息(source/einoAgent/einoRole/iteration 等)
|
||||
for k, v := range dataMap {
|
||||
@@ -1194,16 +1306,26 @@ func (h *AgentHandler) createProgressCallback(runCtx context.Context, cancelRun
|
||||
}
|
||||
return
|
||||
}
|
||||
if eventType == "thinking_stream_delta" {
|
||||
if eventType == "thinking_stream_delta" || eventType == "reasoning_chain_stream_delta" {
|
||||
persistAs := "thinking"
|
||||
if eventType == "reasoning_chain_stream_delta" {
|
||||
persistAs = "reasoning_chain"
|
||||
}
|
||||
if dataMap, ok := data.(map[string]interface{}); ok {
|
||||
if sid, ok2 := dataMap["streamId"].(string); ok2 && sid != "" {
|
||||
tb := thinkingStreams[sid]
|
||||
if tb == nil {
|
||||
tb = &thinkingBuf{meta: map[string]interface{}{}}
|
||||
tb = &thinkingBuf{meta: map[string]interface{}{}, persistAs: persistAs}
|
||||
thinkingStreams[sid] = tb
|
||||
} else if tb.persistAs == "" {
|
||||
tb.persistAs = persistAs
|
||||
}
|
||||
if acc, okAcc := dataMap[openai.SSEAccumulatedKey].(string); okAcc {
|
||||
tb.b.Reset()
|
||||
tb.b.WriteString(acc)
|
||||
} else {
|
||||
tb.b.WriteString(message)
|
||||
}
|
||||
// delta 片段直接拼接;message 本身就是 reasoning content
|
||||
tb.b.WriteString(message)
|
||||
// 有时 delta 先到 start 未到,补充元信息
|
||||
for k, v := range dataMap {
|
||||
tb.meta[k] = v
|
||||
@@ -1213,10 +1335,9 @@ func (h *AgentHandler) createProgressCallback(runCtx context.Context, cancelRun
|
||||
return
|
||||
}
|
||||
|
||||
// 当 Agent 同时发送 thinking_stream_* 和 thinking(带同一 streamId)时,
|
||||
// thinking_stream_* 已经会在 flushThinkingStreams() 聚合落库;
|
||||
// 这里跳过同 streamId 的 thinking,避免 processDetails 双份展示。
|
||||
if eventType == "thinking" {
|
||||
// 当 Agent 同时发送 *_stream_* 与同名 streamId 的 thinking/reasoning_chain 时,
|
||||
// 流式聚合已会在 flushThinkingStreams() 落库;此处跳过逐条重复。
|
||||
if eventType == "thinking" || eventType == "reasoning_chain" {
|
||||
if dataMap, ok := data.(map[string]interface{}); ok {
|
||||
if sid, ok2 := dataMap["streamId"].(string); ok2 && sid != "" {
|
||||
if tb, exists := thinkingStreams[sid]; exists && tb != nil {
|
||||
@@ -1239,13 +1360,17 @@ func (h *AgentHandler) createProgressCallback(runCtx context.Context, cancelRun
|
||||
eventType != "response_start" &&
|
||||
eventType != "response_delta" &&
|
||||
eventType != "tool_result_delta" &&
|
||||
eventType != "eino_trace_run" &&
|
||||
eventType != "eino_trace_start" &&
|
||||
eventType != "eino_trace_end" &&
|
||||
eventType != "eino_trace_error" &&
|
||||
eventType != "eino_agent_reply_stream_start" &&
|
||||
eventType != "eino_agent_reply_stream_delta" &&
|
||||
eventType != "eino_agent_reply_stream_end" {
|
||||
if eventType == "tool_result" {
|
||||
discardPlanningIfEchoesToolResult(&respPlan, data)
|
||||
}
|
||||
// 在关键过程事件落库前,先把「规划中」与 thinking_stream 落库
|
||||
// 在关键过程事件落库前,先把「规划中」与聚合中的 thinking / reasoning_chain 流落库
|
||||
flushResponsePlan()
|
||||
flushThinkingStreams()
|
||||
if err := h.db.AddProcessDetail(assistantMessageID, conversationID, eventType, message, data); err != nil {
|
||||
@@ -1392,10 +1517,12 @@ func (h *AgentHandler) AgentLoopStream(c *gin.Context) {
|
||||
title := safeTruncateString(req.Message, 50)
|
||||
var conv *database.Conversation
|
||||
var err error
|
||||
meta := audit.ConversationCreateMetaFromGin(c, "agent_loop_stream")
|
||||
if req.WebShellConnectionID != "" {
|
||||
conv, err = h.db.CreateConversationWithWebshell(strings.TrimSpace(req.WebShellConnectionID), title)
|
||||
meta.Source = "webshell_chat"
|
||||
conv, err = h.db.CreateConversationWithWebshell(strings.TrimSpace(req.WebShellConnectionID), title, meta)
|
||||
} else {
|
||||
conv, err = h.db.CreateConversation(title)
|
||||
conv, err = h.db.CreateConversation(title, meta)
|
||||
}
|
||||
if err != nil {
|
||||
h.logger.Error("创建对话失败", zap.Error(err))
|
||||
@@ -1427,14 +1554,7 @@ func (h *AgentHandler) AgentLoopStream(c *gin.Context) {
|
||||
h.logger.Warn("获取历史消息失败", zap.Error(err))
|
||||
agentHistoryMessages = []agent.ChatMessage{}
|
||||
} else {
|
||||
// 将数据库消息转换为Agent消息格式
|
||||
agentHistoryMessages = make([]agent.ChatMessage, 0, len(historyMessages))
|
||||
for _, msg := range historyMessages {
|
||||
agentHistoryMessages = append(agentHistoryMessages, agent.ChatMessage{
|
||||
Role: msg.Role,
|
||||
Content: msg.Content,
|
||||
})
|
||||
}
|
||||
agentHistoryMessages = dbMessagesToAgentChatMessages(historyMessages)
|
||||
h.logger.Info("从消息表加载历史消息", zap.Int("count", len(agentHistoryMessages)))
|
||||
}
|
||||
} else {
|
||||
@@ -1727,20 +1847,8 @@ func (h *AgentHandler) AgentLoopStream(c *gin.Context) {
|
||||
|
||||
// 更新助手消息内容
|
||||
if assistantMsg != nil {
|
||||
_, err = h.db.Exec(
|
||||
"UPDATE messages SET content = ?, mcp_execution_ids = ?, updated_at = ? WHERE id = ?",
|
||||
result.Response,
|
||||
func() string {
|
||||
if len(result.MCPExecutionIDs) > 0 {
|
||||
jsonData, _ := json.Marshal(result.MCPExecutionIDs)
|
||||
return string(jsonData)
|
||||
}
|
||||
return ""
|
||||
}(),
|
||||
time.Now(), assistantMessageID,
|
||||
)
|
||||
if err != nil {
|
||||
h.logger.Error("更新助手消息失败", zap.Error(err))
|
||||
if errU := h.db.UpdateAssistantMessageFinalize(assistantMessageID, result.Response, result.MCPExecutionIDs, multiagent.AggregatedReasoningFromTraceJSON(result.LastAgentTraceInput)); errU != nil {
|
||||
h.logger.Error("更新助手消息失败", zap.Error(errU))
|
||||
}
|
||||
} else {
|
||||
// 如果之前创建失败,现在创建
|
||||
@@ -1789,27 +1897,51 @@ func (h *AgentHandler) CancelAgentLoop(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
execID := h.tasks.ActiveMCPExecutionID(req.ConversationID)
|
||||
if execID == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "当前没有正在执行的 MCP 工具(例如模型尚在推理、尚未发起工具调用)。请等待工具开始执行后再试,或使用「彻底停止」结束整轮任务。"})
|
||||
return
|
||||
}
|
||||
note := strings.TrimSpace(req.Reason)
|
||||
if !h.agent.CancelMCPToolExecutionWithNote(execID, note) {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "未找到进行中的工具执行或该调用已结束"})
|
||||
if execID != "" {
|
||||
if !h.agent.CancelMCPToolExecutionWithNote(execID, note) {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "未找到进行中的工具执行或该调用已结束"})
|
||||
return
|
||||
}
|
||||
h.logger.Info("对话页仅终止当前 MCP 工具",
|
||||
zap.String("conversationId", req.ConversationID),
|
||||
zap.String("executionId", execID),
|
||||
zap.Bool("hasNote", note != ""),
|
||||
)
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"status": "tool_abort_requested",
|
||||
"conversationId": req.ConversationID,
|
||||
"executionId": execID,
|
||||
"message": "已请求终止当前工具调用;工具返回后本轮推理将继续(与 MCP 监控页终止一致)。",
|
||||
"continueAfter": true,
|
||||
"interruptWithNote": note != "",
|
||||
"continueWithoutTool": false,
|
||||
})
|
||||
return
|
||||
}
|
||||
h.logger.Info("对话页仅终止当前 MCP 工具",
|
||||
// 无进行中的 MCP 工具(模型纯推理/流式输出阶段):取消当前上下文并由 Eino 流式处理器合并用户补充后自动续跑。
|
||||
h.tasks.SetInterruptContinueNote(req.ConversationID, note)
|
||||
ok, err := h.tasks.CancelTask(req.ConversationID, multiagent.ErrInterruptContinue)
|
||||
if err != nil {
|
||||
h.logger.Error("中断并继续(无工具)失败", zap.Error(err))
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
if !ok {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "未找到正在执行的任务"})
|
||||
return
|
||||
}
|
||||
h.logger.Info("对话页中断并继续(无 MCP 工具,将自动续跑)",
|
||||
zap.String("conversationId", req.ConversationID),
|
||||
zap.String("executionId", execID),
|
||||
zap.Bool("hasNote", note != ""),
|
||||
)
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"status": "tool_abort_requested",
|
||||
"conversationId": req.ConversationID,
|
||||
"executionId": execID,
|
||||
"message": "已请求终止当前工具调用;工具返回后本轮推理将继续(与 MCP 监控页终止一致)。",
|
||||
"continueAfter": true,
|
||||
"interruptWithNote": note != "",
|
||||
"status": "interrupt_continue_scheduled",
|
||||
"conversationId": req.ConversationID,
|
||||
"message": "已请求暂停当前推理;用户补充将合并到上下文并自动继续执行(无需整轮停止)。",
|
||||
"continueAfter": true,
|
||||
"interruptWithNote": note != "",
|
||||
"continueWithoutTool": true,
|
||||
})
|
||||
return
|
||||
}
|
||||
@@ -2006,6 +2138,11 @@ func (h *AgentHandler) CreateBatchQueue(c *gin.Context) {
|
||||
queue = refreshed
|
||||
}
|
||||
}
|
||||
if h.audit != nil {
|
||||
h.audit.RecordOK(c, "task", "create_queue", "创建批量任务队列", "batch_queue", queue.ID, map[string]interface{}{
|
||||
"task_count": len(validTasks), "started": started,
|
||||
})
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"queueId": queue.ID,
|
||||
"queue": queue,
|
||||
@@ -2113,6 +2250,9 @@ func (h *AgentHandler) StartBatchQueue(c *gin.Context) {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "队列不存在"})
|
||||
return
|
||||
}
|
||||
if h.audit != nil {
|
||||
h.audit.RecordOK(c, "task", "start_queue", "启动批量任务队列", "batch_queue", queueID, nil)
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{"message": "批量任务已开始执行", "queueId": queueID})
|
||||
}
|
||||
|
||||
@@ -2141,6 +2281,9 @@ func (h *AgentHandler) RerunBatchQueue(c *gin.Context) {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "启动失败"})
|
||||
return
|
||||
}
|
||||
if h.audit != nil {
|
||||
h.audit.RecordOK(c, "task", "rerun_queue", "重跑批量任务队列", "batch_queue", queueID, nil)
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{"message": "批量任务已重新开始执行", "queueId": queueID})
|
||||
}
|
||||
|
||||
@@ -2152,6 +2295,9 @@ func (h *AgentHandler) PauseBatchQueue(c *gin.Context) {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "队列不存在或无法暂停"})
|
||||
return
|
||||
}
|
||||
if h.audit != nil {
|
||||
h.audit.RecordOK(c, "task", "pause_queue", "暂停批量任务队列", "batch_queue", queueID, nil)
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{"message": "批量任务已暂停"})
|
||||
}
|
||||
|
||||
@@ -2247,6 +2393,16 @@ func (h *AgentHandler) DeleteBatchQueue(c *gin.Context) {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "队列不存在"})
|
||||
return
|
||||
}
|
||||
if h.audit != nil {
|
||||
h.audit.Record(c, audit.Entry{
|
||||
Category: "task",
|
||||
Action: "delete_queue",
|
||||
Result: "success",
|
||||
ResourceType: "batch_queue",
|
||||
ResourceID: queueID,
|
||||
Message: "删除批量任务队列",
|
||||
})
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{"message": "批量任务队列已删除"})
|
||||
}
|
||||
|
||||
@@ -2332,6 +2488,11 @@ func (h *AgentHandler) DeleteBatchTask(c *gin.Context) {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "队列不存在"})
|
||||
return
|
||||
}
|
||||
if h.audit != nil {
|
||||
h.audit.RecordOK(c, "task", "delete_batch_task", "删除批量子任务", "batch_task", taskID, map[string]interface{}{
|
||||
"batch_queue_id": queueID,
|
||||
})
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{"message": "任务已删除", "queue": queue})
|
||||
}
|
||||
|
||||
@@ -2490,7 +2651,7 @@ func (h *AgentHandler) executeBatchQueue(queueID string) {
|
||||
|
||||
// 创建新对话
|
||||
title := safeTruncateString(task.Message, 50)
|
||||
conv, err := h.db.CreateConversation(title)
|
||||
conv, err := h.db.CreateConversation(title, audit.ConversationCreateMeta("batch_task"))
|
||||
var conversationID string
|
||||
if err != nil {
|
||||
h.logger.Error("创建对话失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.Error(err))
|
||||
@@ -2640,12 +2801,12 @@ func (h *AgentHandler) executeBatchQueue(queueID string) {
|
||||
var runErr error
|
||||
switch {
|
||||
case useBatchMulti:
|
||||
resultMA, runErr = multiagent.RunDeepAgent(taskCtx, h.config, &h.config.MultiAgent, h.agent, h.logger, conversationID, finalMessage, []agent.ChatMessage{}, roleTools, progressCallback, h.agentsMarkdownDir, batchOrch)
|
||||
resultMA, runErr = multiagent.RunDeepAgent(taskCtx, h.config, &h.config.MultiAgent, h.agent, h.logger, conversationID, finalMessage, []agent.ChatMessage{}, roleTools, progressCallback, h.agentsMarkdownDir, batchOrch, nil)
|
||||
case useEinoSingle:
|
||||
if h.config == nil {
|
||||
runErr = fmt.Errorf("服务器配置未加载")
|
||||
} else {
|
||||
resultMA, runErr = multiagent.RunEinoSingleChatModelAgent(taskCtx, h.config, &h.config.MultiAgent, h.agent, h.logger, conversationID, finalMessage, []agent.ChatMessage{}, roleTools, progressCallback)
|
||||
resultMA, runErr = multiagent.RunEinoSingleChatModelAgent(taskCtx, h.config, &h.config.MultiAgent, h.agent, h.logger, conversationID, finalMessage, []agent.ChatMessage{}, roleTools, progressCallback, nil)
|
||||
}
|
||||
default:
|
||||
result, runErr = h.agent.AgentLoopWithProgress(taskCtx, finalMessage, []agent.ChatMessage{}, conversationID, progressCallback, roleTools)
|
||||
@@ -2744,17 +2905,7 @@ func (h *AgentHandler) executeBatchQueue(queueID string) {
|
||||
|
||||
// 更新助手消息内容
|
||||
if assistantMessageID != "" {
|
||||
mcpIDsJSON := ""
|
||||
if len(mcpIDs) > 0 {
|
||||
jsonData, _ := json.Marshal(mcpIDs)
|
||||
mcpIDsJSON = string(jsonData)
|
||||
}
|
||||
if _, updateErr := h.db.Exec(
|
||||
"UPDATE messages SET content = ?, mcp_execution_ids = ?, updated_at = ? WHERE id = ?",
|
||||
resText,
|
||||
mcpIDsJSON,
|
||||
time.Now(), assistantMessageID,
|
||||
); updateErr != nil {
|
||||
if updateErr := h.db.UpdateAssistantMessageFinalize(assistantMessageID, resText, mcpIDs, multiagent.AggregatedReasoningFromTraceJSON(lastIn)); updateErr != nil {
|
||||
h.logger.Warn("更新助手消息失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.Error(updateErr))
|
||||
// 如果更新失败,尝试创建新消息
|
||||
_, err = h.db.AddMessage(conversationID, "assistant", resText, mcpIDs)
|
||||
@@ -2846,6 +2997,10 @@ func (h *AgentHandler) loadHistoryFromAgentTrace(conversationID string) ([]agent
|
||||
if content, ok := msgMap["content"].(string); ok {
|
||||
msg.Content = content
|
||||
}
|
||||
// DeepSeek 思考模式:含工具调用的 assistant 须在后续请求中回传 reasoning_content
|
||||
if rc, ok := msgMap["reasoning_content"].(string); ok && strings.TrimSpace(rc) != "" {
|
||||
msg.ReasoningContent = rc
|
||||
}
|
||||
|
||||
// 解析tool_calls(如果存在)
|
||||
if toolCallsRaw, ok := msgMap["tool_calls"]; ok && toolCallsRaw != nil {
|
||||
@@ -2901,6 +3056,11 @@ func (h *AgentHandler) loadHistoryFromAgentTrace(conversationID string) ([]agent
|
||||
if toolCallID, ok := msgMap["tool_call_id"].(string); ok {
|
||||
msg.ToolCallID = toolCallID
|
||||
}
|
||||
if tn, ok := msgMap["tool_name"].(string); ok && strings.TrimSpace(tn) != "" {
|
||||
msg.ToolName = strings.TrimSpace(tn)
|
||||
} else if tn, ok := msgMap["name"].(string); ok && strings.TrimSpace(tn) != "" && strings.EqualFold(msg.Role, "tool") {
|
||||
msg.ToolName = strings.TrimSpace(tn)
|
||||
}
|
||||
|
||||
agentMessages = append(agentMessages, msg)
|
||||
}
|
||||
@@ -2946,3 +3106,18 @@ func (h *AgentHandler) loadHistoryFromAgentTrace(conversationID string) ([]agent
|
||||
)
|
||||
return agentMessages, nil
|
||||
}
|
||||
|
||||
// dbMessagesToAgentChatMessages maps DB rows to agent ChatMessage for history fallback
|
||||
// (includes reasoning_content for DeepSeek thinking + tool replay).
|
||||
func dbMessagesToAgentChatMessages(msgs []database.Message) []agent.ChatMessage {
|
||||
out := make([]agent.ChatMessage, 0, len(msgs))
|
||||
for i := range msgs {
|
||||
m := msgs[i]
|
||||
out = append(out, agent.ChatMessage{
|
||||
Role: m.Role,
|
||||
Content: m.Content,
|
||||
ReasoningContent: m.ReasoningContent,
|
||||
})
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
@@ -0,0 +1,147 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"cyberstrike-ai/internal/audit"
|
||||
"cyberstrike-ai/internal/database"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// AuditHandler serves platform audit log APIs.
|
||||
type AuditHandler struct {
|
||||
db *database.DB
|
||||
audit *audit.Service
|
||||
logger *zap.Logger
|
||||
}
|
||||
|
||||
// NewAuditHandler creates an audit log handler.
|
||||
func NewAuditHandler(db *database.DB, auditSvc *audit.Service, logger *zap.Logger) *AuditHandler {
|
||||
return &AuditHandler{db: db, audit: auditSvc, logger: logger}
|
||||
}
|
||||
|
||||
// Meta GET /api/audit/meta
|
||||
func (h *AuditHandler) Meta(c *gin.Context) {
|
||||
enabled := false
|
||||
retentionDays := 0
|
||||
if h.audit != nil {
|
||||
enabled = h.audit.Enabled()
|
||||
retentionDays = h.audit.RetentionDays()
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"enabled": enabled,
|
||||
"retention_days": retentionDays,
|
||||
"default_page_size": 20,
|
||||
"max_page_size": 100,
|
||||
"max_export": 5000,
|
||||
})
|
||||
}
|
||||
|
||||
// Summary GET /api/audit/summary
|
||||
func (h *AuditHandler) Summary(c *gin.Context) {
|
||||
if h.db == nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "database unavailable"})
|
||||
return
|
||||
}
|
||||
base := auditFilterFromQuery(c)
|
||||
total, err := h.db.CountAuditLogs(base)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
failFilter := base
|
||||
failFilter.Result = "failure"
|
||||
failures, err := h.db.CountAuditLogs(failFilter)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
since := time.Now().AddDate(0, 0, -7)
|
||||
recentFilter := base
|
||||
recentFilter.Since = &since
|
||||
recent7d, err := h.db.CountAuditLogs(recentFilter)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"total": total,
|
||||
"failures": failures,
|
||||
"recent_7d": recent7d,
|
||||
"has_filters": c.Query("category") != "" || c.Query("action") != "" || c.Query("result") != "" ||
|
||||
c.Query("q") != "" || c.Query("since") != "" || c.Query("until") != "",
|
||||
})
|
||||
}
|
||||
|
||||
// ListLogs GET /api/audit/logs
|
||||
func (h *AuditHandler) ListLogs(c *gin.Context) {
|
||||
if h.db == nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "database unavailable"})
|
||||
return
|
||||
}
|
||||
filter := auditFilterFromQuery(c)
|
||||
page, pageSize := auditPaginationFromQuery(c)
|
||||
filter.Limit = pageSize
|
||||
filter.Offset = (page - 1) * pageSize
|
||||
|
||||
logs, err := h.db.ListAuditLogs(filter)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
total, err := h.db.CountAuditLogs(filter)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"logs": logs,
|
||||
"total": total,
|
||||
"page": page,
|
||||
"page_size": pageSize,
|
||||
})
|
||||
}
|
||||
|
||||
// GetLog GET /api/audit/logs/:id
|
||||
func (h *AuditHandler) GetLog(c *gin.Context) {
|
||||
if h.db == nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "database unavailable"})
|
||||
return
|
||||
}
|
||||
row, err := h.db.GetAuditLogByID(c.Param("id"))
|
||||
if err != nil {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "审计记录不存在"})
|
||||
return
|
||||
}
|
||||
audit.ApplyResourceAvailability(h.db, row)
|
||||
c.JSON(http.StatusOK, gin.H{"log": row})
|
||||
}
|
||||
|
||||
// ExportLogs GET /api/audit/logs/export — JSON or CSV (?format=csv), max 5000 rows.
|
||||
func (h *AuditHandler) ExportLogs(c *gin.Context) {
|
||||
if h.db == nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "database unavailable"})
|
||||
return
|
||||
}
|
||||
filter := auditFilterFromQuery(c)
|
||||
filter.Limit = 5000
|
||||
filter.Offset = 0
|
||||
|
||||
logs, err := h.db.ListAuditLogs(filter)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
if c.Query("format") == "csv" {
|
||||
writeAuditLogsCSV(c, logs)
|
||||
return
|
||||
}
|
||||
c.Header("Content-Disposition", `attachment; filename="audit-logs.json"`)
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"exported_at": time.Now().UTC().Format(time.RFC3339),
|
||||
"logs": logs,
|
||||
})
|
||||
}
|
||||
@@ -0,0 +1,42 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"encoding/csv"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"cyberstrike-ai/internal/database"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func writeAuditLogsCSV(c *gin.Context, logs []*database.AuditLog) {
|
||||
c.Header("Content-Type", "text/csv; charset=utf-8")
|
||||
c.Header("Content-Disposition", fmt.Sprintf(`attachment; filename="audit-logs-%s.csv"`, time.Now().Format("20060102")))
|
||||
|
||||
w := csv.NewWriter(c.Writer)
|
||||
_ = w.Write([]string{
|
||||
"id", "created_at", "level", "category", "action", "result", "actor",
|
||||
"session_hint", "client_ip", "resource_type", "resource_id", "message",
|
||||
})
|
||||
for _, row := range logs {
|
||||
if row == nil {
|
||||
continue
|
||||
}
|
||||
_ = w.Write([]string{
|
||||
row.ID,
|
||||
row.CreatedAt.UTC().Format(time.RFC3339),
|
||||
row.Level,
|
||||
row.Category,
|
||||
row.Action,
|
||||
row.Result,
|
||||
row.Actor,
|
||||
row.SessionHint,
|
||||
row.ClientIP,
|
||||
row.ResourceType,
|
||||
row.ResourceID,
|
||||
row.Message,
|
||||
})
|
||||
}
|
||||
w.Flush()
|
||||
}
|
||||
@@ -0,0 +1,48 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"cyberstrike-ai/internal/database"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func auditFilterFromQuery(c *gin.Context) database.ListAuditLogsFilter {
|
||||
filter := database.ListAuditLogsFilter{
|
||||
Level: c.Query("level"),
|
||||
Category: c.Query("category"),
|
||||
Action: c.Query("action"),
|
||||
Result: c.Query("result"),
|
||||
Query: c.Query("q"),
|
||||
ResourceType: c.Query("resource_type"),
|
||||
ResourceID: c.Query("resource_id"),
|
||||
}
|
||||
if since := c.Query("since"); since != "" {
|
||||
if t, err := time.Parse(time.RFC3339, since); err == nil {
|
||||
filter.Since = &t
|
||||
}
|
||||
}
|
||||
if until := c.Query("until"); until != "" {
|
||||
if t, err := time.Parse(time.RFC3339, until); err == nil {
|
||||
filter.Until = &t
|
||||
}
|
||||
}
|
||||
return filter
|
||||
}
|
||||
|
||||
func auditPaginationFromQuery(c *gin.Context) (page, pageSize int) {
|
||||
page = 1
|
||||
pageSize = 20
|
||||
if p, err := strconv.Atoi(c.DefaultQuery("page", "1")); err == nil && p > 0 {
|
||||
page = p
|
||||
}
|
||||
if ps, err := strconv.Atoi(c.DefaultQuery("page_size", "20")); err == nil && ps > 0 {
|
||||
pageSize = ps
|
||||
if pageSize > 100 {
|
||||
pageSize = 100
|
||||
}
|
||||
}
|
||||
return page, pageSize
|
||||
}
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"cyberstrike-ai/internal/audit"
|
||||
"cyberstrike-ai/internal/config"
|
||||
"cyberstrike-ai/internal/security"
|
||||
|
||||
@@ -18,6 +19,12 @@ type AuthHandler struct {
|
||||
config *config.Config
|
||||
configPath string
|
||||
logger *zap.Logger
|
||||
audit *audit.Service
|
||||
}
|
||||
|
||||
// SetAudit wires platform audit logging.
|
||||
func (h *AuthHandler) SetAudit(s *audit.Service) {
|
||||
h.audit = s
|
||||
}
|
||||
|
||||
// NewAuthHandler creates a new AuthHandler.
|
||||
@@ -49,10 +56,32 @@ func (h *AuthHandler) Login(c *gin.Context) {
|
||||
|
||||
token, expiresAt, err := h.manager.Authenticate(req.Password)
|
||||
if err != nil {
|
||||
if h.audit != nil {
|
||||
h.audit.Record(c, audit.Entry{
|
||||
Level: "warn",
|
||||
Category: "auth",
|
||||
Action: "login",
|
||||
Result: "failure",
|
||||
Message: "登录失败:密码错误",
|
||||
})
|
||||
}
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "密码错误"})
|
||||
return
|
||||
}
|
||||
|
||||
if h.audit != nil {
|
||||
h.audit.Record(c, audit.Entry{
|
||||
Category: "auth",
|
||||
Action: "login",
|
||||
Result: "success",
|
||||
SessionHint: audit.HintFromToken(token),
|
||||
Message: "登录成功",
|
||||
Detail: map[string]interface{}{
|
||||
"expires_at": expiresAt.UTC().Format(time.RFC3339),
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"token": token,
|
||||
"expires_at": expiresAt.UTC().Format(time.RFC3339),
|
||||
@@ -73,6 +102,14 @@ func (h *AuthHandler) Logout(c *gin.Context) {
|
||||
}
|
||||
|
||||
h.manager.RevokeToken(token)
|
||||
if h.audit != nil {
|
||||
h.audit.Record(c, audit.Entry{
|
||||
Category: "auth",
|
||||
Action: "logout",
|
||||
Result: "success",
|
||||
Message: "退出登录",
|
||||
})
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{"message": "已退出登录"})
|
||||
}
|
||||
|
||||
@@ -103,6 +140,15 @@ func (h *AuthHandler) ChangePassword(c *gin.Context) {
|
||||
}
|
||||
|
||||
if !h.manager.CheckPassword(oldPassword) {
|
||||
if h.audit != nil {
|
||||
h.audit.Record(c, audit.Entry{
|
||||
Level: "warn",
|
||||
Category: "auth",
|
||||
Action: "change_password",
|
||||
Result: "failure",
|
||||
Message: "修改密码失败:当前密码不正确",
|
||||
})
|
||||
}
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "当前密码不正确"})
|
||||
return
|
||||
}
|
||||
@@ -132,6 +178,15 @@ func (h *AuthHandler) ChangePassword(c *gin.Context) {
|
||||
h.logger.Info("登录密码已更新,所有会话已失效")
|
||||
}
|
||||
|
||||
if h.audit != nil {
|
||||
h.audit.Record(c, audit.Entry{
|
||||
Category: "auth",
|
||||
Action: "change_password",
|
||||
Result: "success",
|
||||
Message: "登录密码已修改",
|
||||
})
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{"message": "密码已更新,请使用新密码重新登录"})
|
||||
}
|
||||
|
||||
|
||||
@@ -13,6 +13,7 @@ import (
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"cyberstrike-ai/internal/audit"
|
||||
"cyberstrike-ai/internal/c2"
|
||||
"cyberstrike-ai/internal/database"
|
||||
|
||||
@@ -25,6 +26,12 @@ import (
|
||||
type C2Handler struct {
|
||||
mgrPtr atomic.Pointer[c2.Manager]
|
||||
logger *zap.Logger
|
||||
audit *audit.Service
|
||||
}
|
||||
|
||||
// SetAudit wires platform audit logging.
|
||||
func (h *C2Handler) SetAudit(s *audit.Service) {
|
||||
h.audit = s
|
||||
}
|
||||
|
||||
// NewC2Handler 创建 C2 处理器;manager 可为 nil(功能关闭时)
|
||||
@@ -104,6 +111,11 @@ func (h *C2Handler) CreateListener(c *gin.Context) {
|
||||
implantToken := listener.ImplantToken
|
||||
listener.EncryptionKey = ""
|
||||
listener.ImplantToken = ""
|
||||
if h.audit != nil {
|
||||
h.audit.RecordOK(c, "c2", "listener_create", "创建 C2 监听器", "c2_listener", listener.ID, map[string]interface{}{
|
||||
"name": listener.Name, "bind": listener.BindHost, "port": listener.BindPort,
|
||||
})
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{"listener": listener, "implant_token": implantToken})
|
||||
}
|
||||
|
||||
@@ -205,6 +217,9 @@ func (h *C2Handler) DeleteListener(c *gin.Context) {
|
||||
c.JSON(code, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
if h.audit != nil {
|
||||
h.audit.RecordOK(c, "c2", "listener_delete", "删除 C2 监听器", "c2_listener", id, nil)
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{"deleted": true})
|
||||
}
|
||||
|
||||
@@ -222,6 +237,9 @@ func (h *C2Handler) StartListener(c *gin.Context) {
|
||||
}
|
||||
listener.EncryptionKey = ""
|
||||
listener.ImplantToken = ""
|
||||
if h.audit != nil {
|
||||
h.audit.RecordOK(c, "c2", "listener_start", "启动 C2 监听器", "c2_listener", id, nil)
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{"listener": listener})
|
||||
}
|
||||
|
||||
@@ -236,6 +254,9 @@ func (h *C2Handler) StopListener(c *gin.Context) {
|
||||
c.JSON(code, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
if h.audit != nil {
|
||||
h.audit.RecordOK(c, "c2", "listener_stop", "停止 C2 监听器", "c2_listener", id, nil)
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{"stopped": true})
|
||||
}
|
||||
|
||||
@@ -297,6 +318,9 @@ func (h *C2Handler) DeleteSession(c *gin.Context) {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
if h.audit != nil {
|
||||
h.audit.RecordOK(c, "c2", "session_delete", "删除 C2 会话", "c2_session", id, nil)
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{"deleted": true})
|
||||
}
|
||||
|
||||
@@ -407,6 +431,11 @@ func (h *C2Handler) DeleteTasks(c *gin.Context) {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
if h.audit != nil {
|
||||
h.audit.RecordOK(c, "c2", "task_delete", "批量删除 C2 任务", "c2_task", "", map[string]interface{}{
|
||||
"count": n, "ids": req.IDs,
|
||||
})
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{"deleted": n})
|
||||
}
|
||||
|
||||
@@ -457,6 +486,11 @@ func (h *C2Handler) CreateTask(c *gin.Context) {
|
||||
c.JSON(code, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
if h.audit != nil {
|
||||
h.audit.RecordOK(c, "c2", "task_create", "创建 C2 任务", "c2_task", task.ID, map[string]interface{}{
|
||||
"session_id": req.SessionID, "task_type": req.TaskType,
|
||||
})
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{"task": task})
|
||||
}
|
||||
|
||||
@@ -471,6 +505,9 @@ func (h *C2Handler) CancelTask(c *gin.Context) {
|
||||
c.JSON(code, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
if h.audit != nil {
|
||||
h.audit.RecordOK(c, "c2", "task_cancel", "取消 C2 任务", "c2_task", id, nil)
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{"cancelled": true})
|
||||
}
|
||||
|
||||
|
||||
@@ -12,6 +12,8 @@ import (
|
||||
"time"
|
||||
"unicode/utf8"
|
||||
|
||||
"cyberstrike-ai/internal/audit"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
@@ -24,6 +26,12 @@ const (
|
||||
// ChatUploadsHandler 对话中上传附件(chat_uploads 目录)的管理 API
|
||||
type ChatUploadsHandler struct {
|
||||
logger *zap.Logger
|
||||
audit *audit.Service
|
||||
}
|
||||
|
||||
// SetAudit wires platform audit logging.
|
||||
func (h *ChatUploadsHandler) SetAudit(s *audit.Service) {
|
||||
h.audit = s
|
||||
}
|
||||
|
||||
// NewChatUploadsHandler 创建处理器
|
||||
@@ -230,6 +238,9 @@ func (h *ChatUploadsHandler) Delete(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
}
|
||||
if h.audit != nil {
|
||||
h.audit.RecordOK(c, "file", "delete", "删除对话附件", "chat_upload", body.Path, nil)
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{"ok": true})
|
||||
}
|
||||
|
||||
@@ -503,6 +514,11 @@ func (h *ChatUploadsHandler) Upload(c *gin.Context) {
|
||||
}
|
||||
rel, _ := filepath.Rel(root, fullPath)
|
||||
absSaved, _ := filepath.Abs(fullPath)
|
||||
if h.audit != nil {
|
||||
h.audit.RecordOK(c, "file", "upload", "上传对话附件", "chat_upload", filepath.ToSlash(rel), map[string]interface{}{
|
||||
"name": unique,
|
||||
})
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"ok": true,
|
||||
"relativePath": filepath.ToSlash(rel),
|
||||
|
||||
+154
-19
@@ -14,6 +14,7 @@ import (
|
||||
"time"
|
||||
|
||||
"cyberstrike-ai/internal/agents"
|
||||
"cyberstrike-ai/internal/audit"
|
||||
"cyberstrike-ai/internal/config"
|
||||
"cyberstrike-ai/internal/knowledge"
|
||||
"cyberstrike-ai/internal/mcp"
|
||||
@@ -87,6 +88,7 @@ type ConfigHandler struct {
|
||||
knowledgeInitializer KnowledgeInitializer // 知识库初始化器(可选)
|
||||
appUpdater AppUpdater // App更新器(可选)
|
||||
robotRestarter RobotRestarter // 机器人连接重启器(可选),ApplyConfig 时重启钉钉/飞书
|
||||
audit *audit.Service
|
||||
logger *zap.Logger
|
||||
mu sync.RWMutex
|
||||
lastEmbeddingConfig *config.EmbeddingConfig // 上一次的嵌入模型配置(用于检测变更)
|
||||
@@ -206,6 +208,32 @@ func (h *ConfigHandler) SetRobotRestarter(restarter RobotRestarter) {
|
||||
h.robotRestarter = restarter
|
||||
}
|
||||
|
||||
// SetAudit wires platform audit logging.
|
||||
func (h *ConfigHandler) SetAudit(s *audit.Service) {
|
||||
h.mu.Lock()
|
||||
defer h.mu.Unlock()
|
||||
h.audit = s
|
||||
}
|
||||
|
||||
// ApplyWechatRobotBinding 微信 iLink 扫码绑定成功后写入配置并重启机器人连接
|
||||
func (h *ConfigHandler) ApplyWechatRobotBinding(wc config.RobotWechatConfig) error {
|
||||
h.mu.Lock()
|
||||
wc.Enabled = true
|
||||
h.config.Robots.Wechat = wc
|
||||
h.mu.Unlock()
|
||||
if err := h.saveConfig(); err != nil {
|
||||
return err
|
||||
}
|
||||
if h.robotRestarter != nil {
|
||||
h.robotRestarter.RestartRobotConnections()
|
||||
}
|
||||
h.logger.Info("微信机器人绑定已保存",
|
||||
zap.String("ilink_bot_id", wc.ILinkBotID),
|
||||
zap.Bool("enabled", wc.Enabled),
|
||||
)
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetConfigResponse 获取配置响应
|
||||
type GetConfigResponse struct {
|
||||
OpenAI config.OpenAIConfig `json:"openai"`
|
||||
@@ -291,7 +319,7 @@ func (h *ConfigHandler) GetConfig(c *gin.Context) {
|
||||
}
|
||||
multiPub := config.MultiAgentPublic{
|
||||
Enabled: h.config.MultiAgent.Enabled,
|
||||
RobotUseMultiAgent: h.config.MultiAgent.RobotUseMultiAgent,
|
||||
RobotDefaultAgentMode: config.NormalizeRobotAgentMode(h.config.MultiAgent),
|
||||
BatchUseMultiAgent: h.config.MultiAgent.BatchUseMultiAgent,
|
||||
SubAgentCount: subAgentCount,
|
||||
Orchestration: config.NormalizeMultiAgentOrchestration(h.config.MultiAgent.Orchestration),
|
||||
@@ -609,15 +637,46 @@ func (h *ConfigHandler) GetTools(c *gin.Context) {
|
||||
|
||||
// UpdateConfigRequest 更新配置请求
|
||||
type UpdateConfigRequest struct {
|
||||
OpenAI *config.OpenAIConfig `json:"openai,omitempty"`
|
||||
FOFA *config.FofaConfig `json:"fofa,omitempty"`
|
||||
MCP *config.MCPConfig `json:"mcp,omitempty"`
|
||||
Tools []ToolEnableStatus `json:"tools,omitempty"`
|
||||
Agent *config.AgentConfig `json:"agent,omitempty"`
|
||||
Knowledge *config.KnowledgeConfig `json:"knowledge,omitempty"`
|
||||
Robots *config.RobotsConfig `json:"robots,omitempty"`
|
||||
MultiAgent *config.MultiAgentAPIUpdate `json:"multi_agent,omitempty"`
|
||||
C2 *config.C2APIUpdate `json:"c2,omitempty"`
|
||||
OpenAI *config.OpenAIConfig `json:"openai,omitempty"`
|
||||
FOFA *config.FofaConfig `json:"fofa,omitempty"`
|
||||
MCP *config.MCPConfig `json:"mcp,omitempty"`
|
||||
Tools []ToolEnableStatus `json:"tools,omitempty"`
|
||||
Agent *AgentConfigUpdate `json:"agent,omitempty"`
|
||||
Knowledge *config.KnowledgeConfig `json:"knowledge,omitempty"`
|
||||
Robots *config.RobotsConfig `json:"robots,omitempty"`
|
||||
MultiAgent *config.MultiAgentAPIUpdate `json:"multi_agent,omitempty"`
|
||||
C2 *config.C2APIUpdate `json:"c2,omitempty"`
|
||||
}
|
||||
|
||||
// AgentConfigUpdate 用于 PATCH /api/config 的 agent 段:仅 JSON 中出现的字段(指针非 nil)覆盖内存配置。
|
||||
// 避免旧版「整包替换 *AgentConfig」时,未传的整型字段被反序列化为 0 误覆盖(例如 tool_timeout_minutes 变成 0)。
|
||||
type AgentConfigUpdate struct {
|
||||
MaxIterations *int `json:"max_iterations,omitempty"`
|
||||
LargeResultThreshold *int `json:"large_result_threshold,omitempty"`
|
||||
ResultStorageDir *string `json:"result_storage_dir,omitempty"`
|
||||
ToolTimeoutMinutes *int `json:"tool_timeout_minutes,omitempty"`
|
||||
SystemPromptPath *string `json:"system_prompt_path,omitempty"`
|
||||
}
|
||||
|
||||
func applyAgentConfigUpdate(dst *config.AgentConfig, src *AgentConfigUpdate) {
|
||||
if dst == nil || src == nil {
|
||||
return
|
||||
}
|
||||
if src.MaxIterations != nil {
|
||||
dst.MaxIterations = *src.MaxIterations
|
||||
}
|
||||
if src.LargeResultThreshold != nil {
|
||||
dst.LargeResultThreshold = *src.LargeResultThreshold
|
||||
}
|
||||
if src.ResultStorageDir != nil {
|
||||
dst.ResultStorageDir = *src.ResultStorageDir
|
||||
}
|
||||
if src.ToolTimeoutMinutes != nil {
|
||||
dst.ToolTimeoutMinutes = *src.ToolTimeoutMinutes
|
||||
}
|
||||
if src.SystemPromptPath != nil {
|
||||
dst.SystemPromptPath = *src.SystemPromptPath
|
||||
}
|
||||
}
|
||||
|
||||
// ToolEnableStatus 工具启用状态
|
||||
@@ -664,12 +723,19 @@ func (h *ConfigHandler) UpdateConfig(c *gin.Context) {
|
||||
)
|
||||
}
|
||||
|
||||
// 更新Agent配置
|
||||
// 更新Agent配置(按字段合并,避免部分 JSON 把未出现的字段写成 0)
|
||||
if req.Agent != nil {
|
||||
h.config.Agent = *req.Agent
|
||||
applyAgentConfigUpdate(&h.config.Agent, req.Agent)
|
||||
h.logger.Info("更新Agent配置",
|
||||
zap.Int("max_iterations", h.config.Agent.MaxIterations),
|
||||
zap.Int("tool_timeout_minutes", h.config.Agent.ToolTimeoutMinutes),
|
||||
)
|
||||
if h.agent != nil && req.Agent.MaxIterations != nil {
|
||||
h.agent.UpdateMaxIterations(h.config.Agent.MaxIterations)
|
||||
}
|
||||
if h.mcpServer != nil {
|
||||
h.mcpServer.ConfigureHTTPToolCallTimeoutFromAgentMinutes(h.config.Agent.ToolTimeoutMinutes)
|
||||
}
|
||||
}
|
||||
|
||||
// 更新Knowledge配置
|
||||
@@ -697,6 +763,7 @@ func (h *ConfigHandler) UpdateConfig(c *gin.Context) {
|
||||
if req.Robots != nil {
|
||||
h.config.Robots = *req.Robots
|
||||
h.logger.Info("更新机器人配置",
|
||||
zap.Bool("wechat_enabled", h.config.Robots.Wechat.Enabled),
|
||||
zap.Bool("wecom_enabled", h.config.Robots.Wecom.Enabled),
|
||||
zap.Bool("dingtalk_enabled", h.config.Robots.Dingtalk.Enabled),
|
||||
zap.Bool("lark_enabled", h.config.Robots.Lark.Enabled),
|
||||
@@ -712,15 +779,21 @@ func (h *ConfigHandler) UpdateConfig(c *gin.Context) {
|
||||
// 多代理标量(sub_agents 等仍由 config.yaml 维护)
|
||||
if req.MultiAgent != nil {
|
||||
h.config.MultiAgent.Enabled = req.MultiAgent.Enabled
|
||||
h.config.MultiAgent.RobotUseMultiAgent = req.MultiAgent.RobotUseMultiAgent
|
||||
h.config.MultiAgent.BatchUseMultiAgent = req.MultiAgent.BatchUseMultiAgent
|
||||
if mode := strings.TrimSpace(req.MultiAgent.RobotDefaultAgentMode); mode != "" {
|
||||
h.config.MultiAgent.RobotDefaultAgentMode = mode
|
||||
} else {
|
||||
h.config.MultiAgent.RobotDefaultAgentMode = "react"
|
||||
}
|
||||
if req.MultiAgent.PlanExecuteLoopMaxIterations != nil {
|
||||
h.config.MultiAgent.PlanExecuteLoopMaxIterations = *req.MultiAgent.PlanExecuteLoopMaxIterations
|
||||
}
|
||||
h.config.MultiAgent.EinoMiddleware.ToolSearchAlwaysVisibleTools = dedupeToolNameList(req.MultiAgent.ToolSearchAlwaysVisibleTools)
|
||||
if req.MultiAgent.ToolSearchAlwaysVisibleTools != nil {
|
||||
h.config.MultiAgent.EinoMiddleware.ToolSearchAlwaysVisibleTools = dedupeToolNameList(*req.MultiAgent.ToolSearchAlwaysVisibleTools)
|
||||
}
|
||||
h.logger.Info("更新多代理配置",
|
||||
zap.Bool("enabled", h.config.MultiAgent.Enabled),
|
||||
zap.Bool("robot_use_multi_agent", h.config.MultiAgent.RobotUseMultiAgent),
|
||||
zap.String("robot_default_agent_mode", config.NormalizeRobotAgentMode(h.config.MultiAgent)),
|
||||
zap.Bool("batch_use_multi_agent", h.config.MultiAgent.BatchUseMultiAgent),
|
||||
zap.Int("plan_execute_loop_max_iterations", h.config.MultiAgent.PlanExecuteLoopMaxIterations),
|
||||
zap.Int("tool_search_always_visible_tools", len(h.config.MultiAgent.EinoMiddleware.ToolSearchAlwaysVisibleTools)),
|
||||
@@ -843,6 +916,9 @@ func (h *ConfigHandler) UpdateConfig(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
if h.audit != nil {
|
||||
h.audit.RecordOK(c, "config", "update", "更新内存配置", "config", "", nil)
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{"message": "配置已更新"})
|
||||
}
|
||||
|
||||
@@ -973,6 +1049,9 @@ func (h *ConfigHandler) ApplyConfig(c *gin.Context) {
|
||||
h.logger.Info("检测到知识库从禁用变为启用,开始动态初始化知识库组件")
|
||||
if _, err := knowledgeInitializer(); err != nil {
|
||||
h.logger.Error("动态初始化知识库失败", zap.Error(err))
|
||||
if h.audit != nil {
|
||||
h.audit.RecordFail(c, "config", "apply", "应用配置失败:初始化知识库", map[string]interface{}{"error": err.Error()})
|
||||
}
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "初始化知识库失败: " + err.Error()})
|
||||
return
|
||||
}
|
||||
@@ -1007,6 +1086,9 @@ func (h *ConfigHandler) ApplyConfig(c *gin.Context) {
|
||||
h.logger.Info("开始重新初始化知识库组件(嵌入模型配置已变更)")
|
||||
if _, err := reinitKnowledgeInitializer(); err != nil {
|
||||
h.logger.Error("重新初始化知识库失败", zap.Error(err))
|
||||
if h.audit != nil {
|
||||
h.audit.RecordFail(c, "config", "apply", "应用配置失败:重新初始化知识库", map[string]interface{}{"error": err.Error()})
|
||||
}
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "重新初始化知识库失败: " + err.Error()})
|
||||
return
|
||||
}
|
||||
@@ -1020,6 +1102,9 @@ func (h *ConfigHandler) ApplyConfig(c *gin.Context) {
|
||||
if c2Rt != nil {
|
||||
if err := c2Rt.ReconcileC2AfterConfigApply(); err != nil {
|
||||
h.logger.Error("C2 配置应用失败", zap.Error(err))
|
||||
if h.audit != nil {
|
||||
h.audit.RecordFail(c, "config", "apply", "应用配置失败:C2", map[string]interface{}{"error": err.Error()})
|
||||
}
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "C2 启动失败: " + err.Error()})
|
||||
return
|
||||
}
|
||||
@@ -1116,6 +1201,9 @@ func (h *ConfigHandler) ApplyConfig(c *gin.Context) {
|
||||
h.agent.UpdateToolDescriptionMode(h.config.Security.ToolDescriptionMode)
|
||||
h.logger.Info("Agent配置已更新")
|
||||
}
|
||||
if h.mcpServer != nil {
|
||||
h.mcpServer.ConfigureHTTPToolCallTimeoutFromAgentMinutes(h.config.Agent.ToolTimeoutMinutes)
|
||||
}
|
||||
|
||||
// 更新AttackChainHandler的OpenAI配置
|
||||
if h.attackChainHandler != nil {
|
||||
@@ -1158,6 +1246,20 @@ func (h *ConfigHandler) ApplyConfig(c *gin.Context) {
|
||||
zap.Int("tools_count", len(h.config.Security.Tools)),
|
||||
)
|
||||
|
||||
if h.audit != nil {
|
||||
h.audit.Record(c, audit.Entry{
|
||||
Category: "config",
|
||||
Action: "apply",
|
||||
Result: "success",
|
||||
Message: "配置已应用",
|
||||
Detail: map[string]interface{}{
|
||||
"tools_count": len(h.config.Security.Tools),
|
||||
"knowledge_enabled": h.config.Knowledge.Enabled,
|
||||
"c2_enabled": h.config.C2.EnabledEffective(),
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"message": "配置已应用",
|
||||
"tools_count": len(h.config.Security.Tools),
|
||||
@@ -1181,7 +1283,7 @@ func (h *ConfigHandler) saveConfig() error {
|
||||
return fmt.Errorf("解析配置文件失败: %w", err)
|
||||
}
|
||||
|
||||
updateAgentConfig(root, h.config.Agent.MaxIterations)
|
||||
updateAgentConfig(root, h.config.Agent)
|
||||
updateMCPConfig(root, h.config.MCP)
|
||||
updateOpenAIConfig(root, h.config.OpenAI)
|
||||
updateFOFAConfig(root, h.config.FOFA)
|
||||
@@ -1286,10 +1388,14 @@ func writeYAMLDocument(path string, doc *yaml.Node) error {
|
||||
return os.WriteFile(path, buf.Bytes(), 0644)
|
||||
}
|
||||
|
||||
func updateAgentConfig(doc *yaml.Node, maxIterations int) {
|
||||
func updateAgentConfig(doc *yaml.Node, agent config.AgentConfig) {
|
||||
root := doc.Content[0]
|
||||
agentNode := ensureMap(root, "agent")
|
||||
setIntInMap(agentNode, "max_iterations", maxIterations)
|
||||
setIntInMap(agentNode, "max_iterations", agent.MaxIterations)
|
||||
setIntInMap(agentNode, "tool_timeout_minutes", agent.ToolTimeoutMinutes)
|
||||
setIntInMap(agentNode, "large_result_threshold", agent.LargeResultThreshold)
|
||||
setStringInMap(agentNode, "result_storage_dir", agent.ResultStorageDir)
|
||||
setStringInMap(agentNode, "system_prompt_path", agent.SystemPromptPath)
|
||||
}
|
||||
|
||||
func updateMCPConfig(doc *yaml.Node, cfg config.MCPConfig) {
|
||||
@@ -1312,6 +1418,19 @@ func updateOpenAIConfig(doc *yaml.Node, cfg config.OpenAIConfig) {
|
||||
if cfg.MaxTotalTokens > 0 {
|
||||
setIntInMap(openaiNode, "max_total_tokens", cfg.MaxTotalTokens)
|
||||
}
|
||||
rn := ensureMap(openaiNode, "reasoning")
|
||||
if strings.TrimSpace(cfg.Reasoning.Mode) != "" {
|
||||
setStringInMap(rn, "mode", cfg.Reasoning.Mode)
|
||||
}
|
||||
if strings.TrimSpace(cfg.Reasoning.Effort) != "" {
|
||||
setStringInMap(rn, "effort", cfg.Reasoning.Effort)
|
||||
}
|
||||
if cfg.Reasoning.AllowClientReasoning != nil {
|
||||
setBoolInMap(rn, "allow_client_reasoning", *cfg.Reasoning.AllowClientReasoning)
|
||||
}
|
||||
if strings.TrimSpace(cfg.Reasoning.Profile) != "" {
|
||||
setStringInMap(rn, "profile", cfg.Reasoning.Profile)
|
||||
}
|
||||
}
|
||||
|
||||
func updateFOFAConfig(doc *yaml.Node, cfg config.FofaConfig) {
|
||||
@@ -1416,6 +1535,20 @@ func updateRobotsConfig(doc *yaml.Node, cfg config.RobotsConfig) {
|
||||
root := doc.Content[0]
|
||||
robotsNode := ensureMap(root, "robots")
|
||||
|
||||
if cfg.Session.StrictUserIdentity != nil {
|
||||
sessionNode := ensureMap(robotsNode, "session")
|
||||
setBoolInMap(sessionNode, "strict_user_identity", *cfg.Session.StrictUserIdentity)
|
||||
}
|
||||
|
||||
wechatNode := ensureMap(robotsNode, "wechat")
|
||||
setBoolInMap(wechatNode, "enabled", cfg.Wechat.Enabled)
|
||||
setStringInMap(wechatNode, "bot_token", cfg.Wechat.BotToken)
|
||||
setStringInMap(wechatNode, "ilink_bot_id", cfg.Wechat.ILinkBotID)
|
||||
setStringInMap(wechatNode, "ilink_user_id", cfg.Wechat.ILinkUserID)
|
||||
setStringInMap(wechatNode, "base_url", cfg.Wechat.BaseURL)
|
||||
setStringInMap(wechatNode, "bot_type", cfg.Wechat.BotType)
|
||||
setStringInMap(wechatNode, "bot_agent", cfg.Wechat.BotAgent)
|
||||
|
||||
wecomNode := ensureMap(robotsNode, "wecom")
|
||||
setBoolInMap(wecomNode, "enabled", cfg.Wecom.Enabled)
|
||||
setStringInMap(wecomNode, "token", cfg.Wecom.Token)
|
||||
@@ -1428,19 +1561,21 @@ func updateRobotsConfig(doc *yaml.Node, cfg config.RobotsConfig) {
|
||||
setBoolInMap(dingtalkNode, "enabled", cfg.Dingtalk.Enabled)
|
||||
setStringInMap(dingtalkNode, "client_id", cfg.Dingtalk.ClientID)
|
||||
setStringInMap(dingtalkNode, "client_secret", cfg.Dingtalk.ClientSecret)
|
||||
setBoolInMap(dingtalkNode, "allow_conversation_id_fallback", cfg.Dingtalk.AllowConversationIDFallback)
|
||||
|
||||
larkNode := ensureMap(robotsNode, "lark")
|
||||
setBoolInMap(larkNode, "enabled", cfg.Lark.Enabled)
|
||||
setStringInMap(larkNode, "app_id", cfg.Lark.AppID)
|
||||
setStringInMap(larkNode, "app_secret", cfg.Lark.AppSecret)
|
||||
setStringInMap(larkNode, "verify_token", cfg.Lark.VerifyToken)
|
||||
setBoolInMap(larkNode, "allow_chat_id_fallback", cfg.Lark.AllowChatIDFallback)
|
||||
}
|
||||
|
||||
func updateMultiAgentConfig(doc *yaml.Node, cfg config.MultiAgentConfig) {
|
||||
root := doc.Content[0]
|
||||
maNode := ensureMap(root, "multi_agent")
|
||||
setBoolInMap(maNode, "enabled", cfg.Enabled)
|
||||
setBoolInMap(maNode, "robot_use_multi_agent", cfg.RobotUseMultiAgent)
|
||||
setStringInMap(maNode, "robot_default_agent_mode", config.NormalizeRobotAgentMode(cfg))
|
||||
setBoolInMap(maNode, "batch_use_multi_agent", cfg.BatchUseMultiAgent)
|
||||
setIntInMap(maNode, "plan_execute_loop_max_iterations", cfg.PlanExecuteLoopMaxIterations)
|
||||
mwNode := ensureMap(maNode, "eino_middleware")
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"net/http"
|
||||
"strconv"
|
||||
|
||||
"cyberstrike-ai/internal/audit"
|
||||
"cyberstrike-ai/internal/database"
|
||||
"github.com/gin-gonic/gin"
|
||||
"go.uber.org/zap"
|
||||
@@ -14,6 +15,12 @@ import (
|
||||
type ConversationHandler struct {
|
||||
db *database.DB
|
||||
logger *zap.Logger
|
||||
audit *audit.Service
|
||||
}
|
||||
|
||||
// SetAudit wires platform audit logging.
|
||||
func (h *ConversationHandler) SetAudit(s *audit.Service) {
|
||||
h.audit = s
|
||||
}
|
||||
|
||||
// NewConversationHandler 创建新的对话处理器
|
||||
@@ -42,7 +49,7 @@ func (h *ConversationHandler) CreateConversation(c *gin.Context) {
|
||||
title = "新对话"
|
||||
}
|
||||
|
||||
conv, err := h.db.CreateConversation(title)
|
||||
conv, err := h.db.CreateConversation(title, audit.ConversationCreateMetaFromGin(c, "api"))
|
||||
if err != nil {
|
||||
h.logger.Error("创建对话失败", zap.Error(err))
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
@@ -117,6 +124,8 @@ func (h *ConversationHandler) GetMessageProcessDetails(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
details = database.DedupeConsecutiveProcessDetails(details)
|
||||
|
||||
// 转换为前端期望的 JSON 结构(与 GetConversation 中 processDetails 结构一致)
|
||||
out := make([]map[string]interface{}, 0, len(details))
|
||||
for _, d := range details {
|
||||
@@ -187,6 +196,17 @@ func (h *ConversationHandler) DeleteConversation(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
if h.audit != nil {
|
||||
h.audit.Record(c, audit.Entry{
|
||||
Category: "conversation",
|
||||
Action: "delete",
|
||||
Result: "success",
|
||||
ResourceType: "conversation",
|
||||
ResourceID: id,
|
||||
Message: "删除对话",
|
||||
})
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{"message": "删除成功"})
|
||||
}
|
||||
|
||||
@@ -225,6 +245,12 @@ func (h *ConversationHandler) DeleteConversationTurn(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
if h.audit != nil {
|
||||
h.audit.RecordOK(c, "conversation", "delete_turn", "删除对话轮次", "conversation", conversationID, map[string]interface{}{
|
||||
"message_id": req.MessageID,
|
||||
"deleted": len(deletedIDs),
|
||||
})
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"deletedMessageIds": deletedIDs,
|
||||
"message": "ok",
|
||||
|
||||
@@ -0,0 +1,122 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"cyberstrike-ai/internal/agent"
|
||||
"cyberstrike-ai/internal/multiagent"
|
||||
)
|
||||
|
||||
func (h *AgentHandler) einoRunRetryMaxAttempts() int {
|
||||
if h.config != nil {
|
||||
return multiagent.RunRetryMaxAttemptsFromConfig(&h.config.MultiAgent.EinoMiddleware)
|
||||
}
|
||||
return multiagent.RunRetryMaxAttemptsFromConfig(nil)
|
||||
}
|
||||
|
||||
func (h *AgentHandler) einoRunRetryMaxBackoffSec() int {
|
||||
if h.config != nil && h.config.MultiAgent.EinoMiddleware.RunRetryMaxBackoffSec > 0 {
|
||||
return h.config.MultiAgent.EinoMiddleware.RunRetryMaxBackoffSec
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
// applyEinoTraceResumeSegment 中断并继续:persist last_react_* → loadHistory,可选替换下一段 user 文案。
|
||||
func (h *AgentHandler) applyEinoTraceResumeSegment(
|
||||
conversationID string,
|
||||
result *multiagent.RunResult,
|
||||
curHistory *[]agent.ChatMessage,
|
||||
curFinalMessage *string,
|
||||
segmentUserMessage string,
|
||||
) {
|
||||
if shouldPersistEinoAgentTraceAfterRunError(context.Background()) {
|
||||
h.persistEinoAgentTraceForResume(conversationID, result)
|
||||
}
|
||||
if hist, err := h.loadHistoryFromAgentTrace(conversationID); err == nil && len(hist) > 0 {
|
||||
*curHistory = hist
|
||||
}
|
||||
if segmentUserMessage != "" {
|
||||
*curFinalMessage = segmentUserMessage
|
||||
}
|
||||
}
|
||||
|
||||
// applyEinoTransientRetrySegment 临时错误重试:恢复轨迹并保留本请求原始 user 文案(不注入续跑说明)。
|
||||
// segmentUserMessage 为本轮 HTTP 请求开始时用户发送的内容,避免因清空 finalMessage 而丢失「你好」等短句。
|
||||
func (h *AgentHandler) applyEinoTransientRetrySegment(
|
||||
conversationID string,
|
||||
result *multiagent.RunResult,
|
||||
curHistory *[]agent.ChatMessage,
|
||||
curFinalMessage *string,
|
||||
segmentUserMessage string,
|
||||
) {
|
||||
if shouldPersistEinoAgentTraceAfterRunError(context.Background()) {
|
||||
h.persistEinoAgentTraceForResume(conversationID, result)
|
||||
}
|
||||
if hist, err := h.loadHistoryFromAgentTrace(conversationID); err == nil && len(hist) > 0 {
|
||||
*curHistory = hist
|
||||
}
|
||||
if s := strings.TrimSpace(segmentUserMessage); s != "" {
|
||||
*curFinalMessage = segmentUserMessage
|
||||
}
|
||||
}
|
||||
|
||||
// handleEinoTransientRetryContinue 在 SSE 任务循环内处理临时错误重试;返回 true 表示外层 for 应 continue。
|
||||
func (h *AgentHandler) handleEinoTransientRetryContinue(
|
||||
baseCtx context.Context,
|
||||
conversationID string,
|
||||
result *multiagent.RunResult,
|
||||
runErr error,
|
||||
transientAttempts *int,
|
||||
curHistory *[]agent.ChatMessage,
|
||||
curFinalMessage *string,
|
||||
segmentUserMessage string,
|
||||
progressCallback func(eventType, message string, data interface{}),
|
||||
sendProgress func(msg string, extra map[string]interface{}),
|
||||
) (handled bool, fatal error) {
|
||||
if !errors.Is(runErr, multiagent.ErrTransientRetryContinue) {
|
||||
return false, nil
|
||||
}
|
||||
maxAttempts := h.einoRunRetryMaxAttempts()
|
||||
*transientAttempts++
|
||||
if *transientAttempts > maxAttempts {
|
||||
if shouldPersistEinoAgentTraceAfterRunError(baseCtx) {
|
||||
h.persistEinoAgentTraceForResume(conversationID, result)
|
||||
}
|
||||
return false, errors.New("transient retry exhausted: " + runErr.Error())
|
||||
}
|
||||
attemptNo := *transientAttempts
|
||||
backoff := multiagent.TransientRetryBackoff(attemptNo-1, h.einoRunRetryMaxBackoffSec())
|
||||
if progressCallback != nil {
|
||||
progressCallback("eino_run_retry", fmt.Sprintf("遇到临时错误,%d 秒后第 %d/%d 次重试…", int(backoff.Seconds()), attemptNo, maxAttempts), map[string]interface{}{
|
||||
"conversationId": conversationID,
|
||||
"source": "eino",
|
||||
"attempt": attemptNo,
|
||||
"maxAttempts": maxAttempts,
|
||||
"backoffSec": int(backoff.Seconds()),
|
||||
})
|
||||
}
|
||||
select {
|
||||
case <-baseCtx.Done():
|
||||
return false, context.Cause(baseCtx)
|
||||
case <-time.After(backoff):
|
||||
}
|
||||
h.applyEinoTransientRetrySegment(conversationID, result, curHistory, curFinalMessage, segmentUserMessage)
|
||||
if progressCallback != nil {
|
||||
progressCallback("eino_run_retry", "已恢复上下文,正在重试…", map[string]interface{}{
|
||||
"conversationId": conversationID,
|
||||
"source": "eino",
|
||||
"attempt": attemptNo,
|
||||
})
|
||||
}
|
||||
if sendProgress != nil {
|
||||
sendProgress("正在重试…", map[string]interface{}{
|
||||
"conversationId": conversationID,
|
||||
"source": "transient_retry",
|
||||
})
|
||||
}
|
||||
return true, nil
|
||||
}
|
||||
@@ -46,7 +46,7 @@ func (h *AgentHandler) EinoSingleAgentLoopStream(c *gin.Context) {
|
||||
sendEvent := func(eventType, message string, data interface{}) {
|
||||
if eventType == "error" && baseCtx != nil {
|
||||
cause := context.Cause(baseCtx)
|
||||
if errors.Is(cause, ErrTaskCancelled) {
|
||||
if errors.Is(cause, ErrTaskCancelled) || errors.Is(cause, multiagent.ErrInterruptContinue) {
|
||||
return
|
||||
}
|
||||
}
|
||||
@@ -90,7 +90,7 @@ func (h *AgentHandler) EinoSingleAgentLoopStream(c *gin.Context) {
|
||||
zap.String("conversationId", req.ConversationID),
|
||||
)
|
||||
|
||||
prep, err := h.prepareMultiAgentSession(&req)
|
||||
prep, err := h.prepareMultiAgentSession(&req, c, "eino_agent_stream")
|
||||
if err != nil {
|
||||
sendEvent("error", err.Error(), nil)
|
||||
sendEvent("done", "", nil)
|
||||
@@ -119,6 +119,7 @@ func (h *AgentHandler) EinoSingleAgentLoopStream(c *gin.Context) {
|
||||
|
||||
var cancelWithCause context.CancelCauseFunc
|
||||
curFinalMessage := prep.FinalMessage
|
||||
segmentUserMessage := prep.FinalMessage // 本请求原始用户句,临时重试时不得丢失
|
||||
curHistory := prep.History
|
||||
roleTools := prep.RoleTools
|
||||
|
||||
@@ -175,29 +176,125 @@ func (h *AgentHandler) EinoSingleAgentLoopStream(c *gin.Context) {
|
||||
}
|
||||
taskOwned = true
|
||||
|
||||
progressCallback := h.createProgressCallback(taskCtx, cancelWithCause, conversationID, assistantMessageID, sendEvent)
|
||||
taskCtx = mcp.WithMCPConversationID(taskCtx, conversationID)
|
||||
taskCtx = mcp.WithToolRunRegistry(taskCtx, h.tasks)
|
||||
taskCtx = multiagent.WithHITLToolInterceptor(taskCtx, func(ctx context.Context, toolName, arguments string) (string, error) {
|
||||
return h.interceptHITLForEinoTool(ctx, cancelWithCause, conversationID, assistantMessageID, sendEvent, toolName, arguments)
|
||||
})
|
||||
var cumulativeMCPExecutionIDs []string
|
||||
var transientRunAttempts int
|
||||
// 同一请求内分段续跑时,主代理 iteration 事件按偏移累计,避免 UI 出现「第3轮 → 第1轮」回跳。
|
||||
var mainIterationOffset int
|
||||
|
||||
result, runErr = multiagent.RunEinoSingleChatModelAgent(
|
||||
taskCtx,
|
||||
h.config,
|
||||
&h.config.MultiAgent,
|
||||
h.agent,
|
||||
h.logger,
|
||||
conversationID,
|
||||
curFinalMessage,
|
||||
curHistory,
|
||||
roleTools,
|
||||
progressCallback,
|
||||
)
|
||||
timeoutCancel()
|
||||
for {
|
||||
segmentMainIterationMax := 0
|
||||
rawProgressCallback := h.createProgressCallback(taskCtx, cancelWithCause, conversationID, assistantMessageID, sendEvent)
|
||||
progressCallback := func(eventType, message string, data interface{}) {
|
||||
if eventType == "iteration" {
|
||||
if m, ok := data.(map[string]interface{}); ok {
|
||||
if scope, _ := m["einoScope"].(string); scope == "main" {
|
||||
raw := 0
|
||||
switch v := m["iteration"].(type) {
|
||||
case int:
|
||||
raw = v
|
||||
case int32:
|
||||
raw = int(v)
|
||||
case int64:
|
||||
raw = int(v)
|
||||
case float64:
|
||||
raw = int(v)
|
||||
case float32:
|
||||
raw = int(v)
|
||||
}
|
||||
if raw > 0 {
|
||||
if raw > segmentMainIterationMax {
|
||||
segmentMainIterationMax = raw
|
||||
}
|
||||
m["iteration"] = raw + mainIterationOffset
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
rawProgressCallback(eventType, message, data)
|
||||
}
|
||||
taskCtxLoop := mcp.WithMCPConversationID(taskCtx, conversationID)
|
||||
taskCtxLoop = mcp.WithToolRunRegistry(taskCtxLoop, h.tasks)
|
||||
taskCtxLoop = multiagent.WithHITLToolInterceptor(taskCtxLoop, func(ctx context.Context, toolName, arguments string) (string, error) {
|
||||
return h.interceptHITLForEinoTool(ctx, cancelWithCause, conversationID, assistantMessageID, sendEvent, toolName, arguments)
|
||||
})
|
||||
|
||||
result, runErr = multiagent.RunEinoSingleChatModelAgent(
|
||||
taskCtxLoop,
|
||||
h.config,
|
||||
&h.config.MultiAgent,
|
||||
h.agent,
|
||||
h.logger,
|
||||
conversationID,
|
||||
curFinalMessage,
|
||||
curHistory,
|
||||
roleTools,
|
||||
progressCallback,
|
||||
chatReasoningToClientIntent(req.Reasoning),
|
||||
)
|
||||
|
||||
if result != nil && len(result.MCPExecutionIDs) > 0 {
|
||||
cumulativeMCPExecutionIDs = mergeMCPExecutionIDLists(cumulativeMCPExecutionIDs, result.MCPExecutionIDs)
|
||||
}
|
||||
|
||||
if runErr == nil {
|
||||
// 任一段成功完成后,重置临时错误重试窗口(次数/退避从头开始)。
|
||||
transientRunAttempts = 0
|
||||
timeoutCancel()
|
||||
break
|
||||
}
|
||||
|
||||
handled, fatalErr := h.handleEinoTransientRetryContinue(
|
||||
baseCtx, conversationID, result, runErr, &transientRunAttempts,
|
||||
&curHistory, &curFinalMessage, segmentUserMessage, progressCallback,
|
||||
func(msg string, extra map[string]interface{}) { sendEvent("progress", msg, extra) },
|
||||
)
|
||||
if handled {
|
||||
mainIterationOffset += segmentMainIterationMax
|
||||
timeoutCancel()
|
||||
baseCtx, cancelWithCause = context.WithCancelCause(context.Background())
|
||||
h.tasks.BindTaskCancel(conversationID, cancelWithCause)
|
||||
taskCtx, timeoutCancel = context.WithTimeout(baseCtx, 600*time.Minute)
|
||||
h.tasks.UpdateTaskStatus(conversationID, "running")
|
||||
continue
|
||||
}
|
||||
if fatalErr != nil {
|
||||
runErr = fatalErr
|
||||
}
|
||||
|
||||
if runErr != nil {
|
||||
cause := context.Cause(baseCtx)
|
||||
if errors.Is(cause, multiagent.ErrInterruptContinue) {
|
||||
if shouldPersistEinoAgentTraceAfterRunError(baseCtx) {
|
||||
h.persistEinoAgentTraceForResume(conversationID, result)
|
||||
}
|
||||
note := h.tasks.TakeInterruptContinueNote(conversationID)
|
||||
icSummary := interruptContinueTimelineSummary(note)
|
||||
progressCallback("user_interrupt_continue", icSummary, map[string]interface{}{
|
||||
"conversationId": conversationID,
|
||||
"rawReason": strings.TrimSpace(note),
|
||||
"emptyReason": strings.TrimSpace(note) == "",
|
||||
"kind": "no_active_mcp_tool",
|
||||
})
|
||||
inject := formatInterruptContinueUserMessage(note)
|
||||
// 不写入 messages 表为 user 气泡:避免主对话流出现大段模板;说明已由 user_interrupt_continue 记入助手 process_details(迭代详情)。
|
||||
if hist, err := h.loadHistoryFromAgentTrace(conversationID); err == nil && len(hist) > 0 {
|
||||
curHistory = hist
|
||||
}
|
||||
curFinalMessage = inject
|
||||
sendEvent("progress", "已合并用户补充与最新轨迹,正在继续推理…", map[string]interface{}{
|
||||
"conversationId": conversationID,
|
||||
"source": "interrupt_continue",
|
||||
})
|
||||
mainIterationOffset += segmentMainIterationMax
|
||||
// 非临时错误分段续跑(用户中断并继续)时,清空 transient 计数,避免跨分段累加。
|
||||
transientRunAttempts = 0
|
||||
timeoutCancel()
|
||||
baseCtx, cancelWithCause = context.WithCancelCause(context.Background())
|
||||
h.tasks.BindTaskCancel(conversationID, cancelWithCause)
|
||||
taskCtx, timeoutCancel = context.WithTimeout(baseCtx, 600*time.Minute)
|
||||
h.tasks.UpdateTaskStatus(conversationID, "running")
|
||||
continue
|
||||
}
|
||||
|
||||
if shouldPersistEinoAgentTraceAfterRunError(baseCtx) {
|
||||
h.persistEinoAgentTraceForResume(conversationID, result)
|
||||
}
|
||||
@@ -221,6 +318,7 @@ func (h *AgentHandler) EinoSingleAgentLoopStream(c *gin.Context) {
|
||||
"messageId": assistantMessageID,
|
||||
})
|
||||
sendEvent("done", "", map[string]interface{}{"conversationId": conversationID})
|
||||
timeoutCancel()
|
||||
return
|
||||
}
|
||||
|
||||
@@ -238,6 +336,7 @@ func (h *AgentHandler) EinoSingleAgentLoopStream(c *gin.Context) {
|
||||
"errorType": "timeout",
|
||||
})
|
||||
sendEvent("done", "", map[string]interface{}{"conversationId": conversationID})
|
||||
timeoutCancel()
|
||||
return
|
||||
}
|
||||
|
||||
@@ -254,22 +353,14 @@ func (h *AgentHandler) EinoSingleAgentLoopStream(c *gin.Context) {
|
||||
"messageId": assistantMessageID,
|
||||
})
|
||||
sendEvent("done", "", map[string]interface{}{"conversationId": conversationID})
|
||||
timeoutCancel()
|
||||
return
|
||||
}
|
||||
|
||||
timeoutCancel()
|
||||
|
||||
if assistantMessageID != "" {
|
||||
mcpIDsJSON := ""
|
||||
if len(result.MCPExecutionIDs) > 0 {
|
||||
jsonData, _ := json.Marshal(result.MCPExecutionIDs)
|
||||
mcpIDsJSON = string(jsonData)
|
||||
}
|
||||
_, _ = h.db.Exec(
|
||||
"UPDATE messages SET content = ?, mcp_execution_ids = ?, updated_at = ? WHERE id = ?",
|
||||
result.Response,
|
||||
mcpIDsJSON,
|
||||
time.Now(),
|
||||
assistantMessageID,
|
||||
)
|
||||
_ = h.db.UpdateAssistantMessageFinalize(assistantMessageID, result.Response, cumulativeMCPExecutionIDs, multiagent.AggregatedReasoningFromTraceJSON(result.LastAgentTraceInput))
|
||||
}
|
||||
|
||||
if result.LastAgentTraceInput != "" || result.LastAgentTraceOutput != "" {
|
||||
@@ -279,7 +370,7 @@ func (h *AgentHandler) EinoSingleAgentLoopStream(c *gin.Context) {
|
||||
}
|
||||
|
||||
sendEvent("response", result.Response, map[string]interface{}{
|
||||
"mcpExecutionIds": result.MCPExecutionIDs,
|
||||
"mcpExecutionIds": cumulativeMCPExecutionIDs,
|
||||
"conversationId": conversationID,
|
||||
"messageId": assistantMessageID,
|
||||
"agentMode": "eino_single",
|
||||
@@ -297,7 +388,7 @@ func (h *AgentHandler) EinoSingleAgentLoop(c *gin.Context) {
|
||||
|
||||
h.logger.Info("收到 Eino ADK 单代理非流式请求", zap.String("conversationId", req.ConversationID))
|
||||
|
||||
prep, err := h.prepareMultiAgentSession(&req)
|
||||
prep, err := h.prepareMultiAgentSession(&req, c, "eino_agent")
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
@@ -337,6 +428,7 @@ func (h *AgentHandler) EinoSingleAgentLoop(c *gin.Context) {
|
||||
prep.History,
|
||||
prep.RoleTools,
|
||||
progressCallback,
|
||||
chatReasoningToClientIntent(req.Reasoning),
|
||||
)
|
||||
if runErr != nil {
|
||||
if shouldPersistEinoAgentTraceAfterRunError(baseCtx) {
|
||||
@@ -347,18 +439,7 @@ func (h *AgentHandler) EinoSingleAgentLoop(c *gin.Context) {
|
||||
}
|
||||
|
||||
if prep.AssistantMessageID != "" {
|
||||
mcpIDsJSON := ""
|
||||
if len(result.MCPExecutionIDs) > 0 {
|
||||
jsonData, _ := json.Marshal(result.MCPExecutionIDs)
|
||||
mcpIDsJSON = string(jsonData)
|
||||
}
|
||||
_, _ = h.db.Exec(
|
||||
"UPDATE messages SET content = ?, mcp_execution_ids = ?, updated_at = ? WHERE id = ?",
|
||||
result.Response,
|
||||
mcpIDsJSON,
|
||||
time.Now(),
|
||||
prep.AssistantMessageID,
|
||||
)
|
||||
_ = h.db.UpdateAssistantMessageFinalize(prep.AssistantMessageID, result.Response, result.MCPExecutionIDs, multiagent.AggregatedReasoningFromTraceJSON(result.LastAgentTraceInput))
|
||||
}
|
||||
if result.LastAgentTraceInput != "" || result.LastAgentTraceOutput != "" {
|
||||
_ = h.db.SaveAgentTrace(prep.ConversationID, result.LastAgentTraceInput, result.LastAgentTraceOutput)
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"os"
|
||||
"sync"
|
||||
|
||||
"cyberstrike-ai/internal/audit"
|
||||
"cyberstrike-ai/internal/config"
|
||||
"cyberstrike-ai/internal/mcp"
|
||||
|
||||
@@ -20,9 +21,15 @@ type ExternalMCPHandler struct {
|
||||
config *config.Config
|
||||
configPath string
|
||||
logger *zap.Logger
|
||||
audit *audit.Service
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
// SetAudit wires platform audit logging.
|
||||
func (h *ExternalMCPHandler) SetAudit(s *audit.Service) {
|
||||
h.audit = s
|
||||
}
|
||||
|
||||
// NewExternalMCPHandler 创建外部MCP处理器
|
||||
func NewExternalMCPHandler(manager *mcp.ExternalMCPManager, cfg *config.Config, configPath string, logger *zap.Logger) *ExternalMCPHandler {
|
||||
return &ExternalMCPHandler{
|
||||
@@ -180,6 +187,16 @@ func (h *ExternalMCPHandler) AddOrUpdateExternalMCP(c *gin.Context) {
|
||||
}
|
||||
|
||||
h.logger.Info("外部MCP配置已更新", zap.String("name", name))
|
||||
if h.audit != nil {
|
||||
h.audit.Record(c, audit.Entry{
|
||||
Category: "external_mcp",
|
||||
Action: "upsert",
|
||||
Result: "success",
|
||||
ResourceType: "external_mcp",
|
||||
ResourceID: name,
|
||||
Message: "更新外部 MCP 配置",
|
||||
})
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{"message": "配置已更新"})
|
||||
}
|
||||
|
||||
@@ -209,6 +226,16 @@ func (h *ExternalMCPHandler) DeleteExternalMCP(c *gin.Context) {
|
||||
}
|
||||
|
||||
h.logger.Info("外部MCP配置已删除", zap.String("name", name))
|
||||
if h.audit != nil {
|
||||
h.audit.Record(c, audit.Entry{
|
||||
Category: "external_mcp",
|
||||
Action: "delete",
|
||||
Result: "success",
|
||||
ResourceType: "external_mcp",
|
||||
ResourceID: name,
|
||||
Message: "删除外部 MCP 配置",
|
||||
})
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{"message": "配置已删除"})
|
||||
}
|
||||
|
||||
|
||||
@@ -616,6 +616,11 @@ func (h *AgentHandler) DecideHITLInterrupt(c *gin.Context) {
|
||||
c.JSON(http.StatusConflict, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
if h.audit != nil {
|
||||
h.audit.RecordOK(c, "hitl", "decision", "HITL 审批决策", "hitl_interrupt", req.InterruptID, map[string]interface{}{
|
||||
"decision": req.Decision,
|
||||
})
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{"ok": true})
|
||||
}
|
||||
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"cyberstrike-ai/internal/audit"
|
||||
"cyberstrike-ai/internal/database"
|
||||
"cyberstrike-ai/internal/knowledge"
|
||||
|
||||
@@ -20,6 +21,12 @@ type KnowledgeHandler struct {
|
||||
indexer *knowledge.Indexer
|
||||
db *database.DB
|
||||
logger *zap.Logger
|
||||
audit *audit.Service
|
||||
}
|
||||
|
||||
// SetAudit wires platform audit logging.
|
||||
func (h *KnowledgeHandler) SetAudit(s *audit.Service) {
|
||||
h.audit = s
|
||||
}
|
||||
|
||||
// NewKnowledgeHandler 创建新的知识库处理器
|
||||
@@ -303,6 +310,9 @@ func (h *KnowledgeHandler) DeleteItem(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
if h.audit != nil {
|
||||
h.audit.RecordOK(c, "knowledge", "item_delete", "删除知识项", "knowledge_item", id, nil)
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{"message": "删除成功"})
|
||||
}
|
||||
|
||||
@@ -316,6 +326,9 @@ func (h *KnowledgeHandler) RebuildIndex(c *gin.Context) {
|
||||
}
|
||||
}()
|
||||
|
||||
if h.audit != nil {
|
||||
h.audit.RecordOK(c, "knowledge", "index_rebuild", "重建知识库索引", "knowledge", "", nil)
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{"message": "索引重建已开始,将在后台进行"})
|
||||
}
|
||||
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
"strings"
|
||||
|
||||
"cyberstrike-ai/internal/agents"
|
||||
"cyberstrike-ai/internal/audit"
|
||||
"cyberstrike-ai/internal/config"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
@@ -18,7 +19,8 @@ var markdownAgentFilenameRe = regexp.MustCompile(`^[a-zA-Z0-9][a-zA-Z0-9_.-]*\.m
|
||||
|
||||
// MarkdownAgentsHandler 管理 agents 目录下子代理 Markdown(增删改查)。
|
||||
type MarkdownAgentsHandler struct {
|
||||
dir string
|
||||
dir string
|
||||
audit *audit.Service
|
||||
}
|
||||
|
||||
// NewMarkdownAgentsHandler dir 须为已解析的绝对路径。
|
||||
@@ -26,6 +28,11 @@ func NewMarkdownAgentsHandler(dir string) *MarkdownAgentsHandler {
|
||||
return &MarkdownAgentsHandler{dir: strings.TrimSpace(dir)}
|
||||
}
|
||||
|
||||
// SetAudit wires platform audit logging.
|
||||
func (h *MarkdownAgentsHandler) SetAudit(s *audit.Service) {
|
||||
h.audit = s
|
||||
}
|
||||
|
||||
func (h *MarkdownAgentsHandler) safeJoin(filename string) (string, error) {
|
||||
filename = strings.TrimSpace(filename)
|
||||
if filename == "" || !markdownAgentFilenameRe.MatchString(filename) {
|
||||
@@ -227,6 +234,9 @@ func (h *MarkdownAgentsHandler) CreateMarkdownAgent(c *gin.Context) {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
if h.audit != nil {
|
||||
h.audit.RecordOK(c, "agent", "markdown_create", "创建 Markdown 子代理", "markdown_agent", filepath.Base(path), nil)
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{"filename": filepath.Base(path), "message": "已创建"})
|
||||
}
|
||||
|
||||
@@ -294,6 +304,9 @@ func (h *MarkdownAgentsHandler) UpdateMarkdownAgent(c *gin.Context) {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
if h.audit != nil {
|
||||
h.audit.RecordOK(c, "agent", "markdown_update", "更新 Markdown 子代理", "markdown_agent", filename, nil)
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{"message": "已保存"})
|
||||
}
|
||||
|
||||
@@ -313,5 +326,8 @@ func (h *MarkdownAgentsHandler) DeleteMarkdownAgent(c *gin.Context) {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
if h.audit != nil {
|
||||
h.audit.RecordOK(c, "agent", "markdown_delete", "删除 Markdown 子代理", "markdown_agent", filename, nil)
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{"message": "已删除"})
|
||||
}
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"cyberstrike-ai/internal/audit"
|
||||
"cyberstrike-ai/internal/database"
|
||||
"cyberstrike-ai/internal/mcp"
|
||||
"cyberstrike-ai/internal/security"
|
||||
@@ -23,6 +24,12 @@ type MonitorHandler struct {
|
||||
executor *security.Executor
|
||||
db *database.DB
|
||||
logger *zap.Logger
|
||||
audit *audit.Service
|
||||
}
|
||||
|
||||
// SetAudit wires platform audit logging.
|
||||
func (h *MonitorHandler) SetAudit(s *audit.Service) {
|
||||
h.audit = s
|
||||
}
|
||||
|
||||
// NewMonitorHandler 创建新的监控处理器
|
||||
@@ -365,6 +372,11 @@ func (h *MonitorHandler) DeleteExecution(c *gin.Context) {
|
||||
}
|
||||
|
||||
h.logger.Info("执行记录已从数据库删除", zap.String("executionId", id), zap.String("toolName", exec.ToolName))
|
||||
if h.audit != nil {
|
||||
h.audit.RecordOK(c, "tool", "execution_delete", "删除工具执行记录", "tool_execution", id, map[string]interface{}{
|
||||
"tool_name": exec.ToolName,
|
||||
})
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{"message": "执行记录已删除"})
|
||||
return
|
||||
}
|
||||
@@ -440,6 +452,11 @@ func (h *MonitorHandler) DeleteExecutions(c *gin.Context) {
|
||||
}
|
||||
|
||||
h.logger.Info("批量删除执行记录成功", zap.Int("count", len(request.IDs)))
|
||||
if h.audit != nil {
|
||||
h.audit.RecordOK(c, "tool", "execution_delete_batch", "批量删除工具执行记录", "tool_execution", "", map[string]interface{}{
|
||||
"count": len(request.IDs),
|
||||
})
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{"message": "成功删除执行记录", "deleted": len(executions)})
|
||||
return
|
||||
}
|
||||
|
||||
+178
-50
@@ -63,7 +63,7 @@ func (h *AgentHandler) MultiAgentLoopStream(c *gin.Context) {
|
||||
// 为避免 UI 看到“取消错误 + cancelled 文案”两条回复,这里直接丢弃取消对应的 error。
|
||||
if eventType == "error" && baseCtx != nil {
|
||||
cause := context.Cause(baseCtx)
|
||||
if errors.Is(cause, ErrTaskCancelled) {
|
||||
if errors.Is(cause, ErrTaskCancelled) || errors.Is(cause, multiagent.ErrInterruptContinue) {
|
||||
return
|
||||
}
|
||||
}
|
||||
@@ -107,7 +107,7 @@ func (h *AgentHandler) MultiAgentLoopStream(c *gin.Context) {
|
||||
zap.String("conversationId", req.ConversationID),
|
||||
)
|
||||
|
||||
prep, err := h.prepareMultiAgentSession(&req)
|
||||
prep, err := h.prepareMultiAgentSession(&req, c, "multi_agent_stream")
|
||||
if err != nil {
|
||||
sendEvent("error", err.Error(), nil)
|
||||
sendEvent("done", "", nil)
|
||||
@@ -136,6 +136,7 @@ func (h *AgentHandler) MultiAgentLoopStream(c *gin.Context) {
|
||||
|
||||
var cancelWithCause context.CancelCauseFunc
|
||||
curFinalMessage := prep.FinalMessage
|
||||
segmentUserMessage := prep.FinalMessage // 本请求原始用户句,临时重试时不得丢失
|
||||
curHistory := prep.History
|
||||
roleTools := prep.RoleTools
|
||||
orch := strings.TrimSpace(req.Orchestration)
|
||||
@@ -184,31 +185,128 @@ func (h *AgentHandler) MultiAgentLoopStream(c *gin.Context) {
|
||||
}
|
||||
taskOwned = true
|
||||
|
||||
progressCallback := h.createProgressCallback(taskCtx, cancelWithCause, conversationID, assistantMessageID, sendEvent)
|
||||
taskCtx = mcp.WithMCPConversationID(taskCtx, conversationID)
|
||||
taskCtx = mcp.WithToolRunRegistry(taskCtx, h.tasks)
|
||||
taskCtx = multiagent.WithHITLToolInterceptor(taskCtx, func(ctx context.Context, toolName, arguments string) (string, error) {
|
||||
return h.interceptHITLForEinoTool(ctx, cancelWithCause, conversationID, assistantMessageID, sendEvent, toolName, arguments)
|
||||
})
|
||||
// 同一 HTTP 流内多段 Run(如中断并继续)合并 MCP execution id,供最终 response / 库表与工具芯片展示完整列表
|
||||
var cumulativeMCPExecutionIDs []string
|
||||
var transientRunAttempts int
|
||||
// 同一请求内分段续跑时,主代理 iteration 事件按偏移累计,避免 UI 出现「第3轮 → 第1轮」回跳。
|
||||
var mainIterationOffset int
|
||||
|
||||
result, runErr = multiagent.RunDeepAgent(
|
||||
taskCtx,
|
||||
h.config,
|
||||
&h.config.MultiAgent,
|
||||
h.agent,
|
||||
h.logger,
|
||||
conversationID,
|
||||
curFinalMessage,
|
||||
curHistory,
|
||||
roleTools,
|
||||
progressCallback,
|
||||
h.agentsMarkdownDir,
|
||||
orch,
|
||||
)
|
||||
timeoutCancel()
|
||||
for {
|
||||
segmentMainIterationMax := 0
|
||||
rawProgressCallback := h.createProgressCallback(taskCtx, cancelWithCause, conversationID, assistantMessageID, sendEvent)
|
||||
progressCallback := func(eventType, message string, data interface{}) {
|
||||
if eventType == "iteration" {
|
||||
if m, ok := data.(map[string]interface{}); ok {
|
||||
if scope, _ := m["einoScope"].(string); scope == "main" {
|
||||
raw := 0
|
||||
switch v := m["iteration"].(type) {
|
||||
case int:
|
||||
raw = v
|
||||
case int32:
|
||||
raw = int(v)
|
||||
case int64:
|
||||
raw = int(v)
|
||||
case float64:
|
||||
raw = int(v)
|
||||
case float32:
|
||||
raw = int(v)
|
||||
}
|
||||
if raw > 0 {
|
||||
if raw > segmentMainIterationMax {
|
||||
segmentMainIterationMax = raw
|
||||
}
|
||||
m["iteration"] = raw + mainIterationOffset
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
rawProgressCallback(eventType, message, data)
|
||||
}
|
||||
taskCtxLoop := mcp.WithMCPConversationID(taskCtx, conversationID)
|
||||
taskCtxLoop = mcp.WithToolRunRegistry(taskCtxLoop, h.tasks)
|
||||
taskCtxLoop = multiagent.WithHITLToolInterceptor(taskCtxLoop, func(ctx context.Context, toolName, arguments string) (string, error) {
|
||||
return h.interceptHITLForEinoTool(ctx, cancelWithCause, conversationID, assistantMessageID, sendEvent, toolName, arguments)
|
||||
})
|
||||
|
||||
result, runErr = multiagent.RunDeepAgent(
|
||||
taskCtxLoop,
|
||||
h.config,
|
||||
&h.config.MultiAgent,
|
||||
h.agent,
|
||||
h.logger,
|
||||
conversationID,
|
||||
curFinalMessage,
|
||||
curHistory,
|
||||
roleTools,
|
||||
progressCallback,
|
||||
h.agentsMarkdownDir,
|
||||
orch,
|
||||
chatReasoningToClientIntent(req.Reasoning),
|
||||
)
|
||||
|
||||
if result != nil && len(result.MCPExecutionIDs) > 0 {
|
||||
cumulativeMCPExecutionIDs = mergeMCPExecutionIDLists(cumulativeMCPExecutionIDs, result.MCPExecutionIDs)
|
||||
}
|
||||
|
||||
if runErr == nil {
|
||||
// 任一段成功完成后,重置临时错误重试窗口(次数/退避从头开始)。
|
||||
transientRunAttempts = 0
|
||||
timeoutCancel()
|
||||
break
|
||||
}
|
||||
|
||||
handled, fatalErr := h.handleEinoTransientRetryContinue(
|
||||
baseCtx, conversationID, result, runErr, &transientRunAttempts,
|
||||
&curHistory, &curFinalMessage, segmentUserMessage, progressCallback,
|
||||
func(msg string, extra map[string]interface{}) { sendEvent("progress", msg, extra) },
|
||||
)
|
||||
if handled {
|
||||
mainIterationOffset += segmentMainIterationMax
|
||||
timeoutCancel()
|
||||
baseCtx, cancelWithCause = context.WithCancelCause(context.Background())
|
||||
h.tasks.BindTaskCancel(conversationID, cancelWithCause)
|
||||
taskCtx, timeoutCancel = context.WithTimeout(baseCtx, 600*time.Minute)
|
||||
h.tasks.UpdateTaskStatus(conversationID, "running")
|
||||
continue
|
||||
}
|
||||
if fatalErr != nil {
|
||||
runErr = fatalErr
|
||||
}
|
||||
|
||||
if runErr != nil {
|
||||
cause := context.Cause(baseCtx)
|
||||
if errors.Is(cause, multiagent.ErrInterruptContinue) {
|
||||
if shouldPersistEinoAgentTraceAfterRunError(baseCtx) {
|
||||
h.persistEinoAgentTraceForResume(conversationID, result)
|
||||
}
|
||||
note := h.tasks.TakeInterruptContinueNote(conversationID)
|
||||
icSummary := interruptContinueTimelineSummary(note)
|
||||
progressCallback("user_interrupt_continue", icSummary, map[string]interface{}{
|
||||
"conversationId": conversationID,
|
||||
"rawReason": strings.TrimSpace(note),
|
||||
"emptyReason": strings.TrimSpace(note) == "",
|
||||
"kind": "no_active_mcp_tool",
|
||||
})
|
||||
inject := formatInterruptContinueUserMessage(note)
|
||||
// 不写入 messages 表为 user 气泡:避免主对话流出现大段模板;说明已由 user_interrupt_continue 记入助手 process_details(迭代详情)。
|
||||
if hist, err := h.loadHistoryFromAgentTrace(conversationID); err == nil && len(hist) > 0 {
|
||||
curHistory = hist
|
||||
}
|
||||
curFinalMessage = inject
|
||||
sendEvent("progress", "已合并用户补充与最新轨迹,正在继续推理…", map[string]interface{}{
|
||||
"conversationId": conversationID,
|
||||
"source": "interrupt_continue",
|
||||
})
|
||||
mainIterationOffset += segmentMainIterationMax
|
||||
// 非临时错误分段续跑(用户中断并继续)时,清空 transient 计数,避免跨分段累加。
|
||||
transientRunAttempts = 0
|
||||
timeoutCancel()
|
||||
baseCtx, cancelWithCause = context.WithCancelCause(context.Background())
|
||||
h.tasks.BindTaskCancel(conversationID, cancelWithCause)
|
||||
taskCtx, timeoutCancel = context.WithTimeout(baseCtx, 600*time.Minute)
|
||||
h.tasks.UpdateTaskStatus(conversationID, "running")
|
||||
continue
|
||||
}
|
||||
|
||||
if shouldPersistEinoAgentTraceAfterRunError(baseCtx) {
|
||||
h.persistEinoAgentTraceForResume(conversationID, result)
|
||||
}
|
||||
@@ -232,6 +330,7 @@ func (h *AgentHandler) MultiAgentLoopStream(c *gin.Context) {
|
||||
"messageId": assistantMessageID,
|
||||
})
|
||||
sendEvent("done", "", map[string]interface{}{"conversationId": conversationID})
|
||||
timeoutCancel()
|
||||
return
|
||||
}
|
||||
|
||||
@@ -249,6 +348,7 @@ func (h *AgentHandler) MultiAgentLoopStream(c *gin.Context) {
|
||||
"errorType": "timeout",
|
||||
})
|
||||
sendEvent("done", "", map[string]interface{}{"conversationId": conversationID})
|
||||
timeoutCancel()
|
||||
return
|
||||
}
|
||||
|
||||
@@ -265,22 +365,14 @@ func (h *AgentHandler) MultiAgentLoopStream(c *gin.Context) {
|
||||
"messageId": assistantMessageID,
|
||||
})
|
||||
sendEvent("done", "", map[string]interface{}{"conversationId": conversationID})
|
||||
timeoutCancel()
|
||||
return
|
||||
}
|
||||
|
||||
timeoutCancel()
|
||||
|
||||
if assistantMessageID != "" {
|
||||
mcpIDsJSON := ""
|
||||
if len(result.MCPExecutionIDs) > 0 {
|
||||
jsonData, _ := json.Marshal(result.MCPExecutionIDs)
|
||||
mcpIDsJSON = string(jsonData)
|
||||
}
|
||||
_, _ = h.db.Exec(
|
||||
"UPDATE messages SET content = ?, mcp_execution_ids = ?, updated_at = ? WHERE id = ?",
|
||||
result.Response,
|
||||
mcpIDsJSON,
|
||||
time.Now(),
|
||||
assistantMessageID,
|
||||
)
|
||||
_ = h.db.UpdateAssistantMessageFinalize(assistantMessageID, result.Response, cumulativeMCPExecutionIDs, multiagent.AggregatedReasoningFromTraceJSON(result.LastAgentTraceInput))
|
||||
}
|
||||
|
||||
if result.LastAgentTraceInput != "" || result.LastAgentTraceOutput != "" {
|
||||
@@ -294,7 +386,7 @@ func (h *AgentHandler) MultiAgentLoopStream(c *gin.Context) {
|
||||
effectiveOrch = config.NormalizeMultiAgentOrchestration(o)
|
||||
}
|
||||
sendEvent("response", result.Response, map[string]interface{}{
|
||||
"mcpExecutionIds": result.MCPExecutionIDs,
|
||||
"mcpExecutionIds": cumulativeMCPExecutionIDs,
|
||||
"conversationId": conversationID,
|
||||
"messageId": assistantMessageID,
|
||||
"agentMode": "eino_" + effectiveOrch,
|
||||
@@ -317,7 +409,7 @@ func (h *AgentHandler) MultiAgentLoop(c *gin.Context) {
|
||||
|
||||
h.logger.Info("收到 Eino DeepAgent 非流式请求", zap.String("conversationId", req.ConversationID))
|
||||
|
||||
prep, err := h.prepareMultiAgentSession(&req)
|
||||
prep, err := h.prepareMultiAgentSession(&req, c, "multi_agent")
|
||||
if err != nil {
|
||||
status, msg := multiAgentHTTPErrorStatus(err)
|
||||
c.JSON(status, gin.H{"error": msg})
|
||||
@@ -350,6 +442,7 @@ func (h *AgentHandler) MultiAgentLoop(c *gin.Context) {
|
||||
progressCallback,
|
||||
h.agentsMarkdownDir,
|
||||
strings.TrimSpace(req.Orchestration),
|
||||
chatReasoningToClientIntent(req.Reasoning),
|
||||
)
|
||||
if runErr != nil {
|
||||
if shouldPersistEinoAgentTraceAfterRunError(baseCtx) {
|
||||
@@ -365,18 +458,7 @@ func (h *AgentHandler) MultiAgentLoop(c *gin.Context) {
|
||||
}
|
||||
|
||||
if prep.AssistantMessageID != "" {
|
||||
mcpIDsJSON := ""
|
||||
if len(result.MCPExecutionIDs) > 0 {
|
||||
jsonData, _ := json.Marshal(result.MCPExecutionIDs)
|
||||
mcpIDsJSON = string(jsonData)
|
||||
}
|
||||
_, _ = h.db.Exec(
|
||||
"UPDATE messages SET content = ?, mcp_execution_ids = ?, updated_at = ? WHERE id = ?",
|
||||
result.Response,
|
||||
mcpIDsJSON,
|
||||
time.Now(),
|
||||
prep.AssistantMessageID,
|
||||
)
|
||||
_ = h.db.UpdateAssistantMessageFinalize(prep.AssistantMessageID, result.Response, result.MCPExecutionIDs, multiagent.AggregatedReasoningFromTraceJSON(result.LastAgentTraceInput))
|
||||
}
|
||||
|
||||
if result.LastAgentTraceInput != "" || result.LastAgentTraceOutput != "" {
|
||||
@@ -406,6 +488,52 @@ func (h *AgentHandler) persistEinoAgentTraceForResume(conversationID string, res
|
||||
}
|
||||
}
|
||||
|
||||
// mergeMCPExecutionIDLists 去重合并多段 Run 的 MCP execution id(顺序:先 dst 后 more)。
|
||||
func mergeMCPExecutionIDLists(dst []string, more []string) []string {
|
||||
seen := make(map[string]struct{}, len(dst)+len(more))
|
||||
out := make([]string, 0, len(dst)+len(more))
|
||||
add := func(ids []string) {
|
||||
for _, id := range ids {
|
||||
id = strings.TrimSpace(id)
|
||||
if id == "" {
|
||||
continue
|
||||
}
|
||||
if _, ok := seen[id]; ok {
|
||||
continue
|
||||
}
|
||||
seen[id] = struct{}{}
|
||||
out = append(out, id)
|
||||
}
|
||||
}
|
||||
add(dst)
|
||||
add(more)
|
||||
return out
|
||||
}
|
||||
|
||||
// interruptContinueTimelineSummary 时间线 / process_details 中展示的简短正文(完整模板已写入另一条用户消息)。
|
||||
func interruptContinueTimelineSummary(note string) string {
|
||||
note = strings.TrimSpace(note)
|
||||
if note == "" {
|
||||
return "用户选择「中断并继续」,未填写说明;已按默认渗透补充模板合并上下文并续跑。"
|
||||
}
|
||||
return "用户中断说明(原文):\n\n" + note
|
||||
}
|
||||
|
||||
// formatInterruptContinueUserMessage 将「中断并继续」弹窗中的说明格式化为新一轮 user 消息(渗透场景下强调路径补充与端口复扫)。
|
||||
func formatInterruptContinueUserMessage(note string) string {
|
||||
var b strings.Builder
|
||||
b.WriteString("【用户补充 / 中断后继续】\n")
|
||||
if s := strings.TrimSpace(note); s != "" {
|
||||
b.WriteString(s)
|
||||
b.WriteString("\n\n")
|
||||
}
|
||||
b.WriteString("【请在本轮落实】\n")
|
||||
b.WriteString("- 将用户提供的接口路径、参数、业务变化纳入后续测试与推理。\n")
|
||||
b.WriteString("- 若资产或目标信息有更新,请对目标重新执行端口/服务探测,再基于新结果规划下一步。\n")
|
||||
b.WriteString("- 在已有轨迹基础上推进,避免无意义重复已完成的步骤。\n")
|
||||
return strings.TrimSpace(b.String())
|
||||
}
|
||||
|
||||
func multiAgentHTTPErrorStatus(err error) (int, string) {
|
||||
msg := err.Error()
|
||||
switch {
|
||||
|
||||
@@ -5,9 +5,11 @@ import (
|
||||
"strings"
|
||||
|
||||
"cyberstrike-ai/internal/agent"
|
||||
"cyberstrike-ai/internal/audit"
|
||||
"cyberstrike-ai/internal/database"
|
||||
"cyberstrike-ai/internal/mcp/builtin"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
@@ -22,7 +24,7 @@ type multiAgentPrepared struct {
|
||||
UserMessageID string
|
||||
}
|
||||
|
||||
func (h *AgentHandler) prepareMultiAgentSession(req *ChatRequest) (*multiAgentPrepared, error) {
|
||||
func (h *AgentHandler) prepareMultiAgentSession(req *ChatRequest, c *gin.Context, source string) (*multiAgentPrepared, error) {
|
||||
if len(req.Attachments) > maxAttachments {
|
||||
return nil, fmt.Errorf("附件最多 %d 个", maxAttachments)
|
||||
}
|
||||
@@ -33,10 +35,13 @@ func (h *AgentHandler) prepareMultiAgentSession(req *ChatRequest) (*multiAgentPr
|
||||
title := safeTruncateString(req.Message, 50)
|
||||
var conv *database.Conversation
|
||||
var err error
|
||||
meta := audit.ConversationCreateMetaFromGin(c, source)
|
||||
if strings.TrimSpace(req.WebShellConnectionID) != "" {
|
||||
conv, err = h.db.CreateConversationWithWebshell(strings.TrimSpace(req.WebShellConnectionID), title)
|
||||
meta.Source = source + "_webshell"
|
||||
meta.WebShellConnectionID = strings.TrimSpace(req.WebShellConnectionID)
|
||||
conv, err = h.db.CreateConversationWithWebshell(meta.WebShellConnectionID, title, meta)
|
||||
} else {
|
||||
conv, err = h.db.CreateConversation(title)
|
||||
conv, err = h.db.CreateConversation(title, meta)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("创建对话失败: %w", err)
|
||||
@@ -55,13 +60,7 @@ func (h *AgentHandler) prepareMultiAgentSession(req *ChatRequest) (*multiAgentPr
|
||||
if getErr != nil {
|
||||
agentHistoryMessages = []agent.ChatMessage{}
|
||||
} else {
|
||||
agentHistoryMessages = make([]agent.ChatMessage, 0, len(historyMessages))
|
||||
for _, msg := range historyMessages {
|
||||
agentHistoryMessages = append(agentHistoryMessages, agent.ChatMessage{
|
||||
Role: msg.Role,
|
||||
Content: msg.Content,
|
||||
})
|
||||
}
|
||||
agentHistoryMessages = dbMessagesToAgentChatMessages(historyMessages)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -6254,7 +6254,7 @@ func (h *OpenAPIHandler) GetConversationResults(c *gin.Context) {
|
||||
}
|
||||
|
||||
// 获取漏洞列表
|
||||
vulnList, err := h.db.ListVulnerabilities(1000, 0, "", conversationID, "", "", "", "", "")
|
||||
vulnList, err := h.db.ListVulnerabilities(1000, 0, database.VulnerabilityListFilter{ConversationID: conversationID})
|
||||
if err != nil {
|
||||
h.logger.Warn("获取漏洞列表失败", zap.Error(err))
|
||||
vulnList = []*database.Vulnerability{}
|
||||
|
||||
@@ -133,7 +133,7 @@ func (h *RobotHandler) getOrCreateConversation(platform, userID, title string) (
|
||||
} else {
|
||||
t = safeTruncateString(t, 50)
|
||||
}
|
||||
conv, err := h.db.CreateConversation(t)
|
||||
conv, err := h.db.CreateConversation(t, database.ConversationCreateMeta{Source: "robot:" + platform})
|
||||
if err != nil {
|
||||
h.logger.Warn("创建机器人会话失败", zap.Error(err))
|
||||
return "", false
|
||||
@@ -188,7 +188,7 @@ func (h *RobotHandler) setRole(platform, userID, roleName string) {
|
||||
// clearConversation 清空当前会话(切换到新对话)
|
||||
func (h *RobotHandler) clearConversation(platform, userID string) (newConvID string) {
|
||||
title := "新对话 " + time.Now().Format("01-02 15:04")
|
||||
conv, err := h.db.CreateConversation(title)
|
||||
conv, err := h.db.CreateConversation(title, database.ConversationCreateMeta{Source: "robot:" + platform + ":new"})
|
||||
if err != nil {
|
||||
h.logger.Warn("创建新对话失败", zap.Error(err))
|
||||
return ""
|
||||
@@ -242,7 +242,7 @@ func (h *RobotHandler) HandleMessage(platform, userID, text string) (reply strin
|
||||
h.cancelMu.Unlock()
|
||||
}()
|
||||
role := h.getRole(platform, userID)
|
||||
resp, newConvID, err := h.agentHandler.ProcessMessageForRobot(ctx, convID, text, role)
|
||||
resp, newConvID, err := h.agentHandler.ProcessMessageForRobot(ctx, platform, convID, text, role)
|
||||
if err != nil {
|
||||
h.logger.Warn("机器人 Agent 执行失败", zap.String("platform", platform), zap.String("userID", userID), zap.Error(err))
|
||||
if errors.Is(err, context.Canceled) {
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
"cyberstrike-ai/internal/audit"
|
||||
"cyberstrike-ai/internal/config"
|
||||
|
||||
"gopkg.in/yaml.v3"
|
||||
@@ -21,6 +22,12 @@ type RoleHandler struct {
|
||||
config *config.Config
|
||||
configPath string
|
||||
logger *zap.Logger
|
||||
audit *audit.Service
|
||||
}
|
||||
|
||||
// SetAudit wires platform audit logging.
|
||||
func (h *RoleHandler) SetAudit(s *audit.Service) {
|
||||
h.audit = s
|
||||
}
|
||||
|
||||
// NewRoleHandler 创建新的角色处理器
|
||||
@@ -174,6 +181,9 @@ func (h *RoleHandler) UpdateRole(c *gin.Context) {
|
||||
}
|
||||
|
||||
h.logger.Info("更新角色", zap.String("oldKey", roleName), zap.String("newKey", finalKey), zap.String("name", req.Name))
|
||||
if h.audit != nil {
|
||||
h.audit.RecordOK(c, "role", "update", "更新角色", "role", finalKey, map[string]interface{}{"name": req.Name})
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"message": "角色已更新",
|
||||
"role": req,
|
||||
@@ -219,6 +229,9 @@ func (h *RoleHandler) CreateRole(c *gin.Context) {
|
||||
}
|
||||
|
||||
h.logger.Info("创建角色", zap.String("roleName", req.Name))
|
||||
if h.audit != nil {
|
||||
h.audit.RecordOK(c, "role", "create", "创建角色", "role", req.Name, nil)
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"message": "角色已创建",
|
||||
"role": req,
|
||||
@@ -287,6 +300,9 @@ func (h *RoleHandler) DeleteRole(c *gin.Context) {
|
||||
}
|
||||
|
||||
h.logger.Info("删除角色", zap.String("roleName", roleName))
|
||||
if h.audit != nil {
|
||||
h.audit.RecordOK(c, "role", "delete", "删除角色", "role", roleName, nil)
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"message": "角色已删除",
|
||||
})
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
"cyberstrike-ai/internal/audit"
|
||||
"cyberstrike-ai/internal/config"
|
||||
"cyberstrike-ai/internal/database"
|
||||
"cyberstrike-ai/internal/skillpackage"
|
||||
@@ -23,6 +24,12 @@ type SkillsHandler struct {
|
||||
configPath string
|
||||
logger *zap.Logger
|
||||
db *database.DB // 数据库连接(遗留统计;MCP list/read 已移除)
|
||||
audit *audit.Service
|
||||
}
|
||||
|
||||
// SetAudit wires platform audit logging.
|
||||
func (h *SkillsHandler) SetAudit(s *audit.Service) {
|
||||
h.audit = s
|
||||
}
|
||||
|
||||
// NewSkillsHandler 创建新的Skills处理器
|
||||
@@ -365,6 +372,9 @@ func (h *SkillsHandler) CreateSkill(c *gin.Context) {
|
||||
}
|
||||
|
||||
h.logger.Info("创建skill成功", zap.String("skill", req.Name))
|
||||
if h.audit != nil {
|
||||
h.audit.RecordOK(c, "skill", "create", "创建 Skill", "skill", req.Name, nil)
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"message": "skill已创建",
|
||||
"skill": map[string]interface{}{
|
||||
@@ -425,6 +435,9 @@ func (h *SkillsHandler) UpdateSkill(c *gin.Context) {
|
||||
}
|
||||
|
||||
h.logger.Info("更新skill成功", zap.String("skill", skillName))
|
||||
if h.audit != nil {
|
||||
h.audit.RecordOK(c, "skill", "update", "更新 Skill", "skill", skillName, nil)
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"message": "skill已更新",
|
||||
})
|
||||
@@ -459,6 +472,11 @@ func (h *SkillsHandler) DeleteSkill(c *gin.Context) {
|
||||
}
|
||||
|
||||
h.logger.Info("删除skill成功", zap.String("skill", skillName))
|
||||
if h.audit != nil {
|
||||
h.audit.RecordOK(c, "skill", "delete", "删除 Skill", "skill", skillName, map[string]interface{}{
|
||||
"affected_roles": affectedRoles,
|
||||
})
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"message": responseMsg,
|
||||
"affected_roles": affectedRoles,
|
||||
|
||||
@@ -6,6 +6,8 @@ import (
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"cyberstrike-ai/internal/multiagent"
|
||||
)
|
||||
|
||||
// ErrTaskCancelled 用户取消任务的错误
|
||||
@@ -32,6 +34,9 @@ type AgentTask struct {
|
||||
// ActiveMCPExecutionID 当前正在执行的 MCP 工具 executionId(仅内存,供「中断并继续」= 仅掐当前工具)
|
||||
ActiveMCPExecutionID string `json:"-"`
|
||||
|
||||
// InterruptContinueNote 无 MCP 时「中断并继续」由用户在弹窗中填写的补充说明(Cancel 前写入,续跑轮次读取后清空)
|
||||
InterruptContinueNote string `json:"-"`
|
||||
|
||||
cancel func(error)
|
||||
}
|
||||
|
||||
@@ -65,6 +70,50 @@ func (m *AgentTaskManager) UnregisterRunningTool(conversationID, executionID str
|
||||
}
|
||||
}
|
||||
|
||||
// SetInterruptContinueNote 在发起 ErrInterruptContinue 取消前写入用户补充说明(仅内存)。
|
||||
func (m *AgentTaskManager) SetInterruptContinueNote(conversationID, note string) {
|
||||
conversationID = strings.TrimSpace(conversationID)
|
||||
if conversationID == "" {
|
||||
return
|
||||
}
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
if t, ok := m.tasks[conversationID]; ok && t != nil {
|
||||
t.InterruptContinueNote = note
|
||||
}
|
||||
}
|
||||
|
||||
// TakeInterruptContinueNote 读取并清空补充说明(续跑开始时调用一次)。
|
||||
func (m *AgentTaskManager) TakeInterruptContinueNote(conversationID string) string {
|
||||
conversationID = strings.TrimSpace(conversationID)
|
||||
if conversationID == "" {
|
||||
return ""
|
||||
}
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
if t, ok := m.tasks[conversationID]; ok && t != nil {
|
||||
n := t.InterruptContinueNote
|
||||
t.InterruptContinueNote = ""
|
||||
return n
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// BindTaskCancel 在同一运行任务内替换与 context 绑定的 cancel 函数(用于中断后继续时换新 baseCtx)。
|
||||
func (m *AgentTaskManager) BindTaskCancel(conversationID string, cancel context.CancelCauseFunc) {
|
||||
conversationID = strings.TrimSpace(conversationID)
|
||||
if conversationID == "" || cancel == nil {
|
||||
return
|
||||
}
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
if t, ok := m.tasks[conversationID]; ok && t != nil {
|
||||
t.cancel = func(err error) {
|
||||
cancel(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ActiveMCPExecutionID 返回当前会话进行中的工具 executionId,无则空串。
|
||||
func (m *AgentTaskManager) ActiveMCPExecutionID(conversationID string) string {
|
||||
conversationID = strings.TrimSpace(conversationID)
|
||||
@@ -210,8 +259,16 @@ func (m *AgentTaskManager) CancelTask(conversationID string, cause error) (bool,
|
||||
return true, nil
|
||||
}
|
||||
|
||||
task.Status = "cancelling"
|
||||
task.CancellingAt = time.Now()
|
||||
// ErrInterruptContinue:仅掐断当前推理步骤,随后由处理器续跑,不进入长时间「取消中」态。
|
||||
if cause != nil && errors.Is(cause, multiagent.ErrInterruptContinue) {
|
||||
task.Status = "running"
|
||||
} else {
|
||||
task.Status = "cancelling"
|
||||
task.CancellingAt = time.Now()
|
||||
}
|
||||
if cause != nil && errors.Is(cause, ErrTaskCancelled) {
|
||||
task.InterruptContinueNote = ""
|
||||
}
|
||||
cancel := task.cancel
|
||||
m.mu.Unlock()
|
||||
|
||||
|
||||
@@ -253,5 +253,5 @@ func (h *TerminalHandler) RunCommandStream(c *gin.Context) {
|
||||
flusher.Flush()
|
||||
}
|
||||
|
||||
runCommandStreamImpl(cmd, sendEvent, ctx)
|
||||
_ = runCommandStreamImpl(cmd, sendEvent, ctx)
|
||||
}
|
||||
|
||||
@@ -15,11 +15,11 @@ const ptyCols = 256
|
||||
const ptyRows = 40
|
||||
|
||||
// runCommandStreamImpl 在 Unix 下用 PTY 执行,使 ping 等命令按终端宽度排版(isatty 为真)
|
||||
func runCommandStreamImpl(cmd *exec.Cmd, sendEvent func(streamEvent), ctx context.Context) {
|
||||
func runCommandStreamImpl(cmd *exec.Cmd, sendEvent func(streamEvent), ctx context.Context) int {
|
||||
ptmx, err := pty.StartWithSize(cmd, &pty.Winsize{Cols: ptyCols, Rows: ptyRows})
|
||||
if err != nil {
|
||||
sendEvent(streamEvent{T: "exit", C: -1})
|
||||
return
|
||||
return -1
|
||||
}
|
||||
defer ptmx.Close()
|
||||
|
||||
@@ -43,4 +43,5 @@ func runCommandStreamImpl(cmd *exec.Cmd, sendEvent func(streamEvent), ctx contex
|
||||
exitCode = -1
|
||||
}
|
||||
sendEvent(streamEvent{T: "exit", C: exitCode})
|
||||
return exitCode
|
||||
}
|
||||
|
||||
@@ -11,20 +11,20 @@ import (
|
||||
)
|
||||
|
||||
// runCommandStreamImpl 在 Windows 下用 stdout/stderr 管道执行
|
||||
func runCommandStreamImpl(cmd *exec.Cmd, sendEvent func(streamEvent), ctx context.Context) {
|
||||
func runCommandStreamImpl(cmd *exec.Cmd, sendEvent func(streamEvent), ctx context.Context) int {
|
||||
stdoutPipe, err := cmd.StdoutPipe()
|
||||
if err != nil {
|
||||
sendEvent(streamEvent{T: "exit", C: -1})
|
||||
return
|
||||
return -1
|
||||
}
|
||||
stderrPipe, err := cmd.StderrPipe()
|
||||
if err != nil {
|
||||
sendEvent(streamEvent{T: "exit", C: -1})
|
||||
return
|
||||
return -1
|
||||
}
|
||||
if err := cmd.Start(); err != nil {
|
||||
sendEvent(streamEvent{T: "exit", C: -1})
|
||||
return
|
||||
return -1
|
||||
}
|
||||
|
||||
normalize := func(s string) string {
|
||||
@@ -62,4 +62,5 @@ func runCommandStreamImpl(cmd *exec.Cmd, sendEvent func(streamEvent), ctx contex
|
||||
exitCode = -1
|
||||
}
|
||||
sendEvent(streamEvent{T: "exit", C: exitCode})
|
||||
return exitCode
|
||||
}
|
||||
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"cyberstrike-ai/internal/audit"
|
||||
"cyberstrike-ai/internal/database"
|
||||
"github.com/gin-gonic/gin"
|
||||
"go.uber.org/zap"
|
||||
@@ -16,6 +17,12 @@ import (
|
||||
type VulnerabilityHandler struct {
|
||||
db *database.DB
|
||||
logger *zap.Logger
|
||||
audit *audit.Service
|
||||
}
|
||||
|
||||
// SetAudit wires platform audit logging.
|
||||
func (h *VulnerabilityHandler) SetAudit(s *audit.Service) {
|
||||
h.audit = s
|
||||
}
|
||||
|
||||
// NewVulnerabilityHandler 创建新的漏洞处理器
|
||||
@@ -72,6 +79,11 @@ func (h *VulnerabilityHandler) CreateVulnerability(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
if h.audit != nil {
|
||||
h.audit.RecordOK(c, "vulnerability", "create", "创建漏洞记录", "vulnerability", created.ID, map[string]interface{}{
|
||||
"severity": created.Severity, "title": created.Title,
|
||||
})
|
||||
}
|
||||
c.JSON(http.StatusOK, created)
|
||||
}
|
||||
|
||||
@@ -98,18 +110,29 @@ type ListVulnerabilitiesResponse struct {
|
||||
TotalPages int `json:"total_pages"`
|
||||
}
|
||||
|
||||
func parseVulnerabilityListFilter(c *gin.Context) database.VulnerabilityListFilter {
|
||||
q := strings.TrimSpace(c.Query("q"))
|
||||
if q == "" {
|
||||
q = strings.TrimSpace(c.Query("search"))
|
||||
}
|
||||
return database.VulnerabilityListFilter{
|
||||
ID: c.Query("id"),
|
||||
Search: q,
|
||||
ConversationID: c.Query("conversation_id"),
|
||||
Severity: c.Query("severity"),
|
||||
Status: c.Query("status"),
|
||||
TaskID: c.Query("task_id"),
|
||||
ConversationTag: c.Query("conversation_tag"),
|
||||
TaskTag: c.Query("task_tag"),
|
||||
}
|
||||
}
|
||||
|
||||
// ListVulnerabilities 列出漏洞
|
||||
func (h *VulnerabilityHandler) ListVulnerabilities(c *gin.Context) {
|
||||
limitStr := c.DefaultQuery("limit", "20")
|
||||
offsetStr := c.DefaultQuery("offset", "0")
|
||||
pageStr := c.Query("page")
|
||||
id := c.Query("id")
|
||||
conversationID := c.Query("conversation_id")
|
||||
severity := c.Query("severity")
|
||||
status := c.Query("status")
|
||||
taskID := c.Query("task_id")
|
||||
conversationTag := c.Query("conversation_tag")
|
||||
taskTag := c.Query("task_tag")
|
||||
filter := parseVulnerabilityListFilter(c)
|
||||
|
||||
limit, _ := strconv.Atoi(limitStr)
|
||||
offset, _ := strconv.Atoi(offsetStr)
|
||||
@@ -131,7 +154,7 @@ func (h *VulnerabilityHandler) ListVulnerabilities(c *gin.Context) {
|
||||
}
|
||||
|
||||
// 获取总数
|
||||
total, err := h.db.CountVulnerabilities(id, conversationID, severity, status, taskID, conversationTag, taskTag)
|
||||
total, err := h.db.CountVulnerabilities(filter)
|
||||
if err != nil {
|
||||
h.logger.Error("获取漏洞总数失败", zap.Error(err))
|
||||
// 继续执行,使用0作为总数
|
||||
@@ -139,7 +162,7 @@ func (h *VulnerabilityHandler) ListVulnerabilities(c *gin.Context) {
|
||||
}
|
||||
|
||||
// 获取漏洞列表
|
||||
vulnerabilities, err := h.db.ListVulnerabilities(limit, offset, id, conversationID, severity, status, taskID, conversationTag, taskTag)
|
||||
vulnerabilities, err := h.db.ListVulnerabilities(limit, offset, filter)
|
||||
if err != nil {
|
||||
h.logger.Error("获取漏洞列表失败", zap.Error(err))
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
@@ -249,6 +272,11 @@ func (h *VulnerabilityHandler) UpdateVulnerability(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
if h.audit != nil {
|
||||
h.audit.RecordOK(c, "vulnerability", "update", "更新漏洞记录", "vulnerability", id, map[string]interface{}{
|
||||
"severity": updated.Severity, "status": updated.Status,
|
||||
})
|
||||
}
|
||||
c.JSON(http.StatusOK, updated)
|
||||
}
|
||||
|
||||
@@ -262,15 +290,25 @@ func (h *VulnerabilityHandler) DeleteVulnerability(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
if h.audit != nil {
|
||||
h.audit.Record(c, audit.Entry{
|
||||
Category: "vulnerability",
|
||||
Action: "delete",
|
||||
Result: "success",
|
||||
ResourceType: "vulnerability",
|
||||
ResourceID: id,
|
||||
Message: "删除漏洞记录",
|
||||
})
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{"message": "删除成功"})
|
||||
}
|
||||
|
||||
// GetVulnerabilityStats 获取漏洞统计
|
||||
func (h *VulnerabilityHandler) GetVulnerabilityStats(c *gin.Context) {
|
||||
conversationID := c.Query("conversation_id")
|
||||
taskID := c.Query("task_id")
|
||||
filter := parseVulnerabilityListFilter(c)
|
||||
|
||||
stats, err := h.db.GetVulnerabilityStats(conversationID, taskID)
|
||||
stats, err := h.db.GetVulnerabilityStats(filter)
|
||||
if err != nil {
|
||||
h.logger.Error("获取漏洞统计失败", zap.Error(err))
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
@@ -304,15 +342,9 @@ func (h *VulnerabilityHandler) ExportVulnerabilities(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
id := c.Query("id")
|
||||
conversationID := c.Query("conversation_id")
|
||||
severity := c.Query("severity")
|
||||
status := c.Query("status")
|
||||
taskID := c.Query("task_id")
|
||||
conversationTag := c.Query("conversation_tag")
|
||||
taskTag := c.Query("task_tag")
|
||||
filter := parseVulnerabilityListFilter(c)
|
||||
|
||||
total, err := h.db.CountVulnerabilities(id, conversationID, severity, status, taskID, conversationTag, taskTag)
|
||||
total, err := h.db.CountVulnerabilities(filter)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
@@ -322,7 +354,7 @@ func (h *VulnerabilityHandler) ExportVulnerabilities(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
items, err := h.db.ListVulnerabilities(total, 0, id, conversationID, severity, status, taskID, conversationTag, taskTag)
|
||||
items, err := h.db.ListVulnerabilities(total, 0, filter)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
|
||||
@@ -2,6 +2,7 @@ package handler
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/tls"
|
||||
"database/sql"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
@@ -12,6 +13,7 @@ import (
|
||||
"time"
|
||||
"unicode/utf8"
|
||||
|
||||
"cyberstrike-ai/internal/audit"
|
||||
"cyberstrike-ai/internal/database"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
@@ -304,6 +306,12 @@ type WebShellHandler struct {
|
||||
logger *zap.Logger
|
||||
client *http.Client
|
||||
db *database.DB
|
||||
audit *audit.Service
|
||||
}
|
||||
|
||||
// SetAudit wires platform audit logging.
|
||||
func (h *WebShellHandler) SetAudit(s *audit.Service) {
|
||||
h.audit = s
|
||||
}
|
||||
|
||||
// NewWebShellHandler 创建 WebShell 处理器,db 可为 nil(连接配置接口将不可用)
|
||||
@@ -311,8 +319,12 @@ func NewWebShellHandler(logger *zap.Logger, db *database.DB) *WebShellHandler {
|
||||
return &WebShellHandler{
|
||||
logger: logger,
|
||||
client: &http.Client{
|
||||
Timeout: 30 * time.Second,
|
||||
Transport: &http.Transport{DisableKeepAlives: false},
|
||||
Timeout: 30 * time.Second,
|
||||
Transport: &http.Transport{
|
||||
DisableKeepAlives: false,
|
||||
// WebShell 场景常见自签证书或 IP 访问(证书无 IP SAN);默认跳过校验,与蚁剑等客户端一致。
|
||||
TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, //nolint:gosec // intentional for webshell proxy
|
||||
},
|
||||
},
|
||||
db: db,
|
||||
}
|
||||
@@ -403,6 +415,15 @@ func (h *WebShellHandler) CreateConnection(c *gin.Context) {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
if h.audit != nil {
|
||||
host := req.URL
|
||||
if u, err := url.Parse(req.URL); err == nil {
|
||||
host = u.Host
|
||||
}
|
||||
h.audit.RecordOK(c, "webshell", "connection_create", "创建 WebShell 连接", "webshell_connection", conn.ID, map[string]interface{}{
|
||||
"host": host, "type": shellType,
|
||||
})
|
||||
}
|
||||
c.JSON(http.StatusOK, conn)
|
||||
}
|
||||
|
||||
@@ -485,6 +506,9 @@ func (h *WebShellHandler) DeleteConnection(c *gin.Context) {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
if h.audit != nil {
|
||||
h.audit.RecordOK(c, "webshell", "connection_delete", "删除 WebShell 连接", "webshell_connection", id, nil)
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{"ok": true})
|
||||
}
|
||||
|
||||
@@ -714,8 +738,9 @@ func (h *WebShellHandler) Exec(c *gin.Context) {
|
||||
output := decodeWebshellOutput(out, req.Encoding)
|
||||
httpCode := resp.StatusCode
|
||||
|
||||
ok := resp.StatusCode == http.StatusOK
|
||||
c.JSON(http.StatusOK, ExecResponse{
|
||||
OK: resp.StatusCode == http.StatusOK,
|
||||
OK: ok,
|
||||
Output: output,
|
||||
HTTPCode: httpCode,
|
||||
})
|
||||
|
||||
@@ -0,0 +1,293 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"cyberstrike-ai/internal/config"
|
||||
"cyberstrike-ai/internal/robot/ilink"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/uuid"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
const wechatLoginTTL = 5 * time.Minute
|
||||
|
||||
// WechatConfigSaver 绑定成功后写入配置并重启机器人连接
|
||||
type WechatConfigSaver interface {
|
||||
ApplyWechatRobotBinding(cfg config.RobotWechatConfig) error
|
||||
}
|
||||
|
||||
type wechatLoginSession struct {
|
||||
QRCode string
|
||||
QRCodeImgURL string
|
||||
PendingVerify string
|
||||
CurrentBaseURL string
|
||||
StartedAt time.Time
|
||||
}
|
||||
|
||||
// WechatRobotHandler 微信 iLink 机器人(扫码绑定 + 配置)
|
||||
type WechatRobotHandler struct {
|
||||
config *config.Config
|
||||
configSaver WechatConfigSaver
|
||||
logger *zap.Logger
|
||||
mu sync.Mutex
|
||||
logins map[string]*wechatLoginSession
|
||||
}
|
||||
|
||||
// NewWechatRobotHandler 创建微信机器人处理器
|
||||
func NewWechatRobotHandler(cfg *config.Config, saver WechatConfigSaver, logger *zap.Logger) *WechatRobotHandler {
|
||||
return &WechatRobotHandler{
|
||||
config: cfg,
|
||||
configSaver: saver,
|
||||
logger: logger,
|
||||
logins: make(map[string]*wechatLoginSession),
|
||||
}
|
||||
}
|
||||
|
||||
func (h *WechatRobotHandler) purgeExpiredLogins() {
|
||||
now := time.Now()
|
||||
for k, v := range h.logins {
|
||||
if now.Sub(v.StartedAt) > wechatLoginTTL {
|
||||
delete(h.logins, k)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (h *WechatRobotHandler) ilinkClient(baseURL string) *ilink.Client {
|
||||
ver := h.config.Version
|
||||
if ver == "" {
|
||||
ver = "1.0.0"
|
||||
}
|
||||
ver = strings.TrimPrefix(strings.TrimSpace(ver), "v")
|
||||
ver = strings.TrimPrefix(ver, "V")
|
||||
wc := h.config.Robots.Wechat
|
||||
return ilink.NewClient(baseURL, wc.BotToken, wc.BotAgent, ilink.BuildClientVersion(ver))
|
||||
}
|
||||
|
||||
// HandleWechatQRCode POST /api/robot/wechat/qrcode — 生成绑定二维码
|
||||
func (h *WechatRobotHandler) HandleWechatQRCode(c *gin.Context) {
|
||||
h.mu.Lock()
|
||||
h.purgeExpiredLogins()
|
||||
h.mu.Unlock()
|
||||
|
||||
var req struct {
|
||||
BotType string `json:"bot_type"`
|
||||
}
|
||||
_ = c.ShouldBindJSON(&req)
|
||||
|
||||
botType := req.BotType
|
||||
if botType == "" {
|
||||
botType = h.config.Robots.Wechat.BotType
|
||||
}
|
||||
if botType == "" {
|
||||
botType = ilink.DefaultBotType
|
||||
}
|
||||
baseURL := h.config.Robots.Wechat.BaseURL
|
||||
if baseURL == "" {
|
||||
baseURL = ilink.DefaultBaseURL
|
||||
}
|
||||
|
||||
var localTokens []string
|
||||
if t := h.config.Robots.Wechat.BotToken; t != "" {
|
||||
localTokens = []string{t}
|
||||
}
|
||||
|
||||
client := h.ilinkClient(baseURL)
|
||||
ctx, cancel := context.WithTimeout(c.Request.Context(), 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
qr, err := client.GetBotQRCode(ctx, botType, localTokens)
|
||||
if err != nil {
|
||||
h.logger.Warn("获取微信二维码失败", zap.Error(err))
|
||||
c.JSON(http.StatusBadGateway, gin.H{"error": "获取二维码失败: " + err.Error()})
|
||||
return
|
||||
}
|
||||
if qr.QRCode == "" || qr.QRCodeImgContent == "" {
|
||||
c.JSON(http.StatusBadGateway, gin.H{"error": "微信服务器未返回有效二维码"})
|
||||
return
|
||||
}
|
||||
|
||||
sessionKey := uuid.New().String()
|
||||
h.mu.Lock()
|
||||
h.logins[sessionKey] = &wechatLoginSession{
|
||||
QRCode: qr.QRCode,
|
||||
QRCodeImgURL: qr.QRCodeImgContent,
|
||||
CurrentBaseURL: baseURL,
|
||||
StartedAt: time.Now(),
|
||||
}
|
||||
h.mu.Unlock()
|
||||
|
||||
resp := gin.H{
|
||||
"session_key": sessionKey,
|
||||
"qrcode": qr.QRCode,
|
||||
"qrcode_open_url": qr.QRCodeImgContent,
|
||||
"message": "请使用微信扫描二维码并确认绑定",
|
||||
}
|
||||
if dataURL, err := ilink.QRCodeDataURL(qr.QRCodeImgContent, 256); err != nil {
|
||||
h.logger.Warn("生成二维码图片失败", zap.Error(err))
|
||||
} else {
|
||||
resp["qrcode_image_data_url"] = dataURL
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, resp)
|
||||
}
|
||||
|
||||
// HandleWechatQRCodeStatus GET /api/robot/wechat/qrcode/status — 轮询扫码状态
|
||||
func (h *WechatRobotHandler) HandleWechatQRCodeStatus(c *gin.Context) {
|
||||
sessionKey := c.Query("session_key")
|
||||
verifyCode := c.Query("verify_code")
|
||||
if sessionKey == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "缺少 session_key"})
|
||||
return
|
||||
}
|
||||
|
||||
h.mu.Lock()
|
||||
sess, ok := h.logins[sessionKey]
|
||||
h.mu.Unlock()
|
||||
if !ok {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "登录会话不存在或已过期,请重新生成二维码"})
|
||||
return
|
||||
}
|
||||
if time.Since(sess.StartedAt) > wechatLoginTTL {
|
||||
h.mu.Lock()
|
||||
delete(h.logins, sessionKey)
|
||||
h.mu.Unlock()
|
||||
c.JSON(http.StatusGone, gin.H{"error": "二维码已过期,请重新生成"})
|
||||
return
|
||||
}
|
||||
|
||||
baseURL := sess.CurrentBaseURL
|
||||
if baseURL == "" {
|
||||
baseURL = ilink.DefaultBaseURL
|
||||
}
|
||||
vc := verifyCode
|
||||
if vc == "" {
|
||||
vc = sess.PendingVerify
|
||||
}
|
||||
|
||||
client := h.ilinkClient(baseURL)
|
||||
ctx, cancel := context.WithTimeout(c.Request.Context(), 40*time.Second)
|
||||
defer cancel()
|
||||
|
||||
st, err := client.GetQRCodeStatus(ctx, sess.QRCode, vc)
|
||||
if err != nil {
|
||||
h.logger.Warn("轮询微信二维码状态失败", zap.Error(err))
|
||||
c.JSON(http.StatusBadGateway, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
switch st.Status {
|
||||
case "wait", "scaned":
|
||||
c.JSON(http.StatusOK, gin.H{"status": st.Status})
|
||||
return
|
||||
case "need_verifycode":
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"status": st.Status,
|
||||
"message": "请在手机微信查看配对数字,并在下方输入",
|
||||
})
|
||||
return
|
||||
case "scaned_but_redirect":
|
||||
if st.RedirectHost != "" {
|
||||
h.mu.Lock()
|
||||
if s, ok := h.logins[sessionKey]; ok {
|
||||
s.CurrentBaseURL = "https://" + st.RedirectHost
|
||||
}
|
||||
h.mu.Unlock()
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{"status": st.Status})
|
||||
return
|
||||
case "binded_redirect":
|
||||
h.mu.Lock()
|
||||
delete(h.logins, sessionKey)
|
||||
h.mu.Unlock()
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"status": st.Status,
|
||||
"already_connected": true,
|
||||
"message": "该微信已绑定过,无需重复绑定",
|
||||
})
|
||||
return
|
||||
case "confirmed":
|
||||
if st.BotToken == "" || st.ILinkBotID == "" {
|
||||
c.JSON(http.StatusBadGateway, gin.H{"error": "绑定确认成功但缺少 bot_token"})
|
||||
return
|
||||
}
|
||||
saveBase := st.BaseURL
|
||||
if saveBase == "" {
|
||||
saveBase = baseURL
|
||||
}
|
||||
wc := h.config.Robots.Wechat
|
||||
wc.Enabled = true
|
||||
wc.BotToken = st.BotToken
|
||||
wc.ILinkBotID = st.ILinkBotID
|
||||
wc.ILinkUserID = st.ILinkUserID
|
||||
wc.BaseURL = saveBase
|
||||
if wc.BotType == "" {
|
||||
wc.BotType = ilink.DefaultBotType
|
||||
}
|
||||
if wc.BotAgent == "" {
|
||||
wc.BotAgent = ilink.DefaultBotAgent
|
||||
}
|
||||
if h.configSaver != nil {
|
||||
if err := h.configSaver.ApplyWechatRobotBinding(wc); err != nil {
|
||||
h.logger.Warn("保存微信机器人配置失败", zap.Error(err))
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "保存配置失败: " + err.Error()})
|
||||
return
|
||||
}
|
||||
} else {
|
||||
h.config.Robots.Wechat = wc
|
||||
}
|
||||
h.mu.Lock()
|
||||
delete(h.logins, sessionKey)
|
||||
h.mu.Unlock()
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"status": "confirmed",
|
||||
"message": "绑定成功,微信机器人已启用",
|
||||
"ilink_bot_id": st.ILinkBotID,
|
||||
"ilink_user_id": st.ILinkUserID,
|
||||
})
|
||||
return
|
||||
default:
|
||||
c.JSON(http.StatusOK, gin.H{"status": st.Status})
|
||||
}
|
||||
}
|
||||
|
||||
// HandleWechatVerifyCode POST /api/robot/wechat/qrcode/verify — 提交手机配对数字
|
||||
func (h *WechatRobotHandler) HandleWechatVerifyCode(c *gin.Context) {
|
||||
var req struct {
|
||||
SessionKey string `json:"session_key"`
|
||||
VerifyCode string `json:"verify_code"`
|
||||
}
|
||||
if err := c.ShouldBindJSON(&req); err != nil || req.SessionKey == "" || req.VerifyCode == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "需要 session_key 与 verify_code"})
|
||||
return
|
||||
}
|
||||
h.mu.Lock()
|
||||
sess, ok := h.logins[req.SessionKey]
|
||||
if ok {
|
||||
sess.PendingVerify = req.VerifyCode
|
||||
}
|
||||
h.mu.Unlock()
|
||||
if !ok {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "登录会话不存在或已过期"})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{"message": "已提交配对码,请继续等待绑定"})
|
||||
}
|
||||
|
||||
// HandleWechatStatus GET /api/robot/wechat/status — 当前绑定状态(供前端展示)
|
||||
func (h *WechatRobotHandler) HandleWechatStatus(c *gin.Context) {
|
||||
wc := h.config.Robots.Wechat
|
||||
bound := wc.BotToken != "" && wc.ILinkBotID != ""
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"enabled": wc.Enabled,
|
||||
"bound": bound,
|
||||
"ilink_bot_id": wc.ILinkBotID,
|
||||
"ilink_user_id": wc.ILinkUserID,
|
||||
"base_url": wc.BaseURL,
|
||||
})
|
||||
}
|
||||
+81
-1
@@ -44,6 +44,10 @@ type Server struct {
|
||||
runningCancels map[string]context.CancelFunc
|
||||
runningCancelsMu sync.Mutex
|
||||
abortUserNotes map[string]string // 监控页终止时附带的用户说明,与 executionID 对应
|
||||
// httpToolTimeoutMinutes 同步 agent.tool_timeout_minutes,用于 POST /api/mcp 的 tools/call(不经 Agent 包装的路径)。
|
||||
// nil 表示未配置,沿用默认 30 分钟;指向 0 表示不限制;>0 为分钟数。
|
||||
httpToolTimeoutMinutes *int
|
||||
httpToolTimeoutMu sync.RWMutex
|
||||
}
|
||||
|
||||
type sseClient struct {
|
||||
@@ -90,6 +94,39 @@ func NewServerWithStorage(logger *zap.Logger, storage MonitorStorage) *Server {
|
||||
return s
|
||||
}
|
||||
|
||||
// ConfigureHTTPToolCallTimeoutFromAgentMinutes 将 agent.tool_timeout_minutes 同步到经 HTTP POST /api/mcp 触发的 tools/call。
|
||||
// minutes<=0 表示不设置硬性截止时间(与配置「0 不限制」一致);minutes>0 为该次调用的最长等待时间。
|
||||
// 未调用前对 tools/call 使用默认 30 分钟(与历史硬编码一致)。
|
||||
func (s *Server) ConfigureHTTPToolCallTimeoutFromAgentMinutes(minutes int) {
|
||||
if s == nil {
|
||||
return
|
||||
}
|
||||
v := minutes
|
||||
if v < 0 {
|
||||
v = 0
|
||||
}
|
||||
s.httpToolTimeoutMu.Lock()
|
||||
defer s.httpToolTimeoutMu.Unlock()
|
||||
s.httpToolTimeoutMinutes = &v
|
||||
}
|
||||
|
||||
func (s *Server) effectiveHTTPToolCallDeadline() (context.Context, context.CancelFunc) {
|
||||
const defaultDur = 30 * time.Minute
|
||||
if s == nil {
|
||||
return context.WithTimeout(context.Background(), defaultDur)
|
||||
}
|
||||
s.httpToolTimeoutMu.RLock()
|
||||
mPtr := s.httpToolTimeoutMinutes
|
||||
s.httpToolTimeoutMu.RUnlock()
|
||||
if mPtr == nil {
|
||||
return context.WithTimeout(context.Background(), defaultDur)
|
||||
}
|
||||
if *mPtr <= 0 {
|
||||
return context.WithCancel(context.Background())
|
||||
}
|
||||
return context.WithTimeout(context.Background(), time.Duration(*mPtr)*time.Minute)
|
||||
}
|
||||
|
||||
// RegisterTool 注册工具
|
||||
func (s *Server) RegisterTool(tool Tool, handler ToolHandler) {
|
||||
s.mu.Lock()
|
||||
@@ -457,7 +494,7 @@ func (s *Server) handleCallTool(msg *Message) *Message {
|
||||
}
|
||||
}
|
||||
|
||||
baseCtx, timeoutCancel := context.WithTimeout(context.Background(), 30*time.Minute)
|
||||
baseCtx, timeoutCancel := s.effectiveHTTPToolCallDeadline()
|
||||
defer timeoutCancel()
|
||||
execCtx, runCancel := context.WithCancel(baseCtx)
|
||||
s.registerRunningCancel(executionID, runCancel)
|
||||
@@ -883,6 +920,49 @@ func (s *Server) CallTool(ctx context.Context, toolName string, args map[string]
|
||||
return finalResult, executionID, nil
|
||||
}
|
||||
|
||||
// RecordCompletedToolInvocation 将已在其它路径完成的工具调用写入监控存储(格式与 CallTool 结束后一致),
|
||||
// 用于 Eino ADK filesystem execute 等未经过 CallTool 的场景;返回 executionId 供助手消息 mcpExecutionIds 关联。
|
||||
func (s *Server) RecordCompletedToolInvocation(toolName string, args map[string]interface{}, resultText string, invokeErr error) string {
|
||||
if s == nil {
|
||||
return ""
|
||||
}
|
||||
if args == nil {
|
||||
args = map[string]interface{}{}
|
||||
}
|
||||
executionID := uuid.New().String()
|
||||
now := time.Now()
|
||||
failed := invokeErr != nil
|
||||
exec := &ToolExecution{
|
||||
ID: executionID,
|
||||
ToolName: toolName,
|
||||
Arguments: args,
|
||||
StartTime: now,
|
||||
EndTime: &now,
|
||||
Duration: 0,
|
||||
}
|
||||
if failed {
|
||||
exec.Status = "failed"
|
||||
exec.Error = invokeErr.Error()
|
||||
if strings.TrimSpace(resultText) != "" {
|
||||
exec.Result = &ToolResult{Content: []Content{{Type: "text", Text: resultText}}}
|
||||
}
|
||||
} else {
|
||||
exec.Status = "completed"
|
||||
text := resultText
|
||||
if strings.TrimSpace(text) == "" {
|
||||
text = "(无输出)"
|
||||
}
|
||||
exec.Result = &ToolResult{Content: []Content{{Type: "text", Text: text}}}
|
||||
}
|
||||
if s.storage != nil {
|
||||
if err := s.storage.SaveToolExecution(exec); err != nil {
|
||||
s.logger.Warn("RecordCompletedToolInvocation 保存失败", zap.Error(err))
|
||||
}
|
||||
}
|
||||
s.updateStats(toolName, failed)
|
||||
return executionID
|
||||
}
|
||||
|
||||
// cleanupOldExecutions 清理旧的执行记录,防止内存无限增长
|
||||
func (s *Server) cleanupOldExecutions() {
|
||||
if len(s.executions) <= s.maxExecutionsInMemory {
|
||||
|
||||
@@ -11,8 +11,13 @@ import (
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"unicode/utf8"
|
||||
|
||||
"cyberstrike-ai/internal/agent"
|
||||
"cyberstrike-ai/internal/config"
|
||||
"cyberstrike-ai/internal/einomcp"
|
||||
"cyberstrike-ai/internal/einoobserve"
|
||||
"cyberstrike-ai/internal/openai"
|
||||
|
||||
"github.com/cloudwego/eino/adk"
|
||||
"github.com/cloudwego/eino/schema"
|
||||
@@ -20,7 +25,9 @@ import (
|
||||
)
|
||||
|
||||
// normalizeStreamingDelta 将可能是“累计片段”的 chunk 归一化为“纯增量”。
|
||||
// 一些模型/桥接层在流式过程中会重复发送已输出前缀,前端若直接 buffer+=chunk 会出现“结巴”重复。
|
||||
// 一些模型/桥接层在流式过程中会重复发送已输出前缀,前端若直接 buffer+=chunk 会出现重复文本。
|
||||
//
|
||||
// 注意:与 internal/openai.normalizeStreamingDelta 保持一致。
|
||||
func normalizeStreamingDelta(current, incoming string) (next, delta string) {
|
||||
if incoming == "" {
|
||||
return current, ""
|
||||
@@ -28,31 +35,22 @@ func normalizeStreamingDelta(current, incoming string) (next, delta string) {
|
||||
if current == "" {
|
||||
return incoming, incoming
|
||||
}
|
||||
if incoming == current {
|
||||
return current, ""
|
||||
}
|
||||
// incoming 是累计全文(包含 current 前缀)
|
||||
if strings.HasPrefix(incoming, current) {
|
||||
if strings.HasPrefix(incoming, current) && len(incoming) > len(current) {
|
||||
return incoming, incoming[len(current):]
|
||||
}
|
||||
// incoming 完全是已输出尾部重发
|
||||
if strings.HasSuffix(current, incoming) {
|
||||
if incoming == current && utf8.RuneCountInString(current) > 1 {
|
||||
return current, ""
|
||||
}
|
||||
|
||||
// 处理边界重叠:current 后缀与 incoming 前缀重叠,只追加非重叠部分。
|
||||
max := len(current)
|
||||
if len(incoming) < max {
|
||||
max = len(incoming)
|
||||
}
|
||||
for overlap := max; overlap > 0; overlap-- {
|
||||
if current[len(current)-overlap:] == incoming[:overlap] {
|
||||
return current + incoming[overlap:], incoming[overlap:]
|
||||
}
|
||||
}
|
||||
return current + incoming, incoming
|
||||
}
|
||||
|
||||
func isInterruptContinue(ctx context.Context) bool {
|
||||
if ctx == nil {
|
||||
return false
|
||||
}
|
||||
return errors.Is(context.Cause(ctx), ErrInterruptContinue)
|
||||
}
|
||||
|
||||
func isEinoIterationLimitError(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
@@ -79,14 +77,32 @@ type einoADKRunLoopArgs struct {
|
||||
StreamsMainAssistant func(agent string) bool
|
||||
EinoRoleTag func(agent string) string
|
||||
CheckpointDir string
|
||||
// RunRetryMaxAttempts / RunRetryMaxBackoffSec:429、5xx、网络抖动时的指数退避续跑(0=默认 10 次 / 30s 上限)。
|
||||
RunRetryMaxAttempts int
|
||||
RunRetryMaxBackoffSec int
|
||||
|
||||
McpIDsMu *sync.Mutex
|
||||
McpIDs *[]string
|
||||
|
||||
// FilesystemMonitorAgent / FilesystemMonitorRecord 非 nil 时,将 Eino ADK filesystem 中间件工具(ls/read_file/write_file/edit_file/glob/grep)
|
||||
// 在完成时写入 MCP 监控;execute 仍由 eino_execute_monitor 记录,此处跳过。
|
||||
FilesystemMonitorAgent *agent.Agent
|
||||
FilesystemMonitorRecord einomcp.ExecutionRecorder
|
||||
|
||||
// ToolInvokeNotify 与 einomcp.ToolsFromDefinitions 共享:run loop 在迭代前 Set,MCP 桥 Fire 以补全 tool_result。
|
||||
ToolInvokeNotify *einomcp.ToolInvokeNotifyHolder
|
||||
|
||||
DA adk.Agent
|
||||
|
||||
// EmptyResponseMessage 当未捕获到助手正文时的占位(多代理与单代理文案不同)。
|
||||
EmptyResponseMessage string
|
||||
|
||||
// ModelFacingTrace 可选:由各 ChatModelAgent Handlers 链末尾中间件写入「即将送入模型」的消息快照;
|
||||
// 非空时优先用于 LastAgentTraceInput 序列化,使续跑与 summarization/reduction 后的上下文一致。
|
||||
ModelFacingTrace *modelFacingTraceHolder
|
||||
|
||||
// EinoCallbacks 可选:为 ADK Runner 注入 eino [callbacks] 全链路观测(见 internal/einoobserve)。
|
||||
EinoCallbacks *config.MultiAgentEinoCallbacksConfig
|
||||
}
|
||||
|
||||
func runEinoADKAgentLoop(ctx context.Context, args *einoADKRunLoopArgs, baseMsgs []adk.Message) (*RunResult, error) {
|
||||
@@ -164,6 +180,8 @@ func runEinoADKAgentLoop(ctx context.Context, args *einoADKRunLoopArgs, baseMsgs
|
||||
var einoMainRound int
|
||||
var einoLastAgent string
|
||||
subAgentToolStep := make(map[string]int)
|
||||
// mainAgentToolStep:主代理每次工具调用批次递增,供 UI 显示「第 N 轮」(单代理无子代理切换时原先会一直停在第 1 轮)。
|
||||
mainAgentToolStep := make(map[string]int)
|
||||
pendingByID := make(map[string]toolCallPendingInfo)
|
||||
pendingQueueByAgent := make(map[string][]string)
|
||||
markPending := func(tc toolCallPendingInfo) {
|
||||
@@ -224,6 +242,82 @@ func runEinoADKAgentLoop(ctx context.Context, args *einoADKRunLoopArgs, baseMsgs
|
||||
pendingQueueByAgent = make(map[string][]string)
|
||||
}
|
||||
|
||||
// 最近一次成功的 Eino filesystem execute 的标准输出(trim):用于抑制模型紧接着复述同一字符串时的重复「助手输出」时间线。
|
||||
var executeStdoutDupMu sync.Mutex
|
||||
var pendingExecuteStdoutDup string
|
||||
recordPendingExecuteStdoutDup := func(toolName, stdout string, isErr bool) {
|
||||
if isErr || !strings.EqualFold(strings.TrimSpace(toolName), "execute") {
|
||||
return
|
||||
}
|
||||
t := strings.TrimSpace(stdout)
|
||||
if t == "" {
|
||||
return
|
||||
}
|
||||
executeStdoutDupMu.Lock()
|
||||
pendingExecuteStdoutDup = t
|
||||
executeStdoutDupMu.Unlock()
|
||||
}
|
||||
|
||||
var toolResultSent sync.Map // toolCallID -> struct{};与 ADK Tool 消息去重,避免 bridge 与事件流各推一次
|
||||
if args.ToolInvokeNotify != nil {
|
||||
args.ToolInvokeNotify.Set(func(toolCallID, toolName, einoAgent string, success bool, content string, invokeErr error) {
|
||||
tid := strings.TrimSpace(toolCallID)
|
||||
removePendingByID(tid)
|
||||
if tid == "" || progress == nil {
|
||||
return
|
||||
}
|
||||
if _, loaded := toolResultSent.LoadOrStore(tid, struct{}{}); loaded {
|
||||
return
|
||||
}
|
||||
isErr := !success || invokeErr != nil
|
||||
body := content
|
||||
if invokeErr != nil {
|
||||
// 保留已流式累计的 stdout(如 execute 超时前的一半输出),避免 tool_result 只剩错误串、模型与 UI 丢失上下文
|
||||
tail := friendlyEinoExecuteInvokeTail(invokeErr)
|
||||
// execute 流式包装可能已把超时句写入 content(供 ADK tool 与流式 delta);勿重复拼接
|
||||
if tail != "" && strings.Contains(content, tail) {
|
||||
body = content
|
||||
} else if strings.TrimSpace(content) != "" {
|
||||
body = strings.TrimRight(content, "\n") + "\n\n" + tail
|
||||
} else {
|
||||
body = tail
|
||||
}
|
||||
isErr = true
|
||||
}
|
||||
recordPendingExecuteStdoutDup(toolName, body, isErr)
|
||||
preview := body
|
||||
if len(preview) > 200 {
|
||||
preview = preview[:200] + "..."
|
||||
}
|
||||
agentTag := strings.TrimSpace(einoAgent)
|
||||
if agentTag == "" {
|
||||
agentTag = orchestratorName
|
||||
}
|
||||
progress("tool_result", fmt.Sprintf("工具结果 (%s)", toolName), map[string]interface{}{
|
||||
"toolName": toolName,
|
||||
"success": !isErr,
|
||||
"isError": isErr,
|
||||
"result": body,
|
||||
"resultPreview": preview,
|
||||
"toolCallId": tid,
|
||||
"conversationId": conversationID,
|
||||
"einoAgent": agentTag,
|
||||
"einoRole": einoRoleTag(agentTag),
|
||||
"source": "eino",
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
if args.EinoCallbacks != nil {
|
||||
ctx = einoobserve.AttachAgentRunCallbacks(ctx, args.EinoCallbacks, einoobserve.Params{
|
||||
Logger: logger,
|
||||
Progress: progress,
|
||||
ConversationID: conversationID,
|
||||
OrchMode: orchMode,
|
||||
OrchestratorName: orchestratorName,
|
||||
})
|
||||
}
|
||||
|
||||
runnerCfg := adk.RunnerConfig{
|
||||
Agent: da,
|
||||
EnableStreaming: true,
|
||||
@@ -346,13 +440,36 @@ func runEinoADKAgentLoop(ctx context.Context, args *einoADKRunLoopArgs, baseMsgs
|
||||
return runErr
|
||||
}
|
||||
|
||||
// maybeRetryTransientRun:不在此层 runner.Run/Resume;由 handler 落库 + loadHistoryFromAgentTrace 分段续跑(同中断并继续)。
|
||||
maybeRetryTransientRun := func(runErr error) (retry bool, fatal error) {
|
||||
if runErr == nil || !isEinoTransientRunError(runErr) {
|
||||
return false, handleRunErr(runErr)
|
||||
}
|
||||
if logger != nil {
|
||||
logger.Warn("eino transient error, ending run segment for handler resume",
|
||||
zap.Error(runErr),
|
||||
zap.String("orchestration", orchMode))
|
||||
}
|
||||
if progress != nil {
|
||||
progress("eino_run_retry", "遇到临时错误(限流或网络波动),将保存上下文并重试…", map[string]interface{}{
|
||||
"conversationId": conversationID,
|
||||
"source": "eino",
|
||||
"orchestration": orchMode,
|
||||
"error": runErr.Error(),
|
||||
"resumeKind": "trace_segment",
|
||||
})
|
||||
}
|
||||
return false, ErrTransientRetryContinue
|
||||
}
|
||||
|
||||
takePartial := func(runErr error) (*RunResult, error) {
|
||||
if len(runAccumulatedMsgs) <= baseAccumulatedCount {
|
||||
return nil, runErr
|
||||
}
|
||||
ids := snapshotMCPIDs()
|
||||
return buildEinoRunResultFromAccumulated(
|
||||
orchMode, runAccumulatedMsgs, lastAssistant, lastPlanExecuteExecutor, emptyHint, ids, true,
|
||||
orchMode, runAccumulatedMsgs, persistTraceSource(args, runAccumulatedMsgs),
|
||||
lastAssistant, lastPlanExecuteExecutor, emptyHint, ids, true,
|
||||
), runErr
|
||||
}
|
||||
|
||||
@@ -362,10 +479,18 @@ func runEinoADKAgentLoop(ctx context.Context, args *einoADKRunLoopArgs, baseMsgs
|
||||
case <-ctx.Done():
|
||||
flushAllPendingAsFailed(ctx.Err())
|
||||
if progress != nil {
|
||||
progress("error", "Request cancelled / 请求已取消", map[string]interface{}{
|
||||
"conversationId": conversationID,
|
||||
"source": "eino",
|
||||
})
|
||||
if isInterruptContinue(ctx) {
|
||||
progress("progress", "已暂停当前输出,正在合并用户补充并继续…", map[string]interface{}{
|
||||
"conversationId": conversationID,
|
||||
"source": "eino",
|
||||
"kind": "interrupt_continue",
|
||||
})
|
||||
} else {
|
||||
progress("error", "Request cancelled / 请求已取消", map[string]interface{}{
|
||||
"conversationId": conversationID,
|
||||
"source": "eino",
|
||||
})
|
||||
}
|
||||
}
|
||||
return takePartial(ctx.Err())
|
||||
default:
|
||||
@@ -379,10 +504,18 @@ func runEinoADKAgentLoop(ctx context.Context, args *einoADKRunLoopArgs, baseMsgs
|
||||
if ctxErr := ctx.Err(); ctxErr != nil {
|
||||
flushAllPendingAsFailed(ctxErr)
|
||||
if progress != nil {
|
||||
progress("error", ctxErr.Error(), map[string]interface{}{
|
||||
"conversationId": conversationID,
|
||||
"source": "eino",
|
||||
})
|
||||
if isInterruptContinue(ctx) {
|
||||
progress("progress", "已暂停当前输出,正在合并用户补充并继续…", map[string]interface{}{
|
||||
"conversationId": conversationID,
|
||||
"source": "eino",
|
||||
"kind": "interrupt_continue",
|
||||
})
|
||||
} else {
|
||||
progress("error", ctxErr.Error(), map[string]interface{}{
|
||||
"conversationId": conversationID,
|
||||
"source": "eino",
|
||||
})
|
||||
}
|
||||
}
|
||||
return takePartial(ctxErr)
|
||||
}
|
||||
@@ -411,7 +544,7 @@ func runEinoADKAgentLoop(ctx context.Context, args *einoADKRunLoopArgs, baseMsgs
|
||||
continue
|
||||
}
|
||||
if ev.Err != nil {
|
||||
if retErr := handleRunErr(ev.Err); retErr != nil {
|
||||
if _, retErr := maybeRetryTransientRun(ev.Err); retErr != nil {
|
||||
return takePartial(retErr)
|
||||
}
|
||||
}
|
||||
@@ -423,8 +556,10 @@ func runEinoADKAgentLoop(ctx context.Context, args *einoADKRunLoopArgs, baseMsgs
|
||||
}
|
||||
}
|
||||
if streamsMainAssistant(ev.AgentName) {
|
||||
mainIterKey := einoMainIterationKey(iterEinoAgent, orchestratorName)
|
||||
if einoMainRound == 0 {
|
||||
einoMainRound = 1
|
||||
mainAgentToolStep[mainIterKey] = 1
|
||||
progress("iteration", "", map[string]interface{}{
|
||||
"iteration": 1,
|
||||
"einoScope": "main",
|
||||
@@ -434,17 +569,26 @@ func runEinoADKAgentLoop(ctx context.Context, args *einoADKRunLoopArgs, baseMsgs
|
||||
"conversationId": conversationID,
|
||||
"source": "eino",
|
||||
})
|
||||
} else if einoLastAgent != "" && !streamsMainAssistant(einoLastAgent) {
|
||||
einoMainRound++
|
||||
progress("iteration", "", map[string]interface{}{
|
||||
"iteration": einoMainRound,
|
||||
"einoScope": "main",
|
||||
"einoRole": "orchestrator",
|
||||
"einoAgent": iterEinoAgent,
|
||||
"orchestration": orchMode,
|
||||
"conversationId": conversationID,
|
||||
"source": "eino",
|
||||
})
|
||||
} else if einoLastAgent != "" {
|
||||
needBump := false
|
||||
if !streamsMainAssistant(einoLastAgent) {
|
||||
needBump = true // 子代理 → 主代理
|
||||
} else if einoLastAgent != ev.AgentName {
|
||||
needBump = true // plan_execute:planner ↔ executor 等主代理切换
|
||||
}
|
||||
if needBump {
|
||||
einoMainRound++
|
||||
mainAgentToolStep[mainIterKey] = einoMainRound
|
||||
progress("iteration", "", map[string]interface{}{
|
||||
"iteration": einoMainRound,
|
||||
"einoScope": "main",
|
||||
"einoRole": "orchestrator",
|
||||
"einoAgent": iterEinoAgent,
|
||||
"orchestration": orchMode,
|
||||
"conversationId": conversationID,
|
||||
"source": "eino",
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
einoLastAgent = ev.AgentName
|
||||
@@ -467,98 +611,193 @@ func runEinoADKAgentLoop(ctx context.Context, args *einoADKRunLoopArgs, baseMsgs
|
||||
var subAssistantBuf string
|
||||
var subReplyStreamID string
|
||||
var mainAssistantBuf string
|
||||
// 已通过 response_delta 推到前端的正文(与 monitor.js normalizeStreamingDeltaJs 累积一致)
|
||||
var mainAssistWireAccum string
|
||||
var mainAssistDupTarget string // 非空表示本段主助手流需缓冲至 EOF,与 execute 输出比对去重
|
||||
var reasoningBuf string
|
||||
var prevReasoningDisplay string // UI 用:剥离 Claude 内部 signature 尾缀后的累计展示
|
||||
var streamRecvErr error
|
||||
type streamMsg struct {
|
||||
chunk *schema.Message
|
||||
err error
|
||||
}
|
||||
recvCh := make(chan streamMsg, 8)
|
||||
go func() {
|
||||
defer close(recvCh)
|
||||
for {
|
||||
ch, rerr := mv.MessageStream.Recv()
|
||||
recvCh <- streamMsg{chunk: ch, err: rerr}
|
||||
if rerr != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
streamRecvLoop:
|
||||
for {
|
||||
chunk, rerr := mv.MessageStream.Recv()
|
||||
if rerr != nil {
|
||||
if errors.Is(rerr, io.EOF) {
|
||||
break
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
streamRecvErr = ctx.Err()
|
||||
break streamRecvLoop
|
||||
case sm, ok := <-recvCh:
|
||||
if !ok {
|
||||
break streamRecvLoop
|
||||
}
|
||||
if logger != nil {
|
||||
logger.Warn("eino stream recv error, flushing incomplete stream",
|
||||
zap.Error(rerr),
|
||||
zap.String("agent", ev.AgentName),
|
||||
zap.Int("toolFragments", len(toolStreamFragments)))
|
||||
}
|
||||
streamRecvErr = rerr
|
||||
break
|
||||
}
|
||||
if chunk == nil {
|
||||
continue
|
||||
}
|
||||
if progress != nil && strings.TrimSpace(chunk.ReasoningContent) != "" {
|
||||
var reasoningDelta string
|
||||
reasoningBuf, reasoningDelta = normalizeStreamingDelta(reasoningBuf, chunk.ReasoningContent)
|
||||
if reasoningDelta != "" {
|
||||
if reasoningStreamID == "" {
|
||||
reasoningStreamID = fmt.Sprintf("eino-reasoning-%s-%d", conversationID, atomic.AddInt64(&reasoningStreamSeq, 1))
|
||||
progress("thinking_stream_start", " ", map[string]interface{}{
|
||||
"streamId": reasoningStreamID,
|
||||
"source": "eino",
|
||||
"einoAgent": ev.AgentName,
|
||||
"einoRole": einoRoleTag(ev.AgentName),
|
||||
"orchestration": orchMode,
|
||||
})
|
||||
chunk, rerr := sm.chunk, sm.err
|
||||
if rerr != nil {
|
||||
if errors.Is(rerr, io.EOF) {
|
||||
break streamRecvLoop
|
||||
}
|
||||
progress("thinking_stream_delta", reasoningDelta, map[string]interface{}{
|
||||
"streamId": reasoningStreamID,
|
||||
})
|
||||
if logger != nil {
|
||||
logger.Warn("eino stream recv error, flushing incomplete stream",
|
||||
zap.Error(rerr),
|
||||
zap.String("agent", ev.AgentName),
|
||||
zap.Int("toolFragments", len(toolStreamFragments)))
|
||||
}
|
||||
streamRecvErr = rerr
|
||||
break streamRecvLoop
|
||||
}
|
||||
}
|
||||
if chunk.Content != "" {
|
||||
if progress != nil && streamsMainAssistant(ev.AgentName) {
|
||||
var contentDelta string
|
||||
mainAssistantBuf, contentDelta = normalizeStreamingDelta(mainAssistantBuf, chunk.Content)
|
||||
if contentDelta != "" {
|
||||
if !streamHeaderSent {
|
||||
progress("response_start", "", map[string]interface{}{
|
||||
"conversationId": conversationID,
|
||||
"mcpExecutionIds": snapshotMCPIDs(),
|
||||
"messageGeneratedBy": "eino:" + ev.AgentName,
|
||||
"einoRole": "orchestrator",
|
||||
"einoAgent": ev.AgentName,
|
||||
"orchestration": orchMode,
|
||||
})
|
||||
streamHeaderSent = true
|
||||
if chunk == nil {
|
||||
continue
|
||||
}
|
||||
if progress != nil && strings.TrimSpace(chunk.ReasoningContent) != "" {
|
||||
var reasoningDelta string
|
||||
reasoningBuf, reasoningDelta = normalizeStreamingDelta(reasoningBuf, chunk.ReasoningContent)
|
||||
if reasoningDelta != "" {
|
||||
fullDisplay := openai.DisplayReasoningContent(reasoningBuf)
|
||||
var displayDelta string
|
||||
if strings.HasPrefix(fullDisplay, prevReasoningDisplay) {
|
||||
displayDelta = fullDisplay[len(prevReasoningDisplay):]
|
||||
} else {
|
||||
displayDelta = fullDisplay
|
||||
}
|
||||
progress("response_delta", contentDelta, map[string]interface{}{
|
||||
"conversationId": conversationID,
|
||||
"mcpExecutionIds": snapshotMCPIDs(),
|
||||
"einoRole": "orchestrator",
|
||||
"einoAgent": ev.AgentName,
|
||||
"orchestration": orchMode,
|
||||
})
|
||||
}
|
||||
} else if !streamsMainAssistant(ev.AgentName) {
|
||||
var subDelta string
|
||||
subAssistantBuf, subDelta = normalizeStreamingDelta(subAssistantBuf, chunk.Content)
|
||||
if subDelta != "" {
|
||||
if progress != nil {
|
||||
if subReplyStreamID == "" {
|
||||
subReplyStreamID = fmt.Sprintf("eino-sub-reply-%s-%d", conversationID, atomic.AddInt64(&einoSubReplyStreamSeq, 1))
|
||||
progress("eino_agent_reply_stream_start", "", map[string]interface{}{
|
||||
"streamId": subReplyStreamID,
|
||||
"einoAgent": ev.AgentName,
|
||||
"einoRole": "sub",
|
||||
"conversationId": conversationID,
|
||||
"source": "eino",
|
||||
prevReasoningDisplay = fullDisplay
|
||||
if displayDelta != "" {
|
||||
if reasoningStreamID == "" {
|
||||
reasoningStreamID = fmt.Sprintf("eino-reasoning-%s-%d", conversationID, atomic.AddInt64(&reasoningStreamSeq, 1))
|
||||
progress("reasoning_chain_stream_start", " ", map[string]interface{}{
|
||||
"streamId": reasoningStreamID,
|
||||
"source": "eino",
|
||||
"einoAgent": ev.AgentName,
|
||||
"einoRole": einoRoleTag(ev.AgentName),
|
||||
"orchestration": orchMode,
|
||||
})
|
||||
}
|
||||
progress("eino_agent_reply_stream_delta", subDelta, map[string]interface{}{
|
||||
"streamId": subReplyStreamID,
|
||||
"conversationId": conversationID,
|
||||
})
|
||||
progress("reasoning_chain_stream_delta", displayDelta, openai.WithSSEAccumulated(map[string]interface{}{
|
||||
"streamId": reasoningStreamID,
|
||||
}, fullDisplay))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if len(chunk.ToolCalls) > 0 {
|
||||
toolStreamFragments = append(toolStreamFragments, chunk.ToolCalls...)
|
||||
if chunk.Content != "" {
|
||||
if progress != nil && streamsMainAssistant(ev.AgentName) {
|
||||
var contentDelta string
|
||||
mainAssistantBuf, contentDelta = normalizeStreamingDelta(mainAssistantBuf, chunk.Content)
|
||||
if contentDelta != "" {
|
||||
if mainAssistDupTarget == "" {
|
||||
executeStdoutDupMu.Lock()
|
||||
if pendingExecuteStdoutDup != "" {
|
||||
mainAssistDupTarget = pendingExecuteStdoutDup
|
||||
}
|
||||
executeStdoutDupMu.Unlock()
|
||||
}
|
||||
if mainAssistDupTarget != "" {
|
||||
// 已展示过 tool_result,缓冲全文;EOF 后与 execute 输出相同则不再发助手流
|
||||
} else {
|
||||
if !streamHeaderSent {
|
||||
progress("response_start", "", map[string]interface{}{
|
||||
"conversationId": conversationID,
|
||||
"mcpExecutionIds": snapshotMCPIDs(),
|
||||
"messageGeneratedBy": "eino:" + ev.AgentName,
|
||||
"einoRole": "orchestrator",
|
||||
"einoAgent": ev.AgentName,
|
||||
"orchestration": orchMode,
|
||||
})
|
||||
streamHeaderSent = true
|
||||
}
|
||||
progress("response_delta", contentDelta, openai.WithSSEAccumulated(map[string]interface{}{
|
||||
"conversationId": conversationID,
|
||||
"mcpExecutionIds": snapshotMCPIDs(),
|
||||
"einoRole": "orchestrator",
|
||||
"einoAgent": ev.AgentName,
|
||||
"orchestration": orchMode,
|
||||
}, mainAssistantBuf))
|
||||
mainAssistWireAccum, _ = normalizeStreamingDelta(mainAssistWireAccum, contentDelta)
|
||||
}
|
||||
}
|
||||
} else if !streamsMainAssistant(ev.AgentName) {
|
||||
var subDelta string
|
||||
subAssistantBuf, subDelta = normalizeStreamingDelta(subAssistantBuf, chunk.Content)
|
||||
if subDelta != "" {
|
||||
if progress != nil {
|
||||
if subReplyStreamID == "" {
|
||||
subReplyStreamID = fmt.Sprintf("eino-sub-reply-%s-%d", conversationID, atomic.AddInt64(&einoSubReplyStreamSeq, 1))
|
||||
progress("eino_agent_reply_stream_start", "", map[string]interface{}{
|
||||
"streamId": subReplyStreamID,
|
||||
"einoAgent": ev.AgentName,
|
||||
"einoRole": "sub",
|
||||
"conversationId": conversationID,
|
||||
"source": "eino",
|
||||
})
|
||||
}
|
||||
progress("eino_agent_reply_stream_delta", subDelta, openai.WithSSEAccumulated(map[string]interface{}{
|
||||
"streamId": subReplyStreamID,
|
||||
"conversationId": conversationID,
|
||||
}, subAssistantBuf))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if len(chunk.ToolCalls) > 0 {
|
||||
toolStreamFragments = append(toolStreamFragments, chunk.ToolCalls...)
|
||||
}
|
||||
}
|
||||
}
|
||||
if streamsMainAssistant(ev.AgentName) {
|
||||
if s := strings.TrimSpace(mainAssistantBuf); s != "" {
|
||||
s := strings.TrimSpace(mainAssistantBuf)
|
||||
if mainAssistDupTarget != "" {
|
||||
executeStdoutDupMu.Lock()
|
||||
pendingExecuteStdoutDup = ""
|
||||
executeStdoutDupMu.Unlock()
|
||||
if s != "" && s == mainAssistDupTarget {
|
||||
// 与刚展示的 execute 结果完全一致:不再发助手流式事件,仍写入轨迹与最终回复字段
|
||||
lastAssistant = s
|
||||
runAccumulatedMsgs = append(runAccumulatedMsgs, schema.AssistantMessage(s, nil))
|
||||
if orchMode == "plan_execute" && strings.EqualFold(strings.TrimSpace(ev.AgentName), "executor") {
|
||||
lastPlanExecuteExecutor = UnwrapPlanExecuteUserText(s)
|
||||
}
|
||||
} else if s != "" {
|
||||
if progress != nil {
|
||||
// 仅用 TrimSpace 与 execute 比对;推到 UI 的必须是 mainAssistantBuf,
|
||||
// 否则尾部空白/换行与已流式前缀不一致时,前端 normalize 会走拼接路径造成叠字。
|
||||
_, eofTail := normalizeStreamingDelta(mainAssistWireAccum, mainAssistantBuf)
|
||||
if eofTail != "" {
|
||||
if !streamHeaderSent {
|
||||
progress("response_start", "", map[string]interface{}{
|
||||
"conversationId": conversationID,
|
||||
"mcpExecutionIds": snapshotMCPIDs(),
|
||||
"messageGeneratedBy": "eino:" + ev.AgentName,
|
||||
"einoRole": "orchestrator",
|
||||
"einoAgent": ev.AgentName,
|
||||
"orchestration": orchMode,
|
||||
})
|
||||
}
|
||||
progress("response_delta", eofTail, openai.WithSSEAccumulated(map[string]interface{}{
|
||||
"conversationId": conversationID,
|
||||
"mcpExecutionIds": snapshotMCPIDs(),
|
||||
"einoRole": "orchestrator",
|
||||
"einoAgent": ev.AgentName,
|
||||
"orchestration": orchMode,
|
||||
}, mainAssistantBuf))
|
||||
mainAssistWireAccum, _ = normalizeStreamingDelta(mainAssistWireAccum, eofTail)
|
||||
}
|
||||
}
|
||||
lastAssistant = s
|
||||
runAccumulatedMsgs = append(runAccumulatedMsgs, schema.AssistantMessage(s, nil))
|
||||
if orchMode == "plan_execute" && strings.EqualFold(strings.TrimSpace(ev.AgentName), "executor") {
|
||||
lastPlanExecuteExecutor = UnwrapPlanExecuteUserText(s)
|
||||
}
|
||||
}
|
||||
} else if s != "" {
|
||||
lastAssistant = s
|
||||
runAccumulatedMsgs = append(runAccumulatedMsgs, schema.AssistantMessage(s, nil))
|
||||
if orchMode == "plan_execute" && strings.EqualFold(strings.TrimSpace(ev.AgentName), "executor") {
|
||||
@@ -588,10 +827,17 @@ func runEinoADKAgentLoop(ctx context.Context, args *einoADKRunLoopArgs, baseMsgs
|
||||
}
|
||||
var lastToolChunk *schema.Message
|
||||
if merged := mergeStreamingToolCallFragments(toolStreamFragments); len(merged) > 0 {
|
||||
lastToolChunk = &schema.Message{ToolCalls: merged}
|
||||
lastToolChunk = mergeMessageToolCalls(&schema.Message{ToolCalls: merged})
|
||||
}
|
||||
tryEmitToolCallsOnce(lastToolChunk, ev.AgentName, orchestratorName, conversationID, orchMode, progress, toolEmitSeen, subAgentToolStep, mainAgentToolStep, markPending)
|
||||
// 流式路径此前只把 tool_calls 推给进度 UI,未写入 runAccumulatedMsgs;落库后 loadHistory→RepairOrphan 会删掉全部 tool 结果,表现为「续跑/下轮失忆」。
|
||||
if lastToolChunk != nil && len(lastToolChunk.ToolCalls) > 0 {
|
||||
runAccumulatedMsgs = append(runAccumulatedMsgs, schema.AssistantMessage("", lastToolChunk.ToolCalls))
|
||||
}
|
||||
tryEmitToolCallsOnce(lastToolChunk, ev.AgentName, orchestratorName, conversationID, progress, toolEmitSeen, subAgentToolStep, markPending)
|
||||
if streamRecvErr != nil {
|
||||
if isInterruptContinue(ctx) {
|
||||
return takePartial(streamRecvErr)
|
||||
}
|
||||
if progress != nil {
|
||||
progress("eino_stream_error", streamRecvErr.Error(), map[string]interface{}{
|
||||
"conversationId": conversationID,
|
||||
@@ -600,7 +846,7 @@ func runEinoADKAgentLoop(ctx context.Context, args *einoADKRunLoopArgs, baseMsgs
|
||||
"einoRole": einoRoleTag(ev.AgentName),
|
||||
})
|
||||
}
|
||||
if retErr := handleRunErr(streamRecvErr); retErr != nil {
|
||||
if _, retErr := maybeRetryTransientRun(streamRecvErr); retErr != nil {
|
||||
return takePartial(retErr)
|
||||
}
|
||||
}
|
||||
@@ -612,11 +858,11 @@ func runEinoADKAgentLoop(ctx context.Context, args *einoADKRunLoopArgs, baseMsgs
|
||||
continue
|
||||
}
|
||||
runAccumulatedMsgs = append(runAccumulatedMsgs, msg)
|
||||
tryEmitToolCallsOnce(mergeMessageToolCalls(msg), ev.AgentName, orchestratorName, conversationID, progress, toolEmitSeen, subAgentToolStep, markPending)
|
||||
tryEmitToolCallsOnce(mergeMessageToolCalls(msg), ev.AgentName, orchestratorName, conversationID, orchMode, progress, toolEmitSeen, subAgentToolStep, mainAgentToolStep, markPending)
|
||||
|
||||
if mv.Role == schema.Assistant {
|
||||
if progress != nil && strings.TrimSpace(msg.ReasoningContent) != "" {
|
||||
progress("thinking", strings.TrimSpace(msg.ReasoningContent), map[string]interface{}{
|
||||
progress("reasoning_chain", openai.DisplayReasoningContent(strings.TrimSpace(msg.ReasoningContent)), map[string]interface{}{
|
||||
"conversationId": conversationID,
|
||||
"source": "eino",
|
||||
"einoAgent": ev.AgentName,
|
||||
@@ -627,26 +873,42 @@ func runEinoADKAgentLoop(ctx context.Context, args *einoADKRunLoopArgs, baseMsgs
|
||||
body := strings.TrimSpace(msg.Content)
|
||||
if body != "" {
|
||||
if streamsMainAssistant(ev.AgentName) {
|
||||
if progress != nil {
|
||||
progress("response_start", "", map[string]interface{}{
|
||||
"conversationId": conversationID,
|
||||
"mcpExecutionIds": snapshotMCPIDs(),
|
||||
"messageGeneratedBy": "eino:" + ev.AgentName,
|
||||
"einoRole": "orchestrator",
|
||||
"einoAgent": ev.AgentName,
|
||||
"orchestration": orchMode,
|
||||
})
|
||||
progress("response_delta", body, map[string]interface{}{
|
||||
"conversationId": conversationID,
|
||||
"mcpExecutionIds": snapshotMCPIDs(),
|
||||
"einoRole": "orchestrator",
|
||||
"einoAgent": ev.AgentName,
|
||||
"orchestration": orchMode,
|
||||
})
|
||||
}
|
||||
lastAssistant = body
|
||||
if orchMode == "plan_execute" && strings.EqualFold(strings.TrimSpace(ev.AgentName), "executor") {
|
||||
lastPlanExecuteExecutor = UnwrapPlanExecuteUserText(body)
|
||||
executeStdoutDupMu.Lock()
|
||||
dup := pendingExecuteStdoutDup
|
||||
if dup != "" && body == dup {
|
||||
pendingExecuteStdoutDup = ""
|
||||
executeStdoutDupMu.Unlock()
|
||||
lastAssistant = body
|
||||
if orchMode == "plan_execute" && strings.EqualFold(strings.TrimSpace(ev.AgentName), "executor") {
|
||||
lastPlanExecuteExecutor = UnwrapPlanExecuteUserText(body)
|
||||
}
|
||||
// 非流式:与 execute 输出相同则跳过助手通道展示(msg 已在上方写入 runAccumulatedMsgs)
|
||||
} else {
|
||||
if dup != "" {
|
||||
pendingExecuteStdoutDup = ""
|
||||
}
|
||||
executeStdoutDupMu.Unlock()
|
||||
if progress != nil {
|
||||
progress("response_start", "", map[string]interface{}{
|
||||
"conversationId": conversationID,
|
||||
"mcpExecutionIds": snapshotMCPIDs(),
|
||||
"messageGeneratedBy": "eino:" + ev.AgentName,
|
||||
"einoRole": "orchestrator",
|
||||
"einoAgent": ev.AgentName,
|
||||
"orchestration": orchMode,
|
||||
})
|
||||
progress("response_delta", body, openai.WithSSEAccumulated(map[string]interface{}{
|
||||
"conversationId": conversationID,
|
||||
"mcpExecutionIds": snapshotMCPIDs(),
|
||||
"einoRole": "orchestrator",
|
||||
"einoAgent": ev.AgentName,
|
||||
"orchestration": orchMode,
|
||||
}, body))
|
||||
}
|
||||
lastAssistant = body
|
||||
if orchMode == "plan_execute" && strings.EqualFold(strings.TrimSpace(ev.AgentName), "executor") {
|
||||
lastPlanExecuteExecutor = UnwrapPlanExecuteUserText(body)
|
||||
}
|
||||
}
|
||||
} else if progress != nil {
|
||||
progress("eino_agent_reply", body, map[string]interface{}{
|
||||
@@ -702,12 +964,19 @@ func runEinoADKAgentLoop(ctx context.Context, args *einoADKRunLoopArgs, baseMsgs
|
||||
break
|
||||
}
|
||||
}
|
||||
} else {
|
||||
removePendingByID(toolCallID)
|
||||
}
|
||||
if toolCallID != "" {
|
||||
removePendingByID(toolCallID)
|
||||
if _, loaded := toolResultSent.LoadOrStore(toolCallID, struct{}{}); loaded {
|
||||
// ToolInvokeNotify 可能已推过 tool_result(如 execute 流式包装里 Fire 仅携带截断后的 stdout),
|
||||
// 此处仍应用 ADK Tool 消息中的完整内容刷新去重基准,避免模型复述全文时与截断串比对失败而重复展示「助手输出」。
|
||||
recordPendingExecuteStdoutDup(toolName, content, isErr)
|
||||
continue
|
||||
}
|
||||
data["toolCallId"] = toolCallID
|
||||
}
|
||||
recordPendingExecuteStdoutDup(toolName, content, isErr)
|
||||
recordEinoADKFilesystemToolMonitor(args.FilesystemMonitorAgent, args.FilesystemMonitorRecord, toolName, toolCallID, runAccumulatedMsgs, content, isErr)
|
||||
progress("tool_result", fmt.Sprintf("工具结果 (%s)", toolName), data)
|
||||
}
|
||||
}
|
||||
@@ -717,26 +986,52 @@ func runEinoADKAgentLoop(ctx context.Context, args *einoADKRunLoopArgs, baseMsgs
|
||||
mcpIDsMu.Unlock()
|
||||
|
||||
out := buildEinoRunResultFromAccumulated(
|
||||
orchMode, runAccumulatedMsgs, lastAssistant, lastPlanExecuteExecutor, emptyHint, ids, false,
|
||||
orchMode, runAccumulatedMsgs, persistTraceSource(args, runAccumulatedMsgs),
|
||||
lastAssistant, lastPlanExecuteExecutor, emptyHint, ids, false,
|
||||
)
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func persistTraceSource(args *einoADKRunLoopArgs, fallback []adk.Message) []adk.Message {
|
||||
if args != nil && args.ModelFacingTrace != nil {
|
||||
if snap := args.ModelFacingTrace.Snapshot(); len(snap) > 0 {
|
||||
return snap
|
||||
}
|
||||
}
|
||||
return fallback
|
||||
}
|
||||
|
||||
func einoPartialRunLastOutputHint() string {
|
||||
return "[执行未正常结束(用户停止、超时或异常)。续跑时请基于上文已产生的工具与结果继续,勿重复已完成步骤。]\n" +
|
||||
"[Run ended abnormally; continue from the trace above without repeating completed steps.]"
|
||||
}
|
||||
|
||||
// friendlyEinoExecuteInvokeTail 将 Eino execute 等非 MCP 路径的结尾错误转成简短提示;其它情况保留原 error 文本。
|
||||
func friendlyEinoExecuteInvokeTail(invokeErr error) string {
|
||||
if invokeErr == nil {
|
||||
return ""
|
||||
}
|
||||
if errors.Is(invokeErr, context.DeadlineExceeded) {
|
||||
return einoExecuteTimeoutUserHint()
|
||||
}
|
||||
return "[执行未正常结束] " + invokeErr.Error()
|
||||
}
|
||||
|
||||
func buildEinoRunResultFromAccumulated(
|
||||
orchMode string,
|
||||
runAccumulatedMsgs []adk.Message,
|
||||
persistMsgs []adk.Message,
|
||||
lastAssistant string,
|
||||
lastPlanExecuteExecutor string,
|
||||
emptyHint string,
|
||||
mcpIDs []string,
|
||||
partial bool,
|
||||
) *RunResult {
|
||||
histJSON, _ := json.Marshal(runAccumulatedMsgs)
|
||||
traceForJSON := persistMsgs
|
||||
if len(traceForJSON) == 0 {
|
||||
traceForJSON = runAccumulatedMsgs
|
||||
}
|
||||
histJSON, _ := json.Marshal(traceForJSON)
|
||||
cleaned := strings.TrimSpace(lastAssistant)
|
||||
if orchMode == "plan_execute" {
|
||||
if e := strings.TrimSpace(lastPlanExecuteExecutor); e != "" {
|
||||
@@ -745,6 +1040,11 @@ func buildEinoRunResultFromAccumulated(
|
||||
cleaned = UnwrapPlanExecuteUserText(cleaned)
|
||||
}
|
||||
}
|
||||
if cleaned == "" {
|
||||
if fb := strings.TrimSpace(einoExtractFallbackAssistantFromMsgs(runAccumulatedMsgs)); fb != "" {
|
||||
cleaned = fb
|
||||
}
|
||||
}
|
||||
cleaned = dedupeRepeatedParagraphs(cleaned, 80)
|
||||
cleaned = dedupeParagraphsByLineFingerprint(cleaned, 100)
|
||||
// 防止超长响应导致 JSON 序列化慢或 OOM(多代理拼接大量工具输出时可能触发)。
|
||||
@@ -771,6 +1071,79 @@ func buildEinoRunResultFromAccumulated(
|
||||
return out
|
||||
}
|
||||
|
||||
// einoExtractFallbackAssistantFromMsgs 在「主通道未产出助手正文」时,从 Eino ADK 轨迹中回填用户可见回复。
|
||||
// 典型场景:监督者仅调用 exit(final_result 落在 Tool 消息中),或工具结果已写入历史但 lastAssistant 未更新。
|
||||
//
|
||||
// 优先级:最后一次 exit 工具输出 → 最后一条含 exit 的助手 tool_calls 参数中的 final_result。
|
||||
func einoExtractFallbackAssistantFromMsgs(msgs []adk.Message) string {
|
||||
for i := len(msgs) - 1; i >= 0; i-- {
|
||||
m := msgs[i]
|
||||
if m == nil || m.Role != schema.Tool {
|
||||
continue
|
||||
}
|
||||
if !strings.EqualFold(strings.TrimSpace(m.ToolName), adk.ToolInfoExit.Name) {
|
||||
continue
|
||||
}
|
||||
content := strings.TrimSpace(m.Content)
|
||||
if content == "" || strings.HasPrefix(content, einomcp.ToolErrorPrefix) {
|
||||
continue
|
||||
}
|
||||
return content
|
||||
}
|
||||
for i := len(msgs) - 1; i >= 0; i-- {
|
||||
m := msgs[i]
|
||||
if m == nil || m.Role != schema.Assistant {
|
||||
continue
|
||||
}
|
||||
if s := einoExtractExitFinalFromAssistantToolCalls(m); s != "" {
|
||||
return s
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func einoExtractExitFinalFromAssistantToolCalls(msg *schema.Message) string {
|
||||
if msg == nil || len(msg.ToolCalls) == 0 {
|
||||
return ""
|
||||
}
|
||||
for i := len(msg.ToolCalls) - 1; i >= 0; i-- {
|
||||
tc := msg.ToolCalls[i]
|
||||
if !strings.EqualFold(strings.TrimSpace(tc.Function.Name), adk.ToolInfoExit.Name) {
|
||||
continue
|
||||
}
|
||||
if s := einoParseExitFinalResultArguments(tc.Function.Arguments); s != "" {
|
||||
return s
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func einoParseExitFinalResultArguments(arguments string) string {
|
||||
arguments = strings.TrimSpace(arguments)
|
||||
if arguments == "" {
|
||||
return ""
|
||||
}
|
||||
var wrap struct {
|
||||
FinalResult json.RawMessage `json:"final_result"`
|
||||
}
|
||||
if err := json.Unmarshal([]byte(arguments), &wrap); err != nil || len(wrap.FinalResult) == 0 {
|
||||
return ""
|
||||
}
|
||||
var s string
|
||||
if err := json.Unmarshal(wrap.FinalResult, &s); err == nil {
|
||||
return strings.TrimSpace(s)
|
||||
}
|
||||
var anyVal interface{}
|
||||
if err := json.Unmarshal(wrap.FinalResult, &anyVal); err != nil {
|
||||
return ""
|
||||
}
|
||||
b, err := json.Marshal(anyVal)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
return strings.TrimSpace(string(b))
|
||||
}
|
||||
|
||||
func buildEinoCheckpointID(orchMode string) string {
|
||||
mode := sanitizeEinoPathSegment(strings.TrimSpace(orchMode))
|
||||
if mode == "" {
|
||||
|
||||
@@ -0,0 +1,31 @@
|
||||
package multiagent
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"cyberstrike-ai/internal/agent"
|
||||
"cyberstrike-ai/internal/einomcp"
|
||||
)
|
||||
|
||||
// newEinoExecuteMonitorCallback 在 Eino filesystem execute 结束时写入 MCP 监控库并 recorder(executionId),
|
||||
// 与 CallTool 路径一致,供助手消息展示「渗透测试详情」芯片。
|
||||
func newEinoExecuteMonitorCallback(ag *agent.Agent, recorder einomcp.ExecutionRecorder) func(command, stdout string, success bool, invokeErr error) {
|
||||
return func(command, stdout string, success bool, invokeErr error) {
|
||||
if ag == nil || recorder == nil {
|
||||
return
|
||||
}
|
||||
var err error
|
||||
if !success {
|
||||
if invokeErr != nil {
|
||||
err = invokeErr
|
||||
} else {
|
||||
err = fmt.Errorf("execute failed")
|
||||
}
|
||||
}
|
||||
args := map[string]interface{}{"command": command}
|
||||
id := ag.RecordLocalToolExecution("execute", args, stdout, err)
|
||||
if id != "" {
|
||||
recorder(id)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -2,20 +2,58 @@ package multiagent
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"cyberstrike-ai/internal/einomcp"
|
||||
"cyberstrike-ai/internal/security"
|
||||
|
||||
"github.com/cloudwego/eino/adk/filesystem"
|
||||
"github.com/cloudwego/eino/compose"
|
||||
"github.com/cloudwego/eino/schema"
|
||||
)
|
||||
|
||||
// prependPythonUnbufferedEnv 为 /bin/sh -c 注入 PYTHONUNBUFFERED=1。
|
||||
// eino-ext local 对流式 stdout 使用 bufio 按「行」推送;python3 写管道时默认块缓冲,print 长期留在用户态缓冲,
|
||||
// 管道里收不到换行,表现为长时间无输出直至超时或退出。若命令里已出现 PYTHONUNBUFFERED 则不再覆盖。
|
||||
func prependPythonUnbufferedEnv(shellCommand string) string {
|
||||
if strings.TrimSpace(shellCommand) == "" {
|
||||
return shellCommand
|
||||
}
|
||||
if strings.Contains(strings.ToUpper(shellCommand), "PYTHONUNBUFFERED") {
|
||||
return shellCommand
|
||||
}
|
||||
return "export PYTHONUNBUFFERED=1\n" + shellCommand
|
||||
}
|
||||
|
||||
// einoExecuteTimeoutUserHint 与写入 ADK 工具消息(模型可见)及 SSE tool_result 尾标一致。
|
||||
func einoExecuteTimeoutUserHint() string {
|
||||
return "已超时终止 · Timed out"
|
||||
}
|
||||
|
||||
// einoStreamingShellWrap 包装 Eino filesystem 使用的 StreamingShell(cloudwego eino-ext local.Local)。
|
||||
// 官方 execute 工具默认走 ExecuteStreaming 且不设 RunInBackendGround;末尾带 & 时子进程仍与管道相连,
|
||||
// streamStdout 按行读取会在无换行输出时长时间阻塞(与 MCP 工具 exec 的独立实现不同)。
|
||||
// 对「完全后台」命令自动开启 RunInBackendGround,与 local.runCmdInBackground 行为对齐。
|
||||
//
|
||||
// 使用 Pipe 将内层流转发给调用方:在 inner EOF 后、关闭 Pipe 前同步调用 ToolInvokeNotify.Fire,
|
||||
// 保证 run loop 在模型开始下一轮输出前已记录 execute 结果(用于 UI 与「重复助手复述」去重)。
|
||||
//
|
||||
// 若 inner 在校验阶段直接返回 error(未建立 reader),不会进入下方 goroutine,也必须 Fire;
|
||||
// 否则 pending tool_call 要等整轮 run 结束才被 force-close,与已展示的助手/工具软错误文案不同步。
|
||||
type einoStreamingShellWrap struct {
|
||||
inner filesystem.StreamingShell
|
||||
inner filesystem.StreamingShell
|
||||
invokeNotify *einomcp.ToolInvokeNotifyHolder
|
||||
einoAgentName string
|
||||
// outputChunk 可选;非 nil 时在收到内层 ExecuteResponse 片段时推送,与 MCP 工具的 tool_result_delta 一致(需有效 toolCallId)。
|
||||
outputChunk func(toolName, toolCallID, chunk string)
|
||||
// toolTimeoutMinutes 与 agent.tool_timeout_minutes 对齐;>0 时对单次 execute 套用 context 超时(与 MCP 工具经 executeToolViaMCP 行为一致)。0 表示仅依赖上层 ctx(如整任务 10h 上限)。
|
||||
toolTimeoutMinutes int
|
||||
// recordMonitor 在 execute 流结束后写入 tool_executions 并 recorder(executionId),使「渗透测试详情」与常规 MCP 一致。
|
||||
recordMonitor func(command, stdout string, success bool, invokeErr error)
|
||||
}
|
||||
|
||||
func (w *einoStreamingShellWrap) ExecuteStreaming(ctx context.Context, input *filesystem.ExecuteRequest) (*schema.StreamReader[*filesystem.ExecuteResponse], error) {
|
||||
@@ -26,8 +64,123 @@ func (w *einoStreamingShellWrap) ExecuteStreaming(ctx context.Context, input *fi
|
||||
return w.inner.ExecuteStreaming(ctx, nil)
|
||||
}
|
||||
req := *input
|
||||
userCmd := strings.TrimSpace(req.Command)
|
||||
if security.IsBackgroundShellCommand(req.Command) && !req.RunInBackendGround {
|
||||
req.RunInBackendGround = true
|
||||
}
|
||||
return w.inner.ExecuteStreaming(ctx, &req)
|
||||
req.Command = prependPythonUnbufferedEnv(req.Command)
|
||||
tid := strings.TrimSpace(compose.GetToolCallID(ctx))
|
||||
agentTag := strings.TrimSpace(w.einoAgentName)
|
||||
|
||||
execCtx := ctx
|
||||
var execCancel context.CancelFunc
|
||||
if w.toolTimeoutMinutes > 0 {
|
||||
execCtx, execCancel = context.WithTimeout(ctx, time.Duration(w.toolTimeoutMinutes)*time.Minute)
|
||||
}
|
||||
|
||||
sr, err := w.inner.ExecuteStreaming(execCtx, &req)
|
||||
if err != nil {
|
||||
if execCancel != nil {
|
||||
execCancel()
|
||||
}
|
||||
if w.recordMonitor != nil {
|
||||
w.recordMonitor(userCmd, "", false, err)
|
||||
}
|
||||
if w.invokeNotify != nil && tid != "" {
|
||||
w.invokeNotify.Fire(tid, "execute", agentTag, false, "", err)
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
if sr == nil || w.invokeNotify == nil || tid == "" {
|
||||
if execCancel != nil {
|
||||
execCancel()
|
||||
}
|
||||
return sr, nil
|
||||
}
|
||||
|
||||
outR, outW := schema.Pipe[*filesystem.ExecuteResponse](32)
|
||||
|
||||
go func(inner *schema.StreamReader[*filesystem.ExecuteResponse], command string, cancel context.CancelFunc, tctx context.Context) {
|
||||
defer inner.Close()
|
||||
if cancel != nil {
|
||||
defer cancel()
|
||||
}
|
||||
|
||||
var sb strings.Builder
|
||||
const maxCapture = 16 * 1024
|
||||
success := true
|
||||
var invokeErr error
|
||||
exitCode := 0
|
||||
hasExitCode := false
|
||||
|
||||
for {
|
||||
resp, rerr := inner.Recv()
|
||||
if errors.Is(rerr, io.EOF) {
|
||||
break
|
||||
}
|
||||
if rerr != nil {
|
||||
success = false
|
||||
invokeErr = rerr
|
||||
_ = outW.Send(nil, rerr)
|
||||
break
|
||||
}
|
||||
if resp != nil {
|
||||
if resp.ExitCode != nil {
|
||||
hasExitCode = true
|
||||
exitCode = *resp.ExitCode
|
||||
}
|
||||
var appended string
|
||||
if remain := maxCapture - sb.Len(); remain > 0 {
|
||||
out := resp.Output
|
||||
if len(out) > remain {
|
||||
out = out[:remain]
|
||||
}
|
||||
sb.WriteString(out)
|
||||
appended = out
|
||||
}
|
||||
// 仅推送写入 sb 的片段,与末尾 Fire/recordMonitor 的截断累计一致,避免最终 tool_result 短于已展示增量。
|
||||
if w.outputChunk != nil && strings.TrimSpace(appended) != "" {
|
||||
w.outputChunk("execute", tid, appended)
|
||||
}
|
||||
if outW.Send(resp, nil) {
|
||||
success = false
|
||||
invokeErr = fmt.Errorf("execute stream closed by consumer")
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if success && hasExitCode && exitCode != 0 {
|
||||
success = false
|
||||
invokeErr = fmt.Errorf("execute exited with code %d", exitCode)
|
||||
}
|
||||
// WithTimeout 触发后,子进程常被信号结束,local 侧多报 exit -1 / canceled,错误链里不一定带 DeadlineExceeded。
|
||||
// 用执行所用 ctx 归一化,便于 UI 展示「超时」而非含糊的 -1。
|
||||
if tctx != nil && errors.Is(tctx.Err(), context.DeadlineExceeded) {
|
||||
success = false
|
||||
invokeErr = context.DeadlineExceeded
|
||||
}
|
||||
// ADK 从本 Pipe 拼出 tool 消息正文;仅 Notify 尾标不会进入模型上下文。超时句写入流,与 UI 一致。
|
||||
if invokeErr != nil && errors.Is(invokeErr, context.DeadlineExceeded) {
|
||||
hint := "\n\n" + einoExecuteTimeoutUserHint() + "\n"
|
||||
_ = outW.Send(&filesystem.ExecuteResponse{Output: hint}, nil)
|
||||
if w.outputChunk != nil && tid != "" {
|
||||
w.outputChunk("execute", tid, hint)
|
||||
}
|
||||
if remain := maxCapture - sb.Len(); remain > 0 {
|
||||
h := hint
|
||||
if len(h) > remain {
|
||||
h = h[:remain]
|
||||
}
|
||||
sb.WriteString(h)
|
||||
}
|
||||
}
|
||||
if w.recordMonitor != nil {
|
||||
w.recordMonitor(command, sb.String(), success, invokeErr)
|
||||
}
|
||||
w.invokeNotify.Fire(tid, "execute", agentTag, success, sb.String(), invokeErr)
|
||||
outW.Close()
|
||||
}(sr, userCmd, execCancel, execCtx)
|
||||
|
||||
return outR, nil
|
||||
}
|
||||
|
||||
@@ -0,0 +1,62 @@
|
||||
package multiagent
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/cloudwego/eino/schema"
|
||||
)
|
||||
|
||||
func TestEinoExtractFallbackAssistantFromMsgs_exitToolMessage(t *testing.T) {
|
||||
u := schema.UserMessage("hi")
|
||||
tm := schema.ToolMessage("answer for user", "call-exit-1")
|
||||
tm.ToolName = "exit"
|
||||
if got := einoExtractFallbackAssistantFromMsgs([]*schema.Message{u, tm}); got != "answer for user" {
|
||||
t.Fatalf("got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEinoExtractFallbackAssistantFromMsgs_lastExitWins(t *testing.T) {
|
||||
msgs := []*schema.Message{
|
||||
schema.UserMessage("hi"),
|
||||
toolExitMsg("first", "c1"),
|
||||
toolExitMsg("second", "c2"),
|
||||
}
|
||||
if got := einoExtractFallbackAssistantFromMsgs(msgs); got != "second" {
|
||||
t.Fatalf("got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEinoExtractFallbackAssistantFromMsgs_fromAssistantToolCalls(t *testing.T) {
|
||||
m := schema.AssistantMessage("", []schema.ToolCall{{
|
||||
ID: "x",
|
||||
Type: "function",
|
||||
Function: schema.FunctionCall{
|
||||
Name: "exit",
|
||||
Arguments: `{"final_result":"from args"}`,
|
||||
},
|
||||
}})
|
||||
if got := einoExtractFallbackAssistantFromMsgs([]*schema.Message{m}); got != "from args" {
|
||||
t.Fatalf("got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEinoExtractFallbackAssistantFromMsgs_prefersToolOverEarlierAssistant(t *testing.T) {
|
||||
asst := schema.AssistantMessage("", []schema.ToolCall{{
|
||||
ID: "x",
|
||||
Type: "function",
|
||||
Function: schema.FunctionCall{
|
||||
Name: "exit",
|
||||
Arguments: `{"final_result":"from args"}`,
|
||||
},
|
||||
}})
|
||||
tool := toolExitMsg("from tool", "c1")
|
||||
if got := einoExtractFallbackAssistantFromMsgs([]*schema.Message{asst, tool}); got != "from tool" {
|
||||
t.Fatalf("got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func toolExitMsg(content, callID string) *schema.Message {
|
||||
m := schema.ToolMessage(content, callID)
|
||||
m.ToolName = "exit"
|
||||
return m
|
||||
}
|
||||
@@ -0,0 +1,101 @@
|
||||
package multiagent
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"strings"
|
||||
|
||||
"cyberstrike-ai/internal/agent"
|
||||
"cyberstrike-ai/internal/einomcp"
|
||||
|
||||
"github.com/cloudwego/eino/adk"
|
||||
"github.com/cloudwego/eino/schema"
|
||||
)
|
||||
|
||||
// einoADKFilesystemToolNames 与 cloudwego/eino/adk/middlewares/filesystem 默认 ToolName* 一致。
|
||||
// execute 已由 eino_execute_monitor 落库,此处不包含。
|
||||
var einoADKFilesystemToolNames = map[string]struct{}{
|
||||
"ls": {},
|
||||
"read_file": {},
|
||||
"write_file": {},
|
||||
"edit_file": {},
|
||||
"glob": {},
|
||||
"grep": {},
|
||||
}
|
||||
|
||||
func isBuiltinEinoADKFilesystemToolName(name string) bool {
|
||||
n := strings.ToLower(strings.TrimSpace(name))
|
||||
_, ok := einoADKFilesystemToolNames[n]
|
||||
return ok
|
||||
}
|
||||
|
||||
func toolCallArgsFromAccumulated(msgs []adk.Message, toolCallID, expectToolName string) map[string]interface{} {
|
||||
tid := strings.TrimSpace(toolCallID)
|
||||
expect := strings.TrimSpace(expectToolName)
|
||||
for i := len(msgs) - 1; i >= 0; i-- {
|
||||
m := msgs[i]
|
||||
if m == nil || m.Role != schema.Assistant || len(m.ToolCalls) == 0 {
|
||||
continue
|
||||
}
|
||||
for j := len(m.ToolCalls) - 1; j >= 0; j-- {
|
||||
tc := m.ToolCalls[j]
|
||||
if tid != "" && strings.TrimSpace(tc.ID) != tid {
|
||||
continue
|
||||
}
|
||||
fn := strings.TrimSpace(tc.Function.Name)
|
||||
if expect != "" && !strings.EqualFold(fn, expect) {
|
||||
continue
|
||||
}
|
||||
raw := strings.TrimSpace(tc.Function.Arguments)
|
||||
if raw == "" {
|
||||
return map[string]interface{}{}
|
||||
}
|
||||
var args map[string]interface{}
|
||||
if err := json.Unmarshal([]byte(raw), &args); err != nil {
|
||||
return map[string]interface{}{"arguments_raw": raw}
|
||||
}
|
||||
if args == nil {
|
||||
return map[string]interface{}{}
|
||||
}
|
||||
return args
|
||||
}
|
||||
}
|
||||
return map[string]interface{}{}
|
||||
}
|
||||
|
||||
// recordEinoADKFilesystemToolMonitor 将 Eino ADK filesystem 中间件工具结果写入 MCP 监控(与 execute / MCP 桥芯片一致)。
|
||||
func recordEinoADKFilesystemToolMonitor(
|
||||
ag *agent.Agent,
|
||||
rec einomcp.ExecutionRecorder,
|
||||
toolName string,
|
||||
toolCallID string,
|
||||
msgs []adk.Message,
|
||||
resultText string,
|
||||
isErr bool,
|
||||
) {
|
||||
if ag == nil || rec == nil {
|
||||
return
|
||||
}
|
||||
name := strings.TrimSpace(toolName)
|
||||
if name == "" || strings.EqualFold(name, "execute") {
|
||||
return
|
||||
}
|
||||
if !isBuiltinEinoADKFilesystemToolName(name) {
|
||||
return
|
||||
}
|
||||
args := toolCallArgsFromAccumulated(msgs, toolCallID, name)
|
||||
storedName := "eino_fs::" + strings.ToLower(name)
|
||||
var invErr error
|
||||
if isErr {
|
||||
t := strings.TrimSpace(resultText)
|
||||
if t == "" {
|
||||
invErr = errors.New("tool error")
|
||||
} else {
|
||||
invErr = errors.New(t)
|
||||
}
|
||||
}
|
||||
id := ag.RecordLocalToolExecution(storedName, args, resultText, invErr)
|
||||
if id != "" {
|
||||
rec(id)
|
||||
}
|
||||
}
|
||||
@@ -161,6 +161,8 @@ func buildReductionMiddleware(ctx context.Context, mw config.MultiAgentEinoMiddl
|
||||
}
|
||||
|
||||
// prependEinoMiddlewares returns handlers to prepend (outermost first) and optionally replaces tools when tool_search is used.
|
||||
// toolSearchActive is true when the toolsearch middleware was mounted (dynamic tools split off); callers should pass this to
|
||||
// injectToolNamesOnlyInstruction — tool_search is not part of the pre-middleware tools list, so name-scanning alone cannot detect it.
|
||||
func prependEinoMiddlewares(
|
||||
ctx context.Context,
|
||||
mw *config.MultiAgentEinoMiddlewareConfig,
|
||||
@@ -170,16 +172,16 @@ func prependEinoMiddlewares(
|
||||
skillsRoot string,
|
||||
conversationID string,
|
||||
logger *zap.Logger,
|
||||
) (outTools []tool.BaseTool, extraHandlers []adk.ChatModelAgentMiddleware, err error) {
|
||||
) (outTools []tool.BaseTool, extraHandlers []adk.ChatModelAgentMiddleware, toolSearchActive bool, err error) {
|
||||
if mw == nil {
|
||||
return tools, nil, nil
|
||||
return tools, nil, false, nil
|
||||
}
|
||||
outTools = tools
|
||||
|
||||
if mw.PatchToolCallsEffective() {
|
||||
patchMW, perr := patchtoolcalls.New(ctx, &patchtoolcalls.Config{})
|
||||
if perr != nil {
|
||||
return nil, nil, fmt.Errorf("patchtoolcalls: %w", perr)
|
||||
return nil, nil, false, fmt.Errorf("patchtoolcalls: %w", perr)
|
||||
}
|
||||
extraHandlers = append(extraHandlers, patchMW)
|
||||
}
|
||||
@@ -190,7 +192,7 @@ func prependEinoMiddlewares(
|
||||
} else {
|
||||
redMW, rerr := buildReductionMiddleware(ctx, *mw, conversationID, einoLoc, logger)
|
||||
if rerr != nil {
|
||||
return nil, nil, rerr
|
||||
return nil, nil, false, rerr
|
||||
}
|
||||
extraHandlers = append(extraHandlers, redMW)
|
||||
}
|
||||
@@ -209,10 +211,11 @@ func prependEinoMiddlewares(
|
||||
if split && len(dynamic) > 0 {
|
||||
ts, terr := toolsearch.New(ctx, &toolsearch.Config{DynamicTools: dynamic})
|
||||
if terr != nil {
|
||||
return nil, nil, fmt.Errorf("toolsearch: %w", terr)
|
||||
return nil, nil, false, fmt.Errorf("toolsearch: %w", terr)
|
||||
}
|
||||
extraHandlers = append(extraHandlers, ts)
|
||||
outTools = static
|
||||
toolSearchActive = true
|
||||
if logger != nil {
|
||||
logger.Info("eino middleware: tool_search enabled",
|
||||
zap.Int("static_tools", len(static)),
|
||||
@@ -233,12 +236,12 @@ func prependEinoMiddlewares(
|
||||
}
|
||||
baseDir := filepath.Join(skillsRoot, rel, sanitizeEinoPathSegment(conversationID))
|
||||
if mk := os.MkdirAll(baseDir, 0o755); mk != nil {
|
||||
return nil, nil, fmt.Errorf("plantask mkdir: %w", mk)
|
||||
return nil, nil, toolSearchActive, fmt.Errorf("plantask mkdir: %w", mk)
|
||||
}
|
||||
ptBE := &localPlantaskBackend{Local: einoLoc}
|
||||
pt, perr := plantask.New(ctx, &plantask.Config{Backend: ptBE, BaseDir: baseDir})
|
||||
if perr != nil {
|
||||
return nil, nil, fmt.Errorf("plantask: %w", perr)
|
||||
return nil, nil, toolSearchActive, fmt.Errorf("plantask: %w", perr)
|
||||
}
|
||||
extraHandlers = append(extraHandlers, pt)
|
||||
if logger != nil {
|
||||
@@ -247,7 +250,7 @@ func prependEinoMiddlewares(
|
||||
}
|
||||
}
|
||||
|
||||
return outTools, extraHandlers, nil
|
||||
return outTools, extraHandlers, toolSearchActive, nil
|
||||
}
|
||||
|
||||
func deepExtrasFromConfig(ma *config.MultiAgentConfig) (outputKey string, retry *adk.ModelRetryConfig, taskDesc func(context.Context, []adk.Agent) (string, error)) {
|
||||
|
||||
@@ -0,0 +1,84 @@
|
||||
package multiagent
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"sync"
|
||||
|
||||
"github.com/cloudwego/eino/adk"
|
||||
)
|
||||
|
||||
// modelFacingTraceHolder 保存「即将送入 ChatModel」的消息快照(已走 summarization / reduction / orphan 修剪等),
|
||||
// 用于 last_react_input 落库,使续跑与「上下文压缩后」的模型视角一致,而非仅依赖事件流 append 的 runAccumulatedMsgs。
|
||||
type modelFacingTraceHolder struct {
|
||||
mu sync.Mutex
|
||||
// msgs 为深拷贝后的切片,避免框架后续原地修改污染快照
|
||||
msgs []adk.Message
|
||||
}
|
||||
|
||||
func newModelFacingTraceHolder() *modelFacingTraceHolder {
|
||||
return &modelFacingTraceHolder{}
|
||||
}
|
||||
|
||||
// Snapshot 返回当前快照的再一次深拷贝(供序列化落库,避免与 holder 互斥长期持锁)。
|
||||
func (h *modelFacingTraceHolder) Snapshot() []adk.Message {
|
||||
if h == nil {
|
||||
return nil
|
||||
}
|
||||
h.mu.Lock()
|
||||
defer h.mu.Unlock()
|
||||
return cloneADKMessagesForTrace(h.msgs)
|
||||
}
|
||||
|
||||
func (h *modelFacingTraceHolder) storeFromState(state *adk.ChatModelAgentState) {
|
||||
if h == nil || state == nil || len(state.Messages) == 0 {
|
||||
return
|
||||
}
|
||||
cloned := cloneADKMessagesForTrace(state.Messages)
|
||||
if len(cloned) == 0 {
|
||||
return
|
||||
}
|
||||
h.mu.Lock()
|
||||
h.msgs = cloned
|
||||
h.mu.Unlock()
|
||||
}
|
||||
|
||||
func cloneADKMessagesForTrace(msgs []adk.Message) []adk.Message {
|
||||
if len(msgs) == 0 {
|
||||
return nil
|
||||
}
|
||||
b, err := json.Marshal(msgs)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
var out []adk.Message
|
||||
if err := json.Unmarshal(b, &out); err != nil {
|
||||
return nil
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// modelFacingTraceMiddleware 必须在 Handlers 链中处于 **BeforeModel 最后**(telemetry 之后),
|
||||
// 此时 state.Messages 即为本次 LLM 调用的最终入参。
|
||||
type modelFacingTraceMiddleware struct {
|
||||
adk.BaseChatModelAgentMiddleware
|
||||
holder *modelFacingTraceHolder
|
||||
}
|
||||
|
||||
func newModelFacingTraceMiddleware(holder *modelFacingTraceHolder) adk.ChatModelAgentMiddleware {
|
||||
if holder == nil {
|
||||
return nil
|
||||
}
|
||||
return &modelFacingTraceMiddleware{holder: holder}
|
||||
}
|
||||
|
||||
func (m *modelFacingTraceMiddleware) BeforeModelRewriteState(
|
||||
ctx context.Context,
|
||||
state *adk.ChatModelAgentState,
|
||||
mc *adk.ModelContext,
|
||||
) (context.Context, *adk.ChatModelAgentState, error) {
|
||||
if m.holder != nil && state != nil {
|
||||
m.holder.storeFromState(state)
|
||||
}
|
||||
return ctx, state, nil
|
||||
}
|
||||
@@ -41,6 +41,8 @@ type PlanExecuteRootArgs struct {
|
||||
FilesystemMiddleware adk.ChatModelAgentMiddleware
|
||||
// PlannerReplannerRewriteHandlers applies BeforeModelRewriteState pipeline for planner/replanner input.
|
||||
PlannerReplannerRewriteHandlers []adk.ChatModelAgentMiddleware
|
||||
// ModelFacingTrace 可选:由 Executor Handlers 链末尾写入,供 last_react 与 summarization 后上下文对齐。
|
||||
ModelFacingTrace *modelFacingTraceHolder
|
||||
}
|
||||
|
||||
// NewPlanExecuteRoot 返回 plan → execute → replan 预置编排根节点(与 Deep / Supervisor 并列)。
|
||||
@@ -101,6 +103,11 @@ func NewPlanExecuteRoot(ctx context.Context, a *PlanExecuteRootArgs) (adk.Resuma
|
||||
if teleMw := newEinoModelInputTelemetryMiddleware(a.Logger, a.ModelName, a.ConversationID, "plan_execute_executor"); teleMw != nil {
|
||||
execHandlers = append(execHandlers, teleMw)
|
||||
}
|
||||
if a.ModelFacingTrace != nil {
|
||||
if capMw := newModelFacingTraceMiddleware(a.ModelFacingTrace); capMw != nil {
|
||||
execHandlers = append(execHandlers, capMw)
|
||||
}
|
||||
}
|
||||
executor, err := newPlanExecuteExecutor(ctx, &planexecute.ExecutorConfig{
|
||||
Model: a.ExecModel,
|
||||
ToolsConfig: a.ToolsCfg,
|
||||
|
||||
@@ -13,11 +13,11 @@ import (
|
||||
"cyberstrike-ai/internal/config"
|
||||
"cyberstrike-ai/internal/einomcp"
|
||||
"cyberstrike-ai/internal/openai"
|
||||
"cyberstrike-ai/internal/reasoning"
|
||||
|
||||
einoopenai "github.com/cloudwego/eino-ext/components/model/openai"
|
||||
"github.com/cloudwego/eino/adk"
|
||||
"github.com/cloudwego/eino/compose"
|
||||
"github.com/cloudwego/eino/schema"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
@@ -37,6 +37,7 @@ func RunEinoSingleChatModelAgent(
|
||||
history []agent.ChatMessage,
|
||||
roleTools []string,
|
||||
progress func(eventType, message string, data interface{}),
|
||||
reasoningClient *reasoning.ClientIntent,
|
||||
) (*RunResult, error) {
|
||||
if appCfg == nil || ag == nil {
|
||||
return nil, fmt.Errorf("eino single: 配置或 Agent 为空")
|
||||
@@ -86,13 +87,15 @@ func RunEinoSingleChatModelAgent(
|
||||
})
|
||||
}
|
||||
|
||||
toolInvokeNotify := einomcp.NewToolInvokeNotifyHolder()
|
||||
einoExecMonitor := newEinoExecuteMonitorCallback(ag, recorder)
|
||||
mainDefs := ag.ToolsForRole(roleTools)
|
||||
mainTools, err := einomcp.ToolsFromDefinitions(ag, holder, mainDefs, recorder, toolOutputChunk)
|
||||
mainTools, err := einomcp.ToolsFromDefinitions(ag, holder, mainDefs, recorder, toolOutputChunk, toolInvokeNotify, einoSingleAgentName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
mainToolsForCfg, mainOrchestratorPre, err := prependEinoMiddlewares(ctx, &ma.EinoMiddleware, einoMWMain, mainTools, einoLoc, skillsRoot, conversationID, logger)
|
||||
mainToolsForCfg, mainOrchestratorPre, singleToolSearchActive, err := prependEinoMiddlewares(ctx, &ma.EinoMiddleware, einoMWMain, mainTools, einoLoc, skillsRoot, conversationID, logger)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("eino single eino 中间件: %w", err)
|
||||
}
|
||||
@@ -119,6 +122,7 @@ func RunEinoSingleChatModelAgent(
|
||||
Model: appCfg.OpenAI.Model,
|
||||
HTTPClient: httpClient,
|
||||
}
|
||||
reasoning.ApplyToEinoChatModelConfig(baseModelCfg, &appCfg.OpenAI, reasoningClient)
|
||||
|
||||
mainModel, err := einoopenai.NewChatModel(ctx, baseModelCfg)
|
||||
if err != nil {
|
||||
@@ -130,13 +134,15 @@ func RunEinoSingleChatModelAgent(
|
||||
return nil, fmt.Errorf("eino single summarization: %w", err)
|
||||
}
|
||||
|
||||
handlers := make([]adk.ChatModelAgentMiddleware, 0, 4)
|
||||
modelFacingTrace := newModelFacingTraceHolder()
|
||||
|
||||
handlers := make([]adk.ChatModelAgentMiddleware, 0, 8)
|
||||
if len(mainOrchestratorPre) > 0 {
|
||||
handlers = append(handlers, mainOrchestratorPre...)
|
||||
}
|
||||
if einoSkillMW != nil {
|
||||
if einoFSTools && einoLoc != nil {
|
||||
fsMw, fsErr := subAgentFilesystemMiddleware(ctx, einoLoc)
|
||||
fsMw, fsErr := subAgentFilesystemMiddleware(ctx, einoLoc, toolInvokeNotify, einoSingleAgentName, einoExecMonitor, agentToolTimeoutMinutes(appCfg), toolOutputChunk)
|
||||
if fsErr != nil {
|
||||
return nil, fmt.Errorf("eino single filesystem 中间件: %w", fsErr)
|
||||
}
|
||||
@@ -148,6 +154,9 @@ func RunEinoSingleChatModelAgent(
|
||||
if teleMw := newEinoModelInputTelemetryMiddleware(logger, appCfg.OpenAI.Model, conversationID, "eino_single"); teleMw != nil {
|
||||
handlers = append(handlers, teleMw)
|
||||
}
|
||||
if capMw := newModelFacingTraceMiddleware(modelFacingTrace); capMw != nil {
|
||||
handlers = append(handlers, capMw)
|
||||
}
|
||||
|
||||
maxIter := ma.MaxIteration
|
||||
if maxIter <= 0 {
|
||||
@@ -162,28 +171,21 @@ func RunEinoSingleChatModelAgent(
|
||||
Tools: mainToolsForCfg,
|
||||
UnknownToolsHandler: einomcp.UnknownToolReminderHandler(),
|
||||
ToolCallMiddlewares: []compose.ToolMiddleware{
|
||||
{Invokable: hitlToolCallMiddleware()},
|
||||
{Invokable: softRecoveryToolCallMiddleware()},
|
||||
hitlToolCallMiddleware(),
|
||||
softRecoveryToolMiddleware(),
|
||||
},
|
||||
},
|
||||
EmitInternalEvents: true,
|
||||
}
|
||||
ins := injectToolNamesOnlyInstruction(ctx, ag.EinoSingleAgentSystemInstruction(), mainTools)
|
||||
ins := injectToolNamesOnlyInstruction(ctx, ag.EinoSingleAgentSystemInstruction(), mainTools, singleToolSearchActive)
|
||||
if logger != nil {
|
||||
names := collectToolNames(ctx, mainTools)
|
||||
mountedNames := collectToolNames(ctx, mainToolsForCfg)
|
||||
hasToolSearch := false
|
||||
for _, n := range names {
|
||||
if strings.EqualFold(strings.TrimSpace(n), "tool_search") {
|
||||
hasToolSearch = true
|
||||
break
|
||||
}
|
||||
}
|
||||
logger.Info("eino tool-name injection",
|
||||
zap.String("scope", "eino_single"),
|
||||
zap.Int("tool_names", len(names)),
|
||||
zap.Int("mounted_tool_names", len(mountedNames)),
|
||||
zap.Bool("has_tool_search", hasToolSearch),
|
||||
zap.Bool("tool_search_middleware", singleToolSearchActive),
|
||||
)
|
||||
}
|
||||
|
||||
@@ -210,7 +212,7 @@ func RunEinoSingleChatModelAgent(
|
||||
}
|
||||
|
||||
baseMsgs := historyToMessages(history, appCfg, &ma.EinoMiddleware)
|
||||
baseMsgs = append(baseMsgs, schema.UserMessage(userMessage))
|
||||
baseMsgs = appendUserMessageIfNeeded(baseMsgs, userMessage)
|
||||
|
||||
streamsMainAssistant := func(agent string) bool {
|
||||
return agent == "" || agent == einoSingleAgentName
|
||||
@@ -221,18 +223,25 @@ func RunEinoSingleChatModelAgent(
|
||||
}
|
||||
|
||||
return runEinoADKAgentLoop(ctx, &einoADKRunLoopArgs{
|
||||
OrchMode: "eino_single",
|
||||
OrchestratorName: einoSingleAgentName,
|
||||
ConversationID: conversationID,
|
||||
Progress: progress,
|
||||
Logger: logger,
|
||||
SnapshotMCPIDs: snapshotMCPIDs,
|
||||
StreamsMainAssistant: streamsMainAssistant,
|
||||
EinoRoleTag: einoRoleTag,
|
||||
CheckpointDir: ma.EinoMiddleware.CheckpointDir,
|
||||
McpIDsMu: &mcpIDsMu,
|
||||
McpIDs: &mcpIDs,
|
||||
DA: chatAgent,
|
||||
OrchMode: "eino_single",
|
||||
OrchestratorName: einoSingleAgentName,
|
||||
ConversationID: conversationID,
|
||||
Progress: progress,
|
||||
Logger: logger,
|
||||
SnapshotMCPIDs: snapshotMCPIDs,
|
||||
StreamsMainAssistant: streamsMainAssistant,
|
||||
EinoRoleTag: einoRoleTag,
|
||||
CheckpointDir: ma.EinoMiddleware.CheckpointDir,
|
||||
RunRetryMaxAttempts: ma.EinoMiddleware.RunRetryMaxAttempts,
|
||||
RunRetryMaxBackoffSec: ma.EinoMiddleware.RunRetryMaxBackoffSec,
|
||||
McpIDsMu: &mcpIDsMu,
|
||||
McpIDs: &mcpIDs,
|
||||
FilesystemMonitorAgent: ag,
|
||||
FilesystemMonitorRecord: recorder,
|
||||
ToolInvokeNotify: toolInvokeNotify,
|
||||
DA: chatAgent,
|
||||
ModelFacingTrace: modelFacingTrace,
|
||||
EinoCallbacks: &ma.EinoCallbacks,
|
||||
EmptyResponseMessage: "(Eino ADK single-agent session completed but no assistant text was captured. Check process details or logs.) " +
|
||||
"(Eino ADK 单代理会话已完成,但未捕获到助手文本输出。请查看过程详情或日志。)",
|
||||
}, baseMsgs)
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"strings"
|
||||
|
||||
"cyberstrike-ai/internal/config"
|
||||
"cyberstrike-ai/internal/einomcp"
|
||||
|
||||
localbk "github.com/cloudwego/eino-ext/adk/backend/local"
|
||||
"github.com/cloudwego/eino/adk"
|
||||
@@ -75,12 +76,35 @@ func prepareEinoSkills(
|
||||
// subAgentFilesystemMiddleware returns filesystem middleware for a sub-agent when Deep itself
|
||||
// does not set Backend (fsTools false on orchestrator) but we still want tools on subs — not used;
|
||||
// when orchestrator has Backend, builtin FS is only on outer agent; subs need explicit FS for parity.
|
||||
func subAgentFilesystemMiddleware(ctx context.Context, loc *localbk.Local) (adk.ChatModelAgentMiddleware, error) {
|
||||
func subAgentFilesystemMiddleware(
|
||||
ctx context.Context,
|
||||
loc *localbk.Local,
|
||||
invokeNotify *einomcp.ToolInvokeNotifyHolder,
|
||||
einoAgentName string,
|
||||
recordMonitor func(command, stdout string, success bool, invokeErr error),
|
||||
toolTimeoutMinutes int,
|
||||
outputChunk func(toolName, toolCallID, chunk string),
|
||||
) (adk.ChatModelAgentMiddleware, error) {
|
||||
if loc == nil {
|
||||
return nil, nil
|
||||
}
|
||||
return filesystem.New(ctx, &filesystem.MiddlewareConfig{
|
||||
Backend: loc,
|
||||
StreamingShell: &einoStreamingShellWrap{inner: loc},
|
||||
Backend: loc,
|
||||
StreamingShell: &einoStreamingShellWrap{
|
||||
inner: loc,
|
||||
invokeNotify: invokeNotify,
|
||||
einoAgentName: strings.TrimSpace(einoAgentName),
|
||||
outputChunk: outputChunk,
|
||||
recordMonitor: recordMonitor,
|
||||
toolTimeoutMinutes: toolTimeoutMinutes,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
// agentToolTimeoutMinutes 返回 agent.tool_timeout_minutes(与 executeToolViaMCP 一致);cfg 为 nil 时 0。
|
||||
func agentToolTimeoutMinutes(cfg *config.Config) int {
|
||||
if cfg == nil {
|
||||
return 0
|
||||
}
|
||||
return cfg.Agent.ToolTimeoutMinutes
|
||||
}
|
||||
|
||||
@@ -214,7 +214,7 @@ func summarizeFinalizeWithRecentAssistantToolTrail(
|
||||
selectedCount++
|
||||
}
|
||||
|
||||
// 还原时间顺序
|
||||
// 还原时间顺序。round 内为原始 *schema.Message 指针,保留 ReasoningContent(DeepSeek 工具续跑所必需)。
|
||||
selectedMsgs := make([]adk.Message, 0, 8)
|
||||
for i := len(selectedRoundsReverse) - 1; i >= 0; i-- {
|
||||
selectedMsgs = append(selectedMsgs, selectedRoundsReverse[i].messages...)
|
||||
|
||||
@@ -9,34 +9,43 @@ import (
|
||||
|
||||
// injectToolNamesOnlyInstruction prepends a compact tool-name-only section into
|
||||
// the system instruction so the model can reference current callable names.
|
||||
func injectToolNamesOnlyInstruction(ctx context.Context, instruction string, tools []tool.BaseTool) string {
|
||||
// toolSearchMiddlewareActive must be true when prependEinoMiddlewares mounted toolsearch (dynamic tools); do not infer this
|
||||
// by scanning tool names — tool_search is injected by middleware and is usually absent from the pre-split tools list.
|
||||
func injectToolNamesOnlyInstruction(ctx context.Context, instruction string, tools []tool.BaseTool, toolSearchMiddlewareActive bool) string {
|
||||
names := collectToolNames(ctx, tools)
|
||||
if len(names) == 0 {
|
||||
return strings.TrimSpace(instruction)
|
||||
}
|
||||
hasToolSearch := false
|
||||
for _, n := range names {
|
||||
if strings.EqualFold(strings.TrimSpace(n), "tool_search") {
|
||||
hasToolSearch = true
|
||||
break
|
||||
hasToolSearch := toolSearchMiddlewareActive
|
||||
if !hasToolSearch {
|
||||
for _, n := range names {
|
||||
if strings.EqualFold(strings.TrimSpace(n), "tool_search") {
|
||||
hasToolSearch = true
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var sb strings.Builder
|
||||
sb.WriteString("以下是当前会话中可调用的工具名称列表(仅名称,无参数定义):\n")
|
||||
sb.WriteString("以下是当前会话绑定的工具名称索引(仅名称,无参数 JSON Schema)。\n")
|
||||
sb.WriteString("说明:若启用了 tool_search,则列表里可能含「非常驻」工具——它们不一定出现在当前轮次下发给模型的工具定义中;在未看到该工具的完整 schema 前,禁止凭名称臆测参数。\n")
|
||||
for _, name := range names {
|
||||
sb.WriteString("- ")
|
||||
sb.WriteString(name)
|
||||
sb.WriteByte('\n')
|
||||
}
|
||||
sb.WriteString("\n使用规则:\n")
|
||||
sb.WriteString("1) 上述仅为名称列表,不包含参数定义。\n")
|
||||
sb.WriteString("1) 上表仅为名称索引,不含参数定义。禁止猜测参数名、类型、枚举取值或是否必填。\n")
|
||||
if hasToolSearch {
|
||||
sb.WriteString("2) 在调用具体工具前,应先使用 tool_search 查看工具详情与参数要求,再发起调用。\n")
|
||||
sb.WriteString("【强制 / 最高优先级】本会话已启用 tool_search(动态工具池)。凡名称索引里出现、但你在「当前请求所附 tools 定义」中看不到其完整参数 schema 的工具,一律必须先调用 tool_search;为省 token 或赶进度而跳过 tool_search、直接调用业务工具,属于明确禁止的错误流程。\n")
|
||||
sb.WriteString("2) 默认策略:只要对目标工具的参数定义有任何不确定,就先 tool_search;宁可多一次 tool_search,也不要在未见 schema 时盲调业务工具。\n")
|
||||
sb.WriteString("3) 调用顺序:先 tool_search(唯一必填参数 regex_pattern:按工具名匹配的正则,如子串 nuclei 或 ^exact_tool_name$)→ 在后续轮次确认目标工具已出现在 tools 列表且已阅读其 schema → 再发起对该工具的真实调用。\n")
|
||||
sb.WriteString("4) tool_search 的返回仅为匹配到的工具名列表;schema 在解锁后的下一轮才会下发。禁止在 schema 未出现时编造 JSON 参数。\n")
|
||||
sb.WriteString("5) 不要臆造不存在的工具名。\n\n")
|
||||
} else {
|
||||
sb.WriteString("2) 调用具体工具前,请先确认该工具的参数要求;不确定时先澄清再调用。\n")
|
||||
sb.WriteString("2) 调用具体工具前,请先确认该工具的参数要求(以当前请求中的工具定义为准);不确定时先澄清再调用。\n")
|
||||
sb.WriteString("3) 不要臆造不存在的工具名。\n\n")
|
||||
}
|
||||
sb.WriteString("3) 不要臆造不存在的工具名。\n\n")
|
||||
if s := strings.TrimSpace(instruction); s != "" {
|
||||
sb.WriteString(s)
|
||||
}
|
||||
|
||||
@@ -0,0 +1,173 @@
|
||||
package multiagent
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"cyberstrike-ai/internal/config"
|
||||
|
||||
"github.com/cloudwego/eino/adk"
|
||||
"github.com/cloudwego/eino/schema"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultEinoRunRetryMaxAttempts = 10
|
||||
defaultEinoRunRetryMaxBackoff = 30 * time.Second
|
||||
)
|
||||
|
||||
// isEinoTransientRunError 判断 ADK 运行期错误是否适合指数退避续跑(429、5xx、网络抖动等)。
|
||||
// 用户取消、超时、迭代上限等由 run loop 单独处理,不在此列。
|
||||
func isEinoTransientRunError(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
|
||||
return false
|
||||
}
|
||||
if isEinoIterationLimitError(err) {
|
||||
return false
|
||||
}
|
||||
msg := strings.ToLower(strings.TrimSpace(err.Error()))
|
||||
if msg == "" {
|
||||
return false
|
||||
}
|
||||
transientMarkers := []string{
|
||||
"406",
|
||||
"429",
|
||||
"too many requests",
|
||||
"rate limit",
|
||||
"rate_limit",
|
||||
"ratelimit",
|
||||
"quota exceeded",
|
||||
"overloaded",
|
||||
"capacity",
|
||||
"temporarily unavailable",
|
||||
"service unavailable",
|
||||
"bad gateway",
|
||||
"gateway timeout",
|
||||
"internal server error",
|
||||
"connection reset",
|
||||
"connection refused",
|
||||
"connection closed",
|
||||
"i/o timeout",
|
||||
"no such host",
|
||||
"network is unreachable",
|
||||
"broken pipe",
|
||||
"eof",
|
||||
"read tcp",
|
||||
"write tcp",
|
||||
"dial tcp",
|
||||
"tls handshake timeout",
|
||||
"stream error",
|
||||
"unexpected eof",
|
||||
"unexpected end of json",
|
||||
"status code: 406",
|
||||
"status code: 502",
|
||||
"502",
|
||||
"503",
|
||||
"504",
|
||||
"500",
|
||||
}
|
||||
for _, m := range transientMarkers {
|
||||
if strings.Contains(msg, m) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func einoRunRetryMaxAttempts(args *einoADKRunLoopArgs) int {
|
||||
if args != nil && args.RunRetryMaxAttempts > 0 {
|
||||
return args.RunRetryMaxAttempts
|
||||
}
|
||||
return defaultEinoRunRetryMaxAttempts
|
||||
}
|
||||
|
||||
// RunRetryMaxAttemptsFromConfig 供 handler 分段续跑计数(与 eino_middleware.run_retry_max_attempts 一致)。
|
||||
func RunRetryMaxAttemptsFromConfig(mw *config.MultiAgentEinoMiddlewareConfig) int {
|
||||
if mw != nil && mw.RunRetryMaxAttempts > 0 {
|
||||
return mw.RunRetryMaxAttempts
|
||||
}
|
||||
return defaultEinoRunRetryMaxAttempts
|
||||
}
|
||||
|
||||
// TransientRetryBackoff 供 handler 在分段续跑前退避。
|
||||
func TransientRetryBackoff(attempt int, maxBackoffSec int) time.Duration {
|
||||
max := defaultEinoRunRetryMaxBackoff
|
||||
if maxBackoffSec > 0 {
|
||||
max = time.Duration(maxBackoffSec) * time.Second
|
||||
}
|
||||
return einoTransientRetryBackoff(attempt, max)
|
||||
}
|
||||
|
||||
func einoRunRetryMaxBackoff(args *einoADKRunLoopArgs) time.Duration {
|
||||
if args != nil && args.RunRetryMaxBackoffSec > 0 {
|
||||
return time.Duration(args.RunRetryMaxBackoffSec) * time.Second
|
||||
}
|
||||
return defaultEinoRunRetryMaxBackoff
|
||||
}
|
||||
|
||||
// einoRunRestartContextSource 描述无 checkpoint Resume 时 Run 使用的消息来源(日志/SSE)。
|
||||
type einoRunRestartContextSource string
|
||||
|
||||
const (
|
||||
einoRestartContextInitial einoRunRestartContextSource = "initial"
|
||||
einoRestartContextAccumulated einoRunRestartContextSource = "accumulated"
|
||||
einoRestartContextModelTrace einoRunRestartContextSource = "model_trace"
|
||||
)
|
||||
|
||||
// einoMessagesForRunRestart 在退避后重新 Run 时选用最完整的上下文:
|
||||
// 1) ModelFacingTrace(与模型实际入参一致) 2) 事件流累积的 runAccumulatedMsgs 3) 初始 msgs。
|
||||
func einoMessagesForRunRestart(args *einoADKRunLoopArgs, baseMsgs, accumulated []adk.Message, baseCount int) ([]adk.Message, einoRunRestartContextSource) {
|
||||
if trace := persistTraceSource(args, nil); len(trace) > 0 {
|
||||
return append([]adk.Message(nil), trace...), einoRestartContextModelTrace
|
||||
}
|
||||
if len(accumulated) > baseCount {
|
||||
return append([]adk.Message(nil), accumulated...), einoRestartContextAccumulated
|
||||
}
|
||||
return append([]adk.Message(nil), baseMsgs...), einoRestartContextInitial
|
||||
}
|
||||
|
||||
// adkMessagesHasUserContent 从尾部向前查找,是否已有与 want 相同的 user 消息(避免重复 append)。
|
||||
func adkMessagesHasUserContent(msgs []adk.Message, want string) bool {
|
||||
want = strings.TrimSpace(want)
|
||||
if want == "" {
|
||||
return true
|
||||
}
|
||||
for i := len(msgs) - 1; i >= 0; i-- {
|
||||
m := msgs[i]
|
||||
if m == nil {
|
||||
continue
|
||||
}
|
||||
if m.Role == schema.User {
|
||||
return strings.TrimSpace(m.Content) == want
|
||||
}
|
||||
if m.Role == schema.Assistant || m.Role == schema.Tool {
|
||||
continue
|
||||
}
|
||||
break
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// appendUserMessageIfNeeded 在 history 轨迹之后追加本轮 user 消息(仅当轨迹中尚未包含该句)。
|
||||
func appendUserMessageIfNeeded(msgs []adk.Message, userMessage string) []adk.Message {
|
||||
if strings.TrimSpace(userMessage) == "" || adkMessagesHasUserContent(msgs, userMessage) {
|
||||
return msgs
|
||||
}
|
||||
return append(msgs, schema.UserMessage(userMessage))
|
||||
}
|
||||
|
||||
// einoTransientRetryBackoff 指数退避:2s, 4s, 8s… capped by maxBackoff。
|
||||
func einoTransientRetryBackoff(attempt int, maxBackoff time.Duration) time.Duration {
|
||||
if attempt < 0 {
|
||||
attempt = 0
|
||||
}
|
||||
backoff := time.Duration(1<<uint(attempt+1)) * time.Second
|
||||
if maxBackoff > 0 && backoff > maxBackoff {
|
||||
backoff = maxBackoff
|
||||
}
|
||||
return backoff
|
||||
}
|
||||
@@ -0,0 +1,104 @@
|
||||
package multiagent
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/cloudwego/eino/adk"
|
||||
"github.com/cloudwego/eino/schema"
|
||||
)
|
||||
|
||||
func TestIsEinoTransientRunError(t *testing.T) {
|
||||
t.Parallel()
|
||||
cases := []struct {
|
||||
name string
|
||||
err error
|
||||
want bool
|
||||
}{
|
||||
{"nil", nil, false},
|
||||
{"429", errors.New("HTTP 429 Too Many Requests"), true},
|
||||
{"rate limit", errors.New(`{"error":"rate limit exceeded"}`), true},
|
||||
{"connection reset", errors.New("read tcp: connection reset by peer"), true},
|
||||
{"503", errors.New("upstream returned 503"), true},
|
||||
{"iteration limit", errors.New("max iteration reached"), false},
|
||||
{"canceled", context.Canceled, false},
|
||||
{"deadline", context.DeadlineExceeded, false},
|
||||
{"auth", errors.New("invalid api key"), false},
|
||||
}
|
||||
for _, tc := range cases {
|
||||
tc := tc
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
if got := isEinoTransientRunError(tc.err); got != tc.want {
|
||||
t.Fatalf("isEinoTransientRunError(%v) = %v, want %v", tc.err, got, tc.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestEinoTransientRetryBackoff(t *testing.T) {
|
||||
t.Parallel()
|
||||
max := 30 * time.Second
|
||||
if got := einoTransientRetryBackoff(0, max); got != 2*time.Second {
|
||||
t.Fatalf("attempt 0: got %v", got)
|
||||
}
|
||||
if got := einoTransientRetryBackoff(4, max); got != 30*time.Second {
|
||||
t.Fatalf("attempt 4 capped: got %v", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEinoMessagesForRunRestart(t *testing.T) {
|
||||
t.Parallel()
|
||||
base := []adk.Message{schema.UserMessage("hi")}
|
||||
acc := append([]adk.Message(nil), base...)
|
||||
acc = append(acc, schema.AssistantMessage("step1", nil))
|
||||
|
||||
got, src := einoMessagesForRunRestart(nil, base, acc, len(base))
|
||||
if src != einoRestartContextAccumulated || len(got) != 2 {
|
||||
t.Fatalf("accumulated: src=%s len=%d", src, len(got))
|
||||
}
|
||||
|
||||
holder := newModelFacingTraceHolder()
|
||||
holder.storeFromState(&adk.ChatModelAgentState{
|
||||
Messages: []adk.Message{schema.UserMessage("u"), schema.AssistantMessage("model-view", nil)},
|
||||
})
|
||||
got2, src2 := einoMessagesForRunRestart(&einoADKRunLoopArgs{ModelFacingTrace: holder}, base, acc, len(base))
|
||||
if src2 != einoRestartContextModelTrace || len(got2) != 2 {
|
||||
t.Fatalf("model trace: src=%s len=%d", src2, len(got2))
|
||||
}
|
||||
}
|
||||
|
||||
func TestEinoRunRetryMaxAttemptsFromArgs(t *testing.T) {
|
||||
t.Parallel()
|
||||
if einoRunRetryMaxAttempts(nil) != defaultEinoRunRetryMaxAttempts {
|
||||
t.Fatal("nil args should use default")
|
||||
}
|
||||
if einoRunRetryMaxAttempts(&einoADKRunLoopArgs{RunRetryMaxAttempts: 3}) != 3 {
|
||||
t.Fatal("custom max attempts")
|
||||
}
|
||||
if RunRetryMaxAttemptsFromConfig(nil) != defaultEinoRunRetryMaxAttempts {
|
||||
t.Fatal("config nil should use default")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAppendUserMessageIfNeeded(t *testing.T) {
|
||||
t.Parallel()
|
||||
msgs := []adk.Message{schema.UserMessage("old task")}
|
||||
out := appendUserMessageIfNeeded(msgs, "你好,你是谁")
|
||||
if len(out) != 2 || out[1].Content != "你好,你是谁" {
|
||||
t.Fatalf("should append user: len=%d", len(out))
|
||||
}
|
||||
dup := appendUserMessageIfNeeded(out, "你好,你是谁")
|
||||
if len(dup) != 2 {
|
||||
t.Fatalf("should not duplicate user message: len=%d", len(dup))
|
||||
}
|
||||
}
|
||||
|
||||
func TestErrTransientRetryContinue(t *testing.T) {
|
||||
t.Parallel()
|
||||
if !errors.Is(ErrTransientRetryContinue, ErrTransientRetryContinue) {
|
||||
t.Fatal("sentinel should match")
|
||||
}
|
||||
}
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
|
||||
"github.com/cloudwego/eino/adk"
|
||||
"github.com/cloudwego/eino/compose"
|
||||
"github.com/cloudwego/eino/schema"
|
||||
)
|
||||
|
||||
type hitlInterceptorKey struct{}
|
||||
@@ -41,7 +42,31 @@ func WithHITLToolInterceptor(ctx context.Context, fn HITLToolInterceptor) contex
|
||||
return context.WithValue(ctx, hitlInterceptorKey{}, fn)
|
||||
}
|
||||
|
||||
func hitlToolCallMiddleware() compose.InvokableToolMiddleware {
|
||||
// hitlToolCallMiddleware 同时注册 Invokable 与 Streamable。
|
||||
// Eino filesystem 的 execute 为流式工具(StreamableTool),仅挂 Invokable 时人机协同不会拦截,会直接执行。
|
||||
func hitlToolCallMiddleware() compose.ToolMiddleware {
|
||||
return compose.ToolMiddleware{
|
||||
Invokable: hitlInvokableToolCallMiddleware(),
|
||||
Streamable: hitlStreamableToolCallMiddleware(),
|
||||
}
|
||||
}
|
||||
|
||||
func hitlClearReturnDirectlyIfTransfer(ctx context.Context, toolName string) {
|
||||
if !strings.EqualFold(strings.TrimSpace(toolName), adk.TransferToAgentToolName) {
|
||||
return
|
||||
}
|
||||
_ = compose.ProcessState[*adk.State](ctx, func(_ context.Context, st *adk.State) error {
|
||||
if st == nil {
|
||||
return nil
|
||||
}
|
||||
st.ReturnDirectlyToolCallID = ""
|
||||
st.HasReturnDirectly = false
|
||||
st.ReturnDirectlyEvent = nil
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
func hitlInvokableToolCallMiddleware() compose.InvokableToolMiddleware {
|
||||
return func(next compose.InvokableToolEndpoint) compose.InvokableToolEndpoint {
|
||||
return func(ctx context.Context, input *compose.ToolInput) (*compose.ToolOutput, error) {
|
||||
if input != nil {
|
||||
@@ -55,17 +80,7 @@ func hitlToolCallMiddleware() compose.InvokableToolMiddleware {
|
||||
// transfer_to_agent 在 Eino 中标记为 returnDirectly:工具成功后 ReAct 子图会直接 END,
|
||||
// 并依赖真实工具内的 SendToolGenAction 触发移交。HITL 拒绝时不会执行真实工具,
|
||||
// 若仍走 returnDirectly 分支,监督者会在无 Transfer 动作的情况下结束,模型不再迭代。
|
||||
if strings.EqualFold(strings.TrimSpace(input.Name), adk.TransferToAgentToolName) {
|
||||
_ = compose.ProcessState[*adk.State](ctx, func(_ context.Context, st *adk.State) error {
|
||||
if st == nil {
|
||||
return nil
|
||||
}
|
||||
st.ReturnDirectlyToolCallID = ""
|
||||
st.HasReturnDirectly = false
|
||||
st.ReturnDirectlyEvent = nil
|
||||
return nil
|
||||
})
|
||||
}
|
||||
hitlClearReturnDirectlyIfTransfer(ctx, input.Name)
|
||||
return &compose.ToolOutput{Result: msg}, nil
|
||||
}
|
||||
return nil, err
|
||||
@@ -79,3 +94,30 @@ func hitlToolCallMiddleware() compose.InvokableToolMiddleware {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func hitlStreamableToolCallMiddleware() compose.StreamableToolMiddleware {
|
||||
return func(next compose.StreamableToolEndpoint) compose.StreamableToolEndpoint {
|
||||
return func(ctx context.Context, input *compose.ToolInput) (*compose.StreamToolOutput, error) {
|
||||
if input != nil {
|
||||
if fn, ok := ctx.Value(hitlInterceptorKey{}).(HITLToolInterceptor); ok && fn != nil {
|
||||
edited, err := fn(ctx, input.Name, input.Arguments)
|
||||
if err != nil {
|
||||
if IsHumanRejectError(err) {
|
||||
msg := fmt.Sprintf("[HITL Reject] Tool '%s' was rejected by human reviewer. Reason: %s\nPlease adjust parameters/plan and continue without this call.",
|
||||
input.Name, strings.TrimSpace(err.Error()))
|
||||
hitlClearReturnDirectlyIfTransfer(ctx, input.Name)
|
||||
return &compose.StreamToolOutput{
|
||||
Result: schema.StreamReaderFromArray([]string{msg}),
|
||||
}, nil
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
if edited != "" {
|
||||
input.Arguments = edited
|
||||
}
|
||||
}
|
||||
}
|
||||
return next(ctx, input)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,11 @@
|
||||
package multiagent
|
||||
|
||||
import "errors"
|
||||
|
||||
// ErrInterruptContinue 作为 context.CancelCause 使用:用户选择「中断并继续」且当前无进行中的 MCP 工具时,
|
||||
// 取消当前推理/流式输出,并在同一会话任务内携带用户补充说明自动续跑下一轮(类似 Hermes 式人机回合)。
|
||||
var ErrInterruptContinue = errors.New("agent interrupt: continue with user-supplied context")
|
||||
|
||||
// ErrTransientRetryContinue 表示 Run 因 429/网络等临时错误结束,应由 handler 落库轨迹后
|
||||
// loadHistoryFromAgentTrace 再开下一轮 Run(与 ErrInterruptContinue 同级的「分段续跑」语义)。
|
||||
var ErrTransientRetryContinue = errors.New("agent transient: retry after persisting trace")
|
||||
@@ -0,0 +1,22 @@
|
||||
package multiagent
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// Eino execute 去重分支 EOF flush 须以 mainAssistantBuf 为基准计算 tail,
|
||||
// 若误用 TrimSpace(mainAssistantBuf),会与已推前缀在空白处失配,normalize 走拼接路径叠字。
|
||||
func TestNormalizeStreamingDelta_eofTailUsesRawBufNotTrim(t *testing.T) {
|
||||
wireAccum := "phrase "
|
||||
rawFull := "phrase \n"
|
||||
_, tail := normalizeStreamingDelta(wireAccum, rawFull)
|
||||
if want := "\n"; tail != want {
|
||||
t.Fatalf("tail=%q want %q", tail, want)
|
||||
}
|
||||
|
||||
nextWrong, badTail := normalizeStreamingDelta(wireAccum, strings.TrimSpace(rawFull))
|
||||
if badTail != "phrase" || nextWrong != "phrase phrase" {
|
||||
t.Fatalf("trimmed full vs wire prefix mismatch should concat-append; got next=%q badTail=%q", nextWrong, badTail)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,52 @@
|
||||
package multiagent
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// AggregatedReasoningFromTraceJSON concatenates non-empty assistant `reasoning_content`
|
||||
// fields from last_react-style JSON (slice of message objects) in document order.
|
||||
// Used to persist on the single assistant bubble row for audit and for GetMessages fallback
|
||||
// when the full trace JSON is unavailable. For strict per-message replay, prefer last_react_input.
|
||||
func AggregatedReasoningFromTraceJSON(traceJSON string) string {
|
||||
traceJSON = strings.TrimSpace(traceJSON)
|
||||
if traceJSON == "" {
|
||||
return ""
|
||||
}
|
||||
var arr []map[string]interface{}
|
||||
if err := json.Unmarshal([]byte(traceJSON), &arr); err != nil {
|
||||
return ""
|
||||
}
|
||||
var b strings.Builder
|
||||
for _, m := range arr {
|
||||
role, _ := m["role"].(string)
|
||||
if !strings.EqualFold(strings.TrimSpace(role), "assistant") {
|
||||
continue
|
||||
}
|
||||
rc := reasoningContentFromMessageMap(m)
|
||||
if rc == "" {
|
||||
continue
|
||||
}
|
||||
if b.Len() > 0 {
|
||||
b.WriteByte('\n')
|
||||
}
|
||||
b.WriteString(rc)
|
||||
}
|
||||
return b.String()
|
||||
}
|
||||
|
||||
func reasoningContentFromMessageMap(m map[string]interface{}) string {
|
||||
if m == nil {
|
||||
return ""
|
||||
}
|
||||
switch v := m["reasoning_content"].(type) {
|
||||
case string:
|
||||
return strings.TrimSpace(v)
|
||||
case nil:
|
||||
return ""
|
||||
default:
|
||||
return strings.TrimSpace(fmt.Sprint(v))
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,20 @@
|
||||
package multiagent
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestAggregatedReasoningFromTraceJSON(t *testing.T) {
|
||||
const j = `[
|
||||
{"role":"user","content":"hi"},
|
||||
{"role":"assistant","content":"c1","reasoning_content":"r1","tool_calls":[{"id":"1","type":"function","function":{"name":"f","arguments":"{}"}}]},
|
||||
{"role":"tool","tool_call_id":"1","content":"out"},
|
||||
{"role":"assistant","content":"c2","reasoning_content":"r2"}
|
||||
]`
|
||||
got := AggregatedReasoningFromTraceJSON(j)
|
||||
want := "r1\nr2"
|
||||
if got != want {
|
||||
t.Fatalf("got %q want %q", got, want)
|
||||
}
|
||||
if AggregatedReasoningFromTraceJSON("") != "" || AggregatedReasoningFromTraceJSON("[]") != "" {
|
||||
t.Fatal("empty expected")
|
||||
}
|
||||
}
|
||||
+160
-111
@@ -17,6 +17,7 @@ import (
|
||||
"cyberstrike-ai/internal/config"
|
||||
"cyberstrike-ai/internal/einomcp"
|
||||
"cyberstrike-ai/internal/openai"
|
||||
"cyberstrike-ai/internal/reasoning"
|
||||
|
||||
einoopenai "github.com/cloudwego/eino-ext/components/model/openai"
|
||||
"github.com/cloudwego/eino/adk"
|
||||
@@ -48,6 +49,7 @@ type toolCallPendingInfo struct {
|
||||
|
||||
// RunDeepAgent 使用 Eino 多代理预置编排执行一轮对话(deep / plan_execute / supervisor;流式事件通过 progress 回调输出)。
|
||||
// orchestrationOverride 非空时优先(如聊天/WebShell 请求体);否则用 multi_agent.orchestration(遗留 yaml);皆空则按 deep。
|
||||
// reasoningClient 来自 ChatRequest.reasoning;可为 nil(机器人/批量等走全局 openai.reasoning)。
|
||||
func RunDeepAgent(
|
||||
ctx context.Context,
|
||||
appCfg *config.Config,
|
||||
@@ -61,6 +63,7 @@ func RunDeepAgent(
|
||||
progress func(eventType, message string, data interface{}),
|
||||
agentsMarkdownDir string,
|
||||
orchestrationOverride string,
|
||||
reasoningClient *reasoning.ClientIntent,
|
||||
) (*RunResult, error) {
|
||||
if appCfg == nil || ma == nil || ag == nil {
|
||||
return nil, fmt.Errorf("multiagent: 配置或 Agent 为空")
|
||||
@@ -110,6 +113,7 @@ func RunDeepAgent(
|
||||
mcpIDs = append(mcpIDs, id)
|
||||
mcpIDsMu.Unlock()
|
||||
}
|
||||
einoExecMonitor := newEinoExecuteMonitorCallback(ag, recorder)
|
||||
|
||||
// 与单代理流式一致:在 response_start / response_delta 的 data 中带当前 mcpExecutionIds,供主聊天绑定复制与展示。
|
||||
snapshotMCPIDs := func() []string {
|
||||
@@ -120,6 +124,7 @@ func RunDeepAgent(
|
||||
return out
|
||||
}
|
||||
|
||||
toolInvokeNotify := einomcp.NewToolInvokeNotifyHolder()
|
||||
mainDefs := ag.ToolsForRole(roleTools)
|
||||
toolOutputChunk := func(toolName, toolCallID, chunk string) {
|
||||
// When toolCallId is missing, frontend ignores tool_result_delta.
|
||||
@@ -137,16 +142,6 @@ func RunDeepAgent(
|
||||
})
|
||||
}
|
||||
|
||||
mainTools, err := einomcp.ToolsFromDefinitions(ag, holder, mainDefs, recorder, toolOutputChunk)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
mainToolsForCfg, mainOrchestratorPre, err := prependEinoMiddlewares(ctx, &ma.EinoMiddleware, einoMWMain, mainTools, einoLoc, skillsRoot, conversationID, logger)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
httpClient := &http.Client{
|
||||
Timeout: 30 * time.Minute,
|
||||
Transport: &http.Transport{
|
||||
@@ -171,6 +166,7 @@ func RunDeepAgent(
|
||||
Model: appCfg.OpenAI.Model,
|
||||
HTTPClient: httpClient,
|
||||
}
|
||||
reasoning.ApplyToEinoChatModelConfig(baseModelCfg, &appCfg.OpenAI, reasoningClient)
|
||||
|
||||
deepMaxIter := ma.MaxIteration
|
||||
if deepMaxIter <= 0 {
|
||||
@@ -222,12 +218,12 @@ func RunDeepAgent(
|
||||
}
|
||||
|
||||
subDefs := ag.ToolsForRole(roleTools)
|
||||
subTools, err := einomcp.ToolsFromDefinitions(ag, holder, subDefs, recorder, toolOutputChunk)
|
||||
subTools, err := einomcp.ToolsFromDefinitions(ag, holder, subDefs, recorder, toolOutputChunk, toolInvokeNotify, id)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("子代理 %q 工具: %w", id, err)
|
||||
}
|
||||
|
||||
subToolsForCfg, subPre, err := prependEinoMiddlewares(ctx, &ma.EinoMiddleware, einoMWSub, subTools, einoLoc, skillsRoot, conversationID, logger)
|
||||
subToolsForCfg, subPre, subToolSearchActive, err := prependEinoMiddlewares(ctx, &ma.EinoMiddleware, einoMWSub, subTools, einoLoc, skillsRoot, conversationID, logger)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("子代理 %q eino 中间件: %w", id, err)
|
||||
}
|
||||
@@ -248,7 +244,7 @@ func RunDeepAgent(
|
||||
}
|
||||
if einoSkillMW != nil {
|
||||
if einoFSTools && einoLoc != nil {
|
||||
subFs, fsErr := subAgentFilesystemMiddleware(ctx, einoLoc)
|
||||
subFs, fsErr := subAgentFilesystemMiddleware(ctx, einoLoc, toolInvokeNotify, id, einoExecMonitor, agentToolTimeoutMinutes(appCfg), toolOutputChunk)
|
||||
if fsErr != nil {
|
||||
return nil, fmt.Errorf("子代理 %q filesystem 中间件: %w", id, fsErr)
|
||||
}
|
||||
@@ -264,23 +260,16 @@ func RunDeepAgent(
|
||||
subHandlers = append(subHandlers, teleMw)
|
||||
}
|
||||
|
||||
subInstrFinal := injectToolNamesOnlyInstruction(ctx, instr, subTools)
|
||||
subInstrFinal := injectToolNamesOnlyInstruction(ctx, instr, subTools, subToolSearchActive)
|
||||
if logger != nil {
|
||||
subNames := collectToolNames(ctx, subTools)
|
||||
mountedNames := collectToolNames(ctx, subToolsForCfg)
|
||||
hasToolSearch := false
|
||||
for _, n := range subNames {
|
||||
if strings.EqualFold(strings.TrimSpace(n), "tool_search") {
|
||||
hasToolSearch = true
|
||||
break
|
||||
}
|
||||
}
|
||||
logger.Info("eino tool-name injection",
|
||||
zap.String("scope", "sub_agent"),
|
||||
zap.String("agent", id),
|
||||
zap.Int("tool_names", len(subNames)),
|
||||
zap.Int("mounted_tool_names", len(mountedNames)),
|
||||
zap.Bool("has_tool_search", hasToolSearch),
|
||||
zap.Bool("tool_search_middleware", subToolSearchActive),
|
||||
)
|
||||
}
|
||||
sa, err := adk.NewChatModelAgent(ctx, &adk.ChatModelAgentConfig{
|
||||
@@ -293,8 +282,8 @@ func RunDeepAgent(
|
||||
Tools: subToolsForCfg,
|
||||
UnknownToolsHandler: einomcp.UnknownToolReminderHandler(),
|
||||
ToolCallMiddlewares: []compose.ToolMiddleware{
|
||||
{Invokable: hitlToolCallMiddleware()},
|
||||
{Invokable: softRecoveryToolCallMiddleware()},
|
||||
hitlToolCallMiddleware(),
|
||||
softRecoveryToolMiddleware(),
|
||||
},
|
||||
},
|
||||
EmitInternalEvents: true,
|
||||
@@ -319,6 +308,8 @@ func RunDeepAgent(
|
||||
return nil, fmt.Errorf("多代理主 summarization 中间件: %w", err)
|
||||
}
|
||||
|
||||
modelFacingTrace := newModelFacingTraceHolder()
|
||||
|
||||
// 与 deep.Config.Name / supervisor 主代理 Name 一致。
|
||||
orchestratorName := "cyberstrike-deep"
|
||||
orchDescription := "Coordinates specialist agents and MCP tools for authorized security testing."
|
||||
@@ -338,23 +329,26 @@ func RunDeepAgent(
|
||||
orchDescription = d
|
||||
}
|
||||
}
|
||||
orchInstruction = injectToolNamesOnlyInstruction(ctx, orchInstruction, mainTools)
|
||||
|
||||
mainTools, err := einomcp.ToolsFromDefinitions(ag, holder, mainDefs, recorder, toolOutputChunk, toolInvokeNotify, orchestratorName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
mainToolsForCfg, mainOrchestratorPre, mainToolSearchActive, err := prependEinoMiddlewares(ctx, &ma.EinoMiddleware, einoMWMain, mainTools, einoLoc, skillsRoot, conversationID, logger)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
orchInstruction = injectToolNamesOnlyInstruction(ctx, orchInstruction, mainTools, mainToolSearchActive)
|
||||
if logger != nil {
|
||||
mainNames := collectToolNames(ctx, mainTools)
|
||||
mountedNames := collectToolNames(ctx, mainToolsForCfg)
|
||||
hasToolSearch := false
|
||||
for _, n := range mainNames {
|
||||
if strings.EqualFold(strings.TrimSpace(n), "tool_search") {
|
||||
hasToolSearch = true
|
||||
break
|
||||
}
|
||||
}
|
||||
logger.Info("eino tool-name injection",
|
||||
zap.String("scope", "orchestrator"),
|
||||
zap.String("orchestration", orchMode),
|
||||
zap.Int("tool_names", len(mainNames)),
|
||||
zap.Int("mounted_tool_names", len(mountedNames)),
|
||||
zap.Bool("has_tool_search", hasToolSearch),
|
||||
zap.Bool("tool_search_middleware", mainToolSearchActive),
|
||||
)
|
||||
}
|
||||
|
||||
@@ -381,7 +375,14 @@ func RunDeepAgent(
|
||||
var deepShell filesystem.StreamingShell
|
||||
if einoLoc != nil && einoFSTools {
|
||||
deepBackend = einoLoc
|
||||
deepShell = einoLoc
|
||||
deepShell = &einoStreamingShellWrap{
|
||||
inner: einoLoc,
|
||||
invokeNotify: toolInvokeNotify,
|
||||
einoAgentName: orchestratorName,
|
||||
outputChunk: toolOutputChunk,
|
||||
recordMonitor: einoExecMonitor,
|
||||
toolTimeoutMinutes: agentToolTimeoutMinutes(appCfg),
|
||||
}
|
||||
}
|
||||
|
||||
// noNestedTaskMiddleware 必须在最外层(最先拦截),防止 skill 或其他中间件内部触发 task 调用绕过检测。
|
||||
@@ -400,6 +401,9 @@ func RunDeepAgent(
|
||||
if teleMw := newEinoModelInputTelemetryMiddleware(logger, appCfg.OpenAI.Model, conversationID, "deep_orchestrator"); teleMw != nil {
|
||||
deepHandlers = append(deepHandlers, teleMw)
|
||||
}
|
||||
if capMw := newModelFacingTraceMiddleware(modelFacingTrace); capMw != nil {
|
||||
deepHandlers = append(deepHandlers, capMw)
|
||||
}
|
||||
|
||||
supHandlers := []adk.ChatModelAgentMiddleware{}
|
||||
if len(mainOrchestratorPre) > 0 {
|
||||
@@ -413,14 +417,17 @@ func RunDeepAgent(
|
||||
if teleMw := newEinoModelInputTelemetryMiddleware(logger, appCfg.OpenAI.Model, conversationID, "supervisor_orchestrator"); teleMw != nil {
|
||||
supHandlers = append(supHandlers, teleMw)
|
||||
}
|
||||
if capMw := newModelFacingTraceMiddleware(modelFacingTrace); capMw != nil {
|
||||
supHandlers = append(supHandlers, capMw)
|
||||
}
|
||||
|
||||
mainToolsCfg := adk.ToolsConfig{
|
||||
ToolsNodeConfig: compose.ToolsNodeConfig{
|
||||
Tools: mainToolsForCfg,
|
||||
UnknownToolsHandler: einomcp.UnknownToolReminderHandler(),
|
||||
ToolCallMiddlewares: []compose.ToolMiddleware{
|
||||
{Invokable: hitlToolCallMiddleware()},
|
||||
{Invokable: softRecoveryToolCallMiddleware()},
|
||||
hitlToolCallMiddleware(),
|
||||
softRecoveryToolMiddleware(),
|
||||
},
|
||||
},
|
||||
EmitInternalEvents: true,
|
||||
@@ -438,7 +445,7 @@ func RunDeepAgent(
|
||||
// 构建 filesystem 中间件(与 Deep sub-agent 一致)
|
||||
var peFsMw adk.ChatModelAgentMiddleware
|
||||
if einoSkillMW != nil && einoFSTools && einoLoc != nil {
|
||||
peFsMw, err = subAgentFilesystemMiddleware(ctx, einoLoc)
|
||||
peFsMw, err = subAgentFilesystemMiddleware(ctx, einoLoc, toolInvokeNotify, "executor", einoExecMonitor, agentToolTimeoutMinutes(appCfg), toolOutputChunk)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("plan_execute filesystem 中间件: %w", err)
|
||||
}
|
||||
@@ -458,6 +465,7 @@ func RunDeepAgent(
|
||||
ExecPreMiddlewares: mainOrchestratorPre,
|
||||
SkillMiddleware: einoSkillMW,
|
||||
FilesystemMiddleware: peFsMw,
|
||||
ModelFacingTrace: modelFacingTrace,
|
||||
PlannerReplannerRewriteHandlers: []adk.ChatModelAgentMiddleware{
|
||||
mainSumMw,
|
||||
// 孤儿 tool 消息兜底:必须挂在 summarization 之后、telemetry 之前。
|
||||
@@ -530,7 +538,7 @@ func RunDeepAgent(
|
||||
}
|
||||
|
||||
baseMsgs := historyToMessages(history, appCfg, &ma.EinoMiddleware)
|
||||
baseMsgs = append(baseMsgs, schema.UserMessage(userMessage))
|
||||
baseMsgs = appendUserMessageIfNeeded(baseMsgs, userMessage)
|
||||
|
||||
streamsMainAssistant := func(agent string) bool {
|
||||
if orchMode == "plan_execute" {
|
||||
@@ -549,95 +557,109 @@ func RunDeepAgent(
|
||||
}
|
||||
|
||||
return runEinoADKAgentLoop(ctx, &einoADKRunLoopArgs{
|
||||
OrchMode: orchMode,
|
||||
OrchestratorName: orchestratorName,
|
||||
ConversationID: conversationID,
|
||||
Progress: progress,
|
||||
Logger: logger,
|
||||
SnapshotMCPIDs: snapshotMCPIDs,
|
||||
StreamsMainAssistant: streamsMainAssistant,
|
||||
EinoRoleTag: einoRoleTag,
|
||||
CheckpointDir: ma.EinoMiddleware.CheckpointDir,
|
||||
McpIDsMu: &mcpIDsMu,
|
||||
McpIDs: &mcpIDs,
|
||||
DA: da,
|
||||
OrchMode: orchMode,
|
||||
OrchestratorName: orchestratorName,
|
||||
ConversationID: conversationID,
|
||||
Progress: progress,
|
||||
Logger: logger,
|
||||
SnapshotMCPIDs: snapshotMCPIDs,
|
||||
StreamsMainAssistant: streamsMainAssistant,
|
||||
EinoRoleTag: einoRoleTag,
|
||||
CheckpointDir: ma.EinoMiddleware.CheckpointDir,
|
||||
RunRetryMaxAttempts: ma.EinoMiddleware.RunRetryMaxAttempts,
|
||||
RunRetryMaxBackoffSec: ma.EinoMiddleware.RunRetryMaxBackoffSec,
|
||||
McpIDsMu: &mcpIDsMu,
|
||||
McpIDs: &mcpIDs,
|
||||
FilesystemMonitorAgent: ag,
|
||||
FilesystemMonitorRecord: recorder,
|
||||
ToolInvokeNotify: toolInvokeNotify,
|
||||
DA: da,
|
||||
ModelFacingTrace: modelFacingTrace,
|
||||
EinoCallbacks: &ma.EinoCallbacks,
|
||||
EmptyResponseMessage: "(Eino multi-agent orchestration completed but no assistant text was captured. Check process details or logs.) " +
|
||||
"(Eino 多代理编排已完成,但未捕获到助手文本输出。请查看过程详情或日志。)",
|
||||
}, baseMsgs)
|
||||
}
|
||||
|
||||
func chatToolCallsToSchema(tcs []agent.ToolCall) []schema.ToolCall {
|
||||
if len(tcs) == 0 {
|
||||
return nil
|
||||
}
|
||||
out := make([]schema.ToolCall, 0, len(tcs))
|
||||
for _, tc := range tcs {
|
||||
if strings.TrimSpace(tc.ID) == "" {
|
||||
continue
|
||||
}
|
||||
argsStr := ""
|
||||
if tc.Function.Arguments != nil {
|
||||
b, err := json.Marshal(tc.Function.Arguments)
|
||||
if err == nil {
|
||||
argsStr = string(b)
|
||||
}
|
||||
}
|
||||
// Some OpenAI-compatible gateways require `function.arguments` to exist
|
||||
// on every assistant tool_call message. When args are empty, omitempty may
|
||||
// drop the field during serialization and cause "missing field arguments"
|
||||
// on the next turn history replay.
|
||||
if strings.TrimSpace(argsStr) == "" {
|
||||
argsStr = "{}"
|
||||
}
|
||||
typ := tc.Type
|
||||
if typ == "" {
|
||||
typ = "function"
|
||||
}
|
||||
out = append(out, schema.ToolCall{
|
||||
ID: tc.ID,
|
||||
Type: typ,
|
||||
Function: schema.FunctionCall{
|
||||
Name: tc.Function.Name,
|
||||
Arguments: argsStr,
|
||||
},
|
||||
})
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// historyToMessages 将轨迹恢复的 ChatMessage 转为 Eino ADK 消息:**不裁剪条数、不按 token 预算截断**,
|
||||
// 并保留 user / assistant(含仅 tool_calls)/ tool,与库中 last_react 轨迹一致。
|
||||
func historyToMessages(history []agent.ChatMessage, appCfg *config.Config, mwCfg *config.MultiAgentEinoMiddlewareConfig) []adk.Message {
|
||||
_ = appCfg
|
||||
_ = mwCfg
|
||||
if len(history) == 0 {
|
||||
return nil
|
||||
}
|
||||
// Keep a bounded tail first; then enforce a token budget.
|
||||
const maxHistoryMessages = 200
|
||||
start := 0
|
||||
if len(history) > maxHistoryMessages {
|
||||
start = len(history) - maxHistoryMessages
|
||||
}
|
||||
raw := make([]adk.Message, 0, len(history[start:]))
|
||||
for _, h := range history[start:] {
|
||||
switch h.Role {
|
||||
raw := make([]adk.Message, 0, len(history))
|
||||
for _, h := range history {
|
||||
role := strings.ToLower(strings.TrimSpace(h.Role))
|
||||
switch role {
|
||||
case "user":
|
||||
if strings.TrimSpace(h.Content) != "" {
|
||||
raw = append(raw, schema.UserMessage(h.Content))
|
||||
}
|
||||
case "assistant":
|
||||
if strings.TrimSpace(h.Content) == "" && len(h.ToolCalls) > 0 {
|
||||
toolSchema := chatToolCallsToSchema(h.ToolCalls)
|
||||
hasRC := strings.TrimSpace(h.ReasoningContent) != ""
|
||||
if len(toolSchema) > 0 || strings.TrimSpace(h.Content) != "" || hasRC {
|
||||
am := schema.AssistantMessage(h.Content, toolSchema)
|
||||
if hasRC {
|
||||
am.ReasoningContent = strings.TrimSpace(h.ReasoningContent)
|
||||
}
|
||||
raw = append(raw, am)
|
||||
}
|
||||
case "tool":
|
||||
if strings.TrimSpace(h.ToolCallID) == "" && strings.TrimSpace(h.Content) == "" {
|
||||
continue
|
||||
}
|
||||
if strings.TrimSpace(h.Content) != "" {
|
||||
raw = append(raw, schema.AssistantMessage(h.Content, nil))
|
||||
var opts []schema.ToolMessageOption
|
||||
if tn := strings.TrimSpace(h.ToolName); tn != "" {
|
||||
opts = append(opts, schema.WithToolName(tn))
|
||||
}
|
||||
raw = append(raw, schema.ToolMessage(h.Content, h.ToolCallID, opts...))
|
||||
default:
|
||||
continue
|
||||
}
|
||||
}
|
||||
if len(raw) == 0 {
|
||||
return raw
|
||||
}
|
||||
maxTotal := 120000
|
||||
modelName := "gpt-4o"
|
||||
if appCfg != nil {
|
||||
if appCfg.OpenAI.MaxTotalTokens > 0 {
|
||||
maxTotal = appCfg.OpenAI.MaxTotalTokens
|
||||
}
|
||||
if m := strings.TrimSpace(appCfg.OpenAI.Model); m != "" {
|
||||
modelName = m
|
||||
}
|
||||
}
|
||||
ratio := 0.35
|
||||
if mwCfg != nil {
|
||||
ratio = mwCfg.HistoryInputBudgetRatioEffective()
|
||||
}
|
||||
budget := int(float64(maxTotal) * ratio)
|
||||
if budget < 4096 {
|
||||
budget = 4096
|
||||
}
|
||||
tc := agent.NewTikTokenCounter()
|
||||
outRev := make([]adk.Message, 0, len(raw))
|
||||
used := 0
|
||||
for i := len(raw) - 1; i >= 0; i-- {
|
||||
msg := raw[i]
|
||||
n, err := tc.Count(modelName, string(msg.Role)+"\n"+msg.Content)
|
||||
if err != nil {
|
||||
n = (len(msg.Content) + 3) / 4
|
||||
}
|
||||
if n <= 0 {
|
||||
n = 1
|
||||
}
|
||||
if used+n > budget {
|
||||
break
|
||||
}
|
||||
used += n
|
||||
outRev = append(outRev, msg)
|
||||
}
|
||||
out := make([]adk.Message, 0, len(outRev))
|
||||
for i := len(outRev) - 1; i >= 0; i-- {
|
||||
out = append(out, outRev[i])
|
||||
}
|
||||
return out
|
||||
return raw
|
||||
}
|
||||
|
||||
// mergeStreamingToolCallFragments 将流式多帧的 ToolCall 按 index 合并 arguments(与 schema.concatToolCalls 行为一致)。
|
||||
@@ -724,12 +746,23 @@ func toolCallsRichSignature(msg *schema.Message) string {
|
||||
return base + "|" + strings.Join(parts, ";")
|
||||
}
|
||||
|
||||
func einoMainIterationKey(agentName, orchestratorName string) string {
|
||||
key := strings.TrimSpace(agentName)
|
||||
if key == "" {
|
||||
key = strings.TrimSpace(orchestratorName)
|
||||
}
|
||||
if key == "" {
|
||||
return "_main"
|
||||
}
|
||||
return key
|
||||
}
|
||||
|
||||
func tryEmitToolCallsOnce(
|
||||
msg *schema.Message,
|
||||
agentName, orchestratorName, conversationID string,
|
||||
agentName, orchestratorName, conversationID, orchMode string,
|
||||
progress func(string, string, interface{}),
|
||||
seen map[string]struct{},
|
||||
subAgentToolStep map[string]int,
|
||||
subAgentToolStep, mainAgentToolStep map[string]int,
|
||||
markPending func(toolCallPendingInfo),
|
||||
) {
|
||||
if msg == nil || len(msg.ToolCalls) == 0 || progress == nil || seen == nil {
|
||||
@@ -743,14 +776,14 @@ func tryEmitToolCallsOnce(
|
||||
return
|
||||
}
|
||||
seen[sig] = struct{}{}
|
||||
emitToolCallsFromMessage(msg, agentName, orchestratorName, conversationID, progress, subAgentToolStep, markPending)
|
||||
emitToolCallsFromMessage(msg, agentName, orchestratorName, conversationID, orchMode, progress, subAgentToolStep, mainAgentToolStep, markPending)
|
||||
}
|
||||
|
||||
func emitToolCallsFromMessage(
|
||||
msg *schema.Message,
|
||||
agentName, orchestratorName, conversationID string,
|
||||
agentName, orchestratorName, conversationID, orchMode string,
|
||||
progress func(string, string, interface{}),
|
||||
subAgentToolStep map[string]int,
|
||||
subAgentToolStep, mainAgentToolStep map[string]int,
|
||||
markPending func(toolCallPendingInfo),
|
||||
) {
|
||||
if msg == nil || len(msg.ToolCalls) == 0 || progress == nil {
|
||||
@@ -771,6 +804,22 @@ func emitToolCallsFromMessage(
|
||||
"conversationId": conversationID,
|
||||
"source": "eino",
|
||||
})
|
||||
} else if mainAgentToolStep != nil {
|
||||
key := einoMainIterationKey(agentName, orchestratorName)
|
||||
mainAgentToolStep[key]++
|
||||
n := mainAgentToolStep[key]
|
||||
// 第 1 轮已在主代理进入时发出;此后每次工具批次对应新一轮 ReAct(与子代理按工具计步一致)。
|
||||
if n > 1 {
|
||||
progress("iteration", "", map[string]interface{}{
|
||||
"iteration": n,
|
||||
"einoScope": "main",
|
||||
"einoRole": "orchestrator",
|
||||
"einoAgent": agentName,
|
||||
"orchestration": orchMode,
|
||||
"conversationId": conversationID,
|
||||
"source": "eino",
|
||||
})
|
||||
}
|
||||
}
|
||||
role := "orchestrator"
|
||||
if isSubToolRound {
|
||||
|
||||
@@ -0,0 +1,22 @@
|
||||
package multiagent
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"cyberstrike-ai/internal/agent"
|
||||
)
|
||||
|
||||
func TestHistoryToMessagesPreservesReasoningContent(t *testing.T) {
|
||||
h := []agent.ChatMessage{
|
||||
{Role: "user", Content: "u"},
|
||||
{Role: "assistant", Content: "c", ReasoningContent: "r1", ToolCalls: []agent.ToolCall{{ID: "t1", Type: "function", Function: agent.FunctionCall{Name: "f", Arguments: map[string]interface{}{}}}}},
|
||||
}
|
||||
msgs := historyToMessages(h, nil, nil)
|
||||
if len(msgs) != 2 {
|
||||
t.Fatalf("len=%d", len(msgs))
|
||||
}
|
||||
am := msgs[1]
|
||||
if am.ReasoningContent != "r1" || am.Content != "c" {
|
||||
t.Fatalf("got reasoning=%q content=%q", am.ReasoningContent, am.Content)
|
||||
}
|
||||
}
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"strings"
|
||||
|
||||
"github.com/cloudwego/eino/compose"
|
||||
"github.com/cloudwego/eino/schema"
|
||||
)
|
||||
|
||||
// softRecoveryToolCallMiddleware returns an InvokableToolMiddleware that catches
|
||||
@@ -16,8 +17,9 @@ import (
|
||||
// returned to the LLM. This allows the model to self-correct within the same
|
||||
// iteration rather than crashing the entire graph and requiring a full replay.
|
||||
//
|
||||
// Without this middleware, a JSON parse failure in any tool's InvokableRun propagates
|
||||
// as a hard error through the Eino ToolsNode → [NodeRunError] → ev.Err, which
|
||||
// Without Invokable (+ Streamable where applicable) registration, a JSON parse failure
|
||||
// in InvokableRun / StreamableRun propagates as a hard error through the Eino ToolsNode
|
||||
// → [NodeRunError] → ev.Err, which
|
||||
// either triggers the full-replay retry loop (expensive) or terminates the run
|
||||
// entirely once retries are exhausted. With it, the LLM simply sees an error message
|
||||
// in the tool result and can adjust its next tool call accordingly.
|
||||
@@ -39,6 +41,44 @@ func softRecoveryToolCallMiddleware() compose.InvokableToolMiddleware {
|
||||
}
|
||||
}
|
||||
|
||||
// softRecoveryStreamableToolCallMiddleware mirrors softRecoveryToolCallMiddleware for
|
||||
// tools that implement StreamableTool only (e.g. Eino ADK filesystem execute).
|
||||
// Eino applies Invokable vs Streamable middleware to disjoint code paths in ToolsNode;
|
||||
// registering only Invokable leaves streaming tools uncovered — empty/malformed JSON
|
||||
// then fails inside [LocalStreamFunc] before the inner endpoint runs.
|
||||
func softRecoveryStreamableToolCallMiddleware() compose.StreamableToolMiddleware {
|
||||
return func(next compose.StreamableToolEndpoint) compose.StreamableToolEndpoint {
|
||||
return func(ctx context.Context, input *compose.ToolInput) (*compose.StreamToolOutput, error) {
|
||||
out, err := next(ctx, input)
|
||||
if err == nil {
|
||||
return out, nil
|
||||
}
|
||||
if !isSoftRecoverableToolError(err) {
|
||||
return out, err
|
||||
}
|
||||
toolName := ""
|
||||
args := ""
|
||||
if input != nil {
|
||||
toolName = input.Name
|
||||
args = input.Arguments
|
||||
}
|
||||
msg := buildSoftRecoveryMessage(toolName, args, err)
|
||||
return &compose.StreamToolOutput{
|
||||
Result: schema.StreamReaderFromArray([]string{msg}),
|
||||
}, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// softRecoveryToolMiddleware returns a ToolMiddleware with both Invokable and Streamable
|
||||
// soft recovery (same semantics as hitlToolCallMiddleware bundling).
|
||||
func softRecoveryToolMiddleware() compose.ToolMiddleware {
|
||||
return compose.ToolMiddleware{
|
||||
Invokable: softRecoveryToolCallMiddleware(),
|
||||
Streamable: softRecoveryStreamableToolCallMiddleware(),
|
||||
}
|
||||
}
|
||||
|
||||
// isSoftRecoverableToolError determines whether a tool execution error should be
|
||||
// silently converted to a tool-result message rather than crashing the graph.
|
||||
//
|
||||
|
||||
@@ -4,6 +4,8 @@ import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"io"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/cloudwego/eino/compose"
|
||||
@@ -108,6 +110,39 @@ func TestSoftRecoveryToolCallMiddleware_PassesThrough(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestSoftRecoveryStreamableToolCallMiddleware_LocalStreamFuncJSONError(t *testing.T) {
|
||||
mw := softRecoveryStreamableToolCallMiddleware()
|
||||
next := func(ctx context.Context, input *compose.ToolInput) (*compose.StreamToolOutput, error) {
|
||||
return nil, errors.New(`[LocalStreamFunc] failed to unmarshal arguments in json, toolName=execute, err="Syntax error no sources available, the input json is empty`)
|
||||
}
|
||||
wrapped := mw(next)
|
||||
out, err := wrapped(context.Background(), &compose.ToolInput{
|
||||
Name: "execute",
|
||||
Arguments: "",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("expected nil error (soft recovery), got: %v", err)
|
||||
}
|
||||
if out == nil || out.Result == nil {
|
||||
t.Fatal("expected stream result")
|
||||
}
|
||||
var sb strings.Builder
|
||||
for {
|
||||
chunk, rerr := out.Result.Recv()
|
||||
if errors.Is(rerr, io.EOF) {
|
||||
break
|
||||
}
|
||||
if rerr != nil {
|
||||
t.Fatalf("recv: %v", rerr)
|
||||
}
|
||||
sb.WriteString(chunk)
|
||||
}
|
||||
text := sb.String()
|
||||
if !containsAll(text, "[Tool Error]", "execute", "JSON") {
|
||||
t.Fatalf("recovery message missing expected content: %s", text)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSoftRecoveryToolCallMiddleware_ConvertsJSONError(t *testing.T) {
|
||||
mw := softRecoveryToolCallMiddleware()
|
||||
next := func(ctx context.Context, input *compose.ToolInput) (*compose.ToolOutput, error) {
|
||||
|
||||
@@ -9,6 +9,9 @@ package openai
|
||||
// Stream: Claude SSE (event: content_block_delta / message_delta) → OpenAI SSE 格式
|
||||
// Auth: Bearer → x-api-key
|
||||
// Tools: OpenAI tools[] → Claude tools[] (input_schema)
|
||||
//
|
||||
// Extended thinking: 顶层 `thinking` 从 OpenAI 请求体透传;响应中 `thinking` block 映射为
|
||||
// `reasoning_content`(可读前缀 + 内部 JSON 尾缀以保留 signature,供多轮工具续跑;UI 用 openai.DisplayReasoningContent 剥离)。
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
@@ -38,6 +41,7 @@ type claudeRequest struct {
|
||||
Messages []claudeMessage `json:"messages"`
|
||||
Tools []claudeTool `json:"tools,omitempty"`
|
||||
Stream bool `json:"stream,omitempty"`
|
||||
Thinking json.RawMessage `json:"thinking,omitempty"`
|
||||
}
|
||||
|
||||
type claudeMessage struct {
|
||||
@@ -76,6 +80,10 @@ type claudeContentBlock struct {
|
||||
// text block
|
||||
Text string `json:"text,omitempty"`
|
||||
|
||||
// thinking block (extended thinking)
|
||||
Thinking string `json:"thinking,omitempty"`
|
||||
Signature string `json:"signature,omitempty"`
|
||||
|
||||
// tool_use block (assistant 返回)
|
||||
ID string `json:"id,omitempty"`
|
||||
Name string `json:"name,omitempty"`
|
||||
@@ -176,7 +184,13 @@ func convertOpenAIToClaude(payload interface{}) (*claudeRequest, error) {
|
||||
|
||||
// tool_calls (assistant 消息中包含工具调用)
|
||||
if role == "assistant" {
|
||||
rc, _ := mm["reasoning_content"].(string)
|
||||
_, thinkingReplay := parseClaudeReasoningAssistantBlocks(rc)
|
||||
|
||||
var blocks []claudeContentBlock
|
||||
for _, tb := range thinkingReplay {
|
||||
blocks = append(blocks, tb)
|
||||
}
|
||||
if content != "" {
|
||||
blocks = append(blocks, claudeContentBlock{Type: "text", Text: content})
|
||||
}
|
||||
@@ -290,6 +304,13 @@ func convertOpenAIToClaude(payload interface{}) (*claudeRequest, error) {
|
||||
}
|
||||
}
|
||||
|
||||
// Extended thinking (Anthropic top-level); merged from Eino ExtraFields / admin extras.
|
||||
if th, ok := oai["thinking"]; ok && th != nil {
|
||||
if raw, err := json.Marshal(th); err == nil && len(raw) > 0 && string(raw) != "null" {
|
||||
req.Thinking = json.RawMessage(raw)
|
||||
}
|
||||
}
|
||||
|
||||
return req, nil
|
||||
}
|
||||
|
||||
@@ -318,9 +339,12 @@ func claudeToOpenAIResponseJSON(claudeBody []byte) ([]byte, error) {
|
||||
|
||||
var textContent string
|
||||
var toolCalls []interface{}
|
||||
var thinkingBlocks []claudeContentBlock
|
||||
|
||||
for _, block := range cr.Content {
|
||||
switch block.Type {
|
||||
case "thinking":
|
||||
thinkingBlocks = append(thinkingBlocks, block)
|
||||
case "text":
|
||||
textContent += block.Text
|
||||
case "tool_use":
|
||||
@@ -344,6 +368,18 @@ func claudeToOpenAIResponseJSON(claudeBody []byte) ([]byte, error) {
|
||||
if len(toolCalls) > 0 {
|
||||
message["tool_calls"] = toolCalls
|
||||
}
|
||||
if len(thinkingBlocks) > 0 {
|
||||
var parts []string
|
||||
for _, tb := range thinkingBlocks {
|
||||
if strings.TrimSpace(tb.Thinking) != "" {
|
||||
parts = append(parts, tb.Thinking)
|
||||
}
|
||||
}
|
||||
rc := appendClaudeReasoningRoundTrip(strings.Join(parts, "\n\n"), thinkingBlocks)
|
||||
if rc != "" {
|
||||
message["reasoning_content"] = rc
|
||||
}
|
||||
}
|
||||
|
||||
choice := map[string]interface{}{
|
||||
"index": 0,
|
||||
@@ -901,8 +937,16 @@ func (rt *claudeRoundTripper) RoundTrip(req *http.Request) (*http.Response, erro
|
||||
|
||||
reader := bufio.NewReader(resp.Body)
|
||||
blockToToolIndex := make(map[int]int)
|
||||
blockIndexToType := make(map[int]string)
|
||||
nextToolIndex := 0
|
||||
|
||||
type thinkingAcc struct {
|
||||
text strings.Builder
|
||||
sig strings.Builder
|
||||
}
|
||||
thinkingByIndex := make(map[int]*thinkingAcc)
|
||||
var finishedThinking []claudeContentBlock
|
||||
|
||||
for {
|
||||
line, readErr := reader.ReadString('\n')
|
||||
if readErr != nil {
|
||||
@@ -947,6 +991,11 @@ func (rt *claudeRoundTripper) RoundTrip(req *http.Request) (*http.Response, erro
|
||||
blockIdx := int(blockIdxFlt)
|
||||
cb, _ := event["content_block"].(map[string]interface{})
|
||||
bt, _ := cb["type"].(string)
|
||||
blockIndexToType[blockIdx] = bt
|
||||
|
||||
if bt == "thinking" {
|
||||
thinkingByIndex[blockIdx] = &thinkingAcc{}
|
||||
}
|
||||
|
||||
if bt == "tool_use" {
|
||||
id, _ := cb["id"].(string)
|
||||
@@ -986,7 +1035,35 @@ func (rt *claudeRoundTripper) RoundTrip(req *http.Request) (*http.Response, erro
|
||||
delta, _ := event["delta"].(map[string]interface{})
|
||||
dt, _ := delta["type"].(string)
|
||||
|
||||
if dt == "text_delta" {
|
||||
if dt == "thinking_delta" {
|
||||
tPart, _ := delta["thinking"].(string)
|
||||
if tPart != "" {
|
||||
if acc := thinkingByIndex[blockIdx]; acc != nil {
|
||||
acc.text.WriteString(tPart)
|
||||
}
|
||||
oaiChunk := map[string]interface{}{
|
||||
"choices": []map[string]interface{}{
|
||||
{
|
||||
"delta": map[string]interface{}{
|
||||
"reasoning_content": tPart,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
b, _ := json.Marshal(oaiChunk)
|
||||
if !writeLine("data: " + string(b) + "\n\n") {
|
||||
pw.Close()
|
||||
return
|
||||
}
|
||||
}
|
||||
} else if dt == "signature_delta" {
|
||||
sigPart, _ := delta["signature"].(string)
|
||||
if sigPart != "" {
|
||||
if acc := thinkingByIndex[blockIdx]; acc != nil {
|
||||
acc.sig.WriteString(sigPart)
|
||||
}
|
||||
}
|
||||
} else if dt == "text_delta" {
|
||||
text, _ := delta["text"].(string)
|
||||
oaiChunk := map[string]interface{}{
|
||||
"choices": []map[string]interface{}{
|
||||
@@ -1031,6 +1108,21 @@ func (rt *claudeRoundTripper) RoundTrip(req *http.Request) (*http.Response, erro
|
||||
}
|
||||
}
|
||||
|
||||
case "content_block_stop":
|
||||
blockIdxFlt, _ := event["index"].(float64)
|
||||
blockIdx := int(blockIdxFlt)
|
||||
bt := blockIndexToType[blockIdx]
|
||||
if bt == "thinking" {
|
||||
if acc := thinkingByIndex[blockIdx]; acc != nil {
|
||||
finishedThinking = append(finishedThinking, claudeContentBlock{
|
||||
Type: "thinking",
|
||||
Thinking: acc.text.String(),
|
||||
Signature: acc.sig.String(),
|
||||
})
|
||||
delete(thinkingByIndex, blockIdx)
|
||||
}
|
||||
}
|
||||
|
||||
case "message_delta":
|
||||
d, _ := event["delta"].(map[string]interface{})
|
||||
if sr, ok := d["stop_reason"].(string); ok {
|
||||
@@ -1051,6 +1143,25 @@ func (rt *claudeRoundTripper) RoundTrip(req *http.Request) (*http.Response, erro
|
||||
}
|
||||
|
||||
case "message_stop":
|
||||
if len(finishedThinking) > 0 {
|
||||
suffix := appendClaudeReasoningRoundTrip("", finishedThinking)
|
||||
if strings.TrimSpace(suffix) != "" {
|
||||
oaiChunk := map[string]interface{}{
|
||||
"choices": []map[string]interface{}{
|
||||
{
|
||||
"delta": map[string]interface{}{
|
||||
"reasoning_content": suffix,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
b, _ := json.Marshal(oaiChunk)
|
||||
if !writeLine("data: " + string(b) + "\n\n") {
|
||||
pw.Close()
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
writeLine("data: [DONE]\n\n")
|
||||
pw.Close()
|
||||
return
|
||||
|
||||
@@ -0,0 +1,81 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// claudeReasoningRoundTripSep separates human-readable reasoning from a JSON payload of
|
||||
// Anthropic thinking blocks (with signatures) for multi-turn extended thinking + tools.
|
||||
// Not shown in UI (see DisplayReasoningContent).
|
||||
const claudeReasoningRoundTripSep = "\n---CSAI_CLAUDE_THINKING_BLOCKS---\n"
|
||||
|
||||
// DisplayReasoningContent returns reasoning text suitable for the UI (strips internal
|
||||
// Claude round-trip JSON suffix). Safe for DeepSeek/plain reasoning strings (no-op).
|
||||
func DisplayReasoningContent(s string) string {
|
||||
s = strings.TrimSpace(s)
|
||||
if s == "" {
|
||||
return ""
|
||||
}
|
||||
i := strings.LastIndex(s, claudeReasoningRoundTripSep)
|
||||
if i < 0 {
|
||||
return s
|
||||
}
|
||||
return strings.TrimSpace(s[:i])
|
||||
}
|
||||
|
||||
func appendClaudeReasoningRoundTrip(display string, blocks []claudeContentBlock) string {
|
||||
var payload []map[string]string
|
||||
for _, b := range blocks {
|
||||
if b.Type != "thinking" {
|
||||
continue
|
||||
}
|
||||
payload = append(payload, map[string]string{
|
||||
"type": b.Type,
|
||||
"thinking": b.Thinking,
|
||||
"signature": b.Signature,
|
||||
})
|
||||
}
|
||||
if len(payload) == 0 {
|
||||
return strings.TrimSpace(display)
|
||||
}
|
||||
js, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
return strings.TrimSpace(display)
|
||||
}
|
||||
d := strings.TrimSpace(display)
|
||||
if d == "" {
|
||||
return claudeReasoningRoundTripSep + string(js)
|
||||
}
|
||||
return d + claudeReasoningRoundTripSep + string(js)
|
||||
}
|
||||
|
||||
// parseClaudeReasoningAssistantBlocks extracts Anthropic thinking blocks from an OpenAI-style
|
||||
// reasoning_content string. When no suffix is present, blocks is nil (caller must not invent signatures).
|
||||
func parseClaudeReasoningAssistantBlocks(reasoningContent string) (display string, blocks []claudeContentBlock) {
|
||||
reasoningContent = strings.TrimSpace(reasoningContent)
|
||||
if reasoningContent == "" {
|
||||
return "", nil
|
||||
}
|
||||
idx := strings.LastIndex(reasoningContent, claudeReasoningRoundTripSep)
|
||||
if idx < 0 {
|
||||
return reasoningContent, nil
|
||||
}
|
||||
display = strings.TrimSpace(reasoningContent[:idx])
|
||||
jsonPart := strings.TrimSpace(reasoningContent[idx+len(claudeReasoningRoundTripSep):])
|
||||
var arr []struct {
|
||||
Type string `json:"type"`
|
||||
Thinking string `json:"thinking"`
|
||||
Signature string `json:"signature"`
|
||||
}
|
||||
if err := json.Unmarshal([]byte(jsonPart), &arr); err != nil {
|
||||
return reasoningContent, nil
|
||||
}
|
||||
for _, x := range arr {
|
||||
if x.Type != "thinking" {
|
||||
continue
|
||||
}
|
||||
blocks = append(blocks, claudeContentBlock{Type: "thinking", Thinking: x.Thinking, Signature: x.Signature})
|
||||
}
|
||||
return display, blocks
|
||||
}
|
||||
@@ -0,0 +1,102 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestDisplayReasoningContent(t *testing.T) {
|
||||
raw := "hello" + claudeReasoningRoundTripSep + `[{"type":"thinking","thinking":"x","signature":"sig"}]`
|
||||
if d := DisplayReasoningContent(raw); d != "hello" {
|
||||
t.Fatalf("got %q", d)
|
||||
}
|
||||
if DisplayReasoningContent("plain") != "plain" {
|
||||
t.Fatal()
|
||||
}
|
||||
}
|
||||
|
||||
func TestAppendParseClaudeReasoningRoundTrip(t *testing.T) {
|
||||
blocks := []claudeContentBlock{
|
||||
{Type: "thinking", Thinking: "a", Signature: "s1"},
|
||||
{Type: "thinking", Thinking: "b", Signature: "s2"},
|
||||
}
|
||||
s := appendClaudeReasoningRoundTrip("sum", blocks)
|
||||
if !strings.Contains(s, claudeReasoningRoundTripSep) {
|
||||
t.Fatal("missing sep")
|
||||
}
|
||||
display, back := parseClaudeReasoningAssistantBlocks(s)
|
||||
if display != "sum" || len(back) != 2 {
|
||||
t.Fatalf("display=%q len=%d", display, len(back))
|
||||
}
|
||||
if back[0].Signature != "s1" || back[1].Thinking != "b" {
|
||||
t.Fatalf("%+v", back)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvertOpenAIToClaude_AssistantReasoningReplay(t *testing.T) {
|
||||
rc := appendClaudeReasoningRoundTrip("vis", []claudeContentBlock{
|
||||
{Type: "thinking", Thinking: "t1", Signature: "sig1"},
|
||||
})
|
||||
payload := map[string]interface{}{
|
||||
"model": "claude-3-5-sonnet-latest",
|
||||
"messages": []interface{}{
|
||||
map[string]interface{}{
|
||||
"role": "assistant",
|
||||
"content": "out",
|
||||
"reasoning_content": rc,
|
||||
},
|
||||
},
|
||||
}
|
||||
req, err := convertOpenAIToClaude(payload)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(req.Messages) != 1 {
|
||||
t.Fatalf("messages=%d", len(req.Messages))
|
||||
}
|
||||
blocks := req.Messages[0].Content.Blocks
|
||||
if len(blocks) < 2 {
|
||||
t.Fatalf("blocks=%d", len(blocks))
|
||||
}
|
||||
if blocks[0].Type != "thinking" || blocks[0].Signature != "sig1" {
|
||||
t.Fatalf("first block %+v", blocks[0])
|
||||
}
|
||||
foundText := false
|
||||
for _, b := range blocks {
|
||||
if b.Type == "text" && b.Text == "out" {
|
||||
foundText = true
|
||||
}
|
||||
}
|
||||
if !foundText {
|
||||
t.Fatalf("blocks=%+v", blocks)
|
||||
}
|
||||
}
|
||||
|
||||
func TestClaudeToOpenAIResponseJSON_Thinking(t *testing.T) {
|
||||
claudeBody := []byte(`{
|
||||
"id":"msg_1","type":"message","role":"assistant","model":"x","stop_reason":"end_turn",
|
||||
"content":[
|
||||
{"type":"thinking","thinking":"step","signature":"sigx"},
|
||||
{"type":"text","text":"hi"}
|
||||
]
|
||||
}`)
|
||||
oai, err := claudeToOpenAIResponseJSON(claudeBody)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
var wrap map[string]interface{}
|
||||
if err := json.Unmarshal(oai, &wrap); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
choices := wrap["choices"].([]interface{})
|
||||
ch0 := choices[0].(map[string]interface{})
|
||||
msg := ch0["message"].(map[string]interface{})
|
||||
rc, _ := msg["reasoning_content"].(string)
|
||||
if !strings.Contains(rc, "step") || !strings.Contains(rc, claudeReasoningRoundTripSep) {
|
||||
t.Fatalf("reasoning_content=%q", rc)
|
||||
}
|
||||
if msg["content"] != "hi" {
|
||||
t.Fatal()
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,56 @@
|
||||
package openai
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestNormalizeStreamingDelta_RepeatedCharBoundary(t *testing.T) {
|
||||
// 流式在重复数字边界分片:不得把 "43" 的首字符与 "194" 尾字符误合并。
|
||||
cur, d := normalizeStreamingDelta("https://x:194", "43")
|
||||
if want := "https://x:19443"; cur != want {
|
||||
t.Fatalf("next: want %q got %q", want, cur)
|
||||
}
|
||||
if d != "43" {
|
||||
t.Fatalf("delta: want %q got %q", "43", d)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeStreamingDelta_CumulativePrefix(t *testing.T) {
|
||||
cur, d := normalizeStreamingDelta("今天", "今天天气")
|
||||
if cur != "今天天气" || d != "天气" {
|
||||
t.Fatalf("got cur=%q d=%q", cur, d)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeStreamingDelta_FullRetransmit(t *testing.T) {
|
||||
cur, d := normalizeStreamingDelta("今天", "今天")
|
||||
if d != "" || cur != "今天" {
|
||||
t.Fatalf("got cur=%q d=%q", cur, d)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeStreamingDelta_SingleRuneRepeated(t *testing.T) {
|
||||
cur, d := normalizeStreamingDelta("呀", "呀")
|
||||
if want := "呀呀"; cur != want {
|
||||
t.Fatalf("next: want %q got %q", want, cur)
|
||||
}
|
||||
if d != "呀" {
|
||||
t.Fatalf("delta: want %q got %q", "呀", d)
|
||||
}
|
||||
cur, d = normalizeStreamingDelta("4", "4")
|
||||
if want := "44"; cur != want {
|
||||
t.Fatalf("next: want %q got %q", want, cur)
|
||||
}
|
||||
if d != "4" {
|
||||
t.Fatalf("delta: want %q got %q", "4", d)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeStreamingDelta_CumulativeExtendsNumber(t *testing.T) {
|
||||
// 已缓冲 "194" 后收到累计串 "19443"(注意 "1943" 并非 "19443" 的前缀,不能靠误写的中间态测 HasPrefix)。
|
||||
cur, d := normalizeStreamingDelta("194", "19443")
|
||||
if want := "19443"; cur != want {
|
||||
t.Fatalf("next: want %q got %q", want, cur)
|
||||
}
|
||||
if d != "43" {
|
||||
t.Fatalf("delta: want %q got %q", "43", d)
|
||||
}
|
||||
}
|
||||
+12
-17
@@ -10,6 +10,7 @@ import (
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
"unicode/utf8"
|
||||
|
||||
"cyberstrike-ai/internal/config"
|
||||
|
||||
@@ -34,7 +35,15 @@ func (e *APIError) Error() string {
|
||||
}
|
||||
|
||||
// normalizeStreamingDelta 将可能是“累计片段/重发片段”的内容归一化为“纯增量”。
|
||||
// 部分兼容网关会返回累计 content;若直接 append 会出现重复文本(结巴)。
|
||||
// 部分兼容网关会返回累计 content;若直接 append 会出现重复文本。
|
||||
//
|
||||
// 注意:
|
||||
// - 不做「任意后缀与前缀重叠」合并;流式可能在重复字符边界分片("194"+"43"→"19443")。
|
||||
// - HasPrefix 仅在 incoming 严格长于 current 时视为累计全文,否则会把分片产生的第二个相同
|
||||
// 单字/单码点(叠字、44、22 等)误判为「整段重复」而吞字。
|
||||
// - incoming==current 仅当 current 长度 >1 个码点时才视为整包重发;单码点重复必须走拼接。
|
||||
// - 不再使用「current 以 incoming 结尾则丢弃」:否则 "1943"+"43" 会误吞增量(19443 显示成 1943)。
|
||||
// 若网关重复发送尾部片段,应重复送完整累计串,由 HasPrefix 分支去重。
|
||||
func normalizeStreamingDelta(current, incoming string) (next, delta string) {
|
||||
if incoming == "" {
|
||||
return current, ""
|
||||
@@ -42,26 +51,12 @@ func normalizeStreamingDelta(current, incoming string) (next, delta string) {
|
||||
if current == "" {
|
||||
return incoming, incoming
|
||||
}
|
||||
if incoming == current {
|
||||
return current, ""
|
||||
}
|
||||
if strings.HasPrefix(incoming, current) {
|
||||
if strings.HasPrefix(incoming, current) && len(incoming) > len(current) {
|
||||
return incoming, incoming[len(current):]
|
||||
}
|
||||
if strings.HasSuffix(current, incoming) {
|
||||
if incoming == current && utf8.RuneCountInString(current) > 1 {
|
||||
return current, ""
|
||||
}
|
||||
|
||||
// 边界重叠:current 后缀与 incoming 前缀重合,仅追加非重叠部分。
|
||||
max := len(current)
|
||||
if len(incoming) < max {
|
||||
max = len(incoming)
|
||||
}
|
||||
for overlap := max; overlap > 0; overlap-- {
|
||||
if current[len(current)-overlap:] == incoming[:overlap] {
|
||||
return current + incoming[overlap:], incoming[overlap:]
|
||||
}
|
||||
}
|
||||
return current + incoming, incoming
|
||||
}
|
||||
|
||||
|
||||
@@ -0,0 +1,20 @@
|
||||
package openai
|
||||
|
||||
// SSEAccumulatedKey 为 SSE progress 事件 data 中的服务端权威流式全文快照字段。
|
||||
// 前端应优先用该字段更新 buffer,避免对 delta 二次 normalize 导致叠字。
|
||||
const SSEAccumulatedKey = "accumulated"
|
||||
|
||||
// WithSSEAccumulated 在 progress data 中附带当前流式累计全文(权威快照)。
|
||||
func WithSSEAccumulated(data map[string]interface{}, accumulated string) map[string]interface{} {
|
||||
if data == nil {
|
||||
data = make(map[string]interface{}, 1)
|
||||
}
|
||||
data[SSEAccumulatedKey] = accumulated
|
||||
return data
|
||||
}
|
||||
|
||||
// NormalizeStreamingDelta 将可能是“累计片段/重发片段”的内容归一化为“纯增量”。
|
||||
// 与 unexported normalizeStreamingDelta 相同,供 agent / multiagent 等包在发 SSE 前累计正文。
|
||||
func NormalizeStreamingDelta(current, incoming string) (next, delta string) {
|
||||
return normalizeStreamingDelta(current, incoming)
|
||||
}
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user