mirror of
https://github.com/Ed1s0nZ/CyberStrikeAI.git
synced 2026-06-05 22:06:41 +02:00
Compare commits
142 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 97834c162e | |||
| 9276f2f144 | |||
| a454cada6a | |||
| 99b53d4fbc | |||
| a43a9deaea | |||
| ce88da84c9 | |||
| 15855c7073 | |||
| 43eb3e546b | |||
| 2d52c9b6ac | |||
| d5401b8b4c | |||
| 5fd4393a2e | |||
| a049f6b5c2 | |||
| acba8e5a39 | |||
| f826b91362 | |||
| 98c2de2a60 | |||
| 1c4d4b305b | |||
| f210ac9a03 | |||
| 6685076dfb | |||
| 7f322653f6 | |||
| 66ac2f1357 | |||
| c446e22d0c | |||
| 0358d3a67d | |||
| 9b82f265fd | |||
| 3d9cae58e4 | |||
| 1f1eadee5e | |||
| 0569255189 | |||
| 8ccf90d067 | |||
| b3be89f47d | |||
| b9bf8f62d4 | |||
| 05ca0c1480 | |||
| 47a4f3fc5b | |||
| a3b378ae9e | |||
| a904d26e78 | |||
| 7ba7476c4f | |||
| ae25a243ac | |||
| 23bd6288ff | |||
| fef21d3a24 | |||
| 933bba4517 | |||
| e1d65437cc | |||
| 9325aed1eb | |||
| dee2b3ab42 | |||
| a69bc93fa1 | |||
| b1a620bfce | |||
| 61b164eec2 | |||
| ba77e1837e | |||
| eacad60fd6 | |||
| 70bf5c93bf | |||
| 08bd278d8c | |||
| 22746d64a3 | |||
| 199392a5d5 | |||
| aafb4cb584 | |||
| 96e3dd397c | |||
| ec0f17145b | |||
| ed53da0999 | |||
| dc440fc511 | |||
| 009ae59033 | |||
| f348b3245a | |||
| 0018c5219c | |||
| 01a3e3677a | |||
| a12ecdb46f | |||
| 9f59230d74 | |||
| 085c6a1c72 | |||
| 7b3860971f | |||
| f6f7b7b237 | |||
| d5cf4b3b16 | |||
| 3e58d8355b | |||
| eb01ade63b | |||
| d1dc15fa44 | |||
| 73a39ef868 | |||
| a022baef03 | |||
| 59312d428e | |||
| 951d14ef14 | |||
| 0eb22da6e9 | |||
| 5fd9ef0514 | |||
| 9a4f3c7d35 | |||
| ead2ce3ecc | |||
| 8733f3a2d2 | |||
| 8642f3ba31 | |||
| 6a262a7367 | |||
| eb9192ddb3 | |||
| 5587e75628 | |||
| 74bbb453e2 | |||
| 66842f6206 | |||
| dc1779275d | |||
| 10dff937b1 | |||
| d4e1fe3bbe | |||
| 179976ae57 | |||
| 1c758bb98c | |||
| 17c4f38ee3 | |||
| cd7e57d121 | |||
| 0f2c3f65cc | |||
| 7779666e27 | |||
| c74bd4403b | |||
| 04d23ddb43 | |||
| 0874e84393 | |||
| 57f57f30b1 | |||
| f37d613a0c | |||
| 87d0ff9154 | |||
| b3418f39b8 | |||
| f9e1ca0e2d | |||
| 2c45879669 | |||
| 1cdcfa2c2d | |||
| eab5b73846 | |||
| d961ba1ec7 | |||
| 1ba5e57ec6 | |||
| 1216d25f96 | |||
| fde693408e | |||
| 352a81a869 | |||
| b2562b1010 | |||
| 0d8ba51087 | |||
| 0b847fcea3 | |||
| bf2f49fe62 | |||
| 75e64b1a86 | |||
| 2167735022 | |||
| 4ee292cc1f | |||
| 961205940f | |||
| ffe797bd06 | |||
| b6c864547e | |||
| da369c2edc | |||
| 54dc31a616 | |||
| 9e0b985221 | |||
| eb47077082 | |||
| f9a482857d | |||
| 679a68b12f | |||
| 840a26c7ef | |||
| 030e69c02d | |||
| d9683cdb44 | |||
| 60a063dd7d | |||
| 5f0c1805a7 | |||
| cb7e66001b | |||
| 4ea838f1d7 | |||
| 573648fc4b | |||
| f0e090abea | |||
| 549dcf518c | |||
| c74e20c54a | |||
| c94a9fd9e9 | |||
| ce9749a8ef | |||
| 145da12017 | |||
| 5111f4c311 | |||
| 8f6384a083 | |||
| 762f778e1e | |||
| 4a11ba8f14 |
@@ -174,9 +174,11 @@ The `run.sh` script will automatically:
|
|||||||
- ✅ Build the project
|
- ✅ Build the project
|
||||||
- ✅ Start the server
|
- ✅ Start the server
|
||||||
|
|
||||||
|
**Networking defaults:** `run.sh` starts the server with **`--https`** and the repo **`config.yaml`** (local self-signed TLS; better for many concurrent streams). Use **`./run.sh --http`** for plain HTTP. In production, set **`server.tls_cert_path`** / **`server.tls_key_path`** in **`config.yaml`** (see comments there). For manual runs, add **`--https`** or **`CYBERSTRIKE_HTTPS=1`**; if **`-config`** is wrong, the binary prints a short usage hint on stderr.
|
||||||
|
|
||||||
**First-Time Configuration:**
|
**First-Time Configuration:**
|
||||||
1. **Configure OpenAI-compatible API** (required before first use)
|
1. **Configure OpenAI-compatible API** (required before first use)
|
||||||
- Open http://localhost:8080 after launch
|
- After launch, open **`https://127.0.0.1:8080/`** (or **`https://localhost:8080/`**; replace **8080** with `server.port` in `config.yaml`) and accept the self-signed certificate warning once. If you used `./run.sh --http`, use **`http://`** instead.
|
||||||
- Go to `Settings` → Fill in your API credentials:
|
- Go to `Settings` → Fill in your API credentials:
|
||||||
```yaml
|
```yaml
|
||||||
openai:
|
openai:
|
||||||
@@ -197,21 +199,23 @@ The `run.sh` script will automatically:
|
|||||||
|
|
||||||
**Alternative Launch Methods:**
|
**Alternative Launch Methods:**
|
||||||
```bash
|
```bash
|
||||||
# Direct Go run (requires manual setup)
|
# Direct Go run (set up env yourself); add --https to match run.sh defaults
|
||||||
go run cmd/server/main.go
|
go run cmd/server/main.go --https
|
||||||
|
|
||||||
# Manual build
|
# Manual build
|
||||||
go build -o cyberstrike-ai cmd/server/main.go
|
go build -o cyberstrike-ai cmd/server/main.go
|
||||||
./cyberstrike-ai
|
./cyberstrike-ai --https
|
||||||
```
|
```
|
||||||
|
|
||||||
|
If server logs show `client sent an HTTP request to an HTTPS server`, a client is still using **`http://`** on a TLS-only port—switch the URL to **`https://`**.
|
||||||
|
|
||||||
**Note:** The Python virtual environment (`venv/`) is automatically created and managed by `run.sh`. Tools that require Python (like `api-fuzzer`, `http-framework-test`, etc.) will automatically use this environment.
|
**Note:** The Python virtual environment (`venv/`) is automatically created and managed by `run.sh`. Tools that require Python (like `api-fuzzer`, `http-framework-test`, etc.) will automatically use this environment.
|
||||||
|
|
||||||
### Version Update (No Breaking Changes)
|
### Version Update (No Breaking Changes)
|
||||||
|
|
||||||
**CyberStrikeAI one-click upgrade (recommended):**
|
**CyberStrikeAI one-click upgrade (recommended):**
|
||||||
1. (First time) enable the script: `chmod +x upgrade.sh`
|
1. (First time) enable the script: `chmod +x upgrade.sh`
|
||||||
2. Upgrade with: `./upgrade.sh` (optional flags: `--tag vX.Y.Z`, `--no-venv`, `--preserve-custom`, `--yes`)
|
2. Upgrade with: `./upgrade.sh` (optional flags: `--tag vX.Y.Z`, `--no-venv`, `--yes`). Local `tools/`, `roles/`, and `skills/` are always preserved.
|
||||||
3. The script will back up your `config.yaml` and `data/`, upgrade the code from GitHub Release, update `config.yaml`'s `version`, then restart the server.
|
3. The script will back up your `config.yaml` and `data/`, upgrade the code from GitHub Release, update `config.yaml`'s `version`, then restart the server.
|
||||||
|
|
||||||
Recommended one-liner:
|
Recommended one-liner:
|
||||||
@@ -281,7 +285,7 @@ Requirements / tips:
|
|||||||
- **Supervisor orchestrator**: fixed name **`orchestrator-supervisor.md`** (plus optional `orchestrator_instruction_supervisor`); requires at least one sub-agent.
|
- **Supervisor orchestrator**: fixed name **`orchestrator-supervisor.md`** (plus optional `orchestrator_instruction_supervisor`); requires at least one sub-agent.
|
||||||
- **Sub-agents** (for **deep** / **supervisor**): other `*.md` files (YAML front matter + body). Not used as **`task`** targets if marked orchestrator-only.
|
- **Sub-agents** (for **deep** / **supervisor**): other `*.md` files (YAML front matter + body). Not used as **`task`** targets if marked orchestrator-only.
|
||||||
- **Management** – Web UI: **Agents → Agent management**; API `/api/multi-agent/markdown-agents`.
|
- **Management** – Web UI: **Agents → Agent management**; API `/api/multi-agent/markdown-agents`.
|
||||||
- **Config** – `multi_agent` in `config.yaml`: `enabled`, `default_mode`, `robot_use_multi_agent`, `batch_use_multi_agent`, `max_iteration`, `plan_execute_loop_max_iterations`, per-mode orchestrator instruction fields, optional YAML `sub_agents` merged with disk (`id` clash → Markdown wins), **`eino_skills`**, **`eino_middleware`** (optional ADK middleware and Deep/Supervisor tuning).
|
- **Config** – `multi_agent` in `config.yaml`: `enabled`, `robot_default_agent_mode`, `batch_use_multi_agent`, `max_iteration`, `plan_execute_loop_max_iterations`, per-mode orchestrator instruction fields, optional YAML `sub_agents` merged with disk (`id` clash → Markdown wins), **`eino_skills`**, **`eino_middleware`** (optional ADK middleware and Deep/Supervisor tuning).
|
||||||
- **Details** – **[docs/MULTI_AGENT_EINO.md](docs/MULTI_AGENT_EINO.md)** (streaming, robots, batch, middleware caveats).
|
- **Details** – **[docs/MULTI_AGENT_EINO.md](docs/MULTI_AGENT_EINO.md)** (streaming, robots, batch, middleware caveats).
|
||||||
|
|
||||||
### Skills System (Agent Skills + Eino)
|
### Skills System (Agent Skills + Eino)
|
||||||
@@ -532,7 +536,7 @@ agents_dir: "agents" # Multi-agent Markdown definitions (orchestrator + sub-age
|
|||||||
multi_agent:
|
multi_agent:
|
||||||
enabled: false
|
enabled: false
|
||||||
default_mode: "single" # single | multi (UI default when multi-agent is enabled)
|
default_mode: "single" # single | multi (UI default when multi-agent is enabled)
|
||||||
robot_use_multi_agent: false
|
robot_default_agent_mode: react
|
||||||
batch_use_multi_agent: false
|
batch_use_multi_agent: false
|
||||||
orchestrator_instruction: "" # Deep; used when orchestrator.md body is empty
|
orchestrator_instruction: "" # Deep; used when orchestrator.md body is empty
|
||||||
# orchestrator_instruction_plan_execute / orchestrator_instruction_supervisor optional
|
# orchestrator_instruction_plan_execute / orchestrator_instruction_supervisor optional
|
||||||
|
|||||||
+11
-7
@@ -173,9 +173,11 @@ chmod +x run.sh && ./run.sh
|
|||||||
- ✅ 编译构建项目
|
- ✅ 编译构建项目
|
||||||
- ✅ 启动服务器
|
- ✅ 启动服务器
|
||||||
|
|
||||||
|
**网络默认:** `run.sh` 会以 **`--https`** 并传入项目根 **`config.yaml`** 启动(本机自签证书,多路流式场景更稳)。只要明文 HTTP 用 **`./run.sh --http`**。生产环境在 **`config.yaml`** 的 **`server.tls_cert_path` / `server.tls_key_path`** 配正式证书(见文件内注释)。手动启动可加 **`--https`** 或环境变量 **`CYBERSTRIKE_HTTPS=1`**;`-config` 写错时程序会在终端提示正确写法。
|
||||||
|
|
||||||
**首次配置:**
|
**首次配置:**
|
||||||
1. **配置 AI 模型 API**(首次使用前必填)
|
1. **配置 AI 模型 API**(首次使用前必填)
|
||||||
- 启动后访问 http://localhost:8080
|
- 启动后在浏览器打开 **`https://127.0.0.1:8080/`**(或 **`https://localhost:8080/`**;端口以 `config.yaml` 中 **`server.port`** 为准,默认 8080),并按提示信任自签证书。若使用 **`./run.sh --http`**,则改用 **`http://`** 访问。
|
||||||
- 进入 `设置` → 填写 API 配置信息:
|
- 进入 `设置` → 填写 API 配置信息:
|
||||||
```yaml
|
```yaml
|
||||||
openai:
|
openai:
|
||||||
@@ -196,20 +198,22 @@ chmod +x run.sh && ./run.sh
|
|||||||
|
|
||||||
**其他启动方式:**
|
**其他启动方式:**
|
||||||
```bash
|
```bash
|
||||||
# 直接运行(需手动配置环境)
|
# 直接运行(需自行配环境);与 run.sh 默认一致可加 --https
|
||||||
go run cmd/server/main.go
|
go run cmd/server/main.go --https
|
||||||
|
|
||||||
# 手动编译
|
# 手动编译
|
||||||
go build -o cyberstrike-ai cmd/server/main.go
|
go build -o cyberstrike-ai cmd/server/main.go
|
||||||
./cyberstrike-ai
|
./cyberstrike-ai --https
|
||||||
```
|
```
|
||||||
|
|
||||||
|
若日志出现 `client sent an HTTP request to an HTTPS server`,说明仍有客户端用 **`http://`** 访问只提供 HTTPS 的端口,请改为 **`https://`**。
|
||||||
|
|
||||||
**说明:** Python 虚拟环境(`venv/`)由 `run.sh` 自动创建和管理。需要 Python 的工具(如 `api-fuzzer`、`http-framework-test` 等)会自动使用该环境。
|
**说明:** Python 虚拟环境(`venv/`)由 `run.sh` 自动创建和管理。需要 Python 的工具(如 `api-fuzzer`、`http-framework-test` 等)会自动使用该环境。
|
||||||
|
|
||||||
### CyberStrikeAI 版本更新(无兼容性问题)
|
### CyberStrikeAI 版本更新(无兼容性问题)
|
||||||
|
|
||||||
1. (首次使用)启用脚本:`chmod +x upgrade.sh`
|
1. (首次使用)启用脚本:`chmod +x upgrade.sh`
|
||||||
2. 一键升级:`./upgrade.sh`(可选参数:`--tag vX.Y.Z`、`--no-venv`、`--preserve-custom`、`--yes`)
|
2. 一键升级:`./upgrade.sh`(可选参数:`--tag vX.Y.Z`、`--no-venv`、`--yes`)。本地的 `tools/`、`roles/`、`skills/` 会始终保留不被覆盖。
|
||||||
3. 脚本会备份你的 `config.yaml` 和 `data/`,从 GitHub Release 升级代码,更新 `config.yaml` 的 `version` 字段后重启服务。
|
3. 脚本会备份你的 `config.yaml` 和 `data/`,从 GitHub Release 升级代码,更新 `config.yaml` 的 `version` 字段后重启服务。
|
||||||
|
|
||||||
推荐的一键指令:
|
推荐的一键指令:
|
||||||
@@ -279,7 +283,7 @@ go build -o cyberstrike-ai cmd/server/main.go
|
|||||||
- **Supervisor 主代理**:固定 **`orchestrator-supervisor.md`**(另可配 `orchestrator_instruction_supervisor`);至少需一名子代理。
|
- **Supervisor 主代理**:固定 **`orchestrator-supervisor.md`**(另可配 `orchestrator_instruction_supervisor`);至少需一名子代理。
|
||||||
- **子代理**(**deep** / **supervisor**):其余 `*.md`;标成 orchestrator 的不会进入 `task` 列表。
|
- **子代理**(**deep** / **supervisor**):其余 `*.md`;标成 orchestrator 的不会进入 `task` 列表。
|
||||||
- **界面管理**:**Agents → Agent 管理**;API `/api/multi-agent/markdown-agents`。
|
- **界面管理**:**Agents → Agent 管理**;API `/api/multi-agent/markdown-agents`。
|
||||||
- **配置项**:`multi_agent`:`enabled`、`default_mode`、`robot_use_multi_agent`、`batch_use_multi_agent`、`max_iteration`、`plan_execute_loop_max_iterations`、各模式 orchestrator 指令字段、可选 YAML `sub_agents` 与目录合并(同 `id` → Markdown 优先)、**`eino_skills`**、**`eino_middleware`**。
|
- **配置项**:`multi_agent`:`enabled`、`robot_default_agent_mode`、`batch_use_multi_agent`、`max_iteration`、`plan_execute_loop_max_iterations`、各模式 orchestrator 指令字段、可选 YAML `sub_agents` 与目录合并(同 `id` → Markdown 优先)、**`eino_skills`**、**`eino_middleware`**。
|
||||||
- **更多细节**:[docs/MULTI_AGENT_EINO.md](docs/MULTI_AGENT_EINO.md)(流式、机器人、批量、中间件差异)。
|
- **更多细节**:[docs/MULTI_AGENT_EINO.md](docs/MULTI_AGENT_EINO.md)(流式、机器人、批量、中间件差异)。
|
||||||
|
|
||||||
### Skills 技能系统(Agent Skills + Eino)
|
### Skills 技能系统(Agent Skills + Eino)
|
||||||
@@ -530,7 +534,7 @@ agents_dir: "agents" # 多代理 Markdown(主代理 orchestrator.md + 子代
|
|||||||
multi_agent:
|
multi_agent:
|
||||||
enabled: false
|
enabled: false
|
||||||
default_mode: "single" # single | multi(开启多代理时的界面默认模式)
|
default_mode: "single" # single | multi(开启多代理时的界面默认模式)
|
||||||
robot_use_multi_agent: false
|
robot_default_agent_mode: react
|
||||||
batch_use_multi_agent: false
|
batch_use_multi_agent: false
|
||||||
orchestrator_instruction: "" # Deep;orchestrator.md 正文为空时使用
|
orchestrator_instruction: "" # Deep;orchestrator.md 正文为空时使用
|
||||||
# orchestrator_instruction_plan_execute / orchestrator_instruction_supervisor 可选
|
# orchestrator_instruction_plan_execute / orchestrator_instruction_supervisor 可选
|
||||||
|
|||||||
+43
-3
@@ -9,22 +9,62 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"os"
|
"os"
|
||||||
"os/signal"
|
"os/signal"
|
||||||
|
"strings"
|
||||||
"syscall"
|
"syscall"
|
||||||
)
|
)
|
||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
var configPath = flag.String("config", "config.yaml", "配置文件路径")
|
var configPath = flag.String("config", "config.yaml", "配置文件路径")
|
||||||
|
var httpsBootstrap = flag.Bool("https", false, "启用主站 HTTPS:未配置 tls_cert_path/tls_key_path 时使用内存自签证书(本地测试);与 run.sh 默认行为一致")
|
||||||
flag.Parse()
|
flag.Parse()
|
||||||
|
|
||||||
|
// 环境变量兼容(便于 systemd/docker 等不传参场景)
|
||||||
|
if !*httpsBootstrap {
|
||||||
|
v := strings.TrimSpace(os.Getenv("CYBERSTRIKE_HTTPS"))
|
||||||
|
if v == "1" || strings.EqualFold(v, "true") || strings.EqualFold(v, "yes") {
|
||||||
|
*httpsBootstrap = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// 加载配置
|
// 加载配置
|
||||||
cfg, err := config.Load(*configPath)
|
cp := strings.TrimSpace(*configPath)
|
||||||
|
if cp == "" {
|
||||||
|
cp = "config.yaml"
|
||||||
|
}
|
||||||
|
if strings.HasPrefix(cp, "-") {
|
||||||
|
fmt.Fprintf(os.Stderr, "无效的 -config 路径 %q。\n若同时需要 HTTPS,请写成: ./cyberstrike-ai --https -config config.yaml(-config 后必须是 yaml 文件路径)。\n", cp)
|
||||||
|
os.Exit(2)
|
||||||
|
}
|
||||||
|
cfg, err := config.Load(cp)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
fmt.Printf("加载配置失败: %v\n", err)
|
fmt.Printf("加载配置失败: %v\n", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if *httpsBootstrap {
|
||||||
|
config.ApplyDevHTTPSBootstrap(cfg)
|
||||||
|
}
|
||||||
|
|
||||||
|
port := cfg.Server.Port
|
||||||
|
if port <= 0 {
|
||||||
|
port = 8080
|
||||||
|
}
|
||||||
|
scheme := "http"
|
||||||
|
if config.MainWebUIUsesHTTPS(&cfg.Server) {
|
||||||
|
scheme = "https"
|
||||||
|
}
|
||||||
|
fmt.Println()
|
||||||
|
fmt.Printf("→ Web 界面: %s://127.0.0.1:%d/\n", scheme, port)
|
||||||
|
if scheme == "https" && cfg.Server.TLSAutoSelfSign {
|
||||||
|
fmt.Println(" (内存自签证书:浏览器首次需确认「继续访问」)")
|
||||||
|
}
|
||||||
|
if scheme == "https" && config.ServerHTTPRedirectEnabled(&cfg.Server) {
|
||||||
|
fmt.Printf(" (http://127.0.0.1:%d/ 将自动跳转到 HTTPS)\n", port)
|
||||||
|
}
|
||||||
|
fmt.Println()
|
||||||
|
|
||||||
// MCP 启用且 auth_header_value 为空时,自动生成随机密钥并写回配置
|
// MCP 启用且 auth_header_value 为空时,自动生成随机密钥并写回配置
|
||||||
if err := config.EnsureMCPAuth(*configPath, cfg); err != nil {
|
if err := config.EnsureMCPAuth(cp, cfg); err != nil {
|
||||||
fmt.Printf("MCP 鉴权配置失败: %v\n", err)
|
fmt.Printf("MCP 鉴权配置失败: %v\n", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -44,7 +84,7 @@ func main() {
|
|||||||
signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM)
|
signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM)
|
||||||
|
|
||||||
// 创建应用
|
// 创建应用
|
||||||
application, err := app.New(cfg, log)
|
application, err := app.New(cfg, log, cp)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatal("应用初始化失败", "error", err)
|
log.Fatal("应用初始化失败", "error", err)
|
||||||
}
|
}
|
||||||
|
|||||||
+39
-8
@@ -10,11 +10,22 @@
|
|||||||
# ============================================
|
# ============================================
|
||||||
|
|
||||||
# 前端显示的版本号(可选,不填则显示默认版本)
|
# 前端显示的版本号(可选,不填则显示默认版本)
|
||||||
version: "v1.6.9"
|
version: "v1.6.22"
|
||||||
# 服务器配置
|
# 服务器配置
|
||||||
server:
|
server:
|
||||||
host: 0.0.0.0 # 监听地址,0.0.0.0 表示监听所有网络接口
|
host: 0.0.0.0 # 监听地址,0.0.0.0 表示监听所有网络接口
|
||||||
port: 8080 # HTTP 服务端口,可通过浏览器访问 http://localhost:8080
|
port: 8080 # 服务端口;未启用 TLS 时为 http://localhost:8080
|
||||||
|
# --- 可选:HTTPS + HTTP/2(缓解浏览器对同源 HTTP/1.1 的并发连接数限制,多路 Deep 流式更稳)---
|
||||||
|
# 启用 TLS 的条件(满足其一即可):tls_enabled: true,或 tls_auto_self_sign: true,或同时配置了 tls_cert_path + tls_key_path。
|
||||||
|
# 启用后请用 https://127.0.0.1:<本端口>/ 访问;若仍用 http:// 访问同端口,将自动 308 跳转到 HTTPS(可用 tls_http_redirect: false 关闭)。
|
||||||
|
tls_enabled: true
|
||||||
|
# 启用 HTTPS 时,明文 HTTP 是否自动跳转到 HTTPS(默认 true;同端口嗅探 TLS/HTTP 后分流)
|
||||||
|
# tls_http_redirect: true
|
||||||
|
# 方式 A(推荐生产):PEM 证书与私钥路径
|
||||||
|
# tls_cert_path: /path/to/fullchain.pem
|
||||||
|
# tls_key_path: /path/to/privkey.pem
|
||||||
|
# 方式 B(仅本地/测试):无证书文件时内存自签(浏览器会提示不受信任;SAN 含 localhost / 127.0.0.1)
|
||||||
|
tls_auto_self_sign: true
|
||||||
# 认证配置
|
# 认证配置
|
||||||
auth:
|
auth:
|
||||||
password: # Web 登录密码,请修改为强密码
|
password: # Web 登录密码,请修改为强密码
|
||||||
@@ -23,6 +34,12 @@ auth:
|
|||||||
log:
|
log:
|
||||||
level: info # 日志级别: debug(调试), info(信息), warn(警告), error(错误)
|
level: info # 日志级别: debug(调试), info(信息), warn(警告), error(错误)
|
||||||
output: stdout # 日志输出位置: stdout(标准输出), stderr(标准错误), 或文件路径
|
output: stdout # 日志输出位置: stdout(标准输出), stderr(标准错误), 或文件路径
|
||||||
|
# 平台操作审计(系统设置 -> 日志审计;不记录对话正文与每次工具调用)
|
||||||
|
audit:
|
||||||
|
enabled: true
|
||||||
|
retention_days: 15 # 0 表示不自动清理
|
||||||
|
max_detail_bytes: 8192
|
||||||
|
auth_failure_cooldown_seconds: 60 # 同一 IP 登录/改密失败审计最短间隔(秒);未配置时默认 60;-1 关闭节流
|
||||||
# ============================================
|
# ============================================
|
||||||
# 对话相关配置
|
# 对话相关配置
|
||||||
# ============================================
|
# ============================================
|
||||||
@@ -43,8 +60,8 @@ openai:
|
|||||||
max_total_tokens: 120000 # LLM 相关上下文的最大 Token 数限制(内存压缩和攻击链构建会共用此配置)
|
max_total_tokens: 120000 # LLM 相关上下文的最大 Token 数限制(内存压缩和攻击链构建会共用此配置)
|
||||||
# Eino 路径模型推理:DeepSeek/OpenAI 为 thinking / reasoning_effort 等;provider 为 claude 时合并为 Anthropic 顶层 thinking(extended thinking),mode: off 关闭
|
# Eino 路径模型推理:DeepSeek/OpenAI 为 thinking / reasoning_effort 等;provider 为 claude 时合并为 Anthropic 顶层 thinking(extended thinking),mode: off 关闭
|
||||||
reasoning:
|
reasoning:
|
||||||
mode: off # auto | on | off;off 时不附加任何推理扩展字段
|
mode: on # auto | on | off;off 时不附加任何推理扩展字段
|
||||||
effort: max # low | medium | high | max;空表示不指定(openai_compat 下 auto 且无强度时不发请求扩展)
|
effort: high # low | medium | high | max | xhigh(最高档:OpenAI 常用 xhigh,部分网关用 max,原样下发);空表示不指定
|
||||||
allow_client_reasoning: true # false 时忽略对话请求体 reasoning,仅以下方为准
|
allow_client_reasoning: true # false 时忽略对话请求体 reasoning,仅以下方为准
|
||||||
profile: openai_compat # auto | deepseek_compat | openai_compat | output_config_effort
|
profile: openai_compat # auto | deepseek_compat | openai_compat | output_config_effort
|
||||||
# extra_request_fields: {} # 可选:管理员自定义根级 JSON 片段(高级)
|
# extra_request_fields: {} # 可选:管理员自定义根级 JSON 片段(高级)
|
||||||
@@ -60,21 +77,23 @@ fofa:
|
|||||||
# Agent 配置
|
# Agent 配置
|
||||||
# 达到最大迭代次数时,AI 会自动总结测试结果
|
# 达到最大迭代次数时,AI 会自动总结测试结果
|
||||||
agent:
|
agent:
|
||||||
max_iterations: 120 # 最大迭代次数,AI 代理最多执行多少轮工具调用
|
max_iterations: 1200 # 最大迭代次数,AI 代理最多执行多少轮工具调用
|
||||||
large_result_threshold: 102400 # 大结果阈值(字节),默认50KB,超过此大小会自动保存到存储
|
large_result_threshold: 102400 # 大结果阈值(字节),默认50KB,超过此大小会自动保存到存储
|
||||||
result_storage_dir: tmp # 结果存储目录,大结果会保存在此目录下
|
result_storage_dir: tmp # 结果存储目录,大结果会保存在此目录下
|
||||||
tool_timeout_minutes: 30 # 单次工具执行最大时长(分钟),超时自动终止;0 表示不限制(不推荐,易出现长时间挂起)
|
tool_timeout_minutes: 60 # 单次工具执行最大时长(分钟),超时自动终止;0 表示不限制(不推荐,易出现长时间挂起)
|
||||||
# system_prompt_path: prompts/single-react.md # 可选:单代理系统提示文件(相对本配置文件所在目录);非空且可读时替换内置提示
|
# system_prompt_path: prompts/single-react.md # 可选:单代理系统提示文件(相对本配置文件所在目录);非空且可读时替换内置提示
|
||||||
|
|
||||||
|
system_prompt_path: ""
|
||||||
# 人机协同(HITL)全局白名单:此处列出的工具始终免审批,与对话页「白名单工具(免审批,逗号分隔)」合并为并集;侧栏「应用」可合并写入本列表并立即生效。
|
# 人机协同(HITL)全局白名单:此处列出的工具始终免审批,与对话页「白名单工具(免审批,逗号分隔)」合并为并集;侧栏「应用」可合并写入本列表并立即生效。
|
||||||
hitl:
|
hitl:
|
||||||
# 按你环境里的真实工具名增删(与侧栏一致、小写不敏感);不需要全局免审批可改为 []
|
# 按你环境里的真实工具名增删(与侧栏一致、小写不敏感);不需要全局免审批可改为 []
|
||||||
tool_whitelist: [read_file, list_dir, glob, grep]
|
tool_whitelist: [read_file, list_dir, glob, grep]
|
||||||
# 多代理(CloudWeGo Eino DeepAgent,与上方单 Agent /api/agent-loop 并存)
|
# 多代理(CloudWeGo Eino DeepAgent,与上方单 Agent /api/agent-loop 并存)
|
||||||
# 依赖在 go.mod 中拉取;若下载失败可设置: go env -w GOPROXY=https://goproxy.cn,direct
|
# 依赖在 go.mod 中拉取;若下载失败可设置: go env -w GOPROXY=https://goproxy.cn,direct
|
||||||
# 启用后需重启服务才会注册 /api/multi-agent 与 /api/multi-agent/stream;Deep / Plan-Execute / Supervisor 由对话页与 WebShell 所选模式在请求体中传入;机器人/批量无请求体时固定按 deep
|
# 启用后需重启服务才会注册 /api/multi-agent 与 /api/multi-agent/stream;Deep / Plan-Execute / Supervisor 由对话页与 WebShell 所选模式在请求体中传入;机器人按 robot_default_agent_mode
|
||||||
multi_agent:
|
multi_agent:
|
||||||
enabled: true
|
enabled: true
|
||||||
robot_use_multi_agent: true # true 时企业微信/钉钉/飞书机器人也走 Eino 多代理(成本更高)
|
robot_default_agent_mode: eino_single # 企微/钉钉/飞书机器人默认对话模式:react | eino_single | deep | plan_execute | supervisor
|
||||||
batch_use_multi_agent: false # true 时「批量任务」队列中每个子任务也走 Eino 多代理(成本更高)
|
batch_use_multi_agent: false # true 时「批量任务」队列中每个子任务也走 Eino 多代理(成本更高)
|
||||||
max_iteration: 0 # 主代理 / plan_execute 执行器最大轮次,0 表示沿用 agent.max_iterations
|
max_iteration: 0 # 主代理 / plan_execute 执行器最大轮次,0 表示沿用 agent.max_iterations
|
||||||
# plan_execute 专用:execute↔replan 外层循环上限,0 表示 Eino 默认 10。当前实现下 Executor 会挂载 patch/reduction/tool_search 等前置中间件。
|
# plan_execute 专用:execute↔replan 外层循环上限,0 表示 Eino 默认 10。当前实现下 Executor 会挂载 patch/reduction/tool_search 等前置中间件。
|
||||||
@@ -114,6 +133,8 @@ multi_agent:
|
|||||||
plan_execute_max_step_result_runes: 4000 # plan_execute 每步结果最大字符数(超出截断)
|
plan_execute_max_step_result_runes: 4000 # plan_execute 每步结果最大字符数(超出截断)
|
||||||
plan_execute_keep_last_steps: 8 # plan_execute 仅保留最近 N 步正文,早期步骤折叠为标题
|
plan_execute_keep_last_steps: 8 # plan_execute 仅保留最近 N 步正文,早期步骤折叠为标题
|
||||||
checkpoint_dir: "" # 非空:为 adk.NewRunner 启用按会话子目录的文件型 CheckPointStore,便于中断恢复持久化;Resume 的 HTTP/前端流程需另行对接
|
checkpoint_dir: "" # 非空:为 adk.NewRunner 启用按会话子目录的文件型 CheckPointStore,便于中断恢复持久化;Resume 的 HTTP/前端流程需另行对接
|
||||||
|
run_retry_max_attempts: 0 # >0:429/5xx/网络抖动时 ADK 运行循环指数退避续跑次数;0=默认 10
|
||||||
|
run_retry_max_backoff_sec: 0 # 单次退避上限秒数;0=默认 30
|
||||||
deep_output_key: "" # 非空:将最终助手输出写入 adk session 的键名(Deep 与 Supervisor 主代理);空表示不写入
|
deep_output_key: "" # 非空:将最终助手输出写入 adk session 的键名(Deep 与 Supervisor 主代理);空表示不写入
|
||||||
deep_model_retry_max_retries: 0 # >0:ChatModel 调用失败时的框架级最大重试次数(Deep 与 Supervisor 主);0:不重试
|
deep_model_retry_max_retries: 0 # >0:ChatModel 调用失败时的框架级最大重试次数(Deep 与 Supervisor 主);0:不重试
|
||||||
task_tool_description_prefix: "" # 非空:仅 Deep 的 task 工具使用自定义描述前缀,运行时会拼接子代理名称;空则走 Eino 默认生成逻辑
|
task_tool_description_prefix: "" # 非空:仅 Deep 的 task 工具使用自定义描述前缀,运行时会拼接子代理名称;空则走 Eino 默认生成逻辑
|
||||||
@@ -224,6 +245,14 @@ knowledge:
|
|||||||
# 用于在手机端通过企业微信/钉钉/飞书与 CyberStrikeAI 对话,无需部署在服务器上也可使用
|
# 用于在手机端通过企业微信/钉钉/飞书与 CyberStrikeAI 对话,无需部署在服务器上也可使用
|
||||||
# 在系统设置 -> 机器人设置 中可配置
|
# 在系统设置 -> 机器人设置 中可配置
|
||||||
robots:
|
robots:
|
||||||
|
wechat: # 微信 iLink(个人微信 ClawBot,扫码绑定)
|
||||||
|
enabled: false
|
||||||
|
bot_token: ""
|
||||||
|
ilink_bot_id: ""
|
||||||
|
ilink_user_id: ""
|
||||||
|
base_url: https://ilinkai.weixin.qq.com
|
||||||
|
bot_type: "3"
|
||||||
|
bot_agent: CyberStrikeAI/1.0
|
||||||
wecom: # 企业微信
|
wecom: # 企业微信
|
||||||
enabled: false
|
enabled: false
|
||||||
token: ""
|
token: ""
|
||||||
@@ -235,11 +264,13 @@ robots:
|
|||||||
enabled: false
|
enabled: false
|
||||||
client_id: ""
|
client_id: ""
|
||||||
client_secret: ""
|
client_secret: ""
|
||||||
|
allow_conversation_id_fallback: false
|
||||||
lark: # 飞书
|
lark: # 飞书
|
||||||
enabled: false
|
enabled: false
|
||||||
app_id: ""
|
app_id: ""
|
||||||
app_secret: ""
|
app_secret: ""
|
||||||
verify_token: ""
|
verify_token: ""
|
||||||
|
allow_chat_id_fallback: false
|
||||||
# ============================================
|
# ============================================
|
||||||
# Skills 相关配置
|
# Skills 相关配置
|
||||||
# ============================================
|
# ============================================
|
||||||
|
|||||||
@@ -27,12 +27,14 @@ require (
|
|||||||
github.com/open-dingtalk/dingtalk-stream-sdk-go v0.9.1
|
github.com/open-dingtalk/dingtalk-stream-sdk-go v0.9.1
|
||||||
github.com/pkoukk/tiktoken-go v0.1.8
|
github.com/pkoukk/tiktoken-go v0.1.8
|
||||||
github.com/robfig/cron/v3 v3.0.1
|
github.com/robfig/cron/v3 v3.0.1
|
||||||
|
github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e
|
||||||
go.opentelemetry.io/otel v1.34.0
|
go.opentelemetry.io/otel v1.34.0
|
||||||
go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp v1.34.0
|
go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp v1.34.0
|
||||||
go.opentelemetry.io/otel/exporters/stdout/stdouttrace v1.34.0
|
go.opentelemetry.io/otel/exporters/stdout/stdouttrace v1.34.0
|
||||||
go.opentelemetry.io/otel/sdk v1.34.0
|
go.opentelemetry.io/otel/sdk v1.34.0
|
||||||
go.opentelemetry.io/otel/trace v1.34.0
|
go.opentelemetry.io/otel/trace v1.34.0
|
||||||
go.uber.org/zap v1.26.0
|
go.uber.org/zap v1.26.0
|
||||||
|
golang.org/x/net v0.35.0
|
||||||
golang.org/x/text v0.26.0
|
golang.org/x/text v0.26.0
|
||||||
golang.org/x/time v0.14.0
|
golang.org/x/time v0.14.0
|
||||||
gopkg.in/yaml.v3 v3.0.1
|
gopkg.in/yaml.v3 v3.0.1
|
||||||
@@ -88,7 +90,6 @@ require (
|
|||||||
golang.org/x/arch v0.15.0 // indirect
|
golang.org/x/arch v0.15.0 // indirect
|
||||||
golang.org/x/crypto v0.39.0 // indirect
|
golang.org/x/crypto v0.39.0 // indirect
|
||||||
golang.org/x/exp v0.0.0-20250305212735-054e65f0b394 // indirect
|
golang.org/x/exp v0.0.0-20250305212735-054e65f0b394 // indirect
|
||||||
golang.org/x/net v0.34.0 // indirect
|
|
||||||
golang.org/x/oauth2 v0.30.0 // indirect
|
golang.org/x/oauth2 v0.30.0 // indirect
|
||||||
golang.org/x/sys v0.33.0 // indirect
|
golang.org/x/sys v0.33.0 // indirect
|
||||||
google.golang.org/genproto/googleapis/api v0.0.0-20250115164207-1a7da9e5054f // indirect
|
google.golang.org/genproto/googleapis/api v0.0.0-20250115164207-1a7da9e5054f // indirect
|
||||||
|
|||||||
@@ -163,6 +163,8 @@ github.com/rollbar/rollbar-go v1.0.2/go.mod h1:AcFs5f0I+c71bpHlXNNDbOWJiKwjFDtIS
|
|||||||
github.com/sirupsen/logrus v1.2.0/go.mod h1:LxeOpSwHxABJmUn/MG1IvRgCAasNZTLOkJPxbbu5VWo=
|
github.com/sirupsen/logrus v1.2.0/go.mod h1:LxeOpSwHxABJmUn/MG1IvRgCAasNZTLOkJPxbbu5VWo=
|
||||||
github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ=
|
github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ=
|
||||||
github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ=
|
github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ=
|
||||||
|
github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e h1:MRM5ITcdelLK2j1vwZ3Je0FKVCfqOLp5zO6trqMLYs0=
|
||||||
|
github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e/go.mod h1:XV66xRDqSt+GTGFMVlhk3ULuV0y9ZmzeVGR4mloJI3M=
|
||||||
github.com/slongfield/pyfmt v0.0.0-20220222012616-ea85ff4c361f h1:Z2cODYsUxQPofhpYRMQVwWz4yUVpHF+vPi+eUdruUYI=
|
github.com/slongfield/pyfmt v0.0.0-20220222012616-ea85ff4c361f h1:Z2cODYsUxQPofhpYRMQVwWz4yUVpHF+vPi+eUdruUYI=
|
||||||
github.com/slongfield/pyfmt v0.0.0-20220222012616-ea85ff4c361f/go.mod h1:JqzWyvTuI2X4+9wOHmKSQCYxybB/8j6Ko43qVmXDuZg=
|
github.com/slongfield/pyfmt v0.0.0-20220222012616-ea85ff4c361f/go.mod h1:JqzWyvTuI2X4+9wOHmKSQCYxybB/8j6Ko43qVmXDuZg=
|
||||||
github.com/smarty/assertions v1.16.0 h1:EvHNkdRA4QHMrn75NZSoUQ/mAUXAYWfatfB01yTCzfY=
|
github.com/smarty/assertions v1.16.0 h1:EvHNkdRA4QHMrn75NZSoUQ/mAUXAYWfatfB01yTCzfY=
|
||||||
@@ -245,8 +247,8 @@ golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn
|
|||||||
golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
|
golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
|
||||||
golang.org/x/net v0.0.0-20200226121028-0de0cce0169b/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
|
golang.org/x/net v0.0.0-20200226121028-0de0cce0169b/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
|
||||||
golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU=
|
golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU=
|
||||||
golang.org/x/net v0.34.0 h1:Mb7Mrk043xzHgnRM88suvJFwzVrRfHEHJEl5/71CKw0=
|
golang.org/x/net v0.35.0 h1:T5GQRQb2y08kTAByq9L4/bz8cipCdA8FbRTXewonqY8=
|
||||||
golang.org/x/net v0.34.0/go.mod h1:di0qlW3YNM5oh6GqDGQr92MyTozJPmybPK4Ev/Gm31k=
|
golang.org/x/net v0.35.0/go.mod h1:EglIi67kWsHKlRzzVMUD93VMSWGFOMSZgxFjparz1Qk=
|
||||||
golang.org/x/oauth2 v0.30.0 h1:dnDm7JmhM45NNpd8FDDeLhK6FwqbOf4MLCM9zb1BOHI=
|
golang.org/x/oauth2 v0.30.0 h1:dnDm7JmhM45NNpd8FDDeLhK6FwqbOf4MLCM9zb1BOHI=
|
||||||
golang.org/x/oauth2 v0.30.0/go.mod h1:B++QgG3ZKulg6sRPGD/mqlHQs5rB3Ml9erfeDY7xKlU=
|
golang.org/x/oauth2 v0.30.0/go.mod h1:B++QgG3ZKulg6sRPGD/mqlHQs5rB3Ml9erfeDY7xKlU=
|
||||||
golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||||
|
|||||||
+32
-8
@@ -598,11 +598,17 @@ func (a *Agent) AgentLoopWithProgress(ctx context.Context, userInput string, his
|
|||||||
thinkingStreamSeq++
|
thinkingStreamSeq++
|
||||||
thinkingStreamId := fmt.Sprintf("thinking-stream-%s-%d-%d", conversationID, i+1, thinkingStreamSeq)
|
thinkingStreamId := fmt.Sprintf("thinking-stream-%s-%d-%d", conversationID, i+1, thinkingStreamSeq)
|
||||||
thinkingStreamStarted := false
|
thinkingStreamStarted := false
|
||||||
|
var thinkingWire string
|
||||||
|
|
||||||
response, err := a.callOpenAIStreamWithToolCalls(ctx, messages, tools, func(delta string) error {
|
response, err := a.callOpenAIStreamWithToolCalls(ctx, messages, tools, func(delta string) error {
|
||||||
if delta == "" {
|
if delta == "" {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
var deltaOut string
|
||||||
|
thinkingWire, deltaOut = openai.NormalizeStreamingDelta(thinkingWire, delta)
|
||||||
|
if deltaOut == "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
if !thinkingStreamStarted {
|
if !thinkingStreamStarted {
|
||||||
thinkingStreamStarted = true
|
thinkingStreamStarted = true
|
||||||
sendProgress("thinking_stream_start", " ", map[string]interface{}{
|
sendProgress("thinking_stream_start", " ", map[string]interface{}{
|
||||||
@@ -611,10 +617,10 @@ func (a *Agent) AgentLoopWithProgress(ctx context.Context, userInput string, his
|
|||||||
"toolStream": false,
|
"toolStream": false,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
sendProgress("thinking_stream_delta", delta, map[string]interface{}{
|
sendProgress("thinking_stream_delta", deltaOut, openai.WithSSEAccumulated(map[string]interface{}{
|
||||||
"streamId": thinkingStreamId,
|
"streamId": thinkingStreamId,
|
||||||
"iteration": i + 1,
|
"iteration": i + 1,
|
||||||
})
|
}, thinkingWire))
|
||||||
return nil
|
return nil
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -827,10 +833,16 @@ func (a *Agent) AgentLoopWithProgress(ctx context.Context, userInput string, his
|
|||||||
"mcpExecutionIds": result.MCPExecutionIDs,
|
"mcpExecutionIds": result.MCPExecutionIDs,
|
||||||
"messageGeneratedBy": "summary",
|
"messageGeneratedBy": "summary",
|
||||||
})
|
})
|
||||||
|
var summaryWire string
|
||||||
streamText, _ := a.callOpenAIStreamText(ctx, messages, []Tool{}, func(delta string) error {
|
streamText, _ := a.callOpenAIStreamText(ctx, messages, []Tool{}, func(delta string) error {
|
||||||
sendProgress("response_delta", delta, map[string]interface{}{
|
var deltaOut string
|
||||||
|
summaryWire, deltaOut = openai.NormalizeStreamingDelta(summaryWire, delta)
|
||||||
|
if deltaOut == "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
sendProgress("response_delta", deltaOut, openai.WithSSEAccumulated(map[string]interface{}{
|
||||||
"conversationId": conversationID,
|
"conversationId": conversationID,
|
||||||
})
|
}, summaryWire))
|
||||||
return nil
|
return nil
|
||||||
})
|
})
|
||||||
if strings.TrimSpace(streamText) != "" {
|
if strings.TrimSpace(streamText) != "" {
|
||||||
@@ -874,10 +886,16 @@ func (a *Agent) AgentLoopWithProgress(ctx context.Context, userInput string, his
|
|||||||
"mcpExecutionIds": result.MCPExecutionIDs,
|
"mcpExecutionIds": result.MCPExecutionIDs,
|
||||||
"messageGeneratedBy": "summary",
|
"messageGeneratedBy": "summary",
|
||||||
})
|
})
|
||||||
|
var summaryWire string
|
||||||
streamText, _ := a.callOpenAIStreamText(ctx, messages, []Tool{}, func(delta string) error {
|
streamText, _ := a.callOpenAIStreamText(ctx, messages, []Tool{}, func(delta string) error {
|
||||||
sendProgress("response_delta", delta, map[string]interface{}{
|
var deltaOut string
|
||||||
|
summaryWire, deltaOut = openai.NormalizeStreamingDelta(summaryWire, delta)
|
||||||
|
if deltaOut == "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
sendProgress("response_delta", deltaOut, openai.WithSSEAccumulated(map[string]interface{}{
|
||||||
"conversationId": conversationID,
|
"conversationId": conversationID,
|
||||||
})
|
}, summaryWire))
|
||||||
return nil
|
return nil
|
||||||
})
|
})
|
||||||
if strings.TrimSpace(streamText) != "" {
|
if strings.TrimSpace(streamText) != "" {
|
||||||
@@ -921,10 +939,16 @@ func (a *Agent) AgentLoopWithProgress(ctx context.Context, userInput string, his
|
|||||||
"mcpExecutionIds": result.MCPExecutionIDs,
|
"mcpExecutionIds": result.MCPExecutionIDs,
|
||||||
"messageGeneratedBy": "max_iter_summary",
|
"messageGeneratedBy": "max_iter_summary",
|
||||||
})
|
})
|
||||||
|
var summaryWire string
|
||||||
streamText, _ := a.callOpenAIStreamText(ctx, messages, []Tool{}, func(delta string) error {
|
streamText, _ := a.callOpenAIStreamText(ctx, messages, []Tool{}, func(delta string) error {
|
||||||
sendProgress("response_delta", delta, map[string]interface{}{
|
var deltaOut string
|
||||||
|
summaryWire, deltaOut = openai.NormalizeStreamingDelta(summaryWire, delta)
|
||||||
|
if deltaOut == "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
sendProgress("response_delta", deltaOut, openai.WithSSEAccumulated(map[string]interface{}{
|
||||||
"conversationId": conversationID,
|
"conversationId": conversationID,
|
||||||
})
|
}, summaryWire))
|
||||||
return nil
|
return nil
|
||||||
})
|
})
|
||||||
if strings.TrimSpace(streamText) != "" {
|
if strings.TrimSpace(streamText) != "" {
|
||||||
|
|||||||
@@ -0,0 +1,167 @@
|
|||||||
|
package agent
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ParseTraceMessages 解析落库的 last_react_input(OpenAI 风格 messages JSON 数组)。
|
||||||
|
func ParseTraceMessages(traceInputJSON string) ([]ChatMessage, error) {
|
||||||
|
traceInputJSON = strings.TrimSpace(traceInputJSON)
|
||||||
|
if traceInputJSON == "" {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
var raw []map[string]interface{}
|
||||||
|
if err := json.Unmarshal([]byte(traceInputJSON), &raw); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
out := make([]ChatMessage, 0, len(raw))
|
||||||
|
for _, msgMap := range raw {
|
||||||
|
msg := ChatMessage{}
|
||||||
|
role, _ := msgMap["role"].(string)
|
||||||
|
if role == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
msg.Role = role
|
||||||
|
if content, ok := msgMap["content"].(string); ok {
|
||||||
|
msg.Content = content
|
||||||
|
}
|
||||||
|
if rc, ok := msgMap["reasoning_content"].(string); ok && strings.TrimSpace(rc) != "" {
|
||||||
|
msg.ReasoningContent = rc
|
||||||
|
}
|
||||||
|
if toolCallsRaw, ok := msgMap["tool_calls"]; ok && toolCallsRaw != nil {
|
||||||
|
if toolCallsArray, ok := toolCallsRaw.([]interface{}); ok {
|
||||||
|
for _, tcRaw := range toolCallsArray {
|
||||||
|
tcMap, ok := tcRaw.(map[string]interface{})
|
||||||
|
if !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
toolCall := ToolCall{}
|
||||||
|
if id, ok := tcMap["id"].(string); ok {
|
||||||
|
toolCall.ID = id
|
||||||
|
}
|
||||||
|
if toolType, ok := tcMap["type"].(string); ok {
|
||||||
|
toolCall.Type = toolType
|
||||||
|
}
|
||||||
|
if funcMap, ok := tcMap["function"].(map[string]interface{}); ok {
|
||||||
|
toolCall.Function = FunctionCall{}
|
||||||
|
if name, ok := funcMap["name"].(string); ok {
|
||||||
|
toolCall.Function.Name = name
|
||||||
|
}
|
||||||
|
if argsRaw, ok := funcMap["arguments"]; ok {
|
||||||
|
if argsStr, ok := argsRaw.(string); ok {
|
||||||
|
var argsMap map[string]interface{}
|
||||||
|
if err := json.Unmarshal([]byte(argsStr), &argsMap); err == nil {
|
||||||
|
toolCall.Function.Arguments = argsMap
|
||||||
|
}
|
||||||
|
} else if argsMap, ok := argsRaw.(map[string]interface{}); ok {
|
||||||
|
toolCall.Function.Arguments = argsMap
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if toolCall.ID != "" {
|
||||||
|
msg.ToolCalls = append(msg.ToolCalls, toolCall)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if toolCallID, ok := msgMap["tool_call_id"].(string); ok {
|
||||||
|
msg.ToolCallID = toolCallID
|
||||||
|
}
|
||||||
|
if tn, ok := msgMap["tool_name"].(string); ok && strings.TrimSpace(tn) != "" {
|
||||||
|
msg.ToolName = strings.TrimSpace(tn)
|
||||||
|
} else if tn, ok := msgMap["name"].(string); ok && strings.TrimSpace(tn) != "" && strings.EqualFold(msg.Role, "tool") {
|
||||||
|
msg.ToolName = strings.TrimSpace(tn)
|
||||||
|
}
|
||||||
|
out = append(out, msg)
|
||||||
|
}
|
||||||
|
return out, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ExtractLastUserTurnMessages 仅保留最后一次 user 提问起的消息(不含更早的用户轮次;跳过 system)。
|
||||||
|
// 与「继续对话」续跑所用轨迹范围一致:当前任务轮次,而非整段多轮对话历史。
|
||||||
|
func ExtractLastUserTurnMessages(msgs []ChatMessage) []ChatMessage {
|
||||||
|
if len(msgs) == 0 {
|
||||||
|
return msgs
|
||||||
|
}
|
||||||
|
lastUser := -1
|
||||||
|
for i, m := range msgs {
|
||||||
|
if strings.EqualFold(m.Role, "user") {
|
||||||
|
lastUser = i
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if lastUser < 0 {
|
||||||
|
return msgs
|
||||||
|
}
|
||||||
|
trimmed := msgs[lastUser:]
|
||||||
|
out := make([]ChatMessage, 0, len(trimmed))
|
||||||
|
for _, m := range trimmed {
|
||||||
|
if strings.EqualFold(m.Role, "system") {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
out = append(out, m)
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
// ExtractLastUserTurnTraceJSON 在 JSON 轨迹上裁剪为最后一次 user 起的片段(供落库格式直接处理)。
|
||||||
|
func ExtractLastUserTurnTraceJSON(traceInputJSON string) string {
|
||||||
|
traceInputJSON = strings.TrimSpace(traceInputJSON)
|
||||||
|
if traceInputJSON == "" {
|
||||||
|
return traceInputJSON
|
||||||
|
}
|
||||||
|
var arr []map[string]interface{}
|
||||||
|
if err := json.Unmarshal([]byte(traceInputJSON), &arr); err != nil {
|
||||||
|
return traceInputJSON
|
||||||
|
}
|
||||||
|
lastUser := -1
|
||||||
|
for i, m := range arr {
|
||||||
|
if r, _ := m["role"].(string); strings.EqualFold(r, "user") {
|
||||||
|
lastUser = i
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if lastUser <= 0 {
|
||||||
|
return traceInputJSON
|
||||||
|
}
|
||||||
|
trimmed := arr[lastUser:]
|
||||||
|
b, err := json.Marshal(trimmed)
|
||||||
|
if err != nil {
|
||||||
|
return traceInputJSON
|
||||||
|
}
|
||||||
|
return string(b)
|
||||||
|
}
|
||||||
|
|
||||||
|
// MergeAssistantTraceOutput 将 last_react_output 合并进轨迹最后一条 assistant(与 loadHistoryFromAgentTrace 一致)。
|
||||||
|
func MergeAssistantTraceOutput(msgs []ChatMessage, assistantOut string) []ChatMessage {
|
||||||
|
assistantOut = strings.TrimSpace(assistantOut)
|
||||||
|
if assistantOut == "" || len(msgs) == 0 {
|
||||||
|
return msgs
|
||||||
|
}
|
||||||
|
out := append([]ChatMessage(nil), msgs...)
|
||||||
|
last := &out[len(out)-1]
|
||||||
|
if strings.EqualFold(last.Role, "assistant") && len(last.ToolCalls) == 0 {
|
||||||
|
last.Content = assistantOut
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
out = append(out, ChatMessage{
|
||||||
|
Role: "assistant",
|
||||||
|
Content: assistantOut,
|
||||||
|
})
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
// MessagesToTraceJSON 将消息带序列化为 JSON(跳过 system)。
|
||||||
|
func MessagesToTraceJSON(msgs []ChatMessage) (string, error) {
|
||||||
|
filtered := make([]ChatMessage, 0, len(msgs))
|
||||||
|
for _, m := range msgs {
|
||||||
|
if strings.EqualFold(m.Role, "system") {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
filtered = append(filtered, m)
|
||||||
|
}
|
||||||
|
b, err := json.Marshal(filtered)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
return string(b), nil
|
||||||
|
}
|
||||||
@@ -0,0 +1,57 @@
|
|||||||
|
package agent
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestExtractLastUserTurnTraceJSON(t *testing.T) {
|
||||||
|
raw := []map[string]interface{}{
|
||||||
|
{"role": "user", "content": "old question"},
|
||||||
|
{"role": "assistant", "content": "old answer"},
|
||||||
|
{"role": "user", "content": "new target 1.1.1.1"},
|
||||||
|
{"role": "assistant", "tool_calls": []interface{}{map[string]interface{}{
|
||||||
|
"id": "c1", "type": "function",
|
||||||
|
"function": map[string]interface{}{"name": "nmap", "arguments": "{}"},
|
||||||
|
}}},
|
||||||
|
{"role": "tool", "tool_call_id": "c1", "content": "open ports"},
|
||||||
|
}
|
||||||
|
b, _ := json.Marshal(raw)
|
||||||
|
out := ExtractLastUserTurnTraceJSON(string(b))
|
||||||
|
var trimmed []map[string]interface{}
|
||||||
|
if err := json.Unmarshal([]byte(out), &trimmed); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if len(trimmed) != 3 {
|
||||||
|
t.Fatalf("expected 3 messages, got %d", len(trimmed))
|
||||||
|
}
|
||||||
|
if trimmed[0]["content"] != "new target 1.1.1.1" {
|
||||||
|
t.Fatalf("unexpected first message: %v", trimmed[0])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestExtractLastUserTurnMessagesSkipsSystem(t *testing.T) {
|
||||||
|
msgs := []ChatMessage{
|
||||||
|
{Role: "system", Content: "sys"},
|
||||||
|
{Role: "user", Content: "q"},
|
||||||
|
{Role: "assistant", Content: "a"},
|
||||||
|
}
|
||||||
|
out := ExtractLastUserTurnMessages(msgs)
|
||||||
|
if len(out) != 2 {
|
||||||
|
t.Fatalf("expected 2, got %d", len(out))
|
||||||
|
}
|
||||||
|
if out[0].Role != "user" {
|
||||||
|
t.Fatal("expected user first")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMergeAssistantTraceOutput(t *testing.T) {
|
||||||
|
msgs := []ChatMessage{
|
||||||
|
{Role: "user", Content: "q"},
|
||||||
|
{Role: "assistant", Content: "draft"},
|
||||||
|
}
|
||||||
|
out := MergeAssistantTraceOutput(msgs, "final summary")
|
||||||
|
if out[len(out)-1].Content != "final summary" {
|
||||||
|
t.Fatalf("expected merged output, got %q", out[len(out)-1].Content)
|
||||||
|
}
|
||||||
|
}
|
||||||
+129
-11
@@ -3,8 +3,10 @@ package app
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"crypto/subtle"
|
"crypto/subtle"
|
||||||
|
"crypto/tls"
|
||||||
"database/sql"
|
"database/sql"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
@@ -13,6 +15,7 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"cyberstrike-ai/internal/agent"
|
"cyberstrike-ai/internal/agent"
|
||||||
|
"cyberstrike-ai/internal/audit"
|
||||||
"cyberstrike-ai/internal/c2"
|
"cyberstrike-ai/internal/c2"
|
||||||
"cyberstrike-ai/internal/config"
|
"cyberstrike-ai/internal/config"
|
||||||
"cyberstrike-ai/internal/database"
|
"cyberstrike-ai/internal/database"
|
||||||
@@ -30,6 +33,7 @@ import (
|
|||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
"go.uber.org/zap"
|
"go.uber.org/zap"
|
||||||
|
"golang.org/x/net/http2"
|
||||||
)
|
)
|
||||||
|
|
||||||
// App 应用
|
// App 应用
|
||||||
@@ -53,14 +57,16 @@ type App struct {
|
|||||||
robotMu sync.Mutex // 保护钉钉/飞书长连接的 cancel
|
robotMu sync.Mutex // 保护钉钉/飞书长连接的 cancel
|
||||||
dingCancel context.CancelFunc // 钉钉 Stream 取消函数,用于配置变更时重启
|
dingCancel context.CancelFunc // 钉钉 Stream 取消函数,用于配置变更时重启
|
||||||
larkCancel context.CancelFunc // 飞书长连接取消函数,用于配置变更时重启
|
larkCancel context.CancelFunc // 飞书长连接取消函数,用于配置变更时重启
|
||||||
|
wechatCancel context.CancelFunc // 微信 iLink 长轮询取消函数
|
||||||
c2Manager *c2.Manager // C2 管理器(未启用 C2 时为 nil)
|
c2Manager *c2.Manager // C2 管理器(未启用 C2 时为 nil)
|
||||||
c2Watchdog *c2.SessionWatchdog // C2 会话看门狗
|
c2Watchdog *c2.SessionWatchdog // C2 会话看门狗
|
||||||
c2WatchdogCancel context.CancelFunc // 看门狗取消函数
|
c2WatchdogCancel context.CancelFunc // 看门狗取消函数
|
||||||
c2Handler *handler.C2Handler // C2 REST(与 Manager 生命周期同步)
|
c2Handler *handler.C2Handler // C2 REST(与 Manager 生命周期同步)
|
||||||
|
auditSvc *audit.Service
|
||||||
}
|
}
|
||||||
|
|
||||||
// New 创建新应用
|
// New 创建新应用
|
||||||
func New(cfg *config.Config, log *logger.Logger) (*App, error) {
|
func New(cfg *config.Config, log *logger.Logger, configPath string) (*App, error) {
|
||||||
gin.SetMode(gin.ReleaseMode)
|
gin.SetMode(gin.ReleaseMode)
|
||||||
router := gin.Default()
|
router := gin.Default()
|
||||||
|
|
||||||
@@ -89,6 +95,11 @@ func New(cfg *config.Config, log *logger.Logger) (*App, error) {
|
|||||||
return nil, fmt.Errorf("初始化数据库失败: %w", err)
|
return nil, fmt.Errorf("初始化数据库失败: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
auditSvc := audit.NewService(db, cfg, log.Logger)
|
||||||
|
audit.RegisterConversationCreateHook(auditSvc)
|
||||||
|
auditSvc.PurgeExpired()
|
||||||
|
audit.StartRetentionLoop(auditSvc, log.Logger)
|
||||||
|
|
||||||
// 创建MCP服务器(带数据库持久化)
|
// 创建MCP服务器(带数据库持久化)
|
||||||
mcpServer := mcp.NewServerWithStorage(log.Logger, db)
|
mcpServer := mcp.NewServerWithStorage(log.Logger, db)
|
||||||
mcpServer.ConfigureHTTPToolCallTimeoutFromAgentMinutes(cfg.Agent.ToolTimeoutMinutes)
|
mcpServer.ConfigureHTTPToolCallTimeoutFromAgentMinutes(cfg.Agent.ToolTimeoutMinutes)
|
||||||
@@ -218,6 +229,7 @@ func New(cfg *config.Config, log *logger.Logger) (*App, error) {
|
|||||||
|
|
||||||
// 创建知识库API处理器
|
// 创建知识库API处理器
|
||||||
knowledgeHandler = handler.NewKnowledgeHandler(knowledgeManager, knowledgeRetriever, knowledgeIndexer, db, log.Logger)
|
knowledgeHandler = handler.NewKnowledgeHandler(knowledgeManager, knowledgeRetriever, knowledgeIndexer, db, log.Logger)
|
||||||
|
knowledgeHandler.SetAudit(auditSvc)
|
||||||
log.Logger.Info("知识库模块初始化完成", zap.Bool("handler_created", knowledgeHandler != nil))
|
log.Logger.Info("知识库模块初始化完成", zap.Bool("handler_created", knowledgeHandler != nil))
|
||||||
|
|
||||||
// 扫描知识库并建立索引(异步)
|
// 扫描知识库并建立索引(异步)
|
||||||
@@ -292,10 +304,10 @@ func New(cfg *config.Config, log *logger.Logger) (*App, error) {
|
|||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
|
|
||||||
// 获取配置文件路径
|
// 配置文件路径必须由入口传入(与 flag -config 一致)。勿再用 os.Args[1],否则 ./cyberstrike-ai --https 会把 --https 当成路径。
|
||||||
configPath := "config.yaml"
|
configPath = strings.TrimSpace(configPath)
|
||||||
if len(os.Args) > 1 {
|
if configPath == "" {
|
||||||
configPath = os.Args[1]
|
configPath = "config.yaml"
|
||||||
}
|
}
|
||||||
|
|
||||||
skillsDir := skillpackage.SkillsRootFromConfig(cfg.SkillsDir, configPath)
|
skillsDir := skillpackage.SkillsRootFromConfig(cfg.SkillsDir, configPath)
|
||||||
@@ -314,31 +326,42 @@ func New(cfg *config.Config, log *logger.Logger) (*App, error) {
|
|||||||
log.Logger.Warn("创建 agents 目录失败", zap.String("path", agentsDir), zap.Error(err))
|
log.Logger.Warn("创建 agents 目录失败", zap.String("path", agentsDir), zap.Error(err))
|
||||||
}
|
}
|
||||||
markdownAgentsHandler := handler.NewMarkdownAgentsHandler(agentsDir)
|
markdownAgentsHandler := handler.NewMarkdownAgentsHandler(agentsDir)
|
||||||
|
markdownAgentsHandler.SetAudit(auditSvc)
|
||||||
log.Logger.Info("多代理 Markdown 子 Agent 目录", zap.String("agentsDir", agentsDir))
|
log.Logger.Info("多代理 Markdown 子 Agent 目录", zap.String("agentsDir", agentsDir))
|
||||||
|
|
||||||
// 创建处理器
|
// 创建处理器
|
||||||
agentHandler := handler.NewAgentHandler(agent, db, cfg, log.Logger)
|
agentHandler := handler.NewAgentHandler(agent, db, cfg, log.Logger)
|
||||||
|
agentHandler.SetAudit(auditSvc)
|
||||||
agentHandler.SetAgentsMarkdownDir(agentsDir)
|
agentHandler.SetAgentsMarkdownDir(agentsDir)
|
||||||
// 如果知识库已启用,设置知识库管理器到AgentHandler以便记录检索日志
|
// 如果知识库已启用,设置知识库管理器到AgentHandler以便记录检索日志
|
||||||
if knowledgeManager != nil {
|
if knowledgeManager != nil {
|
||||||
agentHandler.SetKnowledgeManager(knowledgeManager)
|
agentHandler.SetKnowledgeManager(knowledgeManager)
|
||||||
}
|
}
|
||||||
monitorHandler := handler.NewMonitorHandler(mcpServer, executor, db, log.Logger)
|
monitorHandler := handler.NewMonitorHandler(mcpServer, executor, db, log.Logger)
|
||||||
|
monitorHandler.SetAudit(auditSvc)
|
||||||
monitorHandler.SetExternalMCPManager(externalMCPMgr) // 设置外部MCP管理器,以便获取外部MCP执行记录
|
monitorHandler.SetExternalMCPManager(externalMCPMgr) // 设置外部MCP管理器,以便获取外部MCP执行记录
|
||||||
notificationHandler := handler.NewNotificationHandler(db, agentHandler, log.Logger)
|
notificationHandler := handler.NewNotificationHandler(db, agentHandler, log.Logger)
|
||||||
groupHandler := handler.NewGroupHandler(db, log.Logger)
|
groupHandler := handler.NewGroupHandler(db, log.Logger)
|
||||||
authHandler := handler.NewAuthHandler(authManager, cfg, configPath, log.Logger)
|
authHandler := handler.NewAuthHandler(authManager, cfg, configPath, log.Logger)
|
||||||
|
authHandler.SetAudit(auditSvc)
|
||||||
attackChainHandler := handler.NewAttackChainHandler(db, &cfg.OpenAI, log.Logger)
|
attackChainHandler := handler.NewAttackChainHandler(db, &cfg.OpenAI, log.Logger)
|
||||||
vulnerabilityHandler := handler.NewVulnerabilityHandler(db, log.Logger)
|
vulnerabilityHandler := handler.NewVulnerabilityHandler(db, log.Logger)
|
||||||
|
vulnerabilityHandler.SetAudit(auditSvc)
|
||||||
webshellHandler := handler.NewWebShellHandler(log.Logger, db)
|
webshellHandler := handler.NewWebShellHandler(log.Logger, db)
|
||||||
|
webshellHandler.SetAudit(auditSvc)
|
||||||
chatUploadsHandler := handler.NewChatUploadsHandler(log.Logger)
|
chatUploadsHandler := handler.NewChatUploadsHandler(log.Logger)
|
||||||
|
chatUploadsHandler.SetAudit(auditSvc)
|
||||||
registerWebshellTools(mcpServer, db, webshellHandler, log.Logger)
|
registerWebshellTools(mcpServer, db, webshellHandler, log.Logger)
|
||||||
registerWebshellManagementTools(mcpServer, db, webshellHandler, log.Logger)
|
registerWebshellManagementTools(mcpServer, db, webshellHandler, log.Logger)
|
||||||
configHandler := handler.NewConfigHandler(configPath, cfg, mcpServer, executor, agent, attackChainHandler, externalMCPMgr, log.Logger)
|
configHandler := handler.NewConfigHandler(configPath, cfg, mcpServer, executor, agent, attackChainHandler, externalMCPMgr, log.Logger)
|
||||||
|
configHandler.SetAudit(auditSvc)
|
||||||
agentHandler.SetHitlToolWhitelistSaver(configHandler)
|
agentHandler.SetHitlToolWhitelistSaver(configHandler)
|
||||||
externalMCPHandler := handler.NewExternalMCPHandler(externalMCPMgr, cfg, configPath, log.Logger)
|
externalMCPHandler := handler.NewExternalMCPHandler(externalMCPMgr, cfg, configPath, log.Logger)
|
||||||
|
externalMCPHandler.SetAudit(auditSvc)
|
||||||
roleHandler := handler.NewRoleHandler(cfg, configPath, log.Logger)
|
roleHandler := handler.NewRoleHandler(cfg, configPath, log.Logger)
|
||||||
|
roleHandler.SetAudit(auditSvc)
|
||||||
skillsHandler := handler.NewSkillsHandler(cfg, configPath, log.Logger)
|
skillsHandler := handler.NewSkillsHandler(cfg, configPath, log.Logger)
|
||||||
|
skillsHandler.SetAudit(auditSvc)
|
||||||
fofaHandler := handler.NewFofaHandler(cfg, log.Logger)
|
fofaHandler := handler.NewFofaHandler(cfg, log.Logger)
|
||||||
terminalHandler := handler.NewTerminalHandler(log.Logger)
|
terminalHandler := handler.NewTerminalHandler(log.Logger)
|
||||||
if db != nil {
|
if db != nil {
|
||||||
@@ -353,9 +376,12 @@ func New(cfg *config.Config, log *logger.Logger) (*App, error) {
|
|||||||
registerC2Tools(mcpServer, c2Manager, log.Logger, cfg.Server.Port)
|
registerC2Tools(mcpServer, c2Manager, log.Logger, cfg.Server.Port)
|
||||||
}
|
}
|
||||||
c2Handler := handler.NewC2Handler(c2Manager, log.Logger)
|
c2Handler := handler.NewC2Handler(c2Manager, log.Logger)
|
||||||
|
c2Handler.SetAudit(auditSvc)
|
||||||
|
|
||||||
// 创建OpenAPI处理器
|
// 创建OpenAPI处理器
|
||||||
conversationHandler := handler.NewConversationHandler(db, log.Logger)
|
conversationHandler := handler.NewConversationHandler(db, log.Logger)
|
||||||
|
conversationHandler.SetAudit(auditSvc)
|
||||||
|
auditHandler := handler.NewAuditHandler(db, auditSvc, log.Logger)
|
||||||
robotHandler := handler.NewRobotHandler(cfg, db, agentHandler, log.Logger)
|
robotHandler := handler.NewRobotHandler(cfg, db, agentHandler, log.Logger)
|
||||||
openAPIHandler := handler.NewOpenAPIHandler(db, log.Logger, resultStorage, conversationHandler, agentHandler)
|
openAPIHandler := handler.NewOpenAPIHandler(db, log.Logger, resultStorage, conversationHandler, agentHandler)
|
||||||
|
|
||||||
@@ -381,6 +407,7 @@ func New(cfg *config.Config, log *logger.Logger) (*App, error) {
|
|||||||
c2Watchdog: c2Watchdog,
|
c2Watchdog: c2Watchdog,
|
||||||
c2WatchdogCancel: watchdogCancel,
|
c2WatchdogCancel: watchdogCancel,
|
||||||
c2Handler: c2Handler,
|
c2Handler: c2Handler,
|
||||||
|
auditSvc: auditSvc,
|
||||||
}
|
}
|
||||||
// 飞书/钉钉长连接(无需公网),启用时在后台启动;后续前端应用配置时会通过 RestartRobotConnections 重启
|
// 飞书/钉钉长连接(无需公网),启用时在后台启动;后续前端应用配置时会通过 RestartRobotConnections 重启
|
||||||
app.startRobotConnections()
|
app.startRobotConnections()
|
||||||
@@ -446,9 +473,11 @@ func New(cfg *config.Config, log *logger.Logger) (*App, error) {
|
|||||||
configHandler.SetRetrieverUpdater(knowledgeRetriever)
|
configHandler.SetRetrieverUpdater(knowledgeRetriever)
|
||||||
}
|
}
|
||||||
|
|
||||||
// 设置机器人连接重启器,前端应用配置后无需重启服务即可使钉钉/飞书新配置生效
|
// 设置机器人连接重启器,前端应用配置后无需重启服务即可使钉钉/飞书/微信新配置生效
|
||||||
configHandler.SetRobotRestarter(app)
|
configHandler.SetRobotRestarter(app)
|
||||||
|
|
||||||
|
wechatRobotHandler := handler.NewWechatRobotHandler(cfg, configHandler, log.Logger)
|
||||||
|
|
||||||
configHandler.SetC2Runtime(app)
|
configHandler.SetC2Runtime(app)
|
||||||
configHandler.SetC2ToolRegistrar(func() error {
|
configHandler.SetC2ToolRegistrar(func() error {
|
||||||
if app.config.C2.EnabledEffective() && app.c2Manager != nil {
|
if app.config.C2.EnabledEffective() && app.c2Manager != nil {
|
||||||
@@ -466,6 +495,7 @@ func New(cfg *config.Config, log *logger.Logger) (*App, error) {
|
|||||||
notificationHandler,
|
notificationHandler,
|
||||||
conversationHandler,
|
conversationHandler,
|
||||||
robotHandler,
|
robotHandler,
|
||||||
|
wechatRobotHandler,
|
||||||
groupHandler,
|
groupHandler,
|
||||||
configHandler,
|
configHandler,
|
||||||
externalMCPHandler,
|
externalMCPHandler,
|
||||||
@@ -480,6 +510,7 @@ func New(cfg *config.Config, log *logger.Logger) (*App, error) {
|
|||||||
fofaHandler,
|
fofaHandler,
|
||||||
terminalHandler,
|
terminalHandler,
|
||||||
app.c2Handler,
|
app.c2Handler,
|
||||||
|
auditHandler,
|
||||||
mcpServer,
|
mcpServer,
|
||||||
authManager,
|
authManager,
|
||||||
openAPIHandler,
|
openAPIHandler,
|
||||||
@@ -530,18 +561,49 @@ func (a *App) RunWithContext(ctx context.Context) error {
|
|||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
|
|
||||||
// 启动主服务器
|
// 启动主服务器(可选 HTTPS + HTTP/2,见 config server.tls_*)
|
||||||
addr := fmt.Sprintf("%s:%d", a.config.Server.Host, a.config.Server.Port)
|
addr := fmt.Sprintf("%s:%d", a.config.Server.Host, a.config.Server.Port)
|
||||||
a.logger.Info("启动HTTP服务器", zap.String("address", addr))
|
tlsMode, tlsConf, certFile, keyFile, tlsErr := prepareMainServerTLS(&a.config.Server)
|
||||||
|
if tlsErr != nil {
|
||||||
|
return tlsErr
|
||||||
|
}
|
||||||
|
|
||||||
srv := &http.Server{Addr: addr, Handler: a.router}
|
srv := &http.Server{Addr: addr, Handler: a.router}
|
||||||
|
var mainMux *mainServerMux
|
||||||
|
httpRedirect := config.ServerHTTPRedirectEnabled(&a.config.Server)
|
||||||
|
if tlsMode != mainTLSOff {
|
||||||
|
srv.TLSConfig = tlsConf
|
||||||
|
if err := http2.ConfigureServer(srv, &http2.Server{}); err != nil {
|
||||||
|
return fmt.Errorf("主服务 HTTP/2 配置失败: %w", err)
|
||||||
|
}
|
||||||
|
switch tlsMode {
|
||||||
|
case mainTLSFromFiles:
|
||||||
|
a.logger.Info("启动 HTTPS 主服务(已启用 HTTP/2 协商)",
|
||||||
|
zap.String("address", addr),
|
||||||
|
zap.String("cert", certFile),
|
||||||
|
)
|
||||||
|
case mainTLSInMemorySelfSigned:
|
||||||
|
a.logger.Info("启动 HTTPS 主服务(内存自签证书,仅测试;已启用 HTTP/2 协商)",
|
||||||
|
zap.String("address", addr),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
if httpRedirect {
|
||||||
|
a.logger.Info("已启用 HTTP→HTTPS 自动跳转(同端口嗅探分流)", zap.String("address", addr))
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
a.logger.Info("启动 HTTP 主服务", zap.String("address", addr))
|
||||||
|
}
|
||||||
|
|
||||||
// 监听 context 取消,优雅关闭 HTTP 服务器
|
// 监听 context 取消,优雅关闭 HTTP 服务器
|
||||||
go func() {
|
go func() {
|
||||||
<-ctx.Done()
|
<-ctx.Done()
|
||||||
shutdownCtx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
shutdownCtx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
if err := srv.Shutdown(shutdownCtx); err != nil {
|
if mainMux != nil {
|
||||||
|
if err := mainMux.Shutdown(shutdownCtx); err != nil {
|
||||||
|
a.logger.Error("HTTP/HTTPS 分流服务器关闭失败", zap.Error(err))
|
||||||
|
}
|
||||||
|
} else if err := srv.Shutdown(shutdownCtx); err != nil {
|
||||||
a.logger.Error("HTTP服务器关闭失败", zap.Error(err))
|
a.logger.Error("HTTP服务器关闭失败", zap.Error(err))
|
||||||
}
|
}
|
||||||
if mcpServer != nil {
|
if mcpServer != nil {
|
||||||
@@ -551,7 +613,36 @@ func (a *App) RunWithContext(ctx context.Context) error {
|
|||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
if err := srv.ListenAndServe(); err != nil && err != http.ErrServerClosed {
|
var err error
|
||||||
|
switch {
|
||||||
|
case tlsMode != mainTLSOff && httpRedirect:
|
||||||
|
var tlsConfReady *tls.Config
|
||||||
|
tlsConfReady, err = ensureMainTLSConfigCerts(tlsMode, tlsConf, certFile, keyFile)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("加载 TLS 证书: %w", err)
|
||||||
|
}
|
||||||
|
srv.TLSConfig = tlsConfReady
|
||||||
|
var ln net.Listener
|
||||||
|
ln, err = net.Listen("tcp", addr)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
mainMux = newMainServerMux(ln, srv, portFromListenAddr(addr), a.logger.Logger)
|
||||||
|
err = mainMux.Serve()
|
||||||
|
case tlsMode == mainTLSOff:
|
||||||
|
err = srv.ListenAndServe()
|
||||||
|
case tlsMode == mainTLSFromFiles:
|
||||||
|
err = srv.ListenAndServeTLS(certFile, keyFile)
|
||||||
|
case tlsMode == mainTLSInMemorySelfSigned:
|
||||||
|
var ln net.Listener
|
||||||
|
ln, err = tls.Listen("tcp", addr, srv.TLSConfig)
|
||||||
|
if err == nil {
|
||||||
|
err = srv.Serve(ln)
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
err = srv.ListenAndServe()
|
||||||
|
}
|
||||||
|
if err != nil && err != http.ErrServerClosed {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
@@ -612,9 +703,14 @@ func (a *App) startRobotConnections() {
|
|||||||
a.dingCancel = cancel
|
a.dingCancel = cancel
|
||||||
go robot.StartDing(ctx, cfg.Robots, a.robotHandler, a.logger.Logger)
|
go robot.StartDing(ctx, cfg.Robots, a.robotHandler, a.logger.Logger)
|
||||||
}
|
}
|
||||||
|
if cfg.Robots.Wechat.Enabled && cfg.Robots.Wechat.BotToken != "" {
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
a.wechatCancel = cancel
|
||||||
|
go robot.StartWechat(ctx, cfg.Robots, a.robotHandler, cfg.Version, a.logger.Logger)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// RestartRobotConnections 重启钉钉/飞书长连接,使前端应用配置后立即生效(实现 handler.RobotRestarter)
|
// RestartRobotConnections 重启钉钉/飞书/微信长连接,使前端应用配置后立即生效(实现 handler.RobotRestarter)
|
||||||
func (a *App) RestartRobotConnections() {
|
func (a *App) RestartRobotConnections() {
|
||||||
a.robotMu.Lock()
|
a.robotMu.Lock()
|
||||||
if a.dingCancel != nil {
|
if a.dingCancel != nil {
|
||||||
@@ -625,6 +721,10 @@ func (a *App) RestartRobotConnections() {
|
|||||||
a.larkCancel()
|
a.larkCancel()
|
||||||
a.larkCancel = nil
|
a.larkCancel = nil
|
||||||
}
|
}
|
||||||
|
if a.wechatCancel != nil {
|
||||||
|
a.wechatCancel()
|
||||||
|
a.wechatCancel = nil
|
||||||
|
}
|
||||||
a.robotMu.Unlock()
|
a.robotMu.Unlock()
|
||||||
// 给旧 goroutine 一点时间退出
|
// 给旧 goroutine 一点时间退出
|
||||||
time.Sleep(200 * time.Millisecond)
|
time.Sleep(200 * time.Millisecond)
|
||||||
@@ -640,6 +740,7 @@ func setupRoutes(
|
|||||||
notificationHandler *handler.NotificationHandler,
|
notificationHandler *handler.NotificationHandler,
|
||||||
conversationHandler *handler.ConversationHandler,
|
conversationHandler *handler.ConversationHandler,
|
||||||
robotHandler *handler.RobotHandler,
|
robotHandler *handler.RobotHandler,
|
||||||
|
wechatRobotHandler *handler.WechatRobotHandler,
|
||||||
groupHandler *handler.GroupHandler,
|
groupHandler *handler.GroupHandler,
|
||||||
configHandler *handler.ConfigHandler,
|
configHandler *handler.ConfigHandler,
|
||||||
externalMCPHandler *handler.ExternalMCPHandler,
|
externalMCPHandler *handler.ExternalMCPHandler,
|
||||||
@@ -654,6 +755,7 @@ func setupRoutes(
|
|||||||
fofaHandler *handler.FofaHandler,
|
fofaHandler *handler.FofaHandler,
|
||||||
terminalHandler *handler.TerminalHandler,
|
terminalHandler *handler.TerminalHandler,
|
||||||
c2Handler *handler.C2Handler,
|
c2Handler *handler.C2Handler,
|
||||||
|
auditHandler *handler.AuditHandler,
|
||||||
mcpServer *mcp.Server,
|
mcpServer *mcp.Server,
|
||||||
authManager *security.AuthManager,
|
authManager *security.AuthManager,
|
||||||
openAPIHandler *handler.OpenAPIHandler,
|
openAPIHandler *handler.OpenAPIHandler,
|
||||||
@@ -688,6 +790,12 @@ func setupRoutes(
|
|||||||
// 机器人测试(需登录):POST /api/robot/test,body: {"platform":"dingtalk","user_id":"test","text":"帮助"},用于验证机器人逻辑
|
// 机器人测试(需登录):POST /api/robot/test,body: {"platform":"dingtalk","user_id":"test","text":"帮助"},用于验证机器人逻辑
|
||||||
protected.POST("/robot/test", robotHandler.HandleRobotTest)
|
protected.POST("/robot/test", robotHandler.HandleRobotTest)
|
||||||
|
|
||||||
|
// 微信 iLink 扫码绑定(需登录)
|
||||||
|
protected.POST("/robot/wechat/qrcode", wechatRobotHandler.HandleWechatQRCode)
|
||||||
|
protected.GET("/robot/wechat/qrcode/status", wechatRobotHandler.HandleWechatQRCodeStatus)
|
||||||
|
protected.POST("/robot/wechat/qrcode/verify", wechatRobotHandler.HandleWechatVerifyCode)
|
||||||
|
protected.GET("/robot/wechat/status", wechatRobotHandler.HandleWechatStatus)
|
||||||
|
|
||||||
// Agent Loop
|
// Agent Loop
|
||||||
protected.POST("/agent-loop", agentHandler.AgentLoop)
|
protected.POST("/agent-loop", agentHandler.AgentLoop)
|
||||||
// Agent Loop 流式输出
|
// Agent Loop 流式输出
|
||||||
@@ -784,6 +892,13 @@ func setupRoutes(
|
|||||||
protected.POST("/terminal/run/stream", terminalHandler.RunCommandStream)
|
protected.POST("/terminal/run/stream", terminalHandler.RunCommandStream)
|
||||||
protected.GET("/terminal/ws", terminalHandler.RunCommandWS)
|
protected.GET("/terminal/ws", terminalHandler.RunCommandWS)
|
||||||
|
|
||||||
|
// 平台审计日志
|
||||||
|
protected.GET("/audit/meta", auditHandler.Meta)
|
||||||
|
protected.GET("/audit/summary", auditHandler.Summary)
|
||||||
|
protected.GET("/audit/logs", auditHandler.ListLogs)
|
||||||
|
protected.GET("/audit/logs/export", auditHandler.ExportLogs)
|
||||||
|
protected.GET("/audit/logs/:id", auditHandler.GetLog)
|
||||||
|
|
||||||
// 外部MCP管理
|
// 外部MCP管理
|
||||||
protected.GET("/external-mcp", externalMCPHandler.GetExternalMCPs)
|
protected.GET("/external-mcp", externalMCPHandler.GetExternalMCPs)
|
||||||
protected.GET("/external-mcp/stats", externalMCPHandler.GetExternalMCPStats)
|
protected.GET("/external-mcp/stats", externalMCPHandler.GetExternalMCPStats)
|
||||||
@@ -1845,6 +1960,9 @@ func initializeKnowledge(
|
|||||||
|
|
||||||
// 创建知识库API处理器
|
// 创建知识库API处理器
|
||||||
knowledgeHandler := handler.NewKnowledgeHandler(knowledgeManager, knowledgeRetriever, knowledgeIndexer, db, logger)
|
knowledgeHandler := handler.NewKnowledgeHandler(knowledgeManager, knowledgeRetriever, knowledgeIndexer, db, logger)
|
||||||
|
if app != nil && app.auditSvc != nil {
|
||||||
|
knowledgeHandler.SetAudit(app.auditSvc)
|
||||||
|
}
|
||||||
logger.Info("知识库模块初始化完成", zap.Bool("handler_created", knowledgeHandler != nil))
|
logger.Info("知识库模块初始化完成", zap.Bool("handler_created", knowledgeHandler != nil))
|
||||||
|
|
||||||
// 设置知识库管理器到AgentHandler以便记录检索日志
|
// 设置知识库管理器到AgentHandler以便记录检索日志
|
||||||
|
|||||||
@@ -0,0 +1,196 @@
|
|||||||
|
package app
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bufio"
|
||||||
|
"context"
|
||||||
|
"crypto/tls"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
"strconv"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"go.uber.org/zap"
|
||||||
|
)
|
||||||
|
|
||||||
|
// peekedConn 在已预读首字节后仍将连接交给 net/http 或 crypto/tls。
|
||||||
|
type peekedConn struct {
|
||||||
|
net.Conn
|
||||||
|
r *bufio.Reader
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *peekedConn) Read(p []byte) (int, error) {
|
||||||
|
return c.r.Read(p)
|
||||||
|
}
|
||||||
|
|
||||||
|
// oneConnListener 供 http.Server.Serve 处理单条 TCP 连接(含 keep-alive)。
|
||||||
|
type oneConnListener struct {
|
||||||
|
conn net.Conn
|
||||||
|
addr net.Addr
|
||||||
|
once sync.Once
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *oneConnListener) Accept() (net.Conn, error) {
|
||||||
|
var c net.Conn
|
||||||
|
l.once.Do(func() {
|
||||||
|
c = l.conn
|
||||||
|
l.conn = nil
|
||||||
|
})
|
||||||
|
if c == nil {
|
||||||
|
return nil, net.ErrClosed
|
||||||
|
}
|
||||||
|
return c, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *oneConnListener) Close() error { return nil }
|
||||||
|
func (l *oneConnListener) Addr() net.Addr { return l.addr }
|
||||||
|
|
||||||
|
func isTLSHandshakeRecord(b byte) bool {
|
||||||
|
return b == 0x16
|
||||||
|
}
|
||||||
|
|
||||||
|
func newHTTPToHTTPSRedirectHandler(httpsPort int) http.Handler {
|
||||||
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
host := r.Host
|
||||||
|
if h, _, err := net.SplitHostPort(host); err == nil {
|
||||||
|
host = h
|
||||||
|
}
|
||||||
|
var target string
|
||||||
|
if httpsPort == 443 {
|
||||||
|
target = fmt.Sprintf("https://%s%s", host, r.URL.RequestURI())
|
||||||
|
} else {
|
||||||
|
target = fmt.Sprintf("https://%s:%d%s", host, httpsPort, r.URL.RequestURI())
|
||||||
|
}
|
||||||
|
http.Redirect(w, r, target, http.StatusPermanentRedirect)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func portFromListenAddr(addr string) int {
|
||||||
|
_, portStr, err := net.SplitHostPort(addr)
|
||||||
|
if err != nil {
|
||||||
|
return 443
|
||||||
|
}
|
||||||
|
p, err := strconv.Atoi(portStr)
|
||||||
|
if err != nil || p <= 0 {
|
||||||
|
return 443
|
||||||
|
}
|
||||||
|
return p
|
||||||
|
}
|
||||||
|
|
||||||
|
func ensureMainTLSConfigCerts(mode mainTLSMode, tlsConf *tls.Config, certFile, keyFile string) (*tls.Config, error) {
|
||||||
|
if mode != mainTLSFromFiles {
|
||||||
|
return tlsConf, nil
|
||||||
|
}
|
||||||
|
if tlsConf == nil {
|
||||||
|
tlsConf = &tls.Config{MinVersion: tls.VersionTLS12}
|
||||||
|
}
|
||||||
|
if len(tlsConf.Certificates) > 0 {
|
||||||
|
return tlsConf, nil
|
||||||
|
}
|
||||||
|
cert, err := tls.LoadX509KeyPair(certFile, keyFile)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
tlsConf.Certificates = []tls.Certificate{cert}
|
||||||
|
return tlsConf, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type mainServerMux struct {
|
||||||
|
ln net.Listener
|
||||||
|
httpsSrv *http.Server
|
||||||
|
redirectSrv *http.Server
|
||||||
|
logger *zap.Logger
|
||||||
|
}
|
||||||
|
|
||||||
|
func newMainServerMux(ln net.Listener, httpsSrv *http.Server, httpsPort int, logger *zap.Logger) *mainServerMux {
|
||||||
|
return &mainServerMux{
|
||||||
|
ln: ln,
|
||||||
|
httpsSrv: httpsSrv,
|
||||||
|
redirectSrv: &http.Server{Handler: newHTTPToHTTPSRedirectHandler(httpsPort), ReadHeaderTimeout: 10 * time.Second},
|
||||||
|
logger: logger,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mainServerMux) Serve() error {
|
||||||
|
for {
|
||||||
|
conn, err := m.ln.Accept()
|
||||||
|
if err != nil {
|
||||||
|
if errors.Is(err, net.ErrClosed) {
|
||||||
|
return http.ErrServerClosed
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
go m.handleConn(conn)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mainServerMux) handleConn(raw net.Conn) {
|
||||||
|
if err := raw.SetReadDeadline(time.Now().Add(10 * time.Second)); err != nil {
|
||||||
|
_ = raw.Close()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
br := bufio.NewReader(raw)
|
||||||
|
b, err := br.Peek(1)
|
||||||
|
if err != nil {
|
||||||
|
_ = raw.Close()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
_ = raw.SetReadDeadline(time.Time{})
|
||||||
|
|
||||||
|
pc := &peekedConn{Conn: raw, r: br}
|
||||||
|
ocl := &oneConnListener{conn: pc, addr: raw.LocalAddr()}
|
||||||
|
|
||||||
|
if isTLSHandshakeRecord(b[0]) {
|
||||||
|
m.serveHTTPS(pc, raw.LocalAddr())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if err := m.redirectSrv.Serve(ocl); err != nil && !errors.Is(err, net.ErrClosed) && !errors.Is(err, http.ErrServerClosed) {
|
||||||
|
m.logger.Debug("HTTP 重定向连接处理结束", zap.Error(err))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// serveHTTPS 在已嗅探为 TLS 的连接上完成握手,再按 ALPN 走 HTTP/2 或 HTTP/1.1。
|
||||||
|
// 不能对同一 http.Server 并发调用 Serve(TLSConfig!=nil),否则握手/ALPN 会异常(浏览器 ERR_SSL_PROTOCOL_ERROR)。
|
||||||
|
func (m *mainServerMux) serveHTTPS(pc *peekedConn, localAddr net.Addr) {
|
||||||
|
tlsConn := tls.Server(pc, m.httpsSrv.TLSConfig)
|
||||||
|
handCtx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
if err := tlsConn.HandshakeContext(handCtx); err != nil {
|
||||||
|
m.logger.Debug("TLS 握手失败", zap.Error(err))
|
||||||
|
_ = pc.Close()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
srv := m.httpsSrv
|
||||||
|
if srv.TLSNextProto != nil {
|
||||||
|
proto := tlsConn.ConnectionState().NegotiatedProtocol
|
||||||
|
if fn := srv.TLSNextProto[proto]; fn != nil {
|
||||||
|
fn(srv, tlsConn, srv.Handler)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
plain := *srv
|
||||||
|
plain.TLSConfig = nil
|
||||||
|
ocl := &oneConnListener{conn: tlsConn, addr: localAddr}
|
||||||
|
if err := plain.Serve(ocl); err != nil && !errors.Is(err, net.ErrClosed) && !errors.Is(err, http.ErrServerClosed) {
|
||||||
|
m.logger.Debug("HTTPS 连接处理结束", zap.Error(err))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mainServerMux) Shutdown(ctx context.Context) error {
|
||||||
|
_ = m.ln.Close()
|
||||||
|
var err1, err2 error
|
||||||
|
if m.httpsSrv != nil {
|
||||||
|
err1 = m.httpsSrv.Shutdown(ctx)
|
||||||
|
}
|
||||||
|
if m.redirectSrv != nil {
|
||||||
|
err2 = m.redirectSrv.Shutdown(ctx)
|
||||||
|
}
|
||||||
|
if err1 != nil {
|
||||||
|
return err1
|
||||||
|
}
|
||||||
|
return err2
|
||||||
|
}
|
||||||
@@ -0,0 +1,150 @@
|
|||||||
|
package app
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/tls"
|
||||||
|
"io"
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"strconv"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"cyberstrike-ai/internal/config"
|
||||||
|
|
||||||
|
"golang.org/x/net/http2"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestNewHTTPToHTTPSRedirectHandler(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
httpsPort int
|
||||||
|
host string
|
||||||
|
uri string
|
||||||
|
wantTarget string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "non standard port",
|
||||||
|
httpsPort: 8080,
|
||||||
|
host: "127.0.0.1:8080",
|
||||||
|
uri: "/login?next=/",
|
||||||
|
wantTarget: "https://127.0.0.1:8080/login?next=/",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "standard port",
|
||||||
|
httpsPort: 443,
|
||||||
|
host: "example.com:80",
|
||||||
|
uri: "/",
|
||||||
|
wantTarget: "https://example.com/",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
tt := tt
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
h := newHTTPToHTTPSRedirectHandler(tt.httpsPort)
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "http://"+tt.host+tt.uri, nil)
|
||||||
|
req.Host = tt.host
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
h.ServeHTTP(rec, req)
|
||||||
|
if rec.Code != http.StatusPermanentRedirect {
|
||||||
|
t.Fatalf("status = %d, want %d", rec.Code, http.StatusPermanentRedirect)
|
||||||
|
}
|
||||||
|
if got := rec.Header().Get("Location"); got != tt.wantTarget {
|
||||||
|
t.Fatalf("Location = %q, want %q", got, tt.wantTarget)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestIsTLSHandshakeRecord(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
if !isTLSHandshakeRecord(0x16) {
|
||||||
|
t.Fatal("expected TLS handshake record")
|
||||||
|
}
|
||||||
|
if isTLSHandshakeRecord('G') {
|
||||||
|
t.Fatal("GET should not be TLS")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestServerHTTPRedirectEnabled(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
disabled := false
|
||||||
|
enabled := true
|
||||||
|
if config.ServerHTTPRedirectEnabled(nil) {
|
||||||
|
t.Fatal("nil config should disable redirect")
|
||||||
|
}
|
||||||
|
if !config.ServerHTTPRedirectEnabled(&config.ServerConfig{TLSEnabled: true}) {
|
||||||
|
t.Fatal("HTTPS without explicit flag should enable redirect")
|
||||||
|
}
|
||||||
|
if config.ServerHTTPRedirectEnabled(&config.ServerConfig{TLSEnabled: true, TLSHTTPRedirect: &disabled}) {
|
||||||
|
t.Fatal("explicit false should disable redirect")
|
||||||
|
}
|
||||||
|
if !config.ServerHTTPRedirectEnabled(&config.ServerConfig{TLSEnabled: true, TLSHTTPRedirect: &enabled}) {
|
||||||
|
t.Fatal("explicit true should enable redirect")
|
||||||
|
}
|
||||||
|
if config.ServerHTTPRedirectEnabled(&config.ServerConfig{}) {
|
||||||
|
t.Fatal("plain HTTP should not redirect")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMainServerMuxHTTPRedirectAndHTTPS(t *testing.T) {
|
||||||
|
cert, err := generateMainServerSelfSignedCert()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("generate cert: %v", err)
|
||||||
|
}
|
||||||
|
handler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||||
|
_, _ = io.WriteString(w, "ok")
|
||||||
|
})
|
||||||
|
srv := &http.Server{Handler: handler, TLSConfig: &tls.Config{
|
||||||
|
MinVersion: tls.VersionTLS12,
|
||||||
|
Certificates: []tls.Certificate{cert},
|
||||||
|
}}
|
||||||
|
if err := http2.ConfigureServer(srv, &http2.Server{}); err != nil {
|
||||||
|
t.Fatalf("configure http2: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
ln, err := net.Listen("tcp", "127.0.0.1:0")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("listen: %v", err)
|
||||||
|
}
|
||||||
|
defer ln.Close()
|
||||||
|
|
||||||
|
mux := newMainServerMux(ln, srv, portFromListenAddr(ln.Addr().String()), nil)
|
||||||
|
go func() { _ = mux.Serve() }()
|
||||||
|
|
||||||
|
client := &http.Client{
|
||||||
|
Transport: &http.Transport{
|
||||||
|
TLSClientConfig: &tls.Config{InsecureSkipVerify: true, MinVersion: tls.VersionTLS12},
|
||||||
|
},
|
||||||
|
CheckRedirect: func(_ *http.Request, _ []*http.Request) error {
|
||||||
|
return http.ErrUseLastResponse
|
||||||
|
},
|
||||||
|
}
|
||||||
|
addr := ln.Addr().String()
|
||||||
|
|
||||||
|
httpResp, err := client.Get("http://" + addr + "/")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("http get: %v", err)
|
||||||
|
}
|
||||||
|
_ = httpResp.Body.Close()
|
||||||
|
if httpResp.StatusCode != http.StatusPermanentRedirect {
|
||||||
|
t.Fatalf("http status = %d, want %d", httpResp.StatusCode, http.StatusPermanentRedirect)
|
||||||
|
}
|
||||||
|
if got := httpResp.Header.Get("Location"); got != "https://127.0.0.1:"+strconv.Itoa(portFromListenAddr(addr))+"/" {
|
||||||
|
t.Fatalf("Location = %q", got)
|
||||||
|
}
|
||||||
|
|
||||||
|
httpsResp, err := client.Get("https://" + addr + "/")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("https get: %v", err)
|
||||||
|
}
|
||||||
|
defer httpsResp.Body.Close()
|
||||||
|
if httpsResp.StatusCode != http.StatusOK {
|
||||||
|
t.Fatalf("https status = %d, want %d", httpsResp.StatusCode, http.StatusOK)
|
||||||
|
}
|
||||||
|
body, _ := io.ReadAll(httpsResp.Body)
|
||||||
|
if string(body) != "ok" {
|
||||||
|
t.Fatalf("body = %q, want ok", body)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,86 @@
|
|||||||
|
package app
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/ecdsa"
|
||||||
|
"crypto/elliptic"
|
||||||
|
"crypto/rand"
|
||||||
|
"crypto/tls"
|
||||||
|
"crypto/x509"
|
||||||
|
"crypto/x509/pkix"
|
||||||
|
"encoding/pem"
|
||||||
|
"fmt"
|
||||||
|
"math/big"
|
||||||
|
"net"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"cyberstrike-ai/internal/config"
|
||||||
|
)
|
||||||
|
|
||||||
|
// mainTLSMode 主 Web 服务 TLS 启动方式。
|
||||||
|
type mainTLSMode int
|
||||||
|
|
||||||
|
const (
|
||||||
|
mainTLSOff mainTLSMode = iota
|
||||||
|
mainTLSFromFiles
|
||||||
|
mainTLSInMemorySelfSigned
|
||||||
|
)
|
||||||
|
|
||||||
|
// prepareMainServerTLS 根据 server 配置决定主站是否启用 HTTPS(及 HTTP/2 协商)。
|
||||||
|
// fromFiles:使用 tls_cert_path + tls_key_path,由 http.Server.ListenAndServeTLS 加载 PEM。
|
||||||
|
// inMemory:tls_auto_self_sign 生成的自签证书,仅用于本地/测试。
|
||||||
|
func prepareMainServerTLS(cfg *config.ServerConfig) (mode mainTLSMode, tlsConf *tls.Config, certFile, keyFile string, err error) {
|
||||||
|
if cfg == nil || !config.MainWebUIUsesHTTPS(cfg) {
|
||||||
|
return mainTLSOff, nil, "", "", nil
|
||||||
|
}
|
||||||
|
certFile = strings.TrimSpace(cfg.TLSCertPath)
|
||||||
|
keyFile = strings.TrimSpace(cfg.TLSKeyPath)
|
||||||
|
if certFile != "" && keyFile != "" {
|
||||||
|
// 证书由 ListenAndServeTLS 从文件加载;此处仅提供最小 TLS 配置供 http2.ConfigureServer 合并 ALPN。
|
||||||
|
return mainTLSFromFiles, &tls.Config{MinVersion: tls.VersionTLS12}, certFile, keyFile, nil
|
||||||
|
}
|
||||||
|
if cfg.TLSAutoSelfSign {
|
||||||
|
cert, genErr := generateMainServerSelfSignedCert()
|
||||||
|
if genErr != nil {
|
||||||
|
return mainTLSOff, nil, "", "", fmt.Errorf("生成自签 TLS 证书: %w", genErr)
|
||||||
|
}
|
||||||
|
tlsConf = &tls.Config{
|
||||||
|
MinVersion: tls.VersionTLS12,
|
||||||
|
Certificates: []tls.Certificate{cert},
|
||||||
|
}
|
||||||
|
return mainTLSInMemorySelfSigned, tlsConf, "", "", nil
|
||||||
|
}
|
||||||
|
return mainTLSOff, nil, "", "", fmt.Errorf("server: 已启用 TLS(tls_enabled / tls_auto_self_sign / 证书路径),请设置 tls_cert_path 与 tls_key_path,或将 tls_auto_self_sign 设为 true(仅测试环境)")
|
||||||
|
}
|
||||||
|
|
||||||
|
func generateMainServerSelfSignedCert() (tls.Certificate, error) {
|
||||||
|
priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||||
|
if err != nil {
|
||||||
|
return tls.Certificate{}, err
|
||||||
|
}
|
||||||
|
serial, err := rand.Int(rand.Reader, big.NewInt(1<<62))
|
||||||
|
if err != nil {
|
||||||
|
return tls.Certificate{}, err
|
||||||
|
}
|
||||||
|
tmpl := &x509.Certificate{
|
||||||
|
SerialNumber: serial,
|
||||||
|
Subject: pkix.Name{CommonName: "CyberStrikeAI"},
|
||||||
|
NotBefore: time.Now().Add(-1 * time.Hour),
|
||||||
|
NotAfter: time.Now().Add(365 * 24 * time.Hour),
|
||||||
|
KeyUsage: x509.KeyUsageDigitalSignature,
|
||||||
|
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
|
||||||
|
IPAddresses: []net.IP{net.ParseIP("127.0.0.1"), net.ParseIP("::1")},
|
||||||
|
DNSNames: []string{"localhost"},
|
||||||
|
}
|
||||||
|
der, err := x509.CreateCertificate(rand.Reader, tmpl, tmpl, &priv.PublicKey, priv)
|
||||||
|
if err != nil {
|
||||||
|
return tls.Certificate{}, err
|
||||||
|
}
|
||||||
|
keyDER, err := x509.MarshalECPrivateKey(priv)
|
||||||
|
if err != nil {
|
||||||
|
return tls.Certificate{}, err
|
||||||
|
}
|
||||||
|
certPEM := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: der})
|
||||||
|
keyPEM := pem.EncodeToMemory(&pem.Block{Type: "EC PRIVATE KEY", Bytes: keyDER})
|
||||||
|
return tls.X509KeyPair(certPEM, keyPEM)
|
||||||
|
}
|
||||||
@@ -82,7 +82,7 @@ func NewBuilder(db *database.DB, openAIConfig *config.OpenAIConfig, logger *zap.
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// BuildChainFromConversation 从对话构建攻击链(简化版本:用户输入+最后一轮ReAct输入+大模型输出)
|
// BuildChainFromConversation 从对话构建攻击链(单次 LLM 调用;输入为当前任务轮次的 last_react 轨迹,与继续对话续跑范围一致)。
|
||||||
func (b *Builder) BuildChainFromConversation(ctx context.Context, conversationID string) (*Chain, error) {
|
func (b *Builder) BuildChainFromConversation(ctx context.Context, conversationID string) (*Chain, error) {
|
||||||
b.logger.Info("开始构建攻击链(简化版本)", zap.String("conversationId", conversationID))
|
b.logger.Info("开始构建攻击链(简化版本)", zap.String("conversationId", conversationID))
|
||||||
|
|
||||||
@@ -157,33 +157,34 @@ func (b *Builder) BuildChainFromConversation(ctx context.Context, conversationID
|
|||||||
var reactInputFinal string
|
var reactInputFinal string
|
||||||
var dataSource string // 记录数据来源
|
var dataSource string // 记录数据来源
|
||||||
|
|
||||||
// 如果成功获取到保存的ReAct数据,直接使用
|
// 优先使用落库的代理轨迹(与继续对话 loadHistoryFromAgentTrace 同源),并裁剪为「当前任务轮次」
|
||||||
if reactInputJSON != "" && modelOutput != "" {
|
if reactInputJSON != "" {
|
||||||
// 计算 ReAct 输入的哈希值,用于追踪
|
trimmedJSON := agent.ExtractLastUserTurnTraceJSON(reactInputJSON)
|
||||||
hash := sha256.Sum256([]byte(reactInputJSON))
|
hash := sha256.Sum256([]byte(trimmedJSON))
|
||||||
reactInputHash := hex.EncodeToString(hash[:])[:16] // 使用前16字符作为短标识
|
reactInputHash := hex.EncodeToString(hash[:])[:16]
|
||||||
|
|
||||||
// 统计消息数量
|
|
||||||
var messageCount int
|
var messageCount int
|
||||||
var tempMessages []interface{}
|
if msgs, parseErr := agent.ParseTraceMessages(trimmedJSON); parseErr == nil {
|
||||||
if json.Unmarshal([]byte(reactInputJSON), &tempMessages) == nil {
|
messageCount = len(msgs)
|
||||||
messageCount = len(tempMessages)
|
msgs = agent.MergeAssistantTraceOutput(msgs, modelOutput)
|
||||||
|
reactInputFinal = b.formatAgentTraceFromChatMessages(msgs)
|
||||||
|
} else {
|
||||||
|
b.logger.Warn("解析代理轨迹失败,回退原始 JSON 格式化", zap.Error(parseErr))
|
||||||
|
reactInputFinal = b.formatAgentTraceInputFromJSON(trimmedJSON)
|
||||||
|
if strings.TrimSpace(modelOutput) != "" {
|
||||||
|
reactInputFinal += "\n\n## 助手结论(last_react_output)\n\n" + modelOutput
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
dataSource = "database_last_agent_trace"
|
dataSource = "last_user_turn_agent_trace"
|
||||||
b.logger.Info("使用保存的ReAct数据构建攻击链",
|
b.logger.Info("使用当前任务轮次代理轨迹构建攻击链(与续跑上下文范围一致)",
|
||||||
zap.String("conversationId", conversationID),
|
zap.String("conversationId", conversationID),
|
||||||
zap.String("dataSource", dataSource),
|
zap.String("dataSource", dataSource),
|
||||||
zap.Int("reactInputSize", len(reactInputJSON)),
|
zap.Int("traceInputSizeBeforeTrim", len(reactInputJSON)),
|
||||||
|
zap.Int("traceInputSizeAfterTrim", len(trimmedJSON)),
|
||||||
zap.Int("messageCount", messageCount),
|
zap.Int("messageCount", messageCount),
|
||||||
zap.String("reactInputHash", reactInputHash),
|
zap.String("reactInputHash", reactInputHash),
|
||||||
zap.Int("modelOutputSize", len(modelOutput)))
|
zap.Int("modelOutputSize", len(modelOutput)))
|
||||||
|
|
||||||
// 从保存的ReAct输入(JSON格式)中提取用户输入
|
|
||||||
// userInput = b.extractUserInputFromReActInput(reactInputJSON)
|
|
||||||
|
|
||||||
// 将JSON格式的messages转换为可读格式
|
|
||||||
reactInputFinal = b.formatAgentTraceInputFromJSON(reactInputJSON)
|
|
||||||
} else {
|
} else {
|
||||||
// 2. 如果没有保存的ReAct数据,从对话消息构建
|
// 2. 如果没有保存的ReAct数据,从对话消息构建
|
||||||
dataSource = "messages_table"
|
dataSource = "messages_table"
|
||||||
@@ -243,8 +244,15 @@ func (b *Builder) BuildChainFromConversation(ctx context.Context, conversationID
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// 3. 构建简化的prompt,一次性传递给大模型
|
// 3. 按 token 预算压缩输入,再构建 prompt(避免超出模型上下文)
|
||||||
prompt := b.buildSimplePrompt(reactInputFinal, modelOutput)
|
reactInputFinal, modelOutput, _ = b.fitAttackChainPayload(reactInputFinal, modelOutput)
|
||||||
|
|
||||||
|
// 4. 构建 prompt 并单次调用大模型(助手结论已并入轨迹时不再重复传入)
|
||||||
|
promptAssistantOut := modelOutput
|
||||||
|
if reactInputJSON != "" {
|
||||||
|
promptAssistantOut = ""
|
||||||
|
}
|
||||||
|
prompt := b.buildSimplePrompt(reactInputFinal, promptAssistantOut)
|
||||||
// fmt.Println(prompt)
|
// fmt.Println(prompt)
|
||||||
// 6. 调用AI生成攻击链(一次性,不做任何处理)
|
// 6. 调用AI生成攻击链(一次性,不做任何处理)
|
||||||
chainJSON, err := b.callAIForChainGeneration(ctx, prompt)
|
chainJSON, err := b.callAIForChainGeneration(ctx, prompt)
|
||||||
@@ -366,10 +374,17 @@ func (b *Builder) formatProcessDetailsForAttackChain(details []database.ProcessD
|
|||||||
return strings.TrimSpace(sb.String())
|
return strings.TrimSpace(sb.String())
|
||||||
}
|
}
|
||||||
|
|
||||||
// buildAgentTraceInput 构建最后一轮ReAct的输入(历史消息+当前用户输入)
|
// buildAgentTraceInput 构建最后一轮 ReAct 的输入(从最后一条 user 消息起,不含更早轮次)。
|
||||||
func (b *Builder) buildAgentTraceInput(messages []database.Message) string {
|
func (b *Builder) buildAgentTraceInput(messages []database.Message) string {
|
||||||
|
start := 0
|
||||||
|
for i := len(messages) - 1; i >= 0; i-- {
|
||||||
|
if strings.EqualFold(messages[i].Role, "user") {
|
||||||
|
start = i
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
var builder strings.Builder
|
var builder strings.Builder
|
||||||
for _, msg := range messages {
|
for _, msg := range messages[start:] {
|
||||||
builder.WriteString(fmt.Sprintf("[%s]: %s\n\n", msg.Role, msg.Content))
|
builder.WriteString(fmt.Sprintf("[%s]: %s\n\n", msg.Role, msg.Content))
|
||||||
}
|
}
|
||||||
return builder.String()
|
return builder.String()
|
||||||
@@ -396,67 +411,66 @@ func (b *Builder) buildAgentTraceInput(messages []database.Message) string {
|
|||||||
// return ""
|
// return ""
|
||||||
// }
|
// }
|
||||||
|
|
||||||
// formatAgentTraceInputFromJSON 将JSON格式的messages数组转换为可读的字符串格式
|
// formatAgentTraceInputFromJSON 将 JSON 轨迹转为可读文本(会先按当前任务轮次裁剪)。
|
||||||
func (b *Builder) formatAgentTraceInputFromJSON(reactInputJSON string) string {
|
func (b *Builder) formatAgentTraceInputFromJSON(reactInputJSON string) string {
|
||||||
var messages []map[string]interface{}
|
trimmed := agent.ExtractLastUserTurnTraceJSON(reactInputJSON)
|
||||||
if err := json.Unmarshal([]byte(reactInputJSON), &messages); err != nil {
|
msgs, err := agent.ParseTraceMessages(trimmed)
|
||||||
|
if err != nil {
|
||||||
b.logger.Warn("解析ReAct输入JSON失败", zap.Error(err))
|
b.logger.Warn("解析ReAct输入JSON失败", zap.Error(err))
|
||||||
return reactInputJSON // 如果解析失败,返回原始JSON
|
return trimmed
|
||||||
}
|
}
|
||||||
|
return b.formatAgentTraceFromChatMessages(msgs)
|
||||||
|
}
|
||||||
|
|
||||||
|
// formatAgentTraceFromChatMessages 将代理消息带格式化为攻击链分析输入(与续跑轨迹字段一致)。
|
||||||
|
func (b *Builder) formatAgentTraceFromChatMessages(msgs []agent.ChatMessage) string {
|
||||||
var builder strings.Builder
|
var builder strings.Builder
|
||||||
for _, msg := range messages {
|
for _, msg := range msgs {
|
||||||
role, _ := msg["role"].(string)
|
role := msg.Role
|
||||||
content, _ := msg["content"].(string)
|
content := msg.Content
|
||||||
|
|
||||||
// 处理assistant消息:提取tool_calls信息
|
if strings.EqualFold(role, "assistant") && len(msg.ToolCalls) > 0 {
|
||||||
if role == "assistant" {
|
if content != "" {
|
||||||
if toolCalls, ok := msg["tool_calls"].([]interface{}); ok && len(toolCalls) > 0 {
|
builder.WriteString(fmt.Sprintf("[%s]: %s\n", role, content))
|
||||||
// 如果有文本内容,先显示
|
}
|
||||||
if content != "" {
|
builder.WriteString(fmt.Sprintf("[%s] 工具调用 (%d个):\n", role, len(msg.ToolCalls)))
|
||||||
builder.WriteString(fmt.Sprintf("[%s]: %s\n", role, content))
|
for i, tc := range msg.ToolCalls {
|
||||||
}
|
args := ""
|
||||||
// 详细显示每个工具调用
|
if tc.Function.Arguments != nil {
|
||||||
builder.WriteString(fmt.Sprintf("[%s] 工具调用 (%d个):\n", role, len(toolCalls)))
|
if b, err := json.Marshal(tc.Function.Arguments); err == nil {
|
||||||
for i, toolCall := range toolCalls {
|
args = string(b)
|
||||||
if tc, ok := toolCall.(map[string]interface{}); ok {
|
|
||||||
toolCallID, _ := tc["id"].(string)
|
|
||||||
if funcData, ok := tc["function"].(map[string]interface{}); ok {
|
|
||||||
toolName, _ := funcData["name"].(string)
|
|
||||||
arguments, _ := funcData["arguments"].(string)
|
|
||||||
builder.WriteString(fmt.Sprintf(" [工具调用 %d]\n", i+1))
|
|
||||||
builder.WriteString(fmt.Sprintf(" ID: %s\n", toolCallID))
|
|
||||||
builder.WriteString(fmt.Sprintf(" 工具名称: %s\n", toolName))
|
|
||||||
builder.WriteString(fmt.Sprintf(" 参数: %s\n", arguments))
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
builder.WriteString("\n")
|
builder.WriteString(fmt.Sprintf(" [工具调用 %d]\n", i+1))
|
||||||
continue
|
builder.WriteString(fmt.Sprintf(" ID: %s\n", tc.ID))
|
||||||
|
builder.WriteString(fmt.Sprintf(" 工具名称: %s\n", tc.Function.Name))
|
||||||
|
builder.WriteString(fmt.Sprintf(" 参数: %s\n", args))
|
||||||
}
|
}
|
||||||
|
builder.WriteString("\n")
|
||||||
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
// 处理tool消息:显示tool_call_id和完整内容
|
if strings.EqualFold(role, "tool") {
|
||||||
if role == "tool" {
|
if msg.ToolCallID != "" {
|
||||||
toolCallID, _ := msg["tool_call_id"].(string)
|
builder.WriteString(fmt.Sprintf("[%s] (tool_call_id: %s):\n%s\n\n", role, msg.ToolCallID, content))
|
||||||
if toolCallID != "" {
|
|
||||||
builder.WriteString(fmt.Sprintf("[%s] (tool_call_id: %s):\n%s\n\n", role, toolCallID, content))
|
|
||||||
} else {
|
} else {
|
||||||
builder.WriteString(fmt.Sprintf("[%s]: %s\n\n", role, content))
|
builder.WriteString(fmt.Sprintf("[%s]: %s\n\n", role, content))
|
||||||
}
|
}
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
// 其他消息类型(system, user等)正常显示
|
|
||||||
builder.WriteString(fmt.Sprintf("[%s]: %s\n\n", role, content))
|
builder.WriteString(fmt.Sprintf("[%s]: %s\n\n", role, content))
|
||||||
}
|
}
|
||||||
|
|
||||||
return builder.String()
|
return builder.String()
|
||||||
}
|
}
|
||||||
|
|
||||||
// buildSimplePrompt 构建简化的prompt
|
// buildSimplePrompt 构建简化的prompt
|
||||||
func (b *Builder) buildSimplePrompt(reactInput, modelOutput string) string {
|
func (b *Builder) buildSimplePrompt(reactInput, modelOutput string) string {
|
||||||
return fmt.Sprintf(`你是专业的安全测试分析师和攻击链构建专家。你的任务是根据对话记录和工具执行结果,构建一个逻辑清晰、有教育意义的攻击链图,完整展现渗透测试的思维过程和执行路径。
|
return fmt.Sprintf(`你是专业的安全测试分析师和攻击链构建专家。你的任务是根据**当前任务轮次**的对话记录和工具执行结果,一次性输出攻击链 JSON(不要分多轮追问)。
|
||||||
|
|
||||||
|
## 输入范围(与「继续对话」续跑一致)
|
||||||
|
- 下方「ReAct 轨迹」仅包含**最后一次用户提问之后**的消息与工具结果(last_react 当前任务轮次),不含更早的用户提问轮次。
|
||||||
|
- 「助手结论」为同轮任务的最终输出摘要(last_react_output);节点须与轨迹中的实际工具执行一致,严禁编造。
|
||||||
|
|
||||||
## 核心目标
|
## 核心目标
|
||||||
|
|
||||||
@@ -618,12 +632,9 @@ func (b *Builder) buildSimplePrompt(reactInput, modelOutput string) string {
|
|||||||
5. **漏洞确认**:如何确认漏洞存在?(action→vulnerability)
|
5. **漏洞确认**:如何确认漏洞存在?(action→vulnerability)
|
||||||
6. **攻击路径**:完整的攻击路径是什么?(从target到vulnerability的路径)
|
6. **攻击路径**:完整的攻击路径是什么?(从target到vulnerability的路径)
|
||||||
|
|
||||||
## 最后一轮ReAct输入
|
## 当前任务 ReAct 轨迹(含工具执行;助手结论见轨迹末尾 assistant)
|
||||||
|
|
||||||
%s
|
%s
|
||||||
|
|
||||||
## 大模型输出
|
|
||||||
|
|
||||||
%s
|
%s
|
||||||
|
|
||||||
## 输出格式
|
## 输出格式
|
||||||
@@ -752,7 +763,15 @@ func (b *Builder) buildSimplePrompt(reactInput, modelOutput string) string {
|
|||||||
9. **不要过度精简**:如果实际执行步骤较多,可以适当增加节点数量(最多20个),确保不遗漏关键步骤。
|
9. **不要过度精简**:如果实际执行步骤较多,可以适当增加节点数量(最多20个),确保不遗漏关键步骤。
|
||||||
10. **输出前验证**:在输出JSON前,必须验证所有边都满足source < target的条件,确保DAG结构正确。
|
10. **输出前验证**:在输出JSON前,必须验证所有边都满足source < target的条件,确保DAG结构正确。
|
||||||
|
|
||||||
现在开始分析并构建攻击链:`, reactInput, modelOutput)
|
现在开始分析并构建攻击链:`, reactInput, assistantOutSection(modelOutput))
|
||||||
|
}
|
||||||
|
|
||||||
|
func assistantOutSection(modelOutput string) string {
|
||||||
|
modelOutput = strings.TrimSpace(modelOutput)
|
||||||
|
if modelOutput == "" {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return "\n## 助手结论(补充)\n\n" + modelOutput + "\n"
|
||||||
}
|
}
|
||||||
|
|
||||||
// saveChain 保存攻击链到数据库
|
// saveChain 保存攻击链到数据库
|
||||||
@@ -812,7 +831,7 @@ func (b *Builder) callAIForChainGeneration(ctx context.Context, prompt string) (
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
"temperature": 0.3,
|
"temperature": 0.3,
|
||||||
"max_completion_tokens": 80000,
|
"max_completion_tokens": attackChainMaxCompletionTokens(b.maxTokens),
|
||||||
}
|
}
|
||||||
|
|
||||||
var apiResponse struct {
|
var apiResponse struct {
|
||||||
|
|||||||
@@ -0,0 +1,248 @@
|
|||||||
|
package attackchain
|
||||||
|
|
||||||
|
import (
|
||||||
|
"strings"
|
||||||
|
"unicode/utf8"
|
||||||
|
|
||||||
|
"go.uber.org/zap"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
attackChainTruncationMarker = "\n\n...[攻击链输入已截断 / attack chain input truncated]...\n\n"
|
||||||
|
attackChainSystemReserve = 256
|
||||||
|
attackChainSafetyReserve = 2048
|
||||||
|
)
|
||||||
|
|
||||||
|
// attackChainMaxCompletionTokens 为攻击链 JSON 输出预留的 completion token 上限。
|
||||||
|
func attackChainMaxCompletionTokens(maxTotal int) int {
|
||||||
|
const capTokens = 16384
|
||||||
|
if maxTotal <= 0 {
|
||||||
|
return 8192
|
||||||
|
}
|
||||||
|
v := maxTotal / 8
|
||||||
|
if v < 4096 {
|
||||||
|
v = 4096
|
||||||
|
}
|
||||||
|
if v > capTokens {
|
||||||
|
v = capTokens
|
||||||
|
}
|
||||||
|
return v
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *Builder) modelName() string {
|
||||||
|
if b.openAIConfig != nil && b.openAIConfig.Model != "" {
|
||||||
|
return b.openAIConfig.Model
|
||||||
|
}
|
||||||
|
return "gpt-4"
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *Builder) countTokens(text string) int {
|
||||||
|
if text == "" {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
n, err := b.tokenCounter.Count(b.modelName(), text)
|
||||||
|
if err != nil {
|
||||||
|
return utf8.RuneCountInString(text) / 4
|
||||||
|
}
|
||||||
|
return n
|
||||||
|
}
|
||||||
|
|
||||||
|
// attackChainPayloadTokenBudget 计算 reactInput + modelOutput 可用的 token 预算。
|
||||||
|
func (b *Builder) attackChainPayloadTokenBudget() int {
|
||||||
|
maxTotal := b.maxTokens
|
||||||
|
if maxTotal <= 0 {
|
||||||
|
maxTotal = 100000
|
||||||
|
}
|
||||||
|
templateTok := b.countTokens(b.buildSimplePrompt("", ""))
|
||||||
|
completion := attackChainMaxCompletionTokens(maxTotal)
|
||||||
|
reserve := templateTok + attackChainSystemReserve + completion + attackChainSafetyReserve
|
||||||
|
budget := maxTotal - reserve
|
||||||
|
minBudget := maxTotal * 35 / 100
|
||||||
|
if budget < minBudget {
|
||||||
|
budget = minBudget
|
||||||
|
}
|
||||||
|
if budget < 4096 {
|
||||||
|
budget = 4096
|
||||||
|
}
|
||||||
|
return budget
|
||||||
|
}
|
||||||
|
|
||||||
|
// fitAttackChainPayload 在构建最终 prompt 前压缩 ReAct 轨迹与模型输出,避免超出模型上下文。
|
||||||
|
func (b *Builder) fitAttackChainPayload(reactInput, modelOutput string) (string, string, bool) {
|
||||||
|
budget := b.attackChainPayloadTokenBudget()
|
||||||
|
modelBudget := budget * 15 / 100
|
||||||
|
if modelBudget < 512 {
|
||||||
|
modelBudget = 512
|
||||||
|
}
|
||||||
|
reactBudget := budget - modelBudget
|
||||||
|
|
||||||
|
origReactTok := b.countTokens(reactInput)
|
||||||
|
origModelTok := b.countTokens(modelOutput)
|
||||||
|
truncated := false
|
||||||
|
|
||||||
|
outModel := modelOutput
|
||||||
|
if origModelTok > modelBudget {
|
||||||
|
outModel = truncateTextByTokens(b, modelOutput, modelBudget)
|
||||||
|
truncated = true
|
||||||
|
}
|
||||||
|
|
||||||
|
outReact := reactInput
|
||||||
|
perToolLimits := []int{12000, 6000, 3000, 1500, 800}
|
||||||
|
for _, lim := range perToolLimits {
|
||||||
|
compact := compactFormattedToolBodies(outReact, lim)
|
||||||
|
if compact != outReact {
|
||||||
|
outReact = compact
|
||||||
|
truncated = true
|
||||||
|
}
|
||||||
|
if b.countTokens(outReact) <= reactBudget {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if b.countTokens(outReact) > reactBudget {
|
||||||
|
outReact = truncateTextByTokens(b, outReact, reactBudget)
|
||||||
|
truncated = true
|
||||||
|
}
|
||||||
|
|
||||||
|
if truncated {
|
||||||
|
b.logger.Info("攻击链输入已按 token 预算截断",
|
||||||
|
zap.Int("maxTotalTokens", b.maxTokens),
|
||||||
|
zap.Int("payloadBudget", budget),
|
||||||
|
zap.Int("reactBudget", reactBudget),
|
||||||
|
zap.Int("modelBudget", modelBudget),
|
||||||
|
zap.Int("reactInputTokensBefore", origReactTok),
|
||||||
|
zap.Int("reactInputTokensAfter", b.countTokens(outReact)),
|
||||||
|
zap.Int("modelOutputTokensBefore", origModelTok),
|
||||||
|
zap.Int("modelOutputTokensAfter", b.countTokens(outModel)),
|
||||||
|
zap.Int("maxCompletionTokens", attackChainMaxCompletionTokens(b.maxTokens)),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
return outReact, outModel, truncated
|
||||||
|
}
|
||||||
|
|
||||||
|
// compactFormattedToolBodies 缩短格式化 trace 中 [tool] 消息的正文,保留工具头与调用 ID。
|
||||||
|
func compactFormattedToolBodies(s string, maxRunesPerBody int) string {
|
||||||
|
if maxRunesPerBody <= 0 || s == "" {
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
const marker = "[tool]"
|
||||||
|
var out strings.Builder
|
||||||
|
remaining := s
|
||||||
|
changed := false
|
||||||
|
for {
|
||||||
|
idx := strings.Index(remaining, marker)
|
||||||
|
if idx < 0 {
|
||||||
|
out.WriteString(remaining)
|
||||||
|
break
|
||||||
|
}
|
||||||
|
out.WriteString(remaining[:idx])
|
||||||
|
remaining = remaining[idx:]
|
||||||
|
nl := strings.IndexByte(remaining, '\n')
|
||||||
|
if nl < 0 {
|
||||||
|
out.WriteString(remaining)
|
||||||
|
break
|
||||||
|
}
|
||||||
|
header := remaining[:nl+1]
|
||||||
|
remaining = remaining[nl+1:]
|
||||||
|
bodyEnd := strings.Index(remaining, "\n\n[")
|
||||||
|
var body, rest string
|
||||||
|
if bodyEnd < 0 {
|
||||||
|
body = remaining
|
||||||
|
rest = ""
|
||||||
|
} else {
|
||||||
|
body = remaining[:bodyEnd]
|
||||||
|
rest = remaining[bodyEnd:]
|
||||||
|
}
|
||||||
|
if runeLen(body) > maxRunesPerBody {
|
||||||
|
body = truncateRunesWithNotice(body, maxRunesPerBody)
|
||||||
|
changed = true
|
||||||
|
}
|
||||||
|
out.WriteString(header)
|
||||||
|
out.WriteString(body)
|
||||||
|
remaining = rest
|
||||||
|
if rest == "" {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !changed {
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
return out.String()
|
||||||
|
}
|
||||||
|
|
||||||
|
func truncateTextByTokens(b *Builder, text string, maxTokens int) string {
|
||||||
|
if maxTokens <= 0 || text == "" {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
if b.countTokens(text) <= maxTokens {
|
||||||
|
return text
|
||||||
|
}
|
||||||
|
markerTok := b.countTokens(attackChainTruncationMarker)
|
||||||
|
usable := maxTokens - markerTok
|
||||||
|
if usable < 256 {
|
||||||
|
usable = maxTokens / 2
|
||||||
|
}
|
||||||
|
headBudget := usable * 60 / 100
|
||||||
|
tailBudget := usable - headBudget
|
||||||
|
head := takeTokensFromStart(b, text, headBudget)
|
||||||
|
tail := takeTokensFromEnd(b, text, tailBudget)
|
||||||
|
return head + attackChainTruncationMarker + tail
|
||||||
|
}
|
||||||
|
|
||||||
|
func takeTokensFromStart(b *Builder, text string, maxTokens int) string {
|
||||||
|
rs := []rune(text)
|
||||||
|
if len(rs) == 0 || maxTokens <= 0 {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
lo, hi := 0, len(rs)
|
||||||
|
for lo < hi {
|
||||||
|
mid := (lo + hi + 1) / 2
|
||||||
|
if b.countTokens(string(rs[:mid])) <= maxTokens {
|
||||||
|
lo = mid
|
||||||
|
} else {
|
||||||
|
hi = mid - 1
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return string(rs[:lo])
|
||||||
|
}
|
||||||
|
|
||||||
|
func takeTokensFromEnd(b *Builder, text string, maxTokens int) string {
|
||||||
|
rs := []rune(text)
|
||||||
|
if len(rs) == 0 || maxTokens <= 0 {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
lo, hi := 0, len(rs)
|
||||||
|
for lo < hi {
|
||||||
|
mid := (lo + hi) / 2
|
||||||
|
if b.countTokens(string(rs[mid:])) <= maxTokens {
|
||||||
|
hi = mid
|
||||||
|
} else {
|
||||||
|
lo = mid + 1
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return string(rs[lo:])
|
||||||
|
}
|
||||||
|
|
||||||
|
func truncateRunesWithNotice(s string, maxRunes int) string {
|
||||||
|
rs := []rune(s)
|
||||||
|
if len(rs) <= maxRunes {
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
const notice = "\n...[工具输出已截断 / tool output truncated]...\n"
|
||||||
|
noticeRunes := []rune(notice)
|
||||||
|
keep := maxRunes - len(noticeRunes)
|
||||||
|
if keep < 200 {
|
||||||
|
keep = maxRunes * 2 / 3
|
||||||
|
}
|
||||||
|
if keep < 1 {
|
||||||
|
return notice
|
||||||
|
}
|
||||||
|
head := keep * 70 / 100
|
||||||
|
tail := keep - head
|
||||||
|
return string(rs[:head]) + notice + string(rs[len(rs)-tail:])
|
||||||
|
}
|
||||||
|
|
||||||
|
func runeLen(s string) int {
|
||||||
|
return len([]rune(s))
|
||||||
|
}
|
||||||
@@ -0,0 +1,63 @@
|
|||||||
|
package attackchain
|
||||||
|
|
||||||
|
import (
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"cyberstrike-ai/internal/agent"
|
||||||
|
"cyberstrike-ai/internal/config"
|
||||||
|
|
||||||
|
"go.uber.org/zap"
|
||||||
|
)
|
||||||
|
|
||||||
|
func testBuilder(maxTotal int) *Builder {
|
||||||
|
return &Builder{
|
||||||
|
logger: zap.NewNop(),
|
||||||
|
openAIConfig: &config.OpenAIConfig{Model: "gpt-4"},
|
||||||
|
tokenCounter: agent.NewTikTokenCounter(),
|
||||||
|
maxTokens: maxTotal,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCompactFormattedToolBodies(t *testing.T) {
|
||||||
|
long := strings.Repeat("x", 20000)
|
||||||
|
in := "[user]: hi\n\n[tool] (tool_call_id: abc):\n" + long + "\n\n[assistant]: done\n"
|
||||||
|
out := compactFormattedToolBodies(in, 500)
|
||||||
|
if strings.Contains(out, strings.Repeat("x", 10000)) {
|
||||||
|
t.Fatal("expected tool body to be truncated")
|
||||||
|
}
|
||||||
|
if !strings.Contains(out, "[user]: hi") {
|
||||||
|
t.Fatal("expected user header preserved")
|
||||||
|
}
|
||||||
|
if !strings.Contains(out, "[assistant]: done") {
|
||||||
|
t.Fatal("expected assistant header preserved")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFitAttackChainPayloadWithinBudget(t *testing.T) {
|
||||||
|
b := testBuilder(32000)
|
||||||
|
react := strings.Repeat("scan ", 50000)
|
||||||
|
model := strings.Repeat("result ", 10000)
|
||||||
|
r, m, truncated := b.fitAttackChainPayload(react, model)
|
||||||
|
if !truncated {
|
||||||
|
t.Fatal("expected truncation for large payload")
|
||||||
|
}
|
||||||
|
prompt := b.buildSimplePrompt(r, m)
|
||||||
|
total := b.countTokens(prompt) + attackChainMaxCompletionTokens(b.maxTokens) + attackChainSystemReserve
|
||||||
|
if total > b.maxTokens+attackChainSafetyReserve {
|
||||||
|
t.Fatalf("prompt still too large: estimated %d > max %d", total, b.maxTokens)
|
||||||
|
}
|
||||||
|
_ = m
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAttackChainMaxCompletionTokens(t *testing.T) {
|
||||||
|
if got := attackChainMaxCompletionTokens(120000); got != 15000 && got != 16384 {
|
||||||
|
// 120000/8 = 15000
|
||||||
|
if got < 4096 || got > 16384 {
|
||||||
|
t.Fatalf("unexpected completion cap: %d", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if got := attackChainMaxCompletionTokens(0); got != 8192 {
|
||||||
|
t.Fatalf("expected default 8192, got %d", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,55 @@
|
|||||||
|
package audit
|
||||||
|
|
||||||
|
import (
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"cyberstrike-ai/internal/database"
|
||||||
|
"cyberstrike-ai/internal/security"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
)
|
||||||
|
|
||||||
|
// RegisterConversationCreateHook records platform audit rows for every new conversation.
|
||||||
|
func RegisterConversationCreateHook(s *Service) {
|
||||||
|
if s == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
database.SetConversationCreateHook(func(conv *database.Conversation, meta database.ConversationCreateMeta) {
|
||||||
|
detail := map[string]interface{}{
|
||||||
|
"title": conv.Title,
|
||||||
|
"source": meta.Source,
|
||||||
|
}
|
||||||
|
if meta.WebShellConnectionID != "" {
|
||||||
|
detail["webshell_connection_id"] = meta.WebShellConnectionID
|
||||||
|
}
|
||||||
|
s.Record(nil, Entry{
|
||||||
|
Category: "conversation",
|
||||||
|
Action: "create",
|
||||||
|
Result: "success",
|
||||||
|
Message: "创建对话",
|
||||||
|
ResourceType: "conversation",
|
||||||
|
ResourceID: conv.ID,
|
||||||
|
Detail: detail,
|
||||||
|
ClientIP: meta.ClientIP,
|
||||||
|
SessionHint: meta.SessionHint,
|
||||||
|
})
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// ConversationCreateMeta builds audit metadata for conversation creation.
|
||||||
|
func ConversationCreateMeta(source string) database.ConversationCreateMeta {
|
||||||
|
return database.ConversationCreateMeta{Source: strings.TrimSpace(source)}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ConversationCreateMetaFromGin includes client IP and session hint when available.
|
||||||
|
func ConversationCreateMetaFromGin(c *gin.Context, source string) database.ConversationCreateMeta {
|
||||||
|
m := ConversationCreateMeta(source)
|
||||||
|
if c == nil {
|
||||||
|
return m
|
||||||
|
}
|
||||||
|
m.ClientIP = c.ClientIP()
|
||||||
|
if token := c.GetString(security.ContextAuthTokenKey); token != "" {
|
||||||
|
m.SessionHint = sessionHint(token)
|
||||||
|
}
|
||||||
|
return m
|
||||||
|
}
|
||||||
@@ -0,0 +1,9 @@
|
|||||||
|
package audit
|
||||||
|
|
||||||
|
// RetentionDays returns configured retention; 0 means keep forever.
|
||||||
|
func (s *Service) RetentionDays() int {
|
||||||
|
if s == nil || s.cfg == nil {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
return s.cfg.Audit.RetentionDaysEffective()
|
||||||
|
}
|
||||||
@@ -0,0 +1,29 @@
|
|||||||
|
package audit
|
||||||
|
|
||||||
|
import "github.com/gin-gonic/gin"
|
||||||
|
|
||||||
|
// RecordAction writes a platform audit row with common defaults.
|
||||||
|
func (s *Service) RecordAction(c *gin.Context, category, action, result, message, resourceType, resourceID string, detail map[string]interface{}) {
|
||||||
|
if s == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
s.Record(c, Entry{
|
||||||
|
Category: category,
|
||||||
|
Action: action,
|
||||||
|
Result: result,
|
||||||
|
Message: message,
|
||||||
|
ResourceType: resourceType,
|
||||||
|
ResourceID: resourceID,
|
||||||
|
Detail: detail,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// RecordOK is a shorthand for successful operations.
|
||||||
|
func (s *Service) RecordOK(c *gin.Context, category, action, message, resourceType, resourceID string, detail map[string]interface{}) {
|
||||||
|
s.RecordAction(c, category, action, "success", message, resourceType, resourceID, detail)
|
||||||
|
}
|
||||||
|
|
||||||
|
// RecordFail is a shorthand for failed operations.
|
||||||
|
func (s *Service) RecordFail(c *gin.Context, category, action, message string, detail map[string]interface{}) {
|
||||||
|
s.RecordAction(c, category, action, "failure", message, "", "", detail)
|
||||||
|
}
|
||||||
@@ -0,0 +1,86 @@
|
|||||||
|
package audit
|
||||||
|
|
||||||
|
import (
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"cyberstrike-ai/internal/database"
|
||||||
|
)
|
||||||
|
|
||||||
|
var auditActionsResourceRemoved = map[string]bool{
|
||||||
|
"delete": true,
|
||||||
|
"item_delete": true,
|
||||||
|
"connection_delete": true,
|
||||||
|
"listener_delete": true,
|
||||||
|
"session_delete": true,
|
||||||
|
"task_delete": true,
|
||||||
|
"execution_delete": true,
|
||||||
|
"execution_delete_batch": true,
|
||||||
|
"delete_queue": true,
|
||||||
|
"delete_batch_task": true,
|
||||||
|
"markdown_delete": true,
|
||||||
|
}
|
||||||
|
|
||||||
|
// ApplyResourceAvailability sets log.ResourceAvailable when the linked resource can be checked.
|
||||||
|
func ApplyResourceAvailability(db *database.DB, log *database.AuditLog) {
|
||||||
|
if log == nil || strings.TrimSpace(log.ResourceID) == "" {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if auditActionsResourceRemoved[log.Action] {
|
||||||
|
f := false
|
||||||
|
log.ResourceAvailable = &f
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if db == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
available, known := resourceStillExists(db, log.ResourceType, log.ResourceID)
|
||||||
|
if known {
|
||||||
|
log.ResourceAvailable = &available
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func resourceStillExists(db *database.DB, resourceType, resourceID string) (bool, bool) {
|
||||||
|
resourceID = strings.TrimSpace(resourceID)
|
||||||
|
if resourceID == "" {
|
||||||
|
return false, false
|
||||||
|
}
|
||||||
|
t := strings.TrimSpace(resourceType)
|
||||||
|
if t == "" {
|
||||||
|
if len(resourceID) > 8 && !strings.HasPrefix(resourceID, "c2_") {
|
||||||
|
t = "conversation"
|
||||||
|
} else {
|
||||||
|
return false, false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
switch t {
|
||||||
|
case "conversation":
|
||||||
|
ok, err := db.ConversationExists(resourceID)
|
||||||
|
return ok, err == nil
|
||||||
|
case "vulnerability":
|
||||||
|
_, err := db.GetVulnerability(resourceID)
|
||||||
|
if err != nil {
|
||||||
|
return false, strings.Contains(err.Error(), "不存在")
|
||||||
|
}
|
||||||
|
return true, true
|
||||||
|
case "batch_queue":
|
||||||
|
_, err := db.GetBatchQueue(resourceID)
|
||||||
|
return err == nil, true
|
||||||
|
case "c2_listener":
|
||||||
|
_, err := db.GetC2Listener(resourceID)
|
||||||
|
return err == nil, true
|
||||||
|
case "c2_session":
|
||||||
|
_, err := db.GetC2Session(resourceID)
|
||||||
|
return err == nil, true
|
||||||
|
case "c2_task":
|
||||||
|
_, err := db.GetC2Task(resourceID)
|
||||||
|
return err == nil, true
|
||||||
|
case "webshell_connection":
|
||||||
|
c, err := db.GetWebshellConnection(resourceID)
|
||||||
|
return err == nil && c != nil, true
|
||||||
|
case "tool_execution":
|
||||||
|
_, err := db.GetToolExecution(resourceID)
|
||||||
|
return err == nil, true
|
||||||
|
default:
|
||||||
|
return false, false
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,27 @@
|
|||||||
|
package audit
|
||||||
|
|
||||||
|
import (
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"go.uber.org/zap"
|
||||||
|
)
|
||||||
|
|
||||||
|
// auditRetentionPurgeInterval is how often PurgeExpired runs while the process is up (startup also purges once).
|
||||||
|
const auditRetentionPurgeInterval = time.Hour
|
||||||
|
|
||||||
|
// StartRetentionLoop periodically purges expired audit rows.
|
||||||
|
func StartRetentionLoop(s *Service, logger *zap.Logger) {
|
||||||
|
if s == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
go func() {
|
||||||
|
ticker := time.NewTicker(auditRetentionPurgeInterval)
|
||||||
|
defer ticker.Stop()
|
||||||
|
for range ticker.C {
|
||||||
|
s.PurgeExpired()
|
||||||
|
if logger != nil {
|
||||||
|
logger.Debug("audit retention tick completed")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
@@ -0,0 +1,58 @@
|
|||||||
|
package audit
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
var sensitiveKeySubstrings = []string{
|
||||||
|
"password", "api_key", "apikey", "secret", "token", "authorization",
|
||||||
|
"credential", "private_key", "access_key",
|
||||||
|
}
|
||||||
|
|
||||||
|
// SanitizeDetail redacts sensitive keys and truncates serialized size.
|
||||||
|
func SanitizeDetail(detail map[string]interface{}, maxBytes int) map[string]interface{} {
|
||||||
|
if detail == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if maxBytes <= 0 {
|
||||||
|
maxBytes = 8192
|
||||||
|
}
|
||||||
|
out := sanitizeValue("", detail)
|
||||||
|
if m, ok := out.(map[string]interface{}); ok {
|
||||||
|
b, _ := json.Marshal(m)
|
||||||
|
if len(b) > maxBytes {
|
||||||
|
return map[string]interface{}{
|
||||||
|
"_truncated": true,
|
||||||
|
"_preview": string(b[:maxBytes]),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return m
|
||||||
|
}
|
||||||
|
return map[string]interface{}{"value": out}
|
||||||
|
}
|
||||||
|
|
||||||
|
func sanitizeValue(key string, v interface{}) interface{} {
|
||||||
|
kl := strings.ToLower(key)
|
||||||
|
for _, sub := range sensitiveKeySubstrings {
|
||||||
|
if strings.Contains(kl, sub) {
|
||||||
|
return "***"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
switch t := v.(type) {
|
||||||
|
case map[string]interface{}:
|
||||||
|
m := make(map[string]interface{}, len(t))
|
||||||
|
for k, val := range t {
|
||||||
|
m[k] = sanitizeValue(k, val)
|
||||||
|
}
|
||||||
|
return m
|
||||||
|
case []interface{}:
|
||||||
|
arr := make([]interface{}, len(t))
|
||||||
|
for i, val := range t {
|
||||||
|
arr[i] = sanitizeValue(key, val)
|
||||||
|
}
|
||||||
|
return arr
|
||||||
|
default:
|
||||||
|
return v
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,172 @@
|
|||||||
|
package audit
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/sha256"
|
||||||
|
"encoding/hex"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"cyberstrike-ai/internal/config"
|
||||||
|
"cyberstrike-ai/internal/database"
|
||||||
|
"cyberstrike-ai/internal/security"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/google/uuid"
|
||||||
|
"go.uber.org/zap"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Service persists platform audit logs.
|
||||||
|
type Service struct {
|
||||||
|
db *database.DB
|
||||||
|
cfg *config.Config
|
||||||
|
logger *zap.Logger
|
||||||
|
failThrottle *failureThrottle
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewService creates an audit service.
|
||||||
|
func NewService(db *database.DB, cfg *config.Config, logger *zap.Logger) *Service {
|
||||||
|
return &Service{
|
||||||
|
db: db,
|
||||||
|
cfg: cfg,
|
||||||
|
logger: logger,
|
||||||
|
failThrottle: newFailureThrottle(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Enabled reports whether audit persistence is on.
|
||||||
|
func (s *Service) Enabled() bool {
|
||||||
|
if s == nil || s.cfg == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return s.cfg.Audit.EnabledEffective()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Record writes one audit row from a Gin request context.
|
||||||
|
func (s *Service) Record(c *gin.Context, e Entry) {
|
||||||
|
if s == nil || !s.Enabled() || s.db == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if strings.TrimSpace(e.Category) == "" || strings.TrimSpace(e.Action) == "" {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if e.Result == "failure" && !s.allowFailureAudit(c, e) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if strings.TrimSpace(e.Result) == "" {
|
||||||
|
e.Result = "success"
|
||||||
|
}
|
||||||
|
if strings.TrimSpace(e.Level) == "" {
|
||||||
|
if e.Result == "failure" {
|
||||||
|
e.Level = "warn"
|
||||||
|
} else {
|
||||||
|
e.Level = "info"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if strings.TrimSpace(e.Actor) == "" {
|
||||||
|
e.Actor = "admin"
|
||||||
|
}
|
||||||
|
maxDetail := s.cfg.Audit.MaxDetailBytesEffective()
|
||||||
|
detail := SanitizeDetail(e.Detail, maxDetail)
|
||||||
|
|
||||||
|
sessionHintVal := e.SessionHint
|
||||||
|
if sessionHintVal == "" && c != nil {
|
||||||
|
if token := c.GetString(security.ContextAuthTokenKey); token != "" {
|
||||||
|
sessionHintVal = sessionHint(token)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
clientIPVal := e.ClientIP
|
||||||
|
if clientIPVal == "" {
|
||||||
|
clientIPVal = clientIP(c)
|
||||||
|
}
|
||||||
|
|
||||||
|
row := &database.AuditLog{
|
||||||
|
ID: "audit_" + strings.ReplaceAll(uuid.New().String(), "-", ""),
|
||||||
|
CreatedAt: time.Now(),
|
||||||
|
Level: e.Level,
|
||||||
|
Category: e.Category,
|
||||||
|
Action: e.Action,
|
||||||
|
Result: e.Result,
|
||||||
|
Actor: e.Actor,
|
||||||
|
SessionHint: sessionHintVal,
|
||||||
|
ClientIP: clientIPVal,
|
||||||
|
UserAgent: userAgent(c),
|
||||||
|
ResourceType: e.ResourceType,
|
||||||
|
ResourceID: e.ResourceID,
|
||||||
|
Message: e.Message,
|
||||||
|
Detail: detail,
|
||||||
|
}
|
||||||
|
if err := s.db.AppendAuditLog(row); err != nil && s.logger != nil {
|
||||||
|
s.logger.Warn("写入审计日志失败",
|
||||||
|
zap.String("action", e.Action),
|
||||||
|
zap.Error(err),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// RecordSystem writes an audit row without HTTP context (e.g. retention cleanup).
|
||||||
|
func (s *Service) RecordSystem(e Entry) {
|
||||||
|
s.Record(nil, e)
|
||||||
|
}
|
||||||
|
|
||||||
|
// PurgeExpired deletes rows older than retention_days when configured.
|
||||||
|
func (s *Service) PurgeExpired() {
|
||||||
|
if s == nil || s.db == nil || s.cfg == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
days := s.cfg.Audit.RetentionDaysEffective()
|
||||||
|
if days <= 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
cutoff := time.Now().AddDate(0, 0, -days)
|
||||||
|
n, err := s.db.DeleteAuditLogsBefore(cutoff)
|
||||||
|
if err != nil {
|
||||||
|
if s.logger != nil {
|
||||||
|
s.logger.Warn("清理过期审计日志失败", zap.Error(err))
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if n > 0 && s.logger != nil {
|
||||||
|
s.logger.Info("已清理过期审计日志", zap.Int64("deleted", n))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// HintFromToken returns a short stable hash prefix for a session token.
|
||||||
|
func HintFromToken(token string) string {
|
||||||
|
return sessionHint(token)
|
||||||
|
}
|
||||||
|
|
||||||
|
func sessionHint(token string) string {
|
||||||
|
token = strings.TrimSpace(token)
|
||||||
|
if token == "" {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
sum := sha256.Sum256([]byte(token))
|
||||||
|
return hex.EncodeToString(sum[:4])
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Service) allowFailureAudit(c *gin.Context, e Entry) bool {
|
||||||
|
if !isAuthFailureThrottled(e.Category, e.Action) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
cooldown := time.Duration(s.cfg.Audit.AuthFailureCooldownEffective()) * time.Second
|
||||||
|
key := authFailureThrottleKey(e.Category, e.Action, clientIP(c))
|
||||||
|
return s.failThrottle.allow(key, cooldown)
|
||||||
|
}
|
||||||
|
|
||||||
|
func clientIP(c *gin.Context) string {
|
||||||
|
if c == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return c.ClientIP()
|
||||||
|
}
|
||||||
|
|
||||||
|
func userAgent(c *gin.Context) string {
|
||||||
|
if c == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
ua := c.GetHeader("User-Agent")
|
||||||
|
if len(ua) > 512 {
|
||||||
|
return ua[:512]
|
||||||
|
}
|
||||||
|
return ua
|
||||||
|
}
|
||||||
@@ -0,0 +1,55 @@
|
|||||||
|
package audit
|
||||||
|
|
||||||
|
import (
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// failureThrottle deduplicates high-frequency failure audit rows (e.g. wrong password).
|
||||||
|
type failureThrottle struct {
|
||||||
|
mu sync.Mutex
|
||||||
|
last map[string]time.Time
|
||||||
|
}
|
||||||
|
|
||||||
|
func newFailureThrottle() *failureThrottle {
|
||||||
|
return &failureThrottle{last: make(map[string]time.Time)}
|
||||||
|
}
|
||||||
|
|
||||||
|
// allow reports whether a row with the given key may be written now.
|
||||||
|
func (t *failureThrottle) allow(key string, cooldown time.Duration) bool {
|
||||||
|
if t == nil || cooldown <= 0 || key == "" {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
now := time.Now()
|
||||||
|
t.mu.Lock()
|
||||||
|
defer t.mu.Unlock()
|
||||||
|
if prev, ok := t.last[key]; ok && now.Sub(prev) < cooldown {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
t.last[key] = now
|
||||||
|
if len(t.last) > 4096 {
|
||||||
|
for k, ts := range t.last {
|
||||||
|
if now.Sub(ts) > cooldown*2 {
|
||||||
|
delete(t.last, k)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// authFailureThrottleKey builds a per-IP key for auth failure deduplication.
|
||||||
|
func authFailureThrottleKey(category, action, clientIP string) string {
|
||||||
|
return category + ":" + action + ":" + clientIP
|
||||||
|
}
|
||||||
|
|
||||||
|
func isAuthFailureThrottled(category, action string) bool {
|
||||||
|
if category != "auth" {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
switch action {
|
||||||
|
case "login", "change_password":
|
||||||
|
return true
|
||||||
|
default:
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,16 @@
|
|||||||
|
package audit
|
||||||
|
|
||||||
|
// Entry describes one platform audit record (not chat/tool execution bodies).
|
||||||
|
type Entry struct {
|
||||||
|
Level string
|
||||||
|
Category string
|
||||||
|
Action string
|
||||||
|
Result string // success | failure
|
||||||
|
Actor string
|
||||||
|
SessionHint string
|
||||||
|
ResourceType string
|
||||||
|
ResourceID string
|
||||||
|
Message string
|
||||||
|
Detail map[string]interface{}
|
||||||
|
ClientIP string // optional when c is nil (robot, batch, DB hook)
|
||||||
|
}
|
||||||
@@ -239,13 +239,15 @@ func (m *Manager) StartListener(id string) (*database.C2Listener, error) {
|
|||||||
}
|
}
|
||||||
cfg.ApplyDefaults()
|
cfg.ApplyDefaults()
|
||||||
|
|
||||||
// 通过工厂创建具体实现
|
// 通过工厂创建具体实现。必须使用 rec 的副本:HTTP handler 在返回 JSON 前会清空
|
||||||
|
// rec.ImplantToken / EncryptionKey 做脱敏,若 listener 实现持有同一指针会导致 beacon 鉴权永久失败。
|
||||||
|
listenerRec := *rec
|
||||||
factory := m.registry.Get(rec.Type)
|
factory := m.registry.Get(rec.Type)
|
||||||
if factory == nil {
|
if factory == nil {
|
||||||
return nil, ErrUnsupportedType
|
return nil, ErrUnsupportedType
|
||||||
}
|
}
|
||||||
inst, err := factory(ListenerCreationCtx{
|
inst, err := factory(ListenerCreationCtx{
|
||||||
Listener: rec,
|
Listener: &listenerRec,
|
||||||
Config: cfg,
|
Config: cfg,
|
||||||
Manager: m,
|
Manager: m,
|
||||||
Logger: m.logger.With(zap.String("listener_id", rec.ID), zap.String("type", rec.Type)),
|
Logger: m.logger.With(zap.String("listener_id", rec.ID), zap.String("type", rec.Type)),
|
||||||
|
|||||||
@@ -0,0 +1,74 @@
|
|||||||
|
package c2
|
||||||
|
|
||||||
|
import (
|
||||||
|
"io"
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
"path/filepath"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"cyberstrike-ai/internal/database"
|
||||||
|
|
||||||
|
"go.uber.org/zap"
|
||||||
|
)
|
||||||
|
|
||||||
|
// 回归:StartListener 返回的 rec 被 handler 脱敏清空 ImplantToken 后,运行中的 HTTP listener 仍能鉴权。
|
||||||
|
func TestStartListener_ImplantTokenSurvivesHandlerRedaction(t *testing.T) {
|
||||||
|
tmp := t.TempDir()
|
||||||
|
db, err := database.NewDB(filepath.Join(tmp, "c2.sqlite"), zap.NewNop())
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
t.Cleanup(func() { _ = db.Close() })
|
||||||
|
|
||||||
|
lnPick, err := net.Listen("tcp", "127.0.0.1:0")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
port := lnPick.Addr().(*net.TCPAddr).Port
|
||||||
|
_ = lnPick.Close()
|
||||||
|
|
||||||
|
mgr := NewManager(db, zap.NewNop(), tmp)
|
||||||
|
mgr.Registry().Register(string(ListenerTypeHTTPBeacon), NewHTTPBeaconListener)
|
||||||
|
rec, err := mgr.CreateListener(CreateListenerInput{
|
||||||
|
Name: "t",
|
||||||
|
Type: string(ListenerTypeHTTPBeacon),
|
||||||
|
BindHost: "127.0.0.1",
|
||||||
|
BindPort: port,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
token := rec.ImplantToken
|
||||||
|
|
||||||
|
rec, err = mgr.StartListener(rec.ID)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
// 模拟 internal/handler/c2.go StartListener 在 JSON 响应前的脱敏
|
||||||
|
rec.ImplantToken = ""
|
||||||
|
rec.EncryptionKey = ""
|
||||||
|
|
||||||
|
time.Sleep(50 * time.Millisecond)
|
||||||
|
|
||||||
|
body := `{"hostname":"n","username":"u","os":"Linux","arch":"amd64","internal_ip":"10.0.0.1","pid":42}`
|
||||||
|
req, _ := http.NewRequest(http.MethodPost, "http://127.0.0.1:"+strconv.Itoa(port)+"/check_in", strings.NewReader(body))
|
||||||
|
req.Header.Set("X-Implant-Token", token)
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
resp, err := http.DefaultClient.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
b, _ := io.ReadAll(resp.Body)
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
t.Fatalf("status=%d body=%s", resp.StatusCode, b)
|
||||||
|
}
|
||||||
|
if !strings.Contains(string(b), "session_id") {
|
||||||
|
t.Fatalf("expected session_id in body: %s", b)
|
||||||
|
}
|
||||||
|
_ = mgr.StopListener(rec.ID)
|
||||||
|
}
|
||||||
+110
-14
@@ -26,6 +26,7 @@ type Config struct {
|
|||||||
Security SecurityConfig `yaml:"security"`
|
Security SecurityConfig `yaml:"security"`
|
||||||
Database DatabaseConfig `yaml:"database"`
|
Database DatabaseConfig `yaml:"database"`
|
||||||
Auth AuthConfig `yaml:"auth"`
|
Auth AuthConfig `yaml:"auth"`
|
||||||
|
Audit AuditConfig `yaml:"audit,omitempty" json:"audit,omitempty"`
|
||||||
ExternalMCP ExternalMCPConfig `yaml:"external_mcp,omitempty"`
|
ExternalMCP ExternalMCPConfig `yaml:"external_mcp,omitempty"`
|
||||||
Knowledge KnowledgeConfig `yaml:"knowledge,omitempty"`
|
Knowledge KnowledgeConfig `yaml:"knowledge,omitempty"`
|
||||||
C2 C2Config `yaml:"c2,omitempty" json:"c2,omitempty"` // 内置 C2 总开关;未配置时默认启用
|
C2 C2Config `yaml:"c2,omitempty" json:"c2,omitempty"` // 内置 C2 总开关;未配置时默认启用
|
||||||
@@ -39,9 +40,9 @@ type Config struct {
|
|||||||
|
|
||||||
// MultiAgentConfig 基于 CloudWeGo Eino adk/prebuilt 的多代理编排(deep | plan_execute | supervisor,与单 Agent /agent-loop 并存)。
|
// MultiAgentConfig 基于 CloudWeGo Eino adk/prebuilt 的多代理编排(deep | plan_execute | supervisor,与单 Agent /agent-loop 并存)。
|
||||||
type MultiAgentConfig struct {
|
type MultiAgentConfig struct {
|
||||||
Enabled bool `yaml:"enabled" json:"enabled"`
|
Enabled bool `yaml:"enabled" json:"enabled"`
|
||||||
RobotUseMultiAgent bool `yaml:"robot_use_multi_agent" json:"robot_use_multi_agent"` // 为 true 时钉钉/飞书/企微机器人走 Eino 多代理
|
RobotDefaultAgentMode string `yaml:"robot_default_agent_mode,omitempty" json:"robot_default_agent_mode,omitempty"` // react | eino_single | deep | plan_execute | supervisor
|
||||||
BatchUseMultiAgent bool `yaml:"batch_use_multi_agent" json:"batch_use_multi_agent"` // 为 true 时批量任务队列中每子任务走 Eino 多代理
|
BatchUseMultiAgent bool `yaml:"batch_use_multi_agent" json:"batch_use_multi_agent"` // 为 true 时批量任务队列中每子任务走 Eino 多代理
|
||||||
// Orchestration 已弃用:保留仅兼容旧版 config.yaml;编排由聊天/WebShell 请求体 orchestration 决定,未传时按 deep。
|
// Orchestration 已弃用:保留仅兼容旧版 config.yaml;编排由聊天/WebShell 请求体 orchestration 决定,未传时按 deep。
|
||||||
Orchestration string `yaml:"orchestration,omitempty" json:"orchestration,omitempty"`
|
Orchestration string `yaml:"orchestration,omitempty" json:"orchestration,omitempty"`
|
||||||
MaxIteration int `yaml:"max_iteration" json:"max_iteration"` // 主代理 / 执行器最大推理轮次(Deep、Supervisor、plan_execute 的 Executor)
|
MaxIteration int `yaml:"max_iteration" json:"max_iteration"` // 主代理 / 执行器最大推理轮次(Deep、Supervisor、plan_execute 的 Executor)
|
||||||
@@ -227,6 +228,10 @@ type MultiAgentEinoMiddlewareConfig struct {
|
|||||||
DeepOutputKey string `yaml:"deep_output_key,omitempty" json:"deep_output_key,omitempty"`
|
DeepOutputKey string `yaml:"deep_output_key,omitempty" json:"deep_output_key,omitempty"`
|
||||||
// DeepModelRetryMaxRetries > 0 enables deep.Config ModelRetryConfig (framework-level chat model retries).
|
// DeepModelRetryMaxRetries > 0 enables deep.Config ModelRetryConfig (framework-level chat model retries).
|
||||||
DeepModelRetryMaxRetries int `yaml:"deep_model_retry_max_retries,omitempty" json:"deep_model_retry_max_retries,omitempty"`
|
DeepModelRetryMaxRetries int `yaml:"deep_model_retry_max_retries,omitempty" json:"deep_model_retry_max_retries,omitempty"`
|
||||||
|
// RunRetryMaxAttempts > 0:429/5xx/网络抖动时 handler 分段续跑次数;0=默认 10。
|
||||||
|
RunRetryMaxAttempts int `yaml:"run_retry_max_attempts,omitempty" json:"run_retry_max_attempts,omitempty"`
|
||||||
|
// RunRetryMaxBackoffSec 单次退避上限秒数;0=默认 30。
|
||||||
|
RunRetryMaxBackoffSec int `yaml:"run_retry_max_backoff_sec,omitempty" json:"run_retry_max_backoff_sec,omitempty"`
|
||||||
// TaskToolDescriptionPrefix when non-empty sets deep.Config TaskToolDescriptionGenerator (sub-agent names appended).
|
// TaskToolDescriptionPrefix when non-empty sets deep.Config TaskToolDescriptionGenerator (sub-agent names appended).
|
||||||
TaskToolDescriptionPrefix string `yaml:"task_tool_description_prefix,omitempty" json:"task_tool_description_prefix,omitempty"`
|
TaskToolDescriptionPrefix string `yaml:"task_tool_description_prefix,omitempty" json:"task_tool_description_prefix,omitempty"`
|
||||||
}
|
}
|
||||||
@@ -362,9 +367,9 @@ type MultiAgentSubConfig struct {
|
|||||||
|
|
||||||
// MultiAgentPublic 返回给前端的精简信息(不含子代理指令全文)。
|
// MultiAgentPublic 返回给前端的精简信息(不含子代理指令全文)。
|
||||||
type MultiAgentPublic struct {
|
type MultiAgentPublic struct {
|
||||||
Enabled bool `json:"enabled"`
|
Enabled bool `json:"enabled"`
|
||||||
RobotUseMultiAgent bool `json:"robot_use_multi_agent"`
|
RobotDefaultAgentMode string `json:"robot_default_agent_mode,omitempty"`
|
||||||
BatchUseMultiAgent bool `json:"batch_use_multi_agent"`
|
BatchUseMultiAgent bool `json:"batch_use_multi_agent"`
|
||||||
SubAgentCount int `json:"sub_agent_count"`
|
SubAgentCount int `json:"sub_agent_count"`
|
||||||
Orchestration string `json:"orchestration,omitempty"`
|
Orchestration string `json:"orchestration,omitempty"`
|
||||||
PlanExecuteLoopMaxIterations int `json:"plan_execute_loop_max_iterations"`
|
PlanExecuteLoopMaxIterations int `json:"plan_execute_loop_max_iterations"`
|
||||||
@@ -372,6 +377,18 @@ type MultiAgentPublic struct {
|
|||||||
ToolSearchAlwaysVisibleEffectiveTools []string `json:"tool_search_always_visible_effective_tools,omitempty"`
|
ToolSearchAlwaysVisibleEffectiveTools []string `json:"tool_search_always_visible_effective_tools,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// NormalizeRobotAgentMode 解析机器人默认对话模式(react | eino_single | deep | plan_execute | supervisor);空值视为 react。
|
||||||
|
func NormalizeRobotAgentMode(ma MultiAgentConfig) string {
|
||||||
|
s := strings.TrimSpace(strings.ToLower(ma.RobotDefaultAgentMode))
|
||||||
|
if s == "" || s == "single" || s == "react" {
|
||||||
|
return "react"
|
||||||
|
}
|
||||||
|
if s == "eino_single" {
|
||||||
|
return "eino_single"
|
||||||
|
}
|
||||||
|
return NormalizeMultiAgentOrchestration(s)
|
||||||
|
}
|
||||||
|
|
||||||
// NormalizeMultiAgentOrchestration 返回 deep、plan_execute 或 supervisor。
|
// NormalizeMultiAgentOrchestration 返回 deep、plan_execute 或 supervisor。
|
||||||
func NormalizeMultiAgentOrchestration(s string) string {
|
func NormalizeMultiAgentOrchestration(s string) string {
|
||||||
v := strings.TrimSpace(strings.ToLower(s))
|
v := strings.TrimSpace(strings.ToLower(s))
|
||||||
@@ -387,21 +404,35 @@ func NormalizeMultiAgentOrchestration(s string) string {
|
|||||||
|
|
||||||
// MultiAgentAPIUpdate 设置页/API 仅更新多代理标量字段;写入 YAML 时不覆盖 sub_agents 等块。
|
// MultiAgentAPIUpdate 设置页/API 仅更新多代理标量字段;写入 YAML 时不覆盖 sub_agents 等块。
|
||||||
type MultiAgentAPIUpdate struct {
|
type MultiAgentAPIUpdate struct {
|
||||||
Enabled bool `json:"enabled"`
|
Enabled bool `json:"enabled"`
|
||||||
RobotUseMultiAgent bool `json:"robot_use_multi_agent"`
|
RobotDefaultAgentMode string `json:"robot_default_agent_mode,omitempty"`
|
||||||
BatchUseMultiAgent bool `json:"batch_use_multi_agent"`
|
BatchUseMultiAgent bool `json:"batch_use_multi_agent"`
|
||||||
PlanExecuteLoopMaxIterations *int `json:"plan_execute_loop_max_iterations,omitempty"`
|
PlanExecuteLoopMaxIterations *int `json:"plan_execute_loop_max_iterations,omitempty"`
|
||||||
ToolSearchAlwaysVisibleTools []string `json:"tool_search_always_visible_tools,omitempty"`
|
// 指针区分「JSON 未传该字段」与「传空数组要清空」;省略时不应覆盖 YAML 中的常驻工具白名单。
|
||||||
|
ToolSearchAlwaysVisibleTools *[]string `json:"tool_search_always_visible_tools,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// RobotsConfig 机器人配置(企业微信、钉钉、飞书等)
|
// RobotsConfig 机器人配置(企业微信、钉钉、飞书、微信 iLink 等)
|
||||||
type RobotsConfig struct {
|
type RobotsConfig struct {
|
||||||
Session RobotSessionConfig `yaml:"session,omitempty" json:"session,omitempty"` // 机器人会话隔离策略
|
Session RobotSessionConfig `yaml:"session,omitempty" json:"session,omitempty"` // 机器人会话隔离策略
|
||||||
|
Wechat RobotWechatConfig `yaml:"wechat,omitempty" json:"wechat,omitempty"` // 微信(iLink 扫码绑定)
|
||||||
Wecom RobotWecomConfig `yaml:"wecom,omitempty" json:"wecom,omitempty"` // 企业微信
|
Wecom RobotWecomConfig `yaml:"wecom,omitempty" json:"wecom,omitempty"` // 企业微信
|
||||||
Dingtalk RobotDingtalkConfig `yaml:"dingtalk,omitempty" json:"dingtalk,omitempty"` // 钉钉
|
Dingtalk RobotDingtalkConfig `yaml:"dingtalk,omitempty" json:"dingtalk,omitempty"` // 钉钉
|
||||||
Lark RobotLarkConfig `yaml:"lark,omitempty" json:"lark,omitempty"` // 飞书
|
Lark RobotLarkConfig `yaml:"lark,omitempty" json:"lark,omitempty"` // 飞书
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// RobotWechatConfig 微信 iLink 机器人配置(个人微信 ClawBot / iLink 协议)
|
||||||
|
type RobotWechatConfig struct {
|
||||||
|
Enabled bool `yaml:"enabled" json:"enabled"`
|
||||||
|
BotToken string `yaml:"bot_token,omitempty" json:"bot_token,omitempty"`
|
||||||
|
ILinkBotID string `yaml:"ilink_bot_id,omitempty" json:"ilink_bot_id,omitempty"`
|
||||||
|
ILinkUserID string `yaml:"ilink_user_id,omitempty" json:"ilink_user_id,omitempty"`
|
||||||
|
BaseURL string `yaml:"base_url,omitempty" json:"base_url,omitempty"` // 默认 https://ilinkai.weixin.qq.com
|
||||||
|
BotType string `yaml:"bot_type,omitempty" json:"bot_type,omitempty"` // get_bot_qrcode 参数,默认 3
|
||||||
|
BotAgent string `yaml:"bot_agent,omitempty" json:"bot_agent,omitempty"` // base_info.bot_agent
|
||||||
|
GetUpdatesBuf string `yaml:"get_updates_buf,omitempty" json:"get_updates_buf,omitempty"` // 长轮询游标(运行时)
|
||||||
|
}
|
||||||
|
|
||||||
// RobotSessionConfig 机器人会话隔离策略
|
// RobotSessionConfig 机器人会话隔离策略
|
||||||
type RobotSessionConfig struct {
|
type RobotSessionConfig struct {
|
||||||
StrictUserIdentity *bool `yaml:"strict_user_identity,omitempty" json:"strict_user_identity,omitempty"` // true 时只允许真实用户标识,不允许会话/群 ID 兜底
|
StrictUserIdentity *bool `yaml:"strict_user_identity,omitempty" json:"strict_user_identity,omitempty"` // true 时只允许真实用户标识,不允许会话/群 ID 兜底
|
||||||
@@ -443,8 +474,17 @@ type RobotLarkConfig struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type ServerConfig struct {
|
type ServerConfig struct {
|
||||||
Host string `yaml:"host"`
|
Host string `yaml:"host" json:"host"`
|
||||||
Port int `yaml:"port"`
|
Port int `yaml:"port" json:"port"`
|
||||||
|
// TLSEnabled 为 true 时主 Web UI 使用 HTTPS;现代浏览器在同源下会协商 HTTP/2,缓解 HTTP/1.1 每源并发连接数限制。
|
||||||
|
TLSEnabled bool `yaml:"tls_enabled,omitempty" json:"tls_enabled,omitempty"`
|
||||||
|
// TLSCertPath / TLSKeyPath 非空时从 PEM 文件加载证书(生产环境推荐)。
|
||||||
|
TLSCertPath string `yaml:"tls_cert_path,omitempty" json:"tls_cert_path,omitempty"`
|
||||||
|
TLSKeyPath string `yaml:"tls_key_path,omitempty" json:"tls_key_path,omitempty"`
|
||||||
|
// TLSAutoSelfSign 为 true 且未配置有效证书路径时,启动时生成内存自签证书(仅本地/测试;浏览器会提示不受信任)。
|
||||||
|
TLSAutoSelfSign bool `yaml:"tls_auto_self_sign,omitempty" json:"tls_auto_self_sign,omitempty"`
|
||||||
|
// TLSHTTPRedirect 为 false 时禁用 HTTP→HTTPS 跳转;省略或为 true 且已启用 HTTPS 时,明文 HTTP 访问将 308 跳转到 HTTPS(同端口嗅探分流)。
|
||||||
|
TLSHTTPRedirect *bool `yaml:"tls_http_redirect,omitempty" json:"tls_http_redirect,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type LogConfig struct {
|
type LogConfig struct {
|
||||||
@@ -474,7 +514,7 @@ type OpenAIConfig struct {
|
|||||||
type OpenAIReasoningConfig struct {
|
type OpenAIReasoningConfig struct {
|
||||||
// Mode: auto(默认)| on | off | default(与 auto 相同)。off 时不向模型附加推理扩展字段。
|
// Mode: auto(默认)| on | off | default(与 auto 相同)。off 时不向模型附加推理扩展字段。
|
||||||
Mode string `yaml:"mode,omitempty" json:"mode,omitempty"`
|
Mode string `yaml:"mode,omitempty" json:"mode,omitempty"`
|
||||||
// Effort: low | medium | high | max;空表示不单独指定强度(各 profile 行为见 internal/reasoning)。
|
// Effort: low | medium | high | max | xhigh;max/xhigh 为不同网关最高档命名,原样下发、不互转。空表示不单独指定强度。
|
||||||
Effort string `yaml:"effort,omitempty" json:"effort,omitempty"`
|
Effort string `yaml:"effort,omitempty" json:"effort,omitempty"`
|
||||||
// AllowClientReasoning 为 false 时忽略请求体 reasoning;nil 或未设置等同于 true。
|
// AllowClientReasoning 为 false 时忽略请求体 reasoning;nil 或未设置等同于 true。
|
||||||
AllowClientReasoning *bool `yaml:"allow_client_reasoning,omitempty" json:"allow_client_reasoning,omitempty"`
|
AllowClientReasoning *bool `yaml:"allow_client_reasoning,omitempty" json:"allow_client_reasoning,omitempty"`
|
||||||
@@ -552,6 +592,51 @@ type AuthConfig struct {
|
|||||||
GeneratedPasswordPersistErr string `yaml:"-" json:"-"`
|
GeneratedPasswordPersistErr string `yaml:"-" json:"-"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// AuditConfig platform operation audit log settings (not chat/tool execution bodies).
|
||||||
|
type AuditConfig struct {
|
||||||
|
// Enabled nil or true enables persistence; explicit false disables.
|
||||||
|
Enabled *bool `yaml:"enabled,omitempty" json:"enabled,omitempty"`
|
||||||
|
RetentionDays int `yaml:"retention_days,omitempty" json:"retention_days,omitempty"`
|
||||||
|
MaxDetailBytes int `yaml:"max_detail_bytes,omitempty" json:"max_detail_bytes,omitempty"`
|
||||||
|
// AuthFailureCooldownSeconds: per-IP cooldown for auth login/change_password failure audit rows; -1 disables; 0 uses default 60.
|
||||||
|
AuthFailureCooldownSeconds int `yaml:"auth_failure_cooldown_seconds,omitempty" json:"auth_failure_cooldown_seconds,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// EnabledEffective returns true unless audit.enabled is explicitly false.
|
||||||
|
func (a AuditConfig) EnabledEffective() bool {
|
||||||
|
if a.Enabled == nil {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return *a.Enabled
|
||||||
|
}
|
||||||
|
|
||||||
|
// RetentionDaysEffective returns retention; 0 means keep forever.
|
||||||
|
func (a AuditConfig) RetentionDaysEffective() int {
|
||||||
|
if a.RetentionDays < 0 {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
return a.RetentionDays
|
||||||
|
}
|
||||||
|
|
||||||
|
// MaxDetailBytesEffective caps serialized detail JSON size.
|
||||||
|
func (a AuditConfig) MaxDetailBytesEffective() int {
|
||||||
|
if a.MaxDetailBytes <= 0 {
|
||||||
|
return 8192
|
||||||
|
}
|
||||||
|
return a.MaxDetailBytes
|
||||||
|
}
|
||||||
|
|
||||||
|
// AuthFailureCooldownEffective returns seconds between duplicate auth-failure audit rows per IP (default 60; -1 disables).
|
||||||
|
func (a AuditConfig) AuthFailureCooldownEffective() int {
|
||||||
|
if a.AuthFailureCooldownSeconds < 0 {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
if a.AuthFailureCooldownSeconds == 0 {
|
||||||
|
return 60
|
||||||
|
}
|
||||||
|
return a.AuthFailureCooldownSeconds
|
||||||
|
}
|
||||||
|
|
||||||
// ExternalMCPConfig 外部MCP配置
|
// ExternalMCPConfig 外部MCP配置
|
||||||
type ExternalMCPConfig struct {
|
type ExternalMCPConfig struct {
|
||||||
Servers map[string]ExternalMCPServerConfig `yaml:"servers,omitempty" json:"servers,omitempty"`
|
Servers map[string]ExternalMCPServerConfig `yaml:"servers,omitempty" json:"servers,omitempty"`
|
||||||
@@ -644,6 +729,9 @@ func Load(path string) (*Config, error) {
|
|||||||
if cfg.Auth.SessionDurationHours <= 0 {
|
if cfg.Auth.SessionDurationHours <= 0 {
|
||||||
cfg.Auth.SessionDurationHours = 12
|
cfg.Auth.SessionDurationHours = 12
|
||||||
}
|
}
|
||||||
|
if cfg.Audit.MaxDetailBytes <= 0 {
|
||||||
|
cfg.Audit.MaxDetailBytes = 8192
|
||||||
|
}
|
||||||
if strings.TrimSpace(cfg.Auth.Password) == "" {
|
if strings.TrimSpace(cfg.Auth.Password) == "" {
|
||||||
password, err := generateStrongPassword(24)
|
password, err := generateStrongPassword(24)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -1147,6 +1235,14 @@ func Default() *Config {
|
|||||||
Auth: AuthConfig{
|
Auth: AuthConfig{
|
||||||
SessionDurationHours: 12,
|
SessionDurationHours: 12,
|
||||||
},
|
},
|
||||||
|
Audit: func() AuditConfig {
|
||||||
|
on := true
|
||||||
|
return AuditConfig{
|
||||||
|
RetentionDays: 90,
|
||||||
|
MaxDetailBytes: 8192,
|
||||||
|
Enabled: &on,
|
||||||
|
}
|
||||||
|
}(),
|
||||||
Robots: RobotsConfig{
|
Robots: RobotsConfig{
|
||||||
Session: RobotSessionConfig{
|
Session: RobotSessionConfig{
|
||||||
StrictUserIdentity: &strictRobotIdentity,
|
StrictUserIdentity: &strictRobotIdentity,
|
||||||
|
|||||||
@@ -0,0 +1,46 @@
|
|||||||
|
package config
|
||||||
|
|
||||||
|
import "strings"
|
||||||
|
|
||||||
|
// MainWebUIUsesHTTPS 判断主 Web UI 是否以 HTTPS 监听(与 internal/app.prepareMainServerTLS 前置条件一致)。
|
||||||
|
func MainWebUIUsesHTTPS(s *ServerConfig) bool {
|
||||||
|
if s == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if s.TLSEnabled {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
if s.TLSAutoSelfSign {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
cert := strings.TrimSpace(s.TLSCertPath)
|
||||||
|
key := strings.TrimSpace(s.TLSKeyPath)
|
||||||
|
return cert != "" && key != ""
|
||||||
|
}
|
||||||
|
|
||||||
|
// ServerHTTPRedirectEnabled 是否在主站启用 HTTPS 时把明文 HTTP 请求重定向到 HTTPS(默认开启)。
|
||||||
|
func ServerHTTPRedirectEnabled(s *ServerConfig) bool {
|
||||||
|
if s == nil || !MainWebUIUsesHTTPS(s) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if s.TLSHTTPRedirect == nil {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return *s.TLSHTTPRedirect
|
||||||
|
}
|
||||||
|
|
||||||
|
// ApplyDevHTTPSBootstrap 供 --https / 一键脚本使用:强制开启主站 TLS。
|
||||||
|
// 若已配置 tls_cert_path 与 tls_key_path 则仅用 PEM,不开启自签;否则启用 tls_auto_self_sign(内存证书,仅本地测试)。
|
||||||
|
func ApplyDevHTTPSBootstrap(cfg *Config) {
|
||||||
|
if cfg == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
cfg.Server.TLSEnabled = true
|
||||||
|
cert := strings.TrimSpace(cfg.Server.TLSCertPath)
|
||||||
|
key := strings.TrimSpace(cfg.Server.TLSKeyPath)
|
||||||
|
if cert != "" && key != "" {
|
||||||
|
cfg.Server.TLSAutoSelfSign = false
|
||||||
|
return
|
||||||
|
}
|
||||||
|
cfg.Server.TLSAutoSelfSign = true
|
||||||
|
}
|
||||||
@@ -0,0 +1,210 @@
|
|||||||
|
package database
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// AuditLog platform operation audit record.
|
||||||
|
type AuditLog struct {
|
||||||
|
ID string `json:"id"`
|
||||||
|
CreatedAt time.Time `json:"createdAt"`
|
||||||
|
Level string `json:"level"`
|
||||||
|
Category string `json:"category"`
|
||||||
|
Action string `json:"action"`
|
||||||
|
Result string `json:"result"`
|
||||||
|
Actor string `json:"actor"`
|
||||||
|
SessionHint string `json:"sessionHint,omitempty"`
|
||||||
|
ClientIP string `json:"clientIp,omitempty"`
|
||||||
|
UserAgent string `json:"userAgent,omitempty"`
|
||||||
|
ResourceType string `json:"resourceType,omitempty"`
|
||||||
|
ResourceID string `json:"resourceId,omitempty"`
|
||||||
|
ResourceAvailable *bool `json:"resourceAvailable,omitempty"` // API-only: whether linked resource still exists
|
||||||
|
Message string `json:"message"`
|
||||||
|
Detail map[string]interface{} `json:"detail,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// ListAuditLogsFilter query parameters.
|
||||||
|
type ListAuditLogsFilter struct {
|
||||||
|
Level string
|
||||||
|
Category string
|
||||||
|
Action string
|
||||||
|
Result string
|
||||||
|
Query string
|
||||||
|
ResourceType string
|
||||||
|
ResourceID string
|
||||||
|
Since *time.Time
|
||||||
|
Until *time.Time
|
||||||
|
Limit int
|
||||||
|
Offset int
|
||||||
|
}
|
||||||
|
|
||||||
|
func buildAuditLogsWhere(filter ListAuditLogsFilter) (string, []interface{}) {
|
||||||
|
conditions := []string{"1=1"}
|
||||||
|
args := []interface{}{}
|
||||||
|
if filter.Level != "" {
|
||||||
|
conditions = append(conditions, "level = ?")
|
||||||
|
args = append(args, filter.Level)
|
||||||
|
}
|
||||||
|
if filter.Category != "" {
|
||||||
|
conditions = append(conditions, "category = ?")
|
||||||
|
args = append(args, filter.Category)
|
||||||
|
}
|
||||||
|
if filter.Action != "" {
|
||||||
|
conditions = append(conditions, "action = ?")
|
||||||
|
args = append(args, filter.Action)
|
||||||
|
}
|
||||||
|
if filter.Result != "" {
|
||||||
|
conditions = append(conditions, "result = ?")
|
||||||
|
args = append(args, filter.Result)
|
||||||
|
}
|
||||||
|
if filter.ResourceType != "" {
|
||||||
|
conditions = append(conditions, "resource_type = ?")
|
||||||
|
args = append(args, filter.ResourceType)
|
||||||
|
}
|
||||||
|
if filter.ResourceID != "" {
|
||||||
|
conditions = append(conditions, "resource_id = ?")
|
||||||
|
args = append(args, filter.ResourceID)
|
||||||
|
}
|
||||||
|
if filter.Since != nil {
|
||||||
|
conditions = append(conditions, "created_at >= ?")
|
||||||
|
args = append(args, *filter.Since)
|
||||||
|
}
|
||||||
|
if filter.Until != nil {
|
||||||
|
conditions = append(conditions, "created_at <= ?")
|
||||||
|
args = append(args, *filter.Until)
|
||||||
|
}
|
||||||
|
if q := strings.TrimSpace(filter.Query); q != "" {
|
||||||
|
like := "%" + q + "%"
|
||||||
|
conditions = append(conditions, "(message LIKE ? OR resource_id LIKE ? OR action LIKE ? OR category LIKE ?)")
|
||||||
|
args = append(args, like, like, like, like)
|
||||||
|
}
|
||||||
|
return strings.Join(conditions, " AND "), args
|
||||||
|
}
|
||||||
|
|
||||||
|
// AppendAuditLog inserts one audit row.
|
||||||
|
func (db *DB) AppendAuditLog(row *AuditLog) error {
|
||||||
|
if row == nil {
|
||||||
|
return errors.New("audit log is nil")
|
||||||
|
}
|
||||||
|
if strings.TrimSpace(row.ID) == "" {
|
||||||
|
return errors.New("audit id is required")
|
||||||
|
}
|
||||||
|
if row.CreatedAt.IsZero() {
|
||||||
|
row.CreatedAt = time.Now()
|
||||||
|
}
|
||||||
|
if strings.TrimSpace(row.Level) == "" {
|
||||||
|
row.Level = "info"
|
||||||
|
}
|
||||||
|
detailJSON := ""
|
||||||
|
if len(row.Detail) > 0 {
|
||||||
|
if b, err := json.Marshal(row.Detail); err == nil {
|
||||||
|
detailJSON = string(b)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
query := `
|
||||||
|
INSERT INTO audit_logs (
|
||||||
|
id, created_at, level, category, action, result, actor, session_hint,
|
||||||
|
client_ip, user_agent, resource_type, resource_id, message, detail_json
|
||||||
|
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||||
|
`
|
||||||
|
_, err := db.Exec(query,
|
||||||
|
row.ID, row.CreatedAt, row.Level, row.Category, row.Action, row.Result,
|
||||||
|
row.Actor, row.SessionHint, row.ClientIP, row.UserAgent,
|
||||||
|
row.ResourceType, row.ResourceID, row.Message, detailJSON,
|
||||||
|
)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetAuditLogByID returns one row.
|
||||||
|
func (db *DB) GetAuditLogByID(id string) (*AuditLog, error) {
|
||||||
|
id = strings.TrimSpace(id)
|
||||||
|
if id == "" {
|
||||||
|
return nil, errors.New("id is required")
|
||||||
|
}
|
||||||
|
query := `
|
||||||
|
SELECT id, created_at, level, category, action, result, actor,
|
||||||
|
COALESCE(session_hint, ''), COALESCE(client_ip, ''), COALESCE(user_agent, ''),
|
||||||
|
COALESCE(resource_type, ''), COALESCE(resource_id, ''), message, COALESCE(detail_json, '')
|
||||||
|
FROM audit_logs WHERE id = ?
|
||||||
|
`
|
||||||
|
var row AuditLog
|
||||||
|
var detailJSON string
|
||||||
|
err := db.QueryRow(query, id).Scan(
|
||||||
|
&row.ID, &row.CreatedAt, &row.Level, &row.Category, &row.Action, &row.Result, &row.Actor,
|
||||||
|
&row.SessionHint, &row.ClientIP, &row.UserAgent,
|
||||||
|
&row.ResourceType, &row.ResourceID, &row.Message, &detailJSON,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if detailJSON != "" {
|
||||||
|
_ = json.Unmarshal([]byte(detailJSON), &row.Detail)
|
||||||
|
}
|
||||||
|
return &row, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// CountAuditLogs counts rows matching filter.
|
||||||
|
func (db *DB) CountAuditLogs(filter ListAuditLogsFilter) (int64, error) {
|
||||||
|
where, args := buildAuditLogsWhere(filter)
|
||||||
|
query := `SELECT COUNT(*) FROM audit_logs WHERE ` + where
|
||||||
|
var n int64
|
||||||
|
err := db.QueryRow(query, args...).Scan(&n)
|
||||||
|
return n, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// ListAuditLogs lists audit rows newest first.
|
||||||
|
func (db *DB) ListAuditLogs(filter ListAuditLogsFilter) ([]*AuditLog, error) {
|
||||||
|
where, args := buildAuditLogsWhere(filter)
|
||||||
|
limit := filter.Limit
|
||||||
|
if limit <= 0 || limit > 500 {
|
||||||
|
limit = 50
|
||||||
|
}
|
||||||
|
offset := filter.Offset
|
||||||
|
if offset < 0 {
|
||||||
|
offset = 0
|
||||||
|
}
|
||||||
|
query := `
|
||||||
|
SELECT id, created_at, level, category, action, result, actor,
|
||||||
|
COALESCE(session_hint, ''), COALESCE(client_ip, ''), COALESCE(user_agent, ''),
|
||||||
|
COALESCE(resource_type, ''), COALESCE(resource_id, ''), message, COALESCE(detail_json, '')
|
||||||
|
FROM audit_logs
|
||||||
|
WHERE ` + where + `
|
||||||
|
ORDER BY created_at DESC
|
||||||
|
LIMIT ? OFFSET ?
|
||||||
|
`
|
||||||
|
args = append(args, limit, offset)
|
||||||
|
rows, err := db.Query(query, args...)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer rows.Close()
|
||||||
|
var list []*AuditLog
|
||||||
|
for rows.Next() {
|
||||||
|
var row AuditLog
|
||||||
|
var detailJSON string
|
||||||
|
if err := rows.Scan(
|
||||||
|
&row.ID, &row.CreatedAt, &row.Level, &row.Category, &row.Action, &row.Result, &row.Actor,
|
||||||
|
&row.SessionHint, &row.ClientIP, &row.UserAgent,
|
||||||
|
&row.ResourceType, &row.ResourceID, &row.Message, &detailJSON,
|
||||||
|
); err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if detailJSON != "" {
|
||||||
|
_ = json.Unmarshal([]byte(detailJSON), &row.Detail)
|
||||||
|
}
|
||||||
|
list = append(list, &row)
|
||||||
|
}
|
||||||
|
return list, rows.Err()
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeleteAuditLogsBefore removes rows older than cutoff.
|
||||||
|
func (db *DB) DeleteAuditLogsBefore(cutoff time.Time) (int64, error) {
|
||||||
|
res, err := db.Exec(`DELETE FROM audit_logs WHERE created_at < ?`, cutoff)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
return res.RowsAffected()
|
||||||
|
}
|
||||||
@@ -26,7 +26,7 @@ type Conversation struct {
|
|||||||
// Message 消息
|
// Message 消息
|
||||||
type Message struct {
|
type Message struct {
|
||||||
ID string `json:"id"`
|
ID string `json:"id"`
|
||||||
ConversationID string `json:"conversationId"`
|
ConversationID string `json:"conversationId"`
|
||||||
Role string `json:"role"`
|
Role string `json:"role"`
|
||||||
Content string `json:"content"`
|
Content string `json:"content"`
|
||||||
ReasoningContent string `json:"reasoningContent,omitempty"`
|
ReasoningContent string `json:"reasoningContent,omitempty"`
|
||||||
@@ -37,12 +37,12 @@ type Message struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// CreateConversation 创建新对话
|
// CreateConversation 创建新对话
|
||||||
func (db *DB) CreateConversation(title string) (*Conversation, error) {
|
func (db *DB) CreateConversation(title string, meta ConversationCreateMeta) (*Conversation, error) {
|
||||||
return db.CreateConversationWithWebshell("", title)
|
return db.CreateConversationWithWebshell("", title, meta)
|
||||||
}
|
}
|
||||||
|
|
||||||
// CreateConversationWithWebshell 创建新对话,可选绑定 WebShell 连接 ID(为空则普通对话)
|
// CreateConversationWithWebshell 创建新对话,可选绑定 WebShell 连接 ID(为空则普通对话)
|
||||||
func (db *DB) CreateConversationWithWebshell(webshellConnectionID, title string) (*Conversation, error) {
|
func (db *DB) CreateConversationWithWebshell(webshellConnectionID, title string, meta ConversationCreateMeta) (*Conversation, error) {
|
||||||
id := uuid.New().String()
|
id := uuid.New().String()
|
||||||
now := time.Now()
|
now := time.Now()
|
||||||
|
|
||||||
@@ -62,12 +62,17 @@ func (db *DB) CreateConversationWithWebshell(webshellConnectionID, title string)
|
|||||||
return nil, fmt.Errorf("创建对话失败: %w", err)
|
return nil, fmt.Errorf("创建对话失败: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return &Conversation{
|
conv := &Conversation{
|
||||||
ID: id,
|
ID: id,
|
||||||
Title: title,
|
Title: title,
|
||||||
CreatedAt: now,
|
CreatedAt: now,
|
||||||
UpdatedAt: now,
|
UpdatedAt: now,
|
||||||
}, nil
|
}
|
||||||
|
if webshellConnectionID != "" {
|
||||||
|
meta.WebShellConnectionID = webshellConnectionID
|
||||||
|
}
|
||||||
|
notifyConversationCreated(conv, meta)
|
||||||
|
return conv, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetConversationByWebshellConnectionID 根据 WebShell 连接 ID 获取该连接下最近一条对话(用于 AI 助手持久化)
|
// GetConversationByWebshellConnectionID 根据 WebShell 连接 ID 获取该连接下最近一条对话(用于 AI 助手持久化)
|
||||||
@@ -117,6 +122,7 @@ func (db *DB) GetConversationByWebshellConnectionID(connectionID string) (*Conve
|
|||||||
}
|
}
|
||||||
for i := range conv.Messages {
|
for i := range conv.Messages {
|
||||||
if details, ok := processDetailsMap[conv.Messages[i].ID]; ok {
|
if details, ok := processDetailsMap[conv.Messages[i].ID]; ok {
|
||||||
|
details = DedupeConsecutiveProcessDetails(details)
|
||||||
detailsJSON := make([]map[string]interface{}, len(details))
|
detailsJSON := make([]map[string]interface{}, len(details))
|
||||||
for j, detail := range details {
|
for j, detail := range details {
|
||||||
var data interface{}
|
var data interface{}
|
||||||
@@ -181,6 +187,23 @@ func (db *DB) ListConversationsByWebshellConnectionID(connectionID string) ([]We
|
|||||||
return list, rows.Err()
|
return list, rows.Err()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ConversationExists reports whether a conversation row exists (lightweight check for audit links).
|
||||||
|
func (db *DB) ConversationExists(id string) (bool, error) {
|
||||||
|
id = strings.TrimSpace(id)
|
||||||
|
if id == "" {
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
var one int
|
||||||
|
err := db.QueryRow("SELECT 1 FROM conversations WHERE id = ? LIMIT 1", id).Scan(&one)
|
||||||
|
if err == sql.ErrNoRows {
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
|
||||||
// GetConversation 获取对话
|
// GetConversation 获取对话
|
||||||
func (db *DB) GetConversation(id string) (*Conversation, error) {
|
func (db *DB) GetConversation(id string) (*Conversation, error) {
|
||||||
var conv Conversation
|
var conv Conversation
|
||||||
@@ -235,6 +258,7 @@ func (db *DB) GetConversation(id string) (*Conversation, error) {
|
|||||||
// 将过程详情附加到对应的消息上
|
// 将过程详情附加到对应的消息上
|
||||||
for i := range conv.Messages {
|
for i := range conv.Messages {
|
||||||
if details, ok := processDetailsMap[conv.Messages[i].ID]; ok {
|
if details, ok := processDetailsMap[conv.Messages[i].ID]; ok {
|
||||||
|
details = DedupeConsecutiveProcessDetails(details)
|
||||||
// 将ProcessDetail转换为JSON格式,以便前端使用
|
// 将ProcessDetail转换为JSON格式,以便前端使用
|
||||||
detailsJSON := make([]map[string]interface{}, len(details))
|
detailsJSON := make([]map[string]interface{}, len(details))
|
||||||
for j, detail := range details {
|
for j, detail := range details {
|
||||||
|
|||||||
@@ -0,0 +1,29 @@
|
|||||||
|
package database
|
||||||
|
|
||||||
|
// ConversationCreateMeta describes how a conversation was created (for audit hooks).
|
||||||
|
type ConversationCreateMeta struct {
|
||||||
|
Source string
|
||||||
|
WebShellConnectionID string
|
||||||
|
ClientIP string
|
||||||
|
SessionHint string
|
||||||
|
}
|
||||||
|
|
||||||
|
// ConversationCreateHook is invoked after a conversation row is inserted.
|
||||||
|
type ConversationCreateHook func(conv *Conversation, meta ConversationCreateMeta)
|
||||||
|
|
||||||
|
var conversationCreateHook ConversationCreateHook
|
||||||
|
|
||||||
|
// SetConversationCreateHook registers a global hook (e.g. platform audit).
|
||||||
|
func SetConversationCreateHook(h ConversationCreateHook) {
|
||||||
|
conversationCreateHook = h
|
||||||
|
}
|
||||||
|
|
||||||
|
func notifyConversationCreated(conv *Conversation, meta ConversationCreateMeta) {
|
||||||
|
if conversationCreateHook == nil || conv == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if meta.Source == "" {
|
||||||
|
meta.Source = "unknown"
|
||||||
|
}
|
||||||
|
conversationCreateHook(conv, meta)
|
||||||
|
}
|
||||||
@@ -387,6 +387,24 @@ func (db *DB) initTables() error {
|
|||||||
created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP
|
created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP
|
||||||
);`
|
);`
|
||||||
|
|
||||||
|
createAuditLogsTable := `
|
||||||
|
CREATE TABLE IF NOT EXISTS audit_logs (
|
||||||
|
id TEXT PRIMARY KEY,
|
||||||
|
created_at DATETIME NOT NULL,
|
||||||
|
level TEXT NOT NULL DEFAULT 'info',
|
||||||
|
category TEXT NOT NULL,
|
||||||
|
action TEXT NOT NULL,
|
||||||
|
result TEXT NOT NULL,
|
||||||
|
actor TEXT NOT NULL DEFAULT 'admin',
|
||||||
|
session_hint TEXT,
|
||||||
|
client_ip TEXT,
|
||||||
|
user_agent TEXT,
|
||||||
|
resource_type TEXT,
|
||||||
|
resource_id TEXT,
|
||||||
|
message TEXT NOT NULL,
|
||||||
|
detail_json TEXT
|
||||||
|
);`
|
||||||
|
|
||||||
createC2ProfilesTable := `
|
createC2ProfilesTable := `
|
||||||
CREATE TABLE IF NOT EXISTS c2_profiles (
|
CREATE TABLE IF NOT EXISTS c2_profiles (
|
||||||
id TEXT PRIMARY KEY,
|
id TEXT PRIMARY KEY,
|
||||||
@@ -445,6 +463,10 @@ func (db *DB) initTables() error {
|
|||||||
CREATE INDEX IF NOT EXISTS idx_c2_events_created_at ON c2_events(created_at);
|
CREATE INDEX IF NOT EXISTS idx_c2_events_created_at ON c2_events(created_at);
|
||||||
CREATE INDEX IF NOT EXISTS idx_c2_events_category ON c2_events(category);
|
CREATE INDEX IF NOT EXISTS idx_c2_events_category ON c2_events(category);
|
||||||
CREATE INDEX IF NOT EXISTS idx_c2_events_session ON c2_events(session_id);
|
CREATE INDEX IF NOT EXISTS idx_c2_events_session ON c2_events(session_id);
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_audit_logs_created_at ON audit_logs(created_at);
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_audit_logs_category ON audit_logs(category);
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_audit_logs_action ON audit_logs(action);
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_audit_logs_result ON audit_logs(result);
|
||||||
`
|
`
|
||||||
|
|
||||||
if _, err := db.Exec(createConversationsTable); err != nil {
|
if _, err := db.Exec(createConversationsTable); err != nil {
|
||||||
@@ -514,6 +536,10 @@ func (db *DB) initTables() error {
|
|||||||
return fmt.Errorf("创建webshell_connection_states表失败: %w", err)
|
return fmt.Errorf("创建webshell_connection_states表失败: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if _, err := db.Exec(createAuditLogsTable); err != nil {
|
||||||
|
return fmt.Errorf("创建audit_logs表失败: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
for tableName, ddl := range map[string]string{
|
for tableName, ddl := range map[string]string{
|
||||||
"c2_listeners": createC2ListenersTable,
|
"c2_listeners": createC2ListenersTable,
|
||||||
"c2_sessions": createC2SessionsTable,
|
"c2_sessions": createC2SessionsTable,
|
||||||
|
|||||||
@@ -0,0 +1,28 @@
|
|||||||
|
package database
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
// DedupeConsecutiveProcessDetails 去掉相邻且语义相同的过程详情(使用 DB 中 data 列原始 JSON 作指纹,避免 map 序列化键序不稳定)。
|
||||||
|
func DedupeConsecutiveProcessDetails(rows []ProcessDetail) []ProcessDetail {
|
||||||
|
if len(rows) < 2 {
|
||||||
|
return rows
|
||||||
|
}
|
||||||
|
out := make([]ProcessDetail, 0, len(rows))
|
||||||
|
var lastKey string
|
||||||
|
for _, d := range rows {
|
||||||
|
key := processDetailRowKey(d)
|
||||||
|
if len(out) > 0 && key != "" && key == lastKey {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
out = append(out, d)
|
||||||
|
lastKey = key
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
func processDetailRowKey(d ProcessDetail) string {
|
||||||
|
return fmt.Sprintf("%s\x00%s\x00%s", d.EventType, strings.TrimSpace(d.Message), d.Data)
|
||||||
|
}
|
||||||
@@ -3,12 +3,84 @@ package database
|
|||||||
import (
|
import (
|
||||||
"database/sql"
|
"database/sql"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
"go.uber.org/zap"
|
"go.uber.org/zap"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// VulnerabilityListFilter 列表/统计/导出共用的筛选条件
|
||||||
|
type VulnerabilityListFilter struct {
|
||||||
|
ID string
|
||||||
|
Search string // 关键词模糊匹配(标题、描述、类型、目标等)
|
||||||
|
ConversationID string
|
||||||
|
Severity string
|
||||||
|
Status string
|
||||||
|
TaskID string
|
||||||
|
ConversationTag string
|
||||||
|
TaskTag string
|
||||||
|
}
|
||||||
|
|
||||||
|
func escapeVulnerabilityLikePattern(s string) string {
|
||||||
|
s = strings.ReplaceAll(s, `\`, `\\`)
|
||||||
|
s = strings.ReplaceAll(s, `%`, `\%`)
|
||||||
|
s = strings.ReplaceAll(s, `_`, `\_`)
|
||||||
|
return "%" + s + "%"
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f VulnerabilityListFilter) appendWhere(query string, args []interface{}) (string, []interface{}) {
|
||||||
|
if f.ID != "" {
|
||||||
|
query += " AND id = ?"
|
||||||
|
args = append(args, f.ID)
|
||||||
|
}
|
||||||
|
if f.ConversationID != "" {
|
||||||
|
query += " AND conversation_id = ?"
|
||||||
|
args = append(args, f.ConversationID)
|
||||||
|
}
|
||||||
|
if f.TaskID != "" {
|
||||||
|
query += " AND EXISTS (SELECT 1 FROM batch_tasks bt WHERE bt.conversation_id = vulnerabilities.conversation_id AND (bt.id = ? OR bt.queue_id = ?))"
|
||||||
|
args = append(args, f.TaskID, f.TaskID)
|
||||||
|
}
|
||||||
|
if f.ConversationTag != "" {
|
||||||
|
query += " AND conversation_tag = ?"
|
||||||
|
args = append(args, f.ConversationTag)
|
||||||
|
}
|
||||||
|
if f.TaskTag != "" {
|
||||||
|
query += " AND task_tag = ?"
|
||||||
|
args = append(args, f.TaskTag)
|
||||||
|
}
|
||||||
|
if f.Severity != "" {
|
||||||
|
query += " AND severity = ?"
|
||||||
|
args = append(args, f.Severity)
|
||||||
|
}
|
||||||
|
if f.Status != "" {
|
||||||
|
query += " AND status = ?"
|
||||||
|
args = append(args, f.Status)
|
||||||
|
}
|
||||||
|
search := strings.TrimSpace(f.Search)
|
||||||
|
if search != "" {
|
||||||
|
pattern := escapeVulnerabilityLikePattern(search)
|
||||||
|
query += ` AND (
|
||||||
|
LOWER(id) LIKE LOWER(?) OR
|
||||||
|
LOWER(title) LIKE LOWER(?) OR
|
||||||
|
LOWER(COALESCE(description, '')) LIKE LOWER(?) OR
|
||||||
|
LOWER(COALESCE(vulnerability_type, '')) LIKE LOWER(?) OR
|
||||||
|
LOWER(COALESCE(target, '')) LIKE LOWER(?) OR
|
||||||
|
LOWER(COALESCE(proof, '')) LIKE LOWER(?) OR
|
||||||
|
LOWER(COALESCE(impact, '')) LIKE LOWER(?) OR
|
||||||
|
LOWER(COALESCE(recommendation, '')) LIKE LOWER(?) OR
|
||||||
|
LOWER(COALESCE(conversation_id, '')) LIKE LOWER(?) OR
|
||||||
|
LOWER(COALESCE(conversation_tag, '')) LIKE LOWER(?) OR
|
||||||
|
LOWER(COALESCE(task_tag, '')) LIKE LOWER(?)
|
||||||
|
)`
|
||||||
|
for i := 0; i < 11; i++ {
|
||||||
|
args = append(args, pattern)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return query, args
|
||||||
|
}
|
||||||
|
|
||||||
// Vulnerability 漏洞
|
// Vulnerability 漏洞
|
||||||
type Vulnerability struct {
|
type Vulnerability struct {
|
||||||
ID string `json:"id"`
|
ID string `json:"id"`
|
||||||
@@ -97,7 +169,7 @@ func (db *DB) GetVulnerability(id string) (*Vulnerability, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// ListVulnerabilities 列出漏洞
|
// ListVulnerabilities 列出漏洞
|
||||||
func (db *DB) ListVulnerabilities(limit, offset int, id, conversationID, severity, status, taskID, conversationTag, taskTag string) ([]*Vulnerability, error) {
|
func (db *DB) ListVulnerabilities(limit, offset int, filter VulnerabilityListFilter) ([]*Vulnerability, error) {
|
||||||
query := `
|
query := `
|
||||||
SELECT id, conversation_id, title, description, severity, status, conversation_tag, task_tag,
|
SELECT id, conversation_id, title, description, severity, status, conversation_tag, task_tag,
|
||||||
vulnerability_type, target, proof, impact, recommendation,
|
vulnerability_type, target, proof, impact, recommendation,
|
||||||
@@ -108,35 +180,7 @@ func (db *DB) ListVulnerabilities(limit, offset int, id, conversationID, severit
|
|||||||
WHERE 1=1
|
WHERE 1=1
|
||||||
`
|
`
|
||||||
args := []interface{}{}
|
args := []interface{}{}
|
||||||
|
query, args = filter.appendWhere(query, args)
|
||||||
if id != "" {
|
|
||||||
query += " AND id = ?"
|
|
||||||
args = append(args, id)
|
|
||||||
}
|
|
||||||
if conversationID != "" {
|
|
||||||
query += " AND conversation_id = ?"
|
|
||||||
args = append(args, conversationID)
|
|
||||||
}
|
|
||||||
if taskID != "" {
|
|
||||||
query += " AND EXISTS (SELECT 1 FROM batch_tasks bt WHERE bt.conversation_id = vulnerabilities.conversation_id AND (bt.id = ? OR bt.queue_id = ?))"
|
|
||||||
args = append(args, taskID, taskID)
|
|
||||||
}
|
|
||||||
if conversationTag != "" {
|
|
||||||
query += " AND conversation_tag = ?"
|
|
||||||
args = append(args, conversationTag)
|
|
||||||
}
|
|
||||||
if taskTag != "" {
|
|
||||||
query += " AND task_tag = ?"
|
|
||||||
args = append(args, taskTag)
|
|
||||||
}
|
|
||||||
if severity != "" {
|
|
||||||
query += " AND severity = ?"
|
|
||||||
args = append(args, severity)
|
|
||||||
}
|
|
||||||
if status != "" {
|
|
||||||
query += " AND status = ?"
|
|
||||||
args = append(args, status)
|
|
||||||
}
|
|
||||||
|
|
||||||
query += " ORDER BY created_at DESC LIMIT ? OFFSET ?"
|
query += " ORDER BY created_at DESC LIMIT ? OFFSET ?"
|
||||||
args = append(args, limit, offset)
|
args = append(args, limit, offset)
|
||||||
@@ -168,38 +212,10 @@ func (db *DB) ListVulnerabilities(limit, offset int, id, conversationID, severit
|
|||||||
}
|
}
|
||||||
|
|
||||||
// CountVulnerabilities 统计漏洞总数(支持筛选条件)
|
// CountVulnerabilities 统计漏洞总数(支持筛选条件)
|
||||||
func (db *DB) CountVulnerabilities(id, conversationID, severity, status, taskID, conversationTag, taskTag string) (int, error) {
|
func (db *DB) CountVulnerabilities(filter VulnerabilityListFilter) (int, error) {
|
||||||
query := "SELECT COUNT(*) FROM vulnerabilities WHERE 1=1"
|
query := "SELECT COUNT(*) FROM vulnerabilities WHERE 1=1"
|
||||||
args := []interface{}{}
|
args := []interface{}{}
|
||||||
|
query, args = filter.appendWhere(query, args)
|
||||||
if id != "" {
|
|
||||||
query += " AND id = ?"
|
|
||||||
args = append(args, id)
|
|
||||||
}
|
|
||||||
if conversationID != "" {
|
|
||||||
query += " AND conversation_id = ?"
|
|
||||||
args = append(args, conversationID)
|
|
||||||
}
|
|
||||||
if taskID != "" {
|
|
||||||
query += " AND EXISTS (SELECT 1 FROM batch_tasks bt WHERE bt.conversation_id = vulnerabilities.conversation_id AND (bt.id = ? OR bt.queue_id = ?))"
|
|
||||||
args = append(args, taskID, taskID)
|
|
||||||
}
|
|
||||||
if conversationTag != "" {
|
|
||||||
query += " AND conversation_tag = ?"
|
|
||||||
args = append(args, conversationTag)
|
|
||||||
}
|
|
||||||
if taskTag != "" {
|
|
||||||
query += " AND task_tag = ?"
|
|
||||||
args = append(args, taskTag)
|
|
||||||
}
|
|
||||||
if severity != "" {
|
|
||||||
query += " AND severity = ?"
|
|
||||||
args = append(args, severity)
|
|
||||||
}
|
|
||||||
if status != "" {
|
|
||||||
query += " AND status = ?"
|
|
||||||
args = append(args, status)
|
|
||||||
}
|
|
||||||
|
|
||||||
var count int
|
var count int
|
||||||
err := db.QueryRow(query, args...).Scan(&count)
|
err := db.QueryRow(query, args...).Scan(&count)
|
||||||
@@ -245,19 +261,12 @@ func (db *DB) DeleteVulnerability(id string) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// GetVulnerabilityStats 获取漏洞统计(筛选条件与 ListVulnerabilities / CountVulnerabilities 一致)
|
// GetVulnerabilityStats 获取漏洞统计(筛选条件与 ListVulnerabilities / CountVulnerabilities 一致)
|
||||||
func (db *DB) GetVulnerabilityStats(conversationID, taskID string) (map[string]interface{}, error) {
|
func (db *DB) GetVulnerabilityStats(filter VulnerabilityListFilter) (map[string]interface{}, error) {
|
||||||
stats := make(map[string]interface{})
|
stats := make(map[string]interface{})
|
||||||
|
|
||||||
where := "WHERE 1=1"
|
where := "WHERE 1=1"
|
||||||
args := []interface{}{}
|
args := []interface{}{}
|
||||||
if conversationID != "" {
|
where, args = filter.appendWhere(where, args)
|
||||||
where += " AND conversation_id = ?"
|
|
||||||
args = append(args, conversationID)
|
|
||||||
}
|
|
||||||
if taskID != "" {
|
|
||||||
where += " AND EXISTS (SELECT 1 FROM batch_tasks bt WHERE bt.conversation_id = vulnerabilities.conversation_id AND (bt.id = ? OR bt.queue_id = ?))"
|
|
||||||
args = append(args, taskID, taskID)
|
|
||||||
}
|
|
||||||
|
|
||||||
// 总漏洞数
|
// 总漏洞数
|
||||||
var totalCount int
|
var totalCount int
|
||||||
|
|||||||
+191
-49
@@ -17,12 +17,14 @@ import (
|
|||||||
"unicode/utf8"
|
"unicode/utf8"
|
||||||
|
|
||||||
"cyberstrike-ai/internal/agent"
|
"cyberstrike-ai/internal/agent"
|
||||||
|
"cyberstrike-ai/internal/audit"
|
||||||
"cyberstrike-ai/internal/config"
|
"cyberstrike-ai/internal/config"
|
||||||
"cyberstrike-ai/internal/database"
|
"cyberstrike-ai/internal/database"
|
||||||
"cyberstrike-ai/internal/reasoning"
|
"cyberstrike-ai/internal/reasoning"
|
||||||
"cyberstrike-ai/internal/mcp"
|
"cyberstrike-ai/internal/mcp"
|
||||||
"cyberstrike-ai/internal/mcp/builtin"
|
"cyberstrike-ai/internal/mcp/builtin"
|
||||||
"cyberstrike-ai/internal/multiagent"
|
"cyberstrike-ai/internal/multiagent"
|
||||||
|
"cyberstrike-ai/internal/openai"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/robfig/cron/v3"
|
"github.com/robfig/cron/v3"
|
||||||
@@ -130,6 +132,12 @@ type AgentHandler struct {
|
|||||||
batchRunning map[string]struct{}
|
batchRunning map[string]struct{}
|
||||||
// hitlWhitelistSaver 侧栏「应用」HITL 时将会话增量白名单合并写入 config.yaml(可选)
|
// hitlWhitelistSaver 侧栏「应用」HITL 时将会话增量白名单合并写入 config.yaml(可选)
|
||||||
hitlWhitelistSaver HitlToolWhitelistSaver
|
hitlWhitelistSaver HitlToolWhitelistSaver
|
||||||
|
audit *audit.Service
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetAudit wires platform audit logging.
|
||||||
|
func (h *AgentHandler) SetAudit(s *audit.Service) {
|
||||||
|
h.audit = s
|
||||||
}
|
}
|
||||||
|
|
||||||
// HitlToolWhitelistSaver 合并 HITL 免审批工具到全局配置并落盘
|
// HitlToolWhitelistSaver 合并 HITL 免审批工具到全局配置并落盘
|
||||||
@@ -206,7 +214,7 @@ type ChatAttachment struct {
|
|||||||
type ChatReasoningRequest struct {
|
type ChatReasoningRequest struct {
|
||||||
// Mode: default(跟随系统)| off | on | auto
|
// Mode: default(跟随系统)| off | on | auto
|
||||||
Mode string `json:"mode,omitempty"`
|
Mode string `json:"mode,omitempty"`
|
||||||
// Effort: low | medium | high | max;空表示不指定(由系统默认与各 profile 决定)。
|
// Effort: low | medium | high | max | xhigh(原样下发;不同网关最高档命名不同)。空表示不指定。
|
||||||
Effort string `json:"effort,omitempty"`
|
Effort string `json:"effort,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -552,7 +560,7 @@ func (h *AgentHandler) AgentLoop(c *gin.Context) {
|
|||||||
conversationID := req.ConversationID
|
conversationID := req.ConversationID
|
||||||
if conversationID == "" {
|
if conversationID == "" {
|
||||||
title := safeTruncateString(req.Message, 50)
|
title := safeTruncateString(req.Message, 50)
|
||||||
conv, err := h.db.CreateConversation(title)
|
conv, err := h.db.CreateConversation(title, audit.ConversationCreateMetaFromGin(c, "agent_loop"))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
h.logger.Error("创建对话失败", zap.Error(err))
|
h.logger.Error("创建对话失败", zap.Error(err))
|
||||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||||
@@ -716,11 +724,43 @@ func (h *AgentHandler) AgentLoop(c *gin.Context) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (h *AgentHandler) finalizeRobotAgentError(ctx context.Context, assistantMessageID, conversationID string, resultMA *multiagent.RunResult, errMA error) (string, string, error) {
|
||||||
|
if shouldPersistEinoAgentTraceAfterRunError(ctx) {
|
||||||
|
h.persistEinoAgentTraceForResume(conversationID, resultMA)
|
||||||
|
}
|
||||||
|
errMsg := "执行失败: " + errMA.Error()
|
||||||
|
if assistantMessageID != "" {
|
||||||
|
_, _ = h.db.Exec("UPDATE messages SET content = ?, updated_at = ? WHERE id = ?", errMsg, time.Now(), assistantMessageID)
|
||||||
|
_ = h.db.AddProcessDetail(assistantMessageID, conversationID, "error", errMsg, nil)
|
||||||
|
}
|
||||||
|
return "", conversationID, errMA
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *AgentHandler) finalizeRobotAgentSuccess(assistantMessageID, conversationID string, resultMA *multiagent.RunResult) (string, string, error) {
|
||||||
|
if assistantMessageID != "" {
|
||||||
|
if errU := h.db.UpdateAssistantMessageFinalize(assistantMessageID, resultMA.Response, resultMA.MCPExecutionIDs, multiagent.AggregatedReasoningFromTraceJSON(resultMA.LastAgentTraceInput)); errU != nil {
|
||||||
|
h.logger.Warn("机器人:更新助手消息失败", zap.Error(errU))
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if _, err := h.db.AddMessage(conversationID, "assistant", resultMA.Response, resultMA.MCPExecutionIDs); err != nil {
|
||||||
|
h.logger.Warn("机器人:保存助手消息失败", zap.Error(err))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if resultMA.LastAgentTraceInput != "" || resultMA.LastAgentTraceOutput != "" {
|
||||||
|
_ = h.db.SaveAgentTrace(conversationID, resultMA.LastAgentTraceInput, resultMA.LastAgentTraceOutput)
|
||||||
|
}
|
||||||
|
return resultMA.Response, conversationID, nil
|
||||||
|
}
|
||||||
|
|
||||||
// ProcessMessageForRobot 供机器人(企业微信/钉钉/飞书)调用:与 /api/agent-loop/stream 相同执行路径(含 progressCallback、过程详情),仅不发送 SSE,最后返回完整回复
|
// ProcessMessageForRobot 供机器人(企业微信/钉钉/飞书)调用:与 /api/agent-loop/stream 相同执行路径(含 progressCallback、过程详情),仅不发送 SSE,最后返回完整回复
|
||||||
func (h *AgentHandler) ProcessMessageForRobot(ctx context.Context, conversationID, message, role string) (response string, convID string, err error) {
|
func (h *AgentHandler) ProcessMessageForRobot(ctx context.Context, platform, conversationID, message, role string) (response string, convID string, err error) {
|
||||||
if conversationID == "" {
|
if conversationID == "" {
|
||||||
title := safeTruncateString(message, 50)
|
title := safeTruncateString(message, 50)
|
||||||
conv, createErr := h.db.CreateConversation(title)
|
src := "robot"
|
||||||
|
if strings.TrimSpace(platform) != "" {
|
||||||
|
src = "robot:" + strings.TrimSpace(platform)
|
||||||
|
}
|
||||||
|
conv, createErr := h.db.CreateConversation(title, audit.ConversationCreateMeta(src))
|
||||||
if createErr != nil {
|
if createErr != nil {
|
||||||
return "", "", fmt.Errorf("创建对话失败: %w", createErr)
|
return "", "", fmt.Errorf("创建对话失败: %w", createErr)
|
||||||
}
|
}
|
||||||
@@ -768,53 +808,92 @@ func (h *AgentHandler) ProcessMessageForRobot(ctx context.Context, conversationI
|
|||||||
if assistantMsg != nil {
|
if assistantMsg != nil {
|
||||||
assistantMessageID = assistantMsg.ID
|
assistantMessageID = assistantMsg.ID
|
||||||
}
|
}
|
||||||
progressCallback := h.createProgressCallback(ctx, nil, conversationID, assistantMessageID, nil)
|
|
||||||
|
|
||||||
useRobotMulti := h.config != nil && h.config.MultiAgent.Enabled && h.config.MultiAgent.RobotUseMultiAgent
|
// 注册运行中任务并向 taskEventBus 镜像进度事件,供 Web 端 task-events 补流(与 agent-loop/stream 一致)。
|
||||||
if useRobotMulti {
|
taskCtx, cancelWithCause := context.WithCancelCause(ctx)
|
||||||
resultMA, errMA := multiagent.RunDeepAgent(
|
defer cancelWithCause(nil)
|
||||||
ctx,
|
taskStatus := "completed"
|
||||||
h.config,
|
defer func() {
|
||||||
&h.config.MultiAgent,
|
h.tasks.FinishTask(conversationID, taskStatus)
|
||||||
h.agent,
|
}()
|
||||||
h.logger,
|
if _, err := h.tasks.StartTask(conversationID, message, cancelWithCause); err != nil {
|
||||||
conversationID,
|
if errors.Is(err, ErrTaskAlreadyRunning) {
|
||||||
finalMessage,
|
return "", conversationID, fmt.Errorf("当前会话已有任务正在执行中,请稍后再试")
|
||||||
agentHistoryMessages,
|
|
||||||
roleTools,
|
|
||||||
progressCallback,
|
|
||||||
h.agentsMarkdownDir,
|
|
||||||
"deep",
|
|
||||||
nil,
|
|
||||||
)
|
|
||||||
if errMA != nil {
|
|
||||||
if shouldPersistEinoAgentTraceAfterRunError(ctx) {
|
|
||||||
h.persistEinoAgentTraceForResume(conversationID, resultMA)
|
|
||||||
}
|
|
||||||
errMsg := "执行失败: " + errMA.Error()
|
|
||||||
if assistantMessageID != "" {
|
|
||||||
_, _ = h.db.Exec("UPDATE messages SET content = ?, updated_at = ? WHERE id = ?", errMsg, time.Now(), assistantMessageID)
|
|
||||||
_ = h.db.AddProcessDetail(assistantMessageID, conversationID, "error", errMsg, nil)
|
|
||||||
}
|
|
||||||
return "", conversationID, errMA
|
|
||||||
}
|
}
|
||||||
if assistantMessageID != "" {
|
return "", conversationID, fmt.Errorf("无法启动任务: %w", err)
|
||||||
if errU := h.db.UpdateAssistantMessageFinalize(assistantMessageID, resultMA.Response, resultMA.MCPExecutionIDs, multiagent.AggregatedReasoningFromTraceJSON(resultMA.LastAgentTraceInput)); errU != nil {
|
}
|
||||||
h.logger.Warn("机器人:更新助手消息失败", zap.Error(errU))
|
progressCallback := h.createProgressCallback(taskCtx, cancelWithCause, conversationID, assistantMessageID, nil)
|
||||||
|
|
||||||
|
robotMode := "react"
|
||||||
|
if h.config != nil {
|
||||||
|
robotMode = config.NormalizeRobotAgentMode(h.config.MultiAgent)
|
||||||
|
}
|
||||||
|
switch robotMode {
|
||||||
|
case "eino_single":
|
||||||
|
curHist := agentHistoryMessages
|
||||||
|
curMsg := finalMessage
|
||||||
|
segmentUserMessage := finalMessage
|
||||||
|
var resultMA *multiagent.RunResult
|
||||||
|
var errMA error
|
||||||
|
var transientRunAttempts int
|
||||||
|
for {
|
||||||
|
resultMA, errMA = multiagent.RunEinoSingleChatModelAgent(
|
||||||
|
taskCtx, h.config, &h.config.MultiAgent, h.agent, h.logger,
|
||||||
|
conversationID, curMsg, curHist, roleTools, progressCallback, nil,
|
||||||
|
)
|
||||||
|
if errMA == nil {
|
||||||
|
// 成功后重置 transient 重试窗口,下一次分段从第 1 次重试开始。
|
||||||
|
transientRunAttempts = 0
|
||||||
|
break
|
||||||
}
|
}
|
||||||
} else {
|
if handled, _ := h.handleEinoTransientRetryContinue(
|
||||||
if _, err = h.db.AddMessage(conversationID, "assistant", resultMA.Response, resultMA.MCPExecutionIDs); err != nil {
|
taskCtx, conversationID, resultMA, errMA, &transientRunAttempts,
|
||||||
h.logger.Warn("机器人:保存助手消息失败", zap.Error(err))
|
&curHist, &curMsg, segmentUserMessage, progressCallback, nil,
|
||||||
|
); handled {
|
||||||
|
continue
|
||||||
}
|
}
|
||||||
|
taskStatus = "failed"
|
||||||
|
return h.finalizeRobotAgentError(taskCtx, assistantMessageID, conversationID, resultMA, errMA)
|
||||||
}
|
}
|
||||||
if resultMA.LastAgentTraceInput != "" || resultMA.LastAgentTraceOutput != "" {
|
return h.finalizeRobotAgentSuccess(assistantMessageID, conversationID, resultMA)
|
||||||
_ = h.db.SaveAgentTrace(conversationID, resultMA.LastAgentTraceInput, resultMA.LastAgentTraceOutput)
|
case "deep", "plan_execute", "supervisor":
|
||||||
|
if h.config == nil || !h.config.MultiAgent.Enabled {
|
||||||
|
h.logger.Warn("机器人配置为多代理模式但未启用 multi_agent,回退原生 ReAct",
|
||||||
|
zap.String("robot_mode", robotMode))
|
||||||
|
break
|
||||||
}
|
}
|
||||||
return resultMA.Response, conversationID, nil
|
curHist := agentHistoryMessages
|
||||||
|
curMsg := finalMessage
|
||||||
|
segmentUserMessage := finalMessage
|
||||||
|
var resultMA *multiagent.RunResult
|
||||||
|
var errMA error
|
||||||
|
var transientRunAttempts int
|
||||||
|
for {
|
||||||
|
resultMA, errMA = multiagent.RunDeepAgent(
|
||||||
|
taskCtx, h.config, &h.config.MultiAgent, h.agent, h.logger,
|
||||||
|
conversationID, curMsg, curHist, roleTools, progressCallback,
|
||||||
|
h.agentsMarkdownDir, robotMode, nil,
|
||||||
|
)
|
||||||
|
if errMA == nil {
|
||||||
|
// 成功后重置 transient 重试窗口,下一次分段从第 1 次重试开始。
|
||||||
|
transientRunAttempts = 0
|
||||||
|
break
|
||||||
|
}
|
||||||
|
if handled, _ := h.handleEinoTransientRetryContinue(
|
||||||
|
taskCtx, conversationID, resultMA, errMA, &transientRunAttempts,
|
||||||
|
&curHist, &curMsg, segmentUserMessage, progressCallback, nil,
|
||||||
|
); handled {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
taskStatus = "failed"
|
||||||
|
return h.finalizeRobotAgentError(taskCtx, assistantMessageID, conversationID, resultMA, errMA)
|
||||||
|
}
|
||||||
|
return h.finalizeRobotAgentSuccess(assistantMessageID, conversationID, resultMA)
|
||||||
}
|
}
|
||||||
|
|
||||||
result, err := h.agent.AgentLoopWithProgress(ctx, finalMessage, agentHistoryMessages, conversationID, progressCallback, roleTools)
|
result, err := h.agent.AgentLoopWithProgress(taskCtx, finalMessage, agentHistoryMessages, conversationID, progressCallback, roleTools)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
taskStatus = "failed"
|
||||||
errMsg := "执行失败: " + err.Error()
|
errMsg := "执行失败: " + err.Error()
|
||||||
if assistantMessageID != "" {
|
if assistantMessageID != "" {
|
||||||
_, _ = h.db.Exec("UPDATE messages SET content = ?, updated_at = ? WHERE id = ?", errMsg, time.Now(), assistantMessageID)
|
_, _ = h.db.Exec("UPDATE messages SET content = ?, updated_at = ? WHERE id = ?", errMsg, time.Now(), assistantMessageID)
|
||||||
@@ -846,6 +925,23 @@ type StreamEvent struct {
|
|||||||
Data interface{} `json:"data,omitempty"`
|
Data interface{} `json:"data,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// publishProgressToTaskEventBus 将进度事件镜像到 taskEventBus(机器人/无 HTTP SSE 客户端时供 Web task-events 订阅)。
|
||||||
|
func (h *AgentHandler) publishProgressToTaskEventBus(conversationID, eventType, message string, data interface{}) {
|
||||||
|
if h == nil || h.taskEventBus == nil || strings.TrimSpace(conversationID) == "" {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
event := StreamEvent{Type: eventType, Message: message, Data: data}
|
||||||
|
eventJSON, err := json.Marshal(event)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
sseLine := make([]byte, 0, len(eventJSON)+8)
|
||||||
|
sseLine = append(sseLine, []byte("data: ")...)
|
||||||
|
sseLine = append(sseLine, eventJSON...)
|
||||||
|
sseLine = append(sseLine, '\n', '\n')
|
||||||
|
h.taskEventBus.Publish(conversationID, sseLine)
|
||||||
|
}
|
||||||
|
|
||||||
// createProgressCallback 创建进度回调函数,用于保存processDetails
|
// createProgressCallback 创建进度回调函数,用于保存processDetails
|
||||||
// sendEventFunc: 可选的流式事件发送函数,如果为nil则不发送流式事件
|
// sendEventFunc: 可选的流式事件发送函数,如果为nil则不发送流式事件
|
||||||
func (h *AgentHandler) createProgressCallback(runCtx context.Context, cancelRun context.CancelCauseFunc, conversationID, assistantMessageID string, sendEventFunc func(eventType, message string, data interface{})) agent.ProgressCallback {
|
func (h *AgentHandler) createProgressCallback(runCtx context.Context, cancelRun context.CancelCauseFunc, conversationID, assistantMessageID string, sendEventFunc func(eventType, message string, data interface{})) agent.ProgressCallback {
|
||||||
@@ -955,9 +1051,11 @@ func (h *AgentHandler) createProgressCallback(runCtx context.Context, cancelRun
|
|||||||
}
|
}
|
||||||
|
|
||||||
return func(eventType, message string, data interface{}) {
|
return func(eventType, message string, data interface{}) {
|
||||||
// 如果提供了sendEventFunc,发送流式事件
|
// 流式:写 HTTP SSE;非流式(机器人等):镜像到 taskEventBus 供 Web 订阅
|
||||||
if sendEventFunc != nil {
|
if sendEventFunc != nil {
|
||||||
sendEventFunc(eventType, message, data)
|
sendEventFunc(eventType, message, data)
|
||||||
|
} else {
|
||||||
|
h.publishProgressToTaskEventBus(conversationID, eventType, message, data)
|
||||||
}
|
}
|
||||||
|
|
||||||
// 保存tool_call事件中的参数
|
// 保存tool_call事件中的参数
|
||||||
@@ -1158,7 +1256,16 @@ func (h *AgentHandler) createProgressCallback(runCtx context.Context, cancelRun
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
if eventType == "response_delta" {
|
if eventType == "response_delta" {
|
||||||
respPlan.b.WriteString(message)
|
if dataMap, ok := data.(map[string]interface{}); ok {
|
||||||
|
if acc, okAcc := dataMap[openai.SSEAccumulatedKey].(string); okAcc {
|
||||||
|
respPlan.b.Reset()
|
||||||
|
respPlan.b.WriteString(acc)
|
||||||
|
} else {
|
||||||
|
respPlan.b.WriteString(message)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
respPlan.b.WriteString(message)
|
||||||
|
}
|
||||||
if dataMap, ok := data.(map[string]interface{}); ok && respPlan.meta == nil {
|
if dataMap, ok := data.(map[string]interface{}); ok && respPlan.meta == nil {
|
||||||
respPlan.meta = make(map[string]interface{}, len(dataMap))
|
respPlan.meta = make(map[string]interface{}, len(dataMap))
|
||||||
for k, v := range dataMap {
|
for k, v := range dataMap {
|
||||||
@@ -1213,8 +1320,12 @@ func (h *AgentHandler) createProgressCallback(runCtx context.Context, cancelRun
|
|||||||
} else if tb.persistAs == "" {
|
} else if tb.persistAs == "" {
|
||||||
tb.persistAs = persistAs
|
tb.persistAs = persistAs
|
||||||
}
|
}
|
||||||
// delta 片段直接拼接
|
if acc, okAcc := dataMap[openai.SSEAccumulatedKey].(string); okAcc {
|
||||||
tb.b.WriteString(message)
|
tb.b.Reset()
|
||||||
|
tb.b.WriteString(acc)
|
||||||
|
} else {
|
||||||
|
tb.b.WriteString(message)
|
||||||
|
}
|
||||||
// 有时 delta 先到 start 未到,补充元信息
|
// 有时 delta 先到 start 未到,补充元信息
|
||||||
for k, v := range dataMap {
|
for k, v := range dataMap {
|
||||||
tb.meta[k] = v
|
tb.meta[k] = v
|
||||||
@@ -1406,10 +1517,12 @@ func (h *AgentHandler) AgentLoopStream(c *gin.Context) {
|
|||||||
title := safeTruncateString(req.Message, 50)
|
title := safeTruncateString(req.Message, 50)
|
||||||
var conv *database.Conversation
|
var conv *database.Conversation
|
||||||
var err error
|
var err error
|
||||||
|
meta := audit.ConversationCreateMetaFromGin(c, "agent_loop_stream")
|
||||||
if req.WebShellConnectionID != "" {
|
if req.WebShellConnectionID != "" {
|
||||||
conv, err = h.db.CreateConversationWithWebshell(strings.TrimSpace(req.WebShellConnectionID), title)
|
meta.Source = "webshell_chat"
|
||||||
|
conv, err = h.db.CreateConversationWithWebshell(strings.TrimSpace(req.WebShellConnectionID), title, meta)
|
||||||
} else {
|
} else {
|
||||||
conv, err = h.db.CreateConversation(title)
|
conv, err = h.db.CreateConversation(title, meta)
|
||||||
}
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
h.logger.Error("创建对话失败", zap.Error(err))
|
h.logger.Error("创建对话失败", zap.Error(err))
|
||||||
@@ -2025,6 +2138,11 @@ func (h *AgentHandler) CreateBatchQueue(c *gin.Context) {
|
|||||||
queue = refreshed
|
queue = refreshed
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if h.audit != nil {
|
||||||
|
h.audit.RecordOK(c, "task", "create_queue", "创建批量任务队列", "batch_queue", queue.ID, map[string]interface{}{
|
||||||
|
"task_count": len(validTasks), "started": started,
|
||||||
|
})
|
||||||
|
}
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
"queueId": queue.ID,
|
"queueId": queue.ID,
|
||||||
"queue": queue,
|
"queue": queue,
|
||||||
@@ -2132,6 +2250,9 @@ func (h *AgentHandler) StartBatchQueue(c *gin.Context) {
|
|||||||
c.JSON(http.StatusNotFound, gin.H{"error": "队列不存在"})
|
c.JSON(http.StatusNotFound, gin.H{"error": "队列不存在"})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
if h.audit != nil {
|
||||||
|
h.audit.RecordOK(c, "task", "start_queue", "启动批量任务队列", "batch_queue", queueID, nil)
|
||||||
|
}
|
||||||
c.JSON(http.StatusOK, gin.H{"message": "批量任务已开始执行", "queueId": queueID})
|
c.JSON(http.StatusOK, gin.H{"message": "批量任务已开始执行", "queueId": queueID})
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -2160,6 +2281,9 @@ func (h *AgentHandler) RerunBatchQueue(c *gin.Context) {
|
|||||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "启动失败"})
|
c.JSON(http.StatusInternalServerError, gin.H{"error": "启动失败"})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
if h.audit != nil {
|
||||||
|
h.audit.RecordOK(c, "task", "rerun_queue", "重跑批量任务队列", "batch_queue", queueID, nil)
|
||||||
|
}
|
||||||
c.JSON(http.StatusOK, gin.H{"message": "批量任务已重新开始执行", "queueId": queueID})
|
c.JSON(http.StatusOK, gin.H{"message": "批量任务已重新开始执行", "queueId": queueID})
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -2171,6 +2295,9 @@ func (h *AgentHandler) PauseBatchQueue(c *gin.Context) {
|
|||||||
c.JSON(http.StatusNotFound, gin.H{"error": "队列不存在或无法暂停"})
|
c.JSON(http.StatusNotFound, gin.H{"error": "队列不存在或无法暂停"})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
if h.audit != nil {
|
||||||
|
h.audit.RecordOK(c, "task", "pause_queue", "暂停批量任务队列", "batch_queue", queueID, nil)
|
||||||
|
}
|
||||||
c.JSON(http.StatusOK, gin.H{"message": "批量任务已暂停"})
|
c.JSON(http.StatusOK, gin.H{"message": "批量任务已暂停"})
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -2266,6 +2393,16 @@ func (h *AgentHandler) DeleteBatchQueue(c *gin.Context) {
|
|||||||
c.JSON(http.StatusNotFound, gin.H{"error": "队列不存在"})
|
c.JSON(http.StatusNotFound, gin.H{"error": "队列不存在"})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
if h.audit != nil {
|
||||||
|
h.audit.Record(c, audit.Entry{
|
||||||
|
Category: "task",
|
||||||
|
Action: "delete_queue",
|
||||||
|
Result: "success",
|
||||||
|
ResourceType: "batch_queue",
|
||||||
|
ResourceID: queueID,
|
||||||
|
Message: "删除批量任务队列",
|
||||||
|
})
|
||||||
|
}
|
||||||
c.JSON(http.StatusOK, gin.H{"message": "批量任务队列已删除"})
|
c.JSON(http.StatusOK, gin.H{"message": "批量任务队列已删除"})
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -2351,6 +2488,11 @@ func (h *AgentHandler) DeleteBatchTask(c *gin.Context) {
|
|||||||
c.JSON(http.StatusNotFound, gin.H{"error": "队列不存在"})
|
c.JSON(http.StatusNotFound, gin.H{"error": "队列不存在"})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
if h.audit != nil {
|
||||||
|
h.audit.RecordOK(c, "task", "delete_batch_task", "删除批量子任务", "batch_task", taskID, map[string]interface{}{
|
||||||
|
"batch_queue_id": queueID,
|
||||||
|
})
|
||||||
|
}
|
||||||
c.JSON(http.StatusOK, gin.H{"message": "任务已删除", "queue": queue})
|
c.JSON(http.StatusOK, gin.H{"message": "任务已删除", "queue": queue})
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -2509,7 +2651,7 @@ func (h *AgentHandler) executeBatchQueue(queueID string) {
|
|||||||
|
|
||||||
// 创建新对话
|
// 创建新对话
|
||||||
title := safeTruncateString(task.Message, 50)
|
title := safeTruncateString(task.Message, 50)
|
||||||
conv, err := h.db.CreateConversation(title)
|
conv, err := h.db.CreateConversation(title, audit.ConversationCreateMeta("batch_task"))
|
||||||
var conversationID string
|
var conversationID string
|
||||||
if err != nil {
|
if err != nil {
|
||||||
h.logger.Error("创建对话失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.Error(err))
|
h.logger.Error("创建对话失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.Error(err))
|
||||||
|
|||||||
@@ -0,0 +1,147 @@
|
|||||||
|
package handler
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"cyberstrike-ai/internal/audit"
|
||||||
|
"cyberstrike-ai/internal/database"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"go.uber.org/zap"
|
||||||
|
)
|
||||||
|
|
||||||
|
// AuditHandler serves platform audit log APIs.
|
||||||
|
type AuditHandler struct {
|
||||||
|
db *database.DB
|
||||||
|
audit *audit.Service
|
||||||
|
logger *zap.Logger
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewAuditHandler creates an audit log handler.
|
||||||
|
func NewAuditHandler(db *database.DB, auditSvc *audit.Service, logger *zap.Logger) *AuditHandler {
|
||||||
|
return &AuditHandler{db: db, audit: auditSvc, logger: logger}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Meta GET /api/audit/meta
|
||||||
|
func (h *AuditHandler) Meta(c *gin.Context) {
|
||||||
|
enabled := false
|
||||||
|
retentionDays := 0
|
||||||
|
if h.audit != nil {
|
||||||
|
enabled = h.audit.Enabled()
|
||||||
|
retentionDays = h.audit.RetentionDays()
|
||||||
|
}
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"enabled": enabled,
|
||||||
|
"retention_days": retentionDays,
|
||||||
|
"default_page_size": 20,
|
||||||
|
"max_page_size": 100,
|
||||||
|
"max_export": 5000,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// Summary GET /api/audit/summary
|
||||||
|
func (h *AuditHandler) Summary(c *gin.Context) {
|
||||||
|
if h.db == nil {
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{"error": "database unavailable"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
base := auditFilterFromQuery(c)
|
||||||
|
total, err := h.db.CountAuditLogs(base)
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
failFilter := base
|
||||||
|
failFilter.Result = "failure"
|
||||||
|
failures, err := h.db.CountAuditLogs(failFilter)
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
since := time.Now().AddDate(0, 0, -7)
|
||||||
|
recentFilter := base
|
||||||
|
recentFilter.Since = &since
|
||||||
|
recent7d, err := h.db.CountAuditLogs(recentFilter)
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"total": total,
|
||||||
|
"failures": failures,
|
||||||
|
"recent_7d": recent7d,
|
||||||
|
"has_filters": c.Query("category") != "" || c.Query("action") != "" || c.Query("result") != "" ||
|
||||||
|
c.Query("q") != "" || c.Query("since") != "" || c.Query("until") != "",
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// ListLogs GET /api/audit/logs
|
||||||
|
func (h *AuditHandler) ListLogs(c *gin.Context) {
|
||||||
|
if h.db == nil {
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{"error": "database unavailable"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
filter := auditFilterFromQuery(c)
|
||||||
|
page, pageSize := auditPaginationFromQuery(c)
|
||||||
|
filter.Limit = pageSize
|
||||||
|
filter.Offset = (page - 1) * pageSize
|
||||||
|
|
||||||
|
logs, err := h.db.ListAuditLogs(filter)
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
total, err := h.db.CountAuditLogs(filter)
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"logs": logs,
|
||||||
|
"total": total,
|
||||||
|
"page": page,
|
||||||
|
"page_size": pageSize,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetLog GET /api/audit/logs/:id
|
||||||
|
func (h *AuditHandler) GetLog(c *gin.Context) {
|
||||||
|
if h.db == nil {
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{"error": "database unavailable"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
row, err := h.db.GetAuditLogByID(c.Param("id"))
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(http.StatusNotFound, gin.H{"error": "审计记录不存在"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
audit.ApplyResourceAvailability(h.db, row)
|
||||||
|
c.JSON(http.StatusOK, gin.H{"log": row})
|
||||||
|
}
|
||||||
|
|
||||||
|
// ExportLogs GET /api/audit/logs/export — JSON or CSV (?format=csv), max 5000 rows.
|
||||||
|
func (h *AuditHandler) ExportLogs(c *gin.Context) {
|
||||||
|
if h.db == nil {
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{"error": "database unavailable"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
filter := auditFilterFromQuery(c)
|
||||||
|
filter.Limit = 5000
|
||||||
|
filter.Offset = 0
|
||||||
|
|
||||||
|
logs, err := h.db.ListAuditLogs(filter)
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if c.Query("format") == "csv" {
|
||||||
|
writeAuditLogsCSV(c, logs)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
c.Header("Content-Disposition", `attachment; filename="audit-logs.json"`)
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"exported_at": time.Now().UTC().Format(time.RFC3339),
|
||||||
|
"logs": logs,
|
||||||
|
})
|
||||||
|
}
|
||||||
@@ -0,0 +1,42 @@
|
|||||||
|
package handler
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/csv"
|
||||||
|
"fmt"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"cyberstrike-ai/internal/database"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
)
|
||||||
|
|
||||||
|
func writeAuditLogsCSV(c *gin.Context, logs []*database.AuditLog) {
|
||||||
|
c.Header("Content-Type", "text/csv; charset=utf-8")
|
||||||
|
c.Header("Content-Disposition", fmt.Sprintf(`attachment; filename="audit-logs-%s.csv"`, time.Now().Format("20060102")))
|
||||||
|
|
||||||
|
w := csv.NewWriter(c.Writer)
|
||||||
|
_ = w.Write([]string{
|
||||||
|
"id", "created_at", "level", "category", "action", "result", "actor",
|
||||||
|
"session_hint", "client_ip", "resource_type", "resource_id", "message",
|
||||||
|
})
|
||||||
|
for _, row := range logs {
|
||||||
|
if row == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
_ = w.Write([]string{
|
||||||
|
row.ID,
|
||||||
|
row.CreatedAt.UTC().Format(time.RFC3339),
|
||||||
|
row.Level,
|
||||||
|
row.Category,
|
||||||
|
row.Action,
|
||||||
|
row.Result,
|
||||||
|
row.Actor,
|
||||||
|
row.SessionHint,
|
||||||
|
row.ClientIP,
|
||||||
|
row.ResourceType,
|
||||||
|
row.ResourceID,
|
||||||
|
row.Message,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
w.Flush()
|
||||||
|
}
|
||||||
@@ -0,0 +1,48 @@
|
|||||||
|
package handler
|
||||||
|
|
||||||
|
import (
|
||||||
|
"strconv"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"cyberstrike-ai/internal/database"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
)
|
||||||
|
|
||||||
|
func auditFilterFromQuery(c *gin.Context) database.ListAuditLogsFilter {
|
||||||
|
filter := database.ListAuditLogsFilter{
|
||||||
|
Level: c.Query("level"),
|
||||||
|
Category: c.Query("category"),
|
||||||
|
Action: c.Query("action"),
|
||||||
|
Result: c.Query("result"),
|
||||||
|
Query: c.Query("q"),
|
||||||
|
ResourceType: c.Query("resource_type"),
|
||||||
|
ResourceID: c.Query("resource_id"),
|
||||||
|
}
|
||||||
|
if since := c.Query("since"); since != "" {
|
||||||
|
if t, err := time.Parse(time.RFC3339, since); err == nil {
|
||||||
|
filter.Since = &t
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if until := c.Query("until"); until != "" {
|
||||||
|
if t, err := time.Parse(time.RFC3339, until); err == nil {
|
||||||
|
filter.Until = &t
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return filter
|
||||||
|
}
|
||||||
|
|
||||||
|
func auditPaginationFromQuery(c *gin.Context) (page, pageSize int) {
|
||||||
|
page = 1
|
||||||
|
pageSize = 20
|
||||||
|
if p, err := strconv.Atoi(c.DefaultQuery("page", "1")); err == nil && p > 0 {
|
||||||
|
page = p
|
||||||
|
}
|
||||||
|
if ps, err := strconv.Atoi(c.DefaultQuery("page_size", "20")); err == nil && ps > 0 {
|
||||||
|
pageSize = ps
|
||||||
|
if pageSize > 100 {
|
||||||
|
pageSize = 100
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return page, pageSize
|
||||||
|
}
|
||||||
@@ -5,6 +5,7 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"cyberstrike-ai/internal/audit"
|
||||||
"cyberstrike-ai/internal/config"
|
"cyberstrike-ai/internal/config"
|
||||||
"cyberstrike-ai/internal/security"
|
"cyberstrike-ai/internal/security"
|
||||||
|
|
||||||
@@ -18,6 +19,12 @@ type AuthHandler struct {
|
|||||||
config *config.Config
|
config *config.Config
|
||||||
configPath string
|
configPath string
|
||||||
logger *zap.Logger
|
logger *zap.Logger
|
||||||
|
audit *audit.Service
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetAudit wires platform audit logging.
|
||||||
|
func (h *AuthHandler) SetAudit(s *audit.Service) {
|
||||||
|
h.audit = s
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewAuthHandler creates a new AuthHandler.
|
// NewAuthHandler creates a new AuthHandler.
|
||||||
@@ -49,10 +56,32 @@ func (h *AuthHandler) Login(c *gin.Context) {
|
|||||||
|
|
||||||
token, expiresAt, err := h.manager.Authenticate(req.Password)
|
token, expiresAt, err := h.manager.Authenticate(req.Password)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
if h.audit != nil {
|
||||||
|
h.audit.Record(c, audit.Entry{
|
||||||
|
Level: "warn",
|
||||||
|
Category: "auth",
|
||||||
|
Action: "login",
|
||||||
|
Result: "failure",
|
||||||
|
Message: "登录失败:密码错误",
|
||||||
|
})
|
||||||
|
}
|
||||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "密码错误"})
|
c.JSON(http.StatusUnauthorized, gin.H{"error": "密码错误"})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if h.audit != nil {
|
||||||
|
h.audit.Record(c, audit.Entry{
|
||||||
|
Category: "auth",
|
||||||
|
Action: "login",
|
||||||
|
Result: "success",
|
||||||
|
SessionHint: audit.HintFromToken(token),
|
||||||
|
Message: "登录成功",
|
||||||
|
Detail: map[string]interface{}{
|
||||||
|
"expires_at": expiresAt.UTC().Format(time.RFC3339),
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
"token": token,
|
"token": token,
|
||||||
"expires_at": expiresAt.UTC().Format(time.RFC3339),
|
"expires_at": expiresAt.UTC().Format(time.RFC3339),
|
||||||
@@ -73,6 +102,14 @@ func (h *AuthHandler) Logout(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
h.manager.RevokeToken(token)
|
h.manager.RevokeToken(token)
|
||||||
|
if h.audit != nil {
|
||||||
|
h.audit.Record(c, audit.Entry{
|
||||||
|
Category: "auth",
|
||||||
|
Action: "logout",
|
||||||
|
Result: "success",
|
||||||
|
Message: "退出登录",
|
||||||
|
})
|
||||||
|
}
|
||||||
c.JSON(http.StatusOK, gin.H{"message": "已退出登录"})
|
c.JSON(http.StatusOK, gin.H{"message": "已退出登录"})
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -103,6 +140,15 @@ func (h *AuthHandler) ChangePassword(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if !h.manager.CheckPassword(oldPassword) {
|
if !h.manager.CheckPassword(oldPassword) {
|
||||||
|
if h.audit != nil {
|
||||||
|
h.audit.Record(c, audit.Entry{
|
||||||
|
Level: "warn",
|
||||||
|
Category: "auth",
|
||||||
|
Action: "change_password",
|
||||||
|
Result: "failure",
|
||||||
|
Message: "修改密码失败:当前密码不正确",
|
||||||
|
})
|
||||||
|
}
|
||||||
c.JSON(http.StatusBadRequest, gin.H{"error": "当前密码不正确"})
|
c.JSON(http.StatusBadRequest, gin.H{"error": "当前密码不正确"})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -132,6 +178,15 @@ func (h *AuthHandler) ChangePassword(c *gin.Context) {
|
|||||||
h.logger.Info("登录密码已更新,所有会话已失效")
|
h.logger.Info("登录密码已更新,所有会话已失效")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if h.audit != nil {
|
||||||
|
h.audit.Record(c, audit.Entry{
|
||||||
|
Category: "auth",
|
||||||
|
Action: "change_password",
|
||||||
|
Result: "success",
|
||||||
|
Message: "登录密码已修改",
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
c.JSON(http.StatusOK, gin.H{"message": "密码已更新,请使用新密码重新登录"})
|
c.JSON(http.StatusOK, gin.H{"message": "密码已更新,请使用新密码重新登录"})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -13,6 +13,7 @@ import (
|
|||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"cyberstrike-ai/internal/audit"
|
||||||
"cyberstrike-ai/internal/c2"
|
"cyberstrike-ai/internal/c2"
|
||||||
"cyberstrike-ai/internal/database"
|
"cyberstrike-ai/internal/database"
|
||||||
|
|
||||||
@@ -25,6 +26,12 @@ import (
|
|||||||
type C2Handler struct {
|
type C2Handler struct {
|
||||||
mgrPtr atomic.Pointer[c2.Manager]
|
mgrPtr atomic.Pointer[c2.Manager]
|
||||||
logger *zap.Logger
|
logger *zap.Logger
|
||||||
|
audit *audit.Service
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetAudit wires platform audit logging.
|
||||||
|
func (h *C2Handler) SetAudit(s *audit.Service) {
|
||||||
|
h.audit = s
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewC2Handler 创建 C2 处理器;manager 可为 nil(功能关闭时)
|
// NewC2Handler 创建 C2 处理器;manager 可为 nil(功能关闭时)
|
||||||
@@ -104,6 +111,11 @@ func (h *C2Handler) CreateListener(c *gin.Context) {
|
|||||||
implantToken := listener.ImplantToken
|
implantToken := listener.ImplantToken
|
||||||
listener.EncryptionKey = ""
|
listener.EncryptionKey = ""
|
||||||
listener.ImplantToken = ""
|
listener.ImplantToken = ""
|
||||||
|
if h.audit != nil {
|
||||||
|
h.audit.RecordOK(c, "c2", "listener_create", "创建 C2 监听器", "c2_listener", listener.ID, map[string]interface{}{
|
||||||
|
"name": listener.Name, "bind": listener.BindHost, "port": listener.BindPort,
|
||||||
|
})
|
||||||
|
}
|
||||||
c.JSON(http.StatusOK, gin.H{"listener": listener, "implant_token": implantToken})
|
c.JSON(http.StatusOK, gin.H{"listener": listener, "implant_token": implantToken})
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -205,6 +217,9 @@ func (h *C2Handler) DeleteListener(c *gin.Context) {
|
|||||||
c.JSON(code, gin.H{"error": err.Error()})
|
c.JSON(code, gin.H{"error": err.Error()})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
if h.audit != nil {
|
||||||
|
h.audit.RecordOK(c, "c2", "listener_delete", "删除 C2 监听器", "c2_listener", id, nil)
|
||||||
|
}
|
||||||
c.JSON(http.StatusOK, gin.H{"deleted": true})
|
c.JSON(http.StatusOK, gin.H{"deleted": true})
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -222,6 +237,9 @@ func (h *C2Handler) StartListener(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
listener.EncryptionKey = ""
|
listener.EncryptionKey = ""
|
||||||
listener.ImplantToken = ""
|
listener.ImplantToken = ""
|
||||||
|
if h.audit != nil {
|
||||||
|
h.audit.RecordOK(c, "c2", "listener_start", "启动 C2 监听器", "c2_listener", id, nil)
|
||||||
|
}
|
||||||
c.JSON(http.StatusOK, gin.H{"listener": listener})
|
c.JSON(http.StatusOK, gin.H{"listener": listener})
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -236,6 +254,9 @@ func (h *C2Handler) StopListener(c *gin.Context) {
|
|||||||
c.JSON(code, gin.H{"error": err.Error()})
|
c.JSON(code, gin.H{"error": err.Error()})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
if h.audit != nil {
|
||||||
|
h.audit.RecordOK(c, "c2", "listener_stop", "停止 C2 监听器", "c2_listener", id, nil)
|
||||||
|
}
|
||||||
c.JSON(http.StatusOK, gin.H{"stopped": true})
|
c.JSON(http.StatusOK, gin.H{"stopped": true})
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -297,6 +318,9 @@ func (h *C2Handler) DeleteSession(c *gin.Context) {
|
|||||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
if h.audit != nil {
|
||||||
|
h.audit.RecordOK(c, "c2", "session_delete", "删除 C2 会话", "c2_session", id, nil)
|
||||||
|
}
|
||||||
c.JSON(http.StatusOK, gin.H{"deleted": true})
|
c.JSON(http.StatusOK, gin.H{"deleted": true})
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -407,6 +431,11 @@ func (h *C2Handler) DeleteTasks(c *gin.Context) {
|
|||||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
if h.audit != nil {
|
||||||
|
h.audit.RecordOK(c, "c2", "task_delete", "批量删除 C2 任务", "c2_task", "", map[string]interface{}{
|
||||||
|
"count": n, "ids": req.IDs,
|
||||||
|
})
|
||||||
|
}
|
||||||
c.JSON(http.StatusOK, gin.H{"deleted": n})
|
c.JSON(http.StatusOK, gin.H{"deleted": n})
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -457,6 +486,11 @@ func (h *C2Handler) CreateTask(c *gin.Context) {
|
|||||||
c.JSON(code, gin.H{"error": err.Error()})
|
c.JSON(code, gin.H{"error": err.Error()})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
if h.audit != nil {
|
||||||
|
h.audit.RecordOK(c, "c2", "task_create", "创建 C2 任务", "c2_task", task.ID, map[string]interface{}{
|
||||||
|
"session_id": req.SessionID, "task_type": req.TaskType,
|
||||||
|
})
|
||||||
|
}
|
||||||
c.JSON(http.StatusOK, gin.H{"task": task})
|
c.JSON(http.StatusOK, gin.H{"task": task})
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -471,6 +505,9 @@ func (h *C2Handler) CancelTask(c *gin.Context) {
|
|||||||
c.JSON(code, gin.H{"error": err.Error()})
|
c.JSON(code, gin.H{"error": err.Error()})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
if h.audit != nil {
|
||||||
|
h.audit.RecordOK(c, "c2", "task_cancel", "取消 C2 任务", "c2_task", id, nil)
|
||||||
|
}
|
||||||
c.JSON(http.StatusOK, gin.H{"cancelled": true})
|
c.JSON(http.StatusOK, gin.H{"cancelled": true})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -12,6 +12,8 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
"unicode/utf8"
|
"unicode/utf8"
|
||||||
|
|
||||||
|
"cyberstrike-ai/internal/audit"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"go.uber.org/zap"
|
"go.uber.org/zap"
|
||||||
)
|
)
|
||||||
@@ -24,6 +26,12 @@ const (
|
|||||||
// ChatUploadsHandler 对话中上传附件(chat_uploads 目录)的管理 API
|
// ChatUploadsHandler 对话中上传附件(chat_uploads 目录)的管理 API
|
||||||
type ChatUploadsHandler struct {
|
type ChatUploadsHandler struct {
|
||||||
logger *zap.Logger
|
logger *zap.Logger
|
||||||
|
audit *audit.Service
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetAudit wires platform audit logging.
|
||||||
|
func (h *ChatUploadsHandler) SetAudit(s *audit.Service) {
|
||||||
|
h.audit = s
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewChatUploadsHandler 创建处理器
|
// NewChatUploadsHandler 创建处理器
|
||||||
@@ -230,6 +238,9 @@ func (h *ChatUploadsHandler) Delete(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if h.audit != nil {
|
||||||
|
h.audit.RecordOK(c, "file", "delete", "删除对话附件", "chat_upload", body.Path, nil)
|
||||||
|
}
|
||||||
c.JSON(http.StatusOK, gin.H{"ok": true})
|
c.JSON(http.StatusOK, gin.H{"ok": true})
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -503,6 +514,11 @@ func (h *ChatUploadsHandler) Upload(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
rel, _ := filepath.Rel(root, fullPath)
|
rel, _ := filepath.Rel(root, fullPath)
|
||||||
absSaved, _ := filepath.Abs(fullPath)
|
absSaved, _ := filepath.Abs(fullPath)
|
||||||
|
if h.audit != nil {
|
||||||
|
h.audit.RecordOK(c, "file", "upload", "上传对话附件", "chat_upload", filepath.ToSlash(rel), map[string]interface{}{
|
||||||
|
"name": unique,
|
||||||
|
})
|
||||||
|
}
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
"ok": true,
|
"ok": true,
|
||||||
"relativePath": filepath.ToSlash(rel),
|
"relativePath": filepath.ToSlash(rel),
|
||||||
|
|||||||
@@ -14,6 +14,7 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"cyberstrike-ai/internal/agents"
|
"cyberstrike-ai/internal/agents"
|
||||||
|
"cyberstrike-ai/internal/audit"
|
||||||
"cyberstrike-ai/internal/config"
|
"cyberstrike-ai/internal/config"
|
||||||
"cyberstrike-ai/internal/knowledge"
|
"cyberstrike-ai/internal/knowledge"
|
||||||
"cyberstrike-ai/internal/mcp"
|
"cyberstrike-ai/internal/mcp"
|
||||||
@@ -87,6 +88,7 @@ type ConfigHandler struct {
|
|||||||
knowledgeInitializer KnowledgeInitializer // 知识库初始化器(可选)
|
knowledgeInitializer KnowledgeInitializer // 知识库初始化器(可选)
|
||||||
appUpdater AppUpdater // App更新器(可选)
|
appUpdater AppUpdater // App更新器(可选)
|
||||||
robotRestarter RobotRestarter // 机器人连接重启器(可选),ApplyConfig 时重启钉钉/飞书
|
robotRestarter RobotRestarter // 机器人连接重启器(可选),ApplyConfig 时重启钉钉/飞书
|
||||||
|
audit *audit.Service
|
||||||
logger *zap.Logger
|
logger *zap.Logger
|
||||||
mu sync.RWMutex
|
mu sync.RWMutex
|
||||||
lastEmbeddingConfig *config.EmbeddingConfig // 上一次的嵌入模型配置(用于检测变更)
|
lastEmbeddingConfig *config.EmbeddingConfig // 上一次的嵌入模型配置(用于检测变更)
|
||||||
@@ -206,6 +208,32 @@ func (h *ConfigHandler) SetRobotRestarter(restarter RobotRestarter) {
|
|||||||
h.robotRestarter = restarter
|
h.robotRestarter = restarter
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetAudit wires platform audit logging.
|
||||||
|
func (h *ConfigHandler) SetAudit(s *audit.Service) {
|
||||||
|
h.mu.Lock()
|
||||||
|
defer h.mu.Unlock()
|
||||||
|
h.audit = s
|
||||||
|
}
|
||||||
|
|
||||||
|
// ApplyWechatRobotBinding 微信 iLink 扫码绑定成功后写入配置并重启机器人连接
|
||||||
|
func (h *ConfigHandler) ApplyWechatRobotBinding(wc config.RobotWechatConfig) error {
|
||||||
|
h.mu.Lock()
|
||||||
|
wc.Enabled = true
|
||||||
|
h.config.Robots.Wechat = wc
|
||||||
|
h.mu.Unlock()
|
||||||
|
if err := h.saveConfig(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if h.robotRestarter != nil {
|
||||||
|
h.robotRestarter.RestartRobotConnections()
|
||||||
|
}
|
||||||
|
h.logger.Info("微信机器人绑定已保存",
|
||||||
|
zap.String("ilink_bot_id", wc.ILinkBotID),
|
||||||
|
zap.Bool("enabled", wc.Enabled),
|
||||||
|
)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// GetConfigResponse 获取配置响应
|
// GetConfigResponse 获取配置响应
|
||||||
type GetConfigResponse struct {
|
type GetConfigResponse struct {
|
||||||
OpenAI config.OpenAIConfig `json:"openai"`
|
OpenAI config.OpenAIConfig `json:"openai"`
|
||||||
@@ -291,7 +319,7 @@ func (h *ConfigHandler) GetConfig(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
multiPub := config.MultiAgentPublic{
|
multiPub := config.MultiAgentPublic{
|
||||||
Enabled: h.config.MultiAgent.Enabled,
|
Enabled: h.config.MultiAgent.Enabled,
|
||||||
RobotUseMultiAgent: h.config.MultiAgent.RobotUseMultiAgent,
|
RobotDefaultAgentMode: config.NormalizeRobotAgentMode(h.config.MultiAgent),
|
||||||
BatchUseMultiAgent: h.config.MultiAgent.BatchUseMultiAgent,
|
BatchUseMultiAgent: h.config.MultiAgent.BatchUseMultiAgent,
|
||||||
SubAgentCount: subAgentCount,
|
SubAgentCount: subAgentCount,
|
||||||
Orchestration: config.NormalizeMultiAgentOrchestration(h.config.MultiAgent.Orchestration),
|
Orchestration: config.NormalizeMultiAgentOrchestration(h.config.MultiAgent.Orchestration),
|
||||||
@@ -735,6 +763,7 @@ func (h *ConfigHandler) UpdateConfig(c *gin.Context) {
|
|||||||
if req.Robots != nil {
|
if req.Robots != nil {
|
||||||
h.config.Robots = *req.Robots
|
h.config.Robots = *req.Robots
|
||||||
h.logger.Info("更新机器人配置",
|
h.logger.Info("更新机器人配置",
|
||||||
|
zap.Bool("wechat_enabled", h.config.Robots.Wechat.Enabled),
|
||||||
zap.Bool("wecom_enabled", h.config.Robots.Wecom.Enabled),
|
zap.Bool("wecom_enabled", h.config.Robots.Wecom.Enabled),
|
||||||
zap.Bool("dingtalk_enabled", h.config.Robots.Dingtalk.Enabled),
|
zap.Bool("dingtalk_enabled", h.config.Robots.Dingtalk.Enabled),
|
||||||
zap.Bool("lark_enabled", h.config.Robots.Lark.Enabled),
|
zap.Bool("lark_enabled", h.config.Robots.Lark.Enabled),
|
||||||
@@ -750,15 +779,21 @@ func (h *ConfigHandler) UpdateConfig(c *gin.Context) {
|
|||||||
// 多代理标量(sub_agents 等仍由 config.yaml 维护)
|
// 多代理标量(sub_agents 等仍由 config.yaml 维护)
|
||||||
if req.MultiAgent != nil {
|
if req.MultiAgent != nil {
|
||||||
h.config.MultiAgent.Enabled = req.MultiAgent.Enabled
|
h.config.MultiAgent.Enabled = req.MultiAgent.Enabled
|
||||||
h.config.MultiAgent.RobotUseMultiAgent = req.MultiAgent.RobotUseMultiAgent
|
|
||||||
h.config.MultiAgent.BatchUseMultiAgent = req.MultiAgent.BatchUseMultiAgent
|
h.config.MultiAgent.BatchUseMultiAgent = req.MultiAgent.BatchUseMultiAgent
|
||||||
|
if mode := strings.TrimSpace(req.MultiAgent.RobotDefaultAgentMode); mode != "" {
|
||||||
|
h.config.MultiAgent.RobotDefaultAgentMode = mode
|
||||||
|
} else {
|
||||||
|
h.config.MultiAgent.RobotDefaultAgentMode = "react"
|
||||||
|
}
|
||||||
if req.MultiAgent.PlanExecuteLoopMaxIterations != nil {
|
if req.MultiAgent.PlanExecuteLoopMaxIterations != nil {
|
||||||
h.config.MultiAgent.PlanExecuteLoopMaxIterations = *req.MultiAgent.PlanExecuteLoopMaxIterations
|
h.config.MultiAgent.PlanExecuteLoopMaxIterations = *req.MultiAgent.PlanExecuteLoopMaxIterations
|
||||||
}
|
}
|
||||||
h.config.MultiAgent.EinoMiddleware.ToolSearchAlwaysVisibleTools = dedupeToolNameList(req.MultiAgent.ToolSearchAlwaysVisibleTools)
|
if req.MultiAgent.ToolSearchAlwaysVisibleTools != nil {
|
||||||
|
h.config.MultiAgent.EinoMiddleware.ToolSearchAlwaysVisibleTools = dedupeToolNameList(*req.MultiAgent.ToolSearchAlwaysVisibleTools)
|
||||||
|
}
|
||||||
h.logger.Info("更新多代理配置",
|
h.logger.Info("更新多代理配置",
|
||||||
zap.Bool("enabled", h.config.MultiAgent.Enabled),
|
zap.Bool("enabled", h.config.MultiAgent.Enabled),
|
||||||
zap.Bool("robot_use_multi_agent", h.config.MultiAgent.RobotUseMultiAgent),
|
zap.String("robot_default_agent_mode", config.NormalizeRobotAgentMode(h.config.MultiAgent)),
|
||||||
zap.Bool("batch_use_multi_agent", h.config.MultiAgent.BatchUseMultiAgent),
|
zap.Bool("batch_use_multi_agent", h.config.MultiAgent.BatchUseMultiAgent),
|
||||||
zap.Int("plan_execute_loop_max_iterations", h.config.MultiAgent.PlanExecuteLoopMaxIterations),
|
zap.Int("plan_execute_loop_max_iterations", h.config.MultiAgent.PlanExecuteLoopMaxIterations),
|
||||||
zap.Int("tool_search_always_visible_tools", len(h.config.MultiAgent.EinoMiddleware.ToolSearchAlwaysVisibleTools)),
|
zap.Int("tool_search_always_visible_tools", len(h.config.MultiAgent.EinoMiddleware.ToolSearchAlwaysVisibleTools)),
|
||||||
@@ -881,6 +916,9 @@ func (h *ConfigHandler) UpdateConfig(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if h.audit != nil {
|
||||||
|
h.audit.RecordOK(c, "config", "update", "更新内存配置", "config", "", nil)
|
||||||
|
}
|
||||||
c.JSON(http.StatusOK, gin.H{"message": "配置已更新"})
|
c.JSON(http.StatusOK, gin.H{"message": "配置已更新"})
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1011,6 +1049,9 @@ func (h *ConfigHandler) ApplyConfig(c *gin.Context) {
|
|||||||
h.logger.Info("检测到知识库从禁用变为启用,开始动态初始化知识库组件")
|
h.logger.Info("检测到知识库从禁用变为启用,开始动态初始化知识库组件")
|
||||||
if _, err := knowledgeInitializer(); err != nil {
|
if _, err := knowledgeInitializer(); err != nil {
|
||||||
h.logger.Error("动态初始化知识库失败", zap.Error(err))
|
h.logger.Error("动态初始化知识库失败", zap.Error(err))
|
||||||
|
if h.audit != nil {
|
||||||
|
h.audit.RecordFail(c, "config", "apply", "应用配置失败:初始化知识库", map[string]interface{}{"error": err.Error()})
|
||||||
|
}
|
||||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "初始化知识库失败: " + err.Error()})
|
c.JSON(http.StatusInternalServerError, gin.H{"error": "初始化知识库失败: " + err.Error()})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -1045,6 +1086,9 @@ func (h *ConfigHandler) ApplyConfig(c *gin.Context) {
|
|||||||
h.logger.Info("开始重新初始化知识库组件(嵌入模型配置已变更)")
|
h.logger.Info("开始重新初始化知识库组件(嵌入模型配置已变更)")
|
||||||
if _, err := reinitKnowledgeInitializer(); err != nil {
|
if _, err := reinitKnowledgeInitializer(); err != nil {
|
||||||
h.logger.Error("重新初始化知识库失败", zap.Error(err))
|
h.logger.Error("重新初始化知识库失败", zap.Error(err))
|
||||||
|
if h.audit != nil {
|
||||||
|
h.audit.RecordFail(c, "config", "apply", "应用配置失败:重新初始化知识库", map[string]interface{}{"error": err.Error()})
|
||||||
|
}
|
||||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "重新初始化知识库失败: " + err.Error()})
|
c.JSON(http.StatusInternalServerError, gin.H{"error": "重新初始化知识库失败: " + err.Error()})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -1058,6 +1102,9 @@ func (h *ConfigHandler) ApplyConfig(c *gin.Context) {
|
|||||||
if c2Rt != nil {
|
if c2Rt != nil {
|
||||||
if err := c2Rt.ReconcileC2AfterConfigApply(); err != nil {
|
if err := c2Rt.ReconcileC2AfterConfigApply(); err != nil {
|
||||||
h.logger.Error("C2 配置应用失败", zap.Error(err))
|
h.logger.Error("C2 配置应用失败", zap.Error(err))
|
||||||
|
if h.audit != nil {
|
||||||
|
h.audit.RecordFail(c, "config", "apply", "应用配置失败:C2", map[string]interface{}{"error": err.Error()})
|
||||||
|
}
|
||||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "C2 启动失败: " + err.Error()})
|
c.JSON(http.StatusInternalServerError, gin.H{"error": "C2 启动失败: " + err.Error()})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -1199,6 +1246,20 @@ func (h *ConfigHandler) ApplyConfig(c *gin.Context) {
|
|||||||
zap.Int("tools_count", len(h.config.Security.Tools)),
|
zap.Int("tools_count", len(h.config.Security.Tools)),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if h.audit != nil {
|
||||||
|
h.audit.Record(c, audit.Entry{
|
||||||
|
Category: "config",
|
||||||
|
Action: "apply",
|
||||||
|
Result: "success",
|
||||||
|
Message: "配置已应用",
|
||||||
|
Detail: map[string]interface{}{
|
||||||
|
"tools_count": len(h.config.Security.Tools),
|
||||||
|
"knowledge_enabled": h.config.Knowledge.Enabled,
|
||||||
|
"c2_enabled": h.config.C2.EnabledEffective(),
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
"message": "配置已应用",
|
"message": "配置已应用",
|
||||||
"tools_count": len(h.config.Security.Tools),
|
"tools_count": len(h.config.Security.Tools),
|
||||||
@@ -1474,6 +1535,20 @@ func updateRobotsConfig(doc *yaml.Node, cfg config.RobotsConfig) {
|
|||||||
root := doc.Content[0]
|
root := doc.Content[0]
|
||||||
robotsNode := ensureMap(root, "robots")
|
robotsNode := ensureMap(root, "robots")
|
||||||
|
|
||||||
|
if cfg.Session.StrictUserIdentity != nil {
|
||||||
|
sessionNode := ensureMap(robotsNode, "session")
|
||||||
|
setBoolInMap(sessionNode, "strict_user_identity", *cfg.Session.StrictUserIdentity)
|
||||||
|
}
|
||||||
|
|
||||||
|
wechatNode := ensureMap(robotsNode, "wechat")
|
||||||
|
setBoolInMap(wechatNode, "enabled", cfg.Wechat.Enabled)
|
||||||
|
setStringInMap(wechatNode, "bot_token", cfg.Wechat.BotToken)
|
||||||
|
setStringInMap(wechatNode, "ilink_bot_id", cfg.Wechat.ILinkBotID)
|
||||||
|
setStringInMap(wechatNode, "ilink_user_id", cfg.Wechat.ILinkUserID)
|
||||||
|
setStringInMap(wechatNode, "base_url", cfg.Wechat.BaseURL)
|
||||||
|
setStringInMap(wechatNode, "bot_type", cfg.Wechat.BotType)
|
||||||
|
setStringInMap(wechatNode, "bot_agent", cfg.Wechat.BotAgent)
|
||||||
|
|
||||||
wecomNode := ensureMap(robotsNode, "wecom")
|
wecomNode := ensureMap(robotsNode, "wecom")
|
||||||
setBoolInMap(wecomNode, "enabled", cfg.Wecom.Enabled)
|
setBoolInMap(wecomNode, "enabled", cfg.Wecom.Enabled)
|
||||||
setStringInMap(wecomNode, "token", cfg.Wecom.Token)
|
setStringInMap(wecomNode, "token", cfg.Wecom.Token)
|
||||||
@@ -1486,19 +1561,21 @@ func updateRobotsConfig(doc *yaml.Node, cfg config.RobotsConfig) {
|
|||||||
setBoolInMap(dingtalkNode, "enabled", cfg.Dingtalk.Enabled)
|
setBoolInMap(dingtalkNode, "enabled", cfg.Dingtalk.Enabled)
|
||||||
setStringInMap(dingtalkNode, "client_id", cfg.Dingtalk.ClientID)
|
setStringInMap(dingtalkNode, "client_id", cfg.Dingtalk.ClientID)
|
||||||
setStringInMap(dingtalkNode, "client_secret", cfg.Dingtalk.ClientSecret)
|
setStringInMap(dingtalkNode, "client_secret", cfg.Dingtalk.ClientSecret)
|
||||||
|
setBoolInMap(dingtalkNode, "allow_conversation_id_fallback", cfg.Dingtalk.AllowConversationIDFallback)
|
||||||
|
|
||||||
larkNode := ensureMap(robotsNode, "lark")
|
larkNode := ensureMap(robotsNode, "lark")
|
||||||
setBoolInMap(larkNode, "enabled", cfg.Lark.Enabled)
|
setBoolInMap(larkNode, "enabled", cfg.Lark.Enabled)
|
||||||
setStringInMap(larkNode, "app_id", cfg.Lark.AppID)
|
setStringInMap(larkNode, "app_id", cfg.Lark.AppID)
|
||||||
setStringInMap(larkNode, "app_secret", cfg.Lark.AppSecret)
|
setStringInMap(larkNode, "app_secret", cfg.Lark.AppSecret)
|
||||||
setStringInMap(larkNode, "verify_token", cfg.Lark.VerifyToken)
|
setStringInMap(larkNode, "verify_token", cfg.Lark.VerifyToken)
|
||||||
|
setBoolInMap(larkNode, "allow_chat_id_fallback", cfg.Lark.AllowChatIDFallback)
|
||||||
}
|
}
|
||||||
|
|
||||||
func updateMultiAgentConfig(doc *yaml.Node, cfg config.MultiAgentConfig) {
|
func updateMultiAgentConfig(doc *yaml.Node, cfg config.MultiAgentConfig) {
|
||||||
root := doc.Content[0]
|
root := doc.Content[0]
|
||||||
maNode := ensureMap(root, "multi_agent")
|
maNode := ensureMap(root, "multi_agent")
|
||||||
setBoolInMap(maNode, "enabled", cfg.Enabled)
|
setBoolInMap(maNode, "enabled", cfg.Enabled)
|
||||||
setBoolInMap(maNode, "robot_use_multi_agent", cfg.RobotUseMultiAgent)
|
setStringInMap(maNode, "robot_default_agent_mode", config.NormalizeRobotAgentMode(cfg))
|
||||||
setBoolInMap(maNode, "batch_use_multi_agent", cfg.BatchUseMultiAgent)
|
setBoolInMap(maNode, "batch_use_multi_agent", cfg.BatchUseMultiAgent)
|
||||||
setIntInMap(maNode, "plan_execute_loop_max_iterations", cfg.PlanExecuteLoopMaxIterations)
|
setIntInMap(maNode, "plan_execute_loop_max_iterations", cfg.PlanExecuteLoopMaxIterations)
|
||||||
mwNode := ensureMap(maNode, "eino_middleware")
|
mwNode := ensureMap(maNode, "eino_middleware")
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ import (
|
|||||||
"net/http"
|
"net/http"
|
||||||
"strconv"
|
"strconv"
|
||||||
|
|
||||||
|
"cyberstrike-ai/internal/audit"
|
||||||
"cyberstrike-ai/internal/database"
|
"cyberstrike-ai/internal/database"
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"go.uber.org/zap"
|
"go.uber.org/zap"
|
||||||
@@ -14,6 +15,12 @@ import (
|
|||||||
type ConversationHandler struct {
|
type ConversationHandler struct {
|
||||||
db *database.DB
|
db *database.DB
|
||||||
logger *zap.Logger
|
logger *zap.Logger
|
||||||
|
audit *audit.Service
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetAudit wires platform audit logging.
|
||||||
|
func (h *ConversationHandler) SetAudit(s *audit.Service) {
|
||||||
|
h.audit = s
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewConversationHandler 创建新的对话处理器
|
// NewConversationHandler 创建新的对话处理器
|
||||||
@@ -42,7 +49,7 @@ func (h *ConversationHandler) CreateConversation(c *gin.Context) {
|
|||||||
title = "新对话"
|
title = "新对话"
|
||||||
}
|
}
|
||||||
|
|
||||||
conv, err := h.db.CreateConversation(title)
|
conv, err := h.db.CreateConversation(title, audit.ConversationCreateMetaFromGin(c, "api"))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
h.logger.Error("创建对话失败", zap.Error(err))
|
h.logger.Error("创建对话失败", zap.Error(err))
|
||||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||||
@@ -117,6 +124,8 @@ func (h *ConversationHandler) GetMessageProcessDetails(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
details = database.DedupeConsecutiveProcessDetails(details)
|
||||||
|
|
||||||
// 转换为前端期望的 JSON 结构(与 GetConversation 中 processDetails 结构一致)
|
// 转换为前端期望的 JSON 结构(与 GetConversation 中 processDetails 结构一致)
|
||||||
out := make([]map[string]interface{}, 0, len(details))
|
out := make([]map[string]interface{}, 0, len(details))
|
||||||
for _, d := range details {
|
for _, d := range details {
|
||||||
@@ -187,6 +196,17 @@ func (h *ConversationHandler) DeleteConversation(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if h.audit != nil {
|
||||||
|
h.audit.Record(c, audit.Entry{
|
||||||
|
Category: "conversation",
|
||||||
|
Action: "delete",
|
||||||
|
Result: "success",
|
||||||
|
ResourceType: "conversation",
|
||||||
|
ResourceID: id,
|
||||||
|
Message: "删除对话",
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
c.JSON(http.StatusOK, gin.H{"message": "删除成功"})
|
c.JSON(http.StatusOK, gin.H{"message": "删除成功"})
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -225,6 +245,12 @@ func (h *ConversationHandler) DeleteConversationTurn(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if h.audit != nil {
|
||||||
|
h.audit.RecordOK(c, "conversation", "delete_turn", "删除对话轮次", "conversation", conversationID, map[string]interface{}{
|
||||||
|
"message_id": req.MessageID,
|
||||||
|
"deleted": len(deletedIDs),
|
||||||
|
})
|
||||||
|
}
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
"deletedMessageIds": deletedIDs,
|
"deletedMessageIds": deletedIDs,
|
||||||
"message": "ok",
|
"message": "ok",
|
||||||
|
|||||||
@@ -0,0 +1,122 @@
|
|||||||
|
package handler
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"cyberstrike-ai/internal/agent"
|
||||||
|
"cyberstrike-ai/internal/multiagent"
|
||||||
|
)
|
||||||
|
|
||||||
|
func (h *AgentHandler) einoRunRetryMaxAttempts() int {
|
||||||
|
if h.config != nil {
|
||||||
|
return multiagent.RunRetryMaxAttemptsFromConfig(&h.config.MultiAgent.EinoMiddleware)
|
||||||
|
}
|
||||||
|
return multiagent.RunRetryMaxAttemptsFromConfig(nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *AgentHandler) einoRunRetryMaxBackoffSec() int {
|
||||||
|
if h.config != nil && h.config.MultiAgent.EinoMiddleware.RunRetryMaxBackoffSec > 0 {
|
||||||
|
return h.config.MultiAgent.EinoMiddleware.RunRetryMaxBackoffSec
|
||||||
|
}
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
|
// applyEinoTraceResumeSegment 中断并继续:persist last_react_* → loadHistory,可选替换下一段 user 文案。
|
||||||
|
func (h *AgentHandler) applyEinoTraceResumeSegment(
|
||||||
|
conversationID string,
|
||||||
|
result *multiagent.RunResult,
|
||||||
|
curHistory *[]agent.ChatMessage,
|
||||||
|
curFinalMessage *string,
|
||||||
|
segmentUserMessage string,
|
||||||
|
) {
|
||||||
|
if shouldPersistEinoAgentTraceAfterRunError(context.Background()) {
|
||||||
|
h.persistEinoAgentTraceForResume(conversationID, result)
|
||||||
|
}
|
||||||
|
if hist, err := h.loadHistoryFromAgentTrace(conversationID); err == nil && len(hist) > 0 {
|
||||||
|
*curHistory = hist
|
||||||
|
}
|
||||||
|
if segmentUserMessage != "" {
|
||||||
|
*curFinalMessage = segmentUserMessage
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// applyEinoTransientRetrySegment 临时错误重试:恢复轨迹并保留本请求原始 user 文案(不注入续跑说明)。
|
||||||
|
// segmentUserMessage 为本轮 HTTP 请求开始时用户发送的内容,避免因清空 finalMessage 而丢失「你好」等短句。
|
||||||
|
func (h *AgentHandler) applyEinoTransientRetrySegment(
|
||||||
|
conversationID string,
|
||||||
|
result *multiagent.RunResult,
|
||||||
|
curHistory *[]agent.ChatMessage,
|
||||||
|
curFinalMessage *string,
|
||||||
|
segmentUserMessage string,
|
||||||
|
) {
|
||||||
|
if shouldPersistEinoAgentTraceAfterRunError(context.Background()) {
|
||||||
|
h.persistEinoAgentTraceForResume(conversationID, result)
|
||||||
|
}
|
||||||
|
if hist, err := h.loadHistoryFromAgentTrace(conversationID); err == nil && len(hist) > 0 {
|
||||||
|
*curHistory = hist
|
||||||
|
}
|
||||||
|
if s := strings.TrimSpace(segmentUserMessage); s != "" {
|
||||||
|
*curFinalMessage = segmentUserMessage
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// handleEinoTransientRetryContinue 在 SSE 任务循环内处理临时错误重试;返回 true 表示外层 for 应 continue。
|
||||||
|
func (h *AgentHandler) handleEinoTransientRetryContinue(
|
||||||
|
baseCtx context.Context,
|
||||||
|
conversationID string,
|
||||||
|
result *multiagent.RunResult,
|
||||||
|
runErr error,
|
||||||
|
transientAttempts *int,
|
||||||
|
curHistory *[]agent.ChatMessage,
|
||||||
|
curFinalMessage *string,
|
||||||
|
segmentUserMessage string,
|
||||||
|
progressCallback func(eventType, message string, data interface{}),
|
||||||
|
sendProgress func(msg string, extra map[string]interface{}),
|
||||||
|
) (handled bool, fatal error) {
|
||||||
|
if !errors.Is(runErr, multiagent.ErrTransientRetryContinue) {
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
maxAttempts := h.einoRunRetryMaxAttempts()
|
||||||
|
*transientAttempts++
|
||||||
|
if *transientAttempts > maxAttempts {
|
||||||
|
if shouldPersistEinoAgentTraceAfterRunError(baseCtx) {
|
||||||
|
h.persistEinoAgentTraceForResume(conversationID, result)
|
||||||
|
}
|
||||||
|
return false, errors.New("transient retry exhausted: " + runErr.Error())
|
||||||
|
}
|
||||||
|
attemptNo := *transientAttempts
|
||||||
|
backoff := multiagent.TransientRetryBackoff(attemptNo-1, h.einoRunRetryMaxBackoffSec())
|
||||||
|
if progressCallback != nil {
|
||||||
|
progressCallback("eino_run_retry", fmt.Sprintf("遇到临时错误,%d 秒后第 %d/%d 次重试…", int(backoff.Seconds()), attemptNo, maxAttempts), map[string]interface{}{
|
||||||
|
"conversationId": conversationID,
|
||||||
|
"source": "eino",
|
||||||
|
"attempt": attemptNo,
|
||||||
|
"maxAttempts": maxAttempts,
|
||||||
|
"backoffSec": int(backoff.Seconds()),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
select {
|
||||||
|
case <-baseCtx.Done():
|
||||||
|
return false, context.Cause(baseCtx)
|
||||||
|
case <-time.After(backoff):
|
||||||
|
}
|
||||||
|
h.applyEinoTransientRetrySegment(conversationID, result, curHistory, curFinalMessage, segmentUserMessage)
|
||||||
|
if progressCallback != nil {
|
||||||
|
progressCallback("eino_run_retry", "已恢复上下文,正在重试…", map[string]interface{}{
|
||||||
|
"conversationId": conversationID,
|
||||||
|
"source": "eino",
|
||||||
|
"attempt": attemptNo,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
if sendProgress != nil {
|
||||||
|
sendProgress("正在重试…", map[string]interface{}{
|
||||||
|
"conversationId": conversationID,
|
||||||
|
"source": "transient_retry",
|
||||||
|
})
|
||||||
|
}
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
@@ -90,7 +90,7 @@ func (h *AgentHandler) EinoSingleAgentLoopStream(c *gin.Context) {
|
|||||||
zap.String("conversationId", req.ConversationID),
|
zap.String("conversationId", req.ConversationID),
|
||||||
)
|
)
|
||||||
|
|
||||||
prep, err := h.prepareMultiAgentSession(&req)
|
prep, err := h.prepareMultiAgentSession(&req, c, "eino_agent_stream")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
sendEvent("error", err.Error(), nil)
|
sendEvent("error", err.Error(), nil)
|
||||||
sendEvent("done", "", nil)
|
sendEvent("done", "", nil)
|
||||||
@@ -119,6 +119,7 @@ func (h *AgentHandler) EinoSingleAgentLoopStream(c *gin.Context) {
|
|||||||
|
|
||||||
var cancelWithCause context.CancelCauseFunc
|
var cancelWithCause context.CancelCauseFunc
|
||||||
curFinalMessage := prep.FinalMessage
|
curFinalMessage := prep.FinalMessage
|
||||||
|
segmentUserMessage := prep.FinalMessage // 本请求原始用户句,临时重试时不得丢失
|
||||||
curHistory := prep.History
|
curHistory := prep.History
|
||||||
roleTools := prep.RoleTools
|
roleTools := prep.RoleTools
|
||||||
|
|
||||||
@@ -176,9 +177,41 @@ func (h *AgentHandler) EinoSingleAgentLoopStream(c *gin.Context) {
|
|||||||
taskOwned = true
|
taskOwned = true
|
||||||
|
|
||||||
var cumulativeMCPExecutionIDs []string
|
var cumulativeMCPExecutionIDs []string
|
||||||
|
var transientRunAttempts int
|
||||||
|
// 同一请求内分段续跑时,主代理 iteration 事件按偏移累计,避免 UI 出现「第3轮 → 第1轮」回跳。
|
||||||
|
var mainIterationOffset int
|
||||||
|
|
||||||
for {
|
for {
|
||||||
progressCallback := h.createProgressCallback(taskCtx, cancelWithCause, conversationID, assistantMessageID, sendEvent)
|
segmentMainIterationMax := 0
|
||||||
|
rawProgressCallback := h.createProgressCallback(taskCtx, cancelWithCause, conversationID, assistantMessageID, sendEvent)
|
||||||
|
progressCallback := func(eventType, message string, data interface{}) {
|
||||||
|
if eventType == "iteration" {
|
||||||
|
if m, ok := data.(map[string]interface{}); ok {
|
||||||
|
if scope, _ := m["einoScope"].(string); scope == "main" {
|
||||||
|
raw := 0
|
||||||
|
switch v := m["iteration"].(type) {
|
||||||
|
case int:
|
||||||
|
raw = v
|
||||||
|
case int32:
|
||||||
|
raw = int(v)
|
||||||
|
case int64:
|
||||||
|
raw = int(v)
|
||||||
|
case float64:
|
||||||
|
raw = int(v)
|
||||||
|
case float32:
|
||||||
|
raw = int(v)
|
||||||
|
}
|
||||||
|
if raw > 0 {
|
||||||
|
if raw > segmentMainIterationMax {
|
||||||
|
segmentMainIterationMax = raw
|
||||||
|
}
|
||||||
|
m["iteration"] = raw + mainIterationOffset
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
rawProgressCallback(eventType, message, data)
|
||||||
|
}
|
||||||
taskCtxLoop := mcp.WithMCPConversationID(taskCtx, conversationID)
|
taskCtxLoop := mcp.WithMCPConversationID(taskCtx, conversationID)
|
||||||
taskCtxLoop = mcp.WithToolRunRegistry(taskCtxLoop, h.tasks)
|
taskCtxLoop = mcp.WithToolRunRegistry(taskCtxLoop, h.tasks)
|
||||||
taskCtxLoop = multiagent.WithHITLToolInterceptor(taskCtxLoop, func(ctx context.Context, toolName, arguments string) (string, error) {
|
taskCtxLoop = multiagent.WithHITLToolInterceptor(taskCtxLoop, func(ctx context.Context, toolName, arguments string) (string, error) {
|
||||||
@@ -198,16 +231,36 @@ func (h *AgentHandler) EinoSingleAgentLoopStream(c *gin.Context) {
|
|||||||
progressCallback,
|
progressCallback,
|
||||||
chatReasoningToClientIntent(req.Reasoning),
|
chatReasoningToClientIntent(req.Reasoning),
|
||||||
)
|
)
|
||||||
timeoutCancel()
|
|
||||||
|
|
||||||
if result != nil && len(result.MCPExecutionIDs) > 0 {
|
if result != nil && len(result.MCPExecutionIDs) > 0 {
|
||||||
cumulativeMCPExecutionIDs = mergeMCPExecutionIDLists(cumulativeMCPExecutionIDs, result.MCPExecutionIDs)
|
cumulativeMCPExecutionIDs = mergeMCPExecutionIDLists(cumulativeMCPExecutionIDs, result.MCPExecutionIDs)
|
||||||
}
|
}
|
||||||
|
|
||||||
if runErr == nil {
|
if runErr == nil {
|
||||||
|
// 任一段成功完成后,重置临时错误重试窗口(次数/退避从头开始)。
|
||||||
|
transientRunAttempts = 0
|
||||||
|
timeoutCancel()
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
|
|
||||||
|
handled, fatalErr := h.handleEinoTransientRetryContinue(
|
||||||
|
baseCtx, conversationID, result, runErr, &transientRunAttempts,
|
||||||
|
&curHistory, &curFinalMessage, segmentUserMessage, progressCallback,
|
||||||
|
func(msg string, extra map[string]interface{}) { sendEvent("progress", msg, extra) },
|
||||||
|
)
|
||||||
|
if handled {
|
||||||
|
mainIterationOffset += segmentMainIterationMax
|
||||||
|
timeoutCancel()
|
||||||
|
baseCtx, cancelWithCause = context.WithCancelCause(context.Background())
|
||||||
|
h.tasks.BindTaskCancel(conversationID, cancelWithCause)
|
||||||
|
taskCtx, timeoutCancel = context.WithTimeout(baseCtx, 600*time.Minute)
|
||||||
|
h.tasks.UpdateTaskStatus(conversationID, "running")
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if fatalErr != nil {
|
||||||
|
runErr = fatalErr
|
||||||
|
}
|
||||||
|
|
||||||
cause := context.Cause(baseCtx)
|
cause := context.Cause(baseCtx)
|
||||||
if errors.Is(cause, multiagent.ErrInterruptContinue) {
|
if errors.Is(cause, multiagent.ErrInterruptContinue) {
|
||||||
if shouldPersistEinoAgentTraceAfterRunError(baseCtx) {
|
if shouldPersistEinoAgentTraceAfterRunError(baseCtx) {
|
||||||
@@ -231,10 +284,14 @@ func (h *AgentHandler) EinoSingleAgentLoopStream(c *gin.Context) {
|
|||||||
"conversationId": conversationID,
|
"conversationId": conversationID,
|
||||||
"source": "interrupt_continue",
|
"source": "interrupt_continue",
|
||||||
})
|
})
|
||||||
h.tasks.UpdateTaskStatus(conversationID, "running")
|
mainIterationOffset += segmentMainIterationMax
|
||||||
|
// 非临时错误分段续跑(用户中断并继续)时,清空 transient 计数,避免跨分段累加。
|
||||||
|
transientRunAttempts = 0
|
||||||
|
timeoutCancel()
|
||||||
baseCtx, cancelWithCause = context.WithCancelCause(context.Background())
|
baseCtx, cancelWithCause = context.WithCancelCause(context.Background())
|
||||||
h.tasks.BindTaskCancel(conversationID, cancelWithCause)
|
h.tasks.BindTaskCancel(conversationID, cancelWithCause)
|
||||||
taskCtx, timeoutCancel = context.WithTimeout(baseCtx, 600*time.Minute)
|
taskCtx, timeoutCancel = context.WithTimeout(baseCtx, 600*time.Minute)
|
||||||
|
h.tasks.UpdateTaskStatus(conversationID, "running")
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -261,6 +318,7 @@ func (h *AgentHandler) EinoSingleAgentLoopStream(c *gin.Context) {
|
|||||||
"messageId": assistantMessageID,
|
"messageId": assistantMessageID,
|
||||||
})
|
})
|
||||||
sendEvent("done", "", map[string]interface{}{"conversationId": conversationID})
|
sendEvent("done", "", map[string]interface{}{"conversationId": conversationID})
|
||||||
|
timeoutCancel()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -278,6 +336,7 @@ func (h *AgentHandler) EinoSingleAgentLoopStream(c *gin.Context) {
|
|||||||
"errorType": "timeout",
|
"errorType": "timeout",
|
||||||
})
|
})
|
||||||
sendEvent("done", "", map[string]interface{}{"conversationId": conversationID})
|
sendEvent("done", "", map[string]interface{}{"conversationId": conversationID})
|
||||||
|
timeoutCancel()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -294,9 +353,12 @@ func (h *AgentHandler) EinoSingleAgentLoopStream(c *gin.Context) {
|
|||||||
"messageId": assistantMessageID,
|
"messageId": assistantMessageID,
|
||||||
})
|
})
|
||||||
sendEvent("done", "", map[string]interface{}{"conversationId": conversationID})
|
sendEvent("done", "", map[string]interface{}{"conversationId": conversationID})
|
||||||
|
timeoutCancel()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
timeoutCancel()
|
||||||
|
|
||||||
if assistantMessageID != "" {
|
if assistantMessageID != "" {
|
||||||
_ = h.db.UpdateAssistantMessageFinalize(assistantMessageID, result.Response, cumulativeMCPExecutionIDs, multiagent.AggregatedReasoningFromTraceJSON(result.LastAgentTraceInput))
|
_ = h.db.UpdateAssistantMessageFinalize(assistantMessageID, result.Response, cumulativeMCPExecutionIDs, multiagent.AggregatedReasoningFromTraceJSON(result.LastAgentTraceInput))
|
||||||
}
|
}
|
||||||
@@ -326,7 +388,7 @@ func (h *AgentHandler) EinoSingleAgentLoop(c *gin.Context) {
|
|||||||
|
|
||||||
h.logger.Info("收到 Eino ADK 单代理非流式请求", zap.String("conversationId", req.ConversationID))
|
h.logger.Info("收到 Eino ADK 单代理非流式请求", zap.String("conversationId", req.ConversationID))
|
||||||
|
|
||||||
prep, err := h.prepareMultiAgentSession(&req)
|
prep, err := h.prepareMultiAgentSession(&req, c, "eino_agent")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||||
return
|
return
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ import (
|
|||||||
"os"
|
"os"
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
|
"cyberstrike-ai/internal/audit"
|
||||||
"cyberstrike-ai/internal/config"
|
"cyberstrike-ai/internal/config"
|
||||||
"cyberstrike-ai/internal/mcp"
|
"cyberstrike-ai/internal/mcp"
|
||||||
|
|
||||||
@@ -20,9 +21,15 @@ type ExternalMCPHandler struct {
|
|||||||
config *config.Config
|
config *config.Config
|
||||||
configPath string
|
configPath string
|
||||||
logger *zap.Logger
|
logger *zap.Logger
|
||||||
|
audit *audit.Service
|
||||||
mu sync.RWMutex
|
mu sync.RWMutex
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetAudit wires platform audit logging.
|
||||||
|
func (h *ExternalMCPHandler) SetAudit(s *audit.Service) {
|
||||||
|
h.audit = s
|
||||||
|
}
|
||||||
|
|
||||||
// NewExternalMCPHandler 创建外部MCP处理器
|
// NewExternalMCPHandler 创建外部MCP处理器
|
||||||
func NewExternalMCPHandler(manager *mcp.ExternalMCPManager, cfg *config.Config, configPath string, logger *zap.Logger) *ExternalMCPHandler {
|
func NewExternalMCPHandler(manager *mcp.ExternalMCPManager, cfg *config.Config, configPath string, logger *zap.Logger) *ExternalMCPHandler {
|
||||||
return &ExternalMCPHandler{
|
return &ExternalMCPHandler{
|
||||||
@@ -180,6 +187,16 @@ func (h *ExternalMCPHandler) AddOrUpdateExternalMCP(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
h.logger.Info("外部MCP配置已更新", zap.String("name", name))
|
h.logger.Info("外部MCP配置已更新", zap.String("name", name))
|
||||||
|
if h.audit != nil {
|
||||||
|
h.audit.Record(c, audit.Entry{
|
||||||
|
Category: "external_mcp",
|
||||||
|
Action: "upsert",
|
||||||
|
Result: "success",
|
||||||
|
ResourceType: "external_mcp",
|
||||||
|
ResourceID: name,
|
||||||
|
Message: "更新外部 MCP 配置",
|
||||||
|
})
|
||||||
|
}
|
||||||
c.JSON(http.StatusOK, gin.H{"message": "配置已更新"})
|
c.JSON(http.StatusOK, gin.H{"message": "配置已更新"})
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -209,6 +226,16 @@ func (h *ExternalMCPHandler) DeleteExternalMCP(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
h.logger.Info("外部MCP配置已删除", zap.String("name", name))
|
h.logger.Info("外部MCP配置已删除", zap.String("name", name))
|
||||||
|
if h.audit != nil {
|
||||||
|
h.audit.Record(c, audit.Entry{
|
||||||
|
Category: "external_mcp",
|
||||||
|
Action: "delete",
|
||||||
|
Result: "success",
|
||||||
|
ResourceType: "external_mcp",
|
||||||
|
ResourceID: name,
|
||||||
|
Message: "删除外部 MCP 配置",
|
||||||
|
})
|
||||||
|
}
|
||||||
c.JSON(http.StatusOK, gin.H{"message": "配置已删除"})
|
c.JSON(http.StatusOK, gin.H{"message": "配置已删除"})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -616,6 +616,11 @@ func (h *AgentHandler) DecideHITLInterrupt(c *gin.Context) {
|
|||||||
c.JSON(http.StatusConflict, gin.H{"error": err.Error()})
|
c.JSON(http.StatusConflict, gin.H{"error": err.Error()})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
if h.audit != nil {
|
||||||
|
h.audit.RecordOK(c, "hitl", "decision", "HITL 审批决策", "hitl_interrupt", req.InterruptID, map[string]interface{}{
|
||||||
|
"decision": req.Decision,
|
||||||
|
})
|
||||||
|
}
|
||||||
c.JSON(http.StatusOK, gin.H{"ok": true})
|
c.JSON(http.StatusOK, gin.H{"ok": true})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ import (
|
|||||||
"net/http"
|
"net/http"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"cyberstrike-ai/internal/audit"
|
||||||
"cyberstrike-ai/internal/database"
|
"cyberstrike-ai/internal/database"
|
||||||
"cyberstrike-ai/internal/knowledge"
|
"cyberstrike-ai/internal/knowledge"
|
||||||
|
|
||||||
@@ -20,6 +21,12 @@ type KnowledgeHandler struct {
|
|||||||
indexer *knowledge.Indexer
|
indexer *knowledge.Indexer
|
||||||
db *database.DB
|
db *database.DB
|
||||||
logger *zap.Logger
|
logger *zap.Logger
|
||||||
|
audit *audit.Service
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetAudit wires platform audit logging.
|
||||||
|
func (h *KnowledgeHandler) SetAudit(s *audit.Service) {
|
||||||
|
h.audit = s
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewKnowledgeHandler 创建新的知识库处理器
|
// NewKnowledgeHandler 创建新的知识库处理器
|
||||||
@@ -303,6 +310,9 @@ func (h *KnowledgeHandler) DeleteItem(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if h.audit != nil {
|
||||||
|
h.audit.RecordOK(c, "knowledge", "item_delete", "删除知识项", "knowledge_item", id, nil)
|
||||||
|
}
|
||||||
c.JSON(http.StatusOK, gin.H{"message": "删除成功"})
|
c.JSON(http.StatusOK, gin.H{"message": "删除成功"})
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -316,6 +326,9 @@ func (h *KnowledgeHandler) RebuildIndex(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
|
if h.audit != nil {
|
||||||
|
h.audit.RecordOK(c, "knowledge", "index_rebuild", "重建知识库索引", "knowledge", "", nil)
|
||||||
|
}
|
||||||
c.JSON(http.StatusOK, gin.H{"message": "索引重建已开始,将在后台进行"})
|
c.JSON(http.StatusOK, gin.H{"message": "索引重建已开始,将在后台进行"})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"cyberstrike-ai/internal/agents"
|
"cyberstrike-ai/internal/agents"
|
||||||
|
"cyberstrike-ai/internal/audit"
|
||||||
"cyberstrike-ai/internal/config"
|
"cyberstrike-ai/internal/config"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
@@ -18,7 +19,8 @@ var markdownAgentFilenameRe = regexp.MustCompile(`^[a-zA-Z0-9][a-zA-Z0-9_.-]*\.m
|
|||||||
|
|
||||||
// MarkdownAgentsHandler 管理 agents 目录下子代理 Markdown(增删改查)。
|
// MarkdownAgentsHandler 管理 agents 目录下子代理 Markdown(增删改查)。
|
||||||
type MarkdownAgentsHandler struct {
|
type MarkdownAgentsHandler struct {
|
||||||
dir string
|
dir string
|
||||||
|
audit *audit.Service
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewMarkdownAgentsHandler dir 须为已解析的绝对路径。
|
// NewMarkdownAgentsHandler dir 须为已解析的绝对路径。
|
||||||
@@ -26,6 +28,11 @@ func NewMarkdownAgentsHandler(dir string) *MarkdownAgentsHandler {
|
|||||||
return &MarkdownAgentsHandler{dir: strings.TrimSpace(dir)}
|
return &MarkdownAgentsHandler{dir: strings.TrimSpace(dir)}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetAudit wires platform audit logging.
|
||||||
|
func (h *MarkdownAgentsHandler) SetAudit(s *audit.Service) {
|
||||||
|
h.audit = s
|
||||||
|
}
|
||||||
|
|
||||||
func (h *MarkdownAgentsHandler) safeJoin(filename string) (string, error) {
|
func (h *MarkdownAgentsHandler) safeJoin(filename string) (string, error) {
|
||||||
filename = strings.TrimSpace(filename)
|
filename = strings.TrimSpace(filename)
|
||||||
if filename == "" || !markdownAgentFilenameRe.MatchString(filename) {
|
if filename == "" || !markdownAgentFilenameRe.MatchString(filename) {
|
||||||
@@ -227,6 +234,9 @@ func (h *MarkdownAgentsHandler) CreateMarkdownAgent(c *gin.Context) {
|
|||||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
if h.audit != nil {
|
||||||
|
h.audit.RecordOK(c, "agent", "markdown_create", "创建 Markdown 子代理", "markdown_agent", filepath.Base(path), nil)
|
||||||
|
}
|
||||||
c.JSON(http.StatusOK, gin.H{"filename": filepath.Base(path), "message": "已创建"})
|
c.JSON(http.StatusOK, gin.H{"filename": filepath.Base(path), "message": "已创建"})
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -294,6 +304,9 @@ func (h *MarkdownAgentsHandler) UpdateMarkdownAgent(c *gin.Context) {
|
|||||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
if h.audit != nil {
|
||||||
|
h.audit.RecordOK(c, "agent", "markdown_update", "更新 Markdown 子代理", "markdown_agent", filename, nil)
|
||||||
|
}
|
||||||
c.JSON(http.StatusOK, gin.H{"message": "已保存"})
|
c.JSON(http.StatusOK, gin.H{"message": "已保存"})
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -313,5 +326,8 @@ func (h *MarkdownAgentsHandler) DeleteMarkdownAgent(c *gin.Context) {
|
|||||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
if h.audit != nil {
|
||||||
|
h.audit.RecordOK(c, "agent", "markdown_delete", "删除 Markdown 子代理", "markdown_agent", filename, nil)
|
||||||
|
}
|
||||||
c.JSON(http.StatusOK, gin.H{"message": "已删除"})
|
c.JSON(http.StatusOK, gin.H{"message": "已删除"})
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"cyberstrike-ai/internal/audit"
|
||||||
"cyberstrike-ai/internal/database"
|
"cyberstrike-ai/internal/database"
|
||||||
"cyberstrike-ai/internal/mcp"
|
"cyberstrike-ai/internal/mcp"
|
||||||
"cyberstrike-ai/internal/security"
|
"cyberstrike-ai/internal/security"
|
||||||
@@ -23,6 +24,12 @@ type MonitorHandler struct {
|
|||||||
executor *security.Executor
|
executor *security.Executor
|
||||||
db *database.DB
|
db *database.DB
|
||||||
logger *zap.Logger
|
logger *zap.Logger
|
||||||
|
audit *audit.Service
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetAudit wires platform audit logging.
|
||||||
|
func (h *MonitorHandler) SetAudit(s *audit.Service) {
|
||||||
|
h.audit = s
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewMonitorHandler 创建新的监控处理器
|
// NewMonitorHandler 创建新的监控处理器
|
||||||
@@ -365,6 +372,11 @@ func (h *MonitorHandler) DeleteExecution(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
h.logger.Info("执行记录已从数据库删除", zap.String("executionId", id), zap.String("toolName", exec.ToolName))
|
h.logger.Info("执行记录已从数据库删除", zap.String("executionId", id), zap.String("toolName", exec.ToolName))
|
||||||
|
if h.audit != nil {
|
||||||
|
h.audit.RecordOK(c, "tool", "execution_delete", "删除工具执行记录", "tool_execution", id, map[string]interface{}{
|
||||||
|
"tool_name": exec.ToolName,
|
||||||
|
})
|
||||||
|
}
|
||||||
c.JSON(http.StatusOK, gin.H{"message": "执行记录已删除"})
|
c.JSON(http.StatusOK, gin.H{"message": "执行记录已删除"})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -440,6 +452,11 @@ func (h *MonitorHandler) DeleteExecutions(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
h.logger.Info("批量删除执行记录成功", zap.Int("count", len(request.IDs)))
|
h.logger.Info("批量删除执行记录成功", zap.Int("count", len(request.IDs)))
|
||||||
|
if h.audit != nil {
|
||||||
|
h.audit.RecordOK(c, "tool", "execution_delete_batch", "批量删除工具执行记录", "tool_execution", "", map[string]interface{}{
|
||||||
|
"count": len(request.IDs),
|
||||||
|
})
|
||||||
|
}
|
||||||
c.JSON(http.StatusOK, gin.H{"message": "成功删除执行记录", "deleted": len(executions)})
|
c.JSON(http.StatusOK, gin.H{"message": "成功删除执行记录", "deleted": len(executions)})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -107,7 +107,7 @@ func (h *AgentHandler) MultiAgentLoopStream(c *gin.Context) {
|
|||||||
zap.String("conversationId", req.ConversationID),
|
zap.String("conversationId", req.ConversationID),
|
||||||
)
|
)
|
||||||
|
|
||||||
prep, err := h.prepareMultiAgentSession(&req)
|
prep, err := h.prepareMultiAgentSession(&req, c, "multi_agent_stream")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
sendEvent("error", err.Error(), nil)
|
sendEvent("error", err.Error(), nil)
|
||||||
sendEvent("done", "", nil)
|
sendEvent("done", "", nil)
|
||||||
@@ -136,6 +136,7 @@ func (h *AgentHandler) MultiAgentLoopStream(c *gin.Context) {
|
|||||||
|
|
||||||
var cancelWithCause context.CancelCauseFunc
|
var cancelWithCause context.CancelCauseFunc
|
||||||
curFinalMessage := prep.FinalMessage
|
curFinalMessage := prep.FinalMessage
|
||||||
|
segmentUserMessage := prep.FinalMessage // 本请求原始用户句,临时重试时不得丢失
|
||||||
curHistory := prep.History
|
curHistory := prep.History
|
||||||
roleTools := prep.RoleTools
|
roleTools := prep.RoleTools
|
||||||
orch := strings.TrimSpace(req.Orchestration)
|
orch := strings.TrimSpace(req.Orchestration)
|
||||||
@@ -186,9 +187,41 @@ func (h *AgentHandler) MultiAgentLoopStream(c *gin.Context) {
|
|||||||
|
|
||||||
// 同一 HTTP 流内多段 Run(如中断并继续)合并 MCP execution id,供最终 response / 库表与工具芯片展示完整列表
|
// 同一 HTTP 流内多段 Run(如中断并继续)合并 MCP execution id,供最终 response / 库表与工具芯片展示完整列表
|
||||||
var cumulativeMCPExecutionIDs []string
|
var cumulativeMCPExecutionIDs []string
|
||||||
|
var transientRunAttempts int
|
||||||
|
// 同一请求内分段续跑时,主代理 iteration 事件按偏移累计,避免 UI 出现「第3轮 → 第1轮」回跳。
|
||||||
|
var mainIterationOffset int
|
||||||
|
|
||||||
for {
|
for {
|
||||||
progressCallback := h.createProgressCallback(taskCtx, cancelWithCause, conversationID, assistantMessageID, sendEvent)
|
segmentMainIterationMax := 0
|
||||||
|
rawProgressCallback := h.createProgressCallback(taskCtx, cancelWithCause, conversationID, assistantMessageID, sendEvent)
|
||||||
|
progressCallback := func(eventType, message string, data interface{}) {
|
||||||
|
if eventType == "iteration" {
|
||||||
|
if m, ok := data.(map[string]interface{}); ok {
|
||||||
|
if scope, _ := m["einoScope"].(string); scope == "main" {
|
||||||
|
raw := 0
|
||||||
|
switch v := m["iteration"].(type) {
|
||||||
|
case int:
|
||||||
|
raw = v
|
||||||
|
case int32:
|
||||||
|
raw = int(v)
|
||||||
|
case int64:
|
||||||
|
raw = int(v)
|
||||||
|
case float64:
|
||||||
|
raw = int(v)
|
||||||
|
case float32:
|
||||||
|
raw = int(v)
|
||||||
|
}
|
||||||
|
if raw > 0 {
|
||||||
|
if raw > segmentMainIterationMax {
|
||||||
|
segmentMainIterationMax = raw
|
||||||
|
}
|
||||||
|
m["iteration"] = raw + mainIterationOffset
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
rawProgressCallback(eventType, message, data)
|
||||||
|
}
|
||||||
taskCtxLoop := mcp.WithMCPConversationID(taskCtx, conversationID)
|
taskCtxLoop := mcp.WithMCPConversationID(taskCtx, conversationID)
|
||||||
taskCtxLoop = mcp.WithToolRunRegistry(taskCtxLoop, h.tasks)
|
taskCtxLoop = mcp.WithToolRunRegistry(taskCtxLoop, h.tasks)
|
||||||
taskCtxLoop = multiagent.WithHITLToolInterceptor(taskCtxLoop, func(ctx context.Context, toolName, arguments string) (string, error) {
|
taskCtxLoop = multiagent.WithHITLToolInterceptor(taskCtxLoop, func(ctx context.Context, toolName, arguments string) (string, error) {
|
||||||
@@ -210,16 +243,36 @@ func (h *AgentHandler) MultiAgentLoopStream(c *gin.Context) {
|
|||||||
orch,
|
orch,
|
||||||
chatReasoningToClientIntent(req.Reasoning),
|
chatReasoningToClientIntent(req.Reasoning),
|
||||||
)
|
)
|
||||||
timeoutCancel()
|
|
||||||
|
|
||||||
if result != nil && len(result.MCPExecutionIDs) > 0 {
|
if result != nil && len(result.MCPExecutionIDs) > 0 {
|
||||||
cumulativeMCPExecutionIDs = mergeMCPExecutionIDLists(cumulativeMCPExecutionIDs, result.MCPExecutionIDs)
|
cumulativeMCPExecutionIDs = mergeMCPExecutionIDLists(cumulativeMCPExecutionIDs, result.MCPExecutionIDs)
|
||||||
}
|
}
|
||||||
|
|
||||||
if runErr == nil {
|
if runErr == nil {
|
||||||
|
// 任一段成功完成后,重置临时错误重试窗口(次数/退避从头开始)。
|
||||||
|
transientRunAttempts = 0
|
||||||
|
timeoutCancel()
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
|
|
||||||
|
handled, fatalErr := h.handleEinoTransientRetryContinue(
|
||||||
|
baseCtx, conversationID, result, runErr, &transientRunAttempts,
|
||||||
|
&curHistory, &curFinalMessage, segmentUserMessage, progressCallback,
|
||||||
|
func(msg string, extra map[string]interface{}) { sendEvent("progress", msg, extra) },
|
||||||
|
)
|
||||||
|
if handled {
|
||||||
|
mainIterationOffset += segmentMainIterationMax
|
||||||
|
timeoutCancel()
|
||||||
|
baseCtx, cancelWithCause = context.WithCancelCause(context.Background())
|
||||||
|
h.tasks.BindTaskCancel(conversationID, cancelWithCause)
|
||||||
|
taskCtx, timeoutCancel = context.WithTimeout(baseCtx, 600*time.Minute)
|
||||||
|
h.tasks.UpdateTaskStatus(conversationID, "running")
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if fatalErr != nil {
|
||||||
|
runErr = fatalErr
|
||||||
|
}
|
||||||
|
|
||||||
cause := context.Cause(baseCtx)
|
cause := context.Cause(baseCtx)
|
||||||
if errors.Is(cause, multiagent.ErrInterruptContinue) {
|
if errors.Is(cause, multiagent.ErrInterruptContinue) {
|
||||||
if shouldPersistEinoAgentTraceAfterRunError(baseCtx) {
|
if shouldPersistEinoAgentTraceAfterRunError(baseCtx) {
|
||||||
@@ -243,10 +296,14 @@ func (h *AgentHandler) MultiAgentLoopStream(c *gin.Context) {
|
|||||||
"conversationId": conversationID,
|
"conversationId": conversationID,
|
||||||
"source": "interrupt_continue",
|
"source": "interrupt_continue",
|
||||||
})
|
})
|
||||||
h.tasks.UpdateTaskStatus(conversationID, "running")
|
mainIterationOffset += segmentMainIterationMax
|
||||||
|
// 非临时错误分段续跑(用户中断并继续)时,清空 transient 计数,避免跨分段累加。
|
||||||
|
transientRunAttempts = 0
|
||||||
|
timeoutCancel()
|
||||||
baseCtx, cancelWithCause = context.WithCancelCause(context.Background())
|
baseCtx, cancelWithCause = context.WithCancelCause(context.Background())
|
||||||
h.tasks.BindTaskCancel(conversationID, cancelWithCause)
|
h.tasks.BindTaskCancel(conversationID, cancelWithCause)
|
||||||
taskCtx, timeoutCancel = context.WithTimeout(baseCtx, 600*time.Minute)
|
taskCtx, timeoutCancel = context.WithTimeout(baseCtx, 600*time.Minute)
|
||||||
|
h.tasks.UpdateTaskStatus(conversationID, "running")
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -273,6 +330,7 @@ func (h *AgentHandler) MultiAgentLoopStream(c *gin.Context) {
|
|||||||
"messageId": assistantMessageID,
|
"messageId": assistantMessageID,
|
||||||
})
|
})
|
||||||
sendEvent("done", "", map[string]interface{}{"conversationId": conversationID})
|
sendEvent("done", "", map[string]interface{}{"conversationId": conversationID})
|
||||||
|
timeoutCancel()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -290,6 +348,7 @@ func (h *AgentHandler) MultiAgentLoopStream(c *gin.Context) {
|
|||||||
"errorType": "timeout",
|
"errorType": "timeout",
|
||||||
})
|
})
|
||||||
sendEvent("done", "", map[string]interface{}{"conversationId": conversationID})
|
sendEvent("done", "", map[string]interface{}{"conversationId": conversationID})
|
||||||
|
timeoutCancel()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -306,9 +365,12 @@ func (h *AgentHandler) MultiAgentLoopStream(c *gin.Context) {
|
|||||||
"messageId": assistantMessageID,
|
"messageId": assistantMessageID,
|
||||||
})
|
})
|
||||||
sendEvent("done", "", map[string]interface{}{"conversationId": conversationID})
|
sendEvent("done", "", map[string]interface{}{"conversationId": conversationID})
|
||||||
|
timeoutCancel()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
timeoutCancel()
|
||||||
|
|
||||||
if assistantMessageID != "" {
|
if assistantMessageID != "" {
|
||||||
_ = h.db.UpdateAssistantMessageFinalize(assistantMessageID, result.Response, cumulativeMCPExecutionIDs, multiagent.AggregatedReasoningFromTraceJSON(result.LastAgentTraceInput))
|
_ = h.db.UpdateAssistantMessageFinalize(assistantMessageID, result.Response, cumulativeMCPExecutionIDs, multiagent.AggregatedReasoningFromTraceJSON(result.LastAgentTraceInput))
|
||||||
}
|
}
|
||||||
@@ -347,7 +409,7 @@ func (h *AgentHandler) MultiAgentLoop(c *gin.Context) {
|
|||||||
|
|
||||||
h.logger.Info("收到 Eino DeepAgent 非流式请求", zap.String("conversationId", req.ConversationID))
|
h.logger.Info("收到 Eino DeepAgent 非流式请求", zap.String("conversationId", req.ConversationID))
|
||||||
|
|
||||||
prep, err := h.prepareMultiAgentSession(&req)
|
prep, err := h.prepareMultiAgentSession(&req, c, "multi_agent")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
status, msg := multiAgentHTTPErrorStatus(err)
|
status, msg := multiAgentHTTPErrorStatus(err)
|
||||||
c.JSON(status, gin.H{"error": msg})
|
c.JSON(status, gin.H{"error": msg})
|
||||||
|
|||||||
@@ -5,9 +5,11 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"cyberstrike-ai/internal/agent"
|
"cyberstrike-ai/internal/agent"
|
||||||
|
"cyberstrike-ai/internal/audit"
|
||||||
"cyberstrike-ai/internal/database"
|
"cyberstrike-ai/internal/database"
|
||||||
"cyberstrike-ai/internal/mcp/builtin"
|
"cyberstrike-ai/internal/mcp/builtin"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
"go.uber.org/zap"
|
"go.uber.org/zap"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -22,7 +24,7 @@ type multiAgentPrepared struct {
|
|||||||
UserMessageID string
|
UserMessageID string
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *AgentHandler) prepareMultiAgentSession(req *ChatRequest) (*multiAgentPrepared, error) {
|
func (h *AgentHandler) prepareMultiAgentSession(req *ChatRequest, c *gin.Context, source string) (*multiAgentPrepared, error) {
|
||||||
if len(req.Attachments) > maxAttachments {
|
if len(req.Attachments) > maxAttachments {
|
||||||
return nil, fmt.Errorf("附件最多 %d 个", maxAttachments)
|
return nil, fmt.Errorf("附件最多 %d 个", maxAttachments)
|
||||||
}
|
}
|
||||||
@@ -33,10 +35,13 @@ func (h *AgentHandler) prepareMultiAgentSession(req *ChatRequest) (*multiAgentPr
|
|||||||
title := safeTruncateString(req.Message, 50)
|
title := safeTruncateString(req.Message, 50)
|
||||||
var conv *database.Conversation
|
var conv *database.Conversation
|
||||||
var err error
|
var err error
|
||||||
|
meta := audit.ConversationCreateMetaFromGin(c, source)
|
||||||
if strings.TrimSpace(req.WebShellConnectionID) != "" {
|
if strings.TrimSpace(req.WebShellConnectionID) != "" {
|
||||||
conv, err = h.db.CreateConversationWithWebshell(strings.TrimSpace(req.WebShellConnectionID), title)
|
meta.Source = source + "_webshell"
|
||||||
|
meta.WebShellConnectionID = strings.TrimSpace(req.WebShellConnectionID)
|
||||||
|
conv, err = h.db.CreateConversationWithWebshell(meta.WebShellConnectionID, title, meta)
|
||||||
} else {
|
} else {
|
||||||
conv, err = h.db.CreateConversation(title)
|
conv, err = h.db.CreateConversation(title, meta)
|
||||||
}
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("创建对话失败: %w", err)
|
return nil, fmt.Errorf("创建对话失败: %w", err)
|
||||||
|
|||||||
@@ -6254,7 +6254,7 @@ func (h *OpenAPIHandler) GetConversationResults(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// 获取漏洞列表
|
// 获取漏洞列表
|
||||||
vulnList, err := h.db.ListVulnerabilities(1000, 0, "", conversationID, "", "", "", "", "")
|
vulnList, err := h.db.ListVulnerabilities(1000, 0, database.VulnerabilityListFilter{ConversationID: conversationID})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
h.logger.Warn("获取漏洞列表失败", zap.Error(err))
|
h.logger.Warn("获取漏洞列表失败", zap.Error(err))
|
||||||
vulnList = []*database.Vulnerability{}
|
vulnList = []*database.Vulnerability{}
|
||||||
|
|||||||
@@ -133,7 +133,7 @@ func (h *RobotHandler) getOrCreateConversation(platform, userID, title string) (
|
|||||||
} else {
|
} else {
|
||||||
t = safeTruncateString(t, 50)
|
t = safeTruncateString(t, 50)
|
||||||
}
|
}
|
||||||
conv, err := h.db.CreateConversation(t)
|
conv, err := h.db.CreateConversation(t, database.ConversationCreateMeta{Source: "robot:" + platform})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
h.logger.Warn("创建机器人会话失败", zap.Error(err))
|
h.logger.Warn("创建机器人会话失败", zap.Error(err))
|
||||||
return "", false
|
return "", false
|
||||||
@@ -188,7 +188,7 @@ func (h *RobotHandler) setRole(platform, userID, roleName string) {
|
|||||||
// clearConversation 清空当前会话(切换到新对话)
|
// clearConversation 清空当前会话(切换到新对话)
|
||||||
func (h *RobotHandler) clearConversation(platform, userID string) (newConvID string) {
|
func (h *RobotHandler) clearConversation(platform, userID string) (newConvID string) {
|
||||||
title := "新对话 " + time.Now().Format("01-02 15:04")
|
title := "新对话 " + time.Now().Format("01-02 15:04")
|
||||||
conv, err := h.db.CreateConversation(title)
|
conv, err := h.db.CreateConversation(title, database.ConversationCreateMeta{Source: "robot:" + platform + ":new"})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
h.logger.Warn("创建新对话失败", zap.Error(err))
|
h.logger.Warn("创建新对话失败", zap.Error(err))
|
||||||
return ""
|
return ""
|
||||||
@@ -242,7 +242,7 @@ func (h *RobotHandler) HandleMessage(platform, userID, text string) (reply strin
|
|||||||
h.cancelMu.Unlock()
|
h.cancelMu.Unlock()
|
||||||
}()
|
}()
|
||||||
role := h.getRole(platform, userID)
|
role := h.getRole(platform, userID)
|
||||||
resp, newConvID, err := h.agentHandler.ProcessMessageForRobot(ctx, convID, text, role)
|
resp, newConvID, err := h.agentHandler.ProcessMessageForRobot(ctx, platform, convID, text, role)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
h.logger.Warn("机器人 Agent 执行失败", zap.String("platform", platform), zap.String("userID", userID), zap.Error(err))
|
h.logger.Warn("机器人 Agent 执行失败", zap.String("platform", platform), zap.String("userID", userID), zap.Error(err))
|
||||||
if errors.Is(err, context.Canceled) {
|
if errors.Is(err, context.Canceled) {
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ import (
|
|||||||
"regexp"
|
"regexp"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
|
"cyberstrike-ai/internal/audit"
|
||||||
"cyberstrike-ai/internal/config"
|
"cyberstrike-ai/internal/config"
|
||||||
|
|
||||||
"gopkg.in/yaml.v3"
|
"gopkg.in/yaml.v3"
|
||||||
@@ -21,6 +22,12 @@ type RoleHandler struct {
|
|||||||
config *config.Config
|
config *config.Config
|
||||||
configPath string
|
configPath string
|
||||||
logger *zap.Logger
|
logger *zap.Logger
|
||||||
|
audit *audit.Service
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetAudit wires platform audit logging.
|
||||||
|
func (h *RoleHandler) SetAudit(s *audit.Service) {
|
||||||
|
h.audit = s
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewRoleHandler 创建新的角色处理器
|
// NewRoleHandler 创建新的角色处理器
|
||||||
@@ -174,6 +181,9 @@ func (h *RoleHandler) UpdateRole(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
h.logger.Info("更新角色", zap.String("oldKey", roleName), zap.String("newKey", finalKey), zap.String("name", req.Name))
|
h.logger.Info("更新角色", zap.String("oldKey", roleName), zap.String("newKey", finalKey), zap.String("name", req.Name))
|
||||||
|
if h.audit != nil {
|
||||||
|
h.audit.RecordOK(c, "role", "update", "更新角色", "role", finalKey, map[string]interface{}{"name": req.Name})
|
||||||
|
}
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
"message": "角色已更新",
|
"message": "角色已更新",
|
||||||
"role": req,
|
"role": req,
|
||||||
@@ -219,6 +229,9 @@ func (h *RoleHandler) CreateRole(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
h.logger.Info("创建角色", zap.String("roleName", req.Name))
|
h.logger.Info("创建角色", zap.String("roleName", req.Name))
|
||||||
|
if h.audit != nil {
|
||||||
|
h.audit.RecordOK(c, "role", "create", "创建角色", "role", req.Name, nil)
|
||||||
|
}
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
"message": "角色已创建",
|
"message": "角色已创建",
|
||||||
"role": req,
|
"role": req,
|
||||||
@@ -287,6 +300,9 @@ func (h *RoleHandler) DeleteRole(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
h.logger.Info("删除角色", zap.String("roleName", roleName))
|
h.logger.Info("删除角色", zap.String("roleName", roleName))
|
||||||
|
if h.audit != nil {
|
||||||
|
h.audit.RecordOK(c, "role", "delete", "删除角色", "role", roleName, nil)
|
||||||
|
}
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
"message": "角色已删除",
|
"message": "角色已删除",
|
||||||
})
|
})
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ import (
|
|||||||
"regexp"
|
"regexp"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
|
"cyberstrike-ai/internal/audit"
|
||||||
"cyberstrike-ai/internal/config"
|
"cyberstrike-ai/internal/config"
|
||||||
"cyberstrike-ai/internal/database"
|
"cyberstrike-ai/internal/database"
|
||||||
"cyberstrike-ai/internal/skillpackage"
|
"cyberstrike-ai/internal/skillpackage"
|
||||||
@@ -23,6 +24,12 @@ type SkillsHandler struct {
|
|||||||
configPath string
|
configPath string
|
||||||
logger *zap.Logger
|
logger *zap.Logger
|
||||||
db *database.DB // 数据库连接(遗留统计;MCP list/read 已移除)
|
db *database.DB // 数据库连接(遗留统计;MCP list/read 已移除)
|
||||||
|
audit *audit.Service
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetAudit wires platform audit logging.
|
||||||
|
func (h *SkillsHandler) SetAudit(s *audit.Service) {
|
||||||
|
h.audit = s
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewSkillsHandler 创建新的Skills处理器
|
// NewSkillsHandler 创建新的Skills处理器
|
||||||
@@ -365,6 +372,9 @@ func (h *SkillsHandler) CreateSkill(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
h.logger.Info("创建skill成功", zap.String("skill", req.Name))
|
h.logger.Info("创建skill成功", zap.String("skill", req.Name))
|
||||||
|
if h.audit != nil {
|
||||||
|
h.audit.RecordOK(c, "skill", "create", "创建 Skill", "skill", req.Name, nil)
|
||||||
|
}
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
"message": "skill已创建",
|
"message": "skill已创建",
|
||||||
"skill": map[string]interface{}{
|
"skill": map[string]interface{}{
|
||||||
@@ -425,6 +435,9 @@ func (h *SkillsHandler) UpdateSkill(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
h.logger.Info("更新skill成功", zap.String("skill", skillName))
|
h.logger.Info("更新skill成功", zap.String("skill", skillName))
|
||||||
|
if h.audit != nil {
|
||||||
|
h.audit.RecordOK(c, "skill", "update", "更新 Skill", "skill", skillName, nil)
|
||||||
|
}
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
"message": "skill已更新",
|
"message": "skill已更新",
|
||||||
})
|
})
|
||||||
@@ -459,6 +472,11 @@ func (h *SkillsHandler) DeleteSkill(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
h.logger.Info("删除skill成功", zap.String("skill", skillName))
|
h.logger.Info("删除skill成功", zap.String("skill", skillName))
|
||||||
|
if h.audit != nil {
|
||||||
|
h.audit.RecordOK(c, "skill", "delete", "删除 Skill", "skill", skillName, map[string]interface{}{
|
||||||
|
"affected_roles": affectedRoles,
|
||||||
|
})
|
||||||
|
}
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
"message": responseMsg,
|
"message": responseMsg,
|
||||||
"affected_roles": affectedRoles,
|
"affected_roles": affectedRoles,
|
||||||
|
|||||||
@@ -253,5 +253,5 @@ func (h *TerminalHandler) RunCommandStream(c *gin.Context) {
|
|||||||
flusher.Flush()
|
flusher.Flush()
|
||||||
}
|
}
|
||||||
|
|
||||||
runCommandStreamImpl(cmd, sendEvent, ctx)
|
_ = runCommandStreamImpl(cmd, sendEvent, ctx)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -15,11 +15,11 @@ const ptyCols = 256
|
|||||||
const ptyRows = 40
|
const ptyRows = 40
|
||||||
|
|
||||||
// runCommandStreamImpl 在 Unix 下用 PTY 执行,使 ping 等命令按终端宽度排版(isatty 为真)
|
// runCommandStreamImpl 在 Unix 下用 PTY 执行,使 ping 等命令按终端宽度排版(isatty 为真)
|
||||||
func runCommandStreamImpl(cmd *exec.Cmd, sendEvent func(streamEvent), ctx context.Context) {
|
func runCommandStreamImpl(cmd *exec.Cmd, sendEvent func(streamEvent), ctx context.Context) int {
|
||||||
ptmx, err := pty.StartWithSize(cmd, &pty.Winsize{Cols: ptyCols, Rows: ptyRows})
|
ptmx, err := pty.StartWithSize(cmd, &pty.Winsize{Cols: ptyCols, Rows: ptyRows})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
sendEvent(streamEvent{T: "exit", C: -1})
|
sendEvent(streamEvent{T: "exit", C: -1})
|
||||||
return
|
return -1
|
||||||
}
|
}
|
||||||
defer ptmx.Close()
|
defer ptmx.Close()
|
||||||
|
|
||||||
@@ -43,4 +43,5 @@ func runCommandStreamImpl(cmd *exec.Cmd, sendEvent func(streamEvent), ctx contex
|
|||||||
exitCode = -1
|
exitCode = -1
|
||||||
}
|
}
|
||||||
sendEvent(streamEvent{T: "exit", C: exitCode})
|
sendEvent(streamEvent{T: "exit", C: exitCode})
|
||||||
|
return exitCode
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -11,20 +11,20 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
// runCommandStreamImpl 在 Windows 下用 stdout/stderr 管道执行
|
// runCommandStreamImpl 在 Windows 下用 stdout/stderr 管道执行
|
||||||
func runCommandStreamImpl(cmd *exec.Cmd, sendEvent func(streamEvent), ctx context.Context) {
|
func runCommandStreamImpl(cmd *exec.Cmd, sendEvent func(streamEvent), ctx context.Context) int {
|
||||||
stdoutPipe, err := cmd.StdoutPipe()
|
stdoutPipe, err := cmd.StdoutPipe()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
sendEvent(streamEvent{T: "exit", C: -1})
|
sendEvent(streamEvent{T: "exit", C: -1})
|
||||||
return
|
return -1
|
||||||
}
|
}
|
||||||
stderrPipe, err := cmd.StderrPipe()
|
stderrPipe, err := cmd.StderrPipe()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
sendEvent(streamEvent{T: "exit", C: -1})
|
sendEvent(streamEvent{T: "exit", C: -1})
|
||||||
return
|
return -1
|
||||||
}
|
}
|
||||||
if err := cmd.Start(); err != nil {
|
if err := cmd.Start(); err != nil {
|
||||||
sendEvent(streamEvent{T: "exit", C: -1})
|
sendEvent(streamEvent{T: "exit", C: -1})
|
||||||
return
|
return -1
|
||||||
}
|
}
|
||||||
|
|
||||||
normalize := func(s string) string {
|
normalize := func(s string) string {
|
||||||
@@ -62,4 +62,5 @@ func runCommandStreamImpl(cmd *exec.Cmd, sendEvent func(streamEvent), ctx contex
|
|||||||
exitCode = -1
|
exitCode = -1
|
||||||
}
|
}
|
||||||
sendEvent(streamEvent{T: "exit", C: exitCode})
|
sendEvent(streamEvent{T: "exit", C: exitCode})
|
||||||
|
return exitCode
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"cyberstrike-ai/internal/audit"
|
||||||
"cyberstrike-ai/internal/database"
|
"cyberstrike-ai/internal/database"
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"go.uber.org/zap"
|
"go.uber.org/zap"
|
||||||
@@ -16,6 +17,12 @@ import (
|
|||||||
type VulnerabilityHandler struct {
|
type VulnerabilityHandler struct {
|
||||||
db *database.DB
|
db *database.DB
|
||||||
logger *zap.Logger
|
logger *zap.Logger
|
||||||
|
audit *audit.Service
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetAudit wires platform audit logging.
|
||||||
|
func (h *VulnerabilityHandler) SetAudit(s *audit.Service) {
|
||||||
|
h.audit = s
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewVulnerabilityHandler 创建新的漏洞处理器
|
// NewVulnerabilityHandler 创建新的漏洞处理器
|
||||||
@@ -72,6 +79,11 @@ func (h *VulnerabilityHandler) CreateVulnerability(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if h.audit != nil {
|
||||||
|
h.audit.RecordOK(c, "vulnerability", "create", "创建漏洞记录", "vulnerability", created.ID, map[string]interface{}{
|
||||||
|
"severity": created.Severity, "title": created.Title,
|
||||||
|
})
|
||||||
|
}
|
||||||
c.JSON(http.StatusOK, created)
|
c.JSON(http.StatusOK, created)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -98,18 +110,29 @@ type ListVulnerabilitiesResponse struct {
|
|||||||
TotalPages int `json:"total_pages"`
|
TotalPages int `json:"total_pages"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func parseVulnerabilityListFilter(c *gin.Context) database.VulnerabilityListFilter {
|
||||||
|
q := strings.TrimSpace(c.Query("q"))
|
||||||
|
if q == "" {
|
||||||
|
q = strings.TrimSpace(c.Query("search"))
|
||||||
|
}
|
||||||
|
return database.VulnerabilityListFilter{
|
||||||
|
ID: c.Query("id"),
|
||||||
|
Search: q,
|
||||||
|
ConversationID: c.Query("conversation_id"),
|
||||||
|
Severity: c.Query("severity"),
|
||||||
|
Status: c.Query("status"),
|
||||||
|
TaskID: c.Query("task_id"),
|
||||||
|
ConversationTag: c.Query("conversation_tag"),
|
||||||
|
TaskTag: c.Query("task_tag"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// ListVulnerabilities 列出漏洞
|
// ListVulnerabilities 列出漏洞
|
||||||
func (h *VulnerabilityHandler) ListVulnerabilities(c *gin.Context) {
|
func (h *VulnerabilityHandler) ListVulnerabilities(c *gin.Context) {
|
||||||
limitStr := c.DefaultQuery("limit", "20")
|
limitStr := c.DefaultQuery("limit", "20")
|
||||||
offsetStr := c.DefaultQuery("offset", "0")
|
offsetStr := c.DefaultQuery("offset", "0")
|
||||||
pageStr := c.Query("page")
|
pageStr := c.Query("page")
|
||||||
id := c.Query("id")
|
filter := parseVulnerabilityListFilter(c)
|
||||||
conversationID := c.Query("conversation_id")
|
|
||||||
severity := c.Query("severity")
|
|
||||||
status := c.Query("status")
|
|
||||||
taskID := c.Query("task_id")
|
|
||||||
conversationTag := c.Query("conversation_tag")
|
|
||||||
taskTag := c.Query("task_tag")
|
|
||||||
|
|
||||||
limit, _ := strconv.Atoi(limitStr)
|
limit, _ := strconv.Atoi(limitStr)
|
||||||
offset, _ := strconv.Atoi(offsetStr)
|
offset, _ := strconv.Atoi(offsetStr)
|
||||||
@@ -131,7 +154,7 @@ func (h *VulnerabilityHandler) ListVulnerabilities(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// 获取总数
|
// 获取总数
|
||||||
total, err := h.db.CountVulnerabilities(id, conversationID, severity, status, taskID, conversationTag, taskTag)
|
total, err := h.db.CountVulnerabilities(filter)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
h.logger.Error("获取漏洞总数失败", zap.Error(err))
|
h.logger.Error("获取漏洞总数失败", zap.Error(err))
|
||||||
// 继续执行,使用0作为总数
|
// 继续执行,使用0作为总数
|
||||||
@@ -139,7 +162,7 @@ func (h *VulnerabilityHandler) ListVulnerabilities(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// 获取漏洞列表
|
// 获取漏洞列表
|
||||||
vulnerabilities, err := h.db.ListVulnerabilities(limit, offset, id, conversationID, severity, status, taskID, conversationTag, taskTag)
|
vulnerabilities, err := h.db.ListVulnerabilities(limit, offset, filter)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
h.logger.Error("获取漏洞列表失败", zap.Error(err))
|
h.logger.Error("获取漏洞列表失败", zap.Error(err))
|
||||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||||
@@ -249,6 +272,11 @@ func (h *VulnerabilityHandler) UpdateVulnerability(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if h.audit != nil {
|
||||||
|
h.audit.RecordOK(c, "vulnerability", "update", "更新漏洞记录", "vulnerability", id, map[string]interface{}{
|
||||||
|
"severity": updated.Severity, "status": updated.Status,
|
||||||
|
})
|
||||||
|
}
|
||||||
c.JSON(http.StatusOK, updated)
|
c.JSON(http.StatusOK, updated)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -262,15 +290,25 @@ func (h *VulnerabilityHandler) DeleteVulnerability(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if h.audit != nil {
|
||||||
|
h.audit.Record(c, audit.Entry{
|
||||||
|
Category: "vulnerability",
|
||||||
|
Action: "delete",
|
||||||
|
Result: "success",
|
||||||
|
ResourceType: "vulnerability",
|
||||||
|
ResourceID: id,
|
||||||
|
Message: "删除漏洞记录",
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
c.JSON(http.StatusOK, gin.H{"message": "删除成功"})
|
c.JSON(http.StatusOK, gin.H{"message": "删除成功"})
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetVulnerabilityStats 获取漏洞统计
|
// GetVulnerabilityStats 获取漏洞统计
|
||||||
func (h *VulnerabilityHandler) GetVulnerabilityStats(c *gin.Context) {
|
func (h *VulnerabilityHandler) GetVulnerabilityStats(c *gin.Context) {
|
||||||
conversationID := c.Query("conversation_id")
|
filter := parseVulnerabilityListFilter(c)
|
||||||
taskID := c.Query("task_id")
|
|
||||||
|
|
||||||
stats, err := h.db.GetVulnerabilityStats(conversationID, taskID)
|
stats, err := h.db.GetVulnerabilityStats(filter)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
h.logger.Error("获取漏洞统计失败", zap.Error(err))
|
h.logger.Error("获取漏洞统计失败", zap.Error(err))
|
||||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||||
@@ -304,15 +342,9 @@ func (h *VulnerabilityHandler) ExportVulnerabilities(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
id := c.Query("id")
|
filter := parseVulnerabilityListFilter(c)
|
||||||
conversationID := c.Query("conversation_id")
|
|
||||||
severity := c.Query("severity")
|
|
||||||
status := c.Query("status")
|
|
||||||
taskID := c.Query("task_id")
|
|
||||||
conversationTag := c.Query("conversation_tag")
|
|
||||||
taskTag := c.Query("task_tag")
|
|
||||||
|
|
||||||
total, err := h.db.CountVulnerabilities(id, conversationID, severity, status, taskID, conversationTag, taskTag)
|
total, err := h.db.CountVulnerabilities(filter)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||||
return
|
return
|
||||||
@@ -322,7 +354,7 @@ func (h *VulnerabilityHandler) ExportVulnerabilities(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
items, err := h.db.ListVulnerabilities(total, 0, id, conversationID, severity, status, taskID, conversationTag, taskTag)
|
items, err := h.db.ListVulnerabilities(total, 0, filter)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||||
return
|
return
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ package handler
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
|
"crypto/tls"
|
||||||
"database/sql"
|
"database/sql"
|
||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
@@ -12,6 +13,7 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
"unicode/utf8"
|
"unicode/utf8"
|
||||||
|
|
||||||
|
"cyberstrike-ai/internal/audit"
|
||||||
"cyberstrike-ai/internal/database"
|
"cyberstrike-ai/internal/database"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
@@ -304,6 +306,12 @@ type WebShellHandler struct {
|
|||||||
logger *zap.Logger
|
logger *zap.Logger
|
||||||
client *http.Client
|
client *http.Client
|
||||||
db *database.DB
|
db *database.DB
|
||||||
|
audit *audit.Service
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetAudit wires platform audit logging.
|
||||||
|
func (h *WebShellHandler) SetAudit(s *audit.Service) {
|
||||||
|
h.audit = s
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewWebShellHandler 创建 WebShell 处理器,db 可为 nil(连接配置接口将不可用)
|
// NewWebShellHandler 创建 WebShell 处理器,db 可为 nil(连接配置接口将不可用)
|
||||||
@@ -311,8 +319,12 @@ func NewWebShellHandler(logger *zap.Logger, db *database.DB) *WebShellHandler {
|
|||||||
return &WebShellHandler{
|
return &WebShellHandler{
|
||||||
logger: logger,
|
logger: logger,
|
||||||
client: &http.Client{
|
client: &http.Client{
|
||||||
Timeout: 30 * time.Second,
|
Timeout: 30 * time.Second,
|
||||||
Transport: &http.Transport{DisableKeepAlives: false},
|
Transport: &http.Transport{
|
||||||
|
DisableKeepAlives: false,
|
||||||
|
// WebShell 场景常见自签证书或 IP 访问(证书无 IP SAN);默认跳过校验,与蚁剑等客户端一致。
|
||||||
|
TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, //nolint:gosec // intentional for webshell proxy
|
||||||
|
},
|
||||||
},
|
},
|
||||||
db: db,
|
db: db,
|
||||||
}
|
}
|
||||||
@@ -403,6 +415,15 @@ func (h *WebShellHandler) CreateConnection(c *gin.Context) {
|
|||||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
if h.audit != nil {
|
||||||
|
host := req.URL
|
||||||
|
if u, err := url.Parse(req.URL); err == nil {
|
||||||
|
host = u.Host
|
||||||
|
}
|
||||||
|
h.audit.RecordOK(c, "webshell", "connection_create", "创建 WebShell 连接", "webshell_connection", conn.ID, map[string]interface{}{
|
||||||
|
"host": host, "type": shellType,
|
||||||
|
})
|
||||||
|
}
|
||||||
c.JSON(http.StatusOK, conn)
|
c.JSON(http.StatusOK, conn)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -485,6 +506,9 @@ func (h *WebShellHandler) DeleteConnection(c *gin.Context) {
|
|||||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
if h.audit != nil {
|
||||||
|
h.audit.RecordOK(c, "webshell", "connection_delete", "删除 WebShell 连接", "webshell_connection", id, nil)
|
||||||
|
}
|
||||||
c.JSON(http.StatusOK, gin.H{"ok": true})
|
c.JSON(http.StatusOK, gin.H{"ok": true})
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -714,8 +738,9 @@ func (h *WebShellHandler) Exec(c *gin.Context) {
|
|||||||
output := decodeWebshellOutput(out, req.Encoding)
|
output := decodeWebshellOutput(out, req.Encoding)
|
||||||
httpCode := resp.StatusCode
|
httpCode := resp.StatusCode
|
||||||
|
|
||||||
|
ok := resp.StatusCode == http.StatusOK
|
||||||
c.JSON(http.StatusOK, ExecResponse{
|
c.JSON(http.StatusOK, ExecResponse{
|
||||||
OK: resp.StatusCode == http.StatusOK,
|
OK: ok,
|
||||||
Output: output,
|
Output: output,
|
||||||
HTTPCode: httpCode,
|
HTTPCode: httpCode,
|
||||||
})
|
})
|
||||||
|
|||||||
@@ -0,0 +1,293 @@
|
|||||||
|
package handler
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"cyberstrike-ai/internal/config"
|
||||||
|
"cyberstrike-ai/internal/robot/ilink"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/google/uuid"
|
||||||
|
"go.uber.org/zap"
|
||||||
|
)
|
||||||
|
|
||||||
|
const wechatLoginTTL = 5 * time.Minute
|
||||||
|
|
||||||
|
// WechatConfigSaver 绑定成功后写入配置并重启机器人连接
|
||||||
|
type WechatConfigSaver interface {
|
||||||
|
ApplyWechatRobotBinding(cfg config.RobotWechatConfig) error
|
||||||
|
}
|
||||||
|
|
||||||
|
type wechatLoginSession struct {
|
||||||
|
QRCode string
|
||||||
|
QRCodeImgURL string
|
||||||
|
PendingVerify string
|
||||||
|
CurrentBaseURL string
|
||||||
|
StartedAt time.Time
|
||||||
|
}
|
||||||
|
|
||||||
|
// WechatRobotHandler 微信 iLink 机器人(扫码绑定 + 配置)
|
||||||
|
type WechatRobotHandler struct {
|
||||||
|
config *config.Config
|
||||||
|
configSaver WechatConfigSaver
|
||||||
|
logger *zap.Logger
|
||||||
|
mu sync.Mutex
|
||||||
|
logins map[string]*wechatLoginSession
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewWechatRobotHandler 创建微信机器人处理器
|
||||||
|
func NewWechatRobotHandler(cfg *config.Config, saver WechatConfigSaver, logger *zap.Logger) *WechatRobotHandler {
|
||||||
|
return &WechatRobotHandler{
|
||||||
|
config: cfg,
|
||||||
|
configSaver: saver,
|
||||||
|
logger: logger,
|
||||||
|
logins: make(map[string]*wechatLoginSession),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *WechatRobotHandler) purgeExpiredLogins() {
|
||||||
|
now := time.Now()
|
||||||
|
for k, v := range h.logins {
|
||||||
|
if now.Sub(v.StartedAt) > wechatLoginTTL {
|
||||||
|
delete(h.logins, k)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *WechatRobotHandler) ilinkClient(baseURL string) *ilink.Client {
|
||||||
|
ver := h.config.Version
|
||||||
|
if ver == "" {
|
||||||
|
ver = "1.0.0"
|
||||||
|
}
|
||||||
|
ver = strings.TrimPrefix(strings.TrimSpace(ver), "v")
|
||||||
|
ver = strings.TrimPrefix(ver, "V")
|
||||||
|
wc := h.config.Robots.Wechat
|
||||||
|
return ilink.NewClient(baseURL, wc.BotToken, wc.BotAgent, ilink.BuildClientVersion(ver))
|
||||||
|
}
|
||||||
|
|
||||||
|
// HandleWechatQRCode POST /api/robot/wechat/qrcode — 生成绑定二维码
|
||||||
|
func (h *WechatRobotHandler) HandleWechatQRCode(c *gin.Context) {
|
||||||
|
h.mu.Lock()
|
||||||
|
h.purgeExpiredLogins()
|
||||||
|
h.mu.Unlock()
|
||||||
|
|
||||||
|
var req struct {
|
||||||
|
BotType string `json:"bot_type"`
|
||||||
|
}
|
||||||
|
_ = c.ShouldBindJSON(&req)
|
||||||
|
|
||||||
|
botType := req.BotType
|
||||||
|
if botType == "" {
|
||||||
|
botType = h.config.Robots.Wechat.BotType
|
||||||
|
}
|
||||||
|
if botType == "" {
|
||||||
|
botType = ilink.DefaultBotType
|
||||||
|
}
|
||||||
|
baseURL := h.config.Robots.Wechat.BaseURL
|
||||||
|
if baseURL == "" {
|
||||||
|
baseURL = ilink.DefaultBaseURL
|
||||||
|
}
|
||||||
|
|
||||||
|
var localTokens []string
|
||||||
|
if t := h.config.Robots.Wechat.BotToken; t != "" {
|
||||||
|
localTokens = []string{t}
|
||||||
|
}
|
||||||
|
|
||||||
|
client := h.ilinkClient(baseURL)
|
||||||
|
ctx, cancel := context.WithTimeout(c.Request.Context(), 30*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
qr, err := client.GetBotQRCode(ctx, botType, localTokens)
|
||||||
|
if err != nil {
|
||||||
|
h.logger.Warn("获取微信二维码失败", zap.Error(err))
|
||||||
|
c.JSON(http.StatusBadGateway, gin.H{"error": "获取二维码失败: " + err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if qr.QRCode == "" || qr.QRCodeImgContent == "" {
|
||||||
|
c.JSON(http.StatusBadGateway, gin.H{"error": "微信服务器未返回有效二维码"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
sessionKey := uuid.New().String()
|
||||||
|
h.mu.Lock()
|
||||||
|
h.logins[sessionKey] = &wechatLoginSession{
|
||||||
|
QRCode: qr.QRCode,
|
||||||
|
QRCodeImgURL: qr.QRCodeImgContent,
|
||||||
|
CurrentBaseURL: baseURL,
|
||||||
|
StartedAt: time.Now(),
|
||||||
|
}
|
||||||
|
h.mu.Unlock()
|
||||||
|
|
||||||
|
resp := gin.H{
|
||||||
|
"session_key": sessionKey,
|
||||||
|
"qrcode": qr.QRCode,
|
||||||
|
"qrcode_open_url": qr.QRCodeImgContent,
|
||||||
|
"message": "请使用微信扫描二维码并确认绑定",
|
||||||
|
}
|
||||||
|
if dataURL, err := ilink.QRCodeDataURL(qr.QRCodeImgContent, 256); err != nil {
|
||||||
|
h.logger.Warn("生成二维码图片失败", zap.Error(err))
|
||||||
|
} else {
|
||||||
|
resp["qrcode_image_data_url"] = dataURL
|
||||||
|
}
|
||||||
|
|
||||||
|
c.JSON(http.StatusOK, resp)
|
||||||
|
}
|
||||||
|
|
||||||
|
// HandleWechatQRCodeStatus GET /api/robot/wechat/qrcode/status — 轮询扫码状态
|
||||||
|
func (h *WechatRobotHandler) HandleWechatQRCodeStatus(c *gin.Context) {
|
||||||
|
sessionKey := c.Query("session_key")
|
||||||
|
verifyCode := c.Query("verify_code")
|
||||||
|
if sessionKey == "" {
|
||||||
|
c.JSON(http.StatusBadRequest, gin.H{"error": "缺少 session_key"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
h.mu.Lock()
|
||||||
|
sess, ok := h.logins[sessionKey]
|
||||||
|
h.mu.Unlock()
|
||||||
|
if !ok {
|
||||||
|
c.JSON(http.StatusNotFound, gin.H{"error": "登录会话不存在或已过期,请重新生成二维码"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if time.Since(sess.StartedAt) > wechatLoginTTL {
|
||||||
|
h.mu.Lock()
|
||||||
|
delete(h.logins, sessionKey)
|
||||||
|
h.mu.Unlock()
|
||||||
|
c.JSON(http.StatusGone, gin.H{"error": "二维码已过期,请重新生成"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
baseURL := sess.CurrentBaseURL
|
||||||
|
if baseURL == "" {
|
||||||
|
baseURL = ilink.DefaultBaseURL
|
||||||
|
}
|
||||||
|
vc := verifyCode
|
||||||
|
if vc == "" {
|
||||||
|
vc = sess.PendingVerify
|
||||||
|
}
|
||||||
|
|
||||||
|
client := h.ilinkClient(baseURL)
|
||||||
|
ctx, cancel := context.WithTimeout(c.Request.Context(), 40*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
st, err := client.GetQRCodeStatus(ctx, sess.QRCode, vc)
|
||||||
|
if err != nil {
|
||||||
|
h.logger.Warn("轮询微信二维码状态失败", zap.Error(err))
|
||||||
|
c.JSON(http.StatusBadGateway, gin.H{"error": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
switch st.Status {
|
||||||
|
case "wait", "scaned":
|
||||||
|
c.JSON(http.StatusOK, gin.H{"status": st.Status})
|
||||||
|
return
|
||||||
|
case "need_verifycode":
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"status": st.Status,
|
||||||
|
"message": "请在手机微信查看配对数字,并在下方输入",
|
||||||
|
})
|
||||||
|
return
|
||||||
|
case "scaned_but_redirect":
|
||||||
|
if st.RedirectHost != "" {
|
||||||
|
h.mu.Lock()
|
||||||
|
if s, ok := h.logins[sessionKey]; ok {
|
||||||
|
s.CurrentBaseURL = "https://" + st.RedirectHost
|
||||||
|
}
|
||||||
|
h.mu.Unlock()
|
||||||
|
}
|
||||||
|
c.JSON(http.StatusOK, gin.H{"status": st.Status})
|
||||||
|
return
|
||||||
|
case "binded_redirect":
|
||||||
|
h.mu.Lock()
|
||||||
|
delete(h.logins, sessionKey)
|
||||||
|
h.mu.Unlock()
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"status": st.Status,
|
||||||
|
"already_connected": true,
|
||||||
|
"message": "该微信已绑定过,无需重复绑定",
|
||||||
|
})
|
||||||
|
return
|
||||||
|
case "confirmed":
|
||||||
|
if st.BotToken == "" || st.ILinkBotID == "" {
|
||||||
|
c.JSON(http.StatusBadGateway, gin.H{"error": "绑定确认成功但缺少 bot_token"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
saveBase := st.BaseURL
|
||||||
|
if saveBase == "" {
|
||||||
|
saveBase = baseURL
|
||||||
|
}
|
||||||
|
wc := h.config.Robots.Wechat
|
||||||
|
wc.Enabled = true
|
||||||
|
wc.BotToken = st.BotToken
|
||||||
|
wc.ILinkBotID = st.ILinkBotID
|
||||||
|
wc.ILinkUserID = st.ILinkUserID
|
||||||
|
wc.BaseURL = saveBase
|
||||||
|
if wc.BotType == "" {
|
||||||
|
wc.BotType = ilink.DefaultBotType
|
||||||
|
}
|
||||||
|
if wc.BotAgent == "" {
|
||||||
|
wc.BotAgent = ilink.DefaultBotAgent
|
||||||
|
}
|
||||||
|
if h.configSaver != nil {
|
||||||
|
if err := h.configSaver.ApplyWechatRobotBinding(wc); err != nil {
|
||||||
|
h.logger.Warn("保存微信机器人配置失败", zap.Error(err))
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{"error": "保存配置失败: " + err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
h.config.Robots.Wechat = wc
|
||||||
|
}
|
||||||
|
h.mu.Lock()
|
||||||
|
delete(h.logins, sessionKey)
|
||||||
|
h.mu.Unlock()
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"status": "confirmed",
|
||||||
|
"message": "绑定成功,微信机器人已启用",
|
||||||
|
"ilink_bot_id": st.ILinkBotID,
|
||||||
|
"ilink_user_id": st.ILinkUserID,
|
||||||
|
})
|
||||||
|
return
|
||||||
|
default:
|
||||||
|
c.JSON(http.StatusOK, gin.H{"status": st.Status})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// HandleWechatVerifyCode POST /api/robot/wechat/qrcode/verify — 提交手机配对数字
|
||||||
|
func (h *WechatRobotHandler) HandleWechatVerifyCode(c *gin.Context) {
|
||||||
|
var req struct {
|
||||||
|
SessionKey string `json:"session_key"`
|
||||||
|
VerifyCode string `json:"verify_code"`
|
||||||
|
}
|
||||||
|
if err := c.ShouldBindJSON(&req); err != nil || req.SessionKey == "" || req.VerifyCode == "" {
|
||||||
|
c.JSON(http.StatusBadRequest, gin.H{"error": "需要 session_key 与 verify_code"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
h.mu.Lock()
|
||||||
|
sess, ok := h.logins[req.SessionKey]
|
||||||
|
if ok {
|
||||||
|
sess.PendingVerify = req.VerifyCode
|
||||||
|
}
|
||||||
|
h.mu.Unlock()
|
||||||
|
if !ok {
|
||||||
|
c.JSON(http.StatusNotFound, gin.H{"error": "登录会话不存在或已过期"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
c.JSON(http.StatusOK, gin.H{"message": "已提交配对码,请继续等待绑定"})
|
||||||
|
}
|
||||||
|
|
||||||
|
// HandleWechatStatus GET /api/robot/wechat/status — 当前绑定状态(供前端展示)
|
||||||
|
func (h *WechatRobotHandler) HandleWechatStatus(c *gin.Context) {
|
||||||
|
wc := h.config.Robots.Wechat
|
||||||
|
bound := wc.BotToken != "" && wc.ILinkBotID != ""
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"enabled": wc.Enabled,
|
||||||
|
"bound": bound,
|
||||||
|
"ilink_bot_id": wc.ILinkBotID,
|
||||||
|
"ilink_user_id": wc.ILinkUserID,
|
||||||
|
"base_url": wc.BaseURL,
|
||||||
|
})
|
||||||
|
}
|
||||||
@@ -15,8 +15,8 @@ import (
|
|||||||
|
|
||||||
"cyberstrike-ai/internal/agent"
|
"cyberstrike-ai/internal/agent"
|
||||||
"cyberstrike-ai/internal/config"
|
"cyberstrike-ai/internal/config"
|
||||||
"cyberstrike-ai/internal/einoobserve"
|
|
||||||
"cyberstrike-ai/internal/einomcp"
|
"cyberstrike-ai/internal/einomcp"
|
||||||
|
"cyberstrike-ai/internal/einoobserve"
|
||||||
"cyberstrike-ai/internal/openai"
|
"cyberstrike-ai/internal/openai"
|
||||||
|
|
||||||
"github.com/cloudwego/eino/adk"
|
"github.com/cloudwego/eino/adk"
|
||||||
@@ -77,6 +77,9 @@ type einoADKRunLoopArgs struct {
|
|||||||
StreamsMainAssistant func(agent string) bool
|
StreamsMainAssistant func(agent string) bool
|
||||||
EinoRoleTag func(agent string) string
|
EinoRoleTag func(agent string) string
|
||||||
CheckpointDir string
|
CheckpointDir string
|
||||||
|
// RunRetryMaxAttempts / RunRetryMaxBackoffSec:429、5xx、网络抖动时的指数退避续跑(0=默认 10 次 / 30s 上限)。
|
||||||
|
RunRetryMaxAttempts int
|
||||||
|
RunRetryMaxBackoffSec int
|
||||||
|
|
||||||
McpIDsMu *sync.Mutex
|
McpIDsMu *sync.Mutex
|
||||||
McpIDs *[]string
|
McpIDs *[]string
|
||||||
@@ -177,6 +180,8 @@ func runEinoADKAgentLoop(ctx context.Context, args *einoADKRunLoopArgs, baseMsgs
|
|||||||
var einoMainRound int
|
var einoMainRound int
|
||||||
var einoLastAgent string
|
var einoLastAgent string
|
||||||
subAgentToolStep := make(map[string]int)
|
subAgentToolStep := make(map[string]int)
|
||||||
|
// mainAgentToolStep:主代理每次工具调用批次递增,供 UI 显示「第 N 轮」(单代理无子代理切换时原先会一直停在第 1 轮)。
|
||||||
|
mainAgentToolStep := make(map[string]int)
|
||||||
pendingByID := make(map[string]toolCallPendingInfo)
|
pendingByID := make(map[string]toolCallPendingInfo)
|
||||||
pendingQueueByAgent := make(map[string][]string)
|
pendingQueueByAgent := make(map[string][]string)
|
||||||
markPending := func(tc toolCallPendingInfo) {
|
markPending := func(tc toolCallPendingInfo) {
|
||||||
@@ -267,7 +272,16 @@ func runEinoADKAgentLoop(ctx context.Context, args *einoADKRunLoopArgs, baseMsgs
|
|||||||
isErr := !success || invokeErr != nil
|
isErr := !success || invokeErr != nil
|
||||||
body := content
|
body := content
|
||||||
if invokeErr != nil {
|
if invokeErr != nil {
|
||||||
body = invokeErr.Error()
|
// 保留已流式累计的 stdout(如 execute 超时前的一半输出),避免 tool_result 只剩错误串、模型与 UI 丢失上下文
|
||||||
|
tail := friendlyEinoExecuteInvokeTail(invokeErr)
|
||||||
|
// execute 流式包装可能已把超时句写入 content(供 ADK tool 与流式 delta);勿重复拼接
|
||||||
|
if tail != "" && strings.Contains(content, tail) {
|
||||||
|
body = content
|
||||||
|
} else if strings.TrimSpace(content) != "" {
|
||||||
|
body = strings.TrimRight(content, "\n") + "\n\n" + tail
|
||||||
|
} else {
|
||||||
|
body = tail
|
||||||
|
}
|
||||||
isErr = true
|
isErr = true
|
||||||
}
|
}
|
||||||
recordPendingExecuteStdoutDup(toolName, body, isErr)
|
recordPendingExecuteStdoutDup(toolName, body, isErr)
|
||||||
@@ -426,6 +440,28 @@ func runEinoADKAgentLoop(ctx context.Context, args *einoADKRunLoopArgs, baseMsgs
|
|||||||
return runErr
|
return runErr
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// maybeRetryTransientRun:不在此层 runner.Run/Resume;由 handler 落库 + loadHistoryFromAgentTrace 分段续跑(同中断并继续)。
|
||||||
|
maybeRetryTransientRun := func(runErr error) (retry bool, fatal error) {
|
||||||
|
if runErr == nil || !isEinoTransientRunError(runErr) {
|
||||||
|
return false, handleRunErr(runErr)
|
||||||
|
}
|
||||||
|
if logger != nil {
|
||||||
|
logger.Warn("eino transient error, ending run segment for handler resume",
|
||||||
|
zap.Error(runErr),
|
||||||
|
zap.String("orchestration", orchMode))
|
||||||
|
}
|
||||||
|
if progress != nil {
|
||||||
|
progress("eino_run_retry", "遇到临时错误(限流或网络波动),将保存上下文并重试…", map[string]interface{}{
|
||||||
|
"conversationId": conversationID,
|
||||||
|
"source": "eino",
|
||||||
|
"orchestration": orchMode,
|
||||||
|
"error": runErr.Error(),
|
||||||
|
"resumeKind": "trace_segment",
|
||||||
|
})
|
||||||
|
}
|
||||||
|
return false, ErrTransientRetryContinue
|
||||||
|
}
|
||||||
|
|
||||||
takePartial := func(runErr error) (*RunResult, error) {
|
takePartial := func(runErr error) (*RunResult, error) {
|
||||||
if len(runAccumulatedMsgs) <= baseAccumulatedCount {
|
if len(runAccumulatedMsgs) <= baseAccumulatedCount {
|
||||||
return nil, runErr
|
return nil, runErr
|
||||||
@@ -508,7 +544,7 @@ func runEinoADKAgentLoop(ctx context.Context, args *einoADKRunLoopArgs, baseMsgs
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
if ev.Err != nil {
|
if ev.Err != nil {
|
||||||
if retErr := handleRunErr(ev.Err); retErr != nil {
|
if _, retErr := maybeRetryTransientRun(ev.Err); retErr != nil {
|
||||||
return takePartial(retErr)
|
return takePartial(retErr)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -520,8 +556,10 @@ func runEinoADKAgentLoop(ctx context.Context, args *einoADKRunLoopArgs, baseMsgs
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
if streamsMainAssistant(ev.AgentName) {
|
if streamsMainAssistant(ev.AgentName) {
|
||||||
|
mainIterKey := einoMainIterationKey(iterEinoAgent, orchestratorName)
|
||||||
if einoMainRound == 0 {
|
if einoMainRound == 0 {
|
||||||
einoMainRound = 1
|
einoMainRound = 1
|
||||||
|
mainAgentToolStep[mainIterKey] = 1
|
||||||
progress("iteration", "", map[string]interface{}{
|
progress("iteration", "", map[string]interface{}{
|
||||||
"iteration": 1,
|
"iteration": 1,
|
||||||
"einoScope": "main",
|
"einoScope": "main",
|
||||||
@@ -531,17 +569,26 @@ func runEinoADKAgentLoop(ctx context.Context, args *einoADKRunLoopArgs, baseMsgs
|
|||||||
"conversationId": conversationID,
|
"conversationId": conversationID,
|
||||||
"source": "eino",
|
"source": "eino",
|
||||||
})
|
})
|
||||||
} else if einoLastAgent != "" && !streamsMainAssistant(einoLastAgent) {
|
} else if einoLastAgent != "" {
|
||||||
einoMainRound++
|
needBump := false
|
||||||
progress("iteration", "", map[string]interface{}{
|
if !streamsMainAssistant(einoLastAgent) {
|
||||||
"iteration": einoMainRound,
|
needBump = true // 子代理 → 主代理
|
||||||
"einoScope": "main",
|
} else if einoLastAgent != ev.AgentName {
|
||||||
"einoRole": "orchestrator",
|
needBump = true // plan_execute:planner ↔ executor 等主代理切换
|
||||||
"einoAgent": iterEinoAgent,
|
}
|
||||||
"orchestration": orchMode,
|
if needBump {
|
||||||
"conversationId": conversationID,
|
einoMainRound++
|
||||||
"source": "eino",
|
mainAgentToolStep[mainIterKey] = einoMainRound
|
||||||
})
|
progress("iteration", "", map[string]interface{}{
|
||||||
|
"iteration": einoMainRound,
|
||||||
|
"einoScope": "main",
|
||||||
|
"einoRole": "orchestrator",
|
||||||
|
"einoAgent": iterEinoAgent,
|
||||||
|
"orchestration": orchMode,
|
||||||
|
"conversationId": conversationID,
|
||||||
|
"source": "eino",
|
||||||
|
})
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
einoLastAgent = ev.AgentName
|
einoLastAgent = ev.AgentName
|
||||||
@@ -564,6 +611,8 @@ func runEinoADKAgentLoop(ctx context.Context, args *einoADKRunLoopArgs, baseMsgs
|
|||||||
var subAssistantBuf string
|
var subAssistantBuf string
|
||||||
var subReplyStreamID string
|
var subReplyStreamID string
|
||||||
var mainAssistantBuf string
|
var mainAssistantBuf string
|
||||||
|
// 已通过 response_delta 推到前端的正文(与 monitor.js normalizeStreamingDeltaJs 累积一致)
|
||||||
|
var mainAssistWireAccum string
|
||||||
var mainAssistDupTarget string // 非空表示本段主助手流需缓冲至 EOF,与 execute 输出比对去重
|
var mainAssistDupTarget string // 非空表示本段主助手流需缓冲至 EOF,与 execute 输出比对去重
|
||||||
var reasoningBuf string
|
var reasoningBuf string
|
||||||
var prevReasoningDisplay string // UI 用:剥离 Claude 内部 signature 尾缀后的累计展示
|
var prevReasoningDisplay string // UI 用:剥离 Claude 内部 signature 尾缀后的累计展示
|
||||||
@@ -633,9 +682,9 @@ func runEinoADKAgentLoop(ctx context.Context, args *einoADKRunLoopArgs, baseMsgs
|
|||||||
"orchestration": orchMode,
|
"orchestration": orchMode,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
progress("reasoning_chain_stream_delta", displayDelta, map[string]interface{}{
|
progress("reasoning_chain_stream_delta", displayDelta, openai.WithSSEAccumulated(map[string]interface{}{
|
||||||
"streamId": reasoningStreamID,
|
"streamId": reasoningStreamID,
|
||||||
})
|
}, fullDisplay))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -665,13 +714,14 @@ func runEinoADKAgentLoop(ctx context.Context, args *einoADKRunLoopArgs, baseMsgs
|
|||||||
})
|
})
|
||||||
streamHeaderSent = true
|
streamHeaderSent = true
|
||||||
}
|
}
|
||||||
progress("response_delta", contentDelta, map[string]interface{}{
|
progress("response_delta", contentDelta, openai.WithSSEAccumulated(map[string]interface{}{
|
||||||
"conversationId": conversationID,
|
"conversationId": conversationID,
|
||||||
"mcpExecutionIds": snapshotMCPIDs(),
|
"mcpExecutionIds": snapshotMCPIDs(),
|
||||||
"einoRole": "orchestrator",
|
"einoRole": "orchestrator",
|
||||||
"einoAgent": ev.AgentName,
|
"einoAgent": ev.AgentName,
|
||||||
"orchestration": orchMode,
|
"orchestration": orchMode,
|
||||||
})
|
}, mainAssistantBuf))
|
||||||
|
mainAssistWireAccum, _ = normalizeStreamingDelta(mainAssistWireAccum, contentDelta)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else if !streamsMainAssistant(ev.AgentName) {
|
} else if !streamsMainAssistant(ev.AgentName) {
|
||||||
@@ -689,10 +739,10 @@ func runEinoADKAgentLoop(ctx context.Context, args *einoADKRunLoopArgs, baseMsgs
|
|||||||
"source": "eino",
|
"source": "eino",
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
progress("eino_agent_reply_stream_delta", subDelta, map[string]interface{}{
|
progress("eino_agent_reply_stream_delta", subDelta, openai.WithSSEAccumulated(map[string]interface{}{
|
||||||
"streamId": subReplyStreamID,
|
"streamId": subReplyStreamID,
|
||||||
"conversationId": conversationID,
|
"conversationId": conversationID,
|
||||||
})
|
}, subAssistantBuf))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -717,21 +767,29 @@ func runEinoADKAgentLoop(ctx context.Context, args *einoADKRunLoopArgs, baseMsgs
|
|||||||
}
|
}
|
||||||
} else if s != "" {
|
} else if s != "" {
|
||||||
if progress != nil {
|
if progress != nil {
|
||||||
progress("response_start", "", map[string]interface{}{
|
// 仅用 TrimSpace 与 execute 比对;推到 UI 的必须是 mainAssistantBuf,
|
||||||
"conversationId": conversationID,
|
// 否则尾部空白/换行与已流式前缀不一致时,前端 normalize 会走拼接路径造成叠字。
|
||||||
"mcpExecutionIds": snapshotMCPIDs(),
|
_, eofTail := normalizeStreamingDelta(mainAssistWireAccum, mainAssistantBuf)
|
||||||
"messageGeneratedBy": "eino:" + ev.AgentName,
|
if eofTail != "" {
|
||||||
"einoRole": "orchestrator",
|
if !streamHeaderSent {
|
||||||
"einoAgent": ev.AgentName,
|
progress("response_start", "", map[string]interface{}{
|
||||||
"orchestration": orchMode,
|
"conversationId": conversationID,
|
||||||
})
|
"mcpExecutionIds": snapshotMCPIDs(),
|
||||||
progress("response_delta", s, map[string]interface{}{
|
"messageGeneratedBy": "eino:" + ev.AgentName,
|
||||||
"conversationId": conversationID,
|
"einoRole": "orchestrator",
|
||||||
"mcpExecutionIds": snapshotMCPIDs(),
|
"einoAgent": ev.AgentName,
|
||||||
"einoRole": "orchestrator",
|
"orchestration": orchMode,
|
||||||
"einoAgent": ev.AgentName,
|
})
|
||||||
"orchestration": orchMode,
|
}
|
||||||
})
|
progress("response_delta", eofTail, openai.WithSSEAccumulated(map[string]interface{}{
|
||||||
|
"conversationId": conversationID,
|
||||||
|
"mcpExecutionIds": snapshotMCPIDs(),
|
||||||
|
"einoRole": "orchestrator",
|
||||||
|
"einoAgent": ev.AgentName,
|
||||||
|
"orchestration": orchMode,
|
||||||
|
}, mainAssistantBuf))
|
||||||
|
mainAssistWireAccum, _ = normalizeStreamingDelta(mainAssistWireAccum, eofTail)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
lastAssistant = s
|
lastAssistant = s
|
||||||
runAccumulatedMsgs = append(runAccumulatedMsgs, schema.AssistantMessage(s, nil))
|
runAccumulatedMsgs = append(runAccumulatedMsgs, schema.AssistantMessage(s, nil))
|
||||||
@@ -771,7 +829,7 @@ func runEinoADKAgentLoop(ctx context.Context, args *einoADKRunLoopArgs, baseMsgs
|
|||||||
if merged := mergeStreamingToolCallFragments(toolStreamFragments); len(merged) > 0 {
|
if merged := mergeStreamingToolCallFragments(toolStreamFragments); len(merged) > 0 {
|
||||||
lastToolChunk = mergeMessageToolCalls(&schema.Message{ToolCalls: merged})
|
lastToolChunk = mergeMessageToolCalls(&schema.Message{ToolCalls: merged})
|
||||||
}
|
}
|
||||||
tryEmitToolCallsOnce(lastToolChunk, ev.AgentName, orchestratorName, conversationID, progress, toolEmitSeen, subAgentToolStep, markPending)
|
tryEmitToolCallsOnce(lastToolChunk, ev.AgentName, orchestratorName, conversationID, orchMode, progress, toolEmitSeen, subAgentToolStep, mainAgentToolStep, markPending)
|
||||||
// 流式路径此前只把 tool_calls 推给进度 UI,未写入 runAccumulatedMsgs;落库后 loadHistory→RepairOrphan 会删掉全部 tool 结果,表现为「续跑/下轮失忆」。
|
// 流式路径此前只把 tool_calls 推给进度 UI,未写入 runAccumulatedMsgs;落库后 loadHistory→RepairOrphan 会删掉全部 tool 结果,表现为「续跑/下轮失忆」。
|
||||||
if lastToolChunk != nil && len(lastToolChunk.ToolCalls) > 0 {
|
if lastToolChunk != nil && len(lastToolChunk.ToolCalls) > 0 {
|
||||||
runAccumulatedMsgs = append(runAccumulatedMsgs, schema.AssistantMessage("", lastToolChunk.ToolCalls))
|
runAccumulatedMsgs = append(runAccumulatedMsgs, schema.AssistantMessage("", lastToolChunk.ToolCalls))
|
||||||
@@ -788,7 +846,7 @@ func runEinoADKAgentLoop(ctx context.Context, args *einoADKRunLoopArgs, baseMsgs
|
|||||||
"einoRole": einoRoleTag(ev.AgentName),
|
"einoRole": einoRoleTag(ev.AgentName),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
if retErr := handleRunErr(streamRecvErr); retErr != nil {
|
if _, retErr := maybeRetryTransientRun(streamRecvErr); retErr != nil {
|
||||||
return takePartial(retErr)
|
return takePartial(retErr)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -800,7 +858,7 @@ func runEinoADKAgentLoop(ctx context.Context, args *einoADKRunLoopArgs, baseMsgs
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
runAccumulatedMsgs = append(runAccumulatedMsgs, msg)
|
runAccumulatedMsgs = append(runAccumulatedMsgs, msg)
|
||||||
tryEmitToolCallsOnce(mergeMessageToolCalls(msg), ev.AgentName, orchestratorName, conversationID, progress, toolEmitSeen, subAgentToolStep, markPending)
|
tryEmitToolCallsOnce(mergeMessageToolCalls(msg), ev.AgentName, orchestratorName, conversationID, orchMode, progress, toolEmitSeen, subAgentToolStep, mainAgentToolStep, markPending)
|
||||||
|
|
||||||
if mv.Role == schema.Assistant {
|
if mv.Role == schema.Assistant {
|
||||||
if progress != nil && strings.TrimSpace(msg.ReasoningContent) != "" {
|
if progress != nil && strings.TrimSpace(msg.ReasoningContent) != "" {
|
||||||
@@ -839,13 +897,13 @@ func runEinoADKAgentLoop(ctx context.Context, args *einoADKRunLoopArgs, baseMsgs
|
|||||||
"einoAgent": ev.AgentName,
|
"einoAgent": ev.AgentName,
|
||||||
"orchestration": orchMode,
|
"orchestration": orchMode,
|
||||||
})
|
})
|
||||||
progress("response_delta", body, map[string]interface{}{
|
progress("response_delta", body, openai.WithSSEAccumulated(map[string]interface{}{
|
||||||
"conversationId": conversationID,
|
"conversationId": conversationID,
|
||||||
"mcpExecutionIds": snapshotMCPIDs(),
|
"mcpExecutionIds": snapshotMCPIDs(),
|
||||||
"einoRole": "orchestrator",
|
"einoRole": "orchestrator",
|
||||||
"einoAgent": ev.AgentName,
|
"einoAgent": ev.AgentName,
|
||||||
"orchestration": orchMode,
|
"orchestration": orchMode,
|
||||||
})
|
}, body))
|
||||||
}
|
}
|
||||||
lastAssistant = body
|
lastAssistant = body
|
||||||
if orchMode == "plan_execute" && strings.EqualFold(strings.TrimSpace(ev.AgentName), "executor") {
|
if orchMode == "plan_execute" && strings.EqualFold(strings.TrimSpace(ev.AgentName), "executor") {
|
||||||
@@ -948,6 +1006,17 @@ func einoPartialRunLastOutputHint() string {
|
|||||||
"[Run ended abnormally; continue from the trace above without repeating completed steps.]"
|
"[Run ended abnormally; continue from the trace above without repeating completed steps.]"
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// friendlyEinoExecuteInvokeTail 将 Eino execute 等非 MCP 路径的结尾错误转成简短提示;其它情况保留原 error 文本。
|
||||||
|
func friendlyEinoExecuteInvokeTail(invokeErr error) string {
|
||||||
|
if invokeErr == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
if errors.Is(invokeErr, context.DeadlineExceeded) {
|
||||||
|
return einoExecuteTimeoutUserHint()
|
||||||
|
}
|
||||||
|
return "[执行未正常结束] " + invokeErr.Error()
|
||||||
|
}
|
||||||
|
|
||||||
func buildEinoRunResultFromAccumulated(
|
func buildEinoRunResultFromAccumulated(
|
||||||
orchMode string,
|
orchMode string,
|
||||||
runAccumulatedMsgs []adk.Message,
|
runAccumulatedMsgs []adk.Message,
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"strings"
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
"cyberstrike-ai/internal/einomcp"
|
"cyberstrike-ai/internal/einomcp"
|
||||||
"cyberstrike-ai/internal/security"
|
"cyberstrike-ai/internal/security"
|
||||||
@@ -15,6 +16,24 @@ import (
|
|||||||
"github.com/cloudwego/eino/schema"
|
"github.com/cloudwego/eino/schema"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// prependPythonUnbufferedEnv 为 /bin/sh -c 注入 PYTHONUNBUFFERED=1。
|
||||||
|
// eino-ext local 对流式 stdout 使用 bufio 按「行」推送;python3 写管道时默认块缓冲,print 长期留在用户态缓冲,
|
||||||
|
// 管道里收不到换行,表现为长时间无输出直至超时或退出。若命令里已出现 PYTHONUNBUFFERED 则不再覆盖。
|
||||||
|
func prependPythonUnbufferedEnv(shellCommand string) string {
|
||||||
|
if strings.TrimSpace(shellCommand) == "" {
|
||||||
|
return shellCommand
|
||||||
|
}
|
||||||
|
if strings.Contains(strings.ToUpper(shellCommand), "PYTHONUNBUFFERED") {
|
||||||
|
return shellCommand
|
||||||
|
}
|
||||||
|
return "export PYTHONUNBUFFERED=1\n" + shellCommand
|
||||||
|
}
|
||||||
|
|
||||||
|
// einoExecuteTimeoutUserHint 与写入 ADK 工具消息(模型可见)及 SSE tool_result 尾标一致。
|
||||||
|
func einoExecuteTimeoutUserHint() string {
|
||||||
|
return "已超时终止 · Timed out"
|
||||||
|
}
|
||||||
|
|
||||||
// einoStreamingShellWrap 包装 Eino filesystem 使用的 StreamingShell(cloudwego eino-ext local.Local)。
|
// einoStreamingShellWrap 包装 Eino filesystem 使用的 StreamingShell(cloudwego eino-ext local.Local)。
|
||||||
// 官方 execute 工具默认走 ExecuteStreaming 且不设 RunInBackendGround;末尾带 & 时子进程仍与管道相连,
|
// 官方 execute 工具默认走 ExecuteStreaming 且不设 RunInBackendGround;末尾带 & 时子进程仍与管道相连,
|
||||||
// streamStdout 按行读取会在无换行输出时长时间阻塞(与 MCP 工具 exec 的独立实现不同)。
|
// streamStdout 按行读取会在无换行输出时长时间阻塞(与 MCP 工具 exec 的独立实现不同)。
|
||||||
@@ -29,6 +48,10 @@ type einoStreamingShellWrap struct {
|
|||||||
inner filesystem.StreamingShell
|
inner filesystem.StreamingShell
|
||||||
invokeNotify *einomcp.ToolInvokeNotifyHolder
|
invokeNotify *einomcp.ToolInvokeNotifyHolder
|
||||||
einoAgentName string
|
einoAgentName string
|
||||||
|
// outputChunk 可选;非 nil 时在收到内层 ExecuteResponse 片段时推送,与 MCP 工具的 tool_result_delta 一致(需有效 toolCallId)。
|
||||||
|
outputChunk func(toolName, toolCallID, chunk string)
|
||||||
|
// toolTimeoutMinutes 与 agent.tool_timeout_minutes 对齐;>0 时对单次 execute 套用 context 超时(与 MCP 工具经 executeToolViaMCP 行为一致)。0 表示仅依赖上层 ctx(如整任务 10h 上限)。
|
||||||
|
toolTimeoutMinutes int
|
||||||
// recordMonitor 在 execute 流结束后写入 tool_executions 并 recorder(executionId),使「渗透测试详情」与常规 MCP 一致。
|
// recordMonitor 在 execute 流结束后写入 tool_executions 并 recorder(executionId),使「渗透测试详情」与常规 MCP 一致。
|
||||||
recordMonitor func(command, stdout string, success bool, invokeErr error)
|
recordMonitor func(command, stdout string, success bool, invokeErr error)
|
||||||
}
|
}
|
||||||
@@ -41,17 +64,27 @@ func (w *einoStreamingShellWrap) ExecuteStreaming(ctx context.Context, input *fi
|
|||||||
return w.inner.ExecuteStreaming(ctx, nil)
|
return w.inner.ExecuteStreaming(ctx, nil)
|
||||||
}
|
}
|
||||||
req := *input
|
req := *input
|
||||||
cmd := strings.TrimSpace(req.Command)
|
userCmd := strings.TrimSpace(req.Command)
|
||||||
if security.IsBackgroundShellCommand(req.Command) && !req.RunInBackendGround {
|
if security.IsBackgroundShellCommand(req.Command) && !req.RunInBackendGround {
|
||||||
req.RunInBackendGround = true
|
req.RunInBackendGround = true
|
||||||
}
|
}
|
||||||
|
req.Command = prependPythonUnbufferedEnv(req.Command)
|
||||||
tid := strings.TrimSpace(compose.GetToolCallID(ctx))
|
tid := strings.TrimSpace(compose.GetToolCallID(ctx))
|
||||||
agentTag := strings.TrimSpace(w.einoAgentName)
|
agentTag := strings.TrimSpace(w.einoAgentName)
|
||||||
|
|
||||||
sr, err := w.inner.ExecuteStreaming(ctx, &req)
|
execCtx := ctx
|
||||||
|
var execCancel context.CancelFunc
|
||||||
|
if w.toolTimeoutMinutes > 0 {
|
||||||
|
execCtx, execCancel = context.WithTimeout(ctx, time.Duration(w.toolTimeoutMinutes)*time.Minute)
|
||||||
|
}
|
||||||
|
|
||||||
|
sr, err := w.inner.ExecuteStreaming(execCtx, &req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
if execCancel != nil {
|
||||||
|
execCancel()
|
||||||
|
}
|
||||||
if w.recordMonitor != nil {
|
if w.recordMonitor != nil {
|
||||||
w.recordMonitor(cmd, "", false, err)
|
w.recordMonitor(userCmd, "", false, err)
|
||||||
}
|
}
|
||||||
if w.invokeNotify != nil && tid != "" {
|
if w.invokeNotify != nil && tid != "" {
|
||||||
w.invokeNotify.Fire(tid, "execute", agentTag, false, "", err)
|
w.invokeNotify.Fire(tid, "execute", agentTag, false, "", err)
|
||||||
@@ -59,13 +92,19 @@ func (w *einoStreamingShellWrap) ExecuteStreaming(ctx context.Context, input *fi
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
if sr == nil || w.invokeNotify == nil || tid == "" {
|
if sr == nil || w.invokeNotify == nil || tid == "" {
|
||||||
|
if execCancel != nil {
|
||||||
|
execCancel()
|
||||||
|
}
|
||||||
return sr, nil
|
return sr, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
outR, outW := schema.Pipe[*filesystem.ExecuteResponse](32)
|
outR, outW := schema.Pipe[*filesystem.ExecuteResponse](32)
|
||||||
|
|
||||||
go func(inner *schema.StreamReader[*filesystem.ExecuteResponse], command string) {
|
go func(inner *schema.StreamReader[*filesystem.ExecuteResponse], command string, cancel context.CancelFunc, tctx context.Context) {
|
||||||
defer inner.Close()
|
defer inner.Close()
|
||||||
|
if cancel != nil {
|
||||||
|
defer cancel()
|
||||||
|
}
|
||||||
|
|
||||||
var sb strings.Builder
|
var sb strings.Builder
|
||||||
const maxCapture = 16 * 1024
|
const maxCapture = 16 * 1024
|
||||||
@@ -90,12 +129,18 @@ func (w *einoStreamingShellWrap) ExecuteStreaming(ctx context.Context, input *fi
|
|||||||
hasExitCode = true
|
hasExitCode = true
|
||||||
exitCode = *resp.ExitCode
|
exitCode = *resp.ExitCode
|
||||||
}
|
}
|
||||||
|
var appended string
|
||||||
if remain := maxCapture - sb.Len(); remain > 0 {
|
if remain := maxCapture - sb.Len(); remain > 0 {
|
||||||
out := resp.Output
|
out := resp.Output
|
||||||
if len(out) > remain {
|
if len(out) > remain {
|
||||||
out = out[:remain]
|
out = out[:remain]
|
||||||
}
|
}
|
||||||
sb.WriteString(out)
|
sb.WriteString(out)
|
||||||
|
appended = out
|
||||||
|
}
|
||||||
|
// 仅推送写入 sb 的片段,与末尾 Fire/recordMonitor 的截断累计一致,避免最终 tool_result 短于已展示增量。
|
||||||
|
if w.outputChunk != nil && strings.TrimSpace(appended) != "" {
|
||||||
|
w.outputChunk("execute", tid, appended)
|
||||||
}
|
}
|
||||||
if outW.Send(resp, nil) {
|
if outW.Send(resp, nil) {
|
||||||
success = false
|
success = false
|
||||||
@@ -109,12 +154,33 @@ func (w *einoStreamingShellWrap) ExecuteStreaming(ctx context.Context, input *fi
|
|||||||
success = false
|
success = false
|
||||||
invokeErr = fmt.Errorf("execute exited with code %d", exitCode)
|
invokeErr = fmt.Errorf("execute exited with code %d", exitCode)
|
||||||
}
|
}
|
||||||
|
// WithTimeout 触发后,子进程常被信号结束,local 侧多报 exit -1 / canceled,错误链里不一定带 DeadlineExceeded。
|
||||||
|
// 用执行所用 ctx 归一化,便于 UI 展示「超时」而非含糊的 -1。
|
||||||
|
if tctx != nil && errors.Is(tctx.Err(), context.DeadlineExceeded) {
|
||||||
|
success = false
|
||||||
|
invokeErr = context.DeadlineExceeded
|
||||||
|
}
|
||||||
|
// ADK 从本 Pipe 拼出 tool 消息正文;仅 Notify 尾标不会进入模型上下文。超时句写入流,与 UI 一致。
|
||||||
|
if invokeErr != nil && errors.Is(invokeErr, context.DeadlineExceeded) {
|
||||||
|
hint := "\n\n" + einoExecuteTimeoutUserHint() + "\n"
|
||||||
|
_ = outW.Send(&filesystem.ExecuteResponse{Output: hint}, nil)
|
||||||
|
if w.outputChunk != nil && tid != "" {
|
||||||
|
w.outputChunk("execute", tid, hint)
|
||||||
|
}
|
||||||
|
if remain := maxCapture - sb.Len(); remain > 0 {
|
||||||
|
h := hint
|
||||||
|
if len(h) > remain {
|
||||||
|
h = h[:remain]
|
||||||
|
}
|
||||||
|
sb.WriteString(h)
|
||||||
|
}
|
||||||
|
}
|
||||||
if w.recordMonitor != nil {
|
if w.recordMonitor != nil {
|
||||||
w.recordMonitor(command, sb.String(), success, invokeErr)
|
w.recordMonitor(command, sb.String(), success, invokeErr)
|
||||||
}
|
}
|
||||||
w.invokeNotify.Fire(tid, "execute", agentTag, success, sb.String(), invokeErr)
|
w.invokeNotify.Fire(tid, "execute", agentTag, success, sb.String(), invokeErr)
|
||||||
outW.Close()
|
outW.Close()
|
||||||
}(sr, cmd)
|
}(sr, userCmd, execCancel, execCtx)
|
||||||
|
|
||||||
return outR, nil
|
return outR, nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -161,6 +161,8 @@ func buildReductionMiddleware(ctx context.Context, mw config.MultiAgentEinoMiddl
|
|||||||
}
|
}
|
||||||
|
|
||||||
// prependEinoMiddlewares returns handlers to prepend (outermost first) and optionally replaces tools when tool_search is used.
|
// prependEinoMiddlewares returns handlers to prepend (outermost first) and optionally replaces tools when tool_search is used.
|
||||||
|
// toolSearchActive is true when the toolsearch middleware was mounted (dynamic tools split off); callers should pass this to
|
||||||
|
// injectToolNamesOnlyInstruction — tool_search is not part of the pre-middleware tools list, so name-scanning alone cannot detect it.
|
||||||
func prependEinoMiddlewares(
|
func prependEinoMiddlewares(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
mw *config.MultiAgentEinoMiddlewareConfig,
|
mw *config.MultiAgentEinoMiddlewareConfig,
|
||||||
@@ -170,16 +172,16 @@ func prependEinoMiddlewares(
|
|||||||
skillsRoot string,
|
skillsRoot string,
|
||||||
conversationID string,
|
conversationID string,
|
||||||
logger *zap.Logger,
|
logger *zap.Logger,
|
||||||
) (outTools []tool.BaseTool, extraHandlers []adk.ChatModelAgentMiddleware, err error) {
|
) (outTools []tool.BaseTool, extraHandlers []adk.ChatModelAgentMiddleware, toolSearchActive bool, err error) {
|
||||||
if mw == nil {
|
if mw == nil {
|
||||||
return tools, nil, nil
|
return tools, nil, false, nil
|
||||||
}
|
}
|
||||||
outTools = tools
|
outTools = tools
|
||||||
|
|
||||||
if mw.PatchToolCallsEffective() {
|
if mw.PatchToolCallsEffective() {
|
||||||
patchMW, perr := patchtoolcalls.New(ctx, &patchtoolcalls.Config{})
|
patchMW, perr := patchtoolcalls.New(ctx, &patchtoolcalls.Config{})
|
||||||
if perr != nil {
|
if perr != nil {
|
||||||
return nil, nil, fmt.Errorf("patchtoolcalls: %w", perr)
|
return nil, nil, false, fmt.Errorf("patchtoolcalls: %w", perr)
|
||||||
}
|
}
|
||||||
extraHandlers = append(extraHandlers, patchMW)
|
extraHandlers = append(extraHandlers, patchMW)
|
||||||
}
|
}
|
||||||
@@ -190,7 +192,7 @@ func prependEinoMiddlewares(
|
|||||||
} else {
|
} else {
|
||||||
redMW, rerr := buildReductionMiddleware(ctx, *mw, conversationID, einoLoc, logger)
|
redMW, rerr := buildReductionMiddleware(ctx, *mw, conversationID, einoLoc, logger)
|
||||||
if rerr != nil {
|
if rerr != nil {
|
||||||
return nil, nil, rerr
|
return nil, nil, false, rerr
|
||||||
}
|
}
|
||||||
extraHandlers = append(extraHandlers, redMW)
|
extraHandlers = append(extraHandlers, redMW)
|
||||||
}
|
}
|
||||||
@@ -209,10 +211,11 @@ func prependEinoMiddlewares(
|
|||||||
if split && len(dynamic) > 0 {
|
if split && len(dynamic) > 0 {
|
||||||
ts, terr := toolsearch.New(ctx, &toolsearch.Config{DynamicTools: dynamic})
|
ts, terr := toolsearch.New(ctx, &toolsearch.Config{DynamicTools: dynamic})
|
||||||
if terr != nil {
|
if terr != nil {
|
||||||
return nil, nil, fmt.Errorf("toolsearch: %w", terr)
|
return nil, nil, false, fmt.Errorf("toolsearch: %w", terr)
|
||||||
}
|
}
|
||||||
extraHandlers = append(extraHandlers, ts)
|
extraHandlers = append(extraHandlers, ts)
|
||||||
outTools = static
|
outTools = static
|
||||||
|
toolSearchActive = true
|
||||||
if logger != nil {
|
if logger != nil {
|
||||||
logger.Info("eino middleware: tool_search enabled",
|
logger.Info("eino middleware: tool_search enabled",
|
||||||
zap.Int("static_tools", len(static)),
|
zap.Int("static_tools", len(static)),
|
||||||
@@ -233,12 +236,12 @@ func prependEinoMiddlewares(
|
|||||||
}
|
}
|
||||||
baseDir := filepath.Join(skillsRoot, rel, sanitizeEinoPathSegment(conversationID))
|
baseDir := filepath.Join(skillsRoot, rel, sanitizeEinoPathSegment(conversationID))
|
||||||
if mk := os.MkdirAll(baseDir, 0o755); mk != nil {
|
if mk := os.MkdirAll(baseDir, 0o755); mk != nil {
|
||||||
return nil, nil, fmt.Errorf("plantask mkdir: %w", mk)
|
return nil, nil, toolSearchActive, fmt.Errorf("plantask mkdir: %w", mk)
|
||||||
}
|
}
|
||||||
ptBE := &localPlantaskBackend{Local: einoLoc}
|
ptBE := &localPlantaskBackend{Local: einoLoc}
|
||||||
pt, perr := plantask.New(ctx, &plantask.Config{Backend: ptBE, BaseDir: baseDir})
|
pt, perr := plantask.New(ctx, &plantask.Config{Backend: ptBE, BaseDir: baseDir})
|
||||||
if perr != nil {
|
if perr != nil {
|
||||||
return nil, nil, fmt.Errorf("plantask: %w", perr)
|
return nil, nil, toolSearchActive, fmt.Errorf("plantask: %w", perr)
|
||||||
}
|
}
|
||||||
extraHandlers = append(extraHandlers, pt)
|
extraHandlers = append(extraHandlers, pt)
|
||||||
if logger != nil {
|
if logger != nil {
|
||||||
@@ -247,7 +250,7 @@ func prependEinoMiddlewares(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return outTools, extraHandlers, nil
|
return outTools, extraHandlers, toolSearchActive, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func deepExtrasFromConfig(ma *config.MultiAgentConfig) (outputKey string, retry *adk.ModelRetryConfig, taskDesc func(context.Context, []adk.Agent) (string, error)) {
|
func deepExtrasFromConfig(ma *config.MultiAgentConfig) (outputKey string, retry *adk.ModelRetryConfig, taskDesc func(context.Context, []adk.Agent) (string, error)) {
|
||||||
|
|||||||
@@ -18,7 +18,6 @@ import (
|
|||||||
einoopenai "github.com/cloudwego/eino-ext/components/model/openai"
|
einoopenai "github.com/cloudwego/eino-ext/components/model/openai"
|
||||||
"github.com/cloudwego/eino/adk"
|
"github.com/cloudwego/eino/adk"
|
||||||
"github.com/cloudwego/eino/compose"
|
"github.com/cloudwego/eino/compose"
|
||||||
"github.com/cloudwego/eino/schema"
|
|
||||||
"go.uber.org/zap"
|
"go.uber.org/zap"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -96,7 +95,7 @@ func RunEinoSingleChatModelAgent(
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
mainToolsForCfg, mainOrchestratorPre, err := prependEinoMiddlewares(ctx, &ma.EinoMiddleware, einoMWMain, mainTools, einoLoc, skillsRoot, conversationID, logger)
|
mainToolsForCfg, mainOrchestratorPre, singleToolSearchActive, err := prependEinoMiddlewares(ctx, &ma.EinoMiddleware, einoMWMain, mainTools, einoLoc, skillsRoot, conversationID, logger)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("eino single eino 中间件: %w", err)
|
return nil, fmt.Errorf("eino single eino 中间件: %w", err)
|
||||||
}
|
}
|
||||||
@@ -143,7 +142,7 @@ func RunEinoSingleChatModelAgent(
|
|||||||
}
|
}
|
||||||
if einoSkillMW != nil {
|
if einoSkillMW != nil {
|
||||||
if einoFSTools && einoLoc != nil {
|
if einoFSTools && einoLoc != nil {
|
||||||
fsMw, fsErr := subAgentFilesystemMiddleware(ctx, einoLoc, toolInvokeNotify, einoSingleAgentName, einoExecMonitor)
|
fsMw, fsErr := subAgentFilesystemMiddleware(ctx, einoLoc, toolInvokeNotify, einoSingleAgentName, einoExecMonitor, agentToolTimeoutMinutes(appCfg), toolOutputChunk)
|
||||||
if fsErr != nil {
|
if fsErr != nil {
|
||||||
return nil, fmt.Errorf("eino single filesystem 中间件: %w", fsErr)
|
return nil, fmt.Errorf("eino single filesystem 中间件: %w", fsErr)
|
||||||
}
|
}
|
||||||
@@ -178,22 +177,15 @@ func RunEinoSingleChatModelAgent(
|
|||||||
},
|
},
|
||||||
EmitInternalEvents: true,
|
EmitInternalEvents: true,
|
||||||
}
|
}
|
||||||
ins := injectToolNamesOnlyInstruction(ctx, ag.EinoSingleAgentSystemInstruction(), mainTools)
|
ins := injectToolNamesOnlyInstruction(ctx, ag.EinoSingleAgentSystemInstruction(), mainTools, singleToolSearchActive)
|
||||||
if logger != nil {
|
if logger != nil {
|
||||||
names := collectToolNames(ctx, mainTools)
|
names := collectToolNames(ctx, mainTools)
|
||||||
mountedNames := collectToolNames(ctx, mainToolsForCfg)
|
mountedNames := collectToolNames(ctx, mainToolsForCfg)
|
||||||
hasToolSearch := false
|
|
||||||
for _, n := range names {
|
|
||||||
if strings.EqualFold(strings.TrimSpace(n), "tool_search") {
|
|
||||||
hasToolSearch = true
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
logger.Info("eino tool-name injection",
|
logger.Info("eino tool-name injection",
|
||||||
zap.String("scope", "eino_single"),
|
zap.String("scope", "eino_single"),
|
||||||
zap.Int("tool_names", len(names)),
|
zap.Int("tool_names", len(names)),
|
||||||
zap.Int("mounted_tool_names", len(mountedNames)),
|
zap.Int("mounted_tool_names", len(mountedNames)),
|
||||||
zap.Bool("has_tool_search", hasToolSearch),
|
zap.Bool("tool_search_middleware", singleToolSearchActive),
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -220,7 +212,7 @@ func RunEinoSingleChatModelAgent(
|
|||||||
}
|
}
|
||||||
|
|
||||||
baseMsgs := historyToMessages(history, appCfg, &ma.EinoMiddleware)
|
baseMsgs := historyToMessages(history, appCfg, &ma.EinoMiddleware)
|
||||||
baseMsgs = append(baseMsgs, schema.UserMessage(userMessage))
|
baseMsgs = appendUserMessageIfNeeded(baseMsgs, userMessage)
|
||||||
|
|
||||||
streamsMainAssistant := func(agent string) bool {
|
streamsMainAssistant := func(agent string) bool {
|
||||||
return agent == "" || agent == einoSingleAgentName
|
return agent == "" || agent == einoSingleAgentName
|
||||||
@@ -240,6 +232,8 @@ func RunEinoSingleChatModelAgent(
|
|||||||
StreamsMainAssistant: streamsMainAssistant,
|
StreamsMainAssistant: streamsMainAssistant,
|
||||||
EinoRoleTag: einoRoleTag,
|
EinoRoleTag: einoRoleTag,
|
||||||
CheckpointDir: ma.EinoMiddleware.CheckpointDir,
|
CheckpointDir: ma.EinoMiddleware.CheckpointDir,
|
||||||
|
RunRetryMaxAttempts: ma.EinoMiddleware.RunRetryMaxAttempts,
|
||||||
|
RunRetryMaxBackoffSec: ma.EinoMiddleware.RunRetryMaxBackoffSec,
|
||||||
McpIDsMu: &mcpIDsMu,
|
McpIDsMu: &mcpIDsMu,
|
||||||
McpIDs: &mcpIDs,
|
McpIDs: &mcpIDs,
|
||||||
FilesystemMonitorAgent: ag,
|
FilesystemMonitorAgent: ag,
|
||||||
|
|||||||
@@ -82,6 +82,8 @@ func subAgentFilesystemMiddleware(
|
|||||||
invokeNotify *einomcp.ToolInvokeNotifyHolder,
|
invokeNotify *einomcp.ToolInvokeNotifyHolder,
|
||||||
einoAgentName string,
|
einoAgentName string,
|
||||||
recordMonitor func(command, stdout string, success bool, invokeErr error),
|
recordMonitor func(command, stdout string, success bool, invokeErr error),
|
||||||
|
toolTimeoutMinutes int,
|
||||||
|
outputChunk func(toolName, toolCallID, chunk string),
|
||||||
) (adk.ChatModelAgentMiddleware, error) {
|
) (adk.ChatModelAgentMiddleware, error) {
|
||||||
if loc == nil {
|
if loc == nil {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
@@ -89,10 +91,20 @@ func subAgentFilesystemMiddleware(
|
|||||||
return filesystem.New(ctx, &filesystem.MiddlewareConfig{
|
return filesystem.New(ctx, &filesystem.MiddlewareConfig{
|
||||||
Backend: loc,
|
Backend: loc,
|
||||||
StreamingShell: &einoStreamingShellWrap{
|
StreamingShell: &einoStreamingShellWrap{
|
||||||
inner: loc,
|
inner: loc,
|
||||||
invokeNotify: invokeNotify,
|
invokeNotify: invokeNotify,
|
||||||
einoAgentName: strings.TrimSpace(einoAgentName),
|
einoAgentName: strings.TrimSpace(einoAgentName),
|
||||||
recordMonitor: recordMonitor,
|
outputChunk: outputChunk,
|
||||||
|
recordMonitor: recordMonitor,
|
||||||
|
toolTimeoutMinutes: toolTimeoutMinutes,
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// agentToolTimeoutMinutes 返回 agent.tool_timeout_minutes(与 executeToolViaMCP 一致);cfg 为 nil 时 0。
|
||||||
|
func agentToolTimeoutMinutes(cfg *config.Config) int {
|
||||||
|
if cfg == nil {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
return cfg.Agent.ToolTimeoutMinutes
|
||||||
|
}
|
||||||
|
|||||||
@@ -9,34 +9,43 @@ import (
|
|||||||
|
|
||||||
// injectToolNamesOnlyInstruction prepends a compact tool-name-only section into
|
// injectToolNamesOnlyInstruction prepends a compact tool-name-only section into
|
||||||
// the system instruction so the model can reference current callable names.
|
// the system instruction so the model can reference current callable names.
|
||||||
func injectToolNamesOnlyInstruction(ctx context.Context, instruction string, tools []tool.BaseTool) string {
|
// toolSearchMiddlewareActive must be true when prependEinoMiddlewares mounted toolsearch (dynamic tools); do not infer this
|
||||||
|
// by scanning tool names — tool_search is injected by middleware and is usually absent from the pre-split tools list.
|
||||||
|
func injectToolNamesOnlyInstruction(ctx context.Context, instruction string, tools []tool.BaseTool, toolSearchMiddlewareActive bool) string {
|
||||||
names := collectToolNames(ctx, tools)
|
names := collectToolNames(ctx, tools)
|
||||||
if len(names) == 0 {
|
if len(names) == 0 {
|
||||||
return strings.TrimSpace(instruction)
|
return strings.TrimSpace(instruction)
|
||||||
}
|
}
|
||||||
hasToolSearch := false
|
hasToolSearch := toolSearchMiddlewareActive
|
||||||
for _, n := range names {
|
if !hasToolSearch {
|
||||||
if strings.EqualFold(strings.TrimSpace(n), "tool_search") {
|
for _, n := range names {
|
||||||
hasToolSearch = true
|
if strings.EqualFold(strings.TrimSpace(n), "tool_search") {
|
||||||
break
|
hasToolSearch = true
|
||||||
|
break
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
var sb strings.Builder
|
var sb strings.Builder
|
||||||
sb.WriteString("以下是当前会话中可调用的工具名称列表(仅名称,无参数定义):\n")
|
sb.WriteString("以下是当前会话绑定的工具名称索引(仅名称,无参数 JSON Schema)。\n")
|
||||||
|
sb.WriteString("说明:若启用了 tool_search,则列表里可能含「非常驻」工具——它们不一定出现在当前轮次下发给模型的工具定义中;在未看到该工具的完整 schema 前,禁止凭名称臆测参数。\n")
|
||||||
for _, name := range names {
|
for _, name := range names {
|
||||||
sb.WriteString("- ")
|
sb.WriteString("- ")
|
||||||
sb.WriteString(name)
|
sb.WriteString(name)
|
||||||
sb.WriteByte('\n')
|
sb.WriteByte('\n')
|
||||||
}
|
}
|
||||||
sb.WriteString("\n使用规则:\n")
|
sb.WriteString("\n使用规则:\n")
|
||||||
sb.WriteString("1) 上述仅为名称列表,不包含参数定义。\n")
|
sb.WriteString("1) 上表仅为名称索引,不含参数定义。禁止猜测参数名、类型、枚举取值或是否必填。\n")
|
||||||
if hasToolSearch {
|
if hasToolSearch {
|
||||||
sb.WriteString("2) 在调用具体工具前,应先使用 tool_search 查看工具详情与参数要求,再发起调用。\n")
|
sb.WriteString("【强制 / 最高优先级】本会话已启用 tool_search(动态工具池)。凡名称索引里出现、但你在「当前请求所附 tools 定义」中看不到其完整参数 schema 的工具,一律必须先调用 tool_search;为省 token 或赶进度而跳过 tool_search、直接调用业务工具,属于明确禁止的错误流程。\n")
|
||||||
|
sb.WriteString("2) 默认策略:只要对目标工具的参数定义有任何不确定,就先 tool_search;宁可多一次 tool_search,也不要在未见 schema 时盲调业务工具。\n")
|
||||||
|
sb.WriteString("3) 调用顺序:先 tool_search(唯一必填参数 regex_pattern:按工具名匹配的正则,如子串 nuclei 或 ^exact_tool_name$)→ 在后续轮次确认目标工具已出现在 tools 列表且已阅读其 schema → 再发起对该工具的真实调用。\n")
|
||||||
|
sb.WriteString("4) tool_search 的返回仅为匹配到的工具名列表;schema 在解锁后的下一轮才会下发。禁止在 schema 未出现时编造 JSON 参数。\n")
|
||||||
|
sb.WriteString("5) 不要臆造不存在的工具名。\n\n")
|
||||||
} else {
|
} else {
|
||||||
sb.WriteString("2) 调用具体工具前,请先确认该工具的参数要求;不确定时先澄清再调用。\n")
|
sb.WriteString("2) 调用具体工具前,请先确认该工具的参数要求(以当前请求中的工具定义为准);不确定时先澄清再调用。\n")
|
||||||
|
sb.WriteString("3) 不要臆造不存在的工具名。\n\n")
|
||||||
}
|
}
|
||||||
sb.WriteString("3) 不要臆造不存在的工具名。\n\n")
|
|
||||||
if s := strings.TrimSpace(instruction); s != "" {
|
if s := strings.TrimSpace(instruction); s != "" {
|
||||||
sb.WriteString(s)
|
sb.WriteString(s)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -0,0 +1,173 @@
|
|||||||
|
package multiagent
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"cyberstrike-ai/internal/config"
|
||||||
|
|
||||||
|
"github.com/cloudwego/eino/adk"
|
||||||
|
"github.com/cloudwego/eino/schema"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
defaultEinoRunRetryMaxAttempts = 10
|
||||||
|
defaultEinoRunRetryMaxBackoff = 30 * time.Second
|
||||||
|
)
|
||||||
|
|
||||||
|
// isEinoTransientRunError 判断 ADK 运行期错误是否适合指数退避续跑(429、5xx、网络抖动等)。
|
||||||
|
// 用户取消、超时、迭代上限等由 run loop 单独处理,不在此列。
|
||||||
|
func isEinoTransientRunError(err error) bool {
|
||||||
|
if err == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if isEinoIterationLimitError(err) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
msg := strings.ToLower(strings.TrimSpace(err.Error()))
|
||||||
|
if msg == "" {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
transientMarkers := []string{
|
||||||
|
"406",
|
||||||
|
"429",
|
||||||
|
"too many requests",
|
||||||
|
"rate limit",
|
||||||
|
"rate_limit",
|
||||||
|
"ratelimit",
|
||||||
|
"quota exceeded",
|
||||||
|
"overloaded",
|
||||||
|
"capacity",
|
||||||
|
"temporarily unavailable",
|
||||||
|
"service unavailable",
|
||||||
|
"bad gateway",
|
||||||
|
"gateway timeout",
|
||||||
|
"internal server error",
|
||||||
|
"connection reset",
|
||||||
|
"connection refused",
|
||||||
|
"connection closed",
|
||||||
|
"i/o timeout",
|
||||||
|
"no such host",
|
||||||
|
"network is unreachable",
|
||||||
|
"broken pipe",
|
||||||
|
"eof",
|
||||||
|
"read tcp",
|
||||||
|
"write tcp",
|
||||||
|
"dial tcp",
|
||||||
|
"tls handshake timeout",
|
||||||
|
"stream error",
|
||||||
|
"unexpected eof",
|
||||||
|
"unexpected end of json",
|
||||||
|
"status code: 406",
|
||||||
|
"status code: 502",
|
||||||
|
"502",
|
||||||
|
"503",
|
||||||
|
"504",
|
||||||
|
"500",
|
||||||
|
}
|
||||||
|
for _, m := range transientMarkers {
|
||||||
|
if strings.Contains(msg, m) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func einoRunRetryMaxAttempts(args *einoADKRunLoopArgs) int {
|
||||||
|
if args != nil && args.RunRetryMaxAttempts > 0 {
|
||||||
|
return args.RunRetryMaxAttempts
|
||||||
|
}
|
||||||
|
return defaultEinoRunRetryMaxAttempts
|
||||||
|
}
|
||||||
|
|
||||||
|
// RunRetryMaxAttemptsFromConfig 供 handler 分段续跑计数(与 eino_middleware.run_retry_max_attempts 一致)。
|
||||||
|
func RunRetryMaxAttemptsFromConfig(mw *config.MultiAgentEinoMiddlewareConfig) int {
|
||||||
|
if mw != nil && mw.RunRetryMaxAttempts > 0 {
|
||||||
|
return mw.RunRetryMaxAttempts
|
||||||
|
}
|
||||||
|
return defaultEinoRunRetryMaxAttempts
|
||||||
|
}
|
||||||
|
|
||||||
|
// TransientRetryBackoff 供 handler 在分段续跑前退避。
|
||||||
|
func TransientRetryBackoff(attempt int, maxBackoffSec int) time.Duration {
|
||||||
|
max := defaultEinoRunRetryMaxBackoff
|
||||||
|
if maxBackoffSec > 0 {
|
||||||
|
max = time.Duration(maxBackoffSec) * time.Second
|
||||||
|
}
|
||||||
|
return einoTransientRetryBackoff(attempt, max)
|
||||||
|
}
|
||||||
|
|
||||||
|
func einoRunRetryMaxBackoff(args *einoADKRunLoopArgs) time.Duration {
|
||||||
|
if args != nil && args.RunRetryMaxBackoffSec > 0 {
|
||||||
|
return time.Duration(args.RunRetryMaxBackoffSec) * time.Second
|
||||||
|
}
|
||||||
|
return defaultEinoRunRetryMaxBackoff
|
||||||
|
}
|
||||||
|
|
||||||
|
// einoRunRestartContextSource 描述无 checkpoint Resume 时 Run 使用的消息来源(日志/SSE)。
|
||||||
|
type einoRunRestartContextSource string
|
||||||
|
|
||||||
|
const (
|
||||||
|
einoRestartContextInitial einoRunRestartContextSource = "initial"
|
||||||
|
einoRestartContextAccumulated einoRunRestartContextSource = "accumulated"
|
||||||
|
einoRestartContextModelTrace einoRunRestartContextSource = "model_trace"
|
||||||
|
)
|
||||||
|
|
||||||
|
// einoMessagesForRunRestart 在退避后重新 Run 时选用最完整的上下文:
|
||||||
|
// 1) ModelFacingTrace(与模型实际入参一致) 2) 事件流累积的 runAccumulatedMsgs 3) 初始 msgs。
|
||||||
|
func einoMessagesForRunRestart(args *einoADKRunLoopArgs, baseMsgs, accumulated []adk.Message, baseCount int) ([]adk.Message, einoRunRestartContextSource) {
|
||||||
|
if trace := persistTraceSource(args, nil); len(trace) > 0 {
|
||||||
|
return append([]adk.Message(nil), trace...), einoRestartContextModelTrace
|
||||||
|
}
|
||||||
|
if len(accumulated) > baseCount {
|
||||||
|
return append([]adk.Message(nil), accumulated...), einoRestartContextAccumulated
|
||||||
|
}
|
||||||
|
return append([]adk.Message(nil), baseMsgs...), einoRestartContextInitial
|
||||||
|
}
|
||||||
|
|
||||||
|
// adkMessagesHasUserContent 从尾部向前查找,是否已有与 want 相同的 user 消息(避免重复 append)。
|
||||||
|
func adkMessagesHasUserContent(msgs []adk.Message, want string) bool {
|
||||||
|
want = strings.TrimSpace(want)
|
||||||
|
if want == "" {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
for i := len(msgs) - 1; i >= 0; i-- {
|
||||||
|
m := msgs[i]
|
||||||
|
if m == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if m.Role == schema.User {
|
||||||
|
return strings.TrimSpace(m.Content) == want
|
||||||
|
}
|
||||||
|
if m.Role == schema.Assistant || m.Role == schema.Tool {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
break
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// appendUserMessageIfNeeded 在 history 轨迹之后追加本轮 user 消息(仅当轨迹中尚未包含该句)。
|
||||||
|
func appendUserMessageIfNeeded(msgs []adk.Message, userMessage string) []adk.Message {
|
||||||
|
if strings.TrimSpace(userMessage) == "" || adkMessagesHasUserContent(msgs, userMessage) {
|
||||||
|
return msgs
|
||||||
|
}
|
||||||
|
return append(msgs, schema.UserMessage(userMessage))
|
||||||
|
}
|
||||||
|
|
||||||
|
// einoTransientRetryBackoff 指数退避:2s, 4s, 8s… capped by maxBackoff。
|
||||||
|
func einoTransientRetryBackoff(attempt int, maxBackoff time.Duration) time.Duration {
|
||||||
|
if attempt < 0 {
|
||||||
|
attempt = 0
|
||||||
|
}
|
||||||
|
backoff := time.Duration(1<<uint(attempt+1)) * time.Second
|
||||||
|
if maxBackoff > 0 && backoff > maxBackoff {
|
||||||
|
backoff = maxBackoff
|
||||||
|
}
|
||||||
|
return backoff
|
||||||
|
}
|
||||||
@@ -0,0 +1,104 @@
|
|||||||
|
package multiagent
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/cloudwego/eino/adk"
|
||||||
|
"github.com/cloudwego/eino/schema"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestIsEinoTransientRunError(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
cases := []struct {
|
||||||
|
name string
|
||||||
|
err error
|
||||||
|
want bool
|
||||||
|
}{
|
||||||
|
{"nil", nil, false},
|
||||||
|
{"429", errors.New("HTTP 429 Too Many Requests"), true},
|
||||||
|
{"rate limit", errors.New(`{"error":"rate limit exceeded"}`), true},
|
||||||
|
{"connection reset", errors.New("read tcp: connection reset by peer"), true},
|
||||||
|
{"503", errors.New("upstream returned 503"), true},
|
||||||
|
{"iteration limit", errors.New("max iteration reached"), false},
|
||||||
|
{"canceled", context.Canceled, false},
|
||||||
|
{"deadline", context.DeadlineExceeded, false},
|
||||||
|
{"auth", errors.New("invalid api key"), false},
|
||||||
|
}
|
||||||
|
for _, tc := range cases {
|
||||||
|
tc := tc
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
if got := isEinoTransientRunError(tc.err); got != tc.want {
|
||||||
|
t.Fatalf("isEinoTransientRunError(%v) = %v, want %v", tc.err, got, tc.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEinoTransientRetryBackoff(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
max := 30 * time.Second
|
||||||
|
if got := einoTransientRetryBackoff(0, max); got != 2*time.Second {
|
||||||
|
t.Fatalf("attempt 0: got %v", got)
|
||||||
|
}
|
||||||
|
if got := einoTransientRetryBackoff(4, max); got != 30*time.Second {
|
||||||
|
t.Fatalf("attempt 4 capped: got %v", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEinoMessagesForRunRestart(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
base := []adk.Message{schema.UserMessage("hi")}
|
||||||
|
acc := append([]adk.Message(nil), base...)
|
||||||
|
acc = append(acc, schema.AssistantMessage("step1", nil))
|
||||||
|
|
||||||
|
got, src := einoMessagesForRunRestart(nil, base, acc, len(base))
|
||||||
|
if src != einoRestartContextAccumulated || len(got) != 2 {
|
||||||
|
t.Fatalf("accumulated: src=%s len=%d", src, len(got))
|
||||||
|
}
|
||||||
|
|
||||||
|
holder := newModelFacingTraceHolder()
|
||||||
|
holder.storeFromState(&adk.ChatModelAgentState{
|
||||||
|
Messages: []adk.Message{schema.UserMessage("u"), schema.AssistantMessage("model-view", nil)},
|
||||||
|
})
|
||||||
|
got2, src2 := einoMessagesForRunRestart(&einoADKRunLoopArgs{ModelFacingTrace: holder}, base, acc, len(base))
|
||||||
|
if src2 != einoRestartContextModelTrace || len(got2) != 2 {
|
||||||
|
t.Fatalf("model trace: src=%s len=%d", src2, len(got2))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEinoRunRetryMaxAttemptsFromArgs(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
if einoRunRetryMaxAttempts(nil) != defaultEinoRunRetryMaxAttempts {
|
||||||
|
t.Fatal("nil args should use default")
|
||||||
|
}
|
||||||
|
if einoRunRetryMaxAttempts(&einoADKRunLoopArgs{RunRetryMaxAttempts: 3}) != 3 {
|
||||||
|
t.Fatal("custom max attempts")
|
||||||
|
}
|
||||||
|
if RunRetryMaxAttemptsFromConfig(nil) != defaultEinoRunRetryMaxAttempts {
|
||||||
|
t.Fatal("config nil should use default")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAppendUserMessageIfNeeded(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
msgs := []adk.Message{schema.UserMessage("old task")}
|
||||||
|
out := appendUserMessageIfNeeded(msgs, "你好,你是谁")
|
||||||
|
if len(out) != 2 || out[1].Content != "你好,你是谁" {
|
||||||
|
t.Fatalf("should append user: len=%d", len(out))
|
||||||
|
}
|
||||||
|
dup := appendUserMessageIfNeeded(out, "你好,你是谁")
|
||||||
|
if len(dup) != 2 {
|
||||||
|
t.Fatalf("should not duplicate user message: len=%d", len(dup))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestErrTransientRetryContinue(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
if !errors.Is(ErrTransientRetryContinue, ErrTransientRetryContinue) {
|
||||||
|
t.Fatal("sentinel should match")
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -5,3 +5,7 @@ import "errors"
|
|||||||
// ErrInterruptContinue 作为 context.CancelCause 使用:用户选择「中断并继续」且当前无进行中的 MCP 工具时,
|
// ErrInterruptContinue 作为 context.CancelCause 使用:用户选择「中断并继续」且当前无进行中的 MCP 工具时,
|
||||||
// 取消当前推理/流式输出,并在同一会话任务内携带用户补充说明自动续跑下一轮(类似 Hermes 式人机回合)。
|
// 取消当前推理/流式输出,并在同一会话任务内携带用户补充说明自动续跑下一轮(类似 Hermes 式人机回合)。
|
||||||
var ErrInterruptContinue = errors.New("agent interrupt: continue with user-supplied context")
|
var ErrInterruptContinue = errors.New("agent interrupt: continue with user-supplied context")
|
||||||
|
|
||||||
|
// ErrTransientRetryContinue 表示 Run 因 429/网络等临时错误结束,应由 handler 落库轨迹后
|
||||||
|
// loadHistoryFromAgentTrace 再开下一轮 Run(与 ErrInterruptContinue 同级的「分段续跑」语义)。
|
||||||
|
var ErrTransientRetryContinue = errors.New("agent transient: retry after persisting trace")
|
||||||
|
|||||||
@@ -0,0 +1,22 @@
|
|||||||
|
package multiagent
|
||||||
|
|
||||||
|
import (
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Eino execute 去重分支 EOF flush 须以 mainAssistantBuf 为基准计算 tail,
|
||||||
|
// 若误用 TrimSpace(mainAssistantBuf),会与已推前缀在空白处失配,normalize 走拼接路径叠字。
|
||||||
|
func TestNormalizeStreamingDelta_eofTailUsesRawBufNotTrim(t *testing.T) {
|
||||||
|
wireAccum := "phrase "
|
||||||
|
rawFull := "phrase \n"
|
||||||
|
_, tail := normalizeStreamingDelta(wireAccum, rawFull)
|
||||||
|
if want := "\n"; tail != want {
|
||||||
|
t.Fatalf("tail=%q want %q", tail, want)
|
||||||
|
}
|
||||||
|
|
||||||
|
nextWrong, badTail := normalizeStreamingDelta(wireAccum, strings.TrimSpace(rawFull))
|
||||||
|
if badTail != "phrase" || nextWrong != "phrase phrase" {
|
||||||
|
t.Fatalf("trimmed full vs wire prefix mismatch should concat-append; got next=%q badTail=%q", nextWrong, badTail)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -223,7 +223,7 @@ func RunDeepAgent(
|
|||||||
return nil, fmt.Errorf("子代理 %q 工具: %w", id, err)
|
return nil, fmt.Errorf("子代理 %q 工具: %w", id, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
subToolsForCfg, subPre, err := prependEinoMiddlewares(ctx, &ma.EinoMiddleware, einoMWSub, subTools, einoLoc, skillsRoot, conversationID, logger)
|
subToolsForCfg, subPre, subToolSearchActive, err := prependEinoMiddlewares(ctx, &ma.EinoMiddleware, einoMWSub, subTools, einoLoc, skillsRoot, conversationID, logger)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("子代理 %q eino 中间件: %w", id, err)
|
return nil, fmt.Errorf("子代理 %q eino 中间件: %w", id, err)
|
||||||
}
|
}
|
||||||
@@ -244,7 +244,7 @@ func RunDeepAgent(
|
|||||||
}
|
}
|
||||||
if einoSkillMW != nil {
|
if einoSkillMW != nil {
|
||||||
if einoFSTools && einoLoc != nil {
|
if einoFSTools && einoLoc != nil {
|
||||||
subFs, fsErr := subAgentFilesystemMiddleware(ctx, einoLoc, toolInvokeNotify, id, einoExecMonitor)
|
subFs, fsErr := subAgentFilesystemMiddleware(ctx, einoLoc, toolInvokeNotify, id, einoExecMonitor, agentToolTimeoutMinutes(appCfg), toolOutputChunk)
|
||||||
if fsErr != nil {
|
if fsErr != nil {
|
||||||
return nil, fmt.Errorf("子代理 %q filesystem 中间件: %w", id, fsErr)
|
return nil, fmt.Errorf("子代理 %q filesystem 中间件: %w", id, fsErr)
|
||||||
}
|
}
|
||||||
@@ -260,23 +260,16 @@ func RunDeepAgent(
|
|||||||
subHandlers = append(subHandlers, teleMw)
|
subHandlers = append(subHandlers, teleMw)
|
||||||
}
|
}
|
||||||
|
|
||||||
subInstrFinal := injectToolNamesOnlyInstruction(ctx, instr, subTools)
|
subInstrFinal := injectToolNamesOnlyInstruction(ctx, instr, subTools, subToolSearchActive)
|
||||||
if logger != nil {
|
if logger != nil {
|
||||||
subNames := collectToolNames(ctx, subTools)
|
subNames := collectToolNames(ctx, subTools)
|
||||||
mountedNames := collectToolNames(ctx, subToolsForCfg)
|
mountedNames := collectToolNames(ctx, subToolsForCfg)
|
||||||
hasToolSearch := false
|
|
||||||
for _, n := range subNames {
|
|
||||||
if strings.EqualFold(strings.TrimSpace(n), "tool_search") {
|
|
||||||
hasToolSearch = true
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
logger.Info("eino tool-name injection",
|
logger.Info("eino tool-name injection",
|
||||||
zap.String("scope", "sub_agent"),
|
zap.String("scope", "sub_agent"),
|
||||||
zap.String("agent", id),
|
zap.String("agent", id),
|
||||||
zap.Int("tool_names", len(subNames)),
|
zap.Int("tool_names", len(subNames)),
|
||||||
zap.Int("mounted_tool_names", len(mountedNames)),
|
zap.Int("mounted_tool_names", len(mountedNames)),
|
||||||
zap.Bool("has_tool_search", hasToolSearch),
|
zap.Bool("tool_search_middleware", subToolSearchActive),
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
sa, err := adk.NewChatModelAgent(ctx, &adk.ChatModelAgentConfig{
|
sa, err := adk.NewChatModelAgent(ctx, &adk.ChatModelAgentConfig{
|
||||||
@@ -341,28 +334,21 @@ func RunDeepAgent(
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
mainToolsForCfg, mainOrchestratorPre, err := prependEinoMiddlewares(ctx, &ma.EinoMiddleware, einoMWMain, mainTools, einoLoc, skillsRoot, conversationID, logger)
|
mainToolsForCfg, mainOrchestratorPre, mainToolSearchActive, err := prependEinoMiddlewares(ctx, &ma.EinoMiddleware, einoMWMain, mainTools, einoLoc, skillsRoot, conversationID, logger)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
orchInstruction = injectToolNamesOnlyInstruction(ctx, orchInstruction, mainTools)
|
orchInstruction = injectToolNamesOnlyInstruction(ctx, orchInstruction, mainTools, mainToolSearchActive)
|
||||||
if logger != nil {
|
if logger != nil {
|
||||||
mainNames := collectToolNames(ctx, mainTools)
|
mainNames := collectToolNames(ctx, mainTools)
|
||||||
mountedNames := collectToolNames(ctx, mainToolsForCfg)
|
mountedNames := collectToolNames(ctx, mainToolsForCfg)
|
||||||
hasToolSearch := false
|
|
||||||
for _, n := range mainNames {
|
|
||||||
if strings.EqualFold(strings.TrimSpace(n), "tool_search") {
|
|
||||||
hasToolSearch = true
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
logger.Info("eino tool-name injection",
|
logger.Info("eino tool-name injection",
|
||||||
zap.String("scope", "orchestrator"),
|
zap.String("scope", "orchestrator"),
|
||||||
zap.String("orchestration", orchMode),
|
zap.String("orchestration", orchMode),
|
||||||
zap.Int("tool_names", len(mainNames)),
|
zap.Int("tool_names", len(mainNames)),
|
||||||
zap.Int("mounted_tool_names", len(mountedNames)),
|
zap.Int("mounted_tool_names", len(mountedNames)),
|
||||||
zap.Bool("has_tool_search", hasToolSearch),
|
zap.Bool("tool_search_middleware", mainToolSearchActive),
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -390,10 +376,12 @@ func RunDeepAgent(
|
|||||||
if einoLoc != nil && einoFSTools {
|
if einoLoc != nil && einoFSTools {
|
||||||
deepBackend = einoLoc
|
deepBackend = einoLoc
|
||||||
deepShell = &einoStreamingShellWrap{
|
deepShell = &einoStreamingShellWrap{
|
||||||
inner: einoLoc,
|
inner: einoLoc,
|
||||||
invokeNotify: toolInvokeNotify,
|
invokeNotify: toolInvokeNotify,
|
||||||
einoAgentName: orchestratorName,
|
einoAgentName: orchestratorName,
|
||||||
recordMonitor: einoExecMonitor,
|
outputChunk: toolOutputChunk,
|
||||||
|
recordMonitor: einoExecMonitor,
|
||||||
|
toolTimeoutMinutes: agentToolTimeoutMinutes(appCfg),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -457,7 +445,7 @@ func RunDeepAgent(
|
|||||||
// 构建 filesystem 中间件(与 Deep sub-agent 一致)
|
// 构建 filesystem 中间件(与 Deep sub-agent 一致)
|
||||||
var peFsMw adk.ChatModelAgentMiddleware
|
var peFsMw adk.ChatModelAgentMiddleware
|
||||||
if einoSkillMW != nil && einoFSTools && einoLoc != nil {
|
if einoSkillMW != nil && einoFSTools && einoLoc != nil {
|
||||||
peFsMw, err = subAgentFilesystemMiddleware(ctx, einoLoc, toolInvokeNotify, "executor", einoExecMonitor)
|
peFsMw, err = subAgentFilesystemMiddleware(ctx, einoLoc, toolInvokeNotify, "executor", einoExecMonitor, agentToolTimeoutMinutes(appCfg), toolOutputChunk)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("plan_execute filesystem 中间件: %w", err)
|
return nil, fmt.Errorf("plan_execute filesystem 中间件: %w", err)
|
||||||
}
|
}
|
||||||
@@ -550,7 +538,7 @@ func RunDeepAgent(
|
|||||||
}
|
}
|
||||||
|
|
||||||
baseMsgs := historyToMessages(history, appCfg, &ma.EinoMiddleware)
|
baseMsgs := historyToMessages(history, appCfg, &ma.EinoMiddleware)
|
||||||
baseMsgs = append(baseMsgs, schema.UserMessage(userMessage))
|
baseMsgs = appendUserMessageIfNeeded(baseMsgs, userMessage)
|
||||||
|
|
||||||
streamsMainAssistant := func(agent string) bool {
|
streamsMainAssistant := func(agent string) bool {
|
||||||
if orchMode == "plan_execute" {
|
if orchMode == "plan_execute" {
|
||||||
@@ -578,6 +566,8 @@ func RunDeepAgent(
|
|||||||
StreamsMainAssistant: streamsMainAssistant,
|
StreamsMainAssistant: streamsMainAssistant,
|
||||||
EinoRoleTag: einoRoleTag,
|
EinoRoleTag: einoRoleTag,
|
||||||
CheckpointDir: ma.EinoMiddleware.CheckpointDir,
|
CheckpointDir: ma.EinoMiddleware.CheckpointDir,
|
||||||
|
RunRetryMaxAttempts: ma.EinoMiddleware.RunRetryMaxAttempts,
|
||||||
|
RunRetryMaxBackoffSec: ma.EinoMiddleware.RunRetryMaxBackoffSec,
|
||||||
McpIDsMu: &mcpIDsMu,
|
McpIDsMu: &mcpIDsMu,
|
||||||
McpIDs: &mcpIDs,
|
McpIDs: &mcpIDs,
|
||||||
FilesystemMonitorAgent: ag,
|
FilesystemMonitorAgent: ag,
|
||||||
@@ -607,6 +597,13 @@ func chatToolCallsToSchema(tcs []agent.ToolCall) []schema.ToolCall {
|
|||||||
argsStr = string(b)
|
argsStr = string(b)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
// Some OpenAI-compatible gateways require `function.arguments` to exist
|
||||||
|
// on every assistant tool_call message. When args are empty, omitempty may
|
||||||
|
// drop the field during serialization and cause "missing field arguments"
|
||||||
|
// on the next turn history replay.
|
||||||
|
if strings.TrimSpace(argsStr) == "" {
|
||||||
|
argsStr = "{}"
|
||||||
|
}
|
||||||
typ := tc.Type
|
typ := tc.Type
|
||||||
if typ == "" {
|
if typ == "" {
|
||||||
typ = "function"
|
typ = "function"
|
||||||
@@ -749,12 +746,23 @@ func toolCallsRichSignature(msg *schema.Message) string {
|
|||||||
return base + "|" + strings.Join(parts, ";")
|
return base + "|" + strings.Join(parts, ";")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func einoMainIterationKey(agentName, orchestratorName string) string {
|
||||||
|
key := strings.TrimSpace(agentName)
|
||||||
|
if key == "" {
|
||||||
|
key = strings.TrimSpace(orchestratorName)
|
||||||
|
}
|
||||||
|
if key == "" {
|
||||||
|
return "_main"
|
||||||
|
}
|
||||||
|
return key
|
||||||
|
}
|
||||||
|
|
||||||
func tryEmitToolCallsOnce(
|
func tryEmitToolCallsOnce(
|
||||||
msg *schema.Message,
|
msg *schema.Message,
|
||||||
agentName, orchestratorName, conversationID string,
|
agentName, orchestratorName, conversationID, orchMode string,
|
||||||
progress func(string, string, interface{}),
|
progress func(string, string, interface{}),
|
||||||
seen map[string]struct{},
|
seen map[string]struct{},
|
||||||
subAgentToolStep map[string]int,
|
subAgentToolStep, mainAgentToolStep map[string]int,
|
||||||
markPending func(toolCallPendingInfo),
|
markPending func(toolCallPendingInfo),
|
||||||
) {
|
) {
|
||||||
if msg == nil || len(msg.ToolCalls) == 0 || progress == nil || seen == nil {
|
if msg == nil || len(msg.ToolCalls) == 0 || progress == nil || seen == nil {
|
||||||
@@ -768,14 +776,14 @@ func tryEmitToolCallsOnce(
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
seen[sig] = struct{}{}
|
seen[sig] = struct{}{}
|
||||||
emitToolCallsFromMessage(msg, agentName, orchestratorName, conversationID, progress, subAgentToolStep, markPending)
|
emitToolCallsFromMessage(msg, agentName, orchestratorName, conversationID, orchMode, progress, subAgentToolStep, mainAgentToolStep, markPending)
|
||||||
}
|
}
|
||||||
|
|
||||||
func emitToolCallsFromMessage(
|
func emitToolCallsFromMessage(
|
||||||
msg *schema.Message,
|
msg *schema.Message,
|
||||||
agentName, orchestratorName, conversationID string,
|
agentName, orchestratorName, conversationID, orchMode string,
|
||||||
progress func(string, string, interface{}),
|
progress func(string, string, interface{}),
|
||||||
subAgentToolStep map[string]int,
|
subAgentToolStep, mainAgentToolStep map[string]int,
|
||||||
markPending func(toolCallPendingInfo),
|
markPending func(toolCallPendingInfo),
|
||||||
) {
|
) {
|
||||||
if msg == nil || len(msg.ToolCalls) == 0 || progress == nil {
|
if msg == nil || len(msg.ToolCalls) == 0 || progress == nil {
|
||||||
@@ -796,6 +804,22 @@ func emitToolCallsFromMessage(
|
|||||||
"conversationId": conversationID,
|
"conversationId": conversationID,
|
||||||
"source": "eino",
|
"source": "eino",
|
||||||
})
|
})
|
||||||
|
} else if mainAgentToolStep != nil {
|
||||||
|
key := einoMainIterationKey(agentName, orchestratorName)
|
||||||
|
mainAgentToolStep[key]++
|
||||||
|
n := mainAgentToolStep[key]
|
||||||
|
// 第 1 轮已在主代理进入时发出;此后每次工具批次对应新一轮 ReAct(与子代理按工具计步一致)。
|
||||||
|
if n > 1 {
|
||||||
|
progress("iteration", "", map[string]interface{}{
|
||||||
|
"iteration": n,
|
||||||
|
"einoScope": "main",
|
||||||
|
"einoRole": "orchestrator",
|
||||||
|
"einoAgent": agentName,
|
||||||
|
"orchestration": orchMode,
|
||||||
|
"conversationId": conversationID,
|
||||||
|
"source": "eino",
|
||||||
|
})
|
||||||
|
}
|
||||||
}
|
}
|
||||||
role := "orchestrator"
|
role := "orchestrator"
|
||||||
if isSubToolRound {
|
if isSubToolRound {
|
||||||
|
|||||||
@@ -0,0 +1,20 @@
|
|||||||
|
package openai
|
||||||
|
|
||||||
|
// SSEAccumulatedKey 为 SSE progress 事件 data 中的服务端权威流式全文快照字段。
|
||||||
|
// 前端应优先用该字段更新 buffer,避免对 delta 二次 normalize 导致叠字。
|
||||||
|
const SSEAccumulatedKey = "accumulated"
|
||||||
|
|
||||||
|
// WithSSEAccumulated 在 progress data 中附带当前流式累计全文(权威快照)。
|
||||||
|
func WithSSEAccumulated(data map[string]interface{}, accumulated string) map[string]interface{} {
|
||||||
|
if data == nil {
|
||||||
|
data = make(map[string]interface{}, 1)
|
||||||
|
}
|
||||||
|
data[SSEAccumulatedKey] = accumulated
|
||||||
|
return data
|
||||||
|
}
|
||||||
|
|
||||||
|
// NormalizeStreamingDelta 将可能是“累计片段/重发片段”的内容归一化为“纯增量”。
|
||||||
|
// 与 unexported normalizeStreamingDelta 相同,供 agent / multiagent 等包在发 SSE 前累计正文。
|
||||||
|
func NormalizeStreamingDelta(current, incoming string) (next, delta string) {
|
||||||
|
return normalizeStreamingDelta(current, incoming)
|
||||||
|
}
|
||||||
@@ -149,13 +149,18 @@ func effectiveEffort(sr *config.OpenAIReasoningConfig, client *ClientIntent, all
|
|||||||
func normalizeEffort(s string) string {
|
func normalizeEffort(s string) string {
|
||||||
e := strings.ToLower(strings.TrimSpace(s))
|
e := strings.ToLower(strings.TrimSpace(s))
|
||||||
switch e {
|
switch e {
|
||||||
case "low", "medium", "high", "max":
|
case "low", "medium", "high", "max", "xhigh":
|
||||||
return e
|
return e
|
||||||
default:
|
default:
|
||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// usesExtraFieldsReasoningEffort 为 Eino 无枚举的最高档 effort,经 ExtraFields 原样下发(max / xhigh 由网关自行识别,不做互转)。
|
||||||
|
func usesExtraFieldsReasoningEffort(e string) bool {
|
||||||
|
return e == "max" || e == "xhigh"
|
||||||
|
}
|
||||||
|
|
||||||
func resolveWireProfile(oa *config.OpenAIConfig, sr *config.OpenAIReasoningConfig) wireProfile {
|
func resolveWireProfile(oa *config.OpenAIConfig, sr *config.OpenAIReasoningConfig) wireProfile {
|
||||||
if strings.EqualFold(strings.TrimSpace(oa.Provider), "claude") {
|
if strings.EqualFold(strings.TrimSpace(oa.Provider), "claude") {
|
||||||
return wireClaude
|
return wireClaude
|
||||||
@@ -210,11 +215,11 @@ func applyOpenAICompat(cfg *einoopenai.ChatModelConfig, mode, effort string) {
|
|||||||
if e == "" {
|
if e == "" {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if e == "max" {
|
if usesExtraFieldsReasoningEffort(e) {
|
||||||
if cfg.ExtraFields == nil {
|
if cfg.ExtraFields == nil {
|
||||||
cfg.ExtraFields = make(map[string]any)
|
cfg.ExtraFields = make(map[string]any)
|
||||||
}
|
}
|
||||||
cfg.ExtraFields["reasoning_effort"] = "max"
|
cfg.ExtraFields["reasoning_effort"] = effortStringForAPI(e)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
switch e {
|
switch e {
|
||||||
@@ -245,6 +250,6 @@ func applyOutputConfigEffort(cfg *einoopenai.ChatModelConfig, mode, effort strin
|
|||||||
}
|
}
|
||||||
|
|
||||||
func effortStringForAPI(e string) string {
|
func effortStringForAPI(e string) string {
|
||||||
// Gateways expect lowercase strings; "max" kept as max.
|
// 原样透传:OpenAI 官方多为 xhigh,部分兼容网关为 max,由配置/对话 effort 选择。
|
||||||
return strings.ToLower(strings.TrimSpace(e))
|
return strings.ToLower(strings.TrimSpace(e))
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -0,0 +1,66 @@
|
|||||||
|
package reasoning
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"cyberstrike-ai/internal/config"
|
||||||
|
|
||||||
|
einoopenai "github.com/cloudwego/eino-ext/components/model/openai"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestEffortStringForAPI_passthrough(t *testing.T) {
|
||||||
|
cases := map[string]string{
|
||||||
|
"max": "max",
|
||||||
|
"xhigh": "xhigh",
|
||||||
|
"HIGH": "high",
|
||||||
|
"Medium": "medium",
|
||||||
|
}
|
||||||
|
for in, want := range cases {
|
||||||
|
if got := effortStringForAPI(in); got != want {
|
||||||
|
t.Fatalf("%q -> %q, want %q", in, got, want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNormalizeEffort_maxAndXhigh(t *testing.T) {
|
||||||
|
if normalizeEffort("xhigh") != "xhigh" {
|
||||||
|
t.Fatal("xhigh not accepted")
|
||||||
|
}
|
||||||
|
if normalizeEffort("max") != "max" {
|
||||||
|
t.Fatal("max not accepted")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestApplyOpenAICompat_xhighExtraField(t *testing.T) {
|
||||||
|
cfg := &einoopenai.ChatModelConfig{}
|
||||||
|
oa := &config.OpenAIConfig{
|
||||||
|
Reasoning: config.OpenAIReasoningConfig{
|
||||||
|
Profile: "openai_compat",
|
||||||
|
Mode: "on",
|
||||||
|
Effort: "xhigh",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
ApplyToEinoChatModelConfig(cfg, oa, nil)
|
||||||
|
if cfg.ExtraFields == nil {
|
||||||
|
t.Fatal("expected ExtraFields")
|
||||||
|
}
|
||||||
|
if got, _ := cfg.ExtraFields["reasoning_effort"].(string); got != "xhigh" {
|
||||||
|
t.Fatalf("reasoning_effort=%q", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestApplyOpenAICompat_maxPassthrough(t *testing.T) {
|
||||||
|
cfg := &einoopenai.ChatModelConfig{}
|
||||||
|
oa := &config.OpenAIConfig{
|
||||||
|
Reasoning: config.OpenAIReasoningConfig{
|
||||||
|
Profile: "openai_compat",
|
||||||
|
Mode: "on",
|
||||||
|
Effort: "max",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
ApplyToEinoChatModelConfig(cfg, oa, nil)
|
||||||
|
got, _ := cfg.ExtraFields["reasoning_effort"].(string)
|
||||||
|
if got != "max" {
|
||||||
|
t.Fatalf("max effort wire=%q, want max", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,316 @@
|
|||||||
|
package ilink
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"crypto/rand"
|
||||||
|
"encoding/base64"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"net/url"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
DefaultBaseURL = "https://ilinkai.weixin.qq.com"
|
||||||
|
DefaultBotType = "3"
|
||||||
|
DefaultBotAgent = "CyberStrikeAI/1.0"
|
||||||
|
ILinkAppID = "bot"
|
||||||
|
QRLongPollTimeout = 35 * time.Second
|
||||||
|
APIDefaultTimeout = 15 * time.Second
|
||||||
|
GetUpdatesTimeout = 35 * time.Second
|
||||||
|
)
|
||||||
|
|
||||||
|
// Client 微信 iLink Bot HTTP 客户端(与 @tencent-weixin/openclaw-weixin 协议兼容)
|
||||||
|
type Client struct {
|
||||||
|
BaseURL string
|
||||||
|
BotToken string
|
||||||
|
BotAgent string
|
||||||
|
ClientVersion uint32
|
||||||
|
HTTP *http.Client
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewClient(baseURL, botToken, botAgent string, clientVersion uint32) *Client {
|
||||||
|
base := strings.TrimSpace(baseURL)
|
||||||
|
if base == "" {
|
||||||
|
base = DefaultBaseURL
|
||||||
|
}
|
||||||
|
agent := strings.TrimSpace(botAgent)
|
||||||
|
if agent == "" {
|
||||||
|
agent = DefaultBotAgent
|
||||||
|
}
|
||||||
|
return &Client{
|
||||||
|
BaseURL: strings.TrimRight(base, "/"),
|
||||||
|
BotToken: strings.TrimSpace(botToken),
|
||||||
|
BotAgent: sanitizeBotAgent(agent),
|
||||||
|
ClientVersion: clientVersion,
|
||||||
|
HTTP: &http.Client{Timeout: 0},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// BuildClientVersion 将 semver 编码为 iLink-App-ClientVersion(0x00MMNNPP)
|
||||||
|
func BuildClientVersion(version string) uint32 {
|
||||||
|
parts := strings.Split(version, ".")
|
||||||
|
parse := func(i int) int {
|
||||||
|
if i >= len(parts) {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
n, _ := strconv.Atoi(strings.TrimSpace(parts[i]))
|
||||||
|
if n < 0 {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
return n
|
||||||
|
}
|
||||||
|
major := parse(0) & 0xff
|
||||||
|
minor := parse(1) & 0xff
|
||||||
|
patch := parse(2) & 0xff
|
||||||
|
return uint32((major << 16) | (minor << 8) | patch)
|
||||||
|
}
|
||||||
|
|
||||||
|
type baseInfo struct {
|
||||||
|
ChannelVersion string `json:"channel_version"`
|
||||||
|
BotAgent string `json:"bot_agent"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Client) buildBaseInfo() baseInfo {
|
||||||
|
return baseInfo{
|
||||||
|
ChannelVersion: "1.0.0",
|
||||||
|
BotAgent: c.BotAgent,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func randomWechatUIN() string {
|
||||||
|
var b [4]byte
|
||||||
|
_, _ = rand.Read(b[:])
|
||||||
|
u := uint32(b[0])<<24 | uint32(b[1])<<16 | uint32(b[2])<<8 | uint32(b[3])
|
||||||
|
return base64.StdEncoding.EncodeToString([]byte(strconv.FormatUint(uint64(u), 10)))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Client) commonHeaders() http.Header {
|
||||||
|
h := http.Header{}
|
||||||
|
h.Set("iLink-App-Id", ILinkAppID)
|
||||||
|
h.Set("iLink-App-ClientVersion", strconv.FormatUint(uint64(c.ClientVersion), 10))
|
||||||
|
return h
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Client) authHeaders() http.Header {
|
||||||
|
h := c.commonHeaders()
|
||||||
|
h.Set("Content-Type", "application/json")
|
||||||
|
h.Set("AuthorizationType", "ilink_bot_token")
|
||||||
|
h.Set("X-WECHAT-UIN", randomWechatUIN())
|
||||||
|
if c.BotToken != "" {
|
||||||
|
h.Set("Authorization", "Bearer "+c.BotToken)
|
||||||
|
}
|
||||||
|
return h
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Client) endpointURL(path string) (string, error) {
|
||||||
|
u, err := url.Parse(c.BaseURL + "/")
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
ref, err := url.Parse(path)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
return u.ResolveReference(ref).String(), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Client) doRequest(ctx context.Context, method, path string, body []byte, headers http.Header, timeout time.Duration) ([]byte, error) {
|
||||||
|
reqURL, err := c.endpointURL(path)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
var bodyReader io.Reader
|
||||||
|
if len(body) > 0 {
|
||||||
|
bodyReader = bytes.NewReader(body)
|
||||||
|
}
|
||||||
|
req, err := http.NewRequestWithContext(ctx, method, reqURL, bodyReader)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
for k, vs := range headers {
|
||||||
|
for _, v := range vs {
|
||||||
|
req.Header.Add(k, v)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
client := c.HTTP
|
||||||
|
if client == nil {
|
||||||
|
client = http.DefaultClient
|
||||||
|
}
|
||||||
|
if timeout > 0 {
|
||||||
|
ctx2, cancel := context.WithTimeout(ctx, timeout)
|
||||||
|
defer cancel()
|
||||||
|
req = req.WithContext(ctx2)
|
||||||
|
}
|
||||||
|
resp, err := client.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
raw, err := io.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||||
|
return nil, fmt.Errorf("ilink %s %s: %d %s", method, path, resp.StatusCode, string(raw))
|
||||||
|
}
|
||||||
|
return raw, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// QRCodeResponse 获取二维码响应
|
||||||
|
type QRCodeResponse struct {
|
||||||
|
QRCode string `json:"qrcode"`
|
||||||
|
QRCodeImgContent string `json:"qrcode_img_content"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetBotQRCode 获取绑定二维码
|
||||||
|
func (c *Client) GetBotQRCode(ctx context.Context, botType string, localTokenList []string) (*QRCodeResponse, error) {
|
||||||
|
if strings.TrimSpace(botType) == "" {
|
||||||
|
botType = DefaultBotType
|
||||||
|
}
|
||||||
|
body, _ := json.Marshal(map[string]interface{}{
|
||||||
|
"local_token_list": localTokenList,
|
||||||
|
})
|
||||||
|
path := "ilink/bot/get_bot_qrcode?bot_type=" + url.QueryEscape(botType)
|
||||||
|
raw, err := c.doRequest(ctx, http.MethodPost, path, body, c.authHeaders(), APIDefaultTimeout)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
var out QRCodeResponse
|
||||||
|
if err := json.Unmarshal(raw, &out); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &out, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// QRStatusResponse 二维码状态轮询响应
|
||||||
|
type QRStatusResponse struct {
|
||||||
|
Status string `json:"status"`
|
||||||
|
BotToken string `json:"bot_token"`
|
||||||
|
ILinkBotID string `json:"ilink_bot_id"`
|
||||||
|
ILinkUserID string `json:"ilink_user_id"`
|
||||||
|
BaseURL string `json:"baseurl"`
|
||||||
|
RedirectHost string `json:"redirect_host"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetQRCodeStatus 长轮询二维码扫码状态
|
||||||
|
func (c *Client) GetQRCodeStatus(ctx context.Context, qrcode, verifyCode string) (*QRStatusResponse, error) {
|
||||||
|
path := "ilink/bot/get_qrcode_status?qrcode=" + url.QueryEscape(qrcode)
|
||||||
|
if verifyCode != "" {
|
||||||
|
path += "&verify_code=" + url.QueryEscape(verifyCode)
|
||||||
|
}
|
||||||
|
raw, err := c.doRequest(ctx, http.MethodGet, path, nil, c.commonHeaders(), QRLongPollTimeout)
|
||||||
|
if err != nil {
|
||||||
|
if ctx.Err() != nil {
|
||||||
|
return &QRStatusResponse{Status: "wait"}, nil
|
||||||
|
}
|
||||||
|
return &QRStatusResponse{Status: "wait"}, nil
|
||||||
|
}
|
||||||
|
var out QRStatusResponse
|
||||||
|
if err := json.Unmarshal(raw, &out); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &out, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// MessageItem 消息内容项
|
||||||
|
type MessageItem struct {
|
||||||
|
Type int `json:"type"`
|
||||||
|
TextItem *struct {
|
||||||
|
Text string `json:"text"`
|
||||||
|
} `json:"text_item,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// WeixinMessage 入站消息
|
||||||
|
type WeixinMessage struct {
|
||||||
|
FromUserID string `json:"from_user_id"`
|
||||||
|
MessageType int `json:"message_type"`
|
||||||
|
MessageState int `json:"message_state"`
|
||||||
|
ItemList []MessageItem `json:"item_list"`
|
||||||
|
ContextToken string `json:"context_token"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetUpdatesResponse 长轮询消息响应
|
||||||
|
type GetUpdatesResponse struct {
|
||||||
|
Ret int `json:"ret"`
|
||||||
|
ErrCode int `json:"errcode"`
|
||||||
|
ErrMsg string `json:"errmsg"`
|
||||||
|
Msgs []WeixinMessage `json:"msgs"`
|
||||||
|
GetUpdatesBuf string `json:"get_updates_buf"`
|
||||||
|
LongPollingTimeoutMs int `json:"longpolling_timeout_ms"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetUpdates 长轮询获取新消息
|
||||||
|
func (c *Client) GetUpdates(ctx context.Context, getUpdatesBuf string) (*GetUpdatesResponse, error) {
|
||||||
|
body, _ := json.Marshal(map[string]interface{}{
|
||||||
|
"get_updates_buf": getUpdatesBuf,
|
||||||
|
"base_info": c.buildBaseInfo(),
|
||||||
|
})
|
||||||
|
raw, err := c.doRequest(ctx, http.MethodPost, "ilink/bot/getupdates", body, c.authHeaders(), GetUpdatesTimeout)
|
||||||
|
if err != nil {
|
||||||
|
if ctx.Err() != nil {
|
||||||
|
return &GetUpdatesResponse{Ret: 0, GetUpdatesBuf: getUpdatesBuf}, nil
|
||||||
|
}
|
||||||
|
return &GetUpdatesResponse{Ret: 0, GetUpdatesBuf: getUpdatesBuf}, nil
|
||||||
|
}
|
||||||
|
var out GetUpdatesResponse
|
||||||
|
if err := json.Unmarshal(raw, &out); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &out, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// SendTextMessage 发送文本回复
|
||||||
|
func (c *Client) SendTextMessage(ctx context.Context, toUserID, contextToken, text, clientID string) error {
|
||||||
|
if clientID == "" {
|
||||||
|
clientID = randomClientID()
|
||||||
|
}
|
||||||
|
payload := map[string]interface{}{
|
||||||
|
"msg": map[string]interface{}{
|
||||||
|
"to_user_id": toUserID,
|
||||||
|
"client_id": clientID,
|
||||||
|
"message_type": 2,
|
||||||
|
"message_state": 2,
|
||||||
|
"context_token": contextToken,
|
||||||
|
"item_list": []map[string]interface{}{
|
||||||
|
{"type": 1, "text_item": map[string]string{"text": text}},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"base_info": c.buildBaseInfo(),
|
||||||
|
}
|
||||||
|
body, _ := json.Marshal(payload)
|
||||||
|
_, err := c.doRequest(ctx, http.MethodPost, "ilink/bot/sendmessage", body, c.authHeaders(), APIDefaultTimeout)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func randomClientID() string {
|
||||||
|
var b [8]byte
|
||||||
|
_, _ = rand.Read(b[:])
|
||||||
|
return fmt.Sprintf("%x", b)
|
||||||
|
}
|
||||||
|
|
||||||
|
func sanitizeBotAgent(raw string) string {
|
||||||
|
raw = strings.TrimSpace(raw)
|
||||||
|
if raw == "" {
|
||||||
|
return DefaultBotAgent
|
||||||
|
}
|
||||||
|
if len(raw) > 256 {
|
||||||
|
return raw[:256]
|
||||||
|
}
|
||||||
|
return raw
|
||||||
|
}
|
||||||
|
|
||||||
|
// ExtractText 从消息中提取首条文本
|
||||||
|
func ExtractText(msg WeixinMessage) string {
|
||||||
|
for _, item := range msg.ItemList {
|
||||||
|
if item.Type == 1 && item.TextItem != nil {
|
||||||
|
return strings.TrimSpace(item.TextItem.Text)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
@@ -0,0 +1,26 @@
|
|||||||
|
package ilink
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/base64"
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/skip2/go-qrcode"
|
||||||
|
)
|
||||||
|
|
||||||
|
// QRCodeDataURL 将扫码内容(一般为 liteapp 链接)编码为 PNG data URL,供 Web 端展示。
|
||||||
|
// qrcode_img_content 不是图片直链,不能用作 <img src>。
|
||||||
|
func QRCodeDataURL(content string, size int) (string, error) {
|
||||||
|
content = strings.TrimSpace(content)
|
||||||
|
if content == "" {
|
||||||
|
return "", fmt.Errorf("empty qr content")
|
||||||
|
}
|
||||||
|
if size <= 0 {
|
||||||
|
size = 256
|
||||||
|
}
|
||||||
|
png, err := qrcode.Encode(content, qrcode.Medium, size)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
return "data:image/png;base64," + base64.StdEncoding.EncodeToString(png), nil
|
||||||
|
}
|
||||||
@@ -0,0 +1,96 @@
|
|||||||
|
package robot
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"cyberstrike-ai/internal/config"
|
||||||
|
"cyberstrike-ai/internal/robot/ilink"
|
||||||
|
|
||||||
|
"go.uber.org/zap"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
wechatReconnectInitial = 5 * time.Second
|
||||||
|
wechatReconnectMax = 60 * time.Second
|
||||||
|
wechatPlatform = "wechat"
|
||||||
|
)
|
||||||
|
|
||||||
|
// StartWechat 启动微信 iLink 长轮询(无需公网回调),收到消息后调用 handler 并回复。
|
||||||
|
func StartWechat(ctx context.Context, robotsCfg config.RobotsConfig, h MessageHandler, appVersion string, logger *zap.Logger) {
|
||||||
|
cfg := robotsCfg.Wechat
|
||||||
|
if !cfg.Enabled || cfg.BotToken == "" {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
go runWechatLoop(ctx, cfg, h, appVersion, logger)
|
||||||
|
}
|
||||||
|
|
||||||
|
func runWechatLoop(ctx context.Context, cfg config.RobotWechatConfig, h MessageHandler, appVersion string, logger *zap.Logger) {
|
||||||
|
backoff := wechatReconnectInitial
|
||||||
|
for {
|
||||||
|
err := runWechatPoll(ctx, cfg, h, appVersion, logger)
|
||||||
|
if ctx.Err() != nil {
|
||||||
|
logger.Info("微信 iLink 长轮询已按配置关闭")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
logger.Warn("微信 iLink 长轮询异常,将自动重连", zap.Error(err), zap.Duration("retry_after", backoff))
|
||||||
|
}
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return
|
||||||
|
case <-time.After(backoff):
|
||||||
|
if backoff < wechatReconnectMax {
|
||||||
|
backoff *= 2
|
||||||
|
if backoff > wechatReconnectMax {
|
||||||
|
backoff = wechatReconnectMax
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func runWechatPoll(ctx context.Context, cfg config.RobotWechatConfig, h MessageHandler, appVersion string, logger *zap.Logger) error {
|
||||||
|
client := ilink.NewClient(cfg.BaseURL, cfg.BotToken, cfg.BotAgent, ilink.BuildClientVersion(appVersion))
|
||||||
|
buf := cfg.GetUpdatesBuf
|
||||||
|
logger.Info("微信 iLink 长轮询已启动", zap.String("ilink_bot_id", cfg.ILinkBotID))
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return ctx.Err()
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
resp, err := client.GetUpdates(ctx, buf)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if resp.ErrCode != 0 && resp.Ret != 0 {
|
||||||
|
logger.Warn("微信 getUpdates 返回错误", zap.Int("errcode", resp.ErrCode), zap.String("errmsg", resp.ErrMsg))
|
||||||
|
}
|
||||||
|
if resp.GetUpdatesBuf != "" {
|
||||||
|
buf = resp.GetUpdatesBuf
|
||||||
|
}
|
||||||
|
for _, msg := range resp.Msgs {
|
||||||
|
if msg.MessageType != 1 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
text := ilink.ExtractText(msg)
|
||||||
|
if text == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
userID := strings.TrimSpace(msg.FromUserID)
|
||||||
|
if userID == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
logger.Info("微信收到消息", zap.String("from", userID), zap.String("content", text))
|
||||||
|
reply := h.HandleMessage(wechatPlatform, userID, text)
|
||||||
|
if strings.TrimSpace(reply) == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if err := client.SendTextMessage(ctx, userID, msg.ContextToken, reply, ""); err != nil {
|
||||||
|
logger.Warn("微信发送回复失败", zap.String("to", userID), zap.Error(err))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -4,7 +4,7 @@
|
|||||||
|
|
||||||
### What it does
|
### What it does
|
||||||
|
|
||||||
- Configure **Host / Port / Password** and choose **Single-Agent** or **Multi-Agent**
|
- Configure **Host / Port / HTTPS / Password** and choose an agent mode
|
||||||
- Click **Validate** to login (`POST /api/auth/login`) and verify token (`GET /api/auth/validate`)
|
- Click **Validate** to login (`POST /api/auth/login`) and verify token (`GET /api/auth/validate`)
|
||||||
- Right-click any HTTP message in Burp and send it to CyberStrikeAI for **streaming web pentest**
|
- Right-click any HTTP message in Burp and send it to CyberStrikeAI for **streaming web pentest**
|
||||||
- Keep a **test history sidebar** (searchable) so you can revisit previous runs
|
- Keep a **test history sidebar** (searchable) so you can revisit previous runs
|
||||||
@@ -63,6 +63,7 @@ If you already have Gradle available, you can still use `build.gradle` to build.
|
|||||||
|
|
||||||
### Notes
|
### Notes
|
||||||
|
|
||||||
- This extension connects to your CyberStrikeAI server (default is `http://127.0.0.1:8080`).
|
- Default connection is `https://127.0.0.1:8080` (**HTTPS** checked). Self-signed / local certs are trusted automatically (no import).
|
||||||
|
- Uncheck **HTTPS** only if your server runs plain HTTP.
|
||||||
- It uses **Bearer Token** authentication obtained from the configured password.
|
- It uses **Bearer Token** authentication obtained from the configured password.
|
||||||
|
|
||||||
|
|||||||
@@ -81,7 +81,8 @@ cd plugins/burp-suite/cyberstrikeai-burp-extension
|
|||||||
2) 填写:
|
2) 填写:
|
||||||
- **Host**:例如 `127.0.0.1`
|
- **Host**:例如 `127.0.0.1`
|
||||||
- **Port**:例如 `8080`
|
- **Port**:例如 `8080`
|
||||||
- **Password**:你的 CyberStrikeAI 登录密码(对应服务端 `config.yaml` 的 `auth.password`)
|
- **HTTPS**:默认勾选(对接 `config.yaml` 中 `tls_enabled` / 自签证书);插件会自动信任本地自签证书,无需导入
|
||||||
|
- **Password**:你的 CyberStrikeAI 登录密码(对应服务端 `auth.password`)
|
||||||
- **Agent mode**:选择 `Single Agent` 或 `Multi Agent`
|
- **Agent mode**:选择 `Single Agent` 或 `Multi Agent`
|
||||||
3) 点击 **Validate**
|
3) 点击 **Validate**
|
||||||
- 成功:状态显示 `OK (token saved)`
|
- 成功:状态显示 `OK (token saved)`
|
||||||
@@ -94,8 +95,9 @@ cd plugins/burp-suite/cyberstrikeai-burp-extension
|
|||||||
|
|
||||||
- **Validate 失败 / 401**
|
- **Validate 失败 / 401**
|
||||||
- 确认密码是否正确(服务端 `auth.password`)
|
- 确认密码是否正确(服务端 `auth.password`)
|
||||||
- 确认 IP/端口是否能访问(例如浏览器能打开 `http://IP:PORT/`)
|
- 确认 IP/端口是否能访问(例如浏览器能打开 `https://IP:PORT/`)
|
||||||
- 若服务器启用了反向代理/HTTPS,需要把插件里 baseUrl 改成对应协议与端口(当前插件默认使用 `http://`)
|
- 服务端启用 TLS 时勾选 **HTTPS**(默认已勾选);自签证书无需手动导入
|
||||||
|
- 若仍为纯 HTTP 部署,取消勾选 **HTTPS**
|
||||||
|
|
||||||
- **选择 Multi Agent 后提示“多代理未启用”**
|
- **选择 Multi Agent 后提示“多代理未启用”**
|
||||||
- 服务端需要开启:`config.yaml` 中 `multi_agent.enabled: true`
|
- 服务端需要开启:`config.yaml` 中 `multi_agent.enabled: true`
|
||||||
|
|||||||
BIN
Binary file not shown.
+52
-11
@@ -73,15 +73,34 @@ public class BurpExtender implements IBurpExtender, IContextMenuFactory {
|
|||||||
public void onEvent(String type, String message, String rawJson) {
|
public void onEvent(String type, String message, String rawJson) {
|
||||||
if (type == null) type = "";
|
if (type == null) type = "";
|
||||||
switch (type) {
|
switch (type) {
|
||||||
|
case "response_start":
|
||||||
|
tab.appendProgressToRun(runId, "\n\n[主回复]\n");
|
||||||
|
break;
|
||||||
case "response_delta":
|
case "response_delta":
|
||||||
case "eino_agent_reply_stream_delta":
|
if (message != null && !message.isEmpty()) {
|
||||||
tab.appendFinalToRun(runId, message);
|
tab.appendFinalToRun(runId, message);
|
||||||
|
}
|
||||||
break;
|
break;
|
||||||
case "response":
|
case "response":
|
||||||
tab.appendFinalToRun(runId, "\n\n--- Final Response ---\n");
|
|
||||||
tab.appendFinalToRun(runId, message);
|
tab.appendFinalToRun(runId, message);
|
||||||
tab.setFinalResponse(runId, message);
|
tab.setFinalResponse(runId, message);
|
||||||
break;
|
break;
|
||||||
|
case "eino_agent_reply_stream_start":
|
||||||
|
tab.appendProgressToRun(runId, "\n\n[子代理回复]\n");
|
||||||
|
break;
|
||||||
|
case "eino_agent_reply_stream_delta":
|
||||||
|
if (message != null && !message.isEmpty()) {
|
||||||
|
tab.appendProgressToRun(runId, message);
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
case "eino_agent_reply_stream_end":
|
||||||
|
tab.appendProgressToRun(runId, "\n");
|
||||||
|
break;
|
||||||
|
case "eino_agent_reply":
|
||||||
|
if (message != null && !message.isEmpty()) {
|
||||||
|
tab.appendProgressToRun(runId, "\n\n[子代理回复]\n" + message + "\n");
|
||||||
|
}
|
||||||
|
break;
|
||||||
case "progress":
|
case "progress":
|
||||||
tab.appendProgressToRun(runId, "\n[progress] " + message + "\n");
|
tab.appendProgressToRun(runId, "\n[progress] " + message + "\n");
|
||||||
tab.setRunStatus(runId, "running");
|
tab.setRunStatus(runId, "running");
|
||||||
@@ -94,21 +113,40 @@ public class BurpExtender implements IBurpExtender, IContextMenuFactory {
|
|||||||
tab.appendProgressToRun(runId, "\n[error] " + message + "\n");
|
tab.appendProgressToRun(runId, "\n[error] " + message + "\n");
|
||||||
tab.setRunStatus(runId, "error");
|
tab.setRunStatus(runId, "error");
|
||||||
break;
|
break;
|
||||||
|
case "reasoning_chain_stream_start":
|
||||||
|
tab.appendProgressToRun(runId, "\n\n[推理过程]\n");
|
||||||
|
break;
|
||||||
|
case "reasoning_chain_stream_delta":
|
||||||
|
if (message != null && !message.isEmpty()) {
|
||||||
|
tab.appendProgressToRun(runId, message);
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
case "reasoning_chain_stream_end":
|
||||||
|
tab.appendProgressToRun(runId, "\n");
|
||||||
|
break;
|
||||||
|
case "reasoning_chain":
|
||||||
|
if (message != null && !message.isEmpty()) {
|
||||||
|
String streamId = rawJson != null ? SimpleJson.extractStringField(rawJson, "streamId") : "";
|
||||||
|
if (streamId == null || streamId.isEmpty()) {
|
||||||
|
tab.appendProgressToRun(runId, "\n\n[推理过程]\n" + message + "\n");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
break;
|
||||||
case "thinking_stream_start":
|
case "thinking_stream_start":
|
||||||
if (tab.isShowDebugEvents()) {
|
if (tab.isShowDebugEvents()) {
|
||||||
tab.resetThinkingStream(runId);
|
tab.resetThinkingStream(runId);
|
||||||
}
|
}
|
||||||
break;
|
break;
|
||||||
case "thinking_stream_delta":
|
case "thinking_stream_delta":
|
||||||
|
if (tab.isShowDebugEvents() && message != null && !message.isEmpty()) {
|
||||||
|
tab.appendProgressToRun(runId, message);
|
||||||
|
}
|
||||||
|
break;
|
||||||
case "tool_call":
|
case "tool_call":
|
||||||
case "tool_result":
|
case "tool_result":
|
||||||
case "tool_result_delta":
|
case "tool_result_delta":
|
||||||
if (tab.isShowDebugEvents() && message != null && !message.isEmpty()) {
|
if (tab.isShowDebugEvents() && message != null && !message.isEmpty()) {
|
||||||
if ("thinking_stream_delta".equals(type)) {
|
tab.appendProgressToRun(runId, "\n[" + type + "] " + message + "\n");
|
||||||
tab.appendThinkingDelta(runId, message);
|
|
||||||
} else {
|
|
||||||
tab.appendProgressToRun(runId, "\n[" + type + "] " + message + "\n");
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
break;
|
break;
|
||||||
case "conversation":
|
case "conversation":
|
||||||
@@ -125,7 +163,9 @@ public class BurpExtender implements IBurpExtender, IContextMenuFactory {
|
|||||||
case "done":
|
case "done":
|
||||||
break;
|
break;
|
||||||
default:
|
default:
|
||||||
if (tab.isShowDebugEvents() && message != null && !message.isEmpty()) {
|
if (tab.isShowDebugEvents() && message != null && !message.isEmpty()
|
||||||
|
&& !type.endsWith("_stream_delta") && !type.endsWith("_stream_start")
|
||||||
|
&& !type.endsWith("_stream_end")) {
|
||||||
tab.appendProgressToRun(runId, "\n[" + type + "] " + message + "\n");
|
tab.appendProgressToRun(runId, "\n[" + type + "] " + message + "\n");
|
||||||
}
|
}
|
||||||
break;
|
break;
|
||||||
@@ -134,8 +174,9 @@ public class BurpExtender implements IBurpExtender, IContextMenuFactory {
|
|||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void onError(String message, Exception e) {
|
public void onError(String message, Exception e) {
|
||||||
tab.appendProgressToRun(runId, "\n[error] " + message + "\n");
|
boolean cancelled = message != null && message.toLowerCase().contains("cancel");
|
||||||
tab.setRunStatus(runId, "error");
|
tab.appendProgressToRun(runId, cancelled ? "\n[info] " + message + "\n" : "\n[error] " + message + "\n");
|
||||||
|
tab.setRunStatus(runId, cancelled ? "cancelled" : "error");
|
||||||
callbacks.printError("CyberStrikeAI stream error: " + message);
|
callbacks.printError("CyberStrikeAI stream error: " + message);
|
||||||
if (e != null) {
|
if (e != null) {
|
||||||
callbacks.printError(e.toString());
|
callbacks.printError(e.toString());
|
||||||
|
|||||||
+127
-11
@@ -2,17 +2,29 @@ package burp;
|
|||||||
|
|
||||||
import java.io.BufferedReader;
|
import java.io.BufferedReader;
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
|
import java.io.InterruptedIOException;
|
||||||
import java.io.InputStream;
|
import java.io.InputStream;
|
||||||
import java.io.InputStreamReader;
|
import java.io.InputStreamReader;
|
||||||
import java.io.OutputStream;
|
import java.io.OutputStream;
|
||||||
import java.net.HttpURLConnection;
|
import java.net.HttpURLConnection;
|
||||||
|
import java.net.SocketTimeoutException;
|
||||||
import java.net.URL;
|
import java.net.URL;
|
||||||
import java.nio.charset.StandardCharsets;
|
import java.nio.charset.StandardCharsets;
|
||||||
import java.util.HashMap;
|
import java.util.HashMap;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
|
import java.util.concurrent.atomic.AtomicReference;
|
||||||
|
|
||||||
final class CyberStrikeAIClient {
|
final class CyberStrikeAIClient {
|
||||||
|
|
||||||
|
private static final int AUTH_CONNECT_TIMEOUT_MS = 4_000;
|
||||||
|
private static final int AUTH_READ_TIMEOUT_MS = 5_000;
|
||||||
|
/** login + validate 整段上限,避免两次读超时叠加拖到半分钟 */
|
||||||
|
private static final int AUTH_OVERALL_TIMEOUT_MS = 10_000;
|
||||||
|
private static final int DEFAULT_READ_TIMEOUT_MS = 15_000;
|
||||||
|
|
||||||
|
private final AtomicReference<HttpURLConnection> activeConnection = new AtomicReference<>();
|
||||||
|
private final AtomicReference<Thread> activeThread = new AtomicReference<>();
|
||||||
|
|
||||||
static final class Config {
|
static final class Config {
|
||||||
final String baseUrl; // e.g. http://127.0.0.1:8080
|
final String baseUrl; // e.g. http://127.0.0.1:8080
|
||||||
final String password;
|
final String password;
|
||||||
@@ -49,15 +61,97 @@ final class CyberStrikeAIClient {
|
|||||||
void onDone();
|
void onDone();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
boolean hasActiveRequest() {
|
||||||
|
return activeConnection.get() != null;
|
||||||
|
}
|
||||||
|
|
||||||
|
void cancelActiveRequest() {
|
||||||
|
HttpURLConnection conn = activeConnection.getAndSet(null);
|
||||||
|
if (conn != null) {
|
||||||
|
try {
|
||||||
|
conn.disconnect();
|
||||||
|
} catch (Exception ignored) {
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Thread t = activeThread.getAndSet(null);
|
||||||
|
if (t != null) {
|
||||||
|
t.interrupt();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
String loginAndValidate(Config cfg) throws IOException {
|
String loginAndValidate(Config cfg) throws IOException {
|
||||||
String token = login(cfg.baseUrl, cfg.password);
|
Thread worker = Thread.currentThread();
|
||||||
validate(cfg.baseUrl, token);
|
java.util.Timer deadline = new java.util.Timer("CyberStrikeAI-AuthDeadline", true);
|
||||||
return token;
|
deadline.schedule(new java.util.TimerTask() {
|
||||||
|
@Override
|
||||||
|
public void run() {
|
||||||
|
worker.interrupt();
|
||||||
|
cancelActiveRequest();
|
||||||
|
}
|
||||||
|
}, AUTH_OVERALL_TIMEOUT_MS);
|
||||||
|
try {
|
||||||
|
String token = login(cfg.baseUrl, cfg.password);
|
||||||
|
if (Thread.interrupted()) {
|
||||||
|
throw timeoutIOException();
|
||||||
|
}
|
||||||
|
validate(cfg.baseUrl, token);
|
||||||
|
if (Thread.interrupted()) {
|
||||||
|
throw timeoutIOException();
|
||||||
|
}
|
||||||
|
return token;
|
||||||
|
} catch (SocketTimeoutException e) {
|
||||||
|
throw timeoutIOException();
|
||||||
|
} finally {
|
||||||
|
deadline.cancel();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private static IOException timeoutIOException() {
|
||||||
|
return new IOException("Connection timed out (~" + (AUTH_OVERALL_TIMEOUT_MS / 1000)
|
||||||
|
+ "s). Check host/port and HTTPS checkbox.");
|
||||||
|
}
|
||||||
|
|
||||||
|
private void trackConnection(HttpURLConnection conn) {
|
||||||
|
activeThread.set(Thread.currentThread());
|
||||||
|
activeConnection.set(conn);
|
||||||
|
}
|
||||||
|
|
||||||
|
private void releaseConnection(HttpURLConnection conn) {
|
||||||
|
if (activeConnection.compareAndSet(conn, null)) {
|
||||||
|
activeThread.set(null);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private static boolean isCancelled(Throwable e) {
|
||||||
|
if (e == null) {
|
||||||
|
return Thread.currentThread().isInterrupted();
|
||||||
|
}
|
||||||
|
if (Thread.currentThread().isInterrupted()) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
if (e instanceof InterruptedIOException) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
if (e instanceof SocketTimeoutException) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
Throwable cause = e.getCause();
|
||||||
|
if (cause != null && cause != e) {
|
||||||
|
return isCancelled(cause);
|
||||||
|
}
|
||||||
|
String msg = e.getMessage();
|
||||||
|
return msg != null && (
|
||||||
|
msg.toLowerCase().contains("cancel")
|
||||||
|
|| msg.toLowerCase().contains("abort")
|
||||||
|
|| msg.toLowerCase().contains("closed")
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
private String login(String baseUrl, String password) throws IOException {
|
private String login(String baseUrl, String password) throws IOException {
|
||||||
URL url = new URL(baseUrl + "/api/auth/login");
|
URL url = new URL(baseUrl + "/api/auth/login");
|
||||||
HttpURLConnection conn = (HttpURLConnection) url.openConnection();
|
HttpURLConnection conn = SslTrustAll.open(url, AUTH_CONNECT_TIMEOUT_MS, AUTH_READ_TIMEOUT_MS);
|
||||||
|
trackConnection(conn);
|
||||||
|
try {
|
||||||
conn.setRequestMethod("POST");
|
conn.setRequestMethod("POST");
|
||||||
conn.setDoOutput(true);
|
conn.setDoOutput(true);
|
||||||
conn.setRequestProperty("Content-Type", "application/json");
|
conn.setRequestProperty("Content-Type", "application/json");
|
||||||
@@ -92,11 +186,16 @@ final class CyberStrikeAIClient {
|
|||||||
throw new IOException("Login response missing token. Check backend address and credentials.");
|
throw new IOException("Login response missing token. Check backend address and credentials.");
|
||||||
}
|
}
|
||||||
return token;
|
return token;
|
||||||
|
} finally {
|
||||||
|
releaseConnection(conn);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private void validate(String baseUrl, String token) throws IOException {
|
private void validate(String baseUrl, String token) throws IOException {
|
||||||
URL url = new URL(baseUrl + "/api/auth/validate");
|
URL url = new URL(baseUrl + "/api/auth/validate");
|
||||||
HttpURLConnection conn = (HttpURLConnection) url.openConnection();
|
HttpURLConnection conn = SslTrustAll.open(url, AUTH_CONNECT_TIMEOUT_MS, AUTH_READ_TIMEOUT_MS);
|
||||||
|
trackConnection(conn);
|
||||||
|
try {
|
||||||
conn.setRequestMethod("GET");
|
conn.setRequestMethod("GET");
|
||||||
conn.setRequestProperty("Authorization", "Bearer " + token);
|
conn.setRequestProperty("Authorization", "Bearer " + token);
|
||||||
int code = conn.getResponseCode();
|
int code = conn.getResponseCode();
|
||||||
@@ -104,6 +203,9 @@ final class CyberStrikeAIClient {
|
|||||||
if (code < 200 || code >= 300) {
|
if (code < 200 || code >= 300) {
|
||||||
throw new IOException("Validate failed (" + code + "): " + resp);
|
throw new IOException("Validate failed (" + code + "): " + resp);
|
||||||
}
|
}
|
||||||
|
} finally {
|
||||||
|
releaseConnection(conn);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void streamTest(Config cfg, String token, String message, StreamListener listener) {
|
void streamTest(Config cfg, String token, String message, StreamListener listener) {
|
||||||
@@ -117,11 +219,12 @@ final class CyberStrikeAIClient {
|
|||||||
payload.put("orchestration", cfg.agentMode.orchestration);
|
payload.put("orchestration", cfg.agentMode.orchestration);
|
||||||
}
|
}
|
||||||
|
|
||||||
new Thread(() -> {
|
Thread worker = new Thread(() -> {
|
||||||
HttpURLConnection conn = null;
|
HttpURLConnection conn = null;
|
||||||
try {
|
try {
|
||||||
URL url = new URL(urlStr);
|
URL url = new URL(urlStr);
|
||||||
conn = (HttpURLConnection) url.openConnection();
|
conn = SslTrustAll.open(url, AUTH_CONNECT_TIMEOUT_MS, 0);
|
||||||
|
trackConnection(conn);
|
||||||
conn.setRequestMethod("POST");
|
conn.setRequestMethod("POST");
|
||||||
conn.setDoOutput(true);
|
conn.setDoOutput(true);
|
||||||
conn.setRequestProperty("Content-Type", "application/json");
|
conn.setRequestProperty("Content-Type", "application/json");
|
||||||
@@ -142,6 +245,9 @@ final class CyberStrikeAIClient {
|
|||||||
try (BufferedReader br = new BufferedReader(new InputStreamReader(is, StandardCharsets.UTF_8))) {
|
try (BufferedReader br = new BufferedReader(new InputStreamReader(is, StandardCharsets.UTF_8))) {
|
||||||
String line;
|
String line;
|
||||||
while ((line = br.readLine()) != null) {
|
while ((line = br.readLine()) != null) {
|
||||||
|
if (Thread.currentThread().isInterrupted()) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
// SSE format: "data: {json}"
|
// SSE format: "data: {json}"
|
||||||
if (line.startsWith("data:")) {
|
if (line.startsWith("data:")) {
|
||||||
String json = line.substring("data:".length()).trim();
|
String json = line.substring("data:".length()).trim();
|
||||||
@@ -156,15 +262,25 @@ final class CyberStrikeAIClient {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
listener.onDone();
|
if (Thread.currentThread().isInterrupted()) {
|
||||||
|
listener.onError("Cancelled.", null);
|
||||||
|
} else {
|
||||||
|
listener.onDone();
|
||||||
|
}
|
||||||
} catch (Exception e) {
|
} catch (Exception e) {
|
||||||
listener.onError(e.getMessage(), e);
|
if (isCancelled(e)) {
|
||||||
|
listener.onError("Cancelled.", e);
|
||||||
|
} else {
|
||||||
|
listener.onError(e.getMessage(), e);
|
||||||
|
}
|
||||||
} finally {
|
} finally {
|
||||||
if (conn != null) {
|
if (conn != null) {
|
||||||
|
releaseConnection(conn);
|
||||||
conn.disconnect();
|
conn.disconnect();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}, "CyberStrikeAI-Stream").start();
|
}, "CyberStrikeAI-Stream");
|
||||||
|
worker.start();
|
||||||
}
|
}
|
||||||
|
|
||||||
void cancelByConversationId(String baseUrl, String token, String conversationId) throws IOException {
|
void cancelByConversationId(String baseUrl, String token, String conversationId) throws IOException {
|
||||||
@@ -172,7 +288,7 @@ final class CyberStrikeAIClient {
|
|||||||
throw new IOException("Missing conversationId.");
|
throw new IOException("Missing conversationId.");
|
||||||
}
|
}
|
||||||
URL url = new URL(baseUrl + "/api/agent-loop/cancel");
|
URL url = new URL(baseUrl + "/api/agent-loop/cancel");
|
||||||
HttpURLConnection conn = (HttpURLConnection) url.openConnection();
|
HttpURLConnection conn = SslTrustAll.open(url, AUTH_CONNECT_TIMEOUT_MS, AUTH_READ_TIMEOUT_MS);
|
||||||
conn.setRequestMethod("POST");
|
conn.setRequestMethod("POST");
|
||||||
conn.setDoOutput(true);
|
conn.setDoOutput(true);
|
||||||
conn.setRequestProperty("Content-Type", "application/json");
|
conn.setRequestProperty("Content-Type", "application/json");
|
||||||
|
|||||||
+130
-34
@@ -14,6 +14,7 @@ final class CyberStrikeAITab implements ITab {
|
|||||||
|
|
||||||
private final JTextField hostField = new JTextField("127.0.0.1");
|
private final JTextField hostField = new JTextField("127.0.0.1");
|
||||||
private final JTextField portField = new JTextField("8080");
|
private final JTextField portField = new JTextField("8080");
|
||||||
|
private final JCheckBox useHttpsBox = new JCheckBox("HTTPS", true);
|
||||||
private final JPasswordField passwordField = new JPasswordField();
|
private final JPasswordField passwordField = new JPasswordField();
|
||||||
private final JComboBox<String> agentModeBox = new JComboBox<>(new String[]{
|
private final JComboBox<String> agentModeBox = new JComboBox<>(new String[]{
|
||||||
"Native ReAct", "Eino Single (ADK)", "Deep (DeepAgent)", "Plan-Execute", "Supervisor"
|
"Native ReAct", "Eino Single (ADK)", "Deep (DeepAgent)", "Plan-Execute", "Supervisor"
|
||||||
@@ -29,6 +30,10 @@ final class CyberStrikeAITab implements ITab {
|
|||||||
|
|
||||||
private final JTextArea progressArea = new JTextArea();
|
private final JTextArea progressArea = new JTextArea();
|
||||||
private final JTextArea finalRawArea = new JTextArea(); // raw final stream / final response
|
private final JTextArea finalRawArea = new JTextArea(); // raw final stream / final response
|
||||||
|
private JScrollPane progressScrollPane;
|
||||||
|
private JScrollPane finalRawScrollPane;
|
||||||
|
/** 距底部在此像素内视为「跟随滚动」,否则用户上拉阅读时不抢滚动条 */
|
||||||
|
private static final int SCROLL_FOLLOW_THRESHOLD_PX = 48;
|
||||||
private final JEditorPane markdownPane = new JEditorPane("text/html", "");
|
private final JEditorPane markdownPane = new JEditorPane("text/html", "");
|
||||||
private final CardLayout outputCardsLayout = new CardLayout();
|
private final CardLayout outputCardsLayout = new CardLayout();
|
||||||
private final JPanel outputCards = new JPanel(outputCardsLayout);
|
private final JPanel outputCards = new JPanel(outputCardsLayout);
|
||||||
@@ -41,6 +46,7 @@ final class CyberStrikeAITab implements ITab {
|
|||||||
|
|
||||||
private final CyberStrikeAIClient client = new CyberStrikeAIClient();
|
private final CyberStrikeAIClient client = new CyberStrikeAIClient();
|
||||||
private final AtomicReference<String> tokenRef = new AtomicReference<>("");
|
private final AtomicReference<String> tokenRef = new AtomicReference<>("");
|
||||||
|
private final AtomicReference<Thread> validateThreadRef = new AtomicReference<>();
|
||||||
|
|
||||||
private final DefaultListModel<TestRun> testListModel = new DefaultListModel<>();
|
private final DefaultListModel<TestRun> testListModel = new DefaultListModel<>();
|
||||||
private final JList<TestRun> testList = new JList<>(testListModel);
|
private final JList<TestRun> testList = new JList<>(testListModel);
|
||||||
@@ -107,6 +113,8 @@ final class CyberStrikeAITab implements ITab {
|
|||||||
row1.add(hostField);
|
row1.add(hostField);
|
||||||
row1.add(new JLabel("Port"));
|
row1.add(new JLabel("Port"));
|
||||||
row1.add(portField);
|
row1.add(portField);
|
||||||
|
useHttpsBox.setToolTipText("Use https:// for CyberStrikeAI (self-signed certs are trusted automatically)");
|
||||||
|
row1.add(useHttpsBox);
|
||||||
row1.add(new JLabel("Password"));
|
row1.add(new JLabel("Password"));
|
||||||
row1.add(passwordField);
|
row1.add(passwordField);
|
||||||
row1.add(validateButton);
|
row1.add(validateButton);
|
||||||
@@ -186,15 +194,22 @@ final class CyberStrikeAITab implements ITab {
|
|||||||
configureTextArea(requestArea, false);
|
configureTextArea(requestArea, false);
|
||||||
configureTextArea(responseArea, false);
|
configureTextArea(responseArea, false);
|
||||||
|
|
||||||
outputCards.add(new JScrollPane(finalRawArea), "raw");
|
finalRawScrollPane = new JScrollPane(finalRawArea);
|
||||||
|
finalRawScrollPane.setHorizontalScrollBarPolicy(ScrollPaneConstants.HORIZONTAL_SCROLLBAR_NEVER);
|
||||||
|
finalRawScrollPane.getVerticalScrollBar().setUnitIncrement(16);
|
||||||
|
outputCards.add(finalRawScrollPane, "raw");
|
||||||
outputCards.add(new JScrollPane(markdownPane), "md");
|
outputCards.add(new JScrollPane(markdownPane), "md");
|
||||||
|
|
||||||
outputRoot.add(buildOutputHeader(), BorderLayout.NORTH);
|
outputRoot.add(buildOutputHeader(), BorderLayout.NORTH);
|
||||||
outputRoot.add(buildOutputBody(), BorderLayout.CENTER);
|
outputRoot.add(buildOutputBody(), BorderLayout.CENTER);
|
||||||
|
|
||||||
rightTabs.addTab("Output", outputRoot);
|
rightTabs.addTab("Output", outputRoot);
|
||||||
rightTabs.addTab("Request", new JScrollPane(requestArea));
|
JScrollPane requestScroll = new JScrollPane(requestArea);
|
||||||
rightTabs.addTab("Response", new JScrollPane(responseArea));
|
requestScroll.setHorizontalScrollBarPolicy(ScrollPaneConstants.HORIZONTAL_SCROLLBAR_NEVER);
|
||||||
|
rightTabs.addTab("Request", requestScroll);
|
||||||
|
JScrollPane responseScroll = new JScrollPane(responseArea);
|
||||||
|
responseScroll.setHorizontalScrollBarPolicy(ScrollPaneConstants.HORIZONTAL_SCROLLBAR_NEVER);
|
||||||
|
rightTabs.addTab("Response", responseScroll);
|
||||||
return rightTabs;
|
return rightTabs;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -210,12 +225,13 @@ final class CyberStrikeAITab implements ITab {
|
|||||||
}
|
}
|
||||||
|
|
||||||
private JComponent buildOutputBody() {
|
private JComponent buildOutputBody() {
|
||||||
JScrollPane progressScroll = new JScrollPane(progressArea);
|
progressScrollPane = new JScrollPane(progressArea);
|
||||||
progressScroll.setBorder(BorderFactory.createTitledBorder("Progress"));
|
progressScrollPane.setBorder(BorderFactory.createTitledBorder("Progress"));
|
||||||
progressScroll.getVerticalScrollBar().setUnitIncrement(16);
|
progressScrollPane.setHorizontalScrollBarPolicy(ScrollPaneConstants.HORIZONTAL_SCROLLBAR_NEVER);
|
||||||
|
progressScrollPane.getVerticalScrollBar().setUnitIncrement(16);
|
||||||
|
|
||||||
JPanel empty = new JPanel();
|
JPanel empty = new JPanel();
|
||||||
progressContainer.add(progressScroll, "show");
|
progressContainer.add(progressScrollPane, "show");
|
||||||
progressContainer.add(empty, "hide");
|
progressContainer.add(empty, "hide");
|
||||||
((CardLayout) progressContainer.getLayout()).show(progressContainer, "show");
|
((CardLayout) progressContainer.getLayout()).show(progressContainer, "show");
|
||||||
|
|
||||||
@@ -259,10 +275,27 @@ final class CyberStrikeAITab implements ITab {
|
|||||||
return split;
|
return split;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private static boolean isScrollNearBottom(JScrollPane scrollPane) {
|
||||||
|
if (scrollPane == null) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
JScrollBar bar = scrollPane.getVerticalScrollBar();
|
||||||
|
int max = Math.max(0, bar.getMaximum() - bar.getVisibleAmount());
|
||||||
|
return bar.getValue() >= max - SCROLL_FOLLOW_THRESHOLD_PX;
|
||||||
|
}
|
||||||
|
|
||||||
|
private static void scrollPaneToBottom(JScrollPane scrollPane) {
|
||||||
|
if (scrollPane == null) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
JScrollBar bar = scrollPane.getVerticalScrollBar();
|
||||||
|
bar.setValue(bar.getMaximum());
|
||||||
|
}
|
||||||
|
|
||||||
private static void configureTextArea(JTextArea area, boolean monospaced) {
|
private static void configureTextArea(JTextArea area, boolean monospaced) {
|
||||||
area.setEditable(false);
|
area.setEditable(false);
|
||||||
area.setLineWrap(false);
|
area.setLineWrap(true);
|
||||||
area.setWrapStyleWord(false);
|
area.setWrapStyleWord(true);
|
||||||
if (monospaced) {
|
if (monospaced) {
|
||||||
area.setFont(new Font(Font.MONOSPACED, Font.PLAIN, 12));
|
area.setFont(new Font(Font.MONOSPACED, Font.PLAIN, 12));
|
||||||
} else {
|
} else {
|
||||||
@@ -381,24 +414,44 @@ final class CyberStrikeAITab implements ITab {
|
|||||||
|
|
||||||
private void wireActions() {
|
private void wireActions() {
|
||||||
validateButton.addActionListener(e -> {
|
validateButton.addActionListener(e -> {
|
||||||
validateButton.setEnabled(false);
|
if ("Cancel".equals(validateButton.getText())) {
|
||||||
|
cancelValidateInProgress();
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
validateButton.setText("Cancel");
|
||||||
|
validateButton.setEnabled(true);
|
||||||
|
stopButton.setEnabled(true);
|
||||||
statusLabel.setText("Validating...");
|
statusLabel.setText("Validating...");
|
||||||
log("Validating connection...");
|
log("Validating connection... (max ~10s; click Cancel or Stop to abort)");
|
||||||
new Thread(() -> {
|
Thread worker = new Thread(() -> {
|
||||||
try {
|
try {
|
||||||
CyberStrikeAIClient.Config cfg = currentConfig();
|
CyberStrikeAIClient.Config cfg = currentConfig();
|
||||||
String token = client.loginAndValidate(cfg);
|
String token = client.loginAndValidate(cfg);
|
||||||
|
if (Thread.currentThread().isInterrupted()) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
tokenRef.set(token);
|
tokenRef.set(token);
|
||||||
SwingUtilities.invokeLater(() -> statusLabel.setText("OK (token saved)"));
|
SwingUtilities.invokeLater(() -> statusLabel.setText("OK (token saved)"));
|
||||||
log("Validation OK.");
|
log("Validation OK.");
|
||||||
} catch (Exception ex) {
|
} catch (Exception ex) {
|
||||||
tokenRef.set("");
|
tokenRef.set("");
|
||||||
SwingUtilities.invokeLater(() -> statusLabel.setText("Failed: " + ex.getMessage()));
|
if (Thread.currentThread().isInterrupted()) {
|
||||||
log("Validation failed: " + ex.getMessage());
|
SwingUtilities.invokeLater(() -> statusLabel.setText("Cancelled"));
|
||||||
|
log("Validation cancelled.");
|
||||||
|
} else {
|
||||||
|
SwingUtilities.invokeLater(() -> statusLabel.setText("Failed: " + ex.getMessage()));
|
||||||
|
log("Validation failed: " + ex.getMessage());
|
||||||
|
}
|
||||||
} finally {
|
} finally {
|
||||||
SwingUtilities.invokeLater(() -> validateButton.setEnabled(true));
|
validateThreadRef.set(null);
|
||||||
|
SwingUtilities.invokeLater(() -> {
|
||||||
|
validateButton.setText("Validate");
|
||||||
|
validateButton.setEnabled(true);
|
||||||
|
});
|
||||||
}
|
}
|
||||||
}, "CyberStrikeAI-Validate").start();
|
}, "CyberStrikeAI-Validate");
|
||||||
|
validateThreadRef.set(worker);
|
||||||
|
worker.start();
|
||||||
});
|
});
|
||||||
|
|
||||||
clearButton.addActionListener(e -> {
|
clearButton.addActionListener(e -> {
|
||||||
@@ -435,10 +488,23 @@ final class CyberStrikeAITab implements ITab {
|
|||||||
});
|
});
|
||||||
|
|
||||||
stopButton.addActionListener(e -> {
|
stopButton.addActionListener(e -> {
|
||||||
|
if ("Cancel".equals(validateButton.getText())) {
|
||||||
|
cancelValidateInProgress();
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
String runId = selectedRunId;
|
String runId = selectedRunId;
|
||||||
|
if (runId != null && client.hasActiveRequest()) {
|
||||||
|
client.cancelActiveRequest();
|
||||||
|
appendProgressToRun(runId, "\n[info] Stream stopped.\n");
|
||||||
|
setRunStatus(runId, "cancelled");
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
if (runId == null) return;
|
if (runId == null) return;
|
||||||
TestRun run = runs.get(runId);
|
TestRun run = runs.get(runId);
|
||||||
if (run == null) return;
|
if (run == null) return;
|
||||||
|
|
||||||
String token = getToken();
|
String token = getToken();
|
||||||
if (token == null || token.trim().isEmpty()) {
|
if (token == null || token.trim().isEmpty()) {
|
||||||
appendProgressToRun(runId, "\n[error] Not validated.\n");
|
appendProgressToRun(runId, "\n[error] Not validated.\n");
|
||||||
@@ -483,7 +549,8 @@ final class CyberStrikeAITab implements ITab {
|
|||||||
String host = hostField.getText().trim();
|
String host = hostField.getText().trim();
|
||||||
String port = portField.getText().trim();
|
String port = portField.getText().trim();
|
||||||
String password = new String(passwordField.getPassword());
|
String password = new String(passwordField.getPassword());
|
||||||
String baseUrl = "http://" + host + ":" + port;
|
String scheme = useHttpsBox.isSelected() ? "https" : "http";
|
||||||
|
String baseUrl = scheme + "://" + host + ":" + port;
|
||||||
int idx = agentModeBox.getSelectedIndex();
|
int idx = agentModeBox.getSelectedIndex();
|
||||||
CyberStrikeAIClient.AgentMode mode = (idx >= 0 && idx < AGENT_MODES.length)
|
CyberStrikeAIClient.AgentMode mode = (idx >= 0 && idx < AGENT_MODES.length)
|
||||||
? AGENT_MODES[idx]
|
? AGENT_MODES[idx]
|
||||||
@@ -567,10 +634,31 @@ final class CyberStrikeAITab implements ITab {
|
|||||||
run.progressBuffer.append(s);
|
run.progressBuffer.append(s);
|
||||||
}
|
}
|
||||||
if (runId.equals(selectedRunId)) {
|
if (runId.equals(selectedRunId)) {
|
||||||
SwingUtilities.invokeLater(() -> {
|
SwingUtilities.invokeLater(() -> appendProgressUi(s, false));
|
||||||
progressArea.append(s);
|
}
|
||||||
progressArea.setCaretPosition(progressArea.getDocument().getLength());
|
}
|
||||||
});
|
|
||||||
|
private void appendProgressUi(String s, boolean forceFollow) {
|
||||||
|
JScrollBar bar = progressScrollPane != null ? progressScrollPane.getVerticalScrollBar() : null;
|
||||||
|
int scrollBefore = bar != null ? bar.getValue() : 0;
|
||||||
|
boolean follow = forceFollow || isScrollNearBottom(progressScrollPane);
|
||||||
|
progressArea.append(s);
|
||||||
|
if (follow) {
|
||||||
|
scrollPaneToBottom(progressScrollPane);
|
||||||
|
} else if (bar != null) {
|
||||||
|
bar.setValue(scrollBefore);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private void appendFinalUi(String s, boolean forceFollow) {
|
||||||
|
JScrollBar bar = finalRawScrollPane != null ? finalRawScrollPane.getVerticalScrollBar() : null;
|
||||||
|
int scrollBefore = bar != null ? bar.getValue() : 0;
|
||||||
|
boolean follow = forceFollow || isScrollNearBottom(finalRawScrollPane);
|
||||||
|
finalRawArea.append(s);
|
||||||
|
if (follow) {
|
||||||
|
scrollPaneToBottom(finalRawScrollPane);
|
||||||
|
} else if (bar != null) {
|
||||||
|
bar.setValue(scrollBefore);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -620,10 +708,7 @@ final class CyberStrikeAITab implements ITab {
|
|||||||
run.finalBuffer.append(s);
|
run.finalBuffer.append(s);
|
||||||
}
|
}
|
||||||
if (runId.equals(selectedRunId)) {
|
if (runId.equals(selectedRunId)) {
|
||||||
SwingUtilities.invokeLater(() -> {
|
SwingUtilities.invokeLater(() -> appendFinalUi(s, false));
|
||||||
finalRawArea.append(s);
|
|
||||||
finalRawArea.setCaretPosition(finalRawArea.getDocument().getLength());
|
|
||||||
});
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -656,9 +741,9 @@ final class CyberStrikeAITab implements ITab {
|
|||||||
}
|
}
|
||||||
SwingUtilities.invokeLater(() -> {
|
SwingUtilities.invokeLater(() -> {
|
||||||
progressArea.setText(progress);
|
progressArea.setText(progress);
|
||||||
progressArea.setCaretPosition(progressArea.getDocument().getLength());
|
scrollPaneToBottom(progressScrollPane);
|
||||||
finalRawArea.setText(fin);
|
finalRawArea.setText(fin);
|
||||||
finalRawArea.setCaretPosition(finalRawArea.getDocument().getLength());
|
scrollPaneToBottom(finalRawScrollPane);
|
||||||
requestArea.setText(run.requestRaw == null ? "" : run.requestRaw);
|
requestArea.setText(run.requestRaw == null ? "" : run.requestRaw);
|
||||||
responseArea.setText(run.responseRaw == null ? "" : run.responseRaw);
|
responseArea.setText(run.responseRaw == null ? "" : run.responseRaw);
|
||||||
refreshOutputView();
|
refreshOutputView();
|
||||||
@@ -682,25 +767,36 @@ final class CyberStrikeAITab implements ITab {
|
|||||||
|
|
||||||
void clearAndShowStreamHeader(String title) {
|
void clearAndShowStreamHeader(String title) {
|
||||||
SwingUtilities.invokeLater(() -> {
|
SwingUtilities.invokeLater(() -> {
|
||||||
progressArea.setText("");
|
progressArea.setText("[*] " + title + "\n\n");
|
||||||
finalRawArea.setText(title + "\n\n");
|
finalRawArea.setText("");
|
||||||
|
markdownPane.setText("");
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
// Legacy helpers kept for Validate logging
|
// Legacy helpers kept for Validate logging
|
||||||
void appendStreamLine(String s) {
|
void appendStreamLine(String s) {
|
||||||
if (s == null) return;
|
if (s == null) return;
|
||||||
SwingUtilities.invokeLater(() -> {
|
SwingUtilities.invokeLater(() -> appendProgressUi(s + "\n", false));
|
||||||
progressArea.append(s);
|
|
||||||
progressArea.append("\n");
|
|
||||||
progressArea.setCaretPosition(progressArea.getDocument().getLength());
|
|
||||||
});
|
|
||||||
}
|
}
|
||||||
|
|
||||||
private void log(String s) {
|
private void log(String s) {
|
||||||
appendStreamLine("[*] " + s);
|
appendStreamLine("[*] " + s);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private void cancelValidateInProgress() {
|
||||||
|
client.cancelActiveRequest();
|
||||||
|
Thread t = validateThreadRef.getAndSet(null);
|
||||||
|
if (t != null) {
|
||||||
|
t.interrupt();
|
||||||
|
}
|
||||||
|
SwingUtilities.invokeLater(() -> {
|
||||||
|
statusLabel.setText("Cancelled");
|
||||||
|
validateButton.setText("Validate");
|
||||||
|
validateButton.setEnabled(true);
|
||||||
|
});
|
||||||
|
log("Validation cancelled.");
|
||||||
|
}
|
||||||
|
|
||||||
private void applyFilter() {
|
private void applyFilter() {
|
||||||
String q = searchField.getText();
|
String q = searchField.getText();
|
||||||
if (q == null) q = "";
|
if (q == null) q = "";
|
||||||
|
|||||||
@@ -0,0 +1,149 @@
|
|||||||
|
package burp;
|
||||||
|
|
||||||
|
import javax.net.ssl.HostnameVerifier;
|
||||||
|
import javax.net.ssl.HttpsURLConnection;
|
||||||
|
import javax.net.ssl.SSLSocketFactory;
|
||||||
|
import javax.net.ssl.SSLContext;
|
||||||
|
import javax.net.ssl.TrustManager;
|
||||||
|
import javax.net.ssl.X509TrustManager;
|
||||||
|
import java.io.IOException;
|
||||||
|
import java.net.HttpURLConnection;
|
||||||
|
import java.net.InetSocketAddress;
|
||||||
|
import java.net.Socket;
|
||||||
|
import java.net.URL;
|
||||||
|
import java.security.cert.X509Certificate;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Opens HTTPS connections without validating server certificates (self-signed / local dev).
|
||||||
|
* Applied per-connection only; does not change JVM-wide defaults for other Burp components.
|
||||||
|
*/
|
||||||
|
final class SslTrustAll {
|
||||||
|
|
||||||
|
private static volatile SSLSocketFactory socketFactory;
|
||||||
|
private static final HostnameVerifier TRUST_ALL_HOSTS = (hostname, session) -> true;
|
||||||
|
|
||||||
|
private SslTrustAll() {
|
||||||
|
}
|
||||||
|
|
||||||
|
static HttpURLConnection open(URL url) throws IOException {
|
||||||
|
return open(url, 5_000, 30_000);
|
||||||
|
}
|
||||||
|
|
||||||
|
static HttpURLConnection open(URL url, int connectTimeoutMs, int readTimeoutMs) throws IOException {
|
||||||
|
HttpURLConnection conn = (HttpURLConnection) url.openConnection();
|
||||||
|
conn.setConnectTimeout(connectTimeoutMs);
|
||||||
|
conn.setReadTimeout(readTimeoutMs);
|
||||||
|
if (conn instanceof HttpsURLConnection) {
|
||||||
|
HttpsURLConnection https = (HttpsURLConnection) conn;
|
||||||
|
https.setSSLSocketFactory(new TimeoutSslSocketFactory(socketFactory(), connectTimeoutMs, readTimeoutMs));
|
||||||
|
https.setHostnameVerifier(TRUST_ALL_HOSTS);
|
||||||
|
}
|
||||||
|
return conn;
|
||||||
|
}
|
||||||
|
|
||||||
|
private static SSLSocketFactory socketFactory() {
|
||||||
|
SSLSocketFactory sf = socketFactory;
|
||||||
|
if (sf != null) {
|
||||||
|
return sf;
|
||||||
|
}
|
||||||
|
synchronized (SslTrustAll.class) {
|
||||||
|
sf = socketFactory;
|
||||||
|
if (sf != null) {
|
||||||
|
return sf;
|
||||||
|
}
|
||||||
|
try {
|
||||||
|
TrustManager[] trustAll = new TrustManager[]{
|
||||||
|
new X509TrustManager() {
|
||||||
|
@Override
|
||||||
|
public X509Certificate[] getAcceptedIssuers() {
|
||||||
|
return new X509Certificate[0];
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void checkClientTrusted(X509Certificate[] chain, String authType) {
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void checkServerTrusted(X509Certificate[] chain, String authType) {
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
SSLContext ctx = SSLContext.getInstance("TLS");
|
||||||
|
ctx.init(null, trustAll, new java.security.SecureRandom());
|
||||||
|
sf = ctx.getSocketFactory();
|
||||||
|
socketFactory = sf;
|
||||||
|
return sf;
|
||||||
|
} catch (Exception e) {
|
||||||
|
throw new RuntimeException("Failed to initialize trust-all TLS", e);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Ensures TCP connect + socket read respect timeouts (plain HttpURLConnection SSL can hang longer). */
|
||||||
|
private static final class TimeoutSslSocketFactory extends SSLSocketFactory {
|
||||||
|
private final SSLSocketFactory delegate;
|
||||||
|
private final int connectTimeoutMs;
|
||||||
|
private final int readTimeoutMs;
|
||||||
|
|
||||||
|
TimeoutSslSocketFactory(SSLSocketFactory delegate, int connectTimeoutMs, int readTimeoutMs) {
|
||||||
|
this.delegate = delegate;
|
||||||
|
this.connectTimeoutMs = connectTimeoutMs;
|
||||||
|
this.readTimeoutMs = readTimeoutMs;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String[] getDefaultCipherSuites() {
|
||||||
|
return delegate.getDefaultCipherSuites();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String[] getSupportedCipherSuites() {
|
||||||
|
return delegate.getSupportedCipherSuites();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Socket createSocket() throws IOException {
|
||||||
|
return tune(delegate.createSocket());
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Socket createSocket(Socket s, String host, int port, boolean autoClose) throws IOException {
|
||||||
|
return tune(delegate.createSocket(s, host, port, autoClose));
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Socket createSocket(String host, int port) throws IOException {
|
||||||
|
Socket plain = new Socket();
|
||||||
|
plain.connect(new InetSocketAddress(host, port), connectTimeoutMs);
|
||||||
|
return tune(delegate.createSocket(plain, host, port, true));
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Socket createSocket(String host, int port, java.net.InetAddress localHost, int localPort) throws IOException {
|
||||||
|
Socket plain = new Socket();
|
||||||
|
plain.bind(new InetSocketAddress(localHost, localPort));
|
||||||
|
plain.connect(new InetSocketAddress(host, port), connectTimeoutMs);
|
||||||
|
return tune(delegate.createSocket(plain, host, port, true));
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Socket createSocket(java.net.InetAddress host, int port) throws IOException {
|
||||||
|
Socket plain = new Socket();
|
||||||
|
plain.connect(new InetSocketAddress(host, port), connectTimeoutMs);
|
||||||
|
return tune(delegate.createSocket(plain, host.getHostName(), port, true));
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Socket createSocket(java.net.InetAddress address, int port, java.net.InetAddress localAddress, int localPort) throws IOException {
|
||||||
|
Socket plain = new Socket();
|
||||||
|
plain.bind(new InetSocketAddress(localAddress, localPort));
|
||||||
|
plain.connect(new InetSocketAddress(address, port), connectTimeoutMs);
|
||||||
|
return tune(delegate.createSocket(plain, address.getHostName(), port, true));
|
||||||
|
}
|
||||||
|
|
||||||
|
private Socket tune(Socket socket) throws IOException {
|
||||||
|
socket.setSoTimeout(readTimeoutMs);
|
||||||
|
return socket;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
BIN
Binary file not shown.
BIN
Binary file not shown.
BIN
Binary file not shown.
BIN
Binary file not shown.
BIN
Binary file not shown.
BIN
Binary file not shown.
BIN
Binary file not shown.
BIN
Binary file not shown.
BIN
Binary file not shown.
BIN
Binary file not shown.
BIN
Binary file not shown.
BIN
Binary file not shown.
BIN
Binary file not shown.
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user