mirror of
https://github.com/Ed1s0nZ/CyberStrikeAI.git
synced 2026-06-22 22:10:06 +02:00
Compare commits
87 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| c3d2a41301 | |||
| 1a2e282d46 | |||
| 8129f2147f | |||
| 4a9889f0af | |||
| 732d47a965 | |||
| e22382aab0 | |||
| b6ff80adf2 | |||
| 51f1cfde2f | |||
| b2c8913014 | |||
| ae98288b62 | |||
| 9955e856a0 | |||
| 018544e5f9 | |||
| c1c86e4632 | |||
| 08d77bc12b | |||
| ce73a7b3e4 | |||
| f78f424aab | |||
| e19d8e39bd | |||
| ecf594a25b | |||
| d5759f6d83 | |||
| 81b3f64b15 | |||
| 0e0f1352f0 | |||
| ffba311afd | |||
| d9ed36cfb1 | |||
| b7f80b78ee | |||
| 8f8e5cfff5 | |||
| 120f860640 | |||
| 90cd119a83 | |||
| 56d597e0c5 | |||
| 11ab5cde8f | |||
| 46a7d338a4 | |||
| 46f68cc1d4 | |||
| 7003cdb2e3 | |||
| 4e5e6208bd | |||
| 6a7e78a846 | |||
| 88c6fbfb75 | |||
| 1cd6d0fa90 | |||
| 24390db100 | |||
| c000fe5195 | |||
| 0b4a11d01a | |||
| d433e44a7d | |||
| 7de51fe0ea | |||
| a354cf97e5 | |||
| c180f07c7e | |||
| 15730d3ef4 | |||
| b7fa18b6d4 | |||
| 8d622f63ff | |||
| 20b05146fb | |||
| d8768eae76 | |||
| 9232cee38d | |||
| 6c975e63d2 | |||
| e175523b82 | |||
| ae23427d9e | |||
| 93a2504ce3 | |||
| 09b0479fb3 | |||
| 2bdc9d4fe0 | |||
| 01b3d8056c | |||
| ed479d5e4d | |||
| a49f595231 | |||
| 82cf014a5e | |||
| 508de5fad0 | |||
| 6712344411 | |||
| 7eadccbff6 | |||
| 01b361e4a7 | |||
| f6ce31c961 | |||
| d5a0f93c6c | |||
| 56faefaaf9 | |||
| 16e9c5874a | |||
| 41b5cdde6b | |||
| cf1f8515d9 | |||
| 5e2b30c029 | |||
| 8c7c22369e | |||
| 9b1aba692b | |||
| db730b48c1 | |||
| dfb7dd7390 | |||
| 9f6eb33047 | |||
| 616d87f4cc | |||
| 8d999792b8 | |||
| afae8970d1 | |||
| 4d7330c5c3 | |||
| 8884bfb0b4 | |||
| fb351c80b6 | |||
| 664834e338 | |||
| 95bf62db88 | |||
| 656242614d | |||
| a9d6d8c00e | |||
| 0d6a43c0a8 | |||
| 702f286eb1 |
@@ -112,7 +112,7 @@ CyberStrikeAI is an **AI-native security testing platform** built in Go. It inte
|
|||||||
- 🔒 Password-protected web UI, audit logs, and SQLite persistence
|
- 🔒 Password-protected web UI, audit logs, and SQLite persistence
|
||||||
- 📚 Knowledge base (RAG) with embedding-based vector retrieval (cosine similarity), optional **Eino Compose** indexing pipeline, and configurable post-retrieval budgets / reranking hooks
|
- 📚 Knowledge base (RAG) with embedding-based vector retrieval (cosine similarity), optional **Eino Compose** indexing pipeline, and configurable post-retrieval budgets / reranking hooks
|
||||||
- 📁 Conversation grouping with pinning, rename, and batch management
|
- 📁 Conversation grouping with pinning, rename, and batch management
|
||||||
- 📂 **Project management**: group conversations and vulnerabilities by project; **shared facts** (project blackboard) persist cross-session context (targets, env, auth notes) with auto-injection for agents and MCP tools (`upsert_project_fact`, `get_project_fact`, …)
|
- 📂 **Project management**: shared facts (blackboard) across sessions, `upsert_project_fact` + `links` to chain paths; attack-chain and project fact graph views
|
||||||
- 🛡️ Vulnerability management with CRUD operations, severity tracking, status workflow, and statistics
|
- 🛡️ Vulnerability management with CRUD operations, severity tracking, status workflow, and statistics
|
||||||
- 📋 Batch task management: create task queues, add multiple tasks, and execute them sequentially
|
- 📋 Batch task management: create task queues, add multiple tasks, and execute them sequentially
|
||||||
- 🎭 Role-based testing: predefined security testing roles (Penetration Testing, CTF, Web App Scanning, etc.) with custom prompts and tool restrictions
|
- 🎭 Role-based testing: predefined security testing roles (Penetration Testing, CTF, Web App Scanning, etc.) with custom prompts and tool restrictions
|
||||||
@@ -312,7 +312,7 @@ Requirements / tips:
|
|||||||
### Tool Orchestration & Extensions
|
### Tool Orchestration & Extensions
|
||||||
- **YAML recipes** in `tools/*.yaml` describe commands, arguments, prompts, and metadata.
|
- **YAML recipes** in `tools/*.yaml` describe commands, arguments, prompts, and metadata.
|
||||||
- **Directory hot-reload** – pointing `security.tools_dir` to a folder is usually enough; inline definitions in `config.yaml` remain supported for quick experiments.
|
- **Directory hot-reload** – pointing `security.tools_dir` to a folder is usually enough; inline definitions in `config.yaml` remain supported for quick experiments.
|
||||||
- **Large-result pagination** – outputs beyond 200 KB are stored as artifacts retrievable through the `query_execution_result` tool with paging, filters, and regex search.
|
- **Large tool outputs** – outputs beyond `reduction_max_length_for_trunc` are summarized via Eino reduction with full content persisted under `tmp/reduction/`; use `read_file` on the path in `<persisted-output>`.
|
||||||
- **Result compression** – multi-megabyte logs can be summarized or losslessly compressed before persisting to keep SQLite lean.
|
- **Result compression** – multi-megabyte logs can be summarized or losslessly compressed before persisting to keep SQLite lean.
|
||||||
|
|
||||||
**Creating a custom tool (typical flow)**
|
**Creating a custom tool (typical flow)**
|
||||||
@@ -551,6 +551,11 @@ multi_agent:
|
|||||||
# orchestrator_instruction_plan_execute / orchestrator_instruction_supervisor optional
|
# orchestrator_instruction_plan_execute / orchestrator_instruction_supervisor optional
|
||||||
# eino_skills: { disable: false, filesystem_tools: true, skill_tool_name: skill }
|
# eino_skills: { disable: false, filesystem_tools: true, skill_tool_name: skill }
|
||||||
# eino_middleware: plantask_enable, checkpoint_dir, deep_model_retry_max_retries, deep_output_key, ...
|
# eino_middleware: plantask_enable, checkpoint_dir, deep_model_retry_max_retries, deep_output_key, ...
|
||||||
|
project:
|
||||||
|
enabled: true # Enable project blackboard & fact MCP tools
|
||||||
|
fact_index_max_runes: 65000
|
||||||
|
fact_summary_max_runes: 24000
|
||||||
|
default_inject_deprecated: false
|
||||||
```
|
```
|
||||||
|
|
||||||
### Tool Definition Example (`tools/nmap.yaml`)
|
### Tool Definition Example (`tools/nmap.yaml`)
|
||||||
|
|||||||
+7
-2
@@ -111,7 +111,7 @@ CyberStrikeAI 是一款 **AI 原生安全测试平台**,基于 Go 构建,集
|
|||||||
- 🔒 Web 登录保护、审计日志、SQLite 持久化
|
- 🔒 Web 登录保护、审计日志、SQLite 持久化
|
||||||
- 📚 知识库(RAG):向量嵌入与余弦相似度检索(与 Eino `retriever.Retriever` 语义一致),可选 **Eino Compose** 索引流水线及检索后处理(预算、重排等配置项)
|
- 📚 知识库(RAG):向量嵌入与余弦相似度检索(与 Eino `retriever.Retriever` 语义一致),可选 **Eino Compose** 索引流水线及检索后处理(预算、重排等配置项)
|
||||||
- 📁 对话分组管理:支持分组创建、置顶、重命名、删除等操作
|
- 📁 对话分组管理:支持分组创建、置顶、重命名、删除等操作
|
||||||
- 📂 **项目管理**:按项目归类对话与漏洞;**共享事实**(项目黑板)在多会话间沉淀目标/环境/认证等认知,自动注入 Agent 上下文,支持 MCP 工具读写(`upsert_project_fact`、`get_project_fact` 等)
|
- 📂 **项目管理**:共享事实(黑板)跨会话沉淀认知,`upsert_project_fact` + `links` 串联攻击路径;聊天攻击链与项目事实图可视化
|
||||||
- 🛡️ 漏洞管理功能:完整的漏洞 CRUD 操作,支持严重程度分级、状态流转、按对话/严重程度/状态过滤,以及统计看板
|
- 🛡️ 漏洞管理功能:完整的漏洞 CRUD 操作,支持严重程度分级、状态流转、按对话/严重程度/状态过滤,以及统计看板
|
||||||
- 📋 批量任务管理:创建任务队列,批量添加任务,依次顺序执行,支持任务编辑与状态跟踪
|
- 📋 批量任务管理:创建任务队列,批量添加任务,依次顺序执行,支持任务编辑与状态跟踪
|
||||||
- 🎭 角色化测试:预设安全测试角色(渗透测试、CTF、Web 应用扫描等),支持自定义提示词和工具限制
|
- 🎭 角色化测试:预设安全测试角色(渗透测试、CTF、Web 应用扫描等),支持自定义提示词和工具限制
|
||||||
@@ -310,7 +310,7 @@ go build -o cyberstrike-ai cmd/server/main.go
|
|||||||
### 工具编排与扩展
|
### 工具编排与扩展
|
||||||
- `tools/*.yaml` 定义命令、参数、提示词与元数据,可热加载。
|
- `tools/*.yaml` 定义命令、参数、提示词与元数据,可热加载。
|
||||||
- `security.tools_dir` 指向目录即可批量启用;仍支持在主配置里内联定义。
|
- `security.tools_dir` 指向目录即可批量启用;仍支持在主配置里内联定义。
|
||||||
- **大结果分页**:超过 200KB 的输出会保存为附件,可通过 `query_execution_result` 工具分页、过滤、正则检索。
|
- **大工具输出**:超过 `reduction_max_length_for_trunc` 时由 Eino reduction 摘要,完整内容落盘至 `tmp/reduction/`;按 `<persisted-output>` 中的路径用 `read_file` 读取。
|
||||||
- **结果压缩/摘要**:多兆字节日志可先压缩或生成摘要再写入 SQLite,减小档案体积。
|
- **结果压缩/摘要**:多兆字节日志可先压缩或生成摘要再写入 SQLite,减小档案体积。
|
||||||
|
|
||||||
**自定义工具的一般步骤**
|
**自定义工具的一般步骤**
|
||||||
@@ -549,6 +549,11 @@ multi_agent:
|
|||||||
# orchestrator_instruction_plan_execute / orchestrator_instruction_supervisor 可选
|
# orchestrator_instruction_plan_execute / orchestrator_instruction_supervisor 可选
|
||||||
# eino_skills: { disable: false, filesystem_tools: true, skill_tool_name: skill }
|
# eino_skills: { disable: false, filesystem_tools: true, skill_tool_name: skill }
|
||||||
# eino_middleware: plantask_enable、checkpoint_dir、deep_model_retry_max_retries、deep_output_key 等
|
# eino_middleware: plantask_enable、checkpoint_dir、deep_model_retry_max_retries、deep_output_key 等
|
||||||
|
project:
|
||||||
|
enabled: true # 启用项目黑板与事实 MCP 工具
|
||||||
|
fact_index_max_runes: 65000
|
||||||
|
fact_summary_max_runes: 24000
|
||||||
|
default_inject_deprecated: false
|
||||||
```
|
```
|
||||||
|
|
||||||
### 工具模版示例(`tools/nmap.yaml`)
|
### 工具模版示例(`tools/nmap.yaml`)
|
||||||
|
|||||||
@@ -5,7 +5,6 @@ import (
|
|||||||
"cyberstrike-ai/internal/logger"
|
"cyberstrike-ai/internal/logger"
|
||||||
"cyberstrike-ai/internal/mcp"
|
"cyberstrike-ai/internal/mcp"
|
||||||
"cyberstrike-ai/internal/security"
|
"cyberstrike-ai/internal/security"
|
||||||
"cyberstrike-ai/internal/storage"
|
|
||||||
"flag"
|
"flag"
|
||||||
"fmt"
|
"fmt"
|
||||||
"os"
|
"os"
|
||||||
@@ -33,23 +32,6 @@ func main() {
|
|||||||
// 创建安全工具执行器
|
// 创建安全工具执行器
|
||||||
executor := security.NewExecutor(&cfg.Security, mcpServer, log.Logger)
|
executor := security.NewExecutor(&cfg.Security, mcpServer, log.Logger)
|
||||||
|
|
||||||
// 初始化结果存储(与 internal/app/app.go 同样的逻辑)。
|
|
||||||
// stdio 模式下原本不初始化,导致 'exec' 等查询型工具报"结果存储未初始化"。
|
|
||||||
resultStorageDir := "tmp"
|
|
||||||
if cfg.Agent.ResultStorageDir != "" {
|
|
||||||
resultStorageDir = cfg.Agent.ResultStorageDir
|
|
||||||
}
|
|
||||||
if err := os.MkdirAll(resultStorageDir, 0755); err != nil {
|
|
||||||
fmt.Fprintf(os.Stderr, "创建结果存储目录失败: %v\n", err)
|
|
||||||
os.Exit(1)
|
|
||||||
}
|
|
||||||
resultStorage, err := storage.NewFileResultStorage(resultStorageDir, log.Logger)
|
|
||||||
if err != nil {
|
|
||||||
fmt.Fprintf(os.Stderr, "初始化结果存储失败: %v\n", err)
|
|
||||||
os.Exit(1)
|
|
||||||
}
|
|
||||||
executor.SetResultStorage(resultStorage)
|
|
||||||
|
|
||||||
// 注册工具
|
// 注册工具
|
||||||
executor.RegisterTools(mcpServer)
|
executor.RegisterTools(mcpServer)
|
||||||
|
|
||||||
@@ -61,4 +43,3 @@ func main() {
|
|||||||
os.Exit(1)
|
os.Exit(1)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
+8
-8
@@ -10,7 +10,7 @@
|
|||||||
# ============================================
|
# ============================================
|
||||||
|
|
||||||
# 前端显示的版本号(可选,不填则显示默认版本)
|
# 前端显示的版本号(可选,不填则显示默认版本)
|
||||||
version: "v1.6.37"
|
version: "v1.6.42"
|
||||||
# 服务器配置
|
# 服务器配置
|
||||||
server:
|
server:
|
||||||
host: 0.0.0.0 # 监听地址,0.0.0.0 表示监听所有网络接口
|
host: 0.0.0.0 # 监听地址,0.0.0.0 表示监听所有网络接口
|
||||||
@@ -58,7 +58,7 @@ openai:
|
|||||||
api_key: sk-xxxxxxx # API 密钥(必填)
|
api_key: sk-xxxxxxx # API 密钥(必填)
|
||||||
model: qwen3-max # 模型名称(必填)
|
model: qwen3-max # 模型名称(必填)
|
||||||
max_total_tokens: 120000 # LLM 相关上下文的最大 Token 数限制(内存压缩和攻击链构建会共用此配置)
|
max_total_tokens: 120000 # LLM 相关上下文的最大 Token 数限制(内存压缩和攻击链构建会共用此配置)
|
||||||
# Eino 路径模型推理:DeepSeek/OpenAI 为 thinking / reasoning_effort 等;provider 为 claude 时合并为 Anthropic 顶层 thinking(extended thinking),mode: off 关闭
|
# Eino 路径模型推理:DeepSeek/OpenAI 为 thinking / reasoning_effort;Claude 4.6+ 为 adaptive + output_config.effort(仅显式配置 effort 时下发);3.7 为 enabled+budget_tokens:10000(文档示例),effort 不映射,自定义预算用 extra_request_fields
|
||||||
reasoning:
|
reasoning:
|
||||||
mode: on # auto | on | off;off 时不附加任何推理扩展字段
|
mode: on # auto | on | off;off 时不附加任何推理扩展字段
|
||||||
effort: high # low | medium | high | max | xhigh(最高档:OpenAI 常用 xhigh,部分网关用 max,原样下发);空表示不指定
|
effort: high # low | medium | high | max | xhigh(最高档:OpenAI 常用 xhigh,部分网关用 max,原样下发);空表示不指定
|
||||||
@@ -92,8 +92,6 @@ fofa:
|
|||||||
# 达到最大迭代次数时,AI 会自动总结测试结果
|
# 达到最大迭代次数时,AI 会自动总结测试结果
|
||||||
agent:
|
agent:
|
||||||
max_iterations: 12000 # 全局最大迭代次数(单代理 / Deep / Supervisor / Plan-Execute 主执行器 / 子代理均沿用;agents/*.md 中 max_iterations>0 可单独覆盖)
|
max_iterations: 12000 # 全局最大迭代次数(单代理 / Deep / Supervisor / Plan-Execute 主执行器 / 子代理均沿用;agents/*.md 中 max_iterations>0 可单独覆盖)
|
||||||
large_result_threshold: 102400 # 大结果阈值(字节),默认50KB,超过此大小会自动保存到存储
|
|
||||||
result_storage_dir: tmp # 结果存储目录,大结果会保存在此目录下
|
|
||||||
tool_timeout_minutes: 60 # 单次工具执行最大时长(分钟),超时自动终止;0 表示不限制(不推荐,易出现长时间挂起)
|
tool_timeout_minutes: 60 # 单次工具执行最大时长(分钟),超时自动终止;0 表示不限制(不推荐,易出现长时间挂起)
|
||||||
# system_prompt_path: prompts/single-agent.md # 可选:单代理系统提示文件(相对本配置文件所在目录);非空且可读时替换内置提示
|
# system_prompt_path: prompts/single-agent.md # 可选:单代理系统提示文件(相对本配置文件所在目录);非空且可读时替换内置提示
|
||||||
|
|
||||||
@@ -144,10 +142,10 @@ multi_agent:
|
|||||||
plan_execute_max_step_result_runes: 4000 # plan_execute 每步结果最大字符数(超出截断)
|
plan_execute_max_step_result_runes: 4000 # plan_execute 每步结果最大字符数(超出截断)
|
||||||
plan_execute_keep_last_steps: 8 # plan_execute 仅保留最近 N 步正文,早期步骤折叠为标题
|
plan_execute_keep_last_steps: 8 # plan_execute 仅保留最近 N 步正文,早期步骤折叠为标题
|
||||||
checkpoint_dir: data/eino-checkpoints # P0:进程崩溃/OOM 后同会话自动 ADK Resume;正常结束会删 .ckpt;与「中断并继续」(last_react_*) 是两套机制
|
checkpoint_dir: data/eino-checkpoints # P0:进程崩溃/OOM 后同会话自动 ADK Resume;正常结束会删 .ckpt;与「中断并继续」(last_react_*) 是两套机制
|
||||||
run_retry_max_attempts: 0 # 429/5xx/网络抖动时整轮 Run 指数退避续跑;0=默认 10(与 deep_model_retry 互补,建议保持默认)
|
run_retry_max_attempts: 0 # 429/5xx/网络抖动时可退避重试次数(run loop + summarization 共用 isEinoTransientRunError);0=默认 10
|
||||||
run_retry_max_backoff_sec: 0 # 单次退避上限秒数;0=默认 30
|
run_retry_max_backoff_sec: 0 # 单次退避上限秒数;0=默认 30
|
||||||
deep_output_key: final_answer # P0:Eino session 写入最终助手结论(框架内部;Deep/Supervisor 主/eino_single)
|
deep_output_key: final_answer # P0:Eino session 写入最终助手结论(框架内部;Deep/Supervisor 主/eino_single)
|
||||||
deep_model_retry_max_retries: 3 # P0:单次 ChatModel API 失败时框架自动重试(超时/502 等);子代理模型不受此项影响
|
deep_model_retry_max_retries: 0 # 已废弃,请用 run_retry_max_attempts;保留字段仅为兼容旧配置
|
||||||
task_tool_description_prefix: "" # 非空:仅 Deep 的 task 工具使用自定义描述前缀,运行时会拼接子代理名称;空则走 Eino 默认生成逻辑
|
task_tool_description_prefix: "" # 非空:仅 Deep 的 task 工具使用自定义描述前缀,运行时会拼接子代理名称;空则走 Eino 默认生成逻辑
|
||||||
# Eino callbacks + OpenTelemetry:框架级 span(与 Zap 对齐);默认不向终端用户 UI 推 eino_trace_*(见 sse_trace_to_client)
|
# Eino callbacks + OpenTelemetry:框架级 span(与 Zap 对齐);默认不向终端用户 UI 推 eino_trace_*(见 sse_trace_to_client)
|
||||||
eino_callbacks:
|
eino_callbacks:
|
||||||
@@ -310,7 +308,9 @@ roles_dir: roles # 角色配置文件目录(相对于配置文件所在目录
|
|||||||
project:
|
project:
|
||||||
enabled: true
|
enabled: true
|
||||||
# default_project_id: "" # 可选:机器人/批量任务创建对话时的默认项目 ID
|
# default_project_id: "" # 可选:机器人/批量任务创建对话时的默认项目 ID
|
||||||
fact_index_max_runes: 6500
|
fact_index_max_runes: 65000
|
||||||
fact_summary_max_runes: 2400
|
# 事实关系速览段预算(从索引总预算中预留)
|
||||||
|
fact_index_path_max_runes: 10000
|
||||||
|
fact_summary_max_runes: 24000
|
||||||
default_inject_deprecated: false
|
default_inject_deprecated: false
|
||||||
|
|
||||||
|
|||||||
+17
-135
@@ -18,7 +18,6 @@ import (
|
|||||||
"cyberstrike-ai/internal/mcp"
|
"cyberstrike-ai/internal/mcp"
|
||||||
"cyberstrike-ai/internal/mcp/builtin"
|
"cyberstrike-ai/internal/mcp/builtin"
|
||||||
"cyberstrike-ai/internal/openai"
|
"cyberstrike-ai/internal/openai"
|
||||||
"cyberstrike-ai/internal/storage"
|
|
||||||
|
|
||||||
"go.uber.org/zap"
|
"go.uber.org/zap"
|
||||||
)
|
)
|
||||||
@@ -32,8 +31,6 @@ type Agent struct {
|
|||||||
externalMCPMgr *mcp.ExternalMCPManager // 外部MCP管理器
|
externalMCPMgr *mcp.ExternalMCPManager // 外部MCP管理器
|
||||||
logger *zap.Logger
|
logger *zap.Logger
|
||||||
maxIterations int
|
maxIterations int
|
||||||
resultStorage ResultStorage // 结果存储
|
|
||||||
largeResultThreshold int // 大结果阈值(字节)
|
|
||||||
mu sync.RWMutex // 添加互斥锁以支持并发更新
|
mu sync.RWMutex // 添加互斥锁以支持并发更新
|
||||||
toolNameMapping map[string]string // 工具名称映射:OpenAI格式 -> 原始格式(用于外部MCP工具)
|
toolNameMapping map[string]string // 工具名称映射:OpenAI格式 -> 原始格式(用于外部MCP工具)
|
||||||
currentConversationID string // 当前对话ID(用于自动传递给工具)
|
currentConversationID string // 当前对话ID(用于自动传递给工具)
|
||||||
@@ -41,18 +38,6 @@ type Agent struct {
|
|||||||
toolDescriptionMode string // 工具描述模式: "short" | "full",默认 short
|
toolDescriptionMode string // 工具描述模式: "short" | "full",默认 short
|
||||||
}
|
}
|
||||||
|
|
||||||
// ResultStorage 结果存储接口(直接使用 storage 包的类型)
|
|
||||||
type ResultStorage interface {
|
|
||||||
SaveResult(executionID string, toolName string, result string) error
|
|
||||||
GetResult(executionID string) (string, error)
|
|
||||||
GetResultPage(executionID string, page int, limit int) (*storage.ResultPage, error)
|
|
||||||
SearchResult(executionID string, keyword string, useRegex bool) ([]string, error)
|
|
||||||
FilterResult(executionID string, filter string, useRegex bool) ([]string, error)
|
|
||||||
GetResultMetadata(executionID string) (*storage.ResultMetadata, error)
|
|
||||||
GetResultPath(executionID string) string
|
|
||||||
DeleteResult(executionID string) error
|
|
||||||
}
|
|
||||||
|
|
||||||
type agentConversationIDKey struct{}
|
type agentConversationIDKey struct{}
|
||||||
|
|
||||||
func withAgentConversationID(ctx context.Context, id string) context.Context {
|
func withAgentConversationID(ctx context.Context, id string) context.Context {
|
||||||
@@ -83,26 +68,6 @@ func NewAgent(cfg *config.OpenAIConfig, agentCfg *config.AgentConfig, mcpServer
|
|||||||
maxIterations = 30
|
maxIterations = 30
|
||||||
}
|
}
|
||||||
|
|
||||||
// 设置大结果阈值,默认50KB
|
|
||||||
largeResultThreshold := 50 * 1024
|
|
||||||
if agentCfg != nil && agentCfg.LargeResultThreshold > 0 {
|
|
||||||
largeResultThreshold = agentCfg.LargeResultThreshold
|
|
||||||
}
|
|
||||||
|
|
||||||
// 设置结果存储目录,默认tmp
|
|
||||||
resultStorageDir := "tmp"
|
|
||||||
if agentCfg != nil && agentCfg.ResultStorageDir != "" {
|
|
||||||
resultStorageDir = agentCfg.ResultStorageDir
|
|
||||||
}
|
|
||||||
|
|
||||||
// 初始化结果存储
|
|
||||||
var resultStorage ResultStorage
|
|
||||||
if resultStorageDir != "" {
|
|
||||||
// 导入storage包(避免循环依赖,使用接口)
|
|
||||||
// 这里需要在实际使用时初始化
|
|
||||||
// 暂时设为nil,在需要时初始化
|
|
||||||
}
|
|
||||||
|
|
||||||
// 配置HTTP Transport,优化连接管理和超时设置
|
// 配置HTTP Transport,优化连接管理和超时设置
|
||||||
transport := &http.Transport{
|
transport := &http.Transport{
|
||||||
DialContext: (&net.Dialer{
|
DialContext: (&net.Dialer{
|
||||||
@@ -133,20 +98,11 @@ func NewAgent(cfg *config.OpenAIConfig, agentCfg *config.AgentConfig, mcpServer
|
|||||||
externalMCPMgr: externalMCPMgr,
|
externalMCPMgr: externalMCPMgr,
|
||||||
logger: logger,
|
logger: logger,
|
||||||
maxIterations: maxIterations,
|
maxIterations: maxIterations,
|
||||||
resultStorage: resultStorage,
|
|
||||||
largeResultThreshold: largeResultThreshold,
|
|
||||||
toolNameMapping: make(map[string]string), // 初始化工具名称映射
|
toolNameMapping: make(map[string]string), // 初始化工具名称映射
|
||||||
toolDescriptionMode: "short",
|
toolDescriptionMode: "short",
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetResultStorage 设置结果存储(用于避免循环依赖)
|
|
||||||
func (a *Agent) SetResultStorage(storage ResultStorage) {
|
|
||||||
a.mu.Lock()
|
|
||||||
defer a.mu.Unlock()
|
|
||||||
a.resultStorage = storage
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetPromptBaseDir 设置单代理 system_prompt_path 相对路径的基准目录(一般为 config.yaml 所在目录)。
|
// SetPromptBaseDir 设置单代理 system_prompt_path 相对路径的基准目录(一般为 config.yaml 所在目录)。
|
||||||
func (a *Agent) SetPromptBaseDir(dir string) {
|
func (a *Agent) SetPromptBaseDir(dir string) {
|
||||||
a.mu.Lock()
|
a.mu.Lock()
|
||||||
@@ -663,46 +619,6 @@ func (a *Agent) executeToolViaMCP(ctx context.Context, toolName string, args map
|
|||||||
}
|
}
|
||||||
|
|
||||||
resultStr := resultText.String()
|
resultStr := resultText.String()
|
||||||
resultSize := len(resultStr)
|
|
||||||
|
|
||||||
// 检测大结果并保存
|
|
||||||
a.mu.RLock()
|
|
||||||
threshold := a.largeResultThreshold
|
|
||||||
storage := a.resultStorage
|
|
||||||
a.mu.RUnlock()
|
|
||||||
|
|
||||||
if resultSize > threshold && storage != nil {
|
|
||||||
// 异步保存大结果
|
|
||||||
go func() {
|
|
||||||
if err := storage.SaveResult(executionID, toolName, resultStr); err != nil {
|
|
||||||
a.logger.Warn("保存大结果失败",
|
|
||||||
zap.String("executionID", executionID),
|
|
||||||
zap.String("toolName", toolName),
|
|
||||||
zap.Error(err),
|
|
||||||
)
|
|
||||||
} else {
|
|
||||||
a.logger.Info("大结果已保存",
|
|
||||||
zap.String("executionID", executionID),
|
|
||||||
zap.String("toolName", toolName),
|
|
||||||
zap.Int("size", resultSize),
|
|
||||||
)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
// 返回最小化通知
|
|
||||||
lines := strings.Split(resultStr, "\n")
|
|
||||||
filePath := ""
|
|
||||||
if storage != nil {
|
|
||||||
filePath = storage.GetResultPath(executionID)
|
|
||||||
}
|
|
||||||
notification := a.formatMinimalNotification(executionID, toolName, resultSize, len(lines), filePath)
|
|
||||||
|
|
||||||
return &ToolExecutionResult{
|
|
||||||
Result: notification,
|
|
||||||
ExecutionID: executionID,
|
|
||||||
IsError: result != nil && result.IsError,
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
return &ToolExecutionResult{
|
return &ToolExecutionResult{
|
||||||
Result: resultStr,
|
Result: resultStr,
|
||||||
@@ -711,57 +627,6 @@ func (a *Agent) executeToolViaMCP(ctx context.Context, toolName string, args map
|
|||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// formatMinimalNotification 格式化最小化通知
|
|
||||||
func (a *Agent) formatMinimalNotification(executionID string, toolName string, size int, lineCount int, filePath string) string {
|
|
||||||
var sb strings.Builder
|
|
||||||
|
|
||||||
sb.WriteString(fmt.Sprintf("工具执行完成。结果已保存(ID: %s)。\n\n", executionID))
|
|
||||||
sb.WriteString("结果信息:\n")
|
|
||||||
sb.WriteString(fmt.Sprintf(" - 工具: %s\n", toolName))
|
|
||||||
sb.WriteString(fmt.Sprintf(" - 大小: %d 字节 (%.2f KB)\n", size, float64(size)/1024))
|
|
||||||
sb.WriteString(fmt.Sprintf(" - 行数: %d 行\n", lineCount))
|
|
||||||
if filePath != "" {
|
|
||||||
sb.WriteString(fmt.Sprintf(" - 文件路径: %s\n", filePath))
|
|
||||||
}
|
|
||||||
sb.WriteString("\n")
|
|
||||||
sb.WriteString("推荐使用 query_execution_result 工具查询完整结果:\n")
|
|
||||||
sb.WriteString(fmt.Sprintf(" - 查询第一页: query_execution_result(execution_id=\"%s\", page=1, limit=100)\n", executionID))
|
|
||||||
sb.WriteString(fmt.Sprintf(" - 搜索关键词: query_execution_result(execution_id=\"%s\", search=\"关键词\")\n", executionID))
|
|
||||||
sb.WriteString(fmt.Sprintf(" - 过滤条件: query_execution_result(execution_id=\"%s\", filter=\"error\")\n", executionID))
|
|
||||||
sb.WriteString(fmt.Sprintf(" - 正则匹配: query_execution_result(execution_id=\"%s\", search=\"\\\\d+\\\\.\\\\d+\\\\.\\\\d+\\\\.\\\\d+\", use_regex=true)\n", executionID))
|
|
||||||
sb.WriteString("\n")
|
|
||||||
if filePath != "" {
|
|
||||||
sb.WriteString("如果 query_execution_result 工具不满足需求,也可以使用其他工具处理文件:\n")
|
|
||||||
sb.WriteString("\n")
|
|
||||||
sb.WriteString("**分段读取示例:**\n")
|
|
||||||
sb.WriteString(fmt.Sprintf(" - 查看前100行: exec(command=\"head\", args=[\"-n\", \"100\", \"%s\"])\n", filePath))
|
|
||||||
sb.WriteString(fmt.Sprintf(" - 查看后100行: exec(command=\"tail\", args=[\"-n\", \"100\", \"%s\"])\n", filePath))
|
|
||||||
sb.WriteString(fmt.Sprintf(" - 查看第50-150行: exec(command=\"sed\", args=[\"-n\", \"50,150p\", \"%s\"])\n", filePath))
|
|
||||||
sb.WriteString("\n")
|
|
||||||
sb.WriteString("**搜索和正则匹配示例:**\n")
|
|
||||||
sb.WriteString(fmt.Sprintf(" - 搜索关键词: exec(command=\"grep\", args=[\"关键词\", \"%s\"])\n", filePath))
|
|
||||||
sb.WriteString(fmt.Sprintf(" - 正则匹配IP地址: exec(command=\"grep\", args=[\"-E\", \"\\\\d+\\\\.\\\\d+\\\\.\\\\d+\\\\.\\\\d+\", \"%s\"])\n", filePath))
|
|
||||||
sb.WriteString(fmt.Sprintf(" - 不区分大小写搜索: exec(command=\"grep\", args=[\"-i\", \"关键词\", \"%s\"])\n", filePath))
|
|
||||||
sb.WriteString(fmt.Sprintf(" - 显示匹配行号: exec(command=\"grep\", args=[\"-n\", \"关键词\", \"%s\"])\n", filePath))
|
|
||||||
sb.WriteString("\n")
|
|
||||||
sb.WriteString("**过滤和统计示例:**\n")
|
|
||||||
sb.WriteString(fmt.Sprintf(" - 统计总行数: exec(command=\"wc\", args=[\"-l\", \"%s\"])\n", filePath))
|
|
||||||
sb.WriteString(fmt.Sprintf(" - 过滤包含error的行: exec(command=\"grep\", args=[\"error\", \"%s\"])\n", filePath))
|
|
||||||
sb.WriteString(fmt.Sprintf(" - 排除空行: exec(command=\"grep\", args=[\"-v\", \"^$\", \"%s\"])\n", filePath))
|
|
||||||
sb.WriteString("\n")
|
|
||||||
sb.WriteString("**完整读取(不推荐大文件):**\n")
|
|
||||||
sb.WriteString(fmt.Sprintf(" - 使用 cat 工具: cat(file=\"%s\")\n", filePath))
|
|
||||||
sb.WriteString(fmt.Sprintf(" - 使用 exec 工具: exec(command=\"cat\", args=[\"%s\"])\n", filePath))
|
|
||||||
sb.WriteString("\n")
|
|
||||||
sb.WriteString("**注意:**\n")
|
|
||||||
sb.WriteString(" - 直接读取大文件可能会再次触发大结果保存机制\n")
|
|
||||||
sb.WriteString(" - 建议优先使用分段读取和搜索功能,避免一次性加载整个文件\n")
|
|
||||||
sb.WriteString(" - 正则表达式语法遵循标准 POSIX 正则表达式规范\n")
|
|
||||||
}
|
|
||||||
|
|
||||||
return sb.String()
|
|
||||||
}
|
|
||||||
|
|
||||||
// UpdateConfig 更新OpenAI配置
|
// UpdateConfig 更新OpenAI配置
|
||||||
func (a *Agent) UpdateConfig(cfg *config.OpenAIConfig) {
|
func (a *Agent) UpdateConfig(cfg *config.OpenAIConfig) {
|
||||||
a.mu.Lock()
|
a.mu.Lock()
|
||||||
@@ -923,6 +788,23 @@ func (a *Agent) RecordLocalToolExecution(toolName string, args map[string]interf
|
|||||||
return a.mcpServer.RecordCompletedToolInvocation(toolName, args, resultText, invokeErr)
|
return a.mcpServer.RecordCompletedToolInvocation(toolName, args, resultText, invokeErr)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// UpdateMCPExecutionDisplayResult 将监控库中的工具结果更新为送入模型的展示正文(reduction 后)。
|
||||||
|
func (a *Agent) UpdateMCPExecutionDisplayResult(executionID, resultText string) {
|
||||||
|
if a == nil || strings.TrimSpace(executionID) == "" {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
text := resultText
|
||||||
|
if strings.TrimSpace(text) == "" {
|
||||||
|
text = "(无输出)"
|
||||||
|
}
|
||||||
|
tr := &mcp.ToolResult{
|
||||||
|
Content: []mcp.Content{{Type: "text", Text: text}},
|
||||||
|
}
|
||||||
|
if a.mcpServer != nil {
|
||||||
|
_ = a.mcpServer.UpdateToolExecutionResult(executionID, tr)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// CancelMCPToolExecutionWithNote 取消一次进行中的 MCP 工具(先内部后外部),与监控页「终止工具」一致;note 非空时合并进返回给模型的文本。
|
// CancelMCPToolExecutionWithNote 取消一次进行中的 MCP 工具(先内部后外部),与监控页「终止工具」一致;note 非空时合并进返回给模型的文本。
|
||||||
func (a *Agent) CancelMCPToolExecutionWithNote(executionID, note string) bool {
|
func (a *Agent) CancelMCPToolExecutionWithNote(executionID, note string) bool {
|
||||||
executionID = strings.TrimSpace(executionID)
|
executionID = strings.TrimSpace(executionID)
|
||||||
|
|||||||
@@ -1,21 +1,16 @@
|
|||||||
package agent
|
package agent
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"os"
|
|
||||||
"path/filepath"
|
|
||||||
"strings"
|
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
|
||||||
|
|
||||||
"cyberstrike-ai/internal/config"
|
"cyberstrike-ai/internal/config"
|
||||||
"cyberstrike-ai/internal/mcp"
|
"cyberstrike-ai/internal/mcp"
|
||||||
"cyberstrike-ai/internal/storage"
|
|
||||||
|
|
||||||
"go.uber.org/zap"
|
"go.uber.org/zap"
|
||||||
)
|
)
|
||||||
|
|
||||||
// setupTestAgent 创建测试用的Agent
|
// setupTestAgent 创建测试用的Agent
|
||||||
func setupTestAgent(t *testing.T) (*Agent, *storage.FileResultStorage) {
|
func setupTestAgent(t *testing.T) *Agent {
|
||||||
logger := zap.NewNop()
|
logger := zap.NewNop()
|
||||||
mcpServer := mcp.NewServer(logger)
|
mcpServer := mcp.NewServer(logger)
|
||||||
|
|
||||||
@@ -27,204 +22,9 @@ func setupTestAgent(t *testing.T) (*Agent, *storage.FileResultStorage) {
|
|||||||
|
|
||||||
agentCfg := &config.AgentConfig{
|
agentCfg := &config.AgentConfig{
|
||||||
MaxIterations: 10,
|
MaxIterations: 10,
|
||||||
LargeResultThreshold: 100, // 设置较小的阈值便于测试
|
|
||||||
ResultStorageDir: "",
|
|
||||||
}
|
}
|
||||||
|
|
||||||
agent := NewAgent(openAICfg, agentCfg, mcpServer, nil, logger, 10)
|
return NewAgent(openAICfg, agentCfg, mcpServer, nil, logger, 10)
|
||||||
|
|
||||||
// 创建测试存储
|
|
||||||
tmpDir := filepath.Join(os.TempDir(), "test_agent_storage_"+time.Now().Format("20060102_150405"))
|
|
||||||
testStorage, err := storage.NewFileResultStorage(tmpDir, logger)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("创建测试存储失败: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
agent.SetResultStorage(testStorage)
|
|
||||||
|
|
||||||
return agent, testStorage
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestAgent_FormatMinimalNotification(t *testing.T) {
|
|
||||||
agent, testStorage := setupTestAgent(t)
|
|
||||||
_ = testStorage // 避免未使用变量警告
|
|
||||||
|
|
||||||
executionID := "test_exec_001"
|
|
||||||
toolName := "nmap_scan"
|
|
||||||
size := 50000
|
|
||||||
lineCount := 1000
|
|
||||||
filePath := "tmp/test_exec_001.txt"
|
|
||||||
|
|
||||||
notification := agent.formatMinimalNotification(executionID, toolName, size, lineCount, filePath)
|
|
||||||
|
|
||||||
// 验证通知包含必要信息
|
|
||||||
if !strings.Contains(notification, executionID) {
|
|
||||||
t.Errorf("通知中应该包含执行ID: %s", executionID)
|
|
||||||
}
|
|
||||||
|
|
||||||
if !strings.Contains(notification, toolName) {
|
|
||||||
t.Errorf("通知中应该包含工具名称: %s", toolName)
|
|
||||||
}
|
|
||||||
|
|
||||||
if !strings.Contains(notification, "50000") {
|
|
||||||
t.Errorf("通知中应该包含大小信息")
|
|
||||||
}
|
|
||||||
|
|
||||||
if !strings.Contains(notification, "1000") {
|
|
||||||
t.Errorf("通知中应该包含行数信息")
|
|
||||||
}
|
|
||||||
|
|
||||||
if !strings.Contains(notification, "query_execution_result") {
|
|
||||||
t.Errorf("通知中应该包含查询工具的使用说明")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestAgent_ExecuteToolViaMCP_LargeResult(t *testing.T) {
|
|
||||||
agent, _ := setupTestAgent(t)
|
|
||||||
|
|
||||||
// 创建模拟的MCP工具结果(大结果)
|
|
||||||
largeResult := &mcp.ToolResult{
|
|
||||||
Content: []mcp.Content{
|
|
||||||
{
|
|
||||||
Type: "text",
|
|
||||||
Text: strings.Repeat("This is a test line with some content.\n", 1000), // 约50KB
|
|
||||||
},
|
|
||||||
},
|
|
||||||
IsError: false,
|
|
||||||
}
|
|
||||||
|
|
||||||
// 模拟MCP服务器返回大结果
|
|
||||||
// 由于我们需要模拟CallTool的行为,这里需要创建一个mock或者使用实际的MCP服务器
|
|
||||||
// 为了简化测试,我们直接测试结果处理逻辑
|
|
||||||
|
|
||||||
// 设置阈值
|
|
||||||
agent.mu.Lock()
|
|
||||||
agent.largeResultThreshold = 1000 // 设置较小的阈值
|
|
||||||
agent.mu.Unlock()
|
|
||||||
|
|
||||||
// 创建执行ID
|
|
||||||
executionID := "test_exec_large_001"
|
|
||||||
toolName := "test_tool"
|
|
||||||
|
|
||||||
// 格式化结果
|
|
||||||
var resultText strings.Builder
|
|
||||||
for _, content := range largeResult.Content {
|
|
||||||
resultText.WriteString(content.Text)
|
|
||||||
resultText.WriteString("\n")
|
|
||||||
}
|
|
||||||
|
|
||||||
resultStr := resultText.String()
|
|
||||||
resultSize := len(resultStr)
|
|
||||||
|
|
||||||
// 检测大结果并保存
|
|
||||||
agent.mu.RLock()
|
|
||||||
threshold := agent.largeResultThreshold
|
|
||||||
storage := agent.resultStorage
|
|
||||||
agent.mu.RUnlock()
|
|
||||||
|
|
||||||
if resultSize > threshold && storage != nil {
|
|
||||||
// 保存大结果
|
|
||||||
err := storage.SaveResult(executionID, toolName, resultStr)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("保存大结果失败: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// 生成通知
|
|
||||||
lines := strings.Split(resultStr, "\n")
|
|
||||||
filePath := storage.GetResultPath(executionID)
|
|
||||||
notification := agent.formatMinimalNotification(executionID, toolName, resultSize, len(lines), filePath)
|
|
||||||
|
|
||||||
// 验证通知格式
|
|
||||||
if !strings.Contains(notification, executionID) {
|
|
||||||
t.Errorf("通知中应该包含执行ID")
|
|
||||||
}
|
|
||||||
|
|
||||||
// 验证结果已保存
|
|
||||||
savedResult, err := storage.GetResult(executionID)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("获取保存的结果失败: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if savedResult != resultStr {
|
|
||||||
t.Errorf("保存的结果与原始结果不匹配")
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
t.Fatal("大结果应该被检测到并保存")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestAgent_ExecuteToolViaMCP_SmallResult(t *testing.T) {
|
|
||||||
agent, _ := setupTestAgent(t)
|
|
||||||
|
|
||||||
// 创建小结果
|
|
||||||
smallResult := &mcp.ToolResult{
|
|
||||||
Content: []mcp.Content{
|
|
||||||
{
|
|
||||||
Type: "text",
|
|
||||||
Text: "Small result content",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
IsError: false,
|
|
||||||
}
|
|
||||||
|
|
||||||
// 设置较大的阈值
|
|
||||||
agent.mu.Lock()
|
|
||||||
agent.largeResultThreshold = 100000 // 100KB
|
|
||||||
agent.mu.Unlock()
|
|
||||||
|
|
||||||
// 格式化结果
|
|
||||||
var resultText strings.Builder
|
|
||||||
for _, content := range smallResult.Content {
|
|
||||||
resultText.WriteString(content.Text)
|
|
||||||
resultText.WriteString("\n")
|
|
||||||
}
|
|
||||||
|
|
||||||
resultStr := resultText.String()
|
|
||||||
resultSize := len(resultStr)
|
|
||||||
|
|
||||||
// 检测大结果
|
|
||||||
agent.mu.RLock()
|
|
||||||
threshold := agent.largeResultThreshold
|
|
||||||
storage := agent.resultStorage
|
|
||||||
agent.mu.RUnlock()
|
|
||||||
|
|
||||||
if resultSize > threshold && storage != nil {
|
|
||||||
t.Fatal("小结果不应该被保存")
|
|
||||||
}
|
|
||||||
|
|
||||||
// 小结果应该直接返回
|
|
||||||
if resultSize <= threshold {
|
|
||||||
// 这是预期的行为
|
|
||||||
if resultStr == "" {
|
|
||||||
t.Fatal("小结果应该直接返回,不应该为空")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestAgent_SetResultStorage(t *testing.T) {
|
|
||||||
agent, _ := setupTestAgent(t)
|
|
||||||
|
|
||||||
// 创建新的存储
|
|
||||||
tmpDir := filepath.Join(os.TempDir(), "test_new_storage_"+time.Now().Format("20060102_150405"))
|
|
||||||
newStorage, err := storage.NewFileResultStorage(tmpDir, zap.NewNop())
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("创建新存储失败: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// 设置新存储
|
|
||||||
agent.SetResultStorage(newStorage)
|
|
||||||
|
|
||||||
// 验证存储已更新
|
|
||||||
agent.mu.RLock()
|
|
||||||
currentStorage := agent.resultStorage
|
|
||||||
agent.mu.RUnlock()
|
|
||||||
|
|
||||||
if currentStorage != newStorage {
|
|
||||||
t.Fatal("存储未正确更新")
|
|
||||||
}
|
|
||||||
|
|
||||||
// 清理
|
|
||||||
os.RemoveAll(tmpDir)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestAgent_NewAgent_DefaultValues(t *testing.T) {
|
func TestAgent_NewAgent_DefaultValues(t *testing.T) {
|
||||||
@@ -243,14 +43,6 @@ func TestAgent_NewAgent_DefaultValues(t *testing.T) {
|
|||||||
if agent.maxIterations != 30 {
|
if agent.maxIterations != 30 {
|
||||||
t.Errorf("默认迭代次数不匹配。期望: 30, 实际: %d", agent.maxIterations)
|
t.Errorf("默认迭代次数不匹配。期望: 30, 实际: %d", agent.maxIterations)
|
||||||
}
|
}
|
||||||
|
|
||||||
agent.mu.RLock()
|
|
||||||
threshold := agent.largeResultThreshold
|
|
||||||
agent.mu.RUnlock()
|
|
||||||
|
|
||||||
if threshold != 50*1024 {
|
|
||||||
t.Errorf("默认阈值不匹配。期望: %d, 实际: %d", 50*1024, threshold)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestAgent_NewAgent_CustomConfig(t *testing.T) {
|
func TestAgent_NewAgent_CustomConfig(t *testing.T) {
|
||||||
@@ -265,8 +57,6 @@ func TestAgent_NewAgent_CustomConfig(t *testing.T) {
|
|||||||
|
|
||||||
agentCfg := &config.AgentConfig{
|
agentCfg := &config.AgentConfig{
|
||||||
MaxIterations: 20,
|
MaxIterations: 20,
|
||||||
LargeResultThreshold: 100 * 1024, // 100KB
|
|
||||||
ResultStorageDir: "custom_tmp",
|
|
||||||
}
|
}
|
||||||
|
|
||||||
agent := NewAgent(openAICfg, agentCfg, mcpServer, nil, logger, 15)
|
agent := NewAgent(openAICfg, agentCfg, mcpServer, nil, logger, 15)
|
||||||
@@ -274,12 +64,4 @@ func TestAgent_NewAgent_CustomConfig(t *testing.T) {
|
|||||||
if agent.maxIterations != 15 {
|
if agent.maxIterations != 15 {
|
||||||
t.Errorf("迭代次数不匹配。期望: 15, 实际: %d", agent.maxIterations)
|
t.Errorf("迭代次数不匹配。期望: 15, 实际: %d", agent.maxIterations)
|
||||||
}
|
}
|
||||||
|
|
||||||
agent.mu.RLock()
|
|
||||||
threshold := agent.largeResultThreshold
|
|
||||||
agent.mu.RUnlock()
|
|
||||||
|
|
||||||
if threshold != 100*1024 {
|
|
||||||
t.Errorf("阈值不匹配。期望: %d, 实际: %d", 100*1024, threshold)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
package agent
|
package agent
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"cyberstrike-ai/internal/project"
|
"cyberstrike-ai/internal/projectprompt"
|
||||||
)
|
)
|
||||||
|
|
||||||
// DefaultSingleAgentSystemPrompt 单代理(Eino ADK / MCP)内置系统提示;可通过 agent.system_prompt_path 覆盖为文件。
|
// DefaultSingleAgentSystemPrompt 单代理(Eino ADK / MCP)内置系统提示;可通过 agent.system_prompt_path 覆盖为文件。
|
||||||
@@ -107,7 +107,7 @@ func DefaultSingleAgentSystemPrompt() string {
|
|||||||
- 若最近一步得到 404/空结果/无效响应,不得直接结束;至少再进行一次“同目标不同策略”的验证(如变更路径、参数、请求方法、上下文来源)。
|
- 若最近一步得到 404/空结果/无效响应,不得直接结束;至少再进行一次“同目标不同策略”的验证(如变更路径、参数、请求方法、上下文来源)。
|
||||||
- 避免无效空转:同一工具+同类参数连续失败 3 次后,必须切换策略(改工具、改入口、改假设)并说明切换原因。
|
- 避免无效空转:同一工具+同类参数连续失败 3 次后,必须切换策略(改工具、改入口、改假设)并说明切换原因。
|
||||||
|
|
||||||
` + project.FactRecordingBlackboardSection(false) + `
|
` + projectprompt.FactRecordingBlackboardSection(false) + `
|
||||||
|
|
||||||
## 技能库(Skills)与知识库
|
## 技能库(Skills)与知识库
|
||||||
|
|
||||||
|
|||||||
+12
-26
@@ -28,7 +28,6 @@ import (
|
|||||||
"cyberstrike-ai/internal/robot"
|
"cyberstrike-ai/internal/robot"
|
||||||
"cyberstrike-ai/internal/security"
|
"cyberstrike-ai/internal/security"
|
||||||
"cyberstrike-ai/internal/skillpackage"
|
"cyberstrike-ai/internal/skillpackage"
|
||||||
"cyberstrike-ai/internal/storage"
|
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
@@ -130,23 +129,6 @@ func New(cfg *config.Config, log *logger.Logger, configPath string) (*App, error
|
|||||||
externalMCPMgr.StartAllEnabled()
|
externalMCPMgr.StartAllEnabled()
|
||||||
}
|
}
|
||||||
|
|
||||||
// 初始化结果存储
|
|
||||||
resultStorageDir := "tmp"
|
|
||||||
if cfg.Agent.ResultStorageDir != "" {
|
|
||||||
resultStorageDir = cfg.Agent.ResultStorageDir
|
|
||||||
}
|
|
||||||
|
|
||||||
// 确保存储目录存在
|
|
||||||
if err := os.MkdirAll(resultStorageDir, 0755); err != nil {
|
|
||||||
return nil, fmt.Errorf("创建结果存储目录失败: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// 创建结果存储实例
|
|
||||||
resultStorage, err := storage.NewFileResultStorage(resultStorageDir, log.Logger)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("初始化结果存储失败: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// 创建Agent
|
// 创建Agent
|
||||||
maxIterations := cfg.Agent.MaxIterations
|
maxIterations := cfg.Agent.MaxIterations
|
||||||
if maxIterations <= 0 {
|
if maxIterations <= 0 {
|
||||||
@@ -155,12 +137,6 @@ func New(cfg *config.Config, log *logger.Logger, configPath string) (*App, error
|
|||||||
agent := agent.NewAgent(&cfg.OpenAI, &cfg.Agent, mcpServer, externalMCPMgr, log.Logger, maxIterations)
|
agent := agent.NewAgent(&cfg.OpenAI, &cfg.Agent, mcpServer, externalMCPMgr, log.Logger, maxIterations)
|
||||||
agent.UpdateToolDescriptionMode(cfg.Security.ToolDescriptionMode)
|
agent.UpdateToolDescriptionMode(cfg.Security.ToolDescriptionMode)
|
||||||
|
|
||||||
// 设置结果存储到Agent
|
|
||||||
agent.SetResultStorage(resultStorage)
|
|
||||||
|
|
||||||
// 设置结果存储到Executor(用于查询工具)
|
|
||||||
executor.SetResultStorage(resultStorage)
|
|
||||||
|
|
||||||
// 初始化知识库模块(如果启用)
|
// 初始化知识库模块(如果启用)
|
||||||
var knowledgeManager *knowledge.Manager
|
var knowledgeManager *knowledge.Manager
|
||||||
var knowledgeRetriever *knowledge.Retriever
|
var knowledgeRetriever *knowledge.Retriever
|
||||||
@@ -322,7 +298,8 @@ func New(cfg *config.Config, log *logger.Logger, configPath string) (*App, error
|
|||||||
plantaskBase := filepath.Join(skillsDir, plantaskRel)
|
plantaskBase := filepath.Join(skillsDir, plantaskRel)
|
||||||
// Match eino_adk_run_loop: checkpoint_dir is used as configured (relative to process CWD when not absolute).
|
// Match eino_adk_run_loop: checkpoint_dir is used as configured (relative to process CWD when not absolute).
|
||||||
checkpointBase := strings.TrimSpace(cfg.MultiAgent.EinoMiddleware.CheckpointDir)
|
checkpointBase := strings.TrimSpace(cfg.MultiAgent.EinoMiddleware.CheckpointDir)
|
||||||
db.SetEinoConversationDirs(plantaskBase, checkpointBase)
|
reductionRoot := strings.TrimSpace(cfg.MultiAgent.EinoMiddleware.ReductionRootDir)
|
||||||
|
db.SetEinoConversationDirs(plantaskBase, checkpointBase, reductionRoot)
|
||||||
agent.SetPromptBaseDir(configDir)
|
agent.SetPromptBaseDir(configDir)
|
||||||
|
|
||||||
agentsDir := cfg.AgentsDir
|
agentsDir := cfg.AgentsDir
|
||||||
@@ -392,9 +369,10 @@ func New(cfg *config.Config, log *logger.Logger, configPath string) (*App, error
|
|||||||
// 创建OpenAPI处理器
|
// 创建OpenAPI处理器
|
||||||
conversationHandler := handler.NewConversationHandler(db, log.Logger)
|
conversationHandler := handler.NewConversationHandler(db, log.Logger)
|
||||||
conversationHandler.SetAudit(auditSvc)
|
conversationHandler.SetAudit(auditSvc)
|
||||||
|
conversationHandler.SetTaskStopper(agentHandler)
|
||||||
auditHandler := handler.NewAuditHandler(db, auditSvc, log.Logger)
|
auditHandler := handler.NewAuditHandler(db, auditSvc, log.Logger)
|
||||||
robotHandler := handler.NewRobotHandler(cfg, db, agentHandler, log.Logger)
|
robotHandler := handler.NewRobotHandler(cfg, db, agentHandler, log.Logger)
|
||||||
openAPIHandler := handler.NewOpenAPIHandler(db, log.Logger, resultStorage, conversationHandler, agentHandler)
|
openAPIHandler := handler.NewOpenAPIHandler(db, log.Logger, conversationHandler, agentHandler)
|
||||||
|
|
||||||
// 创建 App 实例(部分字段稍后填充)
|
// 创建 App 实例(部分字段稍后填充)
|
||||||
app := &App{
|
app := &App{
|
||||||
@@ -853,6 +831,7 @@ func setupRoutes(
|
|||||||
protected.PUT("/batch-tasks/:queueId/schedule-enabled", agentHandler.SetBatchQueueScheduleEnabled)
|
protected.PUT("/batch-tasks/:queueId/schedule-enabled", agentHandler.SetBatchQueueScheduleEnabled)
|
||||||
protected.DELETE("/batch-tasks/:queueId", agentHandler.DeleteBatchQueue)
|
protected.DELETE("/batch-tasks/:queueId", agentHandler.DeleteBatchQueue)
|
||||||
protected.PUT("/batch-tasks/:queueId/tasks/:taskId", agentHandler.UpdateBatchTask)
|
protected.PUT("/batch-tasks/:queueId/tasks/:taskId", agentHandler.UpdateBatchTask)
|
||||||
|
protected.POST("/batch-tasks/:queueId/tasks/:taskId/run", agentHandler.RunSingleBatchTask)
|
||||||
protected.POST("/batch-tasks/:queueId/tasks", agentHandler.AddBatchTask)
|
protected.POST("/batch-tasks/:queueId/tasks", agentHandler.AddBatchTask)
|
||||||
protected.DELETE("/batch-tasks/:queueId/tasks/:taskId", agentHandler.DeleteBatchTask)
|
protected.DELETE("/batch-tasks/:queueId/tasks/:taskId", agentHandler.DeleteBatchTask)
|
||||||
|
|
||||||
@@ -900,6 +879,7 @@ func setupRoutes(
|
|||||||
protected.POST("/config/apply", configHandler.ApplyConfig)
|
protected.POST("/config/apply", configHandler.ApplyConfig)
|
||||||
protected.POST("/config/test-openai", configHandler.TestOpenAI)
|
protected.POST("/config/test-openai", configHandler.TestOpenAI)
|
||||||
protected.POST("/config/test-vision", configHandler.TestVision)
|
protected.POST("/config/test-vision", configHandler.TestVision)
|
||||||
|
protected.POST("/config/list-models", configHandler.ListModels)
|
||||||
|
|
||||||
// 系统设置 - 终端(执行命令,提高运维效率)
|
// 系统设置 - 终端(执行命令,提高运维效率)
|
||||||
protected.POST("/terminal/run", terminalHandler.RunCommand)
|
protected.POST("/terminal/run", terminalHandler.RunCommand)
|
||||||
@@ -1091,6 +1071,11 @@ func setupRoutes(
|
|||||||
protected.GET("/projects/:id", projectHandler.GetProject)
|
protected.GET("/projects/:id", projectHandler.GetProject)
|
||||||
protected.PUT("/projects/:id", projectHandler.UpdateProject)
|
protected.PUT("/projects/:id", projectHandler.UpdateProject)
|
||||||
protected.DELETE("/projects/:id", projectHandler.DeleteProject)
|
protected.DELETE("/projects/:id", projectHandler.DeleteProject)
|
||||||
|
protected.GET("/projects/:id/fact-graph", projectHandler.GetFactGraph)
|
||||||
|
protected.GET("/projects/:id/fact-edges", projectHandler.ListFactEdges)
|
||||||
|
protected.POST("/projects/:id/fact-edges", projectHandler.CreateFactEdge)
|
||||||
|
protected.DELETE("/projects/:id/fact-edges/:edgeId", projectHandler.DeleteFactEdge)
|
||||||
|
protected.POST("/projects/:id/promote-attack-chain/:conversationId", projectHandler.PromoteAttackChain)
|
||||||
protected.GET("/projects/:id/facts", projectHandler.ListFacts)
|
protected.GET("/projects/:id/facts", projectHandler.ListFacts)
|
||||||
protected.POST("/projects/:id/facts", projectHandler.CreateFact)
|
protected.POST("/projects/:id/facts", projectHandler.CreateFact)
|
||||||
protected.PUT("/projects/:id/facts/:factId", projectHandler.UpdateFact)
|
protected.PUT("/projects/:id/facts/:factId", projectHandler.UpdateFact)
|
||||||
@@ -1131,6 +1116,7 @@ func setupRoutes(
|
|||||||
c2Routes.POST("/listeners/:id/start", c2Handler.StartListener)
|
c2Routes.POST("/listeners/:id/start", c2Handler.StartListener)
|
||||||
c2Routes.POST("/listeners/:id/stop", c2Handler.StopListener)
|
c2Routes.POST("/listeners/:id/stop", c2Handler.StopListener)
|
||||||
c2Routes.GET("/sessions", c2Handler.ListSessions)
|
c2Routes.GET("/sessions", c2Handler.ListSessions)
|
||||||
|
c2Routes.DELETE("/sessions", c2Handler.DeleteSessions)
|
||||||
c2Routes.GET("/sessions/:id", c2Handler.GetSession)
|
c2Routes.GET("/sessions/:id", c2Handler.GetSession)
|
||||||
c2Routes.DELETE("/sessions/:id", c2Handler.DeleteSession)
|
c2Routes.DELETE("/sessions/:id", c2Handler.DeleteSession)
|
||||||
c2Routes.PUT("/sessions/:id/sleep", c2Handler.SetSessionSleep)
|
c2Routes.PUT("/sessions/:id/sleep", c2Handler.SetSessionSleep)
|
||||||
|
|||||||
@@ -61,6 +61,7 @@ func registerC2ListenerTool(s *mcp.Server, m *c2.Manager, l *zap.Logger, webList
|
|||||||
- stop: 停止监听器(需 listener_id)
|
- stop: 停止监听器(需 listener_id)
|
||||||
- delete: 删除监听器(需 listener_id)
|
- delete: 删除监听器(需 listener_id)
|
||||||
监听器类型: tcp_reverse, http_beacon, https_beacon, websocket
|
监听器类型: tcp_reverse, http_beacon, https_beacon, websocket
|
||||||
|
tcp_reverse 默认仅接受 CSB1 加密 Beacon(AES-GCM + ImplantToken)才登记会话;经典 bash/nc 反弹需在 config.allow_legacy_shell=true(公网不推荐)。
|
||||||
端口约束:create/update 的 bind_port 禁止与本平台 Web/API 所用端口相同。当前本服务该端口为 %d(配置项 server.port,随进程启动从配置文件加载)。若 bind_port 与此相同会导致本服务或监听器 bind 失败、Beacon/oneliner 误连到 Web 而非 C2。请为监听器另选空闲端口。`, webListenPort),
|
端口约束:create/update 的 bind_port 禁止与本平台 Web/API 所用端口相同。当前本服务该端口为 %d(配置项 server.port,随进程启动从配置文件加载)。若 bind_port 与此相同会导致本服务或监听器 bind 失败、Beacon/oneliner 误连到 Web 而非 C2。请为监听器另选空闲端口。`, webListenPort),
|
||||||
InputSchema: map[string]interface{}{
|
InputSchema: map[string]interface{}{
|
||||||
"type": "object",
|
"type": "object",
|
||||||
@@ -74,7 +75,7 @@ func registerC2ListenerTool(s *mcp.Server, m *c2.Manager, l *zap.Logger, webList
|
|||||||
"bind_port": map[string]interface{}{"type": "integer", "description": fmt.Sprintf("绑定端口(create 必填)。须 ≠ %d(当前本服务 Web/API 端口,配置 server.port)", webListenPort), "minimum": 1, "maximum": 65535},
|
"bind_port": map[string]interface{}{"type": "integer", "description": fmt.Sprintf("绑定端口(create 必填)。须 ≠ %d(当前本服务 Web/API 端口,配置 server.port)", webListenPort), "minimum": 1, "maximum": 65535},
|
||||||
"profile_id": map[string]interface{}{"type": "string", "description": "Malleable Profile ID"},
|
"profile_id": map[string]interface{}{"type": "string", "description": "Malleable Profile ID"},
|
||||||
"remark": map[string]interface{}{"type": "string", "description": "备注"},
|
"remark": map[string]interface{}{"type": "string", "description": "备注"},
|
||||||
"config": map[string]interface{}{"type": "object", "description": "高级配置(beacon 路径/TLS/OPSEC 等),create/update 可用"},
|
"config": map[string]interface{}{"type": "object", "description": "高级配置(beacon 路径/TLS/OPSEC 等),create/update 可用。tcp_reverse 可选 allow_legacy_shell:true 允许未加密经典 shell(默认 false)"},
|
||||||
},
|
},
|
||||||
"required": []string{"action"},
|
"required": []string{"action"},
|
||||||
},
|
},
|
||||||
@@ -222,20 +223,23 @@ func registerC2SessionTool(s *mcp.Server, m *c2.Manager, l *zap.Logger) {
|
|||||||
s.RegisterTool(mcp.Tool{
|
s.RegisterTool(mcp.Tool{
|
||||||
Name: builtin.ToolC2Session,
|
Name: builtin.ToolC2Session,
|
||||||
Description: `C2 会话管理。通过 action 参数选择操作:
|
Description: `C2 会话管理。通过 action 参数选择操作:
|
||||||
- list: 列出会话(可按 listener_id/status/os/search 过滤)
|
- list: 列出会话(可按 listener_id/status/os/search/suspicious 过滤)
|
||||||
- get: 获取会话详情及最近任务历史(需 session_id)
|
- get: 获取会话详情及最近任务历史(需 session_id)
|
||||||
- set_sleep: 设置心跳间隔(需 session_id)
|
- set_sleep: 设置心跳间隔(需 session_id)
|
||||||
- kill: 下发 exit 任务让 implant 退出(需 session_id)
|
- kill: 下发 exit 任务让 implant 退出(需 session_id)
|
||||||
- delete: 删除会话记录(需 session_id)`,
|
- delete: 删除单个会话记录(需 session_id)
|
||||||
|
- delete_batch: 批量删除会话(需 session_ids 数组)`,
|
||||||
InputSchema: map[string]interface{}{
|
InputSchema: map[string]interface{}{
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"properties": map[string]interface{}{
|
"properties": map[string]interface{}{
|
||||||
"action": map[string]interface{}{"type": "string", "description": "操作: list/get/set_sleep/kill/delete", "enum": []string{"list", "get", "set_sleep", "kill", "delete"}},
|
"action": map[string]interface{}{"type": "string", "description": "操作: list/get/set_sleep/kill/delete/delete_batch", "enum": []string{"list", "get", "set_sleep", "kill", "delete", "delete_batch"}},
|
||||||
"session_id": map[string]interface{}{"type": "string", "description": "会话 ID(get/set_sleep/kill/delete 需要)"},
|
"session_id": map[string]interface{}{"type": "string", "description": "会话 ID(get/set_sleep/kill/delete 需要)"},
|
||||||
|
"session_ids": map[string]interface{}{"type": "array", "items": map[string]interface{}{"type": "string"}, "description": "会话 ID 列表(delete_batch)"},
|
||||||
"listener_id": map[string]interface{}{"type": "string", "description": "按监听器过滤(list)"},
|
"listener_id": map[string]interface{}{"type": "string", "description": "按监听器过滤(list)"},
|
||||||
"status": map[string]interface{}{"type": "string", "description": "按状态过滤: active/sleeping/dead/killed(list)"},
|
"status": map[string]interface{}{"type": "string", "description": "按状态过滤: active/sleeping/dead/killed(list)"},
|
||||||
"os": map[string]interface{}{"type": "string", "description": "按 OS 过滤: linux/windows/darwin(list)"},
|
"os": map[string]interface{}{"type": "string", "description": "按 OS 过滤: linux/windows/darwin(list)"},
|
||||||
"search": map[string]interface{}{"type": "string", "description": "模糊搜索 hostname/username/IP(list)"},
|
"search": map[string]interface{}{"type": "string", "description": "模糊搜索 hostname/username/IP(list)"},
|
||||||
|
"suspicious": map[string]interface{}{"type": "boolean", "description": "仅疑似误报:离线且 tcp_* / unknown / PID 0(list)"},
|
||||||
"limit": map[string]interface{}{"type": "integer", "description": "返回数量上限(list)"},
|
"limit": map[string]interface{}{"type": "integer", "description": "返回数量上限(list)"},
|
||||||
"sleep_seconds": map[string]interface{}{"type": "integer", "description": "心跳间隔秒数(set_sleep)"},
|
"sleep_seconds": map[string]interface{}{"type": "integer", "description": "心跳间隔秒数(set_sleep)"},
|
||||||
"jitter_percent": map[string]interface{}{"type": "integer", "description": "抖动百分比 0-100(set_sleep)"},
|
"jitter_percent": map[string]interface{}{"type": "integer", "description": "抖动百分比 0-100(set_sleep)"},
|
||||||
@@ -257,6 +261,9 @@ func registerC2SessionTool(s *mcp.Server, m *c2.Manager, l *zap.Logger) {
|
|||||||
if limit := int(getFloat64(params, "limit")); limit > 0 {
|
if limit := int(getFloat64(params, "limit")); limit > 0 {
|
||||||
filter.Limit = limit
|
filter.Limit = limit
|
||||||
}
|
}
|
||||||
|
if v, ok := params["suspicious"].(bool); ok && v {
|
||||||
|
filter.Suspicious = true
|
||||||
|
}
|
||||||
sessions, err := m.DB().ListC2Sessions(filter)
|
sessions, err := m.DB().ListC2Sessions(filter)
|
||||||
return makeC2Result(map[string]interface{}{"sessions": sessions, "count": len(sessions)}, err)
|
return makeC2Result(map[string]interface{}{"sessions": sessions, "count": len(sessions)}, err)
|
||||||
|
|
||||||
@@ -274,8 +281,16 @@ func registerC2SessionTool(s *mcp.Server, m *c2.Manager, l *zap.Logger) {
|
|||||||
case "set_sleep":
|
case "set_sleep":
|
||||||
sleep := int(getFloat64(params, "sleep_seconds"))
|
sleep := int(getFloat64(params, "sleep_seconds"))
|
||||||
jitter := int(getFloat64(params, "jitter_percent"))
|
jitter := int(getFloat64(params, "jitter_percent"))
|
||||||
err := m.DB().SetC2SessionSleep(id, sleep, jitter)
|
task, err := m.SetSessionSleep(id, sleep, jitter)
|
||||||
return makeC2Result(map[string]interface{}{"updated": err == nil, "sleep_seconds": sleep, "jitter_percent": jitter}, err)
|
out := map[string]interface{}{
|
||||||
|
"updated": err == nil,
|
||||||
|
"sleep_seconds": sleep,
|
||||||
|
"jitter_percent": jitter,
|
||||||
|
}
|
||||||
|
if task != nil {
|
||||||
|
out["task_id"] = task.ID
|
||||||
|
}
|
||||||
|
return makeC2Result(out, err)
|
||||||
|
|
||||||
case "kill":
|
case "kill":
|
||||||
task, err := m.EnqueueTask(c2.EnqueueTaskInput{
|
task, err := m.EnqueueTask(c2.EnqueueTaskInput{
|
||||||
@@ -292,6 +307,17 @@ func registerC2SessionTool(s *mcp.Server, m *c2.Manager, l *zap.Logger) {
|
|||||||
err := m.DB().DeleteC2Session(id)
|
err := m.DB().DeleteC2Session(id)
|
||||||
return makeC2Result(map[string]interface{}{"deleted": err == nil}, err)
|
return makeC2Result(map[string]interface{}{"deleted": err == nil}, err)
|
||||||
|
|
||||||
|
case "delete_batch":
|
||||||
|
rawIDs, _ := params["session_ids"].([]interface{})
|
||||||
|
ids := make([]string, 0, len(rawIDs))
|
||||||
|
for _, v := range rawIDs {
|
||||||
|
if s, ok := v.(string); ok && strings.TrimSpace(s) != "" {
|
||||||
|
ids = append(ids, strings.TrimSpace(s))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
n, err := m.DB().DeleteC2SessionsByIDs(ids)
|
||||||
|
return makeC2Result(map[string]interface{}{"deleted": n}, err)
|
||||||
|
|
||||||
default:
|
default:
|
||||||
return makeC2Result(nil, fmt.Errorf("unknown action: %s", action))
|
return makeC2Result(nil, fmt.Errorf("unknown action: %s", action))
|
||||||
}
|
}
|
||||||
@@ -491,11 +517,11 @@ func registerC2PayloadTool(s *mcp.Server, m *c2.Manager, l *zap.Logger, webListe
|
|||||||
Name: builtin.ToolC2Payload,
|
Name: builtin.ToolC2Payload,
|
||||||
Description: fmt.Sprintf(`C2 Payload 生成。通过 action 参数选择操作:
|
Description: fmt.Sprintf(`C2 Payload 生成。通过 action 参数选择操作:
|
||||||
- oneliner: 生成单行 payload。kind 必须与监听器协议一致,否则会失败:
|
- oneliner: 生成单行 payload。kind 必须与监听器协议一致,否则会失败:
|
||||||
• tcp_reverse:裸 TCP 反弹,可用 kind: bash, nc, nc_mkfifo, python, perl, powershell(bash 指 /dev/tcp 类,不是 HTTP)。
|
• tcp_reverse:默认仅支持 build 加密 Beacon;若监听器 config.allow_legacy_shell=true,才可用 kind: bash, nc, nc_mkfifo, python, perl, powershell。
|
||||||
• http_beacon / https_beacon / websocket:仅 HTTP(S) Beacon 轮询,oneliner 只能用 kind: curl_beacon(脚本内用 bash+curl,与「tcp 的 bash」不同)。curl_beacon 返回串末尾含「 &」用于把整个 bash -c 放后台;若用 exec/execute 同步执行,必须整段原样复制(含末尾 &)。若删掉 &,内部 while 死循环占满前台,调用会一直阻塞到超时/杀进程。
|
• http_beacon / https_beacon / websocket:仅 HTTP(S) Beacon 轮询,oneliner 只能用 kind: curl_beacon(脚本内用 bash+curl,与「tcp 的 bash」不同)。curl_beacon 返回串末尾含「 &」用于把整个 bash -c 放后台;若用 exec/execute 同步执行,必须整段原样复制(含末尾 &)。若删掉 &,内部 while 死循环占满前台,调用会一直阻塞到超时/杀进程。
|
||||||
• 需要经典 bash 反弹 shell 时:先 c2_listener create type=tcp_reverse,再对该监听器用 kind=bash。
|
• 公网部署 tcp_reverse 请用 build 生成加密 Beacon,勿开启 allow_legacy_shell。
|
||||||
• 省略 kind 时,会按监听器类型自动选第一个兼容类型(HTTP 系默认为 curl_beacon)。
|
• 省略 kind 时,会按监听器类型自动选第一个兼容类型(HTTP 系默认为 curl_beacon)。
|
||||||
- build: 交叉编译 beacon 二进制。支持 http_beacon / https_beacon / websocket / tcp_reverse(tcp_reverse 下植入端回连后先发魔数 CSB1,再走与 HTTP 相同的 AES-GCM JSON 语义;未发魔数的连接仍按经典交互 shell 处理)。
|
- build: 交叉编译 beacon 二进制。支持 http_beacon / https_beacon / websocket / tcp_reverse(tcp_reverse 植入端回连后先发魔数 CSB1,再经 AES-GCM 解密且校验 ImplantToken 后才登记会话)。
|
||||||
依赖的监听器 bind_port 须避开本服务 Web 端口 %d(配置 server.port,与 c2_listener 描述一致),否则 Beacon 无法正确回连。`, webListenPort),
|
依赖的监听器 bind_port 须避开本服务 Web 端口 %d(配置 server.port,与 c2_listener 描述一致),否则 Beacon 无法正确回连。`, webListenPort),
|
||||||
InputSchema: map[string]interface{}{
|
InputSchema: map[string]interface{}{
|
||||||
"type": "object",
|
"type": "object",
|
||||||
@@ -540,6 +566,9 @@ func registerC2PayloadTool(s *mcp.Server, m *c2.Manager, l *zap.Logger, webListe
|
|||||||
}
|
}
|
||||||
return makeC2Result(nil, fmt.Errorf("监听器类型 %s 不支持 %s,兼容类型: %v", listener.Type, kind, names))
|
return makeC2Result(nil, fmt.Errorf("监听器类型 %s 不支持 %s,兼容类型: %v", listener.Type, kind, names))
|
||||||
}
|
}
|
||||||
|
if err := c2.ValidateOnelinerForListener(listener, kind); err != nil {
|
||||||
|
return makeC2Result(nil, err)
|
||||||
|
}
|
||||||
input := c2.OnelinerInput{
|
input := c2.OnelinerInput{
|
||||||
Kind: kind,
|
Kind: kind,
|
||||||
Host: host,
|
Host: host,
|
||||||
|
|||||||
@@ -89,6 +89,28 @@ func registerProjectFactTools(mcpServer *mcp.Server, db *database.DB, cfg *confi
|
|||||||
"type": "string",
|
"type": "string",
|
||||||
"description": "可选:关联的漏洞记录 ID",
|
"description": "可选:关联的漏洞记录 ID",
|
||||||
},
|
},
|
||||||
|
"links": map[string]interface{}{
|
||||||
|
"type": "array",
|
||||||
|
"description": "可选:关系边(from → 当前 fact)。finding 至少 1 条 {from:target/*, type:discovered_on};finding 上记录 exploit 用 {from:exploit/*, type:exploits}。省略保留已有边;传 [] 清空全部关系边。",
|
||||||
|
"items": map[string]interface{}{
|
||||||
|
"type": "object",
|
||||||
|
"properties": map[string]interface{}{
|
||||||
|
"from": map[string]interface{}{
|
||||||
|
"type": "string",
|
||||||
|
"description": "来源 fact_key:存储为 from → 当前 fact",
|
||||||
|
},
|
||||||
|
"type": map[string]interface{}{
|
||||||
|
"type": "string",
|
||||||
|
"description": "depends_on | leads_to | enables | exploits | discovered_on | contains | part_of | supports",
|
||||||
|
},
|
||||||
|
"confidence": map[string]interface{}{
|
||||||
|
"type": "string",
|
||||||
|
"description": "confirmed | tentative | deprecated",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"required": []string{"from", "type"},
|
||||||
|
},
|
||||||
|
},
|
||||||
},
|
},
|
||||||
"required": []string{"fact_key", "summary"},
|
"required": []string{"fact_key", "summary"},
|
||||||
},
|
},
|
||||||
@@ -124,7 +146,26 @@ func registerProjectFactTools(mcpServer *mcp.Server, db *database.DB, cfg *confi
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return textResult("错误: "+err.Error(), true), nil
|
return textResult("错误: "+err.Error(), true), nil
|
||||||
}
|
}
|
||||||
|
if _, hasLinks := args["links"]; hasLinks {
|
||||||
|
linkInputs, err := project.ParseFactLinkInputs(args["links"])
|
||||||
|
if err != nil {
|
||||||
|
return textResult("错误: "+err.Error(), true), nil
|
||||||
|
}
|
||||||
|
convID := agent.ConversationIDFromContext(ctx)
|
||||||
|
if err := project.PersistFactLinksFromParsed(db, projectID, created.FactKey, convID, linkInputs, true); err != nil {
|
||||||
|
return textResult("错误: 保存关系边失败: "+err.Error(), true), nil
|
||||||
|
}
|
||||||
|
created, _ = db.GetProjectFactByKey(projectID, created.FactKey)
|
||||||
|
} else if parsed := project.ParseLinksFromBody(created.Body); len(parsed) > 0 {
|
||||||
|
if err := project.PersistFactIncomingLinks(db, projectID, created.FactKey, parsed, true); err != nil {
|
||||||
|
return textResult("错误: 从 body 解析边失败: "+err.Error(), true), nil
|
||||||
|
}
|
||||||
|
created, _ = db.GetProjectFactByKey(projectID, created.FactKey)
|
||||||
|
}
|
||||||
msg := fmt.Sprintf("事实已保存。\nfact_key: %s\nid: %s\nconfidence: %s", created.FactKey, created.ID, created.Confidence)
|
msg := fmt.Sprintf("事实已保存。\nfact_key: %s\nid: %s\nconfidence: %s", created.FactKey, created.ID, created.Confidence)
|
||||||
|
if in, _ := db.ListIncomingProjectFactEdges(projectID, created.FactKey); len(in) > 0 {
|
||||||
|
msg += "\n关系边: " + project.FormatFactLinksText(in)
|
||||||
|
}
|
||||||
if warn := project.SparseBodyWarningIfNeeded(f.Category, f.FactKey, f.Body); warn != "" {
|
if warn := project.SparseBodyWarningIfNeeded(f.Category, f.FactKey, f.Body); warn != "" {
|
||||||
msg += warn
|
msg += warn
|
||||||
}
|
}
|
||||||
@@ -164,6 +205,18 @@ func registerProjectFactTools(mcpServer *mcp.Server, db *database.DB, cfg *confi
|
|||||||
if f.SourceConversationID != "" {
|
if f.SourceConversationID != "" {
|
||||||
msg += fmt.Sprintf("\nsource_conversation_id: %s", f.SourceConversationID)
|
msg += fmt.Sprintf("\nsource_conversation_id: %s", f.SourceConversationID)
|
||||||
}
|
}
|
||||||
|
if in, _ := db.ListIncomingProjectFactEdges(projectID, f.FactKey); len(in) > 0 {
|
||||||
|
msg += "\n关系边(from → 本 fact):\n"
|
||||||
|
for _, e := range in {
|
||||||
|
msg += fmt.Sprintf("- %s ← %s (%s)\n", e.EdgeType, e.SourceFactKey, e.Confidence)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if out, _ := db.ListOutgoingProjectFactEdges(projectID, f.FactKey); len(out) > 0 {
|
||||||
|
msg += "指向其他事实:\n"
|
||||||
|
for _, e := range out {
|
||||||
|
msg += fmt.Sprintf("- %s → %s (%s)\n", e.EdgeType, e.TargetFactKey, e.Confidence)
|
||||||
|
}
|
||||||
|
}
|
||||||
msg += "\n\n--- body ---\n" + f.Body
|
msg += "\n\n--- body ---\n" + f.Body
|
||||||
if warn := project.SparseBodyWarningIfNeeded(f.Category, f.FactKey, f.Body); warn != "" {
|
if warn := project.SparseBodyWarningIfNeeded(f.Category, f.FactKey, f.Body); warn != "" {
|
||||||
msg += warn
|
msg += warn
|
||||||
|
|||||||
@@ -0,0 +1,203 @@
|
|||||||
|
package attackchain
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"regexp"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"cyberstrike-ai/internal/database"
|
||||||
|
"cyberstrike-ai/internal/project"
|
||||||
|
|
||||||
|
"github.com/google/uuid"
|
||||||
|
)
|
||||||
|
|
||||||
|
var promoteSlugSanitizer = regexp.MustCompile(`[^a-z0-9._/-]+`)
|
||||||
|
|
||||||
|
// PromoteToProjectResult 攻击链沉淀结果。
|
||||||
|
type PromoteToProjectResult struct {
|
||||||
|
FactsCreated int `json:"facts_created"`
|
||||||
|
FactsUpdated int `json:"facts_updated"`
|
||||||
|
EdgesCreated int `json:"edges_created"`
|
||||||
|
FactKeys []string `json:"fact_keys"`
|
||||||
|
Graph *database.ProjectFactGraph `json:"graph,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// PromoteToProject 将对话攻击链沉淀为项目事实与边。
|
||||||
|
func PromoteToProject(db *database.DB, projectID, conversationID string) (*PromoteToProjectResult, error) {
|
||||||
|
if db == nil {
|
||||||
|
return nil, fmt.Errorf("database 未初始化")
|
||||||
|
}
|
||||||
|
projectID = strings.TrimSpace(projectID)
|
||||||
|
conversationID = strings.TrimSpace(conversationID)
|
||||||
|
if projectID == "" || conversationID == "" {
|
||||||
|
return nil, fmt.Errorf("project_id 与 conversation_id 必填")
|
||||||
|
}
|
||||||
|
if _, err := db.GetProject(projectID); err != nil {
|
||||||
|
return nil, fmt.Errorf("项目不存在")
|
||||||
|
}
|
||||||
|
conv, err := db.GetConversation(conversationID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("对话不存在")
|
||||||
|
}
|
||||||
|
if pid := strings.TrimSpace(conv.ProjectID); pid != "" && pid != projectID {
|
||||||
|
return nil, fmt.Errorf("对话已绑定其他项目")
|
||||||
|
}
|
||||||
|
|
||||||
|
nodes, err := db.LoadAttackChainNodes(conversationID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
edges, err := db.LoadAttackChainEdges(conversationID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if len(nodes) == 0 {
|
||||||
|
return nil, fmt.Errorf("该对话尚无攻击链,请先在对话中生成攻击链")
|
||||||
|
}
|
||||||
|
|
||||||
|
res := &PromoteToProjectResult{}
|
||||||
|
nodeToKey := make(map[string]string, len(nodes))
|
||||||
|
usedKeys := map[string]int{}
|
||||||
|
|
||||||
|
for _, node := range nodes {
|
||||||
|
key := allocatePromoteFactKey(node, usedKeys)
|
||||||
|
nodeToKey[node.ID] = key
|
||||||
|
category := mapPromoteNodeCategory(node.Type)
|
||||||
|
existing, getErr := db.GetProjectFactByKey(projectID, key)
|
||||||
|
f := &database.ProjectFact{
|
||||||
|
ProjectID: projectID,
|
||||||
|
FactKey: key,
|
||||||
|
Category: category,
|
||||||
|
Summary: strings.TrimSpace(node.Label),
|
||||||
|
Body: formatPromotedFactBody(node, conversationID),
|
||||||
|
Confidence: "tentative",
|
||||||
|
SourceConversationID: conversationID,
|
||||||
|
}
|
||||||
|
if getErr == nil && existing != nil {
|
||||||
|
f.ID = existing.ID
|
||||||
|
f.CreatedAt = existing.CreatedAt
|
||||||
|
if strings.TrimSpace(f.Summary) == "" {
|
||||||
|
f.Summary = existing.Summary
|
||||||
|
}
|
||||||
|
if _, err := db.UpsertProjectFact(f); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
res.FactsUpdated++
|
||||||
|
} else {
|
||||||
|
if _, err := db.UpsertProjectFact(f); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
res.FactsCreated++
|
||||||
|
}
|
||||||
|
res.FactKeys = append(res.FactKeys, key)
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, edge := range edges {
|
||||||
|
srcKey, ok1 := nodeToKey[edge.Source]
|
||||||
|
tgtKey, ok2 := nodeToKey[edge.Target]
|
||||||
|
if !ok1 || !ok2 || srcKey == tgtKey {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
edgeType := mapPromoteEdgeType(edge.Type)
|
||||||
|
incoming, _ := db.ListIncomingProjectFactEdges(projectID, tgtKey)
|
||||||
|
merged := project.MergeLinkFromInputsUnique(promoteFromEdgeInputsFromDB(incoming), []database.ProjectFactEdgeFromInput{{From: srcKey, Type: edgeType}})
|
||||||
|
if err := db.ReplaceIncomingProjectFactEdges(projectID, tgtKey, merged); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
res.EdgesCreated++
|
||||||
|
if fact, err := db.GetProjectFactByKey(projectID, tgtKey); err == nil {
|
||||||
|
in, _ := db.ListIncomingProjectFactEdges(projectID, tgtKey)
|
||||||
|
fact.Body = project.SyncBodyLinksSection(fact.Body, in)
|
||||||
|
_, _ = db.UpsertProjectFact(fact)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
graph, _ := project.BuildProjectFactGraph(db, projectID, "full", true)
|
||||||
|
res.Graph = graph
|
||||||
|
return res, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func promoteFromEdgeInputsFromDB(edges []*database.ProjectFactEdge) []database.ProjectFactEdgeFromInput {
|
||||||
|
out := make([]database.ProjectFactEdgeFromInput, 0, len(edges))
|
||||||
|
for _, e := range edges {
|
||||||
|
out = append(out, database.ProjectFactEdgeFromInput{From: e.SourceFactKey, Type: e.EdgeType, Confidence: e.Confidence})
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
func mapPromoteNodeCategory(nodeType string) string {
|
||||||
|
switch strings.ToLower(strings.TrimSpace(nodeType)) {
|
||||||
|
case "target":
|
||||||
|
return project.FactCategoryTarget
|
||||||
|
case "vulnerability":
|
||||||
|
return project.FactCategoryFinding
|
||||||
|
case "action":
|
||||||
|
return project.FactCategoryChain
|
||||||
|
default:
|
||||||
|
return project.FactCategoryNote
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func mapPromoteEdgeType(t string) string {
|
||||||
|
switch strings.ToLower(strings.TrimSpace(t)) {
|
||||||
|
case "discovers", "discovered_on", "targets":
|
||||||
|
return "discovered_on"
|
||||||
|
case "exploits":
|
||||||
|
return "exploits"
|
||||||
|
case "enables":
|
||||||
|
return "enables"
|
||||||
|
case "depends_on":
|
||||||
|
return "depends_on"
|
||||||
|
default:
|
||||||
|
return "leads_to"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func allocatePromoteFactKey(node Node, used map[string]int) string {
|
||||||
|
prefix := "chain/"
|
||||||
|
switch strings.ToLower(strings.TrimSpace(node.Type)) {
|
||||||
|
case "target":
|
||||||
|
prefix = "target/"
|
||||||
|
case "vulnerability":
|
||||||
|
prefix = "finding/"
|
||||||
|
case "action":
|
||||||
|
prefix = "chain/"
|
||||||
|
}
|
||||||
|
base := promoteSlugify(node.Label)
|
||||||
|
if base == "" {
|
||||||
|
base = promoteSlugify(node.ID)
|
||||||
|
}
|
||||||
|
if base == "" {
|
||||||
|
base = uuid.New().String()[:8]
|
||||||
|
}
|
||||||
|
key := prefix + base
|
||||||
|
if n, ok := used[key]; ok {
|
||||||
|
n++
|
||||||
|
used[key] = n
|
||||||
|
key = fmt.Sprintf("%s-%d", key, n)
|
||||||
|
} else {
|
||||||
|
used[key] = 1
|
||||||
|
}
|
||||||
|
return key
|
||||||
|
}
|
||||||
|
|
||||||
|
func promoteSlugify(s string) string {
|
||||||
|
s = strings.ToLower(strings.TrimSpace(s))
|
||||||
|
s = strings.NewReplacer(" ", "-", "—", "-", "–", "-", "/", "-").Replace(s)
|
||||||
|
s = promoteSlugSanitizer.ReplaceAllString(s, "-")
|
||||||
|
s = strings.Trim(s, "-")
|
||||||
|
if len(s) > 64 {
|
||||||
|
s = s[:64]
|
||||||
|
}
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
|
||||||
|
func formatPromotedFactBody(node Node, conversationID string) string {
|
||||||
|
var b strings.Builder
|
||||||
|
b.WriteString("## 来源\n")
|
||||||
|
b.WriteString(fmt.Sprintf("- 对话攻击链沉淀\n- source_conversation_id: %s\n- node_id: %s\n- node_type: %s\n\n", conversationID, node.ID, node.Type))
|
||||||
|
b.WriteString("## 摘要\n")
|
||||||
|
b.WriteString(strings.TrimSpace(node.Label))
|
||||||
|
b.WriteString("\n\n## 关联\n- 结构化关系边(自动同步):\n (见项目攻击路径图)\n")
|
||||||
|
return b.String()
|
||||||
|
}
|
||||||
@@ -20,10 +20,9 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
// TCPReverseListener 监听 TCP 端口,等待目标机反弹连接。
|
// TCPReverseListener 监听 TCP 端口,等待目标机反弹连接。
|
||||||
// 经典模式:纯交互式 raw shell,与 nc / bash -i >& /dev/tcp 兼容。
|
// 默认仅接受加密 TCP Beacon:连接后先发送魔数 CSB1,再经 AES-GCM 解密且校验 ImplantToken 后才登记会话。
|
||||||
// 二进制 Beacon:连接后先发送魔数 CSB1,随后使用与 HTTP Beacon 相同的 AES-GCM JSON 语义(成帧见 tcp_beacon_server.go)。
|
// 可选经典模式(config.allow_legacy_shell=true):纯交互式 raw shell,与 nc / bash -i >& /dev/tcp 兼容,无鉴权,仅建议内网实验。
|
||||||
// 每个新连接自动生成一个 implant_uuid(基于远端地址 + 启动时间 hash),登记为 c2_session;
|
// 任务派发(经典模式):同步 exec —— 收到 task 时直接 send 命令字节并读取输出(带结束标记)。
|
||||||
// 任务派发:使用同步 exec 模式 —— 收到 task 时直接 send 命令字节并读取输出(带结束标记)。
|
|
||||||
type TCPReverseListener struct {
|
type TCPReverseListener struct {
|
||||||
rec *database.C2Listener
|
rec *database.C2Listener
|
||||||
cfg *ListenerConfig
|
cfg *ListenerConfig
|
||||||
@@ -122,12 +121,14 @@ func (l *TCPReverseListener) acceptLoop() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// handleConn 一个连接=一个会话:先识别二进制 TCP Beacon(魔数 CSB1),否则走经典交互式 shell。
|
// handleConn 先识别加密 TCP Beacon(魔数 CSB1 + AES-GCM + Token);未通过则按配置拒绝或走经典 shell。
|
||||||
func (l *TCPReverseListener) handleConn(conn net.Conn) {
|
func (l *TCPReverseListener) handleConn(conn net.Conn) {
|
||||||
br := bufio.NewReader(conn)
|
br := bufio.NewReader(conn)
|
||||||
_ = conn.SetReadDeadline(time.Now().Add(20 * time.Second))
|
remote := conn.RemoteAddr().String()
|
||||||
prefix, err := br.Peek(4)
|
|
||||||
if err == nil && len(prefix) == 4 && string(prefix) == tcpBeaconMagic {
|
_ = conn.SetReadDeadline(time.Now().Add(tcpBeaconPeekTimeout))
|
||||||
|
prefix, peekErr := br.Peek(4)
|
||||||
|
if peekErr == nil && len(prefix) == 4 && string(prefix) == tcpBeaconMagic {
|
||||||
if _, err := br.Discard(4); err != nil {
|
if _, err := br.Discard(4); err != nil {
|
||||||
_ = conn.Close()
|
_ = conn.Close()
|
||||||
return
|
return
|
||||||
@@ -136,14 +137,22 @@ func (l *TCPReverseListener) handleConn(conn net.Conn) {
|
|||||||
l.handleTCPBeaconSession(conn, br)
|
l.handleTCPBeaconSession(conn, br)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if !l.cfg.AllowLegacyShell {
|
||||||
|
l.logger.Debug("tcp_reverse 拒绝未加密连接", zap.String("remote", remote))
|
||||||
|
_ = conn.Close()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
_ = conn.SetReadDeadline(time.Time{})
|
_ = conn.SetReadDeadline(time.Time{})
|
||||||
l.handleShellConn(conn, br)
|
l.handleShellConn(conn, br)
|
||||||
}
|
}
|
||||||
|
|
||||||
// handleShellConn 经典裸 TCP 反弹 shell(与 nc/bash /dev/tcp 兼容)。
|
// handleShellConn 经典裸 TCP 反弹 shell(与 nc/bash /dev/tcp 兼容);需监听器显式开启 allow_legacy_shell。
|
||||||
func (l *TCPReverseListener) handleShellConn(conn net.Conn, br *bufio.Reader) {
|
func (l *TCPReverseListener) handleShellConn(conn net.Conn, br *bufio.Reader) {
|
||||||
remote := conn.RemoteAddr().String()
|
remote := conn.RemoteAddr().String()
|
||||||
host, _, _ := net.SplitHostPort(remote)
|
host, _, _ := net.SplitHostPort(remote)
|
||||||
|
|
||||||
// 用 listener+remote_ip 生成稳定 implant_uuid,使同一来源的重连复用同一会话
|
// 用 listener+remote_ip 生成稳定 implant_uuid,使同一来源的重连复用同一会话
|
||||||
uuidSeed := fmt.Sprintf("%s|%s", l.rec.ID, host)
|
uuidSeed := fmt.Sprintf("%s|%s", l.rec.ID, host)
|
||||||
hash := sha256.Sum256([]byte(uuidSeed))
|
hash := sha256.Sum256([]byte(uuidSeed))
|
||||||
|
|||||||
+41
-1
@@ -381,8 +381,10 @@ func (m *Manager) IngestCheckIn(listenerID string, req ImplantCheckInRequest) (*
|
|||||||
Metadata: req.Metadata,
|
Metadata: req.Metadata,
|
||||||
}
|
}
|
||||||
if existing != nil {
|
if existing != nil {
|
||||||
// 保留原 ID/FirstSeenAt/Note,避免被覆盖
|
// 保留原 ID/FirstSeenAt/Note 与操作员设置的 sleep/jitter,避免被 beacon 心跳上报覆盖
|
||||||
session.FirstSeenAt = existing.FirstSeenAt
|
session.FirstSeenAt = existing.FirstSeenAt
|
||||||
|
session.SleepSeconds = existing.SleepSeconds
|
||||||
|
session.JitterPercent = existing.JitterPercent
|
||||||
if session.Note == "" {
|
if session.Note == "" {
|
||||||
session.Note = existing.Note
|
session.Note = existing.Note
|
||||||
}
|
}
|
||||||
@@ -413,6 +415,44 @@ func (m *Manager) IngestCheckIn(listenerID string, req ImplantCheckInRequest) (*
|
|||||||
return session, nil
|
return session, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetSessionSleep 更新会话期望的心跳间隔,并向植入体下发 sleep 任务以尽快生效。
|
||||||
|
func (m *Manager) SetSessionSleep(sessionID string, sleepSeconds, jitterPercent int) (*database.C2Task, error) {
|
||||||
|
if strings.TrimSpace(sessionID) == "" {
|
||||||
|
return nil, ErrInvalidInput
|
||||||
|
}
|
||||||
|
if sleepSeconds < 1 {
|
||||||
|
sleepSeconds = 1
|
||||||
|
}
|
||||||
|
if jitterPercent < 0 {
|
||||||
|
jitterPercent = 0
|
||||||
|
}
|
||||||
|
if jitterPercent > 100 {
|
||||||
|
jitterPercent = 100
|
||||||
|
}
|
||||||
|
if err := m.db.SetC2SessionSleep(sessionID, sleepSeconds, jitterPercent); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
task, err := m.EnqueueTask(EnqueueTaskInput{
|
||||||
|
SessionID: sessionID,
|
||||||
|
TaskType: TaskTypeSleep,
|
||||||
|
Payload: map[string]interface{}{
|
||||||
|
"seconds": sleepSeconds,
|
||||||
|
"jitter": jitterPercent,
|
||||||
|
},
|
||||||
|
Source: "manual",
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
m.logger.Warn("sleep 任务入队失败", zap.Error(err), zap.String("session_id", sessionID))
|
||||||
|
}
|
||||||
|
m.publishEvent("info", "session", sessionID, "",
|
||||||
|
fmt.Sprintf("Sleep 已更新: %ds (抖动 %d%%)", sleepSeconds, jitterPercent),
|
||||||
|
map[string]interface{}{
|
||||||
|
"sleep_seconds": sleepSeconds,
|
||||||
|
"jitter_percent": jitterPercent,
|
||||||
|
})
|
||||||
|
return task, nil
|
||||||
|
}
|
||||||
|
|
||||||
// MarkSessionDead 心跳超时检测器调用:标记会话为 dead
|
// MarkSessionDead 心跳超时检测器调用:标记会话为 dead
|
||||||
func (m *Manager) MarkSessionDead(sessionID string) error {
|
func (m *Manager) MarkSessionDead(sessionID string) error {
|
||||||
if err := m.db.SetC2SessionStatus(sessionID, string(SessionDead)); err != nil {
|
if err := m.db.SetC2SessionStatus(sessionID, string(SessionDead)); err != nil {
|
||||||
|
|||||||
@@ -0,0 +1,118 @@
|
|||||||
|
package c2
|
||||||
|
|
||||||
|
import (
|
||||||
|
"path/filepath"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"cyberstrike-ai/internal/database"
|
||||||
|
|
||||||
|
"go.uber.org/zap"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestIngestCheckIn_PreservesOperatorSleepOnHeartbeat(t *testing.T) {
|
||||||
|
tmp := t.TempDir()
|
||||||
|
db, err := database.NewDB(filepath.Join(tmp, "c2.sqlite"), zap.NewNop())
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
t.Cleanup(func() { _ = db.Close() })
|
||||||
|
|
||||||
|
mgr := NewManager(db, zap.NewNop(), tmp)
|
||||||
|
ln, err := mgr.CreateListener(CreateListenerInput{
|
||||||
|
Name: "t",
|
||||||
|
Type: string(ListenerTypeHTTPBeacon),
|
||||||
|
BindHost: "127.0.0.1",
|
||||||
|
BindPort: 18080,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
first, err := mgr.IngestCheckIn(ln.ID, ImplantCheckInRequest{
|
||||||
|
ImplantUUID: "implant-uuid-1",
|
||||||
|
Hostname: "host1",
|
||||||
|
Username: "user",
|
||||||
|
OS: "darwin",
|
||||||
|
Arch: "amd64",
|
||||||
|
SleepSeconds: 5,
|
||||||
|
JitterPercent: 0,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := db.SetC2SessionSleep(first.ID, 30, 20); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
second, err := mgr.IngestCheckIn(ln.ID, ImplantCheckInRequest{
|
||||||
|
ImplantUUID: "implant-uuid-1",
|
||||||
|
Hostname: "host1",
|
||||||
|
Username: "user",
|
||||||
|
OS: "darwin",
|
||||||
|
Arch: "amd64",
|
||||||
|
SleepSeconds: 5,
|
||||||
|
JitterPercent: 0,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if second.SleepSeconds != 30 || second.JitterPercent != 20 {
|
||||||
|
t.Fatalf("expected sleep=30 jitter=20, got sleep=%d jitter=%d", second.SleepSeconds, second.JitterPercent)
|
||||||
|
}
|
||||||
|
|
||||||
|
stored, err := db.GetC2Session(first.ID)
|
||||||
|
if err != nil || stored == nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if stored.SleepSeconds != 30 || stored.JitterPercent != 20 {
|
||||||
|
t.Fatalf("db: expected sleep=30 jitter=20, got sleep=%d jitter=%d", stored.SleepSeconds, stored.JitterPercent)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSetSessionSleep_UpdatesDBAndEnqueuesTask(t *testing.T) {
|
||||||
|
tmp := t.TempDir()
|
||||||
|
db, err := database.NewDB(filepath.Join(tmp, "c2.sqlite"), zap.NewNop())
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
t.Cleanup(func() { _ = db.Close() })
|
||||||
|
|
||||||
|
mgr := NewManager(db, zap.NewNop(), tmp)
|
||||||
|
ln, err := mgr.CreateListener(CreateListenerInput{
|
||||||
|
Name: "t2",
|
||||||
|
Type: string(ListenerTypeHTTPBeacon),
|
||||||
|
BindHost: "127.0.0.1",
|
||||||
|
BindPort: 18081,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
sess, err := mgr.IngestCheckIn(ln.ID, ImplantCheckInRequest{
|
||||||
|
ImplantUUID: "implant-uuid-2",
|
||||||
|
Hostname: "host2",
|
||||||
|
Username: "user",
|
||||||
|
OS: "linux",
|
||||||
|
Arch: "amd64",
|
||||||
|
SleepSeconds: 5,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
task, err := mgr.SetSessionSleep(sess.ID, 15, 10)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if task == nil || task.TaskType != string(TaskTypeSleep) {
|
||||||
|
t.Fatalf("expected sleep task, got %#v", task)
|
||||||
|
}
|
||||||
|
|
||||||
|
stored, err := db.GetC2Session(sess.ID)
|
||||||
|
if err != nil || stored == nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if stored.SleepSeconds != 15 || stored.JitterPercent != 10 {
|
||||||
|
t.Fatalf("expected sleep=15 jitter=10, got sleep=%d jitter=%d", stored.SleepSeconds, stored.JitterPercent)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -1,9 +1,12 @@
|
|||||||
package c2
|
package c2
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/url"
|
"net/url"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
|
"cyberstrike-ai/internal/database"
|
||||||
)
|
)
|
||||||
|
|
||||||
// OnelinerKind 单行 payload 的语言/形式
|
// OnelinerKind 单行 payload 的语言/形式
|
||||||
@@ -79,6 +82,23 @@ type OnelinerInput struct {
|
|||||||
ImplantToken string // HTTP Beacon 鉴权 token
|
ImplantToken string // HTTP Beacon 鉴权 token
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ValidateOnelinerForListener 校验 oneliner 与监听器配置是否匹配(如 tcp_reverse 默认要求加密 Beacon)。
|
||||||
|
func ValidateOnelinerForListener(listener *database.C2Listener, kind OnelinerKind) error {
|
||||||
|
if listener == nil {
|
||||||
|
return fmt.Errorf("listener is nil")
|
||||||
|
}
|
||||||
|
if ListenerType(listener.Type) == ListenerTypeTCPReverse && tcpOnelinerKinds[kind] {
|
||||||
|
cfg := &ListenerConfig{}
|
||||||
|
if strings.TrimSpace(listener.ConfigJSON) != "" {
|
||||||
|
_ = json.Unmarshal([]byte(listener.ConfigJSON), cfg)
|
||||||
|
}
|
||||||
|
if !cfg.AllowLegacyShell {
|
||||||
|
return fmt.Errorf("监听器未开启 allow_legacy_shell:tcp_reverse 默认仅接受 CSB1 加密 Beacon(AES-GCM + Token);请用 build 生成 beacon,或显式开启 allow_legacy_shell(公网不推荐)")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// GenerateOneliner 生成单行 payload。
|
// GenerateOneliner 生成单行 payload。
|
||||||
// 设计要点:
|
// 设计要点:
|
||||||
// - 不依赖目标机预装的可执行(除该 oneliner 关键的 bash/python/perl 等);
|
// - 不依赖目标机预装的可执行(除该 oneliner 关键的 bash/python/perl 等);
|
||||||
|
|||||||
@@ -23,6 +23,9 @@ import (
|
|||||||
// tcpBeaconMagic 二进制 Beacon 在反向 TCP 连接建立后首先发送的 4 字节,用于与经典 shell 反弹区分。
|
// tcpBeaconMagic 二进制 Beacon 在反向 TCP 连接建立后首先发送的 4 字节,用于与经典 shell 反弹区分。
|
||||||
const tcpBeaconMagic = "CSB1"
|
const tcpBeaconMagic = "CSB1"
|
||||||
|
|
||||||
|
// tcpBeaconPeekTimeout 等待 CSB1 魔数的探测窗口;合法 Beacon 连接后立即发送魔数。
|
||||||
|
const tcpBeaconPeekTimeout = 2 * time.Second
|
||||||
|
|
||||||
// tcpBeaconMaxFrame 单帧密文(base64 字符串)最大字节数,防止 OOM。
|
// tcpBeaconMaxFrame 单帧密文(base64 字符串)最大字节数,防止 OOM。
|
||||||
const tcpBeaconMaxFrame = 64 << 20
|
const tcpBeaconMaxFrame = 64 << 20
|
||||||
|
|
||||||
|
|||||||
@@ -141,6 +141,8 @@ type ListenerConfig struct {
|
|||||||
MaxConcurrentTasks int `json:"max_concurrent_tasks,omitempty"`
|
MaxConcurrentTasks int `json:"max_concurrent_tasks,omitempty"`
|
||||||
// CallbackHost 植入端/Payload 使用的回连主机名(可选);与 bind_host 分离,便于 NAT/ECS 等场景
|
// CallbackHost 植入端/Payload 使用的回连主机名(可选);与 bind_host 分离,便于 NAT/ECS 等场景
|
||||||
CallbackHost string `json:"callback_host,omitempty"`
|
CallbackHost string `json:"callback_host,omitempty"`
|
||||||
|
// AllowLegacyShell 为 true 时 tcp_reverse 允许未加密的经典 bash/nc 反弹 shell 登记会话(默认 false,公网部署强烈不建议开启)
|
||||||
|
AllowLegacyShell bool `json:"allow_legacy_shell,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// ApplyDefaults 对未填字段填默认值;调用方负责持久化时序列化新值
|
// ApplyDefaults 对未填字段填默认值;调用方负责持久化时序列化新值
|
||||||
|
|||||||
@@ -45,6 +45,7 @@ type ProjectConfig struct {
|
|||||||
Enabled bool `yaml:"enabled" json:"enabled"`
|
Enabled bool `yaml:"enabled" json:"enabled"`
|
||||||
DefaultProjectID string `yaml:"default_project_id,omitempty" json:"default_project_id,omitempty"` // 机器人/批量等无显式项目时绑定的默认项目
|
DefaultProjectID string `yaml:"default_project_id,omitempty" json:"default_project_id,omitempty"` // 机器人/批量等无显式项目时绑定的默认项目
|
||||||
FactIndexMaxRunes int `yaml:"fact_index_max_runes,omitempty" json:"fact_index_max_runes,omitempty"`
|
FactIndexMaxRunes int `yaml:"fact_index_max_runes,omitempty" json:"fact_index_max_runes,omitempty"`
|
||||||
|
FactIndexPathMaxRunes int `yaml:"fact_index_path_max_runes,omitempty" json:"fact_index_path_max_runes,omitempty"`
|
||||||
FactSummaryMaxRunes int `yaml:"fact_summary_max_runes,omitempty" json:"fact_summary_max_runes,omitempty"`
|
FactSummaryMaxRunes int `yaml:"fact_summary_max_runes,omitempty" json:"fact_summary_max_runes,omitempty"`
|
||||||
DefaultInjectDeprecated bool `yaml:"default_inject_deprecated,omitempty" json:"default_inject_deprecated,omitempty"`
|
DefaultInjectDeprecated bool `yaml:"default_inject_deprecated,omitempty" json:"default_inject_deprecated,omitempty"`
|
||||||
}
|
}
|
||||||
@@ -57,6 +58,14 @@ func (c ProjectConfig) FactIndexMaxRunesEffective() int {
|
|||||||
return c.FactIndexMaxRunes
|
return c.FactIndexMaxRunes
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// FactIndexPathMaxRunesEffective 攻击路径速览段的最大 rune 数(从 fact_index_max_runes 预算中预留)。
|
||||||
|
func (c ProjectConfig) FactIndexPathMaxRunesEffective() int {
|
||||||
|
if c.FactIndexPathMaxRunes <= 0 {
|
||||||
|
return 1000
|
||||||
|
}
|
||||||
|
return c.FactIndexPathMaxRunes
|
||||||
|
}
|
||||||
|
|
||||||
// FactSummaryMaxRunesEffective upsert 时 summary 最大 rune 数(索引一行,宜含验证要点)。
|
// FactSummaryMaxRunesEffective upsert 时 summary 最大 rune 数(索引一行,宜含验证要点)。
|
||||||
func (c ProjectConfig) FactSummaryMaxRunesEffective() int {
|
func (c ProjectConfig) FactSummaryMaxRunesEffective() int {
|
||||||
if c.FactSummaryMaxRunes <= 0 {
|
if c.FactSummaryMaxRunes <= 0 {
|
||||||
@@ -231,7 +240,7 @@ type MultiAgentEinoMiddlewareConfig struct {
|
|||||||
PlantaskRelDir string `yaml:"plantask_rel_dir,omitempty" json:"plantask_rel_dir,omitempty"`
|
PlantaskRelDir string `yaml:"plantask_rel_dir,omitempty" json:"plantask_rel_dir,omitempty"`
|
||||||
// Reduction truncates/offloads large tool outputs (requires eino local backend for Write).
|
// Reduction truncates/offloads large tool outputs (requires eino local backend for Write).
|
||||||
ReductionEnable bool `yaml:"reduction_enable,omitempty" json:"reduction_enable,omitempty"`
|
ReductionEnable bool `yaml:"reduction_enable,omitempty" json:"reduction_enable,omitempty"`
|
||||||
ReductionRootDir string `yaml:"reduction_root_dir,omitempty" json:"reduction_root_dir,omitempty"` // default: os temp + conversation id
|
ReductionRootDir string `yaml:"reduction_root_dir,omitempty" json:"reduction_root_dir,omitempty"` // 非空:落盘根目录(默认 tmp/reduction);其下按 projects/{id} 或 conversations/{id} 隔离
|
||||||
ReductionMaxLengthForTrunc int `yaml:"reduction_max_length_for_trunc,omitempty" json:"reduction_max_length_for_trunc,omitempty"` // default 12000
|
ReductionMaxLengthForTrunc int `yaml:"reduction_max_length_for_trunc,omitempty" json:"reduction_max_length_for_trunc,omitempty"` // default 12000
|
||||||
ReductionMaxTokensForClear int `yaml:"reduction_max_tokens_for_clear,omitempty" json:"reduction_max_tokens_for_clear,omitempty"` // default 50000
|
ReductionMaxTokensForClear int `yaml:"reduction_max_tokens_for_clear,omitempty" json:"reduction_max_tokens_for_clear,omitempty"` // default 50000
|
||||||
ReductionClearExclude []string `yaml:"reduction_clear_exclude,omitempty" json:"reduction_clear_exclude,omitempty"`
|
ReductionClearExclude []string `yaml:"reduction_clear_exclude,omitempty" json:"reduction_clear_exclude,omitempty"`
|
||||||
@@ -240,7 +249,7 @@ type MultiAgentEinoMiddlewareConfig struct {
|
|||||||
SummarizationTriggerRatio float64 `yaml:"summarization_trigger_ratio,omitempty" json:"summarization_trigger_ratio,omitempty"`
|
SummarizationTriggerRatio float64 `yaml:"summarization_trigger_ratio,omitempty" json:"summarization_trigger_ratio,omitempty"`
|
||||||
// SummarizationEmitInternalEvents controls middleware internal event emission (default true).
|
// SummarizationEmitInternalEvents controls middleware internal event emission (default true).
|
||||||
SummarizationEmitInternalEvents *bool `yaml:"summarization_emit_internal_events,omitempty" json:"summarization_emit_internal_events,omitempty"`
|
SummarizationEmitInternalEvents *bool `yaml:"summarization_emit_internal_events,omitempty" json:"summarization_emit_internal_events,omitempty"`
|
||||||
// SummarizationRetryMaxAttempts is extra retries after the first summarization Generate attempt; 0 = default 3.
|
// SummarizationRetryMaxAttempts 已废弃:summarization 与 run loop 共用 run_retry_max_attempts 及 isEinoTransientRunError。
|
||||||
SummarizationRetryMaxAttempts int `yaml:"summarization_retry_max_attempts,omitempty" json:"summarization_retry_max_attempts,omitempty"`
|
SummarizationRetryMaxAttempts int `yaml:"summarization_retry_max_attempts,omitempty" json:"summarization_retry_max_attempts,omitempty"`
|
||||||
// PlanExecuteUserInputBudgetRatio caps planner/replanner/executor userInput prompt budget ratio (default 0.35).
|
// PlanExecuteUserInputBudgetRatio caps planner/replanner/executor userInput prompt budget ratio (default 0.35).
|
||||||
PlanExecuteUserInputBudgetRatio float64 `yaml:"plan_execute_user_input_budget_ratio,omitempty" json:"plan_execute_user_input_budget_ratio,omitempty"`
|
PlanExecuteUserInputBudgetRatio float64 `yaml:"plan_execute_user_input_budget_ratio,omitempty" json:"plan_execute_user_input_budget_ratio,omitempty"`
|
||||||
@@ -254,9 +263,9 @@ type MultiAgentEinoMiddlewareConfig struct {
|
|||||||
CheckpointDir string `yaml:"checkpoint_dir,omitempty" json:"checkpoint_dir,omitempty"`
|
CheckpointDir string `yaml:"checkpoint_dir,omitempty" json:"checkpoint_dir,omitempty"`
|
||||||
// DeepOutputKey passed to deep.Config OutputKey (session final text); empty = off.
|
// DeepOutputKey passed to deep.Config OutputKey (session final text); empty = off.
|
||||||
DeepOutputKey string `yaml:"deep_output_key,omitempty" json:"deep_output_key,omitempty"`
|
DeepOutputKey string `yaml:"deep_output_key,omitempty" json:"deep_output_key,omitempty"`
|
||||||
// DeepModelRetryMaxRetries > 0 enables deep.Config ModelRetryConfig (framework-level chat model retries).
|
// DeepModelRetryMaxRetries 已废弃:临时错误统一由 run loop 内 isEinoTransientRunError + run_retry_max_attempts 处理。
|
||||||
DeepModelRetryMaxRetries int `yaml:"deep_model_retry_max_retries,omitempty" json:"deep_model_retry_max_retries,omitempty"`
|
DeepModelRetryMaxRetries int `yaml:"deep_model_retry_max_retries,omitempty" json:"deep_model_retry_max_retries,omitempty"`
|
||||||
// RunRetryMaxAttempts > 0:429/5xx/网络抖动时 handler 分段续跑次数;0=默认 10。
|
// RunRetryMaxAttempts > 0:429/5xx/网络抖动时可退避重试次数(run loop 与 summarization 共用);0=默认 10。
|
||||||
RunRetryMaxAttempts int `yaml:"run_retry_max_attempts,omitempty" json:"run_retry_max_attempts,omitempty"`
|
RunRetryMaxAttempts int `yaml:"run_retry_max_attempts,omitempty" json:"run_retry_max_attempts,omitempty"`
|
||||||
// RunRetryMaxBackoffSec 单次退避上限秒数;0=默认 30。
|
// RunRetryMaxBackoffSec 单次退避上限秒数;0=默认 30。
|
||||||
RunRetryMaxBackoffSec int `yaml:"run_retry_max_backoff_sec,omitempty" json:"run_retry_max_backoff_sec,omitempty"`
|
RunRetryMaxBackoffSec int `yaml:"run_retry_max_backoff_sec,omitempty" json:"run_retry_max_backoff_sec,omitempty"`
|
||||||
@@ -594,8 +603,6 @@ type DatabaseConfig struct {
|
|||||||
|
|
||||||
type AgentConfig struct {
|
type AgentConfig struct {
|
||||||
MaxIterations int `yaml:"max_iterations" json:"max_iterations"`
|
MaxIterations int `yaml:"max_iterations" json:"max_iterations"`
|
||||||
LargeResultThreshold int `yaml:"large_result_threshold" json:"large_result_threshold"` // 大结果阈值(字节),默认50KB
|
|
||||||
ResultStorageDir string `yaml:"result_storage_dir" json:"result_storage_dir"` // 结果存储目录,默认tmp
|
|
||||||
ToolTimeoutMinutes int `yaml:"tool_timeout_minutes" json:"tool_timeout_minutes"` // 单次工具执行最大时长(分钟),超时自动终止,防止长时间挂起;0 表示不限制(不推荐)
|
ToolTimeoutMinutes int `yaml:"tool_timeout_minutes" json:"tool_timeout_minutes"` // 单次工具执行最大时长(分钟),超时自动终止,防止长时间挂起;0 表示不限制(不推荐)
|
||||||
// SystemPromptPath 单代理系统提示 Markdown/文本文件路径(相对 config.yaml 所在目录,或可写绝对路径)。非空且可读时替换内置单代理提示;留空用内置。
|
// SystemPromptPath 单代理系统提示 Markdown/文本文件路径(相对 config.yaml 所在目录,或可写绝对路径)。非空且可读时替换内置单代理提示;留空用内置。
|
||||||
SystemPromptPath string `yaml:"system_prompt_path,omitempty" json:"system_prompt_path,omitempty"`
|
SystemPromptPath string `yaml:"system_prompt_path,omitempty" json:"system_prompt_path,omitempty"`
|
||||||
|
|||||||
@@ -69,12 +69,12 @@ func buildAuditLogsWhere(filter ListAuditLogsFilter) (string, []interface{}) {
|
|||||||
args = append(args, filter.ResourceID)
|
args = append(args, filter.ResourceID)
|
||||||
}
|
}
|
||||||
if filter.Since != nil {
|
if filter.Since != nil {
|
||||||
conditions = append(conditions, "created_at >= ?")
|
conditions = append(conditions, sqliteEpochGE("created_at", ">="))
|
||||||
args = append(args, *filter.Since)
|
args = append(args, formatSQLiteUTC(*filter.Since))
|
||||||
}
|
}
|
||||||
if filter.Until != nil {
|
if filter.Until != nil {
|
||||||
conditions = append(conditions, "created_at <= ?")
|
conditions = append(conditions, sqliteEpochGE("created_at", "<="))
|
||||||
args = append(args, *filter.Until)
|
args = append(args, formatSQLiteUTC(*filter.Until))
|
||||||
}
|
}
|
||||||
if q := strings.TrimSpace(filter.Query); q != "" {
|
if q := strings.TrimSpace(filter.Query); q != "" {
|
||||||
like := "%" + q + "%"
|
like := "%" + q + "%"
|
||||||
@@ -93,7 +93,9 @@ func (db *DB) AppendAuditLog(row *AuditLog) error {
|
|||||||
return errors.New("audit id is required")
|
return errors.New("audit id is required")
|
||||||
}
|
}
|
||||||
if row.CreatedAt.IsZero() {
|
if row.CreatedAt.IsZero() {
|
||||||
row.CreatedAt = time.Now()
|
row.CreatedAt = time.Now().UTC()
|
||||||
|
} else {
|
||||||
|
row.CreatedAt = row.CreatedAt.UTC()
|
||||||
}
|
}
|
||||||
if strings.TrimSpace(row.Level) == "" {
|
if strings.TrimSpace(row.Level) == "" {
|
||||||
row.Level = "info"
|
row.Level = "info"
|
||||||
@@ -111,7 +113,7 @@ func (db *DB) AppendAuditLog(row *AuditLog) error {
|
|||||||
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||||
`
|
`
|
||||||
_, err := db.Exec(query,
|
_, err := db.Exec(query,
|
||||||
row.ID, row.CreatedAt, row.Level, row.Category, row.Action, row.Result,
|
row.ID, formatSQLiteUTC(row.CreatedAt), row.Level, row.Category, row.Action, row.Result,
|
||||||
row.Actor, row.SessionHint, row.ClientIP, row.UserAgent,
|
row.Actor, row.SessionHint, row.ClientIP, row.UserAgent,
|
||||||
row.ResourceType, row.ResourceID, row.Message, detailJSON,
|
row.ResourceType, row.ResourceID, row.Message, detailJSON,
|
||||||
)
|
)
|
||||||
@@ -202,7 +204,7 @@ func (db *DB) ListAuditLogs(filter ListAuditLogsFilter) ([]*AuditLog, error) {
|
|||||||
|
|
||||||
// DeleteAuditLogsBefore removes rows older than cutoff.
|
// DeleteAuditLogsBefore removes rows older than cutoff.
|
||||||
func (db *DB) DeleteAuditLogsBefore(cutoff time.Time) (int64, error) {
|
func (db *DB) DeleteAuditLogsBefore(cutoff time.Time) (int64, error) {
|
||||||
res, err := db.Exec(`DELETE FROM audit_logs WHERE created_at < ?`, cutoff)
|
res, err := db.Exec(`DELETE FROM audit_logs WHERE `+sqliteEpochGE("created_at", "<"), formatSQLiteUTC(cutoff))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, err
|
return 0, err
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -0,0 +1,62 @@
|
|||||||
|
package database
|
||||||
|
|
||||||
|
import (
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"go.uber.org/zap"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestBuildAuditLogsWhere_timeFilterSQL(t *testing.T) {
|
||||||
|
since := time.Date(2026, 6, 16, 17, 2, 0, 0, time.UTC)
|
||||||
|
until := time.Date(2026, 6, 17, 3, 3, 0, 0, time.UTC)
|
||||||
|
where, args := buildAuditLogsWhere(ListAuditLogsFilter{Since: &since, Until: &until})
|
||||||
|
if !strings.Contains(where, "strftime('%s', created_at) >=") {
|
||||||
|
t.Fatalf("expected epoch comparison for since, got %q", where)
|
||||||
|
}
|
||||||
|
if !strings.Contains(where, "strftime('%s', created_at) <=") {
|
||||||
|
t.Fatalf("expected epoch comparison for until, got %q", where)
|
||||||
|
}
|
||||||
|
if len(args) != 2 {
|
||||||
|
t.Fatalf("expected 2 time args, got %d", len(args))
|
||||||
|
}
|
||||||
|
for i, arg := range args {
|
||||||
|
s, ok := arg.(string)
|
||||||
|
if !ok || s == "" {
|
||||||
|
t.Fatalf("arg %d: want non-empty UTC RFC3339 string, got %v", i, arg)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestListAuditLogs_timeFilterMixedStorageFormats(t *testing.T) {
|
||||||
|
root, err := os.Getwd()
|
||||||
|
if err != nil {
|
||||||
|
t.Skip(err)
|
||||||
|
}
|
||||||
|
dbPath := filepath.Join(root, "..", "..", "data", "conversations.db")
|
||||||
|
if _, err := os.Stat(dbPath); err != nil {
|
||||||
|
t.Skip("conversations.db not found")
|
||||||
|
}
|
||||||
|
db, err := NewDB(dbPath, zap.NewNop())
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
defer db.Close()
|
||||||
|
|
||||||
|
since, _ := ParseRFC3339Time("2026-06-16T17:02:00Z")
|
||||||
|
until, _ := ParseRFC3339Time("2026-06-17T03:03:00Z")
|
||||||
|
filter := ListAuditLogsFilter{Since: &since, Until: &until, Limit: 50}
|
||||||
|
logs, err := db.ListAuditLogs(filter)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
for _, row := range logs {
|
||||||
|
at := row.CreatedAt.UTC()
|
||||||
|
if at.Before(since) || at.After(until) {
|
||||||
|
t.Fatalf("log %s at %s outside [%s, %s]", row.ID, at, since, until)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -507,6 +507,42 @@ func (db *DB) CancelPendingBatchTasks(queueID string, completedAt time.Time) err
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// PrepareBatchSingleTaskRun 准备单条执行:可选重置子任务,并更新队列索引与状态
|
||||||
|
func (db *DB) PrepareBatchSingleTaskRun(queueID, taskID string, taskIndex int, resetTask, resumeQueue bool) error {
|
||||||
|
tx, err := db.Begin()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("开始事务失败: %w", err)
|
||||||
|
}
|
||||||
|
defer tx.Rollback()
|
||||||
|
|
||||||
|
if resetTask {
|
||||||
|
_, err = tx.Exec(
|
||||||
|
"UPDATE batch_tasks SET status = ?, conversation_id = NULL, started_at = NULL, completed_at = NULL, error = NULL, result = NULL WHERE queue_id = ? AND id = ?",
|
||||||
|
"pending", queueID, taskID,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("重置批量任务状态失败: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if resumeQueue {
|
||||||
|
_, err = tx.Exec(
|
||||||
|
"UPDATE batch_task_queues SET status = ?, current_index = ?, completed_at = NULL, last_run_error = NULL WHERE id = ?",
|
||||||
|
"paused", taskIndex, queueID,
|
||||||
|
)
|
||||||
|
} else {
|
||||||
|
_, err = tx.Exec(
|
||||||
|
"UPDATE batch_task_queues SET current_index = ?, last_run_error = NULL WHERE id = ?",
|
||||||
|
taskIndex, queueID,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("更新批量任务队列状态失败: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return tx.Commit()
|
||||||
|
}
|
||||||
|
|
||||||
// DeleteBatchTask 删除批量任务
|
// DeleteBatchTask 删除批量任务
|
||||||
func (db *DB) DeleteBatchTask(queueID, taskID string) error {
|
func (db *DB) DeleteBatchTask(queueID, taskID string) error {
|
||||||
_, err := db.Exec(
|
_, err := db.Exec(
|
||||||
|
|||||||
@@ -17,6 +17,9 @@ var ErrNoValidC2EventIDs = errors.New("no valid event ids")
|
|||||||
// ErrNoValidC2TaskIDs 批量删除任务时未提供任何合法 ID
|
// ErrNoValidC2TaskIDs 批量删除任务时未提供任何合法 ID
|
||||||
var ErrNoValidC2TaskIDs = errors.New("no valid task ids")
|
var ErrNoValidC2TaskIDs = errors.New("no valid task ids")
|
||||||
|
|
||||||
|
// ErrNoValidC2SessionIDs 批量删除会话时未提供任何合法 ID
|
||||||
|
var ErrNoValidC2SessionIDs = errors.New("no valid session ids")
|
||||||
|
|
||||||
// validC2TextIDForDelete 校验 C2 文本主键(e_/t_/s_/… 等)用于批量删除入参
|
// validC2TextIDForDelete 校验 C2 文本主键(e_/t_/s_/… 等)用于批量删除入参
|
||||||
func validC2TextIDForDelete(id string) bool {
|
func validC2TextIDForDelete(id string) bool {
|
||||||
if len(id) < 2 || len(id) > 80 {
|
if len(id) < 2 || len(id) > 80 {
|
||||||
@@ -473,6 +476,7 @@ type ListC2SessionsFilter struct {
|
|||||||
Status string // active|sleeping|dead|killed;空表示全部
|
Status string // active|sleeping|dead|killed;空表示全部
|
||||||
OS string
|
OS string
|
||||||
Search string // 模糊匹配 hostname/username/internal_ip
|
Search string // 模糊匹配 hostname/username/internal_ip
|
||||||
|
Suspicious bool // 疑似误报:离线且 hostname 为 tcp_* / 用户名为 unknown / PID 为 0
|
||||||
Limit int // 0 表示无限制
|
Limit int // 0 表示无限制
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -497,6 +501,11 @@ func (db *DB) ListC2Sessions(filter ListC2SessionsFilter) ([]*C2Session, error)
|
|||||||
kw := "%" + filter.Search + "%"
|
kw := "%" + filter.Search + "%"
|
||||||
args = append(args, kw, kw, kw)
|
args = append(args, kw, kw, kw)
|
||||||
}
|
}
|
||||||
|
if filter.Suspicious {
|
||||||
|
conditions = append(conditions, `status = 'dead' AND (
|
||||||
|
hostname LIKE 'tcp_%' OR LOWER(COALESCE(username,'')) = 'unknown' OR COALESCE(pid, 0) = 0
|
||||||
|
)`)
|
||||||
|
}
|
||||||
query := `
|
query := `
|
||||||
SELECT id, listener_id, implant_uuid, COALESCE(hostname,''), COALESCE(username,''),
|
SELECT id, listener_id, implant_uuid, COALESCE(hostname,''), COALESCE(username,''),
|
||||||
COALESCE(os,''), COALESCE(arch,''), COALESCE(pid, 0), COALESCE(process_name,''),
|
COALESCE(os,''), COALESCE(arch,''), COALESCE(pid, 0), COALESCE(process_name,''),
|
||||||
@@ -554,6 +563,44 @@ func (db *DB) DeleteC2Session(id string) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// DeleteC2SessionsByIDs 按主键批量删除会话
|
||||||
|
func (db *DB) DeleteC2SessionsByIDs(ids []string) (int64, error) {
|
||||||
|
if len(ids) == 0 {
|
||||||
|
return 0, nil
|
||||||
|
}
|
||||||
|
const maxBatch = 500
|
||||||
|
if len(ids) > maxBatch {
|
||||||
|
ids = ids[:maxBatch]
|
||||||
|
}
|
||||||
|
clean := make([]string, 0, len(ids))
|
||||||
|
seen := make(map[string]struct{}, len(ids))
|
||||||
|
for _, id := range ids {
|
||||||
|
id = strings.TrimSpace(id)
|
||||||
|
if !validC2TextIDForDelete(id) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if _, ok := seen[id]; ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
seen[id] = struct{}{}
|
||||||
|
clean = append(clean, id)
|
||||||
|
}
|
||||||
|
if len(clean) == 0 {
|
||||||
|
return 0, ErrNoValidC2SessionIDs
|
||||||
|
}
|
||||||
|
placeholders := strings.Repeat("?,", len(clean)-1) + "?"
|
||||||
|
args := make([]interface{}, len(clean))
|
||||||
|
for i := range clean {
|
||||||
|
args[i] = clean[i]
|
||||||
|
}
|
||||||
|
query := `DELETE FROM c2_sessions WHERE id IN (` + placeholders + `)`
|
||||||
|
res, err := db.Exec(query, args...)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
return res.RowsAffected()
|
||||||
|
}
|
||||||
|
|
||||||
// ----------------------------------------------------------------------------
|
// ----------------------------------------------------------------------------
|
||||||
// CRUD:C2 任务
|
// CRUD:C2 任务
|
||||||
// ----------------------------------------------------------------------------
|
// ----------------------------------------------------------------------------
|
||||||
|
|||||||
@@ -382,26 +382,40 @@ func (db *DB) CountConversations(search string) (int, error) {
|
|||||||
return count, nil
|
return count, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func conversationOrderClause(sortBy, tableAlias string) string {
|
||||||
|
col := "updated_at"
|
||||||
|
if strings.TrimSpace(strings.ToLower(sortBy)) == "created_at" {
|
||||||
|
col = "created_at"
|
||||||
|
}
|
||||||
|
prefix := tableAlias
|
||||||
|
if prefix != "" {
|
||||||
|
prefix += "."
|
||||||
|
}
|
||||||
|
return "ORDER BY " + prefix + col + " DESC"
|
||||||
|
}
|
||||||
|
|
||||||
// ListConversations 列出所有对话
|
// ListConversations 列出所有对话
|
||||||
func (db *DB) ListConversations(limit, offset int, search string) ([]*Conversation, error) {
|
func (db *DB) ListConversations(limit, offset int, search, sortBy string) ([]*Conversation, error) {
|
||||||
var rows *sql.Rows
|
var rows *sql.Rows
|
||||||
var err error
|
var err error
|
||||||
|
|
||||||
if search != "" {
|
if search != "" {
|
||||||
// 使用 EXISTS 子查询代替 LEFT JOIN + DISTINCT,避免大表笛卡尔积
|
// 使用 EXISTS 子查询代替 LEFT JOIN + DISTINCT,避免大表笛卡尔积
|
||||||
searchPattern := "%" + search + "%"
|
searchPattern := "%" + search + "%"
|
||||||
|
orderClause := conversationOrderClause(sortBy, "c")
|
||||||
rows, err = db.Query(
|
rows, err = db.Query(
|
||||||
`SELECT c.id, c.title, COALESCE(c.pinned, 0), c.created_at, c.updated_at, c.project_id
|
`SELECT c.id, c.title, COALESCE(c.pinned, 0), c.created_at, c.updated_at, c.project_id
|
||||||
FROM conversations c
|
FROM conversations c
|
||||||
WHERE c.title LIKE ?
|
WHERE c.title LIKE ?
|
||||||
OR EXISTS (SELECT 1 FROM messages m WHERE m.conversation_id = c.id AND m.content LIKE ?)
|
OR EXISTS (SELECT 1 FROM messages m WHERE m.conversation_id = c.id AND m.content LIKE ?)
|
||||||
ORDER BY c.updated_at DESC
|
`+orderClause+`
|
||||||
LIMIT ? OFFSET ?`,
|
LIMIT ? OFFSET ?`,
|
||||||
searchPattern, searchPattern, limit, offset,
|
searchPattern, searchPattern, limit, offset,
|
||||||
)
|
)
|
||||||
} else {
|
} else {
|
||||||
|
orderClause := conversationOrderClause(sortBy, "")
|
||||||
rows, err = db.Query(
|
rows, err = db.Query(
|
||||||
"SELECT id, title, COALESCE(pinned, 0), created_at, updated_at, project_id FROM conversations ORDER BY updated_at DESC LIMIT ? OFFSET ?",
|
"SELECT id, title, COALESCE(pinned, 0), created_at, updated_at, project_id FROM conversations "+orderClause+" LIMIT ? OFFSET ?",
|
||||||
limit, offset,
|
limit, offset,
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
@@ -467,11 +481,12 @@ func (db *DB) CountUngroupedConversations() (int, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// ListUngroupedConversations 列出不在任何分组中的对话(最近对话侧栏)。
|
// ListUngroupedConversations 列出不在任何分组中的对话(最近对话侧栏)。
|
||||||
func (db *DB) ListUngroupedConversations(limit, offset int) ([]*Conversation, error) {
|
func (db *DB) ListUngroupedConversations(limit, offset int, sortBy string) ([]*Conversation, error) {
|
||||||
|
orderClause := conversationOrderClause(sortBy, "c")
|
||||||
rows, err := db.Query(
|
rows, err := db.Query(
|
||||||
`SELECT c.id, c.title, COALESCE(c.pinned, 0), c.created_at, c.updated_at, c.project_id `+
|
`SELECT c.id, c.title, COALESCE(c.pinned, 0), c.created_at, c.updated_at, c.project_id `+
|
||||||
ungroupedConversationsSQL+`
|
ungroupedConversationsSQL+`
|
||||||
ORDER BY c.updated_at DESC
|
`+orderClause+`
|
||||||
LIMIT ? OFFSET ?`,
|
LIMIT ? OFFSET ?`,
|
||||||
limit, offset,
|
limit, offset,
|
||||||
)
|
)
|
||||||
@@ -570,12 +585,14 @@ func (db *DB) DeleteConversation(id string) error {
|
|||||||
// 不返回错误,继续删除对话
|
// 不返回错误,继续删除对话
|
||||||
}
|
}
|
||||||
|
|
||||||
|
projectID, _ := db.GetConversationProjectID(id)
|
||||||
|
|
||||||
// 删除对话(外键CASCADE会自动删除其他相关数据)
|
// 删除对话(外键CASCADE会自动删除其他相关数据)
|
||||||
_, err = db.Exec("DELETE FROM conversations WHERE id = ?", id)
|
_, err = db.Exec("DELETE FROM conversations WHERE id = ?", id)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("删除对话失败: %w", err)
|
return fmt.Errorf("删除对话失败: %w", err)
|
||||||
}
|
}
|
||||||
db.removeConversationScopedDirs(id)
|
db.removeConversationScopedDirs(id, projectID)
|
||||||
|
|
||||||
db.logger.Info("对话已删除(漏洞记录已保留)", zap.String("conversationId", id))
|
db.logger.Info("对话已删除(漏洞记录已保留)", zap.String("conversationId", id))
|
||||||
return nil
|
return nil
|
||||||
@@ -613,13 +630,35 @@ func (db *DB) removeConversationScopedDir(base, conversationID, label string) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (db *DB) removeConversationScopedDirs(conversationID string) {
|
func (db *DB) einoReductionBaseDir() string {
|
||||||
// summarization transcript, reduction files, etc.
|
if db == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
if base := strings.TrimSpace(db.einoReductionRootDir); base != "" {
|
||||||
|
return base
|
||||||
|
}
|
||||||
|
return filepath.Join("tmp", "reduction")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (db *DB) removeConversationScopedDirs(conversationID, projectID string) {
|
||||||
|
// summarization transcript, etc.
|
||||||
db.removeConversationScopedDir(db.conversationArtifactsDir, conversationID, "conversation_artifacts")
|
db.removeConversationScopedDir(db.conversationArtifactsDir, conversationID, "conversation_artifacts")
|
||||||
// Eino plantask JSON boards (skills_dir/.eino/plantask/<id>/).
|
// Eino plantask JSON boards (skills_dir/.eino/plantask/<id>/).
|
||||||
db.removeConversationScopedDir(db.einoPlantaskBaseDir, conversationID, "plantask")
|
db.removeConversationScopedDir(db.einoPlantaskBaseDir, conversationID, "plantask")
|
||||||
// Eino ADK runner checkpoints (checkpoint_dir/<id>/).
|
// Eino ADK runner checkpoints (checkpoint_dir/<id>/).
|
||||||
db.removeConversationScopedDir(db.einoCheckpointBaseDir, conversationID, "eino_checkpoint")
|
db.removeConversationScopedDir(db.einoCheckpointBaseDir, conversationID, "eino_checkpoint")
|
||||||
|
// Eino reduction persisted tool outputs (tmp/reduction/conversations/<id>/).
|
||||||
|
// Project-bound sessions share projects/<id>/ — skip on single conversation delete.
|
||||||
|
if strings.TrimSpace(projectID) == "" {
|
||||||
|
reductionBase := filepath.Join(db.einoReductionBaseDir(), "conversations")
|
||||||
|
db.removeConversationScopedDir(reductionBase, conversationID, "reduction")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (db *DB) removeProjectScopedDirs(projectID string) {
|
||||||
|
// Eino reduction persisted tool outputs (tmp/reduction/projects/<id>/).
|
||||||
|
reductionBase := filepath.Join(db.einoReductionBaseDir(), "projects")
|
||||||
|
db.removeConversationScopedDir(reductionBase, projectID, "reduction")
|
||||||
}
|
}
|
||||||
|
|
||||||
// SaveAgentTrace 保存最后一轮代理消息轨迹与助手输出摘要。
|
// SaveAgentTrace 保存最后一轮代理消息轨迹与助手输出摘要。
|
||||||
|
|||||||
@@ -19,7 +19,8 @@ func TestDeleteConversationRemovesEinoScopedDirs(t *testing.T) {
|
|||||||
|
|
||||||
plantaskBase := filepath.Join(tmp, "skills", ".eino", "plantask")
|
plantaskBase := filepath.Join(tmp, "skills", ".eino", "plantask")
|
||||||
checkpointBase := filepath.Join(tmp, "eino-checkpoints")
|
checkpointBase := filepath.Join(tmp, "eino-checkpoints")
|
||||||
db.SetEinoConversationDirs(plantaskBase, checkpointBase)
|
reductionBase := filepath.Join(tmp, "reduction")
|
||||||
|
db.SetEinoConversationDirs(plantaskBase, checkpointBase, reductionBase)
|
||||||
|
|
||||||
conv, err := db.CreateConversation("cleanup test", ConversationCreateMeta{})
|
conv, err := db.CreateConversation("cleanup test", ConversationCreateMeta{})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -34,6 +35,7 @@ func TestDeleteConversationRemovesEinoScopedDirs(t *testing.T) {
|
|||||||
{db.conversationArtifactsDir, "transcript.txt"},
|
{db.conversationArtifactsDir, "transcript.txt"},
|
||||||
{plantaskBase, "task-1.json"},
|
{plantaskBase, "task-1.json"},
|
||||||
{checkpointBase, "runner-deep.ckpt"},
|
{checkpointBase, "runner-deep.ckpt"},
|
||||||
|
{filepath.Join(reductionBase, "conversations"), "tool-output.txt"},
|
||||||
} {
|
} {
|
||||||
dir := filepath.Join(base.root, seg)
|
dir := filepath.Join(base.root, seg)
|
||||||
if err := os.MkdirAll(dir, 0o755); err != nil {
|
if err := os.MkdirAll(dir, 0o755); err != nil {
|
||||||
@@ -48,10 +50,45 @@ func TestDeleteConversationRemovesEinoScopedDirs(t *testing.T) {
|
|||||||
t.Fatalf("DeleteConversation: %v", err)
|
t.Fatalf("DeleteConversation: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, base := range []string{db.conversationArtifactsDir, plantaskBase, checkpointBase} {
|
for _, base := range []string{db.conversationArtifactsDir, plantaskBase, checkpointBase, filepath.Join(reductionBase, "conversations")} {
|
||||||
dir := filepath.Join(base, seg)
|
dir := filepath.Join(base, seg)
|
||||||
if _, statErr := os.Stat(dir); !os.IsNotExist(statErr) {
|
if _, statErr := os.Stat(dir); !os.IsNotExist(statErr) {
|
||||||
t.Fatalf("expected removed dir %s, stat err=%v", dir, statErr)
|
t.Fatalf("expected removed dir %s, stat err=%v", dir, statErr)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestDeleteProjectRemovesReductionDir(t *testing.T) {
|
||||||
|
tmp := t.TempDir()
|
||||||
|
dbPath := filepath.Join(tmp, "conversations.db")
|
||||||
|
db, err := NewDB(dbPath, zap.NewNop())
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("NewDB: %v", err)
|
||||||
|
}
|
||||||
|
defer db.Close()
|
||||||
|
|
||||||
|
reductionBase := filepath.Join(tmp, "reduction")
|
||||||
|
db.SetEinoConversationDirs("", "", reductionBase)
|
||||||
|
|
||||||
|
project, err := db.CreateProject(&Project{Name: "cleanup test"})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("CreateProject: %v", err)
|
||||||
|
}
|
||||||
|
seg := sanitizeConversationPathSegment(project.ID)
|
||||||
|
reductionDir := filepath.Join(reductionBase, "projects", seg, "clear")
|
||||||
|
if err := os.MkdirAll(reductionDir, 0o755); err != nil {
|
||||||
|
t.Fatalf("mkdir %s: %v", reductionDir, err)
|
||||||
|
}
|
||||||
|
if err := os.WriteFile(filepath.Join(reductionDir, "call-1.txt"), []byte("x"), 0o644); err != nil {
|
||||||
|
t.Fatalf("write: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := db.DeleteProject(project.ID); err != nil {
|
||||||
|
t.Fatalf("DeleteProject: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
projectReductionDir := filepath.Join(reductionBase, "projects", seg)
|
||||||
|
if _, statErr := os.Stat(projectReductionDir); !os.IsNotExist(statErr) {
|
||||||
|
t.Fatalf("expected removed dir %s, stat err=%v", projectReductionDir, statErr)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -51,6 +51,7 @@ type DB struct {
|
|||||||
conversationArtifactsDir string
|
conversationArtifactsDir string
|
||||||
einoPlantaskBaseDir string // skills_dir + plantask_rel_dir (per-conversation subdirs)
|
einoPlantaskBaseDir string // skills_dir + plantask_rel_dir (per-conversation subdirs)
|
||||||
einoCheckpointBaseDir string // checkpoint_dir root (per-conversation subdirs)
|
einoCheckpointBaseDir string // checkpoint_dir root (per-conversation subdirs)
|
||||||
|
einoReductionRootDir string // reduction_root_dir or default tmp/reduction (conversations/<id> subdirs)
|
||||||
checkpointLoopName string
|
checkpointLoopName string
|
||||||
checkpointStop chan struct{}
|
checkpointStop chan struct{}
|
||||||
checkpointDone chan struct{}
|
checkpointDone chan struct{}
|
||||||
@@ -159,12 +160,14 @@ func NewDB(dbPath string, logger *zap.Logger) (*DB, error) {
|
|||||||
|
|
||||||
// SetEinoConversationDirs configures best-effort filesystem cleanup on DeleteConversation.
|
// SetEinoConversationDirs configures best-effort filesystem cleanup on DeleteConversation.
|
||||||
// plantaskBase is skills_root/plantask_rel (no conversation id); checkpointBase is checkpoint_dir root.
|
// plantaskBase is skills_root/plantask_rel (no conversation id); checkpointBase is checkpoint_dir root.
|
||||||
func (db *DB) SetEinoConversationDirs(plantaskBase, checkpointBase string) {
|
// reductionRoot is reduction_root_dir from config; empty uses tmp/reduction (conversation-scoped subdirs only).
|
||||||
|
func (db *DB) SetEinoConversationDirs(plantaskBase, checkpointBase, reductionRoot string) {
|
||||||
if db == nil {
|
if db == nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
db.einoPlantaskBaseDir = strings.TrimSpace(plantaskBase)
|
db.einoPlantaskBaseDir = strings.TrimSpace(plantaskBase)
|
||||||
db.einoCheckpointBaseDir = strings.TrimSpace(checkpointBase)
|
db.einoCheckpointBaseDir = strings.TrimSpace(checkpointBase)
|
||||||
|
db.einoReductionRootDir = strings.TrimSpace(reductionRoot)
|
||||||
}
|
}
|
||||||
|
|
||||||
// initTables 初始化数据库表
|
// initTables 初始化数据库表
|
||||||
@@ -353,6 +356,22 @@ func (db *DB) initTables() error {
|
|||||||
UNIQUE(project_id, fact_key)
|
UNIQUE(project_id, fact_key)
|
||||||
);`
|
);`
|
||||||
|
|
||||||
|
// 项目事实关系边(黑板 DAG)
|
||||||
|
createProjectFactEdgesTable := `
|
||||||
|
CREATE TABLE IF NOT EXISTS project_fact_edges (
|
||||||
|
id TEXT PRIMARY KEY,
|
||||||
|
project_id TEXT NOT NULL,
|
||||||
|
source_fact_key TEXT NOT NULL,
|
||||||
|
target_fact_key TEXT NOT NULL,
|
||||||
|
edge_type TEXT NOT NULL,
|
||||||
|
confidence TEXT NOT NULL DEFAULT 'tentative',
|
||||||
|
source_conversation_id TEXT,
|
||||||
|
created_at DATETIME NOT NULL,
|
||||||
|
updated_at DATETIME NOT NULL,
|
||||||
|
FOREIGN KEY (project_id) REFERENCES projects(id) ON DELETE CASCADE,
|
||||||
|
UNIQUE(project_id, source_fact_key, target_fact_key, edge_type)
|
||||||
|
);`
|
||||||
|
|
||||||
// 创建漏洞表
|
// 创建漏洞表
|
||||||
createVulnerabilitiesTable := `
|
createVulnerabilitiesTable := `
|
||||||
CREATE TABLE IF NOT EXISTS vulnerabilities (
|
CREATE TABLE IF NOT EXISTS vulnerabilities (
|
||||||
@@ -591,6 +610,9 @@ func (db *DB) initTables() error {
|
|||||||
CREATE INDEX IF NOT EXISTS idx_project_facts_project_id ON project_facts(project_id);
|
CREATE INDEX IF NOT EXISTS idx_project_facts_project_id ON project_facts(project_id);
|
||||||
CREATE INDEX IF NOT EXISTS idx_project_facts_confidence ON project_facts(confidence);
|
CREATE INDEX IF NOT EXISTS idx_project_facts_confidence ON project_facts(confidence);
|
||||||
CREATE INDEX IF NOT EXISTS idx_project_facts_related_vuln ON project_facts(related_vulnerability_id);
|
CREATE INDEX IF NOT EXISTS idx_project_facts_related_vuln ON project_facts(related_vulnerability_id);
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_project_fact_edges_project ON project_fact_edges(project_id);
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_project_fact_edges_source ON project_fact_edges(project_id, source_fact_key);
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_project_fact_edges_target ON project_fact_edges(project_id, target_fact_key);
|
||||||
CREATE INDEX IF NOT EXISTS idx_conversations_project_id ON conversations(project_id);
|
CREATE INDEX IF NOT EXISTS idx_conversations_project_id ON conversations(project_id);
|
||||||
CREATE INDEX IF NOT EXISTS idx_vulnerabilities_project_id ON vulnerabilities(project_id);
|
CREATE INDEX IF NOT EXISTS idx_vulnerabilities_project_id ON vulnerabilities(project_id);
|
||||||
CREATE INDEX IF NOT EXISTS idx_batch_tasks_queue_id ON batch_tasks(queue_id);
|
CREATE INDEX IF NOT EXISTS idx_batch_tasks_queue_id ON batch_tasks(queue_id);
|
||||||
@@ -672,6 +694,10 @@ func (db *DB) initTables() error {
|
|||||||
return fmt.Errorf("创建project_facts表失败: %w", err)
|
return fmt.Errorf("创建project_facts表失败: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if _, err := db.Exec(createProjectFactEdgesTable); err != nil {
|
||||||
|
return fmt.Errorf("创建project_fact_edges表失败: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
if _, err := db.Exec(createVulnerabilitiesTable); err != nil {
|
if _, err := db.Exec(createVulnerabilitiesTable); err != nil {
|
||||||
return fmt.Errorf("创建vulnerabilities表失败: %w", err)
|
return fmt.Errorf("创建vulnerabilities表失败: %w", err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -72,6 +72,23 @@ func (db *DB) SaveToolExecution(exec *mcp.ToolExecution) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// UpdateToolExecutionResult 仅更新结果字段(用于 reduction 后将监控展示与模型上下文对齐)。
|
||||||
|
func (db *DB) UpdateToolExecutionResult(id string, result *mcp.ToolResult) error {
|
||||||
|
id = strings.TrimSpace(id)
|
||||||
|
if id == "" || result == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
resultBytes, err := json.Marshal(result)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
_, err = db.Exec(`UPDATE tool_executions SET result = ? WHERE id = ?`, string(resultBytes), id)
|
||||||
|
if err != nil {
|
||||||
|
db.logger.Warn("更新工具执行结果失败", zap.Error(err), zap.String("executionId", id))
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
// CountToolExecutions 统计工具执行记录总数
|
// CountToolExecutions 统计工具执行记录总数
|
||||||
func (db *DB) CountToolExecutions(status, toolName string) (int, error) {
|
func (db *DB) CountToolExecutions(status, toolName string) (int, error) {
|
||||||
query := `SELECT COUNT(*) FROM tool_executions`
|
query := `SELECT COUNT(*) FROM tool_executions`
|
||||||
|
|||||||
@@ -195,6 +195,7 @@ func (db *DB) DeleteProject(id string) error {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("删除项目失败: %w", err)
|
return fmt.Errorf("删除项目失败: %w", err)
|
||||||
}
|
}
|
||||||
|
db.removeProjectScopedDirs(id)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -389,7 +390,7 @@ func (db *DB) UpsertProjectFact(f *ProjectFact) (*ProjectFact, error) {
|
|||||||
return f, nil
|
return f, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// DeprecateProjectFact 将事实标记为 deprecated。
|
// DeprecateProjectFact 将事实标记为 deprecated(关联边同步 deprecated)。
|
||||||
func (db *DB) DeprecateProjectFact(projectID, factKey string) error {
|
func (db *DB) DeprecateProjectFact(projectID, factKey string) error {
|
||||||
res, err := db.Exec(
|
res, err := db.Exec(
|
||||||
`UPDATE project_facts SET confidence = 'deprecated', updated_at = ? WHERE project_id = ? AND fact_key = ?`,
|
`UPDATE project_facts SET confidence = 'deprecated', updated_at = ? WHERE project_id = ? AND fact_key = ?`,
|
||||||
@@ -402,7 +403,7 @@ func (db *DB) DeprecateProjectFact(projectID, factKey string) error {
|
|||||||
if n == 0 {
|
if n == 0 {
|
||||||
return fmt.Errorf("事实不存在")
|
return fmt.Errorf("事实不存在")
|
||||||
}
|
}
|
||||||
return nil
|
return db.DeprecateProjectFactEdgesForKey(projectID, factKey)
|
||||||
}
|
}
|
||||||
|
|
||||||
// RestoreProjectFact 将已废弃事实恢复为 tentative 或 confirmed(重新参与黑板索引)。
|
// RestoreProjectFact 将已废弃事实恢复为 tentative 或 confirmed(重新参与黑板索引)。
|
||||||
@@ -430,9 +431,16 @@ func (db *DB) RestoreProjectFact(projectID, factKey, confidence string) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// DeleteProjectFact 删除事实。
|
// DeleteProjectFact 删除事实(级联删除相关边)。
|
||||||
func (db *DB) DeleteProjectFact(id string) error {
|
func (db *DB) DeleteProjectFact(id string) error {
|
||||||
_, err := db.Exec(`DELETE FROM project_facts WHERE id = ?`, id)
|
f, err := db.GetProjectFact(id)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if err := db.DeleteProjectFactEdgesForKey(f.ProjectID, f.FactKey); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
_, err = db.Exec(`DELETE FROM project_facts WHERE id = ?`, id)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -0,0 +1,410 @@
|
|||||||
|
package database
|
||||||
|
|
||||||
|
import (
|
||||||
|
"database/sql"
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/google/uuid"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ValidProjectFactEdgeTypes 项目事实图允许的边类型。
|
||||||
|
var ValidProjectFactEdgeTypes = map[string]struct{}{
|
||||||
|
"depends_on": {},
|
||||||
|
"leads_to": {},
|
||||||
|
"enables": {},
|
||||||
|
"exploits": {},
|
||||||
|
"discovered_on": {},
|
||||||
|
"contains": {},
|
||||||
|
"part_of": {},
|
||||||
|
"supports": {},
|
||||||
|
}
|
||||||
|
|
||||||
|
// ProjectFactEdge 项目事实关系边(source → target)。
|
||||||
|
type ProjectFactEdge struct {
|
||||||
|
ID string `json:"id"`
|
||||||
|
ProjectID string `json:"project_id"`
|
||||||
|
SourceFactKey string `json:"source_fact_key"`
|
||||||
|
TargetFactKey string `json:"target_fact_key"`
|
||||||
|
EdgeType string `json:"edge_type"`
|
||||||
|
Confidence string `json:"confidence"` // confirmed | tentative | deprecated
|
||||||
|
SourceConversationID string `json:"source_conversation_id,omitempty"`
|
||||||
|
CreatedAt time.Time `json:"created_at"`
|
||||||
|
UpdatedAt time.Time `json:"updated_at"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// ProjectFactEdgeInput 写入边时的输入(出边:source → To)。
|
||||||
|
type ProjectFactEdgeInput struct {
|
||||||
|
To string `json:"to"`
|
||||||
|
Type string `json:"type"`
|
||||||
|
Confidence string `json:"confidence,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// ProjectFactEdgeFromInput 写入入边时的输入(From → 当前事实)。
|
||||||
|
type ProjectFactEdgeFromInput struct {
|
||||||
|
From string `json:"from"`
|
||||||
|
Type string `json:"type"`
|
||||||
|
Confidence string `json:"confidence,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// ProjectFactGraphNode 图 API 节点。
|
||||||
|
type ProjectFactGraphNode struct {
|
||||||
|
ID string `json:"id"`
|
||||||
|
FactKey string `json:"fact_key"`
|
||||||
|
Category string `json:"category"`
|
||||||
|
Label string `json:"label"` // 图节点短标签(截断)
|
||||||
|
Summary string `json:"summary"` // 完整摘要(侧栏等详情用)
|
||||||
|
Confidence string `json:"confidence"`
|
||||||
|
Type string `json:"type"`
|
||||||
|
Pinned bool `json:"pinned"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// ProjectFactGraphEdge 图 API 边。
|
||||||
|
type ProjectFactGraphEdge struct {
|
||||||
|
ID string `json:"id"`
|
||||||
|
Source string `json:"source"`
|
||||||
|
Target string `json:"target"`
|
||||||
|
Type string `json:"type"`
|
||||||
|
Confidence string `json:"confidence"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// ProjectFactGraph 项目事实图。
|
||||||
|
type ProjectFactGraph struct {
|
||||||
|
Nodes []ProjectFactGraphNode `json:"nodes"`
|
||||||
|
Edges []ProjectFactGraphEdge `json:"edges"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// ValidateProjectFactEdgeType 校验边类型。
|
||||||
|
func ValidateProjectFactEdgeType(edgeType string) error {
|
||||||
|
edgeType = strings.TrimSpace(strings.ToLower(edgeType))
|
||||||
|
if edgeType == "" {
|
||||||
|
return fmt.Errorf("edge type 不能为空")
|
||||||
|
}
|
||||||
|
if _, ok := ValidProjectFactEdgeTypes[edgeType]; !ok {
|
||||||
|
return fmt.Errorf("无效的 edge type: %s", edgeType)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func normalizeEdgeConfidence(confidence string) string {
|
||||||
|
confidence = strings.TrimSpace(strings.ToLower(confidence))
|
||||||
|
switch confidence {
|
||||||
|
case "confirmed", "deprecated":
|
||||||
|
return confidence
|
||||||
|
default:
|
||||||
|
return "tentative"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ListProjectFactEdgesByProject 列出项目全部边。
|
||||||
|
func (db *DB) ListProjectFactEdgesByProject(projectID string) ([]*ProjectFactEdge, error) {
|
||||||
|
rows, err := db.Query(
|
||||||
|
`SELECT id, project_id, source_fact_key, target_fact_key, edge_type, confidence,
|
||||||
|
COALESCE(source_conversation_id,''), created_at, updated_at
|
||||||
|
FROM project_fact_edges
|
||||||
|
WHERE project_id = ?
|
||||||
|
ORDER BY created_at ASC, rowid ASC`,
|
||||||
|
projectID,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer rows.Close()
|
||||||
|
return scanProjectFactEdges(rows)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ListOutgoingProjectFactEdges 列出某事实的全部出边。
|
||||||
|
func (db *DB) ListOutgoingProjectFactEdges(projectID, sourceFactKey string) ([]*ProjectFactEdge, error) {
|
||||||
|
rows, err := db.Query(
|
||||||
|
`SELECT id, project_id, source_fact_key, target_fact_key, edge_type, confidence,
|
||||||
|
COALESCE(source_conversation_id,''), created_at, updated_at
|
||||||
|
FROM project_fact_edges
|
||||||
|
WHERE project_id = ? AND source_fact_key = ?
|
||||||
|
ORDER BY created_at ASC, rowid ASC`,
|
||||||
|
projectID, sourceFactKey,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer rows.Close()
|
||||||
|
return scanProjectFactEdges(rows)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ListIncomingProjectFactEdges 列出某事实的全部入边。
|
||||||
|
func (db *DB) ListIncomingProjectFactEdges(projectID, targetFactKey string) ([]*ProjectFactEdge, error) {
|
||||||
|
rows, err := db.Query(
|
||||||
|
`SELECT id, project_id, source_fact_key, target_fact_key, edge_type, confidence,
|
||||||
|
COALESCE(source_conversation_id,''), created_at, updated_at
|
||||||
|
FROM project_fact_edges
|
||||||
|
WHERE project_id = ? AND target_fact_key = ?
|
||||||
|
ORDER BY created_at ASC, rowid ASC`,
|
||||||
|
projectID, targetFactKey,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer rows.Close()
|
||||||
|
return scanProjectFactEdges(rows)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ReplaceOutgoingProjectFactEdges 替换某事实的全部出边(links 省略时不调用)。
|
||||||
|
func (db *DB) ReplaceOutgoingProjectFactEdges(projectID, sourceFactKey, sourceConversationID string, inputs []ProjectFactEdgeInput) error {
|
||||||
|
sourceFactKey = strings.TrimSpace(sourceFactKey)
|
||||||
|
if sourceFactKey == "" {
|
||||||
|
return fmt.Errorf("source_fact_key 不能为空")
|
||||||
|
}
|
||||||
|
if _, err := db.Exec(
|
||||||
|
`DELETE FROM project_fact_edges WHERE project_id = ? AND source_fact_key = ?`,
|
||||||
|
projectID, sourceFactKey,
|
||||||
|
); err != nil {
|
||||||
|
return fmt.Errorf("清除旧边失败: %w", err)
|
||||||
|
}
|
||||||
|
for _, in := range inputs {
|
||||||
|
target := strings.TrimSpace(in.To)
|
||||||
|
if target == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if err := ValidateFactKey(target); err != nil {
|
||||||
|
return fmt.Errorf("target fact_key 无效 (%s): %w", target, err)
|
||||||
|
}
|
||||||
|
if target == sourceFactKey {
|
||||||
|
return fmt.Errorf("边不能指向自身: %s", sourceFactKey)
|
||||||
|
}
|
||||||
|
if err := ValidateProjectFactEdgeType(in.Type); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
edge := &ProjectFactEdge{
|
||||||
|
ID: uuid.New().String(),
|
||||||
|
ProjectID: projectID,
|
||||||
|
SourceFactKey: sourceFactKey,
|
||||||
|
TargetFactKey: target,
|
||||||
|
EdgeType: strings.ToLower(strings.TrimSpace(in.Type)),
|
||||||
|
Confidence: normalizeEdgeConfidence(in.Confidence),
|
||||||
|
SourceConversationID: sourceConversationID,
|
||||||
|
CreatedAt: time.Now(),
|
||||||
|
UpdatedAt: time.Now(),
|
||||||
|
}
|
||||||
|
if err := db.insertProjectFactEdge(edge); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ReplaceIncomingProjectFactEdges 替换某事实的全部入边(From 为来源 fact_key)。
|
||||||
|
func (db *DB) ReplaceIncomingProjectFactEdges(projectID, targetFactKey string, inputs []ProjectFactEdgeFromInput) error {
|
||||||
|
targetFactKey = strings.TrimSpace(targetFactKey)
|
||||||
|
if targetFactKey == "" {
|
||||||
|
return fmt.Errorf("target_fact_key 不能为空")
|
||||||
|
}
|
||||||
|
if _, err := db.Exec(
|
||||||
|
`DELETE FROM project_fact_edges WHERE project_id = ? AND target_fact_key = ?`,
|
||||||
|
projectID, targetFactKey,
|
||||||
|
); err != nil {
|
||||||
|
return fmt.Errorf("清除旧入边失败: %w", err)
|
||||||
|
}
|
||||||
|
for _, in := range inputs {
|
||||||
|
source := strings.TrimSpace(in.From)
|
||||||
|
if source == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if err := ValidateFactKey(source); err != nil {
|
||||||
|
return fmt.Errorf("source fact_key 无效 (%s): %w", source, err)
|
||||||
|
}
|
||||||
|
if source == targetFactKey {
|
||||||
|
return fmt.Errorf("边不能指向自身: %s", targetFactKey)
|
||||||
|
}
|
||||||
|
if err := ValidateProjectFactEdgeType(in.Type); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
sourceConversationID := ""
|
||||||
|
if srcFact, err := db.GetProjectFactByKey(projectID, source); err == nil && srcFact != nil {
|
||||||
|
sourceConversationID = srcFact.SourceConversationID
|
||||||
|
}
|
||||||
|
edge := &ProjectFactEdge{
|
||||||
|
ID: uuid.New().String(),
|
||||||
|
ProjectID: projectID,
|
||||||
|
SourceFactKey: source,
|
||||||
|
TargetFactKey: targetFactKey,
|
||||||
|
EdgeType: strings.ToLower(strings.TrimSpace(in.Type)),
|
||||||
|
Confidence: normalizeEdgeConfidence(in.Confidence),
|
||||||
|
SourceConversationID: sourceConversationID,
|
||||||
|
CreatedAt: time.Now(),
|
||||||
|
UpdatedAt: time.Now(),
|
||||||
|
}
|
||||||
|
if err := db.insertProjectFactEdge(edge); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetProjectFactEdge 按 ID 获取边。
|
||||||
|
func (db *DB) GetProjectFactEdge(edgeID string) (*ProjectFactEdge, error) {
|
||||||
|
var e ProjectFactEdge
|
||||||
|
var createdAt, updatedAt string
|
||||||
|
err := db.QueryRow(
|
||||||
|
`SELECT id, project_id, source_fact_key, target_fact_key, edge_type, confidence,
|
||||||
|
COALESCE(source_conversation_id,''), created_at, updated_at
|
||||||
|
FROM project_fact_edges WHERE id = ?`, edgeID,
|
||||||
|
).Scan(&e.ID, &e.ProjectID, &e.SourceFactKey, &e.TargetFactKey, &e.EdgeType, &e.Confidence,
|
||||||
|
&e.SourceConversationID, &createdAt, &updatedAt)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("边不存在")
|
||||||
|
}
|
||||||
|
e.CreatedAt = parseDBTime(createdAt)
|
||||||
|
e.UpdatedAt = parseDBTime(updatedAt)
|
||||||
|
return &e, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddProjectFactEdge 新增单条边(已存在则更新 confidence)。
|
||||||
|
func (db *DB) AddProjectFactEdge(projectID string, in ProjectFactEdgeInput, sourceFactKey, sourceConversationID string) (*ProjectFactEdge, error) {
|
||||||
|
sourceFactKey = strings.TrimSpace(sourceFactKey)
|
||||||
|
target := strings.TrimSpace(in.To)
|
||||||
|
if sourceFactKey == "" || target == "" {
|
||||||
|
return nil, fmt.Errorf("source 与 target 必填")
|
||||||
|
}
|
||||||
|
if sourceFactKey == target {
|
||||||
|
return nil, fmt.Errorf("边不能指向自身")
|
||||||
|
}
|
||||||
|
if err := ValidateProjectFactEdgeType(in.Type); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if err := ValidateFactKey(target); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
now := time.Now()
|
||||||
|
e := &ProjectFactEdge{
|
||||||
|
ID: uuid.New().String(),
|
||||||
|
ProjectID: projectID,
|
||||||
|
SourceFactKey: sourceFactKey,
|
||||||
|
TargetFactKey: target,
|
||||||
|
EdgeType: strings.ToLower(strings.TrimSpace(in.Type)),
|
||||||
|
Confidence: normalizeEdgeConfidence(in.Confidence),
|
||||||
|
SourceConversationID: sourceConversationID,
|
||||||
|
CreatedAt: now,
|
||||||
|
UpdatedAt: now,
|
||||||
|
}
|
||||||
|
_, err := db.Exec(
|
||||||
|
`INSERT INTO project_fact_edges (
|
||||||
|
id, project_id, source_fact_key, target_fact_key, edge_type, confidence,
|
||||||
|
source_conversation_id, created_at, updated_at
|
||||||
|
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||||
|
ON CONFLICT(project_id, source_fact_key, target_fact_key, edge_type)
|
||||||
|
DO UPDATE SET confidence = excluded.confidence, updated_at = excluded.updated_at`,
|
||||||
|
e.ID, e.ProjectID, e.SourceFactKey, e.TargetFactKey, e.EdgeType, e.Confidence,
|
||||||
|
nullIfEmpty(e.SourceConversationID), e.CreatedAt, e.UpdatedAt,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("添加边失败: %w", err)
|
||||||
|
}
|
||||||
|
// 返回最新
|
||||||
|
rows, err := db.Query(
|
||||||
|
`SELECT id, project_id, source_fact_key, target_fact_key, edge_type, confidence,
|
||||||
|
COALESCE(source_conversation_id,''), created_at, updated_at
|
||||||
|
FROM project_fact_edges
|
||||||
|
WHERE project_id = ? AND source_fact_key = ? AND target_fact_key = ? AND edge_type = ?`,
|
||||||
|
projectID, sourceFactKey, target, e.EdgeType,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return e, nil
|
||||||
|
}
|
||||||
|
defer rows.Close()
|
||||||
|
list, err := scanProjectFactEdges(rows)
|
||||||
|
if err != nil || len(list) == 0 {
|
||||||
|
return e, nil
|
||||||
|
}
|
||||||
|
return list[0], nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeleteProjectFactEdge 删除单条边。
|
||||||
|
func (db *DB) DeleteProjectFactEdge(edgeID string) error {
|
||||||
|
res, err := db.Exec(`DELETE FROM project_fact_edges WHERE id = ?`, edgeID)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
n, _ := res.RowsAffected()
|
||||||
|
if n == 0 {
|
||||||
|
return fmt.Errorf("边不存在")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (db *DB) insertProjectFactEdge(e *ProjectFactEdge) error {
|
||||||
|
_, err := db.Exec(
|
||||||
|
`INSERT INTO project_fact_edges (
|
||||||
|
id, project_id, source_fact_key, target_fact_key, edge_type, confidence,
|
||||||
|
source_conversation_id, created_at, updated_at
|
||||||
|
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)`,
|
||||||
|
e.ID, e.ProjectID, e.SourceFactKey, e.TargetFactKey, e.EdgeType, e.Confidence,
|
||||||
|
nullIfEmpty(e.SourceConversationID), e.CreatedAt, e.UpdatedAt,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("写入边失败: %w", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// RenameProjectFactKeyEdges 事实 key 变更时同步边上的引用。
|
||||||
|
func (db *DB) RenameProjectFactKeyEdges(projectID, oldKey, newKey string) error {
|
||||||
|
oldKey = strings.TrimSpace(oldKey)
|
||||||
|
newKey = strings.TrimSpace(newKey)
|
||||||
|
if oldKey == "" || newKey == "" || oldKey == newKey {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
now := time.Now()
|
||||||
|
if _, err := db.Exec(
|
||||||
|
`UPDATE project_fact_edges SET source_fact_key = ?, updated_at = ?
|
||||||
|
WHERE project_id = ? AND source_fact_key = ?`,
|
||||||
|
newKey, now, projectID, oldKey,
|
||||||
|
); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
_, err := db.Exec(
|
||||||
|
`UPDATE project_fact_edges SET target_fact_key = ?, updated_at = ?
|
||||||
|
WHERE project_id = ? AND target_fact_key = ?`,
|
||||||
|
newKey, now, projectID, oldKey,
|
||||||
|
)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeleteProjectFactEdgesForKey 删除与某 fact_key 相关的全部边。
|
||||||
|
func (db *DB) DeleteProjectFactEdgesForKey(projectID, factKey string) error {
|
||||||
|
_, err := db.Exec(
|
||||||
|
`DELETE FROM project_fact_edges
|
||||||
|
WHERE project_id = ? AND (source_fact_key = ? OR target_fact_key = ?)`,
|
||||||
|
projectID, factKey, factKey,
|
||||||
|
)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeprecateProjectFactEdgesForKey 将关联边标记为 deprecated。
|
||||||
|
func (db *DB) DeprecateProjectFactEdgesForKey(projectID, factKey string) error {
|
||||||
|
now := time.Now()
|
||||||
|
_, err := db.Exec(
|
||||||
|
`UPDATE project_fact_edges SET confidence = 'deprecated', updated_at = ?
|
||||||
|
WHERE project_id = ? AND (source_fact_key = ? OR target_fact_key = ?)
|
||||||
|
AND confidence != 'deprecated'`,
|
||||||
|
now, projectID, factKey, factKey,
|
||||||
|
)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func scanProjectFactEdges(rows *sql.Rows) ([]*ProjectFactEdge, error) {
|
||||||
|
var out []*ProjectFactEdge
|
||||||
|
for rows.Next() {
|
||||||
|
var e ProjectFactEdge
|
||||||
|
var createdAt, updatedAt string
|
||||||
|
if err := rows.Scan(
|
||||||
|
&e.ID, &e.ProjectID, &e.SourceFactKey, &e.TargetFactKey, &e.EdgeType, &e.Confidence,
|
||||||
|
&e.SourceConversationID, &createdAt, &updatedAt,
|
||||||
|
); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
e.CreatedAt = parseDBTime(createdAt)
|
||||||
|
e.UpdatedAt = parseDBTime(updatedAt)
|
||||||
|
out = append(out, &e)
|
||||||
|
}
|
||||||
|
return out, rows.Err()
|
||||||
|
}
|
||||||
@@ -0,0 +1,33 @@
|
|||||||
|
package database
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// formatSQLiteUTC stores instants as UTC RFC3339 for consistent SQLite reads/writes.
|
||||||
|
func formatSQLiteUTC(t time.Time) string {
|
||||||
|
return t.UTC().Format(time.RFC3339Nano)
|
||||||
|
}
|
||||||
|
|
||||||
|
// sqliteEpochGE returns SQL comparing column to param as Unix seconds (timezone-safe).
|
||||||
|
func sqliteEpochGE(column, op string) string {
|
||||||
|
return "strftime('%s', " + column + ") " + op + " strftime('%s', ?)"
|
||||||
|
}
|
||||||
|
|
||||||
|
// ParseRFC3339Time parses API/query timestamps (RFC3339 or RFC3339Nano).
|
||||||
|
func ParseRFC3339Time(value string) (time.Time, error) {
|
||||||
|
value = strings.TrimSpace(value)
|
||||||
|
if value == "" {
|
||||||
|
return time.Time{}, errors.New("empty time value")
|
||||||
|
}
|
||||||
|
if t, err := time.Parse(time.RFC3339Nano, value); err == nil {
|
||||||
|
return t.UTC(), nil
|
||||||
|
}
|
||||||
|
t, err := time.Parse(time.RFC3339, value)
|
||||||
|
if err != nil {
|
||||||
|
return time.Time{}, err
|
||||||
|
}
|
||||||
|
return t.UTC(), nil
|
||||||
|
}
|
||||||
@@ -16,7 +16,8 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
// ExecutionRecorder 可选,在 MCP 工具成功返回且带有 execution id 时回调(用于汇总 mcpExecutionIds)。
|
// ExecutionRecorder 可选,在 MCP 工具成功返回且带有 execution id 时回调(用于汇总 mcpExecutionIds)。
|
||||||
type ExecutionRecorder func(executionID string)
|
// toolCallID 来自 Eino compose.GetToolCallID,用于与 reduction 后的展示结果关联。
|
||||||
|
type ExecutionRecorder func(executionID, toolCallID string)
|
||||||
|
|
||||||
// ToolErrorPrefix 用于把内部 MCP 执行结果中的 IsError 标记传递到多代理上层。
|
// ToolErrorPrefix 用于把内部 MCP 执行结果中的 IsError 标记传递到多代理上层。
|
||||||
// Eino 工具通道目前只支持返回字符串,因此通过前缀标识,随后在多代理 runner 中解析为 success/isError。
|
// Eino 工具通道目前只支持返回字符串,因此通过前缀标识,随后在多代理 runner 中解析为 success/isError。
|
||||||
@@ -178,7 +179,7 @@ func runMCPToolInvocation(
|
|||||||
return "", nil
|
return "", nil
|
||||||
}
|
}
|
||||||
if res.ExecutionID != "" && record != nil {
|
if res.ExecutionID != "" && record != nil {
|
||||||
record(res.ExecutionID)
|
record(res.ExecutionID, compose.GetToolCallID(ctx))
|
||||||
}
|
}
|
||||||
if res.IsError {
|
if res.IsError {
|
||||||
return ToolErrorPrefix + res.Result, nil
|
return ToolErrorPrefix + res.Result, nil
|
||||||
|
|||||||
@@ -2,8 +2,8 @@ package einomcp
|
|||||||
|
|
||||||
import "sync"
|
import "sync"
|
||||||
|
|
||||||
// ToolInvokeNotifyHolder 由 Eino run loop 在迭代开始前 Set 回调;MCP 桥在每次 InvokableRun 结束时 Fire,
|
// ToolInvokeNotifyHolder 由 Eino run loop 在迭代开始前 Set 回调;MCP/execute 桥在工具调用结束时 Fire,
|
||||||
// 用于在 ADK 未透出 schema.Tool 事件时仍推送 tool_result、清 pending,避免 UI 卡在「执行中」或迭代末 force-close。
|
// 用于清除 pending tool_call(tool_result 由 ADK schema.Tool 事件推送,含流式工具与 reduction 后正文)。
|
||||||
type ToolInvokeNotifyHolder struct {
|
type ToolInvokeNotifyHolder struct {
|
||||||
mu sync.RWMutex
|
mu sync.RWMutex
|
||||||
fn func(toolCallID, toolName, einoAgent string, success bool, content string, invokeErr error)
|
fn func(toolCallID, toolName, einoAgent string, success bool, content string, invokeErr error)
|
||||||
|
|||||||
+106
-68
@@ -190,6 +190,21 @@ func (h *AgentHandler) SetAudit(s *audit.Service) {
|
|||||||
h.audit = s
|
h.audit = s
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CancelRunningTaskForConversation stops any in-flight agent work for the conversation (idempotent).
|
||||||
|
func (h *AgentHandler) CancelRunningTaskForConversation(conversationID string) {
|
||||||
|
if h == nil || conversationID == "" || h.tasks == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if execID := h.tasks.ActiveMCPExecutionID(conversationID); execID != "" {
|
||||||
|
h.agent.CancelMCPToolExecutionWithNote(execID, "")
|
||||||
|
}
|
||||||
|
if ok, err := h.tasks.CancelTask(conversationID, ErrTaskCancelled); ok {
|
||||||
|
h.logger.Info("已取消会话运行中任务", zap.String("conversationId", conversationID))
|
||||||
|
} else if err != nil {
|
||||||
|
h.logger.Warn("取消会话运行中任务失败", zap.String("conversationId", conversationID), zap.Error(err))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// HitlToolWhitelistSaver 合并 HITL 免审批工具到全局配置并落盘
|
// HitlToolWhitelistSaver 合并 HITL 免审批工具到全局配置并落盘
|
||||||
type HitlToolWhitelistSaver interface {
|
type HitlToolWhitelistSaver interface {
|
||||||
MergeHitlToolWhitelistIntoConfig(add []string) error
|
MergeHitlToolWhitelistIntoConfig(add []string) error
|
||||||
@@ -631,40 +646,11 @@ func (h *AgentHandler) runRobotEinoSingleWithRetry(
|
|||||||
assistantMessageID string,
|
assistantMessageID string,
|
||||||
taskStatus *string,
|
taskStatus *string,
|
||||||
) (string, string, error) {
|
) (string, string, error) {
|
||||||
curHist := history
|
resultMA, errMA := multiagent.RunEinoSingleChatModelAgent(
|
||||||
curMsg := finalMessage
|
taskCtx, h.config, &h.config.MultiAgent, h.agent, h.db, h.logger,
|
||||||
segmentUserMessage := finalMessage
|
conversationID, h.conversationProjectID(conversationID), finalMessage, history, roleTools, progressCallback, nil, h.projectBlackboardBlock(conversationID),
|
||||||
var resultMA *multiagent.RunResult
|
|
||||||
var errMA error
|
|
||||||
var transientRunAttempts int
|
|
||||||
var emptyResponseAttempts int
|
|
||||||
for {
|
|
||||||
resultMA, errMA = multiagent.RunEinoSingleChatModelAgent(
|
|
||||||
taskCtx, h.config, &h.config.MultiAgent, h.agent, h.logger,
|
|
||||||
conversationID, curMsg, curHist, roleTools, progressCallback, nil, h.projectBlackboardBlock(conversationID),
|
|
||||||
)
|
)
|
||||||
handledEmpty, exhaustedEmpty := h.handleEinoEmptyResponseContinue(
|
if errMA != nil {
|
||||||
taskCtx, conversationID, resultMA, errMA, &emptyResponseAttempts,
|
|
||||||
&curHist, &curMsg, segmentUserMessage, progressCallback, nil,
|
|
||||||
)
|
|
||||||
if exhaustedEmpty {
|
|
||||||
errMA = nil
|
|
||||||
break
|
|
||||||
}
|
|
||||||
if handledEmpty {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if errMA == nil {
|
|
||||||
transientRunAttempts = 0
|
|
||||||
emptyResponseAttempts = 0
|
|
||||||
break
|
|
||||||
}
|
|
||||||
if handled, _ := h.handleEinoTransientRetryContinue(
|
|
||||||
taskCtx, conversationID, resultMA, errMA, &transientRunAttempts,
|
|
||||||
&curHist, &curMsg, segmentUserMessage, progressCallback, nil,
|
|
||||||
); handled {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
*taskStatus = "failed"
|
*taskStatus = "failed"
|
||||||
return h.finalizeRobotAgentError(taskCtx, assistantMessageID, conversationID, resultMA, errMA)
|
return h.finalizeRobotAgentError(taskCtx, assistantMessageID, conversationID, resultMA, errMA)
|
||||||
}
|
}
|
||||||
@@ -680,41 +666,12 @@ func (h *AgentHandler) runRobotMultiAgentWithRetry(
|
|||||||
assistantMessageID string,
|
assistantMessageID string,
|
||||||
taskStatus *string,
|
taskStatus *string,
|
||||||
) (string, string, error) {
|
) (string, string, error) {
|
||||||
curHist := history
|
resultMA, errMA := multiagent.RunDeepAgent(
|
||||||
curMsg := finalMessage
|
taskCtx, h.config, &h.config.MultiAgent, h.agent, h.db, h.logger,
|
||||||
segmentUserMessage := finalMessage
|
conversationID, h.conversationProjectID(conversationID), finalMessage, history, roleTools, progressCallback,
|
||||||
var resultMA *multiagent.RunResult
|
|
||||||
var errMA error
|
|
||||||
var transientRunAttempts int
|
|
||||||
var emptyResponseAttempts int
|
|
||||||
for {
|
|
||||||
resultMA, errMA = multiagent.RunDeepAgent(
|
|
||||||
taskCtx, h.config, &h.config.MultiAgent, h.agent, h.logger,
|
|
||||||
conversationID, curMsg, curHist, roleTools, progressCallback,
|
|
||||||
h.agentsMarkdownDir, orchestration, nil, h.projectBlackboardBlock(conversationID),
|
h.agentsMarkdownDir, orchestration, nil, h.projectBlackboardBlock(conversationID),
|
||||||
)
|
)
|
||||||
handledEmpty, exhaustedEmpty := h.handleEinoEmptyResponseContinue(
|
if errMA != nil {
|
||||||
taskCtx, conversationID, resultMA, errMA, &emptyResponseAttempts,
|
|
||||||
&curHist, &curMsg, segmentUserMessage, progressCallback, nil,
|
|
||||||
)
|
|
||||||
if exhaustedEmpty {
|
|
||||||
errMA = nil
|
|
||||||
break
|
|
||||||
}
|
|
||||||
if handledEmpty {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if errMA == nil {
|
|
||||||
transientRunAttempts = 0
|
|
||||||
emptyResponseAttempts = 0
|
|
||||||
break
|
|
||||||
}
|
|
||||||
if handled, _ := h.handleEinoTransientRetryContinue(
|
|
||||||
taskCtx, conversationID, resultMA, errMA, &transientRunAttempts,
|
|
||||||
&curHist, &curMsg, segmentUserMessage, progressCallback, nil,
|
|
||||||
); handled {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
*taskStatus = "failed"
|
*taskStatus = "failed"
|
||||||
return h.finalizeRobotAgentError(taskCtx, assistantMessageID, conversationID, resultMA, errMA)
|
return h.finalizeRobotAgentError(taskCtx, assistantMessageID, conversationID, resultMA, errMA)
|
||||||
}
|
}
|
||||||
@@ -1185,6 +1142,8 @@ func (h *AgentHandler) createProgressCallback(runCtx context.Context, cancelRun
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
flushResponsePlan()
|
flushResponsePlan()
|
||||||
|
// 助手正文开始前,推理流通常已结束;落库以便刷新后「渗透测试详情」可回放
|
||||||
|
flushThinkingStreams()
|
||||||
respPlan.meta = nil
|
respPlan.meta = nil
|
||||||
if dataMap, ok := data.(map[string]interface{}); ok {
|
if dataMap, ok := data.(map[string]interface{}); ok {
|
||||||
respPlan.meta = make(map[string]interface{}, len(dataMap))
|
respPlan.meta = make(map[string]interface{}, len(dataMap))
|
||||||
@@ -1220,6 +1179,19 @@ func (h *AgentHandler) createProgressCallback(runCtx context.Context, cancelRun
|
|||||||
}
|
}
|
||||||
if eventType == "response" {
|
if eventType == "response" {
|
||||||
flushResponsePlan()
|
flushResponsePlan()
|
||||||
|
flushThinkingStreams()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if eventType == "done" {
|
||||||
|
flushResponsePlan()
|
||||||
|
flushThinkingStreams()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 流式思考/推理结束:聚合落库(与 eino_agent_reply_stream_end 同理)
|
||||||
|
if eventType == "thinking_stream_end" || eventType == "reasoning_chain_stream_end" {
|
||||||
|
flushResponsePlan()
|
||||||
|
flushThinkingStreams()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1294,7 +1266,10 @@ func (h *AgentHandler) createProgressCallback(runCtx context.Context, cancelRun
|
|||||||
|
|
||||||
// 保存过程详情到数据库(排除 response/done;response 正文已在 messages 表)
|
// 保存过程详情到数据库(排除 response/done;response 正文已在 messages 表)
|
||||||
// response_start/response_delta 已聚合为 planning,不落逐条。
|
// response_start/response_delta 已聚合为 planning,不落逐条。
|
||||||
|
// [Eino] agent 心跳 progress 仅用于实时进度标题,不落库以免时间线刷屏。
|
||||||
|
skipEinoAgentHeartbeat := eventType == "progress" && strings.HasPrefix(strings.TrimSpace(message), "[Eino] ")
|
||||||
if assistantMessageID != "" &&
|
if assistantMessageID != "" &&
|
||||||
|
!skipEinoAgentHeartbeat &&
|
||||||
eventType != "response" &&
|
eventType != "response" &&
|
||||||
eventType != "done" &&
|
eventType != "done" &&
|
||||||
eventType != "response_start" &&
|
eventType != "response_start" &&
|
||||||
@@ -1663,6 +1638,7 @@ func (h *AgentHandler) ListBatchQueues(c *gin.Context) {
|
|||||||
// StartBatchQueue 开始执行批量任务队列
|
// StartBatchQueue 开始执行批量任务队列
|
||||||
func (h *AgentHandler) StartBatchQueue(c *gin.Context) {
|
func (h *AgentHandler) StartBatchQueue(c *gin.Context) {
|
||||||
queueID := c.Param("queueId")
|
queueID := c.Param("queueId")
|
||||||
|
h.batchTaskManager.ClearSingleRunTask(queueID)
|
||||||
ok, err := h.startBatchQueueExecution(queueID, false)
|
ok, err := h.startBatchQueueExecution(queueID, false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||||
@@ -1694,6 +1670,7 @@ func (h *AgentHandler) RerunBatchQueue(c *gin.Context) {
|
|||||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "重置队列失败"})
|
c.JSON(http.StatusInternalServerError, gin.H{"error": "重置队列失败"})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
h.batchTaskManager.ClearSingleRunTask(queueID)
|
||||||
ok, err := h.startBatchQueueExecution(queueID, false)
|
ok, err := h.startBatchQueueExecution(queueID, false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||||
@@ -1893,6 +1870,53 @@ func (h *AgentHandler) AddBatchTask(c *gin.Context) {
|
|||||||
c.JSON(http.StatusOK, gin.H{"message": "任务已添加", "task": task, "queue": queue})
|
c.JSON(http.StatusOK, gin.H{"message": "任务已添加", "task": task, "queue": queue})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// RunSingleBatchTask 单条执行指定子任务(可覆盖已成功项),完成后暂停队列
|
||||||
|
func (h *AgentHandler) RunSingleBatchTask(c *gin.Context) {
|
||||||
|
queueID := c.Param("queueId")
|
||||||
|
taskID := c.Param("taskId")
|
||||||
|
|
||||||
|
if err := h.batchTaskManager.PrepareSingleTaskRun(queueID, taskID); err != nil {
|
||||||
|
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
h.batchTaskManager.SetSingleRunTask(queueID, taskID)
|
||||||
|
|
||||||
|
// 暂停态单条执行:旧批量协程可能仍占用执行槽,先回收以便重新启动
|
||||||
|
if queue, ok := h.batchTaskManager.GetBatchQueue(queueID); ok && queue.Status == BatchQueueStatusPaused {
|
||||||
|
h.forceUnmarkBatchQueueRunning(queueID)
|
||||||
|
}
|
||||||
|
|
||||||
|
autoStarted := true
|
||||||
|
autoStartMsg := "已开始单条执行"
|
||||||
|
ok, startErr := h.startBatchQueueExecution(queueID, false)
|
||||||
|
if startErr != nil {
|
||||||
|
h.batchTaskManager.ClearSingleRunTask(queueID)
|
||||||
|
autoStarted = false
|
||||||
|
autoStartMsg = "任务已准备就绪,但自动启动失败: " + startErr.Error()
|
||||||
|
} else if !ok {
|
||||||
|
h.batchTaskManager.ClearSingleRunTask(queueID)
|
||||||
|
autoStarted = false
|
||||||
|
autoStartMsg = "任务已准备就绪,但队列不存在"
|
||||||
|
}
|
||||||
|
|
||||||
|
queue, exists := h.batchTaskManager.GetBatchQueue(queueID)
|
||||||
|
if !exists {
|
||||||
|
c.JSON(http.StatusNotFound, gin.H{"error": "队列不存在"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if h.audit != nil {
|
||||||
|
h.audit.RecordOK(c, "task", "run_single_batch_task", "单条执行批量子任务", "batch_task", taskID, map[string]interface{}{
|
||||||
|
"batch_queue_id": queueID,
|
||||||
|
"auto_started": autoStarted,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"message": autoStartMsg,
|
||||||
|
"queue": queue,
|
||||||
|
"autoStarted": autoStarted,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
// DeleteBatchTask 删除批量任务
|
// DeleteBatchTask 删除批量任务
|
||||||
func (h *AgentHandler) DeleteBatchTask(c *gin.Context) {
|
func (h *AgentHandler) DeleteBatchTask(c *gin.Context) {
|
||||||
queueID := c.Param("queueId")
|
queueID := c.Param("queueId")
|
||||||
@@ -1934,6 +1958,10 @@ func (h *AgentHandler) unmarkBatchQueueRunning(queueID string) {
|
|||||||
delete(h.batchRunning, queueID)
|
delete(h.batchRunning, queueID)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (h *AgentHandler) forceUnmarkBatchQueueRunning(queueID string) {
|
||||||
|
h.unmarkBatchQueueRunning(queueID)
|
||||||
|
}
|
||||||
|
|
||||||
func (h *AgentHandler) nextBatchQueueRunAt(cronExpr string, from time.Time) (*time.Time, error) {
|
func (h *AgentHandler) nextBatchQueueRunAt(cronExpr string, from time.Time) (*time.Time, error) {
|
||||||
expr := strings.TrimSpace(cronExpr)
|
expr := strings.TrimSpace(cronExpr)
|
||||||
if expr == "" {
|
if expr == "" {
|
||||||
@@ -2081,6 +2109,10 @@ func (h *AgentHandler) executeBatchQueue(queueID string) {
|
|||||||
h.logger.Error("创建对话失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.Error(err))
|
h.logger.Error("创建对话失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.Error(err))
|
||||||
h.batchTaskManager.UpdateTaskStatus(queueID, task.ID, "failed", "", "创建对话失败: "+err.Error())
|
h.batchTaskManager.UpdateTaskStatus(queueID, task.ID, "failed", "", "创建对话失败: "+err.Error())
|
||||||
h.batchTaskManager.MoveToNextTask(queueID)
|
h.batchTaskManager.MoveToNextTask(queueID)
|
||||||
|
if h.batchTaskManager.TakeSingleRunTaskIfMatch(queueID, task.ID) {
|
||||||
|
h.batchTaskManager.UpdateQueueStatus(queueID, "paused")
|
||||||
|
break
|
||||||
|
}
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
conversationID = conv.ID
|
conversationID = conv.ID
|
||||||
@@ -2218,12 +2250,12 @@ func (h *AgentHandler) executeBatchQueue(queueID string) {
|
|||||||
var runErr error
|
var runErr error
|
||||||
switch {
|
switch {
|
||||||
case useBatchMulti:
|
case useBatchMulti:
|
||||||
resultMA, runErr = multiagent.RunDeepAgent(taskCtx, h.config, &h.config.MultiAgent, h.agent, h.logger, conversationID, finalMessage, []agent.ChatMessage{}, roleTools, progressCallback, h.agentsMarkdownDir, batchOrch, nil, h.projectBlackboardBlock(conversationID))
|
resultMA, runErr = multiagent.RunDeepAgent(taskCtx, h.config, &h.config.MultiAgent, h.agent, h.db, h.logger, conversationID, h.conversationProjectID(conversationID), finalMessage, []agent.ChatMessage{}, roleTools, progressCallback, h.agentsMarkdownDir, batchOrch, nil, h.projectBlackboardBlock(conversationID))
|
||||||
default:
|
default:
|
||||||
if h.config == nil {
|
if h.config == nil {
|
||||||
runErr = fmt.Errorf("服务器配置未加载")
|
runErr = fmt.Errorf("服务器配置未加载")
|
||||||
} else {
|
} else {
|
||||||
resultMA, runErr = multiagent.RunEinoSingleChatModelAgent(taskCtx, h.config, &h.config.MultiAgent, h.agent, h.logger, conversationID, finalMessage, []agent.ChatMessage{}, roleTools, progressCallback, nil, h.projectBlackboardBlock(conversationID))
|
resultMA, runErr = multiagent.RunEinoSingleChatModelAgent(taskCtx, h.config, &h.config.MultiAgent, h.agent, h.db, h.logger, conversationID, h.conversationProjectID(conversationID), finalMessage, []agent.ChatMessage{}, roleTools, progressCallback, nil, h.projectBlackboardBlock(conversationID))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -2337,6 +2369,12 @@ func (h *AgentHandler) executeBatchQueue(queueID string) {
|
|||||||
// 移动到下一个任务
|
// 移动到下一个任务
|
||||||
h.batchTaskManager.MoveToNextTask(queueID)
|
h.batchTaskManager.MoveToNextTask(queueID)
|
||||||
|
|
||||||
|
if h.batchTaskManager.TakeSingleRunTaskIfMatch(queueID, task.ID) {
|
||||||
|
h.batchTaskManager.UpdateQueueStatus(queueID, "paused")
|
||||||
|
h.logger.Info("单条执行完成,队列已暂停", zap.String("queueId", queueID), zap.String("taskId", task.ID))
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
// 检查是否被取消或暂停
|
// 检查是否被取消或暂停
|
||||||
queue, _ = h.batchTaskManager.GetBatchQueue(queueID)
|
queue, _ = h.batchTaskManager.GetBatchQueue(queueID)
|
||||||
if queue.Status == "cancelled" || queue.Status == "paused" {
|
if queue.Status == "cancelled" || queue.Status == "paused" {
|
||||||
|
|||||||
@@ -3,10 +3,14 @@ package handler
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
"sync"
|
"sync"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"cyberstrike-ai/internal/config"
|
"cyberstrike-ai/internal/config"
|
||||||
|
"cyberstrike-ai/internal/database"
|
||||||
|
"cyberstrike-ai/internal/openai"
|
||||||
|
|
||||||
"go.uber.org/zap"
|
"go.uber.org/zap"
|
||||||
)
|
)
|
||||||
@@ -46,3 +50,50 @@ func TestCreateProgressCallback_ConcurrentToolEvents(t *testing.T) {
|
|||||||
}
|
}
|
||||||
wg.Wait()
|
wg.Wait()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TestCreateProgressCallback_FlushesReasoningOnDone 流式推理聚合须在 done/response 时落库,刷新后可回放。
|
||||||
|
func TestCreateProgressCallback_FlushesReasoningOnDone(t *testing.T) {
|
||||||
|
tmp := t.TempDir()
|
||||||
|
db, err := database.NewDB(filepath.Join(tmp, "test.sqlite"), zap.NewNop())
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("NewDB: %v", err)
|
||||||
|
}
|
||||||
|
defer os.RemoveAll(tmp)
|
||||||
|
|
||||||
|
conv, err := db.CreateConversation("test", database.ConversationCreateMeta{})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("CreateConversation: %v", err)
|
||||||
|
}
|
||||||
|
asst, err := db.AddMessage(conv.ID, "assistant", "处理中...", nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("AddMessage: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
h := &AgentHandler{logger: zap.NewNop(), db: db}
|
||||||
|
cb := h.createProgressCallback(context.Background(), nil, conv.ID, asst.ID, nil)
|
||||||
|
|
||||||
|
streamID := "eino-reasoning-test-1"
|
||||||
|
cb("reasoning_chain_stream_start", " ", map[string]interface{}{
|
||||||
|
"streamId": streamID,
|
||||||
|
"source": "eino",
|
||||||
|
})
|
||||||
|
cb("reasoning_chain_stream_delta", "step one", openai.WithSSEAccumulated(map[string]interface{}{
|
||||||
|
"streamId": streamID,
|
||||||
|
}, "step one"))
|
||||||
|
cb("done", "", map[string]interface{}{"conversationId": conv.ID})
|
||||||
|
|
||||||
|
details, err := db.GetProcessDetails(asst.ID)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GetProcessDetails: %v", err)
|
||||||
|
}
|
||||||
|
found := false
|
||||||
|
for _, d := range details {
|
||||||
|
if d.EventType == "reasoning_chain" && d.Message == "step one" {
|
||||||
|
found = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !found {
|
||||||
|
t.Fatalf("expected reasoning_chain persisted on done, got %+v", details)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -2,7 +2,6 @@ package handler
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"strconv"
|
"strconv"
|
||||||
"time"
|
|
||||||
|
|
||||||
"cyberstrike-ai/internal/database"
|
"cyberstrike-ai/internal/database"
|
||||||
|
|
||||||
@@ -20,12 +19,12 @@ func auditFilterFromQuery(c *gin.Context) database.ListAuditLogsFilter {
|
|||||||
ResourceID: c.Query("resource_id"),
|
ResourceID: c.Query("resource_id"),
|
||||||
}
|
}
|
||||||
if since := c.Query("since"); since != "" {
|
if since := c.Query("since"); since != "" {
|
||||||
if t, err := time.Parse(time.RFC3339, since); err == nil {
|
if t, err := database.ParseRFC3339Time(since); err == nil {
|
||||||
filter.Since = &t
|
filter.Since = &t
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if until := c.Query("until"); until != "" {
|
if until := c.Query("until"); until != "" {
|
||||||
if t, err := time.Parse(time.RFC3339, until); err == nil {
|
if t, err := database.ParseRFC3339Time(until); err == nil {
|
||||||
filter.Until = &t
|
filter.Until = &t
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -81,6 +81,7 @@ type BatchTaskManager struct {
|
|||||||
logger *zap.Logger
|
logger *zap.Logger
|
||||||
queues map[string]*BatchTaskQueue
|
queues map[string]*BatchTaskQueue
|
||||||
taskCancels map[string]context.CancelFunc // 存储每个队列当前任务的取消函数
|
taskCancels map[string]context.CancelFunc // 存储每个队列当前任务的取消函数
|
||||||
|
singleRunTasks map[string]string // queueID -> taskID,单条执行完成后暂停队列
|
||||||
mu sync.RWMutex
|
mu sync.RWMutex
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -93,6 +94,7 @@ func NewBatchTaskManager(logger *zap.Logger) *BatchTaskManager {
|
|||||||
logger: logger,
|
logger: logger,
|
||||||
queues: make(map[string]*BatchTaskQueue),
|
queues: make(map[string]*BatchTaskQueue),
|
||||||
taskCancels: make(map[string]context.CancelFunc),
|
taskCancels: make(map[string]context.CancelFunc),
|
||||||
|
singleRunTasks: make(map[string]string),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -864,6 +866,138 @@ func (m *BatchTaskManager) AddTaskToQueue(queueID, message string) (*BatchTask,
|
|||||||
return task, nil
|
return task, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// PrepareSingleTaskRun 准备单条执行:重置目标任务(若已有结果)并定位队列索引
|
||||||
|
func (m *BatchTaskManager) PrepareSingleTaskRun(queueID, taskID string) error {
|
||||||
|
var cancelFunc context.CancelFunc
|
||||||
|
var siblingRunningIDs []string
|
||||||
|
|
||||||
|
m.mu.Lock()
|
||||||
|
queue, exists := m.queues[queueID]
|
||||||
|
if !exists {
|
||||||
|
m.mu.Unlock()
|
||||||
|
return fmt.Errorf("队列不存在")
|
||||||
|
}
|
||||||
|
|
||||||
|
var task *BatchTask
|
||||||
|
taskIndex := -1
|
||||||
|
for i, t := range queue.Tasks {
|
||||||
|
if t.ID == taskID {
|
||||||
|
taskIndex = i
|
||||||
|
task = t
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if task == nil {
|
||||||
|
m.mu.Unlock()
|
||||||
|
return fmt.Errorf("任务不存在")
|
||||||
|
}
|
||||||
|
|
||||||
|
if !queueAllowsSingleTaskRunLocked(queue, task) {
|
||||||
|
m.mu.Unlock()
|
||||||
|
return fmt.Errorf("队列正在执行或未就绪,无法单条执行")
|
||||||
|
}
|
||||||
|
|
||||||
|
// 暂停态:中止在途子任务并收口仍标记 running 的其它子任务,以便单条执行非冲突项
|
||||||
|
if queue.Status == BatchQueueStatusPaused {
|
||||||
|
if c, ok := m.taskCancels[queueID]; ok {
|
||||||
|
cancelFunc = c
|
||||||
|
delete(m.taskCancels, queueID)
|
||||||
|
}
|
||||||
|
for _, t := range queue.Tasks {
|
||||||
|
if t != nil && t.ID != taskID && t.Status == BatchTaskStatusRunning {
|
||||||
|
siblingRunningIDs = append(siblingRunningIDs, t.ID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
needsReset := task.Status != BatchTaskStatusPending
|
||||||
|
resumeQueue := queue.Status == BatchQueueStatusCompleted || queue.Status == BatchQueueStatusCancelled
|
||||||
|
m.mu.Unlock()
|
||||||
|
|
||||||
|
if cancelFunc != nil {
|
||||||
|
cancelFunc()
|
||||||
|
}
|
||||||
|
const staleRunMsg = "为单条执行其它任务,已中止"
|
||||||
|
for _, sid := range siblingRunningIDs {
|
||||||
|
m.UpdateTaskStatus(queueID, sid, BatchTaskStatusCancelled, "", staleRunMsg)
|
||||||
|
}
|
||||||
|
|
||||||
|
m.mu.Lock()
|
||||||
|
defer m.mu.Unlock()
|
||||||
|
|
||||||
|
queue, exists = m.queues[queueID]
|
||||||
|
if !exists {
|
||||||
|
return fmt.Errorf("队列不存在")
|
||||||
|
}
|
||||||
|
|
||||||
|
task = nil
|
||||||
|
taskIndex = -1
|
||||||
|
for i, t := range queue.Tasks {
|
||||||
|
if t.ID == taskID {
|
||||||
|
taskIndex = i
|
||||||
|
task = t
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if task == nil {
|
||||||
|
return fmt.Errorf("任务不存在")
|
||||||
|
}
|
||||||
|
|
||||||
|
if m.db != nil {
|
||||||
|
if err := m.db.PrepareBatchSingleTaskRun(queueID, taskID, taskIndex, needsReset, resumeQueue); err != nil {
|
||||||
|
return fmt.Errorf("准备单条执行失败: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if needsReset {
|
||||||
|
task.Status = BatchTaskStatusPending
|
||||||
|
task.ConversationID = ""
|
||||||
|
task.StartedAt = nil
|
||||||
|
task.CompletedAt = nil
|
||||||
|
task.Error = ""
|
||||||
|
task.Result = ""
|
||||||
|
}
|
||||||
|
queue.CurrentIndex = taskIndex
|
||||||
|
queue.LastRunError = ""
|
||||||
|
if resumeQueue {
|
||||||
|
queue.Status = BatchQueueStatusPaused
|
||||||
|
queue.CompletedAt = nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetSingleRunTask 标记队列仅执行指定子任务,完成后自动暂停
|
||||||
|
func (m *BatchTaskManager) SetSingleRunTask(queueID, taskID string) {
|
||||||
|
m.mu.Lock()
|
||||||
|
defer m.mu.Unlock()
|
||||||
|
if m.singleRunTasks == nil {
|
||||||
|
m.singleRunTasks = make(map[string]string)
|
||||||
|
}
|
||||||
|
m.singleRunTasks[queueID] = taskID
|
||||||
|
}
|
||||||
|
|
||||||
|
// ClearSingleRunTask 清除单条执行标记
|
||||||
|
func (m *BatchTaskManager) ClearSingleRunTask(queueID string) {
|
||||||
|
m.mu.Lock()
|
||||||
|
defer m.mu.Unlock()
|
||||||
|
delete(m.singleRunTasks, queueID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TakeSingleRunTaskIfMatch 若刚完成的子任务为单条执行目标,则清除标记并返回 true
|
||||||
|
func (m *BatchTaskManager) TakeSingleRunTaskIfMatch(queueID, taskID string) bool {
|
||||||
|
m.mu.Lock()
|
||||||
|
defer m.mu.Unlock()
|
||||||
|
if m.singleRunTasks == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if m.singleRunTasks[queueID] != taskID {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
delete(m.singleRunTasks, queueID)
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
// DeleteTask 删除任务(队列空闲时可删;执行中任务不可删)
|
// DeleteTask 删除任务(队列空闲时可删;执行中任务不可删)
|
||||||
func (m *BatchTaskManager) DeleteTask(queueID, taskID string) error {
|
func (m *BatchTaskManager) DeleteTask(queueID, taskID string) error {
|
||||||
m.mu.Lock()
|
m.mu.Lock()
|
||||||
@@ -936,6 +1070,25 @@ func queueAllowsTaskListMutationLocked(queue *BatchTaskQueue) bool {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// queueAllowsSingleTaskRunLocked 是否允许对指定子任务发起单条执行(必须在持有 BatchTaskManager.mu 下调用)
|
||||||
|
func queueAllowsSingleTaskRunLocked(queue *BatchTaskQueue, task *BatchTask) bool {
|
||||||
|
if queue == nil || task == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if task.Status == BatchTaskStatusRunning {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if queue.Status == BatchQueueStatusRunning {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
switch queue.Status {
|
||||||
|
case BatchQueueStatusPending, BatchQueueStatusPaused, BatchQueueStatusCompleted, BatchQueueStatusCancelled:
|
||||||
|
return true
|
||||||
|
default:
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// GetNextTask 获取下一个待执行的任务
|
// GetNextTask 获取下一个待执行的任务
|
||||||
func (m *BatchTaskManager) GetNextTask(queueID string) (*BatchTask, bool) {
|
func (m *BatchTaskManager) GetNextTask(queueID string) (*BatchTask, bool) {
|
||||||
m.mu.Lock()
|
m.mu.Lock()
|
||||||
|
|||||||
+58
-3
@@ -1,6 +1,7 @@
|
|||||||
package handler
|
package handler
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"database/sql"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
@@ -277,6 +278,9 @@ func (h *C2Handler) ListSessions(c *gin.Context) {
|
|||||||
filter.Limit = n
|
filter.Limit = n
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if c.Query("suspicious") == "1" || strings.EqualFold(c.Query("suspicious"), "true") {
|
||||||
|
filter.Suspicious = true
|
||||||
|
}
|
||||||
|
|
||||||
sessions, err := h.mgr().DB().ListC2Sessions(filter)
|
sessions, err := h.mgr().DB().ListC2Sessions(filter)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -324,7 +328,37 @@ func (h *C2Handler) DeleteSession(c *gin.Context) {
|
|||||||
c.JSON(http.StatusOK, gin.H{"deleted": true})
|
c.JSON(http.StatusOK, gin.H{"deleted": true})
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetSessionSleep 设置会话的 sleep/jitter
|
// DeleteSessions 批量删除会话(请求体 JSON: {"ids":["s_xxx",...]})
|
||||||
|
func (h *C2Handler) DeleteSessions(c *gin.Context) {
|
||||||
|
var req struct {
|
||||||
|
IDs []string `json:"ids"`
|
||||||
|
}
|
||||||
|
if err := c.ShouldBindJSON(&req); err != nil {
|
||||||
|
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid json: " + err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if len(req.IDs) == 0 {
|
||||||
|
c.JSON(http.StatusBadRequest, gin.H{"error": "ids is required"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
n, err := h.mgr().DB().DeleteC2SessionsByIDs(req.IDs)
|
||||||
|
if err != nil {
|
||||||
|
if errors.Is(err, database.ErrNoValidC2SessionIDs) {
|
||||||
|
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if h.audit != nil {
|
||||||
|
h.audit.RecordOK(c, "c2", "session_delete", "批量删除 C2 会话", "c2_session", "", map[string]interface{}{
|
||||||
|
"count": n, "ids": req.IDs,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
c.JSON(http.StatusOK, gin.H{"deleted": n})
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetSessionSleep 设置会话的 sleep/jitter,并下发 sleep 任务到植入体
|
||||||
func (h *C2Handler) SetSessionSleep(c *gin.Context) {
|
func (h *C2Handler) SetSessionSleep(c *gin.Context) {
|
||||||
id := c.Param("id")
|
id := c.Param("id")
|
||||||
var req struct {
|
var req struct {
|
||||||
@@ -335,12 +369,33 @@ func (h *C2Handler) SetSessionSleep(c *gin.Context) {
|
|||||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
if req.SleepSeconds < 1 {
|
||||||
|
c.JSON(http.StatusBadRequest, gin.H{"error": "sleep_seconds must be >= 1"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if req.JitterPercent < 0 || req.JitterPercent > 100 {
|
||||||
|
c.JSON(http.StatusBadRequest, gin.H{"error": "jitter_percent must be 0-100"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
if err := h.mgr().DB().SetC2SessionSleep(id, req.SleepSeconds, req.JitterPercent); err != nil {
|
task, err := h.mgr().SetSessionSleep(id, req.SleepSeconds, req.JitterPercent)
|
||||||
|
if err != nil {
|
||||||
|
if errors.Is(err, sql.ErrNoRows) {
|
||||||
|
c.JSON(http.StatusNotFound, gin.H{"error": "session not found"})
|
||||||
|
return
|
||||||
|
}
|
||||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
c.JSON(http.StatusOK, gin.H{"updated": true})
|
out := gin.H{
|
||||||
|
"updated": true,
|
||||||
|
"sleep_seconds": req.SleepSeconds,
|
||||||
|
"jitter_percent": req.JitterPercent,
|
||||||
|
}
|
||||||
|
if task != nil {
|
||||||
|
out["task_id"] = task.ID
|
||||||
|
}
|
||||||
|
c.JSON(http.StatusOK, out)
|
||||||
}
|
}
|
||||||
|
|
||||||
// ============================================================================
|
// ============================================================================
|
||||||
|
|||||||
+74
-10
@@ -689,8 +689,6 @@ type UpdateConfigRequest struct {
|
|||||||
// 避免旧版「整包替换 *AgentConfig」时,未传的整型字段被反序列化为 0 误覆盖(例如 tool_timeout_minutes 变成 0)。
|
// 避免旧版「整包替换 *AgentConfig」时,未传的整型字段被反序列化为 0 误覆盖(例如 tool_timeout_minutes 变成 0)。
|
||||||
type AgentConfigUpdate struct {
|
type AgentConfigUpdate struct {
|
||||||
MaxIterations *int `json:"max_iterations,omitempty"`
|
MaxIterations *int `json:"max_iterations,omitempty"`
|
||||||
LargeResultThreshold *int `json:"large_result_threshold,omitempty"`
|
|
||||||
ResultStorageDir *string `json:"result_storage_dir,omitempty"`
|
|
||||||
ToolTimeoutMinutes *int `json:"tool_timeout_minutes,omitempty"`
|
ToolTimeoutMinutes *int `json:"tool_timeout_minutes,omitempty"`
|
||||||
SystemPromptPath *string `json:"system_prompt_path,omitempty"`
|
SystemPromptPath *string `json:"system_prompt_path,omitempty"`
|
||||||
}
|
}
|
||||||
@@ -702,12 +700,6 @@ func applyAgentConfigUpdate(dst *config.AgentConfig, src *AgentConfigUpdate) {
|
|||||||
if src.MaxIterations != nil {
|
if src.MaxIterations != nil {
|
||||||
dst.MaxIterations = *src.MaxIterations
|
dst.MaxIterations = *src.MaxIterations
|
||||||
}
|
}
|
||||||
if src.LargeResultThreshold != nil {
|
|
||||||
dst.LargeResultThreshold = *src.LargeResultThreshold
|
|
||||||
}
|
|
||||||
if src.ResultStorageDir != nil {
|
|
||||||
dst.ResultStorageDir = *src.ResultStorageDir
|
|
||||||
}
|
|
||||||
if src.ToolTimeoutMinutes != nil {
|
if src.ToolTimeoutMinutes != nil {
|
||||||
dst.ToolTimeoutMinutes = *src.ToolTimeoutMinutes
|
dst.ToolTimeoutMinutes = *src.ToolTimeoutMinutes
|
||||||
}
|
}
|
||||||
@@ -1076,6 +1068,80 @@ func (h *ConfigHandler) TestOpenAI(c *gin.Context) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ListModelsRequest 获取模型列表请求(OpenAI 兼容 GET /models)。
|
||||||
|
type ListModelsRequest struct {
|
||||||
|
Provider string `json:"provider"`
|
||||||
|
BaseURL string `json:"base_url"`
|
||||||
|
APIKey string `json:"api_key"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// ListModels 代理调用上游 GET /models,返回可用模型 id 列表。
|
||||||
|
func (h *ConfigHandler) ListModels(c *gin.Context) {
|
||||||
|
var req ListModelsRequest
|
||||||
|
if err := c.ShouldBindJSON(&req); err != nil {
|
||||||
|
c.JSON(http.StatusBadRequest, gin.H{"error": "无效的请求参数: " + err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
provider := strings.TrimSpace(req.Provider)
|
||||||
|
if provider == "" {
|
||||||
|
provider = "openai"
|
||||||
|
}
|
||||||
|
if strings.EqualFold(provider, "claude") {
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": false,
|
||||||
|
"supported": false,
|
||||||
|
"error": "Claude (Anthropic Messages API) 不支持自动获取模型列表,请手动填写",
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if strings.TrimSpace(req.APIKey) == "" {
|
||||||
|
c.JSON(http.StatusBadRequest, gin.H{"error": "API Key 不能为空"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
baseURL := strings.TrimSuffix(strings.TrimSpace(req.BaseURL), "/")
|
||||||
|
if baseURL == "" {
|
||||||
|
baseURL = "https://api.openai.com/v1"
|
||||||
|
}
|
||||||
|
|
||||||
|
tmpCfg := &config.OpenAIConfig{
|
||||||
|
Provider: provider,
|
||||||
|
BaseURL: baseURL,
|
||||||
|
APIKey: strings.TrimSpace(req.APIKey),
|
||||||
|
}
|
||||||
|
client := openai.NewClient(tmpCfg, nil, h.logger)
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(c.Request.Context(), 30*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
models, err := client.ListModels(ctx)
|
||||||
|
if err != nil {
|
||||||
|
if apiErr, ok := err.(*openai.APIError); ok {
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": false,
|
||||||
|
"supported": true,
|
||||||
|
"error": fmt.Sprintf("API 返回错误 (HTTP %d): %s", apiErr.StatusCode, apiErr.Body),
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": false,
|
||||||
|
"supported": true,
|
||||||
|
"error": err.Error(),
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"success": true,
|
||||||
|
"supported": true,
|
||||||
|
"models": models,
|
||||||
|
"count": len(models),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
// TestVisionRequest 测试 Vision 模型连接;vision.api_key/base_url 留空时可传 openai 段作回退。
|
// TestVisionRequest 测试 Vision 模型连接;vision.api_key/base_url 留空时可传 openai 段作回退。
|
||||||
type TestVisionRequest struct {
|
type TestVisionRequest struct {
|
||||||
Vision config.VisionConfig `json:"vision"`
|
Vision config.VisionConfig `json:"vision"`
|
||||||
@@ -1532,8 +1598,6 @@ func updateAgentConfig(doc *yaml.Node, agent config.AgentConfig) {
|
|||||||
agentNode := ensureMap(root, "agent")
|
agentNode := ensureMap(root, "agent")
|
||||||
setIntInMap(agentNode, "max_iterations", agent.MaxIterations)
|
setIntInMap(agentNode, "max_iterations", agent.MaxIterations)
|
||||||
setIntInMap(agentNode, "tool_timeout_minutes", agent.ToolTimeoutMinutes)
|
setIntInMap(agentNode, "tool_timeout_minutes", agent.ToolTimeoutMinutes)
|
||||||
setIntInMap(agentNode, "large_result_threshold", agent.LargeResultThreshold)
|
|
||||||
setStringInMap(agentNode, "result_storage_dir", agent.ResultStorageDir)
|
|
||||||
setStringInMap(agentNode, "system_prompt_path", agent.SystemPromptPath)
|
setStringInMap(agentNode, "system_prompt_path", agent.SystemPromptPath)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -12,11 +12,17 @@ import (
|
|||||||
"go.uber.org/zap"
|
"go.uber.org/zap"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// ConversationTaskStopper cancels in-flight agent work when a conversation is removed.
|
||||||
|
type ConversationTaskStopper interface {
|
||||||
|
CancelRunningTaskForConversation(conversationID string)
|
||||||
|
}
|
||||||
|
|
||||||
// ConversationHandler 对话处理器
|
// ConversationHandler 对话处理器
|
||||||
type ConversationHandler struct {
|
type ConversationHandler struct {
|
||||||
db *database.DB
|
db *database.DB
|
||||||
logger *zap.Logger
|
logger *zap.Logger
|
||||||
audit *audit.Service
|
audit *audit.Service
|
||||||
|
taskStopper ConversationTaskStopper
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetAudit wires platform audit logging.
|
// SetAudit wires platform audit logging.
|
||||||
@@ -24,6 +30,11 @@ func (h *ConversationHandler) SetAudit(s *audit.Service) {
|
|||||||
h.audit = s
|
h.audit = s
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetTaskStopper wires cancellation of in-flight agent tasks on conversation delete.
|
||||||
|
func (h *ConversationHandler) SetTaskStopper(stopper ConversationTaskStopper) {
|
||||||
|
h.taskStopper = stopper
|
||||||
|
}
|
||||||
|
|
||||||
// NewConversationHandler 创建新的对话处理器
|
// NewConversationHandler 创建新的对话处理器
|
||||||
func NewConversationHandler(db *database.DB, logger *zap.Logger) *ConversationHandler {
|
func NewConversationHandler(db *database.DB, logger *zap.Logger) *ConversationHandler {
|
||||||
return &ConversationHandler{
|
return &ConversationHandler{
|
||||||
@@ -105,17 +116,18 @@ func (h *ConversationHandler) ListConversations(c *gin.Context) {
|
|||||||
|
|
||||||
excludeGrouped := strings.TrimSpace(search) == "" &&
|
excludeGrouped := strings.TrimSpace(search) == "" &&
|
||||||
(c.Query("exclude_grouped") == "true" || c.Query("exclude_grouped") == "1")
|
(c.Query("exclude_grouped") == "true" || c.Query("exclude_grouped") == "1")
|
||||||
|
sortBy := strings.TrimSpace(c.Query("sort_by"))
|
||||||
|
|
||||||
var conversations []*database.Conversation
|
var conversations []*database.Conversation
|
||||||
var total int
|
var total int
|
||||||
var err error
|
var err error
|
||||||
if excludeGrouped {
|
if excludeGrouped {
|
||||||
conversations, err = h.db.ListUngroupedConversations(limit, offset)
|
conversations, err = h.db.ListUngroupedConversations(limit, offset, sortBy)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
total, err = h.db.CountUngroupedConversations()
|
total, err = h.db.CountUngroupedConversations()
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
conversations, err = h.db.ListConversations(limit, offset, search)
|
conversations, err = h.db.ListConversations(limit, offset, search, sortBy)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
total, err = h.db.CountConversations(search)
|
total, err = h.db.CountConversations(search)
|
||||||
}
|
}
|
||||||
@@ -244,6 +256,10 @@ func (h *ConversationHandler) UpdateConversation(c *gin.Context) {
|
|||||||
func (h *ConversationHandler) DeleteConversation(c *gin.Context) {
|
func (h *ConversationHandler) DeleteConversation(c *gin.Context) {
|
||||||
id := c.Param("id")
|
id := c.Param("id")
|
||||||
|
|
||||||
|
if h.taskStopper != nil {
|
||||||
|
h.taskStopper.CancelRunningTaskForConversation(id)
|
||||||
|
}
|
||||||
|
|
||||||
if err := h.db.DeleteConversation(id); err != nil {
|
if err := h.db.DeleteConversation(id); err != nil {
|
||||||
h.logger.Error("删除对话失败", zap.Error(err))
|
h.logger.Error("删除对话失败", zap.Error(err))
|
||||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||||
|
|||||||
@@ -0,0 +1,30 @@
|
|||||||
|
package handler
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"go.uber.org/zap"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestConversationHandlerDeleteConversationCancelsRunningTask(t *testing.T) {
|
||||||
|
tm := NewAgentTaskManager()
|
||||||
|
ctx, cancel := context.WithCancelCause(context.Background())
|
||||||
|
_, err := tm.StartTask("conv-1", "hello", cancel)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("StartTask: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
h := &AgentHandler{tasks: tm, logger: zap.NewNop()}
|
||||||
|
h.CancelRunningTaskForConversation("conv-1")
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
case <-time.After(2 * time.Second):
|
||||||
|
t.Fatal("task context was not cancelled")
|
||||||
|
}
|
||||||
|
if cause := context.Cause(ctx); cause != ErrTaskCancelled {
|
||||||
|
t.Fatalf("expected ErrTaskCancelled, got %v", cause)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -2,31 +2,11 @@ package handler
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"errors"
|
|
||||||
"fmt"
|
|
||||||
"strings"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"cyberstrike-ai/internal/agent"
|
"cyberstrike-ai/internal/agent"
|
||||||
"cyberstrike-ai/internal/multiagent"
|
"cyberstrike-ai/internal/multiagent"
|
||||||
|
|
||||||
"go.uber.org/zap"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func (h *AgentHandler) einoRunRetryMaxAttempts() int {
|
|
||||||
if h.config != nil {
|
|
||||||
return multiagent.RunRetryMaxAttemptsFromConfig(&h.config.MultiAgent.EinoMiddleware)
|
|
||||||
}
|
|
||||||
return multiagent.RunRetryMaxAttemptsFromConfig(nil)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (h *AgentHandler) einoRunRetryMaxBackoffSec() int {
|
|
||||||
if h.config != nil && h.config.MultiAgent.EinoMiddleware.RunRetryMaxBackoffSec > 0 {
|
|
||||||
return h.config.MultiAgent.EinoMiddleware.RunRetryMaxBackoffSec
|
|
||||||
}
|
|
||||||
return 0
|
|
||||||
}
|
|
||||||
|
|
||||||
// applyEinoTraceResumeSegment 中断并继续:persist last_react_* → loadHistory,可选替换下一段 user 文案。
|
// applyEinoTraceResumeSegment 中断并继续:persist last_react_* → loadHistory,可选替换下一段 user 文案。
|
||||||
func (h *AgentHandler) applyEinoTraceResumeSegment(
|
func (h *AgentHandler) applyEinoTraceResumeSegment(
|
||||||
conversationID string,
|
conversationID string,
|
||||||
@@ -45,136 +25,3 @@ func (h *AgentHandler) applyEinoTraceResumeSegment(
|
|||||||
*curFinalMessage = segmentUserMessage
|
*curFinalMessage = segmentUserMessage
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// applyEinoTransientRetrySegment 临时错误重试:恢复轨迹并保留本请求原始 user 文案(不注入续跑说明)。
|
|
||||||
// segmentUserMessage 为本轮 HTTP 请求开始时用户发送的内容,避免因清空 finalMessage 而丢失「你好」等短句。
|
|
||||||
func (h *AgentHandler) applyEinoTransientRetrySegment(
|
|
||||||
conversationID string,
|
|
||||||
result *multiagent.RunResult,
|
|
||||||
curHistory *[]agent.ChatMessage,
|
|
||||||
curFinalMessage *string,
|
|
||||||
segmentUserMessage string,
|
|
||||||
) {
|
|
||||||
if shouldPersistEinoAgentTraceAfterRunError(context.Background()) {
|
|
||||||
h.persistEinoAgentTraceForResume(conversationID, result)
|
|
||||||
}
|
|
||||||
if hist, err := h.loadHistoryFromAgentTrace(conversationID); err == nil && len(hist) > 0 {
|
|
||||||
*curHistory = hist
|
|
||||||
}
|
|
||||||
if s := strings.TrimSpace(segmentUserMessage); s != "" {
|
|
||||||
*curFinalMessage = segmentUserMessage
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// handleEinoTransientRetryContinue 在 SSE 任务循环内处理临时错误重试;返回 true 表示外层 for 应 continue。
|
|
||||||
func (h *AgentHandler) handleEinoTransientRetryContinue(
|
|
||||||
baseCtx context.Context,
|
|
||||||
conversationID string,
|
|
||||||
result *multiagent.RunResult,
|
|
||||||
runErr error,
|
|
||||||
transientAttempts *int,
|
|
||||||
curHistory *[]agent.ChatMessage,
|
|
||||||
curFinalMessage *string,
|
|
||||||
segmentUserMessage string,
|
|
||||||
progressCallback func(eventType, message string, data interface{}),
|
|
||||||
sendProgress func(msg string, extra map[string]interface{}),
|
|
||||||
) (handled bool, fatal error) {
|
|
||||||
if !errors.Is(runErr, multiagent.ErrTransientRetryContinue) {
|
|
||||||
return false, nil
|
|
||||||
}
|
|
||||||
maxAttempts := h.einoRunRetryMaxAttempts()
|
|
||||||
*transientAttempts++
|
|
||||||
if *transientAttempts > maxAttempts {
|
|
||||||
if shouldPersistEinoAgentTraceAfterRunError(baseCtx) {
|
|
||||||
h.persistEinoAgentTraceForResume(conversationID, result)
|
|
||||||
}
|
|
||||||
return false, errors.New("transient retry exhausted: " + runErr.Error())
|
|
||||||
}
|
|
||||||
attemptNo := *transientAttempts
|
|
||||||
backoff := multiagent.TransientRetryBackoff(attemptNo-1, h.einoRunRetryMaxBackoffSec())
|
|
||||||
if progressCallback != nil {
|
|
||||||
progressCallback("eino_run_retry", fmt.Sprintf("遇到临时错误,%d 秒后第 %d/%d 次重试…", int(backoff.Seconds()), attemptNo, maxAttempts), map[string]interface{}{
|
|
||||||
"conversationId": conversationID,
|
|
||||||
"source": "eino",
|
|
||||||
"attempt": attemptNo,
|
|
||||||
"maxAttempts": maxAttempts,
|
|
||||||
"backoffSec": int(backoff.Seconds()),
|
|
||||||
})
|
|
||||||
}
|
|
||||||
select {
|
|
||||||
case <-baseCtx.Done():
|
|
||||||
return false, context.Cause(baseCtx)
|
|
||||||
case <-time.After(backoff):
|
|
||||||
}
|
|
||||||
h.applyEinoTransientRetrySegment(conversationID, result, curHistory, curFinalMessage, segmentUserMessage)
|
|
||||||
if progressCallback != nil {
|
|
||||||
progressCallback("eino_run_retry", "已恢复上下文,正在重试…", map[string]interface{}{
|
|
||||||
"conversationId": conversationID,
|
|
||||||
"source": "eino",
|
|
||||||
"attempt": attemptNo,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
if sendProgress != nil {
|
|
||||||
sendProgress("正在重试…", map[string]interface{}{
|
|
||||||
"conversationId": conversationID,
|
|
||||||
"source": "transient_retry",
|
|
||||||
})
|
|
||||||
}
|
|
||||||
return true, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// handleEinoEmptyResponseContinue 在 SSE 任务循环内处理「正常结束但无助手正文」;返回 exhausted=true 时由外层按成功结束(保留占位文案)。
|
|
||||||
// 与临时错误重试一致:仅恢复轨迹并保留本请求原始 user 文案,不向模型注入续跑说明。
|
|
||||||
func (h *AgentHandler) handleEinoEmptyResponseContinue(
|
|
||||||
baseCtx context.Context,
|
|
||||||
conversationID string,
|
|
||||||
result *multiagent.RunResult,
|
|
||||||
runErr error,
|
|
||||||
emptyResponseAttempts *int,
|
|
||||||
curHistory *[]agent.ChatMessage,
|
|
||||||
curFinalMessage *string,
|
|
||||||
segmentUserMessage string,
|
|
||||||
progressCallback func(eventType, message string, data interface{}),
|
|
||||||
sendProgress func(msg string, extra map[string]interface{}),
|
|
||||||
) (handled bool, exhausted bool) {
|
|
||||||
if !errors.Is(runErr, multiagent.ErrEmptyResponseContinue) {
|
|
||||||
return false, false
|
|
||||||
}
|
|
||||||
maxAttempts := h.einoRunRetryMaxAttempts()
|
|
||||||
*emptyResponseAttempts++
|
|
||||||
if *emptyResponseAttempts > maxAttempts {
|
|
||||||
if h.logger != nil {
|
|
||||||
h.logger.Warn("eino empty response auto resume exhausted",
|
|
||||||
zap.String("conversationId", conversationID),
|
|
||||||
zap.Int("maxAttempts", maxAttempts))
|
|
||||||
}
|
|
||||||
if shouldPersistEinoAgentTraceAfterRunError(baseCtx) {
|
|
||||||
h.persistEinoAgentTraceForResume(conversationID, result)
|
|
||||||
}
|
|
||||||
return false, true
|
|
||||||
}
|
|
||||||
attemptNo := *emptyResponseAttempts
|
|
||||||
if h.logger != nil {
|
|
||||||
h.logger.Info("eino empty response, auto resume from trace",
|
|
||||||
zap.String("conversationId", conversationID),
|
|
||||||
zap.Int("attempt", attemptNo),
|
|
||||||
zap.Int("maxAttempts", maxAttempts))
|
|
||||||
}
|
|
||||||
if progressCallback != nil {
|
|
||||||
progressCallback("eino_empty_response_continue", fmt.Sprintf("未捕获到助手正文,正在基于轨迹自动续跑(%d/%d)…", attemptNo, maxAttempts), map[string]interface{}{
|
|
||||||
"conversationId": conversationID,
|
|
||||||
"source": "eino",
|
|
||||||
"attempt": attemptNo,
|
|
||||||
"maxAttempts": maxAttempts,
|
|
||||||
"resumeKind": "trace_segment",
|
|
||||||
})
|
|
||||||
}
|
|
||||||
h.applyEinoTransientRetrySegment(conversationID, result, curHistory, curFinalMessage, segmentUserMessage)
|
|
||||||
if sendProgress != nil {
|
|
||||||
sendProgress("已恢复上下文,正在继续推理…", map[string]interface{}{
|
|
||||||
"conversationId": conversationID,
|
|
||||||
"source": "empty_response_continue",
|
|
||||||
})
|
|
||||||
}
|
|
||||||
return true, false
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -119,7 +119,6 @@ func (h *AgentHandler) EinoSingleAgentLoopStream(c *gin.Context) {
|
|||||||
|
|
||||||
var cancelWithCause context.CancelCauseFunc
|
var cancelWithCause context.CancelCauseFunc
|
||||||
curFinalMessage := prep.FinalMessage
|
curFinalMessage := prep.FinalMessage
|
||||||
segmentUserMessage := prep.FinalMessage // 本请求原始用户句,临时重试时不得丢失
|
|
||||||
curHistory := prep.History
|
curHistory := prep.History
|
||||||
roleTools := prep.RoleTools
|
roleTools := prep.RoleTools
|
||||||
|
|
||||||
@@ -177,8 +176,6 @@ func (h *AgentHandler) EinoSingleAgentLoopStream(c *gin.Context) {
|
|||||||
taskOwned = true
|
taskOwned = true
|
||||||
|
|
||||||
var cumulativeMCPExecutionIDs []string
|
var cumulativeMCPExecutionIDs []string
|
||||||
var transientRunAttempts int
|
|
||||||
var emptyResponseAttempts int
|
|
||||||
// 同一请求内分段续跑时,主代理 iteration 事件按偏移累计,避免 UI 出现「第3轮 → 第1轮」回跳。
|
// 同一请求内分段续跑时,主代理 iteration 事件按偏移累计,避免 UI 出现「第3轮 → 第1轮」回跳。
|
||||||
var mainIterationOffset int
|
var mainIterationOffset int
|
||||||
|
|
||||||
@@ -224,8 +221,10 @@ func (h *AgentHandler) EinoSingleAgentLoopStream(c *gin.Context) {
|
|||||||
h.config,
|
h.config,
|
||||||
&h.config.MultiAgent,
|
&h.config.MultiAgent,
|
||||||
h.agent,
|
h.agent,
|
||||||
|
h.db,
|
||||||
h.logger,
|
h.logger,
|
||||||
conversationID,
|
conversationID,
|
||||||
|
h.conversationProjectID(conversationID),
|
||||||
curFinalMessage,
|
curFinalMessage,
|
||||||
curHistory,
|
curHistory,
|
||||||
roleTools,
|
roleTools,
|
||||||
@@ -238,54 +237,11 @@ func (h *AgentHandler) EinoSingleAgentLoopStream(c *gin.Context) {
|
|||||||
cumulativeMCPExecutionIDs = mergeMCPExecutionIDLists(cumulativeMCPExecutionIDs, result.MCPExecutionIDs)
|
cumulativeMCPExecutionIDs = mergeMCPExecutionIDLists(cumulativeMCPExecutionIDs, result.MCPExecutionIDs)
|
||||||
}
|
}
|
||||||
|
|
||||||
handledEmpty, exhaustedEmpty := h.handleEinoEmptyResponseContinue(
|
|
||||||
baseCtx, conversationID, result, runErr, &emptyResponseAttempts,
|
|
||||||
&curHistory, &curFinalMessage, segmentUserMessage, progressCallback,
|
|
||||||
func(msg string, extra map[string]interface{}) { sendEvent("progress", msg, extra) },
|
|
||||||
)
|
|
||||||
if exhaustedEmpty {
|
|
||||||
runErr = nil
|
|
||||||
transientRunAttempts = 0
|
|
||||||
timeoutCancel()
|
|
||||||
break
|
|
||||||
}
|
|
||||||
if handledEmpty {
|
|
||||||
mainIterationOffset += segmentMainIterationMax
|
|
||||||
transientRunAttempts = 0
|
|
||||||
timeoutCancel()
|
|
||||||
baseCtx, cancelWithCause = context.WithCancelCause(context.Background())
|
|
||||||
h.tasks.BindTaskCancel(conversationID, cancelWithCause)
|
|
||||||
taskCtx, timeoutCancel = context.WithTimeout(baseCtx, 600*time.Minute)
|
|
||||||
h.tasks.UpdateTaskStatus(conversationID, "running")
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
if runErr == nil {
|
if runErr == nil {
|
||||||
// 任一段成功完成后,重置临时错误重试窗口(次数/退避从头开始)。
|
|
||||||
transientRunAttempts = 0
|
|
||||||
emptyResponseAttempts = 0
|
|
||||||
timeoutCancel()
|
timeoutCancel()
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
|
|
||||||
handled, fatalErr := h.handleEinoTransientRetryContinue(
|
|
||||||
baseCtx, conversationID, result, runErr, &transientRunAttempts,
|
|
||||||
&curHistory, &curFinalMessage, segmentUserMessage, progressCallback,
|
|
||||||
func(msg string, extra map[string]interface{}) { sendEvent("progress", msg, extra) },
|
|
||||||
)
|
|
||||||
if handled {
|
|
||||||
mainIterationOffset += segmentMainIterationMax
|
|
||||||
timeoutCancel()
|
|
||||||
baseCtx, cancelWithCause = context.WithCancelCause(context.Background())
|
|
||||||
h.tasks.BindTaskCancel(conversationID, cancelWithCause)
|
|
||||||
taskCtx, timeoutCancel = context.WithTimeout(baseCtx, 600*time.Minute)
|
|
||||||
h.tasks.UpdateTaskStatus(conversationID, "running")
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if fatalErr != nil {
|
|
||||||
runErr = fatalErr
|
|
||||||
}
|
|
||||||
|
|
||||||
cause := context.Cause(baseCtx)
|
cause := context.Cause(baseCtx)
|
||||||
if errors.Is(cause, multiagent.ErrInterruptContinue) {
|
if errors.Is(cause, multiagent.ErrInterruptContinue) {
|
||||||
if shouldPersistEinoAgentTraceAfterRunError(baseCtx) {
|
if shouldPersistEinoAgentTraceAfterRunError(baseCtx) {
|
||||||
@@ -310,8 +266,6 @@ func (h *AgentHandler) EinoSingleAgentLoopStream(c *gin.Context) {
|
|||||||
"source": "interrupt_continue",
|
"source": "interrupt_continue",
|
||||||
})
|
})
|
||||||
mainIterationOffset += segmentMainIterationMax
|
mainIterationOffset += segmentMainIterationMax
|
||||||
// 非临时错误分段续跑(用户中断并继续)时,清空 transient 计数,避免跨分段累加。
|
|
||||||
transientRunAttempts = 0
|
|
||||||
timeoutCancel()
|
timeoutCancel()
|
||||||
baseCtx, cancelWithCause = context.WithCancelCause(context.Background())
|
baseCtx, cancelWithCause = context.WithCancelCause(context.Background())
|
||||||
h.tasks.BindTaskCancel(conversationID, cancelWithCause)
|
h.tasks.BindTaskCancel(conversationID, cancelWithCause)
|
||||||
@@ -446,16 +400,16 @@ func (h *AgentHandler) EinoSingleAgentLoop(c *gin.Context) {
|
|||||||
curMsg := prep.FinalMessage
|
curMsg := prep.FinalMessage
|
||||||
var result *multiagent.RunResult
|
var result *multiagent.RunResult
|
||||||
var runErr error
|
var runErr error
|
||||||
var transientRunAttempts int
|
|
||||||
var emptyResponseAttempts int
|
|
||||||
for {
|
for {
|
||||||
result, runErr = multiagent.RunEinoSingleChatModelAgent(
|
result, runErr = multiagent.RunEinoSingleChatModelAgent(
|
||||||
taskCtx,
|
taskCtx,
|
||||||
h.config,
|
h.config,
|
||||||
&h.config.MultiAgent,
|
&h.config.MultiAgent,
|
||||||
h.agent,
|
h.agent,
|
||||||
|
h.db,
|
||||||
h.logger,
|
h.logger,
|
||||||
prep.ConversationID,
|
prep.ConversationID,
|
||||||
|
h.conversationProjectID(prep.ConversationID),
|
||||||
curMsg,
|
curMsg,
|
||||||
curHist,
|
curHist,
|
||||||
prep.RoleTools,
|
prep.RoleTools,
|
||||||
@@ -463,28 +417,9 @@ func (h *AgentHandler) EinoSingleAgentLoop(c *gin.Context) {
|
|||||||
chatReasoningToClientIntent(req.Reasoning),
|
chatReasoningToClientIntent(req.Reasoning),
|
||||||
h.projectBlackboardBlock(prep.ConversationID),
|
h.projectBlackboardBlock(prep.ConversationID),
|
||||||
)
|
)
|
||||||
handledEmpty, exhaustedEmpty := h.handleEinoEmptyResponseContinue(
|
|
||||||
baseCtx, prep.ConversationID, result, runErr, &emptyResponseAttempts,
|
|
||||||
&curHist, &curMsg, prep.FinalMessage, progressCallback, nil,
|
|
||||||
)
|
|
||||||
if exhaustedEmpty {
|
|
||||||
runErr = nil
|
|
||||||
break
|
|
||||||
}
|
|
||||||
if handledEmpty {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if runErr == nil {
|
if runErr == nil {
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
if handled, fatalErr := h.handleEinoTransientRetryContinue(
|
|
||||||
baseCtx, prep.ConversationID, result, runErr, &transientRunAttempts,
|
|
||||||
&curHist, &curMsg, prep.FinalMessage, progressCallback, nil,
|
|
||||||
); handled {
|
|
||||||
continue
|
|
||||||
} else if fatalErr != nil {
|
|
||||||
runErr = fatalErr
|
|
||||||
}
|
|
||||||
if shouldPersistEinoAgentTraceAfterRunError(baseCtx) {
|
if shouldPersistEinoAgentTraceAfterRunError(baseCtx) {
|
||||||
h.persistEinoAgentTraceForResume(prep.ConversationID, result)
|
h.persistEinoAgentTraceForResume(prep.ConversationID, result)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -136,7 +136,6 @@ func (h *AgentHandler) MultiAgentLoopStream(c *gin.Context) {
|
|||||||
|
|
||||||
var cancelWithCause context.CancelCauseFunc
|
var cancelWithCause context.CancelCauseFunc
|
||||||
curFinalMessage := prep.FinalMessage
|
curFinalMessage := prep.FinalMessage
|
||||||
segmentUserMessage := prep.FinalMessage // 本请求原始用户句,临时重试时不得丢失
|
|
||||||
curHistory := prep.History
|
curHistory := prep.History
|
||||||
roleTools := prep.RoleTools
|
roleTools := prep.RoleTools
|
||||||
orch := strings.TrimSpace(req.Orchestration)
|
orch := strings.TrimSpace(req.Orchestration)
|
||||||
@@ -187,8 +186,6 @@ func (h *AgentHandler) MultiAgentLoopStream(c *gin.Context) {
|
|||||||
|
|
||||||
// 同一 HTTP 流内多段 Run(如中断并继续)合并 MCP execution id,供最终 response / 库表与工具芯片展示完整列表
|
// 同一 HTTP 流内多段 Run(如中断并继续)合并 MCP execution id,供最终 response / 库表与工具芯片展示完整列表
|
||||||
var cumulativeMCPExecutionIDs []string
|
var cumulativeMCPExecutionIDs []string
|
||||||
var transientRunAttempts int
|
|
||||||
var emptyResponseAttempts int
|
|
||||||
// 同一请求内分段续跑时,主代理 iteration 事件按偏移累计,避免 UI 出现「第3轮 → 第1轮」回跳。
|
// 同一请求内分段续跑时,主代理 iteration 事件按偏移累计,避免 UI 出现「第3轮 → 第1轮」回跳。
|
||||||
var mainIterationOffset int
|
var mainIterationOffset int
|
||||||
|
|
||||||
@@ -234,8 +231,10 @@ func (h *AgentHandler) MultiAgentLoopStream(c *gin.Context) {
|
|||||||
h.config,
|
h.config,
|
||||||
&h.config.MultiAgent,
|
&h.config.MultiAgent,
|
||||||
h.agent,
|
h.agent,
|
||||||
|
h.db,
|
||||||
h.logger,
|
h.logger,
|
||||||
conversationID,
|
conversationID,
|
||||||
|
h.conversationProjectID(conversationID),
|
||||||
curFinalMessage,
|
curFinalMessage,
|
||||||
curHistory,
|
curHistory,
|
||||||
roleTools,
|
roleTools,
|
||||||
@@ -250,54 +249,11 @@ func (h *AgentHandler) MultiAgentLoopStream(c *gin.Context) {
|
|||||||
cumulativeMCPExecutionIDs = mergeMCPExecutionIDLists(cumulativeMCPExecutionIDs, result.MCPExecutionIDs)
|
cumulativeMCPExecutionIDs = mergeMCPExecutionIDLists(cumulativeMCPExecutionIDs, result.MCPExecutionIDs)
|
||||||
}
|
}
|
||||||
|
|
||||||
handledEmpty, exhaustedEmpty := h.handleEinoEmptyResponseContinue(
|
|
||||||
baseCtx, conversationID, result, runErr, &emptyResponseAttempts,
|
|
||||||
&curHistory, &curFinalMessage, segmentUserMessage, progressCallback,
|
|
||||||
func(msg string, extra map[string]interface{}) { sendEvent("progress", msg, extra) },
|
|
||||||
)
|
|
||||||
if exhaustedEmpty {
|
|
||||||
runErr = nil
|
|
||||||
transientRunAttempts = 0
|
|
||||||
timeoutCancel()
|
|
||||||
break
|
|
||||||
}
|
|
||||||
if handledEmpty {
|
|
||||||
mainIterationOffset += segmentMainIterationMax
|
|
||||||
transientRunAttempts = 0
|
|
||||||
timeoutCancel()
|
|
||||||
baseCtx, cancelWithCause = context.WithCancelCause(context.Background())
|
|
||||||
h.tasks.BindTaskCancel(conversationID, cancelWithCause)
|
|
||||||
taskCtx, timeoutCancel = context.WithTimeout(baseCtx, 600*time.Minute)
|
|
||||||
h.tasks.UpdateTaskStatus(conversationID, "running")
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
if runErr == nil {
|
if runErr == nil {
|
||||||
// 任一段成功完成后,重置临时错误重试窗口(次数/退避从头开始)。
|
|
||||||
transientRunAttempts = 0
|
|
||||||
emptyResponseAttempts = 0
|
|
||||||
timeoutCancel()
|
timeoutCancel()
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
|
|
||||||
handled, fatalErr := h.handleEinoTransientRetryContinue(
|
|
||||||
baseCtx, conversationID, result, runErr, &transientRunAttempts,
|
|
||||||
&curHistory, &curFinalMessage, segmentUserMessage, progressCallback,
|
|
||||||
func(msg string, extra map[string]interface{}) { sendEvent("progress", msg, extra) },
|
|
||||||
)
|
|
||||||
if handled {
|
|
||||||
mainIterationOffset += segmentMainIterationMax
|
|
||||||
timeoutCancel()
|
|
||||||
baseCtx, cancelWithCause = context.WithCancelCause(context.Background())
|
|
||||||
h.tasks.BindTaskCancel(conversationID, cancelWithCause)
|
|
||||||
taskCtx, timeoutCancel = context.WithTimeout(baseCtx, 600*time.Minute)
|
|
||||||
h.tasks.UpdateTaskStatus(conversationID, "running")
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if fatalErr != nil {
|
|
||||||
runErr = fatalErr
|
|
||||||
}
|
|
||||||
|
|
||||||
cause := context.Cause(baseCtx)
|
cause := context.Cause(baseCtx)
|
||||||
if errors.Is(cause, multiagent.ErrInterruptContinue) {
|
if errors.Is(cause, multiagent.ErrInterruptContinue) {
|
||||||
if shouldPersistEinoAgentTraceAfterRunError(baseCtx) {
|
if shouldPersistEinoAgentTraceAfterRunError(baseCtx) {
|
||||||
@@ -322,8 +278,6 @@ func (h *AgentHandler) MultiAgentLoopStream(c *gin.Context) {
|
|||||||
"source": "interrupt_continue",
|
"source": "interrupt_continue",
|
||||||
})
|
})
|
||||||
mainIterationOffset += segmentMainIterationMax
|
mainIterationOffset += segmentMainIterationMax
|
||||||
// 非临时错误分段续跑(用户中断并继续)时,清空 transient 计数,避免跨分段累加。
|
|
||||||
transientRunAttempts = 0
|
|
||||||
timeoutCancel()
|
timeoutCancel()
|
||||||
baseCtx, cancelWithCause = context.WithCancelCause(context.Background())
|
baseCtx, cancelWithCause = context.WithCancelCause(context.Background())
|
||||||
h.tasks.BindTaskCancel(conversationID, cancelWithCause)
|
h.tasks.BindTaskCancel(conversationID, cancelWithCause)
|
||||||
@@ -458,16 +412,16 @@ func (h *AgentHandler) MultiAgentLoop(c *gin.Context) {
|
|||||||
curMsg := prep.FinalMessage
|
curMsg := prep.FinalMessage
|
||||||
var result *multiagent.RunResult
|
var result *multiagent.RunResult
|
||||||
var runErr error
|
var runErr error
|
||||||
var transientRunAttempts int
|
|
||||||
var emptyResponseAttempts int
|
|
||||||
for {
|
for {
|
||||||
result, runErr = multiagent.RunDeepAgent(
|
result, runErr = multiagent.RunDeepAgent(
|
||||||
taskCtx,
|
taskCtx,
|
||||||
h.config,
|
h.config,
|
||||||
&h.config.MultiAgent,
|
&h.config.MultiAgent,
|
||||||
h.agent,
|
h.agent,
|
||||||
|
h.db,
|
||||||
h.logger,
|
h.logger,
|
||||||
prep.ConversationID,
|
prep.ConversationID,
|
||||||
|
h.conversationProjectID(prep.ConversationID),
|
||||||
curMsg,
|
curMsg,
|
||||||
curHist,
|
curHist,
|
||||||
prep.RoleTools,
|
prep.RoleTools,
|
||||||
@@ -477,28 +431,9 @@ func (h *AgentHandler) MultiAgentLoop(c *gin.Context) {
|
|||||||
chatReasoningToClientIntent(req.Reasoning),
|
chatReasoningToClientIntent(req.Reasoning),
|
||||||
h.projectBlackboardBlock(prep.ConversationID),
|
h.projectBlackboardBlock(prep.ConversationID),
|
||||||
)
|
)
|
||||||
handledEmpty, exhaustedEmpty := h.handleEinoEmptyResponseContinue(
|
|
||||||
baseCtx, prep.ConversationID, result, runErr, &emptyResponseAttempts,
|
|
||||||
&curHist, &curMsg, prep.FinalMessage, progressCallback, nil,
|
|
||||||
)
|
|
||||||
if exhaustedEmpty {
|
|
||||||
runErr = nil
|
|
||||||
break
|
|
||||||
}
|
|
||||||
if handledEmpty {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if runErr == nil {
|
if runErr == nil {
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
if handled, fatalErr := h.handleEinoTransientRetryContinue(
|
|
||||||
baseCtx, prep.ConversationID, result, runErr, &transientRunAttempts,
|
|
||||||
&curHist, &curMsg, prep.FinalMessage, progressCallback, nil,
|
|
||||||
); handled {
|
|
||||||
continue
|
|
||||||
} else if fatalErr != nil {
|
|
||||||
runErr = fatalErr
|
|
||||||
}
|
|
||||||
if shouldPersistEinoAgentTraceAfterRunError(baseCtx) {
|
if shouldPersistEinoAgentTraceAfterRunError(baseCtx) {
|
||||||
h.persistEinoAgentTraceForResume(prep.ConversationID, result)
|
h.persistEinoAgentTraceForResume(prep.ConversationID, result)
|
||||||
}
|
}
|
||||||
|
|||||||
+139
-34
@@ -2,10 +2,8 @@ package handler
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"net/http"
|
"net/http"
|
||||||
"time"
|
|
||||||
|
|
||||||
"cyberstrike-ai/internal/database"
|
"cyberstrike-ai/internal/database"
|
||||||
"cyberstrike-ai/internal/storage"
|
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"go.uber.org/zap"
|
"go.uber.org/zap"
|
||||||
@@ -15,17 +13,15 @@ import (
|
|||||||
type OpenAPIHandler struct {
|
type OpenAPIHandler struct {
|
||||||
db *database.DB
|
db *database.DB
|
||||||
logger *zap.Logger
|
logger *zap.Logger
|
||||||
resultStorage storage.ResultStorage
|
|
||||||
conversationHdlr *ConversationHandler
|
conversationHdlr *ConversationHandler
|
||||||
agentHdlr *AgentHandler
|
agentHdlr *AgentHandler
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewOpenAPIHandler 创建新的OpenAPI处理器
|
// NewOpenAPIHandler 创建新的OpenAPI处理器
|
||||||
func NewOpenAPIHandler(db *database.DB, logger *zap.Logger, resultStorage storage.ResultStorage, conversationHdlr *ConversationHandler, agentHdlr *AgentHandler) *OpenAPIHandler {
|
func NewOpenAPIHandler(db *database.DB, logger *zap.Logger, conversationHdlr *ConversationHandler, agentHdlr *AgentHandler) *OpenAPIHandler {
|
||||||
return &OpenAPIHandler{
|
return &OpenAPIHandler{
|
||||||
db: db,
|
db: db,
|
||||||
logger: logger,
|
logger: logger,
|
||||||
resultStorage: resultStorage,
|
|
||||||
conversationHdlr: conversationHdlr,
|
conversationHdlr: conversationHdlr,
|
||||||
agentHdlr: agentHdlr,
|
agentHdlr: agentHdlr,
|
||||||
}
|
}
|
||||||
@@ -2468,17 +2464,108 @@ func (h *OpenAPIHandler) GetOpenAPISpec(c *gin.Context) {
|
|||||||
"parameters": []map[string]interface{}{
|
"parameters": []map[string]interface{}{
|
||||||
{"name": "id", "in": "path", "required": true, "schema": map[string]interface{}{"type": "string"}},
|
{"name": "id", "in": "path", "required": true, "schema": map[string]interface{}{"type": "string"}},
|
||||||
{"name": "fact_key", "in": "query", "schema": map[string]interface{}{"type": "string"}},
|
{"name": "fact_key", "in": "query", "schema": map[string]interface{}{"type": "string"}},
|
||||||
|
{"name": "include_links", "in": "query", "schema": map[string]interface{}{"type": "boolean"}},
|
||||||
|
{"name": "include_link_counts", "in": "query", "schema": map[string]interface{}{"type": "boolean"}},
|
||||||
},
|
},
|
||||||
"responses": map[string]interface{}{"200": map[string]interface{}{"description": "事实列表或单条"}},
|
"responses": map[string]interface{}{"200": map[string]interface{}{"description": "事实列表或单条(可含 link_counts / outgoing_links)"}},
|
||||||
},
|
},
|
||||||
"post": map[string]interface{}{
|
"post": map[string]interface{}{
|
||||||
"tags": []string{"项目管理"}, "summary": "创建/更新事实", "operationId": "upsertProjectFactREST",
|
"tags": []string{"项目管理"}, "summary": "创建/更新事实", "operationId": "upsertProjectFactREST",
|
||||||
"parameters": []map[string]interface{}{
|
"parameters": []map[string]interface{}{
|
||||||
{"name": "id", "in": "path", "required": true, "schema": map[string]interface{}{"type": "string"}},
|
{"name": "id", "in": "path", "required": true, "schema": map[string]interface{}{"type": "string"}},
|
||||||
},
|
},
|
||||||
|
"requestBody": map[string]interface{}{
|
||||||
|
"required": true,
|
||||||
|
"content": map[string]interface{}{
|
||||||
|
"application/json": map[string]interface{}{
|
||||||
|
"schema": map[string]interface{}{
|
||||||
|
"type": "object",
|
||||||
|
"properties": map[string]interface{}{
|
||||||
|
"fact_key": map[string]interface{}{"type": "string"},
|
||||||
|
"summary": map[string]interface{}{"type": "string"},
|
||||||
|
"links": map[string]interface{}{
|
||||||
|
"type": "array",
|
||||||
|
"items": map[string]interface{}{
|
||||||
|
"type": "object",
|
||||||
|
"properties": map[string]interface{}{
|
||||||
|
"to": map[string]interface{}{"type": "string"},
|
||||||
|
"type": map[string]interface{}{"type": "string"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"links_text": map[string]interface{}{"type": "string", "description": "type: fact_key 每行一条"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
"responses": map[string]interface{}{"200": map[string]interface{}{"description": "成功"}},
|
"responses": map[string]interface{}{"200": map[string]interface{}{"description": "成功"}},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
"/api/projects/{id}/fact-graph": map[string]interface{}{
|
||||||
|
"get": map[string]interface{}{
|
||||||
|
"tags": []string{"项目管理"}, "summary": "获取项目事实攻击路径图", "operationId": "getProjectFactGraph",
|
||||||
|
"parameters": []map[string]interface{}{
|
||||||
|
{"name": "id", "in": "path", "required": true, "schema": map[string]interface{}{"type": "string"}},
|
||||||
|
{"name": "view", "in": "query", "schema": map[string]interface{}{"type": "string", "enum": []string{"path", "full"}, "default": "path"}},
|
||||||
|
{"name": "exclude_deprecated", "in": "query", "schema": map[string]interface{}{"type": "boolean", "default": true}},
|
||||||
|
},
|
||||||
|
"responses": map[string]interface{}{"200": map[string]interface{}{"description": "nodes + edges"}},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"/api/projects/{id}/fact-edges": map[string]interface{}{
|
||||||
|
"get": map[string]interface{}{
|
||||||
|
"tags": []string{"项目管理"}, "summary": "列出项目全部事实边", "operationId": "listProjectFactEdges",
|
||||||
|
"parameters": []map[string]interface{}{
|
||||||
|
{"name": "id", "in": "path", "required": true, "schema": map[string]interface{}{"type": "string"}},
|
||||||
|
},
|
||||||
|
"responses": map[string]interface{}{"200": map[string]interface{}{"description": "边列表"}},
|
||||||
|
},
|
||||||
|
"post": map[string]interface{}{
|
||||||
|
"tags": []string{"项目管理"}, "summary": "添加事实边", "operationId": "createProjectFactEdge",
|
||||||
|
"parameters": []map[string]interface{}{
|
||||||
|
{"name": "id", "in": "path", "required": true, "schema": map[string]interface{}{"type": "string"}},
|
||||||
|
},
|
||||||
|
"requestBody": map[string]interface{}{
|
||||||
|
"required": true,
|
||||||
|
"content": map[string]interface{}{
|
||||||
|
"application/json": map[string]interface{}{
|
||||||
|
"schema": map[string]interface{}{
|
||||||
|
"type": "object",
|
||||||
|
"required": []string{"source_fact_key", "target_fact_key", "edge_type"},
|
||||||
|
"properties": map[string]interface{}{
|
||||||
|
"source_fact_key": map[string]interface{}{"type": "string"},
|
||||||
|
"target_fact_key": map[string]interface{}{"type": "string"},
|
||||||
|
"edge_type": map[string]interface{}{"type": "string"},
|
||||||
|
"confidence": map[string]interface{}{"type": "string"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"responses": map[string]interface{}{"200": map[string]interface{}{"description": "边已创建"}},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"/api/projects/{id}/fact-edges/{edgeId}": map[string]interface{}{
|
||||||
|
"delete": map[string]interface{}{
|
||||||
|
"tags": []string{"项目管理"}, "summary": "删除事实边", "operationId": "deleteProjectFactEdge",
|
||||||
|
"parameters": []map[string]interface{}{
|
||||||
|
{"name": "id", "in": "path", "required": true, "schema": map[string]interface{}{"type": "string"}},
|
||||||
|
{"name": "edgeId", "in": "path", "required": true, "schema": map[string]interface{}{"type": "string"}},
|
||||||
|
},
|
||||||
|
"responses": map[string]interface{}{"200": map[string]interface{}{"description": "删除成功"}},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"/api/projects/{id}/promote-attack-chain/{conversationId}": map[string]interface{}{
|
||||||
|
"post": map[string]interface{}{
|
||||||
|
"tags": []string{"项目管理"}, "summary": "将对话攻击链沉淀到项目事实图", "operationId": "promoteAttackChainToProject",
|
||||||
|
"parameters": []map[string]interface{}{
|
||||||
|
{"name": "id", "in": "path", "required": true, "schema": map[string]interface{}{"type": "string"}},
|
||||||
|
{"name": "conversationId", "in": "path", "required": true, "schema": map[string]interface{}{"type": "string"}},
|
||||||
|
},
|
||||||
|
"responses": map[string]interface{}{"200": map[string]interface{}{"description": "沉淀结果(facts/edges/graph)"}},
|
||||||
|
},
|
||||||
|
},
|
||||||
"/api/vulnerabilities": map[string]interface{}{
|
"/api/vulnerabilities": map[string]interface{}{
|
||||||
"get": map[string]interface{}{
|
"get": map[string]interface{}{
|
||||||
"tags": []string{"漏洞管理"},
|
"tags": []string{"漏洞管理"},
|
||||||
@@ -5034,6 +5121,51 @@ func (h *OpenAPIHandler) GetOpenAPISpec(c *gin.Context) {
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
"/api/config/list-models": map[string]interface{}{
|
||||||
|
"post": map[string]interface{}{
|
||||||
|
"tags": []string{"配置管理"},
|
||||||
|
"summary": "获取模型列表",
|
||||||
|
"description": "代理调用 OpenAI 兼容 GET /models,返回可用模型 id 列表。Claude 不支持。",
|
||||||
|
"operationId": "listModels",
|
||||||
|
"requestBody": map[string]interface{}{
|
||||||
|
"required": true,
|
||||||
|
"content": map[string]interface{}{
|
||||||
|
"application/json": map[string]interface{}{
|
||||||
|
"schema": map[string]interface{}{
|
||||||
|
"type": "object",
|
||||||
|
"required": []string{"api_key"},
|
||||||
|
"properties": map[string]interface{}{
|
||||||
|
"provider": map[string]interface{}{"type": "string", "description": "LLM提供商(openai/claude)", "example": "openai"},
|
||||||
|
"base_url": map[string]interface{}{"type": "string", "description": "API基地址(可选)"},
|
||||||
|
"api_key": map[string]interface{}{"type": "string", "description": "API密钥"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"responses": map[string]interface{}{
|
||||||
|
"200": map[string]interface{}{
|
||||||
|
"description": "获取结果",
|
||||||
|
"content": map[string]interface{}{
|
||||||
|
"application/json": map[string]interface{}{
|
||||||
|
"schema": map[string]interface{}{
|
||||||
|
"type": "object",
|
||||||
|
"properties": map[string]interface{}{
|
||||||
|
"success": map[string]interface{}{"type": "boolean"},
|
||||||
|
"supported": map[string]interface{}{"type": "boolean"},
|
||||||
|
"error": map[string]interface{}{"type": "string"},
|
||||||
|
"models": map[string]interface{}{"type": "array", "items": map[string]interface{}{"type": "string"}},
|
||||||
|
"count": map[string]interface{}{"type": "integer"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"400": map[string]interface{}{"description": "参数错误"},
|
||||||
|
"401": map[string]interface{}{"description": "未授权"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
|
||||||
// ==================== 终端 ====================
|
// ==================== 终端 ====================
|
||||||
"/api/terminal/run": map[string]interface{}{
|
"/api/terminal/run": map[string]interface{}{
|
||||||
@@ -6354,35 +6486,8 @@ func (h *OpenAPIHandler) GetConversationResults(c *gin.Context) {
|
|||||||
vulnerabilities[i] = *v
|
vulnerabilities[i] = *v
|
||||||
}
|
}
|
||||||
|
|
||||||
// 获取执行结果(从MCP执行记录中获取)
|
// 获取执行结果(历史大结果由 Eino reduction 落盘,此处不再聚合文件存储)
|
||||||
executionResults := []map[string]interface{}{}
|
executionResults := []map[string]interface{}{}
|
||||||
for _, msg := range messages {
|
|
||||||
if len(msg.MCPExecutionIDs) > 0 {
|
|
||||||
for _, execID := range msg.MCPExecutionIDs {
|
|
||||||
// 尝试从结果存储中获取执行结果
|
|
||||||
if h.resultStorage != nil {
|
|
||||||
result, err := h.resultStorage.GetResult(execID)
|
|
||||||
if err == nil && result != "" {
|
|
||||||
// 获取元数据以获取工具名称和创建时间
|
|
||||||
metadata, err := h.resultStorage.GetResultMetadata(execID)
|
|
||||||
toolName := "unknown"
|
|
||||||
createdAt := time.Now()
|
|
||||||
if err == nil && metadata != nil {
|
|
||||||
toolName = metadata.ToolName
|
|
||||||
createdAt = metadata.CreatedAt
|
|
||||||
}
|
|
||||||
executionResults = append(executionResults, map[string]interface{}{
|
|
||||||
"id": execID,
|
|
||||||
"toolName": toolName,
|
|
||||||
"status": "success",
|
|
||||||
"result": result,
|
|
||||||
"createdAt": createdAt.Format(time.RFC3339),
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
response := map[string]interface{}{
|
response := map[string]interface{}{
|
||||||
"conversationId": conv.ID,
|
"conversationId": conv.ID,
|
||||||
|
|||||||
+239
-5
@@ -1,10 +1,12 @@
|
|||||||
package handler
|
package handler
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
|
"cyberstrike-ai/internal/attackchain"
|
||||||
"cyberstrike-ai/internal/database"
|
"cyberstrike-ai/internal/database"
|
||||||
"cyberstrike-ai/internal/project"
|
"cyberstrike-ai/internal/project"
|
||||||
|
|
||||||
@@ -223,6 +225,12 @@ func (h *ProjectHandler) DeleteProject(c *gin.Context) {
|
|||||||
c.JSON(http.StatusOK, gin.H{"success": true})
|
c.JSON(http.StatusOK, gin.H{"success": true})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type factLinkRequest struct {
|
||||||
|
From string `json:"from"`
|
||||||
|
Type string `json:"type"`
|
||||||
|
Confidence string `json:"confidence,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
type upsertFactRequest struct {
|
type upsertFactRequest struct {
|
||||||
FactKey string `json:"fact_key" binding:"required"`
|
FactKey string `json:"fact_key" binding:"required"`
|
||||||
Category string `json:"category"`
|
Category string `json:"category"`
|
||||||
@@ -231,6 +239,8 @@ type upsertFactRequest struct {
|
|||||||
Confidence string `json:"confidence"`
|
Confidence string `json:"confidence"`
|
||||||
Pinned bool `json:"pinned"`
|
Pinned bool `json:"pinned"`
|
||||||
RelatedVulnerabilityID string `json:"related_vulnerability_id"`
|
RelatedVulnerabilityID string `json:"related_vulnerability_id"`
|
||||||
|
Links []factLinkRequest `json:"links"`
|
||||||
|
LinksText *string `json:"links_text"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// updateFactRequest 部分更新事实;指针字段省略=不修改,body 传 "" 可清空(仍走 merge 逻辑见 Upsert)。
|
// updateFactRequest 部分更新事实;指针字段省略=不修改,body 传 "" 可清空(仍走 merge 逻辑见 Upsert)。
|
||||||
@@ -243,6 +253,74 @@ type updateFactRequest struct {
|
|||||||
Pinned *bool `json:"pinned"`
|
Pinned *bool `json:"pinned"`
|
||||||
RelatedVulnerabilityID *string `json:"related_vulnerability_id"`
|
RelatedVulnerabilityID *string `json:"related_vulnerability_id"`
|
||||||
ClearBody bool `json:"clear_body"`
|
ClearBody bool `json:"clear_body"`
|
||||||
|
Links *[]factLinkRequest `json:"links"`
|
||||||
|
LinksText *string `json:"links_text"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func factLinksFromRequest(links []factLinkRequest, linksText *string) (*project.ParsedFactLinks, error) {
|
||||||
|
if len(links) > 0 {
|
||||||
|
parsed := &project.ParsedFactLinks{}
|
||||||
|
for i, l := range links {
|
||||||
|
from := strings.TrimSpace(l.From)
|
||||||
|
edgeType := strings.TrimSpace(l.Type)
|
||||||
|
if from == "" {
|
||||||
|
return nil, fmt.Errorf("links[%d] 须含 from", i)
|
||||||
|
}
|
||||||
|
if edgeType == "" {
|
||||||
|
return nil, fmt.Errorf("links[%d] 须含 type", i)
|
||||||
|
}
|
||||||
|
parsed.Incoming = append(parsed.Incoming, database.ProjectFactEdgeFromInput{
|
||||||
|
From: from, Type: edgeType, Confidence: strings.TrimSpace(l.Confidence),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
return parsed, nil
|
||||||
|
}
|
||||||
|
if linksText != nil {
|
||||||
|
in, err := project.ParseFactLinksText(*linksText)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &project.ParsedFactLinks{Incoming: in}, nil
|
||||||
|
}
|
||||||
|
return &project.ParsedFactLinks{Incoming: []database.ProjectFactEdgeFromInput{}}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type factWithLinksResponse struct {
|
||||||
|
*database.ProjectFact
|
||||||
|
OutgoingLinks []*database.ProjectFactEdge `json:"outgoing_links,omitempty"`
|
||||||
|
IncomingLinks []*database.ProjectFactEdge `json:"incoming_links,omitempty"`
|
||||||
|
LinkCounts *project.LinkCounts `json:"link_counts,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *ProjectHandler) applyFactLinksAfterUpsert(projectID string, fact *database.ProjectFact, links []factLinkRequest, linksText *string, explicitLinks, parseBody bool) error {
|
||||||
|
if explicitLinks {
|
||||||
|
parsed, err := factLinksFromRequest(links, linksText)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return project.PersistFactLinksFromParsed(h.db, projectID, fact.FactKey, fact.SourceConversationID, parsed, true)
|
||||||
|
}
|
||||||
|
if parseBody {
|
||||||
|
inputs := project.ParseLinksFromBody(fact.Body)
|
||||||
|
if inputs == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return project.PersistFactIncomingLinks(h.db, projectID, fact.FactKey, inputs, true)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *ProjectHandler) factResponseWithLinks(projectID string, f *database.ProjectFact, includeLinks bool) interface{} {
|
||||||
|
if !includeLinks || f == nil {
|
||||||
|
return f
|
||||||
|
}
|
||||||
|
out, _ := h.db.ListOutgoingProjectFactEdges(projectID, f.FactKey)
|
||||||
|
in, _ := h.db.ListIncomingProjectFactEdges(projectID, f.FactKey)
|
||||||
|
return &factWithLinksResponse{
|
||||||
|
ProjectFact: f,
|
||||||
|
OutgoingLinks: out,
|
||||||
|
IncomingLinks: in,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// ListFacts GET /api/projects/:id/facts (fact_key 查询参数可获取单条详情)
|
// ListFacts GET /api/projects/:id/facts (fact_key 查询参数可获取单条详情)
|
||||||
@@ -254,7 +332,8 @@ func (h *ProjectHandler) ListFacts(c *gin.Context) {
|
|||||||
c.JSON(http.StatusNotFound, gin.H{"error": err.Error()})
|
c.JSON(http.StatusNotFound, gin.H{"error": err.Error()})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
c.JSON(http.StatusOK, f)
|
includeLinks := c.Query("include_links") == "1" || c.Query("include_links") == "true"
|
||||||
|
c.JSON(http.StatusOK, h.factResponseWithLinks(projectID, f, includeLinks))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
limit, _ := strconv.Atoi(c.DefaultQuery("limit", "100"))
|
limit, _ := strconv.Atoi(c.DefaultQuery("limit", "100"))
|
||||||
@@ -285,7 +364,52 @@ func (h *ProjectHandler) ListFacts(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
list = filtered
|
list = filtered
|
||||||
}
|
}
|
||||||
|
includeLinkCounts := c.Query("include_link_counts") == "1" || c.Query("include_link_counts") == "true"
|
||||||
|
if !includeLinkCounts {
|
||||||
c.JSON(http.StatusOK, list)
|
c.JSON(http.StatusOK, list)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
counts, err := project.LoadProjectFactLinkCounts(h.db, projectID)
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
out := make([]factWithLinksResponse, 0, len(list))
|
||||||
|
for _, f := range list {
|
||||||
|
item := factWithLinksResponse{ProjectFact: f}
|
||||||
|
if c, ok := counts[f.FactKey]; ok {
|
||||||
|
cc := c
|
||||||
|
item.LinkCounts = &cc
|
||||||
|
}
|
||||||
|
out = append(out, item)
|
||||||
|
}
|
||||||
|
c.JSON(http.StatusOK, out)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetFactGraph GET /api/projects/:id/fact-graph?view=path|full
|
||||||
|
func (h *ProjectHandler) GetFactGraph(c *gin.Context) {
|
||||||
|
projectID := c.Param("id")
|
||||||
|
if _, err := h.db.GetProject(projectID); err != nil {
|
||||||
|
c.JSON(http.StatusNotFound, gin.H{"error": "项目不存在"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
view := c.DefaultQuery("view", "path")
|
||||||
|
excludeDeprecated := true
|
||||||
|
if v := c.Query("exclude_deprecated"); v == "0" || v == "false" {
|
||||||
|
excludeDeprecated = false
|
||||||
|
}
|
||||||
|
graph, err := project.BuildProjectFactGraph(h.db, projectID, view, excludeDeprecated)
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if graph.Nodes == nil {
|
||||||
|
graph.Nodes = []database.ProjectFactGraphNode{}
|
||||||
|
}
|
||||||
|
if graph.Edges == nil {
|
||||||
|
graph.Edges = []database.ProjectFactGraphEdge{}
|
||||||
|
}
|
||||||
|
c.JSON(http.StatusOK, graph)
|
||||||
}
|
}
|
||||||
|
|
||||||
// CreateFact POST /api/projects/:id/facts
|
// CreateFact POST /api/projects/:id/facts
|
||||||
@@ -295,8 +419,9 @@ func (h *ProjectHandler) CreateFact(c *gin.Context) {
|
|||||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
projectID := c.Param("id")
|
||||||
f := &database.ProjectFact{
|
f := &database.ProjectFact{
|
||||||
ProjectID: c.Param("id"),
|
ProjectID: projectID,
|
||||||
FactKey: req.FactKey,
|
FactKey: req.FactKey,
|
||||||
Category: req.Category,
|
Category: req.Category,
|
||||||
Summary: req.Summary,
|
Summary: req.Summary,
|
||||||
@@ -310,16 +435,24 @@ func (h *ProjectHandler) CreateFact(c *gin.Context) {
|
|||||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
c.JSON(http.StatusOK, created)
|
explicitLinks := req.Links != nil || req.LinksText != nil
|
||||||
|
if err := h.applyFactLinksAfterUpsert(projectID, created, req.Links, req.LinksText, explicitLinks, !explicitLinks); err != nil {
|
||||||
|
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
created, _ = h.db.GetProjectFactByKey(projectID, created.FactKey)
|
||||||
|
c.JSON(http.StatusOK, h.factResponseWithLinks(projectID, created, true))
|
||||||
}
|
}
|
||||||
|
|
||||||
// UpdateFact PUT /api/projects/:id/facts/:factId
|
// UpdateFact PUT /api/projects/:id/facts/:factId
|
||||||
func (h *ProjectHandler) UpdateFact(c *gin.Context) {
|
func (h *ProjectHandler) UpdateFact(c *gin.Context) {
|
||||||
|
projectID := c.Param("id")
|
||||||
existing, err := h.db.GetProjectFact(c.Param("factId"))
|
existing, err := h.db.GetProjectFact(c.Param("factId"))
|
||||||
if err != nil || existing.ProjectID != c.Param("id") {
|
if err != nil || existing.ProjectID != projectID {
|
||||||
c.JSON(http.StatusNotFound, gin.H{"error": "事实不存在"})
|
c.JSON(http.StatusNotFound, gin.H{"error": "事实不存在"})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
oldFactKey := existing.FactKey
|
||||||
var req updateFactRequest
|
var req updateFactRequest
|
||||||
if err := c.ShouldBindJSON(&req); err != nil {
|
if err := c.ShouldBindJSON(&req); err != nil {
|
||||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||||
@@ -355,7 +488,29 @@ func (h *ProjectHandler) UpdateFact(c *gin.Context) {
|
|||||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
c.JSON(http.StatusOK, updated)
|
if oldFactKey != updated.FactKey {
|
||||||
|
if err := h.db.RenameProjectFactKeyEdges(projectID, oldFactKey, updated.FactKey); err != nil {
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if req.Links != nil || req.LinksText != nil {
|
||||||
|
var links []factLinkRequest
|
||||||
|
if req.Links != nil {
|
||||||
|
links = *req.Links
|
||||||
|
}
|
||||||
|
if err := h.applyFactLinksAfterUpsert(projectID, updated, links, req.LinksText, true, false); err != nil {
|
||||||
|
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
} else if req.ClearBody || req.Body != nil {
|
||||||
|
if err := h.applyFactLinksAfterUpsert(projectID, updated, nil, nil, false, true); err != nil {
|
||||||
|
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
updated, _ = h.db.GetProjectFactByKey(projectID, updated.FactKey)
|
||||||
|
c.JSON(http.StatusOK, h.factResponseWithLinks(projectID, updated, true))
|
||||||
}
|
}
|
||||||
|
|
||||||
// DeleteFact DELETE /api/projects/:id/facts/:factId
|
// DeleteFact DELETE /api/projects/:id/facts/:factId
|
||||||
@@ -408,3 +563,82 @@ func (h *ProjectHandler) RestoreFact(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
c.JSON(http.StatusOK, gin.H{"success": true})
|
c.JSON(http.StatusOK, gin.H{"success": true})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type createFactEdgeRequest struct {
|
||||||
|
SourceFactKey string `json:"source_fact_key" binding:"required"`
|
||||||
|
TargetFactKey string `json:"target_fact_key" binding:"required"`
|
||||||
|
EdgeType string `json:"edge_type" binding:"required"`
|
||||||
|
Confidence string `json:"confidence"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// ListFactEdges GET /api/projects/:id/fact-edges
|
||||||
|
func (h *ProjectHandler) ListFactEdges(c *gin.Context) {
|
||||||
|
projectID := c.Param("id")
|
||||||
|
edges, err := h.db.ListProjectFactEdgesByProject(projectID)
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if edges == nil {
|
||||||
|
edges = []*database.ProjectFactEdge{}
|
||||||
|
}
|
||||||
|
c.JSON(http.StatusOK, edges)
|
||||||
|
}
|
||||||
|
|
||||||
|
// CreateFactEdge POST /api/projects/:id/fact-edges
|
||||||
|
func (h *ProjectHandler) CreateFactEdge(c *gin.Context) {
|
||||||
|
projectID := c.Param("id")
|
||||||
|
var req createFactEdgeRequest
|
||||||
|
if err := c.ShouldBindJSON(&req); err != nil {
|
||||||
|
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
edge, err := h.db.AddProjectFactEdge(projectID, database.ProjectFactEdgeInput{
|
||||||
|
To: req.TargetFactKey,
|
||||||
|
Type: req.EdgeType,
|
||||||
|
Confidence: req.Confidence,
|
||||||
|
}, req.SourceFactKey, "")
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if f, err := h.db.GetProjectFactByKey(projectID, req.TargetFactKey); err == nil {
|
||||||
|
in, _ := h.db.ListIncomingProjectFactEdges(projectID, req.TargetFactKey)
|
||||||
|
f.Body = project.SyncBodyLinksSection(f.Body, in)
|
||||||
|
_, _ = h.db.UpsertProjectFact(f)
|
||||||
|
}
|
||||||
|
c.JSON(http.StatusOK, edge)
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeleteFactEdge DELETE /api/projects/:id/fact-edges/:edgeId
|
||||||
|
func (h *ProjectHandler) DeleteFactEdge(c *gin.Context) {
|
||||||
|
projectID := c.Param("id")
|
||||||
|
edgeID := c.Param("edgeId")
|
||||||
|
edge, err := h.db.GetProjectFactEdge(edgeID)
|
||||||
|
if err != nil || edge.ProjectID != projectID {
|
||||||
|
c.JSON(http.StatusNotFound, gin.H{"error": "边不存在"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if err := h.db.DeleteProjectFactEdge(edgeID); err != nil {
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if f, err := h.db.GetProjectFactByKey(projectID, edge.TargetFactKey); err == nil {
|
||||||
|
in, _ := h.db.ListIncomingProjectFactEdges(projectID, edge.TargetFactKey)
|
||||||
|
f.Body = project.SyncBodyLinksSection(f.Body, in)
|
||||||
|
_, _ = h.db.UpsertProjectFact(f)
|
||||||
|
}
|
||||||
|
c.JSON(http.StatusOK, gin.H{"success": true})
|
||||||
|
}
|
||||||
|
|
||||||
|
// PromoteAttackChain POST /api/projects/:id/promote-attack-chain/:conversationId
|
||||||
|
func (h *ProjectHandler) PromoteAttackChain(c *gin.Context) {
|
||||||
|
projectID := c.Param("id")
|
||||||
|
conversationID := c.Param("conversationId")
|
||||||
|
result, err := attackchain.PromoteToProject(h.db, projectID, conversationID)
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
c.JSON(http.StatusOK, result)
|
||||||
|
}
|
||||||
|
|||||||
@@ -30,3 +30,19 @@ func (h *AgentHandler) projectBlackboardBlock(conversationID string) string {
|
|||||||
}
|
}
|
||||||
return strings.TrimSpace(block)
|
return strings.TrimSpace(block)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// conversationProjectID 返回对话绑定的项目 ID;未绑定或查询失败时返回空字符串。
|
||||||
|
func (h *AgentHandler) conversationProjectID(conversationID string) string {
|
||||||
|
if h == nil || h.db == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
conversationID = strings.TrimSpace(conversationID)
|
||||||
|
if conversationID == "" {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
projectID, err := h.db.GetConversationProjectID(conversationID)
|
||||||
|
if err != nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return strings.TrimSpace(projectID)
|
||||||
|
}
|
||||||
|
|||||||
@@ -447,7 +447,7 @@ func (h *RobotHandler) cmdUnbindProject(platform, userID string) string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (h *RobotHandler) cmdList() string {
|
func (h *RobotHandler) cmdList() string {
|
||||||
convs, err := h.db.ListConversations(50, 0, "")
|
convs, err := h.db.ListConversations(50, 0, "", "")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "获取对话列表失败: " + err.Error()
|
return "获取对话列表失败: " + err.Error()
|
||||||
}
|
}
|
||||||
@@ -594,6 +594,9 @@ func (h *RobotHandler) cmdDelete(platform, userID, convID string) string {
|
|||||||
h.mu.Unlock()
|
h.mu.Unlock()
|
||||||
h.deleteSessionBinding(sk)
|
h.deleteSessionBinding(sk)
|
||||||
}
|
}
|
||||||
|
if h.agentHandler != nil {
|
||||||
|
h.agentHandler.CancelRunningTaskForConversation(convID)
|
||||||
|
}
|
||||||
if err := h.db.DeleteConversation(convID); err != nil {
|
if err := h.db.DeleteConversation(convID); err != nil {
|
||||||
return "删除失败: " + err.Error()
|
return "删除失败: " + err.Error()
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -21,6 +21,7 @@ import (
|
|||||||
// MonitorStorage 监控数据存储接口
|
// MonitorStorage 监控数据存储接口
|
||||||
type MonitorStorage interface {
|
type MonitorStorage interface {
|
||||||
SaveToolExecution(exec *ToolExecution) error
|
SaveToolExecution(exec *ToolExecution) error
|
||||||
|
UpdateToolExecutionResult(id string, result *ToolResult) error
|
||||||
LoadToolExecutions() ([]*ToolExecution, error)
|
LoadToolExecutions() ([]*ToolExecution, error)
|
||||||
GetToolExecution(id string) (*ToolExecution, error)
|
GetToolExecution(id string) (*ToolExecution, error)
|
||||||
SaveToolStats(toolName string, stats *ToolStats) error
|
SaveToolStats(toolName string, stats *ToolStats) error
|
||||||
@@ -963,6 +964,26 @@ func (s *Server) RecordCompletedToolInvocation(toolName string, args map[string]
|
|||||||
return executionID
|
return executionID
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// UpdateToolExecutionResult 将监控库中的工具结果更新为送入模型的展示正文(如 reduction 后的 persisted-output)。
|
||||||
|
func (s *Server) UpdateToolExecutionResult(executionID string, result *ToolResult) error {
|
||||||
|
if s == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
executionID = strings.TrimSpace(executionID)
|
||||||
|
if executionID == "" || result == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
s.mu.Lock()
|
||||||
|
if exec, ok := s.executions[executionID]; ok && exec != nil {
|
||||||
|
exec.Result = result
|
||||||
|
}
|
||||||
|
s.mu.Unlock()
|
||||||
|
if s.storage != nil {
|
||||||
|
return s.storage.UpdateToolExecutionResult(executionID, result)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// cleanupOldExecutions 清理旧的执行记录,防止内存无限增长
|
// cleanupOldExecutions 清理旧的执行记录,防止内存无限增长
|
||||||
func (s *Server) cleanupOldExecutions() {
|
func (s *Server) cleanupOldExecutions() {
|
||||||
if len(s.executions) <= s.maxExecutionsInMemory {
|
if len(s.executions) <= s.maxExecutionsInMemory {
|
||||||
|
|||||||
@@ -88,6 +88,7 @@ type einoADKRunLoopArgs struct {
|
|||||||
// 在完成时写入 MCP 监控;execute 仍由 eino_execute_monitor 记录,此处跳过。
|
// 在完成时写入 MCP 监控;execute 仍由 eino_execute_monitor 记录,此处跳过。
|
||||||
FilesystemMonitorAgent *agent.Agent
|
FilesystemMonitorAgent *agent.Agent
|
||||||
FilesystemMonitorRecord einomcp.ExecutionRecorder
|
FilesystemMonitorRecord einomcp.ExecutionRecorder
|
||||||
|
MCPExecutionBinder *MCPExecutionBinder
|
||||||
|
|
||||||
// ToolInvokeNotify 与 einomcp.ToolsFromDefinitions 共享:run loop 在迭代前 Set,MCP 桥 Fire 以补全 tool_result。
|
// ToolInvokeNotify 与 einomcp.ToolsFromDefinitions 共享:run loop 在迭代前 Set,MCP 桥 Fire 以补全 tool_result。
|
||||||
ToolInvokeNotify *einomcp.ToolInvokeNotifyHolder
|
ToolInvokeNotify *einomcp.ToolInvokeNotifyHolder
|
||||||
@@ -285,53 +286,63 @@ func runEinoADKAgentLoop(ctx context.Context, args *einoADKRunLoopArgs, baseMsgs
|
|||||||
executeStdoutDupMu.Unlock()
|
executeStdoutDupMu.Unlock()
|
||||||
}
|
}
|
||||||
|
|
||||||
var toolResultSent sync.Map // toolCallID -> struct{};与 ADK Tool 消息去重,避免 bridge 与事件流各推一次
|
var toolResultSent sync.Map // toolCallID -> struct{};ADK Tool 事件去重(权威正文来自 reduction 处理后的 agent 上下文)
|
||||||
if args.ToolInvokeNotify != nil {
|
tryEmitToolResultProgress := func(toolName, content, toolCallID string, isErr bool, agentName string) {
|
||||||
args.ToolInvokeNotify.Set(func(toolCallID, toolName, einoAgent string, success bool, content string, invokeErr error) {
|
if progress == nil {
|
||||||
tid := strings.TrimSpace(toolCallID)
|
|
||||||
removePendingByID(tid)
|
|
||||||
if tid == "" || progress == nil {
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if _, loaded := toolResultSent.LoadOrStore(tid, struct{}{}); loaded {
|
toolName = strings.TrimSpace(toolName)
|
||||||
return
|
if toolName == "" {
|
||||||
|
toolName = "unknown"
|
||||||
}
|
}
|
||||||
isErr := !success || invokeErr != nil
|
preview := content
|
||||||
body := content
|
|
||||||
if invokeErr != nil {
|
|
||||||
// 保留已流式累计的 stdout(如 execute 超时前的一半输出),避免 tool_result 只剩错误串、模型与 UI 丢失上下文
|
|
||||||
tail := friendlyEinoExecuteInvokeTail(invokeErr)
|
|
||||||
// execute 流式包装可能已把超时句写入 content(供 ADK tool 与流式 delta);勿重复拼接
|
|
||||||
if tail != "" && strings.Contains(content, tail) {
|
|
||||||
body = content
|
|
||||||
} else if strings.TrimSpace(content) != "" {
|
|
||||||
body = strings.TrimRight(content, "\n") + "\n\n" + tail
|
|
||||||
} else {
|
|
||||||
body = tail
|
|
||||||
}
|
|
||||||
isErr = true
|
|
||||||
}
|
|
||||||
recordPendingExecuteStdoutDup(toolName, body, isErr)
|
|
||||||
preview := body
|
|
||||||
if len(preview) > 200 {
|
if len(preview) > 200 {
|
||||||
preview = preview[:200] + "..."
|
preview = preview[:200] + "..."
|
||||||
}
|
}
|
||||||
agentTag := strings.TrimSpace(einoAgent)
|
data := map[string]interface{}{
|
||||||
if agentTag == "" {
|
|
||||||
agentTag = orchestratorName
|
|
||||||
}
|
|
||||||
progress("tool_result", fmt.Sprintf("工具结果 (%s)", toolName), map[string]interface{}{
|
|
||||||
"toolName": toolName,
|
"toolName": toolName,
|
||||||
"success": !isErr,
|
"success": !isErr,
|
||||||
"isError": isErr,
|
"isError": isErr,
|
||||||
"result": body,
|
"result": content,
|
||||||
"resultPreview": preview,
|
"resultPreview": preview,
|
||||||
"toolCallId": tid,
|
|
||||||
"conversationId": conversationID,
|
"conversationId": conversationID,
|
||||||
"einoAgent": agentTag,
|
"einoAgent": agentName,
|
||||||
"einoRole": einoRoleTag(agentTag),
|
"einoRole": einoRoleTag(agentName),
|
||||||
"source": "eino",
|
"source": "eino",
|
||||||
})
|
}
|
||||||
|
tid := strings.TrimSpace(toolCallID)
|
||||||
|
if tid == "" {
|
||||||
|
if inferred, ok := popNextPendingForAgent(agentName); ok {
|
||||||
|
tid = inferred.ToolCallID
|
||||||
|
} else if inferred, ok := popNextPendingForAgent(orchestratorName); ok {
|
||||||
|
tid = inferred.ToolCallID
|
||||||
|
} else if inferred, ok := popNextPendingForAgent(""); ok {
|
||||||
|
tid = inferred.ToolCallID
|
||||||
|
} else if inferred, ok := popAnyPending(); ok {
|
||||||
|
tid = inferred.ToolCallID
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if tid != "" {
|
||||||
|
removePendingByID(tid)
|
||||||
|
if _, loaded := toolResultSent.LoadOrStore(tid, struct{}{}); loaded {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
data["toolCallId"] = tid
|
||||||
|
toolCallID = tid
|
||||||
|
}
|
||||||
|
recordPendingExecuteStdoutDup(toolName, content, isErr)
|
||||||
|
recordEinoADKFilesystemToolMonitor(args.FilesystemMonitorAgent, args.FilesystemMonitorRecord, toolName, toolCallID, runAccumulatedMsgs, content, isErr)
|
||||||
|
if args.FilesystemMonitorAgent != nil && args.MCPExecutionBinder != nil {
|
||||||
|
if execID := args.MCPExecutionBinder.ExecutionID(toolCallID); execID != "" {
|
||||||
|
args.FilesystemMonitorAgent.UpdateMCPExecutionDisplayResult(execID, content)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
progress("tool_result", fmt.Sprintf("工具结果 (%s)", toolName), data)
|
||||||
|
}
|
||||||
|
if args.ToolInvokeNotify != nil {
|
||||||
|
args.ToolInvokeNotify.Set(func(toolCallID, toolName, einoAgent string, success bool, content string, invokeErr error) {
|
||||||
|
removePendingByID(strings.TrimSpace(toolCallID))
|
||||||
|
// tool_result 仅由下方 ADK schema.Tool 事件推送,正文与送入模型的上下文一致(含 reduction 截断)。
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -372,6 +383,12 @@ func runEinoADKAgentLoop(ctx context.Context, args *einoADKRunLoopArgs, baseMsgs
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
runner := adk.NewRunner(ctx, runnerCfg)
|
runner := adk.NewRunner(ctx, runnerCfg)
|
||||||
|
startRunnerIter := func(runMsgs []adk.Message) *adk.AsyncIterator[*adk.AgentEvent] {
|
||||||
|
if checkPointID != "" {
|
||||||
|
return runner.Run(ctx, runMsgs, adk.WithCheckPointID(checkPointID))
|
||||||
|
}
|
||||||
|
return runner.Run(ctx, runMsgs)
|
||||||
|
}
|
||||||
var iter *adk.AsyncIterator[*adk.AgentEvent]
|
var iter *adk.AsyncIterator[*adk.AgentEvent]
|
||||||
if cpStore != nil && checkPointID != "" {
|
if cpStore != nil && checkPointID != "" {
|
||||||
if _, existed, getErr := cpStore.Get(ctx, checkPointID); getErr != nil {
|
if _, existed, getErr := cpStore.Get(ctx, checkPointID); getErr != nil {
|
||||||
@@ -411,12 +428,9 @@ func runEinoADKAgentLoop(ctx context.Context, args *einoADKRunLoopArgs, baseMsgs
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
if iter == nil {
|
if iter == nil {
|
||||||
if checkPointID != "" {
|
iter = startRunnerIter(msgs)
|
||||||
iter = runner.Run(ctx, msgs, adk.WithCheckPointID(checkPointID))
|
|
||||||
} else {
|
|
||||||
iter = runner.Run(ctx, msgs)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
transientRetrier := newEinoTransientRunRetrier(einoTransientRunRetryPolicyFromArgs(args))
|
||||||
handleRunErr := func(runErr error) error {
|
handleRunErr := func(runErr error) error {
|
||||||
if runErr == nil {
|
if runErr == nil {
|
||||||
return nil
|
return nil
|
||||||
@@ -469,26 +483,60 @@ func runEinoADKAgentLoop(ctx context.Context, args *einoADKRunLoopArgs, baseMsgs
|
|||||||
return runErr
|
return runErr
|
||||||
}
|
}
|
||||||
|
|
||||||
// maybeRetryTransientRun:不在此层 runner.Run/Resume;由 handler 落库 + loadHistoryFromAgentTrace 分段续跑(同中断并继续)。
|
maybeRetryTransientRun := func(runErr error) (restarted bool, fatal error) {
|
||||||
maybeRetryTransientRun := func(runErr error) (retry bool, fatal error) {
|
if runErr == nil {
|
||||||
if runErr == nil || !isEinoTransientRunError(runErr) {
|
return false, nil
|
||||||
|
}
|
||||||
|
if !isEinoTransientRunError(runErr) {
|
||||||
return false, handleRunErr(runErr)
|
return false, handleRunErr(runErr)
|
||||||
}
|
}
|
||||||
|
restarted, restartMsgs, ctxSource, backoff, retErr := transientRetrier.tryRetry(
|
||||||
|
ctx, runErr, args, baseMsgs, runAccumulatedMsgs, baseAccumulatedCount,
|
||||||
|
)
|
||||||
|
if retErr != nil {
|
||||||
|
flushAllPendingAsFailed(runErr)
|
||||||
if logger != nil {
|
if logger != nil {
|
||||||
logger.Warn("eino transient error, ending run segment for handler resume",
|
logger.Warn("eino transient retry exhausted",
|
||||||
|
zap.Error(retErr),
|
||||||
|
zap.String("orchestration", orchMode),
|
||||||
|
zap.Int("maxAttempts", transientRetrier.maxAttempts()))
|
||||||
|
}
|
||||||
|
return false, retErr
|
||||||
|
}
|
||||||
|
if !restarted {
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
attemptNo := transientRetrier.attempt()
|
||||||
|
maxAttempts := transientRetrier.maxAttempts()
|
||||||
|
if logger != nil {
|
||||||
|
logger.Warn("eino transient error, retrying after backoff",
|
||||||
zap.Error(runErr),
|
zap.Error(runErr),
|
||||||
zap.String("orchestration", orchMode))
|
zap.String("orchestration", orchMode),
|
||||||
|
zap.Int("attempt", attemptNo),
|
||||||
|
zap.Int("maxAttempts", maxAttempts),
|
||||||
|
zap.Duration("backoff", backoff))
|
||||||
}
|
}
|
||||||
if progress != nil {
|
if progress != nil {
|
||||||
progress("eino_run_retry", "遇到临时错误(限流或网络波动),将保存上下文并重试…", map[string]interface{}{
|
progress("eino_run_retry", fmt.Sprintf("遇到临时错误(限流或网络波动),%d 秒后第 %d/%d 次重试…", int(backoff.Seconds()), attemptNo, maxAttempts), map[string]interface{}{
|
||||||
"conversationId": conversationID,
|
"conversationId": conversationID,
|
||||||
"source": "eino",
|
"source": "eino",
|
||||||
"orchestration": orchMode,
|
"orchestration": orchMode,
|
||||||
"error": runErr.Error(),
|
"error": runErr.Error(),
|
||||||
"resumeKind": "trace_segment",
|
"attempt": attemptNo,
|
||||||
|
"maxAttempts": maxAttempts,
|
||||||
|
"backoffSec": int(backoff.Seconds()),
|
||||||
|
})
|
||||||
|
progress("eino_run_retry", "已恢复上下文,正在重试…", map[string]interface{}{
|
||||||
|
"conversationId": conversationID,
|
||||||
|
"source": "eino",
|
||||||
|
"orchestration": orchMode,
|
||||||
|
"attempt": attemptNo,
|
||||||
|
"contextSource": string(ctxSource),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
return false, ErrTransientRetryContinue
|
msgs = restartMsgs
|
||||||
|
iter = startRunnerIter(msgs)
|
||||||
|
return true, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
takePartial := func(runErr error) (*RunResult, error) {
|
takePartial := func(runErr error) (*RunResult, error) {
|
||||||
@@ -572,9 +620,13 @@ func runEinoADKAgentLoop(ctx context.Context, args *einoADKRunLoopArgs, baseMsgs
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
if ev.Err != nil {
|
if ev.Err != nil {
|
||||||
if _, retErr := maybeRetryTransientRun(ev.Err); retErr != nil {
|
restarted, retErr := maybeRetryTransientRun(ev.Err)
|
||||||
|
if retErr != nil {
|
||||||
return takePartial(retErr)
|
return takePartial(retErr)
|
||||||
}
|
}
|
||||||
|
if restarted {
|
||||||
|
continue
|
||||||
|
}
|
||||||
}
|
}
|
||||||
if ev.AgentName != "" && progress != nil {
|
if ev.AgentName != "" && progress != nil {
|
||||||
iterEinoAgent := orchestratorName
|
iterEinoAgent := orchestratorName
|
||||||
@@ -619,7 +671,8 @@ func runEinoADKAgentLoop(ctx context.Context, args *einoADKRunLoopArgs, baseMsgs
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
einoLastAgent = ev.AgentName
|
// 仅在代理切换时更新进度标题;同一代理的每个 ADK 事件不再重复刷 progress。
|
||||||
|
if einoLastAgent != ev.AgentName {
|
||||||
progress("progress", fmt.Sprintf("[Eino] %s", ev.AgentName), map[string]interface{}{
|
progress("progress", fmt.Sprintf("[Eino] %s", ev.AgentName), map[string]interface{}{
|
||||||
"conversationId": conversationID,
|
"conversationId": conversationID,
|
||||||
"einoAgent": ev.AgentName,
|
"einoAgent": ev.AgentName,
|
||||||
@@ -627,11 +680,57 @@ func runEinoADKAgentLoop(ctx context.Context, args *einoADKRunLoopArgs, baseMsgs
|
|||||||
"orchestration": orchMode,
|
"orchestration": orchMode,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
einoLastAgent = ev.AgentName
|
||||||
|
}
|
||||||
if ev.Output == nil || ev.Output.MessageOutput == nil {
|
if ev.Output == nil || ev.Output.MessageOutput == nil {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
mv := ev.Output.MessageOutput
|
mv := ev.Output.MessageOutput
|
||||||
|
|
||||||
|
if mv.IsStreaming && mv.MessageStream != nil && mv.Role == schema.Tool {
|
||||||
|
toolName := strings.TrimSpace(mv.ToolName)
|
||||||
|
var toolBuf strings.Builder
|
||||||
|
streamToolCallID := ""
|
||||||
|
var toolStreamRecvErr error
|
||||||
|
for {
|
||||||
|
chunk, rerr := mv.MessageStream.Recv()
|
||||||
|
if errors.Is(rerr, io.EOF) {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
if rerr != nil {
|
||||||
|
toolStreamRecvErr = rerr
|
||||||
|
break
|
||||||
|
}
|
||||||
|
if chunk == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if chunk.Content != "" {
|
||||||
|
toolBuf.WriteString(chunk.Content)
|
||||||
|
}
|
||||||
|
if tid := strings.TrimSpace(chunk.ToolCallID); tid != "" {
|
||||||
|
streamToolCallID = tid
|
||||||
|
}
|
||||||
|
}
|
||||||
|
content := toolBuf.String()
|
||||||
|
isErr := false
|
||||||
|
if strings.HasPrefix(content, einomcp.ToolErrorPrefix) {
|
||||||
|
isErr = true
|
||||||
|
content = strings.TrimPrefix(content, einomcp.ToolErrorPrefix)
|
||||||
|
}
|
||||||
|
if streamToolCallID != "" {
|
||||||
|
opts := []schema.ToolMessageOption{schema.WithToolName(toolName)}
|
||||||
|
runAccumulatedMsgs = append(runAccumulatedMsgs, schema.ToolMessage(content, streamToolCallID, opts...))
|
||||||
|
}
|
||||||
|
tryEmitToolResultProgress(toolName, content, streamToolCallID, isErr, ev.AgentName)
|
||||||
|
if toolStreamRecvErr != nil && logger != nil {
|
||||||
|
logger.Warn("eino tool result stream recv error",
|
||||||
|
zap.Error(toolStreamRecvErr),
|
||||||
|
zap.String("agent", ev.AgentName),
|
||||||
|
zap.String("tool", toolName))
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
if mv.IsStreaming && mv.MessageStream != nil {
|
if mv.IsStreaming && mv.MessageStream != nil {
|
||||||
mainStreamID := fmt.Sprintf("eino-main-%s-%d", conversationID, atomic.AddInt64(&mainResponseStreamSeq, 1))
|
mainStreamID := fmt.Sprintf("eino-main-%s-%d", conversationID, atomic.AddInt64(&mainResponseStreamSeq, 1))
|
||||||
streamHeaderSent := false
|
streamHeaderSent := false
|
||||||
@@ -785,6 +884,16 @@ func runEinoADKAgentLoop(ctx context.Context, args *einoADKRunLoopArgs, baseMsgs
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if progress != nil && reasoningStreamID != "" && strings.TrimSpace(reasoningBuf) != "" {
|
||||||
|
progress("reasoning_chain_stream_end", openai.DisplayReasoningContent(strings.TrimSpace(reasoningBuf)), map[string]interface{}{
|
||||||
|
"streamId": reasoningStreamID,
|
||||||
|
"conversationId": conversationID,
|
||||||
|
"source": "eino",
|
||||||
|
"einoAgent": ev.AgentName,
|
||||||
|
"einoRole": einoRoleTag(ev.AgentName),
|
||||||
|
"orchestration": orchMode,
|
||||||
|
})
|
||||||
|
}
|
||||||
if streamsMainAssistant(ev.AgentName) {
|
if streamsMainAssistant(ev.AgentName) {
|
||||||
s := strings.TrimSpace(mainAssistantBuf)
|
s := strings.TrimSpace(mainAssistantBuf)
|
||||||
if mainAssistDupTarget != "" {
|
if mainAssistDupTarget != "" {
|
||||||
@@ -883,9 +992,13 @@ func runEinoADKAgentLoop(ctx context.Context, args *einoADKRunLoopArgs, baseMsgs
|
|||||||
"einoRole": einoRoleTag(ev.AgentName),
|
"einoRole": einoRoleTag(ev.AgentName),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
if _, retErr := maybeRetryTransientRun(streamRecvErr); retErr != nil {
|
restarted, retErr := maybeRetryTransientRun(streamRecvErr)
|
||||||
|
if retErr != nil {
|
||||||
return takePartial(retErr)
|
return takePartial(retErr)
|
||||||
}
|
}
|
||||||
|
if restarted {
|
||||||
|
continue
|
||||||
|
}
|
||||||
}
|
}
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
@@ -963,7 +1076,7 @@ func runEinoADKAgentLoop(ctx context.Context, args *einoADKRunLoopArgs, baseMsgs
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if mv.Role == schema.Tool && progress != nil {
|
if (mv.Role == schema.Tool || msg.Role == schema.Tool) && progress != nil {
|
||||||
toolName := msg.ToolName
|
toolName := msg.ToolName
|
||||||
if toolName == "" {
|
if toolName == "" {
|
||||||
toolName = mv.ToolName
|
toolName = mv.ToolName
|
||||||
@@ -976,46 +1089,8 @@ func runEinoADKAgentLoop(ctx context.Context, args *einoADKRunLoopArgs, baseMsgs
|
|||||||
content = strings.TrimPrefix(content, einomcp.ToolErrorPrefix)
|
content = strings.TrimPrefix(content, einomcp.ToolErrorPrefix)
|
||||||
}
|
}
|
||||||
|
|
||||||
preview := content
|
|
||||||
if len(preview) > 200 {
|
|
||||||
preview = preview[:200] + "..."
|
|
||||||
}
|
|
||||||
data := map[string]interface{}{
|
|
||||||
"toolName": toolName,
|
|
||||||
"success": !isErr,
|
|
||||||
"isError": isErr,
|
|
||||||
"result": content,
|
|
||||||
"resultPreview": preview,
|
|
||||||
"conversationId": conversationID,
|
|
||||||
"einoAgent": ev.AgentName,
|
|
||||||
"einoRole": einoRoleTag(ev.AgentName),
|
|
||||||
"source": "eino",
|
|
||||||
}
|
|
||||||
toolCallID := strings.TrimSpace(msg.ToolCallID)
|
toolCallID := strings.TrimSpace(msg.ToolCallID)
|
||||||
if toolCallID == "" {
|
tryEmitToolResultProgress(toolName, content, toolCallID, isErr, ev.AgentName)
|
||||||
if inferred, ok := popNextPendingForAgent(ev.AgentName); ok {
|
|
||||||
toolCallID = inferred.ToolCallID
|
|
||||||
} else if inferred, ok := popNextPendingForAgent(orchestratorName); ok {
|
|
||||||
toolCallID = inferred.ToolCallID
|
|
||||||
} else if inferred, ok := popNextPendingForAgent(""); ok {
|
|
||||||
toolCallID = inferred.ToolCallID
|
|
||||||
} else if inferred, ok := popAnyPending(); ok {
|
|
||||||
toolCallID = inferred.ToolCallID
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if toolCallID != "" {
|
|
||||||
removePendingByID(toolCallID)
|
|
||||||
if _, loaded := toolResultSent.LoadOrStore(toolCallID, struct{}{}); loaded {
|
|
||||||
// ToolInvokeNotify 可能已推过 tool_result(如 execute 流式包装里 Fire 仅携带截断后的 stdout),
|
|
||||||
// 此处仍应用 ADK Tool 消息中的完整内容刷新去重基准,避免模型复述全文时与截断串比对失败而重复展示「助手输出」。
|
|
||||||
recordPendingExecuteStdoutDup(toolName, content, isErr)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
data["toolCallId"] = toolCallID
|
|
||||||
}
|
|
||||||
recordPendingExecuteStdoutDup(toolName, content, isErr)
|
|
||||||
recordEinoADKFilesystemToolMonitor(args.FilesystemMonitorAgent, args.FilesystemMonitorRecord, toolName, toolCallID, runAccumulatedMsgs, content, isErr)
|
|
||||||
progress("tool_result", fmt.Sprintf("工具结果 (%s)", toolName), data)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1027,32 +1102,9 @@ func runEinoADKAgentLoop(ctx context.Context, args *einoADKRunLoopArgs, baseMsgs
|
|||||||
orchMode, runAccumulatedMsgs, persistTraceSource(args, runAccumulatedMsgs),
|
orchMode, runAccumulatedMsgs, persistTraceSource(args, runAccumulatedMsgs),
|
||||||
lastAssistant, lastPlanExecuteExecutor, emptyHint, ids, false,
|
lastAssistant, lastPlanExecuteExecutor, emptyHint, ids, false,
|
||||||
)
|
)
|
||||||
if shouldEinoEmptyResponseContinue(out, emptyHint, len(runAccumulatedMsgs), baseAccumulatedCount) {
|
|
||||||
if logger != nil {
|
|
||||||
logger.Info("eino empty response, ending run segment for handler resume",
|
|
||||||
zap.String("conversationId", conversationID),
|
|
||||||
zap.String("orchestration", orchMode),
|
|
||||||
zap.Int("traceMessages", len(runAccumulatedMsgs)))
|
|
||||||
}
|
|
||||||
if progress != nil {
|
|
||||||
progress("eino_empty_response_continue", "会话已结束但未产生助手正文,正在基于轨迹自动续跑…", map[string]interface{}{
|
|
||||||
"conversationId": conversationID,
|
|
||||||
"source": "eino",
|
|
||||||
"resumeKind": "trace_segment",
|
|
||||||
})
|
|
||||||
}
|
|
||||||
return out, ErrEmptyResponseContinue
|
|
||||||
}
|
|
||||||
return out, nil
|
return out, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func shouldEinoEmptyResponseContinue(out *RunResult, emptyHint string, accumulatedLen, baseCount int) bool {
|
|
||||||
if out == nil || accumulatedLen <= baseCount {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
return strings.TrimSpace(out.Response) == strings.TrimSpace(emptyHint)
|
|
||||||
}
|
|
||||||
|
|
||||||
func persistTraceSource(args *einoADKRunLoopArgs, fallback []adk.Message) []adk.Message {
|
func persistTraceSource(args *einoADKRunLoopArgs, fallback []adk.Message) []adk.Message {
|
||||||
if args != nil && args.ModelFacingTrace != nil {
|
if args != nil && args.ModelFacingTrace != nil {
|
||||||
if snap := args.ModelFacingTrace.Snapshot(); len(snap) > 0 {
|
if snap := args.ModelFacingTrace.Snapshot(); len(snap) > 0 {
|
||||||
|
|||||||
@@ -1,21 +0,0 @@
|
|||||||
package multiagent
|
|
||||||
|
|
||||||
import "testing"
|
|
||||||
|
|
||||||
func TestShouldEinoEmptyResponseContinue(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
hint := "(empty hint)"
|
|
||||||
out := &RunResult{Response: hint}
|
|
||||||
if !shouldEinoEmptyResponseContinue(out, hint, 3, 1) {
|
|
||||||
t.Fatal("expected continue when response is empty hint and trace grew")
|
|
||||||
}
|
|
||||||
if shouldEinoEmptyResponseContinue(out, hint, 1, 1) {
|
|
||||||
t.Fatal("expected no continue when trace did not grow")
|
|
||||||
}
|
|
||||||
if shouldEinoEmptyResponseContinue(&RunResult{Response: "hello"}, hint, 3, 1) {
|
|
||||||
t.Fatal("expected no continue when response has content")
|
|
||||||
}
|
|
||||||
if shouldEinoEmptyResponseContinue(nil, hint, 3, 1) {
|
|
||||||
t.Fatal("expected no continue for nil result")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -9,8 +9,8 @@ import (
|
|||||||
|
|
||||||
// newEinoExecuteMonitorCallback 在 Eino filesystem execute 结束时写入 MCP 监控库并 recorder(executionId),
|
// newEinoExecuteMonitorCallback 在 Eino filesystem execute 结束时写入 MCP 监控库并 recorder(executionId),
|
||||||
// 与 CallTool 路径一致,供助手消息展示「渗透测试详情」芯片。
|
// 与 CallTool 路径一致,供助手消息展示「渗透测试详情」芯片。
|
||||||
func newEinoExecuteMonitorCallback(ag *agent.Agent, recorder einomcp.ExecutionRecorder) func(command, stdout string, success bool, invokeErr error) {
|
func newEinoExecuteMonitorCallback(ag *agent.Agent, recorder einomcp.ExecutionRecorder) func(toolCallID, command, stdout string, success bool, invokeErr error) {
|
||||||
return func(command, stdout string, success bool, invokeErr error) {
|
return func(toolCallID, command, stdout string, success bool, invokeErr error) {
|
||||||
if ag == nil || recorder == nil {
|
if ag == nil || recorder == nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -25,7 +25,7 @@ func newEinoExecuteMonitorCallback(ag *agent.Agent, recorder einomcp.ExecutionRe
|
|||||||
args := map[string]interface{}{"command": command}
|
args := map[string]interface{}{"command": command}
|
||||||
id := ag.RecordLocalToolExecution("execute", args, stdout, err)
|
id := ag.RecordLocalToolExecution("execute", args, stdout, err)
|
||||||
if id != "" {
|
if id != "" {
|
||||||
recorder(id)
|
recorder(id, toolCallID)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -34,6 +34,15 @@ func einoExecuteTimeoutUserHint() string {
|
|||||||
return "已超时终止 · Timed out"
|
return "已超时终止 · Timed out"
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// einoExecuteRecvErrIsToolTimeout 判断 Recv 错误是否由 agent.tool_timeout_minutes 触发。
|
||||||
|
// WithTimeout 到期后 local 侧常报 canceled / exit -1,但 execCtx.Err() 仍为 DeadlineExceeded。
|
||||||
|
func einoExecuteRecvErrIsToolTimeout(rerr error, tctx context.Context) bool {
|
||||||
|
if tctx != nil && errors.Is(tctx.Err(), context.DeadlineExceeded) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return errors.Is(rerr, context.DeadlineExceeded)
|
||||||
|
}
|
||||||
|
|
||||||
// einoStreamingShellWrap 包装 Eino filesystem 使用的 StreamingShell(cloudwego eino-ext local.Local)。
|
// einoStreamingShellWrap 包装 Eino filesystem 使用的 StreamingShell(cloudwego eino-ext local.Local)。
|
||||||
// 官方 execute 工具默认走 ExecuteStreaming 且不设 RunInBackendGround;末尾带 & 时子进程仍与管道相连,
|
// 官方 execute 工具默认走 ExecuteStreaming 且不设 RunInBackendGround;末尾带 & 时子进程仍与管道相连,
|
||||||
// streamStdout 按行读取会在无换行输出时长时间阻塞(与 MCP 工具 exec 的独立实现不同)。
|
// streamStdout 按行读取会在无换行输出时长时间阻塞(与 MCP 工具 exec 的独立实现不同)。
|
||||||
@@ -53,7 +62,7 @@ type einoStreamingShellWrap struct {
|
|||||||
// toolTimeoutMinutes 与 agent.tool_timeout_minutes 对齐;>0 时对单次 execute 套用 context 超时(与 MCP 工具经 executeToolViaMCP 行为一致)。0 表示仅依赖上层 ctx(如整任务 10h 上限)。
|
// toolTimeoutMinutes 与 agent.tool_timeout_minutes 对齐;>0 时对单次 execute 套用 context 超时(与 MCP 工具经 executeToolViaMCP 行为一致)。0 表示仅依赖上层 ctx(如整任务 10h 上限)。
|
||||||
toolTimeoutMinutes int
|
toolTimeoutMinutes int
|
||||||
// recordMonitor 在 execute 流结束后写入 tool_executions 并 recorder(executionId),使「渗透测试详情」与常规 MCP 一致。
|
// recordMonitor 在 execute 流结束后写入 tool_executions 并 recorder(executionId),使「渗透测试详情」与常规 MCP 一致。
|
||||||
recordMonitor func(command, stdout string, success bool, invokeErr error)
|
recordMonitor func(toolCallID, command, stdout string, success bool, invokeErr error)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (w *einoStreamingShellWrap) ExecuteStreaming(ctx context.Context, input *filesystem.ExecuteRequest) (*schema.StreamReader[*filesystem.ExecuteResponse], error) {
|
func (w *einoStreamingShellWrap) ExecuteStreaming(ctx context.Context, input *filesystem.ExecuteRequest) (*schema.StreamReader[*filesystem.ExecuteResponse], error) {
|
||||||
@@ -83,15 +92,25 @@ func (w *einoStreamingShellWrap) ExecuteStreaming(ctx context.Context, input *fi
|
|||||||
if execCancel != nil {
|
if execCancel != nil {
|
||||||
execCancel()
|
execCancel()
|
||||||
}
|
}
|
||||||
|
if einoExecuteRecvErrIsToolTimeout(err, execCtx) {
|
||||||
|
hint := "\n\n" + einoExecuteTimeoutUserHint() + "\n"
|
||||||
if w.recordMonitor != nil {
|
if w.recordMonitor != nil {
|
||||||
w.recordMonitor(userCmd, "", false, err)
|
w.recordMonitor(tid, userCmd, hint, false, context.DeadlineExceeded)
|
||||||
|
}
|
||||||
|
if w.invokeNotify != nil && tid != "" {
|
||||||
|
w.invokeNotify.Fire(tid, "execute", agentTag, false, hint, context.DeadlineExceeded)
|
||||||
|
}
|
||||||
|
return schema.StreamReaderFromArray([]*filesystem.ExecuteResponse{{Output: hint}}), nil
|
||||||
|
}
|
||||||
|
if w.recordMonitor != nil {
|
||||||
|
w.recordMonitor(tid, userCmd, "", false, err)
|
||||||
}
|
}
|
||||||
if w.invokeNotify != nil && tid != "" {
|
if w.invokeNotify != nil && tid != "" {
|
||||||
w.invokeNotify.Fire(tid, "execute", agentTag, false, "", err)
|
w.invokeNotify.Fire(tid, "execute", agentTag, false, "", err)
|
||||||
}
|
}
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
if sr == nil || w.invokeNotify == nil || tid == "" {
|
if sr == nil || w.invokeNotify == nil {
|
||||||
if execCancel != nil {
|
if execCancel != nil {
|
||||||
execCancel()
|
execCancel()
|
||||||
}
|
}
|
||||||
@@ -107,7 +126,6 @@ func (w *einoStreamingShellWrap) ExecuteStreaming(ctx context.Context, input *fi
|
|||||||
}
|
}
|
||||||
|
|
||||||
var sb strings.Builder
|
var sb strings.Builder
|
||||||
const maxCapture = 16 * 1024
|
|
||||||
success := true
|
success := true
|
||||||
var invokeErr error
|
var invokeErr error
|
||||||
exitCode := 0
|
exitCode := 0
|
||||||
@@ -121,6 +139,11 @@ func (w *einoStreamingShellWrap) ExecuteStreaming(ctx context.Context, input *fi
|
|||||||
if rerr != nil {
|
if rerr != nil {
|
||||||
success = false
|
success = false
|
||||||
invokeErr = rerr
|
invokeErr = rerr
|
||||||
|
// 单次 execute 超时须与 MCP 工具一致:写入工具结果尾标、继续迭代,不得向 ADK 流注入硬错误。
|
||||||
|
if einoExecuteRecvErrIsToolTimeout(rerr, tctx) {
|
||||||
|
invokeErr = context.DeadlineExceeded
|
||||||
|
break
|
||||||
|
}
|
||||||
_ = outW.Send(nil, rerr)
|
_ = outW.Send(nil, rerr)
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
@@ -130,15 +153,10 @@ func (w *einoStreamingShellWrap) ExecuteStreaming(ctx context.Context, input *fi
|
|||||||
exitCode = *resp.ExitCode
|
exitCode = *resp.ExitCode
|
||||||
}
|
}
|
||||||
var appended string
|
var appended string
|
||||||
if remain := maxCapture - sb.Len(); remain > 0 {
|
if resp.Output != "" {
|
||||||
out := resp.Output
|
sb.WriteString(resp.Output)
|
||||||
if len(out) > remain {
|
appended = resp.Output
|
||||||
out = out[:remain]
|
|
||||||
}
|
}
|
||||||
sb.WriteString(out)
|
|
||||||
appended = out
|
|
||||||
}
|
|
||||||
// 仅推送写入 sb 的片段,与末尾 Fire/recordMonitor 的截断累计一致,避免最终 tool_result 短于已展示增量。
|
|
||||||
if w.outputChunk != nil && strings.TrimSpace(appended) != "" {
|
if w.outputChunk != nil && strings.TrimSpace(appended) != "" {
|
||||||
w.outputChunk("execute", tid, appended)
|
w.outputChunk("execute", tid, appended)
|
||||||
}
|
}
|
||||||
@@ -167,16 +185,10 @@ func (w *einoStreamingShellWrap) ExecuteStreaming(ctx context.Context, input *fi
|
|||||||
if w.outputChunk != nil && tid != "" {
|
if w.outputChunk != nil && tid != "" {
|
||||||
w.outputChunk("execute", tid, hint)
|
w.outputChunk("execute", tid, hint)
|
||||||
}
|
}
|
||||||
if remain := maxCapture - sb.Len(); remain > 0 {
|
sb.WriteString(hint)
|
||||||
h := hint
|
|
||||||
if len(h) > remain {
|
|
||||||
h = h[:remain]
|
|
||||||
}
|
|
||||||
sb.WriteString(h)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
if w.recordMonitor != nil {
|
if w.recordMonitor != nil {
|
||||||
w.recordMonitor(command, sb.String(), success, invokeErr)
|
w.recordMonitor(tid, command, sb.String(), success, invokeErr)
|
||||||
}
|
}
|
||||||
w.invokeNotify.Fire(tid, "execute", agentTag, success, sb.String(), invokeErr)
|
w.invokeNotify.Fire(tid, "execute", agentTag, success, sb.String(), invokeErr)
|
||||||
outW.Close()
|
outW.Close()
|
||||||
|
|||||||
@@ -0,0 +1,138 @@
|
|||||||
|
package multiagent
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"io"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"cyberstrike-ai/internal/einomcp"
|
||||||
|
|
||||||
|
"github.com/cloudwego/eino/adk/filesystem"
|
||||||
|
"github.com/cloudwego/eino/schema"
|
||||||
|
)
|
||||||
|
|
||||||
|
type mockStreamingShell struct {
|
||||||
|
immediateErr error
|
||||||
|
recvErr error
|
||||||
|
output string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockStreamingShell) ExecuteStreaming(ctx context.Context, input *filesystem.ExecuteRequest) (*schema.StreamReader[*filesystem.ExecuteResponse], error) {
|
||||||
|
if m.immediateErr != nil {
|
||||||
|
return nil, m.immediateErr
|
||||||
|
}
|
||||||
|
outR, outW := schema.Pipe[*filesystem.ExecuteResponse](4)
|
||||||
|
go func() {
|
||||||
|
defer outW.Close()
|
||||||
|
if strings.TrimSpace(m.output) != "" {
|
||||||
|
_ = outW.Send(&filesystem.ExecuteResponse{Output: m.output}, nil)
|
||||||
|
}
|
||||||
|
if m.recvErr != nil {
|
||||||
|
_ = outW.Send(nil, m.recvErr)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
return outR, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEinoExecuteRecvErrIsToolTimeout(t *testing.T) {
|
||||||
|
tctx, cancel := context.WithTimeout(context.Background(), time.Millisecond)
|
||||||
|
defer cancel()
|
||||||
|
time.Sleep(2 * time.Millisecond)
|
||||||
|
<-tctx.Done()
|
||||||
|
|
||||||
|
if !einoExecuteRecvErrIsToolTimeout(context.Canceled, tctx) {
|
||||||
|
t.Fatal("expected canceled recv with deadline exec ctx to count as tool timeout")
|
||||||
|
}
|
||||||
|
if !einoExecuteRecvErrIsToolTimeout(context.DeadlineExceeded, nil) {
|
||||||
|
t.Fatal("expected DeadlineExceeded recv without tctx")
|
||||||
|
}
|
||||||
|
if einoExecuteRecvErrIsToolTimeout(errors.New("exit status 1"), context.Background()) {
|
||||||
|
t.Fatal("unexpected timeout for generic error")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEinoStreamingShellWrap_ToolTimeoutImmediateErrIsSoft(t *testing.T) {
|
||||||
|
inner := &mockStreamingShell{immediateErr: context.DeadlineExceeded}
|
||||||
|
wrap := &einoStreamingShellWrap{
|
||||||
|
inner: inner,
|
||||||
|
toolTimeoutMinutes: 60,
|
||||||
|
}
|
||||||
|
sr, err := wrap.ExecuteStreaming(context.Background(), &filesystem.ExecuteRequest{Command: "true"})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("immediate tool timeout must return soft stream, got err: %v", err)
|
||||||
|
}
|
||||||
|
defer sr.Close()
|
||||||
|
|
||||||
|
var got strings.Builder
|
||||||
|
for {
|
||||||
|
resp, rerr := sr.Recv()
|
||||||
|
if errors.Is(rerr, io.EOF) {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
if rerr != nil {
|
||||||
|
t.Fatalf("outer stream must not hard-fail, got: %v", rerr)
|
||||||
|
}
|
||||||
|
if resp != nil && resp.Output != "" {
|
||||||
|
got.WriteString(resp.Output)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !strings.Contains(got.String(), einoExecuteTimeoutUserHint()) {
|
||||||
|
t.Fatalf("expected timeout hint, got: %q", got.String())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEinoStreamingShellWrap_ToolTimeoutRecvErrIsSoft(t *testing.T) {
|
||||||
|
inner := &mockStreamingShell{recvErr: context.DeadlineExceeded}
|
||||||
|
notify := einomcp.NewToolInvokeNotifyHolder()
|
||||||
|
wrap := &einoStreamingShellWrap{
|
||||||
|
inner: inner,
|
||||||
|
invokeNotify: notify,
|
||||||
|
toolTimeoutMinutes: 60,
|
||||||
|
}
|
||||||
|
// 生产路径由 Eino compose 注入 toolCallID;单测通过已过期 execCtx 识别 tool_timeout 软错误。
|
||||||
|
tctx, cancel := context.WithTimeout(context.Background(), time.Millisecond)
|
||||||
|
defer cancel()
|
||||||
|
time.Sleep(2 * time.Millisecond)
|
||||||
|
<-tctx.Done()
|
||||||
|
|
||||||
|
sr, err := wrap.ExecuteStreaming(tctx, &filesystem.ExecuteRequest{Command: "sleep 999"})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("ExecuteStreaming: %v", err)
|
||||||
|
}
|
||||||
|
defer sr.Close()
|
||||||
|
|
||||||
|
var got strings.Builder
|
||||||
|
for {
|
||||||
|
resp, rerr := sr.Recv()
|
||||||
|
if errors.Is(rerr, io.EOF) {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
if rerr != nil {
|
||||||
|
t.Fatalf("outer stream must not hard-fail on tool timeout, got: %v", rerr)
|
||||||
|
}
|
||||||
|
if resp != nil && resp.Output != "" {
|
||||||
|
got.WriteString(resp.Output)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !strings.Contains(got.String(), einoExecuteTimeoutUserHint()) {
|
||||||
|
t.Fatalf("expected timeout hint in stream, got: %q", got.String())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEinoStreamingShellWrap_NonTimeoutRecvErrStillHard(t *testing.T) {
|
||||||
|
inner := &mockStreamingShell{recvErr: errors.New("broken pipe")}
|
||||||
|
wrap := &einoStreamingShellWrap{inner: inner}
|
||||||
|
sr, err := wrap.ExecuteStreaming(context.Background(), &filesystem.ExecuteRequest{Command: "true"})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("ExecuteStreaming: %v", err)
|
||||||
|
}
|
||||||
|
defer sr.Close()
|
||||||
|
|
||||||
|
_, rerr := sr.Recv()
|
||||||
|
if rerr == nil || errors.Is(rerr, io.EOF) {
|
||||||
|
t.Fatal("expected hard stream error for non-timeout failure")
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -96,6 +96,6 @@ func recordEinoADKFilesystemToolMonitor(
|
|||||||
}
|
}
|
||||||
id := ag.RecordLocalToolExecution(storedName, args, resultText, invErr)
|
id := ag.RecordLocalToolExecution(storedName, args, resultText, invErr)
|
||||||
if id != "" {
|
if id != "" {
|
||||||
rec(id)
|
rec(id, toolCallID)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -103,14 +103,26 @@ func mergeAlwaysVisibleToolNames(configured []string) []string {
|
|||||||
return merged
|
return merged
|
||||||
}
|
}
|
||||||
|
|
||||||
func buildReductionMiddleware(ctx context.Context, mw config.MultiAgentEinoMiddlewareConfig, convID string, loc *localbk.Local, logger *zap.Logger) (adk.ChatModelAgentMiddleware, error) {
|
func reductionCacheRootDir(configuredBase, projectID, conversationID string) string {
|
||||||
|
base := strings.TrimSpace(configuredBase)
|
||||||
|
if base == "" {
|
||||||
|
base = filepath.Join("tmp", "reduction")
|
||||||
|
}
|
||||||
|
if pid := strings.TrimSpace(projectID); pid != "" {
|
||||||
|
return filepath.Join(base, "projects", sanitizeEinoPathSegment(pid))
|
||||||
|
}
|
||||||
|
conv := strings.TrimSpace(conversationID)
|
||||||
|
if conv == "" {
|
||||||
|
conv = "default"
|
||||||
|
}
|
||||||
|
return filepath.Join(base, "conversations", sanitizeEinoPathSegment(conv))
|
||||||
|
}
|
||||||
|
|
||||||
|
func buildReductionMiddleware(ctx context.Context, mw config.MultiAgentEinoMiddlewareConfig, projectID, convID string, loc *localbk.Local, logger *zap.Logger) (adk.ChatModelAgentMiddleware, error) {
|
||||||
if loc == nil {
|
if loc == nil {
|
||||||
return nil, fmt.Errorf("reduction: local backend nil")
|
return nil, fmt.Errorf("reduction: local backend nil")
|
||||||
}
|
}
|
||||||
root := strings.TrimSpace(mw.ReductionRootDir)
|
root := reductionCacheRootDir(mw.ReductionRootDir, projectID, convID)
|
||||||
if root == "" {
|
|
||||||
root = filepath.Join(os.TempDir(), "cyberstrike-reduction", sanitizeEinoPathSegment(convID))
|
|
||||||
}
|
|
||||||
if err := os.MkdirAll(root, 0o755); err != nil {
|
if err := os.MkdirAll(root, 0o755); err != nil {
|
||||||
return nil, fmt.Errorf("reduction root: %w", err)
|
return nil, fmt.Errorf("reduction root: %w", err)
|
||||||
}
|
}
|
||||||
@@ -148,6 +160,7 @@ func prependEinoMiddlewares(
|
|||||||
einoLoc *localbk.Local,
|
einoLoc *localbk.Local,
|
||||||
skillsRoot string,
|
skillsRoot string,
|
||||||
conversationID string,
|
conversationID string,
|
||||||
|
projectID string,
|
||||||
logger *zap.Logger,
|
logger *zap.Logger,
|
||||||
) (outTools []tool.BaseTool, extraHandlers []adk.ChatModelAgentMiddleware, toolSearchActive bool, err error) {
|
) (outTools []tool.BaseTool, extraHandlers []adk.ChatModelAgentMiddleware, toolSearchActive bool, err error) {
|
||||||
if mw == nil {
|
if mw == nil {
|
||||||
@@ -167,7 +180,7 @@ func prependEinoMiddlewares(
|
|||||||
if place == einoMWSub && !mw.ReductionSubAgents {
|
if place == einoMWSub && !mw.ReductionSubAgents {
|
||||||
// skip
|
// skip
|
||||||
} else {
|
} else {
|
||||||
redMW, rerr := buildReductionMiddleware(ctx, *mw, conversationID, einoLoc, logger)
|
redMW, rerr := buildReductionMiddleware(ctx, *mw, projectID, conversationID, einoLoc, logger)
|
||||||
if rerr != nil {
|
if rerr != nil {
|
||||||
return nil, nil, false, rerr
|
return nil, nil, false, rerr
|
||||||
}
|
}
|
||||||
@@ -230,17 +243,14 @@ func prependEinoMiddlewares(
|
|||||||
return outTools, extraHandlers, toolSearchActive, nil
|
return outTools, extraHandlers, toolSearchActive, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func deepExtrasFromConfig(ma *config.MultiAgentConfig) (outputKey string, retry *adk.ModelRetryConfig, taskDesc func(context.Context, []adk.Agent) (string, error)) {
|
func deepExtrasFromConfig(ma *config.MultiAgentConfig) (outputKey string, taskDesc func(context.Context, []adk.Agent) (string, error)) {
|
||||||
if ma == nil {
|
if ma == nil {
|
||||||
return "", nil, nil
|
return "", nil
|
||||||
}
|
}
|
||||||
mw := ma.EinoMiddleware
|
mw := ma.EinoMiddleware
|
||||||
if k := strings.TrimSpace(mw.DeepOutputKey); k != "" {
|
if k := strings.TrimSpace(mw.DeepOutputKey); k != "" {
|
||||||
outputKey = k
|
outputKey = k
|
||||||
}
|
}
|
||||||
if mw.DeepModelRetryMaxRetries > 0 {
|
|
||||||
retry = &adk.ModelRetryConfig{MaxRetries: mw.DeepModelRetryMaxRetries}
|
|
||||||
}
|
|
||||||
prefix := strings.TrimSpace(mw.TaskToolDescriptionPrefix)
|
prefix := strings.TrimSpace(mw.TaskToolDescriptionPrefix)
|
||||||
if prefix != "" {
|
if prefix != "" {
|
||||||
taskDesc = func(ctx context.Context, agents []adk.Agent) (string, error) {
|
taskDesc = func(ctx context.Context, agents []adk.Agent) (string, error) {
|
||||||
@@ -261,5 +271,5 @@ func deepExtrasFromConfig(ma *config.MultiAgentConfig) (outputKey string, retry
|
|||||||
return prefix + "\n可用子代理(按名称 transfer / task 调用):" + strings.Join(names, "、"), nil
|
return prefix + "\n可用子代理(按名称 transfer / task 调用):" + strings.Join(names, "、"), nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return outputKey, retry, taskDesc
|
return outputKey, taskDesc
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -3,12 +3,31 @@ package multiagent
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"path/filepath"
|
||||||
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/cloudwego/eino/components/tool"
|
"github.com/cloudwego/eino/components/tool"
|
||||||
"github.com/cloudwego/eino/schema"
|
"github.com/cloudwego/eino/schema"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
func TestReductionCacheRootDir(t *testing.T) {
|
||||||
|
got := reductionCacheRootDir("", "proj-1", "conv-1")
|
||||||
|
want := filepath.Join("tmp", "reduction", "projects", "proj-1")
|
||||||
|
if got != want {
|
||||||
|
t.Fatalf("project scope: got %q want %q", got, want)
|
||||||
|
}
|
||||||
|
got = reductionCacheRootDir("", "", "conv-abc")
|
||||||
|
want = filepath.Join("tmp", "reduction", "conversations", "conv-abc")
|
||||||
|
if got != want {
|
||||||
|
t.Fatalf("conversation scope: got %q want %q", got, want)
|
||||||
|
}
|
||||||
|
custom := reductionCacheRootDir("/data/cache", "p1", "c1")
|
||||||
|
if !strings.HasSuffix(custom, filepath.Join("projects", "p1")) {
|
||||||
|
t.Fatalf("custom base should still scope by project, got %q", custom)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
type stubTool struct{ name string }
|
type stubTool struct{ name string }
|
||||||
|
|
||||||
func (s stubTool) Info(_ context.Context) (*schema.ToolInfo, error) {
|
func (s stubTool) Info(_ context.Context) (*schema.ToolInfo, error) {
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ import (
|
|||||||
|
|
||||||
"cyberstrike-ai/internal/agent"
|
"cyberstrike-ai/internal/agent"
|
||||||
"cyberstrike-ai/internal/config"
|
"cyberstrike-ai/internal/config"
|
||||||
|
"cyberstrike-ai/internal/database"
|
||||||
|
|
||||||
"github.com/cloudwego/eino-ext/components/model/openai"
|
"github.com/cloudwego/eino-ext/components/model/openai"
|
||||||
"github.com/cloudwego/eino/adk"
|
"github.com/cloudwego/eino/adk"
|
||||||
@@ -29,6 +30,8 @@ type PlanExecuteRootArgs struct {
|
|||||||
MwCfg *config.MultiAgentEinoMiddlewareConfig
|
MwCfg *config.MultiAgentEinoMiddlewareConfig
|
||||||
// ConversationID is used for transcript/isolation paths in middleware.
|
// ConversationID is used for transcript/isolation paths in middleware.
|
||||||
ConversationID string
|
ConversationID string
|
||||||
|
DB *database.DB
|
||||||
|
ProjectID string
|
||||||
Logger *zap.Logger
|
Logger *zap.Logger
|
||||||
// ModelName is used for model input token estimation logs.
|
// ModelName is used for model input token estimation logs.
|
||||||
ModelName string
|
ModelName string
|
||||||
@@ -93,7 +96,7 @@ func NewPlanExecuteRoot(ctx context.Context, a *PlanExecuteRootArgs) (adk.Resuma
|
|||||||
}
|
}
|
||||||
// 4. summarization(最后,与 Deep/Supervisor 一致)
|
// 4. summarization(最后,与 Deep/Supervisor 一致)
|
||||||
if a.AppCfg != nil {
|
if a.AppCfg != nil {
|
||||||
sumMw, sumErr := newEinoSummarizationMiddleware(ctx, a.ExecModel, a.AppCfg, a.MwCfg, a.ConversationID, a.Logger)
|
sumMw, sumErr := newEinoSummarizationMiddleware(ctx, a.ExecModel, a.AppCfg, a.MwCfg, a.ConversationID, a.DB, a.ProjectID, a.Logger)
|
||||||
if sumErr != nil {
|
if sumErr != nil {
|
||||||
return nil, fmt.Errorf("plan_execute executor summarization: %w", sumErr)
|
return nil, fmt.Errorf("plan_execute executor summarization: %w", sumErr)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -11,6 +11,7 @@ import (
|
|||||||
|
|
||||||
"cyberstrike-ai/internal/agent"
|
"cyberstrike-ai/internal/agent"
|
||||||
"cyberstrike-ai/internal/config"
|
"cyberstrike-ai/internal/config"
|
||||||
|
"cyberstrike-ai/internal/database"
|
||||||
"cyberstrike-ai/internal/einomcp"
|
"cyberstrike-ai/internal/einomcp"
|
||||||
"cyberstrike-ai/internal/openai"
|
"cyberstrike-ai/internal/openai"
|
||||||
"cyberstrike-ai/internal/project"
|
"cyberstrike-ai/internal/project"
|
||||||
@@ -32,8 +33,10 @@ func RunEinoSingleChatModelAgent(
|
|||||||
appCfg *config.Config,
|
appCfg *config.Config,
|
||||||
ma *config.MultiAgentConfig,
|
ma *config.MultiAgentConfig,
|
||||||
ag *agent.Agent,
|
ag *agent.Agent,
|
||||||
|
db *database.DB,
|
||||||
logger *zap.Logger,
|
logger *zap.Logger,
|
||||||
conversationID string,
|
conversationID string,
|
||||||
|
projectID string,
|
||||||
userMessage string,
|
userMessage string,
|
||||||
history []agent.ChatMessage,
|
history []agent.ChatMessage,
|
||||||
roleTools []string,
|
roleTools []string,
|
||||||
@@ -58,10 +61,12 @@ func RunEinoSingleChatModelAgent(
|
|||||||
|
|
||||||
var mcpIDsMu sync.Mutex
|
var mcpIDsMu sync.Mutex
|
||||||
var mcpIDs []string
|
var mcpIDs []string
|
||||||
recorder := func(id string) {
|
mcpExecBinder := NewMCPExecutionBinder()
|
||||||
|
recorder := func(id, toolCallID string) {
|
||||||
if id == "" {
|
if id == "" {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
mcpExecBinder.Bind(toolCallID, id)
|
||||||
mcpIDsMu.Lock()
|
mcpIDsMu.Lock()
|
||||||
mcpIDs = append(mcpIDs, id)
|
mcpIDs = append(mcpIDs, id)
|
||||||
mcpIDsMu.Unlock()
|
mcpIDsMu.Unlock()
|
||||||
@@ -75,29 +80,15 @@ func RunEinoSingleChatModelAgent(
|
|||||||
return out
|
return out
|
||||||
}
|
}
|
||||||
|
|
||||||
toolOutputChunk := func(toolName, toolCallID, chunk string) {
|
|
||||||
if progress == nil || toolCallID == "" {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
progress("tool_result_delta", chunk, map[string]interface{}{
|
|
||||||
"toolName": toolName,
|
|
||||||
"toolCallId": toolCallID,
|
|
||||||
"index": 0,
|
|
||||||
"total": 0,
|
|
||||||
"iteration": 0,
|
|
||||||
"source": "eino",
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
toolInvokeNotify := einomcp.NewToolInvokeNotifyHolder()
|
toolInvokeNotify := einomcp.NewToolInvokeNotifyHolder()
|
||||||
einoExecMonitor := newEinoExecuteMonitorCallback(ag, recorder)
|
einoExecMonitor := newEinoExecuteMonitorCallback(ag, recorder)
|
||||||
mainDefs := ag.ToolsForRole(roleTools)
|
mainDefs := ag.ToolsForRole(roleTools)
|
||||||
mainTools, err := einomcp.ToolsFromDefinitions(ag, holder, mainDefs, recorder, toolOutputChunk, toolInvokeNotify, einoSingleAgentName)
|
mainTools, err := einomcp.ToolsFromDefinitions(ag, holder, mainDefs, recorder, nil, toolInvokeNotify, einoSingleAgentName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
mainToolsForCfg, mainOrchestratorPre, singleToolSearchActive, err := prependEinoMiddlewares(ctx, &ma.EinoMiddleware, einoMWMain, mainTools, einoLoc, skillsRoot, conversationID, logger)
|
mainToolsForCfg, mainOrchestratorPre, singleToolSearchActive, err := prependEinoMiddlewares(ctx, &ma.EinoMiddleware, einoMWMain, mainTools, einoLoc, skillsRoot, conversationID, projectID, logger)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("eino single eino 中间件: %w", err)
|
return nil, fmt.Errorf("eino single eino 中间件: %w", err)
|
||||||
}
|
}
|
||||||
@@ -132,7 +123,7 @@ func RunEinoSingleChatModelAgent(
|
|||||||
return nil, fmt.Errorf("eino single 模型: %w", err)
|
return nil, fmt.Errorf("eino single 模型: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
mainSumMw, err := newEinoSummarizationMiddleware(ctx, mainModel, appCfg, &ma.EinoMiddleware, conversationID, logger)
|
mainSumMw, err := newEinoSummarizationMiddleware(ctx, mainModel, appCfg, &ma.EinoMiddleware, conversationID, db, projectID, logger)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("eino single summarization: %w", err)
|
return nil, fmt.Errorf("eino single summarization: %w", err)
|
||||||
}
|
}
|
||||||
@@ -145,7 +136,7 @@ func RunEinoSingleChatModelAgent(
|
|||||||
}
|
}
|
||||||
if einoSkillMW != nil {
|
if einoSkillMW != nil {
|
||||||
if einoFSTools && einoLoc != nil {
|
if einoFSTools && einoLoc != nil {
|
||||||
fsMw, fsErr := subAgentFilesystemMiddleware(ctx, einoLoc, toolInvokeNotify, einoSingleAgentName, einoExecMonitor, agentToolTimeoutMinutes(appCfg), toolOutputChunk)
|
fsMw, fsErr := subAgentFilesystemMiddleware(ctx, einoLoc, toolInvokeNotify, einoSingleAgentName, einoExecMonitor, agentToolTimeoutMinutes(appCfg), nil)
|
||||||
if fsErr != nil {
|
if fsErr != nil {
|
||||||
return nil, fmt.Errorf("eino single filesystem 中间件: %w", fsErr)
|
return nil, fmt.Errorf("eino single filesystem 中间件: %w", fsErr)
|
||||||
}
|
}
|
||||||
@@ -197,13 +188,10 @@ func RunEinoSingleChatModelAgent(
|
|||||||
MaxIterations: maxIter,
|
MaxIterations: maxIter,
|
||||||
Handlers: handlers,
|
Handlers: handlers,
|
||||||
}
|
}
|
||||||
outKey, modelRetry, _ := deepExtrasFromConfig(ma)
|
outKey, _ := deepExtrasFromConfig(ma)
|
||||||
if outKey != "" {
|
if outKey != "" {
|
||||||
chatCfg.OutputKey = outKey
|
chatCfg.OutputKey = outKey
|
||||||
}
|
}
|
||||||
if modelRetry != nil {
|
|
||||||
chatCfg.ModelRetryConfig = modelRetry
|
|
||||||
}
|
|
||||||
|
|
||||||
chatAgent, err := adk.NewChatModelAgent(ctx, chatCfg)
|
chatAgent, err := adk.NewChatModelAgent(ctx, chatCfg)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -237,6 +225,7 @@ func RunEinoSingleChatModelAgent(
|
|||||||
McpIDs: &mcpIDs,
|
McpIDs: &mcpIDs,
|
||||||
FilesystemMonitorAgent: ag,
|
FilesystemMonitorAgent: ag,
|
||||||
FilesystemMonitorRecord: recorder,
|
FilesystemMonitorRecord: recorder,
|
||||||
|
MCPExecutionBinder: mcpExecBinder,
|
||||||
ToolInvokeNotify: toolInvokeNotify,
|
ToolInvokeNotify: toolInvokeNotify,
|
||||||
DA: chatAgent,
|
DA: chatAgent,
|
||||||
ModelFacingTrace: modelFacingTrace,
|
ModelFacingTrace: modelFacingTrace,
|
||||||
|
|||||||
@@ -81,7 +81,7 @@ func subAgentFilesystemMiddleware(
|
|||||||
loc *localbk.Local,
|
loc *localbk.Local,
|
||||||
invokeNotify *einomcp.ToolInvokeNotifyHolder,
|
invokeNotify *einomcp.ToolInvokeNotifyHolder,
|
||||||
einoAgentName string,
|
einoAgentName string,
|
||||||
recordMonitor func(command, stdout string, success bool, invokeErr error),
|
recordMonitor func(toolCallID, command, stdout string, success bool, invokeErr error),
|
||||||
toolTimeoutMinutes int,
|
toolTimeoutMinutes int,
|
||||||
outputChunk func(toolName, toolCallID, chunk string),
|
outputChunk func(toolName, toolCallID, chunk string),
|
||||||
) (adk.ChatModelAgentMiddleware, error) {
|
) (adk.ChatModelAgentMiddleware, error) {
|
||||||
|
|||||||
@@ -9,7 +9,9 @@ import (
|
|||||||
|
|
||||||
"cyberstrike-ai/internal/agent"
|
"cyberstrike-ai/internal/agent"
|
||||||
"cyberstrike-ai/internal/config"
|
"cyberstrike-ai/internal/config"
|
||||||
|
"cyberstrike-ai/internal/database"
|
||||||
copenai "cyberstrike-ai/internal/openai"
|
copenai "cyberstrike-ai/internal/openai"
|
||||||
|
"cyberstrike-ai/internal/project"
|
||||||
|
|
||||||
"github.com/bytedance/sonic"
|
"github.com/bytedance/sonic"
|
||||||
"github.com/cloudwego/eino/adk"
|
"github.com/cloudwego/eino/adk"
|
||||||
@@ -20,8 +22,6 @@ import (
|
|||||||
"go.uber.org/zap"
|
"go.uber.org/zap"
|
||||||
)
|
)
|
||||||
|
|
||||||
const defaultSummarizationRetryMax = 3
|
|
||||||
|
|
||||||
// einoSummarizeUserInstruction:压缩历史时保留渗透测试关键信息。
|
// einoSummarizeUserInstruction:压缩历史时保留渗透测试关键信息。
|
||||||
const einoSummarizeUserInstruction = `在保持所有关键安全测试信息完整的前提下压缩对话历史。
|
const einoSummarizeUserInstruction = `在保持所有关键安全测试信息完整的前提下压缩对话历史。
|
||||||
|
|
||||||
@@ -40,6 +40,8 @@ func newEinoSummarizationMiddleware(
|
|||||||
appCfg *config.Config,
|
appCfg *config.Config,
|
||||||
mwCfg *config.MultiAgentEinoMiddlewareConfig,
|
mwCfg *config.MultiAgentEinoMiddlewareConfig,
|
||||||
conversationID string,
|
conversationID string,
|
||||||
|
db *database.DB,
|
||||||
|
projectID string,
|
||||||
logger *zap.Logger,
|
logger *zap.Logger,
|
||||||
) (adk.ChatModelAgentMiddleware, error) {
|
) (adk.ChatModelAgentMiddleware, error) {
|
||||||
if summaryModel == nil || appCfg == nil {
|
if summaryModel == nil || appCfg == nil {
|
||||||
@@ -93,10 +95,8 @@ func newEinoSummarizationMiddleware(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
retryMax := defaultSummarizationRetryMax
|
retryPolicy := einoTransientRunRetryPolicyFromMW(mwCfg)
|
||||||
if mwCfg != nil && mwCfg.SummarizationRetryMaxAttempts > 0 {
|
retryMax := retryPolicy.maxAttempts
|
||||||
retryMax = mwCfg.SummarizationRetryMaxAttempts
|
|
||||||
}
|
|
||||||
|
|
||||||
// ModelOptions apply only to summarization Generate (same ChatModel instance as the agent).
|
// ModelOptions apply only to summarization Generate (same ChatModel instance as the agent).
|
||||||
// Strip thinking/reasoning on this call path; mark requests for empty-choices diagnostics.
|
// Strip thinking/reasoning on this call path; mark requests for empty-choices diagnostics.
|
||||||
@@ -133,17 +133,25 @@ func newEinoSummarizationMiddleware(
|
|||||||
Retry: &summarization.RetryConfig{
|
Retry: &summarization.RetryConfig{
|
||||||
MaxRetries: &retryMax,
|
MaxRetries: &retryMax,
|
||||||
ShouldRetry: func(_ context.Context, _ adk.Message, err error) bool {
|
ShouldRetry: func(_ context.Context, _ adk.Message, err error) bool {
|
||||||
if err != nil && logger != nil {
|
retry := isEinoTransientRunError(err)
|
||||||
logger.Warn("eino summarization generate attempt failed, will retry if attempts remain",
|
if retry && logger != nil {
|
||||||
|
logger.Warn("eino summarization generate transient error, will retry if attempts remain",
|
||||||
zap.Error(err),
|
zap.Error(err),
|
||||||
zap.Int("max_retries", retryMax),
|
zap.Int("max_retries", retryMax),
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
return err != nil
|
return retry
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
Finalize: func(ctx context.Context, originalMessages []adk.Message, summary adk.Message) ([]adk.Message, error) {
|
Finalize: func(ctx context.Context, originalMessages []adk.Message, summary adk.Message) ([]adk.Message, error) {
|
||||||
return summarizeFinalizeWithRecentAssistantToolTrail(ctx, originalMessages, summary, tokenCounter, recentTrailMax)
|
out, ferr := summarizeFinalizeWithRecentAssistantToolTrail(ctx, originalMessages, summary, tokenCounter, recentTrailMax)
|
||||||
|
if ferr != nil {
|
||||||
|
return nil, ferr
|
||||||
|
}
|
||||||
|
if appCfg != nil {
|
||||||
|
out = refreshFactIndexInMessages(out, db, projectID, appCfg.Project, logger)
|
||||||
|
}
|
||||||
|
return out, nil
|
||||||
},
|
},
|
||||||
Callback: func(ctx context.Context, before, after adk.ChatModelAgentState) error {
|
Callback: func(ctx context.Context, before, after adk.ChatModelAgentState) error {
|
||||||
if transcriptPath != "" && len(before.Messages) > 0 {
|
if transcriptPath != "" && len(before.Messages) > 0 {
|
||||||
@@ -176,6 +184,50 @@ func newEinoSummarizationMiddleware(
|
|||||||
return mw, nil
|
return mw, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// refreshFactIndexInMessages 在 summarization 压缩后,用 DB 最新索引替换 system 中已有的项目黑板索引段。
|
||||||
|
func refreshFactIndexInMessages(msgs []adk.Message, db *database.DB, projectID string, cfg config.ProjectConfig, logger *zap.Logger) []adk.Message {
|
||||||
|
if db == nil || !cfg.Enabled {
|
||||||
|
return msgs
|
||||||
|
}
|
||||||
|
projectID = strings.TrimSpace(projectID)
|
||||||
|
if projectID == "" {
|
||||||
|
return msgs
|
||||||
|
}
|
||||||
|
freshIndex, err := project.BuildFactIndexBlock(db, projectID, cfg)
|
||||||
|
if err != nil {
|
||||||
|
if logger != nil {
|
||||||
|
logger.Warn("summarization: 刷新项目黑板索引失败", zap.String("projectId", projectID), zap.Error(err))
|
||||||
|
}
|
||||||
|
return msgs
|
||||||
|
}
|
||||||
|
freshIndex = strings.TrimSpace(freshIndex)
|
||||||
|
if freshIndex == "" {
|
||||||
|
return msgs
|
||||||
|
}
|
||||||
|
|
||||||
|
changed := false
|
||||||
|
out := make([]adk.Message, len(msgs))
|
||||||
|
for i, msg := range msgs {
|
||||||
|
if msg == nil || msg.Role != schema.System {
|
||||||
|
out[i] = msg
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
newContent, ok := project.ReplaceFactIndexSection(msg.Content, freshIndex)
|
||||||
|
if !ok {
|
||||||
|
out[i] = msg
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
cloned := *msg
|
||||||
|
cloned.Content = newContent
|
||||||
|
out[i] = &cloned
|
||||||
|
changed = true
|
||||||
|
}
|
||||||
|
if changed && logger != nil {
|
||||||
|
logger.Info("summarization: 已刷新项目黑板索引", zap.String("projectId", projectID))
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
// summarizeFinalizeWithRecentAssistantToolTrail 在摘要消息后保留最近 assistant/tool 轨迹,避免压缩后执行链断裂。
|
// summarizeFinalizeWithRecentAssistantToolTrail 在摘要消息后保留最近 assistant/tool 轨迹,避免压缩后执行链断裂。
|
||||||
//
|
//
|
||||||
// 关键不变量:tool_call ↔ tool_result 的 pair 必须整体保留或整体丢弃。
|
// 关键不变量:tool_call ↔ tool_result 的 pair 必须整体保留或整体丢弃。
|
||||||
|
|||||||
@@ -7,9 +7,14 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"cyberstrike-ai/internal/config"
|
||||||
|
"cyberstrike-ai/internal/database"
|
||||||
|
"cyberstrike-ai/internal/project"
|
||||||
|
|
||||||
"github.com/cloudwego/eino/adk"
|
"github.com/cloudwego/eino/adk"
|
||||||
"github.com/cloudwego/eino/adk/middlewares/summarization"
|
"github.com/cloudwego/eino/adk/middlewares/summarization"
|
||||||
"github.com/cloudwego/eino/schema"
|
"github.com/cloudwego/eino/schema"
|
||||||
|
"go.uber.org/zap"
|
||||||
)
|
)
|
||||||
|
|
||||||
// fixedTokenCounter 让 tool 消息按 tokensPerToolMessage 计,其它消息按 1 计。
|
// fixedTokenCounter 让 tool 消息按 tokensPerToolMessage 计,其它消息按 1 计。
|
||||||
@@ -389,9 +394,11 @@ func TestSanitizeSystemContentForTranscript_BestPractice(t *testing.T) {
|
|||||||
"你是CyberStrikeAI,是一个专业的网络安全渗透测试专家。",
|
"你是CyberStrikeAI,是一个专业的网络安全渗透测试专家。",
|
||||||
"高强度扫描要求:全力出击",
|
"高强度扫描要求:全力出击",
|
||||||
"",
|
"",
|
||||||
|
project.FactIndexSectionStartMarker,
|
||||||
"## 项目黑板索引(project: 123, id: abc)",
|
"## 项目黑板索引(project: 123, id: abc)",
|
||||||
"(暂无事实)",
|
"(暂无事实)",
|
||||||
"需要写入请使用 upsert_project_fact。",
|
"需要写入请使用 upsert_project_fact。",
|
||||||
|
project.FactIndexSectionEndMarker,
|
||||||
"",
|
"",
|
||||||
"# Skills System",
|
"# Skills System",
|
||||||
"**How to Use Skills**",
|
"**How to Use Skills**",
|
||||||
@@ -419,7 +426,7 @@ func TestSanitizeSystemContentForTranscript_BestPractice(t *testing.T) {
|
|||||||
func TestFormatSummarizationTranscript_OmitsBloatedSystem(t *testing.T) {
|
func TestFormatSummarizationTranscript_OmitsBloatedSystem(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
msgs := []adk.Message{
|
msgs := []adk.Message{
|
||||||
schema.SystemMessage("以下是当前会话绑定的工具名称索引\n- nmap\n\n你是CyberStrikeAI\n## 项目黑板索引(project: p1, id: x)\n(暂无事实)\n# Skills System\nboiler"),
|
schema.SystemMessage("以下是当前会话绑定的工具名称索引\n- nmap\n\n你是CyberStrikeAI\n" + project.FactIndexSectionStartMarker + "\n## 项目黑板索引(project: p1, id: x)\n(暂无事实)\n" + project.FactIndexSectionEndMarker + "\n# Skills System\nboiler"),
|
||||||
schema.UserMessage("hello"),
|
schema.UserMessage("hello"),
|
||||||
schema.AssistantMessage("reply", nil),
|
schema.AssistantMessage("reply", nil),
|
||||||
}
|
}
|
||||||
@@ -434,3 +441,51 @@ func TestFormatSummarizationTranscript_OmitsBloatedSystem(t *testing.T) {
|
|||||||
t.Fatalf("dynamic blackboard missing: %q", out)
|
t.Fatalf("dynamic blackboard missing: %q", out)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestRefreshFactIndexInMessages(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
dbPath := filepath.Join(t.TempDir(), "summarize-facts.db")
|
||||||
|
db, err := database.NewDB(dbPath, zap.NewNop())
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
defer db.Close()
|
||||||
|
|
||||||
|
proj, err := db.CreateProject(&database.Project{Name: "summarize-proj"})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
cfg := config.ProjectConfig{Enabled: true}
|
||||||
|
oldIndex, err := project.BuildFactIndexBlock(db, proj.ID, cfg)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = db.UpsertProjectFact(&database.ProjectFact{
|
||||||
|
ProjectID: proj.ID,
|
||||||
|
FactKey: "target/host",
|
||||||
|
Category: "target",
|
||||||
|
Summary: "fresh host fact",
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
msgs := []adk.Message{
|
||||||
|
schema.SystemMessage("instruction\n\n" + oldIndex),
|
||||||
|
schema.UserMessage("hi"),
|
||||||
|
}
|
||||||
|
|
||||||
|
out := refreshFactIndexInMessages(msgs, db, proj.ID, cfg, nil)
|
||||||
|
sys := out[0].Content
|
||||||
|
if strings.Contains(sys, "(暂无事实)") {
|
||||||
|
t.Fatalf("expected refreshed index, got: %q", sys)
|
||||||
|
}
|
||||||
|
if !strings.Contains(sys, "fresh host fact") {
|
||||||
|
t.Fatalf("expected new fact in index: %q", sys)
|
||||||
|
}
|
||||||
|
if !strings.Contains(sys, "instruction") {
|
||||||
|
t.Fatalf("non-index system content should be preserved: %q", sys)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -6,6 +6,8 @@ import (
|
|||||||
"github.com/cloudwego/eino/adk"
|
"github.com/cloudwego/eino/adk"
|
||||||
"github.com/cloudwego/eino/schema"
|
"github.com/cloudwego/eino/schema"
|
||||||
|
|
||||||
|
"cyberstrike-ai/internal/project"
|
||||||
|
|
||||||
"github.com/bytedance/sonic"
|
"github.com/bytedance/sonic"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -19,7 +21,6 @@ const (
|
|||||||
transcriptToolIndexStartMarker = "以下是当前会话绑定的工具名称索引"
|
transcriptToolIndexStartMarker = "以下是当前会话绑定的工具名称索引"
|
||||||
transcriptPersonaStartMarker = "你是CyberStrikeAI"
|
transcriptPersonaStartMarker = "你是CyberStrikeAI"
|
||||||
transcriptSkillsSystemMarker = "# Skills System"
|
transcriptSkillsSystemMarker = "# Skills System"
|
||||||
transcriptProjectBlackboardMarker = "## 项目黑板索引"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// formatSummarizationTranscript renders pre-compaction messages for transcript.txt.
|
// formatSummarizationTranscript renders pre-compaction messages for transcript.txt.
|
||||||
@@ -88,11 +89,17 @@ func stripSkillsSystemBoilerplate(s string) string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func extractProjectBlackboardSection(s string) string {
|
func extractProjectBlackboardSection(s string) string {
|
||||||
idx := strings.Index(s, transcriptProjectBlackboardMarker)
|
start := strings.Index(s, project.FactIndexSectionStartMarker)
|
||||||
if idx < 0 {
|
if start < 0 {
|
||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
return strings.TrimSpace(s[idx:])
|
section := s[start:]
|
||||||
|
end := strings.Index(section, project.FactIndexSectionEndMarker)
|
||||||
|
if end < 0 {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
section = section[:end+len(project.FactIndexSectionEndMarker)]
|
||||||
|
return strings.TrimSpace(section)
|
||||||
}
|
}
|
||||||
|
|
||||||
func appendTranscriptSection(sb *strings.Builder, role schema.RoleType, body string) {
|
func appendTranscriptSection(sb *strings.Builder, role schema.RoleType, body string) {
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ package multiagent
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
|
"fmt"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@@ -17,8 +18,9 @@ const (
|
|||||||
defaultEinoRunRetryMaxBackoff = 30 * time.Second
|
defaultEinoRunRetryMaxBackoff = 30 * time.Second
|
||||||
)
|
)
|
||||||
|
|
||||||
// isEinoTransientRunError 判断 ADK 运行期错误是否适合指数退避续跑(429、5xx、网络抖动等)。
|
// isEinoTransientRunError 是 Eino 运行期「可退避重试 vs 直接失败」的唯一判据。
|
||||||
// 用户取消、超时、迭代上限等由 run loop 单独处理,不在此列。
|
// 429/5xx/网络抖动等返回 true;用户取消、超时、迭代上限、鉴权失败等返回 false。
|
||||||
|
// 其它模块(run loop、summarization 等)只调用本函数,不在别处维护平行规则。
|
||||||
func isEinoTransientRunError(err error) bool {
|
func isEinoTransientRunError(err error) bool {
|
||||||
if err == nil {
|
if err == nil {
|
||||||
return false
|
return false
|
||||||
@@ -78,6 +80,68 @@ func isEinoTransientRunError(err error) bool {
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type einoTransientRunRetryPolicy struct {
|
||||||
|
maxAttempts int
|
||||||
|
maxBackoff time.Duration
|
||||||
|
}
|
||||||
|
|
||||||
|
func einoTransientRunRetryPolicyFromArgs(args *einoADKRunLoopArgs) einoTransientRunRetryPolicy {
|
||||||
|
return einoTransientRunRetryPolicy{
|
||||||
|
maxAttempts: einoRunRetryMaxAttempts(args),
|
||||||
|
maxBackoff: einoRunRetryMaxBackoff(args),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func einoTransientRunRetryPolicyFromMW(mw *config.MultiAgentEinoMiddlewareConfig) einoTransientRunRetryPolicy {
|
||||||
|
maxBackoff := defaultEinoRunRetryMaxBackoff
|
||||||
|
if mw != nil && mw.RunRetryMaxBackoffSec > 0 {
|
||||||
|
maxBackoff = time.Duration(mw.RunRetryMaxBackoffSec) * time.Second
|
||||||
|
}
|
||||||
|
return einoTransientRunRetryPolicy{
|
||||||
|
maxAttempts: RunRetryMaxAttemptsFromConfig(mw),
|
||||||
|
maxBackoff: maxBackoff,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// einoTransientRunRetrier 在 run loop 内对临时错误做指数退避并重启 Runner(唯一重试执行层)。
|
||||||
|
type einoTransientRunRetrier struct {
|
||||||
|
policy einoTransientRunRetryPolicy
|
||||||
|
attempts int
|
||||||
|
}
|
||||||
|
|
||||||
|
func newEinoTransientRunRetrier(policy einoTransientRunRetryPolicy) *einoTransientRunRetrier {
|
||||||
|
return &einoTransientRunRetrier{policy: policy}
|
||||||
|
}
|
||||||
|
|
||||||
|
// tryRetry 对临时错误退避后返回重启消息;次数用尽返回 exhausted 错误。
|
||||||
|
func (r *einoTransientRunRetrier) tryRetry(
|
||||||
|
ctx context.Context,
|
||||||
|
runErr error,
|
||||||
|
args *einoADKRunLoopArgs,
|
||||||
|
baseMsgs, accumulated []adk.Message,
|
||||||
|
baseCount int,
|
||||||
|
) (restarted bool, restartMsgs []adk.Message, ctxSource einoRunRestartContextSource, backoff time.Duration, fatal error) {
|
||||||
|
if runErr == nil || !isEinoTransientRunError(runErr) {
|
||||||
|
return false, nil, "", 0, runErr
|
||||||
|
}
|
||||||
|
r.attempts++
|
||||||
|
if r.attempts > r.policy.maxAttempts {
|
||||||
|
return false, nil, "", 0, fmt.Errorf("transient retry exhausted after %d attempts: %w", r.policy.maxAttempts, runErr)
|
||||||
|
}
|
||||||
|
backoff = einoTransientRetryBackoff(r.attempts-1, r.policy.maxBackoff)
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return false, nil, "", 0, ctx.Err()
|
||||||
|
case <-time.After(backoff):
|
||||||
|
}
|
||||||
|
restartMsgs, ctxSource = einoMessagesForRunRestart(args, baseMsgs, accumulated, baseCount)
|
||||||
|
return true, restartMsgs, ctxSource, backoff, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *einoTransientRunRetrier) attempt() int { return r.attempts }
|
||||||
|
|
||||||
|
func (r *einoTransientRunRetrier) maxAttempts() int { return r.policy.maxAttempts }
|
||||||
|
|
||||||
func einoRunRetryMaxAttempts(args *einoADKRunLoopArgs) int {
|
func einoRunRetryMaxAttempts(args *einoADKRunLoopArgs) int {
|
||||||
if args != nil && args.RunRetryMaxAttempts > 0 {
|
if args != nil && args.RunRetryMaxAttempts > 0 {
|
||||||
return args.RunRetryMaxAttempts
|
return args.RunRetryMaxAttempts
|
||||||
@@ -85,7 +149,7 @@ func einoRunRetryMaxAttempts(args *einoADKRunLoopArgs) int {
|
|||||||
return defaultEinoRunRetryMaxAttempts
|
return defaultEinoRunRetryMaxAttempts
|
||||||
}
|
}
|
||||||
|
|
||||||
// RunRetryMaxAttemptsFromConfig 供 handler 分段续跑计数(与 eino_middleware.run_retry_max_attempts 一致)。
|
// RunRetryMaxAttemptsFromConfig 与 eino_middleware.run_retry_max_attempts 一致。
|
||||||
func RunRetryMaxAttemptsFromConfig(mw *config.MultiAgentEinoMiddlewareConfig) int {
|
func RunRetryMaxAttemptsFromConfig(mw *config.MultiAgentEinoMiddlewareConfig) int {
|
||||||
if mw != nil && mw.RunRetryMaxAttempts > 0 {
|
if mw != nil && mw.RunRetryMaxAttempts > 0 {
|
||||||
return mw.RunRetryMaxAttempts
|
return mw.RunRetryMaxAttempts
|
||||||
@@ -93,15 +157,6 @@ func RunRetryMaxAttemptsFromConfig(mw *config.MultiAgentEinoMiddlewareConfig) in
|
|||||||
return defaultEinoRunRetryMaxAttempts
|
return defaultEinoRunRetryMaxAttempts
|
||||||
}
|
}
|
||||||
|
|
||||||
// TransientRetryBackoff 供 handler 在分段续跑前退避。
|
|
||||||
func TransientRetryBackoff(attempt int, maxBackoffSec int) time.Duration {
|
|
||||||
max := defaultEinoRunRetryMaxBackoff
|
|
||||||
if maxBackoffSec > 0 {
|
|
||||||
max = time.Duration(maxBackoffSec) * time.Second
|
|
||||||
}
|
|
||||||
return einoTransientRetryBackoff(attempt, max)
|
|
||||||
}
|
|
||||||
|
|
||||||
func einoRunRetryMaxBackoff(args *einoADKRunLoopArgs) time.Duration {
|
func einoRunRetryMaxBackoff(args *einoADKRunLoopArgs) time.Duration {
|
||||||
if args != nil && args.RunRetryMaxBackoffSec > 0 {
|
if args != nil && args.RunRetryMaxBackoffSec > 0 {
|
||||||
return time.Duration(args.RunRetryMaxBackoffSec) * time.Second
|
return time.Duration(args.RunRetryMaxBackoffSec) * time.Second
|
||||||
|
|||||||
@@ -102,10 +102,3 @@ func TestAppendUserMessageIfNeeded(t *testing.T) {
|
|||||||
t.Fatalf("should not duplicate user message: len=%d", len(dup))
|
t.Fatalf("should not duplicate user message: len=%d", len(dup))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestErrTransientRetryContinue(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
if !errors.Is(ErrTransientRetryContinue, ErrTransientRetryContinue) {
|
|
||||||
t.Fatal("sentinel should match")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -5,11 +5,3 @@ import "errors"
|
|||||||
// ErrInterruptContinue 作为 context.CancelCause 使用:用户选择「中断并继续」且当前无进行中的 MCP 工具时,
|
// ErrInterruptContinue 作为 context.CancelCause 使用:用户选择「中断并继续」且当前无进行中的 MCP 工具时,
|
||||||
// 取消当前推理/流式输出,并在同一会话任务内携带用户补充说明自动续跑下一轮(类似 Hermes 式人机回合)。
|
// 取消当前推理/流式输出,并在同一会话任务内携带用户补充说明自动续跑下一轮(类似 Hermes 式人机回合)。
|
||||||
var ErrInterruptContinue = errors.New("agent interrupt: continue with user-supplied context")
|
var ErrInterruptContinue = errors.New("agent interrupt: continue with user-supplied context")
|
||||||
|
|
||||||
// ErrTransientRetryContinue 表示 Run 因 429/网络等临时错误结束,应由 handler 落库轨迹后
|
|
||||||
// loadHistoryFromAgentTrace 再开下一轮 Run(与 ErrInterruptContinue 同级的「分段续跑」语义)。
|
|
||||||
var ErrTransientRetryContinue = errors.New("agent transient: retry after persisting trace")
|
|
||||||
|
|
||||||
// ErrEmptyResponseContinue 表示 Eino ADK 会话正常结束但未捕获到助手正文,应由 handler 落库轨迹后
|
|
||||||
// loadHistoryFromAgentTrace 再开下一轮 Run(与 ErrInterruptContinue / ErrTransientRetryContinue 同级)。
|
|
||||||
var ErrEmptyResponseContinue = errors.New("agent empty response: continue after persisting trace")
|
|
||||||
|
|||||||
@@ -0,0 +1,31 @@
|
|||||||
|
package multiagent
|
||||||
|
|
||||||
|
import "strings"
|
||||||
|
|
||||||
|
// MCPExecutionBinder maps ADK toolCallID → MCP monitor execution ID for a single agent run.
|
||||||
|
type MCPExecutionBinder struct {
|
||||||
|
byToolCall map[string]string
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewMCPExecutionBinder() *MCPExecutionBinder {
|
||||||
|
return &MCPExecutionBinder{byToolCall: make(map[string]string)}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *MCPExecutionBinder) Bind(toolCallID, executionID string) {
|
||||||
|
if b == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
tid := strings.TrimSpace(toolCallID)
|
||||||
|
eid := strings.TrimSpace(executionID)
|
||||||
|
if tid == "" || eid == "" {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
b.byToolCall[tid] = eid
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *MCPExecutionBinder) ExecutionID(toolCallID string) string {
|
||||||
|
if b == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return b.byToolCall[strings.TrimSpace(toolCallID)]
|
||||||
|
}
|
||||||
@@ -0,0 +1,14 @@
|
|||||||
|
package multiagent
|
||||||
|
|
||||||
|
import "testing"
|
||||||
|
|
||||||
|
func TestMCPExecutionBinder(t *testing.T) {
|
||||||
|
b := NewMCPExecutionBinder()
|
||||||
|
b.Bind("call-1", "exec-1")
|
||||||
|
if got := b.ExecutionID("call-1"); got != "exec-1" {
|
||||||
|
t.Fatalf("expected exec-1, got %q", got)
|
||||||
|
}
|
||||||
|
if got := b.ExecutionID("missing"); got != "" {
|
||||||
|
t.Fatalf("expected empty, got %q", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -15,6 +15,7 @@ import (
|
|||||||
"cyberstrike-ai/internal/agent"
|
"cyberstrike-ai/internal/agent"
|
||||||
"cyberstrike-ai/internal/agents"
|
"cyberstrike-ai/internal/agents"
|
||||||
"cyberstrike-ai/internal/config"
|
"cyberstrike-ai/internal/config"
|
||||||
|
"cyberstrike-ai/internal/database"
|
||||||
"cyberstrike-ai/internal/einomcp"
|
"cyberstrike-ai/internal/einomcp"
|
||||||
"cyberstrike-ai/internal/openai"
|
"cyberstrike-ai/internal/openai"
|
||||||
"cyberstrike-ai/internal/project"
|
"cyberstrike-ai/internal/project"
|
||||||
@@ -56,8 +57,10 @@ func RunDeepAgent(
|
|||||||
appCfg *config.Config,
|
appCfg *config.Config,
|
||||||
ma *config.MultiAgentConfig,
|
ma *config.MultiAgentConfig,
|
||||||
ag *agent.Agent,
|
ag *agent.Agent,
|
||||||
|
db *database.DB,
|
||||||
logger *zap.Logger,
|
logger *zap.Logger,
|
||||||
conversationID string,
|
conversationID string,
|
||||||
|
projectID string,
|
||||||
userMessage string,
|
userMessage string,
|
||||||
history []agent.ChatMessage,
|
history []agent.ChatMessage,
|
||||||
roleTools []string,
|
roleTools []string,
|
||||||
@@ -107,10 +110,12 @@ func RunDeepAgent(
|
|||||||
|
|
||||||
var mcpIDsMu sync.Mutex
|
var mcpIDsMu sync.Mutex
|
||||||
var mcpIDs []string
|
var mcpIDs []string
|
||||||
recorder := func(id string) {
|
mcpExecBinder := NewMCPExecutionBinder()
|
||||||
|
recorder := func(id, toolCallID string) {
|
||||||
if id == "" {
|
if id == "" {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
mcpExecBinder.Bind(toolCallID, id)
|
||||||
mcpIDsMu.Lock()
|
mcpIDsMu.Lock()
|
||||||
mcpIDs = append(mcpIDs, id)
|
mcpIDs = append(mcpIDs, id)
|
||||||
mcpIDsMu.Unlock()
|
mcpIDsMu.Unlock()
|
||||||
@@ -128,21 +133,6 @@ func RunDeepAgent(
|
|||||||
|
|
||||||
toolInvokeNotify := einomcp.NewToolInvokeNotifyHolder()
|
toolInvokeNotify := einomcp.NewToolInvokeNotifyHolder()
|
||||||
mainDefs := ag.ToolsForRole(roleTools)
|
mainDefs := ag.ToolsForRole(roleTools)
|
||||||
toolOutputChunk := func(toolName, toolCallID, chunk string) {
|
|
||||||
// When toolCallId is missing, frontend ignores tool_result_delta.
|
|
||||||
if progress == nil || toolCallID == "" {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
progress("tool_result_delta", chunk, map[string]interface{}{
|
|
||||||
"toolName": toolName,
|
|
||||||
"toolCallId": toolCallID,
|
|
||||||
// index/total/iteration are optional for UI; we don't know them in this bridge.
|
|
||||||
"index": 0,
|
|
||||||
"total": 0,
|
|
||||||
"iteration": 0,
|
|
||||||
"source": "eino",
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
httpClient := &http.Client{
|
httpClient := &http.Client{
|
||||||
Timeout: 30 * time.Minute,
|
Timeout: 30 * time.Minute,
|
||||||
@@ -210,19 +200,19 @@ func RunDeepAgent(
|
|||||||
}
|
}
|
||||||
|
|
||||||
subDefs := ag.ToolsForRole(roleTools)
|
subDefs := ag.ToolsForRole(roleTools)
|
||||||
subTools, err := einomcp.ToolsFromDefinitions(ag, holder, subDefs, recorder, toolOutputChunk, toolInvokeNotify, id)
|
subTools, err := einomcp.ToolsFromDefinitions(ag, holder, subDefs, recorder, nil, toolInvokeNotify, id)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("子代理 %q 工具: %w", id, err)
|
return nil, fmt.Errorf("子代理 %q 工具: %w", id, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
subToolsForCfg, subPre, subToolSearchActive, err := prependEinoMiddlewares(ctx, &ma.EinoMiddleware, einoMWSub, subTools, einoLoc, skillsRoot, conversationID, logger)
|
subToolsForCfg, subPre, subToolSearchActive, err := prependEinoMiddlewares(ctx, &ma.EinoMiddleware, einoMWSub, subTools, einoLoc, skillsRoot, conversationID, projectID, logger)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("子代理 %q eino 中间件: %w", id, err)
|
return nil, fmt.Errorf("子代理 %q eino 中间件: %w", id, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
subMax := resolveMaxIterations(appCfg, sub.MaxIterations)
|
subMax := resolveMaxIterations(appCfg, sub.MaxIterations)
|
||||||
|
|
||||||
subSumMw, err := newEinoSummarizationMiddleware(ctx, subModel, appCfg, &ma.EinoMiddleware, conversationID, logger)
|
subSumMw, err := newEinoSummarizationMiddleware(ctx, subModel, appCfg, &ma.EinoMiddleware, conversationID, db, projectID, logger)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("子代理 %q summarization 中间件: %w", id, err)
|
return nil, fmt.Errorf("子代理 %q summarization 中间件: %w", id, err)
|
||||||
}
|
}
|
||||||
@@ -233,7 +223,7 @@ func RunDeepAgent(
|
|||||||
}
|
}
|
||||||
if einoSkillMW != nil {
|
if einoSkillMW != nil {
|
||||||
if einoFSTools && einoLoc != nil {
|
if einoFSTools && einoLoc != nil {
|
||||||
subFs, fsErr := subAgentFilesystemMiddleware(ctx, einoLoc, toolInvokeNotify, id, einoExecMonitor, agentToolTimeoutMinutes(appCfg), toolOutputChunk)
|
subFs, fsErr := subAgentFilesystemMiddleware(ctx, einoLoc, toolInvokeNotify, id, einoExecMonitor, agentToolTimeoutMinutes(appCfg), nil)
|
||||||
if fsErr != nil {
|
if fsErr != nil {
|
||||||
return nil, fmt.Errorf("子代理 %q filesystem 中间件: %w", id, fsErr)
|
return nil, fmt.Errorf("子代理 %q filesystem 中间件: %w", id, fsErr)
|
||||||
}
|
}
|
||||||
@@ -293,7 +283,7 @@ func RunDeepAgent(
|
|||||||
return nil, fmt.Errorf("多代理主模型: %w", err)
|
return nil, fmt.Errorf("多代理主模型: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
mainSumMw, err := newEinoSummarizationMiddleware(ctx, mainModel, appCfg, &ma.EinoMiddleware, conversationID, logger)
|
mainSumMw, err := newEinoSummarizationMiddleware(ctx, mainModel, appCfg, &ma.EinoMiddleware, conversationID, db, projectID, logger)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("多代理主 summarization 中间件: %w", err)
|
return nil, fmt.Errorf("多代理主 summarization 中间件: %w", err)
|
||||||
}
|
}
|
||||||
@@ -320,11 +310,11 @@ func RunDeepAgent(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
mainTools, err := einomcp.ToolsFromDefinitions(ag, holder, mainDefs, recorder, toolOutputChunk, toolInvokeNotify, orchestratorName)
|
mainTools, err := einomcp.ToolsFromDefinitions(ag, holder, mainDefs, recorder, nil, toolInvokeNotify, orchestratorName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
mainToolsForCfg, mainOrchestratorPre, mainToolSearchActive, err := prependEinoMiddlewares(ctx, &ma.EinoMiddleware, einoMWMain, mainTools, einoLoc, skillsRoot, conversationID, logger)
|
mainToolsForCfg, mainOrchestratorPre, mainToolSearchActive, err := prependEinoMiddlewares(ctx, &ma.EinoMiddleware, einoMWMain, mainTools, einoLoc, skillsRoot, conversationID, projectID, logger)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -371,7 +361,7 @@ func RunDeepAgent(
|
|||||||
inner: einoLoc,
|
inner: einoLoc,
|
||||||
invokeNotify: toolInvokeNotify,
|
invokeNotify: toolInvokeNotify,
|
||||||
einoAgentName: orchestratorName,
|
einoAgentName: orchestratorName,
|
||||||
outputChunk: toolOutputChunk,
|
outputChunk: nil,
|
||||||
recordMonitor: einoExecMonitor,
|
recordMonitor: einoExecMonitor,
|
||||||
toolTimeoutMinutes: agentToolTimeoutMinutes(appCfg),
|
toolTimeoutMinutes: agentToolTimeoutMinutes(appCfg),
|
||||||
}
|
}
|
||||||
@@ -426,7 +416,7 @@ func RunDeepAgent(
|
|||||||
EmitInternalEvents: true,
|
EmitInternalEvents: true,
|
||||||
}
|
}
|
||||||
|
|
||||||
deepOutKey, modelRetry, taskGen := deepExtrasFromConfig(ma)
|
deepOutKey, taskGen := deepExtrasFromConfig(ma)
|
||||||
|
|
||||||
var da adk.Agent
|
var da adk.Agent
|
||||||
switch orchMode {
|
switch orchMode {
|
||||||
@@ -438,7 +428,7 @@ func RunDeepAgent(
|
|||||||
// 构建 filesystem 中间件(与 Deep sub-agent 一致)
|
// 构建 filesystem 中间件(与 Deep sub-agent 一致)
|
||||||
var peFsMw adk.ChatModelAgentMiddleware
|
var peFsMw adk.ChatModelAgentMiddleware
|
||||||
if einoSkillMW != nil && einoFSTools && einoLoc != nil {
|
if einoSkillMW != nil && einoFSTools && einoLoc != nil {
|
||||||
peFsMw, err = subAgentFilesystemMiddleware(ctx, einoLoc, toolInvokeNotify, "executor", einoExecMonitor, agentToolTimeoutMinutes(appCfg), toolOutputChunk)
|
peFsMw, err = subAgentFilesystemMiddleware(ctx, einoLoc, toolInvokeNotify, "executor", einoExecMonitor, agentToolTimeoutMinutes(appCfg), nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("plan_execute filesystem 中间件: %w", err)
|
return nil, fmt.Errorf("plan_execute filesystem 中间件: %w", err)
|
||||||
}
|
}
|
||||||
@@ -453,6 +443,8 @@ func RunDeepAgent(
|
|||||||
AppCfg: appCfg,
|
AppCfg: appCfg,
|
||||||
MwCfg: &ma.EinoMiddleware,
|
MwCfg: &ma.EinoMiddleware,
|
||||||
ConversationID: conversationID,
|
ConversationID: conversationID,
|
||||||
|
DB: db,
|
||||||
|
ProjectID: projectID,
|
||||||
Logger: logger,
|
Logger: logger,
|
||||||
ModelName: appCfg.OpenAI.Model,
|
ModelName: appCfg.OpenAI.Model,
|
||||||
ExecPreMiddlewares: mainOrchestratorPre,
|
ExecPreMiddlewares: mainOrchestratorPre,
|
||||||
@@ -481,9 +473,6 @@ func RunDeepAgent(
|
|||||||
Handlers: supHandlers,
|
Handlers: supHandlers,
|
||||||
Exit: &adk.ExitTool{},
|
Exit: &adk.ExitTool{},
|
||||||
}
|
}
|
||||||
if modelRetry != nil {
|
|
||||||
supCfg.ModelRetryConfig = modelRetry
|
|
||||||
}
|
|
||||||
if deepOutKey != "" {
|
if deepOutKey != "" {
|
||||||
supCfg.OutputKey = deepOutKey
|
supCfg.OutputKey = deepOutKey
|
||||||
}
|
}
|
||||||
@@ -517,9 +506,6 @@ func RunDeepAgent(
|
|||||||
if deepOutKey != "" {
|
if deepOutKey != "" {
|
||||||
dcfg.OutputKey = deepOutKey
|
dcfg.OutputKey = deepOutKey
|
||||||
}
|
}
|
||||||
if modelRetry != nil {
|
|
||||||
dcfg.ModelRetryConfig = modelRetry
|
|
||||||
}
|
|
||||||
if taskGen != nil {
|
if taskGen != nil {
|
||||||
dcfg.TaskToolDescriptionGenerator = taskGen
|
dcfg.TaskToolDescriptionGenerator = taskGen
|
||||||
}
|
}
|
||||||
@@ -565,6 +551,7 @@ func RunDeepAgent(
|
|||||||
McpIDs: &mcpIDs,
|
McpIDs: &mcpIDs,
|
||||||
FilesystemMonitorAgent: ag,
|
FilesystemMonitorAgent: ag,
|
||||||
FilesystemMonitorRecord: recorder,
|
FilesystemMonitorRecord: recorder,
|
||||||
|
MCPExecutionBinder: mcpExecBinder,
|
||||||
ToolInvokeNotify: toolInvokeNotify,
|
ToolInvokeNotify: toolInvokeNotify,
|
||||||
DA: da,
|
DA: da,
|
||||||
ModelFacingTrace: modelFacingTrace,
|
ModelFacingTrace: modelFacingTrace,
|
||||||
|
|||||||
@@ -10,7 +10,7 @@ package openai
|
|||||||
// Auth: Bearer → x-api-key
|
// Auth: Bearer → x-api-key
|
||||||
// Tools: OpenAI tools[] → Claude tools[] (input_schema)
|
// Tools: OpenAI tools[] → Claude tools[] (input_schema)
|
||||||
//
|
//
|
||||||
// Extended thinking: 顶层 `thinking` 从 OpenAI 请求体透传;响应中 `thinking` block 映射为
|
// Extended thinking: 顶层 `thinking` / `output_config` 从 OpenAI 请求体透传;响应中 `thinking` block 映射为
|
||||||
// `reasoning_content`(可读前缀 + 内部 JSON 尾缀以保留 signature,供多轮工具续跑;UI 用 openai.DisplayReasoningContent 剥离)。
|
// `reasoning_content`(可读前缀 + 内部 JSON 尾缀以保留 signature,供多轮工具续跑;UI 用 openai.DisplayReasoningContent 剥离)。
|
||||||
|
|
||||||
import (
|
import (
|
||||||
@@ -42,6 +42,7 @@ type claudeRequest struct {
|
|||||||
Tools []claudeTool `json:"tools,omitempty"`
|
Tools []claudeTool `json:"tools,omitempty"`
|
||||||
Stream bool `json:"stream,omitempty"`
|
Stream bool `json:"stream,omitempty"`
|
||||||
Thinking json.RawMessage `json:"thinking,omitempty"`
|
Thinking json.RawMessage `json:"thinking,omitempty"`
|
||||||
|
OutputConfig json.RawMessage `json:"output_config,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type claudeMessage struct {
|
type claudeMessage struct {
|
||||||
@@ -304,12 +305,17 @@ func convertOpenAIToClaude(payload interface{}) (*claudeRequest, error) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Extended thinking (Anthropic top-level); merged from Eino ExtraFields / admin extras.
|
// Extended thinking + effort (Anthropic top-level); merged from Eino ExtraFields / admin extras.
|
||||||
if th, ok := oai["thinking"]; ok && th != nil {
|
if th, ok := oai["thinking"]; ok && th != nil {
|
||||||
if raw, err := json.Marshal(th); err == nil && len(raw) > 0 && string(raw) != "null" {
|
if raw, err := json.Marshal(th); err == nil && len(raw) > 0 && string(raw) != "null" {
|
||||||
req.Thinking = json.RawMessage(raw)
|
req.Thinking = json.RawMessage(raw)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if oc, ok := oai["output_config"]; ok && oc != nil {
|
||||||
|
if raw, err := json.Marshal(oc); err == nil && len(raw) > 0 && string(raw) != "null" {
|
||||||
|
req.OutputConfig = json.RawMessage(raw)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return req, nil
|
return req, nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -73,6 +73,39 @@ func TestConvertOpenAIToClaude_AssistantReasoningReplay(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestConvertOpenAIToClaude_OutputConfigEffort(t *testing.T) {
|
||||||
|
payload := map[string]interface{}{
|
||||||
|
"model": "claude-opus-4-8",
|
||||||
|
"messages": []interface{}{
|
||||||
|
map[string]interface{}{"role": "user", "content": "hi"},
|
||||||
|
},
|
||||||
|
"thinking": map[string]interface{}{
|
||||||
|
"type": "adaptive",
|
||||||
|
"display": "summarized",
|
||||||
|
},
|
||||||
|
"output_config": map[string]interface{}{
|
||||||
|
"effort": "high",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
req, err := convertOpenAIToClaude(payload)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if len(req.Thinking) == 0 {
|
||||||
|
t.Fatal("expected thinking")
|
||||||
|
}
|
||||||
|
if len(req.OutputConfig) == 0 {
|
||||||
|
t.Fatal("expected output_config")
|
||||||
|
}
|
||||||
|
var oc map[string]interface{}
|
||||||
|
if err := json.Unmarshal(req.OutputConfig, &oc); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if oc["effort"] != "high" {
|
||||||
|
t.Fatalf("effort=%v", oc["effort"])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestClaudeToOpenAIResponseJSON_Thinking(t *testing.T) {
|
func TestClaudeToOpenAIResponseJSON_Thinking(t *testing.T) {
|
||||||
claudeBody := []byte(`{
|
claudeBody := []byte(`{
|
||||||
"id":"msg_1","type":"message","role":"assistant","model":"x","stop_reason":"end_turn",
|
"id":"msg_1","type":"message","role":"assistant","model":"x","stop_reason":"end_turn",
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"sort"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
"unicode/utf8"
|
"unicode/utf8"
|
||||||
@@ -535,3 +536,81 @@ func (c *Client) ChatCompletionStreamWithToolCalls(
|
|||||||
|
|
||||||
return full.String(), toolCalls, finishReason, nil
|
return full.String(), toolCalls, finishReason, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ModelsListResponse 表示 OpenAI 兼容 GET /models 响应。
|
||||||
|
type ModelsListResponse struct {
|
||||||
|
Object string `json:"object"`
|
||||||
|
Data []struct {
|
||||||
|
ID string `json:"id"`
|
||||||
|
Object string `json:"object,omitempty"`
|
||||||
|
OwnedBy string `json:"owned_by,omitempty"`
|
||||||
|
} `json:"data"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// ListModels 调用 GET {baseURL}/models 获取可用模型 id 列表(按字典序)。
|
||||||
|
func (c *Client) ListModels(ctx context.Context) ([]string, error) {
|
||||||
|
if c == nil {
|
||||||
|
return nil, fmt.Errorf("openai client is not initialized")
|
||||||
|
}
|
||||||
|
if c.config == nil {
|
||||||
|
return nil, fmt.Errorf("openai config is nil")
|
||||||
|
}
|
||||||
|
if strings.TrimSpace(c.config.APIKey) == "" {
|
||||||
|
return nil, fmt.Errorf("openai api key is empty")
|
||||||
|
}
|
||||||
|
if c.isClaude() {
|
||||||
|
return nil, fmt.Errorf("claude provider does not support models list API")
|
||||||
|
}
|
||||||
|
|
||||||
|
baseURL := strings.TrimSuffix(c.config.BaseURL, "/")
|
||||||
|
if baseURL == "" {
|
||||||
|
baseURL = "https://api.openai.com/v1"
|
||||||
|
}
|
||||||
|
|
||||||
|
req, err := http.NewRequestWithContext(ctx, http.MethodGet, baseURL+"/models", nil)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("build openai models request: %w", err)
|
||||||
|
}
|
||||||
|
req.Header.Set("Authorization", "Bearer "+c.config.APIKey)
|
||||||
|
|
||||||
|
resp, err := c.httpClient.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("call openai models api: %w", err)
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
respBody, err := io.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("read openai models response: %w", err)
|
||||||
|
}
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
return nil, &APIError{
|
||||||
|
StatusCode: resp.StatusCode,
|
||||||
|
Body: string(respBody),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var list ModelsListResponse
|
||||||
|
if err := json.Unmarshal(respBody, &list); err != nil {
|
||||||
|
return nil, fmt.Errorf("decode openai models response: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
seen := make(map[string]struct{}, len(list.Data))
|
||||||
|
models := make([]string, 0, len(list.Data))
|
||||||
|
for _, item := range list.Data {
|
||||||
|
id := strings.TrimSpace(item.ID)
|
||||||
|
if id == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if _, ok := seen[id]; ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
seen[id] = struct{}{}
|
||||||
|
models = append(models, id)
|
||||||
|
}
|
||||||
|
sort.Strings(models)
|
||||||
|
if len(models) == 0 {
|
||||||
|
return nil, fmt.Errorf("models list is empty")
|
||||||
|
}
|
||||||
|
return models, nil
|
||||||
|
}
|
||||||
|
|||||||
@@ -2,7 +2,6 @@ package project
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"sort"
|
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"cyberstrike-ai/internal/config"
|
"cyberstrike-ai/internal/config"
|
||||||
@@ -22,7 +21,13 @@ func AppendSystemPromptBlock(base, block string) string {
|
|||||||
return base + "\n\n" + block
|
return base + "\n\n" + block
|
||||||
}
|
}
|
||||||
|
|
||||||
// BuildFactIndexBlock 为 Agent 系统提示生成项目黑板索引(仅 key + summary,不含 body)。
|
const (
|
||||||
|
factIndexFooterGetDetail = "需要完整内容(攻击链、POC、请求响应等)时必须调用 get_project_fact(fact_key),禁止凭摘要臆造细节。"
|
||||||
|
factIndexFooterWriteHint = "写入事实 links 时用 from(来源 fact_key → 当前 fact),如 finding 上 {from:target/*, type:discovered_on};body 写可复现全流程(发现/利用类 fact_key 建议 finding|chain|exploit|poc/ 前缀)。"
|
||||||
|
factIndexFooterEmpty = "需要写入请使用 upsert_project_fact;需要详情请调用 get_project_fact(fact_key)。"
|
||||||
|
)
|
||||||
|
|
||||||
|
// BuildFactIndexBlock 为 Agent 系统提示生成项目黑板索引(key + summary + 关系边 + 攻击路径,不含 body)。
|
||||||
func BuildFactIndexBlock(db *database.DB, projectID string, cfg config.ProjectConfig) (string, error) {
|
func BuildFactIndexBlock(db *database.DB, projectID string, cfg config.ProjectConfig) (string, error) {
|
||||||
if db == nil || !cfg.Enabled {
|
if db == nil || !cfg.Enabled {
|
||||||
return "", nil
|
return "", nil
|
||||||
@@ -41,27 +46,38 @@ func BuildFactIndexBlock(db *database.DB, projectID string, cfg config.ProjectCo
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
|
allEdges, _ := db.ListProjectFactEdgesByProject(projectID)
|
||||||
|
_, incomingByTarget := indexEdgeGroupMaps(allEdges)
|
||||||
|
|
||||||
if len(facts) == 0 {
|
if len(facts) == 0 {
|
||||||
return fmt.Sprintf("## 项目黑板索引(project: %s, id: %s)\n(暂无事实)\n需要写入请使用 upsert_project_fact;需要详情请调用 get_project_fact(fact_key)。", proj.Name, proj.ID), nil
|
return wrapFactIndexBlock(fmt.Sprintf("## 项目黑板索引(project: %s, id: %s)\n(暂无事实)\n%s", proj.Name, proj.ID, factIndexFooterEmpty)), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
sort.SliceStable(facts, func(i, j int) bool {
|
sortFactsForIndex(facts)
|
||||||
if facts[i].Pinned != facts[j].Pinned {
|
|
||||||
return facts[i].Pinned
|
|
||||||
}
|
|
||||||
return facts[i].UpdatedAt.After(facts[j].UpdatedAt)
|
|
||||||
})
|
|
||||||
|
|
||||||
maxRunes := cfg.FactIndexMaxRunesEffective()
|
maxRunes := cfg.FactIndexMaxRunesEffective()
|
||||||
|
pathMaxRunes := cfg.FactIndexPathMaxRunesEffective()
|
||||||
|
footer := factIndexFooterGetDetail + "\n" + factIndexFooterWriteHint
|
||||||
|
footerRunes := len([]rune(footer))
|
||||||
|
factsBudget := maxRunes - pathMaxRunes - footerRunes
|
||||||
|
if factsBudget < 800 {
|
||||||
|
factsBudget = maxRunes - footerRunes
|
||||||
|
pathMaxRunes = 0
|
||||||
|
}
|
||||||
|
|
||||||
|
indexedKeys := make(map[string]struct{}, len(facts))
|
||||||
var b strings.Builder
|
var b strings.Builder
|
||||||
b.WriteString(fmt.Sprintf("## 项目黑板索引(project: %s, id: %s)\n", proj.Name, proj.ID))
|
b.WriteString(fmt.Sprintf("## 项目黑板索引(project: %s, id: %s)\n", proj.Name, proj.ID))
|
||||||
used := len([]rune(b.String()))
|
used := len([]rune(b.String()))
|
||||||
omitted := 0
|
omitted := 0
|
||||||
|
|
||||||
for _, f := range facts {
|
for _, f := range facts {
|
||||||
line := fmt.Sprintf("- [%s] %s — %s (%s)\n", f.FactKey, f.Category, strings.TrimSpace(f.Summary), f.Confidence)
|
indexedKeys[f.FactKey] = struct{}{}
|
||||||
|
line := fmt.Sprintf("- [%s] %s — %s (%s)", f.FactKey, f.Category, strings.TrimSpace(f.Summary), f.Confidence)
|
||||||
|
line += FormatFactIndexLinksHint(f.FactKey, incomingByTarget[f.FactKey])
|
||||||
|
line += "\n"
|
||||||
lineRunes := len([]rune(line))
|
lineRunes := len([]rune(line))
|
||||||
if used+lineRunes > maxRunes {
|
if used+lineRunes > factsBudget {
|
||||||
omitted++
|
omitted++
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
@@ -72,7 +88,12 @@ func BuildFactIndexBlock(db *database.DB, projectID string, cfg config.ProjectCo
|
|||||||
if omitted > 0 {
|
if omitted > 0 {
|
||||||
b.WriteString(fmt.Sprintf("\n(另有 %d 条未列入索引,请使用 list_project_facts 或 search_project_facts 查询。)\n", omitted))
|
b.WriteString(fmt.Sprintf("\n(另有 %d 条未列入索引,请使用 list_project_facts 或 search_project_facts 查询。)\n", omitted))
|
||||||
}
|
}
|
||||||
b.WriteString("需要完整内容(攻击链、POC、请求响应等)时必须调用 get_project_fact(fact_key),禁止凭摘要臆造细节。\n")
|
|
||||||
b.WriteString("写入事实时:summary 写「什么+在哪+如何验证」;body 写可复现全流程(发现/利用类 fact_key 建议 finding|chain|exploit|poc/ 前缀)。\n")
|
if pathSection := BuildFactPathOverviewSection(allEdges, indexedKeys, pathMaxRunes); pathSection != "" {
|
||||||
return b.String(), nil
|
b.WriteString("\n")
|
||||||
|
b.WriteString(pathSection)
|
||||||
|
}
|
||||||
|
|
||||||
|
b.WriteString(footer)
|
||||||
|
return wrapFactIndexBlock(b.String()), nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -0,0 +1,56 @@
|
|||||||
|
package project
|
||||||
|
|
||||||
|
import "strings"
|
||||||
|
|
||||||
|
// FactIndexSectionHeading 黑板索引可读标题行前缀(块内保留,供 Agent 阅读)。
|
||||||
|
const FactIndexSectionHeading = "## 项目黑板索引"
|
||||||
|
|
||||||
|
// FactIndexSectionStartMarker / EndMarker:HTML 注释边界,供程序化替换;对模型无指令语义。
|
||||||
|
const (
|
||||||
|
FactIndexSectionStartMarker = "<!-- fact-index-start -->"
|
||||||
|
FactIndexSectionEndMarker = "<!-- fact-index-end -->"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ReplaceFactIndexSection 用 freshIndex 替换 content 中已有的项目黑板索引段。
|
||||||
|
// freshIndex 须为 BuildFactIndexBlock 的完整输出。起止 HTML 注释缺失时返回 (_, false)。
|
||||||
|
func ReplaceFactIndexSection(content, freshIndex string) (string, bool) {
|
||||||
|
freshIndex = strings.TrimSpace(freshIndex)
|
||||||
|
if freshIndex == "" {
|
||||||
|
return content, false
|
||||||
|
}
|
||||||
|
start, ok := factIndexSectionStart(content)
|
||||||
|
if !ok {
|
||||||
|
return content, false
|
||||||
|
}
|
||||||
|
end, ok := factIndexSectionEnd(content, start)
|
||||||
|
if !ok || end <= start {
|
||||||
|
return content, false
|
||||||
|
}
|
||||||
|
return content[:start] + freshIndex + content[end:], true
|
||||||
|
}
|
||||||
|
|
||||||
|
// wrapFactIndexBlock 为 BuildFactIndexBlock 正文加上统一起止 HTML 注释边界。
|
||||||
|
func wrapFactIndexBlock(content string) string {
|
||||||
|
content = strings.TrimSpace(content)
|
||||||
|
return FactIndexSectionStartMarker + "\n" + content + "\n" + FactIndexSectionEndMarker + "\n"
|
||||||
|
}
|
||||||
|
|
||||||
|
func factIndexSectionStart(content string) (int, bool) {
|
||||||
|
idx := strings.Index(content, FactIndexSectionStartMarker)
|
||||||
|
if idx < 0 {
|
||||||
|
return 0, false
|
||||||
|
}
|
||||||
|
return idx, true
|
||||||
|
}
|
||||||
|
|
||||||
|
func factIndexSectionEnd(content string, start int) (int, bool) {
|
||||||
|
if start < 0 || start >= len(content) {
|
||||||
|
return 0, false
|
||||||
|
}
|
||||||
|
tail := content[start:]
|
||||||
|
idx := strings.LastIndex(tail, FactIndexSectionEndMarker)
|
||||||
|
if idx < 0 {
|
||||||
|
return 0, false
|
||||||
|
}
|
||||||
|
return start + idx + len(FactIndexSectionEndMarker), true
|
||||||
|
}
|
||||||
@@ -0,0 +1,154 @@
|
|||||||
|
package project
|
||||||
|
|
||||||
|
import (
|
||||||
|
"path/filepath"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"cyberstrike-ai/internal/config"
|
||||||
|
"cyberstrike-ai/internal/database"
|
||||||
|
|
||||||
|
"go.uber.org/zap"
|
||||||
|
)
|
||||||
|
|
||||||
|
func sampleFactIndexWithFacts(projectLabel, summary string) string {
|
||||||
|
return wrapFactIndexBlock("## 项目黑板索引(project: " + projectLabel + ", id: x)\n" +
|
||||||
|
"- [target/a] target — " + summary + " (tentative)\n" +
|
||||||
|
factIndexFooterGetDetail + "\n" +
|
||||||
|
factIndexFooterWriteHint)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestReplaceFactIndexSection(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
oldIndex := sampleFactIndexWithFacts("p1", "old summary")
|
||||||
|
newIndex := sampleFactIndexWithFacts("p1", "new summary")
|
||||||
|
|
||||||
|
t.Run("replaces index before next section", func(t *testing.T) {
|
||||||
|
content := "你是助手\n\n" + oldIndex + "\n\n## 图片分析\n看截图"
|
||||||
|
out, ok := ReplaceFactIndexSection(content, newIndex)
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("expected replacement")
|
||||||
|
}
|
||||||
|
if strings.Contains(out, "old summary") {
|
||||||
|
t.Fatalf("old index should be gone: %q", out)
|
||||||
|
}
|
||||||
|
if !strings.Contains(out, "new summary") || !strings.Contains(out, "## 图片分析") {
|
||||||
|
t.Fatalf("expected new index and preserved vision section: %q", out)
|
||||||
|
}
|
||||||
|
if strings.Count(out, FactIndexSectionStartMarker) != 1 || strings.Count(out, FactIndexSectionEndMarker) != 1 {
|
||||||
|
t.Fatalf("expected exactly one start/end marker pair: %q", out)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("replaces index at end", func(t *testing.T) {
|
||||||
|
content := "## 项目测试范围\nscope\n\n" + oldIndex
|
||||||
|
out, ok := ReplaceFactIndexSection(content, newIndex)
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("expected replacement")
|
||||||
|
}
|
||||||
|
if !strings.Contains(out, "## 项目测试范围") || !strings.Contains(out, "new summary") {
|
||||||
|
t.Fatalf("scope preserved, index updated: %q", out)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("summary with false markdown header does not truncate early", func(t *testing.T) {
|
||||||
|
summaryWithFakeHeader := "see\n\n## fake header in summary"
|
||||||
|
old := sampleFactIndexWithFacts("p1", summaryWithFakeHeader)
|
||||||
|
newIdx := sampleFactIndexWithFacts("p1", "new summary")
|
||||||
|
content := old + "\n\n## 图片分析\nvision"
|
||||||
|
out, ok := ReplaceFactIndexSection(content, newIdx)
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("expected replacement")
|
||||||
|
}
|
||||||
|
if strings.Contains(out, "fake header in summary") {
|
||||||
|
t.Fatalf("old index tail should be fully removed: %q", out)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("summary containing end marker text does not truncate early", func(t *testing.T) {
|
||||||
|
summary := "note " + FactIndexSectionEndMarker + " in summary"
|
||||||
|
old := sampleFactIndexWithFacts("p1", summary)
|
||||||
|
newIdx := sampleFactIndexWithFacts("p1", "clean")
|
||||||
|
content := old + "\n\n## 图片分析\nvision"
|
||||||
|
out, ok := ReplaceFactIndexSection(content, newIdx)
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("expected replacement")
|
||||||
|
}
|
||||||
|
if strings.Contains(out, "in summary") {
|
||||||
|
t.Fatalf("old block should be fully removed: %q", out)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("missing html markers does not replace", func(t *testing.T) {
|
||||||
|
legacy := "## 项目黑板索引(project: p1, id: x)\n- [a] note — old (tentative)\n"
|
||||||
|
newIdx := sampleFactIndexWithFacts("p1", "new")
|
||||||
|
out, ok := ReplaceFactIndexSection("prefix\n\n"+legacy, newIdx)
|
||||||
|
if ok {
|
||||||
|
t.Fatalf("expected no replacement without markers: %q", out)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("empty facts block", func(t *testing.T) {
|
||||||
|
oldEmpty := wrapFactIndexBlock("## 项目黑板索引(project: p1, id: x)\n(暂无事实)\n" + factIndexFooterEmpty)
|
||||||
|
newEmpty := sampleFactIndexWithFacts("p1", "first fact")
|
||||||
|
out, ok := ReplaceFactIndexSection(oldEmpty, newEmpty)
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("expected replacement")
|
||||||
|
}
|
||||||
|
if strings.Contains(out, "(暂无事实)") {
|
||||||
|
t.Fatalf("old empty block should be gone: %q", out)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("no marker", func(t *testing.T) {
|
||||||
|
_, ok := ReplaceFactIndexSection("no blackboard here", newIndex)
|
||||||
|
if ok {
|
||||||
|
t.Fatal("expected false when marker missing")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("empty fresh index", func(t *testing.T) {
|
||||||
|
_, ok := ReplaceFactIndexSection(oldIndex, " ")
|
||||||
|
if ok {
|
||||||
|
t.Fatal("expected false for empty fresh index")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFactIndexSectionBounds_useHTMLMarkers(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
body := sampleFactIndexWithFacts("p", "line with\n\n## not a real section") + "TAIL_SHOULD_DROP"
|
||||||
|
start, ok := factIndexSectionStart(body)
|
||||||
|
if !ok || !strings.HasPrefix(body[start:], FactIndexSectionStartMarker) {
|
||||||
|
t.Fatalf("start should be at html start marker, got %d", start)
|
||||||
|
}
|
||||||
|
end, ok := factIndexSectionEnd(body, start)
|
||||||
|
if !ok || body[end:] != "\nTAIL_SHOULD_DROP" {
|
||||||
|
t.Fatalf("end should be after end marker, got remainder %q", body[end:])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBuildFactIndexBlock_includesHTMLMarkers(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
dbPath := filepath.Join(t.TempDir(), "facts.db")
|
||||||
|
db, err := database.NewDB(dbPath, zap.NewNop())
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
defer db.Close()
|
||||||
|
|
||||||
|
proj, err := db.CreateProject(&database.Project{Name: "marker-proj"})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
block, err := BuildFactIndexBlock(db, proj.ID, config.ProjectConfig{Enabled: true})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if !strings.HasPrefix(strings.TrimSpace(block), FactIndexSectionStartMarker) {
|
||||||
|
t.Fatalf("block should start with start marker: %q", block)
|
||||||
|
}
|
||||||
|
if !strings.Contains(block, FactIndexSectionEndMarker) {
|
||||||
|
t.Fatalf("block should include end marker: %q", block)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,256 @@
|
|||||||
|
package project
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"regexp"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"cyberstrike-ai/internal/database"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
bodyDepFactLine = regexp.MustCompile(`(?im)^[\s\-*]*依赖事实\s*[::]\s*([a-z0-9][a-z0-9._/-]*)`)
|
||||||
|
bodyRelFactLine = regexp.MustCompile(`(?im)^[\s\-*]*相关\s*fact_key\s*[::]\s*([a-z0-9][a-z0-9._/-]*)`)
|
||||||
|
bodyAssocSection = regexp.MustCompile(`(?im)^##\s*关联\s*$`)
|
||||||
|
bodySyncLinksHead = "结构化关系边(自动同步)"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ParseLinksFromBody 从 body「关联」段落解析 from 语义的关系边(无显式 links 时的兜底)。
|
||||||
|
func ParseLinksFromBody(body string) []database.ProjectFactEdgeFromInput {
|
||||||
|
body = strings.TrimSpace(body)
|
||||||
|
if body == "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
seen := map[string]struct{}{}
|
||||||
|
var out []database.ProjectFactEdgeFromInput
|
||||||
|
add := func(key, edgeType string) {
|
||||||
|
key = strings.TrimSpace(key)
|
||||||
|
if key == "" {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if err := database.ValidateFactKey(key); err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
sig := edgeType + "\x00" + key
|
||||||
|
if _, ok := seen[sig]; ok {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
seen[sig] = struct{}{}
|
||||||
|
out = append(out, database.ProjectFactEdgeFromInput{From: key, Type: edgeType})
|
||||||
|
}
|
||||||
|
for _, m := range bodyDepFactLine.FindAllStringSubmatch(body, -1) {
|
||||||
|
if len(m) > 1 {
|
||||||
|
add(m[1], "depends_on")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for _, m := range bodyRelFactLine.FindAllStringSubmatch(body, -1) {
|
||||||
|
if len(m) > 1 {
|
||||||
|
add(m[1], "supports")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// 自动同步块:type: key
|
||||||
|
syncBlock := extractBodySyncLinksBlock(body)
|
||||||
|
for _, line := range strings.Split(syncBlock, "\n") {
|
||||||
|
line = strings.TrimSpace(strings.TrimPrefix(strings.TrimSpace(line), "-"))
|
||||||
|
if line == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
edgeType, source, ok := strings.Cut(line, ":")
|
||||||
|
if !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
edgeType = strings.TrimSpace(edgeType)
|
||||||
|
source = strings.TrimSpace(source)
|
||||||
|
if err := database.ValidateProjectFactEdgeType(edgeType); err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
add(source, edgeType)
|
||||||
|
}
|
||||||
|
if len(out) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
func extractBodySyncLinksBlock(body string) string {
|
||||||
|
lines := strings.Split(body, "\n")
|
||||||
|
var b strings.Builder
|
||||||
|
inAssoc := false
|
||||||
|
inSync := false
|
||||||
|
for _, line := range lines {
|
||||||
|
trim := strings.TrimSpace(line)
|
||||||
|
if bodyAssocSection.MatchString(trim) {
|
||||||
|
inAssoc = true
|
||||||
|
inSync = false
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if inAssoc && strings.HasPrefix(trim, "## ") && !strings.HasPrefix(trim, "## 关联") {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
if inAssoc && strings.Contains(trim, bodySyncLinksHead) {
|
||||||
|
inSync = true
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if inSync {
|
||||||
|
if trim == "" || strings.HasPrefix(trim, "-") || strings.Contains(trim, ":") {
|
||||||
|
if strings.HasPrefix(trim, "-") || (strings.Contains(trim, ":") && !strings.Contains(trim, "related_vulnerability")) {
|
||||||
|
b.WriteString(trim)
|
||||||
|
b.WriteByte('\n')
|
||||||
|
}
|
||||||
|
} else if strings.HasPrefix(trim, "##") {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return b.String()
|
||||||
|
}
|
||||||
|
|
||||||
|
// SyncBodyLinksSection 将入边镜像写入 body 的「关联」段(人读用;结构化以 links 为准)。
|
||||||
|
func SyncBodyLinksSection(body string, edges []*database.ProjectFactEdge) string {
|
||||||
|
body = strings.TrimSpace(body)
|
||||||
|
block := formatBodySyncLinksBlock(edges)
|
||||||
|
if block == "" {
|
||||||
|
return body
|
||||||
|
}
|
||||||
|
if body == "" {
|
||||||
|
return "## 关联\n" + block
|
||||||
|
}
|
||||||
|
lines := strings.Split(body, "\n")
|
||||||
|
var out []string
|
||||||
|
inAssoc := false
|
||||||
|
replaced := false
|
||||||
|
for i := 0; i < len(lines); i++ {
|
||||||
|
trim := strings.TrimSpace(lines[i])
|
||||||
|
if bodyAssocSection.MatchString(trim) {
|
||||||
|
inAssoc = true
|
||||||
|
out = append(out, lines[i])
|
||||||
|
// 跳过旧同步块
|
||||||
|
j := i + 1
|
||||||
|
for j < len(lines) {
|
||||||
|
t := strings.TrimSpace(lines[j])
|
||||||
|
if strings.HasPrefix(t, "## ") {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
if strings.Contains(t, bodySyncLinksHead) {
|
||||||
|
for j < len(lines) {
|
||||||
|
t2 := strings.TrimSpace(lines[j])
|
||||||
|
if t2 != "" && !strings.HasPrefix(t2, "-") && !strings.Contains(t2, ":") && !strings.Contains(t2, bodySyncLinksHead) {
|
||||||
|
if strings.HasPrefix(t2, "##") {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
j++
|
||||||
|
if j < len(lines) && strings.HasPrefix(strings.TrimSpace(lines[j]), "## ") {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
if j >= len(lines) {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
if j > i+1 && strings.TrimSpace(lines[j-1]) == "" && strings.HasPrefix(strings.TrimSpace(lines[j]), "## ") {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
break
|
||||||
|
}
|
||||||
|
j++
|
||||||
|
}
|
||||||
|
out = append(out, block)
|
||||||
|
i = j - 1
|
||||||
|
replaced = true
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
out = append(out, lines[i])
|
||||||
|
}
|
||||||
|
if !replaced {
|
||||||
|
if !inAssoc {
|
||||||
|
out = append(out, "", "## 关联", block)
|
||||||
|
} else {
|
||||||
|
out = append(out, block)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return strings.TrimSpace(strings.Join(out, "\n"))
|
||||||
|
}
|
||||||
|
|
||||||
|
func formatBodySyncLinksBlock(edges []*database.ProjectFactEdge) string {
|
||||||
|
if len(edges) == 0 {
|
||||||
|
return fmt.Sprintf("- %s:\n (暂无)", bodySyncLinksHead)
|
||||||
|
}
|
||||||
|
var b strings.Builder
|
||||||
|
b.WriteString("- ")
|
||||||
|
b.WriteString(bodySyncLinksHead)
|
||||||
|
b.WriteString(":\n")
|
||||||
|
for _, e := range edges {
|
||||||
|
b.WriteString(fmt.Sprintf(" - %s: %s\n", e.EdgeType, e.SourceFactKey))
|
||||||
|
}
|
||||||
|
return strings.TrimRight(b.String(), "\n")
|
||||||
|
}
|
||||||
|
|
||||||
|
// ResolveFactLinksForUpsert 合并显式 links、links_text 与 body 解析结果。
|
||||||
|
func ResolveFactLinksForUpsert(explicit []database.ProjectFactEdgeFromInput, linksText *string, body string, explicitSet bool) ([]database.ProjectFactEdgeFromInput, bool, error) {
|
||||||
|
if explicitSet {
|
||||||
|
if len(explicit) > 0 {
|
||||||
|
return explicit, true, nil
|
||||||
|
}
|
||||||
|
if linksText != nil {
|
||||||
|
parsed, err := ParseFactLinksText(*linksText)
|
||||||
|
if err != nil {
|
||||||
|
return nil, true, err
|
||||||
|
}
|
||||||
|
if parsed == nil {
|
||||||
|
return []database.ProjectFactEdgeFromInput{}, true, nil
|
||||||
|
}
|
||||||
|
return parsed, true, nil
|
||||||
|
}
|
||||||
|
return []database.ProjectFactEdgeFromInput{}, true, nil
|
||||||
|
}
|
||||||
|
if parsed := ParseLinksFromBody(body); len(parsed) > 0 {
|
||||||
|
return parsed, true, nil
|
||||||
|
}
|
||||||
|
return nil, false, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// MergeLinkFromInputsUnique 合并多组 from 入边输入并去重。
|
||||||
|
func MergeLinkFromInputsUnique(groups ...[]database.ProjectFactEdgeFromInput) []database.ProjectFactEdgeFromInput {
|
||||||
|
seen := map[string]struct{}{}
|
||||||
|
var out []database.ProjectFactEdgeFromInput
|
||||||
|
for _, g := range groups {
|
||||||
|
for _, in := range g {
|
||||||
|
sig := in.Type + "\x00" + in.From
|
||||||
|
if _, ok := seen[sig]; ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if err := database.ValidateProjectFactEdgeType(in.Type); err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if err := database.ValidateFactKey(in.From); err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
seen[sig] = struct{}{}
|
||||||
|
out = append(out, in)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
// MergeLinkInputsUnique 合并多组 link 输入并去重(内部出边写入用)。
|
||||||
|
func MergeLinkInputsUnique(groups ...[]database.ProjectFactEdgeInput) []database.ProjectFactEdgeInput {
|
||||||
|
seen := map[string]struct{}{}
|
||||||
|
var out []database.ProjectFactEdgeInput
|
||||||
|
for _, g := range groups {
|
||||||
|
for _, in := range g {
|
||||||
|
sig := in.Type + "\x00" + in.To
|
||||||
|
if _, ok := seen[sig]; ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if err := database.ValidateProjectFactEdgeType(in.Type); err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if err := database.ValidateFactKey(in.To); err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
seen[sig] = struct{}{}
|
||||||
|
out = append(out, in)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
@@ -0,0 +1,68 @@
|
|||||||
|
package project
|
||||||
|
|
||||||
|
import (
|
||||||
|
"path/filepath"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"cyberstrike-ai/internal/database"
|
||||||
|
|
||||||
|
"go.uber.org/zap"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestParseLinksFromBodyDependsOn(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
body := "## 关联\n- 依赖事实: target/api\n- 相关 fact_key: auth/session"
|
||||||
|
links := ParseLinksFromBody(body)
|
||||||
|
if len(links) != 2 {
|
||||||
|
t.Fatalf("want 2 links, got %d", len(links))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSyncBodyLinksSection(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
body := "## 结论\nx\n\n## 关联\n- 依赖事实: old/key"
|
||||||
|
edges := []*database.ProjectFactEdge{{EdgeType: "discovered_on", SourceFactKey: "target/a"}}
|
||||||
|
out := SyncBodyLinksSection(body, edges)
|
||||||
|
if !strings.Contains(out, "discovered_on: target/a") {
|
||||||
|
t.Fatalf("missing synced edge: %q", out)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFactGraphIntegration(t *testing.T) {
|
||||||
|
dir := t.TempDir()
|
||||||
|
dbPath := filepath.Join(dir, "test.db")
|
||||||
|
db, err := database.NewDB(dbPath, zap.NewNop())
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
defer db.Close()
|
||||||
|
|
||||||
|
p, err := db.CreateProject(&database.Project{Name: "g"})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
for _, spec := range []struct{ key, cat, summary string }{
|
||||||
|
{"target/root", "target", "root"},
|
||||||
|
{"finding/x", "finding", "finding x"},
|
||||||
|
} {
|
||||||
|
_, err := db.UpsertProjectFact(&database.ProjectFact{
|
||||||
|
ProjectID: p.ID, FactKey: spec.key, Category: spec.cat, Summary: spec.summary, Confidence: "confirmed",
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if err := db.ReplaceIncomingProjectFactEdges(p.ID, "finding/x", []database.ProjectFactEdgeFromInput{
|
||||||
|
{From: "target/root", Type: "discovered_on"},
|
||||||
|
}); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
graph, err := BuildProjectFactGraph(db, p.ID, "path", true)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if len(graph.Nodes) < 2 || len(graph.Edges) < 1 {
|
||||||
|
t.Fatalf("expected graph nodes/edges, got %d/%d", len(graph.Nodes), len(graph.Edges))
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,407 @@
|
|||||||
|
package project
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"cyberstrike-ai/internal/database"
|
||||||
|
"cyberstrike-ai/internal/projectprompt"
|
||||||
|
)
|
||||||
|
|
||||||
|
// PathGraphCategories 攻击路径视图包含的事实分类。
|
||||||
|
var PathGraphCategories = map[string]struct{}{
|
||||||
|
FactCategoryTarget: {},
|
||||||
|
FactCategoryFinding: {},
|
||||||
|
FactCategoryChain: {},
|
||||||
|
FactCategoryExploit: {},
|
||||||
|
FactCategoryPOC: {},
|
||||||
|
"vuln": {},
|
||||||
|
}
|
||||||
|
|
||||||
|
// GraphNodeType 将 fact category 映射为图节点类型(供前端样式与 ELK 分层)。
|
||||||
|
// 优先使用 category;仅 synthetic 节点(vuln:)或无 category 时才回退到 fact_key 前缀。
|
||||||
|
func GraphNodeType(category, factKey string) string {
|
||||||
|
key := strings.ToLower(strings.TrimSpace(factKey))
|
||||||
|
if strings.HasPrefix(key, "vuln:") {
|
||||||
|
return "vulnerability"
|
||||||
|
}
|
||||||
|
c := strings.ToLower(strings.TrimSpace(category))
|
||||||
|
if c != "" {
|
||||||
|
switch c {
|
||||||
|
case FactCategoryTarget:
|
||||||
|
return "target"
|
||||||
|
case FactCategoryExploit:
|
||||||
|
return "exploit"
|
||||||
|
case FactCategoryPOC:
|
||||||
|
return "poc"
|
||||||
|
case FactCategoryChain:
|
||||||
|
return "chain"
|
||||||
|
case FactCategoryFinding:
|
||||||
|
return "finding"
|
||||||
|
case "vuln":
|
||||||
|
return "vulnerability"
|
||||||
|
case FactCategoryAuth:
|
||||||
|
return "auth"
|
||||||
|
case FactCategoryInfra, FactCategoryBusiness:
|
||||||
|
return "infra"
|
||||||
|
case FactCategoryNote:
|
||||||
|
return "note"
|
||||||
|
case "missing":
|
||||||
|
return "missing"
|
||||||
|
default:
|
||||||
|
return c
|
||||||
|
}
|
||||||
|
}
|
||||||
|
switch {
|
||||||
|
case strings.HasPrefix(key, "target/"):
|
||||||
|
return "target"
|
||||||
|
case strings.HasPrefix(key, "exploit/"), strings.HasPrefix(key, "evidence/"):
|
||||||
|
return "exploit"
|
||||||
|
case strings.HasPrefix(key, "poc/"):
|
||||||
|
return "poc"
|
||||||
|
case strings.HasPrefix(key, "chain/"):
|
||||||
|
return "chain"
|
||||||
|
case strings.HasPrefix(key, "finding/"):
|
||||||
|
return "finding"
|
||||||
|
case strings.HasPrefix(key, "auth/"):
|
||||||
|
return "auth"
|
||||||
|
case strings.HasPrefix(key, "infra/"), strings.HasPrefix(key, "business/"):
|
||||||
|
return "infra"
|
||||||
|
default:
|
||||||
|
return "note"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func truncateGraphLabel(summary string, maxRunes int) string {
|
||||||
|
summary = strings.TrimSpace(summary)
|
||||||
|
if summary == "" {
|
||||||
|
return "—"
|
||||||
|
}
|
||||||
|
r := []rune(summary)
|
||||||
|
if len(r) <= maxRunes {
|
||||||
|
return summary
|
||||||
|
}
|
||||||
|
return string(r[:maxRunes]) + "…"
|
||||||
|
}
|
||||||
|
|
||||||
|
// BuildProjectFactGraph 构建项目事实图(nodes + edges)。
|
||||||
|
func BuildProjectFactGraph(db *database.DB, projectID string, view string, excludeDeprecated bool) (*database.ProjectFactGraph, error) {
|
||||||
|
if db == nil {
|
||||||
|
return nil, fmt.Errorf("database 未初始化")
|
||||||
|
}
|
||||||
|
projectID = strings.TrimSpace(projectID)
|
||||||
|
if projectID == "" {
|
||||||
|
return nil, fmt.Errorf("project_id 不能为空")
|
||||||
|
}
|
||||||
|
|
||||||
|
view = strings.TrimSpace(strings.ToLower(view))
|
||||||
|
if view == "" {
|
||||||
|
view = "path"
|
||||||
|
}
|
||||||
|
|
||||||
|
filter := database.ProjectFactListFilter{}
|
||||||
|
if excludeDeprecated {
|
||||||
|
filter.ExcludeDeprecated = true
|
||||||
|
}
|
||||||
|
facts, err := db.ListProjectFacts(projectID, filter, 1000, 0)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
edges, err := db.ListProjectFactEdgesByProject(projectID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if excludeDeprecated {
|
||||||
|
edges = filterDeprecatedEdges(edges)
|
||||||
|
}
|
||||||
|
|
||||||
|
factByKey := make(map[string]*database.ProjectFact, len(facts))
|
||||||
|
for _, f := range facts {
|
||||||
|
factByKey[f.FactKey] = f
|
||||||
|
}
|
||||||
|
|
||||||
|
pathMode := view == "path"
|
||||||
|
nodeKeys := make(map[string]struct{})
|
||||||
|
|
||||||
|
if pathMode {
|
||||||
|
for _, f := range facts {
|
||||||
|
if isPathGraphFact(f.Category, f.FactKey) {
|
||||||
|
nodeKeys[f.FactKey] = struct{}{}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// 路径视图中保留作为依赖目标的 auth/infra 节点
|
||||||
|
for _, e := range edges {
|
||||||
|
if _, ok := nodeKeys[e.SourceFactKey]; !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if f, ok := factByKey[e.TargetFactKey]; ok && isDependencyGraphFact(f.Category, f.FactKey) {
|
||||||
|
nodeKeys[e.TargetFactKey] = struct{}{}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
for _, f := range facts {
|
||||||
|
nodeKeys[f.FactKey] = struct{}{}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 边上引用的 endpoint 纳入节点集
|
||||||
|
for _, e := range edges {
|
||||||
|
if pathMode {
|
||||||
|
if _, ok := nodeKeys[e.SourceFactKey]; !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if _, ok := nodeKeys[e.TargetFactKey]; ok {
|
||||||
|
// already included
|
||||||
|
} else if f, ok := factByKey[e.TargetFactKey]; !ok {
|
||||||
|
nodeKeys[e.TargetFactKey] = struct{}{} // 占位节点
|
||||||
|
} else if isPathGraphFact(f.Category, f.FactKey) || isDependencyGraphFact(f.Category, f.FactKey) {
|
||||||
|
nodeKeys[e.TargetFactKey] = struct{}{}
|
||||||
|
} else {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
nodeKeys[e.SourceFactKey] = struct{}{}
|
||||||
|
nodeKeys[e.TargetFactKey] = struct{}{}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
nodes := make([]database.ProjectFactGraphNode, 0, len(nodeKeys))
|
||||||
|
for key := range nodeKeys {
|
||||||
|
if f, ok := factByKey[key]; ok {
|
||||||
|
nodes = append(nodes, database.ProjectFactGraphNode{
|
||||||
|
ID: f.FactKey,
|
||||||
|
FactKey: f.FactKey,
|
||||||
|
Category: f.Category,
|
||||||
|
Label: truncateGraphLabel(f.Summary, 48),
|
||||||
|
Summary: strings.TrimSpace(f.Summary),
|
||||||
|
Confidence: f.Confidence,
|
||||||
|
Type: GraphNodeType(f.Category, f.FactKey),
|
||||||
|
Pinned: f.Pinned,
|
||||||
|
})
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
nodes = append(nodes, database.ProjectFactGraphNode{
|
||||||
|
ID: key,
|
||||||
|
FactKey: key,
|
||||||
|
Category: "missing",
|
||||||
|
Label: key,
|
||||||
|
Confidence: "tentative",
|
||||||
|
Type: "missing",
|
||||||
|
Pinned: false,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
graphEdges := make([]database.ProjectFactGraphEdge, 0, len(edges))
|
||||||
|
for _, e := range edges {
|
||||||
|
if pathMode {
|
||||||
|
if _, ok := nodeKeys[e.SourceFactKey]; !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if _, ok := nodeKeys[e.TargetFactKey]; !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if _, ok := nodeKeys[e.SourceFactKey]; !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if _, ok := nodeKeys[e.TargetFactKey]; !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
graphEdges = append(graphEdges, database.ProjectFactGraphEdge{
|
||||||
|
ID: e.ID,
|
||||||
|
Source: e.SourceFactKey,
|
||||||
|
Target: e.TargetFactKey,
|
||||||
|
Type: e.EdgeType,
|
||||||
|
Confidence: e.Confidence,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// related_vulnerability_id 合成边(source=fact → target=vuln:<id>)
|
||||||
|
for _, f := range facts {
|
||||||
|
if _, ok := nodeKeys[f.FactKey]; !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
vid := strings.TrimSpace(f.RelatedVulnerabilityID)
|
||||||
|
if vid == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
vulnNodeID := "vuln:" + vid
|
||||||
|
if _, exists := nodeKeys[vulnNodeID]; !exists {
|
||||||
|
nodeKeys[vulnNodeID] = struct{}{}
|
||||||
|
label := "漏洞"
|
||||||
|
if len(vid) >= 8 {
|
||||||
|
label += " " + vid[:8] + "…"
|
||||||
|
} else {
|
||||||
|
label += " " + vid
|
||||||
|
}
|
||||||
|
nodes = append(nodes, database.ProjectFactGraphNode{
|
||||||
|
ID: vulnNodeID,
|
||||||
|
FactKey: vulnNodeID,
|
||||||
|
Category: "vuln",
|
||||||
|
Label: label,
|
||||||
|
Confidence: f.Confidence,
|
||||||
|
Type: "vulnerability",
|
||||||
|
Pinned: false,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
graphEdges = append(graphEdges, database.ProjectFactGraphEdge{
|
||||||
|
ID: "vuln-link:" + f.FactKey + ":" + vid,
|
||||||
|
Source: f.FactKey,
|
||||||
|
Target: vulnNodeID,
|
||||||
|
Type: "links_vuln",
|
||||||
|
Confidence: f.Confidence,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
return &database.ProjectFactGraph{Nodes: nodes, Edges: graphEdges}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func min(a, b int) int {
|
||||||
|
if a < b {
|
||||||
|
return a
|
||||||
|
}
|
||||||
|
return b
|
||||||
|
}
|
||||||
|
|
||||||
|
func isPathGraphFact(category, factKey string) bool {
|
||||||
|
c := strings.ToLower(strings.TrimSpace(category))
|
||||||
|
if _, ok := PathGraphCategories[c]; ok {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
if c != "" {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
key := strings.ToLower(strings.TrimSpace(factKey))
|
||||||
|
for _, p := range []string{"target/", "finding/", "chain/", "exploit/", "poc/", "evidence/"} {
|
||||||
|
if strings.HasPrefix(key, p) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func isDependencyGraphFact(category, factKey string) bool {
|
||||||
|
c := strings.ToLower(strings.TrimSpace(category))
|
||||||
|
if c == FactCategoryAuth || c == FactCategoryInfra || c == FactCategoryBusiness {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
if c != "" {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
key := strings.ToLower(strings.TrimSpace(factKey))
|
||||||
|
return strings.HasPrefix(key, "auth/") || strings.HasPrefix(key, "infra/") || strings.HasPrefix(key, "business/")
|
||||||
|
}
|
||||||
|
|
||||||
|
func filterDeprecatedEdges(edges []*database.ProjectFactEdge) []*database.ProjectFactEdge {
|
||||||
|
out := make([]*database.ProjectFactEdge, 0, len(edges))
|
||||||
|
for _, e := range edges {
|
||||||
|
if strings.EqualFold(strings.TrimSpace(e.Confidence), "deprecated") {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
out = append(out, e)
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
// ParsedFactLinks 解析 links 参数(from → 当前 fact)。
|
||||||
|
type ParsedFactLinks struct {
|
||||||
|
Incoming []database.ProjectFactEdgeFromInput
|
||||||
|
}
|
||||||
|
|
||||||
|
// ParseFactLinkInputs 从 MCP links 参数解析;空数组表示清空全部入边。
|
||||||
|
func ParseFactLinkInputs(raw interface{}) (*ParsedFactLinks, error) {
|
||||||
|
if raw == nil {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
items, ok := raw.([]interface{})
|
||||||
|
if !ok {
|
||||||
|
return nil, fmt.Errorf("links 须为数组")
|
||||||
|
}
|
||||||
|
if len(items) == 0 {
|
||||||
|
return &ParsedFactLinks{
|
||||||
|
Incoming: []database.ProjectFactEdgeFromInput{},
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
parsed := &ParsedFactLinks{}
|
||||||
|
for i, item := range items {
|
||||||
|
m, ok := item.(map[string]interface{})
|
||||||
|
if !ok {
|
||||||
|
return nil, fmt.Errorf("links[%d] 格式无效", i)
|
||||||
|
}
|
||||||
|
from, _ := m["from"].(string)
|
||||||
|
edgeType, _ := m["type"].(string)
|
||||||
|
from = strings.TrimSpace(from)
|
||||||
|
edgeType = strings.TrimSpace(edgeType)
|
||||||
|
if from == "" {
|
||||||
|
return nil, fmt.Errorf("links[%d] 须含 from", i)
|
||||||
|
}
|
||||||
|
if edgeType == "" {
|
||||||
|
return nil, fmt.Errorf("links[%d] 须含 type", i)
|
||||||
|
}
|
||||||
|
conf, _ := m["confidence"].(string)
|
||||||
|
parsed.Incoming = append(parsed.Incoming, database.ProjectFactEdgeFromInput{
|
||||||
|
From: from, Type: edgeType, Confidence: strings.TrimSpace(conf),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
return parsed, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ParseFactLinksText 解析 UI 文本:`type: source_fact_key` 每行一条(from 语义)。
|
||||||
|
func ParseFactLinksText(text string) ([]database.ProjectFactEdgeFromInput, error) {
|
||||||
|
return ParseFactIncomingLinksText(text)
|
||||||
|
}
|
||||||
|
|
||||||
|
// FormatFactLinksText 将入边格式化为 UI 文本。
|
||||||
|
func FormatFactLinksText(edges []*database.ProjectFactEdge) string {
|
||||||
|
return FormatFactIncomingLinksText(edges)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ParseFactIncomingLinksText 解析 UI 入边文本:`type: source_fact_key` 每行一条。
|
||||||
|
func ParseFactIncomingLinksText(text string) ([]database.ProjectFactEdgeFromInput, error) {
|
||||||
|
text = strings.TrimSpace(text)
|
||||||
|
if text == "" {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
var out []database.ProjectFactEdgeFromInput
|
||||||
|
for i, line := range strings.Split(text, "\n") {
|
||||||
|
line = strings.TrimSpace(line)
|
||||||
|
if line == "" || strings.HasPrefix(line, "#") {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
edgeType, source, ok := strings.Cut(line, ":")
|
||||||
|
if !ok {
|
||||||
|
return nil, fmt.Errorf("第 %d 行格式无效,应为 type: fact_key", i+1)
|
||||||
|
}
|
||||||
|
edgeType = strings.TrimSpace(edgeType)
|
||||||
|
source = strings.TrimSpace(source)
|
||||||
|
if edgeType == "" || source == "" {
|
||||||
|
return nil, fmt.Errorf("第 %d 行 type 或 fact_key 为空", i+1)
|
||||||
|
}
|
||||||
|
out = append(out, database.ProjectFactEdgeFromInput{From: source, Type: edgeType})
|
||||||
|
}
|
||||||
|
return out, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// FormatFactIncomingLinksText 将入边格式化为 UI 文本。
|
||||||
|
func FormatFactIncomingLinksText(edges []*database.ProjectFactEdge) string {
|
||||||
|
if len(edges) == 0 {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
var b strings.Builder
|
||||||
|
for i, e := range edges {
|
||||||
|
if i > 0 {
|
||||||
|
b.WriteByte('\n')
|
||||||
|
}
|
||||||
|
b.WriteString(e.EdgeType)
|
||||||
|
b.WriteString(": ")
|
||||||
|
b.WriteString(e.SourceFactKey)
|
||||||
|
}
|
||||||
|
return b.String()
|
||||||
|
}
|
||||||
|
|
||||||
|
// FactEdgeRecordingGuidance 写入边时的 Agent 规范。
|
||||||
|
func FactEdgeRecordingGuidance() string {
|
||||||
|
return projectprompt.FactEdgeRecordingGuidance()
|
||||||
|
}
|
||||||
@@ -0,0 +1,96 @@
|
|||||||
|
package project
|
||||||
|
|
||||||
|
import (
|
||||||
|
"cyberstrike-ai/internal/database"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ApplyFactOutgoingLinks 替换某事实的出边(links 为 nil 时不修改)。
|
||||||
|
func ApplyFactOutgoingLinks(db *database.DB, projectID, sourceFactKey, sourceConversationID string, links []database.ProjectFactEdgeInput) error {
|
||||||
|
if links == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return db.ReplaceOutgoingProjectFactEdges(projectID, sourceFactKey, sourceConversationID, links)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ResolveFactLinkInputs 合并 links 数组与 links_text 文本(数组优先)。
|
||||||
|
func ResolveFactLinkInputs(links []database.ProjectFactEdgeFromInput, linksText string) ([]database.ProjectFactEdgeFromInput, error) {
|
||||||
|
if len(links) > 0 {
|
||||||
|
return links, nil
|
||||||
|
}
|
||||||
|
return ParseFactLinksText(linksText)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ApplyFactIncomingLinks 替换某事实的入边(links 为 nil 时不修改)。
|
||||||
|
func ApplyFactIncomingLinks(db *database.DB, projectID, targetFactKey string, links []database.ProjectFactEdgeFromInput) error {
|
||||||
|
if links == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return db.ReplaceIncomingProjectFactEdges(projectID, targetFactKey, links)
|
||||||
|
}
|
||||||
|
|
||||||
|
// PersistFactIncomingLinks 写入入边并可选同步当前事实 body「关联」段。
|
||||||
|
func PersistFactIncomingLinks(db *database.DB, projectID, targetFactKey string, links []database.ProjectFactEdgeFromInput, syncBody bool) error {
|
||||||
|
if links == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if err := ApplyFactIncomingLinks(db, projectID, targetFactKey, links); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if !syncBody {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
f, err := db.GetProjectFactByKey(projectID, targetFactKey)
|
||||||
|
if err != nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
in, err := db.ListIncomingProjectFactEdges(projectID, targetFactKey)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
f.Body = SyncBodyLinksSection(f.Body, in)
|
||||||
|
_, err = db.UpsertProjectFact(f)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// PersistFactLinksFromParsed 写入解析后的 links(parsed 为 nil 表示不修改)。
|
||||||
|
func PersistFactLinksFromParsed(db *database.DB, projectID, factKey, sourceConversationID string, parsed *ParsedFactLinks, syncBody bool) error {
|
||||||
|
if parsed == nil || parsed.Incoming == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return PersistFactIncomingLinks(db, projectID, factKey, parsed.Incoming, syncBody)
|
||||||
|
}
|
||||||
|
|
||||||
|
// PersistFactOutgoingLinks 写入出边(图连线等低层 API;body 同步请用 PersistFactIncomingLinks)。
|
||||||
|
func PersistFactOutgoingLinks(db *database.DB, projectID, sourceFactKey, sourceConversationID string, links []database.ProjectFactEdgeInput, syncBody bool) error {
|
||||||
|
if links == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return ApplyFactOutgoingLinks(db, projectID, sourceFactKey, sourceConversationID, links)
|
||||||
|
}
|
||||||
|
|
||||||
|
// LinkCountMap 项目内各 fact 的入/出边计数。
|
||||||
|
type LinkCountMap map[string]LinkCounts
|
||||||
|
|
||||||
|
// LinkCounts 单 fact 的入/出边数。
|
||||||
|
type LinkCounts struct {
|
||||||
|
Outgoing int `json:"outgoing"`
|
||||||
|
Incoming int `json:"incoming"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// LoadProjectFactLinkCounts 批量加载边计数。
|
||||||
|
func LoadProjectFactLinkCounts(db *database.DB, projectID string) (LinkCountMap, error) {
|
||||||
|
edges, err := db.ListProjectFactEdgesByProject(projectID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
m := LinkCountMap{}
|
||||||
|
for _, e := range edges {
|
||||||
|
c := m[e.SourceFactKey]
|
||||||
|
c.Outgoing++
|
||||||
|
m[e.SourceFactKey] = c
|
||||||
|
c = m[e.TargetFactKey]
|
||||||
|
c.Incoming++
|
||||||
|
m[e.TargetFactKey] = c
|
||||||
|
}
|
||||||
|
return m, nil
|
||||||
|
}
|
||||||
@@ -0,0 +1,296 @@
|
|||||||
|
package project
|
||||||
|
|
||||||
|
import (
|
||||||
|
"path/filepath"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"cyberstrike-ai/internal/database"
|
||||||
|
|
||||||
|
"go.uber.org/zap"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestParseFactLinksText(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
inputs, err := ParseFactLinksText("discovered_on: target/api\nleads_to: finding/swagger")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if len(inputs) != 2 {
|
||||||
|
t.Fatalf("want 2 links, got %d", len(inputs))
|
||||||
|
}
|
||||||
|
if inputs[0].Type != "discovered_on" || inputs[0].From != "target/api" {
|
||||||
|
t.Fatalf("unexpected first link: %+v", inputs[0])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseFactIncomingLinksText(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
inputs, err := ParseFactIncomingLinksText("leads_to: finding/swagger\ndepends_on: target/api")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if len(inputs) != 2 {
|
||||||
|
t.Fatalf("want 2 links, got %d", len(inputs))
|
||||||
|
}
|
||||||
|
if inputs[0].Type != "leads_to" || inputs[0].From != "finding/swagger" {
|
||||||
|
t.Fatalf("unexpected first link: %+v", inputs[0])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFormatFactIncomingLinksText(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
text := FormatFactIncomingLinksText([]*database.ProjectFactEdge{
|
||||||
|
{EdgeType: "leads_to", SourceFactKey: "finding/a"},
|
||||||
|
{EdgeType: "depends_on", SourceFactKey: "target/b"},
|
||||||
|
})
|
||||||
|
want := "leads_to: finding/a\ndepends_on: target/b"
|
||||||
|
if text != want {
|
||||||
|
t.Fatalf("got %q want %q", text, want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseFactLinkInputsEmptyClears(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
parsed, err := ParseFactLinkInputs([]interface{}{})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if parsed == nil || parsed.Incoming == nil || len(parsed.Incoming) != 0 {
|
||||||
|
t.Fatalf("empty array should clear incoming links, got %v", parsed)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseFactLinkInputsFrom(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
raw := []interface{}{
|
||||||
|
map[string]interface{}{
|
||||||
|
"from": "target/primary_domain",
|
||||||
|
"type": "discovered_on",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
parsed, err := ParseFactLinkInputs(raw)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if len(parsed.Incoming) != 1 || parsed.Incoming[0].From != "target/primary_domain" {
|
||||||
|
t.Fatalf("unexpected incoming: %+v", parsed.Incoming)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseFactLinkInputsRequiresFrom(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
raw := []interface{}{
|
||||||
|
map[string]interface{}{
|
||||||
|
"to": "target/primary_domain",
|
||||||
|
"type": "discovered_on",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
_, err := ParseFactLinkInputs(raw)
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error when from is missing")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGraphNodeType(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
if GraphNodeType("chain", "chain/x") != "chain" {
|
||||||
|
t.Fatal("chain category")
|
||||||
|
}
|
||||||
|
if GraphNodeType("finding", "finding/x") != "finding" {
|
||||||
|
t.Fatal("finding category")
|
||||||
|
}
|
||||||
|
if GraphNodeType("exploit", "exploit/x") != "exploit" {
|
||||||
|
t.Fatal("exploit category")
|
||||||
|
}
|
||||||
|
if GraphNodeType("finding", "evidence/x") != "finding" {
|
||||||
|
t.Fatal("category should override evidence key prefix")
|
||||||
|
}
|
||||||
|
if GraphNodeType("note", "target/x") != "note" {
|
||||||
|
t.Fatal("category should override target key prefix")
|
||||||
|
}
|
||||||
|
if GraphNodeType("vuln", "finding/x") != "vulnerability" {
|
||||||
|
t.Fatal("vuln category maps to vulnerability node type")
|
||||||
|
}
|
||||||
|
if GraphNodeType("", "target/x") != "target" {
|
||||||
|
t.Fatal("empty category falls back to target key prefix")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBuildProjectFactGraphPreservesStoredEdgeDirection(t *testing.T) {
|
||||||
|
dir := t.TempDir()
|
||||||
|
db, err := database.NewDB(filepath.Join(dir, "test.db"), zap.NewNop())
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
defer db.Close()
|
||||||
|
|
||||||
|
p, err := db.CreateProject(&database.Project{Name: "path-edges"})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
for _, spec := range []struct{ key, cat string }{
|
||||||
|
{"target/primary_domain", "target"},
|
||||||
|
{"chain/full_attack_path", "chain"},
|
||||||
|
{"finding/mysql_public", "finding"},
|
||||||
|
{"exploit/mysql_creds_extract", "exploit"},
|
||||||
|
} {
|
||||||
|
if _, err := db.UpsertProjectFact(&database.ProjectFact{
|
||||||
|
ProjectID: p.ID, FactKey: spec.key, Category: spec.cat, Summary: spec.key, Confidence: "confirmed",
|
||||||
|
}); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if err := db.ReplaceIncomingProjectFactEdges(p.ID, "finding/mysql_public", []database.ProjectFactEdgeFromInput{
|
||||||
|
{From: "target/primary_domain", Type: "discovered_on"},
|
||||||
|
}); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if err := db.ReplaceIncomingProjectFactEdges(p.ID, "finding/mysql_public", []database.ProjectFactEdgeFromInput{
|
||||||
|
{From: "target/primary_domain", Type: "discovered_on"},
|
||||||
|
{From: "exploit/mysql_creds_extract", Type: "exploits"},
|
||||||
|
}); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if err := db.ReplaceIncomingProjectFactEdges(p.ID, "chain/full_attack_path", []database.ProjectFactEdgeFromInput{
|
||||||
|
{From: "target/primary_domain", Type: "discovered_on"},
|
||||||
|
}); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if err := db.ReplaceIncomingProjectFactEdges(p.ID, "exploit/mysql_creds_extract", []database.ProjectFactEdgeFromInput{
|
||||||
|
{From: "chain/full_attack_path", Type: "leads_to"},
|
||||||
|
}); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
graph, err := BuildProjectFactGraph(db, p.ID, "path", true)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
want := map[string]struct{}{
|
||||||
|
"target/primary_domain|discovered_on|finding/mysql_public": {},
|
||||||
|
"exploit/mysql_creds_extract|exploits|finding/mysql_public": {},
|
||||||
|
"target/primary_domain|discovered_on|chain/full_attack_path": {},
|
||||||
|
"chain/full_attack_path|leads_to|exploit/mysql_creds_extract": {},
|
||||||
|
}
|
||||||
|
for _, e := range graph.Edges {
|
||||||
|
key := e.Source + "|" + e.Type + "|" + e.Target
|
||||||
|
delete(want, key)
|
||||||
|
}
|
||||||
|
if len(want) > 0 {
|
||||||
|
t.Fatalf("missing expected stored-direction edges: %v", want)
|
||||||
|
}
|
||||||
|
countInOut := func(factKey string) (out, in int) {
|
||||||
|
for _, e := range graph.Edges {
|
||||||
|
if e.Source == factKey {
|
||||||
|
out++
|
||||||
|
}
|
||||||
|
if e.Target == factKey {
|
||||||
|
in++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return out, in
|
||||||
|
}
|
||||||
|
if out, in := countInOut("chain/full_attack_path"); out != 1 || in != 1 {
|
||||||
|
t.Fatalf("chain/full_attack_path want out=1 in=1 got out=%d in=%d", out, in)
|
||||||
|
}
|
||||||
|
if out, in := countInOut("exploit/mysql_creds_extract"); out != 1 || in != 1 {
|
||||||
|
t.Fatalf("exploit/mysql_creds_extract want out=1 in=1 got out=%d in=%d", out, in)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPersistFactLinksFromUsesFromAsIncoming(t *testing.T) {
|
||||||
|
dir := t.TempDir()
|
||||||
|
db, err := database.NewDB(filepath.Join(dir, "test.db"), zap.NewNop())
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
defer db.Close()
|
||||||
|
|
||||||
|
p, err := db.CreateProject(&database.Project{Name: "from-links"})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
for _, spec := range []struct{ key, cat string }{
|
||||||
|
{"target/primary_domain", "target"},
|
||||||
|
{"finding/sqli", "finding"},
|
||||||
|
} {
|
||||||
|
if _, err := db.UpsertProjectFact(&database.ProjectFact{
|
||||||
|
ProjectID: p.ID, FactKey: spec.key, Category: spec.cat, Summary: spec.key, Confidence: "confirmed",
|
||||||
|
}); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
parsed := &ParsedFactLinks{
|
||||||
|
Incoming: []database.ProjectFactEdgeFromInput{
|
||||||
|
{From: "target/primary_domain", Type: "discovered_on"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
if err := PersistFactLinksFromParsed(db, p.ID, "finding/sqli", "", parsed, false); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
graph, err := BuildProjectFactGraph(db, p.ID, "path", true)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
want := "target/primary_domain|discovered_on|finding/sqli"
|
||||||
|
for _, e := range graph.Edges {
|
||||||
|
key := e.Source + "|" + e.Type + "|" + e.Target
|
||||||
|
if key == want {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
t.Fatalf("expected edge %s, got %+v", want, graph.Edges)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFormatOutgoingLinksHint(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
hint := FormatOutgoingLinksHint([]*database.ProjectFactEdge{
|
||||||
|
{EdgeType: "discovered_on", TargetFactKey: "target/a"},
|
||||||
|
})
|
||||||
|
if hint == "" || hint[0] != ' ' {
|
||||||
|
t.Fatalf("unexpected hint: %q", hint)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestReplaceIncomingAllowsNotYetCreatedSource(t *testing.T) {
|
||||||
|
dir := t.TempDir()
|
||||||
|
db, err := database.NewDB(filepath.Join(dir, "test.db"), zap.NewNop())
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
defer db.Close()
|
||||||
|
|
||||||
|
p, err := db.CreateProject(&database.Project{Name: "parallel-links"})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if _, err := db.UpsertProjectFact(&database.ProjectFact{
|
||||||
|
ProjectID: p.ID, FactKey: "exploit/sqli", Category: "exploit", Summary: "exploit", Confidence: "confirmed",
|
||||||
|
}); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if err := db.ReplaceIncomingProjectFactEdges(p.ID, "exploit/sqli", []database.ProjectFactEdgeFromInput{
|
||||||
|
{From: "finding/sqli_endpoint", Type: "exploits"},
|
||||||
|
}); err != nil {
|
||||||
|
t.Fatalf("incoming edge should not require source fact to exist yet: %v", err)
|
||||||
|
}
|
||||||
|
if _, err := db.UpsertProjectFact(&database.ProjectFact{
|
||||||
|
ProjectID: p.ID, FactKey: "finding/sqli_endpoint", Category: "finding", Summary: "finding", Confidence: "confirmed",
|
||||||
|
}); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
in, err := db.ListIncomingProjectFactEdges(p.ID, "exploit/sqli")
|
||||||
|
if err != nil || len(in) != 1 || in[0].SourceFactKey != "finding/sqli_endpoint" {
|
||||||
|
t.Fatalf("expected persisted edge from finding, got %+v err=%v", in, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidateProjectFactEdgeType(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
if err := database.ValidateProjectFactEdgeType("leads_to"); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if err := database.ValidateProjectFactEdgeType("invalid"); err == nil {
|
||||||
|
t.Fatal("expected error")
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,231 @@
|
|||||||
|
package project
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"sort"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"cyberstrike-ai/internal/database"
|
||||||
|
)
|
||||||
|
|
||||||
|
var factIndexEdgeTypeOrder = []string{
|
||||||
|
"discovered_on", "leads_to", "enables", "depends_on", "exploits", "contains", "part_of", "supports",
|
||||||
|
}
|
||||||
|
|
||||||
|
func filterIndexEdges(edges []*database.ProjectFactEdge) []*database.ProjectFactEdge {
|
||||||
|
if len(edges) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
out := make([]*database.ProjectFactEdge, 0, len(edges))
|
||||||
|
for _, e := range edges {
|
||||||
|
if e == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if strings.EqualFold(strings.TrimSpace(e.Confidence), "deprecated") {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
edgeType := strings.ToLower(strings.TrimSpace(e.EdgeType))
|
||||||
|
if _, ok := database.ValidProjectFactEdgeTypes[edgeType]; !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
out = append(out, e)
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
func edgeConfidenceSuffix(confidence string) string {
|
||||||
|
c := strings.ToLower(strings.TrimSpace(confidence))
|
||||||
|
if c == "" || c == "confirmed" {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return " (" + c + ")"
|
||||||
|
}
|
||||||
|
|
||||||
|
func formatRelationHintPart(e *database.ProjectFactEdge) string {
|
||||||
|
return fmt.Sprintf("%s←%s%s", e.EdgeType, e.SourceFactKey, edgeConfidenceSuffix(e.Confidence))
|
||||||
|
}
|
||||||
|
|
||||||
|
func formatOutgoingHintPart(e *database.ProjectFactEdge) string {
|
||||||
|
return fmt.Sprintf("%s→%s%s", e.EdgeType, e.TargetFactKey, edgeConfidenceSuffix(e.Confidence))
|
||||||
|
}
|
||||||
|
|
||||||
|
func formatIncomingHintPart(e *database.ProjectFactEdge) string {
|
||||||
|
return formatRelationHintPart(e)
|
||||||
|
}
|
||||||
|
|
||||||
|
func joinEdgeHintParts(edges []*database.ProjectFactEdge, formatter func(*database.ProjectFactEdge) string) string {
|
||||||
|
parts := make([]string, 0, len(edges))
|
||||||
|
for _, e := range edges {
|
||||||
|
parts = append(parts, formatter(e))
|
||||||
|
}
|
||||||
|
return strings.Join(parts, ", ")
|
||||||
|
}
|
||||||
|
|
||||||
|
// FormatOutgoingLinksHint 黑板索引用出边摘要(全部有效边类型,不截断)。
|
||||||
|
func FormatOutgoingLinksHint(edges []*database.ProjectFactEdge) string {
|
||||||
|
edges = filterIndexEdges(edges)
|
||||||
|
if len(edges) == 0 {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return " {出边: " + joinEdgeHintParts(edges, formatOutgoingHintPart) + "}"
|
||||||
|
}
|
||||||
|
|
||||||
|
// FormatIncomingLinksHint 黑板索引用入边摘要(全部有效边类型,不截断)。
|
||||||
|
func FormatIncomingLinksHint(edges []*database.ProjectFactEdge) string {
|
||||||
|
edges = filterIndexEdges(edges)
|
||||||
|
if len(edges) == 0 {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return " {入边: " + joinEdgeHintParts(edges, formatIncomingHintPart) + "}"
|
||||||
|
}
|
||||||
|
|
||||||
|
// FormatFactIndexLinksHint 黑板索引行内关系边(from → 当前 fact,与 upsert links 一致)。
|
||||||
|
func FormatFactIndexLinksHint(_ string, incoming []*database.ProjectFactEdge) string {
|
||||||
|
in := filterIndexEdges(incoming)
|
||||||
|
if len(in) == 0 {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return " {关系边: " + joinEdgeHintParts(in, formatRelationHintPart) + "}"
|
||||||
|
}
|
||||||
|
|
||||||
|
func indexEdgeGroupMaps(edges []*database.ProjectFactEdge) (outgoing, incoming map[string][]*database.ProjectFactEdge) {
|
||||||
|
outgoing = map[string][]*database.ProjectFactEdge{}
|
||||||
|
incoming = map[string][]*database.ProjectFactEdge{}
|
||||||
|
for _, e := range filterIndexEdges(edges) {
|
||||||
|
outgoing[e.SourceFactKey] = append(outgoing[e.SourceFactKey], e)
|
||||||
|
incoming[e.TargetFactKey] = append(incoming[e.TargetFactKey], e)
|
||||||
|
}
|
||||||
|
return outgoing, incoming
|
||||||
|
}
|
||||||
|
|
||||||
|
func relationOverviewLine(e *database.ProjectFactEdge) string {
|
||||||
|
return fmt.Sprintf("- %s → %s%s · %s", e.SourceFactKey, e.TargetFactKey, edgeConfidenceSuffix(e.Confidence), e.EdgeType)
|
||||||
|
}
|
||||||
|
|
||||||
|
func indexEdgeSortKey(e *database.ProjectFactEdge) (int, int, string) {
|
||||||
|
confRank := 0
|
||||||
|
if strings.EqualFold(strings.TrimSpace(e.Confidence), "tentative") {
|
||||||
|
confRank = 1
|
||||||
|
}
|
||||||
|
typeRank := len(factIndexEdgeTypeOrder) + 1
|
||||||
|
for i, t := range factIndexEdgeTypeOrder {
|
||||||
|
if strings.EqualFold(e.EdgeType, t) {
|
||||||
|
typeRank = i
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return confRank, typeRank, e.SourceFactKey + ">" + e.TargetFactKey + ">" + e.EdgeType
|
||||||
|
}
|
||||||
|
|
||||||
|
func sortIndexOverviewEdges(edges []*database.ProjectFactEdge) {
|
||||||
|
sort.SliceStable(edges, func(i, j int) bool {
|
||||||
|
ci, ti, ki := indexEdgeSortKey(edges[i])
|
||||||
|
cj, tj, kj := indexEdgeSortKey(edges[j])
|
||||||
|
if ci != cj {
|
||||||
|
return ci < cj
|
||||||
|
}
|
||||||
|
if ti != tj {
|
||||||
|
return ti < tj
|
||||||
|
}
|
||||||
|
return ki < kj
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// BuildFactPathOverviewSection 生成事实关系速览(全部有效边类型,不含 body)。
|
||||||
|
func BuildFactPathOverviewSection(edges []*database.ProjectFactEdge, indexedKeys map[string]struct{}, maxRunes int) string {
|
||||||
|
if maxRunes <= 0 {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
candidates := filterIndexEdges(edges)
|
||||||
|
if len(candidates) == 0 {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
filtered := make([]*database.ProjectFactEdge, 0, len(candidates))
|
||||||
|
for _, e := range candidates {
|
||||||
|
if len(indexedKeys) > 0 {
|
||||||
|
if _, ok := indexedKeys[e.SourceFactKey]; !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if _, ok := indexedKeys[e.TargetFactKey]; !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
filtered = append(filtered, e)
|
||||||
|
}
|
||||||
|
if len(filtered) == 0 {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
sortIndexOverviewEdges(filtered)
|
||||||
|
|
||||||
|
header := "### 攻击路径(事实关系)\n"
|
||||||
|
header += "source → target · type(与攻击路径图/库中方向一致;写入时在目标 fact 的 links 用 from 声明来源)\n"
|
||||||
|
var b strings.Builder
|
||||||
|
b.WriteString(header)
|
||||||
|
used := len([]rune(header))
|
||||||
|
omitted := 0
|
||||||
|
|
||||||
|
for _, e := range filtered {
|
||||||
|
line := relationOverviewLine(e) + "\n"
|
||||||
|
lineRunes := len([]rune(line))
|
||||||
|
if used+lineRunes > maxRunes {
|
||||||
|
omitted++
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
b.WriteString(line)
|
||||||
|
used += lineRunes
|
||||||
|
}
|
||||||
|
if omitted > 0 {
|
||||||
|
extra := fmt.Sprintf("(另有 %d 条关系边未列入,请 get_project_fact 查看完整关系。)\n", omitted)
|
||||||
|
if used+len([]rune(extra)) <= maxRunes {
|
||||||
|
b.WriteString(extra)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if used <= len([]rune(header)) {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return b.String()
|
||||||
|
}
|
||||||
|
|
||||||
|
func factIndexSortPriority(f *database.ProjectFact) int {
|
||||||
|
if f == nil {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
score := 0
|
||||||
|
if f.Pinned {
|
||||||
|
score += 1000
|
||||||
|
}
|
||||||
|
c := strings.ToLower(strings.TrimSpace(f.Category))
|
||||||
|
switch c {
|
||||||
|
case FactCategoryTarget:
|
||||||
|
score += 400
|
||||||
|
case FactCategoryFinding, FactCategoryChain:
|
||||||
|
score += 300
|
||||||
|
case FactCategoryExploit, FactCategoryPOC:
|
||||||
|
score += 250
|
||||||
|
case "auth", "infra", "business":
|
||||||
|
score += 200
|
||||||
|
case "note":
|
||||||
|
score += 50
|
||||||
|
default:
|
||||||
|
key := strings.ToLower(strings.TrimSpace(f.FactKey))
|
||||||
|
if strings.HasPrefix(key, "target/") {
|
||||||
|
score += 400
|
||||||
|
} else if strings.HasPrefix(key, "finding/") || strings.HasPrefix(key, "chain/") {
|
||||||
|
score += 300
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if strings.EqualFold(strings.TrimSpace(f.Confidence), "confirmed") {
|
||||||
|
score += 80
|
||||||
|
}
|
||||||
|
return score
|
||||||
|
}
|
||||||
|
|
||||||
|
func sortFactsForIndex(facts []*database.ProjectFact) {
|
||||||
|
sort.SliceStable(facts, func(i, j int) bool {
|
||||||
|
pi, pj := factIndexSortPriority(facts[i]), factIndexSortPriority(facts[j])
|
||||||
|
if pi != pj {
|
||||||
|
return pi > pj
|
||||||
|
}
|
||||||
|
return facts[i].UpdatedAt.After(facts[j].UpdatedAt)
|
||||||
|
})
|
||||||
|
}
|
||||||
@@ -0,0 +1,161 @@
|
|||||||
|
package project
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"path/filepath"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"cyberstrike-ai/internal/config"
|
||||||
|
"cyberstrike-ai/internal/database"
|
||||||
|
|
||||||
|
"go.uber.org/zap"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestFormatIncomingLinksHint(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
hint := FormatIncomingLinksHint([]*database.ProjectFactEdge{
|
||||||
|
{EdgeType: "discovered_on", SourceFactKey: "finding/x", Confidence: "tentative"},
|
||||||
|
})
|
||||||
|
if !strings.Contains(hint, "入边:") {
|
||||||
|
t.Fatalf("expected 入边 label: %q", hint)
|
||||||
|
}
|
||||||
|
if !strings.Contains(hint, "discovered_on←finding/x") {
|
||||||
|
t.Fatalf("unexpected hint: %q", hint)
|
||||||
|
}
|
||||||
|
if !strings.Contains(hint, "tentative") {
|
||||||
|
t.Fatalf("expected tentative in hint: %q", hint)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFormatIncomingLinksHint_allEdges(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
edges := make([]*database.ProjectFactEdge, 0, 5)
|
||||||
|
for i := 1; i <= 5; i++ {
|
||||||
|
edges = append(edges, &database.ProjectFactEdge{
|
||||||
|
EdgeType: "discovered_on",
|
||||||
|
SourceFactKey: fmt.Sprintf("finding/f%d", i),
|
||||||
|
Confidence: "tentative",
|
||||||
|
})
|
||||||
|
}
|
||||||
|
hint := FormatIncomingLinksHint(edges)
|
||||||
|
if strings.Contains(hint, "+") {
|
||||||
|
t.Fatalf("should not truncate with +N: %q", hint)
|
||||||
|
}
|
||||||
|
for i := 1; i <= 5; i++ {
|
||||||
|
if !strings.Contains(hint, fmt.Sprintf("finding/f%d", i)) {
|
||||||
|
t.Fatalf("missing edge f%d in hint: %q", i, hint)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFormatFactIndexLinksHint_incomingOnly(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
in := []*database.ProjectFactEdge{
|
||||||
|
{EdgeType: "discovered_on", SourceFactKey: "target/dev", Confidence: "tentative"},
|
||||||
|
{EdgeType: "exploits", SourceFactKey: "exploit/rce", Confidence: "confirmed"},
|
||||||
|
}
|
||||||
|
hint := FormatFactIndexLinksHint("finding/sqli", in)
|
||||||
|
if !strings.Contains(hint, "关系边:") {
|
||||||
|
t.Fatalf("missing 关系边 label: %q", hint)
|
||||||
|
}
|
||||||
|
if !strings.Contains(hint, "discovered_on←target/dev") {
|
||||||
|
t.Fatalf("missing discovered_on: %q", hint)
|
||||||
|
}
|
||||||
|
if !strings.Contains(hint, "exploits←exploit/rce") {
|
||||||
|
t.Fatalf("missing exploits: %q", hint)
|
||||||
|
}
|
||||||
|
if strings.Contains(hint, "出边") || strings.Contains(hint, "入边") {
|
||||||
|
t.Fatalf("should not use legacy 出边/入边 labels: %q", hint)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFormatFactIndexLinksHint_includesAuxiliaryEdgeTypes(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
in := []*database.ProjectFactEdge{{EdgeType: "supports", SourceFactKey: "note/log"}}
|
||||||
|
hint := FormatFactIndexLinksHint("finding/x", in)
|
||||||
|
if !strings.Contains(hint, "supports←note/log") {
|
||||||
|
t.Fatalf("supports edge should be included: %q", hint)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBuildFactPathOverviewSection(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
edges := []*database.ProjectFactEdge{
|
||||||
|
{EdgeType: "discovered_on", SourceFactKey: "target/dev", TargetFactKey: "finding/sqli", Confidence: "tentative"},
|
||||||
|
{EdgeType: "exploits", SourceFactKey: "exploit/rce", TargetFactKey: "finding/sqli", Confidence: "confirmed"},
|
||||||
|
{EdgeType: "supports", SourceFactKey: "note/log", TargetFactKey: "finding/sqli"},
|
||||||
|
}
|
||||||
|
keys := map[string]struct{}{
|
||||||
|
"target/dev": {}, "finding/sqli": {}, "exploit/rce": {}, "note/log": {},
|
||||||
|
}
|
||||||
|
section := BuildFactPathOverviewSection(edges, keys, 800)
|
||||||
|
if !strings.Contains(section, "### 攻击路径(事实关系)") {
|
||||||
|
t.Fatalf("missing header: %q", section)
|
||||||
|
}
|
||||||
|
if !strings.Contains(section, "target/dev → finding/sqli") {
|
||||||
|
t.Fatalf("missing discovered_on line: %q", section)
|
||||||
|
}
|
||||||
|
if !strings.Contains(section, "exploit/rce → finding/sqli") {
|
||||||
|
t.Fatalf("missing exploits line: %q", section)
|
||||||
|
}
|
||||||
|
if !strings.Contains(section, "note/log → finding/sqli") {
|
||||||
|
t.Fatalf("supports edge should be included: %q", section)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBuildFactIndexBlock_withLinksAndPathOverview(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
dbPath := filepath.Join(t.TempDir(), "facts.db")
|
||||||
|
db, err := database.NewDB(dbPath, zap.NewNop())
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
defer db.Close()
|
||||||
|
|
||||||
|
proj, err := db.CreateProject(&database.Project{Name: "path-proj"})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
_, err = db.UpsertProjectFact(&database.ProjectFact{
|
||||||
|
ProjectID: proj.ID,
|
||||||
|
FactKey: "target/dev",
|
||||||
|
Category: "target",
|
||||||
|
Summary: "dev 子域",
|
||||||
|
Confidence: "confirmed",
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
_, err = db.UpsertProjectFact(&database.ProjectFact{
|
||||||
|
ProjectID: proj.ID,
|
||||||
|
FactKey: "finding/sqli",
|
||||||
|
Category: "finding",
|
||||||
|
Summary: "时间盲注",
|
||||||
|
Confidence: "tentative",
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
_, err = db.AddProjectFactEdge(proj.ID, database.ProjectFactEdgeInput{
|
||||||
|
To: "finding/sqli",
|
||||||
|
Type: "discovered_on",
|
||||||
|
}, "target/dev", "")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
block, err := BuildFactIndexBlock(db, proj.ID, config.ProjectConfig{Enabled: true, FactIndexMaxRunes: 6500, FactIndexPathMaxRunes: 1000})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if !strings.Contains(block, "关系边: discovered_on←target/dev") {
|
||||||
|
t.Fatalf("finding line should include relation hint: %q", block)
|
||||||
|
}
|
||||||
|
if !strings.Contains(block, "### 攻击路径(事实关系)") {
|
||||||
|
t.Fatalf("missing relation overview: %q", block)
|
||||||
|
}
|
||||||
|
if !strings.Contains(block, "target/dev → finding/sqli") {
|
||||||
|
t.Fatalf("missing overview edge: %q", block)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -1,100 +1,23 @@
|
|||||||
package project
|
package project
|
||||||
|
|
||||||
import (
|
import "cyberstrike-ai/internal/projectprompt"
|
||||||
"strings"
|
|
||||||
|
|
||||||
"cyberstrike-ai/internal/mcp/builtin"
|
// FactRecordingIncrementalRhythmMarkdown 见 projectprompt。
|
||||||
)
|
|
||||||
|
|
||||||
// 边渗透边记录:统一节奏文案(agents/*.md 须与 FactRecordingIncrementalRhythmMarkdown 保持一致)。
|
|
||||||
const (
|
|
||||||
factRhythmCore = "勿等会话结束或收尾再批量写入。每**确认**一条新认知(开放端口/服务版本、入口路径、认证态或凭据特征、可利用点或攻击面变化)后,**立即**调用 `upsert_project_fact`(同 fact_key 覆盖更新)。每**验证**出一条可复现漏洞(含 POC/影响)后,**立即**调用 `record_vulnerability`;与事实可各记一次。继续下一步工作前优先落库,避免上下文压缩后细节丢失。未绑项目时说明无法写黑板,仍在本轮保留证据摘要。"
|
|
||||||
factRhythmCoordinatorSuffix = "委派/子任务返回新认知或漏洞时,由协调者及时写入,勿假定子代理已记。"
|
|
||||||
factRhythmSubAgentSuffix = "若工具集中无上述工具,须在交付物末尾给出「待落库」结构化条目(fact_key 建议、summary、body/POC 要点),供协调者**立即**写入。"
|
|
||||||
)
|
|
||||||
|
|
||||||
// FactRecordingIncrementalRhythmMarkdown 返回边渗透边记录节奏(Markdown,供 agents/*.md 与文档对齐)。
|
|
||||||
func FactRecordingIncrementalRhythmMarkdown(coordinator, subAgent bool) string {
|
func FactRecordingIncrementalRhythmMarkdown(coordinator, subAgent bool) string {
|
||||||
var b strings.Builder
|
return projectprompt.FactRecordingIncrementalRhythmMarkdown(coordinator, subAgent)
|
||||||
b.WriteString("- **边渗透边记录(强制节奏)**:")
|
|
||||||
b.WriteString(factRhythmCore)
|
|
||||||
if coordinator {
|
|
||||||
b.WriteString(factRhythmCoordinatorSuffix)
|
|
||||||
}
|
|
||||||
if subAgent {
|
|
||||||
b.WriteString(factRhythmSubAgentSuffix)
|
|
||||||
}
|
|
||||||
return b.String()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func factRecordingIncrementalRhythmBuiltin(coordinator, subAgent bool) string {
|
// FactRecordingBlackboardSection 见 projectprompt。
|
||||||
var b strings.Builder
|
|
||||||
b.WriteString("- **边渗透边记录(强制节奏)**:勿等会话结束或收尾再批量写入。每**确认**一条新认知(开放端口/服务版本、入口路径、认证态或凭据特征、可利用点或攻击面变化)后,**立即**调用 ")
|
|
||||||
b.WriteString(builtin.ToolUpsertProjectFact)
|
|
||||||
b.WriteString("(同 fact_key 覆盖更新)。每**验证**出一条可复现漏洞(含 POC/影响)后,**立即**调用 ")
|
|
||||||
b.WriteString(builtin.ToolRecordVulnerability)
|
|
||||||
b.WriteString(";与事实可各记一次。继续下一步工作前优先落库,避免上下文压缩后细节丢失。未绑项目时说明无法写黑板,仍在本轮保留证据摘要。")
|
|
||||||
if coordinator {
|
|
||||||
b.WriteString(factRhythmCoordinatorSuffix)
|
|
||||||
}
|
|
||||||
if subAgent {
|
|
||||||
b.WriteString(factRhythmSubAgentSuffix)
|
|
||||||
}
|
|
||||||
return b.String()
|
|
||||||
}
|
|
||||||
|
|
||||||
// FactRecordingBlackboardSection 项目黑板与漏洞记录的完整系统提示块(单/多 Agent 主代理共用)。
|
|
||||||
// coordinatorDelegate 为 true 时追加「协调者代子代理落库」说明(Deep / plan_execute / supervisor)。
|
|
||||||
func FactRecordingBlackboardSection(coordinatorDelegate bool) string {
|
func FactRecordingBlackboardSection(coordinatorDelegate bool) string {
|
||||||
var b strings.Builder
|
return projectprompt.FactRecordingBlackboardSection(coordinatorDelegate)
|
||||||
b.WriteString("## 项目黑板(事实)与漏洞记录(分离)\n\n")
|
|
||||||
b.WriteString("当前对话若已绑定项目,系统会自动注入「项目黑板索引」(仅 fact_key + 摘要)。**摘要不足时必须调用 ")
|
|
||||||
b.WriteString(builtin.ToolGetProjectFact)
|
|
||||||
b.WriteString("(fact_key) 获取 body,禁止凭摘要臆造细节。**\n\n")
|
|
||||||
b.WriteString(factRecordingIncrementalRhythmBuiltin(coordinatorDelegate, false))
|
|
||||||
b.WriteString("\n\n")
|
|
||||||
b.WriteString("- **环境/目标/认证等认知**(非正式漏洞条目):使用 ")
|
|
||||||
b.WriteString(builtin.ToolUpsertProjectFact)
|
|
||||||
b.WriteString(",fact_key 建议 `category/slug`(如 target/primary_domain),同 key 覆盖更新;body 记端口/版本/凭据特征与证据来源。\n")
|
|
||||||
b.WriteString("- **发现与利用上下文**(审计复现):fact_key 建议 finding/、chain/、exploit/、poc/ 前缀;**body 必填**完整攻击链(入口 → 步骤 → 原始请求/响应或命令 → 现象 → 关联 related_vulnerability_id),**禁止仅写结论**;summary 写「什么 + 在哪 + 如何验证」一行要点。\n")
|
|
||||||
b.WriteString("- **可交付漏洞**:使用 ")
|
|
||||||
b.WriteString(builtin.ToolRecordVulnerability)
|
|
||||||
b.WriteString(",含标题、严重程度、类型、目标、证明(POC)、影响、修复建议。记前可先 ")
|
|
||||||
b.WriteString(builtin.ToolListVulnerabilities)
|
|
||||||
b.WriteString(" 查重,详情用 ")
|
|
||||||
b.WriteString(builtin.ToolGetVulnerability)
|
|
||||||
b.WriteString("(id)(默认仅当前项目/会话)。\n")
|
|
||||||
b.WriteString("- 同一发现可能需**各记一次**(事实记**完整攻击链与 exploit 细节**供复现,漏洞记正式 findings)。误报用 ")
|
|
||||||
b.WriteString(builtin.ToolDeprecateProjectFact)
|
|
||||||
b.WriteString(" 或漏洞状态 false_positive。\n")
|
|
||||||
b.WriteString("- 事实多时用 ")
|
|
||||||
b.WriteString(builtin.ToolListProjectFacts)
|
|
||||||
b.WriteString(" / ")
|
|
||||||
b.WriteString(builtin.ToolSearchProjectFacts)
|
|
||||||
b.WriteString(" 检索。\n\n")
|
|
||||||
b.WriteString(FactRecordingGuidanceBlock())
|
|
||||||
b.WriteString("\n\n严重程度:critical / high / medium / low / info。证明须含足够证据(请求响应、截图、命令输出等)。")
|
|
||||||
return b.String()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// FactRecordingSubAgentSection 子代理边渗透边记录(无工具时输出待落库条目)。
|
// FactRecordingSubAgentSection 见 projectprompt。
|
||||||
func FactRecordingSubAgentSection() string {
|
func FactRecordingSubAgentSection() string {
|
||||||
return "## 边渗透边记录\n\n" + factRecordingIncrementalRhythmBuiltin(false, true) + "\n"
|
return projectprompt.FactRecordingSubAgentSection()
|
||||||
}
|
}
|
||||||
|
|
||||||
// FactRecordingBlackboardSectionMarkdown 与 FactRecordingBlackboardSection 等价的 Markdown(工具名为字面量,供 agents/*.md)。
|
// FactRecordingBlackboardSectionMarkdown 见 projectprompt。
|
||||||
func FactRecordingBlackboardSectionMarkdown(coordinatorDelegate bool) string {
|
func FactRecordingBlackboardSectionMarkdown(coordinatorDelegate bool) string {
|
||||||
var b strings.Builder
|
return projectprompt.FactRecordingBlackboardSectionMarkdown(coordinatorDelegate)
|
||||||
b.WriteString("## 项目黑板(事实)与漏洞记录(分离)\n\n")
|
|
||||||
b.WriteString("当前对话若已绑定项目,系统会自动注入「项目黑板索引」(仅 `fact_key` + 摘要)。**摘要不足时必须调用 `get_project_fact(fact_key)` 获取 body,禁止凭摘要臆造细节。**\n\n")
|
|
||||||
b.WriteString(FactRecordingIncrementalRhythmMarkdown(coordinatorDelegate, false))
|
|
||||||
b.WriteString("\n\n")
|
|
||||||
b.WriteString("- **环境/目标/认证等认知**(非正式漏洞):使用 **`upsert_project_fact`**,`fact_key` 建议 `category/slug`(如 `target/primary_domain`),同 key 覆盖更新;body 记端口/版本/凭据特征与证据来源。\n")
|
|
||||||
b.WriteString("- **发现与利用上下文**(审计复现):`fact_key` 建议 `finding/`、`chain/`、`exploit/`、`poc/` 前缀;**body 必填**完整攻击链(入口 → 步骤 → 原始请求/响应或命令 → 现象 → 关联 `related_vulnerability_id`),**禁止仅写结论**;summary 写「什么 + 在哪 + 如何验证」一行要点。\n")
|
|
||||||
b.WriteString("- **可交付漏洞**:使用 **`record_vulnerability`**(标题、描述、严重程度、类型、目标、证明 POC、影响、修复建议)。严重程度 critical / high / medium / low / info。\n")
|
|
||||||
b.WriteString("- 同一发现可能需**各记一次**(事实记可复现攻击链,漏洞记正式 findings)。误报用 **`deprecate_project_fact`** 或漏洞状态 false_positive。\n")
|
|
||||||
b.WriteString("- 事实多时用 **`list_project_facts`** / **`search_project_facts`** 检索。\n\n")
|
|
||||||
b.WriteString(FactRecordingGuidanceBlock())
|
|
||||||
b.WriteString("\n\n严重程度:critical / high / medium / low / info。证明须含足够证据(请求响应、截图、命令输出等)。")
|
|
||||||
return b.String()
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -3,6 +3,8 @@ package project
|
|||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
|
"cyberstrike-ai/internal/projectprompt"
|
||||||
)
|
)
|
||||||
|
|
||||||
// 事实 category 常量(写入 upsert_project_fact 的 category 字段)。
|
// 事实 category 常量(写入 upsert_project_fact 的 category 字段)。
|
||||||
@@ -90,7 +92,8 @@ const attackChainFactBodyTemplate = `## 结论(可验证,一句话)
|
|||||||
|
|
||||||
## 关联
|
## 关联
|
||||||
- related_vulnerability_id: <可选,对应 record_vulnerability 的 id>
|
- related_vulnerability_id: <可选,对应 record_vulnerability 的 id>
|
||||||
- 依赖事实: <fact_key,如 auth/session_cookie>
|
- links(upsert 参数): [{ "from": "<fact_key>", "type": "discovered_on|..." }](from → 当前 fact)
|
||||||
|
- 依赖事实(body 可读镜像): <fact_key,如 auth/session_cookie>
|
||||||
|
|
||||||
## 备注与不确定性
|
## 备注与不确定性
|
||||||
<待验证假设、环境差异、绕过尝试记录>`
|
<待验证假设、环境差异、绕过尝试记录>`
|
||||||
@@ -109,15 +112,7 @@ const envFactBodyTemplate = `## 摘要
|
|||||||
|
|
||||||
// FactRecordingGuidanceBlock 写入系统提示:要求事实沉淀攻击链上下文而非仅结论。
|
// FactRecordingGuidanceBlock 写入系统提示:要求事实沉淀攻击链上下文而非仅结论。
|
||||||
func FactRecordingGuidanceBlock() string {
|
func FactRecordingGuidanceBlock() string {
|
||||||
return `### 事实写入规范(审计复现 / 知识沉淀)
|
return projectprompt.FactRecordingGuidanceBlock()
|
||||||
|
|
||||||
- **summary**:索引用一行,须含「什么 + 在哪 + 如何触发/验证」要点,禁止只写结论(如仅写「存在 SQLi」)。
|
|
||||||
- **body**:完整可复现上下文,写入 ` + "`upsert_project_fact`" + ` 的 body 字段;索引不含 body,后续会话须靠 ` + "`get_project_fact`" + ` 取回。
|
|
||||||
- **category / fact_key 建议**:
|
|
||||||
- 环境认知:` + "`target/`" + `、` + "`auth/`" + `、` + "`infra/`" + `、` + "`business/`" + `(body 用环境模板即可)
|
|
||||||
- 发现与利用:` + "`finding/`" + `、` + "`chain/`" + `、` + "`exploit/`" + `、` + "`poc/`" + `(**必须**用攻击链模板填满 body:入口、逐步攻击链、原始请求/响应或命令、证据、关联漏洞 ID)
|
|
||||||
- **与漏洞记录分工**:` + "`record_vulnerability`" + ` 记可交付 findings;事实记**复现所需的全部上下文**(含失败尝试、绕过、依赖会话),二者可各记一次。
|
|
||||||
- 更新同一发现时保持相同 ` + "`fact_key`" + ` 覆盖写入,勿散落多个 key 导致上下文丢失。`
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// SparseBodyWarning 攻击链类事实 body 不足时的工具返回提示(不阻断保存)。
|
// SparseBodyWarning 攻击链类事实 body 不足时的工具返回提示(不阻断保存)。
|
||||||
|
|||||||
@@ -2,10 +2,14 @@ package project
|
|||||||
|
|
||||||
import "strings"
|
import "strings"
|
||||||
|
|
||||||
|
// VisionImageSectionMarker 图片分析 section 标题(与 AppendVisionImageAnalysisIfReady 注入一致)。
|
||||||
|
const VisionImageSectionMarker = "## 图片分析"
|
||||||
|
|
||||||
// VisionImageAnalysisSection 单/多代理共用的图片分析提示(analyze_image;上下文仅保留文字摘要)。
|
// VisionImageAnalysisSection 单/多代理共用的图片分析提示(analyze_image;上下文仅保留文字摘要)。
|
||||||
func VisionImageAnalysisSection() string {
|
func VisionImageAnalysisSection() string {
|
||||||
var b strings.Builder
|
var b strings.Builder
|
||||||
b.WriteString("## 图片分析\n\n")
|
b.WriteString(VisionImageSectionMarker)
|
||||||
|
b.WriteString("\n\n")
|
||||||
b.WriteString("- 遇到图片文件(截图、验证码、登录页、报告配图)时,若存在工具 analyze_image,请传入服务器上的文件路径进行分析。\n")
|
b.WriteString("- 遇到图片文件(截图、验证码、登录页、报告配图)时,若存在工具 analyze_image,请传入服务器上的文件路径进行分析。\n")
|
||||||
b.WriteString("- 不要对二进制图片使用 read_file 指望理解内容;用户消息中「📎 xxx.png: /path」即为可传给 analyze_image 的路径。\n")
|
b.WriteString("- 不要对二进制图片使用 read_file 指望理解内容;用户消息中「📎 xxx.png: /path」即为可传给 analyze_image 的路径。\n")
|
||||||
b.WriteString("- 验证码类:若已从页面或接口保存为本地图片(如 captcha.png),用 analyze_image,question 写明「只输出验证码字符」;识别失败则刷新验证码后重新保存再识;复杂滑块/行为验证码勿指望单次识图成功。\n")
|
b.WriteString("- 验证码类:若已从页面或接口保存为本地图片(如 captcha.png),用 analyze_image,question 写明「只输出验证码字符」;识别失败则刷新验证码后重新保存再识;复杂滑块/行为验证码勿指望单次识图成功。\n")
|
||||||
|
|||||||
@@ -0,0 +1,132 @@
|
|||||||
|
// Package projectprompt 提供项目黑板相关的系统提示文本(纯字符串,无 database 依赖)。
|
||||||
|
// 供 agent / multiagent 等包引用,避免 agent → project 导入环导致 gopls 元数据失败。
|
||||||
|
package projectprompt
|
||||||
|
|
||||||
|
import (
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"cyberstrike-ai/internal/mcp/builtin"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
factRhythmCore = "勿等会话结束或收尾再批量写入。每**确认**一条新认知(开放端口/服务版本、入口路径、认证态或凭据特征、可利用点或攻击面变化)后,**立即**调用 `upsert_project_fact`(同 fact_key 覆盖更新)。每**验证**出一条可复现漏洞(含 POC/影响)后,**立即**调用 `record_vulnerability`;与事实可各记一次。继续下一步工作前优先落库,避免上下文压缩后细节丢失。未绑项目时说明无法写黑板,仍在本轮保留证据摘要。"
|
||||||
|
factRhythmCoordinatorSuffix = "委派/子任务返回新认知或漏洞时,由协调者及时写入,勿假定子代理已记。"
|
||||||
|
factRhythmSubAgentSuffix = "若工具集中无上述工具,须在交付物末尾给出「待落库」结构化条目(fact_key 建议、summary、body/POC 要点),供协调者**立即**写入。"
|
||||||
|
)
|
||||||
|
|
||||||
|
// FactRecordingIncrementalRhythmMarkdown 返回边渗透边记录节奏(Markdown,供 agents/*.md 与文档对齐)。
|
||||||
|
func FactRecordingIncrementalRhythmMarkdown(coordinator, subAgent bool) string {
|
||||||
|
var b strings.Builder
|
||||||
|
b.WriteString("- **边渗透边记录(强制节奏)**:")
|
||||||
|
b.WriteString(factRhythmCore)
|
||||||
|
if coordinator {
|
||||||
|
b.WriteString(factRhythmCoordinatorSuffix)
|
||||||
|
}
|
||||||
|
if subAgent {
|
||||||
|
b.WriteString(factRhythmSubAgentSuffix)
|
||||||
|
}
|
||||||
|
return b.String()
|
||||||
|
}
|
||||||
|
|
||||||
|
func factRecordingIncrementalRhythmBuiltin(coordinator, subAgent bool) string {
|
||||||
|
var b strings.Builder
|
||||||
|
b.WriteString("- **边渗透边记录(强制节奏)**:勿等会话结束或收尾再批量写入。每**确认**一条新认知(开放端口/服务版本、入口路径、认证态或凭据特征、可利用点或攻击面变化)后,**立即**调用 ")
|
||||||
|
b.WriteString(builtin.ToolUpsertProjectFact)
|
||||||
|
b.WriteString("(同 fact_key 覆盖更新)。每**验证**出一条可复现漏洞(含 POC/影响)后,**立即**调用 ")
|
||||||
|
b.WriteString(builtin.ToolRecordVulnerability)
|
||||||
|
b.WriteString(";与事实可各记一次。继续下一步工作前优先落库,避免上下文压缩后细节丢失。未绑项目时说明无法写黑板,仍在本轮保留证据摘要。")
|
||||||
|
if coordinator {
|
||||||
|
b.WriteString(factRhythmCoordinatorSuffix)
|
||||||
|
}
|
||||||
|
if subAgent {
|
||||||
|
b.WriteString(factRhythmSubAgentSuffix)
|
||||||
|
}
|
||||||
|
return b.String()
|
||||||
|
}
|
||||||
|
|
||||||
|
func factEdgeRecordingGuidance() string {
|
||||||
|
return `### 事实关系边(links)
|
||||||
|
|
||||||
|
- 写入 **finding / chain / exploit / poc** 时,**必须**在 ` + "`upsert_project_fact`" + ` 中提供 ` + "`links`" + `(**推荐 ` + "`from`" + `**:来源 fact 指向当前 fact,即 ` + "`from`" + ` → 当前 ` + "`fact_key`" + `)。
|
||||||
|
- **最少要求**:finding 类至少 1 条 from=target/* + type=discovered_on(即 target → finding);在 finding 上记录 exploit 用 from=exploit/* + type=exploits(即 exploit → finding)。
|
||||||
|
- **常用 type**:` + "`discovered_on`" + `(发现在哪)、` + "`depends_on`" + `(复现前置)、` + "`leads_to`" + `(认知推进)、` + "`enables`" + `(扩大攻击面)、` + "`exploits`" + `(利用关系)、` + "`contains`" + `(资产包含)、` + "`part_of`" + `(属于链/组)、` + "`supports`" + `(证据支撑)。
|
||||||
|
- 更新时:**省略 links 保留已有边**;传入 links 则**替换**全部关系边(from → 当前 fact)。
|
||||||
|
- body 中「依赖事实」段落可与 links 并存(人读);结构化关系以 links 为准。`
|
||||||
|
}
|
||||||
|
|
||||||
|
func factRecordingGuidanceBlock() string {
|
||||||
|
return `### 事实写入规范(审计复现 / 知识沉淀)
|
||||||
|
|
||||||
|
- **summary**:索引用一行,须含「什么 + 在哪 + 如何触发/验证」要点,禁止只写结论(如仅写「存在 SQLi」)。
|
||||||
|
- **body**:完整可复现上下文,写入 ` + "`upsert_project_fact`" + ` 的 body 字段;索引不含 body,后续会话须靠 ` + "`get_project_fact`" + ` 取回。
|
||||||
|
- **category / fact_key 建议**:
|
||||||
|
- 环境认知:` + "`target/`" + `、` + "`auth/`" + `、` + "`infra/`" + `、` + "`business/`" + `(body 用环境模板即可)
|
||||||
|
- 发现与利用:` + "`finding/`" + `、` + "`chain/`" + `、` + "`exploit/`" + `、` + "`poc/`" + `(**必须**用攻击链模板填满 body:入口、逐步攻击链、原始请求/响应或命令、证据、关联漏洞 ID)
|
||||||
|
- **与漏洞记录分工**:` + "`record_vulnerability`" + ` 记可交付 findings;事实记**复现所需的全部上下文**(含失败尝试、绕过、依赖会话),二者可各记一次。
|
||||||
|
- 更新同一发现时保持相同 ` + "`fact_key`" + ` 覆盖写入,勿散落多个 key 导致上下文丢失。`
|
||||||
|
}
|
||||||
|
|
||||||
|
// FactRecordingBlackboardSection 项目黑板与漏洞记录的完整系统提示块(单/多 Agent 主代理共用)。
|
||||||
|
func FactRecordingBlackboardSection(coordinatorDelegate bool) string {
|
||||||
|
var b strings.Builder
|
||||||
|
b.WriteString("## 项目黑板(事实)与漏洞记录(分离)\n\n")
|
||||||
|
b.WriteString("当前对话若已绑定项目,系统会自动注入「项目黑板索引」(仅 fact_key + 摘要)。**摘要不足时必须调用 ")
|
||||||
|
b.WriteString(builtin.ToolGetProjectFact)
|
||||||
|
b.WriteString("(fact_key) 获取 body,禁止凭摘要臆造细节。**\n\n")
|
||||||
|
b.WriteString(factRecordingIncrementalRhythmBuiltin(coordinatorDelegate, false))
|
||||||
|
b.WriteString("\n\n")
|
||||||
|
b.WriteString("- **环境/目标/认证等认知**(非正式漏洞条目):使用 ")
|
||||||
|
b.WriteString(builtin.ToolUpsertProjectFact)
|
||||||
|
b.WriteString(",fact_key 建议 `category/slug`(如 target/primary_domain),同 key 覆盖更新;body 记端口/版本/凭据特征与证据来源。\n")
|
||||||
|
b.WriteString("- **发现与利用上下文**(审计复现):fact_key 建议 finding/、chain/、exploit/、poc/ 前缀;**body 必填**完整攻击链(入口 → 步骤 → 原始请求/响应或命令 → 现象 → 关联 related_vulnerability_id),**禁止仅写结论**;summary 写「什么 + 在哪 + 如何验证」一行要点。\n")
|
||||||
|
b.WriteString("- **可交付漏洞**:使用 ")
|
||||||
|
b.WriteString(builtin.ToolRecordVulnerability)
|
||||||
|
b.WriteString(",含标题、严重程度、类型、目标、证明(POC)、影响、修复建议。记前可先 ")
|
||||||
|
b.WriteString(builtin.ToolListVulnerabilities)
|
||||||
|
b.WriteString(" 查重,详情用 ")
|
||||||
|
b.WriteString(builtin.ToolGetVulnerability)
|
||||||
|
b.WriteString("(id)(默认仅当前项目/会话)。\n")
|
||||||
|
b.WriteString("- 同一发现可能需**各记一次**(事实记**完整攻击链与 exploit 细节**供复现,漏洞记正式 findings)。误报用 ")
|
||||||
|
b.WriteString(builtin.ToolDeprecateProjectFact)
|
||||||
|
b.WriteString(" 或漏洞状态 false_positive。\n")
|
||||||
|
b.WriteString("- 事实多时用 ")
|
||||||
|
b.WriteString(builtin.ToolListProjectFacts)
|
||||||
|
b.WriteString(" / ")
|
||||||
|
b.WriteString(builtin.ToolSearchProjectFacts)
|
||||||
|
b.WriteString(" 检索。\n\n")
|
||||||
|
b.WriteString(factEdgeRecordingGuidance())
|
||||||
|
b.WriteString("\n\n")
|
||||||
|
b.WriteString(factRecordingGuidanceBlock())
|
||||||
|
b.WriteString("\n\n严重程度:critical / high / medium / low / info。证明须含足够证据(请求响应、截图、命令输出等)。")
|
||||||
|
return b.String()
|
||||||
|
}
|
||||||
|
|
||||||
|
// FactRecordingSubAgentSection 子代理边渗透边记录(无工具时输出待落库条目)。
|
||||||
|
func FactRecordingSubAgentSection() string {
|
||||||
|
return "## 边渗透边记录\n\n" + factRecordingIncrementalRhythmBuiltin(false, true) + "\n"
|
||||||
|
}
|
||||||
|
|
||||||
|
// FactRecordingBlackboardSectionMarkdown 与 FactRecordingBlackboardSection 等价的 Markdown(工具名为字面量,供 agents/*.md)。
|
||||||
|
func FactRecordingBlackboardSectionMarkdown(coordinatorDelegate bool) string {
|
||||||
|
var b strings.Builder
|
||||||
|
b.WriteString("## 项目黑板(事实)与漏洞记录(分离)\n\n")
|
||||||
|
b.WriteString("当前对话若已绑定项目,系统会自动注入「项目黑板索引」(仅 `fact_key` + 摘要)。**摘要不足时必须调用 `get_project_fact(fact_key)` 获取 body,禁止凭摘要臆造细节。**\n\n")
|
||||||
|
b.WriteString(FactRecordingIncrementalRhythmMarkdown(coordinatorDelegate, false))
|
||||||
|
b.WriteString("\n\n")
|
||||||
|
b.WriteString("- **环境/目标/认证等认知**(非正式漏洞):使用 **`upsert_project_fact`**,`fact_key` 建议 `category/slug`(如 `target/primary_domain`),同 key 覆盖更新;body 记端口/版本/凭据特征与证据来源。\n")
|
||||||
|
b.WriteString("- **发现与利用上下文**(审计复现):`fact_key` 建议 `finding/`、`chain/`、`exploit/`、`poc/` 前缀;**body 必填**完整攻击链(入口 → 步骤 → 原始请求/响应或命令 → 现象 → 关联 `related_vulnerability_id`),**禁止仅写结论**;summary 写「什么 + 在哪 + 如何验证」一行要点。\n")
|
||||||
|
b.WriteString("- **可交付漏洞**:使用 **`record_vulnerability`**(标题、描述、严重程度、类型、目标、证明 POC、影响、修复建议)。严重程度 critical / high / medium / low / info。\n")
|
||||||
|
b.WriteString("- 同一发现可能需**各记一次**(事实记可复现攻击链,漏洞记正式 findings)。误报用 **`deprecate_project_fact`** 或漏洞状态 false_positive。\n")
|
||||||
|
b.WriteString("- 事实多时用 **`list_project_facts`** / **`search_project_facts`** 检索。\n\n")
|
||||||
|
b.WriteString(factEdgeRecordingGuidance())
|
||||||
|
b.WriteString("\n\n")
|
||||||
|
b.WriteString(factRecordingGuidanceBlock())
|
||||||
|
b.WriteString("\n\n严重程度:critical / high / medium / low / info。证明须含足够证据(请求响应、截图、命令输出等)。")
|
||||||
|
return b.String()
|
||||||
|
}
|
||||||
|
|
||||||
|
// FactEdgeRecordingGuidance 写入边时的 Agent 规范(供 project 包复用)。
|
||||||
|
func FactEdgeRecordingGuidance() string { return factEdgeRecordingGuidance() }
|
||||||
|
|
||||||
|
// FactRecordingGuidanceBlock 事实写入规范块(供 project 包复用)。
|
||||||
|
func FactRecordingGuidanceBlock() string { return factRecordingGuidanceBlock() }
|
||||||
+53
-23
@@ -84,8 +84,9 @@ func ApplyToEinoChatModelConfig(cfg *einoopenai.ChatModelConfig, oa *config.Open
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// applyClaudeExtendedThinking sets Anthropic Messages API `thinking` when absent from ExtraRequestFields.
|
// applyClaudeExtendedThinking sets Anthropic Messages API fields per official guidance:
|
||||||
// Uses adaptive + summarized display by default (per Anthropic guidance for Claude 4.x); Sonnet 3.7 uses enabled+budget.
|
// - Adaptive models (4.6+): thinking.type=adaptive; output_config.effort only when user sets effort (API default is high).
|
||||||
|
// - Sonnet 3.7: thinking.type=enabled + budget_tokens=10000 (doc example); effort is not mapped — use extra_request_fields for custom budget.
|
||||||
func applyClaudeExtendedThinking(cfg *einoopenai.ChatModelConfig, mode, effort, model string) {
|
func applyClaudeExtendedThinking(cfg *einoopenai.ChatModelConfig, mode, effort, model string) {
|
||||||
if cfg == nil || mode == "off" {
|
if cfg == nil || mode == "off" {
|
||||||
return
|
return
|
||||||
@@ -93,31 +94,60 @@ func applyClaudeExtendedThinking(cfg *einoopenai.ChatModelConfig, mode, effort,
|
|||||||
if cfg.ExtraFields == nil {
|
if cfg.ExtraFields == nil {
|
||||||
cfg.ExtraFields = make(map[string]any)
|
cfg.ExtraFields = make(map[string]any)
|
||||||
}
|
}
|
||||||
if _, exists := cfg.ExtraFields["thinking"]; exists {
|
m := strings.ToLower(strings.TrimSpace(model))
|
||||||
|
sonnet37 := isClaudeSonnet37(m)
|
||||||
|
|
||||||
|
if _, exists := cfg.ExtraFields["thinking"]; !exists {
|
||||||
|
cfg.ExtraFields["thinking"] = claudeThinkingForModel(m, sonnet37)
|
||||||
|
}
|
||||||
|
|
||||||
|
applyClaudeOutputConfigEffort(cfg, effort, sonnet37)
|
||||||
|
}
|
||||||
|
|
||||||
|
// claudeSonnet37DefaultBudgetTokens matches Anthropic extended-thinking documentation examples (budget_tokens with max_tokens 16000).
|
||||||
|
const claudeSonnet37DefaultBudgetTokens = 10000
|
||||||
|
|
||||||
|
func isClaudeSonnet37(m string) bool {
|
||||||
|
return strings.Contains(m, "claude-3-7-sonnet") ||
|
||||||
|
strings.Contains(m, "3-7-sonnet") ||
|
||||||
|
strings.Contains(m, "sonnet-3.7")
|
||||||
|
}
|
||||||
|
|
||||||
|
func claudeThinkingForModel(m string, sonnet37 bool) map[string]any {
|
||||||
|
if sonnet37 {
|
||||||
|
return map[string]any{
|
||||||
|
"type": "enabled",
|
||||||
|
"budget_tokens": claudeSonnet37DefaultBudgetTokens,
|
||||||
|
"display": "summarized",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Opus 4.7+: manual enabled+budget rejected — adaptive only.
|
||||||
|
if strings.Contains(m, "opus-4-7") || strings.Contains(m, "opus-4.7") {
|
||||||
|
return map[string]any{
|
||||||
|
"type": "adaptive",
|
||||||
|
"display": "summarized",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return map[string]any{
|
||||||
|
"type": "adaptive",
|
||||||
|
"display": "summarized",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// applyClaudeOutputConfigEffort sets top-level output_config.effort only when effort is explicitly configured.
|
||||||
|
// Omitted effort uses the API default (high); do not inject effort on mode:on alone.
|
||||||
|
func applyClaudeOutputConfigEffort(cfg *einoopenai.ChatModelConfig, effort string, sonnet37 bool) {
|
||||||
|
if cfg == nil || sonnet37 {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
m := strings.ToLower(strings.TrimSpace(model))
|
if _, exists := cfg.ExtraFields["output_config"]; exists {
|
||||||
thinking := map[string]any{
|
return
|
||||||
"type": "adaptive",
|
|
||||||
"display": "summarized",
|
|
||||||
}
|
}
|
||||||
// Sonnet 3.7: manual extended thinking is the documented path.
|
e := effortStringForAPI(effort)
|
||||||
if strings.Contains(m, "claude-3-7-sonnet") || strings.Contains(m, "3-7-sonnet") || strings.Contains(m, "sonnet-3.7") {
|
if e == "" {
|
||||||
thinking = map[string]any{
|
return
|
||||||
"type": "enabled",
|
|
||||||
"budget_tokens": 10000,
|
|
||||||
"display": "summarized",
|
|
||||||
}
|
}
|
||||||
}
|
cfg.ExtraFields["output_config"] = map[string]any{"effort": e}
|
||||||
// Opus 4.7+: manual enabled+budget rejected — keep adaptive only.
|
|
||||||
if strings.Contains(m, "opus-4-7") || strings.Contains(m, "opus-4.7") {
|
|
||||||
thinking = map[string]any{
|
|
||||||
"type": "adaptive",
|
|
||||||
"display": "summarized",
|
|
||||||
}
|
|
||||||
}
|
|
||||||
_ = effort // reserved: map to Anthropic effort / output_config when API stabilizes in one place
|
|
||||||
cfg.ExtraFields["thinking"] = thinking
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func effectiveMode(sr *config.OpenAIReasoningConfig, client *ClientIntent, allowClient bool) string {
|
func effectiveMode(sr *config.OpenAIReasoningConfig, client *ClientIntent, allowClient bool) string {
|
||||||
|
|||||||
@@ -80,3 +80,80 @@ func TestApplyOpenAICompat_maxPassthrough(t *testing.T) {
|
|||||||
t.Fatalf("max effort wire=%q, want max", got)
|
t.Fatalf("max effort wire=%q, want max", got)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestApplyClaude_adaptiveOutputConfigEffort(t *testing.T) {
|
||||||
|
cfg := &einoopenai.ChatModelConfig{}
|
||||||
|
oa := &config.OpenAIConfig{
|
||||||
|
Provider: "claude",
|
||||||
|
Model: "claude-opus-4-8",
|
||||||
|
Reasoning: config.OpenAIReasoningConfig{
|
||||||
|
Mode: "on",
|
||||||
|
Effort: "xhigh",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
ApplyToEinoChatModelConfig(cfg, oa, nil)
|
||||||
|
th, ok := cfg.ExtraFields["thinking"].(map[string]any)
|
||||||
|
if !ok || th["type"] != "adaptive" {
|
||||||
|
t.Fatalf("thinking=%#v", cfg.ExtraFields["thinking"])
|
||||||
|
}
|
||||||
|
oc, ok := cfg.ExtraFields["output_config"].(map[string]any)
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("expected output_config")
|
||||||
|
}
|
||||||
|
if oc["effort"] != "xhigh" {
|
||||||
|
t.Fatalf("effort=%v", oc["effort"])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestApplyClaude_sonnet37OfficialBudget(t *testing.T) {
|
||||||
|
cfg := &einoopenai.ChatModelConfig{}
|
||||||
|
oa := &config.OpenAIConfig{
|
||||||
|
Provider: "claude",
|
||||||
|
Model: "claude-3-7-sonnet-latest",
|
||||||
|
Reasoning: config.OpenAIReasoningConfig{
|
||||||
|
Mode: "on",
|
||||||
|
Effort: "low", // 3.7 has no output_config.effort; effort is not mapped to budget_tokens
|
||||||
|
},
|
||||||
|
}
|
||||||
|
ApplyToEinoChatModelConfig(cfg, oa, nil)
|
||||||
|
th, ok := cfg.ExtraFields["thinking"].(map[string]any)
|
||||||
|
if !ok || th["type"] != "enabled" {
|
||||||
|
t.Fatalf("thinking=%#v", cfg.ExtraFields["thinking"])
|
||||||
|
}
|
||||||
|
if th["budget_tokens"] != claudeSonnet37DefaultBudgetTokens {
|
||||||
|
t.Fatalf("budget_tokens=%v, want official example %d", th["budget_tokens"], claudeSonnet37DefaultBudgetTokens)
|
||||||
|
}
|
||||||
|
if _, hasOC := cfg.ExtraFields["output_config"]; hasOC {
|
||||||
|
t.Fatal("sonnet 3.7 should not set output_config")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestApplyClaude_onWithoutEffortOmitsOutputConfig(t *testing.T) {
|
||||||
|
cfg := &einoopenai.ChatModelConfig{}
|
||||||
|
oa := &config.OpenAIConfig{
|
||||||
|
Provider: "claude",
|
||||||
|
Model: "claude-sonnet-4-6",
|
||||||
|
Reasoning: config.OpenAIReasoningConfig{
|
||||||
|
Mode: "on",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
ApplyToEinoChatModelConfig(cfg, oa, nil)
|
||||||
|
if _, hasOC := cfg.ExtraFields["output_config"]; hasOC {
|
||||||
|
t.Fatal("on without explicit effort should omit output_config (API default high)")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestApplyClaude_autoWithoutEffortSkipsOutputConfig(t *testing.T) {
|
||||||
|
cfg := &einoopenai.ChatModelConfig{}
|
||||||
|
oa := &config.OpenAIConfig{
|
||||||
|
Provider: "claude",
|
||||||
|
Model: "claude-sonnet-4-6",
|
||||||
|
Reasoning: config.OpenAIReasoningConfig{
|
||||||
|
Mode: "auto",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
ApplyToEinoChatModelConfig(cfg, oa, nil)
|
||||||
|
if _, hasOC := cfg.ExtraFields["output_config"]; hasOC {
|
||||||
|
t.Fatal("auto without effort should omit output_config")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -16,7 +16,6 @@ import (
|
|||||||
|
|
||||||
"cyberstrike-ai/internal/config"
|
"cyberstrike-ai/internal/config"
|
||||||
"cyberstrike-ai/internal/mcp"
|
"cyberstrike-ai/internal/mcp"
|
||||||
"cyberstrike-ai/internal/storage"
|
|
||||||
|
|
||||||
"github.com/creack/pty"
|
"github.com/creack/pty"
|
||||||
"go.uber.org/zap"
|
"go.uber.org/zap"
|
||||||
@@ -37,19 +36,6 @@ type Executor struct {
|
|||||||
toolIndex map[string]*config.ToolConfig // 工具索引,用于 O(1) 查找
|
toolIndex map[string]*config.ToolConfig // 工具索引,用于 O(1) 查找
|
||||||
mcpServer *mcp.Server
|
mcpServer *mcp.Server
|
||||||
logger *zap.Logger
|
logger *zap.Logger
|
||||||
resultStorage ResultStorage // 结果存储(用于查询工具)
|
|
||||||
}
|
|
||||||
|
|
||||||
// ResultStorage 结果存储接口(直接使用 storage 包的类型)
|
|
||||||
type ResultStorage interface {
|
|
||||||
SaveResult(executionID string, toolName string, result string) error
|
|
||||||
GetResult(executionID string) (string, error)
|
|
||||||
GetResultPage(executionID string, page int, limit int) (*storage.ResultPage, error)
|
|
||||||
SearchResult(executionID string, keyword string, useRegex bool) ([]string, error)
|
|
||||||
FilterResult(executionID string, filter string, useRegex bool) ([]string, error)
|
|
||||||
GetResultMetadata(executionID string) (*storage.ResultMetadata, error)
|
|
||||||
GetResultPath(executionID string) string
|
|
||||||
DeleteResult(executionID string) error
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewExecutor 创建新的执行器
|
// NewExecutor 创建新的执行器
|
||||||
@@ -59,18 +45,12 @@ func NewExecutor(cfg *config.SecurityConfig, mcpServer *mcp.Server, logger *zap.
|
|||||||
toolIndex: make(map[string]*config.ToolConfig),
|
toolIndex: make(map[string]*config.ToolConfig),
|
||||||
mcpServer: mcpServer,
|
mcpServer: mcpServer,
|
||||||
logger: logger,
|
logger: logger,
|
||||||
resultStorage: nil, // 稍后通过 SetResultStorage 设置
|
|
||||||
}
|
}
|
||||||
// 构建工具索引
|
// 构建工具索引
|
||||||
executor.buildToolIndex()
|
executor.buildToolIndex()
|
||||||
return executor
|
return executor
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetResultStorage 设置结果存储
|
|
||||||
func (e *Executor) SetResultStorage(storage ResultStorage) {
|
|
||||||
e.resultStorage = storage
|
|
||||||
}
|
|
||||||
|
|
||||||
// buildToolIndex 构建工具索引,将 O(n) 查找优化为 O(1)
|
// buildToolIndex 构建工具索引,将 O(n) 查找优化为 O(1)
|
||||||
func (e *Executor) buildToolIndex() {
|
func (e *Executor) buildToolIndex() {
|
||||||
e.toolIndex = make(map[string]*config.ToolConfig)
|
e.toolIndex = make(map[string]*config.ToolConfig)
|
||||||
@@ -1245,20 +1225,11 @@ func runCommandWithPTY(ctx context.Context, cmd *exec.Cmd, cb ToolOutputCallback
|
|||||||
|
|
||||||
// executeInternalTool 执行内部工具(不执行外部命令)
|
// executeInternalTool 执行内部工具(不执行外部命令)
|
||||||
func (e *Executor) executeInternalTool(ctx context.Context, toolName string, command string, args map[string]interface{}) (*mcp.ToolResult, error) {
|
func (e *Executor) executeInternalTool(ctx context.Context, toolName string, command string, args map[string]interface{}) (*mcp.ToolResult, error) {
|
||||||
// 提取内部工具类型(去掉 "internal:" 前缀)
|
|
||||||
internalToolType := strings.TrimPrefix(command, "internal:")
|
internalToolType := strings.TrimPrefix(command, "internal:")
|
||||||
|
e.logger.Warn("未知的内部工具",
|
||||||
e.logger.Info("执行内部工具",
|
|
||||||
zap.String("toolName", toolName),
|
zap.String("toolName", toolName),
|
||||||
zap.String("internalToolType", internalToolType),
|
zap.String("internalToolType", internalToolType),
|
||||||
zap.Any("args", args),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// 根据内部工具类型分发处理
|
|
||||||
switch internalToolType {
|
|
||||||
case "query_execution_result":
|
|
||||||
return e.executeQueryExecutionResult(ctx, args)
|
|
||||||
default:
|
|
||||||
return &mcp.ToolResult{
|
return &mcp.ToolResult{
|
||||||
Content: []mcp.Content{
|
Content: []mcp.Content{
|
||||||
{
|
{
|
||||||
@@ -1269,213 +1240,6 @@ func (e *Executor) executeInternalTool(ctx context.Context, toolName string, com
|
|||||||
IsError: true,
|
IsError: true,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
// executeQueryExecutionResult 执行查询执行结果工具
|
|
||||||
func (e *Executor) executeQueryExecutionResult(ctx context.Context, args map[string]interface{}) (*mcp.ToolResult, error) {
|
|
||||||
// 获取 execution_id 参数
|
|
||||||
executionID, ok := args["execution_id"].(string)
|
|
||||||
if !ok || executionID == "" {
|
|
||||||
return &mcp.ToolResult{
|
|
||||||
Content: []mcp.Content{
|
|
||||||
{
|
|
||||||
Type: "text",
|
|
||||||
Text: "错误: execution_id 参数必需且不能为空",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
IsError: true,
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// 获取可选参数
|
|
||||||
page := 1
|
|
||||||
if p, ok := args["page"].(float64); ok {
|
|
||||||
page = int(p)
|
|
||||||
}
|
|
||||||
if page < 1 {
|
|
||||||
page = 1
|
|
||||||
}
|
|
||||||
|
|
||||||
limit := 100
|
|
||||||
if l, ok := args["limit"].(float64); ok {
|
|
||||||
limit = int(l)
|
|
||||||
}
|
|
||||||
if limit < 1 {
|
|
||||||
limit = 100
|
|
||||||
}
|
|
||||||
if limit > 500 {
|
|
||||||
limit = 500 // 限制最大每页行数
|
|
||||||
}
|
|
||||||
|
|
||||||
search := ""
|
|
||||||
if s, ok := args["search"].(string); ok {
|
|
||||||
search = s
|
|
||||||
}
|
|
||||||
|
|
||||||
filter := ""
|
|
||||||
if f, ok := args["filter"].(string); ok {
|
|
||||||
filter = f
|
|
||||||
}
|
|
||||||
|
|
||||||
useRegex := false
|
|
||||||
if r, ok := args["use_regex"].(bool); ok {
|
|
||||||
useRegex = r
|
|
||||||
}
|
|
||||||
|
|
||||||
// 检查结果存储是否可用
|
|
||||||
if e.resultStorage == nil {
|
|
||||||
return &mcp.ToolResult{
|
|
||||||
Content: []mcp.Content{
|
|
||||||
{
|
|
||||||
Type: "text",
|
|
||||||
Text: "错误: 结果存储未初始化",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
IsError: true,
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// 执行查询
|
|
||||||
var resultPage *storage.ResultPage
|
|
||||||
var err error
|
|
||||||
|
|
||||||
if search != "" {
|
|
||||||
// 搜索模式
|
|
||||||
matchedLines, err := e.resultStorage.SearchResult(executionID, search, useRegex)
|
|
||||||
if err != nil {
|
|
||||||
return &mcp.ToolResult{
|
|
||||||
Content: []mcp.Content{
|
|
||||||
{
|
|
||||||
Type: "text",
|
|
||||||
Text: fmt.Sprintf("搜索失败: %v", err),
|
|
||||||
},
|
|
||||||
},
|
|
||||||
IsError: true,
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
// 对搜索结果进行分页
|
|
||||||
resultPage = paginateLines(matchedLines, page, limit)
|
|
||||||
} else if filter != "" {
|
|
||||||
// 过滤模式
|
|
||||||
filteredLines, err := e.resultStorage.FilterResult(executionID, filter, useRegex)
|
|
||||||
if err != nil {
|
|
||||||
return &mcp.ToolResult{
|
|
||||||
Content: []mcp.Content{
|
|
||||||
{
|
|
||||||
Type: "text",
|
|
||||||
Text: fmt.Sprintf("过滤失败: %v", err),
|
|
||||||
},
|
|
||||||
},
|
|
||||||
IsError: true,
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
// 对过滤结果进行分页
|
|
||||||
resultPage = paginateLines(filteredLines, page, limit)
|
|
||||||
} else {
|
|
||||||
// 普通分页查询
|
|
||||||
resultPage, err = e.resultStorage.GetResultPage(executionID, page, limit)
|
|
||||||
if err != nil {
|
|
||||||
return &mcp.ToolResult{
|
|
||||||
Content: []mcp.Content{
|
|
||||||
{
|
|
||||||
Type: "text",
|
|
||||||
Text: fmt.Sprintf("查询失败: %v", err),
|
|
||||||
},
|
|
||||||
},
|
|
||||||
IsError: true,
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// 获取元信息
|
|
||||||
metadata, err := e.resultStorage.GetResultMetadata(executionID)
|
|
||||||
if err != nil {
|
|
||||||
// 元信息获取失败不影响查询结果
|
|
||||||
e.logger.Warn("获取结果元信息失败", zap.Error(err))
|
|
||||||
}
|
|
||||||
|
|
||||||
// 格式化返回结果
|
|
||||||
var sb strings.Builder
|
|
||||||
sb.WriteString(fmt.Sprintf("查询结果 (执行ID: %s)\n", executionID))
|
|
||||||
|
|
||||||
if metadata != nil {
|
|
||||||
sb.WriteString(fmt.Sprintf("工具: %s | 大小: %d 字节 (%.2f KB) | 总行数: %d\n",
|
|
||||||
metadata.ToolName, metadata.TotalSize, float64(metadata.TotalSize)/1024, metadata.TotalLines))
|
|
||||||
}
|
|
||||||
|
|
||||||
sb.WriteString(fmt.Sprintf("第 %d/%d 页,每页 %d 行,共 %d 行\n\n",
|
|
||||||
resultPage.Page, resultPage.TotalPages, resultPage.Limit, resultPage.TotalLines))
|
|
||||||
|
|
||||||
if len(resultPage.Lines) == 0 {
|
|
||||||
sb.WriteString("没有找到匹配的结果。\n")
|
|
||||||
} else {
|
|
||||||
for i, line := range resultPage.Lines {
|
|
||||||
lineNum := (resultPage.Page-1)*resultPage.Limit + i + 1
|
|
||||||
sb.WriteString(fmt.Sprintf("%d: %s\n", lineNum, line))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
sb.WriteString("\n")
|
|
||||||
if resultPage.Page < resultPage.TotalPages {
|
|
||||||
sb.WriteString(fmt.Sprintf("提示: 使用 page=%d 查看下一页", resultPage.Page+1))
|
|
||||||
if search != "" {
|
|
||||||
sb.WriteString(fmt.Sprintf(",或使用 search=\"%s\" 继续搜索", search))
|
|
||||||
if useRegex {
|
|
||||||
sb.WriteString(" (正则模式)")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if filter != "" {
|
|
||||||
sb.WriteString(fmt.Sprintf(",或使用 filter=\"%s\" 继续过滤", filter))
|
|
||||||
if useRegex {
|
|
||||||
sb.WriteString(" (正则模式)")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
sb.WriteString("\n")
|
|
||||||
}
|
|
||||||
|
|
||||||
return &mcp.ToolResult{
|
|
||||||
Content: []mcp.Content{
|
|
||||||
{
|
|
||||||
Type: "text",
|
|
||||||
Text: sb.String(),
|
|
||||||
},
|
|
||||||
},
|
|
||||||
IsError: false,
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// paginateLines 对行列表进行分页
|
|
||||||
func paginateLines(lines []string, page int, limit int) *storage.ResultPage {
|
|
||||||
totalLines := len(lines)
|
|
||||||
totalPages := (totalLines + limit - 1) / limit
|
|
||||||
if page < 1 {
|
|
||||||
page = 1
|
|
||||||
}
|
|
||||||
if page > totalPages && totalPages > 0 {
|
|
||||||
page = totalPages
|
|
||||||
}
|
|
||||||
|
|
||||||
start := (page - 1) * limit
|
|
||||||
end := start + limit
|
|
||||||
if end > totalLines {
|
|
||||||
end = totalLines
|
|
||||||
}
|
|
||||||
|
|
||||||
var pageLines []string
|
|
||||||
if start < totalLines {
|
|
||||||
pageLines = lines[start:end]
|
|
||||||
} else {
|
|
||||||
pageLines = []string{}
|
|
||||||
}
|
|
||||||
|
|
||||||
return &storage.ResultPage{
|
|
||||||
Lines: pageLines,
|
|
||||||
Page: page,
|
|
||||||
Limit: limit,
|
|
||||||
TotalLines: totalLines,
|
|
||||||
TotalPages: totalPages,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// buildInputSchema 构建输入模式
|
// buildInputSchema 构建输入模式
|
||||||
func (e *Executor) buildInputSchema(toolConfig *config.ToolConfig) map[string]interface{} {
|
func (e *Executor) buildInputSchema(toolConfig *config.ToolConfig) map[string]interface{} {
|
||||||
|
|||||||
@@ -2,15 +2,12 @@ package security
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"os"
|
|
||||||
"path/filepath"
|
|
||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"cyberstrike-ai/internal/config"
|
"cyberstrike-ai/internal/config"
|
||||||
"cyberstrike-ai/internal/mcp"
|
"cyberstrike-ai/internal/mcp"
|
||||||
"cyberstrike-ai/internal/storage"
|
|
||||||
|
|
||||||
"go.uber.org/zap"
|
"go.uber.org/zap"
|
||||||
)
|
)
|
||||||
@@ -28,137 +25,6 @@ func setupTestExecutor(t *testing.T) (*Executor, *mcp.Server) {
|
|||||||
return executor, mcpServer
|
return executor, mcpServer
|
||||||
}
|
}
|
||||||
|
|
||||||
// setupTestStorage 创建测试用的存储
|
|
||||||
func setupTestStorage(t *testing.T) *storage.FileResultStorage {
|
|
||||||
tmpDir := filepath.Join(os.TempDir(), "test_executor_storage_"+time.Now().Format("20060102_150405"))
|
|
||||||
logger := zap.NewNop()
|
|
||||||
|
|
||||||
storage, err := storage.NewFileResultStorage(tmpDir, logger)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("创建测试存储失败: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return storage
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestExecutor_ExecuteInternalTool_QueryExecutionResult(t *testing.T) {
|
|
||||||
executor, _ := setupTestExecutor(t)
|
|
||||||
testStorage := setupTestStorage(t)
|
|
||||||
executor.SetResultStorage(testStorage)
|
|
||||||
|
|
||||||
// 准备测试数据
|
|
||||||
executionID := "test_exec_001"
|
|
||||||
toolName := "nmap_scan"
|
|
||||||
result := "Line 1: Port 22 open\nLine 2: Port 80 open\nLine 3: Port 443 open\nLine 4: error occurred"
|
|
||||||
|
|
||||||
// 保存测试结果
|
|
||||||
err := testStorage.SaveResult(executionID, toolName, result)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("保存测试结果失败: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
ctx := context.Background()
|
|
||||||
|
|
||||||
// 测试1: 基本查询(第一页)
|
|
||||||
args := map[string]interface{}{
|
|
||||||
"execution_id": executionID,
|
|
||||||
"page": float64(1),
|
|
||||||
"limit": float64(2),
|
|
||||||
}
|
|
||||||
|
|
||||||
toolResult, err := executor.executeQueryExecutionResult(ctx, args)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("执行查询失败: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if toolResult.IsError {
|
|
||||||
t.Fatalf("查询应该成功,但返回了错误: %s", toolResult.Content[0].Text)
|
|
||||||
}
|
|
||||||
|
|
||||||
// 验证结果包含预期内容
|
|
||||||
resultText := toolResult.Content[0].Text
|
|
||||||
if !strings.Contains(resultText, executionID) {
|
|
||||||
t.Errorf("结果中应该包含执行ID: %s", executionID)
|
|
||||||
}
|
|
||||||
|
|
||||||
if !strings.Contains(resultText, "第 1/") {
|
|
||||||
t.Errorf("结果中应该包含分页信息")
|
|
||||||
}
|
|
||||||
|
|
||||||
// 测试2: 搜索功能
|
|
||||||
args2 := map[string]interface{}{
|
|
||||||
"execution_id": executionID,
|
|
||||||
"search": "error",
|
|
||||||
"page": float64(1),
|
|
||||||
"limit": float64(10),
|
|
||||||
}
|
|
||||||
|
|
||||||
toolResult2, err := executor.executeQueryExecutionResult(ctx, args2)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("执行搜索失败: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if toolResult2.IsError {
|
|
||||||
t.Fatalf("搜索应该成功,但返回了错误: %s", toolResult2.Content[0].Text)
|
|
||||||
}
|
|
||||||
|
|
||||||
resultText2 := toolResult2.Content[0].Text
|
|
||||||
if !strings.Contains(resultText2, "error") {
|
|
||||||
t.Errorf("搜索结果中应该包含关键词: error")
|
|
||||||
}
|
|
||||||
|
|
||||||
// 测试3: 过滤功能
|
|
||||||
args3 := map[string]interface{}{
|
|
||||||
"execution_id": executionID,
|
|
||||||
"filter": "Port",
|
|
||||||
"page": float64(1),
|
|
||||||
"limit": float64(10),
|
|
||||||
}
|
|
||||||
|
|
||||||
toolResult3, err := executor.executeQueryExecutionResult(ctx, args3)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("执行过滤失败: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if toolResult3.IsError {
|
|
||||||
t.Fatalf("过滤应该成功,但返回了错误: %s", toolResult3.Content[0].Text)
|
|
||||||
}
|
|
||||||
|
|
||||||
resultText3 := toolResult3.Content[0].Text
|
|
||||||
if !strings.Contains(resultText3, "Port") {
|
|
||||||
t.Errorf("过滤结果中应该包含关键词: Port")
|
|
||||||
}
|
|
||||||
|
|
||||||
// 测试4: 缺少必需参数
|
|
||||||
args4 := map[string]interface{}{
|
|
||||||
"page": float64(1),
|
|
||||||
}
|
|
||||||
|
|
||||||
toolResult4, err := executor.executeQueryExecutionResult(ctx, args4)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("执行查询失败: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if !toolResult4.IsError {
|
|
||||||
t.Fatal("缺少execution_id应该返回错误")
|
|
||||||
}
|
|
||||||
|
|
||||||
// 测试5: 不存在的执行ID
|
|
||||||
args5 := map[string]interface{}{
|
|
||||||
"execution_id": "nonexistent_id",
|
|
||||||
"page": float64(1),
|
|
||||||
}
|
|
||||||
|
|
||||||
toolResult5, err := executor.executeQueryExecutionResult(ctx, args5)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("执行查询失败: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if !toolResult5.IsError {
|
|
||||||
t.Fatal("不存在的执行ID应该返回错误")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestExecutor_ExecuteInternalTool_UnknownTool(t *testing.T) {
|
func TestExecutor_ExecuteInternalTool_UnknownTool(t *testing.T) {
|
||||||
executor, _ := setupTestExecutor(t)
|
executor, _ := setupTestExecutor(t)
|
||||||
|
|
||||||
@@ -182,29 +48,6 @@ func TestExecutor_ExecuteInternalTool_UnknownTool(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestExecutor_ExecuteInternalTool_NoStorage(t *testing.T) {
|
|
||||||
executor, _ := setupTestExecutor(t)
|
|
||||||
// 不设置存储,测试未初始化的情况
|
|
||||||
|
|
||||||
ctx := context.Background()
|
|
||||||
args := map[string]interface{}{
|
|
||||||
"execution_id": "test_id",
|
|
||||||
}
|
|
||||||
|
|
||||||
toolResult, err := executor.executeQueryExecutionResult(ctx, args)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("执行查询失败: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if !toolResult.IsError {
|
|
||||||
t.Fatal("未初始化的存储应该返回错误")
|
|
||||||
}
|
|
||||||
|
|
||||||
if !strings.Contains(toolResult.Content[0].Text, "结果存储未初始化") {
|
|
||||||
t.Errorf("错误消息应该包含'结果存储未初始化'")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestExecuteSystemCommand_BackgroundDoesNotBlockOnChildStdout(t *testing.T) {
|
func TestExecuteSystemCommand_BackgroundDoesNotBlockOnChildStdout(t *testing.T) {
|
||||||
executor, _ := setupTestExecutor(t)
|
executor, _ := setupTestExecutor(t)
|
||||||
// 子进程先向 stdout 写无换行字符再长时间 sleep;若与 echo $pid 共享管道且未重定向子进程 stdout,
|
// 子进程先向 stdout 写无换行字符再长时间 sleep;若与 echo $pid 共享管道且未重定向子进程 stdout,
|
||||||
@@ -228,63 +71,58 @@ func TestExecuteSystemCommand_BackgroundDoesNotBlockOnChildStdout(t *testing.T)
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestPaginateLines(t *testing.T) {
|
func TestBuildCommandArgs_NmapSkipsEmptyOptionalFlags(t *testing.T) {
|
||||||
lines := []string{"Line 1", "Line 2", "Line 3", "Line 4", "Line 5"}
|
pos1 := 1
|
||||||
|
executor, _ := setupTestExecutor(t)
|
||||||
// 测试第一页
|
toolConfig := &config.ToolConfig{
|
||||||
page := paginateLines(lines, 1, 2)
|
Name: "nmap",
|
||||||
if page.Page != 1 {
|
Command: "nmap",
|
||||||
t.Errorf("页码不匹配。期望: 1, 实际: %d", page.Page)
|
Args: []string{"-sT", "-sV", "-sC"},
|
||||||
}
|
Parameters: []config.ParameterConfig{
|
||||||
if page.Limit != 2 {
|
{Name: "target", Type: "string", Required: true, Position: &pos1, Format: "positional"},
|
||||||
t.Errorf("每页行数不匹配。期望: 2, 实际: %d", page.Limit)
|
{Name: "ports", Type: "string", Flag: "-p", Format: "flag"},
|
||||||
}
|
{Name: "timing", Type: "string", Template: "-T{value}", Format: "template"},
|
||||||
if page.TotalLines != 5 {
|
{Name: "nse_scripts", Type: "string", Flag: "--script", Format: "flag"},
|
||||||
t.Errorf("总行数不匹配。期望: 5, 实际: %d", page.TotalLines)
|
{Name: "os_detection", Type: "bool", Flag: "-O", Format: "flag", Default: false},
|
||||||
}
|
{Name: "aggressive", Type: "bool", Flag: "-A", Format: "flag", Default: false},
|
||||||
if page.TotalPages != 3 {
|
{Name: "scan_type", Type: "string", Format: "template", Template: "{value}"},
|
||||||
t.Errorf("总页数不匹配。期望: 3, 实际: %d", page.TotalPages)
|
{Name: "additional_args", Type: "string", Format: "positional"},
|
||||||
}
|
},
|
||||||
if len(page.Lines) != 2 {
|
|
||||||
t.Errorf("第一页行数不匹配。期望: 2, 实际: %d", len(page.Lines))
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// 测试第二页
|
args := map[string]interface{}{
|
||||||
page2 := paginateLines(lines, 2, 2)
|
"target": "110.52.223.114",
|
||||||
if len(page2.Lines) != 2 {
|
"ports": "21, 22, 80, 443",
|
||||||
t.Errorf("第二页行数不匹配。期望: 2, 实际: %d", len(page2.Lines))
|
"timing": "4",
|
||||||
}
|
"nse_scripts": "",
|
||||||
if page2.Lines[0] != "Line 3" {
|
"scan_type": "",
|
||||||
t.Errorf("第二页第一行不匹配。期望: Line 3, 实际: %s", page2.Lines[0])
|
"os_detection": false,
|
||||||
|
"aggressive": false,
|
||||||
|
"additional_args": "-Pn",
|
||||||
}
|
}
|
||||||
|
|
||||||
// 测试最后一页
|
cmdArgs := executor.buildCommandArgs("nmap", toolConfig, args)
|
||||||
page3 := paginateLines(lines, 3, 2)
|
joined := strings.Join(cmdArgs, " ")
|
||||||
if len(page3.Lines) != 1 {
|
|
||||||
t.Errorf("第三页行数不匹配。期望: 1, 实际: %d", len(page3.Lines))
|
if strings.Contains(joined, "--script") {
|
||||||
|
t.Fatalf("empty nse_scripts must not emit --script, got: %v", cmdArgs)
|
||||||
|
}
|
||||||
|
if !strings.Contains(joined, "110.52.223.114") {
|
||||||
|
t.Fatalf("target missing from args: %v", cmdArgs)
|
||||||
|
}
|
||||||
|
// target 应出现在 -Pn 之前,避免被误当作 --script 的参数
|
||||||
|
pnIdx := indexOf(cmdArgs, "-Pn")
|
||||||
|
targetIdx := indexOf(cmdArgs, "110.52.223.114")
|
||||||
|
if pnIdx < 0 || targetIdx < 0 || targetIdx >= pnIdx {
|
||||||
|
t.Fatalf("expected target before -Pn, got: %v", cmdArgs)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// 测试超出范围的页码(应该返回最后一页)
|
func indexOf(slice []string, s string) int {
|
||||||
page4 := paginateLines(lines, 4, 2)
|
for i, v := range slice {
|
||||||
if page4.Page != 3 {
|
if v == s {
|
||||||
t.Errorf("超出范围的页码应该被修正为最后一页。期望: 3, 实际: %d", page4.Page)
|
return i
|
||||||
}
|
|
||||||
if len(page4.Lines) != 1 {
|
|
||||||
t.Errorf("最后一页应该只有1行。实际: %d行", len(page4.Lines))
|
|
||||||
}
|
|
||||||
|
|
||||||
// 测试无效页码(小于1)
|
|
||||||
page0 := paginateLines(lines, 0, 2)
|
|
||||||
if page0.Page != 1 {
|
|
||||||
t.Errorf("无效页码应该被修正为1。实际: %d", page0.Page)
|
|
||||||
}
|
|
||||||
|
|
||||||
// 测试空列表
|
|
||||||
emptyPage := paginateLines([]string{}, 1, 10)
|
|
||||||
if emptyPage.TotalLines != 0 {
|
|
||||||
t.Errorf("空列表的总行数应该为0。实际: %d", emptyPage.TotalLines)
|
|
||||||
}
|
|
||||||
if len(emptyPage.Lines) != 0 {
|
|
||||||
t.Errorf("空列表应该返回空结果。实际: %d行", len(emptyPage.Lines))
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
return -1
|
||||||
|
}
|
||||||
|
|||||||
@@ -1,297 +0,0 @@
|
|||||||
package storage
|
|
||||||
|
|
||||||
import (
|
|
||||||
"encoding/json"
|
|
||||||
"fmt"
|
|
||||||
"os"
|
|
||||||
"path/filepath"
|
|
||||||
"regexp"
|
|
||||||
"strings"
|
|
||||||
"sync"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"go.uber.org/zap"
|
|
||||||
)
|
|
||||||
|
|
||||||
// ResultStorage 结果存储接口
|
|
||||||
type ResultStorage interface {
|
|
||||||
// SaveResult 保存工具执行结果
|
|
||||||
SaveResult(executionID string, toolName string, result string) error
|
|
||||||
|
|
||||||
// GetResult 获取完整结果
|
|
||||||
GetResult(executionID string) (string, error)
|
|
||||||
|
|
||||||
// GetResultPage 分页获取结果
|
|
||||||
GetResultPage(executionID string, page int, limit int) (*ResultPage, error)
|
|
||||||
|
|
||||||
// SearchResult 搜索结果
|
|
||||||
// useRegex: 如果为 true,将 keyword 作为正则表达式使用;如果为 false,使用简单的字符串包含匹配
|
|
||||||
SearchResult(executionID string, keyword string, useRegex bool) ([]string, error)
|
|
||||||
|
|
||||||
// FilterResult 过滤结果
|
|
||||||
// useRegex: 如果为 true,将 filter 作为正则表达式使用;如果为 false,使用简单的字符串包含匹配
|
|
||||||
FilterResult(executionID string, filter string, useRegex bool) ([]string, error)
|
|
||||||
|
|
||||||
// GetResultMetadata 获取结果元信息
|
|
||||||
GetResultMetadata(executionID string) (*ResultMetadata, error)
|
|
||||||
|
|
||||||
// GetResultPath 获取结果文件路径
|
|
||||||
GetResultPath(executionID string) string
|
|
||||||
|
|
||||||
// DeleteResult 删除结果
|
|
||||||
DeleteResult(executionID string) error
|
|
||||||
}
|
|
||||||
|
|
||||||
// ResultPage 分页结果
|
|
||||||
type ResultPage struct {
|
|
||||||
Lines []string `json:"lines"`
|
|
||||||
Page int `json:"page"`
|
|
||||||
Limit int `json:"limit"`
|
|
||||||
TotalLines int `json:"total_lines"`
|
|
||||||
TotalPages int `json:"total_pages"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// ResultMetadata 结果元信息
|
|
||||||
type ResultMetadata struct {
|
|
||||||
ExecutionID string `json:"execution_id"`
|
|
||||||
ToolName string `json:"tool_name"`
|
|
||||||
TotalSize int `json:"total_size"`
|
|
||||||
TotalLines int `json:"total_lines"`
|
|
||||||
CreatedAt time.Time `json:"created_at"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// FileResultStorage 基于文件的结果存储实现
|
|
||||||
type FileResultStorage struct {
|
|
||||||
baseDir string
|
|
||||||
logger *zap.Logger
|
|
||||||
mu sync.RWMutex
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewFileResultStorage 创建新的文件结果存储
|
|
||||||
func NewFileResultStorage(baseDir string, logger *zap.Logger) (*FileResultStorage, error) {
|
|
||||||
// 确保目录存在
|
|
||||||
if err := os.MkdirAll(baseDir, 0755); err != nil {
|
|
||||||
return nil, fmt.Errorf("创建存储目录失败: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return &FileResultStorage{
|
|
||||||
baseDir: baseDir,
|
|
||||||
logger: logger,
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// getResultPath 获取结果文件路径
|
|
||||||
func (s *FileResultStorage) getResultPath(executionID string) string {
|
|
||||||
return filepath.Join(s.baseDir, executionID+".txt")
|
|
||||||
}
|
|
||||||
|
|
||||||
// getMetadataPath 获取元数据文件路径
|
|
||||||
func (s *FileResultStorage) getMetadataPath(executionID string) string {
|
|
||||||
return filepath.Join(s.baseDir, executionID+".meta.json")
|
|
||||||
}
|
|
||||||
|
|
||||||
// SaveResult 保存工具执行结果
|
|
||||||
func (s *FileResultStorage) SaveResult(executionID string, toolName string, result string) error {
|
|
||||||
s.mu.Lock()
|
|
||||||
defer s.mu.Unlock()
|
|
||||||
|
|
||||||
// 保存结果文件
|
|
||||||
resultPath := s.getResultPath(executionID)
|
|
||||||
if err := os.WriteFile(resultPath, []byte(result), 0644); err != nil {
|
|
||||||
return fmt.Errorf("保存结果文件失败: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// 计算统计信息
|
|
||||||
lines := strings.Split(result, "\n")
|
|
||||||
metadata := &ResultMetadata{
|
|
||||||
ExecutionID: executionID,
|
|
||||||
ToolName: toolName,
|
|
||||||
TotalSize: len(result),
|
|
||||||
TotalLines: len(lines),
|
|
||||||
CreatedAt: time.Now(),
|
|
||||||
}
|
|
||||||
|
|
||||||
// 保存元数据
|
|
||||||
metadataPath := s.getMetadataPath(executionID)
|
|
||||||
metadataJSON, err := json.Marshal(metadata)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("序列化元数据失败: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := os.WriteFile(metadataPath, metadataJSON, 0644); err != nil {
|
|
||||||
return fmt.Errorf("保存元数据文件失败: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
s.logger.Info("保存工具执行结果",
|
|
||||||
zap.String("executionID", executionID),
|
|
||||||
zap.String("toolName", toolName),
|
|
||||||
zap.Int("size", len(result)),
|
|
||||||
zap.Int("lines", len(lines)),
|
|
||||||
)
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetResult 获取完整结果
|
|
||||||
func (s *FileResultStorage) GetResult(executionID string) (string, error) {
|
|
||||||
s.mu.RLock()
|
|
||||||
defer s.mu.RUnlock()
|
|
||||||
|
|
||||||
resultPath := s.getResultPath(executionID)
|
|
||||||
data, err := os.ReadFile(resultPath)
|
|
||||||
if err != nil {
|
|
||||||
if os.IsNotExist(err) {
|
|
||||||
return "", fmt.Errorf("结果不存在: %s", executionID)
|
|
||||||
}
|
|
||||||
return "", fmt.Errorf("读取结果文件失败: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return string(data), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetResultMetadata 获取结果元信息
|
|
||||||
func (s *FileResultStorage) GetResultMetadata(executionID string) (*ResultMetadata, error) {
|
|
||||||
s.mu.RLock()
|
|
||||||
defer s.mu.RUnlock()
|
|
||||||
|
|
||||||
metadataPath := s.getMetadataPath(executionID)
|
|
||||||
data, err := os.ReadFile(metadataPath)
|
|
||||||
if err != nil {
|
|
||||||
if os.IsNotExist(err) {
|
|
||||||
return nil, fmt.Errorf("结果不存在: %s", executionID)
|
|
||||||
}
|
|
||||||
return nil, fmt.Errorf("读取元数据文件失败: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
var metadata ResultMetadata
|
|
||||||
if err := json.Unmarshal(data, &metadata); err != nil {
|
|
||||||
return nil, fmt.Errorf("解析元数据失败: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return &metadata, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetResultPage 分页获取结果
|
|
||||||
func (s *FileResultStorage) GetResultPage(executionID string, page int, limit int) (*ResultPage, error) {
|
|
||||||
s.mu.RLock()
|
|
||||||
defer s.mu.RUnlock()
|
|
||||||
|
|
||||||
// 获取完整结果
|
|
||||||
result, err := s.GetResult(executionID)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// 分割为行
|
|
||||||
lines := strings.Split(result, "\n")
|
|
||||||
totalLines := len(lines)
|
|
||||||
|
|
||||||
// 计算分页
|
|
||||||
totalPages := (totalLines + limit - 1) / limit
|
|
||||||
if page < 1 {
|
|
||||||
page = 1
|
|
||||||
}
|
|
||||||
if page > totalPages && totalPages > 0 {
|
|
||||||
page = totalPages
|
|
||||||
}
|
|
||||||
|
|
||||||
// 计算起始和结束索引
|
|
||||||
start := (page - 1) * limit
|
|
||||||
end := start + limit
|
|
||||||
if end > totalLines {
|
|
||||||
end = totalLines
|
|
||||||
}
|
|
||||||
|
|
||||||
// 提取指定页的行
|
|
||||||
var pageLines []string
|
|
||||||
if start < totalLines {
|
|
||||||
pageLines = lines[start:end]
|
|
||||||
} else {
|
|
||||||
pageLines = []string{}
|
|
||||||
}
|
|
||||||
|
|
||||||
return &ResultPage{
|
|
||||||
Lines: pageLines,
|
|
||||||
Page: page,
|
|
||||||
Limit: limit,
|
|
||||||
TotalLines: totalLines,
|
|
||||||
TotalPages: totalPages,
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// SearchResult 搜索结果
|
|
||||||
func (s *FileResultStorage) SearchResult(executionID string, keyword string, useRegex bool) ([]string, error) {
|
|
||||||
s.mu.RLock()
|
|
||||||
defer s.mu.RUnlock()
|
|
||||||
|
|
||||||
// 获取完整结果
|
|
||||||
result, err := s.GetResult(executionID)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// 如果使用正则表达式,先编译正则
|
|
||||||
var regex *regexp.Regexp
|
|
||||||
if useRegex {
|
|
||||||
compiledRegex, err := regexp.Compile(keyword)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("无效的正则表达式: %w", err)
|
|
||||||
}
|
|
||||||
regex = compiledRegex
|
|
||||||
}
|
|
||||||
|
|
||||||
// 分割为行并搜索
|
|
||||||
lines := strings.Split(result, "\n")
|
|
||||||
var matchedLines []string
|
|
||||||
|
|
||||||
for _, line := range lines {
|
|
||||||
var matched bool
|
|
||||||
if useRegex {
|
|
||||||
matched = regex.MatchString(line)
|
|
||||||
} else {
|
|
||||||
matched = strings.Contains(line, keyword)
|
|
||||||
}
|
|
||||||
|
|
||||||
if matched {
|
|
||||||
matchedLines = append(matchedLines, line)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return matchedLines, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// FilterResult 过滤结果
|
|
||||||
func (s *FileResultStorage) FilterResult(executionID string, filter string, useRegex bool) ([]string, error) {
|
|
||||||
// 过滤和搜索逻辑相同,都是查找包含关键词的行
|
|
||||||
return s.SearchResult(executionID, filter, useRegex)
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetResultPath 获取结果文件路径
|
|
||||||
func (s *FileResultStorage) GetResultPath(executionID string) string {
|
|
||||||
return s.getResultPath(executionID)
|
|
||||||
}
|
|
||||||
|
|
||||||
// DeleteResult 删除结果
|
|
||||||
func (s *FileResultStorage) DeleteResult(executionID string) error {
|
|
||||||
s.mu.Lock()
|
|
||||||
defer s.mu.Unlock()
|
|
||||||
|
|
||||||
resultPath := s.getResultPath(executionID)
|
|
||||||
metadataPath := s.getMetadataPath(executionID)
|
|
||||||
|
|
||||||
// 删除结果文件
|
|
||||||
if err := os.Remove(resultPath); err != nil && !os.IsNotExist(err) {
|
|
||||||
return fmt.Errorf("删除结果文件失败: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// 删除元数据文件
|
|
||||||
if err := os.Remove(metadataPath); err != nil && !os.IsNotExist(err) {
|
|
||||||
return fmt.Errorf("删除元数据文件失败: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
s.logger.Info("删除工具执行结果",
|
|
||||||
zap.String("executionID", executionID),
|
|
||||||
)
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
@@ -1,453 +0,0 @@
|
|||||||
package storage
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
"os"
|
|
||||||
"path/filepath"
|
|
||||||
"strings"
|
|
||||||
"testing"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"go.uber.org/zap"
|
|
||||||
)
|
|
||||||
|
|
||||||
// setupTestStorage 创建测试用的存储实例
|
|
||||||
func setupTestStorage(t *testing.T) (*FileResultStorage, string) {
|
|
||||||
tmpDir := filepath.Join(os.TempDir(), "test_result_storage_"+time.Now().Format("20060102_150405"))
|
|
||||||
logger := zap.NewNop()
|
|
||||||
|
|
||||||
storage, err := NewFileResultStorage(tmpDir, logger)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("创建测试存储失败: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return storage, tmpDir
|
|
||||||
}
|
|
||||||
|
|
||||||
// cleanupTestStorage 清理测试数据
|
|
||||||
func cleanupTestStorage(t *testing.T, tmpDir string) {
|
|
||||||
if err := os.RemoveAll(tmpDir); err != nil {
|
|
||||||
t.Logf("清理测试目录失败: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestNewFileResultStorage(t *testing.T) {
|
|
||||||
tmpDir := filepath.Join(os.TempDir(), "test_new_storage_"+time.Now().Format("20060102_150405"))
|
|
||||||
defer cleanupTestStorage(t, tmpDir)
|
|
||||||
|
|
||||||
logger := zap.NewNop()
|
|
||||||
storage, err := NewFileResultStorage(tmpDir, logger)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("创建存储失败: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if storage == nil {
|
|
||||||
t.Fatal("存储实例为nil")
|
|
||||||
}
|
|
||||||
|
|
||||||
// 验证目录已创建
|
|
||||||
if _, err := os.Stat(tmpDir); os.IsNotExist(err) {
|
|
||||||
t.Fatal("存储目录未创建")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestFileResultStorage_SaveResult(t *testing.T) {
|
|
||||||
storage, tmpDir := setupTestStorage(t)
|
|
||||||
defer cleanupTestStorage(t, tmpDir)
|
|
||||||
|
|
||||||
executionID := "test_exec_001"
|
|
||||||
toolName := "nmap_scan"
|
|
||||||
result := "Line 1\nLine 2\nLine 3\nLine 4\nLine 5"
|
|
||||||
|
|
||||||
err := storage.SaveResult(executionID, toolName, result)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("保存结果失败: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// 验证结果文件存在
|
|
||||||
resultPath := filepath.Join(tmpDir, executionID+".txt")
|
|
||||||
if _, err := os.Stat(resultPath); os.IsNotExist(err) {
|
|
||||||
t.Fatal("结果文件未创建")
|
|
||||||
}
|
|
||||||
|
|
||||||
// 验证元数据文件存在
|
|
||||||
metadataPath := filepath.Join(tmpDir, executionID+".meta.json")
|
|
||||||
if _, err := os.Stat(metadataPath); os.IsNotExist(err) {
|
|
||||||
t.Fatal("元数据文件未创建")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestFileResultStorage_GetResult(t *testing.T) {
|
|
||||||
storage, tmpDir := setupTestStorage(t)
|
|
||||||
defer cleanupTestStorage(t, tmpDir)
|
|
||||||
|
|
||||||
executionID := "test_exec_002"
|
|
||||||
toolName := "test_tool"
|
|
||||||
expectedResult := "Test result content\nLine 2\nLine 3"
|
|
||||||
|
|
||||||
// 先保存结果
|
|
||||||
err := storage.SaveResult(executionID, toolName, expectedResult)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("保存结果失败: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// 获取结果
|
|
||||||
result, err := storage.GetResult(executionID)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("获取结果失败: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if result != expectedResult {
|
|
||||||
t.Errorf("结果不匹配。期望: %q, 实际: %q", expectedResult, result)
|
|
||||||
}
|
|
||||||
|
|
||||||
// 测试不存在的执行ID
|
|
||||||
_, err = storage.GetResult("nonexistent_id")
|
|
||||||
if err == nil {
|
|
||||||
t.Fatal("应该返回错误")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestFileResultStorage_GetResultMetadata(t *testing.T) {
|
|
||||||
storage, tmpDir := setupTestStorage(t)
|
|
||||||
defer cleanupTestStorage(t, tmpDir)
|
|
||||||
|
|
||||||
executionID := "test_exec_003"
|
|
||||||
toolName := "test_tool"
|
|
||||||
result := "Line 1\nLine 2\nLine 3"
|
|
||||||
|
|
||||||
// 保存结果
|
|
||||||
err := storage.SaveResult(executionID, toolName, result)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("保存结果失败: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// 获取元数据
|
|
||||||
metadata, err := storage.GetResultMetadata(executionID)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("获取元数据失败: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if metadata.ExecutionID != executionID {
|
|
||||||
t.Errorf("执行ID不匹配。期望: %s, 实际: %s", executionID, metadata.ExecutionID)
|
|
||||||
}
|
|
||||||
|
|
||||||
if metadata.ToolName != toolName {
|
|
||||||
t.Errorf("工具名称不匹配。期望: %s, 实际: %s", toolName, metadata.ToolName)
|
|
||||||
}
|
|
||||||
|
|
||||||
if metadata.TotalSize != len(result) {
|
|
||||||
t.Errorf("总大小不匹配。期望: %d, 实际: %d", len(result), metadata.TotalSize)
|
|
||||||
}
|
|
||||||
|
|
||||||
expectedLines := len(strings.Split(result, "\n"))
|
|
||||||
if metadata.TotalLines != expectedLines {
|
|
||||||
t.Errorf("总行数不匹配。期望: %d, 实际: %d", expectedLines, metadata.TotalLines)
|
|
||||||
}
|
|
||||||
|
|
||||||
// 验证创建时间在合理范围内
|
|
||||||
now := time.Now()
|
|
||||||
if metadata.CreatedAt.After(now) || metadata.CreatedAt.Before(now.Add(-time.Second)) {
|
|
||||||
t.Errorf("创建时间不在合理范围内: %v", metadata.CreatedAt)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestFileResultStorage_GetResultPage(t *testing.T) {
|
|
||||||
storage, tmpDir := setupTestStorage(t)
|
|
||||||
defer cleanupTestStorage(t, tmpDir)
|
|
||||||
|
|
||||||
executionID := "test_exec_004"
|
|
||||||
toolName := "test_tool"
|
|
||||||
// 创建包含10行的结果
|
|
||||||
lines := make([]string, 10)
|
|
||||||
for i := 0; i < 10; i++ {
|
|
||||||
lines[i] = fmt.Sprintf("Line %d", i+1)
|
|
||||||
}
|
|
||||||
result := strings.Join(lines, "\n")
|
|
||||||
|
|
||||||
// 保存结果
|
|
||||||
err := storage.SaveResult(executionID, toolName, result)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("保存结果失败: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// 测试第一页(每页3行)
|
|
||||||
page, err := storage.GetResultPage(executionID, 1, 3)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("获取第一页失败: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if page.Page != 1 {
|
|
||||||
t.Errorf("页码不匹配。期望: 1, 实际: %d", page.Page)
|
|
||||||
}
|
|
||||||
|
|
||||||
if page.Limit != 3 {
|
|
||||||
t.Errorf("每页行数不匹配。期望: 3, 实际: %d", page.Limit)
|
|
||||||
}
|
|
||||||
|
|
||||||
if page.TotalLines != 10 {
|
|
||||||
t.Errorf("总行数不匹配。期望: 10, 实际: %d", page.TotalLines)
|
|
||||||
}
|
|
||||||
|
|
||||||
if page.TotalPages != 4 {
|
|
||||||
t.Errorf("总页数不匹配。期望: 4, 实际: %d", page.TotalPages)
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(page.Lines) != 3 {
|
|
||||||
t.Errorf("第一页行数不匹配。期望: 3, 实际: %d", len(page.Lines))
|
|
||||||
}
|
|
||||||
|
|
||||||
if page.Lines[0] != "Line 1" {
|
|
||||||
t.Errorf("第一行内容不匹配。期望: Line 1, 实际: %s", page.Lines[0])
|
|
||||||
}
|
|
||||||
|
|
||||||
// 测试第二页
|
|
||||||
page2, err := storage.GetResultPage(executionID, 2, 3)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("获取第二页失败: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(page2.Lines) != 3 {
|
|
||||||
t.Errorf("第二页行数不匹配。期望: 3, 实际: %d", len(page2.Lines))
|
|
||||||
}
|
|
||||||
|
|
||||||
if page2.Lines[0] != "Line 4" {
|
|
||||||
t.Errorf("第二页第一行内容不匹配。期望: Line 4, 实际: %s", page2.Lines[0])
|
|
||||||
}
|
|
||||||
|
|
||||||
// 测试最后一页(可能不满一页)
|
|
||||||
page4, err := storage.GetResultPage(executionID, 4, 3)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("获取第四页失败: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(page4.Lines) != 1 {
|
|
||||||
t.Errorf("第四页行数不匹配。期望: 1, 实际: %d", len(page4.Lines))
|
|
||||||
}
|
|
||||||
|
|
||||||
// 测试超出范围的页码(应该返回最后一页)
|
|
||||||
page5, err := storage.GetResultPage(executionID, 5, 3)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("获取第五页失败: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// 超出范围的页码会被修正为最后一页,所以应该返回最后一页的内容
|
|
||||||
if page5.Page != 4 {
|
|
||||||
t.Errorf("超出范围的页码应该被修正为最后一页。期望: 4, 实际: %d", page5.Page)
|
|
||||||
}
|
|
||||||
|
|
||||||
// 最后一页应该只有1行
|
|
||||||
if len(page5.Lines) != 1 {
|
|
||||||
t.Errorf("最后一页应该只有1行。实际: %d行", len(page5.Lines))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestFileResultStorage_SearchResult(t *testing.T) {
|
|
||||||
storage, tmpDir := setupTestStorage(t)
|
|
||||||
defer cleanupTestStorage(t, tmpDir)
|
|
||||||
|
|
||||||
executionID := "test_exec_005"
|
|
||||||
toolName := "test_tool"
|
|
||||||
result := "Line 1: error occurred\nLine 2: success\nLine 3: error again\nLine 4: ok"
|
|
||||||
|
|
||||||
// 保存结果
|
|
||||||
err := storage.SaveResult(executionID, toolName, result)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("保存结果失败: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// 搜索包含"error"的行(简单字符串匹配)
|
|
||||||
matchedLines, err := storage.SearchResult(executionID, "error", false)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("搜索失败: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(matchedLines) != 2 {
|
|
||||||
t.Errorf("搜索结果数量不匹配。期望: 2, 实际: %d", len(matchedLines))
|
|
||||||
}
|
|
||||||
|
|
||||||
// 验证搜索结果内容
|
|
||||||
for i, line := range matchedLines {
|
|
||||||
if !strings.Contains(line, "error") {
|
|
||||||
t.Errorf("搜索结果第%d行不包含关键词: %s", i+1, line)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// 测试搜索不存在的关键词
|
|
||||||
noMatch, err := storage.SearchResult(executionID, "nonexistent", false)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("搜索失败: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(noMatch) != 0 {
|
|
||||||
t.Errorf("搜索不存在的关键词应该返回空结果。实际: %d行", len(noMatch))
|
|
||||||
}
|
|
||||||
|
|
||||||
// 测试正则表达式搜索
|
|
||||||
regexMatched, err := storage.SearchResult(executionID, "error.*again", true)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("正则搜索失败: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(regexMatched) != 1 {
|
|
||||||
t.Errorf("正则搜索结果数量不匹配。期望: 1, 实际: %d", len(regexMatched))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestFileResultStorage_FilterResult(t *testing.T) {
|
|
||||||
storage, tmpDir := setupTestStorage(t)
|
|
||||||
defer cleanupTestStorage(t, tmpDir)
|
|
||||||
|
|
||||||
executionID := "test_exec_006"
|
|
||||||
toolName := "test_tool"
|
|
||||||
result := "Line 1: warning message\nLine 2: info message\nLine 3: warning again\nLine 4: debug message"
|
|
||||||
|
|
||||||
// 保存结果
|
|
||||||
err := storage.SaveResult(executionID, toolName, result)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("保存结果失败: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// 过滤包含"warning"的行(简单字符串匹配)
|
|
||||||
filteredLines, err := storage.FilterResult(executionID, "warning", false)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("过滤失败: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(filteredLines) != 2 {
|
|
||||||
t.Errorf("过滤结果数量不匹配。期望: 2, 实际: %d", len(filteredLines))
|
|
||||||
}
|
|
||||||
|
|
||||||
// 验证过滤结果内容
|
|
||||||
for i, line := range filteredLines {
|
|
||||||
if !strings.Contains(line, "warning") {
|
|
||||||
t.Errorf("过滤结果第%d行不包含关键词: %s", i+1, line)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestFileResultStorage_DeleteResult(t *testing.T) {
|
|
||||||
storage, tmpDir := setupTestStorage(t)
|
|
||||||
defer cleanupTestStorage(t, tmpDir)
|
|
||||||
|
|
||||||
executionID := "test_exec_007"
|
|
||||||
toolName := "test_tool"
|
|
||||||
result := "Test result"
|
|
||||||
|
|
||||||
// 保存结果
|
|
||||||
err := storage.SaveResult(executionID, toolName, result)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("保存结果失败: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// 验证文件存在
|
|
||||||
resultPath := filepath.Join(tmpDir, executionID+".txt")
|
|
||||||
metadataPath := filepath.Join(tmpDir, executionID+".meta.json")
|
|
||||||
|
|
||||||
if _, err := os.Stat(resultPath); os.IsNotExist(err) {
|
|
||||||
t.Fatal("结果文件不存在")
|
|
||||||
}
|
|
||||||
|
|
||||||
if _, err := os.Stat(metadataPath); os.IsNotExist(err) {
|
|
||||||
t.Fatal("元数据文件不存在")
|
|
||||||
}
|
|
||||||
|
|
||||||
// 删除结果
|
|
||||||
err = storage.DeleteResult(executionID)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("删除结果失败: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// 验证文件已删除
|
|
||||||
if _, err := os.Stat(resultPath); !os.IsNotExist(err) {
|
|
||||||
t.Fatal("结果文件未被删除")
|
|
||||||
}
|
|
||||||
|
|
||||||
if _, err := os.Stat(metadataPath); !os.IsNotExist(err) {
|
|
||||||
t.Fatal("元数据文件未被删除")
|
|
||||||
}
|
|
||||||
|
|
||||||
// 测试删除不存在的执行ID(应该不报错)
|
|
||||||
err = storage.DeleteResult("nonexistent_id")
|
|
||||||
if err != nil {
|
|
||||||
t.Errorf("删除不存在的执行ID不应该报错: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestFileResultStorage_ConcurrentAccess(t *testing.T) {
|
|
||||||
storage, tmpDir := setupTestStorage(t)
|
|
||||||
defer cleanupTestStorage(t, tmpDir)
|
|
||||||
|
|
||||||
// 并发保存多个结果
|
|
||||||
done := make(chan bool, 10)
|
|
||||||
for i := 0; i < 10; i++ {
|
|
||||||
go func(id int) {
|
|
||||||
executionID := fmt.Sprintf("test_exec_%d", id)
|
|
||||||
toolName := "test_tool"
|
|
||||||
result := fmt.Sprintf("Result %d\nLine 2\nLine 3", id)
|
|
||||||
|
|
||||||
err := storage.SaveResult(executionID, toolName, result)
|
|
||||||
if err != nil {
|
|
||||||
t.Errorf("并发保存失败 (ID: %s): %v", executionID, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// 并发读取
|
|
||||||
_, err = storage.GetResult(executionID)
|
|
||||||
if err != nil {
|
|
||||||
t.Errorf("并发读取失败 (ID: %s): %v", executionID, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
done <- true
|
|
||||||
}(i)
|
|
||||||
}
|
|
||||||
|
|
||||||
// 等待所有goroutine完成
|
|
||||||
for i := 0; i < 10; i++ {
|
|
||||||
<-done
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestFileResultStorage_LargeResult(t *testing.T) {
|
|
||||||
storage, tmpDir := setupTestStorage(t)
|
|
||||||
defer cleanupTestStorage(t, tmpDir)
|
|
||||||
|
|
||||||
executionID := "test_exec_large"
|
|
||||||
toolName := "test_tool"
|
|
||||||
|
|
||||||
// 创建大结果(1000行)
|
|
||||||
lines := make([]string, 1000)
|
|
||||||
for i := 0; i < 1000; i++ {
|
|
||||||
lines[i] = fmt.Sprintf("Line %d: This is a test line with some content", i+1)
|
|
||||||
}
|
|
||||||
result := strings.Join(lines, "\n")
|
|
||||||
|
|
||||||
// 保存大结果
|
|
||||||
err := storage.SaveResult(executionID, toolName, result)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("保存大结果失败: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// 验证元数据
|
|
||||||
metadata, err := storage.GetResultMetadata(executionID)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("获取元数据失败: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if metadata.TotalLines != 1000 {
|
|
||||||
t.Errorf("总行数不匹配。期望: 1000, 实际: %d", metadata.TotalLines)
|
|
||||||
}
|
|
||||||
|
|
||||||
// 测试分页查询大结果
|
|
||||||
page, err := storage.GetResultPage(executionID, 1, 100)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("获取第一页失败: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if page.TotalPages != 10 {
|
|
||||||
t.Errorf("总页数不匹配。期望: 10, 实际: %d", page.TotalPages)
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(page.Lines) != 100 {
|
|
||||||
t.Errorf("第一页行数不匹配。期望: 100, 实际: %d", len(page.Lines))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
+5
-5
@@ -27,13 +27,13 @@ parameters:
|
|||||||
type: "string"
|
type: "string"
|
||||||
description: "数据源(wayback,commoncrawl,otx,urlscan)"
|
description: "数据源(wayback,commoncrawl,otx,urlscan)"
|
||||||
required: false
|
required: false
|
||||||
flag: "-providers"
|
flag: "--providers"
|
||||||
format: "flag"
|
format: "flag"
|
||||||
- name: "include_subs"
|
- name: "include_subs"
|
||||||
type: "bool"
|
type: "bool"
|
||||||
description: "包含子域名"
|
description: "包含子域名"
|
||||||
required: false
|
required: false
|
||||||
flag: "-subs"
|
flag: "--subs"
|
||||||
format: "flag"
|
format: "flag"
|
||||||
default: true
|
default: true
|
||||||
- name: "additional_args"
|
- name: "additional_args"
|
||||||
@@ -42,9 +42,9 @@ parameters:
|
|||||||
额外的Gau参数。用于传递未在参数列表中定义的Gau选项。
|
额外的Gau参数。用于传递未在参数列表中定义的Gau选项。
|
||||||
|
|
||||||
**示例值:**
|
**示例值:**
|
||||||
- "-o output.txt": 输出到文件
|
- "--o output.txt": 输出到文件
|
||||||
- "-t": 线程数
|
- "--threads 4": 线程数
|
||||||
- "-b": 黑名单扩展
|
- "--blacklist ttf,woff,svg,png": 黑名单扩展
|
||||||
|
|
||||||
**注意事项:**
|
**注意事项:**
|
||||||
- 多个参数用空格分隔
|
- 多个参数用空格分隔
|
||||||
|
|||||||
+1178
-47
File diff suppressed because it is too large
Load Diff
+1924
-98
File diff suppressed because it is too large
Load Diff
+230
-3
@@ -258,10 +258,73 @@
|
|||||||
"vulnerabilityManagement": "Vulnerability management",
|
"vulnerabilityManagement": "Vulnerability management",
|
||||||
"addFactCta": "+ Add fact",
|
"addFactCta": "+ Add fact",
|
||||||
"tabFacts": "Fact board",
|
"tabFacts": "Fact board",
|
||||||
|
"tabGraph": "Attack path",
|
||||||
"tabConversations": "Bound conversations",
|
"tabConversations": "Bound conversations",
|
||||||
"tabVulns": "Related vulnerabilities",
|
"tabVulns": "Related vulnerabilities",
|
||||||
"tabSettings": "Settings",
|
"tabSettings": "Settings",
|
||||||
"factToolbarHint": "Index includes key and summary only (must include what + where + how to verify); put attack chain / POC in body, and reproduce via get_project_fact.",
|
"factToolbarHint": "Index includes key and summary only (must include what + where + how to verify); put attack chain / POC in body, and reproduce via get_project_fact.",
|
||||||
|
"graphToolbarHint": "Graph arrows match stored fact links (source → target). Nodes are layered target→infra→finding→exploit. Dashed edges are tentative.",
|
||||||
|
"graphView": "View",
|
||||||
|
"graphViewPath": "Attack path",
|
||||||
|
"graphViewFull": "Full graph",
|
||||||
|
"graphSearchSr": "Search nodes",
|
||||||
|
"graphSearchPlaceholder": "Search nodes…",
|
||||||
|
"graphRefresh": "Refresh",
|
||||||
|
"graphCenter": "Center",
|
||||||
|
"graphEmpty": "No graph data yet. Add links on finding/exploit facts (discovered_on → target/*) to build the path.",
|
||||||
|
"graphEmptyTitle": "Build your attack path",
|
||||||
|
"graphEmptyStep1": "Add target facts (domains, endpoints, scope)",
|
||||||
|
"graphEmptyStep2": "Record findings/exploits with links between facts",
|
||||||
|
"graphEmptyStep3": "Use Connect mode or edit facts to add relationships",
|
||||||
|
"graphEmptyCta": "Add first fact",
|
||||||
|
"graphStats": "Nodes: {{nodes}} | Edges: {{edges}}",
|
||||||
|
"graphStatsNodes": "Nodes",
|
||||||
|
"graphStatsEdges": "Edges",
|
||||||
|
"graphLegendNodes": "Nodes",
|
||||||
|
"graphLegendEdges": "Edges",
|
||||||
|
"graphLegendNodeTarget": "TARGET",
|
||||||
|
"graphLegendNodeInfra": "INFRA",
|
||||||
|
"graphLegendNodeFinding": "FINDING",
|
||||||
|
"graphLegendNodeVuln": "VULN",
|
||||||
|
"graphLegendNodeExploit": "EXPLOIT",
|
||||||
|
"graphLegendNodeMissing": "MISSING",
|
||||||
|
"graphLegendDiscovered": "discovered_on",
|
||||||
|
"graphLegendLeads": "leads_to",
|
||||||
|
"graphLegendExploits": "exploits",
|
||||||
|
"graphLegendTentative": "Tentative (dashed)",
|
||||||
|
"factLinksLabel": "Links (from → this fact)",
|
||||||
|
"factLinksPlaceholder": "discovered_on: target/primary_domain\nexploits: exploit/upload-rce",
|
||||||
|
"factLinksHint": "One per line: type: source_fact_key (source → this fact). Common types: discovered_on, depends_on, leads_to, enables, exploits. Saving replaces all links.",
|
||||||
|
"factIncomingLinksLabel": "Incoming links (read-only)",
|
||||||
|
"factIncomingLinksHint": "Derived from outgoing links on source facts. e.g. finding discovered_on → target/* appears as incoming on the target; edit the source fact's outgoing links.",
|
||||||
|
"factIncomingLinksEmpty": "No incoming links",
|
||||||
|
"graphEdgeFromSelf": "From this node",
|
||||||
|
"graphEdgeToSelf": "To this node",
|
||||||
|
"linksColumn": "Links",
|
||||||
|
"linkCountsTitle": "Outgoing / incoming edge counts",
|
||||||
|
"graphConnect": "Connect",
|
||||||
|
"graphConnectActive": "Connecting…",
|
||||||
|
"graphConnectPickTarget": "Source {{source}} selected — click target node",
|
||||||
|
"graphEdgeTypePrompt": "Edge type (discovered_on / leads_to / depends_on / enables / exploits)",
|
||||||
|
"graphConnectFailed": "Failed to create edge",
|
||||||
|
"graphConnectSuccess": "Edge created",
|
||||||
|
"graphEdgesTitle": "Links",
|
||||||
|
"graphEdgesHint": "Arrow direction matches the database and edit modal (source → target). Click an edge to focus it.",
|
||||||
|
"graphEdgesEmpty": "No links yet",
|
||||||
|
"graphEdgeOutgoing": "Outgoing",
|
||||||
|
"graphEdgeIncoming": "Incoming",
|
||||||
|
"graphEdgeSynthetic": "Auto-generated from fact link; edit the fact to remove",
|
||||||
|
"confirmDeleteGraphEdge": "Delete this link?",
|
||||||
|
"graphEdgeDeleteFailed": "Failed to delete edge",
|
||||||
|
"graphEdgeDeleteSuccess": "Edge deleted",
|
||||||
|
"graphDeleteEdge": "Delete",
|
||||||
|
"viewVulnerability": "View vulnerability",
|
||||||
|
"graphVulnSidebarHint": "Linked vulnerability node. Use the button below to open it in Vulnerability Management.",
|
||||||
|
"promoteAttackChain": "Promote chain",
|
||||||
|
"promoteAttackChainTitle": "Promote conversation attack chain to project facts",
|
||||||
|
"confirmPromoteAttackChain": "Promote this conversation's attack chain into the project? Facts and edges will be created or updated.",
|
||||||
|
"promoteAttackChainFailed": "Promote failed",
|
||||||
|
"promoteAttackChainSuccess": "Promoted: {{facts_created}} new / {{facts_updated}} updated / {{edges_created}} edges",
|
||||||
"searchFactsSr": "Search facts",
|
"searchFactsSr": "Search facts",
|
||||||
"searchFactsPlaceholder": "Search key, summary, body…",
|
"searchFactsPlaceholder": "Search key, summary, body…",
|
||||||
"category": "Category",
|
"category": "Category",
|
||||||
@@ -436,6 +499,9 @@
|
|||||||
"conversationGroups": "Conversation groups",
|
"conversationGroups": "Conversation groups",
|
||||||
"addGroup": "New group",
|
"addGroup": "New group",
|
||||||
"recentConversations": "Recent conversations",
|
"recentConversations": "Recent conversations",
|
||||||
|
"sortConversations": "Sort",
|
||||||
|
"sortByCreatedAt": "Created time",
|
||||||
|
"sortByUpdatedAt": "Updated time",
|
||||||
"batchManage": "Batch manage",
|
"batchManage": "Batch manage",
|
||||||
"paginationShow": "Show {{start}}-{{end}} of {{total}}",
|
"paginationShow": "Show {{start}}-{{end}} of {{total}}",
|
||||||
"paginationRange": "{{start}}-{{end}}/{{total}}",
|
"paginationRange": "{{start}}-{{end}}/{{total}}",
|
||||||
@@ -676,7 +742,12 @@
|
|||||||
"viewConversation": "View conversation",
|
"viewConversation": "View conversation",
|
||||||
"viewVulnerabilities": "View vulnerabilities",
|
"viewVulnerabilities": "View vulnerabilities",
|
||||||
"viewVulnerabilitiesQueueTitle": "View vulnerabilities: open management filtered to this queue",
|
"viewVulnerabilitiesQueueTitle": "View vulnerabilities: open management filtered to this queue",
|
||||||
"retryTask": "Retry",
|
"runSingleTask": "Run task",
|
||||||
|
"confirmRunSingleTask": "Run this task only? The queue will pause when it finishes and will not continue other pending items.",
|
||||||
|
"runSingleTaskFailed": "Failed to run task",
|
||||||
|
"runSingleTaskUnavailable": "Unavailable while the queue or a task is running",
|
||||||
|
"runSingleTaskUnavailableSelf": "This task is running",
|
||||||
|
"runSingleTaskUnavailableQueue": "Queue is running; pause it before running another task individually",
|
||||||
"conversationIdLabel": "Conversation ID",
|
"conversationIdLabel": "Conversation ID",
|
||||||
"statusPending": "Pending",
|
"statusPending": "Pending",
|
||||||
"statusPaused": "Paused",
|
"statusPaused": "Paused",
|
||||||
@@ -1083,6 +1154,7 @@
|
|||||||
"botAgent": "Bot Agent",
|
"botAgent": "Bot Agent",
|
||||||
"ilinkBotId": "iLink Bot ID (filled after bind)",
|
"ilinkBotId": "iLink Bot ID (filled after bind)",
|
||||||
"boundSuccess": "Binding successful. WeChat bot is enabled.",
|
"boundSuccess": "Binding successful. WeChat bot is enabled.",
|
||||||
|
"alreadyBound": "This WeChat account is already bound.",
|
||||||
"openLink": "QR not showing? Open link in WeChat on your phone"
|
"openLink": "QR not showing? Open link in WeChat on your phone"
|
||||||
},
|
},
|
||||||
"wecom": {
|
"wecom": {
|
||||||
@@ -1938,6 +2010,13 @@
|
|||||||
"openaiBaseUrlPlaceholder": "https://api.openai.com/v1",
|
"openaiBaseUrlPlaceholder": "https://api.openai.com/v1",
|
||||||
"openaiApiKeyPlaceholder": "Enter OpenAI API Key",
|
"openaiApiKeyPlaceholder": "Enter OpenAI API Key",
|
||||||
"modelPlaceholder": "gpt-4",
|
"modelPlaceholder": "gpt-4",
|
||||||
|
"fetchModels": "Fetch list",
|
||||||
|
"modelsListFetching": "Fetching model list...",
|
||||||
|
"modelsListSelectPlaceholder": "Select a model",
|
||||||
|
"modelsListSuccess": "Loaded {count} models — use the dropdown on the right, or type in the input",
|
||||||
|
"modelsListFailed": "Failed to fetch model list",
|
||||||
|
"modelsListNeedApiKey": "Please enter API Key first",
|
||||||
|
"modelsListClaudeHint": "Claude does not support auto model list; enter the model name manually",
|
||||||
"maxTotalTokens": "Max Context Tokens",
|
"maxTotalTokens": "Max Context Tokens",
|
||||||
"maxTotalTokensPlaceholder": "120000",
|
"maxTotalTokensPlaceholder": "120000",
|
||||||
"maxTotalTokensHint": "Shared by memory compression and attack chain building. Default: 120000",
|
"maxTotalTokensHint": "Shared by memory compression and attack chain building. Default: 120000",
|
||||||
@@ -2086,14 +2165,35 @@
|
|||||||
"filterResult": "Result",
|
"filterResult": "Result",
|
||||||
"pageSize": "Per page",
|
"pageSize": "Per page",
|
||||||
"statTotal": "Filtered total",
|
"statTotal": "Filtered total",
|
||||||
|
"statSuccess": "Success",
|
||||||
"statFailures": "Failures",
|
"statFailures": "Failures",
|
||||||
"statRecent7d": "Last 7 days",
|
"statRecent7d": "Last 7 days",
|
||||||
"retentionHint": "Audit records are kept for {{days}} days, then purged automatically.",
|
"retentionHint": "Audit records are kept for {{days}} days, then purged automatically.",
|
||||||
"disabledHint": "Audit logging is disabled; new actions are not written.",
|
"disabledHint": "Audit logging is disabled; new actions are not written.",
|
||||||
"filterSince": "From",
|
"filterSince": "From",
|
||||||
"filterUntil": "Until",
|
"filterUntil": "Until",
|
||||||
|
"filterTimeZone": "Timezone: {{tz}} (filter uses your browser's local time)",
|
||||||
|
"datetimePlaceholder": "Select date & time",
|
||||||
|
"timePresets": "Quick range",
|
||||||
|
"preset15m": "Last 15 min",
|
||||||
|
"preset1h": "Last 1 hour",
|
||||||
|
"preset24h": "Last 24 hours",
|
||||||
|
"preset7d": "Last 7 days",
|
||||||
|
"presetToday": "Today",
|
||||||
|
"pickerHour": "Hour",
|
||||||
|
"pickerMinute": "Min",
|
||||||
|
"pickerClear": "Clear",
|
||||||
|
"pickerToday": "Today",
|
||||||
|
"pickerConfirm": "OK",
|
||||||
"filterQuery": "Keyword",
|
"filterQuery": "Keyword",
|
||||||
"filterQueryPlaceholder": "Message / resource ID / action",
|
"filterQueryPlaceholder": "Message / resource ID / action",
|
||||||
|
"colTime": "Time",
|
||||||
|
"colMessage": "Message",
|
||||||
|
"colCategory": "Category",
|
||||||
|
"colAction": "Action",
|
||||||
|
"colResult": "Result",
|
||||||
|
"colIp": "IP",
|
||||||
|
"colResource": "Resource ID",
|
||||||
"cat": {
|
"cat": {
|
||||||
"auth": "Auth",
|
"auth": "Auth",
|
||||||
"config": "Config",
|
"config": "Config",
|
||||||
@@ -2166,6 +2266,93 @@
|
|||||||
"exportDone": "Export complete",
|
"exportDone": "Export complete",
|
||||||
"loading": "Loading...",
|
"loading": "Loading...",
|
||||||
"empty": "No audit records",
|
"empty": "No audit records",
|
||||||
|
"result": {
|
||||||
|
"success": "success",
|
||||||
|
"failure": "failure"
|
||||||
|
},
|
||||||
|
"msg": {
|
||||||
|
"auth": {
|
||||||
|
"login": "Login successful",
|
||||||
|
"login_failed": "Login failed: incorrect password",
|
||||||
|
"logout": "Logged out",
|
||||||
|
"change_password": "Login password changed",
|
||||||
|
"change_password_failed": "Password change failed: current password incorrect"
|
||||||
|
},
|
||||||
|
"config": {
|
||||||
|
"apply": "Configuration applied",
|
||||||
|
"update": "In-memory configuration updated",
|
||||||
|
"apply_fail_kb_init": "Failed to apply config: knowledge base init",
|
||||||
|
"apply_fail_kb_reinit": "Failed to apply config: knowledge base re-init",
|
||||||
|
"apply_fail_c2": "Failed to apply config: C2"
|
||||||
|
},
|
||||||
|
"conversation": {
|
||||||
|
"create": "Conversation created",
|
||||||
|
"delete": "Conversation deleted",
|
||||||
|
"delete_turn": "Conversation turn deleted"
|
||||||
|
},
|
||||||
|
"c2": {
|
||||||
|
"listener_create": "C2 listener created",
|
||||||
|
"listener_delete": "C2 listener deleted",
|
||||||
|
"listener_start": "C2 listener started",
|
||||||
|
"listener_stop": "C2 listener stopped",
|
||||||
|
"session_delete": "C2 session deleted",
|
||||||
|
"task_create": "C2 task created",
|
||||||
|
"task_cancel": "C2 task cancelled",
|
||||||
|
"task_delete": "C2 tasks deleted (batch)"
|
||||||
|
},
|
||||||
|
"webshell": {
|
||||||
|
"connection_create": "WebShell connection created",
|
||||||
|
"connection_delete": "WebShell connection deleted"
|
||||||
|
},
|
||||||
|
"knowledge": {
|
||||||
|
"item_delete": "Knowledge item deleted",
|
||||||
|
"index_rebuild": "Knowledge index rebuilt"
|
||||||
|
},
|
||||||
|
"vulnerability": {
|
||||||
|
"create": "Vulnerability record created",
|
||||||
|
"update": "Vulnerability record updated",
|
||||||
|
"delete": "Vulnerability record deleted",
|
||||||
|
"delete_batch": "Vulnerability records deleted (batch)"
|
||||||
|
},
|
||||||
|
"external_mcp": {
|
||||||
|
"upsert": "External MCP configuration updated",
|
||||||
|
"delete": "External MCP configuration deleted"
|
||||||
|
},
|
||||||
|
"task": {
|
||||||
|
"create_queue": "Batch task queue created",
|
||||||
|
"start_queue": "Batch task queue started",
|
||||||
|
"delete_queue": "Batch task queue deleted",
|
||||||
|
"pause_queue": "Batch task queue paused",
|
||||||
|
"rerun_queue": "Batch task queue rerun",
|
||||||
|
"delete_batch_task": "Batch subtask deleted"
|
||||||
|
},
|
||||||
|
"tool": {
|
||||||
|
"execution_delete": "Tool execution record deleted",
|
||||||
|
"execution_delete_batch": "Tool execution records deleted (batch)"
|
||||||
|
},
|
||||||
|
"file": {
|
||||||
|
"upload": "Chat attachment uploaded",
|
||||||
|
"delete": "Chat attachment deleted"
|
||||||
|
},
|
||||||
|
"hitl": {
|
||||||
|
"decision": "HITL approval decision"
|
||||||
|
},
|
||||||
|
"role": {
|
||||||
|
"create": "Role created",
|
||||||
|
"update": "Role updated",
|
||||||
|
"delete": "Role deleted"
|
||||||
|
},
|
||||||
|
"skill": {
|
||||||
|
"create": "Skill created",
|
||||||
|
"update": "Skill updated",
|
||||||
|
"delete": "Skill deleted"
|
||||||
|
},
|
||||||
|
"agent": {
|
||||||
|
"markdown_create": "Markdown sub-agent created",
|
||||||
|
"markdown_update": "Markdown sub-agent updated",
|
||||||
|
"markdown_delete": "Markdown sub-agent deleted"
|
||||||
|
}
|
||||||
|
},
|
||||||
"paginationShow": "{{start}}-{{end}} of {{total}}",
|
"paginationShow": "{{start}}-{{end}} of {{total}}",
|
||||||
"detailTitle": "Audit detail",
|
"detailTitle": "Audit detail",
|
||||||
"detailTime": "Time",
|
"detailTime": "Time",
|
||||||
@@ -2244,7 +2431,8 @@
|
|||||||
"copyContent": "Copy content",
|
"copyContent": "Copy content",
|
||||||
"correctInfo": "Correct info",
|
"correctInfo": "Correct info",
|
||||||
"errorInfo": "Error info",
|
"errorInfo": "Error info",
|
||||||
"copyError": "Copy error"
|
"copyError": "Copy error",
|
||||||
|
"contentTruncated": "… (display truncated; use read_file on the path in persisted-output for the full file)"
|
||||||
},
|
},
|
||||||
"attackChainModal": {
|
"attackChainModal": {
|
||||||
"title": "Attack chain",
|
"title": "Attack chain",
|
||||||
@@ -2574,6 +2762,11 @@
|
|||||||
},
|
},
|
||||||
"c2": {
|
"c2": {
|
||||||
"clipboardCopied": "Copied to clipboard",
|
"clipboardCopied": "Copied to clipboard",
|
||||||
|
"common": {
|
||||||
|
"justNow": "Just now",
|
||||||
|
"minutesAgo": "{{n}}m ago",
|
||||||
|
"hoursAgo": "{{n}}h ago"
|
||||||
|
},
|
||||||
"fmt": {
|
"fmt": {
|
||||||
"durationMs": "{{n}}ms",
|
"durationMs": "{{n}}ms",
|
||||||
"durationSec": "{{n}}s",
|
"durationSec": "{{n}}s",
|
||||||
@@ -2631,6 +2824,8 @@
|
|||||||
"bindHintExternal": "Use 0.0.0.0 to allow external access",
|
"bindHintExternal": "Use 0.0.0.0 to allow external access",
|
||||||
"callbackHost": "Callback host (optional)",
|
"callbackHost": "Callback host (optional)",
|
||||||
"callbackHostHint": "Public IP or hostname stored for payloads/beacons; separate from bind address. If empty, payload generation falls back to bind address / auto-detect.",
|
"callbackHostHint": "Public IP or hostname stored for payloads/beacons; separate from bind address. If empty, payload generation falls back to bind address / auto-detect.",
|
||||||
|
"allowLegacyShell": "Allow unencrypted classic reverse shell (lab only)",
|
||||||
|
"allowLegacyShellHint": "Off by default. When enabled, raw bash/nc TCP connections register sessions and are vulnerable to internet scanners; use encrypted Beacon builds for production.",
|
||||||
"malleableProfile": "Malleable Profile",
|
"malleableProfile": "Malleable Profile",
|
||||||
"malleableProfileHint": "Optional; HTTP/HTTPS Beacon response headers and traffic disguise. Stop and start the listener again for changes to take effect.",
|
"malleableProfileHint": "Optional; HTTP/HTTPS Beacon response headers and traffic disguise. Stop and start the listener again for changes to take effect.",
|
||||||
"malleableProfileNone": "None",
|
"malleableProfileNone": "None",
|
||||||
@@ -2708,10 +2903,22 @@
|
|||||||
"infoFirstSeen": "First seen",
|
"infoFirstSeen": "First seen",
|
||||||
"infoLastCheckin": "Last check-in",
|
"infoLastCheckin": "Last check-in",
|
||||||
"infoNote": "Note",
|
"infoNote": "Note",
|
||||||
|
"infoNoteEmpty": "No notes",
|
||||||
|
"infoSectionIdentity": "Identity",
|
||||||
|
"infoSectionSystem": "System",
|
||||||
|
"infoSectionNetwork": "Network & beacon",
|
||||||
|
"infoSectionTimeline": "Timeline",
|
||||||
|
"infoSectionNote": "Notes",
|
||||||
"adminYes": "Yes",
|
"adminYes": "Yes",
|
||||||
"adminNo": "No",
|
"adminNo": "No",
|
||||||
"promptSleepSeconds": "Sleep interval (seconds)",
|
"promptSleepSeconds": "Sleep interval (seconds)",
|
||||||
"promptJitterPercent": "Jitter percent (0–100)",
|
"promptJitterPercent": "Jitter percent (0–100)",
|
||||||
|
"sleepModalHint": "Saves to the server and queues a sleep task. The implant applies it on the next task poll; later check-ins keep this config.",
|
||||||
|
"sleepModalTitle": "Beacon interval",
|
||||||
|
"sleepModalCurrent": "Current {{sec}}s · jitter {{jitter}}%",
|
||||||
|
"sleepModalPreview": "Estimated {{min}} – {{max}} s",
|
||||||
|
"sleepModalPresets": "Presets",
|
||||||
|
"toastSleepInvalid": "Sleep interval must be at least 1 second",
|
||||||
"toastSleepUpdated": "Sleep settings updated",
|
"toastSleepUpdated": "Sleep settings updated",
|
||||||
"confirmExitSession": "Send exit command to this session?",
|
"confirmExitSession": "Send exit command to this session?",
|
||||||
"confirmDeleteSession": "Remove this session and related tasks/files from the server? (Does not send exit to the implant; use Kill Session to exit the agent.)",
|
"confirmDeleteSession": "Remove this session and related tasks/files from the server? (Does not send exit to the implant; use Kill Session to exit the agent.)",
|
||||||
@@ -2729,7 +2936,25 @@
|
|||||||
"termWaitFinish": "Please wait for the current command to finish",
|
"termWaitFinish": "Please wait for the current command to finish",
|
||||||
"termCtrlC": "Remote interrupt is not supported in this version",
|
"termCtrlC": "Remote interrupt is not supported in this version",
|
||||||
"termQueued": "[Command queued — will run after the current task completes]",
|
"termQueued": "[Command queued — will run after the current task completes]",
|
||||||
"clearTerminal": "Clear"
|
"clearTerminal": "Clear",
|
||||||
|
"batchDelete": "Delete selected",
|
||||||
|
"deleteFiltered": "Delete filtered",
|
||||||
|
"selectAll": "Select all",
|
||||||
|
"filterAllStatus": "All statuses",
|
||||||
|
"filterAllListeners": "All listeners",
|
||||||
|
"filterSearchPlaceholder": "Search hostname / user / IP",
|
||||||
|
"filterApply": "Filter",
|
||||||
|
"filterReset": "Reset",
|
||||||
|
"filterSuspicious": "Likely false positives",
|
||||||
|
"filterCount": "{{n}} total, {{selected}} selected",
|
||||||
|
"emptyFilter": "No sessions match the current filters",
|
||||||
|
"listEmpty": "No sessions",
|
||||||
|
"selectPromptTitle": "Select a session",
|
||||||
|
"selectPromptHint": "Click a session in the list on the left to view terminal, files, and tasks.",
|
||||||
|
"confirmBatchDelete": "Delete {{n}} selected session(s)? Related tasks and file records will be removed.",
|
||||||
|
"confirmDeleteFiltered": "Delete all {{n}} session(s) in the current filter results?",
|
||||||
|
"toastSelectFirst": "Select at least one session to delete",
|
||||||
|
"toastBatchDeleted": "Deleted {{n}} session(s)"
|
||||||
},
|
},
|
||||||
"tasks": {
|
"tasks": {
|
||||||
"title": "Task Management",
|
"title": "Task Management",
|
||||||
@@ -2752,6 +2977,8 @@
|
|||||||
"pending": "Pending",
|
"pending": "Pending",
|
||||||
"emptyAll": "No tasks yet",
|
"emptyAll": "No tasks yet",
|
||||||
"emptySession": "No tasks for this session",
|
"emptySession": "No tasks for this session",
|
||||||
|
"sessionTaskHistory": "Task history",
|
||||||
|
"sessionTaskCount": "{{n}} tasks",
|
||||||
"colTask": "Task",
|
"colTask": "Task",
|
||||||
"colSession": "Session",
|
"colSession": "Session",
|
||||||
"colType": "Type",
|
"colType": "Type",
|
||||||
|
|||||||
+230
-3
@@ -246,10 +246,73 @@
|
|||||||
"vulnerabilityManagement": "漏洞管理",
|
"vulnerabilityManagement": "漏洞管理",
|
||||||
"addFactCta": "+ 添加事实",
|
"addFactCta": "+ 添加事实",
|
||||||
"tabFacts": "事实黑板",
|
"tabFacts": "事实黑板",
|
||||||
|
"tabGraph": "攻击路径",
|
||||||
"tabConversations": "关联对话",
|
"tabConversations": "关联对话",
|
||||||
"tabVulns": "关联漏洞",
|
"tabVulns": "关联漏洞",
|
||||||
"tabSettings": "设置",
|
"tabSettings": "设置",
|
||||||
"factToolbarHint": "索引仅含 key 与摘要(须含「什么 + 在哪 + 如何验证」);攻击链 / POC 写在 body,Agent 通过 get_project_fact 复现",
|
"factToolbarHint": "索引仅含 key 与摘要(须含「什么 + 在哪 + 如何验证」);攻击链 / POC 写在 body,Agent 通过 get_project_fact 复现",
|
||||||
|
"graphToolbarHint": "攻击路径图箭头与事实存储方向一致(source → target);节点按 target→infra→finding→exploit 分层排布。虚线边为待确认。",
|
||||||
|
"graphView": "视图",
|
||||||
|
"graphViewPath": "攻击路径",
|
||||||
|
"graphViewFull": "完整关系",
|
||||||
|
"graphSearchSr": "搜索节点",
|
||||||
|
"graphSearchPlaceholder": "搜索节点…",
|
||||||
|
"graphRefresh": "刷新",
|
||||||
|
"graphCenter": "居中",
|
||||||
|
"graphEmpty": "暂无路径图数据。为 finding/exploit 类事实添加关系边(discovered_on → target/*)后将在此展示。",
|
||||||
|
"graphEmptyTitle": "构建攻击路径图",
|
||||||
|
"graphEmptyStep1": "添加 target 类事实(目标、域名、入口)",
|
||||||
|
"graphEmptyStep2": "记录 finding / exploit 并在 links 中连边",
|
||||||
|
"graphEmptyStep3": "使用「连边」模式或编辑事实手动补关系",
|
||||||
|
"graphEmptyCta": "添加第一条事实",
|
||||||
|
"graphStats": "节点: {{nodes}} | 边: {{edges}}",
|
||||||
|
"graphStatsNodes": "节点",
|
||||||
|
"graphStatsEdges": "边",
|
||||||
|
"graphLegendNodes": "节点",
|
||||||
|
"graphLegendEdges": "连线",
|
||||||
|
"graphLegendNodeTarget": "TARGET · 目标",
|
||||||
|
"graphLegendNodeInfra": "INFRA · 基础设施",
|
||||||
|
"graphLegendNodeFinding": "FINDING · 发现",
|
||||||
|
"graphLegendNodeVuln": "VULN · 漏洞",
|
||||||
|
"graphLegendNodeExploit": "EXPLOIT · 利用",
|
||||||
|
"graphLegendNodeMissing": "MISSING · 缺失",
|
||||||
|
"graphLegendDiscovered": "discovered_on",
|
||||||
|
"graphLegendLeads": "leads_to",
|
||||||
|
"graphLegendExploits": "exploits",
|
||||||
|
"graphLegendTentative": "待确认(虚线)",
|
||||||
|
"factLinksLabel": "关系边(from → 本事实)",
|
||||||
|
"factLinksPlaceholder": "discovered_on: target/primary_domain\nexploits: exploit/upload-rce",
|
||||||
|
"factLinksHint": "每行一条:type: source_fact_key(来源 → 当前事实)。常用 type:discovered_on、depends_on、leads_to、enables、exploits。保存时替换全部关系边。",
|
||||||
|
"factIncomingLinksLabel": "入边(只读)",
|
||||||
|
"factIncomingLinksHint": "由来源事实的出边产生。例如 finding 的 discovered_on → target/*,在目标上会显示为入边;请编辑来源事实的出边。",
|
||||||
|
"factIncomingLinksEmpty": "暂无入边",
|
||||||
|
"graphEdgeFromSelf": "本节点指出",
|
||||||
|
"graphEdgeToSelf": "指向本节点",
|
||||||
|
"linksColumn": "关系",
|
||||||
|
"linkCountsTitle": "出边数 / 入边数",
|
||||||
|
"graphConnect": "连边",
|
||||||
|
"graphConnectActive": "连边中…",
|
||||||
|
"graphConnectPickTarget": "已选 {{source}},请点击目标节点",
|
||||||
|
"graphEdgeTypePrompt": "边类型(discovered_on / leads_to / depends_on / enables / exploits)",
|
||||||
|
"graphConnectFailed": "创建边失败",
|
||||||
|
"graphConnectSuccess": "边已创建",
|
||||||
|
"graphEdgesTitle": "关系边",
|
||||||
|
"graphEdgesHint": "箭头方向与数据库/编辑弹窗一致(source → target);点击连线可定位。",
|
||||||
|
"graphEdgesEmpty": "暂无关系边",
|
||||||
|
"graphEdgeOutgoing": "出边",
|
||||||
|
"graphEdgeIncoming": "入边",
|
||||||
|
"graphEdgeSynthetic": "由事实关联自动生成,请编辑事实解除",
|
||||||
|
"confirmDeleteGraphEdge": "确定删除此关系边?",
|
||||||
|
"graphEdgeDeleteFailed": "删除边失败",
|
||||||
|
"graphEdgeDeleteSuccess": "边已删除",
|
||||||
|
"graphDeleteEdge": "删边",
|
||||||
|
"viewVulnerability": "查看漏洞",
|
||||||
|
"graphVulnSidebarHint": "关联漏洞节点,点击下方按钮在漏洞管理中查看详情。",
|
||||||
|
"promoteAttackChain": "沉淀攻击链",
|
||||||
|
"promoteAttackChainTitle": "将对话攻击链沉淀为项目事实与边",
|
||||||
|
"confirmPromoteAttackChain": "将该对话的攻击链沉淀到本项目?会创建/更新事实与关系边。",
|
||||||
|
"promoteAttackChainFailed": "沉淀失败",
|
||||||
|
"promoteAttackChainSuccess": "已沉淀:新建 {{facts_created}} / 更新 {{facts_updated}} / 边 {{edges_created}}",
|
||||||
"searchFactsSr": "搜索事实",
|
"searchFactsSr": "搜索事实",
|
||||||
"searchFactsPlaceholder": "搜索 key、摘要、body…",
|
"searchFactsPlaceholder": "搜索 key、摘要、body…",
|
||||||
"category": "分类",
|
"category": "分类",
|
||||||
@@ -424,6 +487,9 @@
|
|||||||
"conversationGroups": "对话分组",
|
"conversationGroups": "对话分组",
|
||||||
"addGroup": "新建分组",
|
"addGroup": "新建分组",
|
||||||
"recentConversations": "最近对话",
|
"recentConversations": "最近对话",
|
||||||
|
"sortConversations": "排序",
|
||||||
|
"sortByCreatedAt": "创建时间",
|
||||||
|
"sortByUpdatedAt": "更新时间",
|
||||||
"batchManage": "批量管理",
|
"batchManage": "批量管理",
|
||||||
"paginationShow": "显示 {{start}}-{{end}} / 共 {{total}}",
|
"paginationShow": "显示 {{start}}-{{end}} / 共 {{total}}",
|
||||||
"paginationRange": "{{start}}-{{end}}/{{total}}",
|
"paginationRange": "{{start}}-{{end}}/{{total}}",
|
||||||
@@ -664,7 +730,12 @@
|
|||||||
"viewConversation": "查看对话",
|
"viewConversation": "查看对话",
|
||||||
"viewVulnerabilities": "查看漏洞",
|
"viewVulnerabilities": "查看漏洞",
|
||||||
"viewVulnerabilitiesQueueTitle": "查看漏洞:打开漏洞管理并筛选本队列",
|
"viewVulnerabilitiesQueueTitle": "查看漏洞:打开漏洞管理并筛选本队列",
|
||||||
"retryTask": "重试",
|
"runSingleTask": "单条执行",
|
||||||
|
"confirmRunSingleTask": "确定执行该任务?仅运行这一条,完成后队列会自动暂停,不会继续执行其他待执行项。",
|
||||||
|
"runSingleTaskFailed": "单条执行失败",
|
||||||
|
"runSingleTaskUnavailable": "队列或任务执行中,暂无法单条执行",
|
||||||
|
"runSingleTaskUnavailableSelf": "该任务正在执行中",
|
||||||
|
"runSingleTaskUnavailableQueue": "队列批量执行中,请暂停后再单条执行其它任务",
|
||||||
"conversationIdLabel": "对话ID",
|
"conversationIdLabel": "对话ID",
|
||||||
"statusPending": "待执行",
|
"statusPending": "待执行",
|
||||||
"statusPaused": "已暂停",
|
"statusPaused": "已暂停",
|
||||||
@@ -1071,6 +1142,7 @@
|
|||||||
"botAgent": "Bot Agent",
|
"botAgent": "Bot Agent",
|
||||||
"ilinkBotId": "iLink Bot ID(绑定后自动填充)",
|
"ilinkBotId": "iLink Bot ID(绑定后自动填充)",
|
||||||
"boundSuccess": "绑定成功,微信机器人已启用。",
|
"boundSuccess": "绑定成功,微信机器人已启用。",
|
||||||
|
"alreadyBound": "该微信已绑定过,无需重复绑定。",
|
||||||
"openLink": "无法显示二维码?点击用手机微信打开链接"
|
"openLink": "无法显示二维码?点击用手机微信打开链接"
|
||||||
},
|
},
|
||||||
"wecom": {
|
"wecom": {
|
||||||
@@ -1926,6 +1998,13 @@
|
|||||||
"openaiBaseUrlPlaceholder": "https://api.openai.com/v1",
|
"openaiBaseUrlPlaceholder": "https://api.openai.com/v1",
|
||||||
"openaiApiKeyPlaceholder": "输入OpenAI API Key",
|
"openaiApiKeyPlaceholder": "输入OpenAI API Key",
|
||||||
"modelPlaceholder": "gpt-4",
|
"modelPlaceholder": "gpt-4",
|
||||||
|
"fetchModels": "获取列表",
|
||||||
|
"modelsListFetching": "正在获取模型列表...",
|
||||||
|
"modelsListSelectPlaceholder": "请选择模型",
|
||||||
|
"modelsListSuccess": "已加载 {count} 个模型,请用右侧下拉框选择,或继续在左侧输入",
|
||||||
|
"modelsListFailed": "获取模型列表失败",
|
||||||
|
"modelsListNeedApiKey": "请先填写 API Key",
|
||||||
|
"modelsListClaudeHint": "Claude 不支持自动获取模型列表,请手动填写",
|
||||||
"maxTotalTokens": "最大上下文 Token 数",
|
"maxTotalTokens": "最大上下文 Token 数",
|
||||||
"maxTotalTokensPlaceholder": "120000",
|
"maxTotalTokensPlaceholder": "120000",
|
||||||
"maxTotalTokensHint": "内存压缩和攻击链构建共用此配置,默认 120000",
|
"maxTotalTokensHint": "内存压缩和攻击链构建共用此配置,默认 120000",
|
||||||
@@ -2074,14 +2153,35 @@
|
|||||||
"filterResult": "结果",
|
"filterResult": "结果",
|
||||||
"pageSize": "每页",
|
"pageSize": "每页",
|
||||||
"statTotal": "当前筛选",
|
"statTotal": "当前筛选",
|
||||||
|
"statSuccess": "成功",
|
||||||
"statFailures": "失败",
|
"statFailures": "失败",
|
||||||
"statRecent7d": "近 7 天",
|
"statRecent7d": "近 7 天",
|
||||||
"retentionHint": "审计记录保留 {{days}} 天,超期自动清理。",
|
"retentionHint": "审计记录保留 {{days}} 天,超期自动清理。",
|
||||||
"disabledHint": "审计功能已关闭,新操作不会写入审计表。",
|
"disabledHint": "审计功能已关闭,新操作不会写入审计表。",
|
||||||
"filterSince": "开始时间",
|
"filterSince": "开始时间",
|
||||||
"filterUntil": "结束时间",
|
"filterUntil": "结束时间",
|
||||||
|
"filterTimeZone": "时区:{{tz}}(筛选按浏览器本地时间)",
|
||||||
|
"datetimePlaceholder": "选择日期时间",
|
||||||
|
"timePresets": "快捷",
|
||||||
|
"preset15m": "最近15分钟",
|
||||||
|
"preset1h": "最近1小时",
|
||||||
|
"preset24h": "最近24小时",
|
||||||
|
"preset7d": "最近7天",
|
||||||
|
"presetToday": "今天",
|
||||||
|
"pickerHour": "时",
|
||||||
|
"pickerMinute": "分",
|
||||||
|
"pickerClear": "清除",
|
||||||
|
"pickerToday": "今天",
|
||||||
|
"pickerConfirm": "确定",
|
||||||
"filterQuery": "关键词",
|
"filterQuery": "关键词",
|
||||||
"filterQueryPlaceholder": "消息 / 资源 ID / 操作名",
|
"filterQueryPlaceholder": "消息 / 资源 ID / 操作名",
|
||||||
|
"colTime": "时间",
|
||||||
|
"colMessage": "说明",
|
||||||
|
"colCategory": "类别",
|
||||||
|
"colAction": "操作",
|
||||||
|
"colResult": "结果",
|
||||||
|
"colIp": "IP",
|
||||||
|
"colResource": "资源 ID",
|
||||||
"cat": {
|
"cat": {
|
||||||
"auth": "认证",
|
"auth": "认证",
|
||||||
"config": "配置",
|
"config": "配置",
|
||||||
@@ -2154,6 +2254,93 @@
|
|||||||
"exportDone": "导出完成",
|
"exportDone": "导出完成",
|
||||||
"loading": "加载中...",
|
"loading": "加载中...",
|
||||||
"empty": "暂无审计记录",
|
"empty": "暂无审计记录",
|
||||||
|
"result": {
|
||||||
|
"success": "成功",
|
||||||
|
"failure": "失败"
|
||||||
|
},
|
||||||
|
"msg": {
|
||||||
|
"auth": {
|
||||||
|
"login": "登录成功",
|
||||||
|
"login_failed": "登录失败:密码错误",
|
||||||
|
"logout": "退出登录",
|
||||||
|
"change_password": "登录密码已修改",
|
||||||
|
"change_password_failed": "修改密码失败:当前密码不正确"
|
||||||
|
},
|
||||||
|
"config": {
|
||||||
|
"apply": "配置已应用",
|
||||||
|
"update": "更新内存配置",
|
||||||
|
"apply_fail_kb_init": "应用配置失败:初始化知识库",
|
||||||
|
"apply_fail_kb_reinit": "应用配置失败:重新初始化知识库",
|
||||||
|
"apply_fail_c2": "应用配置失败:C2"
|
||||||
|
},
|
||||||
|
"conversation": {
|
||||||
|
"create": "创建对话",
|
||||||
|
"delete": "删除对话",
|
||||||
|
"delete_turn": "删除对话轮次"
|
||||||
|
},
|
||||||
|
"c2": {
|
||||||
|
"listener_create": "创建 C2 监听器",
|
||||||
|
"listener_delete": "删除 C2 监听器",
|
||||||
|
"listener_start": "启动 C2 监听器",
|
||||||
|
"listener_stop": "停止 C2 监听器",
|
||||||
|
"session_delete": "删除 C2 会话",
|
||||||
|
"task_create": "创建 C2 任务",
|
||||||
|
"task_cancel": "取消 C2 任务",
|
||||||
|
"task_delete": "批量删除 C2 任务"
|
||||||
|
},
|
||||||
|
"webshell": {
|
||||||
|
"connection_create": "创建 WebShell 连接",
|
||||||
|
"connection_delete": "删除 WebShell 连接"
|
||||||
|
},
|
||||||
|
"knowledge": {
|
||||||
|
"item_delete": "删除知识项",
|
||||||
|
"index_rebuild": "重建知识库索引"
|
||||||
|
},
|
||||||
|
"vulnerability": {
|
||||||
|
"create": "创建漏洞记录",
|
||||||
|
"update": "更新漏洞记录",
|
||||||
|
"delete": "删除漏洞记录",
|
||||||
|
"delete_batch": "批量删除漏洞记录"
|
||||||
|
},
|
||||||
|
"external_mcp": {
|
||||||
|
"upsert": "更新外部 MCP 配置",
|
||||||
|
"delete": "删除外部 MCP 配置"
|
||||||
|
},
|
||||||
|
"task": {
|
||||||
|
"create_queue": "创建批量任务队列",
|
||||||
|
"start_queue": "启动批量任务队列",
|
||||||
|
"delete_queue": "删除批量任务队列",
|
||||||
|
"pause_queue": "暂停批量任务队列",
|
||||||
|
"rerun_queue": "重跑批量任务队列",
|
||||||
|
"delete_batch_task": "删除批量子任务"
|
||||||
|
},
|
||||||
|
"tool": {
|
||||||
|
"execution_delete": "删除工具执行记录",
|
||||||
|
"execution_delete_batch": "批量删除工具执行记录"
|
||||||
|
},
|
||||||
|
"file": {
|
||||||
|
"upload": "上传对话附件",
|
||||||
|
"delete": "删除对话附件"
|
||||||
|
},
|
||||||
|
"hitl": {
|
||||||
|
"decision": "HITL 审批决策"
|
||||||
|
},
|
||||||
|
"role": {
|
||||||
|
"create": "创建角色",
|
||||||
|
"update": "更新角色",
|
||||||
|
"delete": "删除角色"
|
||||||
|
},
|
||||||
|
"skill": {
|
||||||
|
"create": "创建 Skill",
|
||||||
|
"update": "更新 Skill",
|
||||||
|
"delete": "删除 Skill"
|
||||||
|
},
|
||||||
|
"agent": {
|
||||||
|
"markdown_create": "创建 Markdown 子代理",
|
||||||
|
"markdown_update": "更新 Markdown 子代理",
|
||||||
|
"markdown_delete": "删除 Markdown 子代理"
|
||||||
|
}
|
||||||
|
},
|
||||||
"paginationShow": "显示 {{start}}-{{end}} / 共 {{total}} 条",
|
"paginationShow": "显示 {{start}}-{{end}} / 共 {{total}} 条",
|
||||||
"detailTitle": "审计详情",
|
"detailTitle": "审计详情",
|
||||||
"detailTime": "时间",
|
"detailTime": "时间",
|
||||||
@@ -2232,7 +2419,8 @@
|
|||||||
"copyContent": "复制内容",
|
"copyContent": "复制内容",
|
||||||
"correctInfo": "正确信息",
|
"correctInfo": "正确信息",
|
||||||
"errorInfo": "错误信息",
|
"errorInfo": "错误信息",
|
||||||
"copyError": "复制错误"
|
"copyError": "复制错误",
|
||||||
|
"contentTruncated": "…(展示已截断;完整内容见 persisted-output 中的文件路径,用 read_file 读取)"
|
||||||
},
|
},
|
||||||
"attackChainModal": {
|
"attackChainModal": {
|
||||||
"title": "攻击链可视化",
|
"title": "攻击链可视化",
|
||||||
@@ -2562,6 +2750,11 @@
|
|||||||
},
|
},
|
||||||
"c2": {
|
"c2": {
|
||||||
"clipboardCopied": "已复制到剪贴板",
|
"clipboardCopied": "已复制到剪贴板",
|
||||||
|
"common": {
|
||||||
|
"justNow": "刚刚",
|
||||||
|
"minutesAgo": "{{n}} 分钟前",
|
||||||
|
"hoursAgo": "{{n}} 小时前"
|
||||||
|
},
|
||||||
"fmt": {
|
"fmt": {
|
||||||
"durationMs": "{{n}}ms",
|
"durationMs": "{{n}}ms",
|
||||||
"durationSec": "{{n}}秒",
|
"durationSec": "{{n}}秒",
|
||||||
@@ -2619,6 +2812,8 @@
|
|||||||
"bindHintExternal": "使用 0.0.0.0 允许外部访问",
|
"bindHintExternal": "使用 0.0.0.0 允许外部访问",
|
||||||
"callbackHost": "回连地址(可选)",
|
"callbackHost": "回连地址(可选)",
|
||||||
"callbackHostHint": "公网 IP 或域名,写入配置供 Payload/Beacon 使用;与「绑定地址」分离。不填则生成 Payload 时按绑定地址或自动探测。",
|
"callbackHostHint": "公网 IP 或域名,写入配置供 Payload/Beacon 使用;与「绑定地址」分离。不填则生成 Payload 时按绑定地址或自动探测。",
|
||||||
|
"allowLegacyShell": "允许未加密经典反弹 Shell(内网实验)",
|
||||||
|
"allowLegacyShellHint": "默认关闭。开启后 bash/nc 等裸 TCP 连接可登记会话,公网易被扫描器误连;生产环境请使用「生成 Beacon」加密上线。",
|
||||||
"malleableProfile": "Malleable Profile",
|
"malleableProfile": "Malleable Profile",
|
||||||
"malleableProfileHint": "可选;用于 HTTP/HTTPS Beacon 服务端响应头等流量伪装。修改后需停止并重新启动监听器才会生效。",
|
"malleableProfileHint": "可选;用于 HTTP/HTTPS Beacon 服务端响应头等流量伪装。修改后需停止并重新启动监听器才会生效。",
|
||||||
"malleableProfileNone": "不使用",
|
"malleableProfileNone": "不使用",
|
||||||
@@ -2696,10 +2891,22 @@
|
|||||||
"infoFirstSeen": "首次上线",
|
"infoFirstSeen": "首次上线",
|
||||||
"infoLastCheckin": "上次心跳",
|
"infoLastCheckin": "上次心跳",
|
||||||
"infoNote": "备注",
|
"infoNote": "备注",
|
||||||
|
"infoNoteEmpty": "暂无备注",
|
||||||
|
"infoSectionIdentity": "身份信息",
|
||||||
|
"infoSectionSystem": "系统环境",
|
||||||
|
"infoSectionNetwork": "网络与信标",
|
||||||
|
"infoSectionTimeline": "时间线",
|
||||||
|
"infoSectionNote": "备注",
|
||||||
"adminYes": "是",
|
"adminYes": "是",
|
||||||
"adminNo": "否",
|
"adminNo": "否",
|
||||||
"promptSleepSeconds": "Sleep 间隔(秒)",
|
"promptSleepSeconds": "Sleep 间隔(秒)",
|
||||||
"promptJitterPercent": "抖动百分比(0–100)",
|
"promptJitterPercent": "抖动百分比(0–100)",
|
||||||
|
"sleepModalHint": "保存后将写入服务端并下发 sleep 任务;植入体在下次拉取任务后生效,同时后续心跳会同步该配置。",
|
||||||
|
"sleepModalTitle": "心跳配置",
|
||||||
|
"sleepModalCurrent": "当前 {{sec}} 秒 · 抖动 {{jitter}}%",
|
||||||
|
"sleepModalPreview": "预计间隔 {{min}} – {{max}} 秒",
|
||||||
|
"sleepModalPresets": "快捷",
|
||||||
|
"toastSleepInvalid": "Sleep 间隔至少为 1 秒",
|
||||||
"toastSleepUpdated": "Sleep 设置已更新",
|
"toastSleepUpdated": "Sleep 设置已更新",
|
||||||
"confirmExitSession": "向该会话发送退出指令?",
|
"confirmExitSession": "向该会话发送退出指令?",
|
||||||
"confirmDeleteSession": "从服务器删除此会话及其关联任务与文件记录?(不会向植入体发送退出;若需退出目标进程请使用「终止会话」。)",
|
"confirmDeleteSession": "从服务器删除此会话及其关联任务与文件记录?(不会向植入体发送退出;若需退出目标进程请使用「终止会话」。)",
|
||||||
@@ -2717,7 +2924,25 @@
|
|||||||
"termWaitFinish": "请等待当前命令执行完成",
|
"termWaitFinish": "请等待当前命令执行完成",
|
||||||
"termCtrlC": "当前版本暂不支持中断远程命令",
|
"termCtrlC": "当前版本暂不支持中断远程命令",
|
||||||
"termQueued": "[命令已加入队列,将在当前任务完成后执行]",
|
"termQueued": "[命令已加入队列,将在当前任务完成后执行]",
|
||||||
"clearTerminal": "清屏"
|
"clearTerminal": "清屏",
|
||||||
|
"batchDelete": "批量删除",
|
||||||
|
"deleteFiltered": "删除筛选结果",
|
||||||
|
"selectAll": "全选",
|
||||||
|
"filterAllStatus": "全部状态",
|
||||||
|
"filterAllListeners": "全部监听器",
|
||||||
|
"filterSearchPlaceholder": "搜索主机名 / 用户 / IP",
|
||||||
|
"filterApply": "筛选",
|
||||||
|
"filterReset": "重置",
|
||||||
|
"filterSuspicious": "疑似误报",
|
||||||
|
"filterCount": "共 {{n}} 条,已选 {{selected}}",
|
||||||
|
"emptyFilter": "没有符合筛选条件的会话",
|
||||||
|
"listEmpty": "暂无会话",
|
||||||
|
"selectPromptTitle": "选择会话",
|
||||||
|
"selectPromptHint": "在左侧列表中点击一个会话,查看终端、文件与任务详情。",
|
||||||
|
"confirmBatchDelete": "确定删除选中的 {{n}} 个会话?关联任务与文件记录将一并清除。",
|
||||||
|
"confirmDeleteFiltered": "确定删除当前筛选结果中的全部 {{n}} 个会话?",
|
||||||
|
"toastSelectFirst": "请先勾选要删除的会话",
|
||||||
|
"toastBatchDeleted": "已删除 {{n}} 个会话"
|
||||||
},
|
},
|
||||||
"tasks": {
|
"tasks": {
|
||||||
"title": "任务管理",
|
"title": "任务管理",
|
||||||
@@ -2740,6 +2965,8 @@
|
|||||||
"pending": "待处理",
|
"pending": "待处理",
|
||||||
"emptyAll": "暂无任务",
|
"emptyAll": "暂无任务",
|
||||||
"emptySession": "该会话暂无任务",
|
"emptySession": "该会话暂无任务",
|
||||||
|
"sessionTaskHistory": "任务历史",
|
||||||
|
"sessionTaskCount": "共 {{n}} 条",
|
||||||
"colTask": "任务",
|
"colTask": "任务",
|
||||||
"colSession": "会话",
|
"colSession": "会话",
|
||||||
"colType": "类型",
|
"colType": "类型",
|
||||||
|
|||||||
@@ -0,0 +1,428 @@
|
|||||||
|
/**
|
||||||
|
* Audit log datetime picker — cross-browser, locale-aware (SLS-style calendar + time columns).
|
||||||
|
*/
|
||||||
|
(function () {
|
||||||
|
'use strict';
|
||||||
|
|
||||||
|
var registry = {};
|
||||||
|
var popover = null;
|
||||||
|
var activeFieldId = null;
|
||||||
|
var draft = null;
|
||||||
|
var viewYear = 0;
|
||||||
|
var viewMonth = 0;
|
||||||
|
|
||||||
|
function pad2(n) {
|
||||||
|
return String(n).padStart(2, '0');
|
||||||
|
}
|
||||||
|
|
||||||
|
function pickerLocale() {
|
||||||
|
if (typeof auditLocale === 'function') return auditLocale();
|
||||||
|
if (typeof window.__locale === 'string' && window.__locale.startsWith('zh')) return 'zh-CN';
|
||||||
|
return 'en-US';
|
||||||
|
}
|
||||||
|
|
||||||
|
function pickerT(key, fallback) {
|
||||||
|
if (typeof auditT === 'function') return auditT(key, null, fallback);
|
||||||
|
if (typeof t === 'function') {
|
||||||
|
var v = t(key);
|
||||||
|
if (v && v !== key) return v;
|
||||||
|
}
|
||||||
|
return fallback;
|
||||||
|
}
|
||||||
|
|
||||||
|
function partsToStorage(p) {
|
||||||
|
if (!p) return '';
|
||||||
|
return p.y + '-' + pad2(p.m) + '-' + pad2(p.d) + 'T' + pad2(p.h) + ':' + pad2(p.mi);
|
||||||
|
}
|
||||||
|
|
||||||
|
function parseStorage(value) {
|
||||||
|
if (!value) return null;
|
||||||
|
var m = /^(\d{4})-(\d{2})-(\d{2})T(\d{2}):(\d{2})/.exec(String(value).trim());
|
||||||
|
if (!m) return null;
|
||||||
|
return { y: +m[1], m: +m[2], d: +m[3], h: +m[4], mi: +m[5] };
|
||||||
|
}
|
||||||
|
|
||||||
|
function formatDisplay(parts) {
|
||||||
|
if (!parts) return '';
|
||||||
|
var loc = pickerLocale();
|
||||||
|
try {
|
||||||
|
var d = new Date(parts.y, parts.m - 1, parts.d, parts.h, parts.mi, 0, 0);
|
||||||
|
return d.toLocaleString(loc, {
|
||||||
|
year: 'numeric',
|
||||||
|
month: '2-digit',
|
||||||
|
day: '2-digit',
|
||||||
|
hour: '2-digit',
|
||||||
|
minute: '2-digit',
|
||||||
|
hour12: false
|
||||||
|
});
|
||||||
|
} catch (_) {
|
||||||
|
return partsToStorage(parts).replace('T', ' ');
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
function nowParts() {
|
||||||
|
var n = new Date();
|
||||||
|
return { y: n.getFullYear(), m: n.getMonth() + 1, d: n.getDate(), h: n.getHours(), mi: n.getMinutes() };
|
||||||
|
}
|
||||||
|
|
||||||
|
function startOfTodayParts() {
|
||||||
|
var n = new Date();
|
||||||
|
return { y: n.getFullYear(), m: n.getMonth() + 1, d: n.getDate(), h: 0, mi: 0 };
|
||||||
|
}
|
||||||
|
|
||||||
|
function monthTitle(year, month) {
|
||||||
|
var loc = pickerLocale();
|
||||||
|
if (loc.startsWith('zh')) {
|
||||||
|
return year + '\u5e74' + pad2(month) + '\u6708';
|
||||||
|
}
|
||||||
|
try {
|
||||||
|
return new Date(year, month - 1, 1).toLocaleString(loc, { month: 'long', year: 'numeric' });
|
||||||
|
} catch (_) {
|
||||||
|
return year + '-' + pad2(month);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
function weekdayHeaders() {
|
||||||
|
var loc = pickerLocale();
|
||||||
|
if (loc.startsWith('zh')) {
|
||||||
|
return ['\u65e5', '\u4e00', '\u4e8c', '\u4e09', '\u56db', '\u4e94', '\u516d'];
|
||||||
|
}
|
||||||
|
return ['Su', 'Mo', 'Tu', 'We', 'Th', 'Fr', 'Sa'];
|
||||||
|
}
|
||||||
|
|
||||||
|
function buildMonthGrid(year, month) {
|
||||||
|
var first = new Date(year, month - 1, 1);
|
||||||
|
var start = new Date(first);
|
||||||
|
start.setDate(first.getDate() - first.getDay());
|
||||||
|
var cells = [];
|
||||||
|
var cursor = new Date(start);
|
||||||
|
for (var i = 0; i < 42; i++) {
|
||||||
|
cells.push({
|
||||||
|
y: cursor.getFullYear(),
|
||||||
|
m: cursor.getMonth() + 1,
|
||||||
|
d: cursor.getDate(),
|
||||||
|
inMonth: cursor.getMonth() === month - 1
|
||||||
|
});
|
||||||
|
cursor.setDate(cursor.getDate() + 1);
|
||||||
|
}
|
||||||
|
return cells;
|
||||||
|
}
|
||||||
|
|
||||||
|
function ensurePopover() {
|
||||||
|
if (popover) return popover;
|
||||||
|
popover = document.createElement('div');
|
||||||
|
popover.className = 'audit-dt-popover';
|
||||||
|
popover.hidden = true;
|
||||||
|
popover.setAttribute('role', 'dialog');
|
||||||
|
popover.innerHTML =
|
||||||
|
'<div class="audit-dt-popover-inner">' +
|
||||||
|
'<div class="audit-dt-head">' +
|
||||||
|
'<button type="button" class="audit-dt-nav" data-nav="prev" aria-label="prev">‹</button>' +
|
||||||
|
'<span class="audit-dt-month-label"></span>' +
|
||||||
|
'<button type="button" class="audit-dt-nav" data-nav="next" aria-label="next">›</button>' +
|
||||||
|
'</div>' +
|
||||||
|
'<div class="audit-dt-body">' +
|
||||||
|
'<div class="audit-dt-calendar"></div>' +
|
||||||
|
'<div class="audit-dt-time">' +
|
||||||
|
'<div class="audit-dt-time-col" data-part="hour">' +
|
||||||
|
'<span class="audit-dt-time-label audit-dt-hour-label"></span>' +
|
||||||
|
'<div class="audit-dt-time-list"></div>' +
|
||||||
|
'</div>' +
|
||||||
|
'<div class="audit-dt-time-col" data-part="minute">' +
|
||||||
|
'<span class="audit-dt-time-label audit-dt-minute-label"></span>' +
|
||||||
|
'<div class="audit-dt-time-list"></div>' +
|
||||||
|
'</div>' +
|
||||||
|
'</div>' +
|
||||||
|
'</div>' +
|
||||||
|
'<div class="audit-dt-footer">' +
|
||||||
|
'<button type="button" class="audit-dt-footer-btn" data-action="clear"></button>' +
|
||||||
|
'<button type="button" class="audit-dt-footer-btn" data-action="today"></button>' +
|
||||||
|
'<button type="button" class="audit-dt-footer-btn audit-dt-footer-btn--primary" data-action="confirm"></button>' +
|
||||||
|
'</div>' +
|
||||||
|
'</div>';
|
||||||
|
document.body.appendChild(popover);
|
||||||
|
|
||||||
|
popover.addEventListener('click', function (ev) {
|
||||||
|
ev.stopPropagation();
|
||||||
|
var btn = ev.target.closest('[data-nav]');
|
||||||
|
if (btn) {
|
||||||
|
if (btn.getAttribute('data-nav') === 'prev') {
|
||||||
|
viewMonth -= 1;
|
||||||
|
if (viewMonth < 1) { viewMonth = 12; viewYear -= 1; }
|
||||||
|
} else {
|
||||||
|
viewMonth += 1;
|
||||||
|
if (viewMonth > 12) { viewMonth = 1; viewYear += 1; }
|
||||||
|
}
|
||||||
|
renderPopover();
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
var dayBtn = ev.target.closest('[data-day]');
|
||||||
|
if (dayBtn && draft) {
|
||||||
|
draft.y = +dayBtn.getAttribute('data-y');
|
||||||
|
draft.m = +dayBtn.getAttribute('data-m');
|
||||||
|
draft.d = +dayBtn.getAttribute('data-d');
|
||||||
|
if (draft.y !== viewYear || draft.m !== viewMonth) {
|
||||||
|
viewYear = draft.y;
|
||||||
|
viewMonth = draft.m;
|
||||||
|
renderCalendar();
|
||||||
|
} else {
|
||||||
|
updateDaySelection();
|
||||||
|
}
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
var timeBtn = ev.target.closest('[data-time]');
|
||||||
|
if (timeBtn && draft) {
|
||||||
|
var part = timeBtn.getAttribute('data-part');
|
||||||
|
var val = +timeBtn.getAttribute('data-time');
|
||||||
|
if (part === 'hour') draft.h = val;
|
||||||
|
if (part === 'minute') draft.mi = val;
|
||||||
|
updateTimeSelection();
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
var actionBtn = ev.target.closest('[data-action]');
|
||||||
|
if (!actionBtn) return;
|
||||||
|
var action = actionBtn.getAttribute('data-action');
|
||||||
|
if (action === 'clear') {
|
||||||
|
applyValue(activeFieldId, '');
|
||||||
|
closePopover();
|
||||||
|
} else if (action === 'today') {
|
||||||
|
if (draft) {
|
||||||
|
var t = nowParts();
|
||||||
|
draft.y = t.y; draft.m = t.m; draft.d = t.d;
|
||||||
|
viewYear = t.y; viewMonth = t.m;
|
||||||
|
}
|
||||||
|
renderPopover();
|
||||||
|
} else if (action === 'confirm') {
|
||||||
|
applyValue(activeFieldId, partsToStorage(draft));
|
||||||
|
closePopover();
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
document.addEventListener('click', onDocumentClick);
|
||||||
|
document.addEventListener('keydown', onDocumentKeydown);
|
||||||
|
document.addEventListener('languagechange', function () {
|
||||||
|
if (!popover.hidden) renderPopover();
|
||||||
|
refreshAllDisplays();
|
||||||
|
});
|
||||||
|
|
||||||
|
return popover;
|
||||||
|
}
|
||||||
|
|
||||||
|
function onDocumentClick(ev) {
|
||||||
|
if (!popover || popover.hidden) return;
|
||||||
|
if (popover.contains(ev.target)) return;
|
||||||
|
if (activeFieldId && registry[activeFieldId] && registry[activeFieldId].wrap.contains(ev.target)) return;
|
||||||
|
closePopover();
|
||||||
|
}
|
||||||
|
|
||||||
|
function onDocumentKeydown(ev) {
|
||||||
|
if (ev.key === 'Escape' && popover && !popover.hidden) {
|
||||||
|
closePopover();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
function positionPopover(fieldWrap) {
|
||||||
|
var rect = fieldWrap.getBoundingClientRect();
|
||||||
|
var width = 320;
|
||||||
|
popover.style.width = width + 'px';
|
||||||
|
var left = rect.left;
|
||||||
|
if (left + width > window.innerWidth - 12) {
|
||||||
|
left = Math.max(12, window.innerWidth - width - 12);
|
||||||
|
}
|
||||||
|
popover.style.left = left + 'px';
|
||||||
|
var top = rect.bottom + 6;
|
||||||
|
if (top + 340 > window.innerHeight - 12) {
|
||||||
|
top = Math.max(12, rect.top - 340 - 6);
|
||||||
|
}
|
||||||
|
popover.style.top = top + 'px';
|
||||||
|
}
|
||||||
|
|
||||||
|
function renderCalendar() {
|
||||||
|
if (!popover || !draft) return;
|
||||||
|
popover.querySelector('.audit-dt-month-label').textContent = monthTitle(viewYear, viewMonth);
|
||||||
|
var cal = popover.querySelector('.audit-dt-calendar');
|
||||||
|
var headers = weekdayHeaders();
|
||||||
|
var html = '<div class="audit-dt-weekdays">';
|
||||||
|
headers.forEach(function (h) { html += '<span>' + h + '</span>'; });
|
||||||
|
html += '</div><div class="audit-dt-days">';
|
||||||
|
buildMonthGrid(viewYear, viewMonth).forEach(function (cell) {
|
||||||
|
var cls = 'audit-dt-day';
|
||||||
|
if (!cell.inMonth) cls += ' is-other-month';
|
||||||
|
if (draft && cell.y === draft.y && cell.m === draft.m && cell.d === draft.d) cls += ' is-selected';
|
||||||
|
html += '<button type="button" class="' + cls + '" data-day="1" data-y="' + cell.y +
|
||||||
|
'" data-m="' + cell.m + '" data-d="' + cell.d + '">' + cell.d + '</button>';
|
||||||
|
});
|
||||||
|
html += '</div>';
|
||||||
|
cal.innerHTML = html;
|
||||||
|
}
|
||||||
|
|
||||||
|
function renderTimeLists() {
|
||||||
|
if (!popover || !draft) return;
|
||||||
|
var hourList = popover.querySelector('[data-part="hour"] .audit-dt-time-list');
|
||||||
|
var minuteList = popover.querySelector('[data-part="minute"] .audit-dt-time-list');
|
||||||
|
var hourHtml = '';
|
||||||
|
var minuteHtml = '';
|
||||||
|
var h;
|
||||||
|
for (h = 0; h < 24; h++) {
|
||||||
|
hourHtml += '<button type="button" class="audit-dt-time-item' + (draft && draft.h === h ? ' is-selected' : '') +
|
||||||
|
'" data-part="hour" data-time="' + h + '">' + pad2(h) + '</button>';
|
||||||
|
}
|
||||||
|
for (h = 0; h < 60; h++) {
|
||||||
|
minuteHtml += '<button type="button" class="audit-dt-time-item' + (draft && draft.mi === h ? ' is-selected' : '') +
|
||||||
|
'" data-part="minute" data-time="' + h + '">' + pad2(h) + '</button>';
|
||||||
|
}
|
||||||
|
hourList.innerHTML = hourHtml;
|
||||||
|
minuteList.innerHTML = minuteHtml;
|
||||||
|
scrollTimeSelection(hourList, draft.h);
|
||||||
|
scrollTimeSelection(minuteList, draft.mi);
|
||||||
|
}
|
||||||
|
|
||||||
|
function updateDaySelection() {
|
||||||
|
if (!popover || !draft) return;
|
||||||
|
popover.querySelectorAll('.audit-dt-day').forEach(function (btn) {
|
||||||
|
var selected = +btn.getAttribute('data-y') === draft.y &&
|
||||||
|
+btn.getAttribute('data-m') === draft.m &&
|
||||||
|
+btn.getAttribute('data-d') === draft.d;
|
||||||
|
btn.classList.toggle('is-selected', selected);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
function updateTimeSelection() {
|
||||||
|
if (!popover || !draft) return;
|
||||||
|
var hourList = popover.querySelector('[data-part="hour"] .audit-dt-time-list');
|
||||||
|
var minuteList = popover.querySelector('[data-part="minute"] .audit-dt-time-list');
|
||||||
|
if (!hourList || !minuteList || !hourList.children.length) {
|
||||||
|
renderTimeLists();
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
hourList.querySelectorAll('.audit-dt-time-item').forEach(function (btn) {
|
||||||
|
btn.classList.toggle('is-selected', +btn.getAttribute('data-time') === draft.h);
|
||||||
|
});
|
||||||
|
minuteList.querySelectorAll('.audit-dt-time-item').forEach(function (btn) {
|
||||||
|
btn.classList.toggle('is-selected', +btn.getAttribute('data-time') === draft.mi);
|
||||||
|
});
|
||||||
|
scrollTimeSelection(hourList, draft.h);
|
||||||
|
scrollTimeSelection(minuteList, draft.mi);
|
||||||
|
}
|
||||||
|
|
||||||
|
function renderPopover() {
|
||||||
|
if (!popover || !draft) return;
|
||||||
|
popover.querySelector('.audit-dt-hour-label').textContent = pickerT('settingsAudit.pickerHour', 'Hour');
|
||||||
|
popover.querySelector('.audit-dt-minute-label').textContent = pickerT('settingsAudit.pickerMinute', 'Min');
|
||||||
|
popover.querySelector('[data-action="clear"]').textContent = pickerT('settingsAudit.pickerClear', 'Clear');
|
||||||
|
popover.querySelector('[data-action="today"]').textContent = pickerT('settingsAudit.pickerToday', 'Today');
|
||||||
|
popover.querySelector('[data-action="confirm"]').textContent = pickerT('settingsAudit.pickerConfirm', 'OK');
|
||||||
|
renderCalendar();
|
||||||
|
renderTimeLists();
|
||||||
|
}
|
||||||
|
|
||||||
|
function scrollTimeSelection(listEl, value) {
|
||||||
|
var sel = listEl.querySelector('.is-selected');
|
||||||
|
if (sel && sel.scrollIntoView) {
|
||||||
|
sel.scrollIntoView({ block: 'center' });
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
function openPopover(fieldId) {
|
||||||
|
ensurePopover();
|
||||||
|
var entry = registry[fieldId];
|
||||||
|
if (!entry) return;
|
||||||
|
activeFieldId = fieldId;
|
||||||
|
var stored = entry.wrap.dataset.value || '';
|
||||||
|
draft = parseStorage(stored) || nowParts();
|
||||||
|
viewYear = draft.y;
|
||||||
|
viewMonth = draft.m;
|
||||||
|
renderPopover();
|
||||||
|
positionPopover(entry.wrap);
|
||||||
|
popover.hidden = false;
|
||||||
|
}
|
||||||
|
|
||||||
|
function closePopover() {
|
||||||
|
if (!popover) return;
|
||||||
|
popover.hidden = true;
|
||||||
|
activeFieldId = null;
|
||||||
|
draft = null;
|
||||||
|
}
|
||||||
|
|
||||||
|
function refreshDisplay(fieldId) {
|
||||||
|
var entry = registry[fieldId];
|
||||||
|
if (!entry) return;
|
||||||
|
var parts = parseStorage(entry.wrap.dataset.value || '');
|
||||||
|
entry.input.value = parts ? formatDisplay(parts) : '';
|
||||||
|
entry.input.placeholder = pickerT('settingsAudit.datetimePlaceholder', 'Select date & time');
|
||||||
|
entry.clearBtn.hidden = !parts;
|
||||||
|
}
|
||||||
|
|
||||||
|
function refreshAllDisplays() {
|
||||||
|
Object.keys(registry).forEach(refreshDisplay);
|
||||||
|
}
|
||||||
|
|
||||||
|
function applyValue(fieldId, storageValue) {
|
||||||
|
var entry = registry[fieldId];
|
||||||
|
if (!entry) return;
|
||||||
|
entry.wrap.dataset.value = storageValue || '';
|
||||||
|
refreshDisplay(fieldId);
|
||||||
|
}
|
||||||
|
|
||||||
|
function bindField(fieldId) {
|
||||||
|
var wrap = document.getElementById(fieldId);
|
||||||
|
if (!wrap || wrap.dataset.auditDtBound === '1') return;
|
||||||
|
var input = wrap.querySelector('.audit-datetime-input');
|
||||||
|
var openBtn = wrap.querySelector('.audit-datetime-open-btn');
|
||||||
|
var clearBtn = wrap.querySelector('.audit-datetime-clear-btn');
|
||||||
|
if (!input || !openBtn || !clearBtn) return;
|
||||||
|
|
||||||
|
wrap.dataset.auditDtBound = '1';
|
||||||
|
registry[fieldId] = { wrap: wrap, input: input, clearBtn: clearBtn };
|
||||||
|
|
||||||
|
openBtn.addEventListener('click', function (ev) {
|
||||||
|
ev.preventDefault();
|
||||||
|
ev.stopPropagation();
|
||||||
|
if (!popover || popover.hidden || activeFieldId !== fieldId) {
|
||||||
|
openPopover(fieldId);
|
||||||
|
} else {
|
||||||
|
closePopover();
|
||||||
|
}
|
||||||
|
});
|
||||||
|
input.addEventListener('click', function (ev) {
|
||||||
|
ev.stopPropagation();
|
||||||
|
openPopover(fieldId);
|
||||||
|
});
|
||||||
|
clearBtn.addEventListener('click', function (ev) {
|
||||||
|
ev.preventDefault();
|
||||||
|
ev.stopPropagation();
|
||||||
|
applyValue(fieldId, '');
|
||||||
|
});
|
||||||
|
refreshDisplay(fieldId);
|
||||||
|
}
|
||||||
|
|
||||||
|
window.AuditDatetimePicker = {
|
||||||
|
init: function () {
|
||||||
|
bindField('audit-filter-since-field');
|
||||||
|
bindField('audit-filter-until-field');
|
||||||
|
refreshAllDisplays();
|
||||||
|
},
|
||||||
|
getValue: function (inputId) {
|
||||||
|
var fieldId = inputId === 'audit-filter-since' ? 'audit-filter-since-field' : 'audit-filter-until-field';
|
||||||
|
var entry = registry[fieldId];
|
||||||
|
return entry ? (entry.wrap.dataset.value || '') : '';
|
||||||
|
},
|
||||||
|
setValue: function (inputId, dateObj) {
|
||||||
|
if (!dateObj || Number.isNaN(dateObj.getTime())) return;
|
||||||
|
var fieldId = inputId === 'audit-filter-since' ? 'audit-filter-since-field' : 'audit-filter-until-field';
|
||||||
|
var p = {
|
||||||
|
y: dateObj.getFullYear(),
|
||||||
|
m: dateObj.getMonth() + 1,
|
||||||
|
d: dateObj.getDate(),
|
||||||
|
h: dateObj.getHours(),
|
||||||
|
mi: dateObj.getMinutes()
|
||||||
|
};
|
||||||
|
applyValue(fieldId, partsToStorage(p));
|
||||||
|
},
|
||||||
|
clearAll: function () {
|
||||||
|
applyValue('audit-filter-since-field', '');
|
||||||
|
applyValue('audit-filter-until-field', '');
|
||||||
|
closePopover();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
})();
|
||||||
+352
-56
@@ -4,6 +4,7 @@
|
|||||||
let auditLogsPage = 1;
|
let auditLogsPage = 1;
|
||||||
let auditLogsPageSize = 20;
|
let auditLogsPageSize = 20;
|
||||||
let auditLogsTotal = 0;
|
let auditLogsTotal = 0;
|
||||||
|
let auditLogsCache = [];
|
||||||
|
|
||||||
const AUDIT_PAGE_SIZE_KEY = 'cyberstrike_audit_page_size';
|
const AUDIT_PAGE_SIZE_KEY = 'cyberstrike_audit_page_size';
|
||||||
|
|
||||||
@@ -52,24 +53,113 @@ function auditActionLabel(action) {
|
|||||||
return auditT('settingsAudit.act.' + action, null, action);
|
return auditT('settingsAudit.act.' + action, null, action);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/** Stored DB messages that share category+action but need distinct i18n keys. */
|
||||||
|
const AUDIT_MSG_BY_STORED_TEXT = {
|
||||||
|
'登录失败:密码错误': 'settingsAudit.msg.auth.login_failed',
|
||||||
|
'修改密码失败:当前密码不正确': 'settingsAudit.msg.auth.change_password_failed',
|
||||||
|
'应用配置失败:初始化知识库': 'settingsAudit.msg.config.apply_fail_kb_init',
|
||||||
|
'应用配置失败:重新初始化知识库': 'settingsAudit.msg.config.apply_fail_kb_reinit',
|
||||||
|
'应用配置失败:C2': 'settingsAudit.msg.config.apply_fail_c2'
|
||||||
|
};
|
||||||
|
|
||||||
|
function auditMessageLabel(log) {
|
||||||
|
if (!log) return '';
|
||||||
|
const raw = (log.message || '').trim();
|
||||||
|
if (raw && AUDIT_MSG_BY_STORED_TEXT[raw]) {
|
||||||
|
return auditT(AUDIT_MSG_BY_STORED_TEXT[raw], null, raw);
|
||||||
|
}
|
||||||
|
const cat = (log.category || '').trim();
|
||||||
|
const act = (log.action || '').trim();
|
||||||
|
const res = (log.result || '').trim();
|
||||||
|
if (cat && act) {
|
||||||
|
if (cat === 'auth' && act === 'login' && res === 'failure') {
|
||||||
|
return auditT('settingsAudit.msg.auth.login_failed', null, raw);
|
||||||
|
}
|
||||||
|
if (cat === 'auth' && act === 'change_password' && res === 'failure') {
|
||||||
|
return auditT('settingsAudit.msg.auth.change_password_failed', null, raw);
|
||||||
|
}
|
||||||
|
const key = 'settingsAudit.msg.' + cat + '.' + act;
|
||||||
|
const translated = auditT(key, null, null);
|
||||||
|
if (translated && translated !== key) return translated;
|
||||||
|
}
|
||||||
|
return raw;
|
||||||
|
}
|
||||||
|
|
||||||
|
function auditResultLabel(result) {
|
||||||
|
if (!result) return '';
|
||||||
|
return auditT('settingsAudit.result.' + result, null, result);
|
||||||
|
}
|
||||||
|
|
||||||
|
function auditLocale() {
|
||||||
|
if (typeof window.__locale === 'string' && window.__locale.length) {
|
||||||
|
return window.__locale.startsWith('zh') ? 'zh-CN' : 'en-US';
|
||||||
|
}
|
||||||
|
return (typeof navigator !== 'undefined' && navigator.language) ? navigator.language : 'en-US';
|
||||||
|
}
|
||||||
|
|
||||||
|
function auditTimezoneShortLabel() {
|
||||||
|
try {
|
||||||
|
const parts = new Intl.DateTimeFormat(auditLocale(), { timeZoneName: 'short' }).formatToParts(new Date());
|
||||||
|
const tz = parts.find(function (p) { return p.type === 'timeZoneName'; });
|
||||||
|
return tz ? tz.value : '';
|
||||||
|
} catch (_) {
|
||||||
|
return '';
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
function formatAuditTime(iso) {
|
function formatAuditTime(iso) {
|
||||||
if (!iso) return '';
|
if (!iso) return '';
|
||||||
try {
|
try {
|
||||||
const d = new Date(iso);
|
const d = new Date(iso);
|
||||||
if (Number.isNaN(d.getTime())) return iso;
|
if (Number.isNaN(d.getTime())) return iso;
|
||||||
return d.toLocaleString();
|
return d.toLocaleString(auditLocale(), {
|
||||||
|
year: 'numeric',
|
||||||
|
month: '2-digit',
|
||||||
|
day: '2-digit',
|
||||||
|
hour: '2-digit',
|
||||||
|
minute: '2-digit',
|
||||||
|
second: '2-digit',
|
||||||
|
hour12: false,
|
||||||
|
timeZoneName: 'short'
|
||||||
|
});
|
||||||
} catch (_) {
|
} catch (_) {
|
||||||
return iso;
|
return iso;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/** Read stored local datetime (YYYY-MM-DDTHH:mm) from custom picker or raw input. */
|
||||||
|
function getAuditFilterDatetimeValue(inputId) {
|
||||||
|
if (typeof window.AuditDatetimePicker !== 'undefined' && typeof window.AuditDatetimePicker.getValue === 'function') {
|
||||||
|
return window.AuditDatetimePicker.getValue(inputId) || '';
|
||||||
|
}
|
||||||
|
var el = document.getElementById(inputId);
|
||||||
|
return el ? (el.value || '') : '';
|
||||||
|
}
|
||||||
|
|
||||||
|
/** datetime-local / picker storage -> UTC RFC3339 for API. */
|
||||||
function auditDatetimeLocalToRFC3339(value) {
|
function auditDatetimeLocalToRFC3339(value) {
|
||||||
if (!value || !value.trim()) return '';
|
if (!value || !value.trim()) return '';
|
||||||
const d = new Date(value);
|
const m = /^(\d{4})-(\d{2})-(\d{2})T(\d{2}):(\d{2})/.exec(value.trim());
|
||||||
|
if (!m) return '';
|
||||||
|
const d = new Date(+m[1], +m[2] - 1, +m[3], +m[4], +m[5], 0, 0);
|
||||||
if (Number.isNaN(d.getTime())) return '';
|
if (Number.isNaN(d.getTime())) return '';
|
||||||
return d.toISOString();
|
return d.toISOString();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
function updateAuditTimezoneHint() {
|
||||||
|
const el = document.getElementById('audit-filter-timezone-hint');
|
||||||
|
if (!el) return;
|
||||||
|
const tz = auditTimezoneShortLabel();
|
||||||
|
if (!tz) {
|
||||||
|
el.hidden = true;
|
||||||
|
el.textContent = '';
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
el.hidden = false;
|
||||||
|
el.textContent = auditT('settingsAudit.filterTimeZone', { tz: tz },
|
||||||
|
'时区:' + tz + '(筛选按浏览器本地时间,API 使用 UTC)');
|
||||||
|
}
|
||||||
|
|
||||||
function initAuditPageSizeFromStorage() {
|
function initAuditPageSizeFromStorage() {
|
||||||
try {
|
try {
|
||||||
const saved = parseInt(localStorage.getItem(AUDIT_PAGE_SIZE_KEY), 10);
|
const saved = parseInt(localStorage.getItem(AUDIT_PAGE_SIZE_KEY), 10);
|
||||||
@@ -113,6 +203,7 @@ function rebuildAuditActionSelect() {
|
|||||||
actEl.disabled = true;
|
actEl.disabled = true;
|
||||||
actEl.value = '';
|
actEl.value = '';
|
||||||
actEl.title = hint;
|
actEl.title = hint;
|
||||||
|
syncAuditCustomSelect('audit-filter-action');
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -129,6 +220,7 @@ function rebuildAuditActionSelect() {
|
|||||||
if (prev && Array.prototype.some.call(actEl.options, function (o) { return o.value === prev; })) {
|
if (prev && Array.prototype.some.call(actEl.options, function (o) { return o.value === prev; })) {
|
||||||
actEl.value = prev;
|
actEl.value = prev;
|
||||||
}
|
}
|
||||||
|
syncAuditCustomSelect('audit-filter-action');
|
||||||
}
|
}
|
||||||
|
|
||||||
function onAuditCategoryFilterChange() {
|
function onAuditCategoryFilterChange() {
|
||||||
@@ -145,43 +237,17 @@ function buildAuditQueryParams(forExport) {
|
|||||||
const act = document.getElementById('audit-filter-action');
|
const act = document.getElementById('audit-filter-action');
|
||||||
const res = document.getElementById('audit-filter-result');
|
const res = document.getElementById('audit-filter-result');
|
||||||
const q = document.getElementById('audit-filter-q');
|
const q = document.getElementById('audit-filter-q');
|
||||||
const since = document.getElementById('audit-filter-since');
|
|
||||||
const until = document.getElementById('audit-filter-until');
|
|
||||||
if (cat && cat.value) params.set('category', cat.value);
|
if (cat && cat.value) params.set('category', cat.value);
|
||||||
if (act && !act.disabled && act.value) params.set('action', act.value);
|
if (act && !act.disabled && act.value) params.set('action', act.value);
|
||||||
if (res && res.value) params.set('result', res.value);
|
if (res && res.value) params.set('result', res.value);
|
||||||
if (q && q.value.trim()) params.set('q', q.value.trim());
|
if (q && q.value.trim()) params.set('q', q.value.trim());
|
||||||
const sinceISO = since ? auditDatetimeLocalToRFC3339(since.value) : '';
|
const sinceISO = auditDatetimeLocalToRFC3339(getAuditFilterDatetimeValue('audit-filter-since'));
|
||||||
const untilISO = until ? auditDatetimeLocalToRFC3339(until.value) : '';
|
const untilISO = auditDatetimeLocalToRFC3339(getAuditFilterDatetimeValue('audit-filter-until'));
|
||||||
if (sinceISO) params.set('since', sinceISO);
|
if (sinceISO) params.set('since', sinceISO);
|
||||||
if (untilISO) params.set('until', untilISO);
|
if (untilISO) params.set('until', untilISO);
|
||||||
return params.toString();
|
return params.toString();
|
||||||
}
|
}
|
||||||
|
|
||||||
async function loadAuditMeta() {
|
|
||||||
if (typeof apiFetch !== 'function') return;
|
|
||||||
const hint = document.getElementById('audit-retention-hint');
|
|
||||||
try {
|
|
||||||
const r = await apiFetch('/api/audit/meta');
|
|
||||||
if (!r.ok) return;
|
|
||||||
const data = await r.json();
|
|
||||||
if (!hint) return;
|
|
||||||
if (!data.enabled) {
|
|
||||||
hint.hidden = false;
|
|
||||||
hint.textContent = auditT('settingsAudit.disabledHint', null, '审计功能已关闭,新操作不会写入审计表。');
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
const days = data.retention_days;
|
|
||||||
if (days > 0) {
|
|
||||||
hint.hidden = false;
|
|
||||||
hint.textContent = auditT('settingsAudit.retentionHint', { days: days },
|
|
||||||
'审计记录保留 ' + days + ' 天,超期自动清理。');
|
|
||||||
} else {
|
|
||||||
hint.hidden = true;
|
|
||||||
}
|
|
||||||
} catch (_) { /* ignore */ }
|
|
||||||
}
|
|
||||||
|
|
||||||
async function loadAuditSummary() {
|
async function loadAuditSummary() {
|
||||||
if (typeof apiFetch !== 'function') return;
|
if (typeof apiFetch !== 'function') return;
|
||||||
const wrap = document.getElementById('audit-summary-stats');
|
const wrap = document.getElementById('audit-summary-stats');
|
||||||
@@ -191,10 +257,14 @@ async function loadAuditSummary() {
|
|||||||
const data = await r.json();
|
const data = await r.json();
|
||||||
if (wrap) wrap.hidden = false;
|
if (wrap) wrap.hidden = false;
|
||||||
const elTotal = document.getElementById('audit-stat-total');
|
const elTotal = document.getElementById('audit-stat-total');
|
||||||
|
const elSuccess = document.getElementById('audit-stat-success');
|
||||||
const elFail = document.getElementById('audit-stat-failures');
|
const elFail = document.getElementById('audit-stat-failures');
|
||||||
const elRecent = document.getElementById('audit-stat-recent');
|
const elRecent = document.getElementById('audit-stat-recent');
|
||||||
if (elTotal) elTotal.textContent = String(data.total != null ? data.total : 0);
|
const total = data.total != null ? data.total : 0;
|
||||||
if (elFail) elFail.textContent = String(data.failures != null ? data.failures : 0);
|
const failures = data.failures != null ? data.failures : 0;
|
||||||
|
if (elTotal) elTotal.textContent = String(total);
|
||||||
|
if (elSuccess) elSuccess.textContent = String(Math.max(0, total - failures));
|
||||||
|
if (elFail) elFail.textContent = String(failures);
|
||||||
if (elRecent) elRecent.textContent = String(data.recent_7d != null ? data.recent_7d : 0);
|
if (elRecent) elRecent.textContent = String(data.recent_7d != null ? data.recent_7d : 0);
|
||||||
} catch (_) { /* ignore */ }
|
} catch (_) { /* ignore */ }
|
||||||
}
|
}
|
||||||
@@ -214,7 +284,8 @@ async function loadAuditLogs(page) {
|
|||||||
throw new Error(err.error || r.statusText);
|
throw new Error(err.error || r.statusText);
|
||||||
}
|
}
|
||||||
const data = await r.json();
|
const data = await r.json();
|
||||||
renderAuditLogs(data.logs || []);
|
auditLogsCache = data.logs || [];
|
||||||
|
renderAuditLogs(auditLogsCache);
|
||||||
auditLogsTotal = typeof data.total === 'number' ? data.total : 0;
|
auditLogsTotal = typeof data.total === 'number' ? data.total : 0;
|
||||||
const maxPage = Math.max(1, Math.ceil(auditLogsTotal / auditLogsPageSize));
|
const maxPage = Math.max(1, Math.ceil(auditLogsTotal / auditLogsPageSize));
|
||||||
if (auditLogsPage > maxPage) {
|
if (auditLogsPage > maxPage) {
|
||||||
@@ -234,37 +305,57 @@ async function loadAuditLogs(page) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
function auditResultTagClass(result) {
|
||||||
|
return result === 'failure' ? 'audit-tag--fail' : 'audit-tag--ok';
|
||||||
|
}
|
||||||
|
|
||||||
function renderAuditLogs(logs) {
|
function renderAuditLogs(logs) {
|
||||||
const listEl = document.getElementById('audit-log-list');
|
const listEl = document.getElementById('audit-log-list');
|
||||||
if (!listEl) return;
|
if (!listEl) return;
|
||||||
const esc = typeof escapeHtml === 'function' ? escapeHtml : function (s) { return String(s || ''); };
|
const esc = typeof escapeHtml === 'function' ? escapeHtml : function (s) { return String(s || ''); };
|
||||||
if (!logs.length) {
|
if (!logs.length) {
|
||||||
listEl.innerHTML = '<div class="c2-empty">' + esc(auditT('settingsAudit.empty', null, '暂无审计记录')) + '</div>';
|
listEl.innerHTML = '<div class="audit-log-empty">' + esc(auditT('settingsAudit.empty', null, '暂无审计记录')) + '</div>';
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
listEl.innerHTML = logs.map(function (log) {
|
const dash = '<span class="audit-log-cell-muted">—</span>';
|
||||||
const lvl = log.result === 'failure' ? 'warn' : (log.level || 'info');
|
const head = (
|
||||||
|
'<div class="audit-log-table-wrap">' +
|
||||||
|
'<table class="audit-log-table">' +
|
||||||
|
'<thead><tr>' +
|
||||||
|
'<th data-i18n="settingsAudit.colTime">时间</th>' +
|
||||||
|
'<th data-i18n="settingsAudit.colMessage">说明</th>' +
|
||||||
|
'<th data-i18n="settingsAudit.colCategory">类别</th>' +
|
||||||
|
'<th data-i18n="settingsAudit.colAction">操作</th>' +
|
||||||
|
'<th data-i18n="settingsAudit.colResult">结果</th>' +
|
||||||
|
'<th data-i18n="settingsAudit.colIp">IP</th>' +
|
||||||
|
'<th data-i18n="settingsAudit.colResource">资源 ID</th>' +
|
||||||
|
'</tr></thead><tbody>'
|
||||||
|
);
|
||||||
|
const rows = logs.map(function (log) {
|
||||||
const catLabel = esc(auditCategoryLabel(log.category || ''));
|
const catLabel = esc(auditCategoryLabel(log.category || ''));
|
||||||
const actionLabel = esc(auditActionLabel(log.action || ''));
|
const actionLabel = esc(auditActionLabel(log.action || ''));
|
||||||
const msg = esc(log.message || '');
|
const msg = esc(auditMessageLabel(log));
|
||||||
const ip = esc(log.clientIp || '');
|
const ip = esc(log.clientIp || '');
|
||||||
const when = esc(formatAuditTime(log.createdAt));
|
const when = esc(formatAuditTime(log.createdAt));
|
||||||
const res = esc(log.result || '');
|
const res = esc(auditResultLabel(log.result || ''));
|
||||||
const rid = log.resourceId || '';
|
const rid = log.resourceId ? esc(log.resourceId) : '';
|
||||||
const meta = rid ? (' · ' + esc(rid)) : '';
|
|
||||||
const eid = esc(log.id || '');
|
const eid = esc(log.id || '');
|
||||||
|
const resultCls = auditResultTagClass(log.result || '');
|
||||||
|
const rowClick = 'onclick="showAuditLogDetail(\'' + eid + '\')" ' +
|
||||||
|
'onkeydown="if(event.key===\'Enter\'||event.key===\' \'){event.preventDefault();showAuditLogDetail(\'' + eid + '\')}"';
|
||||||
return (
|
return (
|
||||||
'<div class="c2-event-item audit-log-item" role="button" tabindex="0" ' +
|
'<tr class="audit-log-row" role="button" tabindex="0" ' + rowClick + '>' +
|
||||||
'onclick="showAuditLogDetail(\'' + eid + '\')" ' +
|
'<td class="audit-log-col-time">' + when + '</td>' +
|
||||||
'onkeydown="if(event.key===\'Enter\'||event.key===\' \'){event.preventDefault();showAuditLogDetail(\'' + eid + '\')}">' +
|
'<td class="audit-log-col-msg" title="' + msg + '">' + (msg || dash) + '</td>' +
|
||||||
'<div class="c2-event-level ' + esc(lvl) + '"></div>' +
|
'<td>' + (catLabel ? '<span class="audit-tag audit-tag--cat">' + catLabel + '</span>' : dash) + '</td>' +
|
||||||
'<div class="c2-event-content">' +
|
'<td>' + (actionLabel ? '<span class="audit-tag audit-tag--act">' + actionLabel + '</span>' : dash) + '</td>' +
|
||||||
'<div class="c2-event-message">' + msg + '</div>' +
|
'<td>' + (res ? '<span class="audit-tag ' + resultCls + '">' + res + '</span>' : dash) + '</td>' +
|
||||||
'<div class="c2-event-meta">' + when + ' · ' + catLabel + '/' + actionLabel + ' · ' + res + meta +
|
'<td class="audit-log-col-ip">' + (ip || dash) + '</td>' +
|
||||||
(ip ? ' · IP ' + ip : '') +
|
'<td class="audit-log-col-resource" title="' + rid + '">' + (rid || dash) + '</td>' +
|
||||||
'</div></div></div>'
|
'</tr>'
|
||||||
);
|
);
|
||||||
}).join('');
|
}).join('');
|
||||||
|
listEl.innerHTML = head + rows + '</tbody></table></div>';
|
||||||
if (typeof applyTranslations === 'function') {
|
if (typeof applyTranslations === 'function') {
|
||||||
applyTranslations(listEl);
|
applyTranslations(listEl);
|
||||||
}
|
}
|
||||||
@@ -326,17 +417,58 @@ function resetAuditLogFilters() {
|
|||||||
const act = document.getElementById('audit-filter-action');
|
const act = document.getElementById('audit-filter-action');
|
||||||
const res = document.getElementById('audit-filter-result');
|
const res = document.getElementById('audit-filter-result');
|
||||||
const q = document.getElementById('audit-filter-q');
|
const q = document.getElementById('audit-filter-q');
|
||||||
const since = document.getElementById('audit-filter-since');
|
|
||||||
const until = document.getElementById('audit-filter-until');
|
|
||||||
if (cat) cat.value = '';
|
if (cat) cat.value = '';
|
||||||
if (res) res.value = '';
|
if (res) res.value = '';
|
||||||
if (q) q.value = '';
|
if (q) q.value = '';
|
||||||
if (since) since.value = '';
|
if (typeof window.AuditDatetimePicker !== 'undefined' && typeof window.AuditDatetimePicker.clearAll === 'function') {
|
||||||
if (until) until.value = '';
|
window.AuditDatetimePicker.clearAll();
|
||||||
|
}
|
||||||
rebuildAuditActionSelect();
|
rebuildAuditActionSelect();
|
||||||
|
syncAuditCustomSelect('audit-filter-category');
|
||||||
|
syncAuditCustomSelect('audit-filter-result');
|
||||||
filterAuditLogs();
|
filterAuditLogs();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
function applyAuditTimePreset(preset) {
|
||||||
|
if (typeof window.AuditDatetimePicker === 'undefined') return;
|
||||||
|
const now = new Date();
|
||||||
|
let since = new Date(now.getTime());
|
||||||
|
let until = new Date(now.getTime());
|
||||||
|
switch (preset) {
|
||||||
|
case '15m':
|
||||||
|
since = new Date(now.getTime() - 15 * 60 * 1000);
|
||||||
|
break;
|
||||||
|
case '1h':
|
||||||
|
since = new Date(now.getTime() - 60 * 60 * 1000);
|
||||||
|
break;
|
||||||
|
case '24h':
|
||||||
|
since = new Date(now.getTime() - 24 * 60 * 60 * 1000);
|
||||||
|
break;
|
||||||
|
case '7d':
|
||||||
|
since = new Date(now.getTime() - 7 * 24 * 60 * 60 * 1000);
|
||||||
|
break;
|
||||||
|
case 'today':
|
||||||
|
since = new Date(now.getFullYear(), now.getMonth(), now.getDate(), 0, 0, 0, 0);
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
window.AuditDatetimePicker.setValue('audit-filter-since', since);
|
||||||
|
window.AuditDatetimePicker.setValue('audit-filter-until', until);
|
||||||
|
filterAuditLogs();
|
||||||
|
}
|
||||||
|
|
||||||
|
function initAuditTimePresets() {
|
||||||
|
const wrap = document.getElementById('audit-time-presets');
|
||||||
|
if (!wrap || wrap.dataset.bound === '1') return;
|
||||||
|
wrap.dataset.bound = '1';
|
||||||
|
wrap.addEventListener('click', function (ev) {
|
||||||
|
const btn = ev.target.closest('[data-preset]');
|
||||||
|
if (!btn) return;
|
||||||
|
applyAuditTimePreset(btn.getAttribute('data-preset'));
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
/** 资源已被删除/移除的审计操作,不再提供「打开关联资源」 */
|
/** 资源已被删除/移除的审计操作,不再提供「打开关联资源」 */
|
||||||
const AUDIT_ACTIONS_RESOURCE_REMOVED = {
|
const AUDIT_ACTIONS_RESOURCE_REMOVED = {
|
||||||
delete: true,
|
delete: true,
|
||||||
@@ -565,8 +697,8 @@ async function showAuditLogDetail(id) {
|
|||||||
'<div class="modal-body audit-detail-body">' +
|
'<div class="modal-body audit-detail-body">' +
|
||||||
'<p><strong>' + esc(auditT('settingsAudit.detailTime', null, '时间')) + ':</strong> ' + esc(formatAuditTime(log.createdAt)) + '</p>' +
|
'<p><strong>' + esc(auditT('settingsAudit.detailTime', null, '时间')) + ':</strong> ' + esc(formatAuditTime(log.createdAt)) + '</p>' +
|
||||||
'<p><strong>' + esc(auditT('settingsAudit.detailCategory', null, '类别')) + ':</strong> ' + catAction + '</p>' +
|
'<p><strong>' + esc(auditT('settingsAudit.detailCategory', null, '类别')) + ':</strong> ' + catAction + '</p>' +
|
||||||
'<p><strong>' + esc(auditT('settingsAudit.detailResult', null, '结果')) + ':</strong> ' + esc(log.result || '') + '</p>' +
|
'<p><strong>' + esc(auditT('settingsAudit.detailResult', null, '结果')) + ':</strong> ' + esc(auditResultLabel(log.result || '')) + '</p>' +
|
||||||
'<p><strong>' + esc(auditT('settingsAudit.detailMessage', null, '说明')) + ':</strong> ' + esc(log.message || '') + '</p>' +
|
'<p><strong>' + esc(auditT('settingsAudit.detailMessage', null, '说明')) + ':</strong> ' + esc(auditMessageLabel(log)) + '</p>' +
|
||||||
(log.clientIp ? '<p><strong>IP:</strong> ' + esc(log.clientIp) + '</p>' : '') +
|
(log.clientIp ? '<p><strong>IP:</strong> ' + esc(log.clientIp) + '</p>' : '') +
|
||||||
(log.sessionHint ? '<p><strong>' + esc(auditT('settingsAudit.detailSession', null, '会话')) + ':</strong> ' + esc(log.sessionHint) + '</p>' : '') +
|
(log.sessionHint ? '<p><strong>' + esc(auditT('settingsAudit.detailSession', null, '会话')) + ':</strong> ' + esc(log.sessionHint) + '</p>' : '') +
|
||||||
(log.userAgent ? '<p><strong>UA:</strong> ' + esc(log.userAgent) + '</p>' : '') +
|
(log.userAgent ? '<p><strong>UA:</strong> ' + esc(log.userAgent) + '</p>' : '') +
|
||||||
@@ -597,7 +729,171 @@ async function showAuditLogDetail(id) {
|
|||||||
function initAuditLogsSection() {
|
function initAuditLogsSection() {
|
||||||
if (!document.getElementById('audit-log-list')) return;
|
if (!document.getElementById('audit-log-list')) return;
|
||||||
initAuditPageSizeFromStorage();
|
initAuditPageSizeFromStorage();
|
||||||
|
initAuditFilterSelects();
|
||||||
rebuildAuditActionSelect();
|
rebuildAuditActionSelect();
|
||||||
loadAuditMeta();
|
if (typeof window.AuditDatetimePicker !== 'undefined' && typeof window.AuditDatetimePicker.init === 'function') {
|
||||||
|
window.AuditDatetimePicker.init();
|
||||||
|
}
|
||||||
|
initAuditTimePresets();
|
||||||
|
updateAuditTimezoneHint();
|
||||||
loadAuditLogs(1);
|
loadAuditLogs(1);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
function refreshAuditFilterI18n() {
|
||||||
|
const section = document.getElementById('settings-section-audit');
|
||||||
|
if (section && typeof applyTranslations === 'function') {
|
||||||
|
applyTranslations(section);
|
||||||
|
}
|
||||||
|
rebuildAuditActionSelect();
|
||||||
|
syncAuditCustomSelect('audit-filter-category');
|
||||||
|
syncAuditCustomSelect('audit-filter-action');
|
||||||
|
syncAuditCustomSelect('audit-filter-result');
|
||||||
|
updateAuditTimezoneHint();
|
||||||
|
}
|
||||||
|
|
||||||
|
function refreshAuditLogsI18n() {
|
||||||
|
if (!document.getElementById('audit-log-list')) return;
|
||||||
|
refreshAuditFilterI18n();
|
||||||
|
if (auditLogsCache.length) {
|
||||||
|
renderAuditLogs(auditLogsCache);
|
||||||
|
renderAuditLogsPagination();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
document.addEventListener('languagechange', function () {
|
||||||
|
try {
|
||||||
|
refreshAuditLogsI18n();
|
||||||
|
} catch (e) {
|
||||||
|
console.warn('languagechange audit refresh failed', e);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
var auditCustomSelectMap = {};
|
||||||
|
var auditFilterSelectsDocListener = false;
|
||||||
|
|
||||||
|
function closeAllAuditCustomSelects() {
|
||||||
|
Object.keys(auditCustomSelectMap).forEach(function (id) {
|
||||||
|
auditCustomSelectMap[id].wrapper.classList.remove('open');
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
function syncAuditCustomSelect(selectId) {
|
||||||
|
var reg = auditCustomSelectMap[selectId];
|
||||||
|
if (!reg) return;
|
||||||
|
var select = reg.select;
|
||||||
|
var dropdown = reg.dropdown;
|
||||||
|
var trigger = reg.trigger;
|
||||||
|
var wrapper = reg.wrapper;
|
||||||
|
var valueSpan = trigger.querySelector('.audit-custom-select-value');
|
||||||
|
|
||||||
|
dropdown.innerHTML = '';
|
||||||
|
Array.prototype.forEach.call(select.options, function (opt) {
|
||||||
|
var item = document.createElement('div');
|
||||||
|
item.className = 'audit-custom-select-option';
|
||||||
|
item.setAttribute('role', 'option');
|
||||||
|
item.setAttribute('data-value', opt.value);
|
||||||
|
if (opt.value === select.value) {
|
||||||
|
item.classList.add('is-selected');
|
||||||
|
item.setAttribute('aria-selected', 'true');
|
||||||
|
}
|
||||||
|
var check = document.createElement('span');
|
||||||
|
check.className = 'audit-custom-select-check';
|
||||||
|
check.setAttribute('aria-hidden', 'true');
|
||||||
|
check.textContent = '✓';
|
||||||
|
var label = document.createElement('span');
|
||||||
|
label.className = 'audit-custom-select-label';
|
||||||
|
label.textContent = opt.textContent;
|
||||||
|
item.appendChild(check);
|
||||||
|
item.appendChild(label);
|
||||||
|
dropdown.appendChild(item);
|
||||||
|
});
|
||||||
|
|
||||||
|
var selectedOpt = select.options[select.selectedIndex];
|
||||||
|
if (valueSpan) {
|
||||||
|
valueSpan.textContent = selectedOpt ? selectedOpt.textContent : '';
|
||||||
|
}
|
||||||
|
trigger.disabled = !!select.disabled;
|
||||||
|
wrapper.classList.toggle('is-disabled', !!select.disabled);
|
||||||
|
}
|
||||||
|
|
||||||
|
function enhanceAuditFilterSelect(selectId) {
|
||||||
|
var select = document.getElementById(selectId);
|
||||||
|
if (!select) return;
|
||||||
|
if (select.dataset.auditCustom === '1') {
|
||||||
|
syncAuditCustomSelect(selectId);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
select.dataset.auditCustom = '1';
|
||||||
|
select.classList.add('audit-native-select');
|
||||||
|
select.tabIndex = -1;
|
||||||
|
select.setAttribute('aria-hidden', 'true');
|
||||||
|
|
||||||
|
var wrapper = document.createElement('div');
|
||||||
|
wrapper.className = 'audit-custom-select';
|
||||||
|
|
||||||
|
var trigger = document.createElement('button');
|
||||||
|
trigger.type = 'button';
|
||||||
|
trigger.className = 'audit-custom-select-trigger';
|
||||||
|
trigger.setAttribute('aria-haspopup', 'listbox');
|
||||||
|
var valueSpan = document.createElement('span');
|
||||||
|
valueSpan.className = 'audit-custom-select-value';
|
||||||
|
trigger.appendChild(valueSpan);
|
||||||
|
var caret = document.createElement('span');
|
||||||
|
caret.className = 'audit-custom-select-caret';
|
||||||
|
caret.setAttribute('aria-hidden', 'true');
|
||||||
|
caret.textContent = '▾';
|
||||||
|
trigger.appendChild(caret);
|
||||||
|
|
||||||
|
var dropdown = document.createElement('div');
|
||||||
|
dropdown.className = 'audit-custom-select-dropdown';
|
||||||
|
dropdown.setAttribute('role', 'listbox');
|
||||||
|
|
||||||
|
var parent = select.parentNode;
|
||||||
|
parent.insertBefore(wrapper, select);
|
||||||
|
wrapper.appendChild(trigger);
|
||||||
|
wrapper.appendChild(dropdown);
|
||||||
|
wrapper.appendChild(select);
|
||||||
|
|
||||||
|
auditCustomSelectMap[selectId] = {
|
||||||
|
wrapper: wrapper,
|
||||||
|
trigger: trigger,
|
||||||
|
dropdown: dropdown,
|
||||||
|
select: select
|
||||||
|
};
|
||||||
|
|
||||||
|
trigger.addEventListener('click', function (e) {
|
||||||
|
e.stopPropagation();
|
||||||
|
if (select.disabled) return;
|
||||||
|
var open = wrapper.classList.contains('open');
|
||||||
|
closeAllAuditCustomSelects();
|
||||||
|
if (!open) wrapper.classList.add('open');
|
||||||
|
});
|
||||||
|
|
||||||
|
dropdown.addEventListener('click', function (e) {
|
||||||
|
var opt = e.target.closest('.audit-custom-select-option');
|
||||||
|
if (!opt) return;
|
||||||
|
var val = opt.getAttribute('data-value');
|
||||||
|
if (val === null) val = '';
|
||||||
|
if (select.value !== val) {
|
||||||
|
select.value = val;
|
||||||
|
select.dispatchEvent(new Event('change', { bubbles: true }));
|
||||||
|
}
|
||||||
|
wrapper.classList.remove('open');
|
||||||
|
syncAuditCustomSelect(selectId);
|
||||||
|
});
|
||||||
|
|
||||||
|
syncAuditCustomSelect(selectId);
|
||||||
|
}
|
||||||
|
|
||||||
|
function initAuditFilterSelects() {
|
||||||
|
if (!document.getElementById('audit-filter-category')) return;
|
||||||
|
if (!auditFilterSelectsDocListener) {
|
||||||
|
document.addEventListener('click', function () {
|
||||||
|
closeAllAuditCustomSelects();
|
||||||
|
});
|
||||||
|
auditFilterSelectsDocListener = true;
|
||||||
|
}
|
||||||
|
enhanceAuditFilterSelect('audit-filter-category');
|
||||||
|
enhanceAuditFilterSelect('audit-filter-action');
|
||||||
|
enhanceAuditFilterSelect('audit-filter-result');
|
||||||
|
}
|
||||||
|
|||||||
+704
-99
File diff suppressed because it is too large
Load Diff
+284
-39
@@ -2164,7 +2164,110 @@ function showCopySuccess(button) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/** Claude extended thinking 内部尾缀(与后端 DisplayReasoningContent 一致,UI 不展示) */
|
||||||
|
const CLAUDE_REASONING_UI_SUFFIX = '\n---CSAI_CLAUDE_THINKING_BLOCKS---\n';
|
||||||
|
|
||||||
|
function normalizeReasoningContentForDisplay(text) {
|
||||||
|
if (text == null) return '';
|
||||||
|
let s = String(text).trim();
|
||||||
|
if (!s) return '';
|
||||||
|
const idx = s.lastIndexOf(CLAUDE_REASONING_UI_SUFFIX);
|
||||||
|
if (idx >= 0) {
|
||||||
|
s = s.slice(0, idx).trim();
|
||||||
|
}
|
||||||
|
return s;
|
||||||
|
}
|
||||||
|
|
||||||
|
function setMessageReasoningContent(messageIdOrEl, reasoningContent) {
|
||||||
|
const el = typeof messageIdOrEl === 'string' ? document.getElementById(messageIdOrEl) : messageIdOrEl;
|
||||||
|
if (!el || !el.dataset) return;
|
||||||
|
const rc = normalizeReasoningContentForDisplay(reasoningContent);
|
||||||
|
if (rc) {
|
||||||
|
el.dataset.reasoningContent = rc;
|
||||||
|
} else {
|
||||||
|
delete el.dataset.reasoningContent;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
function getMessageReasoningContent(messageIdOrEl) {
|
||||||
|
const el = typeof messageIdOrEl === 'string' ? document.getElementById(messageIdOrEl) : messageIdOrEl;
|
||||||
|
if (!el || !el.dataset) return '';
|
||||||
|
return normalizeReasoningContentForDisplay(el.dataset.reasoningContent || '');
|
||||||
|
}
|
||||||
|
|
||||||
|
function reasoningTextAlreadyInProcessDetails(processDetails, rc) {
|
||||||
|
if (!rc) return true;
|
||||||
|
const list = Array.isArray(processDetails) ? processDetails : [];
|
||||||
|
for (let i = 0; i < list.length; i++) {
|
||||||
|
const d = list[i];
|
||||||
|
if (!d) continue;
|
||||||
|
const et = d.eventType || '';
|
||||||
|
if (et !== 'reasoning_chain' && et !== 'thinking') continue;
|
||||||
|
const msg = normalizeReasoningContentForDisplay(d.message || '');
|
||||||
|
if (!msg) continue;
|
||||||
|
if (msg === rc || msg.includes(rc) || rc.includes(msg)) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
/** 合并 messages.reasoningContent 与 process_details 中的 reasoning_chain,两者都读、都展示(去重后) */
|
||||||
|
function mergeMessageReasoningContentIntoProcessDetails(processDetails, reasoningContent) {
|
||||||
|
const rc = normalizeReasoningContentForDisplay(reasoningContent);
|
||||||
|
const details = Array.isArray(processDetails) ? processDetails.slice() : [];
|
||||||
|
if (!rc || reasoningTextAlreadyInProcessDetails(details, rc)) {
|
||||||
|
return details;
|
||||||
|
}
|
||||||
|
details.push({
|
||||||
|
eventType: 'reasoning_chain',
|
||||||
|
message: rc,
|
||||||
|
data: { source: 'message.reasoningContent' }
|
||||||
|
});
|
||||||
|
return details;
|
||||||
|
}
|
||||||
|
|
||||||
|
async function syncAssistantReasoningContentFromServer(backendMessageId, domAssistantId) {
|
||||||
|
if (!backendMessageId || !domAssistantId || !currentConversationId || typeof apiFetch !== 'function') {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
try {
|
||||||
|
const convRes = await apiFetch(`/api/conversations/${encodeURIComponent(currentConversationId)}?include_process_details=0`);
|
||||||
|
const conv = await convRes.json().catch(() => ({}));
|
||||||
|
if (!convRes.ok || !Array.isArray(conv.messages)) return;
|
||||||
|
const msg = conv.messages.find((m) => m && String(m.id) === String(backendMessageId));
|
||||||
|
if (!msg || !msg.reasoningContent) return;
|
||||||
|
setMessageReasoningContent(domAssistantId, msg.reasoningContent);
|
||||||
|
const pdRes = await apiFetch(`/api/messages/${encodeURIComponent(String(backendMessageId))}/process-details`);
|
||||||
|
const pdJson = await pdRes.json().catch(() => ({}));
|
||||||
|
const details = pdRes.ok && Array.isArray(pdJson.processDetails) ? pdJson.processDetails : [];
|
||||||
|
if (typeof renderProcessDetails === 'function') {
|
||||||
|
renderProcessDetails(domAssistantId, details);
|
||||||
|
}
|
||||||
|
} catch (e) {
|
||||||
|
console.warn('syncAssistantReasoningContentFromServer failed', e);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
window.normalizeReasoningContentForDisplay = normalizeReasoningContentForDisplay;
|
||||||
|
window.setMessageReasoningContent = setMessageReasoningContent;
|
||||||
|
window.getMessageReasoningContent = getMessageReasoningContent;
|
||||||
|
window.filterNoiseProcessDetails = filterNoiseProcessDetails;
|
||||||
|
window.mergeMessageReasoningContentIntoProcessDetails = mergeMessageReasoningContentIntoProcessDetails;
|
||||||
|
window.syncAssistantReasoningContentFromServer = syncAssistantReasoningContentFromServer;
|
||||||
|
|
||||||
/** 相邻且类型/正文/data 完全一致的过程详情只保留一条(与后端去重一致,避免时间线叠多条相同块) */
|
/** 相邻且类型/正文/data 完全一致的过程详情只保留一条(与后端去重一致,避免时间线叠多条相同块) */
|
||||||
|
function isEinoAgentHeartbeatProgress(detail) {
|
||||||
|
if (!detail || detail.eventType !== 'progress') return false;
|
||||||
|
const msg = String(detail.message != null ? detail.message : '').trim();
|
||||||
|
return /^\[Eino\]\s+\S/.test(msg);
|
||||||
|
}
|
||||||
|
|
||||||
|
function filterNoiseProcessDetails(details) {
|
||||||
|
if (!Array.isArray(details)) return details;
|
||||||
|
return details.filter(function (d) { return !isEinoAgentHeartbeatProgress(d); });
|
||||||
|
}
|
||||||
|
|
||||||
function dedupeConsecutiveProcessDetailRows(details) {
|
function dedupeConsecutiveProcessDetailRows(details) {
|
||||||
if (!Array.isArray(details) || details.length < 2) {
|
if (!Array.isArray(details) || details.length < 2) {
|
||||||
return details;
|
return details;
|
||||||
@@ -2282,20 +2385,28 @@ function renderProcessDetails(messageId, processDetails) {
|
|||||||
detailsContainer.appendChild(contentDiv);
|
detailsContainer.appendChild(contentDiv);
|
||||||
}
|
}
|
||||||
|
|
||||||
// processDetails === null 表示“尚未加载(懒加载)”
|
// processDetails === null 表示“尚未加载(懒加载)”;messages.reasoningContent 可先展示
|
||||||
const isLazyNotLoaded = (processDetails === null);
|
const isLazyNotLoaded = (processDetails === null);
|
||||||
if (isLazyNotLoaded) {
|
const reasoningFromMessage = getMessageReasoningContent(messageElement);
|
||||||
|
if (isLazyNotLoaded && !reasoningFromMessage) {
|
||||||
detailsContainer.dataset.lazyNotLoaded = '1';
|
detailsContainer.dataset.lazyNotLoaded = '1';
|
||||||
detailsContainer.dataset.loaded = '0';
|
detailsContainer.dataset.loaded = '0';
|
||||||
timeline.innerHTML = '<div class="progress-timeline-empty">' +
|
timeline.innerHTML = '<div class="progress-timeline-empty">' +
|
||||||
(typeof window.t === 'function' ? window.t('chat.expandDetail') : '展开详情') +
|
(typeof window.t === 'function' ? window.t('chat.expandDetail') : '展开详情') +
|
||||||
'(点击后加载)</div>';
|
'(点击后加载)</div>';
|
||||||
// 默认折叠
|
|
||||||
timeline.classList.remove('expanded');
|
timeline.classList.remove('expanded');
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
if (isLazyNotLoaded) {
|
||||||
|
detailsContainer.dataset.lazyNotLoaded = '1';
|
||||||
|
detailsContainer.dataset.loaded = '0';
|
||||||
|
processDetails = [];
|
||||||
|
} else {
|
||||||
detailsContainer.dataset.lazyNotLoaded = '0';
|
detailsContainer.dataset.lazyNotLoaded = '0';
|
||||||
detailsContainer.dataset.loaded = '1';
|
detailsContainer.dataset.loaded = '1';
|
||||||
|
}
|
||||||
|
processDetails = mergeMessageReasoningContentIntoProcessDetails(processDetails, reasoningFromMessage);
|
||||||
|
processDetails = filterNoiseProcessDetails(processDetails);
|
||||||
processDetails = dedupeConsecutiveProcessDetailRows(processDetails);
|
processDetails = dedupeConsecutiveProcessDetailRows(processDetails);
|
||||||
if (typeof window.coalesceProcessDetailsToolPairs === 'function') {
|
if (typeof window.coalesceProcessDetailsToolPairs === 'function') {
|
||||||
processDetails = window.coalesceProcessDetailsToolPairs(processDetails);
|
processDetails = window.coalesceProcessDetailsToolPairs(processDetails);
|
||||||
@@ -2427,6 +2538,14 @@ function renderProcessDetails(messageId, processDetails) {
|
|||||||
addTimelineItem(timeline, eventType, timelineOpts);
|
addTimelineItem(timeline, eventType, timelineOpts);
|
||||||
});
|
});
|
||||||
|
|
||||||
|
if (isLazyNotLoaded && reasoningFromMessage) {
|
||||||
|
const lazyHint = document.createElement('div');
|
||||||
|
lazyHint.className = 'progress-timeline-empty progress-timeline-lazy-hint';
|
||||||
|
lazyHint.textContent = (typeof window.t === 'function' ? window.t('chat.expandDetail') : '展开详情') +
|
||||||
|
'(点击后加载完整过程详情)';
|
||||||
|
timeline.appendChild(lazyHint);
|
||||||
|
}
|
||||||
|
|
||||||
// 检查是否有错误或取消事件,如果有,确保详情默认折叠(但仍有待审批 HITL 时保持展开,由 restoreHitlInlineForConversation 处理)
|
// 检查是否有错误或取消事件,如果有,确保详情默认折叠(但仍有待审批 HITL 时保持展开,由 restoreHitlInlineForConversation 处理)
|
||||||
const hasPendingHitlInDetails = processDetails.some(d => d && d.eventType === 'hitl_interrupt');
|
const hasPendingHitlInDetails = processDetails.some(d => d && d.eventType === 'hitl_interrupt');
|
||||||
const hasErrorOrCancelled = processDetails.some(d =>
|
const hasErrorOrCancelled = processDetails.some(d =>
|
||||||
@@ -2533,6 +2652,57 @@ async function batchUpdateButtonToolNames(buttonsContainer, executionIds) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// 显示MCP调用详情
|
// 显示MCP调用详情
|
||||||
|
const MCP_DETAIL_MAX_CHARS = 120000;
|
||||||
|
|
||||||
|
function extractMCPResultText(result) {
|
||||||
|
if (!result) return '';
|
||||||
|
const content = result.content;
|
||||||
|
if (typeof content === 'string') return content;
|
||||||
|
if (Array.isArray(content)) {
|
||||||
|
return content
|
||||||
|
.map(item => (item && typeof item === 'object' && typeof item.text === 'string') ? item.text : '')
|
||||||
|
.filter(Boolean)
|
||||||
|
.join('\n\n');
|
||||||
|
}
|
||||||
|
if (content && typeof content === 'object' && typeof content.text === 'string') {
|
||||||
|
return content.text;
|
||||||
|
}
|
||||||
|
return '';
|
||||||
|
}
|
||||||
|
|
||||||
|
function truncateMCPDetailText(text, maxChars) {
|
||||||
|
if (text == null) return '';
|
||||||
|
const s = String(text);
|
||||||
|
if (s.length <= maxChars) return s;
|
||||||
|
const hint = typeof window.t === 'function'
|
||||||
|
? window.t('mcpDetailModal.contentTruncated')
|
||||||
|
: '…(展示已截断;完整内容见 persisted-output 中的文件路径,用 read_file 读取)';
|
||||||
|
return s.slice(0, maxChars) + '\n\n' + hint;
|
||||||
|
}
|
||||||
|
|
||||||
|
/** 响应结果区 JSON 展示(过大时截断 content 内 text,避免 stringify 卡死页面) */
|
||||||
|
function formatMCPResultJsonForDisplay(result, maxChars) {
|
||||||
|
if (!result) return '{}';
|
||||||
|
const payload = {
|
||||||
|
content: result.content,
|
||||||
|
isError: !!result.isError
|
||||||
|
};
|
||||||
|
let json = JSON.stringify(payload, null, 2);
|
||||||
|
if (json.length <= maxChars) {
|
||||||
|
return json;
|
||||||
|
}
|
||||||
|
const text = extractMCPResultText(result);
|
||||||
|
const truncatedPayload = {
|
||||||
|
content: [{ type: 'text', text: truncateMCPDetailText(text, Math.min(maxChars - 800, MCP_DETAIL_MAX_CHARS)) }],
|
||||||
|
isError: !!result.isError
|
||||||
|
};
|
||||||
|
json = JSON.stringify(truncatedPayload, null, 2);
|
||||||
|
if (json.length > maxChars) {
|
||||||
|
return json.slice(0, maxChars) + '\n…';
|
||||||
|
}
|
||||||
|
return json;
|
||||||
|
}
|
||||||
|
|
||||||
async function showMCPDetail(executionId) {
|
async function showMCPDetail(executionId) {
|
||||||
try {
|
try {
|
||||||
openAppModal('mcp-detail-modal', { focus: false });
|
openAppModal('mcp-detail-modal', { focus: false });
|
||||||
@@ -2594,42 +2764,22 @@ async function showMCPDetail(executionId) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if (exec.result) {
|
if (exec.result) {
|
||||||
const responseData = {
|
const agentVisibleText = truncateMCPDetailText(extractMCPResultText(exec.result), MCP_DETAIL_MAX_CHARS);
|
||||||
content: exec.result.content,
|
const emptyText = typeof window.t === 'function' ? window.t('mcpDetailModal.execSuccessNoContent') : '执行成功,未返回可展示的文本内容。';
|
||||||
isError: exec.result.isError
|
|
||||||
};
|
|
||||||
responseElement.textContent = JSON.stringify(responseData, null, 2);
|
|
||||||
|
|
||||||
if (exec.result.isError) {
|
if (exec.result.isError) {
|
||||||
// 错误场景:响应结果标红 + 错误信息区块
|
|
||||||
responseElement.className = 'code-block error';
|
responseElement.className = 'code-block error';
|
||||||
|
responseElement.textContent = formatMCPResultJsonForDisplay(exec.result, MCP_DETAIL_MAX_CHARS);
|
||||||
if (exec.error && errorSection && errorElement) {
|
if (exec.error && errorSection && errorElement) {
|
||||||
errorSection.style.display = 'block';
|
errorSection.style.display = 'block';
|
||||||
errorElement.textContent = exec.error;
|
errorElement.textContent = exec.error;
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
// 成功场景:响应结果保持普通样式,正确信息单独拎出来
|
|
||||||
responseElement.className = 'code-block';
|
responseElement.className = 'code-block';
|
||||||
|
responseElement.textContent = formatMCPResultJsonForDisplay(exec.result, MCP_DETAIL_MAX_CHARS);
|
||||||
if (successSection && successElement) {
|
if (successSection && successElement) {
|
||||||
successSection.style.display = 'block';
|
successSection.style.display = 'block';
|
||||||
let successText = '';
|
successElement.textContent = agentVisibleText || emptyText;
|
||||||
const content = exec.result.content;
|
|
||||||
if (typeof content === 'string') {
|
|
||||||
successText = content;
|
|
||||||
} else if (Array.isArray(content)) {
|
|
||||||
const texts = content
|
|
||||||
.map(item => (item && typeof item === 'object' && typeof item.text === 'string') ? item.text : '')
|
|
||||||
.filter(Boolean);
|
|
||||||
if (texts.length > 0) {
|
|
||||||
successText = texts.join('\n\n');
|
|
||||||
}
|
|
||||||
} else if (content && typeof content === 'object' && typeof content.text === 'string') {
|
|
||||||
successText = content.text;
|
|
||||||
}
|
|
||||||
if (!successText) {
|
|
||||||
successText = typeof window.t === 'function' ? window.t('mcpDetailModal.execSuccessNoContent') : '执行成功,未返回可展示的文本内容。';
|
|
||||||
}
|
|
||||||
successElement.textContent = successText;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
@@ -3193,6 +3343,9 @@ async function loadConversation(conversationId) {
|
|||||||
attachDeleteTurnButton(messageEl);
|
attachDeleteTurnButton(messageEl);
|
||||||
}
|
}
|
||||||
if (msg.role === 'assistant') {
|
if (msg.role === 'assistant') {
|
||||||
|
if (messageEl && msg.reasoningContent) {
|
||||||
|
setMessageReasoningContent(messageEl, msg.reasoningContent);
|
||||||
|
}
|
||||||
const hasField = msg && Object.prototype.hasOwnProperty.call(msg, 'processDetails');
|
const hasField = msg && Object.prototype.hasOwnProperty.call(msg, 'processDetails');
|
||||||
renderProcessDetails(messageId, hasField ? (msg.processDetails || []) : null);
|
renderProcessDetails(messageId, hasField ? (msg.processDetails || []) : null);
|
||||||
if (msg.processDetails && msg.processDetails.length > 0) {
|
if (msg.processDetails && msg.processDetails.length > 0) {
|
||||||
@@ -5623,6 +5776,95 @@ let conversationGroupMappingCache = {};
|
|||||||
let pendingGroupMappings = {}; // 待保留的分组映射(用于处理后端API延迟的情况)
|
let pendingGroupMappings = {}; // 待保留的分组映射(用于处理后端API延迟的情况)
|
||||||
let conversationsListLoadSeq = 0; // 对话列表加载序号,避免并发请求导致重复渲染
|
let conversationsListLoadSeq = 0; // 对话列表加载序号,避免并发请求导致重复渲染
|
||||||
const CONVERSATIONS_PAGE_SIZE_KEY = 'cyberstrike.conversations_page_size';
|
const CONVERSATIONS_PAGE_SIZE_KEY = 'cyberstrike.conversations_page_size';
|
||||||
|
const CONVERSATIONS_SORT_KEY = 'cyberstrike.conversations_sort_by';
|
||||||
|
|
||||||
|
function getConversationSortBy() {
|
||||||
|
try {
|
||||||
|
const saved = localStorage.getItem(CONVERSATIONS_SORT_KEY);
|
||||||
|
if (saved === 'created_at' || saved === 'updated_at') return saved;
|
||||||
|
} catch (e) { /* ignore */ }
|
||||||
|
return 'updated_at';
|
||||||
|
}
|
||||||
|
|
||||||
|
let conversationSortBy = getConversationSortBy();
|
||||||
|
|
||||||
|
function getConversationSortTime(conv) {
|
||||||
|
const field = conversationSortBy === 'created_at' ? 'createdAt' : 'updatedAt';
|
||||||
|
const raw = conv && conv[field];
|
||||||
|
if (!raw) return new Date(0);
|
||||||
|
const date = new Date(raw);
|
||||||
|
return isNaN(date.getTime()) ? new Date(0) : date;
|
||||||
|
}
|
||||||
|
|
||||||
|
function updateConversationSortMenuUI() {
|
||||||
|
const menu = document.getElementById('conversation-sort-menu');
|
||||||
|
const btn = document.getElementById('conversation-sort-btn');
|
||||||
|
if (!menu) return;
|
||||||
|
menu.querySelectorAll('.conversation-sort-option').forEach((option) => {
|
||||||
|
const selected = option.dataset.sort === conversationSortBy;
|
||||||
|
option.classList.toggle('is-selected', selected);
|
||||||
|
option.setAttribute('aria-checked', selected ? 'true' : 'false');
|
||||||
|
});
|
||||||
|
if (btn) {
|
||||||
|
btn.setAttribute('aria-expanded', menu.hidden ? 'false' : 'true');
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
function closeConversationSortMenu() {
|
||||||
|
const menu = document.getElementById('conversation-sort-menu');
|
||||||
|
const btn = document.getElementById('conversation-sort-btn');
|
||||||
|
if (menu) menu.hidden = true;
|
||||||
|
if (btn) btn.setAttribute('aria-expanded', 'false');
|
||||||
|
}
|
||||||
|
|
||||||
|
function toggleConversationSortMenu(event) {
|
||||||
|
if (event) {
|
||||||
|
event.preventDefault();
|
||||||
|
event.stopPropagation();
|
||||||
|
}
|
||||||
|
const menu = document.getElementById('conversation-sort-menu');
|
||||||
|
const btn = document.getElementById('conversation-sort-btn');
|
||||||
|
if (!menu || !btn) return;
|
||||||
|
const willOpen = menu.hidden;
|
||||||
|
closeConversationSortMenu();
|
||||||
|
if (willOpen) {
|
||||||
|
menu.hidden = false;
|
||||||
|
btn.setAttribute('aria-expanded', 'true');
|
||||||
|
updateConversationSortMenuUI();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
function setConversationSortBy(sortBy) {
|
||||||
|
const next = sortBy === 'created_at' ? 'created_at' : 'updated_at';
|
||||||
|
if (next === conversationSortBy) {
|
||||||
|
closeConversationSortMenu();
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
conversationSortBy = next;
|
||||||
|
try {
|
||||||
|
localStorage.setItem(CONVERSATIONS_SORT_KEY, next);
|
||||||
|
} catch (e) { /* ignore */ }
|
||||||
|
updateConversationSortMenuUI();
|
||||||
|
closeConversationSortMenu();
|
||||||
|
conversationsPagination.page = 1;
|
||||||
|
loadConversationsWithGroups(conversationsSearchQuery);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!window.__conversationSortMenuBound) {
|
||||||
|
window.__conversationSortMenuBound = true;
|
||||||
|
document.addEventListener('click', (event) => {
|
||||||
|
const dropdown = document.getElementById('conversation-sort-dropdown');
|
||||||
|
if (!dropdown || dropdown.contains(event.target)) return;
|
||||||
|
closeConversationSortMenu();
|
||||||
|
});
|
||||||
|
document.addEventListener('keydown', (event) => {
|
||||||
|
if (event.key === 'Escape') closeConversationSortMenu();
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
window.toggleConversationSortMenu = toggleConversationSortMenu;
|
||||||
|
window.setConversationSortBy = setConversationSortBy;
|
||||||
|
window.closeConversationSortMenu = closeConversationSortMenu;
|
||||||
|
|
||||||
function getConversationsPageSize() {
|
function getConversationsPageSize() {
|
||||||
try {
|
try {
|
||||||
@@ -5885,6 +6127,9 @@ async function loadConversationsWithGroups(searchQuery = '') {
|
|||||||
const pageSize = conversationsPagination.pageSize;
|
const pageSize = conversationsPagination.pageSize;
|
||||||
const offset = (conversationsPagination.page - 1) * pageSize;
|
const offset = (conversationsPagination.page - 1) * pageSize;
|
||||||
const convParams = new URLSearchParams({ limit: String(pageSize), offset: String(offset) });
|
const convParams = new URLSearchParams({ limit: String(pageSize), offset: String(offset) });
|
||||||
|
if (conversationSortBy === 'created_at') {
|
||||||
|
convParams.set('sort_by', 'created_at');
|
||||||
|
}
|
||||||
if (searchQuery && searchQuery.trim()) {
|
if (searchQuery && searchQuery.trim()) {
|
||||||
convParams.set('search', searchQuery.trim());
|
convParams.set('search', searchQuery.trim());
|
||||||
} else {
|
} else {
|
||||||
@@ -5974,11 +6219,7 @@ async function loadConversationsWithGroups(searchQuery = '') {
|
|||||||
});
|
});
|
||||||
|
|
||||||
// 按时间排序
|
// 按时间排序
|
||||||
const sortByTime = (a, b) => {
|
const sortByTime = (a, b) => getConversationSortTime(b) - getConversationSortTime(a);
|
||||||
const timeA = a.updatedAt ? new Date(a.updatedAt) : new Date(0);
|
|
||||||
const timeB = b.updatedAt ? new Date(b.updatedAt) : new Date(0);
|
|
||||||
return timeB - timeA;
|
|
||||||
};
|
|
||||||
|
|
||||||
pinnedConvs.sort(sortByTime);
|
pinnedConvs.sort(sortByTime);
|
||||||
normalConvs.sort(sortByTime);
|
normalConvs.sort(sortByTime);
|
||||||
@@ -6006,8 +6247,8 @@ async function loadConversationsWithGroups(searchQuery = '') {
|
|||||||
};
|
};
|
||||||
|
|
||||||
normalConvs.forEach(conv => {
|
normalConvs.forEach(conv => {
|
||||||
const dateObj = conv.updatedAt ? new Date(conv.updatedAt) : new Date();
|
const dateObj = getConversationSortTime(conv);
|
||||||
const validDate = isNaN(dateObj.getTime()) ? new Date() : dateObj;
|
const validDate = dateObj.getTime() === 0 ? new Date() : dateObj;
|
||||||
const groupKey = getConversationGroup(validDate, todayStart, sevenDaysCutoff, yesterdayStart);
|
const groupKey = getConversationGroup(validDate, todayStart, sevenDaysCutoff, yesterdayStart);
|
||||||
groups[groupKey].push({
|
groups[groupKey].push({
|
||||||
...conv,
|
...conv,
|
||||||
@@ -6019,8 +6260,8 @@ async function loadConversationsWithGroups(searchQuery = '') {
|
|||||||
|
|
||||||
if (pinnedConvs.length > 0) {
|
if (pinnedConvs.length > 0) {
|
||||||
pinnedConvs.forEach(conv => {
|
pinnedConvs.forEach(conv => {
|
||||||
const dateObj = conv.updatedAt ? new Date(conv.updatedAt) : new Date();
|
const dateObj = getConversationSortTime(conv);
|
||||||
const validDate = isNaN(dateObj.getTime()) ? new Date() : dateObj;
|
const validDate = dateObj.getTime() === 0 ? new Date() : dateObj;
|
||||||
fragment.appendChild(createConversationListItemWithMenu({
|
fragment.appendChild(createConversationListItemWithMenu({
|
||||||
...conv,
|
...conv,
|
||||||
_timeText: formatConversationTimestamp(validDate, todayStart, yesterdayStart),
|
_timeText: formatConversationTimestamp(validDate, todayStart, yesterdayStart),
|
||||||
@@ -7359,8 +7600,11 @@ async function deleteSelectedConversations() {
|
|||||||
for (const id of ids) {
|
for (const id of ids) {
|
||||||
await deleteConversation(id, true); // 跳过内部确认,因为批量删除时已经确认过了
|
await deleteConversation(id, true); // 跳过内部确认,因为批量删除时已经确认过了
|
||||||
}
|
}
|
||||||
closeBatchManageModal();
|
// 删除后保持弹窗打开,便于继续管理剩余对话
|
||||||
loadConversationsWithGroups();
|
const selectAll = document.getElementById('batch-select-all');
|
||||||
|
if (selectAll) {
|
||||||
|
selectAll.checked = false;
|
||||||
|
}
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
console.error('删除失败:', error);
|
console.error('删除失败:', error);
|
||||||
const failedMsg = typeof window.t === 'function' ? window.t('batchManageModal.deleteFailed') : '删除失败';
|
const failedMsg = typeof window.t === 'function' ? window.t('batchManageModal.deleteFailed') : '删除失败';
|
||||||
@@ -8365,6 +8609,7 @@ function clearGroupSearch() {
|
|||||||
|
|
||||||
// 初始化时加载分组
|
// 初始化时加载分组
|
||||||
document.addEventListener('DOMContentLoaded', async () => {
|
document.addEventListener('DOMContentLoaded', async () => {
|
||||||
|
updateConversationSortMenuUI();
|
||||||
await loadGroups();
|
await loadGroups();
|
||||||
await loadConversationsWithGroups();
|
await loadConversationsWithGroups();
|
||||||
|
|
||||||
|
|||||||
@@ -0,0 +1,680 @@
|
|||||||
|
/**
|
||||||
|
* 项目事实图渲染(Cytoscape + ELK),供项目管理页使用。
|
||||||
|
* 节点采用 SVG 卡片背景(图标 + 多行文字),避免 Cytoscape 原生 label 定位问题。
|
||||||
|
*/
|
||||||
|
(function (global) {
|
||||||
|
'use strict';
|
||||||
|
|
||||||
|
let _cy = null;
|
||||||
|
let _graphData = null;
|
||||||
|
let _onNodeSelect = null;
|
||||||
|
let _onEdgeSelect = null;
|
||||||
|
let _resizeObs = null;
|
||||||
|
|
||||||
|
const EDGE_COLORS = {
|
||||||
|
discovered_on: '#4F46E5',
|
||||||
|
leads_to: '#64748B',
|
||||||
|
enables: '#E11D48',
|
||||||
|
exploits: '#DC2626',
|
||||||
|
depends_on: '#0D9488',
|
||||||
|
contains: '#6366F1',
|
||||||
|
part_of: '#6366F1',
|
||||||
|
supports: '#94A3B8',
|
||||||
|
links_vuln: '#BE123C',
|
||||||
|
};
|
||||||
|
|
||||||
|
const CARD_PAD = 14;
|
||||||
|
const CARD_TEXT_PAD_RIGHT = 12;
|
||||||
|
const CARD_ICON = 36;
|
||||||
|
const CARD_ICON_GAP = 12;
|
||||||
|
const CARD_TEXT_X = CARD_PAD + CARD_ICON + CARD_ICON_GAP;
|
||||||
|
const CARD_MIN_W = 300;
|
||||||
|
const CARD_TARGET_W = 360;
|
||||||
|
const CARD_MIN_H = 88;
|
||||||
|
const CARD_MAX_H = 176;
|
||||||
|
const CARD_HEADER_FS = 11;
|
||||||
|
const CARD_HEADER_LH = 16;
|
||||||
|
const CARD_KEY_FS = 10;
|
||||||
|
const CARD_KEY_LH = 14;
|
||||||
|
const CARD_SUMMARY_FS = 13;
|
||||||
|
const CARD_SUMMARY_LH = 18;
|
||||||
|
const CARD_SECTION_GAP = 6;
|
||||||
|
const CARD_FONT =
|
||||||
|
'-apple-system, BlinkMacSystemFont, "Segoe UI", Roboto, "Helvetica Neue", "PingFang SC", "Microsoft YaHei", sans-serif';
|
||||||
|
const CARD_KEY_FONT =
|
||||||
|
'ui-monospace, SFMono-Regular, Menlo, Monaco, Consolas, "Liberation Mono", monospace';
|
||||||
|
|
||||||
|
function nodeTheme(type) {
|
||||||
|
switch (type) {
|
||||||
|
case 'target':
|
||||||
|
return { typeLabel: '目标', typeEn: 'TARGET', accent: '#4F46E5', bgEnd: '#F5F3FF', icon: 'target' };
|
||||||
|
case 'finding':
|
||||||
|
return { typeLabel: '发现', typeEn: 'FINDING', accent: '#E11D48', bgEnd: '#FFF1F2', icon: 'finding', cardStyle: 'default' };
|
||||||
|
case 'exploit':
|
||||||
|
return { typeLabel: '利用', typeEn: 'EXPLOIT', accent: '#B45309', bgEnd: '#FFFBEB', icon: 'vulnerability', cardStyle: 'default' };
|
||||||
|
case 'vulnerability':
|
||||||
|
return { typeLabel: '漏洞', typeEn: 'VULN', accent: '#9333EA', bgEnd: '#F5F3FF', icon: 'vuln', cardStyle: 'default' };
|
||||||
|
case 'auth':
|
||||||
|
return { typeLabel: '认证', typeEn: 'AUTH', accent: '#0D9488', bgEnd: '#F0FDFA', icon: 'default' };
|
||||||
|
case 'infra':
|
||||||
|
return { typeLabel: '基础设施', typeEn: 'INFRA', accent: '#64748B', bgEnd: '#F8FAFC', icon: 'default' };
|
||||||
|
case 'chain':
|
||||||
|
return { typeLabel: '攻击链', typeEn: 'CHAIN', accent: '#7C3AED', bgEnd: '#F5F3FF', icon: 'vulnerability' };
|
||||||
|
case 'poc':
|
||||||
|
return { typeLabel: 'POC', typeEn: 'POC', accent: '#C2410C', bgEnd: '#FFEDD5', icon: 'vulnerability' };
|
||||||
|
case 'business':
|
||||||
|
return { typeLabel: '业务', typeEn: 'BUSINESS', accent: '#0369A1', bgEnd: '#F0F9FF', icon: 'default' };
|
||||||
|
case 'missing':
|
||||||
|
return { typeLabel: '缺失', typeEn: 'MISSING', accent: '#CBD5E1', bgEnd: '#F1F5F9', icon: 'default' };
|
||||||
|
default:
|
||||||
|
return { typeLabel: '备注', typeEn: 'NOTE', accent: '#94A3B8', bgEnd: '#F8FAFC', icon: 'default' };
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
function escapeXml(str) {
|
||||||
|
return String(str)
|
||||||
|
.replace(/&/g, '&')
|
||||||
|
.replace(/</g, '<')
|
||||||
|
.replace(/>/g, '>')
|
||||||
|
.replace(/"/g, '"')
|
||||||
|
.replace(/'/g, ''');
|
||||||
|
}
|
||||||
|
|
||||||
|
function escapeHtml(str) {
|
||||||
|
return escapeXml(str);
|
||||||
|
}
|
||||||
|
|
||||||
|
function buildStatusBadge(confidence) {
|
||||||
|
const conf = (confidence || '').toLowerCase();
|
||||||
|
if (conf === 'tentative') return '待确认';
|
||||||
|
if (conf === 'deprecated') return '已废弃';
|
||||||
|
return '';
|
||||||
|
}
|
||||||
|
|
||||||
|
function buildHeaderText(theme, statusBadge) {
|
||||||
|
const line = (theme.typeEn || '') + ' · ' + (theme.typeLabel || '');
|
||||||
|
return statusBadge ? line + ' · ' + statusBadge : line;
|
||||||
|
}
|
||||||
|
|
||||||
|
function isWideChar(ch) {
|
||||||
|
const code = ch.codePointAt(0) || 0;
|
||||||
|
if (code >= 0x4e00 && code <= 0x9fff) return true;
|
||||||
|
if (code >= 0x3400 && code <= 0x4dbf) return true;
|
||||||
|
if (code >= 0xf900 && code <= 0xfaff) return true;
|
||||||
|
if (code >= 0xff00 && code <= 0xffef) return true;
|
||||||
|
return /[·:,。;!?【】()《》、「」]/.test(ch);
|
||||||
|
}
|
||||||
|
|
||||||
|
function charWidth(ch, fontSize, bold) {
|
||||||
|
const scale = bold ? 1.05 : 1;
|
||||||
|
if (ch === ' ') return fontSize * 0.3 * scale;
|
||||||
|
if (isWideChar(ch)) return fontSize * scale;
|
||||||
|
return fontSize * 0.58 * scale;
|
||||||
|
}
|
||||||
|
|
||||||
|
function lineWidth(text, fontSize, bold) {
|
||||||
|
let width = 0;
|
||||||
|
for (const ch of text) width += charWidth(ch, fontSize, bold);
|
||||||
|
return width;
|
||||||
|
}
|
||||||
|
|
||||||
|
function wrapTextLines(text, maxWidth, fontSize, maxLines, bold) {
|
||||||
|
const raw = String(text || '').replace(/\s+/g, ' ').trim();
|
||||||
|
if (!raw) return ['—'];
|
||||||
|
const safeWidth = Math.max(40, maxWidth - 4);
|
||||||
|
const chars = [...raw];
|
||||||
|
const lines = [];
|
||||||
|
let index = 0;
|
||||||
|
while (index < chars.length && lines.length < maxLines) {
|
||||||
|
let line = '';
|
||||||
|
let width = 0;
|
||||||
|
while (index < chars.length) {
|
||||||
|
const ch = chars[index];
|
||||||
|
const nextWidth = charWidth(ch, fontSize, bold);
|
||||||
|
if (line && width + nextWidth > safeWidth) break;
|
||||||
|
line += ch;
|
||||||
|
width += nextWidth;
|
||||||
|
index += 1;
|
||||||
|
if (width >= safeWidth) break;
|
||||||
|
}
|
||||||
|
if (line) lines.push(line);
|
||||||
|
}
|
||||||
|
if (index < chars.length && lines.length) {
|
||||||
|
let last = lines[lines.length - 1];
|
||||||
|
while (last.length > 1 && lineWidth(last + '…', fontSize, bold) > safeWidth) {
|
||||||
|
last = last.slice(0, -1);
|
||||||
|
}
|
||||||
|
lines[lines.length - 1] = last + '…';
|
||||||
|
}
|
||||||
|
return lines.length ? lines : ['—'];
|
||||||
|
}
|
||||||
|
|
||||||
|
function cardTextWidth(nodeWidth) {
|
||||||
|
return nodeWidth - CARD_TEXT_X - CARD_PAD - CARD_TEXT_PAD_RIGHT;
|
||||||
|
}
|
||||||
|
|
||||||
|
function computeNodeLayout(type, summary, statusBadge, theme, factKey) {
|
||||||
|
const width = type === 'target' ? CARD_TARGET_W : CARD_MIN_W;
|
||||||
|
const textW = cardTextWidth(width);
|
||||||
|
const t = theme || nodeTheme(type);
|
||||||
|
const headerLines = wrapTextLines(buildHeaderText(t, statusBadge), textW, CARD_HEADER_FS, 2, true);
|
||||||
|
const keyText = String(factKey || '').trim();
|
||||||
|
const keyLines = keyText ? wrapTextLines(keyText, textW, CARD_KEY_FS, 2, false) : [];
|
||||||
|
const summaryLines = wrapTextLines(summary, textW, CARD_SUMMARY_FS, keyLines.length ? 3 : 4, true);
|
||||||
|
const keyBlockHeight = keyLines.length
|
||||||
|
? CARD_SECTION_GAP + keyLines.length * CARD_KEY_LH + CARD_SECTION_GAP
|
||||||
|
: CARD_SECTION_GAP;
|
||||||
|
const height = Math.min(
|
||||||
|
CARD_MAX_H,
|
||||||
|
Math.max(
|
||||||
|
CARD_MIN_H,
|
||||||
|
CARD_PAD +
|
||||||
|
headerLines.length * CARD_HEADER_LH +
|
||||||
|
keyBlockHeight +
|
||||||
|
summaryLines.length * CARD_SUMMARY_LH +
|
||||||
|
CARD_PAD,
|
||||||
|
),
|
||||||
|
);
|
||||||
|
return {
|
||||||
|
width,
|
||||||
|
height,
|
||||||
|
headerLines,
|
||||||
|
keyLines,
|
||||||
|
summaryLines,
|
||||||
|
searchLabel: [headerLines.join(' '), keyLines.join(' '), summaryLines.join(' ')]
|
||||||
|
.filter(Boolean)
|
||||||
|
.join('\n'),
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
function svgIconGroup(kind, color, x, y) {
|
||||||
|
const scale = (CARD_ICON / 24).toFixed(3);
|
||||||
|
if (kind === 'target') {
|
||||||
|
return (
|
||||||
|
`<g transform="translate(${x}, ${y}) scale(${scale})">` +
|
||||||
|
`<circle cx="12" cy="12" r="6" fill="none" stroke="${color}" stroke-width="2"/>` +
|
||||||
|
`<circle cx="12" cy="12" r="2.5" fill="${color}"/></g>`
|
||||||
|
);
|
||||||
|
}
|
||||||
|
if (kind === 'finding') {
|
||||||
|
return (
|
||||||
|
`<g transform="translate(${x}, ${y}) scale(${scale})">` +
|
||||||
|
`<circle cx="10" cy="10" r="6" fill="none" stroke="${color}" stroke-width="2"/>` +
|
||||||
|
`<line x1="14.5" y1="14.5" x2="19" y2="19" stroke="${color}" stroke-width="2" stroke-linecap="round"/></g>`
|
||||||
|
);
|
||||||
|
}
|
||||||
|
if (kind === 'vuln') {
|
||||||
|
return (
|
||||||
|
`<g transform="translate(${x}, ${y}) scale(${scale})">` +
|
||||||
|
`<path d="M12 2.5l7.5 3v6.2c0 4.6-3.1 8.1-7.5 9.3-4.4-1.2-7.5-4.7-7.5-9.3V5.5z" fill="${color}" fill-opacity="0.12" stroke="${color}" stroke-width="2"/>` +
|
||||||
|
`<line x1="12" y1="8.5" x2="12" y2="12.5" stroke="${color}" stroke-width="2" stroke-linecap="round"/>` +
|
||||||
|
`<circle cx="12" cy="15.5" r="1.1" fill="${color}"/></g>`
|
||||||
|
);
|
||||||
|
}
|
||||||
|
if (kind === 'vulnerability') {
|
||||||
|
return (
|
||||||
|
`<g transform="translate(${x}, ${y}) scale(${scale})">` +
|
||||||
|
`<path d="M12 3l9 16H3z" fill="none" stroke="${color}" stroke-width="2"/>` +
|
||||||
|
`<line x1="12" y1="9" x2="12" y2="13" stroke="${color}" stroke-width="2"/>` +
|
||||||
|
`<circle cx="12" cy="16" r="1" fill="${color}"/></g>`
|
||||||
|
);
|
||||||
|
}
|
||||||
|
return (
|
||||||
|
`<g transform="translate(${x}, ${y}) scale(${scale})">` +
|
||||||
|
`<circle cx="12" cy="12" r="5" fill="${color}" opacity="0.85"/></g>`
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
function buildNodeCardSvgUrl(theme, layout, confidence) {
|
||||||
|
const { width, height, headerLines, keyLines, summaryLines } = layout;
|
||||||
|
const accent = theme.accent;
|
||||||
|
const bgEnd = theme.bgEnd;
|
||||||
|
const conf = (confidence || '').toLowerCase();
|
||||||
|
const isTentative = conf === 'tentative';
|
||||||
|
const isDeprecated = conf === 'deprecated';
|
||||||
|
const iconX = CARD_PAD;
|
||||||
|
const iconY = (height - CARD_ICON) / 2;
|
||||||
|
const headerY = CARD_PAD + CARD_HEADER_FS;
|
||||||
|
const keyY = CARD_PAD + headerLines.length * CARD_HEADER_LH + CARD_SECTION_GAP + CARD_KEY_FS;
|
||||||
|
const summaryY =
|
||||||
|
CARD_PAD +
|
||||||
|
headerLines.length * CARD_HEADER_LH +
|
||||||
|
(keyLines.length
|
||||||
|
? CARD_SECTION_GAP + keyLines.length * CARD_KEY_LH + CARD_SECTION_GAP
|
||||||
|
: CARD_SECTION_GAP) +
|
||||||
|
CARD_SUMMARY_FS;
|
||||||
|
|
||||||
|
const stroke = isTentative
|
||||||
|
? `stroke="${accent}" stroke-width="1.5" stroke-dasharray="8 5" stroke-opacity="0.9"`
|
||||||
|
: `stroke="${accent}" stroke-width="1.5" stroke-opacity="0.72"`;
|
||||||
|
|
||||||
|
const headerSvg = headerLines
|
||||||
|
.map(
|
||||||
|
(line, i) =>
|
||||||
|
`<text x="${CARD_TEXT_X}" y="${headerY + i * CARD_HEADER_LH}" font-size="${CARD_HEADER_FS}" font-weight="700" fill="${accent}" fill-opacity="0.88" font-family='${CARD_FONT}'>${escapeXml(line)}</text>`,
|
||||||
|
)
|
||||||
|
.join('');
|
||||||
|
|
||||||
|
const keySvg = keyLines
|
||||||
|
.map(
|
||||||
|
(line, i) =>
|
||||||
|
`<text x="${CARD_TEXT_X}" y="${keyY + i * CARD_KEY_LH}" font-size="${CARD_KEY_FS}" font-weight="500" fill="#64748b" font-family='${CARD_KEY_FONT}'>${escapeXml(line)}</text>`,
|
||||||
|
)
|
||||||
|
.join('');
|
||||||
|
|
||||||
|
const summarySvg = summaryLines
|
||||||
|
.map(
|
||||||
|
(line, i) =>
|
||||||
|
`<text x="${CARD_TEXT_X}" y="${summaryY + i * CARD_SUMMARY_LH}" font-size="${CARD_SUMMARY_FS}" font-weight="600" fill="#0f172a" font-family='${CARD_FONT}'>${escapeXml(line)}</text>`,
|
||||||
|
)
|
||||||
|
.join('');
|
||||||
|
|
||||||
|
const textClipW = width - CARD_TEXT_X - CARD_PAD - 2;
|
||||||
|
const textClipH = height - CARD_PAD * 2 + 4;
|
||||||
|
|
||||||
|
const svg =
|
||||||
|
`<svg xmlns="http://www.w3.org/2000/svg" width="${width}" height="${height}" viewBox="0 0 ${width} ${height}">` +
|
||||||
|
`<defs><linearGradient id="bg" x1="0%" y1="0%" x2="100%" y2="100%">` +
|
||||||
|
`<stop offset="0%" stop-color="#FFFFFF"/><stop offset="100%" stop-color="${bgEnd}"/></linearGradient>` +
|
||||||
|
`<clipPath id="textClip"><rect x="${CARD_TEXT_X}" y="${CARD_PAD - 2}" width="${textClipW}" height="${textClipH}"/></clipPath></defs>` +
|
||||||
|
`<g${isDeprecated ? ' opacity="0.55"' : ''}>` +
|
||||||
|
`<rect x="0.75" y="0.75" width="${width - 1.5}" height="${height - 1.5}" rx="12" fill="url(#bg)" ${stroke}/>` +
|
||||||
|
svgIconGroup(theme.icon, accent, iconX, iconY) +
|
||||||
|
`<g clip-path="url(#textClip)">${headerSvg}${keySvg}${summarySvg}</g>` +
|
||||||
|
`</g></svg>`;
|
||||||
|
|
||||||
|
try {
|
||||||
|
return 'data:image/svg+xml;base64,' + btoa(unescape(encodeURIComponent(svg)));
|
||||||
|
} catch (e) {
|
||||||
|
return 'data:image/svg+xml;charset=utf-8,' + encodeURIComponent(svg);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
function destroy() {
|
||||||
|
if (_resizeObs) {
|
||||||
|
_resizeObs.disconnect();
|
||||||
|
_resizeObs = null;
|
||||||
|
}
|
||||||
|
if (_cy) {
|
||||||
|
_cy.destroy();
|
||||||
|
_cy = null;
|
||||||
|
}
|
||||||
|
_graphData = null;
|
||||||
|
}
|
||||||
|
|
||||||
|
function observeContainerResize(container) {
|
||||||
|
if (_resizeObs) {
|
||||||
|
_resizeObs.disconnect();
|
||||||
|
_resizeObs = null;
|
||||||
|
}
|
||||||
|
if (!container || typeof ResizeObserver === 'undefined') return;
|
||||||
|
_resizeObs = new ResizeObserver(() => {
|
||||||
|
if (_cy) {
|
||||||
|
try {
|
||||||
|
_cy.resize();
|
||||||
|
} catch (e) {
|
||||||
|
console.warn('graph resize', e);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
});
|
||||||
|
_resizeObs.observe(container);
|
||||||
|
}
|
||||||
|
|
||||||
|
function centerGraph() {
|
||||||
|
if (!_cy) return;
|
||||||
|
try {
|
||||||
|
_cy.resize();
|
||||||
|
_cy.fit(undefined, 56);
|
||||||
|
if (_cy.zoom() < 0.65) {
|
||||||
|
_cy.zoom(0.65);
|
||||||
|
_cy.center();
|
||||||
|
}
|
||||||
|
} catch (e) {
|
||||||
|
console.warn('centerGraph', e);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ELK 分层(仅影响节点纵向位置,不修改边的 source/target)
|
||||||
|
function pathGraphNodeLayer(type, factKey) {
|
||||||
|
const key = (factKey || '').toLowerCase();
|
||||||
|
if (key.startsWith('vuln:')) return '4';
|
||||||
|
const t = (type || '').toLowerCase();
|
||||||
|
if (t === 'target') return '0';
|
||||||
|
if (t === 'infra' || t === 'auth' || t === 'business') return '1';
|
||||||
|
if (t === 'exploit' || t === 'poc') return '3';
|
||||||
|
if (t === 'vulnerability' || t === 'vuln') return '3';
|
||||||
|
if (t === 'chain' || t === 'finding') return '2';
|
||||||
|
if (t === 'note') return '2';
|
||||||
|
return '2';
|
||||||
|
}
|
||||||
|
|
||||||
|
function applyElkLayout(validEdges, isComplex) {
|
||||||
|
const layoutOptions = {
|
||||||
|
name: 'breadthfirst',
|
||||||
|
directed: true,
|
||||||
|
spacingFactor: isComplex ? 3.0 : 2.5,
|
||||||
|
padding: 40,
|
||||||
|
};
|
||||||
|
const elkInstance = typeof ELK !== 'undefined' ? new ELK() : null;
|
||||||
|
if (!elkInstance) {
|
||||||
|
const layout = _cy.layout(layoutOptions);
|
||||||
|
layout.one('layoutstop', () => setTimeout(centerGraph, 100));
|
||||||
|
layout.run();
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
const nodeGap = isComplex ? 45 : 60;
|
||||||
|
const layerGap = isComplex ? 70 : 95;
|
||||||
|
const elkGraph = {
|
||||||
|
id: 'root',
|
||||||
|
layoutOptions: {
|
||||||
|
'elk.algorithm': 'layered',
|
||||||
|
'elk.direction': 'DOWN',
|
||||||
|
'elk.spacing.nodeNode': String(nodeGap),
|
||||||
|
'elk.layered.spacing.nodeNodeBetweenLayers': String(layerGap),
|
||||||
|
'elk.layered.nodePlacement.strategy': 'BRANDES_KOEPF',
|
||||||
|
},
|
||||||
|
children: (_graphData.nodes || []).map((node) => {
|
||||||
|
const n = _cy ? _cy.getElementById(node.id) : null;
|
||||||
|
const w = n.length ? n.data('nodeWidth') : node.type === 'target' ? CARD_TARGET_W : CARD_MIN_W;
|
||||||
|
const h = n.length ? n.data('nodeHeight') : CARD_MIN_H;
|
||||||
|
const nodeKey = node.fact_key || node.id;
|
||||||
|
return {
|
||||||
|
id: node.id,
|
||||||
|
width: w,
|
||||||
|
height: h,
|
||||||
|
layoutOptions: {
|
||||||
|
'org.eclipse.elk.layered.layering.layerId': pathGraphNodeLayer(node.type, nodeKey),
|
||||||
|
},
|
||||||
|
};
|
||||||
|
}),
|
||||||
|
edges: validEdges.map((edge) => ({
|
||||||
|
id: edge.id,
|
||||||
|
sources: [edge.source],
|
||||||
|
targets: [edge.target],
|
||||||
|
})),
|
||||||
|
};
|
||||||
|
elkInstance
|
||||||
|
.layout(elkGraph)
|
||||||
|
.then((laidOut) => {
|
||||||
|
(laidOut.children || []).forEach((elkNode) => {
|
||||||
|
const cyNode = _cy.getElementById(elkNode.id);
|
||||||
|
if (cyNode.length && elkNode.x != null) {
|
||||||
|
cyNode.position({
|
||||||
|
x: elkNode.x + (elkNode.width || 0) / 2,
|
||||||
|
y: elkNode.y + (elkNode.height || 0) / 2,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
});
|
||||||
|
setTimeout(centerGraph, 120);
|
||||||
|
})
|
||||||
|
.catch(() => {
|
||||||
|
const layout = _cy.layout(layoutOptions);
|
||||||
|
layout.one('layoutstop', () => setTimeout(centerGraph, 100));
|
||||||
|
layout.run();
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
function render(container, graphData, options) {
|
||||||
|
if (!container || typeof cytoscape === 'undefined') {
|
||||||
|
if (container) {
|
||||||
|
container.innerHTML = '<div class="error-message">Cytoscape 未加载</div>';
|
||||||
|
}
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
destroy();
|
||||||
|
_graphData = graphData || { nodes: [], edges: [] };
|
||||||
|
_onNodeSelect = options && options.onNodeSelect;
|
||||||
|
_onEdgeSelect = options && options.onEdgeSelect;
|
||||||
|
|
||||||
|
const nodes = _graphData.nodes || [];
|
||||||
|
const edges = _graphData.edges || [];
|
||||||
|
if (!nodes.length) {
|
||||||
|
const title = (options && options.emptyTitle) || '';
|
||||||
|
const hint = (options && options.emptyText) || '暂无事实关系';
|
||||||
|
const steps = (options && options.emptySteps) || [];
|
||||||
|
const actionLabel = options && options.emptyActionLabel;
|
||||||
|
const stepsHtml = steps.length
|
||||||
|
? '<ol class="project-fact-graph-empty-steps">' +
|
||||||
|
steps.map((s) => '<li>' + escapeHtml(String(s)) + '</li>').join('') +
|
||||||
|
'</ol>'
|
||||||
|
: '';
|
||||||
|
const actionHtml =
|
||||||
|
actionLabel && options.onEmptyAction
|
||||||
|
? '<button type="button" class="btn-primary btn-small project-fact-graph-empty-cta">' +
|
||||||
|
escapeHtml(actionLabel) +
|
||||||
|
'</button>'
|
||||||
|
: '';
|
||||||
|
container.innerHTML =
|
||||||
|
'<div class="project-fact-graph-empty">' +
|
||||||
|
'<div class="project-fact-graph-empty-icon" aria-hidden="true">' +
|
||||||
|
'<svg width="48" height="48" viewBox="0 0 24 24" fill="none"><circle cx="6" cy="6" r="2.5" fill="#4F46E5" opacity="0.9"/><circle cx="18" cy="6" r="2.5" fill="#E11D48" opacity="0.9"/><circle cx="12" cy="18" r="2.5" fill="#0D9488" opacity="0.9"/>' +
|
||||||
|
'<path d="M8 7l4 9M16 7l-4 9M8 7h8" stroke="#CBD5E1" stroke-width="1.5" stroke-linecap="round"/></svg>' +
|
||||||
|
'</div>' +
|
||||||
|
(title ? '<h4 class="project-fact-graph-empty-title">' + escapeHtml(title) + '</h4>' : '') +
|
||||||
|
'<p class="project-fact-graph-empty-hint">' + escapeHtml(hint) + '</p>' +
|
||||||
|
stepsHtml +
|
||||||
|
actionHtml +
|
||||||
|
'</div>';
|
||||||
|
const cta = container.querySelector('.project-fact-graph-empty-cta');
|
||||||
|
if (cta && typeof options.onEmptyAction === 'function') {
|
||||||
|
cta.addEventListener('click', options.onEmptyAction);
|
||||||
|
}
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
|
container.innerHTML = '';
|
||||||
|
const isComplex = nodes.length > 15 || edges.length > 25;
|
||||||
|
const elements = [];
|
||||||
|
const nodeIds = new Set();
|
||||||
|
|
||||||
|
nodes.forEach((node) => {
|
||||||
|
nodeIds.add(node.id);
|
||||||
|
const visualType = resolveGraphNodeType(node);
|
||||||
|
const theme = nodeTheme(visualType);
|
||||||
|
const factKey = node.fact_key || node.id;
|
||||||
|
const summary = (node.summary || node.label || '').trim() || '—';
|
||||||
|
const statusBadge = buildStatusBadge(node.confidence);
|
||||||
|
const layout = computeNodeLayout(visualType, summary, statusBadge, theme, factKey);
|
||||||
|
elements.push({
|
||||||
|
data: {
|
||||||
|
id: node.id,
|
||||||
|
label: layout.searchLabel,
|
||||||
|
factKey: node.fact_key || node.id,
|
||||||
|
category: node.category || '',
|
||||||
|
type: visualType,
|
||||||
|
typeLabel: theme.typeLabel,
|
||||||
|
typeEn: theme.typeEn,
|
||||||
|
accentColor: theme.accent,
|
||||||
|
statusBadge: statusBadge,
|
||||||
|
confidence: node.confidence || '',
|
||||||
|
nodeWidth: layout.width,
|
||||||
|
nodeHeight: layout.height,
|
||||||
|
cardSvgUrl: buildNodeCardSvgUrl(theme, layout, node.confidence),
|
||||||
|
},
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
const validEdges = [];
|
||||||
|
edges.forEach((edge, idx) => {
|
||||||
|
if (!nodeIds.has(edge.source) || !nodeIds.has(edge.target)) return;
|
||||||
|
const id = edge.id || 'e-' + idx;
|
||||||
|
validEdges.push({ ...edge, id });
|
||||||
|
elements.push({
|
||||||
|
data: {
|
||||||
|
id,
|
||||||
|
source: edge.source,
|
||||||
|
target: edge.target,
|
||||||
|
type: edge.type || 'leads_to',
|
||||||
|
confidence: edge.confidence || 'confirmed',
|
||||||
|
},
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
_cy = cytoscape({
|
||||||
|
container,
|
||||||
|
elements,
|
||||||
|
style: [
|
||||||
|
{
|
||||||
|
selector: 'node',
|
||||||
|
style: {
|
||||||
|
label: '',
|
||||||
|
width: (ele) => ele.data('nodeWidth') || CARD_MIN_W,
|
||||||
|
height: (ele) => ele.data('nodeHeight') || CARD_MIN_H,
|
||||||
|
shape: 'round-rectangle',
|
||||||
|
'background-color': '#ffffff',
|
||||||
|
'background-image': (ele) => ele.data('cardSvgUrl') || 'none',
|
||||||
|
'background-width': (ele) => (ele.data('nodeWidth') || CARD_MIN_W) + 'px',
|
||||||
|
'background-height': (ele) => (ele.data('nodeHeight') || CARD_MIN_H) + 'px',
|
||||||
|
'background-position-x': '50%',
|
||||||
|
'background-position-y': '50%',
|
||||||
|
'background-fit': 'none',
|
||||||
|
'border-width': 0,
|
||||||
|
'background-opacity': 1,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
selector: 'edge',
|
||||||
|
style: {
|
||||||
|
width: 2.2,
|
||||||
|
'line-color': (ele) => EDGE_COLORS[ele.data('type')] || '#CBD5E1',
|
||||||
|
'target-arrow-color': (ele) => EDGE_COLORS[ele.data('type')] || '#CBD5E1',
|
||||||
|
'target-arrow-shape': 'triangle',
|
||||||
|
'curve-style': 'bezier',
|
||||||
|
opacity: (ele) => (ele.data('confidence') === 'tentative' ? 0.55 : 0.9),
|
||||||
|
'line-style': (ele) => (ele.data('confidence') === 'tentative' ? 'dashed' : 'solid'),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
selector: 'edge:selected',
|
||||||
|
style: {
|
||||||
|
width: 3.5,
|
||||||
|
opacity: 1,
|
||||||
|
'line-color': '#4F46E5',
|
||||||
|
'target-arrow-color': '#4F46E5',
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
selector: 'node:selected',
|
||||||
|
style: {
|
||||||
|
'border-width': 3,
|
||||||
|
'border-color': '#4F46E5',
|
||||||
|
'border-opacity': 1,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
],
|
||||||
|
minZoom: 0.35,
|
||||||
|
maxZoom: 3,
|
||||||
|
});
|
||||||
|
|
||||||
|
_cy.on('tap', 'node', (evt) => {
|
||||||
|
const d = evt.target.data();
|
||||||
|
const key = d.factKey || d.id;
|
||||||
|
if (_connectMode && _connectPick) {
|
||||||
|
_connectPick(key);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
if (typeof _onNodeSelect === 'function') {
|
||||||
|
_onNodeSelect(key, d);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
_cy.on('tap', 'edge', (evt) => {
|
||||||
|
if (_connectMode && _connectPick) return;
|
||||||
|
const d = evt.target.data();
|
||||||
|
if (typeof _onEdgeSelect === 'function') {
|
||||||
|
_onEdgeSelect(d.id, d);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
_cy.on('tap', (evt) => {
|
||||||
|
if (evt.target === _cy) {
|
||||||
|
clearEdgeSelection();
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
applyElkLayout(validEdges, isComplex);
|
||||||
|
observeContainerResize(container);
|
||||||
|
return _cy;
|
||||||
|
}
|
||||||
|
|
||||||
|
function filterBySearch(query) {
|
||||||
|
if (!_cy) return;
|
||||||
|
const q = (query || '').trim().toLowerCase();
|
||||||
|
_cy.nodes().forEach((n) => {
|
||||||
|
if (!q) {
|
||||||
|
n.style('opacity', 1);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
const text = (
|
||||||
|
(n.data('label') || '') +
|
||||||
|
' ' +
|
||||||
|
(n.data('factKey') || '') +
|
||||||
|
' ' +
|
||||||
|
(n.data('typeLabel') || '')
|
||||||
|
).toLowerCase();
|
||||||
|
n.style('opacity', text.includes(q) ? 1 : 0.15);
|
||||||
|
});
|
||||||
|
_cy.edges().forEach((e) => {
|
||||||
|
e.style('opacity', q ? 0.12 : 0.9);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
let _connectMode = false;
|
||||||
|
let _connectPick = null;
|
||||||
|
|
||||||
|
function selectEdge(edgeId) {
|
||||||
|
if (!_cy || !edgeId) return;
|
||||||
|
_cy.elements().unselect();
|
||||||
|
const edge = _cy.getElementById(edgeId);
|
||||||
|
if (edge.length) edge.select();
|
||||||
|
}
|
||||||
|
|
||||||
|
function clearEdgeSelection() {
|
||||||
|
if (!_cy) return;
|
||||||
|
_cy.elements().unselect();
|
||||||
|
}
|
||||||
|
|
||||||
|
function setConnectMode(enabled, onPick) {
|
||||||
|
_connectMode = !!enabled;
|
||||||
|
_connectPick = typeof onPick === 'function' ? onPick : null;
|
||||||
|
if (_cy) {
|
||||||
|
_cy.userPanningEnabled(!_connectMode);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/** 与后端 GraphNodeType 一致:优先 category,vuln: 合成节点例外;无 category 时回退 type/key。 */
|
||||||
|
function resolveGraphNodeType(node) {
|
||||||
|
if (!node) return 'note';
|
||||||
|
const key = String(node.fact_key || node.id || '').toLowerCase();
|
||||||
|
if (key.startsWith('vuln:')) return 'vulnerability';
|
||||||
|
const cat = String(node.category || '').toLowerCase();
|
||||||
|
if (cat) {
|
||||||
|
if (cat === 'vuln') return 'vulnerability';
|
||||||
|
if (cat === 'missing') return 'missing';
|
||||||
|
return cat;
|
||||||
|
}
|
||||||
|
const t = String(node.type || '').toLowerCase();
|
||||||
|
if (t === 'vuln') return 'vulnerability';
|
||||||
|
if (t) return t;
|
||||||
|
if (key.startsWith('target/')) return 'target';
|
||||||
|
if (key.startsWith('exploit/') || key.startsWith('evidence/')) return 'exploit';
|
||||||
|
if (key.startsWith('poc/')) return 'poc';
|
||||||
|
if (key.startsWith('chain/')) return 'chain';
|
||||||
|
if (key.startsWith('finding/')) return 'finding';
|
||||||
|
if (key.startsWith('auth/')) return 'auth';
|
||||||
|
if (key.startsWith('infra/') || key.startsWith('business/')) return 'infra';
|
||||||
|
return 'note';
|
||||||
|
}
|
||||||
|
|
||||||
|
global.ProjectFactGraph = {
|
||||||
|
render,
|
||||||
|
destroy,
|
||||||
|
center: centerGraph,
|
||||||
|
filterBySearch,
|
||||||
|
setConnectMode,
|
||||||
|
selectEdge,
|
||||||
|
clearEdgeSelection,
|
||||||
|
nodeTheme,
|
||||||
|
resolveGraphNodeType,
|
||||||
|
};
|
||||||
|
})(typeof window !== 'undefined' ? window : globalThis);
|
||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user