Compare commits

...

83 Commits

Author SHA1 Message Date
公明 97834c162e Update config.yaml 2026-05-23 19:53:40 +08:00
公明 9276f2f144 Add files via upload 2026-05-23 19:49:50 +08:00
公明 a454cada6a Add files via upload 2026-05-23 19:39:03 +08:00
公明 99b53d4fbc Add files via upload 2026-05-23 19:35:30 +08:00
公明 a43a9deaea Add files via upload 2026-05-23 19:33:23 +08:00
公明 ce88da84c9 Add files via upload 2026-05-23 19:31:40 +08:00
公明 15855c7073 Add files via upload 2026-05-23 19:29:49 +08:00
公明 43eb3e546b Add files via upload 2026-05-22 17:23:01 +08:00
公明 2d52c9b6ac Update config.yaml 2026-05-22 17:18:48 +08:00
公明 d5401b8b4c Update config.yaml 2026-05-22 17:17:48 +08:00
公明 5fd4393a2e Add files via upload 2026-05-22 17:14:33 +08:00
公明 a049f6b5c2 Add files via upload 2026-05-22 17:13:55 +08:00
公明 acba8e5a39 Add files via upload 2026-05-22 17:11:34 +08:00
公明 f826b91362 Add files via upload 2026-05-22 17:09:54 +08:00
公明 98c2de2a60 Add files via upload 2026-05-22 17:08:05 +08:00
公明 1c4d4b305b Update config.yaml 2026-05-22 15:15:46 +08:00
公明 f210ac9a03 Add files via upload 2026-05-22 11:36:36 +08:00
公明 6685076dfb Add files via upload 2026-05-22 11:35:02 +08:00
公明 7f322653f6 Add files via upload 2026-05-22 11:32:36 +08:00
公明 66ac2f1357 Add files via upload 2026-05-22 11:30:25 +08:00
公明 c446e22d0c Add files via upload 2026-05-22 11:28:51 +08:00
公明 0358d3a67d Add files via upload 2026-05-22 10:30:19 +08:00
公明 9b82f265fd Add files via upload 2026-05-20 18:24:17 +08:00
公明 3d9cae58e4 Update config.yaml 2026-05-20 17:59:57 +08:00
公明 1f1eadee5e Update config.yaml 2026-05-20 17:58:24 +08:00
公明 0569255189 Add files via upload 2026-05-20 17:54:30 +08:00
公明 8ccf90d067 Add files via upload 2026-05-20 17:52:22 +08:00
公明 b3be89f47d Add files via upload 2026-05-20 17:50:52 +08:00
公明 b9bf8f62d4 Add files via upload 2026-05-20 17:48:42 +08:00
公明 05ca0c1480 Update config.yaml 2026-05-20 16:57:50 +08:00
公明 47a4f3fc5b Add files via upload 2026-05-20 16:52:50 +08:00
公明 a3b378ae9e Add files via upload 2026-05-20 16:49:26 +08:00
公明 a904d26e78 Add files via upload 2026-05-20 16:47:34 +08:00
公明 7ba7476c4f Add files via upload 2026-05-20 16:45:59 +08:00
公明 ae25a243ac Add files via upload 2026-05-20 16:43:38 +08:00
公明 23bd6288ff Add files via upload 2026-05-20 16:39:13 +08:00
公明 fef21d3a24 Add files via upload 2026-05-20 16:36:50 +08:00
公明 933bba4517 Update config.yaml 2026-05-20 16:12:13 +08:00
公明 e1d65437cc Add files via upload 2026-05-20 16:11:10 +08:00
公明 9325aed1eb Add files via upload 2026-05-20 16:09:33 +08:00
公明 dee2b3ab42 Add files via upload 2026-05-20 16:07:33 +08:00
公明 a69bc93fa1 Add files via upload 2026-05-20 16:05:40 +08:00
公明 b1a620bfce Update config.yaml 2026-05-20 14:18:33 +08:00
公明 61b164eec2 Add files via upload 2026-05-20 11:03:38 +08:00
公明 ba77e1837e Update config.yaml 2026-05-19 23:05:52 +08:00
公明 eacad60fd6 Add files via upload 2026-05-19 23:03:04 +08:00
公明 70bf5c93bf Update config.yaml 2026-05-19 19:01:31 +08:00
公明 08bd278d8c Update config.yaml 2026-05-19 18:56:24 +08:00
公明 22746d64a3 Add files via upload 2026-05-19 18:53:46 +08:00
公明 199392a5d5 Add files via upload 2026-05-19 18:52:22 +08:00
公明 aafb4cb584 Add files via upload 2026-05-19 18:50:28 +08:00
公明 96e3dd397c Add files via upload 2026-05-19 18:48:17 +08:00
公明 ec0f17145b Add files via upload 2026-05-19 17:50:38 +08:00
公明 ed53da0999 Delete security directory 2026-05-19 17:49:21 +08:00
公明 dc440fc511 Delete robot directory 2026-05-19 17:49:10 +08:00
公明 009ae59033 Delete logger directory 2026-05-19 17:48:59 +08:00
公明 f348b3245a Delete knowledge directory 2026-05-19 17:48:44 +08:00
公明 0018c5219c Delete config directory 2026-05-19 17:48:33 +08:00
公明 01a3e3677a Delete c2 directory 2026-05-19 17:48:22 +08:00
公明 a12ecdb46f Add files via upload 2026-05-19 17:47:56 +08:00
公明 9f59230d74 Add files via upload 2026-05-19 17:46:33 +08:00
公明 085c6a1c72 Add files via upload 2026-05-19 17:43:45 +08:00
公明 7b3860971f Add files via upload 2026-05-19 17:42:12 +08:00
公明 f6f7b7b237 Add files via upload 2026-05-19 17:40:19 +08:00
公明 d5cf4b3b16 Add files via upload 2026-05-19 16:48:07 +08:00
公明 3e58d8355b Add files via upload 2026-05-19 16:32:38 +08:00
公明 eb01ade63b Add files via upload 2026-05-19 16:29:05 +08:00
公明 d1dc15fa44 Add files via upload 2026-05-19 16:27:29 +08:00
公明 73a39ef868 Add files via upload 2026-05-19 16:25:47 +08:00
公明 a022baef03 Add files via upload 2026-05-19 16:23:21 +08:00
公明 59312d428e Add files via upload 2026-05-19 14:53:07 +08:00
公明 951d14ef14 Update config.yaml 2026-05-18 23:51:19 +08:00
公明 0eb22da6e9 Add files via upload 2026-05-18 23:50:55 +08:00
公明 5fd9ef0514 Add files via upload 2026-05-18 23:47:10 +08:00
公明 9a4f3c7d35 Add files via upload 2026-05-18 17:37:29 +08:00
公明 ead2ce3ecc Add files via upload 2026-05-18 17:28:14 +08:00
公明 8733f3a2d2 Update config.yaml 2026-05-18 11:03:29 +08:00
公明 8642f3ba31 Add files via upload 2026-05-17 17:11:16 +08:00
公明 6a262a7367 Add files via upload 2026-05-17 17:09:16 +08:00
公明 eb9192ddb3 Add files via upload 2026-05-17 17:08:42 +08:00
公明 5587e75628 Add files via upload 2026-05-17 17:06:53 +08:00
公明 74bbb453e2 Add files via upload 2026-05-17 17:05:22 +08:00
公明 66842f6206 Add files via upload 2026-05-17 17:01:48 +08:00
108 changed files with 9984 additions and 772 deletions
+2 -2
View File
@@ -285,7 +285,7 @@ Requirements / tips:
- **Supervisor orchestrator**: fixed name **`orchestrator-supervisor.md`** (plus optional `orchestrator_instruction_supervisor`); requires at least one sub-agent.
- **Sub-agents** (for **deep** / **supervisor**): other `*.md` files (YAML front matter + body). Not used as **`task`** targets if marked orchestrator-only.
- **Management** Web UI: **Agents → Agent management**; API `/api/multi-agent/markdown-agents`.
- **Config** `multi_agent` in `config.yaml`: `enabled`, `default_mode`, `robot_use_multi_agent`, `batch_use_multi_agent`, `max_iteration`, `plan_execute_loop_max_iterations`, per-mode orchestrator instruction fields, optional YAML `sub_agents` merged with disk (`id` clash → Markdown wins), **`eino_skills`**, **`eino_middleware`** (optional ADK middleware and Deep/Supervisor tuning).
- **Config** `multi_agent` in `config.yaml`: `enabled`, `robot_default_agent_mode`, `batch_use_multi_agent`, `max_iteration`, `plan_execute_loop_max_iterations`, per-mode orchestrator instruction fields, optional YAML `sub_agents` merged with disk (`id` clash → Markdown wins), **`eino_skills`**, **`eino_middleware`** (optional ADK middleware and Deep/Supervisor tuning).
- **Details** **[docs/MULTI_AGENT_EINO.md](docs/MULTI_AGENT_EINO.md)** (streaming, robots, batch, middleware caveats).
### Skills System (Agent Skills + Eino)
@@ -536,7 +536,7 @@ agents_dir: "agents" # Multi-agent Markdown definitions (orchestrator + sub-age
multi_agent:
enabled: false
default_mode: "single" # single | multi (UI default when multi-agent is enabled)
robot_use_multi_agent: false
robot_default_agent_mode: react
batch_use_multi_agent: false
orchestrator_instruction: "" # Deep; used when orchestrator.md body is empty
# orchestrator_instruction_plan_execute / orchestrator_instruction_supervisor optional
+2 -2
View File
@@ -283,7 +283,7 @@ go build -o cyberstrike-ai cmd/server/main.go
- **Supervisor 主代理**:固定 **`orchestrator-supervisor.md`**(另可配 `orchestrator_instruction_supervisor`);至少需一名子代理。
- **子代理****deep** / **supervisor**):其余 `*.md`;标成 orchestrator 的不会进入 `task` 列表。
- **界面管理****Agents → Agent 管理**API `/api/multi-agent/markdown-agents`。
- **配置项**`multi_agent``enabled`、`default_mode`、`robot_use_multi_agent`、`batch_use_multi_agent`、`max_iteration`、`plan_execute_loop_max_iterations`、各模式 orchestrator 指令字段、可选 YAML `sub_agents` 与目录合并(同 `id` → Markdown 优先)、**`eino_skills`**、**`eino_middleware`**。
- **配置项**`multi_agent``enabled`、`robot_default_agent_mode`、`batch_use_multi_agent`、`max_iteration`、`plan_execute_loop_max_iterations`、各模式 orchestrator 指令字段、可选 YAML `sub_agents` 与目录合并(同 `id` → Markdown 优先)、**`eino_skills`**、**`eino_middleware`**。
- **更多细节**[docs/MULTI_AGENT_EINO.md](docs/MULTI_AGENT_EINO.md)(流式、机器人、批量、中间件差异)。
### Skills 技能系统(Agent Skills + Eino
@@ -534,7 +534,7 @@ agents_dir: "agents" # 多代理 Markdown(主代理 orchestrator.md + 子代
multi_agent:
enabled: false
default_mode: "single" # single | multi(开启多代理时的界面默认模式)
robot_use_multi_agent: false
robot_default_agent_mode: react
batch_use_multi_agent: false
orchestrator_instruction: "" # Deeporchestrator.md 正文为空时使用
# orchestrator_instruction_plan_execute / orchestrator_instruction_supervisor 可选
+25 -5
View File
@@ -10,7 +10,7 @@
# ============================================
# 前端显示的版本号(可选,不填则显示默认版本)
version: "v1.6.15"
version: "v1.6.22"
# 服务器配置
server:
host: 0.0.0.0 # 监听地址,0.0.0.0 表示监听所有网络接口
@@ -34,6 +34,12 @@ auth:
log:
level: info # 日志级别: debug(调试), info(信息), warn(警告), error(错误)
output: stdout # 日志输出位置: stdout(标准输出), stderr(标准错误), 或文件路径
# 平台操作审计(系统设置 -> 日志审计;不记录对话正文与每次工具调用)
audit:
enabled: true
retention_days: 15 # 0 表示不自动清理
max_detail_bytes: 8192
auth_failure_cooldown_seconds: 60 # 同一 IP 登录/改密失败审计最短间隔(秒);未配置时默认 60;-1 关闭节流
# ============================================
# 对话相关配置
# ============================================
@@ -54,8 +60,8 @@ openai:
max_total_tokens: 120000 # LLM 相关上下文的最大 Token 数限制(内存压缩和攻击链构建会共用此配置)
# Eino 路径模型推理:DeepSeek/OpenAI 为 thinking / reasoning_effort 等;provider 为 claude 时合并为 Anthropic 顶层 thinkingextended thinking),mode: off 关闭
reasoning:
mode: off # auto | on | offoff 时不附加任何推理扩展字段
effort: max # low | medium | high | max;空表示不指定openai_compat 下 auto 且无强度时不发请求扩展)
mode: on # auto | on | offoff 时不附加任何推理扩展字段
effort: high # low | medium | high | max | xhigh(最高档:OpenAI 常用 xhigh,部分网关用 max,原样下发);空表示不指定
allow_client_reasoning: true # false 时忽略对话请求体 reasoning,仅以下方为准
profile: openai_compat # auto | deepseek_compat | openai_compat | output_config_effort
# extra_request_fields: {} # 可选:管理员自定义根级 JSON 片段(高级)
@@ -76,16 +82,18 @@ agent:
result_storage_dir: tmp # 结果存储目录,大结果会保存在此目录下
tool_timeout_minutes: 60 # 单次工具执行最大时长(分钟),超时自动终止;0 表示不限制(不推荐,易出现长时间挂起)
# system_prompt_path: prompts/single-react.md # 可选:单代理系统提示文件(相对本配置文件所在目录);非空且可读时替换内置提示
system_prompt_path: ""
# 人机协同(HITL)全局白名单:此处列出的工具始终免审批,与对话页「白名单工具(免审批,逗号分隔)」合并为并集;侧栏「应用」可合并写入本列表并立即生效。
hitl:
# 按你环境里的真实工具名增删(与侧栏一致、小写不敏感);不需要全局免审批可改为 []
tool_whitelist: [read_file, list_dir, glob, grep]
# 多代理(CloudWeGo Eino DeepAgent,与上方单 Agent /api/agent-loop 并存)
# 依赖在 go.mod 中拉取;若下载失败可设置: go env -w GOPROXY=https://goproxy.cn,direct
# 启用后需重启服务才会注册 /api/multi-agent 与 /api/multi-agent/streamDeep / Plan-Execute / Supervisor 由对话页与 WebShell 所选模式在请求体中传入;机器人/批量无请求体时固定按 deep
# 启用后需重启服务才会注册 /api/multi-agent 与 /api/multi-agent/streamDeep / Plan-Execute / Supervisor 由对话页与 WebShell 所选模式在请求体中传入;机器人按 robot_default_agent_mode
multi_agent:
enabled: true
robot_use_multi_agent: true # true 时企业微信/钉钉/飞书机器人也走 Eino 多代理(成本更高)
robot_default_agent_mode: eino_single # 企微/钉钉/飞书机器人默认对话模式:react | eino_single | deep | plan_execute | supervisor
batch_use_multi_agent: false # true 时「批量任务」队列中每个子任务也走 Eino 多代理(成本更高)
max_iteration: 0 # 主代理 / plan_execute 执行器最大轮次,0 表示沿用 agent.max_iterations
# plan_execute 专用:execute↔replan 外层循环上限,0 表示 Eino 默认 10。当前实现下 Executor 会挂载 patch/reduction/tool_search 等前置中间件。
@@ -125,6 +133,8 @@ multi_agent:
plan_execute_max_step_result_runes: 4000 # plan_execute 每步结果最大字符数(超出截断)
plan_execute_keep_last_steps: 8 # plan_execute 仅保留最近 N 步正文,早期步骤折叠为标题
checkpoint_dir: "" # 非空:为 adk.NewRunner 启用按会话子目录的文件型 CheckPointStore,便于中断恢复持久化;Resume 的 HTTP/前端流程需另行对接
run_retry_max_attempts: 0 # >0429/5xx/网络抖动时 ADK 运行循环指数退避续跑次数;0=默认 10
run_retry_max_backoff_sec: 0 # 单次退避上限秒数;0=默认 30
deep_output_key: "" # 非空:将最终助手输出写入 adk session 的键名(Deep 与 Supervisor 主代理);空表示不写入
deep_model_retry_max_retries: 0 # >0ChatModel 调用失败时的框架级最大重试次数(Deep 与 Supervisor 主);0:不重试
task_tool_description_prefix: "" # 非空:仅 Deep 的 task 工具使用自定义描述前缀,运行时会拼接子代理名称;空则走 Eino 默认生成逻辑
@@ -235,6 +245,14 @@ knowledge:
# 用于在手机端通过企业微信/钉钉/飞书与 CyberStrikeAI 对话,无需部署在服务器上也可使用
# 在系统设置 -> 机器人设置 中可配置
robots:
wechat: # 微信 iLink(个人微信 ClawBot,扫码绑定)
enabled: false
bot_token: ""
ilink_bot_id: ""
ilink_user_id: ""
base_url: https://ilinkai.weixin.qq.com
bot_type: "3"
bot_agent: CyberStrikeAI/1.0
wecom: # 企业微信
enabled: false
token: ""
@@ -246,11 +264,13 @@ robots:
enabled: false
client_id: ""
client_secret: ""
allow_conversation_id_fallback: false
lark: # 飞书
enabled: false
app_id: ""
app_secret: ""
verify_token: ""
allow_chat_id_fallback: false
# ============================================
# Skills 相关配置
# ============================================
+2 -1
View File
@@ -27,12 +27,14 @@ require (
github.com/open-dingtalk/dingtalk-stream-sdk-go v0.9.1
github.com/pkoukk/tiktoken-go v0.1.8
github.com/robfig/cron/v3 v3.0.1
github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e
go.opentelemetry.io/otel v1.34.0
go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp v1.34.0
go.opentelemetry.io/otel/exporters/stdout/stdouttrace v1.34.0
go.opentelemetry.io/otel/sdk v1.34.0
go.opentelemetry.io/otel/trace v1.34.0
go.uber.org/zap v1.26.0
golang.org/x/net v0.35.0
golang.org/x/text v0.26.0
golang.org/x/time v0.14.0
gopkg.in/yaml.v3 v3.0.1
@@ -88,7 +90,6 @@ require (
golang.org/x/arch v0.15.0 // indirect
golang.org/x/crypto v0.39.0 // indirect
golang.org/x/exp v0.0.0-20250305212735-054e65f0b394 // indirect
golang.org/x/net v0.35.0 // indirect
golang.org/x/oauth2 v0.30.0 // indirect
golang.org/x/sys v0.33.0 // indirect
google.golang.org/genproto/googleapis/api v0.0.0-20250115164207-1a7da9e5054f // indirect
+2 -2
View File
@@ -163,6 +163,8 @@ github.com/rollbar/rollbar-go v1.0.2/go.mod h1:AcFs5f0I+c71bpHlXNNDbOWJiKwjFDtIS
github.com/sirupsen/logrus v1.2.0/go.mod h1:LxeOpSwHxABJmUn/MG1IvRgCAasNZTLOkJPxbbu5VWo=
github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ=
github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ=
github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e h1:MRM5ITcdelLK2j1vwZ3Je0FKVCfqOLp5zO6trqMLYs0=
github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e/go.mod h1:XV66xRDqSt+GTGFMVlhk3ULuV0y9ZmzeVGR4mloJI3M=
github.com/slongfield/pyfmt v0.0.0-20220222012616-ea85ff4c361f h1:Z2cODYsUxQPofhpYRMQVwWz4yUVpHF+vPi+eUdruUYI=
github.com/slongfield/pyfmt v0.0.0-20220222012616-ea85ff4c361f/go.mod h1:JqzWyvTuI2X4+9wOHmKSQCYxybB/8j6Ko43qVmXDuZg=
github.com/smarty/assertions v1.16.0 h1:EvHNkdRA4QHMrn75NZSoUQ/mAUXAYWfatfB01yTCzfY=
@@ -245,8 +247,6 @@ golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn
golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
golang.org/x/net v0.0.0-20200226121028-0de0cce0169b/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU=
golang.org/x/net v0.34.0 h1:Mb7Mrk043xzHgnRM88suvJFwzVrRfHEHJEl5/71CKw0=
golang.org/x/net v0.34.0/go.mod h1:di0qlW3YNM5oh6GqDGQr92MyTozJPmybPK4Ev/Gm31k=
golang.org/x/net v0.35.0 h1:T5GQRQb2y08kTAByq9L4/bz8cipCdA8FbRTXewonqY8=
golang.org/x/net v0.35.0/go.mod h1:EglIi67kWsHKlRzzVMUD93VMSWGFOMSZgxFjparz1Qk=
golang.org/x/oauth2 v0.30.0 h1:dnDm7JmhM45NNpd8FDDeLhK6FwqbOf4MLCM9zb1BOHI=
+32 -8
View File
@@ -598,11 +598,17 @@ func (a *Agent) AgentLoopWithProgress(ctx context.Context, userInput string, his
thinkingStreamSeq++
thinkingStreamId := fmt.Sprintf("thinking-stream-%s-%d-%d", conversationID, i+1, thinkingStreamSeq)
thinkingStreamStarted := false
var thinkingWire string
response, err := a.callOpenAIStreamWithToolCalls(ctx, messages, tools, func(delta string) error {
if delta == "" {
return nil
}
var deltaOut string
thinkingWire, deltaOut = openai.NormalizeStreamingDelta(thinkingWire, delta)
if deltaOut == "" {
return nil
}
if !thinkingStreamStarted {
thinkingStreamStarted = true
sendProgress("thinking_stream_start", " ", map[string]interface{}{
@@ -611,10 +617,10 @@ func (a *Agent) AgentLoopWithProgress(ctx context.Context, userInput string, his
"toolStream": false,
})
}
sendProgress("thinking_stream_delta", delta, map[string]interface{}{
sendProgress("thinking_stream_delta", deltaOut, openai.WithSSEAccumulated(map[string]interface{}{
"streamId": thinkingStreamId,
"iteration": i + 1,
})
}, thinkingWire))
return nil
})
if err != nil {
@@ -827,10 +833,16 @@ func (a *Agent) AgentLoopWithProgress(ctx context.Context, userInput string, his
"mcpExecutionIds": result.MCPExecutionIDs,
"messageGeneratedBy": "summary",
})
var summaryWire string
streamText, _ := a.callOpenAIStreamText(ctx, messages, []Tool{}, func(delta string) error {
sendProgress("response_delta", delta, map[string]interface{}{
var deltaOut string
summaryWire, deltaOut = openai.NormalizeStreamingDelta(summaryWire, delta)
if deltaOut == "" {
return nil
}
sendProgress("response_delta", deltaOut, openai.WithSSEAccumulated(map[string]interface{}{
"conversationId": conversationID,
})
}, summaryWire))
return nil
})
if strings.TrimSpace(streamText) != "" {
@@ -874,10 +886,16 @@ func (a *Agent) AgentLoopWithProgress(ctx context.Context, userInput string, his
"mcpExecutionIds": result.MCPExecutionIDs,
"messageGeneratedBy": "summary",
})
var summaryWire string
streamText, _ := a.callOpenAIStreamText(ctx, messages, []Tool{}, func(delta string) error {
sendProgress("response_delta", delta, map[string]interface{}{
var deltaOut string
summaryWire, deltaOut = openai.NormalizeStreamingDelta(summaryWire, delta)
if deltaOut == "" {
return nil
}
sendProgress("response_delta", deltaOut, openai.WithSSEAccumulated(map[string]interface{}{
"conversationId": conversationID,
})
}, summaryWire))
return nil
})
if strings.TrimSpace(streamText) != "" {
@@ -921,10 +939,16 @@ func (a *Agent) AgentLoopWithProgress(ctx context.Context, userInput string, his
"mcpExecutionIds": result.MCPExecutionIDs,
"messageGeneratedBy": "max_iter_summary",
})
var summaryWire string
streamText, _ := a.callOpenAIStreamText(ctx, messages, []Tool{}, func(delta string) error {
sendProgress("response_delta", delta, map[string]interface{}{
var deltaOut string
summaryWire, deltaOut = openai.NormalizeStreamingDelta(summaryWire, delta)
if deltaOut == "" {
return nil
}
sendProgress("response_delta", deltaOut, openai.WithSSEAccumulated(map[string]interface{}{
"conversationId": conversationID,
})
}, summaryWire))
return nil
})
if strings.TrimSpace(streamText) != "" {
+167
View File
@@ -0,0 +1,167 @@
package agent
import (
"encoding/json"
"strings"
)
// ParseTraceMessages 解析落库的 last_react_inputOpenAI 风格 messages JSON 数组)。
func ParseTraceMessages(traceInputJSON string) ([]ChatMessage, error) {
traceInputJSON = strings.TrimSpace(traceInputJSON)
if traceInputJSON == "" {
return nil, nil
}
var raw []map[string]interface{}
if err := json.Unmarshal([]byte(traceInputJSON), &raw); err != nil {
return nil, err
}
out := make([]ChatMessage, 0, len(raw))
for _, msgMap := range raw {
msg := ChatMessage{}
role, _ := msgMap["role"].(string)
if role == "" {
continue
}
msg.Role = role
if content, ok := msgMap["content"].(string); ok {
msg.Content = content
}
if rc, ok := msgMap["reasoning_content"].(string); ok && strings.TrimSpace(rc) != "" {
msg.ReasoningContent = rc
}
if toolCallsRaw, ok := msgMap["tool_calls"]; ok && toolCallsRaw != nil {
if toolCallsArray, ok := toolCallsRaw.([]interface{}); ok {
for _, tcRaw := range toolCallsArray {
tcMap, ok := tcRaw.(map[string]interface{})
if !ok {
continue
}
toolCall := ToolCall{}
if id, ok := tcMap["id"].(string); ok {
toolCall.ID = id
}
if toolType, ok := tcMap["type"].(string); ok {
toolCall.Type = toolType
}
if funcMap, ok := tcMap["function"].(map[string]interface{}); ok {
toolCall.Function = FunctionCall{}
if name, ok := funcMap["name"].(string); ok {
toolCall.Function.Name = name
}
if argsRaw, ok := funcMap["arguments"]; ok {
if argsStr, ok := argsRaw.(string); ok {
var argsMap map[string]interface{}
if err := json.Unmarshal([]byte(argsStr), &argsMap); err == nil {
toolCall.Function.Arguments = argsMap
}
} else if argsMap, ok := argsRaw.(map[string]interface{}); ok {
toolCall.Function.Arguments = argsMap
}
}
}
if toolCall.ID != "" {
msg.ToolCalls = append(msg.ToolCalls, toolCall)
}
}
}
}
if toolCallID, ok := msgMap["tool_call_id"].(string); ok {
msg.ToolCallID = toolCallID
}
if tn, ok := msgMap["tool_name"].(string); ok && strings.TrimSpace(tn) != "" {
msg.ToolName = strings.TrimSpace(tn)
} else if tn, ok := msgMap["name"].(string); ok && strings.TrimSpace(tn) != "" && strings.EqualFold(msg.Role, "tool") {
msg.ToolName = strings.TrimSpace(tn)
}
out = append(out, msg)
}
return out, nil
}
// ExtractLastUserTurnMessages 仅保留最后一次 user 提问起的消息(不含更早的用户轮次;跳过 system)。
// 与「继续对话」续跑所用轨迹范围一致:当前任务轮次,而非整段多轮对话历史。
func ExtractLastUserTurnMessages(msgs []ChatMessage) []ChatMessage {
if len(msgs) == 0 {
return msgs
}
lastUser := -1
for i, m := range msgs {
if strings.EqualFold(m.Role, "user") {
lastUser = i
}
}
if lastUser < 0 {
return msgs
}
trimmed := msgs[lastUser:]
out := make([]ChatMessage, 0, len(trimmed))
for _, m := range trimmed {
if strings.EqualFold(m.Role, "system") {
continue
}
out = append(out, m)
}
return out
}
// ExtractLastUserTurnTraceJSON 在 JSON 轨迹上裁剪为最后一次 user 起的片段(供落库格式直接处理)。
func ExtractLastUserTurnTraceJSON(traceInputJSON string) string {
traceInputJSON = strings.TrimSpace(traceInputJSON)
if traceInputJSON == "" {
return traceInputJSON
}
var arr []map[string]interface{}
if err := json.Unmarshal([]byte(traceInputJSON), &arr); err != nil {
return traceInputJSON
}
lastUser := -1
for i, m := range arr {
if r, _ := m["role"].(string); strings.EqualFold(r, "user") {
lastUser = i
}
}
if lastUser <= 0 {
return traceInputJSON
}
trimmed := arr[lastUser:]
b, err := json.Marshal(trimmed)
if err != nil {
return traceInputJSON
}
return string(b)
}
// MergeAssistantTraceOutput 将 last_react_output 合并进轨迹最后一条 assistant(与 loadHistoryFromAgentTrace 一致)。
func MergeAssistantTraceOutput(msgs []ChatMessage, assistantOut string) []ChatMessage {
assistantOut = strings.TrimSpace(assistantOut)
if assistantOut == "" || len(msgs) == 0 {
return msgs
}
out := append([]ChatMessage(nil), msgs...)
last := &out[len(out)-1]
if strings.EqualFold(last.Role, "assistant") && len(last.ToolCalls) == 0 {
last.Content = assistantOut
return out
}
out = append(out, ChatMessage{
Role: "assistant",
Content: assistantOut,
})
return out
}
// MessagesToTraceJSON 将消息带序列化为 JSON(跳过 system)。
func MessagesToTraceJSON(msgs []ChatMessage) (string, error) {
filtered := make([]ChatMessage, 0, len(msgs))
for _, m := range msgs {
if strings.EqualFold(m.Role, "system") {
continue
}
filtered = append(filtered, m)
}
b, err := json.Marshal(filtered)
if err != nil {
return "", err
}
return string(b), nil
}
+57
View File
@@ -0,0 +1,57 @@
package agent
import (
"encoding/json"
"testing"
)
func TestExtractLastUserTurnTraceJSON(t *testing.T) {
raw := []map[string]interface{}{
{"role": "user", "content": "old question"},
{"role": "assistant", "content": "old answer"},
{"role": "user", "content": "new target 1.1.1.1"},
{"role": "assistant", "tool_calls": []interface{}{map[string]interface{}{
"id": "c1", "type": "function",
"function": map[string]interface{}{"name": "nmap", "arguments": "{}"},
}}},
{"role": "tool", "tool_call_id": "c1", "content": "open ports"},
}
b, _ := json.Marshal(raw)
out := ExtractLastUserTurnTraceJSON(string(b))
var trimmed []map[string]interface{}
if err := json.Unmarshal([]byte(out), &trimmed); err != nil {
t.Fatal(err)
}
if len(trimmed) != 3 {
t.Fatalf("expected 3 messages, got %d", len(trimmed))
}
if trimmed[0]["content"] != "new target 1.1.1.1" {
t.Fatalf("unexpected first message: %v", trimmed[0])
}
}
func TestExtractLastUserTurnMessagesSkipsSystem(t *testing.T) {
msgs := []ChatMessage{
{Role: "system", Content: "sys"},
{Role: "user", Content: "q"},
{Role: "assistant", Content: "a"},
}
out := ExtractLastUserTurnMessages(msgs)
if len(out) != 2 {
t.Fatalf("expected 2, got %d", len(out))
}
if out[0].Role != "user" {
t.Fatal("expected user first")
}
}
func TestMergeAssistantTraceOutput(t *testing.T) {
msgs := []ChatMessage{
{Role: "user", Content: "q"},
{Role: "assistant", Content: "draft"},
}
out := MergeAssistantTraceOutput(msgs, "final summary")
if out[len(out)-1].Content != "final summary" {
t.Fatalf("expected merged output, got %q", out[len(out)-1].Content)
}
}
+57 -2
View File
@@ -15,6 +15,7 @@ import (
"time"
"cyberstrike-ai/internal/agent"
"cyberstrike-ai/internal/audit"
"cyberstrike-ai/internal/c2"
"cyberstrike-ai/internal/config"
"cyberstrike-ai/internal/database"
@@ -56,10 +57,12 @@ type App struct {
robotMu sync.Mutex // 保护钉钉/飞书长连接的 cancel
dingCancel context.CancelFunc // 钉钉 Stream 取消函数,用于配置变更时重启
larkCancel context.CancelFunc // 飞书长连接取消函数,用于配置变更时重启
wechatCancel context.CancelFunc // 微信 iLink 长轮询取消函数
c2Manager *c2.Manager // C2 管理器(未启用 C2 时为 nil)
c2Watchdog *c2.SessionWatchdog // C2 会话看门狗
c2WatchdogCancel context.CancelFunc // 看门狗取消函数
c2Handler *handler.C2Handler // C2 REST(与 Manager 生命周期同步)
auditSvc *audit.Service
}
// New 创建新应用
@@ -92,6 +95,11 @@ func New(cfg *config.Config, log *logger.Logger, configPath string) (*App, error
return nil, fmt.Errorf("初始化数据库失败: %w", err)
}
auditSvc := audit.NewService(db, cfg, log.Logger)
audit.RegisterConversationCreateHook(auditSvc)
auditSvc.PurgeExpired()
audit.StartRetentionLoop(auditSvc, log.Logger)
// 创建MCP服务器(带数据库持久化)
mcpServer := mcp.NewServerWithStorage(log.Logger, db)
mcpServer.ConfigureHTTPToolCallTimeoutFromAgentMinutes(cfg.Agent.ToolTimeoutMinutes)
@@ -221,6 +229,7 @@ func New(cfg *config.Config, log *logger.Logger, configPath string) (*App, error
// 创建知识库API处理器
knowledgeHandler = handler.NewKnowledgeHandler(knowledgeManager, knowledgeRetriever, knowledgeIndexer, db, log.Logger)
knowledgeHandler.SetAudit(auditSvc)
log.Logger.Info("知识库模块初始化完成", zap.Bool("handler_created", knowledgeHandler != nil))
// 扫描知识库并建立索引(异步)
@@ -317,31 +326,42 @@ func New(cfg *config.Config, log *logger.Logger, configPath string) (*App, error
log.Logger.Warn("创建 agents 目录失败", zap.String("path", agentsDir), zap.Error(err))
}
markdownAgentsHandler := handler.NewMarkdownAgentsHandler(agentsDir)
markdownAgentsHandler.SetAudit(auditSvc)
log.Logger.Info("多代理 Markdown 子 Agent 目录", zap.String("agentsDir", agentsDir))
// 创建处理器
agentHandler := handler.NewAgentHandler(agent, db, cfg, log.Logger)
agentHandler.SetAudit(auditSvc)
agentHandler.SetAgentsMarkdownDir(agentsDir)
// 如果知识库已启用,设置知识库管理器到AgentHandler以便记录检索日志
if knowledgeManager != nil {
agentHandler.SetKnowledgeManager(knowledgeManager)
}
monitorHandler := handler.NewMonitorHandler(mcpServer, executor, db, log.Logger)
monitorHandler.SetAudit(auditSvc)
monitorHandler.SetExternalMCPManager(externalMCPMgr) // 设置外部MCP管理器,以便获取外部MCP执行记录
notificationHandler := handler.NewNotificationHandler(db, agentHandler, log.Logger)
groupHandler := handler.NewGroupHandler(db, log.Logger)
authHandler := handler.NewAuthHandler(authManager, cfg, configPath, log.Logger)
authHandler.SetAudit(auditSvc)
attackChainHandler := handler.NewAttackChainHandler(db, &cfg.OpenAI, log.Logger)
vulnerabilityHandler := handler.NewVulnerabilityHandler(db, log.Logger)
vulnerabilityHandler.SetAudit(auditSvc)
webshellHandler := handler.NewWebShellHandler(log.Logger, db)
webshellHandler.SetAudit(auditSvc)
chatUploadsHandler := handler.NewChatUploadsHandler(log.Logger)
chatUploadsHandler.SetAudit(auditSvc)
registerWebshellTools(mcpServer, db, webshellHandler, log.Logger)
registerWebshellManagementTools(mcpServer, db, webshellHandler, log.Logger)
configHandler := handler.NewConfigHandler(configPath, cfg, mcpServer, executor, agent, attackChainHandler, externalMCPMgr, log.Logger)
configHandler.SetAudit(auditSvc)
agentHandler.SetHitlToolWhitelistSaver(configHandler)
externalMCPHandler := handler.NewExternalMCPHandler(externalMCPMgr, cfg, configPath, log.Logger)
externalMCPHandler.SetAudit(auditSvc)
roleHandler := handler.NewRoleHandler(cfg, configPath, log.Logger)
roleHandler.SetAudit(auditSvc)
skillsHandler := handler.NewSkillsHandler(cfg, configPath, log.Logger)
skillsHandler.SetAudit(auditSvc)
fofaHandler := handler.NewFofaHandler(cfg, log.Logger)
terminalHandler := handler.NewTerminalHandler(log.Logger)
if db != nil {
@@ -356,9 +376,12 @@ func New(cfg *config.Config, log *logger.Logger, configPath string) (*App, error
registerC2Tools(mcpServer, c2Manager, log.Logger, cfg.Server.Port)
}
c2Handler := handler.NewC2Handler(c2Manager, log.Logger)
c2Handler.SetAudit(auditSvc)
// 创建OpenAPI处理器
conversationHandler := handler.NewConversationHandler(db, log.Logger)
conversationHandler.SetAudit(auditSvc)
auditHandler := handler.NewAuditHandler(db, auditSvc, log.Logger)
robotHandler := handler.NewRobotHandler(cfg, db, agentHandler, log.Logger)
openAPIHandler := handler.NewOpenAPIHandler(db, log.Logger, resultStorage, conversationHandler, agentHandler)
@@ -384,6 +407,7 @@ func New(cfg *config.Config, log *logger.Logger, configPath string) (*App, error
c2Watchdog: c2Watchdog,
c2WatchdogCancel: watchdogCancel,
c2Handler: c2Handler,
auditSvc: auditSvc,
}
// 飞书/钉钉长连接(无需公网),启用时在后台启动;后续前端应用配置时会通过 RestartRobotConnections 重启
app.startRobotConnections()
@@ -449,9 +473,11 @@ func New(cfg *config.Config, log *logger.Logger, configPath string) (*App, error
configHandler.SetRetrieverUpdater(knowledgeRetriever)
}
// 设置机器人连接重启器,前端应用配置后无需重启服务即可使钉钉/飞书新配置生效
// 设置机器人连接重启器,前端应用配置后无需重启服务即可使钉钉/飞书/微信新配置生效
configHandler.SetRobotRestarter(app)
wechatRobotHandler := handler.NewWechatRobotHandler(cfg, configHandler, log.Logger)
configHandler.SetC2Runtime(app)
configHandler.SetC2ToolRegistrar(func() error {
if app.config.C2.EnabledEffective() && app.c2Manager != nil {
@@ -469,6 +495,7 @@ func New(cfg *config.Config, log *logger.Logger, configPath string) (*App, error
notificationHandler,
conversationHandler,
robotHandler,
wechatRobotHandler,
groupHandler,
configHandler,
externalMCPHandler,
@@ -483,6 +510,7 @@ func New(cfg *config.Config, log *logger.Logger, configPath string) (*App, error
fofaHandler,
terminalHandler,
app.c2Handler,
auditHandler,
mcpServer,
authManager,
openAPIHandler,
@@ -675,9 +703,14 @@ func (a *App) startRobotConnections() {
a.dingCancel = cancel
go robot.StartDing(ctx, cfg.Robots, a.robotHandler, a.logger.Logger)
}
if cfg.Robots.Wechat.Enabled && cfg.Robots.Wechat.BotToken != "" {
ctx, cancel := context.WithCancel(context.Background())
a.wechatCancel = cancel
go robot.StartWechat(ctx, cfg.Robots, a.robotHandler, cfg.Version, a.logger.Logger)
}
}
// RestartRobotConnections 重启钉钉/飞书长连接,使前端应用配置后立即生效(实现 handler.RobotRestarter
// RestartRobotConnections 重启钉钉/飞书/微信长连接,使前端应用配置后立即生效(实现 handler.RobotRestarter
func (a *App) RestartRobotConnections() {
a.robotMu.Lock()
if a.dingCancel != nil {
@@ -688,6 +721,10 @@ func (a *App) RestartRobotConnections() {
a.larkCancel()
a.larkCancel = nil
}
if a.wechatCancel != nil {
a.wechatCancel()
a.wechatCancel = nil
}
a.robotMu.Unlock()
// 给旧 goroutine 一点时间退出
time.Sleep(200 * time.Millisecond)
@@ -703,6 +740,7 @@ func setupRoutes(
notificationHandler *handler.NotificationHandler,
conversationHandler *handler.ConversationHandler,
robotHandler *handler.RobotHandler,
wechatRobotHandler *handler.WechatRobotHandler,
groupHandler *handler.GroupHandler,
configHandler *handler.ConfigHandler,
externalMCPHandler *handler.ExternalMCPHandler,
@@ -717,6 +755,7 @@ func setupRoutes(
fofaHandler *handler.FofaHandler,
terminalHandler *handler.TerminalHandler,
c2Handler *handler.C2Handler,
auditHandler *handler.AuditHandler,
mcpServer *mcp.Server,
authManager *security.AuthManager,
openAPIHandler *handler.OpenAPIHandler,
@@ -751,6 +790,12 @@ func setupRoutes(
// 机器人测试(需登录):POST /api/robot/testbody: {"platform":"dingtalk","user_id":"test","text":"帮助"},用于验证机器人逻辑
protected.POST("/robot/test", robotHandler.HandleRobotTest)
// 微信 iLink 扫码绑定(需登录)
protected.POST("/robot/wechat/qrcode", wechatRobotHandler.HandleWechatQRCode)
protected.GET("/robot/wechat/qrcode/status", wechatRobotHandler.HandleWechatQRCodeStatus)
protected.POST("/robot/wechat/qrcode/verify", wechatRobotHandler.HandleWechatVerifyCode)
protected.GET("/robot/wechat/status", wechatRobotHandler.HandleWechatStatus)
// Agent Loop
protected.POST("/agent-loop", agentHandler.AgentLoop)
// Agent Loop 流式输出
@@ -847,6 +892,13 @@ func setupRoutes(
protected.POST("/terminal/run/stream", terminalHandler.RunCommandStream)
protected.GET("/terminal/ws", terminalHandler.RunCommandWS)
// 平台审计日志
protected.GET("/audit/meta", auditHandler.Meta)
protected.GET("/audit/summary", auditHandler.Summary)
protected.GET("/audit/logs", auditHandler.ListLogs)
protected.GET("/audit/logs/export", auditHandler.ExportLogs)
protected.GET("/audit/logs/:id", auditHandler.GetLog)
// 外部MCP管理
protected.GET("/external-mcp", externalMCPHandler.GetExternalMCPs)
protected.GET("/external-mcp/stats", externalMCPHandler.GetExternalMCPStats)
@@ -1908,6 +1960,9 @@ func initializeKnowledge(
// 创建知识库API处理器
knowledgeHandler := handler.NewKnowledgeHandler(knowledgeManager, knowledgeRetriever, knowledgeIndexer, db, logger)
if app != nil && app.auditSvc != nil {
knowledgeHandler.SetAudit(app.auditSvc)
}
logger.Info("知识库模块初始化完成", zap.Bool("handler_created", knowledgeHandler != nil))
// 设置知识库管理器到AgentHandler以便记录检索日志
+85 -66
View File
@@ -82,7 +82,7 @@ func NewBuilder(db *database.DB, openAIConfig *config.OpenAIConfig, logger *zap.
}
}
// BuildChainFromConversation 从对话构建攻击链(简化版本:用户输入+最后一轮ReAct输入+大模型输出)
// BuildChainFromConversation 从对话构建攻击链(单次 LLM 调用;输入为当前任务轮次的 last_react 轨迹,与继续对话续跑范围一致)。
func (b *Builder) BuildChainFromConversation(ctx context.Context, conversationID string) (*Chain, error) {
b.logger.Info("开始构建攻击链(简化版本)", zap.String("conversationId", conversationID))
@@ -157,33 +157,34 @@ func (b *Builder) BuildChainFromConversation(ctx context.Context, conversationID
var reactInputFinal string
var dataSource string // 记录数据来源
// 如果成功获取到保存的ReAct数据,直接使用
if reactInputJSON != "" && modelOutput != "" {
// 计算 ReAct 输入的哈希值,用于追踪
hash := sha256.Sum256([]byte(reactInputJSON))
reactInputHash := hex.EncodeToString(hash[:])[:16] // 使用前16字符作为短标识
// 优先使用落库的代理轨迹(与继续对话 loadHistoryFromAgentTrace 同源),并裁剪为「当前任务轮次」
if reactInputJSON != "" {
trimmedJSON := agent.ExtractLastUserTurnTraceJSON(reactInputJSON)
hash := sha256.Sum256([]byte(trimmedJSON))
reactInputHash := hex.EncodeToString(hash[:])[:16]
// 统计消息数量
var messageCount int
var tempMessages []interface{}
if json.Unmarshal([]byte(reactInputJSON), &tempMessages) == nil {
messageCount = len(tempMessages)
if msgs, parseErr := agent.ParseTraceMessages(trimmedJSON); parseErr == nil {
messageCount = len(msgs)
msgs = agent.MergeAssistantTraceOutput(msgs, modelOutput)
reactInputFinal = b.formatAgentTraceFromChatMessages(msgs)
} else {
b.logger.Warn("解析代理轨迹失败,回退原始 JSON 格式化", zap.Error(parseErr))
reactInputFinal = b.formatAgentTraceInputFromJSON(trimmedJSON)
if strings.TrimSpace(modelOutput) != "" {
reactInputFinal += "\n\n## 助手结论(last_react_output\n\n" + modelOutput
}
}
dataSource = "database_last_agent_trace"
b.logger.Info("使用保存的ReAct数据构建攻击链",
dataSource = "last_user_turn_agent_trace"
b.logger.Info("使用当前任务轮次代理轨迹构建攻击链(与续跑上下文范围一致)",
zap.String("conversationId", conversationID),
zap.String("dataSource", dataSource),
zap.Int("reactInputSize", len(reactInputJSON)),
zap.Int("traceInputSizeBeforeTrim", len(reactInputJSON)),
zap.Int("traceInputSizeAfterTrim", len(trimmedJSON)),
zap.Int("messageCount", messageCount),
zap.String("reactInputHash", reactInputHash),
zap.Int("modelOutputSize", len(modelOutput)))
// 从保存的ReAct输入(JSON格式)中提取用户输入
// userInput = b.extractUserInputFromReActInput(reactInputJSON)
// 将JSON格式的messages转换为可读格式
reactInputFinal = b.formatAgentTraceInputFromJSON(reactInputJSON)
} else {
// 2. 如果没有保存的ReAct数据,从对话消息构建
dataSource = "messages_table"
@@ -243,8 +244,15 @@ func (b *Builder) BuildChainFromConversation(ctx context.Context, conversationID
}
}
// 3. 构建简化的prompt,一次性传递给大模型
prompt := b.buildSimplePrompt(reactInputFinal, modelOutput)
// 3. 按 token 预算压缩输入,再构建 prompt(避免超出模型上下文)
reactInputFinal, modelOutput, _ = b.fitAttackChainPayload(reactInputFinal, modelOutput)
// 4. 构建 prompt 并单次调用大模型(助手结论已并入轨迹时不再重复传入)
promptAssistantOut := modelOutput
if reactInputJSON != "" {
promptAssistantOut = ""
}
prompt := b.buildSimplePrompt(reactInputFinal, promptAssistantOut)
// fmt.Println(prompt)
// 6. 调用AI生成攻击链(一次性,不做任何处理)
chainJSON, err := b.callAIForChainGeneration(ctx, prompt)
@@ -366,10 +374,17 @@ func (b *Builder) formatProcessDetailsForAttackChain(details []database.ProcessD
return strings.TrimSpace(sb.String())
}
// buildAgentTraceInput 构建最后一轮ReAct的输入(历史消息+当前用户输入)
// buildAgentTraceInput 构建最后一轮 ReAct 的输入(从最后一条 user 消息起,不含更早轮次)。
func (b *Builder) buildAgentTraceInput(messages []database.Message) string {
start := 0
for i := len(messages) - 1; i >= 0; i-- {
if strings.EqualFold(messages[i].Role, "user") {
start = i
break
}
}
var builder strings.Builder
for _, msg := range messages {
for _, msg := range messages[start:] {
builder.WriteString(fmt.Sprintf("[%s]: %s\n\n", msg.Role, msg.Content))
}
return builder.String()
@@ -396,67 +411,66 @@ func (b *Builder) buildAgentTraceInput(messages []database.Message) string {
// return ""
// }
// formatAgentTraceInputFromJSON 将JSON格式的messages数组转换为可读的字符串格式
// formatAgentTraceInputFromJSON 将 JSON 轨迹转为可读文本(会先按当前任务轮次裁剪)。
func (b *Builder) formatAgentTraceInputFromJSON(reactInputJSON string) string {
var messages []map[string]interface{}
if err := json.Unmarshal([]byte(reactInputJSON), &messages); err != nil {
trimmed := agent.ExtractLastUserTurnTraceJSON(reactInputJSON)
msgs, err := agent.ParseTraceMessages(trimmed)
if err != nil {
b.logger.Warn("解析ReAct输入JSON失败", zap.Error(err))
return reactInputJSON // 如果解析失败,返回原始JSON
return trimmed
}
return b.formatAgentTraceFromChatMessages(msgs)
}
// formatAgentTraceFromChatMessages 将代理消息带格式化为攻击链分析输入(与续跑轨迹字段一致)。
func (b *Builder) formatAgentTraceFromChatMessages(msgs []agent.ChatMessage) string {
var builder strings.Builder
for _, msg := range messages {
role, _ := msg["role"].(string)
content, _ := msg["content"].(string)
for _, msg := range msgs {
role := msg.Role
content := msg.Content
// 处理assistant消息:提取tool_calls信息
if role == "assistant" {
if toolCalls, ok := msg["tool_calls"].([]interface{}); ok && len(toolCalls) > 0 {
// 如果有文本内容,先显示
if content != "" {
builder.WriteString(fmt.Sprintf("[%s]: %s\n", role, content))
}
// 详细显示每个工具调用
builder.WriteString(fmt.Sprintf("[%s] 工具调用 (%d个):\n", role, len(toolCalls)))
for i, toolCall := range toolCalls {
if tc, ok := toolCall.(map[string]interface{}); ok {
toolCallID, _ := tc["id"].(string)
if funcData, ok := tc["function"].(map[string]interface{}); ok {
toolName, _ := funcData["name"].(string)
arguments, _ := funcData["arguments"].(string)
builder.WriteString(fmt.Sprintf(" [工具调用 %d]\n", i+1))
builder.WriteString(fmt.Sprintf(" ID: %s\n", toolCallID))
builder.WriteString(fmt.Sprintf(" 工具名称: %s\n", toolName))
builder.WriteString(fmt.Sprintf(" 参数: %s\n", arguments))
}
if strings.EqualFold(role, "assistant") && len(msg.ToolCalls) > 0 {
if content != "" {
builder.WriteString(fmt.Sprintf("[%s]: %s\n", role, content))
}
builder.WriteString(fmt.Sprintf("[%s] 工具调用 (%d个):\n", role, len(msg.ToolCalls)))
for i, tc := range msg.ToolCalls {
args := ""
if tc.Function.Arguments != nil {
if b, err := json.Marshal(tc.Function.Arguments); err == nil {
args = string(b)
}
}
builder.WriteString("\n")
continue
builder.WriteString(fmt.Sprintf(" [工具调用 %d]\n", i+1))
builder.WriteString(fmt.Sprintf(" ID: %s\n", tc.ID))
builder.WriteString(fmt.Sprintf(" 工具名称: %s\n", tc.Function.Name))
builder.WriteString(fmt.Sprintf(" 参数: %s\n", args))
}
builder.WriteString("\n")
continue
}
// 处理tool消息:显示tool_call_id和完整内容
if role == "tool" {
toolCallID, _ := msg["tool_call_id"].(string)
if toolCallID != "" {
builder.WriteString(fmt.Sprintf("[%s] (tool_call_id: %s):\n%s\n\n", role, toolCallID, content))
if strings.EqualFold(role, "tool") {
if msg.ToolCallID != "" {
builder.WriteString(fmt.Sprintf("[%s] (tool_call_id: %s):\n%s\n\n", role, msg.ToolCallID, content))
} else {
builder.WriteString(fmt.Sprintf("[%s]: %s\n\n", role, content))
}
continue
}
// 其他消息类型(system, user等)正常显示
builder.WriteString(fmt.Sprintf("[%s]: %s\n\n", role, content))
}
return builder.String()
}
// buildSimplePrompt 构建简化的prompt
func (b *Builder) buildSimplePrompt(reactInput, modelOutput string) string {
return fmt.Sprintf(`你是专业的安全测试分析师和攻击链构建专家你的任务是根据对话记录和工具执行结果构建一个逻辑清晰有教育意义的攻击链图完整展现渗透测试的思维过程和执行路径
return fmt.Sprintf(`你是专业的安全测试分析师和攻击链构建专家你的任务是根据**当前任务轮次**的对话记录和工具执行结果一次性输出攻击链 JSON不要分多轮追问
## 输入范围继续对话续跑一致
- 下方ReAct 轨迹仅包含**最后一次用户提问之后**的消息与工具结果last_react 当前任务轮次不含更早的用户提问轮次
- 助手结论为同轮任务的最终输出摘要last_react_output节点须与轨迹中的实际工具执行一致严禁编造
## 核心目标
@@ -618,12 +632,9 @@ func (b *Builder) buildSimplePrompt(reactInput, modelOutput string) string {
5. **漏洞确认**如何确认漏洞存在actionvulnerability
6. **攻击路径**完整的攻击路径是什么从target到vulnerability的路径
## 最后一轮ReAct输入
## 当前任务 ReAct 轨迹含工具执行助手结论见轨迹末尾 assistant
%s
## 大模型输出
%s
## 输出格式
@@ -752,7 +763,15 @@ func (b *Builder) buildSimplePrompt(reactInput, modelOutput string) string {
9. **不要过度精简**如果实际执行步骤较多可以适当增加节点数量最多20个确保不遗漏关键步骤
10. **输出前验证**在输出JSON前必须验证所有边都满足source < target的条件确保DAG结构正确
现在开始分析并构建攻击链`, reactInput, modelOutput)
现在开始分析并构建攻击链`, reactInput, assistantOutSection(modelOutput))
}
func assistantOutSection(modelOutput string) string {
modelOutput = strings.TrimSpace(modelOutput)
if modelOutput == "" {
return ""
}
return "\n## 助手结论(补充)\n\n" + modelOutput + "\n"
}
// saveChain 保存攻击链到数据库
@@ -812,7 +831,7 @@ func (b *Builder) callAIForChainGeneration(ctx context.Context, prompt string) (
},
},
"temperature": 0.3,
"max_completion_tokens": 80000,
"max_completion_tokens": attackChainMaxCompletionTokens(b.maxTokens),
}
var apiResponse struct {
+248
View File
@@ -0,0 +1,248 @@
package attackchain
import (
"strings"
"unicode/utf8"
"go.uber.org/zap"
)
const (
attackChainTruncationMarker = "\n\n...[攻击链输入已截断 / attack chain input truncated]...\n\n"
attackChainSystemReserve = 256
attackChainSafetyReserve = 2048
)
// attackChainMaxCompletionTokens 为攻击链 JSON 输出预留的 completion token 上限。
func attackChainMaxCompletionTokens(maxTotal int) int {
const capTokens = 16384
if maxTotal <= 0 {
return 8192
}
v := maxTotal / 8
if v < 4096 {
v = 4096
}
if v > capTokens {
v = capTokens
}
return v
}
func (b *Builder) modelName() string {
if b.openAIConfig != nil && b.openAIConfig.Model != "" {
return b.openAIConfig.Model
}
return "gpt-4"
}
func (b *Builder) countTokens(text string) int {
if text == "" {
return 0
}
n, err := b.tokenCounter.Count(b.modelName(), text)
if err != nil {
return utf8.RuneCountInString(text) / 4
}
return n
}
// attackChainPayloadTokenBudget 计算 reactInput + modelOutput 可用的 token 预算。
func (b *Builder) attackChainPayloadTokenBudget() int {
maxTotal := b.maxTokens
if maxTotal <= 0 {
maxTotal = 100000
}
templateTok := b.countTokens(b.buildSimplePrompt("", ""))
completion := attackChainMaxCompletionTokens(maxTotal)
reserve := templateTok + attackChainSystemReserve + completion + attackChainSafetyReserve
budget := maxTotal - reserve
minBudget := maxTotal * 35 / 100
if budget < minBudget {
budget = minBudget
}
if budget < 4096 {
budget = 4096
}
return budget
}
// fitAttackChainPayload 在构建最终 prompt 前压缩 ReAct 轨迹与模型输出,避免超出模型上下文。
func (b *Builder) fitAttackChainPayload(reactInput, modelOutput string) (string, string, bool) {
budget := b.attackChainPayloadTokenBudget()
modelBudget := budget * 15 / 100
if modelBudget < 512 {
modelBudget = 512
}
reactBudget := budget - modelBudget
origReactTok := b.countTokens(reactInput)
origModelTok := b.countTokens(modelOutput)
truncated := false
outModel := modelOutput
if origModelTok > modelBudget {
outModel = truncateTextByTokens(b, modelOutput, modelBudget)
truncated = true
}
outReact := reactInput
perToolLimits := []int{12000, 6000, 3000, 1500, 800}
for _, lim := range perToolLimits {
compact := compactFormattedToolBodies(outReact, lim)
if compact != outReact {
outReact = compact
truncated = true
}
if b.countTokens(outReact) <= reactBudget {
break
}
}
if b.countTokens(outReact) > reactBudget {
outReact = truncateTextByTokens(b, outReact, reactBudget)
truncated = true
}
if truncated {
b.logger.Info("攻击链输入已按 token 预算截断",
zap.Int("maxTotalTokens", b.maxTokens),
zap.Int("payloadBudget", budget),
zap.Int("reactBudget", reactBudget),
zap.Int("modelBudget", modelBudget),
zap.Int("reactInputTokensBefore", origReactTok),
zap.Int("reactInputTokensAfter", b.countTokens(outReact)),
zap.Int("modelOutputTokensBefore", origModelTok),
zap.Int("modelOutputTokensAfter", b.countTokens(outModel)),
zap.Int("maxCompletionTokens", attackChainMaxCompletionTokens(b.maxTokens)),
)
}
return outReact, outModel, truncated
}
// compactFormattedToolBodies 缩短格式化 trace 中 [tool] 消息的正文,保留工具头与调用 ID。
func compactFormattedToolBodies(s string, maxRunesPerBody int) string {
if maxRunesPerBody <= 0 || s == "" {
return s
}
const marker = "[tool]"
var out strings.Builder
remaining := s
changed := false
for {
idx := strings.Index(remaining, marker)
if idx < 0 {
out.WriteString(remaining)
break
}
out.WriteString(remaining[:idx])
remaining = remaining[idx:]
nl := strings.IndexByte(remaining, '\n')
if nl < 0 {
out.WriteString(remaining)
break
}
header := remaining[:nl+1]
remaining = remaining[nl+1:]
bodyEnd := strings.Index(remaining, "\n\n[")
var body, rest string
if bodyEnd < 0 {
body = remaining
rest = ""
} else {
body = remaining[:bodyEnd]
rest = remaining[bodyEnd:]
}
if runeLen(body) > maxRunesPerBody {
body = truncateRunesWithNotice(body, maxRunesPerBody)
changed = true
}
out.WriteString(header)
out.WriteString(body)
remaining = rest
if rest == "" {
break
}
}
if !changed {
return s
}
return out.String()
}
func truncateTextByTokens(b *Builder, text string, maxTokens int) string {
if maxTokens <= 0 || text == "" {
return ""
}
if b.countTokens(text) <= maxTokens {
return text
}
markerTok := b.countTokens(attackChainTruncationMarker)
usable := maxTokens - markerTok
if usable < 256 {
usable = maxTokens / 2
}
headBudget := usable * 60 / 100
tailBudget := usable - headBudget
head := takeTokensFromStart(b, text, headBudget)
tail := takeTokensFromEnd(b, text, tailBudget)
return head + attackChainTruncationMarker + tail
}
func takeTokensFromStart(b *Builder, text string, maxTokens int) string {
rs := []rune(text)
if len(rs) == 0 || maxTokens <= 0 {
return ""
}
lo, hi := 0, len(rs)
for lo < hi {
mid := (lo + hi + 1) / 2
if b.countTokens(string(rs[:mid])) <= maxTokens {
lo = mid
} else {
hi = mid - 1
}
}
return string(rs[:lo])
}
func takeTokensFromEnd(b *Builder, text string, maxTokens int) string {
rs := []rune(text)
if len(rs) == 0 || maxTokens <= 0 {
return ""
}
lo, hi := 0, len(rs)
for lo < hi {
mid := (lo + hi) / 2
if b.countTokens(string(rs[mid:])) <= maxTokens {
hi = mid
} else {
lo = mid + 1
}
}
return string(rs[lo:])
}
func truncateRunesWithNotice(s string, maxRunes int) string {
rs := []rune(s)
if len(rs) <= maxRunes {
return s
}
const notice = "\n...[工具输出已截断 / tool output truncated]...\n"
noticeRunes := []rune(notice)
keep := maxRunes - len(noticeRunes)
if keep < 200 {
keep = maxRunes * 2 / 3
}
if keep < 1 {
return notice
}
head := keep * 70 / 100
tail := keep - head
return string(rs[:head]) + notice + string(rs[len(rs)-tail:])
}
func runeLen(s string) int {
return len([]rune(s))
}
+63
View File
@@ -0,0 +1,63 @@
package attackchain
import (
"strings"
"testing"
"cyberstrike-ai/internal/agent"
"cyberstrike-ai/internal/config"
"go.uber.org/zap"
)
func testBuilder(maxTotal int) *Builder {
return &Builder{
logger: zap.NewNop(),
openAIConfig: &config.OpenAIConfig{Model: "gpt-4"},
tokenCounter: agent.NewTikTokenCounter(),
maxTokens: maxTotal,
}
}
func TestCompactFormattedToolBodies(t *testing.T) {
long := strings.Repeat("x", 20000)
in := "[user]: hi\n\n[tool] (tool_call_id: abc):\n" + long + "\n\n[assistant]: done\n"
out := compactFormattedToolBodies(in, 500)
if strings.Contains(out, strings.Repeat("x", 10000)) {
t.Fatal("expected tool body to be truncated")
}
if !strings.Contains(out, "[user]: hi") {
t.Fatal("expected user header preserved")
}
if !strings.Contains(out, "[assistant]: done") {
t.Fatal("expected assistant header preserved")
}
}
func TestFitAttackChainPayloadWithinBudget(t *testing.T) {
b := testBuilder(32000)
react := strings.Repeat("scan ", 50000)
model := strings.Repeat("result ", 10000)
r, m, truncated := b.fitAttackChainPayload(react, model)
if !truncated {
t.Fatal("expected truncation for large payload")
}
prompt := b.buildSimplePrompt(r, m)
total := b.countTokens(prompt) + attackChainMaxCompletionTokens(b.maxTokens) + attackChainSystemReserve
if total > b.maxTokens+attackChainSafetyReserve {
t.Fatalf("prompt still too large: estimated %d > max %d", total, b.maxTokens)
}
_ = m
}
func TestAttackChainMaxCompletionTokens(t *testing.T) {
if got := attackChainMaxCompletionTokens(120000); got != 15000 && got != 16384 {
// 120000/8 = 15000
if got < 4096 || got > 16384 {
t.Fatalf("unexpected completion cap: %d", got)
}
}
if got := attackChainMaxCompletionTokens(0); got != 8192 {
t.Fatalf("expected default 8192, got %d", got)
}
}
+55
View File
@@ -0,0 +1,55 @@
package audit
import (
"strings"
"cyberstrike-ai/internal/database"
"cyberstrike-ai/internal/security"
"github.com/gin-gonic/gin"
)
// RegisterConversationCreateHook records platform audit rows for every new conversation.
func RegisterConversationCreateHook(s *Service) {
if s == nil {
return
}
database.SetConversationCreateHook(func(conv *database.Conversation, meta database.ConversationCreateMeta) {
detail := map[string]interface{}{
"title": conv.Title,
"source": meta.Source,
}
if meta.WebShellConnectionID != "" {
detail["webshell_connection_id"] = meta.WebShellConnectionID
}
s.Record(nil, Entry{
Category: "conversation",
Action: "create",
Result: "success",
Message: "创建对话",
ResourceType: "conversation",
ResourceID: conv.ID,
Detail: detail,
ClientIP: meta.ClientIP,
SessionHint: meta.SessionHint,
})
})
}
// ConversationCreateMeta builds audit metadata for conversation creation.
func ConversationCreateMeta(source string) database.ConversationCreateMeta {
return database.ConversationCreateMeta{Source: strings.TrimSpace(source)}
}
// ConversationCreateMetaFromGin includes client IP and session hint when available.
func ConversationCreateMetaFromGin(c *gin.Context, source string) database.ConversationCreateMeta {
m := ConversationCreateMeta(source)
if c == nil {
return m
}
m.ClientIP = c.ClientIP()
if token := c.GetString(security.ContextAuthTokenKey); token != "" {
m.SessionHint = sessionHint(token)
}
return m
}
+9
View File
@@ -0,0 +1,9 @@
package audit
// RetentionDays returns configured retention; 0 means keep forever.
func (s *Service) RetentionDays() int {
if s == nil || s.cfg == nil {
return 0
}
return s.cfg.Audit.RetentionDaysEffective()
}
+29
View File
@@ -0,0 +1,29 @@
package audit
import "github.com/gin-gonic/gin"
// RecordAction writes a platform audit row with common defaults.
func (s *Service) RecordAction(c *gin.Context, category, action, result, message, resourceType, resourceID string, detail map[string]interface{}) {
if s == nil {
return
}
s.Record(c, Entry{
Category: category,
Action: action,
Result: result,
Message: message,
ResourceType: resourceType,
ResourceID: resourceID,
Detail: detail,
})
}
// RecordOK is a shorthand for successful operations.
func (s *Service) RecordOK(c *gin.Context, category, action, message, resourceType, resourceID string, detail map[string]interface{}) {
s.RecordAction(c, category, action, "success", message, resourceType, resourceID, detail)
}
// RecordFail is a shorthand for failed operations.
func (s *Service) RecordFail(c *gin.Context, category, action, message string, detail map[string]interface{}) {
s.RecordAction(c, category, action, "failure", message, "", "", detail)
}
+86
View File
@@ -0,0 +1,86 @@
package audit
import (
"strings"
"cyberstrike-ai/internal/database"
)
var auditActionsResourceRemoved = map[string]bool{
"delete": true,
"item_delete": true,
"connection_delete": true,
"listener_delete": true,
"session_delete": true,
"task_delete": true,
"execution_delete": true,
"execution_delete_batch": true,
"delete_queue": true,
"delete_batch_task": true,
"markdown_delete": true,
}
// ApplyResourceAvailability sets log.ResourceAvailable when the linked resource can be checked.
func ApplyResourceAvailability(db *database.DB, log *database.AuditLog) {
if log == nil || strings.TrimSpace(log.ResourceID) == "" {
return
}
if auditActionsResourceRemoved[log.Action] {
f := false
log.ResourceAvailable = &f
return
}
if db == nil {
return
}
available, known := resourceStillExists(db, log.ResourceType, log.ResourceID)
if known {
log.ResourceAvailable = &available
}
}
func resourceStillExists(db *database.DB, resourceType, resourceID string) (bool, bool) {
resourceID = strings.TrimSpace(resourceID)
if resourceID == "" {
return false, false
}
t := strings.TrimSpace(resourceType)
if t == "" {
if len(resourceID) > 8 && !strings.HasPrefix(resourceID, "c2_") {
t = "conversation"
} else {
return false, false
}
}
switch t {
case "conversation":
ok, err := db.ConversationExists(resourceID)
return ok, err == nil
case "vulnerability":
_, err := db.GetVulnerability(resourceID)
if err != nil {
return false, strings.Contains(err.Error(), "不存在")
}
return true, true
case "batch_queue":
_, err := db.GetBatchQueue(resourceID)
return err == nil, true
case "c2_listener":
_, err := db.GetC2Listener(resourceID)
return err == nil, true
case "c2_session":
_, err := db.GetC2Session(resourceID)
return err == nil, true
case "c2_task":
_, err := db.GetC2Task(resourceID)
return err == nil, true
case "webshell_connection":
c, err := db.GetWebshellConnection(resourceID)
return err == nil && c != nil, true
case "tool_execution":
_, err := db.GetToolExecution(resourceID)
return err == nil, true
default:
return false, false
}
}
+27
View File
@@ -0,0 +1,27 @@
package audit
import (
"time"
"go.uber.org/zap"
)
// auditRetentionPurgeInterval is how often PurgeExpired runs while the process is up (startup also purges once).
const auditRetentionPurgeInterval = time.Hour
// StartRetentionLoop periodically purges expired audit rows.
func StartRetentionLoop(s *Service, logger *zap.Logger) {
if s == nil {
return
}
go func() {
ticker := time.NewTicker(auditRetentionPurgeInterval)
defer ticker.Stop()
for range ticker.C {
s.PurgeExpired()
if logger != nil {
logger.Debug("audit retention tick completed")
}
}
}()
}
+58
View File
@@ -0,0 +1,58 @@
package audit
import (
"encoding/json"
"strings"
)
var sensitiveKeySubstrings = []string{
"password", "api_key", "apikey", "secret", "token", "authorization",
"credential", "private_key", "access_key",
}
// SanitizeDetail redacts sensitive keys and truncates serialized size.
func SanitizeDetail(detail map[string]interface{}, maxBytes int) map[string]interface{} {
if detail == nil {
return nil
}
if maxBytes <= 0 {
maxBytes = 8192
}
out := sanitizeValue("", detail)
if m, ok := out.(map[string]interface{}); ok {
b, _ := json.Marshal(m)
if len(b) > maxBytes {
return map[string]interface{}{
"_truncated": true,
"_preview": string(b[:maxBytes]),
}
}
return m
}
return map[string]interface{}{"value": out}
}
func sanitizeValue(key string, v interface{}) interface{} {
kl := strings.ToLower(key)
for _, sub := range sensitiveKeySubstrings {
if strings.Contains(kl, sub) {
return "***"
}
}
switch t := v.(type) {
case map[string]interface{}:
m := make(map[string]interface{}, len(t))
for k, val := range t {
m[k] = sanitizeValue(k, val)
}
return m
case []interface{}:
arr := make([]interface{}, len(t))
for i, val := range t {
arr[i] = sanitizeValue(key, val)
}
return arr
default:
return v
}
}
+172
View File
@@ -0,0 +1,172 @@
package audit
import (
"crypto/sha256"
"encoding/hex"
"strings"
"time"
"cyberstrike-ai/internal/config"
"cyberstrike-ai/internal/database"
"cyberstrike-ai/internal/security"
"github.com/gin-gonic/gin"
"github.com/google/uuid"
"go.uber.org/zap"
)
// Service persists platform audit logs.
type Service struct {
db *database.DB
cfg *config.Config
logger *zap.Logger
failThrottle *failureThrottle
}
// NewService creates an audit service.
func NewService(db *database.DB, cfg *config.Config, logger *zap.Logger) *Service {
return &Service{
db: db,
cfg: cfg,
logger: logger,
failThrottle: newFailureThrottle(),
}
}
// Enabled reports whether audit persistence is on.
func (s *Service) Enabled() bool {
if s == nil || s.cfg == nil {
return false
}
return s.cfg.Audit.EnabledEffective()
}
// Record writes one audit row from a Gin request context.
func (s *Service) Record(c *gin.Context, e Entry) {
if s == nil || !s.Enabled() || s.db == nil {
return
}
if strings.TrimSpace(e.Category) == "" || strings.TrimSpace(e.Action) == "" {
return
}
if e.Result == "failure" && !s.allowFailureAudit(c, e) {
return
}
if strings.TrimSpace(e.Result) == "" {
e.Result = "success"
}
if strings.TrimSpace(e.Level) == "" {
if e.Result == "failure" {
e.Level = "warn"
} else {
e.Level = "info"
}
}
if strings.TrimSpace(e.Actor) == "" {
e.Actor = "admin"
}
maxDetail := s.cfg.Audit.MaxDetailBytesEffective()
detail := SanitizeDetail(e.Detail, maxDetail)
sessionHintVal := e.SessionHint
if sessionHintVal == "" && c != nil {
if token := c.GetString(security.ContextAuthTokenKey); token != "" {
sessionHintVal = sessionHint(token)
}
}
clientIPVal := e.ClientIP
if clientIPVal == "" {
clientIPVal = clientIP(c)
}
row := &database.AuditLog{
ID: "audit_" + strings.ReplaceAll(uuid.New().String(), "-", ""),
CreatedAt: time.Now(),
Level: e.Level,
Category: e.Category,
Action: e.Action,
Result: e.Result,
Actor: e.Actor,
SessionHint: sessionHintVal,
ClientIP: clientIPVal,
UserAgent: userAgent(c),
ResourceType: e.ResourceType,
ResourceID: e.ResourceID,
Message: e.Message,
Detail: detail,
}
if err := s.db.AppendAuditLog(row); err != nil && s.logger != nil {
s.logger.Warn("写入审计日志失败",
zap.String("action", e.Action),
zap.Error(err),
)
}
}
// RecordSystem writes an audit row without HTTP context (e.g. retention cleanup).
func (s *Service) RecordSystem(e Entry) {
s.Record(nil, e)
}
// PurgeExpired deletes rows older than retention_days when configured.
func (s *Service) PurgeExpired() {
if s == nil || s.db == nil || s.cfg == nil {
return
}
days := s.cfg.Audit.RetentionDaysEffective()
if days <= 0 {
return
}
cutoff := time.Now().AddDate(0, 0, -days)
n, err := s.db.DeleteAuditLogsBefore(cutoff)
if err != nil {
if s.logger != nil {
s.logger.Warn("清理过期审计日志失败", zap.Error(err))
}
return
}
if n > 0 && s.logger != nil {
s.logger.Info("已清理过期审计日志", zap.Int64("deleted", n))
}
}
// HintFromToken returns a short stable hash prefix for a session token.
func HintFromToken(token string) string {
return sessionHint(token)
}
func sessionHint(token string) string {
token = strings.TrimSpace(token)
if token == "" {
return ""
}
sum := sha256.Sum256([]byte(token))
return hex.EncodeToString(sum[:4])
}
func (s *Service) allowFailureAudit(c *gin.Context, e Entry) bool {
if !isAuthFailureThrottled(e.Category, e.Action) {
return true
}
cooldown := time.Duration(s.cfg.Audit.AuthFailureCooldownEffective()) * time.Second
key := authFailureThrottleKey(e.Category, e.Action, clientIP(c))
return s.failThrottle.allow(key, cooldown)
}
func clientIP(c *gin.Context) string {
if c == nil {
return ""
}
return c.ClientIP()
}
func userAgent(c *gin.Context) string {
if c == nil {
return ""
}
ua := c.GetHeader("User-Agent")
if len(ua) > 512 {
return ua[:512]
}
return ua
}
+55
View File
@@ -0,0 +1,55 @@
package audit
import (
"sync"
"time"
)
// failureThrottle deduplicates high-frequency failure audit rows (e.g. wrong password).
type failureThrottle struct {
mu sync.Mutex
last map[string]time.Time
}
func newFailureThrottle() *failureThrottle {
return &failureThrottle{last: make(map[string]time.Time)}
}
// allow reports whether a row with the given key may be written now.
func (t *failureThrottle) allow(key string, cooldown time.Duration) bool {
if t == nil || cooldown <= 0 || key == "" {
return true
}
now := time.Now()
t.mu.Lock()
defer t.mu.Unlock()
if prev, ok := t.last[key]; ok && now.Sub(prev) < cooldown {
return false
}
t.last[key] = now
if len(t.last) > 4096 {
for k, ts := range t.last {
if now.Sub(ts) > cooldown*2 {
delete(t.last, k)
}
}
}
return true
}
// authFailureThrottleKey builds a per-IP key for auth failure deduplication.
func authFailureThrottleKey(category, action, clientIP string) string {
return category + ":" + action + ":" + clientIP
}
func isAuthFailureThrottled(category, action string) bool {
if category != "auth" {
return false
}
switch action {
case "login", "change_password":
return true
default:
return false
}
}
+16
View File
@@ -0,0 +1,16 @@
package audit
// Entry describes one platform audit record (not chat/tool execution bodies).
type Entry struct {
Level string
Category string
Action string
Result string // success | failure
Actor string
SessionHint string
ResourceType string
ResourceID string
Message string
Detail map[string]interface{}
ClientIP string // optional when c is nil (robot, batch, DB hook)
}
+4 -2
View File
@@ -239,13 +239,15 @@ func (m *Manager) StartListener(id string) (*database.C2Listener, error) {
}
cfg.ApplyDefaults()
// 通过工厂创建具体实现
// 通过工厂创建具体实现。必须使用 rec 的副本:HTTP handler 在返回 JSON 前会清空
// rec.ImplantToken / EncryptionKey 做脱敏,若 listener 实现持有同一指针会导致 beacon 鉴权永久失败。
listenerRec := *rec
factory := m.registry.Get(rec.Type)
if factory == nil {
return nil, ErrUnsupportedType
}
inst, err := factory(ListenerCreationCtx{
Listener: rec,
Listener: &listenerRec,
Config: cfg,
Manager: m,
Logger: m.logger.With(zap.String("listener_id", rec.ID), zap.String("type", rec.Type)),
+74
View File
@@ -0,0 +1,74 @@
package c2
import (
"io"
"net"
"net/http"
"path/filepath"
"strconv"
"strings"
"testing"
"time"
"cyberstrike-ai/internal/database"
"go.uber.org/zap"
)
// 回归:StartListener 返回的 rec 被 handler 脱敏清空 ImplantToken 后,运行中的 HTTP listener 仍能鉴权。
func TestStartListener_ImplantTokenSurvivesHandlerRedaction(t *testing.T) {
tmp := t.TempDir()
db, err := database.NewDB(filepath.Join(tmp, "c2.sqlite"), zap.NewNop())
if err != nil {
t.Fatal(err)
}
t.Cleanup(func() { _ = db.Close() })
lnPick, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatal(err)
}
port := lnPick.Addr().(*net.TCPAddr).Port
_ = lnPick.Close()
mgr := NewManager(db, zap.NewNop(), tmp)
mgr.Registry().Register(string(ListenerTypeHTTPBeacon), NewHTTPBeaconListener)
rec, err := mgr.CreateListener(CreateListenerInput{
Name: "t",
Type: string(ListenerTypeHTTPBeacon),
BindHost: "127.0.0.1",
BindPort: port,
})
if err != nil {
t.Fatal(err)
}
token := rec.ImplantToken
rec, err = mgr.StartListener(rec.ID)
if err != nil {
t.Fatal(err)
}
// 模拟 internal/handler/c2.go StartListener 在 JSON 响应前的脱敏
rec.ImplantToken = ""
rec.EncryptionKey = ""
time.Sleep(50 * time.Millisecond)
body := `{"hostname":"n","username":"u","os":"Linux","arch":"amd64","internal_ip":"10.0.0.1","pid":42}`
req, _ := http.NewRequest(http.MethodPost, "http://127.0.0.1:"+strconv.Itoa(port)+"/check_in", strings.NewReader(body))
req.Header.Set("X-Implant-Token", token)
req.Header.Set("Content-Type", "application/json")
resp, err := http.DefaultClient.Do(req)
if err != nil {
t.Fatal(err)
}
defer resp.Body.Close()
b, _ := io.ReadAll(resp.Body)
if resp.StatusCode != http.StatusOK {
t.Fatalf("status=%d body=%s", resp.StatusCode, b)
}
if !strings.Contains(string(b), "session_id") {
t.Fatalf("expected session_id in body: %s", b)
}
_ = mgr.StopListener(rec.ID)
}
+97 -11
View File
@@ -26,6 +26,7 @@ type Config struct {
Security SecurityConfig `yaml:"security"`
Database DatabaseConfig `yaml:"database"`
Auth AuthConfig `yaml:"auth"`
Audit AuditConfig `yaml:"audit,omitempty" json:"audit,omitempty"`
ExternalMCP ExternalMCPConfig `yaml:"external_mcp,omitempty"`
Knowledge KnowledgeConfig `yaml:"knowledge,omitempty"`
C2 C2Config `yaml:"c2,omitempty" json:"c2,omitempty"` // 内置 C2 总开关;未配置时默认启用
@@ -39,9 +40,9 @@ type Config struct {
// MultiAgentConfig 基于 CloudWeGo Eino adk/prebuilt 的多代理编排(deep | plan_execute | supervisor,与单 Agent /agent-loop 并存)。
type MultiAgentConfig struct {
Enabled bool `yaml:"enabled" json:"enabled"`
RobotUseMultiAgent bool `yaml:"robot_use_multi_agent" json:"robot_use_multi_agent"` // 为 true 时钉钉/飞书/企微机器人走 Eino 多代理
BatchUseMultiAgent bool `yaml:"batch_use_multi_agent" json:"batch_use_multi_agent"` // 为 true 时批量任务队列中每子任务走 Eino 多代理
Enabled bool `yaml:"enabled" json:"enabled"`
RobotDefaultAgentMode string `yaml:"robot_default_agent_mode,omitempty" json:"robot_default_agent_mode,omitempty"` // react | eino_single | deep | plan_execute | supervisor
BatchUseMultiAgent bool `yaml:"batch_use_multi_agent" json:"batch_use_multi_agent"` // 为 true 时批量任务队列中每子任务走 Eino 多代理
// Orchestration 已弃用:保留仅兼容旧版 config.yaml;编排由聊天/WebShell 请求体 orchestration 决定,未传时按 deep。
Orchestration string `yaml:"orchestration,omitempty" json:"orchestration,omitempty"`
MaxIteration int `yaml:"max_iteration" json:"max_iteration"` // 主代理 / 执行器最大推理轮次(Deep、Supervisor、plan_execute 的 Executor
@@ -227,6 +228,10 @@ type MultiAgentEinoMiddlewareConfig struct {
DeepOutputKey string `yaml:"deep_output_key,omitempty" json:"deep_output_key,omitempty"`
// DeepModelRetryMaxRetries > 0 enables deep.Config ModelRetryConfig (framework-level chat model retries).
DeepModelRetryMaxRetries int `yaml:"deep_model_retry_max_retries,omitempty" json:"deep_model_retry_max_retries,omitempty"`
// RunRetryMaxAttempts > 0429/5xx/网络抖动时 handler 分段续跑次数;0=默认 10。
RunRetryMaxAttempts int `yaml:"run_retry_max_attempts,omitempty" json:"run_retry_max_attempts,omitempty"`
// RunRetryMaxBackoffSec 单次退避上限秒数;0=默认 30。
RunRetryMaxBackoffSec int `yaml:"run_retry_max_backoff_sec,omitempty" json:"run_retry_max_backoff_sec,omitempty"`
// TaskToolDescriptionPrefix when non-empty sets deep.Config TaskToolDescriptionGenerator (sub-agent names appended).
TaskToolDescriptionPrefix string `yaml:"task_tool_description_prefix,omitempty" json:"task_tool_description_prefix,omitempty"`
}
@@ -362,9 +367,9 @@ type MultiAgentSubConfig struct {
// MultiAgentPublic 返回给前端的精简信息(不含子代理指令全文)。
type MultiAgentPublic struct {
Enabled bool `json:"enabled"`
RobotUseMultiAgent bool `json:"robot_use_multi_agent"`
BatchUseMultiAgent bool `json:"batch_use_multi_agent"`
Enabled bool `json:"enabled"`
RobotDefaultAgentMode string `json:"robot_default_agent_mode,omitempty"`
BatchUseMultiAgent bool `json:"batch_use_multi_agent"`
SubAgentCount int `json:"sub_agent_count"`
Orchestration string `json:"orchestration,omitempty"`
PlanExecuteLoopMaxIterations int `json:"plan_execute_loop_max_iterations"`
@@ -372,6 +377,18 @@ type MultiAgentPublic struct {
ToolSearchAlwaysVisibleEffectiveTools []string `json:"tool_search_always_visible_effective_tools,omitempty"`
}
// NormalizeRobotAgentMode 解析机器人默认对话模式(react | eino_single | deep | plan_execute | supervisor);空值视为 react。
func NormalizeRobotAgentMode(ma MultiAgentConfig) string {
s := strings.TrimSpace(strings.ToLower(ma.RobotDefaultAgentMode))
if s == "" || s == "single" || s == "react" {
return "react"
}
if s == "eino_single" {
return "eino_single"
}
return NormalizeMultiAgentOrchestration(s)
}
// NormalizeMultiAgentOrchestration 返回 deep、plan_execute 或 supervisor。
func NormalizeMultiAgentOrchestration(s string) string {
v := strings.TrimSpace(strings.ToLower(s))
@@ -387,22 +404,35 @@ func NormalizeMultiAgentOrchestration(s string) string {
// MultiAgentAPIUpdate 设置页/API 仅更新多代理标量字段;写入 YAML 时不覆盖 sub_agents 等块。
type MultiAgentAPIUpdate struct {
Enabled bool `json:"enabled"`
RobotUseMultiAgent bool `json:"robot_use_multi_agent"`
BatchUseMultiAgent bool `json:"batch_use_multi_agent"`
Enabled bool `json:"enabled"`
RobotDefaultAgentMode string `json:"robot_default_agent_mode,omitempty"`
BatchUseMultiAgent bool `json:"batch_use_multi_agent"`
PlanExecuteLoopMaxIterations *int `json:"plan_execute_loop_max_iterations,omitempty"`
// 指针区分「JSON 未传该字段」与「传空数组要清空」;省略时不应覆盖 YAML 中的常驻工具白名单。
ToolSearchAlwaysVisibleTools *[]string `json:"tool_search_always_visible_tools,omitempty"`
}
// RobotsConfig 机器人配置(企业微信、钉钉、飞书等)
// RobotsConfig 机器人配置(企业微信、钉钉、飞书、微信 iLink 等)
type RobotsConfig struct {
Session RobotSessionConfig `yaml:"session,omitempty" json:"session,omitempty"` // 机器人会话隔离策略
Wechat RobotWechatConfig `yaml:"wechat,omitempty" json:"wechat,omitempty"` // 微信(iLink 扫码绑定)
Wecom RobotWecomConfig `yaml:"wecom,omitempty" json:"wecom,omitempty"` // 企业微信
Dingtalk RobotDingtalkConfig `yaml:"dingtalk,omitempty" json:"dingtalk,omitempty"` // 钉钉
Lark RobotLarkConfig `yaml:"lark,omitempty" json:"lark,omitempty"` // 飞书
}
// RobotWechatConfig 微信 iLink 机器人配置(个人微信 ClawBot / iLink 协议)
type RobotWechatConfig struct {
Enabled bool `yaml:"enabled" json:"enabled"`
BotToken string `yaml:"bot_token,omitempty" json:"bot_token,omitempty"`
ILinkBotID string `yaml:"ilink_bot_id,omitempty" json:"ilink_bot_id,omitempty"`
ILinkUserID string `yaml:"ilink_user_id,omitempty" json:"ilink_user_id,omitempty"`
BaseURL string `yaml:"base_url,omitempty" json:"base_url,omitempty"` // 默认 https://ilinkai.weixin.qq.com
BotType string `yaml:"bot_type,omitempty" json:"bot_type,omitempty"` // get_bot_qrcode 参数,默认 3
BotAgent string `yaml:"bot_agent,omitempty" json:"bot_agent,omitempty"` // base_info.bot_agent
GetUpdatesBuf string `yaml:"get_updates_buf,omitempty" json:"get_updates_buf,omitempty"` // 长轮询游标(运行时)
}
// RobotSessionConfig 机器人会话隔离策略
type RobotSessionConfig struct {
StrictUserIdentity *bool `yaml:"strict_user_identity,omitempty" json:"strict_user_identity,omitempty"` // true 时只允许真实用户标识,不允许会话/群 ID 兜底
@@ -484,7 +514,7 @@ type OpenAIConfig struct {
type OpenAIReasoningConfig struct {
// Mode: auto(默认)| on | off | default(与 auto 相同)。off 时不向模型附加推理扩展字段。
Mode string `yaml:"mode,omitempty" json:"mode,omitempty"`
// Effort: low | medium | high | max;空表示不单独指定强度(各 profile 行为见 internal/reasoning
// Effort: low | medium | high | max | xhighmax/xhigh 为不同网关最高档命名,原样下发、不互转。空表示不单独指定强度
Effort string `yaml:"effort,omitempty" json:"effort,omitempty"`
// AllowClientReasoning 为 false 时忽略请求体 reasoningnil 或未设置等同于 true。
AllowClientReasoning *bool `yaml:"allow_client_reasoning,omitempty" json:"allow_client_reasoning,omitempty"`
@@ -562,6 +592,51 @@ type AuthConfig struct {
GeneratedPasswordPersistErr string `yaml:"-" json:"-"`
}
// AuditConfig platform operation audit log settings (not chat/tool execution bodies).
type AuditConfig struct {
// Enabled nil or true enables persistence; explicit false disables.
Enabled *bool `yaml:"enabled,omitempty" json:"enabled,omitempty"`
RetentionDays int `yaml:"retention_days,omitempty" json:"retention_days,omitempty"`
MaxDetailBytes int `yaml:"max_detail_bytes,omitempty" json:"max_detail_bytes,omitempty"`
// AuthFailureCooldownSeconds: per-IP cooldown for auth login/change_password failure audit rows; -1 disables; 0 uses default 60.
AuthFailureCooldownSeconds int `yaml:"auth_failure_cooldown_seconds,omitempty" json:"auth_failure_cooldown_seconds,omitempty"`
}
// EnabledEffective returns true unless audit.enabled is explicitly false.
func (a AuditConfig) EnabledEffective() bool {
if a.Enabled == nil {
return true
}
return *a.Enabled
}
// RetentionDaysEffective returns retention; 0 means keep forever.
func (a AuditConfig) RetentionDaysEffective() int {
if a.RetentionDays < 0 {
return 0
}
return a.RetentionDays
}
// MaxDetailBytesEffective caps serialized detail JSON size.
func (a AuditConfig) MaxDetailBytesEffective() int {
if a.MaxDetailBytes <= 0 {
return 8192
}
return a.MaxDetailBytes
}
// AuthFailureCooldownEffective returns seconds between duplicate auth-failure audit rows per IP (default 60; -1 disables).
func (a AuditConfig) AuthFailureCooldownEffective() int {
if a.AuthFailureCooldownSeconds < 0 {
return 0
}
if a.AuthFailureCooldownSeconds == 0 {
return 60
}
return a.AuthFailureCooldownSeconds
}
// ExternalMCPConfig 外部MCP配置
type ExternalMCPConfig struct {
Servers map[string]ExternalMCPServerConfig `yaml:"servers,omitempty" json:"servers,omitempty"`
@@ -654,6 +729,9 @@ func Load(path string) (*Config, error) {
if cfg.Auth.SessionDurationHours <= 0 {
cfg.Auth.SessionDurationHours = 12
}
if cfg.Audit.MaxDetailBytes <= 0 {
cfg.Audit.MaxDetailBytes = 8192
}
if strings.TrimSpace(cfg.Auth.Password) == "" {
password, err := generateStrongPassword(24)
if err != nil {
@@ -1157,6 +1235,14 @@ func Default() *Config {
Auth: AuthConfig{
SessionDurationHours: 12,
},
Audit: func() AuditConfig {
on := true
return AuditConfig{
RetentionDays: 90,
MaxDetailBytes: 8192,
Enabled: &on,
}
}(),
Robots: RobotsConfig{
Session: RobotSessionConfig{
StrictUserIdentity: &strictRobotIdentity,
+210
View File
@@ -0,0 +1,210 @@
package database
import (
"encoding/json"
"errors"
"strings"
"time"
)
// AuditLog platform operation audit record.
type AuditLog struct {
ID string `json:"id"`
CreatedAt time.Time `json:"createdAt"`
Level string `json:"level"`
Category string `json:"category"`
Action string `json:"action"`
Result string `json:"result"`
Actor string `json:"actor"`
SessionHint string `json:"sessionHint,omitempty"`
ClientIP string `json:"clientIp,omitempty"`
UserAgent string `json:"userAgent,omitempty"`
ResourceType string `json:"resourceType,omitempty"`
ResourceID string `json:"resourceId,omitempty"`
ResourceAvailable *bool `json:"resourceAvailable,omitempty"` // API-only: whether linked resource still exists
Message string `json:"message"`
Detail map[string]interface{} `json:"detail,omitempty"`
}
// ListAuditLogsFilter query parameters.
type ListAuditLogsFilter struct {
Level string
Category string
Action string
Result string
Query string
ResourceType string
ResourceID string
Since *time.Time
Until *time.Time
Limit int
Offset int
}
func buildAuditLogsWhere(filter ListAuditLogsFilter) (string, []interface{}) {
conditions := []string{"1=1"}
args := []interface{}{}
if filter.Level != "" {
conditions = append(conditions, "level = ?")
args = append(args, filter.Level)
}
if filter.Category != "" {
conditions = append(conditions, "category = ?")
args = append(args, filter.Category)
}
if filter.Action != "" {
conditions = append(conditions, "action = ?")
args = append(args, filter.Action)
}
if filter.Result != "" {
conditions = append(conditions, "result = ?")
args = append(args, filter.Result)
}
if filter.ResourceType != "" {
conditions = append(conditions, "resource_type = ?")
args = append(args, filter.ResourceType)
}
if filter.ResourceID != "" {
conditions = append(conditions, "resource_id = ?")
args = append(args, filter.ResourceID)
}
if filter.Since != nil {
conditions = append(conditions, "created_at >= ?")
args = append(args, *filter.Since)
}
if filter.Until != nil {
conditions = append(conditions, "created_at <= ?")
args = append(args, *filter.Until)
}
if q := strings.TrimSpace(filter.Query); q != "" {
like := "%" + q + "%"
conditions = append(conditions, "(message LIKE ? OR resource_id LIKE ? OR action LIKE ? OR category LIKE ?)")
args = append(args, like, like, like, like)
}
return strings.Join(conditions, " AND "), args
}
// AppendAuditLog inserts one audit row.
func (db *DB) AppendAuditLog(row *AuditLog) error {
if row == nil {
return errors.New("audit log is nil")
}
if strings.TrimSpace(row.ID) == "" {
return errors.New("audit id is required")
}
if row.CreatedAt.IsZero() {
row.CreatedAt = time.Now()
}
if strings.TrimSpace(row.Level) == "" {
row.Level = "info"
}
detailJSON := ""
if len(row.Detail) > 0 {
if b, err := json.Marshal(row.Detail); err == nil {
detailJSON = string(b)
}
}
query := `
INSERT INTO audit_logs (
id, created_at, level, category, action, result, actor, session_hint,
client_ip, user_agent, resource_type, resource_id, message, detail_json
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
`
_, err := db.Exec(query,
row.ID, row.CreatedAt, row.Level, row.Category, row.Action, row.Result,
row.Actor, row.SessionHint, row.ClientIP, row.UserAgent,
row.ResourceType, row.ResourceID, row.Message, detailJSON,
)
return err
}
// GetAuditLogByID returns one row.
func (db *DB) GetAuditLogByID(id string) (*AuditLog, error) {
id = strings.TrimSpace(id)
if id == "" {
return nil, errors.New("id is required")
}
query := `
SELECT id, created_at, level, category, action, result, actor,
COALESCE(session_hint, ''), COALESCE(client_ip, ''), COALESCE(user_agent, ''),
COALESCE(resource_type, ''), COALESCE(resource_id, ''), message, COALESCE(detail_json, '')
FROM audit_logs WHERE id = ?
`
var row AuditLog
var detailJSON string
err := db.QueryRow(query, id).Scan(
&row.ID, &row.CreatedAt, &row.Level, &row.Category, &row.Action, &row.Result, &row.Actor,
&row.SessionHint, &row.ClientIP, &row.UserAgent,
&row.ResourceType, &row.ResourceID, &row.Message, &detailJSON,
)
if err != nil {
return nil, err
}
if detailJSON != "" {
_ = json.Unmarshal([]byte(detailJSON), &row.Detail)
}
return &row, nil
}
// CountAuditLogs counts rows matching filter.
func (db *DB) CountAuditLogs(filter ListAuditLogsFilter) (int64, error) {
where, args := buildAuditLogsWhere(filter)
query := `SELECT COUNT(*) FROM audit_logs WHERE ` + where
var n int64
err := db.QueryRow(query, args...).Scan(&n)
return n, err
}
// ListAuditLogs lists audit rows newest first.
func (db *DB) ListAuditLogs(filter ListAuditLogsFilter) ([]*AuditLog, error) {
where, args := buildAuditLogsWhere(filter)
limit := filter.Limit
if limit <= 0 || limit > 500 {
limit = 50
}
offset := filter.Offset
if offset < 0 {
offset = 0
}
query := `
SELECT id, created_at, level, category, action, result, actor,
COALESCE(session_hint, ''), COALESCE(client_ip, ''), COALESCE(user_agent, ''),
COALESCE(resource_type, ''), COALESCE(resource_id, ''), message, COALESCE(detail_json, '')
FROM audit_logs
WHERE ` + where + `
ORDER BY created_at DESC
LIMIT ? OFFSET ?
`
args = append(args, limit, offset)
rows, err := db.Query(query, args...)
if err != nil {
return nil, err
}
defer rows.Close()
var list []*AuditLog
for rows.Next() {
var row AuditLog
var detailJSON string
if err := rows.Scan(
&row.ID, &row.CreatedAt, &row.Level, &row.Category, &row.Action, &row.Result, &row.Actor,
&row.SessionHint, &row.ClientIP, &row.UserAgent,
&row.ResourceType, &row.ResourceID, &row.Message, &detailJSON,
); err != nil {
continue
}
if detailJSON != "" {
_ = json.Unmarshal([]byte(detailJSON), &row.Detail)
}
list = append(list, &row)
}
return list, rows.Err()
}
// DeleteAuditLogsBefore removes rows older than cutoff.
func (db *DB) DeleteAuditLogsBefore(cutoff time.Time) (int64, error) {
res, err := db.Exec(`DELETE FROM audit_logs WHERE created_at < ?`, cutoff)
if err != nil {
return 0, err
}
return res.RowsAffected()
}
+27 -5
View File
@@ -37,12 +37,12 @@ type Message struct {
}
// CreateConversation 创建新对话
func (db *DB) CreateConversation(title string) (*Conversation, error) {
return db.CreateConversationWithWebshell("", title)
func (db *DB) CreateConversation(title string, meta ConversationCreateMeta) (*Conversation, error) {
return db.CreateConversationWithWebshell("", title, meta)
}
// CreateConversationWithWebshell 创建新对话,可选绑定 WebShell 连接 ID(为空则普通对话)
func (db *DB) CreateConversationWithWebshell(webshellConnectionID, title string) (*Conversation, error) {
func (db *DB) CreateConversationWithWebshell(webshellConnectionID, title string, meta ConversationCreateMeta) (*Conversation, error) {
id := uuid.New().String()
now := time.Now()
@@ -62,12 +62,17 @@ func (db *DB) CreateConversationWithWebshell(webshellConnectionID, title string)
return nil, fmt.Errorf("创建对话失败: %w", err)
}
return &Conversation{
conv := &Conversation{
ID: id,
Title: title,
CreatedAt: now,
UpdatedAt: now,
}, nil
}
if webshellConnectionID != "" {
meta.WebShellConnectionID = webshellConnectionID
}
notifyConversationCreated(conv, meta)
return conv, nil
}
// GetConversationByWebshellConnectionID 根据 WebShell 连接 ID 获取该连接下最近一条对话(用于 AI 助手持久化)
@@ -182,6 +187,23 @@ func (db *DB) ListConversationsByWebshellConnectionID(connectionID string) ([]We
return list, rows.Err()
}
// ConversationExists reports whether a conversation row exists (lightweight check for audit links).
func (db *DB) ConversationExists(id string) (bool, error) {
id = strings.TrimSpace(id)
if id == "" {
return false, nil
}
var one int
err := db.QueryRow("SELECT 1 FROM conversations WHERE id = ? LIMIT 1", id).Scan(&one)
if err == sql.ErrNoRows {
return false, nil
}
if err != nil {
return false, err
}
return true, nil
}
// GetConversation 获取对话
func (db *DB) GetConversation(id string) (*Conversation, error) {
var conv Conversation
@@ -0,0 +1,29 @@
package database
// ConversationCreateMeta describes how a conversation was created (for audit hooks).
type ConversationCreateMeta struct {
Source string
WebShellConnectionID string
ClientIP string
SessionHint string
}
// ConversationCreateHook is invoked after a conversation row is inserted.
type ConversationCreateHook func(conv *Conversation, meta ConversationCreateMeta)
var conversationCreateHook ConversationCreateHook
// SetConversationCreateHook registers a global hook (e.g. platform audit).
func SetConversationCreateHook(h ConversationCreateHook) {
conversationCreateHook = h
}
func notifyConversationCreated(conv *Conversation, meta ConversationCreateMeta) {
if conversationCreateHook == nil || conv == nil {
return
}
if meta.Source == "" {
meta.Source = "unknown"
}
conversationCreateHook(conv, meta)
}
+26
View File
@@ -387,6 +387,24 @@ func (db *DB) initTables() error {
created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP
);`
createAuditLogsTable := `
CREATE TABLE IF NOT EXISTS audit_logs (
id TEXT PRIMARY KEY,
created_at DATETIME NOT NULL,
level TEXT NOT NULL DEFAULT 'info',
category TEXT NOT NULL,
action TEXT NOT NULL,
result TEXT NOT NULL,
actor TEXT NOT NULL DEFAULT 'admin',
session_hint TEXT,
client_ip TEXT,
user_agent TEXT,
resource_type TEXT,
resource_id TEXT,
message TEXT NOT NULL,
detail_json TEXT
);`
createC2ProfilesTable := `
CREATE TABLE IF NOT EXISTS c2_profiles (
id TEXT PRIMARY KEY,
@@ -445,6 +463,10 @@ func (db *DB) initTables() error {
CREATE INDEX IF NOT EXISTS idx_c2_events_created_at ON c2_events(created_at);
CREATE INDEX IF NOT EXISTS idx_c2_events_category ON c2_events(category);
CREATE INDEX IF NOT EXISTS idx_c2_events_session ON c2_events(session_id);
CREATE INDEX IF NOT EXISTS idx_audit_logs_created_at ON audit_logs(created_at);
CREATE INDEX IF NOT EXISTS idx_audit_logs_category ON audit_logs(category);
CREATE INDEX IF NOT EXISTS idx_audit_logs_action ON audit_logs(action);
CREATE INDEX IF NOT EXISTS idx_audit_logs_result ON audit_logs(result);
`
if _, err := db.Exec(createConversationsTable); err != nil {
@@ -514,6 +536,10 @@ func (db *DB) initTables() error {
return fmt.Errorf("创建webshell_connection_states表失败: %w", err)
}
if _, err := db.Exec(createAuditLogsTable); err != nil {
return fmt.Errorf("创建audit_logs表失败: %w", err)
}
for tableName, ddl := range map[string]string{
"c2_listeners": createC2ListenersTable,
"c2_sessions": createC2SessionsTable,
+78 -69
View File
@@ -3,12 +3,84 @@ package database
import (
"database/sql"
"fmt"
"strings"
"time"
"github.com/google/uuid"
"go.uber.org/zap"
)
// VulnerabilityListFilter 列表/统计/导出共用的筛选条件
type VulnerabilityListFilter struct {
ID string
Search string // 关键词模糊匹配(标题、描述、类型、目标等)
ConversationID string
Severity string
Status string
TaskID string
ConversationTag string
TaskTag string
}
func escapeVulnerabilityLikePattern(s string) string {
s = strings.ReplaceAll(s, `\`, `\\`)
s = strings.ReplaceAll(s, `%`, `\%`)
s = strings.ReplaceAll(s, `_`, `\_`)
return "%" + s + "%"
}
func (f VulnerabilityListFilter) appendWhere(query string, args []interface{}) (string, []interface{}) {
if f.ID != "" {
query += " AND id = ?"
args = append(args, f.ID)
}
if f.ConversationID != "" {
query += " AND conversation_id = ?"
args = append(args, f.ConversationID)
}
if f.TaskID != "" {
query += " AND EXISTS (SELECT 1 FROM batch_tasks bt WHERE bt.conversation_id = vulnerabilities.conversation_id AND (bt.id = ? OR bt.queue_id = ?))"
args = append(args, f.TaskID, f.TaskID)
}
if f.ConversationTag != "" {
query += " AND conversation_tag = ?"
args = append(args, f.ConversationTag)
}
if f.TaskTag != "" {
query += " AND task_tag = ?"
args = append(args, f.TaskTag)
}
if f.Severity != "" {
query += " AND severity = ?"
args = append(args, f.Severity)
}
if f.Status != "" {
query += " AND status = ?"
args = append(args, f.Status)
}
search := strings.TrimSpace(f.Search)
if search != "" {
pattern := escapeVulnerabilityLikePattern(search)
query += ` AND (
LOWER(id) LIKE LOWER(?) OR
LOWER(title) LIKE LOWER(?) OR
LOWER(COALESCE(description, '')) LIKE LOWER(?) OR
LOWER(COALESCE(vulnerability_type, '')) LIKE LOWER(?) OR
LOWER(COALESCE(target, '')) LIKE LOWER(?) OR
LOWER(COALESCE(proof, '')) LIKE LOWER(?) OR
LOWER(COALESCE(impact, '')) LIKE LOWER(?) OR
LOWER(COALESCE(recommendation, '')) LIKE LOWER(?) OR
LOWER(COALESCE(conversation_id, '')) LIKE LOWER(?) OR
LOWER(COALESCE(conversation_tag, '')) LIKE LOWER(?) OR
LOWER(COALESCE(task_tag, '')) LIKE LOWER(?)
)`
for i := 0; i < 11; i++ {
args = append(args, pattern)
}
}
return query, args
}
// Vulnerability 漏洞
type Vulnerability struct {
ID string `json:"id"`
@@ -97,7 +169,7 @@ func (db *DB) GetVulnerability(id string) (*Vulnerability, error) {
}
// ListVulnerabilities 列出漏洞
func (db *DB) ListVulnerabilities(limit, offset int, id, conversationID, severity, status, taskID, conversationTag, taskTag string) ([]*Vulnerability, error) {
func (db *DB) ListVulnerabilities(limit, offset int, filter VulnerabilityListFilter) ([]*Vulnerability, error) {
query := `
SELECT id, conversation_id, title, description, severity, status, conversation_tag, task_tag,
vulnerability_type, target, proof, impact, recommendation,
@@ -108,35 +180,7 @@ func (db *DB) ListVulnerabilities(limit, offset int, id, conversationID, severit
WHERE 1=1
`
args := []interface{}{}
if id != "" {
query += " AND id = ?"
args = append(args, id)
}
if conversationID != "" {
query += " AND conversation_id = ?"
args = append(args, conversationID)
}
if taskID != "" {
query += " AND EXISTS (SELECT 1 FROM batch_tasks bt WHERE bt.conversation_id = vulnerabilities.conversation_id AND (bt.id = ? OR bt.queue_id = ?))"
args = append(args, taskID, taskID)
}
if conversationTag != "" {
query += " AND conversation_tag = ?"
args = append(args, conversationTag)
}
if taskTag != "" {
query += " AND task_tag = ?"
args = append(args, taskTag)
}
if severity != "" {
query += " AND severity = ?"
args = append(args, severity)
}
if status != "" {
query += " AND status = ?"
args = append(args, status)
}
query, args = filter.appendWhere(query, args)
query += " ORDER BY created_at DESC LIMIT ? OFFSET ?"
args = append(args, limit, offset)
@@ -168,38 +212,10 @@ func (db *DB) ListVulnerabilities(limit, offset int, id, conversationID, severit
}
// CountVulnerabilities 统计漏洞总数(支持筛选条件)
func (db *DB) CountVulnerabilities(id, conversationID, severity, status, taskID, conversationTag, taskTag string) (int, error) {
func (db *DB) CountVulnerabilities(filter VulnerabilityListFilter) (int, error) {
query := "SELECT COUNT(*) FROM vulnerabilities WHERE 1=1"
args := []interface{}{}
if id != "" {
query += " AND id = ?"
args = append(args, id)
}
if conversationID != "" {
query += " AND conversation_id = ?"
args = append(args, conversationID)
}
if taskID != "" {
query += " AND EXISTS (SELECT 1 FROM batch_tasks bt WHERE bt.conversation_id = vulnerabilities.conversation_id AND (bt.id = ? OR bt.queue_id = ?))"
args = append(args, taskID, taskID)
}
if conversationTag != "" {
query += " AND conversation_tag = ?"
args = append(args, conversationTag)
}
if taskTag != "" {
query += " AND task_tag = ?"
args = append(args, taskTag)
}
if severity != "" {
query += " AND severity = ?"
args = append(args, severity)
}
if status != "" {
query += " AND status = ?"
args = append(args, status)
}
query, args = filter.appendWhere(query, args)
var count int
err := db.QueryRow(query, args...).Scan(&count)
@@ -245,19 +261,12 @@ func (db *DB) DeleteVulnerability(id string) error {
}
// GetVulnerabilityStats 获取漏洞统计(筛选条件与 ListVulnerabilities / CountVulnerabilities 一致)
func (db *DB) GetVulnerabilityStats(conversationID, taskID string) (map[string]interface{}, error) {
func (db *DB) GetVulnerabilityStats(filter VulnerabilityListFilter) (map[string]interface{}, error) {
stats := make(map[string]interface{})
where := "WHERE 1=1"
args := []interface{}{}
if conversationID != "" {
where += " AND conversation_id = ?"
args = append(args, conversationID)
}
if taskID != "" {
where += " AND EXISTS (SELECT 1 FROM batch_tasks bt WHERE bt.conversation_id = vulnerabilities.conversation_id AND (bt.id = ? OR bt.queue_id = ?))"
args = append(args, taskID, taskID)
}
where, args = filter.appendWhere(where, args)
// 总漏洞数
var totalCount int
+191 -49
View File
@@ -17,12 +17,14 @@ import (
"unicode/utf8"
"cyberstrike-ai/internal/agent"
"cyberstrike-ai/internal/audit"
"cyberstrike-ai/internal/config"
"cyberstrike-ai/internal/database"
"cyberstrike-ai/internal/reasoning"
"cyberstrike-ai/internal/mcp"
"cyberstrike-ai/internal/mcp/builtin"
"cyberstrike-ai/internal/multiagent"
"cyberstrike-ai/internal/openai"
"github.com/gin-gonic/gin"
"github.com/robfig/cron/v3"
@@ -130,6 +132,12 @@ type AgentHandler struct {
batchRunning map[string]struct{}
// hitlWhitelistSaver 侧栏「应用」HITL 时将会话增量白名单合并写入 config.yaml(可选)
hitlWhitelistSaver HitlToolWhitelistSaver
audit *audit.Service
}
// SetAudit wires platform audit logging.
func (h *AgentHandler) SetAudit(s *audit.Service) {
h.audit = s
}
// HitlToolWhitelistSaver 合并 HITL 免审批工具到全局配置并落盘
@@ -206,7 +214,7 @@ type ChatAttachment struct {
type ChatReasoningRequest struct {
// Mode: default(跟随系统)| off | on | auto
Mode string `json:"mode,omitempty"`
// Effort: low | medium | high | max;空表示不指定(由系统默认与各 profile 决定)
// Effort: low | medium | high | max | xhigh(原样下发;不同网关最高档命名不同)。空表示不指定
Effort string `json:"effort,omitempty"`
}
@@ -552,7 +560,7 @@ func (h *AgentHandler) AgentLoop(c *gin.Context) {
conversationID := req.ConversationID
if conversationID == "" {
title := safeTruncateString(req.Message, 50)
conv, err := h.db.CreateConversation(title)
conv, err := h.db.CreateConversation(title, audit.ConversationCreateMetaFromGin(c, "agent_loop"))
if err != nil {
h.logger.Error("创建对话失败", zap.Error(err))
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
@@ -716,11 +724,43 @@ func (h *AgentHandler) AgentLoop(c *gin.Context) {
})
}
func (h *AgentHandler) finalizeRobotAgentError(ctx context.Context, assistantMessageID, conversationID string, resultMA *multiagent.RunResult, errMA error) (string, string, error) {
if shouldPersistEinoAgentTraceAfterRunError(ctx) {
h.persistEinoAgentTraceForResume(conversationID, resultMA)
}
errMsg := "执行失败: " + errMA.Error()
if assistantMessageID != "" {
_, _ = h.db.Exec("UPDATE messages SET content = ?, updated_at = ? WHERE id = ?", errMsg, time.Now(), assistantMessageID)
_ = h.db.AddProcessDetail(assistantMessageID, conversationID, "error", errMsg, nil)
}
return "", conversationID, errMA
}
func (h *AgentHandler) finalizeRobotAgentSuccess(assistantMessageID, conversationID string, resultMA *multiagent.RunResult) (string, string, error) {
if assistantMessageID != "" {
if errU := h.db.UpdateAssistantMessageFinalize(assistantMessageID, resultMA.Response, resultMA.MCPExecutionIDs, multiagent.AggregatedReasoningFromTraceJSON(resultMA.LastAgentTraceInput)); errU != nil {
h.logger.Warn("机器人:更新助手消息失败", zap.Error(errU))
}
} else {
if _, err := h.db.AddMessage(conversationID, "assistant", resultMA.Response, resultMA.MCPExecutionIDs); err != nil {
h.logger.Warn("机器人:保存助手消息失败", zap.Error(err))
}
}
if resultMA.LastAgentTraceInput != "" || resultMA.LastAgentTraceOutput != "" {
_ = h.db.SaveAgentTrace(conversationID, resultMA.LastAgentTraceInput, resultMA.LastAgentTraceOutput)
}
return resultMA.Response, conversationID, nil
}
// ProcessMessageForRobot 供机器人(企业微信/钉钉/飞书)调用:与 /api/agent-loop/stream 相同执行路径(含 progressCallback、过程详情),仅不发送 SSE,最后返回完整回复
func (h *AgentHandler) ProcessMessageForRobot(ctx context.Context, conversationID, message, role string) (response string, convID string, err error) {
func (h *AgentHandler) ProcessMessageForRobot(ctx context.Context, platform, conversationID, message, role string) (response string, convID string, err error) {
if conversationID == "" {
title := safeTruncateString(message, 50)
conv, createErr := h.db.CreateConversation(title)
src := "robot"
if strings.TrimSpace(platform) != "" {
src = "robot:" + strings.TrimSpace(platform)
}
conv, createErr := h.db.CreateConversation(title, audit.ConversationCreateMeta(src))
if createErr != nil {
return "", "", fmt.Errorf("创建对话失败: %w", createErr)
}
@@ -768,53 +808,92 @@ func (h *AgentHandler) ProcessMessageForRobot(ctx context.Context, conversationI
if assistantMsg != nil {
assistantMessageID = assistantMsg.ID
}
progressCallback := h.createProgressCallback(ctx, nil, conversationID, assistantMessageID, nil)
useRobotMulti := h.config != nil && h.config.MultiAgent.Enabled && h.config.MultiAgent.RobotUseMultiAgent
if useRobotMulti {
resultMA, errMA := multiagent.RunDeepAgent(
ctx,
h.config,
&h.config.MultiAgent,
h.agent,
h.logger,
conversationID,
finalMessage,
agentHistoryMessages,
roleTools,
progressCallback,
h.agentsMarkdownDir,
"deep",
nil,
)
if errMA != nil {
if shouldPersistEinoAgentTraceAfterRunError(ctx) {
h.persistEinoAgentTraceForResume(conversationID, resultMA)
}
errMsg := "执行失败: " + errMA.Error()
if assistantMessageID != "" {
_, _ = h.db.Exec("UPDATE messages SET content = ?, updated_at = ? WHERE id = ?", errMsg, time.Now(), assistantMessageID)
_ = h.db.AddProcessDetail(assistantMessageID, conversationID, "error", errMsg, nil)
}
return "", conversationID, errMA
// 注册运行中任务并向 taskEventBus 镜像进度事件,供 Web 端 task-events 补流(与 agent-loop/stream 一致)。
taskCtx, cancelWithCause := context.WithCancelCause(ctx)
defer cancelWithCause(nil)
taskStatus := "completed"
defer func() {
h.tasks.FinishTask(conversationID, taskStatus)
}()
if _, err := h.tasks.StartTask(conversationID, message, cancelWithCause); err != nil {
if errors.Is(err, ErrTaskAlreadyRunning) {
return "", conversationID, fmt.Errorf("当前会话已有任务正在执行中,请稍后再试")
}
if assistantMessageID != "" {
if errU := h.db.UpdateAssistantMessageFinalize(assistantMessageID, resultMA.Response, resultMA.MCPExecutionIDs, multiagent.AggregatedReasoningFromTraceJSON(resultMA.LastAgentTraceInput)); errU != nil {
h.logger.Warn("机器人:更新助手消息失败", zap.Error(errU))
return "", conversationID, fmt.Errorf("无法启动任务: %w", err)
}
progressCallback := h.createProgressCallback(taskCtx, cancelWithCause, conversationID, assistantMessageID, nil)
robotMode := "react"
if h.config != nil {
robotMode = config.NormalizeRobotAgentMode(h.config.MultiAgent)
}
switch robotMode {
case "eino_single":
curHist := agentHistoryMessages
curMsg := finalMessage
segmentUserMessage := finalMessage
var resultMA *multiagent.RunResult
var errMA error
var transientRunAttempts int
for {
resultMA, errMA = multiagent.RunEinoSingleChatModelAgent(
taskCtx, h.config, &h.config.MultiAgent, h.agent, h.logger,
conversationID, curMsg, curHist, roleTools, progressCallback, nil,
)
if errMA == nil {
// 成功后重置 transient 重试窗口,下一次分段从第 1 次重试开始。
transientRunAttempts = 0
break
}
} else {
if _, err = h.db.AddMessage(conversationID, "assistant", resultMA.Response, resultMA.MCPExecutionIDs); err != nil {
h.logger.Warn("机器人:保存助手消息失败", zap.Error(err))
if handled, _ := h.handleEinoTransientRetryContinue(
taskCtx, conversationID, resultMA, errMA, &transientRunAttempts,
&curHist, &curMsg, segmentUserMessage, progressCallback, nil,
); handled {
continue
}
taskStatus = "failed"
return h.finalizeRobotAgentError(taskCtx, assistantMessageID, conversationID, resultMA, errMA)
}
if resultMA.LastAgentTraceInput != "" || resultMA.LastAgentTraceOutput != "" {
_ = h.db.SaveAgentTrace(conversationID, resultMA.LastAgentTraceInput, resultMA.LastAgentTraceOutput)
return h.finalizeRobotAgentSuccess(assistantMessageID, conversationID, resultMA)
case "deep", "plan_execute", "supervisor":
if h.config == nil || !h.config.MultiAgent.Enabled {
h.logger.Warn("机器人配置为多代理模式但未启用 multi_agent,回退原生 ReAct",
zap.String("robot_mode", robotMode))
break
}
return resultMA.Response, conversationID, nil
curHist := agentHistoryMessages
curMsg := finalMessage
segmentUserMessage := finalMessage
var resultMA *multiagent.RunResult
var errMA error
var transientRunAttempts int
for {
resultMA, errMA = multiagent.RunDeepAgent(
taskCtx, h.config, &h.config.MultiAgent, h.agent, h.logger,
conversationID, curMsg, curHist, roleTools, progressCallback,
h.agentsMarkdownDir, robotMode, nil,
)
if errMA == nil {
// 成功后重置 transient 重试窗口,下一次分段从第 1 次重试开始。
transientRunAttempts = 0
break
}
if handled, _ := h.handleEinoTransientRetryContinue(
taskCtx, conversationID, resultMA, errMA, &transientRunAttempts,
&curHist, &curMsg, segmentUserMessage, progressCallback, nil,
); handled {
continue
}
taskStatus = "failed"
return h.finalizeRobotAgentError(taskCtx, assistantMessageID, conversationID, resultMA, errMA)
}
return h.finalizeRobotAgentSuccess(assistantMessageID, conversationID, resultMA)
}
result, err := h.agent.AgentLoopWithProgress(ctx, finalMessage, agentHistoryMessages, conversationID, progressCallback, roleTools)
result, err := h.agent.AgentLoopWithProgress(taskCtx, finalMessage, agentHistoryMessages, conversationID, progressCallback, roleTools)
if err != nil {
taskStatus = "failed"
errMsg := "执行失败: " + err.Error()
if assistantMessageID != "" {
_, _ = h.db.Exec("UPDATE messages SET content = ?, updated_at = ? WHERE id = ?", errMsg, time.Now(), assistantMessageID)
@@ -846,6 +925,23 @@ type StreamEvent struct {
Data interface{} `json:"data,omitempty"`
}
// publishProgressToTaskEventBus 将进度事件镜像到 taskEventBus(机器人/无 HTTP SSE 客户端时供 Web task-events 订阅)。
func (h *AgentHandler) publishProgressToTaskEventBus(conversationID, eventType, message string, data interface{}) {
if h == nil || h.taskEventBus == nil || strings.TrimSpace(conversationID) == "" {
return
}
event := StreamEvent{Type: eventType, Message: message, Data: data}
eventJSON, err := json.Marshal(event)
if err != nil {
return
}
sseLine := make([]byte, 0, len(eventJSON)+8)
sseLine = append(sseLine, []byte("data: ")...)
sseLine = append(sseLine, eventJSON...)
sseLine = append(sseLine, '\n', '\n')
h.taskEventBus.Publish(conversationID, sseLine)
}
// createProgressCallback 创建进度回调函数,用于保存processDetails
// sendEventFunc: 可选的流式事件发送函数,如果为nil则不发送流式事件
func (h *AgentHandler) createProgressCallback(runCtx context.Context, cancelRun context.CancelCauseFunc, conversationID, assistantMessageID string, sendEventFunc func(eventType, message string, data interface{})) agent.ProgressCallback {
@@ -955,9 +1051,11 @@ func (h *AgentHandler) createProgressCallback(runCtx context.Context, cancelRun
}
return func(eventType, message string, data interface{}) {
// 如果提供了sendEventFunc,发送流式事件
// 流式:写 HTTP SSE;非流式(机器人等):镜像到 taskEventBus 供 Web 订阅
if sendEventFunc != nil {
sendEventFunc(eventType, message, data)
} else {
h.publishProgressToTaskEventBus(conversationID, eventType, message, data)
}
// 保存tool_call事件中的参数
@@ -1158,7 +1256,16 @@ func (h *AgentHandler) createProgressCallback(runCtx context.Context, cancelRun
return
}
if eventType == "response_delta" {
respPlan.b.WriteString(message)
if dataMap, ok := data.(map[string]interface{}); ok {
if acc, okAcc := dataMap[openai.SSEAccumulatedKey].(string); okAcc {
respPlan.b.Reset()
respPlan.b.WriteString(acc)
} else {
respPlan.b.WriteString(message)
}
} else {
respPlan.b.WriteString(message)
}
if dataMap, ok := data.(map[string]interface{}); ok && respPlan.meta == nil {
respPlan.meta = make(map[string]interface{}, len(dataMap))
for k, v := range dataMap {
@@ -1213,8 +1320,12 @@ func (h *AgentHandler) createProgressCallback(runCtx context.Context, cancelRun
} else if tb.persistAs == "" {
tb.persistAs = persistAs
}
// delta 片段直接拼接
tb.b.WriteString(message)
if acc, okAcc := dataMap[openai.SSEAccumulatedKey].(string); okAcc {
tb.b.Reset()
tb.b.WriteString(acc)
} else {
tb.b.WriteString(message)
}
// 有时 delta 先到 start 未到,补充元信息
for k, v := range dataMap {
tb.meta[k] = v
@@ -1406,10 +1517,12 @@ func (h *AgentHandler) AgentLoopStream(c *gin.Context) {
title := safeTruncateString(req.Message, 50)
var conv *database.Conversation
var err error
meta := audit.ConversationCreateMetaFromGin(c, "agent_loop_stream")
if req.WebShellConnectionID != "" {
conv, err = h.db.CreateConversationWithWebshell(strings.TrimSpace(req.WebShellConnectionID), title)
meta.Source = "webshell_chat"
conv, err = h.db.CreateConversationWithWebshell(strings.TrimSpace(req.WebShellConnectionID), title, meta)
} else {
conv, err = h.db.CreateConversation(title)
conv, err = h.db.CreateConversation(title, meta)
}
if err != nil {
h.logger.Error("创建对话失败", zap.Error(err))
@@ -2025,6 +2138,11 @@ func (h *AgentHandler) CreateBatchQueue(c *gin.Context) {
queue = refreshed
}
}
if h.audit != nil {
h.audit.RecordOK(c, "task", "create_queue", "创建批量任务队列", "batch_queue", queue.ID, map[string]interface{}{
"task_count": len(validTasks), "started": started,
})
}
c.JSON(http.StatusOK, gin.H{
"queueId": queue.ID,
"queue": queue,
@@ -2132,6 +2250,9 @@ func (h *AgentHandler) StartBatchQueue(c *gin.Context) {
c.JSON(http.StatusNotFound, gin.H{"error": "队列不存在"})
return
}
if h.audit != nil {
h.audit.RecordOK(c, "task", "start_queue", "启动批量任务队列", "batch_queue", queueID, nil)
}
c.JSON(http.StatusOK, gin.H{"message": "批量任务已开始执行", "queueId": queueID})
}
@@ -2160,6 +2281,9 @@ func (h *AgentHandler) RerunBatchQueue(c *gin.Context) {
c.JSON(http.StatusInternalServerError, gin.H{"error": "启动失败"})
return
}
if h.audit != nil {
h.audit.RecordOK(c, "task", "rerun_queue", "重跑批量任务队列", "batch_queue", queueID, nil)
}
c.JSON(http.StatusOK, gin.H{"message": "批量任务已重新开始执行", "queueId": queueID})
}
@@ -2171,6 +2295,9 @@ func (h *AgentHandler) PauseBatchQueue(c *gin.Context) {
c.JSON(http.StatusNotFound, gin.H{"error": "队列不存在或无法暂停"})
return
}
if h.audit != nil {
h.audit.RecordOK(c, "task", "pause_queue", "暂停批量任务队列", "batch_queue", queueID, nil)
}
c.JSON(http.StatusOK, gin.H{"message": "批量任务已暂停"})
}
@@ -2266,6 +2393,16 @@ func (h *AgentHandler) DeleteBatchQueue(c *gin.Context) {
c.JSON(http.StatusNotFound, gin.H{"error": "队列不存在"})
return
}
if h.audit != nil {
h.audit.Record(c, audit.Entry{
Category: "task",
Action: "delete_queue",
Result: "success",
ResourceType: "batch_queue",
ResourceID: queueID,
Message: "删除批量任务队列",
})
}
c.JSON(http.StatusOK, gin.H{"message": "批量任务队列已删除"})
}
@@ -2351,6 +2488,11 @@ func (h *AgentHandler) DeleteBatchTask(c *gin.Context) {
c.JSON(http.StatusNotFound, gin.H{"error": "队列不存在"})
return
}
if h.audit != nil {
h.audit.RecordOK(c, "task", "delete_batch_task", "删除批量子任务", "batch_task", taskID, map[string]interface{}{
"batch_queue_id": queueID,
})
}
c.JSON(http.StatusOK, gin.H{"message": "任务已删除", "queue": queue})
}
@@ -2509,7 +2651,7 @@ func (h *AgentHandler) executeBatchQueue(queueID string) {
// 创建新对话
title := safeTruncateString(task.Message, 50)
conv, err := h.db.CreateConversation(title)
conv, err := h.db.CreateConversation(title, audit.ConversationCreateMeta("batch_task"))
var conversationID string
if err != nil {
h.logger.Error("创建对话失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.Error(err))
+147
View File
@@ -0,0 +1,147 @@
package handler
import (
"net/http"
"time"
"cyberstrike-ai/internal/audit"
"cyberstrike-ai/internal/database"
"github.com/gin-gonic/gin"
"go.uber.org/zap"
)
// AuditHandler serves platform audit log APIs.
type AuditHandler struct {
db *database.DB
audit *audit.Service
logger *zap.Logger
}
// NewAuditHandler creates an audit log handler.
func NewAuditHandler(db *database.DB, auditSvc *audit.Service, logger *zap.Logger) *AuditHandler {
return &AuditHandler{db: db, audit: auditSvc, logger: logger}
}
// Meta GET /api/audit/meta
func (h *AuditHandler) Meta(c *gin.Context) {
enabled := false
retentionDays := 0
if h.audit != nil {
enabled = h.audit.Enabled()
retentionDays = h.audit.RetentionDays()
}
c.JSON(http.StatusOK, gin.H{
"enabled": enabled,
"retention_days": retentionDays,
"default_page_size": 20,
"max_page_size": 100,
"max_export": 5000,
})
}
// Summary GET /api/audit/summary
func (h *AuditHandler) Summary(c *gin.Context) {
if h.db == nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "database unavailable"})
return
}
base := auditFilterFromQuery(c)
total, err := h.db.CountAuditLogs(base)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
failFilter := base
failFilter.Result = "failure"
failures, err := h.db.CountAuditLogs(failFilter)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
since := time.Now().AddDate(0, 0, -7)
recentFilter := base
recentFilter.Since = &since
recent7d, err := h.db.CountAuditLogs(recentFilter)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, gin.H{
"total": total,
"failures": failures,
"recent_7d": recent7d,
"has_filters": c.Query("category") != "" || c.Query("action") != "" || c.Query("result") != "" ||
c.Query("q") != "" || c.Query("since") != "" || c.Query("until") != "",
})
}
// ListLogs GET /api/audit/logs
func (h *AuditHandler) ListLogs(c *gin.Context) {
if h.db == nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "database unavailable"})
return
}
filter := auditFilterFromQuery(c)
page, pageSize := auditPaginationFromQuery(c)
filter.Limit = pageSize
filter.Offset = (page - 1) * pageSize
logs, err := h.db.ListAuditLogs(filter)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
total, err := h.db.CountAuditLogs(filter)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, gin.H{
"logs": logs,
"total": total,
"page": page,
"page_size": pageSize,
})
}
// GetLog GET /api/audit/logs/:id
func (h *AuditHandler) GetLog(c *gin.Context) {
if h.db == nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "database unavailable"})
return
}
row, err := h.db.GetAuditLogByID(c.Param("id"))
if err != nil {
c.JSON(http.StatusNotFound, gin.H{"error": "审计记录不存在"})
return
}
audit.ApplyResourceAvailability(h.db, row)
c.JSON(http.StatusOK, gin.H{"log": row})
}
// ExportLogs GET /api/audit/logs/export — JSON or CSV (?format=csv), max 5000 rows.
func (h *AuditHandler) ExportLogs(c *gin.Context) {
if h.db == nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "database unavailable"})
return
}
filter := auditFilterFromQuery(c)
filter.Limit = 5000
filter.Offset = 0
logs, err := h.db.ListAuditLogs(filter)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
if c.Query("format") == "csv" {
writeAuditLogsCSV(c, logs)
return
}
c.Header("Content-Disposition", `attachment; filename="audit-logs.json"`)
c.JSON(http.StatusOK, gin.H{
"exported_at": time.Now().UTC().Format(time.RFC3339),
"logs": logs,
})
}
+42
View File
@@ -0,0 +1,42 @@
package handler
import (
"encoding/csv"
"fmt"
"time"
"cyberstrike-ai/internal/database"
"github.com/gin-gonic/gin"
)
func writeAuditLogsCSV(c *gin.Context, logs []*database.AuditLog) {
c.Header("Content-Type", "text/csv; charset=utf-8")
c.Header("Content-Disposition", fmt.Sprintf(`attachment; filename="audit-logs-%s.csv"`, time.Now().Format("20060102")))
w := csv.NewWriter(c.Writer)
_ = w.Write([]string{
"id", "created_at", "level", "category", "action", "result", "actor",
"session_hint", "client_ip", "resource_type", "resource_id", "message",
})
for _, row := range logs {
if row == nil {
continue
}
_ = w.Write([]string{
row.ID,
row.CreatedAt.UTC().Format(time.RFC3339),
row.Level,
row.Category,
row.Action,
row.Result,
row.Actor,
row.SessionHint,
row.ClientIP,
row.ResourceType,
row.ResourceID,
row.Message,
})
}
w.Flush()
}
+48
View File
@@ -0,0 +1,48 @@
package handler
import (
"strconv"
"time"
"cyberstrike-ai/internal/database"
"github.com/gin-gonic/gin"
)
func auditFilterFromQuery(c *gin.Context) database.ListAuditLogsFilter {
filter := database.ListAuditLogsFilter{
Level: c.Query("level"),
Category: c.Query("category"),
Action: c.Query("action"),
Result: c.Query("result"),
Query: c.Query("q"),
ResourceType: c.Query("resource_type"),
ResourceID: c.Query("resource_id"),
}
if since := c.Query("since"); since != "" {
if t, err := time.Parse(time.RFC3339, since); err == nil {
filter.Since = &t
}
}
if until := c.Query("until"); until != "" {
if t, err := time.Parse(time.RFC3339, until); err == nil {
filter.Until = &t
}
}
return filter
}
func auditPaginationFromQuery(c *gin.Context) (page, pageSize int) {
page = 1
pageSize = 20
if p, err := strconv.Atoi(c.DefaultQuery("page", "1")); err == nil && p > 0 {
page = p
}
if ps, err := strconv.Atoi(c.DefaultQuery("page_size", "20")); err == nil && ps > 0 {
pageSize = ps
if pageSize > 100 {
pageSize = 100
}
}
return page, pageSize
}
+55
View File
@@ -5,6 +5,7 @@ import (
"strings"
"time"
"cyberstrike-ai/internal/audit"
"cyberstrike-ai/internal/config"
"cyberstrike-ai/internal/security"
@@ -18,6 +19,12 @@ type AuthHandler struct {
config *config.Config
configPath string
logger *zap.Logger
audit *audit.Service
}
// SetAudit wires platform audit logging.
func (h *AuthHandler) SetAudit(s *audit.Service) {
h.audit = s
}
// NewAuthHandler creates a new AuthHandler.
@@ -49,10 +56,32 @@ func (h *AuthHandler) Login(c *gin.Context) {
token, expiresAt, err := h.manager.Authenticate(req.Password)
if err != nil {
if h.audit != nil {
h.audit.Record(c, audit.Entry{
Level: "warn",
Category: "auth",
Action: "login",
Result: "failure",
Message: "登录失败:密码错误",
})
}
c.JSON(http.StatusUnauthorized, gin.H{"error": "密码错误"})
return
}
if h.audit != nil {
h.audit.Record(c, audit.Entry{
Category: "auth",
Action: "login",
Result: "success",
SessionHint: audit.HintFromToken(token),
Message: "登录成功",
Detail: map[string]interface{}{
"expires_at": expiresAt.UTC().Format(time.RFC3339),
},
})
}
c.JSON(http.StatusOK, gin.H{
"token": token,
"expires_at": expiresAt.UTC().Format(time.RFC3339),
@@ -73,6 +102,14 @@ func (h *AuthHandler) Logout(c *gin.Context) {
}
h.manager.RevokeToken(token)
if h.audit != nil {
h.audit.Record(c, audit.Entry{
Category: "auth",
Action: "logout",
Result: "success",
Message: "退出登录",
})
}
c.JSON(http.StatusOK, gin.H{"message": "已退出登录"})
}
@@ -103,6 +140,15 @@ func (h *AuthHandler) ChangePassword(c *gin.Context) {
}
if !h.manager.CheckPassword(oldPassword) {
if h.audit != nil {
h.audit.Record(c, audit.Entry{
Level: "warn",
Category: "auth",
Action: "change_password",
Result: "failure",
Message: "修改密码失败:当前密码不正确",
})
}
c.JSON(http.StatusBadRequest, gin.H{"error": "当前密码不正确"})
return
}
@@ -132,6 +178,15 @@ func (h *AuthHandler) ChangePassword(c *gin.Context) {
h.logger.Info("登录密码已更新,所有会话已失效")
}
if h.audit != nil {
h.audit.Record(c, audit.Entry{
Category: "auth",
Action: "change_password",
Result: "success",
Message: "登录密码已修改",
})
}
c.JSON(http.StatusOK, gin.H{"message": "密码已更新,请使用新密码重新登录"})
}
+37
View File
@@ -13,6 +13,7 @@ import (
"sync/atomic"
"time"
"cyberstrike-ai/internal/audit"
"cyberstrike-ai/internal/c2"
"cyberstrike-ai/internal/database"
@@ -25,6 +26,12 @@ import (
type C2Handler struct {
mgrPtr atomic.Pointer[c2.Manager]
logger *zap.Logger
audit *audit.Service
}
// SetAudit wires platform audit logging.
func (h *C2Handler) SetAudit(s *audit.Service) {
h.audit = s
}
// NewC2Handler 创建 C2 处理器;manager 可为 nil(功能关闭时)
@@ -104,6 +111,11 @@ func (h *C2Handler) CreateListener(c *gin.Context) {
implantToken := listener.ImplantToken
listener.EncryptionKey = ""
listener.ImplantToken = ""
if h.audit != nil {
h.audit.RecordOK(c, "c2", "listener_create", "创建 C2 监听器", "c2_listener", listener.ID, map[string]interface{}{
"name": listener.Name, "bind": listener.BindHost, "port": listener.BindPort,
})
}
c.JSON(http.StatusOK, gin.H{"listener": listener, "implant_token": implantToken})
}
@@ -205,6 +217,9 @@ func (h *C2Handler) DeleteListener(c *gin.Context) {
c.JSON(code, gin.H{"error": err.Error()})
return
}
if h.audit != nil {
h.audit.RecordOK(c, "c2", "listener_delete", "删除 C2 监听器", "c2_listener", id, nil)
}
c.JSON(http.StatusOK, gin.H{"deleted": true})
}
@@ -222,6 +237,9 @@ func (h *C2Handler) StartListener(c *gin.Context) {
}
listener.EncryptionKey = ""
listener.ImplantToken = ""
if h.audit != nil {
h.audit.RecordOK(c, "c2", "listener_start", "启动 C2 监听器", "c2_listener", id, nil)
}
c.JSON(http.StatusOK, gin.H{"listener": listener})
}
@@ -236,6 +254,9 @@ func (h *C2Handler) StopListener(c *gin.Context) {
c.JSON(code, gin.H{"error": err.Error()})
return
}
if h.audit != nil {
h.audit.RecordOK(c, "c2", "listener_stop", "停止 C2 监听器", "c2_listener", id, nil)
}
c.JSON(http.StatusOK, gin.H{"stopped": true})
}
@@ -297,6 +318,9 @@ func (h *C2Handler) DeleteSession(c *gin.Context) {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
if h.audit != nil {
h.audit.RecordOK(c, "c2", "session_delete", "删除 C2 会话", "c2_session", id, nil)
}
c.JSON(http.StatusOK, gin.H{"deleted": true})
}
@@ -407,6 +431,11 @@ func (h *C2Handler) DeleteTasks(c *gin.Context) {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
if h.audit != nil {
h.audit.RecordOK(c, "c2", "task_delete", "批量删除 C2 任务", "c2_task", "", map[string]interface{}{
"count": n, "ids": req.IDs,
})
}
c.JSON(http.StatusOK, gin.H{"deleted": n})
}
@@ -457,6 +486,11 @@ func (h *C2Handler) CreateTask(c *gin.Context) {
c.JSON(code, gin.H{"error": err.Error()})
return
}
if h.audit != nil {
h.audit.RecordOK(c, "c2", "task_create", "创建 C2 任务", "c2_task", task.ID, map[string]interface{}{
"session_id": req.SessionID, "task_type": req.TaskType,
})
}
c.JSON(http.StatusOK, gin.H{"task": task})
}
@@ -471,6 +505,9 @@ func (h *C2Handler) CancelTask(c *gin.Context) {
c.JSON(code, gin.H{"error": err.Error()})
return
}
if h.audit != nil {
h.audit.RecordOK(c, "c2", "task_cancel", "取消 C2 任务", "c2_task", id, nil)
}
c.JSON(http.StatusOK, gin.H{"cancelled": true})
}
+16
View File
@@ -12,6 +12,8 @@ import (
"time"
"unicode/utf8"
"cyberstrike-ai/internal/audit"
"github.com/gin-gonic/gin"
"go.uber.org/zap"
)
@@ -24,6 +26,12 @@ const (
// ChatUploadsHandler 对话中上传附件(chat_uploads 目录)的管理 API
type ChatUploadsHandler struct {
logger *zap.Logger
audit *audit.Service
}
// SetAudit wires platform audit logging.
func (h *ChatUploadsHandler) SetAudit(s *audit.Service) {
h.audit = s
}
// NewChatUploadsHandler 创建处理器
@@ -230,6 +238,9 @@ func (h *ChatUploadsHandler) Delete(c *gin.Context) {
return
}
}
if h.audit != nil {
h.audit.RecordOK(c, "file", "delete", "删除对话附件", "chat_upload", body.Path, nil)
}
c.JSON(http.StatusOK, gin.H{"ok": true})
}
@@ -503,6 +514,11 @@ func (h *ChatUploadsHandler) Upload(c *gin.Context) {
}
rel, _ := filepath.Rel(root, fullPath)
absSaved, _ := filepath.Abs(fullPath)
if h.audit != nil {
h.audit.RecordOK(c, "file", "upload", "上传对话附件", "chat_upload", filepath.ToSlash(rel), map[string]interface{}{
"name": unique,
})
}
c.JSON(http.StatusOK, gin.H{
"ok": true,
"relativePath": filepath.ToSlash(rel),
+72 -4
View File
@@ -14,6 +14,7 @@ import (
"time"
"cyberstrike-ai/internal/agents"
"cyberstrike-ai/internal/audit"
"cyberstrike-ai/internal/config"
"cyberstrike-ai/internal/knowledge"
"cyberstrike-ai/internal/mcp"
@@ -87,6 +88,7 @@ type ConfigHandler struct {
knowledgeInitializer KnowledgeInitializer // 知识库初始化器(可选)
appUpdater AppUpdater // App更新器(可选)
robotRestarter RobotRestarter // 机器人连接重启器(可选),ApplyConfig 时重启钉钉/飞书
audit *audit.Service
logger *zap.Logger
mu sync.RWMutex
lastEmbeddingConfig *config.EmbeddingConfig // 上一次的嵌入模型配置(用于检测变更)
@@ -206,6 +208,32 @@ func (h *ConfigHandler) SetRobotRestarter(restarter RobotRestarter) {
h.robotRestarter = restarter
}
// SetAudit wires platform audit logging.
func (h *ConfigHandler) SetAudit(s *audit.Service) {
h.mu.Lock()
defer h.mu.Unlock()
h.audit = s
}
// ApplyWechatRobotBinding 微信 iLink 扫码绑定成功后写入配置并重启机器人连接
func (h *ConfigHandler) ApplyWechatRobotBinding(wc config.RobotWechatConfig) error {
h.mu.Lock()
wc.Enabled = true
h.config.Robots.Wechat = wc
h.mu.Unlock()
if err := h.saveConfig(); err != nil {
return err
}
if h.robotRestarter != nil {
h.robotRestarter.RestartRobotConnections()
}
h.logger.Info("微信机器人绑定已保存",
zap.String("ilink_bot_id", wc.ILinkBotID),
zap.Bool("enabled", wc.Enabled),
)
return nil
}
// GetConfigResponse 获取配置响应
type GetConfigResponse struct {
OpenAI config.OpenAIConfig `json:"openai"`
@@ -291,7 +319,7 @@ func (h *ConfigHandler) GetConfig(c *gin.Context) {
}
multiPub := config.MultiAgentPublic{
Enabled: h.config.MultiAgent.Enabled,
RobotUseMultiAgent: h.config.MultiAgent.RobotUseMultiAgent,
RobotDefaultAgentMode: config.NormalizeRobotAgentMode(h.config.MultiAgent),
BatchUseMultiAgent: h.config.MultiAgent.BatchUseMultiAgent,
SubAgentCount: subAgentCount,
Orchestration: config.NormalizeMultiAgentOrchestration(h.config.MultiAgent.Orchestration),
@@ -735,6 +763,7 @@ func (h *ConfigHandler) UpdateConfig(c *gin.Context) {
if req.Robots != nil {
h.config.Robots = *req.Robots
h.logger.Info("更新机器人配置",
zap.Bool("wechat_enabled", h.config.Robots.Wechat.Enabled),
zap.Bool("wecom_enabled", h.config.Robots.Wecom.Enabled),
zap.Bool("dingtalk_enabled", h.config.Robots.Dingtalk.Enabled),
zap.Bool("lark_enabled", h.config.Robots.Lark.Enabled),
@@ -750,8 +779,12 @@ func (h *ConfigHandler) UpdateConfig(c *gin.Context) {
// 多代理标量(sub_agents 等仍由 config.yaml 维护)
if req.MultiAgent != nil {
h.config.MultiAgent.Enabled = req.MultiAgent.Enabled
h.config.MultiAgent.RobotUseMultiAgent = req.MultiAgent.RobotUseMultiAgent
h.config.MultiAgent.BatchUseMultiAgent = req.MultiAgent.BatchUseMultiAgent
if mode := strings.TrimSpace(req.MultiAgent.RobotDefaultAgentMode); mode != "" {
h.config.MultiAgent.RobotDefaultAgentMode = mode
} else {
h.config.MultiAgent.RobotDefaultAgentMode = "react"
}
if req.MultiAgent.PlanExecuteLoopMaxIterations != nil {
h.config.MultiAgent.PlanExecuteLoopMaxIterations = *req.MultiAgent.PlanExecuteLoopMaxIterations
}
@@ -760,7 +793,7 @@ func (h *ConfigHandler) UpdateConfig(c *gin.Context) {
}
h.logger.Info("更新多代理配置",
zap.Bool("enabled", h.config.MultiAgent.Enabled),
zap.Bool("robot_use_multi_agent", h.config.MultiAgent.RobotUseMultiAgent),
zap.String("robot_default_agent_mode", config.NormalizeRobotAgentMode(h.config.MultiAgent)),
zap.Bool("batch_use_multi_agent", h.config.MultiAgent.BatchUseMultiAgent),
zap.Int("plan_execute_loop_max_iterations", h.config.MultiAgent.PlanExecuteLoopMaxIterations),
zap.Int("tool_search_always_visible_tools", len(h.config.MultiAgent.EinoMiddleware.ToolSearchAlwaysVisibleTools)),
@@ -883,6 +916,9 @@ func (h *ConfigHandler) UpdateConfig(c *gin.Context) {
return
}
if h.audit != nil {
h.audit.RecordOK(c, "config", "update", "更新内存配置", "config", "", nil)
}
c.JSON(http.StatusOK, gin.H{"message": "配置已更新"})
}
@@ -1013,6 +1049,9 @@ func (h *ConfigHandler) ApplyConfig(c *gin.Context) {
h.logger.Info("检测到知识库从禁用变为启用,开始动态初始化知识库组件")
if _, err := knowledgeInitializer(); err != nil {
h.logger.Error("动态初始化知识库失败", zap.Error(err))
if h.audit != nil {
h.audit.RecordFail(c, "config", "apply", "应用配置失败:初始化知识库", map[string]interface{}{"error": err.Error()})
}
c.JSON(http.StatusInternalServerError, gin.H{"error": "初始化知识库失败: " + err.Error()})
return
}
@@ -1047,6 +1086,9 @@ func (h *ConfigHandler) ApplyConfig(c *gin.Context) {
h.logger.Info("开始重新初始化知识库组件(嵌入模型配置已变更)")
if _, err := reinitKnowledgeInitializer(); err != nil {
h.logger.Error("重新初始化知识库失败", zap.Error(err))
if h.audit != nil {
h.audit.RecordFail(c, "config", "apply", "应用配置失败:重新初始化知识库", map[string]interface{}{"error": err.Error()})
}
c.JSON(http.StatusInternalServerError, gin.H{"error": "重新初始化知识库失败: " + err.Error()})
return
}
@@ -1060,6 +1102,9 @@ func (h *ConfigHandler) ApplyConfig(c *gin.Context) {
if c2Rt != nil {
if err := c2Rt.ReconcileC2AfterConfigApply(); err != nil {
h.logger.Error("C2 配置应用失败", zap.Error(err))
if h.audit != nil {
h.audit.RecordFail(c, "config", "apply", "应用配置失败:C2", map[string]interface{}{"error": err.Error()})
}
c.JSON(http.StatusInternalServerError, gin.H{"error": "C2 启动失败: " + err.Error()})
return
}
@@ -1201,6 +1246,20 @@ func (h *ConfigHandler) ApplyConfig(c *gin.Context) {
zap.Int("tools_count", len(h.config.Security.Tools)),
)
if h.audit != nil {
h.audit.Record(c, audit.Entry{
Category: "config",
Action: "apply",
Result: "success",
Message: "配置已应用",
Detail: map[string]interface{}{
"tools_count": len(h.config.Security.Tools),
"knowledge_enabled": h.config.Knowledge.Enabled,
"c2_enabled": h.config.C2.EnabledEffective(),
},
})
}
c.JSON(http.StatusOK, gin.H{
"message": "配置已应用",
"tools_count": len(h.config.Security.Tools),
@@ -1481,6 +1540,15 @@ func updateRobotsConfig(doc *yaml.Node, cfg config.RobotsConfig) {
setBoolInMap(sessionNode, "strict_user_identity", *cfg.Session.StrictUserIdentity)
}
wechatNode := ensureMap(robotsNode, "wechat")
setBoolInMap(wechatNode, "enabled", cfg.Wechat.Enabled)
setStringInMap(wechatNode, "bot_token", cfg.Wechat.BotToken)
setStringInMap(wechatNode, "ilink_bot_id", cfg.Wechat.ILinkBotID)
setStringInMap(wechatNode, "ilink_user_id", cfg.Wechat.ILinkUserID)
setStringInMap(wechatNode, "base_url", cfg.Wechat.BaseURL)
setStringInMap(wechatNode, "bot_type", cfg.Wechat.BotType)
setStringInMap(wechatNode, "bot_agent", cfg.Wechat.BotAgent)
wecomNode := ensureMap(robotsNode, "wecom")
setBoolInMap(wecomNode, "enabled", cfg.Wecom.Enabled)
setStringInMap(wecomNode, "token", cfg.Wecom.Token)
@@ -1507,7 +1575,7 @@ func updateMultiAgentConfig(doc *yaml.Node, cfg config.MultiAgentConfig) {
root := doc.Content[0]
maNode := ensureMap(root, "multi_agent")
setBoolInMap(maNode, "enabled", cfg.Enabled)
setBoolInMap(maNode, "robot_use_multi_agent", cfg.RobotUseMultiAgent)
setStringInMap(maNode, "robot_default_agent_mode", config.NormalizeRobotAgentMode(cfg))
setBoolInMap(maNode, "batch_use_multi_agent", cfg.BatchUseMultiAgent)
setIntInMap(maNode, "plan_execute_loop_max_iterations", cfg.PlanExecuteLoopMaxIterations)
mwNode := ensureMap(maNode, "eino_middleware")
+25 -1
View File
@@ -5,6 +5,7 @@ import (
"net/http"
"strconv"
"cyberstrike-ai/internal/audit"
"cyberstrike-ai/internal/database"
"github.com/gin-gonic/gin"
"go.uber.org/zap"
@@ -14,6 +15,12 @@ import (
type ConversationHandler struct {
db *database.DB
logger *zap.Logger
audit *audit.Service
}
// SetAudit wires platform audit logging.
func (h *ConversationHandler) SetAudit(s *audit.Service) {
h.audit = s
}
// NewConversationHandler 创建新的对话处理器
@@ -42,7 +49,7 @@ func (h *ConversationHandler) CreateConversation(c *gin.Context) {
title = "新对话"
}
conv, err := h.db.CreateConversation(title)
conv, err := h.db.CreateConversation(title, audit.ConversationCreateMetaFromGin(c, "api"))
if err != nil {
h.logger.Error("创建对话失败", zap.Error(err))
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
@@ -189,6 +196,17 @@ func (h *ConversationHandler) DeleteConversation(c *gin.Context) {
return
}
if h.audit != nil {
h.audit.Record(c, audit.Entry{
Category: "conversation",
Action: "delete",
Result: "success",
ResourceType: "conversation",
ResourceID: id,
Message: "删除对话",
})
}
c.JSON(http.StatusOK, gin.H{"message": "删除成功"})
}
@@ -227,6 +245,12 @@ func (h *ConversationHandler) DeleteConversationTurn(c *gin.Context) {
return
}
if h.audit != nil {
h.audit.RecordOK(c, "conversation", "delete_turn", "删除对话轮次", "conversation", conversationID, map[string]interface{}{
"message_id": req.MessageID,
"deleted": len(deletedIDs),
})
}
c.JSON(http.StatusOK, gin.H{
"deletedMessageIds": deletedIDs,
"message": "ok",
+122
View File
@@ -0,0 +1,122 @@
package handler
import (
"context"
"errors"
"fmt"
"strings"
"time"
"cyberstrike-ai/internal/agent"
"cyberstrike-ai/internal/multiagent"
)
func (h *AgentHandler) einoRunRetryMaxAttempts() int {
if h.config != nil {
return multiagent.RunRetryMaxAttemptsFromConfig(&h.config.MultiAgent.EinoMiddleware)
}
return multiagent.RunRetryMaxAttemptsFromConfig(nil)
}
func (h *AgentHandler) einoRunRetryMaxBackoffSec() int {
if h.config != nil && h.config.MultiAgent.EinoMiddleware.RunRetryMaxBackoffSec > 0 {
return h.config.MultiAgent.EinoMiddleware.RunRetryMaxBackoffSec
}
return 0
}
// applyEinoTraceResumeSegment 中断并继续:persist last_react_* → loadHistory,可选替换下一段 user 文案。
func (h *AgentHandler) applyEinoTraceResumeSegment(
conversationID string,
result *multiagent.RunResult,
curHistory *[]agent.ChatMessage,
curFinalMessage *string,
segmentUserMessage string,
) {
if shouldPersistEinoAgentTraceAfterRunError(context.Background()) {
h.persistEinoAgentTraceForResume(conversationID, result)
}
if hist, err := h.loadHistoryFromAgentTrace(conversationID); err == nil && len(hist) > 0 {
*curHistory = hist
}
if segmentUserMessage != "" {
*curFinalMessage = segmentUserMessage
}
}
// applyEinoTransientRetrySegment 临时错误重试:恢复轨迹并保留本请求原始 user 文案(不注入续跑说明)。
// segmentUserMessage 为本轮 HTTP 请求开始时用户发送的内容,避免因清空 finalMessage 而丢失「你好」等短句。
func (h *AgentHandler) applyEinoTransientRetrySegment(
conversationID string,
result *multiagent.RunResult,
curHistory *[]agent.ChatMessage,
curFinalMessage *string,
segmentUserMessage string,
) {
if shouldPersistEinoAgentTraceAfterRunError(context.Background()) {
h.persistEinoAgentTraceForResume(conversationID, result)
}
if hist, err := h.loadHistoryFromAgentTrace(conversationID); err == nil && len(hist) > 0 {
*curHistory = hist
}
if s := strings.TrimSpace(segmentUserMessage); s != "" {
*curFinalMessage = segmentUserMessage
}
}
// handleEinoTransientRetryContinue 在 SSE 任务循环内处理临时错误重试;返回 true 表示外层 for 应 continue。
func (h *AgentHandler) handleEinoTransientRetryContinue(
baseCtx context.Context,
conversationID string,
result *multiagent.RunResult,
runErr error,
transientAttempts *int,
curHistory *[]agent.ChatMessage,
curFinalMessage *string,
segmentUserMessage string,
progressCallback func(eventType, message string, data interface{}),
sendProgress func(msg string, extra map[string]interface{}),
) (handled bool, fatal error) {
if !errors.Is(runErr, multiagent.ErrTransientRetryContinue) {
return false, nil
}
maxAttempts := h.einoRunRetryMaxAttempts()
*transientAttempts++
if *transientAttempts > maxAttempts {
if shouldPersistEinoAgentTraceAfterRunError(baseCtx) {
h.persistEinoAgentTraceForResume(conversationID, result)
}
return false, errors.New("transient retry exhausted: " + runErr.Error())
}
attemptNo := *transientAttempts
backoff := multiagent.TransientRetryBackoff(attemptNo-1, h.einoRunRetryMaxBackoffSec())
if progressCallback != nil {
progressCallback("eino_run_retry", fmt.Sprintf("遇到临时错误,%d 秒后第 %d/%d 次重试…", int(backoff.Seconds()), attemptNo, maxAttempts), map[string]interface{}{
"conversationId": conversationID,
"source": "eino",
"attempt": attemptNo,
"maxAttempts": maxAttempts,
"backoffSec": int(backoff.Seconds()),
})
}
select {
case <-baseCtx.Done():
return false, context.Cause(baseCtx)
case <-time.After(backoff):
}
h.applyEinoTransientRetrySegment(conversationID, result, curHistory, curFinalMessage, segmentUserMessage)
if progressCallback != nil {
progressCallback("eino_run_retry", "已恢复上下文,正在重试…", map[string]interface{}{
"conversationId": conversationID,
"source": "eino",
"attempt": attemptNo,
})
}
if sendProgress != nil {
sendProgress("正在重试…", map[string]interface{}{
"conversationId": conversationID,
"source": "transient_retry",
})
}
return true, nil
}
+67 -5
View File
@@ -90,7 +90,7 @@ func (h *AgentHandler) EinoSingleAgentLoopStream(c *gin.Context) {
zap.String("conversationId", req.ConversationID),
)
prep, err := h.prepareMultiAgentSession(&req)
prep, err := h.prepareMultiAgentSession(&req, c, "eino_agent_stream")
if err != nil {
sendEvent("error", err.Error(), nil)
sendEvent("done", "", nil)
@@ -119,6 +119,7 @@ func (h *AgentHandler) EinoSingleAgentLoopStream(c *gin.Context) {
var cancelWithCause context.CancelCauseFunc
curFinalMessage := prep.FinalMessage
segmentUserMessage := prep.FinalMessage // 本请求原始用户句,临时重试时不得丢失
curHistory := prep.History
roleTools := prep.RoleTools
@@ -176,9 +177,41 @@ func (h *AgentHandler) EinoSingleAgentLoopStream(c *gin.Context) {
taskOwned = true
var cumulativeMCPExecutionIDs []string
var transientRunAttempts int
// 同一请求内分段续跑时,主代理 iteration 事件按偏移累计,避免 UI 出现「第3轮 → 第1轮」回跳。
var mainIterationOffset int
for {
progressCallback := h.createProgressCallback(taskCtx, cancelWithCause, conversationID, assistantMessageID, sendEvent)
segmentMainIterationMax := 0
rawProgressCallback := h.createProgressCallback(taskCtx, cancelWithCause, conversationID, assistantMessageID, sendEvent)
progressCallback := func(eventType, message string, data interface{}) {
if eventType == "iteration" {
if m, ok := data.(map[string]interface{}); ok {
if scope, _ := m["einoScope"].(string); scope == "main" {
raw := 0
switch v := m["iteration"].(type) {
case int:
raw = v
case int32:
raw = int(v)
case int64:
raw = int(v)
case float64:
raw = int(v)
case float32:
raw = int(v)
}
if raw > 0 {
if raw > segmentMainIterationMax {
segmentMainIterationMax = raw
}
m["iteration"] = raw + mainIterationOffset
}
}
}
}
rawProgressCallback(eventType, message, data)
}
taskCtxLoop := mcp.WithMCPConversationID(taskCtx, conversationID)
taskCtxLoop = mcp.WithToolRunRegistry(taskCtxLoop, h.tasks)
taskCtxLoop = multiagent.WithHITLToolInterceptor(taskCtxLoop, func(ctx context.Context, toolName, arguments string) (string, error) {
@@ -198,16 +231,36 @@ func (h *AgentHandler) EinoSingleAgentLoopStream(c *gin.Context) {
progressCallback,
chatReasoningToClientIntent(req.Reasoning),
)
timeoutCancel()
if result != nil && len(result.MCPExecutionIDs) > 0 {
cumulativeMCPExecutionIDs = mergeMCPExecutionIDLists(cumulativeMCPExecutionIDs, result.MCPExecutionIDs)
}
if runErr == nil {
// 任一段成功完成后,重置临时错误重试窗口(次数/退避从头开始)。
transientRunAttempts = 0
timeoutCancel()
break
}
handled, fatalErr := h.handleEinoTransientRetryContinue(
baseCtx, conversationID, result, runErr, &transientRunAttempts,
&curHistory, &curFinalMessage, segmentUserMessage, progressCallback,
func(msg string, extra map[string]interface{}) { sendEvent("progress", msg, extra) },
)
if handled {
mainIterationOffset += segmentMainIterationMax
timeoutCancel()
baseCtx, cancelWithCause = context.WithCancelCause(context.Background())
h.tasks.BindTaskCancel(conversationID, cancelWithCause)
taskCtx, timeoutCancel = context.WithTimeout(baseCtx, 600*time.Minute)
h.tasks.UpdateTaskStatus(conversationID, "running")
continue
}
if fatalErr != nil {
runErr = fatalErr
}
cause := context.Cause(baseCtx)
if errors.Is(cause, multiagent.ErrInterruptContinue) {
if shouldPersistEinoAgentTraceAfterRunError(baseCtx) {
@@ -231,10 +284,14 @@ func (h *AgentHandler) EinoSingleAgentLoopStream(c *gin.Context) {
"conversationId": conversationID,
"source": "interrupt_continue",
})
h.tasks.UpdateTaskStatus(conversationID, "running")
mainIterationOffset += segmentMainIterationMax
// 非临时错误分段续跑(用户中断并继续)时,清空 transient 计数,避免跨分段累加。
transientRunAttempts = 0
timeoutCancel()
baseCtx, cancelWithCause = context.WithCancelCause(context.Background())
h.tasks.BindTaskCancel(conversationID, cancelWithCause)
taskCtx, timeoutCancel = context.WithTimeout(baseCtx, 600*time.Minute)
h.tasks.UpdateTaskStatus(conversationID, "running")
continue
}
@@ -261,6 +318,7 @@ func (h *AgentHandler) EinoSingleAgentLoopStream(c *gin.Context) {
"messageId": assistantMessageID,
})
sendEvent("done", "", map[string]interface{}{"conversationId": conversationID})
timeoutCancel()
return
}
@@ -278,6 +336,7 @@ func (h *AgentHandler) EinoSingleAgentLoopStream(c *gin.Context) {
"errorType": "timeout",
})
sendEvent("done", "", map[string]interface{}{"conversationId": conversationID})
timeoutCancel()
return
}
@@ -294,9 +353,12 @@ func (h *AgentHandler) EinoSingleAgentLoopStream(c *gin.Context) {
"messageId": assistantMessageID,
})
sendEvent("done", "", map[string]interface{}{"conversationId": conversationID})
timeoutCancel()
return
}
timeoutCancel()
if assistantMessageID != "" {
_ = h.db.UpdateAssistantMessageFinalize(assistantMessageID, result.Response, cumulativeMCPExecutionIDs, multiagent.AggregatedReasoningFromTraceJSON(result.LastAgentTraceInput))
}
@@ -326,7 +388,7 @@ func (h *AgentHandler) EinoSingleAgentLoop(c *gin.Context) {
h.logger.Info("收到 Eino ADK 单代理非流式请求", zap.String("conversationId", req.ConversationID))
prep, err := h.prepareMultiAgentSession(&req)
prep, err := h.prepareMultiAgentSession(&req, c, "eino_agent")
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
+27
View File
@@ -6,6 +6,7 @@ import (
"os"
"sync"
"cyberstrike-ai/internal/audit"
"cyberstrike-ai/internal/config"
"cyberstrike-ai/internal/mcp"
@@ -20,9 +21,15 @@ type ExternalMCPHandler struct {
config *config.Config
configPath string
logger *zap.Logger
audit *audit.Service
mu sync.RWMutex
}
// SetAudit wires platform audit logging.
func (h *ExternalMCPHandler) SetAudit(s *audit.Service) {
h.audit = s
}
// NewExternalMCPHandler 创建外部MCP处理器
func NewExternalMCPHandler(manager *mcp.ExternalMCPManager, cfg *config.Config, configPath string, logger *zap.Logger) *ExternalMCPHandler {
return &ExternalMCPHandler{
@@ -180,6 +187,16 @@ func (h *ExternalMCPHandler) AddOrUpdateExternalMCP(c *gin.Context) {
}
h.logger.Info("外部MCP配置已更新", zap.String("name", name))
if h.audit != nil {
h.audit.Record(c, audit.Entry{
Category: "external_mcp",
Action: "upsert",
Result: "success",
ResourceType: "external_mcp",
ResourceID: name,
Message: "更新外部 MCP 配置",
})
}
c.JSON(http.StatusOK, gin.H{"message": "配置已更新"})
}
@@ -209,6 +226,16 @@ func (h *ExternalMCPHandler) DeleteExternalMCP(c *gin.Context) {
}
h.logger.Info("外部MCP配置已删除", zap.String("name", name))
if h.audit != nil {
h.audit.Record(c, audit.Entry{
Category: "external_mcp",
Action: "delete",
Result: "success",
ResourceType: "external_mcp",
ResourceID: name,
Message: "删除外部 MCP 配置",
})
}
c.JSON(http.StatusOK, gin.H{"message": "配置已删除"})
}
+5
View File
@@ -616,6 +616,11 @@ func (h *AgentHandler) DecideHITLInterrupt(c *gin.Context) {
c.JSON(http.StatusConflict, gin.H{"error": err.Error()})
return
}
if h.audit != nil {
h.audit.RecordOK(c, "hitl", "decision", "HITL 审批决策", "hitl_interrupt", req.InterruptID, map[string]interface{}{
"decision": req.Decision,
})
}
c.JSON(http.StatusOK, gin.H{"ok": true})
}
+13
View File
@@ -6,6 +6,7 @@ import (
"net/http"
"time"
"cyberstrike-ai/internal/audit"
"cyberstrike-ai/internal/database"
"cyberstrike-ai/internal/knowledge"
@@ -20,6 +21,12 @@ type KnowledgeHandler struct {
indexer *knowledge.Indexer
db *database.DB
logger *zap.Logger
audit *audit.Service
}
// SetAudit wires platform audit logging.
func (h *KnowledgeHandler) SetAudit(s *audit.Service) {
h.audit = s
}
// NewKnowledgeHandler 创建新的知识库处理器
@@ -303,6 +310,9 @@ func (h *KnowledgeHandler) DeleteItem(c *gin.Context) {
return
}
if h.audit != nil {
h.audit.RecordOK(c, "knowledge", "item_delete", "删除知识项", "knowledge_item", id, nil)
}
c.JSON(http.StatusOK, gin.H{"message": "删除成功"})
}
@@ -316,6 +326,9 @@ func (h *KnowledgeHandler) RebuildIndex(c *gin.Context) {
}
}()
if h.audit != nil {
h.audit.RecordOK(c, "knowledge", "index_rebuild", "重建知识库索引", "knowledge", "", nil)
}
c.JSON(http.StatusOK, gin.H{"message": "索引重建已开始,将在后台进行"})
}
+17 -1
View File
@@ -9,6 +9,7 @@ import (
"strings"
"cyberstrike-ai/internal/agents"
"cyberstrike-ai/internal/audit"
"cyberstrike-ai/internal/config"
"github.com/gin-gonic/gin"
@@ -18,7 +19,8 @@ var markdownAgentFilenameRe = regexp.MustCompile(`^[a-zA-Z0-9][a-zA-Z0-9_.-]*\.m
// MarkdownAgentsHandler 管理 agents 目录下子代理 Markdown(增删改查)。
type MarkdownAgentsHandler struct {
dir string
dir string
audit *audit.Service
}
// NewMarkdownAgentsHandler dir 须为已解析的绝对路径。
@@ -26,6 +28,11 @@ func NewMarkdownAgentsHandler(dir string) *MarkdownAgentsHandler {
return &MarkdownAgentsHandler{dir: strings.TrimSpace(dir)}
}
// SetAudit wires platform audit logging.
func (h *MarkdownAgentsHandler) SetAudit(s *audit.Service) {
h.audit = s
}
func (h *MarkdownAgentsHandler) safeJoin(filename string) (string, error) {
filename = strings.TrimSpace(filename)
if filename == "" || !markdownAgentFilenameRe.MatchString(filename) {
@@ -227,6 +234,9 @@ func (h *MarkdownAgentsHandler) CreateMarkdownAgent(c *gin.Context) {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
if h.audit != nil {
h.audit.RecordOK(c, "agent", "markdown_create", "创建 Markdown 子代理", "markdown_agent", filepath.Base(path), nil)
}
c.JSON(http.StatusOK, gin.H{"filename": filepath.Base(path), "message": "已创建"})
}
@@ -294,6 +304,9 @@ func (h *MarkdownAgentsHandler) UpdateMarkdownAgent(c *gin.Context) {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
if h.audit != nil {
h.audit.RecordOK(c, "agent", "markdown_update", "更新 Markdown 子代理", "markdown_agent", filename, nil)
}
c.JSON(http.StatusOK, gin.H{"message": "已保存"})
}
@@ -313,5 +326,8 @@ func (h *MarkdownAgentsHandler) DeleteMarkdownAgent(c *gin.Context) {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
if h.audit != nil {
h.audit.RecordOK(c, "agent", "markdown_delete", "删除 Markdown 子代理", "markdown_agent", filename, nil)
}
c.JSON(http.StatusOK, gin.H{"message": "已删除"})
}
+17
View File
@@ -9,6 +9,7 @@ import (
"strings"
"time"
"cyberstrike-ai/internal/audit"
"cyberstrike-ai/internal/database"
"cyberstrike-ai/internal/mcp"
"cyberstrike-ai/internal/security"
@@ -23,6 +24,12 @@ type MonitorHandler struct {
executor *security.Executor
db *database.DB
logger *zap.Logger
audit *audit.Service
}
// SetAudit wires platform audit logging.
func (h *MonitorHandler) SetAudit(s *audit.Service) {
h.audit = s
}
// NewMonitorHandler 创建新的监控处理器
@@ -365,6 +372,11 @@ func (h *MonitorHandler) DeleteExecution(c *gin.Context) {
}
h.logger.Info("执行记录已从数据库删除", zap.String("executionId", id), zap.String("toolName", exec.ToolName))
if h.audit != nil {
h.audit.RecordOK(c, "tool", "execution_delete", "删除工具执行记录", "tool_execution", id, map[string]interface{}{
"tool_name": exec.ToolName,
})
}
c.JSON(http.StatusOK, gin.H{"message": "执行记录已删除"})
return
}
@@ -440,6 +452,11 @@ func (h *MonitorHandler) DeleteExecutions(c *gin.Context) {
}
h.logger.Info("批量删除执行记录成功", zap.Int("count", len(request.IDs)))
if h.audit != nil {
h.audit.RecordOK(c, "tool", "execution_delete_batch", "批量删除工具执行记录", "tool_execution", "", map[string]interface{}{
"count": len(request.IDs),
})
}
c.JSON(http.StatusOK, gin.H{"message": "成功删除执行记录", "deleted": len(executions)})
return
}
+67 -5
View File
@@ -107,7 +107,7 @@ func (h *AgentHandler) MultiAgentLoopStream(c *gin.Context) {
zap.String("conversationId", req.ConversationID),
)
prep, err := h.prepareMultiAgentSession(&req)
prep, err := h.prepareMultiAgentSession(&req, c, "multi_agent_stream")
if err != nil {
sendEvent("error", err.Error(), nil)
sendEvent("done", "", nil)
@@ -136,6 +136,7 @@ func (h *AgentHandler) MultiAgentLoopStream(c *gin.Context) {
var cancelWithCause context.CancelCauseFunc
curFinalMessage := prep.FinalMessage
segmentUserMessage := prep.FinalMessage // 本请求原始用户句,临时重试时不得丢失
curHistory := prep.History
roleTools := prep.RoleTools
orch := strings.TrimSpace(req.Orchestration)
@@ -186,9 +187,41 @@ func (h *AgentHandler) MultiAgentLoopStream(c *gin.Context) {
// 同一 HTTP 流内多段 Run(如中断并继续)合并 MCP execution id,供最终 response / 库表与工具芯片展示完整列表
var cumulativeMCPExecutionIDs []string
var transientRunAttempts int
// 同一请求内分段续跑时,主代理 iteration 事件按偏移累计,避免 UI 出现「第3轮 → 第1轮」回跳。
var mainIterationOffset int
for {
progressCallback := h.createProgressCallback(taskCtx, cancelWithCause, conversationID, assistantMessageID, sendEvent)
segmentMainIterationMax := 0
rawProgressCallback := h.createProgressCallback(taskCtx, cancelWithCause, conversationID, assistantMessageID, sendEvent)
progressCallback := func(eventType, message string, data interface{}) {
if eventType == "iteration" {
if m, ok := data.(map[string]interface{}); ok {
if scope, _ := m["einoScope"].(string); scope == "main" {
raw := 0
switch v := m["iteration"].(type) {
case int:
raw = v
case int32:
raw = int(v)
case int64:
raw = int(v)
case float64:
raw = int(v)
case float32:
raw = int(v)
}
if raw > 0 {
if raw > segmentMainIterationMax {
segmentMainIterationMax = raw
}
m["iteration"] = raw + mainIterationOffset
}
}
}
}
rawProgressCallback(eventType, message, data)
}
taskCtxLoop := mcp.WithMCPConversationID(taskCtx, conversationID)
taskCtxLoop = mcp.WithToolRunRegistry(taskCtxLoop, h.tasks)
taskCtxLoop = multiagent.WithHITLToolInterceptor(taskCtxLoop, func(ctx context.Context, toolName, arguments string) (string, error) {
@@ -210,16 +243,36 @@ func (h *AgentHandler) MultiAgentLoopStream(c *gin.Context) {
orch,
chatReasoningToClientIntent(req.Reasoning),
)
timeoutCancel()
if result != nil && len(result.MCPExecutionIDs) > 0 {
cumulativeMCPExecutionIDs = mergeMCPExecutionIDLists(cumulativeMCPExecutionIDs, result.MCPExecutionIDs)
}
if runErr == nil {
// 任一段成功完成后,重置临时错误重试窗口(次数/退避从头开始)。
transientRunAttempts = 0
timeoutCancel()
break
}
handled, fatalErr := h.handleEinoTransientRetryContinue(
baseCtx, conversationID, result, runErr, &transientRunAttempts,
&curHistory, &curFinalMessage, segmentUserMessage, progressCallback,
func(msg string, extra map[string]interface{}) { sendEvent("progress", msg, extra) },
)
if handled {
mainIterationOffset += segmentMainIterationMax
timeoutCancel()
baseCtx, cancelWithCause = context.WithCancelCause(context.Background())
h.tasks.BindTaskCancel(conversationID, cancelWithCause)
taskCtx, timeoutCancel = context.WithTimeout(baseCtx, 600*time.Minute)
h.tasks.UpdateTaskStatus(conversationID, "running")
continue
}
if fatalErr != nil {
runErr = fatalErr
}
cause := context.Cause(baseCtx)
if errors.Is(cause, multiagent.ErrInterruptContinue) {
if shouldPersistEinoAgentTraceAfterRunError(baseCtx) {
@@ -243,10 +296,14 @@ func (h *AgentHandler) MultiAgentLoopStream(c *gin.Context) {
"conversationId": conversationID,
"source": "interrupt_continue",
})
h.tasks.UpdateTaskStatus(conversationID, "running")
mainIterationOffset += segmentMainIterationMax
// 非临时错误分段续跑(用户中断并继续)时,清空 transient 计数,避免跨分段累加。
transientRunAttempts = 0
timeoutCancel()
baseCtx, cancelWithCause = context.WithCancelCause(context.Background())
h.tasks.BindTaskCancel(conversationID, cancelWithCause)
taskCtx, timeoutCancel = context.WithTimeout(baseCtx, 600*time.Minute)
h.tasks.UpdateTaskStatus(conversationID, "running")
continue
}
@@ -273,6 +330,7 @@ func (h *AgentHandler) MultiAgentLoopStream(c *gin.Context) {
"messageId": assistantMessageID,
})
sendEvent("done", "", map[string]interface{}{"conversationId": conversationID})
timeoutCancel()
return
}
@@ -290,6 +348,7 @@ func (h *AgentHandler) MultiAgentLoopStream(c *gin.Context) {
"errorType": "timeout",
})
sendEvent("done", "", map[string]interface{}{"conversationId": conversationID})
timeoutCancel()
return
}
@@ -306,9 +365,12 @@ func (h *AgentHandler) MultiAgentLoopStream(c *gin.Context) {
"messageId": assistantMessageID,
})
sendEvent("done", "", map[string]interface{}{"conversationId": conversationID})
timeoutCancel()
return
}
timeoutCancel()
if assistantMessageID != "" {
_ = h.db.UpdateAssistantMessageFinalize(assistantMessageID, result.Response, cumulativeMCPExecutionIDs, multiagent.AggregatedReasoningFromTraceJSON(result.LastAgentTraceInput))
}
@@ -347,7 +409,7 @@ func (h *AgentHandler) MultiAgentLoop(c *gin.Context) {
h.logger.Info("收到 Eino DeepAgent 非流式请求", zap.String("conversationId", req.ConversationID))
prep, err := h.prepareMultiAgentSession(&req)
prep, err := h.prepareMultiAgentSession(&req, c, "multi_agent")
if err != nil {
status, msg := multiAgentHTTPErrorStatus(err)
c.JSON(status, gin.H{"error": msg})
+8 -3
View File
@@ -5,9 +5,11 @@ import (
"strings"
"cyberstrike-ai/internal/agent"
"cyberstrike-ai/internal/audit"
"cyberstrike-ai/internal/database"
"cyberstrike-ai/internal/mcp/builtin"
"github.com/gin-gonic/gin"
"go.uber.org/zap"
)
@@ -22,7 +24,7 @@ type multiAgentPrepared struct {
UserMessageID string
}
func (h *AgentHandler) prepareMultiAgentSession(req *ChatRequest) (*multiAgentPrepared, error) {
func (h *AgentHandler) prepareMultiAgentSession(req *ChatRequest, c *gin.Context, source string) (*multiAgentPrepared, error) {
if len(req.Attachments) > maxAttachments {
return nil, fmt.Errorf("附件最多 %d 个", maxAttachments)
}
@@ -33,10 +35,13 @@ func (h *AgentHandler) prepareMultiAgentSession(req *ChatRequest) (*multiAgentPr
title := safeTruncateString(req.Message, 50)
var conv *database.Conversation
var err error
meta := audit.ConversationCreateMetaFromGin(c, source)
if strings.TrimSpace(req.WebShellConnectionID) != "" {
conv, err = h.db.CreateConversationWithWebshell(strings.TrimSpace(req.WebShellConnectionID), title)
meta.Source = source + "_webshell"
meta.WebShellConnectionID = strings.TrimSpace(req.WebShellConnectionID)
conv, err = h.db.CreateConversationWithWebshell(meta.WebShellConnectionID, title, meta)
} else {
conv, err = h.db.CreateConversation(title)
conv, err = h.db.CreateConversation(title, meta)
}
if err != nil {
return nil, fmt.Errorf("创建对话失败: %w", err)
+1 -1
View File
@@ -6254,7 +6254,7 @@ func (h *OpenAPIHandler) GetConversationResults(c *gin.Context) {
}
// 获取漏洞列表
vulnList, err := h.db.ListVulnerabilities(1000, 0, "", conversationID, "", "", "", "", "")
vulnList, err := h.db.ListVulnerabilities(1000, 0, database.VulnerabilityListFilter{ConversationID: conversationID})
if err != nil {
h.logger.Warn("获取漏洞列表失败", zap.Error(err))
vulnList = []*database.Vulnerability{}
+3 -3
View File
@@ -133,7 +133,7 @@ func (h *RobotHandler) getOrCreateConversation(platform, userID, title string) (
} else {
t = safeTruncateString(t, 50)
}
conv, err := h.db.CreateConversation(t)
conv, err := h.db.CreateConversation(t, database.ConversationCreateMeta{Source: "robot:" + platform})
if err != nil {
h.logger.Warn("创建机器人会话失败", zap.Error(err))
return "", false
@@ -188,7 +188,7 @@ func (h *RobotHandler) setRole(platform, userID, roleName string) {
// clearConversation 清空当前会话(切换到新对话)
func (h *RobotHandler) clearConversation(platform, userID string) (newConvID string) {
title := "新对话 " + time.Now().Format("01-02 15:04")
conv, err := h.db.CreateConversation(title)
conv, err := h.db.CreateConversation(title, database.ConversationCreateMeta{Source: "robot:" + platform + ":new"})
if err != nil {
h.logger.Warn("创建新对话失败", zap.Error(err))
return ""
@@ -242,7 +242,7 @@ func (h *RobotHandler) HandleMessage(platform, userID, text string) (reply strin
h.cancelMu.Unlock()
}()
role := h.getRole(platform, userID)
resp, newConvID, err := h.agentHandler.ProcessMessageForRobot(ctx, convID, text, role)
resp, newConvID, err := h.agentHandler.ProcessMessageForRobot(ctx, platform, convID, text, role)
if err != nil {
h.logger.Warn("机器人 Agent 执行失败", zap.String("platform", platform), zap.String("userID", userID), zap.Error(err))
if errors.Is(err, context.Canceled) {
+16
View File
@@ -8,6 +8,7 @@ import (
"regexp"
"strings"
"cyberstrike-ai/internal/audit"
"cyberstrike-ai/internal/config"
"gopkg.in/yaml.v3"
@@ -21,6 +22,12 @@ type RoleHandler struct {
config *config.Config
configPath string
logger *zap.Logger
audit *audit.Service
}
// SetAudit wires platform audit logging.
func (h *RoleHandler) SetAudit(s *audit.Service) {
h.audit = s
}
// NewRoleHandler 创建新的角色处理器
@@ -174,6 +181,9 @@ func (h *RoleHandler) UpdateRole(c *gin.Context) {
}
h.logger.Info("更新角色", zap.String("oldKey", roleName), zap.String("newKey", finalKey), zap.String("name", req.Name))
if h.audit != nil {
h.audit.RecordOK(c, "role", "update", "更新角色", "role", finalKey, map[string]interface{}{"name": req.Name})
}
c.JSON(http.StatusOK, gin.H{
"message": "角色已更新",
"role": req,
@@ -219,6 +229,9 @@ func (h *RoleHandler) CreateRole(c *gin.Context) {
}
h.logger.Info("创建角色", zap.String("roleName", req.Name))
if h.audit != nil {
h.audit.RecordOK(c, "role", "create", "创建角色", "role", req.Name, nil)
}
c.JSON(http.StatusOK, gin.H{
"message": "角色已创建",
"role": req,
@@ -287,6 +300,9 @@ func (h *RoleHandler) DeleteRole(c *gin.Context) {
}
h.logger.Info("删除角色", zap.String("roleName", roleName))
if h.audit != nil {
h.audit.RecordOK(c, "role", "delete", "删除角色", "role", roleName, nil)
}
c.JSON(http.StatusOK, gin.H{
"message": "角色已删除",
})
+18
View File
@@ -8,6 +8,7 @@ import (
"regexp"
"strings"
"cyberstrike-ai/internal/audit"
"cyberstrike-ai/internal/config"
"cyberstrike-ai/internal/database"
"cyberstrike-ai/internal/skillpackage"
@@ -23,6 +24,12 @@ type SkillsHandler struct {
configPath string
logger *zap.Logger
db *database.DB // 数据库连接(遗留统计;MCP list/read 已移除)
audit *audit.Service
}
// SetAudit wires platform audit logging.
func (h *SkillsHandler) SetAudit(s *audit.Service) {
h.audit = s
}
// NewSkillsHandler 创建新的Skills处理器
@@ -365,6 +372,9 @@ func (h *SkillsHandler) CreateSkill(c *gin.Context) {
}
h.logger.Info("创建skill成功", zap.String("skill", req.Name))
if h.audit != nil {
h.audit.RecordOK(c, "skill", "create", "创建 Skill", "skill", req.Name, nil)
}
c.JSON(http.StatusOK, gin.H{
"message": "skill已创建",
"skill": map[string]interface{}{
@@ -425,6 +435,9 @@ func (h *SkillsHandler) UpdateSkill(c *gin.Context) {
}
h.logger.Info("更新skill成功", zap.String("skill", skillName))
if h.audit != nil {
h.audit.RecordOK(c, "skill", "update", "更新 Skill", "skill", skillName, nil)
}
c.JSON(http.StatusOK, gin.H{
"message": "skill已更新",
})
@@ -459,6 +472,11 @@ func (h *SkillsHandler) DeleteSkill(c *gin.Context) {
}
h.logger.Info("删除skill成功", zap.String("skill", skillName))
if h.audit != nil {
h.audit.RecordOK(c, "skill", "delete", "删除 Skill", "skill", skillName, map[string]interface{}{
"affected_roles": affectedRoles,
})
}
c.JSON(http.StatusOK, gin.H{
"message": responseMsg,
"affected_roles": affectedRoles,
+1 -1
View File
@@ -253,5 +253,5 @@ func (h *TerminalHandler) RunCommandStream(c *gin.Context) {
flusher.Flush()
}
runCommandStreamImpl(cmd, sendEvent, ctx)
_ = runCommandStreamImpl(cmd, sendEvent, ctx)
}
+3 -2
View File
@@ -15,11 +15,11 @@ const ptyCols = 256
const ptyRows = 40
// runCommandStreamImpl 在 Unix 下用 PTY 执行,使 ping 等命令按终端宽度排版(isatty 为真)
func runCommandStreamImpl(cmd *exec.Cmd, sendEvent func(streamEvent), ctx context.Context) {
func runCommandStreamImpl(cmd *exec.Cmd, sendEvent func(streamEvent), ctx context.Context) int {
ptmx, err := pty.StartWithSize(cmd, &pty.Winsize{Cols: ptyCols, Rows: ptyRows})
if err != nil {
sendEvent(streamEvent{T: "exit", C: -1})
return
return -1
}
defer ptmx.Close()
@@ -43,4 +43,5 @@ func runCommandStreamImpl(cmd *exec.Cmd, sendEvent func(streamEvent), ctx contex
exitCode = -1
}
sendEvent(streamEvent{T: "exit", C: exitCode})
return exitCode
}
+5 -4
View File
@@ -11,20 +11,20 @@ import (
)
// runCommandStreamImpl 在 Windows 下用 stdout/stderr 管道执行
func runCommandStreamImpl(cmd *exec.Cmd, sendEvent func(streamEvent), ctx context.Context) {
func runCommandStreamImpl(cmd *exec.Cmd, sendEvent func(streamEvent), ctx context.Context) int {
stdoutPipe, err := cmd.StdoutPipe()
if err != nil {
sendEvent(streamEvent{T: "exit", C: -1})
return
return -1
}
stderrPipe, err := cmd.StderrPipe()
if err != nil {
sendEvent(streamEvent{T: "exit", C: -1})
return
return -1
}
if err := cmd.Start(); err != nil {
sendEvent(streamEvent{T: "exit", C: -1})
return
return -1
}
normalize := func(s string) string {
@@ -62,4 +62,5 @@ func runCommandStreamImpl(cmd *exec.Cmd, sendEvent func(streamEvent), ctx contex
exitCode = -1
}
sendEvent(streamEvent{T: "exit", C: exitCode})
return exitCode
}
+53 -21
View File
@@ -7,6 +7,7 @@ import (
"strings"
"time"
"cyberstrike-ai/internal/audit"
"cyberstrike-ai/internal/database"
"github.com/gin-gonic/gin"
"go.uber.org/zap"
@@ -16,6 +17,12 @@ import (
type VulnerabilityHandler struct {
db *database.DB
logger *zap.Logger
audit *audit.Service
}
// SetAudit wires platform audit logging.
func (h *VulnerabilityHandler) SetAudit(s *audit.Service) {
h.audit = s
}
// NewVulnerabilityHandler 创建新的漏洞处理器
@@ -72,6 +79,11 @@ func (h *VulnerabilityHandler) CreateVulnerability(c *gin.Context) {
return
}
if h.audit != nil {
h.audit.RecordOK(c, "vulnerability", "create", "创建漏洞记录", "vulnerability", created.ID, map[string]interface{}{
"severity": created.Severity, "title": created.Title,
})
}
c.JSON(http.StatusOK, created)
}
@@ -98,18 +110,29 @@ type ListVulnerabilitiesResponse struct {
TotalPages int `json:"total_pages"`
}
func parseVulnerabilityListFilter(c *gin.Context) database.VulnerabilityListFilter {
q := strings.TrimSpace(c.Query("q"))
if q == "" {
q = strings.TrimSpace(c.Query("search"))
}
return database.VulnerabilityListFilter{
ID: c.Query("id"),
Search: q,
ConversationID: c.Query("conversation_id"),
Severity: c.Query("severity"),
Status: c.Query("status"),
TaskID: c.Query("task_id"),
ConversationTag: c.Query("conversation_tag"),
TaskTag: c.Query("task_tag"),
}
}
// ListVulnerabilities 列出漏洞
func (h *VulnerabilityHandler) ListVulnerabilities(c *gin.Context) {
limitStr := c.DefaultQuery("limit", "20")
offsetStr := c.DefaultQuery("offset", "0")
pageStr := c.Query("page")
id := c.Query("id")
conversationID := c.Query("conversation_id")
severity := c.Query("severity")
status := c.Query("status")
taskID := c.Query("task_id")
conversationTag := c.Query("conversation_tag")
taskTag := c.Query("task_tag")
filter := parseVulnerabilityListFilter(c)
limit, _ := strconv.Atoi(limitStr)
offset, _ := strconv.Atoi(offsetStr)
@@ -131,7 +154,7 @@ func (h *VulnerabilityHandler) ListVulnerabilities(c *gin.Context) {
}
// 获取总数
total, err := h.db.CountVulnerabilities(id, conversationID, severity, status, taskID, conversationTag, taskTag)
total, err := h.db.CountVulnerabilities(filter)
if err != nil {
h.logger.Error("获取漏洞总数失败", zap.Error(err))
// 继续执行,使用0作为总数
@@ -139,7 +162,7 @@ func (h *VulnerabilityHandler) ListVulnerabilities(c *gin.Context) {
}
// 获取漏洞列表
vulnerabilities, err := h.db.ListVulnerabilities(limit, offset, id, conversationID, severity, status, taskID, conversationTag, taskTag)
vulnerabilities, err := h.db.ListVulnerabilities(limit, offset, filter)
if err != nil {
h.logger.Error("获取漏洞列表失败", zap.Error(err))
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
@@ -249,6 +272,11 @@ func (h *VulnerabilityHandler) UpdateVulnerability(c *gin.Context) {
return
}
if h.audit != nil {
h.audit.RecordOK(c, "vulnerability", "update", "更新漏洞记录", "vulnerability", id, map[string]interface{}{
"severity": updated.Severity, "status": updated.Status,
})
}
c.JSON(http.StatusOK, updated)
}
@@ -262,15 +290,25 @@ func (h *VulnerabilityHandler) DeleteVulnerability(c *gin.Context) {
return
}
if h.audit != nil {
h.audit.Record(c, audit.Entry{
Category: "vulnerability",
Action: "delete",
Result: "success",
ResourceType: "vulnerability",
ResourceID: id,
Message: "删除漏洞记录",
})
}
c.JSON(http.StatusOK, gin.H{"message": "删除成功"})
}
// GetVulnerabilityStats 获取漏洞统计
func (h *VulnerabilityHandler) GetVulnerabilityStats(c *gin.Context) {
conversationID := c.Query("conversation_id")
taskID := c.Query("task_id")
filter := parseVulnerabilityListFilter(c)
stats, err := h.db.GetVulnerabilityStats(conversationID, taskID)
stats, err := h.db.GetVulnerabilityStats(filter)
if err != nil {
h.logger.Error("获取漏洞统计失败", zap.Error(err))
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
@@ -304,15 +342,9 @@ func (h *VulnerabilityHandler) ExportVulnerabilities(c *gin.Context) {
return
}
id := c.Query("id")
conversationID := c.Query("conversation_id")
severity := c.Query("severity")
status := c.Query("status")
taskID := c.Query("task_id")
conversationTag := c.Query("conversation_tag")
taskTag := c.Query("task_tag")
filter := parseVulnerabilityListFilter(c)
total, err := h.db.CountVulnerabilities(id, conversationID, severity, status, taskID, conversationTag, taskTag)
total, err := h.db.CountVulnerabilities(filter)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
@@ -322,7 +354,7 @@ func (h *VulnerabilityHandler) ExportVulnerabilities(c *gin.Context) {
return
}
items, err := h.db.ListVulnerabilities(total, 0, id, conversationID, severity, status, taskID, conversationTag, taskTag)
items, err := h.db.ListVulnerabilities(total, 0, filter)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
+28 -3
View File
@@ -2,6 +2,7 @@ package handler
import (
"bytes"
"crypto/tls"
"database/sql"
"encoding/base64"
"encoding/json"
@@ -12,6 +13,7 @@ import (
"time"
"unicode/utf8"
"cyberstrike-ai/internal/audit"
"cyberstrike-ai/internal/database"
"github.com/gin-gonic/gin"
@@ -304,6 +306,12 @@ type WebShellHandler struct {
logger *zap.Logger
client *http.Client
db *database.DB
audit *audit.Service
}
// SetAudit wires platform audit logging.
func (h *WebShellHandler) SetAudit(s *audit.Service) {
h.audit = s
}
// NewWebShellHandler 创建 WebShell 处理器,db 可为 nil(连接配置接口将不可用)
@@ -311,8 +319,12 @@ func NewWebShellHandler(logger *zap.Logger, db *database.DB) *WebShellHandler {
return &WebShellHandler{
logger: logger,
client: &http.Client{
Timeout: 30 * time.Second,
Transport: &http.Transport{DisableKeepAlives: false},
Timeout: 30 * time.Second,
Transport: &http.Transport{
DisableKeepAlives: false,
// WebShell 场景常见自签证书或 IP 访问(证书无 IP SAN);默认跳过校验,与蚁剑等客户端一致。
TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, //nolint:gosec // intentional for webshell proxy
},
},
db: db,
}
@@ -403,6 +415,15 @@ func (h *WebShellHandler) CreateConnection(c *gin.Context) {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
if h.audit != nil {
host := req.URL
if u, err := url.Parse(req.URL); err == nil {
host = u.Host
}
h.audit.RecordOK(c, "webshell", "connection_create", "创建 WebShell 连接", "webshell_connection", conn.ID, map[string]interface{}{
"host": host, "type": shellType,
})
}
c.JSON(http.StatusOK, conn)
}
@@ -485,6 +506,9 @@ func (h *WebShellHandler) DeleteConnection(c *gin.Context) {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
if h.audit != nil {
h.audit.RecordOK(c, "webshell", "connection_delete", "删除 WebShell 连接", "webshell_connection", id, nil)
}
c.JSON(http.StatusOK, gin.H{"ok": true})
}
@@ -714,8 +738,9 @@ func (h *WebShellHandler) Exec(c *gin.Context) {
output := decodeWebshellOutput(out, req.Encoding)
httpCode := resp.StatusCode
ok := resp.StatusCode == http.StatusOK
c.JSON(http.StatusOK, ExecResponse{
OK: resp.StatusCode == http.StatusOK,
OK: ok,
Output: output,
HTTPCode: httpCode,
})
+293
View File
@@ -0,0 +1,293 @@
package handler
import (
"context"
"net/http"
"strings"
"sync"
"time"
"cyberstrike-ai/internal/config"
"cyberstrike-ai/internal/robot/ilink"
"github.com/gin-gonic/gin"
"github.com/google/uuid"
"go.uber.org/zap"
)
const wechatLoginTTL = 5 * time.Minute
// WechatConfigSaver 绑定成功后写入配置并重启机器人连接
type WechatConfigSaver interface {
ApplyWechatRobotBinding(cfg config.RobotWechatConfig) error
}
type wechatLoginSession struct {
QRCode string
QRCodeImgURL string
PendingVerify string
CurrentBaseURL string
StartedAt time.Time
}
// WechatRobotHandler 微信 iLink 机器人(扫码绑定 + 配置)
type WechatRobotHandler struct {
config *config.Config
configSaver WechatConfigSaver
logger *zap.Logger
mu sync.Mutex
logins map[string]*wechatLoginSession
}
// NewWechatRobotHandler 创建微信机器人处理器
func NewWechatRobotHandler(cfg *config.Config, saver WechatConfigSaver, logger *zap.Logger) *WechatRobotHandler {
return &WechatRobotHandler{
config: cfg,
configSaver: saver,
logger: logger,
logins: make(map[string]*wechatLoginSession),
}
}
func (h *WechatRobotHandler) purgeExpiredLogins() {
now := time.Now()
for k, v := range h.logins {
if now.Sub(v.StartedAt) > wechatLoginTTL {
delete(h.logins, k)
}
}
}
func (h *WechatRobotHandler) ilinkClient(baseURL string) *ilink.Client {
ver := h.config.Version
if ver == "" {
ver = "1.0.0"
}
ver = strings.TrimPrefix(strings.TrimSpace(ver), "v")
ver = strings.TrimPrefix(ver, "V")
wc := h.config.Robots.Wechat
return ilink.NewClient(baseURL, wc.BotToken, wc.BotAgent, ilink.BuildClientVersion(ver))
}
// HandleWechatQRCode POST /api/robot/wechat/qrcode — 生成绑定二维码
func (h *WechatRobotHandler) HandleWechatQRCode(c *gin.Context) {
h.mu.Lock()
h.purgeExpiredLogins()
h.mu.Unlock()
var req struct {
BotType string `json:"bot_type"`
}
_ = c.ShouldBindJSON(&req)
botType := req.BotType
if botType == "" {
botType = h.config.Robots.Wechat.BotType
}
if botType == "" {
botType = ilink.DefaultBotType
}
baseURL := h.config.Robots.Wechat.BaseURL
if baseURL == "" {
baseURL = ilink.DefaultBaseURL
}
var localTokens []string
if t := h.config.Robots.Wechat.BotToken; t != "" {
localTokens = []string{t}
}
client := h.ilinkClient(baseURL)
ctx, cancel := context.WithTimeout(c.Request.Context(), 30*time.Second)
defer cancel()
qr, err := client.GetBotQRCode(ctx, botType, localTokens)
if err != nil {
h.logger.Warn("获取微信二维码失败", zap.Error(err))
c.JSON(http.StatusBadGateway, gin.H{"error": "获取二维码失败: " + err.Error()})
return
}
if qr.QRCode == "" || qr.QRCodeImgContent == "" {
c.JSON(http.StatusBadGateway, gin.H{"error": "微信服务器未返回有效二维码"})
return
}
sessionKey := uuid.New().String()
h.mu.Lock()
h.logins[sessionKey] = &wechatLoginSession{
QRCode: qr.QRCode,
QRCodeImgURL: qr.QRCodeImgContent,
CurrentBaseURL: baseURL,
StartedAt: time.Now(),
}
h.mu.Unlock()
resp := gin.H{
"session_key": sessionKey,
"qrcode": qr.QRCode,
"qrcode_open_url": qr.QRCodeImgContent,
"message": "请使用微信扫描二维码并确认绑定",
}
if dataURL, err := ilink.QRCodeDataURL(qr.QRCodeImgContent, 256); err != nil {
h.logger.Warn("生成二维码图片失败", zap.Error(err))
} else {
resp["qrcode_image_data_url"] = dataURL
}
c.JSON(http.StatusOK, resp)
}
// HandleWechatQRCodeStatus GET /api/robot/wechat/qrcode/status — 轮询扫码状态
func (h *WechatRobotHandler) HandleWechatQRCodeStatus(c *gin.Context) {
sessionKey := c.Query("session_key")
verifyCode := c.Query("verify_code")
if sessionKey == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "缺少 session_key"})
return
}
h.mu.Lock()
sess, ok := h.logins[sessionKey]
h.mu.Unlock()
if !ok {
c.JSON(http.StatusNotFound, gin.H{"error": "登录会话不存在或已过期,请重新生成二维码"})
return
}
if time.Since(sess.StartedAt) > wechatLoginTTL {
h.mu.Lock()
delete(h.logins, sessionKey)
h.mu.Unlock()
c.JSON(http.StatusGone, gin.H{"error": "二维码已过期,请重新生成"})
return
}
baseURL := sess.CurrentBaseURL
if baseURL == "" {
baseURL = ilink.DefaultBaseURL
}
vc := verifyCode
if vc == "" {
vc = sess.PendingVerify
}
client := h.ilinkClient(baseURL)
ctx, cancel := context.WithTimeout(c.Request.Context(), 40*time.Second)
defer cancel()
st, err := client.GetQRCodeStatus(ctx, sess.QRCode, vc)
if err != nil {
h.logger.Warn("轮询微信二维码状态失败", zap.Error(err))
c.JSON(http.StatusBadGateway, gin.H{"error": err.Error()})
return
}
switch st.Status {
case "wait", "scaned":
c.JSON(http.StatusOK, gin.H{"status": st.Status})
return
case "need_verifycode":
c.JSON(http.StatusOK, gin.H{
"status": st.Status,
"message": "请在手机微信查看配对数字,并在下方输入",
})
return
case "scaned_but_redirect":
if st.RedirectHost != "" {
h.mu.Lock()
if s, ok := h.logins[sessionKey]; ok {
s.CurrentBaseURL = "https://" + st.RedirectHost
}
h.mu.Unlock()
}
c.JSON(http.StatusOK, gin.H{"status": st.Status})
return
case "binded_redirect":
h.mu.Lock()
delete(h.logins, sessionKey)
h.mu.Unlock()
c.JSON(http.StatusOK, gin.H{
"status": st.Status,
"already_connected": true,
"message": "该微信已绑定过,无需重复绑定",
})
return
case "confirmed":
if st.BotToken == "" || st.ILinkBotID == "" {
c.JSON(http.StatusBadGateway, gin.H{"error": "绑定确认成功但缺少 bot_token"})
return
}
saveBase := st.BaseURL
if saveBase == "" {
saveBase = baseURL
}
wc := h.config.Robots.Wechat
wc.Enabled = true
wc.BotToken = st.BotToken
wc.ILinkBotID = st.ILinkBotID
wc.ILinkUserID = st.ILinkUserID
wc.BaseURL = saveBase
if wc.BotType == "" {
wc.BotType = ilink.DefaultBotType
}
if wc.BotAgent == "" {
wc.BotAgent = ilink.DefaultBotAgent
}
if h.configSaver != nil {
if err := h.configSaver.ApplyWechatRobotBinding(wc); err != nil {
h.logger.Warn("保存微信机器人配置失败", zap.Error(err))
c.JSON(http.StatusInternalServerError, gin.H{"error": "保存配置失败: " + err.Error()})
return
}
} else {
h.config.Robots.Wechat = wc
}
h.mu.Lock()
delete(h.logins, sessionKey)
h.mu.Unlock()
c.JSON(http.StatusOK, gin.H{
"status": "confirmed",
"message": "绑定成功,微信机器人已启用",
"ilink_bot_id": st.ILinkBotID,
"ilink_user_id": st.ILinkUserID,
})
return
default:
c.JSON(http.StatusOK, gin.H{"status": st.Status})
}
}
// HandleWechatVerifyCode POST /api/robot/wechat/qrcode/verify — 提交手机配对数字
func (h *WechatRobotHandler) HandleWechatVerifyCode(c *gin.Context) {
var req struct {
SessionKey string `json:"session_key"`
VerifyCode string `json:"verify_code"`
}
if err := c.ShouldBindJSON(&req); err != nil || req.SessionKey == "" || req.VerifyCode == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "需要 session_key 与 verify_code"})
return
}
h.mu.Lock()
sess, ok := h.logins[req.SessionKey]
if ok {
sess.PendingVerify = req.VerifyCode
}
h.mu.Unlock()
if !ok {
c.JSON(http.StatusNotFound, gin.H{"error": "登录会话不存在或已过期"})
return
}
c.JSON(http.StatusOK, gin.H{"message": "已提交配对码,请继续等待绑定"})
}
// HandleWechatStatus GET /api/robot/wechat/status — 当前绑定状态(供前端展示)
func (h *WechatRobotHandler) HandleWechatStatus(c *gin.Context) {
wc := h.config.Robots.Wechat
bound := wc.BotToken != "" && wc.ILinkBotID != ""
c.JSON(http.StatusOK, gin.H{
"enabled": wc.Enabled,
"bound": bound,
"ilink_bot_id": wc.ILinkBotID,
"ilink_user_id": wc.ILinkUserID,
"base_url": wc.BaseURL,
})
}
+63 -25
View File
@@ -77,6 +77,9 @@ type einoADKRunLoopArgs struct {
StreamsMainAssistant func(agent string) bool
EinoRoleTag func(agent string) string
CheckpointDir string
// RunRetryMaxAttempts / RunRetryMaxBackoffSec429、5xx、网络抖动时的指数退避续跑(0=默认 10 次 / 30s 上限)。
RunRetryMaxAttempts int
RunRetryMaxBackoffSec int
McpIDsMu *sync.Mutex
McpIDs *[]string
@@ -177,6 +180,8 @@ func runEinoADKAgentLoop(ctx context.Context, args *einoADKRunLoopArgs, baseMsgs
var einoMainRound int
var einoLastAgent string
subAgentToolStep := make(map[string]int)
// mainAgentToolStep:主代理每次工具调用批次递增,供 UI 显示「第 N 轮」(单代理无子代理切换时原先会一直停在第 1 轮)。
mainAgentToolStep := make(map[string]int)
pendingByID := make(map[string]toolCallPendingInfo)
pendingQueueByAgent := make(map[string][]string)
markPending := func(tc toolCallPendingInfo) {
@@ -435,6 +440,28 @@ func runEinoADKAgentLoop(ctx context.Context, args *einoADKRunLoopArgs, baseMsgs
return runErr
}
// maybeRetryTransientRun:不在此层 runner.Run/Resume;由 handler 落库 + loadHistoryFromAgentTrace 分段续跑(同中断并继续)。
maybeRetryTransientRun := func(runErr error) (retry bool, fatal error) {
if runErr == nil || !isEinoTransientRunError(runErr) {
return false, handleRunErr(runErr)
}
if logger != nil {
logger.Warn("eino transient error, ending run segment for handler resume",
zap.Error(runErr),
zap.String("orchestration", orchMode))
}
if progress != nil {
progress("eino_run_retry", "遇到临时错误(限流或网络波动),将保存上下文并重试…", map[string]interface{}{
"conversationId": conversationID,
"source": "eino",
"orchestration": orchMode,
"error": runErr.Error(),
"resumeKind": "trace_segment",
})
}
return false, ErrTransientRetryContinue
}
takePartial := func(runErr error) (*RunResult, error) {
if len(runAccumulatedMsgs) <= baseAccumulatedCount {
return nil, runErr
@@ -517,7 +544,7 @@ func runEinoADKAgentLoop(ctx context.Context, args *einoADKRunLoopArgs, baseMsgs
continue
}
if ev.Err != nil {
if retErr := handleRunErr(ev.Err); retErr != nil {
if _, retErr := maybeRetryTransientRun(ev.Err); retErr != nil {
return takePartial(retErr)
}
}
@@ -529,8 +556,10 @@ func runEinoADKAgentLoop(ctx context.Context, args *einoADKRunLoopArgs, baseMsgs
}
}
if streamsMainAssistant(ev.AgentName) {
mainIterKey := einoMainIterationKey(iterEinoAgent, orchestratorName)
if einoMainRound == 0 {
einoMainRound = 1
mainAgentToolStep[mainIterKey] = 1
progress("iteration", "", map[string]interface{}{
"iteration": 1,
"einoScope": "main",
@@ -540,17 +569,26 @@ func runEinoADKAgentLoop(ctx context.Context, args *einoADKRunLoopArgs, baseMsgs
"conversationId": conversationID,
"source": "eino",
})
} else if einoLastAgent != "" && !streamsMainAssistant(einoLastAgent) {
einoMainRound++
progress("iteration", "", map[string]interface{}{
"iteration": einoMainRound,
"einoScope": "main",
"einoRole": "orchestrator",
"einoAgent": iterEinoAgent,
"orchestration": orchMode,
"conversationId": conversationID,
"source": "eino",
})
} else if einoLastAgent != "" {
needBump := false
if !streamsMainAssistant(einoLastAgent) {
needBump = true // 子代理 → 主代理
} else if einoLastAgent != ev.AgentName {
needBump = true // plan_executeplanner ↔ executor 等主代理切换
}
if needBump {
einoMainRound++
mainAgentToolStep[mainIterKey] = einoMainRound
progress("iteration", "", map[string]interface{}{
"iteration": einoMainRound,
"einoScope": "main",
"einoRole": "orchestrator",
"einoAgent": iterEinoAgent,
"orchestration": orchMode,
"conversationId": conversationID,
"source": "eino",
})
}
}
}
einoLastAgent = ev.AgentName
@@ -644,9 +682,9 @@ func runEinoADKAgentLoop(ctx context.Context, args *einoADKRunLoopArgs, baseMsgs
"orchestration": orchMode,
})
}
progress("reasoning_chain_stream_delta", displayDelta, map[string]interface{}{
progress("reasoning_chain_stream_delta", displayDelta, openai.WithSSEAccumulated(map[string]interface{}{
"streamId": reasoningStreamID,
})
}, fullDisplay))
}
}
}
@@ -676,13 +714,13 @@ func runEinoADKAgentLoop(ctx context.Context, args *einoADKRunLoopArgs, baseMsgs
})
streamHeaderSent = true
}
progress("response_delta", contentDelta, map[string]interface{}{
progress("response_delta", contentDelta, openai.WithSSEAccumulated(map[string]interface{}{
"conversationId": conversationID,
"mcpExecutionIds": snapshotMCPIDs(),
"einoRole": "orchestrator",
"einoAgent": ev.AgentName,
"orchestration": orchMode,
})
}, mainAssistantBuf))
mainAssistWireAccum, _ = normalizeStreamingDelta(mainAssistWireAccum, contentDelta)
}
}
@@ -701,10 +739,10 @@ func runEinoADKAgentLoop(ctx context.Context, args *einoADKRunLoopArgs, baseMsgs
"source": "eino",
})
}
progress("eino_agent_reply_stream_delta", subDelta, map[string]interface{}{
progress("eino_agent_reply_stream_delta", subDelta, openai.WithSSEAccumulated(map[string]interface{}{
"streamId": subReplyStreamID,
"conversationId": conversationID,
})
}, subAssistantBuf))
}
}
}
@@ -743,13 +781,13 @@ func runEinoADKAgentLoop(ctx context.Context, args *einoADKRunLoopArgs, baseMsgs
"orchestration": orchMode,
})
}
progress("response_delta", eofTail, map[string]interface{}{
progress("response_delta", eofTail, openai.WithSSEAccumulated(map[string]interface{}{
"conversationId": conversationID,
"mcpExecutionIds": snapshotMCPIDs(),
"einoRole": "orchestrator",
"einoAgent": ev.AgentName,
"orchestration": orchMode,
})
}, mainAssistantBuf))
mainAssistWireAccum, _ = normalizeStreamingDelta(mainAssistWireAccum, eofTail)
}
}
@@ -791,7 +829,7 @@ func runEinoADKAgentLoop(ctx context.Context, args *einoADKRunLoopArgs, baseMsgs
if merged := mergeStreamingToolCallFragments(toolStreamFragments); len(merged) > 0 {
lastToolChunk = mergeMessageToolCalls(&schema.Message{ToolCalls: merged})
}
tryEmitToolCallsOnce(lastToolChunk, ev.AgentName, orchestratorName, conversationID, progress, toolEmitSeen, subAgentToolStep, markPending)
tryEmitToolCallsOnce(lastToolChunk, ev.AgentName, orchestratorName, conversationID, orchMode, progress, toolEmitSeen, subAgentToolStep, mainAgentToolStep, markPending)
// 流式路径此前只把 tool_calls 推给进度 UI,未写入 runAccumulatedMsgs;落库后 loadHistory→RepairOrphan 会删掉全部 tool 结果,表现为「续跑/下轮失忆」。
if lastToolChunk != nil && len(lastToolChunk.ToolCalls) > 0 {
runAccumulatedMsgs = append(runAccumulatedMsgs, schema.AssistantMessage("", lastToolChunk.ToolCalls))
@@ -808,7 +846,7 @@ func runEinoADKAgentLoop(ctx context.Context, args *einoADKRunLoopArgs, baseMsgs
"einoRole": einoRoleTag(ev.AgentName),
})
}
if retErr := handleRunErr(streamRecvErr); retErr != nil {
if _, retErr := maybeRetryTransientRun(streamRecvErr); retErr != nil {
return takePartial(retErr)
}
}
@@ -820,7 +858,7 @@ func runEinoADKAgentLoop(ctx context.Context, args *einoADKRunLoopArgs, baseMsgs
continue
}
runAccumulatedMsgs = append(runAccumulatedMsgs, msg)
tryEmitToolCallsOnce(mergeMessageToolCalls(msg), ev.AgentName, orchestratorName, conversationID, progress, toolEmitSeen, subAgentToolStep, markPending)
tryEmitToolCallsOnce(mergeMessageToolCalls(msg), ev.AgentName, orchestratorName, conversationID, orchMode, progress, toolEmitSeen, subAgentToolStep, mainAgentToolStep, markPending)
if mv.Role == schema.Assistant {
if progress != nil && strings.TrimSpace(msg.ReasoningContent) != "" {
@@ -859,13 +897,13 @@ func runEinoADKAgentLoop(ctx context.Context, args *einoADKRunLoopArgs, baseMsgs
"einoAgent": ev.AgentName,
"orchestration": orchMode,
})
progress("response_delta", body, map[string]interface{}{
progress("response_delta", body, openai.WithSSEAccumulated(map[string]interface{}{
"conversationId": conversationID,
"mcpExecutionIds": snapshotMCPIDs(),
"einoRole": "orchestrator",
"einoAgent": ev.AgentName,
"orchestration": orchMode,
})
}, body))
}
lastAssistant = body
if orchMode == "plan_execute" && strings.EqualFold(strings.TrimSpace(ev.AgentName), "executor") {
+3 -2
View File
@@ -18,7 +18,6 @@ import (
einoopenai "github.com/cloudwego/eino-ext/components/model/openai"
"github.com/cloudwego/eino/adk"
"github.com/cloudwego/eino/compose"
"github.com/cloudwego/eino/schema"
"go.uber.org/zap"
)
@@ -213,7 +212,7 @@ func RunEinoSingleChatModelAgent(
}
baseMsgs := historyToMessages(history, appCfg, &ma.EinoMiddleware)
baseMsgs = append(baseMsgs, schema.UserMessage(userMessage))
baseMsgs = appendUserMessageIfNeeded(baseMsgs, userMessage)
streamsMainAssistant := func(agent string) bool {
return agent == "" || agent == einoSingleAgentName
@@ -233,6 +232,8 @@ func RunEinoSingleChatModelAgent(
StreamsMainAssistant: streamsMainAssistant,
EinoRoleTag: einoRoleTag,
CheckpointDir: ma.EinoMiddleware.CheckpointDir,
RunRetryMaxAttempts: ma.EinoMiddleware.RunRetryMaxAttempts,
RunRetryMaxBackoffSec: ma.EinoMiddleware.RunRetryMaxBackoffSec,
McpIDsMu: &mcpIDsMu,
McpIDs: &mcpIDs,
FilesystemMonitorAgent: ag,
+173
View File
@@ -0,0 +1,173 @@
package multiagent
import (
"context"
"errors"
"strings"
"time"
"cyberstrike-ai/internal/config"
"github.com/cloudwego/eino/adk"
"github.com/cloudwego/eino/schema"
)
const (
defaultEinoRunRetryMaxAttempts = 10
defaultEinoRunRetryMaxBackoff = 30 * time.Second
)
// isEinoTransientRunError 判断 ADK 运行期错误是否适合指数退避续跑(429、5xx、网络抖动等)。
// 用户取消、超时、迭代上限等由 run loop 单独处理,不在此列。
func isEinoTransientRunError(err error) bool {
if err == nil {
return false
}
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
return false
}
if isEinoIterationLimitError(err) {
return false
}
msg := strings.ToLower(strings.TrimSpace(err.Error()))
if msg == "" {
return false
}
transientMarkers := []string{
"406",
"429",
"too many requests",
"rate limit",
"rate_limit",
"ratelimit",
"quota exceeded",
"overloaded",
"capacity",
"temporarily unavailable",
"service unavailable",
"bad gateway",
"gateway timeout",
"internal server error",
"connection reset",
"connection refused",
"connection closed",
"i/o timeout",
"no such host",
"network is unreachable",
"broken pipe",
"eof",
"read tcp",
"write tcp",
"dial tcp",
"tls handshake timeout",
"stream error",
"unexpected eof",
"unexpected end of json",
"status code: 406",
"status code: 502",
"502",
"503",
"504",
"500",
}
for _, m := range transientMarkers {
if strings.Contains(msg, m) {
return true
}
}
return false
}
func einoRunRetryMaxAttempts(args *einoADKRunLoopArgs) int {
if args != nil && args.RunRetryMaxAttempts > 0 {
return args.RunRetryMaxAttempts
}
return defaultEinoRunRetryMaxAttempts
}
// RunRetryMaxAttemptsFromConfig 供 handler 分段续跑计数(与 eino_middleware.run_retry_max_attempts 一致)。
func RunRetryMaxAttemptsFromConfig(mw *config.MultiAgentEinoMiddlewareConfig) int {
if mw != nil && mw.RunRetryMaxAttempts > 0 {
return mw.RunRetryMaxAttempts
}
return defaultEinoRunRetryMaxAttempts
}
// TransientRetryBackoff 供 handler 在分段续跑前退避。
func TransientRetryBackoff(attempt int, maxBackoffSec int) time.Duration {
max := defaultEinoRunRetryMaxBackoff
if maxBackoffSec > 0 {
max = time.Duration(maxBackoffSec) * time.Second
}
return einoTransientRetryBackoff(attempt, max)
}
func einoRunRetryMaxBackoff(args *einoADKRunLoopArgs) time.Duration {
if args != nil && args.RunRetryMaxBackoffSec > 0 {
return time.Duration(args.RunRetryMaxBackoffSec) * time.Second
}
return defaultEinoRunRetryMaxBackoff
}
// einoRunRestartContextSource 描述无 checkpoint Resume 时 Run 使用的消息来源(日志/SSE)。
type einoRunRestartContextSource string
const (
einoRestartContextInitial einoRunRestartContextSource = "initial"
einoRestartContextAccumulated einoRunRestartContextSource = "accumulated"
einoRestartContextModelTrace einoRunRestartContextSource = "model_trace"
)
// einoMessagesForRunRestart 在退避后重新 Run 时选用最完整的上下文:
// 1) ModelFacingTrace(与模型实际入参一致) 2) 事件流累积的 runAccumulatedMsgs 3) 初始 msgs。
func einoMessagesForRunRestart(args *einoADKRunLoopArgs, baseMsgs, accumulated []adk.Message, baseCount int) ([]adk.Message, einoRunRestartContextSource) {
if trace := persistTraceSource(args, nil); len(trace) > 0 {
return append([]adk.Message(nil), trace...), einoRestartContextModelTrace
}
if len(accumulated) > baseCount {
return append([]adk.Message(nil), accumulated...), einoRestartContextAccumulated
}
return append([]adk.Message(nil), baseMsgs...), einoRestartContextInitial
}
// adkMessagesHasUserContent 从尾部向前查找,是否已有与 want 相同的 user 消息(避免重复 append)。
func adkMessagesHasUserContent(msgs []adk.Message, want string) bool {
want = strings.TrimSpace(want)
if want == "" {
return true
}
for i := len(msgs) - 1; i >= 0; i-- {
m := msgs[i]
if m == nil {
continue
}
if m.Role == schema.User {
return strings.TrimSpace(m.Content) == want
}
if m.Role == schema.Assistant || m.Role == schema.Tool {
continue
}
break
}
return false
}
// appendUserMessageIfNeeded 在 history 轨迹之后追加本轮 user 消息(仅当轨迹中尚未包含该句)。
func appendUserMessageIfNeeded(msgs []adk.Message, userMessage string) []adk.Message {
if strings.TrimSpace(userMessage) == "" || adkMessagesHasUserContent(msgs, userMessage) {
return msgs
}
return append(msgs, schema.UserMessage(userMessage))
}
// einoTransientRetryBackoff 指数退避:2s, 4s, 8s… capped by maxBackoff。
func einoTransientRetryBackoff(attempt int, maxBackoff time.Duration) time.Duration {
if attempt < 0 {
attempt = 0
}
backoff := time.Duration(1<<uint(attempt+1)) * time.Second
if maxBackoff > 0 && backoff > maxBackoff {
backoff = maxBackoff
}
return backoff
}
@@ -0,0 +1,104 @@
package multiagent
import (
"context"
"errors"
"testing"
"time"
"github.com/cloudwego/eino/adk"
"github.com/cloudwego/eino/schema"
)
func TestIsEinoTransientRunError(t *testing.T) {
t.Parallel()
cases := []struct {
name string
err error
want bool
}{
{"nil", nil, false},
{"429", errors.New("HTTP 429 Too Many Requests"), true},
{"rate limit", errors.New(`{"error":"rate limit exceeded"}`), true},
{"connection reset", errors.New("read tcp: connection reset by peer"), true},
{"503", errors.New("upstream returned 503"), true},
{"iteration limit", errors.New("max iteration reached"), false},
{"canceled", context.Canceled, false},
{"deadline", context.DeadlineExceeded, false},
{"auth", errors.New("invalid api key"), false},
}
for _, tc := range cases {
tc := tc
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
if got := isEinoTransientRunError(tc.err); got != tc.want {
t.Fatalf("isEinoTransientRunError(%v) = %v, want %v", tc.err, got, tc.want)
}
})
}
}
func TestEinoTransientRetryBackoff(t *testing.T) {
t.Parallel()
max := 30 * time.Second
if got := einoTransientRetryBackoff(0, max); got != 2*time.Second {
t.Fatalf("attempt 0: got %v", got)
}
if got := einoTransientRetryBackoff(4, max); got != 30*time.Second {
t.Fatalf("attempt 4 capped: got %v", got)
}
}
func TestEinoMessagesForRunRestart(t *testing.T) {
t.Parallel()
base := []adk.Message{schema.UserMessage("hi")}
acc := append([]adk.Message(nil), base...)
acc = append(acc, schema.AssistantMessage("step1", nil))
got, src := einoMessagesForRunRestart(nil, base, acc, len(base))
if src != einoRestartContextAccumulated || len(got) != 2 {
t.Fatalf("accumulated: src=%s len=%d", src, len(got))
}
holder := newModelFacingTraceHolder()
holder.storeFromState(&adk.ChatModelAgentState{
Messages: []adk.Message{schema.UserMessage("u"), schema.AssistantMessage("model-view", nil)},
})
got2, src2 := einoMessagesForRunRestart(&einoADKRunLoopArgs{ModelFacingTrace: holder}, base, acc, len(base))
if src2 != einoRestartContextModelTrace || len(got2) != 2 {
t.Fatalf("model trace: src=%s len=%d", src2, len(got2))
}
}
func TestEinoRunRetryMaxAttemptsFromArgs(t *testing.T) {
t.Parallel()
if einoRunRetryMaxAttempts(nil) != defaultEinoRunRetryMaxAttempts {
t.Fatal("nil args should use default")
}
if einoRunRetryMaxAttempts(&einoADKRunLoopArgs{RunRetryMaxAttempts: 3}) != 3 {
t.Fatal("custom max attempts")
}
if RunRetryMaxAttemptsFromConfig(nil) != defaultEinoRunRetryMaxAttempts {
t.Fatal("config nil should use default")
}
}
func TestAppendUserMessageIfNeeded(t *testing.T) {
t.Parallel()
msgs := []adk.Message{schema.UserMessage("old task")}
out := appendUserMessageIfNeeded(msgs, "你好,你是谁")
if len(out) != 2 || out[1].Content != "你好,你是谁" {
t.Fatalf("should append user: len=%d", len(out))
}
dup := appendUserMessageIfNeeded(out, "你好,你是谁")
if len(dup) != 2 {
t.Fatalf("should not duplicate user message: len=%d", len(dup))
}
}
func TestErrTransientRetryContinue(t *testing.T) {
t.Parallel()
if !errors.Is(ErrTransientRetryContinue, ErrTransientRetryContinue) {
t.Fatal("sentinel should match")
}
}
+4
View File
@@ -5,3 +5,7 @@ import "errors"
// ErrInterruptContinue 作为 context.CancelCause 使用:用户选择「中断并继续」且当前无进行中的 MCP 工具时,
// 取消当前推理/流式输出,并在同一会话任务内携带用户补充说明自动续跑下一轮(类似 Hermes 式人机回合)。
var ErrInterruptContinue = errors.New("agent interrupt: continue with user-supplied context")
// ErrTransientRetryContinue 表示 Run 因 429/网络等临时错误结束,应由 handler 落库轨迹后
// loadHistoryFromAgentTrace 再开下一轮 Run(与 ErrInterruptContinue 同级的「分段续跑」语义)。
var ErrTransientRetryContinue = errors.New("agent transient: retry after persisting trace")
+42 -6
View File
@@ -538,7 +538,7 @@ func RunDeepAgent(
}
baseMsgs := historyToMessages(history, appCfg, &ma.EinoMiddleware)
baseMsgs = append(baseMsgs, schema.UserMessage(userMessage))
baseMsgs = appendUserMessageIfNeeded(baseMsgs, userMessage)
streamsMainAssistant := func(agent string) bool {
if orchMode == "plan_execute" {
@@ -566,6 +566,8 @@ func RunDeepAgent(
StreamsMainAssistant: streamsMainAssistant,
EinoRoleTag: einoRoleTag,
CheckpointDir: ma.EinoMiddleware.CheckpointDir,
RunRetryMaxAttempts: ma.EinoMiddleware.RunRetryMaxAttempts,
RunRetryMaxBackoffSec: ma.EinoMiddleware.RunRetryMaxBackoffSec,
McpIDsMu: &mcpIDsMu,
McpIDs: &mcpIDs,
FilesystemMonitorAgent: ag,
@@ -595,6 +597,13 @@ func chatToolCallsToSchema(tcs []agent.ToolCall) []schema.ToolCall {
argsStr = string(b)
}
}
// Some OpenAI-compatible gateways require `function.arguments` to exist
// on every assistant tool_call message. When args are empty, omitempty may
// drop the field during serialization and cause "missing field arguments"
// on the next turn history replay.
if strings.TrimSpace(argsStr) == "" {
argsStr = "{}"
}
typ := tc.Type
if typ == "" {
typ = "function"
@@ -737,12 +746,23 @@ func toolCallsRichSignature(msg *schema.Message) string {
return base + "|" + strings.Join(parts, ";")
}
func einoMainIterationKey(agentName, orchestratorName string) string {
key := strings.TrimSpace(agentName)
if key == "" {
key = strings.TrimSpace(orchestratorName)
}
if key == "" {
return "_main"
}
return key
}
func tryEmitToolCallsOnce(
msg *schema.Message,
agentName, orchestratorName, conversationID string,
agentName, orchestratorName, conversationID, orchMode string,
progress func(string, string, interface{}),
seen map[string]struct{},
subAgentToolStep map[string]int,
subAgentToolStep, mainAgentToolStep map[string]int,
markPending func(toolCallPendingInfo),
) {
if msg == nil || len(msg.ToolCalls) == 0 || progress == nil || seen == nil {
@@ -756,14 +776,14 @@ func tryEmitToolCallsOnce(
return
}
seen[sig] = struct{}{}
emitToolCallsFromMessage(msg, agentName, orchestratorName, conversationID, progress, subAgentToolStep, markPending)
emitToolCallsFromMessage(msg, agentName, orchestratorName, conversationID, orchMode, progress, subAgentToolStep, mainAgentToolStep, markPending)
}
func emitToolCallsFromMessage(
msg *schema.Message,
agentName, orchestratorName, conversationID string,
agentName, orchestratorName, conversationID, orchMode string,
progress func(string, string, interface{}),
subAgentToolStep map[string]int,
subAgentToolStep, mainAgentToolStep map[string]int,
markPending func(toolCallPendingInfo),
) {
if msg == nil || len(msg.ToolCalls) == 0 || progress == nil {
@@ -784,6 +804,22 @@ func emitToolCallsFromMessage(
"conversationId": conversationID,
"source": "eino",
})
} else if mainAgentToolStep != nil {
key := einoMainIterationKey(agentName, orchestratorName)
mainAgentToolStep[key]++
n := mainAgentToolStep[key]
// 第 1 轮已在主代理进入时发出;此后每次工具批次对应新一轮 ReAct(与子代理按工具计步一致)。
if n > 1 {
progress("iteration", "", map[string]interface{}{
"iteration": n,
"einoScope": "main",
"einoRole": "orchestrator",
"einoAgent": agentName,
"orchestration": orchMode,
"conversationId": conversationID,
"source": "eino",
})
}
}
role := "orchestrator"
if isSubToolRound {
+20
View File
@@ -0,0 +1,20 @@
package openai
// SSEAccumulatedKey 为 SSE progress 事件 data 中的服务端权威流式全文快照字段。
// 前端应优先用该字段更新 buffer,避免对 delta 二次 normalize 导致叠字。
const SSEAccumulatedKey = "accumulated"
// WithSSEAccumulated 在 progress data 中附带当前流式累计全文(权威快照)。
func WithSSEAccumulated(data map[string]interface{}, accumulated string) map[string]interface{} {
if data == nil {
data = make(map[string]interface{}, 1)
}
data[SSEAccumulatedKey] = accumulated
return data
}
// NormalizeStreamingDelta 将可能是“累计片段/重发片段”的内容归一化为“纯增量”。
// 与 unexported normalizeStreamingDelta 相同,供 agent / multiagent 等包在发 SSE 前累计正文。
func NormalizeStreamingDelta(current, incoming string) (next, delta string) {
return normalizeStreamingDelta(current, incoming)
}
+9 -4
View File
@@ -149,13 +149,18 @@ func effectiveEffort(sr *config.OpenAIReasoningConfig, client *ClientIntent, all
func normalizeEffort(s string) string {
e := strings.ToLower(strings.TrimSpace(s))
switch e {
case "low", "medium", "high", "max":
case "low", "medium", "high", "max", "xhigh":
return e
default:
return ""
}
}
// usesExtraFieldsReasoningEffort 为 Eino 无枚举的最高档 effort,经 ExtraFields 原样下发(max / xhigh 由网关自行识别,不做互转)。
func usesExtraFieldsReasoningEffort(e string) bool {
return e == "max" || e == "xhigh"
}
func resolveWireProfile(oa *config.OpenAIConfig, sr *config.OpenAIReasoningConfig) wireProfile {
if strings.EqualFold(strings.TrimSpace(oa.Provider), "claude") {
return wireClaude
@@ -210,11 +215,11 @@ func applyOpenAICompat(cfg *einoopenai.ChatModelConfig, mode, effort string) {
if e == "" {
return
}
if e == "max" {
if usesExtraFieldsReasoningEffort(e) {
if cfg.ExtraFields == nil {
cfg.ExtraFields = make(map[string]any)
}
cfg.ExtraFields["reasoning_effort"] = "max"
cfg.ExtraFields["reasoning_effort"] = effortStringForAPI(e)
return
}
switch e {
@@ -245,6 +250,6 @@ func applyOutputConfigEffort(cfg *einoopenai.ChatModelConfig, mode, effort strin
}
func effortStringForAPI(e string) string {
// Gateways expect lowercase strings; "max" kept as max.
// 原样透传:OpenAI 官方多为 xhigh,部分兼容网关为 max,由配置/对话 effort 选择。
return strings.ToLower(strings.TrimSpace(e))
}
+66
View File
@@ -0,0 +1,66 @@
package reasoning
import (
"testing"
"cyberstrike-ai/internal/config"
einoopenai "github.com/cloudwego/eino-ext/components/model/openai"
)
func TestEffortStringForAPI_passthrough(t *testing.T) {
cases := map[string]string{
"max": "max",
"xhigh": "xhigh",
"HIGH": "high",
"Medium": "medium",
}
for in, want := range cases {
if got := effortStringForAPI(in); got != want {
t.Fatalf("%q -> %q, want %q", in, got, want)
}
}
}
func TestNormalizeEffort_maxAndXhigh(t *testing.T) {
if normalizeEffort("xhigh") != "xhigh" {
t.Fatal("xhigh not accepted")
}
if normalizeEffort("max") != "max" {
t.Fatal("max not accepted")
}
}
func TestApplyOpenAICompat_xhighExtraField(t *testing.T) {
cfg := &einoopenai.ChatModelConfig{}
oa := &config.OpenAIConfig{
Reasoning: config.OpenAIReasoningConfig{
Profile: "openai_compat",
Mode: "on",
Effort: "xhigh",
},
}
ApplyToEinoChatModelConfig(cfg, oa, nil)
if cfg.ExtraFields == nil {
t.Fatal("expected ExtraFields")
}
if got, _ := cfg.ExtraFields["reasoning_effort"].(string); got != "xhigh" {
t.Fatalf("reasoning_effort=%q", got)
}
}
func TestApplyOpenAICompat_maxPassthrough(t *testing.T) {
cfg := &einoopenai.ChatModelConfig{}
oa := &config.OpenAIConfig{
Reasoning: config.OpenAIReasoningConfig{
Profile: "openai_compat",
Mode: "on",
Effort: "max",
},
}
ApplyToEinoChatModelConfig(cfg, oa, nil)
got, _ := cfg.ExtraFields["reasoning_effort"].(string)
if got != "max" {
t.Fatalf("max effort wire=%q, want max", got)
}
}
+316
View File
@@ -0,0 +1,316 @@
package ilink
import (
"bytes"
"context"
"crypto/rand"
"encoding/base64"
"encoding/json"
"fmt"
"io"
"net/http"
"net/url"
"strconv"
"strings"
"time"
)
const (
DefaultBaseURL = "https://ilinkai.weixin.qq.com"
DefaultBotType = "3"
DefaultBotAgent = "CyberStrikeAI/1.0"
ILinkAppID = "bot"
QRLongPollTimeout = 35 * time.Second
APIDefaultTimeout = 15 * time.Second
GetUpdatesTimeout = 35 * time.Second
)
// Client 微信 iLink Bot HTTP 客户端(与 @tencent-weixin/openclaw-weixin 协议兼容)
type Client struct {
BaseURL string
BotToken string
BotAgent string
ClientVersion uint32
HTTP *http.Client
}
func NewClient(baseURL, botToken, botAgent string, clientVersion uint32) *Client {
base := strings.TrimSpace(baseURL)
if base == "" {
base = DefaultBaseURL
}
agent := strings.TrimSpace(botAgent)
if agent == "" {
agent = DefaultBotAgent
}
return &Client{
BaseURL: strings.TrimRight(base, "/"),
BotToken: strings.TrimSpace(botToken),
BotAgent: sanitizeBotAgent(agent),
ClientVersion: clientVersion,
HTTP: &http.Client{Timeout: 0},
}
}
// BuildClientVersion 将 semver 编码为 iLink-App-ClientVersion0x00MMNNPP
func BuildClientVersion(version string) uint32 {
parts := strings.Split(version, ".")
parse := func(i int) int {
if i >= len(parts) {
return 0
}
n, _ := strconv.Atoi(strings.TrimSpace(parts[i]))
if n < 0 {
return 0
}
return n
}
major := parse(0) & 0xff
minor := parse(1) & 0xff
patch := parse(2) & 0xff
return uint32((major << 16) | (minor << 8) | patch)
}
type baseInfo struct {
ChannelVersion string `json:"channel_version"`
BotAgent string `json:"bot_agent"`
}
func (c *Client) buildBaseInfo() baseInfo {
return baseInfo{
ChannelVersion: "1.0.0",
BotAgent: c.BotAgent,
}
}
func randomWechatUIN() string {
var b [4]byte
_, _ = rand.Read(b[:])
u := uint32(b[0])<<24 | uint32(b[1])<<16 | uint32(b[2])<<8 | uint32(b[3])
return base64.StdEncoding.EncodeToString([]byte(strconv.FormatUint(uint64(u), 10)))
}
func (c *Client) commonHeaders() http.Header {
h := http.Header{}
h.Set("iLink-App-Id", ILinkAppID)
h.Set("iLink-App-ClientVersion", strconv.FormatUint(uint64(c.ClientVersion), 10))
return h
}
func (c *Client) authHeaders() http.Header {
h := c.commonHeaders()
h.Set("Content-Type", "application/json")
h.Set("AuthorizationType", "ilink_bot_token")
h.Set("X-WECHAT-UIN", randomWechatUIN())
if c.BotToken != "" {
h.Set("Authorization", "Bearer "+c.BotToken)
}
return h
}
func (c *Client) endpointURL(path string) (string, error) {
u, err := url.Parse(c.BaseURL + "/")
if err != nil {
return "", err
}
ref, err := url.Parse(path)
if err != nil {
return "", err
}
return u.ResolveReference(ref).String(), nil
}
func (c *Client) doRequest(ctx context.Context, method, path string, body []byte, headers http.Header, timeout time.Duration) ([]byte, error) {
reqURL, err := c.endpointURL(path)
if err != nil {
return nil, err
}
var bodyReader io.Reader
if len(body) > 0 {
bodyReader = bytes.NewReader(body)
}
req, err := http.NewRequestWithContext(ctx, method, reqURL, bodyReader)
if err != nil {
return nil, err
}
for k, vs := range headers {
for _, v := range vs {
req.Header.Add(k, v)
}
}
client := c.HTTP
if client == nil {
client = http.DefaultClient
}
if timeout > 0 {
ctx2, cancel := context.WithTimeout(ctx, timeout)
defer cancel()
req = req.WithContext(ctx2)
}
resp, err := client.Do(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()
raw, err := io.ReadAll(resp.Body)
if err != nil {
return nil, err
}
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
return nil, fmt.Errorf("ilink %s %s: %d %s", method, path, resp.StatusCode, string(raw))
}
return raw, nil
}
// QRCodeResponse 获取二维码响应
type QRCodeResponse struct {
QRCode string `json:"qrcode"`
QRCodeImgContent string `json:"qrcode_img_content"`
}
// GetBotQRCode 获取绑定二维码
func (c *Client) GetBotQRCode(ctx context.Context, botType string, localTokenList []string) (*QRCodeResponse, error) {
if strings.TrimSpace(botType) == "" {
botType = DefaultBotType
}
body, _ := json.Marshal(map[string]interface{}{
"local_token_list": localTokenList,
})
path := "ilink/bot/get_bot_qrcode?bot_type=" + url.QueryEscape(botType)
raw, err := c.doRequest(ctx, http.MethodPost, path, body, c.authHeaders(), APIDefaultTimeout)
if err != nil {
return nil, err
}
var out QRCodeResponse
if err := json.Unmarshal(raw, &out); err != nil {
return nil, err
}
return &out, nil
}
// QRStatusResponse 二维码状态轮询响应
type QRStatusResponse struct {
Status string `json:"status"`
BotToken string `json:"bot_token"`
ILinkBotID string `json:"ilink_bot_id"`
ILinkUserID string `json:"ilink_user_id"`
BaseURL string `json:"baseurl"`
RedirectHost string `json:"redirect_host"`
}
// GetQRCodeStatus 长轮询二维码扫码状态
func (c *Client) GetQRCodeStatus(ctx context.Context, qrcode, verifyCode string) (*QRStatusResponse, error) {
path := "ilink/bot/get_qrcode_status?qrcode=" + url.QueryEscape(qrcode)
if verifyCode != "" {
path += "&verify_code=" + url.QueryEscape(verifyCode)
}
raw, err := c.doRequest(ctx, http.MethodGet, path, nil, c.commonHeaders(), QRLongPollTimeout)
if err != nil {
if ctx.Err() != nil {
return &QRStatusResponse{Status: "wait"}, nil
}
return &QRStatusResponse{Status: "wait"}, nil
}
var out QRStatusResponse
if err := json.Unmarshal(raw, &out); err != nil {
return nil, err
}
return &out, nil
}
// MessageItem 消息内容项
type MessageItem struct {
Type int `json:"type"`
TextItem *struct {
Text string `json:"text"`
} `json:"text_item,omitempty"`
}
// WeixinMessage 入站消息
type WeixinMessage struct {
FromUserID string `json:"from_user_id"`
MessageType int `json:"message_type"`
MessageState int `json:"message_state"`
ItemList []MessageItem `json:"item_list"`
ContextToken string `json:"context_token"`
}
// GetUpdatesResponse 长轮询消息响应
type GetUpdatesResponse struct {
Ret int `json:"ret"`
ErrCode int `json:"errcode"`
ErrMsg string `json:"errmsg"`
Msgs []WeixinMessage `json:"msgs"`
GetUpdatesBuf string `json:"get_updates_buf"`
LongPollingTimeoutMs int `json:"longpolling_timeout_ms"`
}
// GetUpdates 长轮询获取新消息
func (c *Client) GetUpdates(ctx context.Context, getUpdatesBuf string) (*GetUpdatesResponse, error) {
body, _ := json.Marshal(map[string]interface{}{
"get_updates_buf": getUpdatesBuf,
"base_info": c.buildBaseInfo(),
})
raw, err := c.doRequest(ctx, http.MethodPost, "ilink/bot/getupdates", body, c.authHeaders(), GetUpdatesTimeout)
if err != nil {
if ctx.Err() != nil {
return &GetUpdatesResponse{Ret: 0, GetUpdatesBuf: getUpdatesBuf}, nil
}
return &GetUpdatesResponse{Ret: 0, GetUpdatesBuf: getUpdatesBuf}, nil
}
var out GetUpdatesResponse
if err := json.Unmarshal(raw, &out); err != nil {
return nil, err
}
return &out, nil
}
// SendTextMessage 发送文本回复
func (c *Client) SendTextMessage(ctx context.Context, toUserID, contextToken, text, clientID string) error {
if clientID == "" {
clientID = randomClientID()
}
payload := map[string]interface{}{
"msg": map[string]interface{}{
"to_user_id": toUserID,
"client_id": clientID,
"message_type": 2,
"message_state": 2,
"context_token": contextToken,
"item_list": []map[string]interface{}{
{"type": 1, "text_item": map[string]string{"text": text}},
},
},
"base_info": c.buildBaseInfo(),
}
body, _ := json.Marshal(payload)
_, err := c.doRequest(ctx, http.MethodPost, "ilink/bot/sendmessage", body, c.authHeaders(), APIDefaultTimeout)
return err
}
func randomClientID() string {
var b [8]byte
_, _ = rand.Read(b[:])
return fmt.Sprintf("%x", b)
}
func sanitizeBotAgent(raw string) string {
raw = strings.TrimSpace(raw)
if raw == "" {
return DefaultBotAgent
}
if len(raw) > 256 {
return raw[:256]
}
return raw
}
// ExtractText 从消息中提取首条文本
func ExtractText(msg WeixinMessage) string {
for _, item := range msg.ItemList {
if item.Type == 1 && item.TextItem != nil {
return strings.TrimSpace(item.TextItem.Text)
}
}
return ""
}
+26
View File
@@ -0,0 +1,26 @@
package ilink
import (
"encoding/base64"
"fmt"
"strings"
"github.com/skip2/go-qrcode"
)
// QRCodeDataURL 将扫码内容(一般为 liteapp 链接)编码为 PNG data URL,供 Web 端展示。
// qrcode_img_content 不是图片直链,不能用作 <img src>。
func QRCodeDataURL(content string, size int) (string, error) {
content = strings.TrimSpace(content)
if content == "" {
return "", fmt.Errorf("empty qr content")
}
if size <= 0 {
size = 256
}
png, err := qrcode.Encode(content, qrcode.Medium, size)
if err != nil {
return "", err
}
return "data:image/png;base64," + base64.StdEncoding.EncodeToString(png), nil
}
+96
View File
@@ -0,0 +1,96 @@
package robot
import (
"context"
"strings"
"time"
"cyberstrike-ai/internal/config"
"cyberstrike-ai/internal/robot/ilink"
"go.uber.org/zap"
)
const (
wechatReconnectInitial = 5 * time.Second
wechatReconnectMax = 60 * time.Second
wechatPlatform = "wechat"
)
// StartWechat 启动微信 iLink 长轮询(无需公网回调),收到消息后调用 handler 并回复。
func StartWechat(ctx context.Context, robotsCfg config.RobotsConfig, h MessageHandler, appVersion string, logger *zap.Logger) {
cfg := robotsCfg.Wechat
if !cfg.Enabled || cfg.BotToken == "" {
return
}
go runWechatLoop(ctx, cfg, h, appVersion, logger)
}
func runWechatLoop(ctx context.Context, cfg config.RobotWechatConfig, h MessageHandler, appVersion string, logger *zap.Logger) {
backoff := wechatReconnectInitial
for {
err := runWechatPoll(ctx, cfg, h, appVersion, logger)
if ctx.Err() != nil {
logger.Info("微信 iLink 长轮询已按配置关闭")
return
}
if err != nil {
logger.Warn("微信 iLink 长轮询异常,将自动重连", zap.Error(err), zap.Duration("retry_after", backoff))
}
select {
case <-ctx.Done():
return
case <-time.After(backoff):
if backoff < wechatReconnectMax {
backoff *= 2
if backoff > wechatReconnectMax {
backoff = wechatReconnectMax
}
}
}
}
}
func runWechatPoll(ctx context.Context, cfg config.RobotWechatConfig, h MessageHandler, appVersion string, logger *zap.Logger) error {
client := ilink.NewClient(cfg.BaseURL, cfg.BotToken, cfg.BotAgent, ilink.BuildClientVersion(appVersion))
buf := cfg.GetUpdatesBuf
logger.Info("微信 iLink 长轮询已启动", zap.String("ilink_bot_id", cfg.ILinkBotID))
for {
select {
case <-ctx.Done():
return ctx.Err()
default:
}
resp, err := client.GetUpdates(ctx, buf)
if err != nil {
return err
}
if resp.ErrCode != 0 && resp.Ret != 0 {
logger.Warn("微信 getUpdates 返回错误", zap.Int("errcode", resp.ErrCode), zap.String("errmsg", resp.ErrMsg))
}
if resp.GetUpdatesBuf != "" {
buf = resp.GetUpdatesBuf
}
for _, msg := range resp.Msgs {
if msg.MessageType != 1 {
continue
}
text := ilink.ExtractText(msg)
if text == "" {
continue
}
userID := strings.TrimSpace(msg.FromUserID)
if userID == "" {
continue
}
logger.Info("微信收到消息", zap.String("from", userID), zap.String("content", text))
reply := h.HandleMessage(wechatPlatform, userID, text)
if strings.TrimSpace(reply) == "" {
continue
}
if err := client.SendTextMessage(ctx, userID, msg.ContextToken, reply, ""); err != nil {
logger.Warn("微信发送回复失败", zap.String("to", userID), zap.Error(err))
}
}
}
}
@@ -4,7 +4,7 @@
### What it does
- Configure **Host / Port / Password** and choose **Single-Agent** or **Multi-Agent**
- Configure **Host / Port / HTTPS / Password** and choose an agent mode
- Click **Validate** to login (`POST /api/auth/login`) and verify token (`GET /api/auth/validate`)
- Right-click any HTTP message in Burp and send it to CyberStrikeAI for **streaming web pentest**
- Keep a **test history sidebar** (searchable) so you can revisit previous runs
@@ -63,6 +63,7 @@ If you already have Gradle available, you can still use `build.gradle` to build.
### Notes
- This extension connects to your CyberStrikeAI server (default is `http://127.0.0.1:8080`).
- Default connection is `https://127.0.0.1:8080` (**HTTPS** checked). Self-signed / local certs are trusted automatically (no import).
- Uncheck **HTTPS** only if your server runs plain HTTP.
- It uses **Bearer Token** authentication obtained from the configured password.
@@ -81,7 +81,8 @@ cd plugins/burp-suite/cyberstrikeai-burp-extension
2) 填写:
- **Host**:例如 `127.0.0.1`
- **Port**:例如 `8080`
- **Password**:你的 CyberStrikeAI 登录密码(对应服务端 `config.yaml` `auth.password`
- **HTTPS**:默认勾选(对接 `config.yaml` `tls_enabled` / 自签证书);插件会自动信任本地自签证书,无需导入
- **Password**:你的 CyberStrikeAI 登录密码(对应服务端 `auth.password`
- **Agent mode**:选择 `Single Agent``Multi Agent`
3) 点击 **Validate**
- 成功:状态显示 `OK (token saved)`
@@ -94,8 +95,9 @@ cd plugins/burp-suite/cyberstrikeai-burp-extension
- **Validate 失败 / 401**
- 确认密码是否正确(服务端 `auth.password`
- 确认 IP/端口是否能访问(例如浏览器能打开 `http://IP:PORT/`
- 服务启用了反向代理/HTTPS,需要把插件里 baseUrl 改成对应协议与端口(当前插件默认使用 `http://`
- 确认 IP/端口是否能访问(例如浏览器能打开 `https://IP:PORT/`
- 服务启用 TLS 时勾选 **HTTPS**(默认已勾选);自签证书无需手动导入
- 若仍为纯 HTTP 部署,取消勾选 **HTTPS**
- **选择 Multi Agent 后提示“多代理未启用”**
- 服务端需要开启:`config.yaml``multi_agent.enabled: true`
@@ -73,15 +73,34 @@ public class BurpExtender implements IBurpExtender, IContextMenuFactory {
public void onEvent(String type, String message, String rawJson) {
if (type == null) type = "";
switch (type) {
case "response_start":
tab.appendProgressToRun(runId, "\n\n[主回复]\n");
break;
case "response_delta":
case "eino_agent_reply_stream_delta":
tab.appendFinalToRun(runId, message);
if (message != null && !message.isEmpty()) {
tab.appendFinalToRun(runId, message);
}
break;
case "response":
tab.appendFinalToRun(runId, "\n\n--- Final Response ---\n");
tab.appendFinalToRun(runId, message);
tab.setFinalResponse(runId, message);
break;
case "eino_agent_reply_stream_start":
tab.appendProgressToRun(runId, "\n\n[子代理回复]\n");
break;
case "eino_agent_reply_stream_delta":
if (message != null && !message.isEmpty()) {
tab.appendProgressToRun(runId, message);
}
break;
case "eino_agent_reply_stream_end":
tab.appendProgressToRun(runId, "\n");
break;
case "eino_agent_reply":
if (message != null && !message.isEmpty()) {
tab.appendProgressToRun(runId, "\n\n[子代理回复]\n" + message + "\n");
}
break;
case "progress":
tab.appendProgressToRun(runId, "\n[progress] " + message + "\n");
tab.setRunStatus(runId, "running");
@@ -94,21 +113,40 @@ public class BurpExtender implements IBurpExtender, IContextMenuFactory {
tab.appendProgressToRun(runId, "\n[error] " + message + "\n");
tab.setRunStatus(runId, "error");
break;
case "reasoning_chain_stream_start":
tab.appendProgressToRun(runId, "\n\n[推理过程]\n");
break;
case "reasoning_chain_stream_delta":
if (message != null && !message.isEmpty()) {
tab.appendProgressToRun(runId, message);
}
break;
case "reasoning_chain_stream_end":
tab.appendProgressToRun(runId, "\n");
break;
case "reasoning_chain":
if (message != null && !message.isEmpty()) {
String streamId = rawJson != null ? SimpleJson.extractStringField(rawJson, "streamId") : "";
if (streamId == null || streamId.isEmpty()) {
tab.appendProgressToRun(runId, "\n\n[推理过程]\n" + message + "\n");
}
}
break;
case "thinking_stream_start":
if (tab.isShowDebugEvents()) {
tab.resetThinkingStream(runId);
}
break;
case "thinking_stream_delta":
if (tab.isShowDebugEvents() && message != null && !message.isEmpty()) {
tab.appendProgressToRun(runId, message);
}
break;
case "tool_call":
case "tool_result":
case "tool_result_delta":
if (tab.isShowDebugEvents() && message != null && !message.isEmpty()) {
if ("thinking_stream_delta".equals(type)) {
tab.appendThinkingDelta(runId, message);
} else {
tab.appendProgressToRun(runId, "\n[" + type + "] " + message + "\n");
}
tab.appendProgressToRun(runId, "\n[" + type + "] " + message + "\n");
}
break;
case "conversation":
@@ -125,7 +163,9 @@ public class BurpExtender implements IBurpExtender, IContextMenuFactory {
case "done":
break;
default:
if (tab.isShowDebugEvents() && message != null && !message.isEmpty()) {
if (tab.isShowDebugEvents() && message != null && !message.isEmpty()
&& !type.endsWith("_stream_delta") && !type.endsWith("_stream_start")
&& !type.endsWith("_stream_end")) {
tab.appendProgressToRun(runId, "\n[" + type + "] " + message + "\n");
}
break;
@@ -134,8 +174,9 @@ public class BurpExtender implements IBurpExtender, IContextMenuFactory {
@Override
public void onError(String message, Exception e) {
tab.appendProgressToRun(runId, "\n[error] " + message + "\n");
tab.setRunStatus(runId, "error");
boolean cancelled = message != null && message.toLowerCase().contains("cancel");
tab.appendProgressToRun(runId, cancelled ? "\n[info] " + message + "\n" : "\n[error] " + message + "\n");
tab.setRunStatus(runId, cancelled ? "cancelled" : "error");
callbacks.printError("CyberStrikeAI stream error: " + message);
if (e != null) {
callbacks.printError(e.toString());
@@ -2,17 +2,29 @@ package burp;
import java.io.BufferedReader;
import java.io.IOException;
import java.io.InterruptedIOException;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.io.OutputStream;
import java.net.HttpURLConnection;
import java.net.SocketTimeoutException;
import java.net.URL;
import java.nio.charset.StandardCharsets;
import java.util.HashMap;
import java.util.Map;
import java.util.concurrent.atomic.AtomicReference;
final class CyberStrikeAIClient {
private static final int AUTH_CONNECT_TIMEOUT_MS = 4_000;
private static final int AUTH_READ_TIMEOUT_MS = 5_000;
/** login + validate 整段上限,避免两次读超时叠加拖到半分钟 */
private static final int AUTH_OVERALL_TIMEOUT_MS = 10_000;
private static final int DEFAULT_READ_TIMEOUT_MS = 15_000;
private final AtomicReference<HttpURLConnection> activeConnection = new AtomicReference<>();
private final AtomicReference<Thread> activeThread = new AtomicReference<>();
static final class Config {
final String baseUrl; // e.g. http://127.0.0.1:8080
final String password;
@@ -49,15 +61,97 @@ final class CyberStrikeAIClient {
void onDone();
}
boolean hasActiveRequest() {
return activeConnection.get() != null;
}
void cancelActiveRequest() {
HttpURLConnection conn = activeConnection.getAndSet(null);
if (conn != null) {
try {
conn.disconnect();
} catch (Exception ignored) {
}
}
Thread t = activeThread.getAndSet(null);
if (t != null) {
t.interrupt();
}
}
String loginAndValidate(Config cfg) throws IOException {
String token = login(cfg.baseUrl, cfg.password);
validate(cfg.baseUrl, token);
return token;
Thread worker = Thread.currentThread();
java.util.Timer deadline = new java.util.Timer("CyberStrikeAI-AuthDeadline", true);
deadline.schedule(new java.util.TimerTask() {
@Override
public void run() {
worker.interrupt();
cancelActiveRequest();
}
}, AUTH_OVERALL_TIMEOUT_MS);
try {
String token = login(cfg.baseUrl, cfg.password);
if (Thread.interrupted()) {
throw timeoutIOException();
}
validate(cfg.baseUrl, token);
if (Thread.interrupted()) {
throw timeoutIOException();
}
return token;
} catch (SocketTimeoutException e) {
throw timeoutIOException();
} finally {
deadline.cancel();
}
}
private static IOException timeoutIOException() {
return new IOException("Connection timed out (~" + (AUTH_OVERALL_TIMEOUT_MS / 1000)
+ "s). Check host/port and HTTPS checkbox.");
}
private void trackConnection(HttpURLConnection conn) {
activeThread.set(Thread.currentThread());
activeConnection.set(conn);
}
private void releaseConnection(HttpURLConnection conn) {
if (activeConnection.compareAndSet(conn, null)) {
activeThread.set(null);
}
}
private static boolean isCancelled(Throwable e) {
if (e == null) {
return Thread.currentThread().isInterrupted();
}
if (Thread.currentThread().isInterrupted()) {
return true;
}
if (e instanceof InterruptedIOException) {
return true;
}
if (e instanceof SocketTimeoutException) {
return false;
}
Throwable cause = e.getCause();
if (cause != null && cause != e) {
return isCancelled(cause);
}
String msg = e.getMessage();
return msg != null && (
msg.toLowerCase().contains("cancel")
|| msg.toLowerCase().contains("abort")
|| msg.toLowerCase().contains("closed")
);
}
private String login(String baseUrl, String password) throws IOException {
URL url = new URL(baseUrl + "/api/auth/login");
HttpURLConnection conn = (HttpURLConnection) url.openConnection();
HttpURLConnection conn = SslTrustAll.open(url, AUTH_CONNECT_TIMEOUT_MS, AUTH_READ_TIMEOUT_MS);
trackConnection(conn);
try {
conn.setRequestMethod("POST");
conn.setDoOutput(true);
conn.setRequestProperty("Content-Type", "application/json");
@@ -92,11 +186,16 @@ final class CyberStrikeAIClient {
throw new IOException("Login response missing token. Check backend address and credentials.");
}
return token;
} finally {
releaseConnection(conn);
}
}
private void validate(String baseUrl, String token) throws IOException {
URL url = new URL(baseUrl + "/api/auth/validate");
HttpURLConnection conn = (HttpURLConnection) url.openConnection();
HttpURLConnection conn = SslTrustAll.open(url, AUTH_CONNECT_TIMEOUT_MS, AUTH_READ_TIMEOUT_MS);
trackConnection(conn);
try {
conn.setRequestMethod("GET");
conn.setRequestProperty("Authorization", "Bearer " + token);
int code = conn.getResponseCode();
@@ -104,6 +203,9 @@ final class CyberStrikeAIClient {
if (code < 200 || code >= 300) {
throw new IOException("Validate failed (" + code + "): " + resp);
}
} finally {
releaseConnection(conn);
}
}
void streamTest(Config cfg, String token, String message, StreamListener listener) {
@@ -117,11 +219,12 @@ final class CyberStrikeAIClient {
payload.put("orchestration", cfg.agentMode.orchestration);
}
new Thread(() -> {
Thread worker = new Thread(() -> {
HttpURLConnection conn = null;
try {
URL url = new URL(urlStr);
conn = (HttpURLConnection) url.openConnection();
conn = SslTrustAll.open(url, AUTH_CONNECT_TIMEOUT_MS, 0);
trackConnection(conn);
conn.setRequestMethod("POST");
conn.setDoOutput(true);
conn.setRequestProperty("Content-Type", "application/json");
@@ -142,6 +245,9 @@ final class CyberStrikeAIClient {
try (BufferedReader br = new BufferedReader(new InputStreamReader(is, StandardCharsets.UTF_8))) {
String line;
while ((line = br.readLine()) != null) {
if (Thread.currentThread().isInterrupted()) {
break;
}
// SSE format: "data: {json}"
if (line.startsWith("data:")) {
String json = line.substring("data:".length()).trim();
@@ -156,15 +262,25 @@ final class CyberStrikeAIClient {
}
}
}
listener.onDone();
if (Thread.currentThread().isInterrupted()) {
listener.onError("Cancelled.", null);
} else {
listener.onDone();
}
} catch (Exception e) {
listener.onError(e.getMessage(), e);
if (isCancelled(e)) {
listener.onError("Cancelled.", e);
} else {
listener.onError(e.getMessage(), e);
}
} finally {
if (conn != null) {
releaseConnection(conn);
conn.disconnect();
}
}
}, "CyberStrikeAI-Stream").start();
}, "CyberStrikeAI-Stream");
worker.start();
}
void cancelByConversationId(String baseUrl, String token, String conversationId) throws IOException {
@@ -172,7 +288,7 @@ final class CyberStrikeAIClient {
throw new IOException("Missing conversationId.");
}
URL url = new URL(baseUrl + "/api/agent-loop/cancel");
HttpURLConnection conn = (HttpURLConnection) url.openConnection();
HttpURLConnection conn = SslTrustAll.open(url, AUTH_CONNECT_TIMEOUT_MS, AUTH_READ_TIMEOUT_MS);
conn.setRequestMethod("POST");
conn.setDoOutput(true);
conn.setRequestProperty("Content-Type", "application/json");
@@ -14,6 +14,7 @@ final class CyberStrikeAITab implements ITab {
private final JTextField hostField = new JTextField("127.0.0.1");
private final JTextField portField = new JTextField("8080");
private final JCheckBox useHttpsBox = new JCheckBox("HTTPS", true);
private final JPasswordField passwordField = new JPasswordField();
private final JComboBox<String> agentModeBox = new JComboBox<>(new String[]{
"Native ReAct", "Eino Single (ADK)", "Deep (DeepAgent)", "Plan-Execute", "Supervisor"
@@ -29,6 +30,10 @@ final class CyberStrikeAITab implements ITab {
private final JTextArea progressArea = new JTextArea();
private final JTextArea finalRawArea = new JTextArea(); // raw final stream / final response
private JScrollPane progressScrollPane;
private JScrollPane finalRawScrollPane;
/** 距底部在此像素内视为「跟随滚动」,否则用户上拉阅读时不抢滚动条 */
private static final int SCROLL_FOLLOW_THRESHOLD_PX = 48;
private final JEditorPane markdownPane = new JEditorPane("text/html", "");
private final CardLayout outputCardsLayout = new CardLayout();
private final JPanel outputCards = new JPanel(outputCardsLayout);
@@ -41,6 +46,7 @@ final class CyberStrikeAITab implements ITab {
private final CyberStrikeAIClient client = new CyberStrikeAIClient();
private final AtomicReference<String> tokenRef = new AtomicReference<>("");
private final AtomicReference<Thread> validateThreadRef = new AtomicReference<>();
private final DefaultListModel<TestRun> testListModel = new DefaultListModel<>();
private final JList<TestRun> testList = new JList<>(testListModel);
@@ -107,6 +113,8 @@ final class CyberStrikeAITab implements ITab {
row1.add(hostField);
row1.add(new JLabel("Port"));
row1.add(portField);
useHttpsBox.setToolTipText("Use https:// for CyberStrikeAI (self-signed certs are trusted automatically)");
row1.add(useHttpsBox);
row1.add(new JLabel("Password"));
row1.add(passwordField);
row1.add(validateButton);
@@ -186,15 +194,22 @@ final class CyberStrikeAITab implements ITab {
configureTextArea(requestArea, false);
configureTextArea(responseArea, false);
outputCards.add(new JScrollPane(finalRawArea), "raw");
finalRawScrollPane = new JScrollPane(finalRawArea);
finalRawScrollPane.setHorizontalScrollBarPolicy(ScrollPaneConstants.HORIZONTAL_SCROLLBAR_NEVER);
finalRawScrollPane.getVerticalScrollBar().setUnitIncrement(16);
outputCards.add(finalRawScrollPane, "raw");
outputCards.add(new JScrollPane(markdownPane), "md");
outputRoot.add(buildOutputHeader(), BorderLayout.NORTH);
outputRoot.add(buildOutputBody(), BorderLayout.CENTER);
rightTabs.addTab("Output", outputRoot);
rightTabs.addTab("Request", new JScrollPane(requestArea));
rightTabs.addTab("Response", new JScrollPane(responseArea));
JScrollPane requestScroll = new JScrollPane(requestArea);
requestScroll.setHorizontalScrollBarPolicy(ScrollPaneConstants.HORIZONTAL_SCROLLBAR_NEVER);
rightTabs.addTab("Request", requestScroll);
JScrollPane responseScroll = new JScrollPane(responseArea);
responseScroll.setHorizontalScrollBarPolicy(ScrollPaneConstants.HORIZONTAL_SCROLLBAR_NEVER);
rightTabs.addTab("Response", responseScroll);
return rightTabs;
}
@@ -210,12 +225,13 @@ final class CyberStrikeAITab implements ITab {
}
private JComponent buildOutputBody() {
JScrollPane progressScroll = new JScrollPane(progressArea);
progressScroll.setBorder(BorderFactory.createTitledBorder("Progress"));
progressScroll.getVerticalScrollBar().setUnitIncrement(16);
progressScrollPane = new JScrollPane(progressArea);
progressScrollPane.setBorder(BorderFactory.createTitledBorder("Progress"));
progressScrollPane.setHorizontalScrollBarPolicy(ScrollPaneConstants.HORIZONTAL_SCROLLBAR_NEVER);
progressScrollPane.getVerticalScrollBar().setUnitIncrement(16);
JPanel empty = new JPanel();
progressContainer.add(progressScroll, "show");
progressContainer.add(progressScrollPane, "show");
progressContainer.add(empty, "hide");
((CardLayout) progressContainer.getLayout()).show(progressContainer, "show");
@@ -259,10 +275,27 @@ final class CyberStrikeAITab implements ITab {
return split;
}
private static boolean isScrollNearBottom(JScrollPane scrollPane) {
if (scrollPane == null) {
return true;
}
JScrollBar bar = scrollPane.getVerticalScrollBar();
int max = Math.max(0, bar.getMaximum() - bar.getVisibleAmount());
return bar.getValue() >= max - SCROLL_FOLLOW_THRESHOLD_PX;
}
private static void scrollPaneToBottom(JScrollPane scrollPane) {
if (scrollPane == null) {
return;
}
JScrollBar bar = scrollPane.getVerticalScrollBar();
bar.setValue(bar.getMaximum());
}
private static void configureTextArea(JTextArea area, boolean monospaced) {
area.setEditable(false);
area.setLineWrap(false);
area.setWrapStyleWord(false);
area.setLineWrap(true);
area.setWrapStyleWord(true);
if (monospaced) {
area.setFont(new Font(Font.MONOSPACED, Font.PLAIN, 12));
} else {
@@ -381,24 +414,44 @@ final class CyberStrikeAITab implements ITab {
private void wireActions() {
validateButton.addActionListener(e -> {
validateButton.setEnabled(false);
if ("Cancel".equals(validateButton.getText())) {
cancelValidateInProgress();
return;
}
validateButton.setText("Cancel");
validateButton.setEnabled(true);
stopButton.setEnabled(true);
statusLabel.setText("Validating...");
log("Validating connection...");
new Thread(() -> {
log("Validating connection... (max ~10s; click Cancel or Stop to abort)");
Thread worker = new Thread(() -> {
try {
CyberStrikeAIClient.Config cfg = currentConfig();
String token = client.loginAndValidate(cfg);
if (Thread.currentThread().isInterrupted()) {
return;
}
tokenRef.set(token);
SwingUtilities.invokeLater(() -> statusLabel.setText("OK (token saved)"));
log("Validation OK.");
} catch (Exception ex) {
tokenRef.set("");
SwingUtilities.invokeLater(() -> statusLabel.setText("Failed: " + ex.getMessage()));
log("Validation failed: " + ex.getMessage());
if (Thread.currentThread().isInterrupted()) {
SwingUtilities.invokeLater(() -> statusLabel.setText("Cancelled"));
log("Validation cancelled.");
} else {
SwingUtilities.invokeLater(() -> statusLabel.setText("Failed: " + ex.getMessage()));
log("Validation failed: " + ex.getMessage());
}
} finally {
SwingUtilities.invokeLater(() -> validateButton.setEnabled(true));
validateThreadRef.set(null);
SwingUtilities.invokeLater(() -> {
validateButton.setText("Validate");
validateButton.setEnabled(true);
});
}
}, "CyberStrikeAI-Validate").start();
}, "CyberStrikeAI-Validate");
validateThreadRef.set(worker);
worker.start();
});
clearButton.addActionListener(e -> {
@@ -435,10 +488,23 @@ final class CyberStrikeAITab implements ITab {
});
stopButton.addActionListener(e -> {
if ("Cancel".equals(validateButton.getText())) {
cancelValidateInProgress();
return;
}
String runId = selectedRunId;
if (runId != null && client.hasActiveRequest()) {
client.cancelActiveRequest();
appendProgressToRun(runId, "\n[info] Stream stopped.\n");
setRunStatus(runId, "cancelled");
return;
}
if (runId == null) return;
TestRun run = runs.get(runId);
if (run == null) return;
String token = getToken();
if (token == null || token.trim().isEmpty()) {
appendProgressToRun(runId, "\n[error] Not validated.\n");
@@ -483,7 +549,8 @@ final class CyberStrikeAITab implements ITab {
String host = hostField.getText().trim();
String port = portField.getText().trim();
String password = new String(passwordField.getPassword());
String baseUrl = "http://" + host + ":" + port;
String scheme = useHttpsBox.isSelected() ? "https" : "http";
String baseUrl = scheme + "://" + host + ":" + port;
int idx = agentModeBox.getSelectedIndex();
CyberStrikeAIClient.AgentMode mode = (idx >= 0 && idx < AGENT_MODES.length)
? AGENT_MODES[idx]
@@ -567,10 +634,31 @@ final class CyberStrikeAITab implements ITab {
run.progressBuffer.append(s);
}
if (runId.equals(selectedRunId)) {
SwingUtilities.invokeLater(() -> {
progressArea.append(s);
progressArea.setCaretPosition(progressArea.getDocument().getLength());
});
SwingUtilities.invokeLater(() -> appendProgressUi(s, false));
}
}
private void appendProgressUi(String s, boolean forceFollow) {
JScrollBar bar = progressScrollPane != null ? progressScrollPane.getVerticalScrollBar() : null;
int scrollBefore = bar != null ? bar.getValue() : 0;
boolean follow = forceFollow || isScrollNearBottom(progressScrollPane);
progressArea.append(s);
if (follow) {
scrollPaneToBottom(progressScrollPane);
} else if (bar != null) {
bar.setValue(scrollBefore);
}
}
private void appendFinalUi(String s, boolean forceFollow) {
JScrollBar bar = finalRawScrollPane != null ? finalRawScrollPane.getVerticalScrollBar() : null;
int scrollBefore = bar != null ? bar.getValue() : 0;
boolean follow = forceFollow || isScrollNearBottom(finalRawScrollPane);
finalRawArea.append(s);
if (follow) {
scrollPaneToBottom(finalRawScrollPane);
} else if (bar != null) {
bar.setValue(scrollBefore);
}
}
@@ -620,10 +708,7 @@ final class CyberStrikeAITab implements ITab {
run.finalBuffer.append(s);
}
if (runId.equals(selectedRunId)) {
SwingUtilities.invokeLater(() -> {
finalRawArea.append(s);
finalRawArea.setCaretPosition(finalRawArea.getDocument().getLength());
});
SwingUtilities.invokeLater(() -> appendFinalUi(s, false));
}
}
@@ -656,9 +741,9 @@ final class CyberStrikeAITab implements ITab {
}
SwingUtilities.invokeLater(() -> {
progressArea.setText(progress);
progressArea.setCaretPosition(progressArea.getDocument().getLength());
scrollPaneToBottom(progressScrollPane);
finalRawArea.setText(fin);
finalRawArea.setCaretPosition(finalRawArea.getDocument().getLength());
scrollPaneToBottom(finalRawScrollPane);
requestArea.setText(run.requestRaw == null ? "" : run.requestRaw);
responseArea.setText(run.responseRaw == null ? "" : run.responseRaw);
refreshOutputView();
@@ -682,25 +767,36 @@ final class CyberStrikeAITab implements ITab {
void clearAndShowStreamHeader(String title) {
SwingUtilities.invokeLater(() -> {
progressArea.setText("");
finalRawArea.setText(title + "\n\n");
progressArea.setText("[*] " + title + "\n\n");
finalRawArea.setText("");
markdownPane.setText("");
});
}
// Legacy helpers kept for Validate logging
void appendStreamLine(String s) {
if (s == null) return;
SwingUtilities.invokeLater(() -> {
progressArea.append(s);
progressArea.append("\n");
progressArea.setCaretPosition(progressArea.getDocument().getLength());
});
SwingUtilities.invokeLater(() -> appendProgressUi(s + "\n", false));
}
private void log(String s) {
appendStreamLine("[*] " + s);
}
private void cancelValidateInProgress() {
client.cancelActiveRequest();
Thread t = validateThreadRef.getAndSet(null);
if (t != null) {
t.interrupt();
}
SwingUtilities.invokeLater(() -> {
statusLabel.setText("Cancelled");
validateButton.setText("Validate");
validateButton.setEnabled(true);
});
log("Validation cancelled.");
}
private void applyFilter() {
String q = searchField.getText();
if (q == null) q = "";
@@ -0,0 +1,149 @@
package burp;
import javax.net.ssl.HostnameVerifier;
import javax.net.ssl.HttpsURLConnection;
import javax.net.ssl.SSLSocketFactory;
import javax.net.ssl.SSLContext;
import javax.net.ssl.TrustManager;
import javax.net.ssl.X509TrustManager;
import java.io.IOException;
import java.net.HttpURLConnection;
import java.net.InetSocketAddress;
import java.net.Socket;
import java.net.URL;
import java.security.cert.X509Certificate;
/**
* Opens HTTPS connections without validating server certificates (self-signed / local dev).
* Applied per-connection only; does not change JVM-wide defaults for other Burp components.
*/
final class SslTrustAll {
private static volatile SSLSocketFactory socketFactory;
private static final HostnameVerifier TRUST_ALL_HOSTS = (hostname, session) -> true;
private SslTrustAll() {
}
static HttpURLConnection open(URL url) throws IOException {
return open(url, 5_000, 30_000);
}
static HttpURLConnection open(URL url, int connectTimeoutMs, int readTimeoutMs) throws IOException {
HttpURLConnection conn = (HttpURLConnection) url.openConnection();
conn.setConnectTimeout(connectTimeoutMs);
conn.setReadTimeout(readTimeoutMs);
if (conn instanceof HttpsURLConnection) {
HttpsURLConnection https = (HttpsURLConnection) conn;
https.setSSLSocketFactory(new TimeoutSslSocketFactory(socketFactory(), connectTimeoutMs, readTimeoutMs));
https.setHostnameVerifier(TRUST_ALL_HOSTS);
}
return conn;
}
private static SSLSocketFactory socketFactory() {
SSLSocketFactory sf = socketFactory;
if (sf != null) {
return sf;
}
synchronized (SslTrustAll.class) {
sf = socketFactory;
if (sf != null) {
return sf;
}
try {
TrustManager[] trustAll = new TrustManager[]{
new X509TrustManager() {
@Override
public X509Certificate[] getAcceptedIssuers() {
return new X509Certificate[0];
}
@Override
public void checkClientTrusted(X509Certificate[] chain, String authType) {
}
@Override
public void checkServerTrusted(X509Certificate[] chain, String authType) {
}
}
};
SSLContext ctx = SSLContext.getInstance("TLS");
ctx.init(null, trustAll, new java.security.SecureRandom());
sf = ctx.getSocketFactory();
socketFactory = sf;
return sf;
} catch (Exception e) {
throw new RuntimeException("Failed to initialize trust-all TLS", e);
}
}
}
/** Ensures TCP connect + socket read respect timeouts (plain HttpURLConnection SSL can hang longer). */
private static final class TimeoutSslSocketFactory extends SSLSocketFactory {
private final SSLSocketFactory delegate;
private final int connectTimeoutMs;
private final int readTimeoutMs;
TimeoutSslSocketFactory(SSLSocketFactory delegate, int connectTimeoutMs, int readTimeoutMs) {
this.delegate = delegate;
this.connectTimeoutMs = connectTimeoutMs;
this.readTimeoutMs = readTimeoutMs;
}
@Override
public String[] getDefaultCipherSuites() {
return delegate.getDefaultCipherSuites();
}
@Override
public String[] getSupportedCipherSuites() {
return delegate.getSupportedCipherSuites();
}
@Override
public Socket createSocket() throws IOException {
return tune(delegate.createSocket());
}
@Override
public Socket createSocket(Socket s, String host, int port, boolean autoClose) throws IOException {
return tune(delegate.createSocket(s, host, port, autoClose));
}
@Override
public Socket createSocket(String host, int port) throws IOException {
Socket plain = new Socket();
plain.connect(new InetSocketAddress(host, port), connectTimeoutMs);
return tune(delegate.createSocket(plain, host, port, true));
}
@Override
public Socket createSocket(String host, int port, java.net.InetAddress localHost, int localPort) throws IOException {
Socket plain = new Socket();
plain.bind(new InetSocketAddress(localHost, localPort));
plain.connect(new InetSocketAddress(host, port), connectTimeoutMs);
return tune(delegate.createSocket(plain, host, port, true));
}
@Override
public Socket createSocket(java.net.InetAddress host, int port) throws IOException {
Socket plain = new Socket();
plain.connect(new InetSocketAddress(host, port), connectTimeoutMs);
return tune(delegate.createSocket(plain, host.getHostName(), port, true));
}
@Override
public Socket createSocket(java.net.InetAddress address, int port, java.net.InetAddress localAddress, int localPort) throws IOException {
Socket plain = new Socket();
plain.bind(new InetSocketAddress(localAddress, localPort));
plain.connect(new InetSocketAddress(address, port), connectTimeoutMs);
return tune(delegate.createSocket(plain, address.getHostName(), port, true));
}
private Socket tune(Socket socket) throws IOException {
socket.setSoTimeout(readTimeoutMs);
return socket;
}
}
}
@@ -1,12 +1,16 @@
burp/SslTrustAll.class
burp/SslTrustAll$TimeoutSslSocketFactory.class
burp/CyberStrikeAIClient$StreamListener.class
burp/CyberStrikeAIClient$Config.class
burp/CyberStrikeAIClient$AgentMode.class
burp/MarkdownRenderer.class
burp/SimpleJson.class
burp/CyberStrikeAIClient.class
burp/CyberStrikeAIClient$1.class
burp/CyberStrikeAITab$DotIcon.class
burp/CyberStrikeAITab.class
burp/CyberStrikeAITab$1.class
burp/SslTrustAll$1.class
burp/BurpExtender$1.class
burp/BurpExtender.class
burp/CyberStrikeAITab$TestRun.class
@@ -4,3 +4,4 @@
/Users/temp/Downloads/CyberStrikeAI-main/plugins/burp-suite/cyberstrikeai-burp-extension/src/main/java/burp/HttpMessageFormatter.java
/Users/temp/Downloads/CyberStrikeAI-main/plugins/burp-suite/cyberstrikeai-burp-extension/src/main/java/burp/MarkdownRenderer.java
/Users/temp/Downloads/CyberStrikeAI-main/plugins/burp-suite/cyberstrikeai-burp-extension/src/main/java/burp/SimpleJson.java
/Users/temp/Downloads/CyberStrikeAI-main/plugins/burp-suite/cyberstrikeai-burp-extension/src/main/java/burp/SslTrustAll.java
+28 -3
View File
@@ -857,10 +857,35 @@
background: var(--c2-surface);
border-radius: var(--c2-radius);
border: 1px solid var(--c2-border);
overflow: hidden;
overflow-x: auto;
overflow-y: visible;
}
.c2-task-table { width: 100%; border-collapse: collapse; }
/* 操作列:仅占按钮宽度,避免 100% 表格把余白摊到最右列 */
.c2-task-table th.c2-task-table-col-actions,
.c2-task-table td.c2-task-table-col-actions {
width: 1%;
white-space: nowrap;
text-align: right;
vertical-align: middle;
}
.c2-task-table-actions {
display: inline-flex;
align-items: center;
justify-content: flex-end;
gap: 6px;
flex-wrap: nowrap;
}
.c2-task-table-actions .btn-small,
.c2-task-table-actions .btn-sm {
min-height: 30px;
min-width: 52px;
justify-content: center;
}
.c2-task-table { width: 100%; border-collapse: collapse; table-layout: auto; }
.c2-task-table th {
text-align: left;
@@ -1261,7 +1286,7 @@
display: flex;
align-items: center;
justify-content: center;
z-index: 1000;
z-index: 10050;
padding: 24px;
animation: c2-fade-in 0.15s ease-out;
}
+2024 -24
View File
File diff suppressed because it is too large Load Diff
+172 -3
View File
@@ -820,6 +820,7 @@
"robots": "Bots",
"terminal": "Terminal",
"security": "Security",
"audit": "Audit logs",
"infocollect": "Recon"
},
"infocollect": {
@@ -836,7 +837,32 @@
},
"robots": {
"title": "Bot settings",
"description": "Configure WeCom, DingTalk and Lark bots so you can chat with CyberStrikeAI on your phone without opening the web UI.",
"description": "Configure WeChat (iLink), WeCom, DingTalk and Lark bots so you can chat with CyberStrikeAI on your phone without opening the web UI.",
"wechat": {
"title": "WeChat / iLink",
"subtitle": "Bind personal WeChat via QR code and chat with CyberStrikeAI on your phone",
"statusIdle": "Not bound",
"statusBound": "Connected",
"statusScanning": "Binding…",
"step1": "Generate QR",
"step2": "Scan in WeChat",
"step3": "Confirm",
"enabled": "Enable WeChat bot",
"bindButton": "Generate QR code and bind",
"bindHint": "Scan with WeChat to confirm; settings are saved automatically.",
"qrLoading": "Generating QR code…",
"verifyCodeLabel": "Code on your phone (only if WeChat asks)",
"rebindButton": "Re-bind",
"boundBotId": "Bound Bot ID: ",
"verifyCodeSubmit": "Submit",
"advanced": "Advanced settings",
"baseUrl": "API Base URL",
"botType": "Bot Type",
"botAgent": "Bot Agent",
"ilinkBotId": "iLink Bot ID (filled after bind)",
"boundSuccess": "Binding successful. WeChat bot is enabled.",
"openLink": "QR not showing? Open link in WeChat on your phone"
},
"wecom": {
"title": "WeCom",
"enabled": "Enable WeCom bot",
@@ -1306,6 +1332,35 @@
"noCallsYet": "No calls yet",
"unknownTool": "Unknown tool",
"successFailedRate": "Success {{success}} / Failed {{failed}} · {{rate}}% success rate",
"topToolsTitle": "Top {{n}} tools by calls",
"barVolumeLegend": "Bar length: relative call volume; green/red: success vs failure share",
"clickToFilterTool": "Click a row to filter records below",
"toolRowAriaLabel": "{{name}}, {{total}} calls, {{rate}}% success rate. Click to filter records.",
"successRateAria": "Success rate {{rate}}%",
"filterByToolTitle": "Filtered by: {{tool}}",
"clearToolFilter": "Clear tool filter",
"successCount": "Success {{n}}",
"failedCount": "Failed {{n}}",
"rateHealthy": "Running smoothly",
"rateWarning": "Some failures detected",
"rateCritical": "High failure rate",
"statsSubtitle": "Refreshed {{time}} · {{count}} tools",
"distTitle": "Call distribution",
"distLegend": "Slice area shows share of all calls",
"distClickHint": "Click legend or slice to filter records",
"distHeaderHint": "{{n}} total calls",
"distSegmentAria": "{{name}}, {{pct}}% of calls, {{calls}} times",
"distOthersNoFilter": "Other tools cannot be filtered individually",
"distTotalCalls": "{{n}} total calls",
"distTop6Share": "Top {{n}} share of all calls",
"distOthers": "Other tools",
"distCallsUnit": "{{n}} calls",
"riskTitle": "Failure alerts",
"riskNone": "No recent failures",
"riskItem": "{{name}}: {{failed}} / {{total}} failed",
"selectedToolTitle": "Active filter",
"selectedToolEmpty": "Click a tool on the left to filter records below",
"selectedToolStats": "{{total}} calls · {{success}} ok · {{failed}} failed · {{rate}}% success",
"columnTool": "Tool",
"columnStatus": "Status",
"columnStartTime": "Start time",
@@ -1476,6 +1531,7 @@
"confirmDelete": "Delete this file?",
"editTitle": "Edit file",
"renameTitle": "Rename",
"renameCurrentFile": "Current file",
"newFileName": "New file name",
"empty": "No chat uploads yet",
"errorLoad": "Failed to load",
@@ -1486,6 +1542,15 @@
},
"vulnerabilityPage": {
"statTotal": "Total",
"statClickAll": "View all (clear severity filter)",
"statClickFilter": "Click to filter by this severity; click again to clear",
"advancedFilters": "Advanced filters",
"moreFilters": "More filters",
"applyFilters": "Apply",
"clearAdvanced": "Clear",
"clearAll": "Reset all",
"activeFilters": "Active filters",
"chipRemove": "Remove",
"filter": "Filter",
"clear": "Clear",
"vulnId": "Vuln ID",
@@ -1503,6 +1568,10 @@
"statusFixed": "Fixed",
"statusFalsePositive": "False positive",
"searchVulnId": "Search vuln ID",
"searchKeyword": "Search title, description, type, target…",
"searchKeywordShort": "Keyword",
"filterExactId": "Exact vuln ID",
"filterEnterHint": "Press Enter to filter",
"filterConversation": "Filter by conversation",
"loading": "Loading...",
"loadListFailed": "Failed to load",
@@ -1630,8 +1699,8 @@
"multiAgentPeLoop": "plan_execute outer loop limit",
"multiAgentPeLoopPlaceholder": "0 uses Eino default (10)",
"multiAgentPeLoopHint": "Only for plan_execute; max execute↔replan rounds.",
"multiAgentRobotUse": "Use multi-agent for WeCom / DingTalk / Lark bots",
"multiAgentRobotUseHint": "Requires 'Enable multi-agent' to be checked; usage and cost will be higher.",
"multiAgentRobotMode": "Default conversation mode for bots",
"multiAgentRobotModeHint": "Execution mode for WeCom / DingTalk / Lark bot messages. Deep / Plan-Execute / Supervisor require multi-agent to be enabled.",
"multiAgentBatchUse": "Use multi-agent for batch task queues",
"multiAgentBatchUseHint": "When enabled, each sub-task executed by queue in Task Management will run through Eino DeepAgent (requires multi-agent).",
"enableKnowledge": "Enable knowledge retrieval",
@@ -1723,6 +1792,106 @@
"close": "×",
"newTerminal": "+"
},
"settingsAudit": {
"title": "Audit logs",
"description": "Platform admin actions (login, config, deletes). Does not log chat content, per-command terminal/WebShell runs, or per-tool invocations.",
"filterCategory": "Category",
"filterAction": "Action",
"filterEvent": "Event type",
"filterAllCategories": "All categories",
"filterAllActions": "All actions",
"filterCascadeHint": "Select a category to filter by action",
"filterResult": "Result",
"pageSize": "Per page",
"statTotal": "Filtered total",
"statFailures": "Failures",
"statRecent7d": "Last 7 days",
"retentionHint": "Audit records are kept for {{days}} days, then purged automatically.",
"disabledHint": "Audit logging is disabled; new actions are not written.",
"filterSince": "From",
"filterUntil": "Until",
"filterQuery": "Keyword",
"filterQueryPlaceholder": "Message / resource ID / action",
"cat": {
"auth": "Auth",
"config": "Config",
"terminal": "Terminal",
"c2": "C2",
"webshell": "WebShell",
"knowledge": "Knowledge",
"conversation": "Conversation",
"vulnerability": "Vulnerability",
"externalMcp": "External MCP",
"task": "Tasks",
"tool": "Tools",
"file": "Files",
"hitl": "HITL",
"role": "Roles",
"skill": "Skills",
"agent": "Sub-agents"
},
"act": {
"login": "Login",
"logout": "Logout",
"login_failed": "Login failed",
"password_change": "Password change",
"change_password": "Change password",
"apply": "Apply config",
"update": "Update",
"exec": "Terminal exec",
"exec_stream": "Terminal stream",
"listener_create": "Create listener",
"listener_delete": "Delete listener",
"listener_start": "Start listener",
"listener_stop": "Stop listener",
"session_delete": "Delete session",
"task_create": "Create task",
"task_cancel": "Cancel task",
"task_delete": "Delete task",
"connection_create": "Create connection",
"connection_delete": "Delete connection",
"item_delete": "Delete knowledge item",
"index_rebuild": "Rebuild index",
"delete": "Delete",
"delete_turn": "Delete turn",
"create": "Create",
"upsert": "Upsert external MCP",
"create_queue": "Create batch queue",
"start_queue": "Start batch queue",
"delete_queue": "Delete batch queue",
"pause_queue": "Pause batch queue",
"rerun_queue": "Rerun batch queue",
"delete_batch_task": "Delete batch subtask",
"execution_delete": "Delete execution",
"execution_delete_batch": "Batch delete executions",
"upload": "Upload",
"decision": "HITL decision",
"markdown_create": "Create sub-agent",
"markdown_update": "Update sub-agent",
"markdown_delete": "Delete sub-agent"
},
"openResource": "Open linked resource",
"openResourceChat": "Open linked resource (chat)",
"resourceIdLabel": "Resource ID",
"resourceRemoved": "(resource no longer exists)",
"filterAll": "All",
"filterBtn": "Filter",
"resetBtn": "Reset",
"exportBtn": "Export",
"exportJson": "Export JSON",
"export": "Export JSON",
"exportCsv": "Export CSV",
"exportDone": "Export complete",
"loading": "Loading...",
"empty": "No audit records",
"paginationShow": "{{start}}-{{end}} of {{total}}",
"detailTitle": "Audit detail",
"detailTime": "Time",
"detailCategory": "Category",
"detailResult": "Result",
"detailMessage": "Message",
"detailSession": "Session"
},
"settingsSecurity": {
"changePasswordTitle": "Change password",
"changePasswordDesc": "After changing password, sign in again with the new password.",
+173 -4
View File
@@ -809,6 +809,7 @@
"robots": "机器人设置",
"terminal": "终端",
"security": "安全设置",
"audit": "日志审计",
"infocollect": "信息收集"
},
"infocollect": {
@@ -825,7 +826,32 @@
},
"robots": {
"title": "机器人设置",
"description": "配置企业微信、钉钉、飞书等机器人,在手机端直接与 CyberStrikeAI 对话,无需在服务器上打开网页。",
"description": "配置微信、企业微信、钉钉、飞书等机器人,在手机端直接与 CyberStrikeAI 对话,无需在服务器上打开网页。",
"wechat": {
"title": "微信 / iLink",
"subtitle": "扫码绑定个人微信,在手机端直接与 CyberStrikeAI 对话",
"statusIdle": "未绑定",
"statusBound": "已连接",
"statusScanning": "绑定中…",
"step1": "生成二维码",
"step2": "微信扫码",
"step3": "确认绑定",
"enabled": "启用微信机器人",
"bindButton": "生成二维码并绑定",
"bindHint": "用微信扫码确认后会自动保存并启用。",
"qrLoading": "正在生成二维码…",
"verifyCodeLabel": "手机显示的数字(仅部分账号需要)",
"rebindButton": "重新绑定",
"boundBotId": "已绑定 Bot ID",
"verifyCodeSubmit": "提交",
"advanced": "高级设置",
"baseUrl": "API Base URL",
"botType": "Bot Type",
"botAgent": "Bot Agent",
"ilinkBotId": "iLink Bot ID(绑定后自动填充)",
"boundSuccess": "绑定成功,微信机器人已启用。",
"openLink": "无法显示二维码?点击用手机微信打开链接"
},
"wecom": {
"title": "企业微信",
"enabled": "启用企业微信机器人",
@@ -1295,6 +1321,35 @@
"noCallsYet": "暂无调用",
"unknownTool": "未知工具",
"successFailedRate": "成功 {{success}} / 失败 {{failed}} · 成功率 {{rate}}%",
"topToolsTitle": "工具调用 Top {{n}}",
"barVolumeLegend": "条长表示相对调用量,条内绿/红为成功/失败占比",
"clickToFilterTool": "点击行筛选下方执行记录",
"toolRowAriaLabel": "{{name}}{{total}} 次调用,成功率 {{rate}}%,点击查看执行记录",
"successRateAria": "成功率 {{rate}}%",
"filterByToolTitle": "筛选工具:{{tool}}",
"clearToolFilter": "清除工具筛选",
"successCount": "成功 {{n}}",
"failedCount": "失败 {{n}}",
"rateHealthy": "运行平稳",
"rateWarning": "存在失败调用",
"rateCritical": "失败率偏高",
"statsSubtitle": "最后刷新 {{time}} · 共 {{count}} 个工具",
"distTitle": "调用分布",
"distLegend": "扇区面积为占全部调用比例",
"distClickHint": "点击图例或扇区筛选执行记录",
"distHeaderHint": "共 {{n}} 次调用",
"distSegmentAria": "{{name}},占 {{pct}}%{{calls}} 次",
"distOthersNoFilter": "其他工具无法单独筛选",
"distTotalCalls": "共 {{n}} 次调用",
"distTop6Share": "Top {{n}} 占全部调用",
"distOthers": "其他工具",
"distCallsUnit": "{{n}} 次",
"riskTitle": "失败提醒",
"riskNone": "近期无失败调用",
"riskItem": "{{name}}:失败 {{failed}} / {{total}} 次",
"selectedToolTitle": "当前筛选",
"selectedToolEmpty": "点击左侧工具行,可筛选下方执行记录",
"selectedToolStats": "调用 {{total}} 次 · 成功 {{success}} · 失败 {{failed}} · 成功率 {{rate}}%",
"columnTool": "工具",
"columnStatus": "状态",
"columnStartTime": "开始时间",
@@ -1465,6 +1520,7 @@
"confirmDelete": "确定删除该文件?",
"editTitle": "编辑文件",
"renameTitle": "重命名",
"renameCurrentFile": "当前文件",
"newFileName": "新文件名",
"empty": "暂无对话附件",
"errorLoad": "加载失败",
@@ -1475,6 +1531,15 @@
},
"vulnerabilityPage": {
"statTotal": "总漏洞数",
"statClickAll": "查看全部(清除严重度筛选)",
"statClickFilter": "点击按此严重度筛选;再次点击清除",
"advancedFilters": "高级筛选",
"moreFilters": "更多筛选",
"applyFilters": "应用",
"clearAdvanced": "清空",
"clearAll": "重置全部",
"activeFilters": "已选条件",
"chipRemove": "移除",
"filter": "筛选",
"clear": "清除",
"vulnId": "漏洞ID",
@@ -1491,7 +1556,11 @@
"statusConfirmed": "已确认",
"statusFixed": "已修复",
"statusFalsePositive": "误报",
"searchVulnId": "搜索漏洞ID",
"searchVulnId": "搜索漏洞 ID",
"searchKeyword": "搜索标题、描述、类型、目标…",
"searchKeywordShort": "关键词",
"filterExactId": "精确匹配漏洞 ID",
"filterEnterHint": "回车筛选",
"filterConversation": "筛选特定会话",
"loading": "加载中...",
"loadListFailed": "加载失败",
@@ -1619,8 +1688,8 @@
"multiAgentPeLoop": "plan_execute 外层循环上限",
"multiAgentPeLoopPlaceholder": "0 表示 Eino 默认 10",
"multiAgentPeLoopHint": "仅 plan_execute 有效;execute 与 replan 之间的最大轮次。",
"multiAgentRobotUse": "企业微信 / 钉钉 / 飞书机器人也使用多代理",
"multiAgentRobotUseHint": "需同时勾选「启用多代理」;调用量与成本更高。",
"multiAgentRobotMode": "机器人默认对话模式",
"multiAgentRobotModeHint": "企业微信 / 钉钉 / 飞书机器人每条消息使用的执行模式;Deep / Plan-Execute / Supervisor 需启用多代理。",
"multiAgentBatchUse": "批量任务队列也使用多代理",
"multiAgentBatchUseHint": "开启后,任务管理中按队列执行的每个子任务将走 Eino DeepAgent(需启用多代理)。",
"enableKnowledge": "启用知识检索功能",
@@ -1712,6 +1781,106 @@
"close": "×",
"newTerminal": "+"
},
"settingsAudit": {
"title": "日志审计",
"description": "记录平台管理类操作(登录、配置、删除等),不记录对话正文、终端/WebShell 每次命令与工具调用明细。",
"filterCategory": "类别",
"filterAction": "操作",
"filterEvent": "事件类型",
"filterAllCategories": "全部类别",
"filterAllActions": "全部操作",
"filterCascadeHint": "选择类别后可筛选具体操作",
"filterResult": "结果",
"pageSize": "每页",
"statTotal": "当前筛选",
"statFailures": "失败",
"statRecent7d": "近 7 天",
"retentionHint": "审计记录保留 {{days}} 天,超期自动清理。",
"disabledHint": "审计功能已关闭,新操作不会写入审计表。",
"filterSince": "开始时间",
"filterUntil": "结束时间",
"filterQuery": "关键词",
"filterQueryPlaceholder": "消息 / 资源 ID / 操作名",
"cat": {
"auth": "认证",
"config": "配置",
"terminal": "终端",
"c2": "C2",
"webshell": "WebShell",
"knowledge": "知识库",
"conversation": "对话",
"vulnerability": "漏洞",
"externalMcp": "外部 MCP",
"task": "任务",
"tool": "工具",
"file": "文件",
"hitl": "人机协同",
"role": "角色",
"skill": "Skill",
"agent": "子代理"
},
"act": {
"login": "登录",
"logout": "登出",
"login_failed": "登录失败",
"password_change": "修改密码",
"change_password": "修改密码",
"apply": "应用配置",
"update": "更新",
"exec": "终端执行",
"exec_stream": "终端流式执行",
"listener_create": "创建监听器",
"listener_delete": "删除监听器",
"listener_start": "启动监听器",
"listener_stop": "停止监听器",
"session_delete": "删除会话",
"task_create": "创建任务",
"task_cancel": "取消任务",
"task_delete": "删除任务",
"connection_create": "创建连接",
"connection_delete": "删除连接",
"item_delete": "删除知识项",
"index_rebuild": "重建索引",
"delete": "删除",
"delete_turn": "删除轮次",
"create": "创建",
"upsert": "保存外部 MCP",
"create_queue": "创建批量队列",
"start_queue": "启动批量队列",
"delete_queue": "删除批量队列",
"pause_queue": "暂停批量队列",
"rerun_queue": "重跑批量队列",
"delete_batch_task": "删除批量子任务",
"execution_delete": "删除执行记录",
"execution_delete_batch": "批量删除执行",
"upload": "上传",
"decision": "HITL 决策",
"markdown_create": "创建子代理",
"markdown_update": "更新子代理",
"markdown_delete": "删除子代理"
},
"openResource": "打开关联资源",
"openResourceChat": "打开关联资源(chat",
"resourceIdLabel": "资源 ID",
"resourceRemoved": "(关联对象已删除)",
"filterAll": "全部",
"filterBtn": "筛选",
"resetBtn": "重置",
"exportBtn": "导出",
"exportJson": "导出 JSON",
"export": "导出 JSON",
"exportCsv": "导出 CSV",
"exportDone": "导出完成",
"loading": "加载中...",
"empty": "暂无审计记录",
"paginationShow": "显示 {{start}}-{{end}} / 共 {{total}} 条",
"detailTitle": "审计详情",
"detailTime": "时间",
"detailCategory": "类别",
"detailResult": "结果",
"detailMessage": "说明",
"detailSession": "会话"
},
"settingsSecurity": {
"changePasswordTitle": "修改密码",
"changePasswordDesc": "修改登录密码后,需要使用新密码重新登录。",
+598
View File
@@ -0,0 +1,598 @@
/**
* 系统设置 - 平台操作审计日志
*/
let auditLogsPage = 1;
let auditLogsPageSize = 20;
let auditLogsTotal = 0;
const AUDIT_PAGE_SIZE_KEY = 'cyberstrike_audit_page_size';
/** 按类别列出的操作(用于 datalist 提示,避免超长下拉) */
const AUDIT_ACTIONS_BY_CATEGORY = {
auth: ['login', 'logout', 'change_password'],
config: ['apply', 'update'],
c2: ['listener_create', 'listener_delete', 'listener_start', 'listener_stop',
'session_delete', 'task_create', 'task_cancel', 'task_delete'],
webshell: ['connection_create', 'connection_delete'],
knowledge: ['item_delete', 'index_rebuild'],
conversation: ['create', 'delete', 'delete_turn'],
vulnerability: ['create', 'update', 'delete'],
external_mcp: ['upsert', 'delete'],
task: ['create_queue', 'start_queue', 'delete_queue', 'pause_queue', 'rerun_queue', 'delete_batch_task'],
tool: ['execution_delete', 'execution_delete_batch'],
file: ['upload', 'delete'],
hitl: ['decision'],
role: ['create', 'update', 'delete'],
skill: ['create', 'update', 'delete'],
agent: ['markdown_create', 'markdown_update', 'markdown_delete']
};
function auditT(key, opts, fallback) {
if (typeof t === 'function') {
const v = t(key, opts);
if (v && v !== key) return v;
}
return fallback != null ? fallback : key;
}
function auditCategoryI18nKey(category) {
if (!category) return '';
if (category === 'external_mcp') return 'externalMcp';
return category;
}
function auditCategoryLabel(category) {
if (!category) return '';
const key = 'settingsAudit.cat.' + auditCategoryI18nKey(category);
return auditT(key, null, category);
}
function auditActionLabel(action) {
if (!action) return '';
return auditT('settingsAudit.act.' + action, null, action);
}
function formatAuditTime(iso) {
if (!iso) return '';
try {
const d = new Date(iso);
if (Number.isNaN(d.getTime())) return iso;
return d.toLocaleString();
} catch (_) {
return iso;
}
}
function auditDatetimeLocalToRFC3339(value) {
if (!value || !value.trim()) return '';
const d = new Date(value);
if (Number.isNaN(d.getTime())) return '';
return d.toISOString();
}
function initAuditPageSizeFromStorage() {
try {
const saved = parseInt(localStorage.getItem(AUDIT_PAGE_SIZE_KEY), 10);
if ([10, 20, 50, 100].indexOf(saved) >= 0) {
auditLogsPageSize = saved;
}
} catch (_) { /* ignore */ }
const sel = document.getElementById('audit-page-size');
if (sel) sel.value = String(auditLogsPageSize);
}
function onAuditPageSizeChange() {
const sel = document.getElementById('audit-page-size');
if (!sel) return;
const n = parseInt(sel.value, 10);
if ([10, 20, 50, 100].indexOf(n) < 0) return;
auditLogsPageSize = n;
try {
localStorage.setItem(AUDIT_PAGE_SIZE_KEY, String(n));
} catch (_) { /* ignore */ }
auditLogsPage = 1;
loadAuditLogs(1);
}
function rebuildAuditActionSelect() {
const catEl = document.getElementById('audit-filter-category');
const actEl = document.getElementById('audit-filter-action');
if (!actEl) return;
const category = catEl ? catEl.value : '';
const prev = actEl.value;
const allLabel = auditT('settingsAudit.filterAllActions', null, '全部操作');
const hint = auditT('settingsAudit.filterCascadeHint', null, '选择类别后可筛选具体操作');
actEl.innerHTML = '';
const allOpt = document.createElement('option');
allOpt.value = '';
allOpt.textContent = allLabel;
actEl.appendChild(allOpt);
if (!category) {
actEl.disabled = true;
actEl.value = '';
actEl.title = hint;
return;
}
actEl.disabled = false;
actEl.title = '';
const actions = AUDIT_ACTIONS_BY_CATEGORY[category] || [];
actions.forEach(function (action) {
const opt = document.createElement('option');
opt.value = action;
opt.textContent = auditActionLabel(action);
actEl.appendChild(opt);
});
if (prev && Array.prototype.some.call(actEl.options, function (o) { return o.value === prev; })) {
actEl.value = prev;
}
}
function onAuditCategoryFilterChange() {
rebuildAuditActionSelect();
}
function buildAuditQueryParams(forExport) {
const params = new URLSearchParams();
if (!forExport) {
params.set('page', String(auditLogsPage));
params.set('page_size', String(auditLogsPageSize));
}
const cat = document.getElementById('audit-filter-category');
const act = document.getElementById('audit-filter-action');
const res = document.getElementById('audit-filter-result');
const q = document.getElementById('audit-filter-q');
const since = document.getElementById('audit-filter-since');
const until = document.getElementById('audit-filter-until');
if (cat && cat.value) params.set('category', cat.value);
if (act && !act.disabled && act.value) params.set('action', act.value);
if (res && res.value) params.set('result', res.value);
if (q && q.value.trim()) params.set('q', q.value.trim());
const sinceISO = since ? auditDatetimeLocalToRFC3339(since.value) : '';
const untilISO = until ? auditDatetimeLocalToRFC3339(until.value) : '';
if (sinceISO) params.set('since', sinceISO);
if (untilISO) params.set('until', untilISO);
return params.toString();
}
async function loadAuditMeta() {
if (typeof apiFetch !== 'function') return;
const hint = document.getElementById('audit-retention-hint');
try {
const r = await apiFetch('/api/audit/meta');
if (!r.ok) return;
const data = await r.json();
if (!hint) return;
if (!data.enabled) {
hint.hidden = false;
hint.textContent = auditT('settingsAudit.disabledHint', null, '审计功能已关闭,新操作不会写入审计表。');
return;
}
const days = data.retention_days;
if (days > 0) {
hint.hidden = false;
hint.textContent = auditT('settingsAudit.retentionHint', { days: days },
'审计记录保留 ' + days + ' 天,超期自动清理。');
} else {
hint.hidden = true;
}
} catch (_) { /* ignore */ }
}
async function loadAuditSummary() {
if (typeof apiFetch !== 'function') return;
const wrap = document.getElementById('audit-summary-stats');
try {
const r = await apiFetch('/api/audit/summary?' + buildAuditQueryParams(true));
if (!r.ok) return;
const data = await r.json();
if (wrap) wrap.hidden = false;
const elTotal = document.getElementById('audit-stat-total');
const elFail = document.getElementById('audit-stat-failures');
const elRecent = document.getElementById('audit-stat-recent');
if (elTotal) elTotal.textContent = String(data.total != null ? data.total : 0);
if (elFail) elFail.textContent = String(data.failures != null ? data.failures : 0);
if (elRecent) elRecent.textContent = String(data.recent_7d != null ? data.recent_7d : 0);
} catch (_) { /* ignore */ }
}
async function loadAuditLogs(page) {
if (typeof apiFetch !== 'function') return;
auditLogsPage = page != null ? page : auditLogsPage;
const listEl = document.getElementById('audit-log-list');
if (listEl) {
listEl.innerHTML = '<div class="loading-spinner">' + (typeof escapeHtml === 'function' ? escapeHtml(auditT('settingsAudit.loading', null, '加载中...')) : '加载中...') + '</div>';
}
try {
const qs = buildAuditQueryParams(false);
const r = await apiFetch('/api/audit/logs?' + qs);
if (!r.ok) {
const err = await r.json().catch(function () { return {}; });
throw new Error(err.error || r.statusText);
}
const data = await r.json();
renderAuditLogs(data.logs || []);
auditLogsTotal = typeof data.total === 'number' ? data.total : 0;
const maxPage = Math.max(1, Math.ceil(auditLogsTotal / auditLogsPageSize));
if (auditLogsPage > maxPage) {
loadAuditLogs(maxPage);
return;
}
renderAuditLogsPagination();
loadAuditSummary();
} catch (e) {
if (listEl) {
const msg = typeof escapeHtml === 'function' ? escapeHtml(e.message || String(e)) : (e.message || String(e));
listEl.innerHTML = '<div class="monitor-empty">' + msg + '</div>';
}
if (typeof showToast === 'function') {
showToast(e.message || String(e), 'error');
}
}
}
function renderAuditLogs(logs) {
const listEl = document.getElementById('audit-log-list');
if (!listEl) return;
const esc = typeof escapeHtml === 'function' ? escapeHtml : function (s) { return String(s || ''); };
if (!logs.length) {
listEl.innerHTML = '<div class="c2-empty">' + esc(auditT('settingsAudit.empty', null, '暂无审计记录')) + '</div>';
return;
}
listEl.innerHTML = logs.map(function (log) {
const lvl = log.result === 'failure' ? 'warn' : (log.level || 'info');
const catLabel = esc(auditCategoryLabel(log.category || ''));
const actionLabel = esc(auditActionLabel(log.action || ''));
const msg = esc(log.message || '');
const ip = esc(log.clientIp || '');
const when = esc(formatAuditTime(log.createdAt));
const res = esc(log.result || '');
const rid = log.resourceId || '';
const meta = rid ? (' · ' + esc(rid)) : '';
const eid = esc(log.id || '');
return (
'<div class="c2-event-item audit-log-item" role="button" tabindex="0" ' +
'onclick="showAuditLogDetail(\'' + eid + '\')" ' +
'onkeydown="if(event.key===\'Enter\'||event.key===\' \'){event.preventDefault();showAuditLogDetail(\'' + eid + '\')}">' +
'<div class="c2-event-level ' + esc(lvl) + '"></div>' +
'<div class="c2-event-content">' +
'<div class="c2-event-message">' + msg + '</div>' +
'<div class="c2-event-meta">' + when + ' · ' + catLabel + '/' + actionLabel + ' · ' + res + meta +
(ip ? ' · IP ' + ip : '') +
'</div></div></div>'
);
}).join('');
if (typeof applyTranslations === 'function') {
applyTranslations(listEl);
}
}
function renderAuditLogsPagination() {
const container = document.getElementById('audit-logs-pagination');
if (!container) return;
const esc = typeof escapeHtml === 'function' ? escapeHtml : function (s) { return String(s || ''); };
const total = auditLogsTotal || 0;
const currentPage = auditLogsPage || 1;
const pageSize = auditLogsPageSize || 20;
const totalPages = Math.max(1, Math.ceil(total / pageSize));
const start = total === 0 ? 0 : (currentPage - 1) * pageSize + 1;
const end = total === 0 ? 0 : Math.min(currentPage * pageSize, total);
const infoText = auditT('mcpMonitor.paginationInfo', { start: start, end: end, total: total },
'显示 ' + start + '-' + end + ' / 共 ' + total + ' 条记录');
const perPageLabel = auditT('mcpMonitor.perPageLabel', null, '每页显示');
const firstPageLabel = auditT('mcp.firstPage', null, '首页');
const prevPageLabel = auditT('mcp.prevPage', null, '上一页');
const pageInfoText = auditT('mcp.pageInfo', { page: currentPage, total: totalPages },
'第 ' + currentPage + ' / ' + totalPages + ' 页');
const nextPageLabel = auditT('mcp.nextPage', null, '下一页');
const lastPageLabel = auditT('mcp.lastPage', null, '末页');
const disabledFirst = currentPage === 1 || total === 0;
const disabledLast = currentPage >= totalPages || total === 0;
let html = '<div class="monitor-pagination">';
html += '<div class="pagination-info">';
html += '<span>' + esc(infoText) + '</span>';
html += '<label class="pagination-page-size">' + esc(perPageLabel);
html += '<select id="audit-page-size" onchange="onAuditPageSizeChange()">';
[10, 20, 50, 100].forEach(function (n) {
html += '<option value="' + n + '"' + (pageSize === n ? ' selected' : '') + '>' + n + '</option>';
});
html += '</select></label></div>';
html += '<div class="pagination-controls">';
html += '<button type="button" class="btn-secondary" onclick="goAuditLogsPage(1)"' + (disabledFirst ? ' disabled' : '') + '>' + esc(firstPageLabel) + '</button>';
html += '<button type="button" class="btn-secondary" onclick="goAuditLogsPage(' + (currentPage - 1) + ')"' + (disabledFirst ? ' disabled' : '') + '>' + esc(prevPageLabel) + '</button>';
html += '<span class="pagination-page">' + esc(pageInfoText) + '</span>';
html += '<button type="button" class="btn-secondary" onclick="goAuditLogsPage(' + (currentPage + 1) + ')"' + (disabledLast ? ' disabled' : '') + '>' + esc(nextPageLabel) + '</button>';
html += '<button type="button" class="btn-secondary" onclick="goAuditLogsPage(' + totalPages + ')"' + (disabledLast ? ' disabled' : '') + '>' + esc(lastPageLabel) + '</button>';
html += '</div></div>';
container.innerHTML = html;
}
function goAuditLogsPage(p) {
const totalPages = Math.max(1, Math.ceil((auditLogsTotal || 0) / (auditLogsPageSize || 20)));
if (p < 1 || p > totalPages) return;
loadAuditLogs(p);
}
function filterAuditLogs() {
auditLogsPage = 1;
loadAuditLogs(1);
}
function resetAuditLogFilters() {
const cat = document.getElementById('audit-filter-category');
const act = document.getElementById('audit-filter-action');
const res = document.getElementById('audit-filter-result');
const q = document.getElementById('audit-filter-q');
const since = document.getElementById('audit-filter-since');
const until = document.getElementById('audit-filter-until');
if (cat) cat.value = '';
if (res) res.value = '';
if (q) q.value = '';
if (since) since.value = '';
if (until) until.value = '';
rebuildAuditActionSelect();
filterAuditLogs();
}
/** 资源已被删除/移除的审计操作,不再提供「打开关联资源」 */
const AUDIT_ACTIONS_RESOURCE_REMOVED = {
delete: true,
item_delete: true,
connection_delete: true,
listener_delete: true,
session_delete: true,
task_delete: true,
execution_delete: true,
execution_delete_batch: true,
delete_queue: true,
delete_batch_task: true,
markdown_delete: true
};
function auditResourceWasRemoved(log) {
if (!log || !log.action) return false;
return !!AUDIT_ACTIONS_RESOURCE_REMOVED[log.action];
}
/** 删除类操作,或关联资源已不存在(由详情 API resourceAvailable 判定) */
function auditResourceUnavailable(log) {
if (!log) return false;
if (auditResourceWasRemoved(log)) return true;
return log.resourceAvailable === false;
}
function auditResourceMeta(log) {
if (!log || !log.resourceId) return '';
const esc = typeof escapeHtml === 'function' ? escapeHtml : function (s) { return String(s || ''); };
const id = esc(log.resourceId);
if (auditResourceUnavailable(log)) {
const idLabel = esc(auditT('settingsAudit.resourceIdLabel', null, '资源 ID'));
const removed = esc(auditT('settingsAudit.resourceRemoved', null, '(关联对象已删除)'));
return '<p class="audit-resource-meta"><strong>' + idLabel + ':</strong> <code>' + id +
'</code> <span class="audit-resource-removed">' + removed + '</span></p>';
}
const link = auditResourceLink(log);
return link || ('<p><strong>ID:</strong> ' + id + '</p>');
}
async function auditOpenConversationChat(conversationId) {
const id = String(conversationId || '').trim();
if (!id) return;
if (typeof apiFetch === 'function') {
try {
const r = await apiFetch('/api/conversations/' + encodeURIComponent(id));
if (!r.ok) {
if (typeof showToast === 'function') {
showToast(auditT('settingsAudit.resourceRemoved', null, '(关联对象已删除)'), 'warning');
}
return;
}
} catch (_) {
return;
}
}
closeAuditDetailModal();
if (typeof switchPage === 'function') {
switchPage('chat');
}
if (typeof loadConversation === 'function') {
void loadConversation(id);
}
}
window.auditOpenConversationChat = auditOpenConversationChat;
function auditResourceLink(log) {
if (!log || auditResourceUnavailable(log)) return '';
const type = log.resourceType || '';
const id = log.resourceId || '';
if (!id) return '';
const esc = typeof escapeHtml === 'function' ? escapeHtml : function (s) { return String(s || ''); };
const label = esc(auditT('settingsAudit.openResource', null, '打开关联资源'));
if (type === 'conversation' || (type === '' && id.length > 8 && !id.startsWith('c2_'))) {
const chatLabel = esc(auditT('settingsAudit.openResourceChat', null, '打开关联资源(chat'));
return '<p><button type="button" class="btn-secondary btn-small audit-open-chat-btn" data-conversation-id="' +
esc(id) + '">' + chatLabel + '</button></p>';
}
if (type === 'vulnerability' || type === 'batch_queue') {
const page = type === 'batch_queue' ? 'tasks' : 'vulnerabilities';
return '<p><button type="button" class="btn-secondary btn-small" onclick="closeAuditDetailModal();if(typeof switchPage===\'function\'){switchPage(\'' + page + '\');}">' + label + '</button></p>';
}
if (type === 'c2_listener' || type === 'c2_session' || type === 'c2_task') {
const page = type === 'c2_listener' ? 'c2-listeners' : (type === 'c2_session' ? 'c2-sessions' : 'c2-tasks');
return '<p><button type="button" class="btn-secondary btn-small" onclick="closeAuditDetailModal();if(typeof switchPage===\'function\'){switchPage(\'' + page + '\');}">' + label + '</button></p>';
}
if (type === 'webshell_connection') {
return '<p><button type="button" class="btn-secondary btn-small" onclick="closeAuditDetailModal();if(typeof switchPage===\'function\'){switchPage(\'webshell\');}">' + label + '</button></p>';
}
if (type === 'knowledge_item') {
return '<p><button type="button" class="btn-secondary btn-small" onclick="closeAuditDetailModal();if(typeof switchPage===\'function\'){switchPage(\'knowledge-management\');}">' + label + '</button></p>';
}
if (type === 'chat_upload') {
return '<p><button type="button" class="btn-secondary btn-small" onclick="closeAuditDetailModal();if(typeof switchPage===\'function\'){switchPage(\'chat-files\');}">' + label + '</button></p>';
}
if (type === 'tool_execution') {
return '<p><button type="button" class="btn-secondary btn-small" onclick="closeAuditDetailModal();if(typeof switchPage===\'function\'){switchPage(\'mcp-monitor\');}">' + label + '</button></p>';
}
if (type === 'role' || type === 'skill' || type === 'markdown_agent') {
return '<p><button type="button" class="btn-secondary btn-small" onclick="closeAuditDetailModal();if(typeof switchSettingsSection===\'function\'){switchPage(\'settings\');switchSettingsSection(\'roles\');}">' + label + '</button></p>';
}
return '';
}
function refreshAuditLogs() {
loadAuditLogs(auditLogsPage);
}
async function downloadAuditExport(url, filename) {
const r = await apiFetch(url);
if (!r.ok) {
const err = await r.json().catch(function () { return {}; });
throw new Error(err.error || r.statusText);
}
const blob = await r.blob();
const objectUrl = URL.createObjectURL(blob);
const a = document.createElement('a');
a.href = objectUrl;
a.download = filename;
a.click();
URL.revokeObjectURL(objectUrl);
}
function closeAuditExportMenu() {
const menu = document.getElementById('audit-export-menu');
const trigger = document.getElementById('audit-export-trigger');
if (menu) menu.hidden = true;
if (trigger) trigger.setAttribute('aria-expanded', 'false');
}
function toggleAuditExportMenu(ev) {
if (ev && ev.stopPropagation) ev.stopPropagation();
const menu = document.getElementById('audit-export-menu');
const trigger = document.getElementById('audit-export-trigger');
if (!menu) return;
const willOpen = menu.hidden;
if (willOpen) {
menu.hidden = false;
if (trigger) trigger.setAttribute('aria-expanded', 'true');
if (!window._auditExportMenuDocBound) {
window._auditExportMenuDocBound = true;
document.addEventListener('click', function () {
closeAuditExportMenu();
});
}
} else {
closeAuditExportMenu();
}
}
async function runAuditExport(format) {
closeAuditExportMenu();
if (format === 'csv') {
await exportAuditLogsCsv();
} else {
await exportAuditLogs();
}
}
async function exportAuditLogs() {
if (typeof apiFetch !== 'function') return;
try {
await downloadAuditExport(
'/api/audit/logs/export?' + buildAuditQueryParams(true),
'audit-logs-' + new Date().toISOString().slice(0, 10) + '.json'
);
if (typeof showToast === 'function') {
showToast(auditT('settingsAudit.exportDone', null, '导出完成'), 'success');
}
} catch (e) {
if (typeof showToast === 'function') {
showToast(e.message || String(e), 'error');
}
}
}
async function exportAuditLogsCsv() {
if (typeof apiFetch !== 'function') return;
try {
const qs = buildAuditQueryParams(true);
await downloadAuditExport(
'/api/audit/logs/export?' + (qs ? qs + '&' : '') + 'format=csv',
'audit-logs-' + new Date().toISOString().slice(0, 10) + '.csv'
);
if (typeof showToast === 'function') {
showToast(auditT('settingsAudit.exportDone', null, '导出完成'), 'success');
}
} catch (e) {
if (typeof showToast === 'function') {
showToast(e.message || String(e), 'error');
}
}
}
function closeAuditDetailModal() {
const el = document.getElementById('audit-detail-modal');
if (el) el.remove();
}
async function showAuditLogDetail(id) {
if (!id || typeof apiFetch !== 'function') return;
const esc = typeof escapeHtml === 'function' ? escapeHtml : function (s) { return String(s || ''); };
try {
const r = await apiFetch('/api/audit/logs/' + encodeURIComponent(id));
if (!r.ok) throw new Error('not found');
const data = await r.json();
const log = data.log || {};
const detail = log.detail ? JSON.stringify(log.detail, null, 2) : '';
closeAuditDetailModal();
const overlay = document.createElement('div');
overlay.id = 'audit-detail-modal';
overlay.className = 'modal';
overlay.style.display = 'block';
const catAction = esc(auditCategoryLabel(log.category || '')) + ' / ' + esc(auditActionLabel(log.action || ''));
overlay.innerHTML =
'<div class="modal-content" style="max-width: 720px;">' +
'<div class="modal-header">' +
'<h2>' + esc(auditT('settingsAudit.detailTitle', null, '审计详情')) + '</h2>' +
'<span class="modal-close" onclick="closeAuditDetailModal()">&times;</span>' +
'</div>' +
'<div class="modal-body audit-detail-body">' +
'<p><strong>' + esc(auditT('settingsAudit.detailTime', null, '时间')) + ':</strong> ' + esc(formatAuditTime(log.createdAt)) + '</p>' +
'<p><strong>' + esc(auditT('settingsAudit.detailCategory', null, '类别')) + ':</strong> ' + catAction + '</p>' +
'<p><strong>' + esc(auditT('settingsAudit.detailResult', null, '结果')) + ':</strong> ' + esc(log.result || '') + '</p>' +
'<p><strong>' + esc(auditT('settingsAudit.detailMessage', null, '说明')) + ':</strong> ' + esc(log.message || '') + '</p>' +
(log.clientIp ? '<p><strong>IP:</strong> ' + esc(log.clientIp) + '</p>' : '') +
(log.sessionHint ? '<p><strong>' + esc(auditT('settingsAudit.detailSession', null, '会话')) + ':</strong> ' + esc(log.sessionHint) + '</p>' : '') +
(log.userAgent ? '<p><strong>UA:</strong> ' + esc(log.userAgent) + '</p>' : '') +
auditResourceMeta(log) +
(detail ? '<pre class="audit-detail-pre">' + esc(detail) + '</pre>' : '') +
'</div>' +
'<div class="modal-footer"><button type="button" class="btn-secondary" onclick="closeAuditDetailModal()">' +
esc(auditT('common.close', null, '关闭')) + '</button></div>' +
'</div>';
document.body.appendChild(overlay);
const chatBtn = overlay.querySelector('.audit-open-chat-btn');
if (chatBtn) {
chatBtn.addEventListener('click', function () {
auditOpenConversationChat(chatBtn.getAttribute('data-conversation-id'));
});
}
overlay.addEventListener('click', function (ev) {
if (ev.target === overlay) closeAuditDetailModal();
});
} catch (e) {
if (typeof showToast === 'function') {
showToast(e.message || String(e), 'error');
}
}
}
function initAuditLogsSection() {
if (!document.getElementById('audit-log-list')) return;
initAuditPageSizeFromStorage();
rebuildAuditActionSelect();
loadAuditMeta();
loadAuditLogs(1);
}
+7
View File
@@ -282,6 +282,13 @@ async function submitLogin(event) {
}
async function refreshAppData(showTaskErrors = false) {
if (typeof initChatAgentModeFromConfig === 'function') {
try {
await initChatAgentModeFromConfig();
} catch (error) {
console.warn('刷新对话模式配置失败:', error);
}
}
await Promise.allSettled([
loadConversations(),
loadActiveTasks(showTaskErrors),
+48 -20
View File
@@ -151,6 +151,25 @@
return div.innerHTML;
}
/** 任务列表操作按钮(查看/取消/删除)— 事件委托 */
function bindC2TaskActionDelegation() {
if (document.documentElement.dataset.c2TaskActionsBound === '1') return;
document.documentElement.dataset.c2TaskActionsBound = '1';
document.addEventListener('click', function(e) {
const btn = e.target.closest('[data-c2-task-action]');
if (!btn) return;
e.preventDefault();
e.stopPropagation();
const action = btn.getAttribute('data-c2-task-action');
const id = btn.getAttribute('data-task-id');
if (!id) return;
if (action === 'view') C2.viewTask(id);
else if (action === 'cancel') C2.cancelTask(id);
else if (action === 'delete') C2.deleteTaskById(id);
});
}
bindC2TaskActionDelegation();
/** 监听器表单:Malleable Profile 下拉选项 HTMLvalue / 文本已转义) */
function listenerProfileSelectHtml(selectedProfileId) {
const sel = selectedProfileId ? String(selectedProfileId) : '';
@@ -1293,14 +1312,17 @@
return;
}
container.innerHTML = tasks.map(t => `
container.innerHTML = tasks.map(t => {
const rawId = t.id || '';
return `
<div class="c2-task-item-compact">
<span class="c2-task-status-dot ${t.status}"></span>
<span class="c2-task-type">${t.taskType}</span>
<span class="c2-task-status-dot ${escapeHtml(t.status || '')}"></span>
<span class="c2-task-type">${escapeHtml(t.taskType || '')}</span>
<span class="c2-task-meta">${escapeHtml(taskStatusLabel(t.status))} | ${formatDuration(t.durationMs)}</span>
<button class="btn-ghost btn-sm" onclick="C2.viewTask('${t.id}')">${escapeHtml(c2t('c2.tasks.view'))}</button>
<button type="button" class="btn-secondary btn-small" data-c2-task-action="view" data-task-id="${escapeHtml(rawId)}">${escapeHtml(c2t('c2.tasks.view'))}</button>
</div>
`).join('');
`;
}).join('');
});
};
@@ -1334,13 +1356,12 @@
<th>${escapeHtml(c2t('c2.tasks.colStatus'))}</th>
<th>${escapeHtml(c2t('c2.tasks.colDuration'))}</th>
<th>${escapeHtml(c2t('c2.tasks.colCreated'))}</th>
<th>${escapeHtml(c2t('c2.tasks.colActions'))}</th>
<th class="c2-task-table-col-actions">${escapeHtml(c2t('c2.tasks.colActions'))}</th>
</tr>
</thead>
<tbody>
${C2.tasks.map(t => {
const rawId = t.id || '';
const idJson = JSON.stringify(rawId);
const shortTaskId = rawId.length > 14 ? escapeHtml(rawId.substring(0, 12)) + '\u2026' : escapeHtml(rawId);
const sid = t.sessionId ? escapeHtml(String(t.sessionId).substring(0, 8)) + '\u2026' : '-';
return `
@@ -1356,12 +1377,14 @@
<td><span class="c2-status-badge ${escapeHtml(t.status || '')}">${escapeHtml(taskStatusLabel(t.status))}</span></td>
<td>${formatDuration(t.durationMs)}</td>
<td>${formatTime(t.createdAt)}</td>
<td>
<button type="button" class="btn-ghost btn-sm" onclick="C2.viewTask(${idJson})">${escapeHtml(c2t('c2.tasks.view'))}</button>
<td class="c2-task-table-col-actions">
<div class="c2-task-table-actions">
<button type="button" class="btn-secondary btn-small" data-c2-task-action="view" data-task-id="${escapeHtml(rawId)}">${escapeHtml(c2t('c2.tasks.view'))}</button>
${t.status === 'queued' || t.status === 'sent'
? `<button type="button" class="btn-danger btn-sm" onclick="C2.cancelTask(${idJson})">${escapeHtml(c2t('c2.tasks.cancelBtn'))}</button>`
? `<button type="button" class="btn-danger btn-small" data-c2-task-action="cancel" data-task-id="${escapeHtml(rawId)}">${escapeHtml(c2t('c2.tasks.cancelBtn'))}</button>`
: ''}
<button type="button" class="btn-secondary btn-sm c2-task-row-delete" onclick="C2.deleteTaskById(${idJson})" title="${delTitle}" aria-label="${delTitle}">${escapeHtml(c2t('c2.tasks.deleteBtn'))}</button>
<button type="button" class="btn-danger btn-small" data-c2-task-action="delete" data-task-id="${escapeHtml(rawId)}" title="${delTitle}" aria-label="${delTitle}">${escapeHtml(c2t('c2.tasks.deleteBtn'))}</button>
</div>
</td>
</tr>
`;
@@ -1387,10 +1410,10 @@
</div>
<div class="c2-modal-body">
<div class="c2-task-detail">
<div><strong>${escapeHtml(c2t('c2.tasks.labelId'))}:</strong> ${t.id}</div>
<div><strong>${escapeHtml(c2t('c2.tasks.labelSession'))}:</strong> ${t.sessionId}</div>
<div><strong>${escapeHtml(c2t('c2.tasks.labelType'))}:</strong> ${t.taskType}</div>
<div><strong>${escapeHtml(c2t('c2.tasks.labelStatus'))}:</strong> <span class="c2-status-badge ${t.status}">${escapeHtml(taskStatusLabel(t.status))}</span></div>
<div><strong>${escapeHtml(c2t('c2.tasks.labelId'))}:</strong> ${escapeHtml(t.id || '')}</div>
<div><strong>${escapeHtml(c2t('c2.tasks.labelSession'))}:</strong> ${escapeHtml(t.sessionId || '')}</div>
<div><strong>${escapeHtml(c2t('c2.tasks.labelType'))}:</strong> ${escapeHtml(t.taskType || '')}</div>
<div><strong>${escapeHtml(c2t('c2.tasks.labelStatus'))}:</strong> <span class="c2-status-badge ${escapeHtml(t.status || '')}">${escapeHtml(taskStatusLabel(t.status))}</span></div>
<div><strong>${escapeHtml(c2t('c2.tasks.labelCreated'))}:</strong> ${formatTime(t.createdAt)}</div>
<div><strong>${escapeHtml(c2t('c2.tasks.labelSent'))}:</strong> ${formatTime(t.sentAt)}</div>
<div><strong>${escapeHtml(c2t('c2.tasks.labelCompleted'))}:</strong> ${formatTime(t.completedAt)}</div>
@@ -1416,19 +1439,24 @@
renderTaskModal(local);
return;
}
apiRequest('GET', `${API_BASE}/tasks/${id}`).then(data => {
apiRequest('GET', `${API_BASE}/tasks/${encodeURIComponent(id)}`).then(data => {
if (data.error) {
showToast(String(data.error), 'error');
return;
}
if (data.task) renderTaskModal(data.task);
});
else showToast(c2t('c2.tasks.emptyAll'), 'warn');
}).catch(err => showToast(err.message || String(err), 'error'));
};
C2.cancelTask = function(id) {
apiRequest('POST', `${API_BASE}/tasks/${id}/cancel`, {}).then(data => {
if (data.error) showToast(data.error, 'error');
apiRequest('POST', `${API_BASE}/tasks/${encodeURIComponent(id)}/cancel`, {}).then(data => {
if (data.error) showToast(String(data.error), 'error');
else {
showToast(c2t('c2.tasks.toastCancelled'), 'success');
C2.loadTasks(C2.tasksPage || 1);
}
});
}).catch(err => showToast(err.message || String(err), 'error'));
};
// ============================================================================

Some files were not shown because too many files have changed in this diff Show More