mirror of
https://github.com/Ed1s0nZ/CyberStrikeAI.git
synced 2026-06-24 06:49:59 +02:00
Compare commits
75 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| e28ae39b9a | |||
| df34ceda68 | |||
| 3e69a50f87 | |||
| 53325ce07d | |||
| d85de3461b | |||
| 9306303d99 | |||
| 1e8f72ed74 | |||
| 0198f50314 | |||
| 560d0dca43 | |||
| 47486a49c2 | |||
| 476727933d | |||
| 8bb50e8323 | |||
| e74f2a2292 | |||
| 4799d0dba7 | |||
| 1db917061d | |||
| 41cd7db30f | |||
| 68b3265f3f | |||
| 05dc4395a1 | |||
| 637a35748b | |||
| 5d77a99236 | |||
| e84d936f85 | |||
| e748201ae8 | |||
| 7a3c67458c | |||
| 6e9e43eec8 | |||
| bca86e48ae | |||
| 3f3b8b4db4 | |||
| b366dc0287 | |||
| a52452ceea | |||
| 5b87667782 | |||
| 4f0e812d37 | |||
| 79691c021f | |||
| 5a8309a015 | |||
| 6244197339 | |||
| eb14aca05a | |||
| 091e8a4da8 | |||
| 48ce0c519e | |||
| afc37051c0 | |||
| 2964247361 | |||
| 02919df476 | |||
| c3294d96a2 | |||
| c8b8b41bda | |||
| 9a4c333b90 | |||
| 8e21ae290a | |||
| b9d102d046 | |||
| 8c85494a05 | |||
| c3d2a41301 | |||
| 1a2e282d46 | |||
| 8129f2147f | |||
| 4a9889f0af | |||
| 732d47a965 | |||
| e22382aab0 | |||
| b6ff80adf2 | |||
| 51f1cfde2f | |||
| b2c8913014 | |||
| ae98288b62 | |||
| 9955e856a0 | |||
| 018544e5f9 | |||
| c1c86e4632 | |||
| 08d77bc12b | |||
| ce73a7b3e4 | |||
| f78f424aab | |||
| e19d8e39bd | |||
| ecf594a25b | |||
| d5759f6d83 | |||
| 81b3f64b15 | |||
| 0e0f1352f0 | |||
| ffba311afd | |||
| d9ed36cfb1 | |||
| b7f80b78ee | |||
| 8f8e5cfff5 | |||
| 120f860640 | |||
| 90cd119a83 | |||
| 56d597e0c5 | |||
| 11ab5cde8f | |||
| 46a7d338a4 |
@@ -112,7 +112,7 @@ CyberStrikeAI is an **AI-native security testing platform** built in Go. It inte
|
||||
- 🔒 Password-protected web UI, audit logs, and SQLite persistence
|
||||
- 📚 Knowledge base (RAG) with embedding-based vector retrieval (cosine similarity), optional **Eino Compose** indexing pipeline, and configurable post-retrieval budgets / reranking hooks
|
||||
- 📁 Conversation grouping with pinning, rename, and batch management
|
||||
- 📂 **Project management**: group conversations and vulnerabilities by project; **shared facts** (project blackboard) persist cross-session context (targets, env, auth notes) with auto-injection for agents and MCP tools (`upsert_project_fact`, `get_project_fact`, …)
|
||||
- 📂 **Project management**: shared facts (blackboard) across sessions, `upsert_project_fact` + `links` to chain paths; attack-chain and project fact graph views
|
||||
- 🛡️ Vulnerability management with CRUD operations, severity tracking, status workflow, and statistics
|
||||
- 📋 Batch task management: create task queues, add multiple tasks, and execute them sequentially
|
||||
- 🎭 Role-based testing: predefined security testing roles (Penetration Testing, CTF, Web App Scanning, etc.) with custom prompts and tool restrictions
|
||||
@@ -551,6 +551,11 @@ multi_agent:
|
||||
# orchestrator_instruction_plan_execute / orchestrator_instruction_supervisor optional
|
||||
# eino_skills: { disable: false, filesystem_tools: true, skill_tool_name: skill }
|
||||
# eino_middleware: plantask_enable, checkpoint_dir, deep_model_retry_max_retries, deep_output_key, ...
|
||||
project:
|
||||
enabled: true # Enable project blackboard & fact MCP tools
|
||||
fact_index_max_runes: 65000
|
||||
fact_summary_max_runes: 24000
|
||||
default_inject_deprecated: false
|
||||
```
|
||||
|
||||
### Tool Definition Example (`tools/nmap.yaml`)
|
||||
|
||||
+6
-1
@@ -111,7 +111,7 @@ CyberStrikeAI 是一款 **AI 原生安全测试平台**,基于 Go 构建,集
|
||||
- 🔒 Web 登录保护、审计日志、SQLite 持久化
|
||||
- 📚 知识库(RAG):向量嵌入与余弦相似度检索(与 Eino `retriever.Retriever` 语义一致),可选 **Eino Compose** 索引流水线及检索后处理(预算、重排等配置项)
|
||||
- 📁 对话分组管理:支持分组创建、置顶、重命名、删除等操作
|
||||
- 📂 **项目管理**:按项目归类对话与漏洞;**共享事实**(项目黑板)在多会话间沉淀目标/环境/认证等认知,自动注入 Agent 上下文,支持 MCP 工具读写(`upsert_project_fact`、`get_project_fact` 等)
|
||||
- 📂 **项目管理**:共享事实(黑板)跨会话沉淀认知,`upsert_project_fact` + `links` 串联攻击路径;聊天攻击链与项目事实图可视化
|
||||
- 🛡️ 漏洞管理功能:完整的漏洞 CRUD 操作,支持严重程度分级、状态流转、按对话/严重程度/状态过滤,以及统计看板
|
||||
- 📋 批量任务管理:创建任务队列,批量添加任务,依次顺序执行,支持任务编辑与状态跟踪
|
||||
- 🎭 角色化测试:预设安全测试角色(渗透测试、CTF、Web 应用扫描等),支持自定义提示词和工具限制
|
||||
@@ -549,6 +549,11 @@ multi_agent:
|
||||
# orchestrator_instruction_plan_execute / orchestrator_instruction_supervisor 可选
|
||||
# eino_skills: { disable: false, filesystem_tools: true, skill_tool_name: skill }
|
||||
# eino_middleware: plantask_enable、checkpoint_dir、deep_model_retry_max_retries、deep_output_key 等
|
||||
project:
|
||||
enabled: true # 启用项目黑板与事实 MCP 工具
|
||||
fact_index_max_runes: 65000
|
||||
fact_summary_max_runes: 24000
|
||||
default_inject_deprecated: false
|
||||
```
|
||||
|
||||
### 工具模版示例(`tools/nmap.yaml`)
|
||||
|
||||
+11
-6
@@ -10,7 +10,7 @@
|
||||
# ============================================
|
||||
|
||||
# 前端显示的版本号(可选,不填则显示默认版本)
|
||||
version: "v1.6.41"
|
||||
version: "v1.6.45"
|
||||
# 服务器配置
|
||||
server:
|
||||
host: 0.0.0.0 # 监听地址,0.0.0.0 表示监听所有网络接口
|
||||
@@ -40,6 +40,9 @@ audit:
|
||||
retention_days: 15 # 0 表示不自动清理
|
||||
max_detail_bytes: 8192
|
||||
auth_failure_cooldown_seconds: 60 # 同一 IP 登录/改密失败审计最短间隔(秒);未配置时默认 60;-1 关闭节流
|
||||
# MCP 状态监控执行记录保留(tool_executions 表)
|
||||
monitor:
|
||||
retention_days: 90 # 省略时默认 90;0 表示不自动清理
|
||||
# ============================================
|
||||
# 对话相关配置
|
||||
# ============================================
|
||||
@@ -142,10 +145,10 @@ multi_agent:
|
||||
plan_execute_max_step_result_runes: 4000 # plan_execute 每步结果最大字符数(超出截断)
|
||||
plan_execute_keep_last_steps: 8 # plan_execute 仅保留最近 N 步正文,早期步骤折叠为标题
|
||||
checkpoint_dir: data/eino-checkpoints # P0:进程崩溃/OOM 后同会话自动 ADK Resume;正常结束会删 .ckpt;与「中断并继续」(last_react_*) 是两套机制
|
||||
run_retry_max_attempts: 0 # 429/5xx/网络抖动时整轮 Run 指数退避续跑;0=默认 10(与 deep_model_retry 互补,建议保持默认)
|
||||
run_retry_max_attempts: 0 # 429/5xx/网络抖动时可退避重试次数(run loop + summarization 共用 isEinoTransientRunError);0=默认 10
|
||||
run_retry_max_backoff_sec: 0 # 单次退避上限秒数;0=默认 30
|
||||
deep_output_key: final_answer # P0:Eino session 写入最终助手结论(框架内部;Deep/Supervisor 主/eino_single)
|
||||
deep_model_retry_max_retries: 3 # P0:单次 ChatModel API 失败时框架自动重试(超时/502 等);子代理模型不受此项影响
|
||||
deep_model_retry_max_retries: 0 # 已废弃,请用 run_retry_max_attempts;保留字段仅为兼容旧配置
|
||||
task_tool_description_prefix: "" # 非空:仅 Deep 的 task 工具使用自定义描述前缀,运行时会拼接子代理名称;空则走 Eino 默认生成逻辑
|
||||
# Eino callbacks + OpenTelemetry:框架级 span(与 Zap 对齐);默认不向终端用户 UI 推 eino_trace_*(见 sse_trace_to_client)
|
||||
eino_callbacks:
|
||||
@@ -308,7 +311,9 @@ roles_dir: roles # 角色配置文件目录(相对于配置文件所在目录
|
||||
project:
|
||||
enabled: true
|
||||
# default_project_id: "" # 可选:机器人/批量任务创建对话时的默认项目 ID
|
||||
fact_index_max_runes: 6500
|
||||
fact_summary_max_runes: 2400
|
||||
fact_index_max_runes: 65000
|
||||
# 事实关系速览段预算(从索引总预算中预留)
|
||||
fact_index_path_max_runes: 10000
|
||||
fact_summary_max_runes: 24000
|
||||
default_inject_deprecated: false
|
||||
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
package agent
|
||||
|
||||
import (
|
||||
"cyberstrike-ai/internal/project"
|
||||
"cyberstrike-ai/internal/projectprompt"
|
||||
)
|
||||
|
||||
// DefaultSingleAgentSystemPrompt 单代理(Eino ADK / MCP)内置系统提示;可通过 agent.system_prompt_path 覆盖为文件。
|
||||
@@ -107,7 +107,7 @@ func DefaultSingleAgentSystemPrompt() string {
|
||||
- 若最近一步得到 404/空结果/无效响应,不得直接结束;至少再进行一次“同目标不同策略”的验证(如变更路径、参数、请求方法、上下文来源)。
|
||||
- 避免无效空转:同一工具+同类参数连续失败 3 次后,必须切换策略(改工具、改入口、改假设)并说明切换原因。
|
||||
|
||||
` + project.FactRecordingBlackboardSection(false) + `
|
||||
` + projectprompt.FactRecordingBlackboardSection(false) + `
|
||||
|
||||
## 技能库(Skills)与知识库
|
||||
|
||||
|
||||
+14
-1
@@ -25,6 +25,7 @@ import (
|
||||
"cyberstrike-ai/internal/logger"
|
||||
"cyberstrike-ai/internal/mcp"
|
||||
"cyberstrike-ai/internal/mcp/builtin"
|
||||
"cyberstrike-ai/internal/monitor"
|
||||
"cyberstrike-ai/internal/robot"
|
||||
"cyberstrike-ai/internal/security"
|
||||
"cyberstrike-ai/internal/skillpackage"
|
||||
@@ -99,6 +100,10 @@ func New(cfg *config.Config, log *logger.Logger, configPath string) (*App, error
|
||||
auditSvc.PurgeExpired()
|
||||
audit.StartRetentionLoop(auditSvc, log.Logger)
|
||||
|
||||
monitorRetention := monitor.NewService(db, cfg, log.Logger)
|
||||
monitorRetention.PurgeExpired()
|
||||
monitor.StartRetentionLoop(monitorRetention, log.Logger)
|
||||
|
||||
// 创建MCP服务器(带数据库持久化)
|
||||
mcpServer := mcp.NewServerWithStorage(log.Logger, db)
|
||||
mcpServer.ConfigureHTTPToolCallTimeoutFromAgentMinutes(cfg.Agent.ToolTimeoutMinutes)
|
||||
@@ -298,7 +303,8 @@ func New(cfg *config.Config, log *logger.Logger, configPath string) (*App, error
|
||||
plantaskBase := filepath.Join(skillsDir, plantaskRel)
|
||||
// Match eino_adk_run_loop: checkpoint_dir is used as configured (relative to process CWD when not absolute).
|
||||
checkpointBase := strings.TrimSpace(cfg.MultiAgent.EinoMiddleware.CheckpointDir)
|
||||
db.SetEinoConversationDirs(plantaskBase, checkpointBase)
|
||||
reductionRoot := strings.TrimSpace(cfg.MultiAgent.EinoMiddleware.ReductionRootDir)
|
||||
db.SetEinoConversationDirs(plantaskBase, checkpointBase, reductionRoot)
|
||||
agent.SetPromptBaseDir(configDir)
|
||||
|
||||
agentsDir := cfg.AgentsDir
|
||||
@@ -325,6 +331,7 @@ func New(cfg *config.Config, log *logger.Logger, configPath string) (*App, error
|
||||
}
|
||||
monitorHandler := handler.NewMonitorHandler(mcpServer, executor, db, log.Logger)
|
||||
monitorHandler.SetAudit(auditSvc)
|
||||
monitorHandler.SetMonitorRetention(monitorRetention)
|
||||
monitorHandler.SetExternalMCPManager(externalMCPMgr) // 设置外部MCP管理器,以便获取外部MCP执行记录
|
||||
notificationHandler := handler.NewNotificationHandler(db, agentHandler, log.Logger)
|
||||
groupHandler := handler.NewGroupHandler(db, log.Logger)
|
||||
@@ -368,6 +375,7 @@ func New(cfg *config.Config, log *logger.Logger, configPath string) (*App, error
|
||||
// 创建OpenAPI处理器
|
||||
conversationHandler := handler.NewConversationHandler(db, log.Logger)
|
||||
conversationHandler.SetAudit(auditSvc)
|
||||
conversationHandler.SetTaskStopper(agentHandler)
|
||||
auditHandler := handler.NewAuditHandler(db, auditSvc, log.Logger)
|
||||
robotHandler := handler.NewRobotHandler(cfg, db, agentHandler, log.Logger)
|
||||
openAPIHandler := handler.NewOpenAPIHandler(db, log.Logger, conversationHandler, agentHandler)
|
||||
@@ -1069,6 +1077,11 @@ func setupRoutes(
|
||||
protected.GET("/projects/:id", projectHandler.GetProject)
|
||||
protected.PUT("/projects/:id", projectHandler.UpdateProject)
|
||||
protected.DELETE("/projects/:id", projectHandler.DeleteProject)
|
||||
protected.GET("/projects/:id/fact-graph", projectHandler.GetFactGraph)
|
||||
protected.GET("/projects/:id/fact-edges", projectHandler.ListFactEdges)
|
||||
protected.POST("/projects/:id/fact-edges", projectHandler.CreateFactEdge)
|
||||
protected.DELETE("/projects/:id/fact-edges/:edgeId", projectHandler.DeleteFactEdge)
|
||||
protected.POST("/projects/:id/promote-attack-chain/:conversationId", projectHandler.PromoteAttackChain)
|
||||
protected.GET("/projects/:id/facts", projectHandler.ListFacts)
|
||||
protected.POST("/projects/:id/facts", projectHandler.CreateFact)
|
||||
protected.PUT("/projects/:id/facts/:factId", projectHandler.UpdateFact)
|
||||
|
||||
@@ -89,6 +89,28 @@ func registerProjectFactTools(mcpServer *mcp.Server, db *database.DB, cfg *confi
|
||||
"type": "string",
|
||||
"description": "可选:关联的漏洞记录 ID",
|
||||
},
|
||||
"links": map[string]interface{}{
|
||||
"type": "array",
|
||||
"description": "可选:关系边(from → 当前 fact)。finding 至少 1 条 {from:target/*, type:discovered_on};finding 上记录 exploit 用 {from:exploit/*, type:exploits}。省略保留已有边;传 [] 清空全部关系边。",
|
||||
"items": map[string]interface{}{
|
||||
"type": "object",
|
||||
"properties": map[string]interface{}{
|
||||
"from": map[string]interface{}{
|
||||
"type": "string",
|
||||
"description": "来源 fact_key:存储为 from → 当前 fact",
|
||||
},
|
||||
"type": map[string]interface{}{
|
||||
"type": "string",
|
||||
"description": "depends_on | leads_to | enables | exploits | discovered_on | contains | part_of | supports",
|
||||
},
|
||||
"confidence": map[string]interface{}{
|
||||
"type": "string",
|
||||
"description": "confirmed | tentative | deprecated",
|
||||
},
|
||||
},
|
||||
"required": []string{"from", "type"},
|
||||
},
|
||||
},
|
||||
},
|
||||
"required": []string{"fact_key", "summary"},
|
||||
},
|
||||
@@ -124,7 +146,26 @@ func registerProjectFactTools(mcpServer *mcp.Server, db *database.DB, cfg *confi
|
||||
if err != nil {
|
||||
return textResult("错误: "+err.Error(), true), nil
|
||||
}
|
||||
if _, hasLinks := args["links"]; hasLinks {
|
||||
linkInputs, err := project.ParseFactLinkInputs(args["links"])
|
||||
if err != nil {
|
||||
return textResult("错误: "+err.Error(), true), nil
|
||||
}
|
||||
convID := agent.ConversationIDFromContext(ctx)
|
||||
if err := project.PersistFactLinksFromParsed(db, projectID, created.FactKey, convID, linkInputs, true); err != nil {
|
||||
return textResult("错误: 保存关系边失败: "+err.Error(), true), nil
|
||||
}
|
||||
created, _ = db.GetProjectFactByKey(projectID, created.FactKey)
|
||||
} else if parsed := project.ParseLinksFromBody(created.Body); len(parsed) > 0 {
|
||||
if err := project.PersistFactIncomingLinks(db, projectID, created.FactKey, parsed, true); err != nil {
|
||||
return textResult("错误: 从 body 解析边失败: "+err.Error(), true), nil
|
||||
}
|
||||
created, _ = db.GetProjectFactByKey(projectID, created.FactKey)
|
||||
}
|
||||
msg := fmt.Sprintf("事实已保存。\nfact_key: %s\nid: %s\nconfidence: %s", created.FactKey, created.ID, created.Confidence)
|
||||
if in, _ := db.ListIncomingProjectFactEdges(projectID, created.FactKey); len(in) > 0 {
|
||||
msg += "\n关系边: " + project.FormatFactLinksText(in)
|
||||
}
|
||||
if warn := project.SparseBodyWarningIfNeeded(f.Category, f.FactKey, f.Body); warn != "" {
|
||||
msg += warn
|
||||
}
|
||||
@@ -164,6 +205,18 @@ func registerProjectFactTools(mcpServer *mcp.Server, db *database.DB, cfg *confi
|
||||
if f.SourceConversationID != "" {
|
||||
msg += fmt.Sprintf("\nsource_conversation_id: %s", f.SourceConversationID)
|
||||
}
|
||||
if in, _ := db.ListIncomingProjectFactEdges(projectID, f.FactKey); len(in) > 0 {
|
||||
msg += "\n关系边(from → 本 fact):\n"
|
||||
for _, e := range in {
|
||||
msg += fmt.Sprintf("- %s ← %s (%s)\n", e.EdgeType, e.SourceFactKey, e.Confidence)
|
||||
}
|
||||
}
|
||||
if out, _ := db.ListOutgoingProjectFactEdges(projectID, f.FactKey); len(out) > 0 {
|
||||
msg += "指向其他事实:\n"
|
||||
for _, e := range out {
|
||||
msg += fmt.Sprintf("- %s → %s (%s)\n", e.EdgeType, e.TargetFactKey, e.Confidence)
|
||||
}
|
||||
}
|
||||
msg += "\n\n--- body ---\n" + f.Body
|
||||
if warn := project.SparseBodyWarningIfNeeded(f.Category, f.FactKey, f.Body); warn != "" {
|
||||
msg += warn
|
||||
|
||||
@@ -0,0 +1,203 @@
|
||||
package attackchain
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
"cyberstrike-ai/internal/database"
|
||||
"cyberstrike-ai/internal/project"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
var promoteSlugSanitizer = regexp.MustCompile(`[^a-z0-9._/-]+`)
|
||||
|
||||
// PromoteToProjectResult 攻击链沉淀结果。
|
||||
type PromoteToProjectResult struct {
|
||||
FactsCreated int `json:"facts_created"`
|
||||
FactsUpdated int `json:"facts_updated"`
|
||||
EdgesCreated int `json:"edges_created"`
|
||||
FactKeys []string `json:"fact_keys"`
|
||||
Graph *database.ProjectFactGraph `json:"graph,omitempty"`
|
||||
}
|
||||
|
||||
// PromoteToProject 将对话攻击链沉淀为项目事实与边。
|
||||
func PromoteToProject(db *database.DB, projectID, conversationID string) (*PromoteToProjectResult, error) {
|
||||
if db == nil {
|
||||
return nil, fmt.Errorf("database 未初始化")
|
||||
}
|
||||
projectID = strings.TrimSpace(projectID)
|
||||
conversationID = strings.TrimSpace(conversationID)
|
||||
if projectID == "" || conversationID == "" {
|
||||
return nil, fmt.Errorf("project_id 与 conversation_id 必填")
|
||||
}
|
||||
if _, err := db.GetProject(projectID); err != nil {
|
||||
return nil, fmt.Errorf("项目不存在")
|
||||
}
|
||||
conv, err := db.GetConversation(conversationID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("对话不存在")
|
||||
}
|
||||
if pid := strings.TrimSpace(conv.ProjectID); pid != "" && pid != projectID {
|
||||
return nil, fmt.Errorf("对话已绑定其他项目")
|
||||
}
|
||||
|
||||
nodes, err := db.LoadAttackChainNodes(conversationID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
edges, err := db.LoadAttackChainEdges(conversationID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(nodes) == 0 {
|
||||
return nil, fmt.Errorf("该对话尚无攻击链,请先在对话中生成攻击链")
|
||||
}
|
||||
|
||||
res := &PromoteToProjectResult{}
|
||||
nodeToKey := make(map[string]string, len(nodes))
|
||||
usedKeys := map[string]int{}
|
||||
|
||||
for _, node := range nodes {
|
||||
key := allocatePromoteFactKey(node, usedKeys)
|
||||
nodeToKey[node.ID] = key
|
||||
category := mapPromoteNodeCategory(node.Type)
|
||||
existing, getErr := db.GetProjectFactByKey(projectID, key)
|
||||
f := &database.ProjectFact{
|
||||
ProjectID: projectID,
|
||||
FactKey: key,
|
||||
Category: category,
|
||||
Summary: strings.TrimSpace(node.Label),
|
||||
Body: formatPromotedFactBody(node, conversationID),
|
||||
Confidence: "tentative",
|
||||
SourceConversationID: conversationID,
|
||||
}
|
||||
if getErr == nil && existing != nil {
|
||||
f.ID = existing.ID
|
||||
f.CreatedAt = existing.CreatedAt
|
||||
if strings.TrimSpace(f.Summary) == "" {
|
||||
f.Summary = existing.Summary
|
||||
}
|
||||
if _, err := db.UpsertProjectFact(f); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
res.FactsUpdated++
|
||||
} else {
|
||||
if _, err := db.UpsertProjectFact(f); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
res.FactsCreated++
|
||||
}
|
||||
res.FactKeys = append(res.FactKeys, key)
|
||||
}
|
||||
|
||||
for _, edge := range edges {
|
||||
srcKey, ok1 := nodeToKey[edge.Source]
|
||||
tgtKey, ok2 := nodeToKey[edge.Target]
|
||||
if !ok1 || !ok2 || srcKey == tgtKey {
|
||||
continue
|
||||
}
|
||||
edgeType := mapPromoteEdgeType(edge.Type)
|
||||
incoming, _ := db.ListIncomingProjectFactEdges(projectID, tgtKey)
|
||||
merged := project.MergeLinkFromInputsUnique(promoteFromEdgeInputsFromDB(incoming), []database.ProjectFactEdgeFromInput{{From: srcKey, Type: edgeType}})
|
||||
if err := db.ReplaceIncomingProjectFactEdges(projectID, tgtKey, merged); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
res.EdgesCreated++
|
||||
if fact, err := db.GetProjectFactByKey(projectID, tgtKey); err == nil {
|
||||
in, _ := db.ListIncomingProjectFactEdges(projectID, tgtKey)
|
||||
fact.Body = project.SyncBodyLinksSection(fact.Body, in)
|
||||
_, _ = db.UpsertProjectFact(fact)
|
||||
}
|
||||
}
|
||||
|
||||
graph, _ := project.BuildProjectFactGraph(db, projectID, "full", true)
|
||||
res.Graph = graph
|
||||
return res, nil
|
||||
}
|
||||
|
||||
func promoteFromEdgeInputsFromDB(edges []*database.ProjectFactEdge) []database.ProjectFactEdgeFromInput {
|
||||
out := make([]database.ProjectFactEdgeFromInput, 0, len(edges))
|
||||
for _, e := range edges {
|
||||
out = append(out, database.ProjectFactEdgeFromInput{From: e.SourceFactKey, Type: e.EdgeType, Confidence: e.Confidence})
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func mapPromoteNodeCategory(nodeType string) string {
|
||||
switch strings.ToLower(strings.TrimSpace(nodeType)) {
|
||||
case "target":
|
||||
return project.FactCategoryTarget
|
||||
case "vulnerability":
|
||||
return project.FactCategoryFinding
|
||||
case "action":
|
||||
return project.FactCategoryChain
|
||||
default:
|
||||
return project.FactCategoryNote
|
||||
}
|
||||
}
|
||||
|
||||
func mapPromoteEdgeType(t string) string {
|
||||
switch strings.ToLower(strings.TrimSpace(t)) {
|
||||
case "discovers", "discovered_on", "targets":
|
||||
return "discovered_on"
|
||||
case "exploits":
|
||||
return "exploits"
|
||||
case "enables":
|
||||
return "enables"
|
||||
case "depends_on":
|
||||
return "depends_on"
|
||||
default:
|
||||
return "leads_to"
|
||||
}
|
||||
}
|
||||
|
||||
func allocatePromoteFactKey(node Node, used map[string]int) string {
|
||||
prefix := "chain/"
|
||||
switch strings.ToLower(strings.TrimSpace(node.Type)) {
|
||||
case "target":
|
||||
prefix = "target/"
|
||||
case "vulnerability":
|
||||
prefix = "finding/"
|
||||
case "action":
|
||||
prefix = "chain/"
|
||||
}
|
||||
base := promoteSlugify(node.Label)
|
||||
if base == "" {
|
||||
base = promoteSlugify(node.ID)
|
||||
}
|
||||
if base == "" {
|
||||
base = uuid.New().String()[:8]
|
||||
}
|
||||
key := prefix + base
|
||||
if n, ok := used[key]; ok {
|
||||
n++
|
||||
used[key] = n
|
||||
key = fmt.Sprintf("%s-%d", key, n)
|
||||
} else {
|
||||
used[key] = 1
|
||||
}
|
||||
return key
|
||||
}
|
||||
|
||||
func promoteSlugify(s string) string {
|
||||
s = strings.ToLower(strings.TrimSpace(s))
|
||||
s = strings.NewReplacer(" ", "-", "—", "-", "–", "-", "/", "-").Replace(s)
|
||||
s = promoteSlugSanitizer.ReplaceAllString(s, "-")
|
||||
s = strings.Trim(s, "-")
|
||||
if len(s) > 64 {
|
||||
s = s[:64]
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
func formatPromotedFactBody(node Node, conversationID string) string {
|
||||
var b strings.Builder
|
||||
b.WriteString("## 来源\n")
|
||||
b.WriteString(fmt.Sprintf("- 对话攻击链沉淀\n- source_conversation_id: %s\n- node_id: %s\n- node_type: %s\n\n", conversationID, node.ID, node.Type))
|
||||
b.WriteString("## 摘要\n")
|
||||
b.WriteString(strings.TrimSpace(node.Label))
|
||||
b.WriteString("\n\n## 关联\n- 结构化关系边(自动同步):\n (见项目攻击路径图)\n")
|
||||
return b.String()
|
||||
}
|
||||
@@ -27,6 +27,7 @@ type Config struct {
|
||||
Database DatabaseConfig `yaml:"database"`
|
||||
Auth AuthConfig `yaml:"auth"`
|
||||
Audit AuditConfig `yaml:"audit,omitempty" json:"audit,omitempty"`
|
||||
Monitor MonitorConfig `yaml:"monitor,omitempty" json:"monitor,omitempty"`
|
||||
ExternalMCP ExternalMCPConfig `yaml:"external_mcp,omitempty"`
|
||||
Knowledge KnowledgeConfig `yaml:"knowledge,omitempty"`
|
||||
C2 C2Config `yaml:"c2,omitempty" json:"c2,omitempty"` // 内置 C2 总开关;未配置时默认启用
|
||||
@@ -45,6 +46,7 @@ type ProjectConfig struct {
|
||||
Enabled bool `yaml:"enabled" json:"enabled"`
|
||||
DefaultProjectID string `yaml:"default_project_id,omitempty" json:"default_project_id,omitempty"` // 机器人/批量等无显式项目时绑定的默认项目
|
||||
FactIndexMaxRunes int `yaml:"fact_index_max_runes,omitempty" json:"fact_index_max_runes,omitempty"`
|
||||
FactIndexPathMaxRunes int `yaml:"fact_index_path_max_runes,omitempty" json:"fact_index_path_max_runes,omitempty"`
|
||||
FactSummaryMaxRunes int `yaml:"fact_summary_max_runes,omitempty" json:"fact_summary_max_runes,omitempty"`
|
||||
DefaultInjectDeprecated bool `yaml:"default_inject_deprecated,omitempty" json:"default_inject_deprecated,omitempty"`
|
||||
}
|
||||
@@ -57,6 +59,14 @@ func (c ProjectConfig) FactIndexMaxRunesEffective() int {
|
||||
return c.FactIndexMaxRunes
|
||||
}
|
||||
|
||||
// FactIndexPathMaxRunesEffective 攻击路径速览段的最大 rune 数(从 fact_index_max_runes 预算中预留)。
|
||||
func (c ProjectConfig) FactIndexPathMaxRunesEffective() int {
|
||||
if c.FactIndexPathMaxRunes <= 0 {
|
||||
return 1000
|
||||
}
|
||||
return c.FactIndexPathMaxRunes
|
||||
}
|
||||
|
||||
// FactSummaryMaxRunesEffective upsert 时 summary 最大 rune 数(索引一行,宜含验证要点)。
|
||||
func (c ProjectConfig) FactSummaryMaxRunesEffective() int {
|
||||
if c.FactSummaryMaxRunes <= 0 {
|
||||
@@ -240,7 +250,7 @@ type MultiAgentEinoMiddlewareConfig struct {
|
||||
SummarizationTriggerRatio float64 `yaml:"summarization_trigger_ratio,omitempty" json:"summarization_trigger_ratio,omitempty"`
|
||||
// SummarizationEmitInternalEvents controls middleware internal event emission (default true).
|
||||
SummarizationEmitInternalEvents *bool `yaml:"summarization_emit_internal_events,omitempty" json:"summarization_emit_internal_events,omitempty"`
|
||||
// SummarizationRetryMaxAttempts is extra retries after the first summarization Generate attempt; 0 = default 3.
|
||||
// SummarizationRetryMaxAttempts 已废弃:summarization 与 run loop 共用 run_retry_max_attempts 及 isEinoTransientRunError。
|
||||
SummarizationRetryMaxAttempts int `yaml:"summarization_retry_max_attempts,omitempty" json:"summarization_retry_max_attempts,omitempty"`
|
||||
// PlanExecuteUserInputBudgetRatio caps planner/replanner/executor userInput prompt budget ratio (default 0.35).
|
||||
PlanExecuteUserInputBudgetRatio float64 `yaml:"plan_execute_user_input_budget_ratio,omitempty" json:"plan_execute_user_input_budget_ratio,omitempty"`
|
||||
@@ -254,9 +264,9 @@ type MultiAgentEinoMiddlewareConfig struct {
|
||||
CheckpointDir string `yaml:"checkpoint_dir,omitempty" json:"checkpoint_dir,omitempty"`
|
||||
// DeepOutputKey passed to deep.Config OutputKey (session final text); empty = off.
|
||||
DeepOutputKey string `yaml:"deep_output_key,omitempty" json:"deep_output_key,omitempty"`
|
||||
// DeepModelRetryMaxRetries > 0 enables deep.Config ModelRetryConfig (framework-level chat model retries).
|
||||
// DeepModelRetryMaxRetries 已废弃:临时错误统一由 run loop 内 isEinoTransientRunError + run_retry_max_attempts 处理。
|
||||
DeepModelRetryMaxRetries int `yaml:"deep_model_retry_max_retries,omitempty" json:"deep_model_retry_max_retries,omitempty"`
|
||||
// RunRetryMaxAttempts > 0:429/5xx/网络抖动时 handler 分段续跑次数;0=默认 10。
|
||||
// RunRetryMaxAttempts > 0:429/5xx/网络抖动时可退避重试次数(run loop 与 summarization 共用);0=默认 10。
|
||||
RunRetryMaxAttempts int `yaml:"run_retry_max_attempts,omitempty" json:"run_retry_max_attempts,omitempty"`
|
||||
// RunRetryMaxBackoffSec 单次退避上限秒数;0=默认 30。
|
||||
RunRetryMaxBackoffSec int `yaml:"run_retry_max_backoff_sec,omitempty" json:"run_retry_max_backoff_sec,omitempty"`
|
||||
@@ -614,6 +624,23 @@ type AuthConfig struct {
|
||||
GeneratedPasswordPersistErr string `yaml:"-" json:"-"`
|
||||
}
|
||||
|
||||
// MonitorConfig MCP 状态监控(tool_executions)保留策略。
|
||||
type MonitorConfig struct {
|
||||
// RetentionDays 执行记录保留天数;省略时默认 90;0 表示不自动清理。
|
||||
RetentionDays *int `yaml:"retention_days,omitempty" json:"retention_days,omitempty"`
|
||||
}
|
||||
|
||||
// RetentionDaysEffective returns retention; 0 means keep forever; omitted defaults to 90.
|
||||
func (m MonitorConfig) RetentionDaysEffective() int {
|
||||
if m.RetentionDays == nil {
|
||||
return 90
|
||||
}
|
||||
if *m.RetentionDays < 0 {
|
||||
return 0
|
||||
}
|
||||
return *m.RetentionDays
|
||||
}
|
||||
|
||||
// AuditConfig platform operation audit log settings (not chat/tool execution bodies).
|
||||
type AuditConfig struct {
|
||||
// Enabled nil or true enables persistence; explicit false disables.
|
||||
@@ -1265,6 +1292,10 @@ func Default() *Config {
|
||||
Enabled: &on,
|
||||
}
|
||||
}(),
|
||||
Monitor: func() MonitorConfig {
|
||||
days := 90
|
||||
return MonitorConfig{RetentionDays: &days}
|
||||
}(),
|
||||
Robots: RobotsConfig{
|
||||
Session: RobotSessionConfig{
|
||||
StrictUserIdentity: &strictRobotIdentity,
|
||||
|
||||
@@ -23,6 +23,7 @@ type BatchTaskQueueRow struct {
|
||||
LastScheduleError sql.NullString
|
||||
LastRunError sql.NullString
|
||||
ProjectID sql.NullString
|
||||
Concurrency sql.NullInt64
|
||||
Status string
|
||||
CreatedAt time.Time
|
||||
StartedAt sql.NullTime
|
||||
@@ -53,6 +54,7 @@ func (db *DB) CreateBatchQueue(
|
||||
cronExpr string,
|
||||
nextRunAt *time.Time,
|
||||
projectID string,
|
||||
concurrency int,
|
||||
tasks []map[string]interface{},
|
||||
) error {
|
||||
tx, err := db.Begin()
|
||||
@@ -72,8 +74,8 @@ func (db *DB) CreateBatchQueue(
|
||||
projectIDVal = strings.TrimSpace(projectID)
|
||||
}
|
||||
_, err = tx.Exec(
|
||||
"INSERT INTO batch_task_queues (id, title, role, agent_mode, schedule_mode, cron_expr, next_run_at, schedule_enabled, project_id, status, created_at, current_index) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
|
||||
queueID, title, role, agentMode, scheduleMode, cronExpr, nextRunAtValue, 1, projectIDVal, "pending", now, 0,
|
||||
"INSERT INTO batch_task_queues (id, title, role, agent_mode, schedule_mode, cron_expr, next_run_at, schedule_enabled, project_id, concurrency, status, created_at, current_index) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
|
||||
queueID, title, role, agentMode, scheduleMode, cronExpr, nextRunAtValue, 1, projectIDVal, concurrency, "pending", now, 0,
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("创建批量任务队列失败: %w", err)
|
||||
@@ -102,14 +104,16 @@ func (db *DB) CreateBatchQueue(
|
||||
return tx.Commit()
|
||||
}
|
||||
|
||||
const batchQueueSelectColumns = `id, title, role, agent_mode, schedule_mode, cron_expr, next_run_at, schedule_enabled, last_schedule_trigger_at, last_schedule_error, last_run_error, project_id, concurrency, status, created_at, started_at, completed_at, current_index`
|
||||
|
||||
// GetBatchQueue 获取批量任务队列
|
||||
func (db *DB) GetBatchQueue(queueID string) (*BatchTaskQueueRow, error) {
|
||||
var row BatchTaskQueueRow
|
||||
var createdAt string
|
||||
err := db.QueryRow(
|
||||
"SELECT id, title, role, agent_mode, schedule_mode, cron_expr, next_run_at, schedule_enabled, last_schedule_trigger_at, last_schedule_error, last_run_error, project_id, status, created_at, started_at, completed_at, current_index FROM batch_task_queues WHERE id = ?",
|
||||
"SELECT "+batchQueueSelectColumns+" FROM batch_task_queues WHERE id = ?",
|
||||
queueID,
|
||||
).Scan(&row.ID, &row.Title, &row.Role, &row.AgentMode, &row.ScheduleMode, &row.CronExpr, &row.NextRunAt, &row.ScheduleEnabled, &row.LastScheduleTriggerAt, &row.LastScheduleError, &row.LastRunError, &row.ProjectID, &row.Status, &createdAt, &row.StartedAt, &row.CompletedAt, &row.CurrentIndex)
|
||||
).Scan(&row.ID, &row.Title, &row.Role, &row.AgentMode, &row.ScheduleMode, &row.CronExpr, &row.NextRunAt, &row.ScheduleEnabled, &row.LastScheduleTriggerAt, &row.LastScheduleError, &row.LastRunError, &row.ProjectID, &row.Concurrency, &row.Status, &createdAt, &row.StartedAt, &row.CompletedAt, &row.CurrentIndex)
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, nil
|
||||
}
|
||||
@@ -133,7 +137,7 @@ func (db *DB) GetBatchQueue(queueID string) (*BatchTaskQueueRow, error) {
|
||||
// GetAllBatchQueues 获取所有批量任务队列
|
||||
func (db *DB) GetAllBatchQueues() ([]*BatchTaskQueueRow, error) {
|
||||
rows, err := db.Query(
|
||||
"SELECT id, title, role, agent_mode, schedule_mode, cron_expr, next_run_at, schedule_enabled, last_schedule_trigger_at, last_schedule_error, last_run_error, project_id, status, created_at, started_at, completed_at, current_index FROM batch_task_queues ORDER BY created_at DESC",
|
||||
"SELECT "+batchQueueSelectColumns+" FROM batch_task_queues ORDER BY created_at DESC",
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("查询批量任务队列列表失败: %w", err)
|
||||
@@ -144,7 +148,7 @@ func (db *DB) GetAllBatchQueues() ([]*BatchTaskQueueRow, error) {
|
||||
for rows.Next() {
|
||||
var row BatchTaskQueueRow
|
||||
var createdAt string
|
||||
if err := rows.Scan(&row.ID, &row.Title, &row.Role, &row.AgentMode, &row.ScheduleMode, &row.CronExpr, &row.NextRunAt, &row.ScheduleEnabled, &row.LastScheduleTriggerAt, &row.LastScheduleError, &row.LastRunError, &row.ProjectID, &row.Status, &createdAt, &row.StartedAt, &row.CompletedAt, &row.CurrentIndex); err != nil {
|
||||
if err := rows.Scan(&row.ID, &row.Title, &row.Role, &row.AgentMode, &row.ScheduleMode, &row.CronExpr, &row.NextRunAt, &row.ScheduleEnabled, &row.LastScheduleTriggerAt, &row.LastScheduleError, &row.LastRunError, &row.ProjectID, &row.Concurrency, &row.Status, &createdAt, &row.StartedAt, &row.CompletedAt, &row.CurrentIndex); err != nil {
|
||||
return nil, fmt.Errorf("扫描批量任务队列失败: %w", err)
|
||||
}
|
||||
parsedTime, parseErr := time.Parse("2006-01-02 15:04:05", createdAt)
|
||||
@@ -164,7 +168,7 @@ func (db *DB) GetAllBatchQueues() ([]*BatchTaskQueueRow, error) {
|
||||
|
||||
// ListBatchQueues 列出批量任务队列(支持筛选和分页)
|
||||
func (db *DB) ListBatchQueues(limit, offset int, status, keyword string) ([]*BatchTaskQueueRow, error) {
|
||||
query := "SELECT id, title, role, agent_mode, schedule_mode, cron_expr, next_run_at, schedule_enabled, last_schedule_trigger_at, last_schedule_error, last_run_error, project_id, status, created_at, started_at, completed_at, current_index FROM batch_task_queues WHERE 1=1"
|
||||
query := "SELECT " + batchQueueSelectColumns + " FROM batch_task_queues WHERE 1=1"
|
||||
args := []interface{}{}
|
||||
|
||||
// 状态筛选
|
||||
@@ -192,7 +196,7 @@ func (db *DB) ListBatchQueues(limit, offset int, status, keyword string) ([]*Bat
|
||||
for rows.Next() {
|
||||
var row BatchTaskQueueRow
|
||||
var createdAt string
|
||||
if err := rows.Scan(&row.ID, &row.Title, &row.Role, &row.AgentMode, &row.ScheduleMode, &row.CronExpr, &row.NextRunAt, &row.ScheduleEnabled, &row.LastScheduleTriggerAt, &row.LastScheduleError, &row.LastRunError, &row.ProjectID, &row.Status, &createdAt, &row.StartedAt, &row.CompletedAt, &row.CurrentIndex); err != nil {
|
||||
if err := rows.Scan(&row.ID, &row.Title, &row.Role, &row.AgentMode, &row.ScheduleMode, &row.CronExpr, &row.NextRunAt, &row.ScheduleEnabled, &row.LastScheduleTriggerAt, &row.LastScheduleError, &row.LastRunError, &row.ProjectID, &row.Concurrency, &row.Status, &createdAt, &row.StartedAt, &row.CompletedAt, &row.CurrentIndex); err != nil {
|
||||
return nil, fmt.Errorf("扫描批量任务队列失败: %w", err)
|
||||
}
|
||||
parsedTime, parseErr := time.Parse("2006-01-02 15:04:05", createdAt)
|
||||
@@ -358,11 +362,11 @@ func (db *DB) UpdateBatchQueueCurrentIndex(queueID string, currentIndex int) err
|
||||
return nil
|
||||
}
|
||||
|
||||
// UpdateBatchQueueMetadata 更新批量任务队列标题、角色和代理模式
|
||||
func (db *DB) UpdateBatchQueueMetadata(queueID, title, role, agentMode string) error {
|
||||
// UpdateBatchQueueMetadata 更新批量任务队列标题、角色、代理模式和并发数
|
||||
func (db *DB) UpdateBatchQueueMetadata(queueID, title, role, agentMode string, concurrency int) error {
|
||||
_, err := db.Exec(
|
||||
"UPDATE batch_task_queues SET title = ?, role = ?, agent_mode = ? WHERE id = ?",
|
||||
title, role, agentMode, queueID,
|
||||
"UPDATE batch_task_queues SET title = ?, role = ?, agent_mode = ?, concurrency = ? WHERE id = ?",
|
||||
title, role, agentMode, concurrency, queueID,
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("更新批量任务队列元数据失败: %w", err)
|
||||
|
||||
@@ -352,8 +352,8 @@ func (db *DB) GetConversationLite(id string) (*Conversation, error) {
|
||||
|
||||
conv.Pinned = pinned != 0
|
||||
|
||||
// 加载消息(不加载 process_details)
|
||||
messages, err := db.GetMessages(id)
|
||||
// 加载消息(不加载 process_details / reasoning_content,减少历史会话切换 payload)
|
||||
messages, err := db.GetMessagesLite(id)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("加载消息失败: %w", err)
|
||||
}
|
||||
@@ -585,12 +585,14 @@ func (db *DB) DeleteConversation(id string) error {
|
||||
// 不返回错误,继续删除对话
|
||||
}
|
||||
|
||||
projectID, _ := db.GetConversationProjectID(id)
|
||||
|
||||
// 删除对话(外键CASCADE会自动删除其他相关数据)
|
||||
_, err = db.Exec("DELETE FROM conversations WHERE id = ?", id)
|
||||
if err != nil {
|
||||
return fmt.Errorf("删除对话失败: %w", err)
|
||||
}
|
||||
db.removeConversationScopedDirs(id)
|
||||
db.removeConversationScopedDirs(id, projectID)
|
||||
|
||||
db.logger.Info("对话已删除(漏洞记录已保留)", zap.String("conversationId", id))
|
||||
return nil
|
||||
@@ -628,13 +630,35 @@ func (db *DB) removeConversationScopedDir(base, conversationID, label string) {
|
||||
}
|
||||
}
|
||||
|
||||
func (db *DB) removeConversationScopedDirs(conversationID string) {
|
||||
// summarization transcript, reduction files, etc.
|
||||
func (db *DB) einoReductionBaseDir() string {
|
||||
if db == nil {
|
||||
return ""
|
||||
}
|
||||
if base := strings.TrimSpace(db.einoReductionRootDir); base != "" {
|
||||
return base
|
||||
}
|
||||
return filepath.Join("tmp", "reduction")
|
||||
}
|
||||
|
||||
func (db *DB) removeConversationScopedDirs(conversationID, projectID string) {
|
||||
// summarization transcript, etc.
|
||||
db.removeConversationScopedDir(db.conversationArtifactsDir, conversationID, "conversation_artifacts")
|
||||
// Eino plantask JSON boards (skills_dir/.eino/plantask/<id>/).
|
||||
db.removeConversationScopedDir(db.einoPlantaskBaseDir, conversationID, "plantask")
|
||||
// Eino ADK runner checkpoints (checkpoint_dir/<id>/).
|
||||
db.removeConversationScopedDir(db.einoCheckpointBaseDir, conversationID, "eino_checkpoint")
|
||||
// Eino reduction persisted tool outputs (tmp/reduction/conversations/<id>/).
|
||||
// Project-bound sessions share projects/<id>/ — skip on single conversation delete.
|
||||
if strings.TrimSpace(projectID) == "" {
|
||||
reductionBase := filepath.Join(db.einoReductionBaseDir(), "conversations")
|
||||
db.removeConversationScopedDir(reductionBase, conversationID, "reduction")
|
||||
}
|
||||
}
|
||||
|
||||
func (db *DB) removeProjectScopedDirs(projectID string) {
|
||||
// Eino reduction persisted tool outputs (tmp/reduction/projects/<id>/).
|
||||
reductionBase := filepath.Join(db.einoReductionBaseDir(), "projects")
|
||||
db.removeConversationScopedDir(reductionBase, projectID, "reduction")
|
||||
}
|
||||
|
||||
// SaveAgentTrace 保存最后一轮代理消息轨迹与助手输出摘要。
|
||||
@@ -811,6 +835,62 @@ func (db *DB) GetMessages(conversationID string) ([]Message, error) {
|
||||
return messages, nil
|
||||
}
|
||||
|
||||
// GetMessagesLite 获取对话消息(不含 reasoning_content),用于历史会话快速切换。
|
||||
func (db *DB) GetMessagesLite(conversationID string) ([]Message, error) {
|
||||
rows, err := db.Query(
|
||||
"SELECT id, conversation_id, role, content, mcp_execution_ids, created_at, updated_at FROM messages WHERE conversation_id = ? ORDER BY created_at ASC, rowid ASC",
|
||||
conversationID,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("查询消息失败: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var messages []Message
|
||||
for rows.Next() {
|
||||
var msg Message
|
||||
var mcpIDsJSON sql.NullString
|
||||
var createdAt string
|
||||
var updatedAt sql.NullString
|
||||
|
||||
if err := rows.Scan(&msg.ID, &msg.ConversationID, &msg.Role, &msg.Content, &mcpIDsJSON, &createdAt, &updatedAt); err != nil {
|
||||
return nil, fmt.Errorf("扫描消息失败: %w", err)
|
||||
}
|
||||
|
||||
var err error
|
||||
msg.CreatedAt, err = time.Parse("2006-01-02 15:04:05.999999999-07:00", createdAt)
|
||||
if err != nil {
|
||||
msg.CreatedAt, err = time.Parse("2006-01-02 15:04:05", createdAt)
|
||||
}
|
||||
if err != nil {
|
||||
msg.CreatedAt, _ = time.Parse(time.RFC3339, createdAt)
|
||||
}
|
||||
|
||||
if updatedAt.Valid && strings.TrimSpace(updatedAt.String) != "" {
|
||||
msg.UpdatedAt, err = time.Parse("2006-01-02 15:04:05.999999999-07:00", updatedAt.String)
|
||||
if err != nil {
|
||||
msg.UpdatedAt, err = time.Parse("2006-01-02 15:04:05", updatedAt.String)
|
||||
}
|
||||
if err != nil {
|
||||
msg.UpdatedAt, _ = time.Parse(time.RFC3339, updatedAt.String)
|
||||
}
|
||||
}
|
||||
if msg.UpdatedAt.IsZero() {
|
||||
msg.UpdatedAt = msg.CreatedAt
|
||||
}
|
||||
|
||||
if mcpIDsJSON.Valid && mcpIDsJSON.String != "" {
|
||||
if err := json.Unmarshal([]byte(mcpIDsJSON.String), &msg.MCPExecutionIDs); err != nil {
|
||||
db.logger.Warn("解析MCP执行ID失败", zap.Error(err))
|
||||
}
|
||||
}
|
||||
|
||||
messages = append(messages, msg)
|
||||
}
|
||||
|
||||
return messages, nil
|
||||
}
|
||||
|
||||
// turnSliceRange 根据任意一条消息 ID 定位「一轮对话」在 msgs 中的 [start, end) 下标区间(msgs 须已按时间升序,与 GetMessages 一致)。
|
||||
// 一轮 = 从某条 user 消息起,至下一条 user 之前(含中间所有 assistant)。
|
||||
func turnSliceRange(msgs []Message, anchorID string) (start, end int, err error) {
|
||||
@@ -979,6 +1059,107 @@ func (db *DB) GetProcessDetails(messageID string) ([]ProcessDetail, error) {
|
||||
return details, nil
|
||||
}
|
||||
|
||||
// ProcessDetailsSummary 过程详情摘要(用于折叠态展示,避免全量加载)。
|
||||
type ProcessDetailsSummary struct {
|
||||
Total int `json:"total"`
|
||||
IterationCount int `json:"iterationCount"`
|
||||
MaxIteration int `json:"maxIteration"`
|
||||
}
|
||||
|
||||
// GetProcessDetailsSummary 统计消息的过程详情数量与迭代轮次。
|
||||
func (db *DB) GetProcessDetailsSummary(messageID string) (*ProcessDetailsSummary, error) {
|
||||
var total int
|
||||
if err := db.QueryRow(
|
||||
"SELECT COUNT(*) FROM process_details WHERE message_id = ?",
|
||||
messageID,
|
||||
).Scan(&total); err != nil {
|
||||
return nil, fmt.Errorf("统计过程详情失败: %w", err)
|
||||
}
|
||||
|
||||
summary := &ProcessDetailsSummary{Total: total}
|
||||
if total == 0 {
|
||||
return summary, nil
|
||||
}
|
||||
|
||||
rows, err := db.Query(
|
||||
"SELECT data FROM process_details WHERE message_id = ? AND event_type = 'iteration' ORDER BY created_at ASC, rowid ASC",
|
||||
messageID,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("查询迭代详情失败: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
maxIter := 0
|
||||
iterCount := 0
|
||||
for rows.Next() {
|
||||
var dataJSON string
|
||||
if err := rows.Scan(&dataJSON); err != nil {
|
||||
return nil, fmt.Errorf("扫描迭代详情失败: %w", err)
|
||||
}
|
||||
iterCount++
|
||||
if dataJSON == "" {
|
||||
continue
|
||||
}
|
||||
var payload map[string]interface{}
|
||||
if err := json.Unmarshal([]byte(dataJSON), &payload); err != nil {
|
||||
continue
|
||||
}
|
||||
if n, ok := payload["iteration"].(float64); ok && int(n) > maxIter {
|
||||
maxIter = int(n)
|
||||
}
|
||||
}
|
||||
summary.IterationCount = iterCount
|
||||
summary.MaxIteration = maxIter
|
||||
return summary, nil
|
||||
}
|
||||
|
||||
// GetProcessDetailsPage 分页获取消息的过程详情(按时间升序)。
|
||||
func (db *DB) GetProcessDetailsPage(messageID string, limit, offset int) ([]ProcessDetail, int, error) {
|
||||
var total int
|
||||
if err := db.QueryRow(
|
||||
"SELECT COUNT(*) FROM process_details WHERE message_id = ?",
|
||||
messageID,
|
||||
).Scan(&total); err != nil {
|
||||
return nil, 0, fmt.Errorf("统计过程详情失败: %w", err)
|
||||
}
|
||||
if total == 0 || offset >= total {
|
||||
return nil, total, nil
|
||||
}
|
||||
|
||||
rows, err := db.Query(
|
||||
"SELECT id, message_id, conversation_id, event_type, message, data, created_at FROM process_details WHERE message_id = ? ORDER BY created_at ASC, rowid ASC LIMIT ? OFFSET ?",
|
||||
messageID, limit, offset,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, 0, fmt.Errorf("查询过程详情失败: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var details []ProcessDetail
|
||||
for rows.Next() {
|
||||
var detail ProcessDetail
|
||||
var createdAt string
|
||||
|
||||
if err := rows.Scan(&detail.ID, &detail.MessageID, &detail.ConversationID, &detail.EventType, &detail.Message, &detail.Data, &createdAt); err != nil {
|
||||
return nil, 0, fmt.Errorf("扫描过程详情失败: %w", err)
|
||||
}
|
||||
|
||||
var parseErr error
|
||||
detail.CreatedAt, parseErr = time.Parse("2006-01-02 15:04:05.999999999-07:00", createdAt)
|
||||
if parseErr != nil {
|
||||
detail.CreatedAt, parseErr = time.Parse("2006-01-02 15:04:05", createdAt)
|
||||
}
|
||||
if parseErr != nil {
|
||||
detail.CreatedAt, _ = time.Parse(time.RFC3339, createdAt)
|
||||
}
|
||||
|
||||
details = append(details, detail)
|
||||
}
|
||||
|
||||
return details, total, nil
|
||||
}
|
||||
|
||||
// GetProcessDetailsByConversation 获取对话的所有过程详情(按消息分组)
|
||||
func (db *DB) GetProcessDetailsByConversation(conversationID string) (map[string][]ProcessDetail, error) {
|
||||
rows, err := db.Query(
|
||||
|
||||
@@ -19,7 +19,8 @@ func TestDeleteConversationRemovesEinoScopedDirs(t *testing.T) {
|
||||
|
||||
plantaskBase := filepath.Join(tmp, "skills", ".eino", "plantask")
|
||||
checkpointBase := filepath.Join(tmp, "eino-checkpoints")
|
||||
db.SetEinoConversationDirs(plantaskBase, checkpointBase)
|
||||
reductionBase := filepath.Join(tmp, "reduction")
|
||||
db.SetEinoConversationDirs(plantaskBase, checkpointBase, reductionBase)
|
||||
|
||||
conv, err := db.CreateConversation("cleanup test", ConversationCreateMeta{})
|
||||
if err != nil {
|
||||
@@ -34,6 +35,7 @@ func TestDeleteConversationRemovesEinoScopedDirs(t *testing.T) {
|
||||
{db.conversationArtifactsDir, "transcript.txt"},
|
||||
{plantaskBase, "task-1.json"},
|
||||
{checkpointBase, "runner-deep.ckpt"},
|
||||
{filepath.Join(reductionBase, "conversations"), "tool-output.txt"},
|
||||
} {
|
||||
dir := filepath.Join(base.root, seg)
|
||||
if err := os.MkdirAll(dir, 0o755); err != nil {
|
||||
@@ -48,10 +50,45 @@ func TestDeleteConversationRemovesEinoScopedDirs(t *testing.T) {
|
||||
t.Fatalf("DeleteConversation: %v", err)
|
||||
}
|
||||
|
||||
for _, base := range []string{db.conversationArtifactsDir, plantaskBase, checkpointBase} {
|
||||
for _, base := range []string{db.conversationArtifactsDir, plantaskBase, checkpointBase, filepath.Join(reductionBase, "conversations")} {
|
||||
dir := filepath.Join(base, seg)
|
||||
if _, statErr := os.Stat(dir); !os.IsNotExist(statErr) {
|
||||
t.Fatalf("expected removed dir %s, stat err=%v", dir, statErr)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestDeleteProjectRemovesReductionDir(t *testing.T) {
|
||||
tmp := t.TempDir()
|
||||
dbPath := filepath.Join(tmp, "conversations.db")
|
||||
db, err := NewDB(dbPath, zap.NewNop())
|
||||
if err != nil {
|
||||
t.Fatalf("NewDB: %v", err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
reductionBase := filepath.Join(tmp, "reduction")
|
||||
db.SetEinoConversationDirs("", "", reductionBase)
|
||||
|
||||
project, err := db.CreateProject(&Project{Name: "cleanup test"})
|
||||
if err != nil {
|
||||
t.Fatalf("CreateProject: %v", err)
|
||||
}
|
||||
seg := sanitizeConversationPathSegment(project.ID)
|
||||
reductionDir := filepath.Join(reductionBase, "projects", seg, "clear")
|
||||
if err := os.MkdirAll(reductionDir, 0o755); err != nil {
|
||||
t.Fatalf("mkdir %s: %v", reductionDir, err)
|
||||
}
|
||||
if err := os.WriteFile(filepath.Join(reductionDir, "call-1.txt"), []byte("x"), 0o644); err != nil {
|
||||
t.Fatalf("write: %v", err)
|
||||
}
|
||||
|
||||
if err := db.DeleteProject(project.ID); err != nil {
|
||||
t.Fatalf("DeleteProject: %v", err)
|
||||
}
|
||||
|
||||
projectReductionDir := filepath.Join(reductionBase, "projects", seg)
|
||||
if _, statErr := os.Stat(projectReductionDir); !os.IsNotExist(statErr) {
|
||||
t.Fatalf("expected removed dir %s, stat err=%v", projectReductionDir, statErr)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -51,6 +51,7 @@ type DB struct {
|
||||
conversationArtifactsDir string
|
||||
einoPlantaskBaseDir string // skills_dir + plantask_rel_dir (per-conversation subdirs)
|
||||
einoCheckpointBaseDir string // checkpoint_dir root (per-conversation subdirs)
|
||||
einoReductionRootDir string // reduction_root_dir or default tmp/reduction (conversations/<id> subdirs)
|
||||
checkpointLoopName string
|
||||
checkpointStop chan struct{}
|
||||
checkpointDone chan struct{}
|
||||
@@ -159,12 +160,14 @@ func NewDB(dbPath string, logger *zap.Logger) (*DB, error) {
|
||||
|
||||
// SetEinoConversationDirs configures best-effort filesystem cleanup on DeleteConversation.
|
||||
// plantaskBase is skills_root/plantask_rel (no conversation id); checkpointBase is checkpoint_dir root.
|
||||
func (db *DB) SetEinoConversationDirs(plantaskBase, checkpointBase string) {
|
||||
// reductionRoot is reduction_root_dir from config; empty uses tmp/reduction (conversation-scoped subdirs only).
|
||||
func (db *DB) SetEinoConversationDirs(plantaskBase, checkpointBase, reductionRoot string) {
|
||||
if db == nil {
|
||||
return
|
||||
}
|
||||
db.einoPlantaskBaseDir = strings.TrimSpace(plantaskBase)
|
||||
db.einoCheckpointBaseDir = strings.TrimSpace(checkpointBase)
|
||||
db.einoReductionRootDir = strings.TrimSpace(reductionRoot)
|
||||
}
|
||||
|
||||
// initTables 初始化数据库表
|
||||
@@ -353,6 +356,22 @@ func (db *DB) initTables() error {
|
||||
UNIQUE(project_id, fact_key)
|
||||
);`
|
||||
|
||||
// 项目事实关系边(黑板 DAG)
|
||||
createProjectFactEdgesTable := `
|
||||
CREATE TABLE IF NOT EXISTS project_fact_edges (
|
||||
id TEXT PRIMARY KEY,
|
||||
project_id TEXT NOT NULL,
|
||||
source_fact_key TEXT NOT NULL,
|
||||
target_fact_key TEXT NOT NULL,
|
||||
edge_type TEXT NOT NULL,
|
||||
confidence TEXT NOT NULL DEFAULT 'tentative',
|
||||
source_conversation_id TEXT,
|
||||
created_at DATETIME NOT NULL,
|
||||
updated_at DATETIME NOT NULL,
|
||||
FOREIGN KEY (project_id) REFERENCES projects(id) ON DELETE CASCADE,
|
||||
UNIQUE(project_id, source_fact_key, target_fact_key, edge_type)
|
||||
);`
|
||||
|
||||
// 创建漏洞表
|
||||
createVulnerabilitiesTable := `
|
||||
CREATE TABLE IF NOT EXISTS vulnerabilities (
|
||||
@@ -389,6 +408,8 @@ func (db *DB) initTables() error {
|
||||
last_schedule_trigger_at DATETIME,
|
||||
last_schedule_error TEXT,
|
||||
last_run_error TEXT,
|
||||
project_id TEXT,
|
||||
concurrency INTEGER NOT NULL DEFAULT 1,
|
||||
status TEXT NOT NULL,
|
||||
created_at DATETIME NOT NULL,
|
||||
started_at DATETIME,
|
||||
@@ -591,6 +612,9 @@ func (db *DB) initTables() error {
|
||||
CREATE INDEX IF NOT EXISTS idx_project_facts_project_id ON project_facts(project_id);
|
||||
CREATE INDEX IF NOT EXISTS idx_project_facts_confidence ON project_facts(confidence);
|
||||
CREATE INDEX IF NOT EXISTS idx_project_facts_related_vuln ON project_facts(related_vulnerability_id);
|
||||
CREATE INDEX IF NOT EXISTS idx_project_fact_edges_project ON project_fact_edges(project_id);
|
||||
CREATE INDEX IF NOT EXISTS idx_project_fact_edges_source ON project_fact_edges(project_id, source_fact_key);
|
||||
CREATE INDEX IF NOT EXISTS idx_project_fact_edges_target ON project_fact_edges(project_id, target_fact_key);
|
||||
CREATE INDEX IF NOT EXISTS idx_conversations_project_id ON conversations(project_id);
|
||||
CREATE INDEX IF NOT EXISTS idx_vulnerabilities_project_id ON vulnerabilities(project_id);
|
||||
CREATE INDEX IF NOT EXISTS idx_batch_tasks_queue_id ON batch_tasks(queue_id);
|
||||
@@ -672,6 +696,10 @@ func (db *DB) initTables() error {
|
||||
return fmt.Errorf("创建project_facts表失败: %w", err)
|
||||
}
|
||||
|
||||
if _, err := db.Exec(createProjectFactEdgesTable); err != nil {
|
||||
return fmt.Errorf("创建project_fact_edges表失败: %w", err)
|
||||
}
|
||||
|
||||
if _, err := db.Exec(createVulnerabilitiesTable); err != nil {
|
||||
return fmt.Errorf("创建vulnerabilities表失败: %w", err)
|
||||
}
|
||||
@@ -1111,6 +1139,21 @@ func (db *DB) migrateBatchTaskQueuesTable() error {
|
||||
}
|
||||
}
|
||||
|
||||
var concurrencyCount int
|
||||
err = db.QueryRow("SELECT COUNT(*) FROM pragma_table_info('batch_task_queues') WHERE name='concurrency'").Scan(&concurrencyCount)
|
||||
if err != nil {
|
||||
if _, addErr := db.Exec("ALTER TABLE batch_task_queues ADD COLUMN concurrency INTEGER NOT NULL DEFAULT 1"); addErr != nil {
|
||||
errMsg := strings.ToLower(addErr.Error())
|
||||
if !strings.Contains(errMsg, "duplicate column") && !strings.Contains(errMsg, "already exists") {
|
||||
db.logger.Warn("添加batch_task_queues.concurrency字段失败", zap.Error(addErr))
|
||||
}
|
||||
}
|
||||
} else if concurrencyCount == 0 {
|
||||
if _, err := db.Exec("ALTER TABLE batch_task_queues ADD COLUMN concurrency INTEGER NOT NULL DEFAULT 1"); err != nil {
|
||||
db.logger.Warn("添加batch_task_queues.concurrency字段失败", zap.Error(err))
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
@@ -410,6 +410,76 @@ func (db *DB) GetToolExecutionsByIds(ids []string) ([]*mcp.ToolExecution, error)
|
||||
return executions, nil
|
||||
}
|
||||
|
||||
type toolExecutionStatDelta struct {
|
||||
totalCalls int
|
||||
successCalls int
|
||||
failedCalls int
|
||||
}
|
||||
|
||||
// PurgeToolExecutionsBefore deletes executions older than cutoff and adjusts tool_stats.
|
||||
func (db *DB) PurgeToolExecutionsBefore(cutoff time.Time) (int64, error) {
|
||||
query := `
|
||||
SELECT tool_name, status, COUNT(*) AS cnt
|
||||
FROM tool_executions
|
||||
WHERE ` + sqliteEpochGE("start_time", "<") + `
|
||||
GROUP BY tool_name, status
|
||||
`
|
||||
rows, err := db.Query(query, formatSQLiteUTC(cutoff))
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
deltas := make(map[string]*toolExecutionStatDelta)
|
||||
for rows.Next() {
|
||||
var toolName, status string
|
||||
var count int
|
||||
if err := rows.Scan(&toolName, &status, &count); err != nil {
|
||||
db.logger.Warn("读取待清理执行记录统计失败", zap.Error(err))
|
||||
continue
|
||||
}
|
||||
toolName = strings.TrimSpace(toolName)
|
||||
if toolName == "" || count <= 0 {
|
||||
continue
|
||||
}
|
||||
delta := deltas[toolName]
|
||||
if delta == nil {
|
||||
delta = &toolExecutionStatDelta{}
|
||||
deltas[toolName] = delta
|
||||
}
|
||||
delta.totalCalls += count
|
||||
switch status {
|
||||
case "failed", "cancelled":
|
||||
delta.failedCalls += count
|
||||
case "completed":
|
||||
delta.successCalls += count
|
||||
}
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
res, err := db.Exec(`DELETE FROM tool_executions WHERE `+sqliteEpochGE("start_time", "<"), formatSQLiteUTC(cutoff))
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
deleted, err := res.RowsAffected()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
for toolName, delta := range deltas {
|
||||
if err := db.DecreaseToolStats(toolName, delta.totalCalls, delta.successCalls, delta.failedCalls); err != nil {
|
||||
db.logger.Warn("清理过期执行记录后更新统计失败",
|
||||
zap.Error(err),
|
||||
zap.String("toolName", toolName),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
return deleted, nil
|
||||
}
|
||||
|
||||
// SaveToolStats 保存工具统计信息
|
||||
func (db *DB) SaveToolStats(toolName string, stats *mcp.ToolStats) error {
|
||||
var lastCallTime sql.NullTime
|
||||
|
||||
@@ -0,0 +1,122 @@
|
||||
package database
|
||||
|
||||
import (
|
||||
"path/filepath"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"cyberstrike-ai/internal/mcp"
|
||||
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
func TestPurgeToolExecutionsBefore(t *testing.T) {
|
||||
dbPath := filepath.Join(t.TempDir(), "monitor.db")
|
||||
db, err := NewDB(dbPath, zap.NewNop())
|
||||
if err != nil {
|
||||
t.Fatalf("NewDB: %v", err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
oldStart := time.Now().AddDate(0, 0, -100)
|
||||
newStart := time.Now().AddDate(0, 0, -1)
|
||||
|
||||
oldExec := &mcp.ToolExecution{
|
||||
ID: "old-completed",
|
||||
ToolName: "nmap::scan",
|
||||
Arguments: map[string]interface{}{"target": "127.0.0.1"},
|
||||
Status: "completed",
|
||||
StartTime: oldStart,
|
||||
}
|
||||
oldFailed := &mcp.ToolExecution{
|
||||
ID: "old-failed",
|
||||
ToolName: "nmap::scan",
|
||||
Arguments: map[string]interface{}{"target": "127.0.0.1"},
|
||||
Status: "failed",
|
||||
Error: "timeout",
|
||||
StartTime: oldStart,
|
||||
}
|
||||
newExec := &mcp.ToolExecution{
|
||||
ID: "new-completed",
|
||||
ToolName: "nmap::scan",
|
||||
Arguments: map[string]interface{}{"target": "127.0.0.1"},
|
||||
Status: "completed",
|
||||
StartTime: newStart,
|
||||
}
|
||||
for _, exec := range []*mcp.ToolExecution{oldExec, oldFailed, newExec} {
|
||||
if err := db.SaveToolExecution(exec); err != nil {
|
||||
t.Fatalf("SaveToolExecution(%s): %v", exec.ID, err)
|
||||
}
|
||||
}
|
||||
if err := db.UpdateToolStats("nmap::scan", 3, 2, 1, &newStart); err != nil {
|
||||
t.Fatalf("UpdateToolStats: %v", err)
|
||||
}
|
||||
|
||||
cutoff := time.Now().AddDate(0, 0, -90)
|
||||
deleted, err := db.PurgeToolExecutionsBefore(cutoff)
|
||||
if err != nil {
|
||||
t.Fatalf("PurgeToolExecutionsBefore: %v", err)
|
||||
}
|
||||
if deleted != 2 {
|
||||
t.Fatalf("deleted = %d, want 2", deleted)
|
||||
}
|
||||
|
||||
if _, err := db.GetToolExecution("old-completed"); err == nil {
|
||||
t.Fatal("old-completed should be deleted")
|
||||
}
|
||||
if _, err := db.GetToolExecution("old-failed"); err == nil {
|
||||
t.Fatal("old-failed should be deleted")
|
||||
}
|
||||
if _, err := db.GetToolExecution("new-completed"); err != nil {
|
||||
t.Fatalf("new-completed should remain: %v", err)
|
||||
}
|
||||
|
||||
stats, err := db.LoadToolStats()
|
||||
if err != nil {
|
||||
t.Fatalf("LoadToolStats: %v", err)
|
||||
}
|
||||
stat := stats["nmap::scan"]
|
||||
if stat == nil {
|
||||
t.Fatal("expected stats for nmap::scan")
|
||||
}
|
||||
if stat.TotalCalls != 1 || stat.SuccessCalls != 1 || stat.FailedCalls != 0 {
|
||||
t.Fatalf("stats after purge = %+v, want total=1 success=1 failed=0", stat)
|
||||
}
|
||||
|
||||
total, err := db.CountToolExecutions("", "")
|
||||
if err != nil {
|
||||
t.Fatalf("CountToolExecutions: %v", err)
|
||||
}
|
||||
if total != 1 {
|
||||
t.Fatalf("remaining executions = %d, want 1", total)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPurgeToolExecutionsBefore_zeroRetentionSkipsViaService(t *testing.T) {
|
||||
// RetentionDaysEffective: 0 means no purge at service layer; DB method still works when called directly.
|
||||
dbPath := filepath.Join(t.TempDir(), "monitor.db")
|
||||
db, err := NewDB(dbPath, zap.NewNop())
|
||||
if err != nil {
|
||||
t.Fatalf("NewDB: %v", err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
exec := &mcp.ToolExecution{
|
||||
ID: "ancient",
|
||||
ToolName: "curl::get",
|
||||
Arguments: map[string]interface{}{},
|
||||
Status: "completed",
|
||||
StartTime: time.Now().AddDate(-1, 0, 0),
|
||||
}
|
||||
if err := db.SaveToolExecution(exec); err != nil {
|
||||
t.Fatalf("SaveToolExecution: %v", err)
|
||||
}
|
||||
|
||||
deleted, err := db.PurgeToolExecutionsBefore(time.Now())
|
||||
if err != nil {
|
||||
t.Fatalf("PurgeToolExecutionsBefore: %v", err)
|
||||
}
|
||||
if deleted != 1 {
|
||||
t.Fatalf("deleted = %d, want 1", deleted)
|
||||
}
|
||||
}
|
||||
@@ -195,6 +195,7 @@ func (db *DB) DeleteProject(id string) error {
|
||||
if err != nil {
|
||||
return fmt.Errorf("删除项目失败: %w", err)
|
||||
}
|
||||
db.removeProjectScopedDirs(id)
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -389,7 +390,7 @@ func (db *DB) UpsertProjectFact(f *ProjectFact) (*ProjectFact, error) {
|
||||
return f, nil
|
||||
}
|
||||
|
||||
// DeprecateProjectFact 将事实标记为 deprecated。
|
||||
// DeprecateProjectFact 将事实标记为 deprecated(关联边同步 deprecated)。
|
||||
func (db *DB) DeprecateProjectFact(projectID, factKey string) error {
|
||||
res, err := db.Exec(
|
||||
`UPDATE project_facts SET confidence = 'deprecated', updated_at = ? WHERE project_id = ? AND fact_key = ?`,
|
||||
@@ -402,7 +403,7 @@ func (db *DB) DeprecateProjectFact(projectID, factKey string) error {
|
||||
if n == 0 {
|
||||
return fmt.Errorf("事实不存在")
|
||||
}
|
||||
return nil
|
||||
return db.DeprecateProjectFactEdgesForKey(projectID, factKey)
|
||||
}
|
||||
|
||||
// RestoreProjectFact 将已废弃事实恢复为 tentative 或 confirmed(重新参与黑板索引)。
|
||||
@@ -430,9 +431,16 @@ func (db *DB) RestoreProjectFact(projectID, factKey, confidence string) error {
|
||||
return err
|
||||
}
|
||||
|
||||
// DeleteProjectFact 删除事实。
|
||||
// DeleteProjectFact 删除事实(级联删除相关边)。
|
||||
func (db *DB) DeleteProjectFact(id string) error {
|
||||
_, err := db.Exec(`DELETE FROM project_facts WHERE id = ?`, id)
|
||||
f, err := db.GetProjectFact(id)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := db.DeleteProjectFactEdgesForKey(f.ProjectID, f.FactKey); err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = db.Exec(`DELETE FROM project_facts WHERE id = ?`, id)
|
||||
return err
|
||||
}
|
||||
|
||||
|
||||
@@ -0,0 +1,410 @@
|
||||
package database
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
// ValidProjectFactEdgeTypes 项目事实图允许的边类型。
|
||||
var ValidProjectFactEdgeTypes = map[string]struct{}{
|
||||
"depends_on": {},
|
||||
"leads_to": {},
|
||||
"enables": {},
|
||||
"exploits": {},
|
||||
"discovered_on": {},
|
||||
"contains": {},
|
||||
"part_of": {},
|
||||
"supports": {},
|
||||
}
|
||||
|
||||
// ProjectFactEdge 项目事实关系边(source → target)。
|
||||
type ProjectFactEdge struct {
|
||||
ID string `json:"id"`
|
||||
ProjectID string `json:"project_id"`
|
||||
SourceFactKey string `json:"source_fact_key"`
|
||||
TargetFactKey string `json:"target_fact_key"`
|
||||
EdgeType string `json:"edge_type"`
|
||||
Confidence string `json:"confidence"` // confirmed | tentative | deprecated
|
||||
SourceConversationID string `json:"source_conversation_id,omitempty"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
}
|
||||
|
||||
// ProjectFactEdgeInput 写入边时的输入(出边:source → To)。
|
||||
type ProjectFactEdgeInput struct {
|
||||
To string `json:"to"`
|
||||
Type string `json:"type"`
|
||||
Confidence string `json:"confidence,omitempty"`
|
||||
}
|
||||
|
||||
// ProjectFactEdgeFromInput 写入入边时的输入(From → 当前事实)。
|
||||
type ProjectFactEdgeFromInput struct {
|
||||
From string `json:"from"`
|
||||
Type string `json:"type"`
|
||||
Confidence string `json:"confidence,omitempty"`
|
||||
}
|
||||
|
||||
// ProjectFactGraphNode 图 API 节点。
|
||||
type ProjectFactGraphNode struct {
|
||||
ID string `json:"id"`
|
||||
FactKey string `json:"fact_key"`
|
||||
Category string `json:"category"`
|
||||
Label string `json:"label"` // 图节点短标签(截断)
|
||||
Summary string `json:"summary"` // 完整摘要(侧栏等详情用)
|
||||
Confidence string `json:"confidence"`
|
||||
Type string `json:"type"`
|
||||
Pinned bool `json:"pinned"`
|
||||
}
|
||||
|
||||
// ProjectFactGraphEdge 图 API 边。
|
||||
type ProjectFactGraphEdge struct {
|
||||
ID string `json:"id"`
|
||||
Source string `json:"source"`
|
||||
Target string `json:"target"`
|
||||
Type string `json:"type"`
|
||||
Confidence string `json:"confidence"`
|
||||
}
|
||||
|
||||
// ProjectFactGraph 项目事实图。
|
||||
type ProjectFactGraph struct {
|
||||
Nodes []ProjectFactGraphNode `json:"nodes"`
|
||||
Edges []ProjectFactGraphEdge `json:"edges"`
|
||||
}
|
||||
|
||||
// ValidateProjectFactEdgeType 校验边类型。
|
||||
func ValidateProjectFactEdgeType(edgeType string) error {
|
||||
edgeType = strings.TrimSpace(strings.ToLower(edgeType))
|
||||
if edgeType == "" {
|
||||
return fmt.Errorf("edge type 不能为空")
|
||||
}
|
||||
if _, ok := ValidProjectFactEdgeTypes[edgeType]; !ok {
|
||||
return fmt.Errorf("无效的 edge type: %s", edgeType)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func normalizeEdgeConfidence(confidence string) string {
|
||||
confidence = strings.TrimSpace(strings.ToLower(confidence))
|
||||
switch confidence {
|
||||
case "confirmed", "deprecated":
|
||||
return confidence
|
||||
default:
|
||||
return "tentative"
|
||||
}
|
||||
}
|
||||
|
||||
// ListProjectFactEdgesByProject 列出项目全部边。
|
||||
func (db *DB) ListProjectFactEdgesByProject(projectID string) ([]*ProjectFactEdge, error) {
|
||||
rows, err := db.Query(
|
||||
`SELECT id, project_id, source_fact_key, target_fact_key, edge_type, confidence,
|
||||
COALESCE(source_conversation_id,''), created_at, updated_at
|
||||
FROM project_fact_edges
|
||||
WHERE project_id = ?
|
||||
ORDER BY created_at ASC, rowid ASC`,
|
||||
projectID,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
return scanProjectFactEdges(rows)
|
||||
}
|
||||
|
||||
// ListOutgoingProjectFactEdges 列出某事实的全部出边。
|
||||
func (db *DB) ListOutgoingProjectFactEdges(projectID, sourceFactKey string) ([]*ProjectFactEdge, error) {
|
||||
rows, err := db.Query(
|
||||
`SELECT id, project_id, source_fact_key, target_fact_key, edge_type, confidence,
|
||||
COALESCE(source_conversation_id,''), created_at, updated_at
|
||||
FROM project_fact_edges
|
||||
WHERE project_id = ? AND source_fact_key = ?
|
||||
ORDER BY created_at ASC, rowid ASC`,
|
||||
projectID, sourceFactKey,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
return scanProjectFactEdges(rows)
|
||||
}
|
||||
|
||||
// ListIncomingProjectFactEdges 列出某事实的全部入边。
|
||||
func (db *DB) ListIncomingProjectFactEdges(projectID, targetFactKey string) ([]*ProjectFactEdge, error) {
|
||||
rows, err := db.Query(
|
||||
`SELECT id, project_id, source_fact_key, target_fact_key, edge_type, confidence,
|
||||
COALESCE(source_conversation_id,''), created_at, updated_at
|
||||
FROM project_fact_edges
|
||||
WHERE project_id = ? AND target_fact_key = ?
|
||||
ORDER BY created_at ASC, rowid ASC`,
|
||||
projectID, targetFactKey,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
return scanProjectFactEdges(rows)
|
||||
}
|
||||
|
||||
// ReplaceOutgoingProjectFactEdges 替换某事实的全部出边(links 省略时不调用)。
|
||||
func (db *DB) ReplaceOutgoingProjectFactEdges(projectID, sourceFactKey, sourceConversationID string, inputs []ProjectFactEdgeInput) error {
|
||||
sourceFactKey = strings.TrimSpace(sourceFactKey)
|
||||
if sourceFactKey == "" {
|
||||
return fmt.Errorf("source_fact_key 不能为空")
|
||||
}
|
||||
if _, err := db.Exec(
|
||||
`DELETE FROM project_fact_edges WHERE project_id = ? AND source_fact_key = ?`,
|
||||
projectID, sourceFactKey,
|
||||
); err != nil {
|
||||
return fmt.Errorf("清除旧边失败: %w", err)
|
||||
}
|
||||
for _, in := range inputs {
|
||||
target := strings.TrimSpace(in.To)
|
||||
if target == "" {
|
||||
continue
|
||||
}
|
||||
if err := ValidateFactKey(target); err != nil {
|
||||
return fmt.Errorf("target fact_key 无效 (%s): %w", target, err)
|
||||
}
|
||||
if target == sourceFactKey {
|
||||
return fmt.Errorf("边不能指向自身: %s", sourceFactKey)
|
||||
}
|
||||
if err := ValidateProjectFactEdgeType(in.Type); err != nil {
|
||||
return err
|
||||
}
|
||||
edge := &ProjectFactEdge{
|
||||
ID: uuid.New().String(),
|
||||
ProjectID: projectID,
|
||||
SourceFactKey: sourceFactKey,
|
||||
TargetFactKey: target,
|
||||
EdgeType: strings.ToLower(strings.TrimSpace(in.Type)),
|
||||
Confidence: normalizeEdgeConfidence(in.Confidence),
|
||||
SourceConversationID: sourceConversationID,
|
||||
CreatedAt: time.Now(),
|
||||
UpdatedAt: time.Now(),
|
||||
}
|
||||
if err := db.insertProjectFactEdge(edge); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ReplaceIncomingProjectFactEdges 替换某事实的全部入边(From 为来源 fact_key)。
|
||||
func (db *DB) ReplaceIncomingProjectFactEdges(projectID, targetFactKey string, inputs []ProjectFactEdgeFromInput) error {
|
||||
targetFactKey = strings.TrimSpace(targetFactKey)
|
||||
if targetFactKey == "" {
|
||||
return fmt.Errorf("target_fact_key 不能为空")
|
||||
}
|
||||
if _, err := db.Exec(
|
||||
`DELETE FROM project_fact_edges WHERE project_id = ? AND target_fact_key = ?`,
|
||||
projectID, targetFactKey,
|
||||
); err != nil {
|
||||
return fmt.Errorf("清除旧入边失败: %w", err)
|
||||
}
|
||||
for _, in := range inputs {
|
||||
source := strings.TrimSpace(in.From)
|
||||
if source == "" {
|
||||
continue
|
||||
}
|
||||
if err := ValidateFactKey(source); err != nil {
|
||||
return fmt.Errorf("source fact_key 无效 (%s): %w", source, err)
|
||||
}
|
||||
if source == targetFactKey {
|
||||
return fmt.Errorf("边不能指向自身: %s", targetFactKey)
|
||||
}
|
||||
if err := ValidateProjectFactEdgeType(in.Type); err != nil {
|
||||
return err
|
||||
}
|
||||
sourceConversationID := ""
|
||||
if srcFact, err := db.GetProjectFactByKey(projectID, source); err == nil && srcFact != nil {
|
||||
sourceConversationID = srcFact.SourceConversationID
|
||||
}
|
||||
edge := &ProjectFactEdge{
|
||||
ID: uuid.New().String(),
|
||||
ProjectID: projectID,
|
||||
SourceFactKey: source,
|
||||
TargetFactKey: targetFactKey,
|
||||
EdgeType: strings.ToLower(strings.TrimSpace(in.Type)),
|
||||
Confidence: normalizeEdgeConfidence(in.Confidence),
|
||||
SourceConversationID: sourceConversationID,
|
||||
CreatedAt: time.Now(),
|
||||
UpdatedAt: time.Now(),
|
||||
}
|
||||
if err := db.insertProjectFactEdge(edge); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetProjectFactEdge 按 ID 获取边。
|
||||
func (db *DB) GetProjectFactEdge(edgeID string) (*ProjectFactEdge, error) {
|
||||
var e ProjectFactEdge
|
||||
var createdAt, updatedAt string
|
||||
err := db.QueryRow(
|
||||
`SELECT id, project_id, source_fact_key, target_fact_key, edge_type, confidence,
|
||||
COALESCE(source_conversation_id,''), created_at, updated_at
|
||||
FROM project_fact_edges WHERE id = ?`, edgeID,
|
||||
).Scan(&e.ID, &e.ProjectID, &e.SourceFactKey, &e.TargetFactKey, &e.EdgeType, &e.Confidence,
|
||||
&e.SourceConversationID, &createdAt, &updatedAt)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("边不存在")
|
||||
}
|
||||
e.CreatedAt = parseDBTime(createdAt)
|
||||
e.UpdatedAt = parseDBTime(updatedAt)
|
||||
return &e, nil
|
||||
}
|
||||
|
||||
// AddProjectFactEdge 新增单条边(已存在则更新 confidence)。
|
||||
func (db *DB) AddProjectFactEdge(projectID string, in ProjectFactEdgeInput, sourceFactKey, sourceConversationID string) (*ProjectFactEdge, error) {
|
||||
sourceFactKey = strings.TrimSpace(sourceFactKey)
|
||||
target := strings.TrimSpace(in.To)
|
||||
if sourceFactKey == "" || target == "" {
|
||||
return nil, fmt.Errorf("source 与 target 必填")
|
||||
}
|
||||
if sourceFactKey == target {
|
||||
return nil, fmt.Errorf("边不能指向自身")
|
||||
}
|
||||
if err := ValidateProjectFactEdgeType(in.Type); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := ValidateFactKey(target); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
now := time.Now()
|
||||
e := &ProjectFactEdge{
|
||||
ID: uuid.New().String(),
|
||||
ProjectID: projectID,
|
||||
SourceFactKey: sourceFactKey,
|
||||
TargetFactKey: target,
|
||||
EdgeType: strings.ToLower(strings.TrimSpace(in.Type)),
|
||||
Confidence: normalizeEdgeConfidence(in.Confidence),
|
||||
SourceConversationID: sourceConversationID,
|
||||
CreatedAt: now,
|
||||
UpdatedAt: now,
|
||||
}
|
||||
_, err := db.Exec(
|
||||
`INSERT INTO project_fact_edges (
|
||||
id, project_id, source_fact_key, target_fact_key, edge_type, confidence,
|
||||
source_conversation_id, created_at, updated_at
|
||||
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
ON CONFLICT(project_id, source_fact_key, target_fact_key, edge_type)
|
||||
DO UPDATE SET confidence = excluded.confidence, updated_at = excluded.updated_at`,
|
||||
e.ID, e.ProjectID, e.SourceFactKey, e.TargetFactKey, e.EdgeType, e.Confidence,
|
||||
nullIfEmpty(e.SourceConversationID), e.CreatedAt, e.UpdatedAt,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("添加边失败: %w", err)
|
||||
}
|
||||
// 返回最新
|
||||
rows, err := db.Query(
|
||||
`SELECT id, project_id, source_fact_key, target_fact_key, edge_type, confidence,
|
||||
COALESCE(source_conversation_id,''), created_at, updated_at
|
||||
FROM project_fact_edges
|
||||
WHERE project_id = ? AND source_fact_key = ? AND target_fact_key = ? AND edge_type = ?`,
|
||||
projectID, sourceFactKey, target, e.EdgeType,
|
||||
)
|
||||
if err != nil {
|
||||
return e, nil
|
||||
}
|
||||
defer rows.Close()
|
||||
list, err := scanProjectFactEdges(rows)
|
||||
if err != nil || len(list) == 0 {
|
||||
return e, nil
|
||||
}
|
||||
return list[0], nil
|
||||
}
|
||||
|
||||
// DeleteProjectFactEdge 删除单条边。
|
||||
func (db *DB) DeleteProjectFactEdge(edgeID string) error {
|
||||
res, err := db.Exec(`DELETE FROM project_fact_edges WHERE id = ?`, edgeID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
n, _ := res.RowsAffected()
|
||||
if n == 0 {
|
||||
return fmt.Errorf("边不存在")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (db *DB) insertProjectFactEdge(e *ProjectFactEdge) error {
|
||||
_, err := db.Exec(
|
||||
`INSERT INTO project_fact_edges (
|
||||
id, project_id, source_fact_key, target_fact_key, edge_type, confidence,
|
||||
source_conversation_id, created_at, updated_at
|
||||
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)`,
|
||||
e.ID, e.ProjectID, e.SourceFactKey, e.TargetFactKey, e.EdgeType, e.Confidence,
|
||||
nullIfEmpty(e.SourceConversationID), e.CreatedAt, e.UpdatedAt,
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("写入边失败: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// RenameProjectFactKeyEdges 事实 key 变更时同步边上的引用。
|
||||
func (db *DB) RenameProjectFactKeyEdges(projectID, oldKey, newKey string) error {
|
||||
oldKey = strings.TrimSpace(oldKey)
|
||||
newKey = strings.TrimSpace(newKey)
|
||||
if oldKey == "" || newKey == "" || oldKey == newKey {
|
||||
return nil
|
||||
}
|
||||
now := time.Now()
|
||||
if _, err := db.Exec(
|
||||
`UPDATE project_fact_edges SET source_fact_key = ?, updated_at = ?
|
||||
WHERE project_id = ? AND source_fact_key = ?`,
|
||||
newKey, now, projectID, oldKey,
|
||||
); err != nil {
|
||||
return err
|
||||
}
|
||||
_, err := db.Exec(
|
||||
`UPDATE project_fact_edges SET target_fact_key = ?, updated_at = ?
|
||||
WHERE project_id = ? AND target_fact_key = ?`,
|
||||
newKey, now, projectID, oldKey,
|
||||
)
|
||||
return err
|
||||
}
|
||||
|
||||
// DeleteProjectFactEdgesForKey 删除与某 fact_key 相关的全部边。
|
||||
func (db *DB) DeleteProjectFactEdgesForKey(projectID, factKey string) error {
|
||||
_, err := db.Exec(
|
||||
`DELETE FROM project_fact_edges
|
||||
WHERE project_id = ? AND (source_fact_key = ? OR target_fact_key = ?)`,
|
||||
projectID, factKey, factKey,
|
||||
)
|
||||
return err
|
||||
}
|
||||
|
||||
// DeprecateProjectFactEdgesForKey 将关联边标记为 deprecated。
|
||||
func (db *DB) DeprecateProjectFactEdgesForKey(projectID, factKey string) error {
|
||||
now := time.Now()
|
||||
_, err := db.Exec(
|
||||
`UPDATE project_fact_edges SET confidence = 'deprecated', updated_at = ?
|
||||
WHERE project_id = ? AND (source_fact_key = ? OR target_fact_key = ?)
|
||||
AND confidence != 'deprecated'`,
|
||||
now, projectID, factKey, factKey,
|
||||
)
|
||||
return err
|
||||
}
|
||||
|
||||
func scanProjectFactEdges(rows *sql.Rows) ([]*ProjectFactEdge, error) {
|
||||
var out []*ProjectFactEdge
|
||||
for rows.Next() {
|
||||
var e ProjectFactEdge
|
||||
var createdAt, updatedAt string
|
||||
if err := rows.Scan(
|
||||
&e.ID, &e.ProjectID, &e.SourceFactKey, &e.TargetFactKey, &e.EdgeType, &e.Confidence,
|
||||
&e.SourceConversationID, &createdAt, &updatedAt,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
e.CreatedAt = parseDBTime(createdAt)
|
||||
e.UpdatedAt = parseDBTime(updatedAt)
|
||||
out = append(out, &e)
|
||||
}
|
||||
return out, rows.Err()
|
||||
}
|
||||
+70
-427
@@ -21,7 +21,6 @@ import (
|
||||
"cyberstrike-ai/internal/config"
|
||||
"cyberstrike-ai/internal/database"
|
||||
"cyberstrike-ai/internal/reasoning"
|
||||
"cyberstrike-ai/internal/mcp"
|
||||
"cyberstrike-ai/internal/mcp/builtin"
|
||||
"cyberstrike-ai/internal/multiagent"
|
||||
"cyberstrike-ai/internal/openai"
|
||||
@@ -178,8 +177,6 @@ type AgentHandler struct {
|
||||
}
|
||||
agentsMarkdownDir string // 多代理:Markdown 子 Agent 目录(绝对路径,空则不从磁盘合并)
|
||||
batchCronParser cron.Parser
|
||||
batchRunnerMu sync.Mutex
|
||||
batchRunning map[string]struct{}
|
||||
// hitlWhitelistSaver 侧栏「应用」HITL 时将会话增量白名单合并写入 config.yaml(可选)
|
||||
hitlWhitelistSaver HitlToolWhitelistSaver
|
||||
audit *audit.Service
|
||||
@@ -190,6 +187,21 @@ func (h *AgentHandler) SetAudit(s *audit.Service) {
|
||||
h.audit = s
|
||||
}
|
||||
|
||||
// CancelRunningTaskForConversation stops any in-flight agent work for the conversation (idempotent).
|
||||
func (h *AgentHandler) CancelRunningTaskForConversation(conversationID string) {
|
||||
if h == nil || conversationID == "" || h.tasks == nil {
|
||||
return
|
||||
}
|
||||
if execID := h.tasks.ActiveMCPExecutionID(conversationID); execID != "" {
|
||||
h.agent.CancelMCPToolExecutionWithNote(execID, "")
|
||||
}
|
||||
if ok, err := h.tasks.CancelTask(conversationID, ErrTaskCancelled); ok {
|
||||
h.logger.Info("已取消会话运行中任务", zap.String("conversationId", conversationID))
|
||||
} else if err != nil {
|
||||
h.logger.Warn("取消会话运行中任务失败", zap.String("conversationId", conversationID), zap.Error(err))
|
||||
}
|
||||
}
|
||||
|
||||
// HitlToolWhitelistSaver 合并 HITL 免审批工具到全局配置并落盘
|
||||
type HitlToolWhitelistSaver interface {
|
||||
MergeHitlToolWhitelistIntoConfig(add []string) error
|
||||
@@ -218,7 +230,6 @@ func NewAgentHandler(agent *agent.Agent, db *database.DB, cfg *config.Config, lo
|
||||
config: cfg,
|
||||
hitlManager: NewHITLManager(db, logger),
|
||||
batchCronParser: cron.NewParser(cron.Minute | cron.Hour | cron.Dom | cron.Month | cron.Dow | cron.Descriptor),
|
||||
batchRunning: make(map[string]struct{}),
|
||||
}
|
||||
if err := handler.hitlManager.EnsureSchema(); err != nil {
|
||||
logger.Warn("初始化 HITL 表失败", zap.Error(err))
|
||||
@@ -631,40 +642,11 @@ func (h *AgentHandler) runRobotEinoSingleWithRetry(
|
||||
assistantMessageID string,
|
||||
taskStatus *string,
|
||||
) (string, string, error) {
|
||||
curHist := history
|
||||
curMsg := finalMessage
|
||||
segmentUserMessage := finalMessage
|
||||
var resultMA *multiagent.RunResult
|
||||
var errMA error
|
||||
var transientRunAttempts int
|
||||
var emptyResponseAttempts int
|
||||
for {
|
||||
resultMA, errMA = multiagent.RunEinoSingleChatModelAgent(
|
||||
taskCtx, h.config, &h.config.MultiAgent, h.agent, h.db, h.logger,
|
||||
conversationID, h.conversationProjectID(conversationID), curMsg, curHist, roleTools, progressCallback, nil, h.projectBlackboardBlock(conversationID),
|
||||
)
|
||||
handledEmpty, exhaustedEmpty := h.handleEinoEmptyResponseContinue(
|
||||
taskCtx, conversationID, resultMA, errMA, &emptyResponseAttempts,
|
||||
&curHist, &curMsg, segmentUserMessage, progressCallback, nil,
|
||||
)
|
||||
if exhaustedEmpty {
|
||||
errMA = nil
|
||||
break
|
||||
}
|
||||
if handledEmpty {
|
||||
continue
|
||||
}
|
||||
if errMA == nil {
|
||||
transientRunAttempts = 0
|
||||
emptyResponseAttempts = 0
|
||||
break
|
||||
}
|
||||
if handled, _ := h.handleEinoTransientRetryContinue(
|
||||
taskCtx, conversationID, resultMA, errMA, &transientRunAttempts,
|
||||
&curHist, &curMsg, segmentUserMessage, progressCallback, nil,
|
||||
); handled {
|
||||
continue
|
||||
}
|
||||
resultMA, errMA := multiagent.RunEinoSingleChatModelAgent(
|
||||
taskCtx, h.config, &h.config.MultiAgent, h.agent, h.db, h.logger,
|
||||
conversationID, h.conversationProjectID(conversationID), finalMessage, history, roleTools, progressCallback, nil, h.projectBlackboardBlock(conversationID),
|
||||
)
|
||||
if errMA != nil {
|
||||
*taskStatus = "failed"
|
||||
return h.finalizeRobotAgentError(taskCtx, assistantMessageID, conversationID, resultMA, errMA)
|
||||
}
|
||||
@@ -680,41 +662,12 @@ func (h *AgentHandler) runRobotMultiAgentWithRetry(
|
||||
assistantMessageID string,
|
||||
taskStatus *string,
|
||||
) (string, string, error) {
|
||||
curHist := history
|
||||
curMsg := finalMessage
|
||||
segmentUserMessage := finalMessage
|
||||
var resultMA *multiagent.RunResult
|
||||
var errMA error
|
||||
var transientRunAttempts int
|
||||
var emptyResponseAttempts int
|
||||
for {
|
||||
resultMA, errMA = multiagent.RunDeepAgent(
|
||||
taskCtx, h.config, &h.config.MultiAgent, h.agent, h.db, h.logger,
|
||||
conversationID, h.conversationProjectID(conversationID), curMsg, curHist, roleTools, progressCallback,
|
||||
h.agentsMarkdownDir, orchestration, nil, h.projectBlackboardBlock(conversationID),
|
||||
)
|
||||
handledEmpty, exhaustedEmpty := h.handleEinoEmptyResponseContinue(
|
||||
taskCtx, conversationID, resultMA, errMA, &emptyResponseAttempts,
|
||||
&curHist, &curMsg, segmentUserMessage, progressCallback, nil,
|
||||
)
|
||||
if exhaustedEmpty {
|
||||
errMA = nil
|
||||
break
|
||||
}
|
||||
if handledEmpty {
|
||||
continue
|
||||
}
|
||||
if errMA == nil {
|
||||
transientRunAttempts = 0
|
||||
emptyResponseAttempts = 0
|
||||
break
|
||||
}
|
||||
if handled, _ := h.handleEinoTransientRetryContinue(
|
||||
taskCtx, conversationID, resultMA, errMA, &transientRunAttempts,
|
||||
&curHist, &curMsg, segmentUserMessage, progressCallback, nil,
|
||||
); handled {
|
||||
continue
|
||||
}
|
||||
resultMA, errMA := multiagent.RunDeepAgent(
|
||||
taskCtx, h.config, &h.config.MultiAgent, h.agent, h.db, h.logger,
|
||||
conversationID, h.conversationProjectID(conversationID), finalMessage, history, roleTools, progressCallback,
|
||||
h.agentsMarkdownDir, orchestration, nil, h.projectBlackboardBlock(conversationID),
|
||||
)
|
||||
if errMA != nil {
|
||||
*taskStatus = "failed"
|
||||
return h.finalizeRobotAgentError(taskCtx, assistantMessageID, conversationID, resultMA, errMA)
|
||||
}
|
||||
@@ -1309,7 +1262,10 @@ func (h *AgentHandler) createProgressCallback(runCtx context.Context, cancelRun
|
||||
|
||||
// 保存过程详情到数据库(排除 response/done;response 正文已在 messages 表)
|
||||
// response_start/response_delta 已聚合为 planning,不落逐条。
|
||||
// [Eino] agent 心跳 progress 仅用于实时进度标题,不落库以免时间线刷屏。
|
||||
skipEinoAgentHeartbeat := eventType == "progress" && strings.HasPrefix(strings.TrimSpace(message), "[Eino] ")
|
||||
if assistantMessageID != "" &&
|
||||
!skipEinoAgentHeartbeat &&
|
||||
eventType != "response" &&
|
||||
eventType != "done" &&
|
||||
eventType != "response_start" &&
|
||||
@@ -1376,6 +1332,21 @@ func (h *AgentHandler) CancelAgentLoop(c *gin.Context) {
|
||||
})
|
||||
return
|
||||
}
|
||||
if h.tasks.AbortActiveEinoExecute(req.ConversationID, note) {
|
||||
h.logger.Info("对话页仅终止当前 Eino execute",
|
||||
zap.String("conversationId", req.ConversationID),
|
||||
zap.Bool("hasNote", note != ""),
|
||||
)
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"status": "tool_abort_requested",
|
||||
"conversationId": req.ConversationID,
|
||||
"message": "已请求终止当前 execute 命令;命令返回后本轮推理将继续。",
|
||||
"continueAfter": true,
|
||||
"interruptWithNote": note != "",
|
||||
"continueWithoutTool": false,
|
||||
})
|
||||
return
|
||||
}
|
||||
// 无进行中的 MCP 工具(模型纯推理/流式输出阶段):取消当前上下文并由 Eino 流式处理器合并用户补充后自动续跑。
|
||||
h.tasks.SetInterruptContinueNote(req.ConversationID, note)
|
||||
ok, err := h.tasks.CancelTask(req.ConversationID, multiagent.ErrInterruptContinue)
|
||||
@@ -1495,6 +1466,7 @@ type BatchTaskRequest struct {
|
||||
CronExpr string `json:"cronExpr,omitempty"` // scheduleMode=cron 时必填
|
||||
ExecuteNow bool `json:"executeNow,omitempty"` // 创建后是否立即执行(默认 false)
|
||||
ProjectID string `json:"projectId,omitempty"` // 队列内子对话绑定的项目(可选)
|
||||
Concurrency int `json:"concurrency,omitempty"` // 同时执行的子任务数,默认 1,最大 8
|
||||
}
|
||||
|
||||
// batchQueueWantsEino 队列是否配置为走 Eino 多代理。
|
||||
@@ -1554,7 +1526,7 @@ func (h *AgentHandler) CreateBatchQueue(c *gin.Context) {
|
||||
nextRunAt = &next
|
||||
}
|
||||
|
||||
queue, createErr := h.batchTaskManager.CreateBatchQueue(req.Title, req.Role, agentMode, scheduleMode, cronExpr, req.ProjectID, nextRunAt, validTasks)
|
||||
queue, createErr := h.batchTaskManager.CreateBatchQueue(req.Title, req.Role, agentMode, scheduleMode, cronExpr, req.ProjectID, nextRunAt, req.Concurrency, validTasks)
|
||||
if createErr != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": createErr.Error()})
|
||||
return
|
||||
@@ -1744,15 +1716,16 @@ func (h *AgentHandler) PauseBatchQueue(c *gin.Context) {
|
||||
func (h *AgentHandler) UpdateBatchQueueMetadata(c *gin.Context) {
|
||||
queueID := c.Param("queueId")
|
||||
var req struct {
|
||||
Title string `json:"title"`
|
||||
Role string `json:"role"`
|
||||
AgentMode string `json:"agentMode"`
|
||||
Title string `json:"title"`
|
||||
Role string `json:"role"`
|
||||
AgentMode string `json:"agentMode"`
|
||||
Concurrency *int `json:"concurrency"`
|
||||
}
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
if err := h.batchTaskManager.UpdateQueueMetadata(queueID, req.Title, req.Role, req.AgentMode); err != nil {
|
||||
if err := h.batchTaskManager.UpdateQueueMetadata(queueID, req.Title, req.Role, req.AgentMode, req.Concurrency); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
@@ -1827,9 +1800,17 @@ func (h *AgentHandler) SetBatchQueueScheduleEnabled(c *gin.Context) {
|
||||
// DeleteBatchQueue 删除批量任务队列
|
||||
func (h *AgentHandler) DeleteBatchQueue(c *gin.Context) {
|
||||
queueID := c.Param("queueId")
|
||||
success := h.batchTaskManager.DeleteQueue(queueID)
|
||||
if !success {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "队列不存在"})
|
||||
if err := h.batchTaskManager.DeleteQueue(queueID); err != nil {
|
||||
switch {
|
||||
case errors.Is(err, ErrBatchQueueNotFound):
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "队列不存在"})
|
||||
case errors.Is(err, ErrBatchQueueExecutorActive):
|
||||
c.JSON(http.StatusConflict, gin.H{"error": "队列执行器仍在运行,请稍后再删除"})
|
||||
case errors.Is(err, ErrBatchQueueStillRunning):
|
||||
c.JSON(http.StatusConflict, gin.H{"error": "队列正在运行中,无法删除"})
|
||||
default:
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
}
|
||||
return
|
||||
}
|
||||
if h.audit != nil {
|
||||
@@ -1923,7 +1904,7 @@ func (h *AgentHandler) RunSingleBatchTask(c *gin.Context) {
|
||||
|
||||
// 暂停态单条执行:旧批量协程可能仍占用执行槽,先回收以便重新启动
|
||||
if queue, ok := h.batchTaskManager.GetBatchQueue(queueID); ok && queue.Status == BatchQueueStatusPaused {
|
||||
h.forceUnmarkBatchQueueRunning(queueID)
|
||||
h.batchTaskManager.ForceUnmarkQueueExecutor(queueID)
|
||||
}
|
||||
|
||||
autoStarted := true
|
||||
@@ -1982,26 +1963,6 @@ func (h *AgentHandler) DeleteBatchTask(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{"message": "任务已删除", "queue": queue})
|
||||
}
|
||||
|
||||
func (h *AgentHandler) markBatchQueueRunning(queueID string) bool {
|
||||
h.batchRunnerMu.Lock()
|
||||
defer h.batchRunnerMu.Unlock()
|
||||
if _, exists := h.batchRunning[queueID]; exists {
|
||||
return false
|
||||
}
|
||||
h.batchRunning[queueID] = struct{}{}
|
||||
return true
|
||||
}
|
||||
|
||||
func (h *AgentHandler) unmarkBatchQueueRunning(queueID string) {
|
||||
h.batchRunnerMu.Lock()
|
||||
defer h.batchRunnerMu.Unlock()
|
||||
delete(h.batchRunning, queueID)
|
||||
}
|
||||
|
||||
func (h *AgentHandler) forceUnmarkBatchQueueRunning(queueID string) {
|
||||
h.unmarkBatchQueueRunning(queueID)
|
||||
}
|
||||
|
||||
func (h *AgentHandler) nextBatchQueueRunAt(cronExpr string, from time.Time) (*time.Time, error) {
|
||||
expr := strings.TrimSpace(cronExpr)
|
||||
if expr == "" {
|
||||
@@ -2017,43 +1978,43 @@ func (h *AgentHandler) nextBatchQueueRunAt(cronExpr string, from time.Time) (*ti
|
||||
|
||||
func (h *AgentHandler) startBatchQueueExecution(queueID string, scheduled bool) (bool, error) {
|
||||
// 先获取执行互斥门,再读取队列状态,避免基于过时快照做判断
|
||||
if !h.markBatchQueueRunning(queueID) {
|
||||
if !h.batchTaskManager.TryMarkQueueExecutor(queueID) {
|
||||
return true, nil
|
||||
}
|
||||
|
||||
queue, exists := h.batchTaskManager.GetBatchQueue(queueID)
|
||||
if !exists {
|
||||
h.unmarkBatchQueueRunning(queueID)
|
||||
h.batchTaskManager.UnmarkQueueExecutor(queueID)
|
||||
return false, nil
|
||||
}
|
||||
|
||||
if scheduled {
|
||||
if queue.ScheduleMode != "cron" {
|
||||
h.unmarkBatchQueueRunning(queueID)
|
||||
h.batchTaskManager.UnmarkQueueExecutor(queueID)
|
||||
err := fmt.Errorf("队列未启用 cron 调度")
|
||||
h.batchTaskManager.SetLastScheduleError(queueID, err.Error())
|
||||
return true, err
|
||||
}
|
||||
if queue.Status == "running" || queue.Status == "paused" || queue.Status == "cancelled" {
|
||||
h.unmarkBatchQueueRunning(queueID)
|
||||
h.batchTaskManager.UnmarkQueueExecutor(queueID)
|
||||
err := fmt.Errorf("当前队列状态不允许被调度执行")
|
||||
h.batchTaskManager.SetLastScheduleError(queueID, err.Error())
|
||||
return true, err
|
||||
}
|
||||
if !h.batchTaskManager.ResetQueueForRerun(queueID) {
|
||||
h.unmarkBatchQueueRunning(queueID)
|
||||
h.batchTaskManager.UnmarkQueueExecutor(queueID)
|
||||
err := fmt.Errorf("重置队列失败")
|
||||
h.batchTaskManager.SetLastScheduleError(queueID, err.Error())
|
||||
return true, err
|
||||
}
|
||||
queue, _ = h.batchTaskManager.GetBatchQueue(queueID)
|
||||
} else if queue.Status != "pending" && queue.Status != "paused" {
|
||||
h.unmarkBatchQueueRunning(queueID)
|
||||
h.batchTaskManager.UnmarkQueueExecutor(queueID)
|
||||
return true, fmt.Errorf("队列状态不允许启动")
|
||||
}
|
||||
|
||||
if queue != nil && batchQueueWantsEino(queue.AgentMode) && (h.config == nil || !h.config.MultiAgent.Enabled) {
|
||||
h.unmarkBatchQueueRunning(queueID)
|
||||
h.batchTaskManager.UnmarkQueueExecutor(queueID)
|
||||
err := fmt.Errorf("当前队列配置为 Eino 多代理,但系统未启用多代理")
|
||||
if scheduled {
|
||||
h.batchTaskManager.SetLastScheduleError(queueID, err.Error())
|
||||
@@ -2105,324 +2066,6 @@ func (h *AgentHandler) batchQueueSchedulerLoop() {
|
||||
}
|
||||
}
|
||||
|
||||
// executeBatchQueue 执行批量任务队列
|
||||
func (h *AgentHandler) executeBatchQueue(queueID string) {
|
||||
defer h.unmarkBatchQueueRunning(queueID)
|
||||
h.logger.Info("开始执行批量任务队列", zap.String("queueId", queueID))
|
||||
|
||||
for {
|
||||
// 检查队列状态
|
||||
queue, exists := h.batchTaskManager.GetBatchQueue(queueID)
|
||||
if !exists || queue.Status == "cancelled" || queue.Status == "completed" || queue.Status == "paused" {
|
||||
break
|
||||
}
|
||||
|
||||
// 获取下一个任务
|
||||
task, hasNext := h.batchTaskManager.GetNextTask(queueID)
|
||||
if !hasNext {
|
||||
// 所有任务完成:汇总子任务失败信息便于排障
|
||||
q, ok := h.batchTaskManager.GetBatchQueue(queueID)
|
||||
lastRunErr := ""
|
||||
if ok {
|
||||
for _, t := range q.Tasks {
|
||||
if t.Status == "failed" && t.Error != "" {
|
||||
lastRunErr = t.Error
|
||||
}
|
||||
}
|
||||
}
|
||||
h.batchTaskManager.SetLastRunError(queueID, lastRunErr)
|
||||
h.batchTaskManager.UpdateQueueStatus(queueID, "completed")
|
||||
h.logger.Info("批量任务队列执行完成", zap.String("queueId", queueID))
|
||||
break
|
||||
}
|
||||
|
||||
// 更新任务状态为运行中
|
||||
h.batchTaskManager.UpdateTaskStatus(queueID, task.ID, "running", "", "")
|
||||
|
||||
// 创建新对话
|
||||
title := safeTruncateString(task.Message, 50)
|
||||
batchMeta := audit.ConversationCreateMeta("batch_task")
|
||||
batchMeta.ProjectID = effectiveProjectID(h.config, queue.ProjectID)
|
||||
conv, err := h.db.CreateConversation(title, batchMeta)
|
||||
var conversationID string
|
||||
if err != nil {
|
||||
h.logger.Error("创建对话失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.Error(err))
|
||||
h.batchTaskManager.UpdateTaskStatus(queueID, task.ID, "failed", "", "创建对话失败: "+err.Error())
|
||||
h.batchTaskManager.MoveToNextTask(queueID)
|
||||
if h.batchTaskManager.TakeSingleRunTaskIfMatch(queueID, task.ID) {
|
||||
h.batchTaskManager.UpdateQueueStatus(queueID, "paused")
|
||||
break
|
||||
}
|
||||
continue
|
||||
}
|
||||
conversationID = conv.ID
|
||||
|
||||
// 保存conversationId到任务中(即使是运行中状态也要保存,以便查看对话)
|
||||
h.batchTaskManager.UpdateTaskStatusWithConversationID(queueID, task.ID, "running", "", "", conversationID)
|
||||
|
||||
// 应用角色用户提示词和工具配置
|
||||
finalMessage := task.Message
|
||||
var roleTools []string // 角色配置的工具列表
|
||||
if queue.Role != "" && queue.Role != "默认" {
|
||||
if h.config.Roles != nil {
|
||||
if role, exists := h.config.Roles[queue.Role]; exists && role.Enabled {
|
||||
// 应用用户提示词
|
||||
if role.UserPrompt != "" {
|
||||
finalMessage = role.UserPrompt + "\n\n" + task.Message
|
||||
h.logger.Info("应用角色用户提示词", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.String("role", queue.Role))
|
||||
}
|
||||
// 获取角色配置的工具列表(优先使用tools字段,向后兼容mcps字段)
|
||||
if len(role.Tools) > 0 {
|
||||
roleTools = role.Tools
|
||||
h.logger.Info("使用角色配置的工具列表", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.String("role", queue.Role), zap.Int("toolCount", len(roleTools)))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 保存用户消息(保存原始消息,不包含角色提示词)
|
||||
_, err = h.db.AddMessage(conversationID, "user", task.Message, nil)
|
||||
if err != nil {
|
||||
h.logger.Error("保存用户消息失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.String("conversationId", conversationID), zap.Error(err))
|
||||
}
|
||||
|
||||
// 预先创建助手消息,以便关联过程详情
|
||||
assistantMsg, err := h.db.AddMessage(conversationID, "assistant", "处理中...", nil)
|
||||
if err != nil {
|
||||
h.logger.Error("创建助手消息失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.String("conversationId", conversationID), zap.Error(err))
|
||||
// 如果创建失败,继续执行但不保存过程详情
|
||||
assistantMsg = nil
|
||||
}
|
||||
|
||||
// 创建进度回调函数,复用统一逻辑(批量任务不需要流式事件,所以传入nil)
|
||||
var assistantMessageID string
|
||||
if assistantMsg != nil {
|
||||
assistantMessageID = assistantMsg.ID
|
||||
}
|
||||
// 注意:批量任务没有前端直连的 POST /stream,因此若要支持「刷新后补流」,
|
||||
// 需要把进度事件镜像到 TaskEventBus(GET /api/agent-loop/task-events 会订阅这里)。
|
||||
// progressCallback 将在子任务的 IIFE 内创建,以便拿到 taskCtx/cancelWithCause 与 sendEvent。
|
||||
var progressCallback func(eventType, message string, data interface{})
|
||||
|
||||
// 执行任务(使用包含角色提示词的finalMessage和角色工具列表)
|
||||
h.logger.Info("执行批量任务", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.String("message", task.Message), zap.String("role", queue.Role), zap.String("conversationId", conversationID))
|
||||
|
||||
func() {
|
||||
// 与对话流式接口一致:同 conversationId 仅允许一个运行中任务,并支持 /api/agent-loop/cancel 与会话锁对齐。
|
||||
baseCtx, cancelWithCause := context.WithCancelCause(context.Background())
|
||||
// 单个子任务超时:6 小时(与原先 WithTimeout(Background) 一致)
|
||||
taskCtx, timeoutCancel := context.WithTimeout(baseCtx, 6*time.Hour)
|
||||
|
||||
registered := false
|
||||
finishStatus := "completed"
|
||||
|
||||
defer func() {
|
||||
h.batchTaskManager.SetTaskCancel(queueID, nil)
|
||||
timeoutCancel()
|
||||
if registered {
|
||||
// 与流式接口保持一致:结束前补一个 done,便于前端 task-events 侧及时收口 UI。
|
||||
if h.taskEventBus != nil {
|
||||
ev := StreamEvent{Type: "done", Message: "", Data: map[string]interface{}{"conversationId": conversationID}}
|
||||
if b, err := json.Marshal(ev); err == nil {
|
||||
h.taskEventBus.Publish(conversationID, append(append([]byte("data: "), b...), '\n', '\n'))
|
||||
}
|
||||
}
|
||||
h.tasks.FinishTask(conversationID, finishStatus)
|
||||
}
|
||||
cancelWithCause(nil)
|
||||
}()
|
||||
|
||||
// 事件镜像:只发布到 TaskEventBus,不直接写 HTTP Response(用于刷新后的补流)。
|
||||
sendEvent := func(eventType, message string, data interface{}) {
|
||||
if h.taskEventBus == nil {
|
||||
return
|
||||
}
|
||||
ev := StreamEvent{Type: eventType, Message: message, Data: data}
|
||||
b, err := json.Marshal(ev)
|
||||
if err != nil {
|
||||
b = []byte(`{"type":"error","message":"marshal failed"}`)
|
||||
}
|
||||
line := make([]byte, 0, len(b)+8)
|
||||
line = append(line, []byte("data: ")...)
|
||||
line = append(line, b...)
|
||||
line = append(line, '\n', '\n')
|
||||
h.taskEventBus.Publish(conversationID, line)
|
||||
}
|
||||
|
||||
if _, err := h.tasks.StartTask(conversationID, task.Message, cancelWithCause); err != nil {
|
||||
h.logger.Warn("批量队列子任务注册会话运行状态失败",
|
||||
zap.String("queueId", queueID),
|
||||
zap.String("taskId", task.ID),
|
||||
zap.String("conversationId", conversationID),
|
||||
zap.Error(err))
|
||||
failMsg := err.Error()
|
||||
if errors.Is(err, ErrTaskAlreadyRunning) {
|
||||
failMsg = "会话已有任务正在执行,无法在该会话上并行启动批量子任务"
|
||||
}
|
||||
h.batchTaskManager.UpdateTaskStatus(queueID, task.ID, "failed", "", failMsg)
|
||||
return
|
||||
}
|
||||
registered = true
|
||||
// 存储取消函数:暂停队列时取消子任务 context(与原先语义一致)
|
||||
h.batchTaskManager.SetTaskCancel(queueID, timeoutCancel)
|
||||
|
||||
// 创建进度回调函数:写 DB + 镜像到 task-events,支持刷新后继续流式展示。
|
||||
progressCallback = h.createProgressCallback(taskCtx, cancelWithCause, conversationID, assistantMessageID, sendEvent)
|
||||
taskCtx = mcp.WithMCPConversationID(taskCtx, conversationID)
|
||||
taskCtx = mcp.WithToolRunRegistry(taskCtx, h.tasks)
|
||||
|
||||
// 使用队列配置的角色工具列表(如果为空,表示使用所有工具)
|
||||
useBatchMulti := false
|
||||
batchOrch := "deep"
|
||||
am := strings.TrimSpace(strings.ToLower(queue.AgentMode))
|
||||
if am == "multi" {
|
||||
am = "deep"
|
||||
}
|
||||
if batchQueueWantsEino(queue.AgentMode) && h.config != nil && h.config.MultiAgent.Enabled {
|
||||
useBatchMulti = true
|
||||
batchOrch = config.NormalizeMultiAgentOrchestration(am)
|
||||
} else if queue.AgentMode == "" && h.config != nil && h.config.MultiAgent.Enabled && h.config.MultiAgent.BatchUseMultiAgent {
|
||||
// 兼容历史数据:未配置队列代理模式时,沿用旧的系统级开关
|
||||
useBatchMulti = true
|
||||
batchOrch = "deep"
|
||||
}
|
||||
var resultMA *multiagent.RunResult
|
||||
var runErr error
|
||||
switch {
|
||||
case useBatchMulti:
|
||||
resultMA, runErr = multiagent.RunDeepAgent(taskCtx, h.config, &h.config.MultiAgent, h.agent, h.db, h.logger, conversationID, h.conversationProjectID(conversationID), finalMessage, []agent.ChatMessage{}, roleTools, progressCallback, h.agentsMarkdownDir, batchOrch, nil, h.projectBlackboardBlock(conversationID))
|
||||
default:
|
||||
if h.config == nil {
|
||||
runErr = fmt.Errorf("服务器配置未加载")
|
||||
} else {
|
||||
resultMA, runErr = multiagent.RunEinoSingleChatModelAgent(taskCtx, h.config, &h.config.MultiAgent, h.agent, h.db, h.logger, conversationID, h.conversationProjectID(conversationID), finalMessage, []agent.ChatMessage{}, roleTools, progressCallback, nil, h.projectBlackboardBlock(conversationID))
|
||||
}
|
||||
}
|
||||
|
||||
if runErr != nil {
|
||||
if shouldPersistEinoAgentTraceAfterRunError(baseCtx) {
|
||||
h.persistEinoAgentTraceForResume(conversationID, resultMA)
|
||||
}
|
||||
errStr := runErr.Error()
|
||||
partialResp := ""
|
||||
if resultMA != nil {
|
||||
partialResp = resultMA.Response
|
||||
}
|
||||
isCancelled := errors.Is(context.Cause(baseCtx), ErrTaskCancelled) ||
|
||||
errors.Is(runErr, context.Canceled) ||
|
||||
strings.Contains(strings.ToLower(errStr), "context canceled") ||
|
||||
strings.Contains(strings.ToLower(errStr), "context cancelled") ||
|
||||
(partialResp != "" && (strings.Contains(partialResp, "任务已被取消") || strings.Contains(partialResp, "任务执行中断")))
|
||||
isTimeout := errors.Is(runErr, context.DeadlineExceeded) || errors.Is(context.Cause(taskCtx), context.DeadlineExceeded)
|
||||
|
||||
if isTimeout {
|
||||
finishStatus = "timeout"
|
||||
} else if isCancelled {
|
||||
finishStatus = "cancelled"
|
||||
} else {
|
||||
finishStatus = "failed"
|
||||
}
|
||||
|
||||
if isCancelled {
|
||||
h.logger.Info("批量任务被取消", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.String("conversationId", conversationID))
|
||||
cancelMsg := "任务已被用户取消,后续操作已停止。"
|
||||
// 如果执行结果中有更具体的取消消息,使用它
|
||||
if partialResp != "" && (strings.Contains(partialResp, "任务已被取消") || strings.Contains(partialResp, "任务执行中断")) {
|
||||
cancelMsg = partialResp
|
||||
}
|
||||
// 更新助手消息内容
|
||||
if assistantMessageID != "" {
|
||||
if updateErr := h.appendAssistantMessageNotice(assistantMessageID, cancelMsg); updateErr != nil {
|
||||
h.logger.Warn("更新取消后的助手消息失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.Error(updateErr))
|
||||
}
|
||||
// 保存取消详情到数据库
|
||||
if err := h.db.AddProcessDetail(assistantMessageID, conversationID, "cancelled", cancelMsg, nil); err != nil {
|
||||
h.logger.Warn("保存取消详情失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.Error(err))
|
||||
}
|
||||
} else {
|
||||
// 如果没有预先创建的助手消息,创建一个新的
|
||||
_, errMsg := h.db.AddMessage(conversationID, "assistant", cancelMsg, nil)
|
||||
if errMsg != nil {
|
||||
h.logger.Warn("保存取消消息失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.Error(errMsg))
|
||||
}
|
||||
}
|
||||
h.batchTaskManager.UpdateTaskStatusWithConversationID(queueID, task.ID, "cancelled", cancelMsg, "", conversationID)
|
||||
} else {
|
||||
h.logger.Error("批量任务执行失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.String("conversationId", conversationID), zap.Error(runErr))
|
||||
errorMsg := "执行失败: " + runErr.Error()
|
||||
// 更新助手消息内容
|
||||
if assistantMessageID != "" {
|
||||
if _, updateErr := h.db.Exec(
|
||||
"UPDATE messages SET content = ?, updated_at = ? WHERE id = ?",
|
||||
errorMsg,
|
||||
time.Now(), assistantMessageID,
|
||||
); updateErr != nil {
|
||||
h.logger.Warn("更新失败后的助手消息失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.Error(updateErr))
|
||||
}
|
||||
// 保存错误详情到数据库
|
||||
if err := h.db.AddProcessDetail(assistantMessageID, conversationID, "error", errorMsg, nil); err != nil {
|
||||
h.logger.Warn("保存错误详情失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.Error(err))
|
||||
}
|
||||
}
|
||||
h.batchTaskManager.UpdateTaskStatus(queueID, task.ID, "failed", "", runErr.Error())
|
||||
}
|
||||
} else {
|
||||
h.logger.Info("批量任务执行成功", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.String("conversationId", conversationID))
|
||||
|
||||
resText := resultMA.Response
|
||||
mcpIDs := resultMA.MCPExecutionIDs
|
||||
lastIn := resultMA.LastAgentTraceInput
|
||||
lastOut := resultMA.LastAgentTraceOutput
|
||||
|
||||
// 更新助手消息内容
|
||||
if assistantMessageID != "" {
|
||||
if updateErr := h.db.UpdateAssistantMessageFinalize(assistantMessageID, resText, mcpIDs, multiagent.AggregatedReasoningFromTraceJSON(lastIn)); updateErr != nil {
|
||||
h.logger.Warn("更新助手消息失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.Error(updateErr))
|
||||
// 如果更新失败,尝试创建新消息
|
||||
_, err = h.db.AddMessage(conversationID, "assistant", resText, mcpIDs)
|
||||
if err != nil {
|
||||
h.logger.Error("保存助手消息失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.String("conversationId", conversationID), zap.Error(err))
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// 如果没有预先创建的助手消息,创建一个新的
|
||||
_, err = h.db.AddMessage(conversationID, "assistant", resText, mcpIDs)
|
||||
if err != nil {
|
||||
h.logger.Error("保存助手消息失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.String("conversationId", conversationID), zap.Error(err))
|
||||
}
|
||||
}
|
||||
|
||||
// 保存代理轨迹
|
||||
if lastIn != "" || lastOut != "" {
|
||||
if err := h.db.SaveAgentTrace(conversationID, lastIn, lastOut); err != nil {
|
||||
h.logger.Warn("保存代理轨迹失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.Error(err))
|
||||
} else {
|
||||
h.logger.Info("已保存代理轨迹", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.String("conversationId", conversationID))
|
||||
}
|
||||
}
|
||||
|
||||
// 保存结果
|
||||
h.batchTaskManager.UpdateTaskStatusWithConversationID(queueID, task.ID, "completed", resText, "", conversationID)
|
||||
}
|
||||
}()
|
||||
|
||||
// 移动到下一个任务
|
||||
h.batchTaskManager.MoveToNextTask(queueID)
|
||||
|
||||
if h.batchTaskManager.TakeSingleRunTaskIfMatch(queueID, task.ID) {
|
||||
h.batchTaskManager.UpdateQueueStatus(queueID, "paused")
|
||||
h.logger.Info("单条执行完成,队列已暂停", zap.String("queueId", queueID), zap.String("taskId", task.ID))
|
||||
break
|
||||
}
|
||||
|
||||
// 检查是否被取消或暂停
|
||||
queue, _ = h.batchTaskManager.GetBatchQueue(queueID)
|
||||
if queue.Status == "cancelled" || queue.Status == "paused" {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// loadHistoryFromAgentTrace 从库中保存的代理消息轨迹恢复历史(列 last_react_*;含单代理与 Eino)。
|
||||
// 逻辑与攻击链一致:优先用已保存的 JSON 消息带 + 最后一轮助手摘要,否则回退消息表。
|
||||
func (h *AgentHandler) loadHistoryFromAgentTrace(conversationID string) ([]agent.ChatMessage, error) {
|
||||
|
||||
@@ -0,0 +1,352 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"cyberstrike-ai/internal/agent"
|
||||
"cyberstrike-ai/internal/audit"
|
||||
"cyberstrike-ai/internal/config"
|
||||
"cyberstrike-ai/internal/mcp"
|
||||
"cyberstrike-ai/internal/multiagent"
|
||||
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
const batchQueueWorkerIdlePoll = 200 * time.Millisecond
|
||||
|
||||
// executeBatchQueue 使用并发 worker 池执行批量任务队列。
|
||||
func (h *AgentHandler) executeBatchQueue(queueID string) {
|
||||
defer h.batchTaskManager.UnmarkQueueExecutor(queueID)
|
||||
|
||||
queue, exists := h.batchTaskManager.GetBatchQueue(queueID)
|
||||
if !exists {
|
||||
return
|
||||
}
|
||||
concurrency := normalizeBatchQueueConcurrency(queue.Concurrency)
|
||||
h.logger.Info("开始执行批量任务队列", zap.String("queueId", queueID), zap.Int("concurrency", concurrency))
|
||||
|
||||
var wg sync.WaitGroup
|
||||
for i := 0; i < concurrency; i++ {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
h.runBatchQueueWorker(queueID)
|
||||
}()
|
||||
}
|
||||
wg.Wait()
|
||||
|
||||
h.tryFinalizeBatchQueue(queueID)
|
||||
}
|
||||
|
||||
func (h *AgentHandler) runBatchQueueWorker(queueID string) {
|
||||
for {
|
||||
queue, exists := h.batchTaskManager.GetBatchQueue(queueID)
|
||||
if batchQueueExecutionShouldStop(queue, exists) {
|
||||
return
|
||||
}
|
||||
|
||||
task, ok := h.batchTaskManager.ClaimNextPendingTask(queueID)
|
||||
if !ok {
|
||||
if !h.batchTaskManager.HasRunningTasks(queueID) {
|
||||
return
|
||||
}
|
||||
time.Sleep(batchQueueWorkerIdlePoll)
|
||||
continue
|
||||
}
|
||||
|
||||
queue, _ = h.batchTaskManager.GetBatchQueue(queueID)
|
||||
if queue == nil {
|
||||
return
|
||||
}
|
||||
|
||||
h.batchTaskManager.UpdateTaskStatus(queueID, task.ID, BatchTaskStatusRunning, "", "")
|
||||
h.executeOneBatchSubTask(queueID, queue, task)
|
||||
|
||||
if h.batchTaskManager.TakeSingleRunTaskIfMatch(queueID, task.ID) {
|
||||
h.batchTaskManager.UpdateQueueStatus(queueID, BatchQueueStatusPaused)
|
||||
h.logger.Info("单条执行完成,队列已暂停", zap.String("queueId", queueID), zap.String("taskId", task.ID))
|
||||
return
|
||||
}
|
||||
|
||||
queue, exists = h.batchTaskManager.GetBatchQueue(queueID)
|
||||
if batchQueueExecutionShouldStop(queue, exists) {
|
||||
if !exists {
|
||||
h.logger.Warn("批量队列在执行收尾时已不存在,安全退出", zap.String("queueId", queueID))
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (h *AgentHandler) tryFinalizeBatchQueue(queueID string) {
|
||||
queue, exists := h.batchTaskManager.GetBatchQueue(queueID)
|
||||
if !exists || queue == nil {
|
||||
return
|
||||
}
|
||||
if queue.Status != BatchQueueStatusRunning {
|
||||
return
|
||||
}
|
||||
if h.batchTaskManager.HasPendingOrRunningTasks(queueID) {
|
||||
return
|
||||
}
|
||||
|
||||
lastRunErr := ""
|
||||
for _, t := range queue.Tasks {
|
||||
if t != nil && t.Status == BatchTaskStatusFailed && t.Error != "" {
|
||||
lastRunErr = t.Error
|
||||
}
|
||||
}
|
||||
h.batchTaskManager.SetLastRunError(queueID, lastRunErr)
|
||||
h.batchTaskManager.UpdateQueueStatus(queueID, BatchQueueStatusCompleted)
|
||||
h.logger.Info("批量任务队列执行完成", zap.String("queueId", queueID))
|
||||
}
|
||||
|
||||
// executeOneBatchSubTask 执行单条批量子任务(各自独立会话)。
|
||||
func (h *AgentHandler) executeOneBatchSubTask(queueID string, queue *BatchTaskQueue, task *BatchTask) {
|
||||
title := safeTruncateString(task.Message, 50)
|
||||
batchMeta := audit.ConversationCreateMeta("batch_task")
|
||||
batchMeta.ProjectID = effectiveProjectID(h.config, queue.ProjectID)
|
||||
conv, err := h.db.CreateConversation(title, batchMeta)
|
||||
if err != nil {
|
||||
h.logger.Error("创建对话失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.Error(err))
|
||||
h.batchTaskManager.UpdateTaskStatus(queueID, task.ID, BatchTaskStatusFailed, "", "创建对话失败: "+err.Error())
|
||||
return
|
||||
}
|
||||
conversationID := conv.ID
|
||||
|
||||
h.batchTaskManager.UpdateTaskStatusWithConversationID(queueID, task.ID, BatchTaskStatusRunning, "", "", conversationID)
|
||||
|
||||
finalMessage := task.Message
|
||||
var roleTools []string
|
||||
if queue.Role != "" && queue.Role != "默认" {
|
||||
if h.config.Roles != nil {
|
||||
if role, exists := h.config.Roles[queue.Role]; exists && role.Enabled {
|
||||
if role.UserPrompt != "" {
|
||||
finalMessage = role.UserPrompt + "\n\n" + task.Message
|
||||
h.logger.Info("应用角色用户提示词", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.String("role", queue.Role))
|
||||
}
|
||||
if len(role.Tools) > 0 {
|
||||
roleTools = role.Tools
|
||||
h.logger.Info("使用角色配置的工具列表", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.String("role", queue.Role), zap.Int("toolCount", len(roleTools)))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if _, err = h.db.AddMessage(conversationID, "user", task.Message, nil); err != nil {
|
||||
h.logger.Error("保存用户消息失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.String("conversationId", conversationID), zap.Error(err))
|
||||
}
|
||||
|
||||
assistantMsg, err := h.db.AddMessage(conversationID, "assistant", "处理中...", nil)
|
||||
if err != nil {
|
||||
h.logger.Error("创建助手消息失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.String("conversationId", conversationID), zap.Error(err))
|
||||
assistantMsg = nil
|
||||
}
|
||||
|
||||
var assistantMessageID string
|
||||
if assistantMsg != nil {
|
||||
assistantMessageID = assistantMsg.ID
|
||||
}
|
||||
|
||||
h.logger.Info("执行批量任务", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.String("message", task.Message), zap.String("role", queue.Role), zap.String("conversationId", conversationID))
|
||||
|
||||
baseCtx, cancelWithCause := context.WithCancelCause(context.Background())
|
||||
taskCtx, timeoutCancel := context.WithTimeout(baseCtx, 6*time.Hour)
|
||||
|
||||
registered := false
|
||||
finishStatus := "completed"
|
||||
|
||||
defer func() {
|
||||
h.batchTaskManager.SetTaskCancel(queueID, task.ID, nil)
|
||||
timeoutCancel()
|
||||
if registered {
|
||||
if h.taskEventBus != nil {
|
||||
ev := StreamEvent{Type: "done", Message: "", Data: map[string]interface{}{"conversationId": conversationID}}
|
||||
if b, err := json.Marshal(ev); err == nil {
|
||||
h.taskEventBus.Publish(conversationID, append(append([]byte("data: "), b...), '\n', '\n'))
|
||||
}
|
||||
}
|
||||
h.tasks.FinishTask(conversationID, finishStatus)
|
||||
}
|
||||
cancelWithCause(nil)
|
||||
}()
|
||||
|
||||
sendEvent := func(eventType, message string, data interface{}) {
|
||||
if h.taskEventBus == nil {
|
||||
return
|
||||
}
|
||||
ev := StreamEvent{Type: eventType, Message: message, Data: data}
|
||||
b, err := json.Marshal(ev)
|
||||
if err != nil {
|
||||
b = []byte(`{"type":"error","message":"marshal failed"}`)
|
||||
}
|
||||
line := make([]byte, 0, len(b)+8)
|
||||
line = append(line, []byte("data: ")...)
|
||||
line = append(line, b...)
|
||||
line = append(line, '\n', '\n')
|
||||
h.taskEventBus.Publish(conversationID, line)
|
||||
}
|
||||
|
||||
if _, err := h.tasks.StartTask(conversationID, task.Message, cancelWithCause); err != nil {
|
||||
h.logger.Warn("批量队列子任务注册会话运行状态失败",
|
||||
zap.String("queueId", queueID),
|
||||
zap.String("taskId", task.ID),
|
||||
zap.String("conversationId", conversationID),
|
||||
zap.Error(err))
|
||||
failMsg := err.Error()
|
||||
if errors.Is(err, ErrTaskAlreadyRunning) {
|
||||
failMsg = "会话已有任务正在执行,无法在该会话上并行启动批量子任务"
|
||||
}
|
||||
h.batchTaskManager.UpdateTaskStatus(queueID, task.ID, BatchTaskStatusFailed, "", failMsg)
|
||||
return
|
||||
}
|
||||
registered = true
|
||||
h.batchTaskManager.SetTaskCancel(queueID, task.ID, timeoutCancel)
|
||||
|
||||
progressCallback := h.createProgressCallback(taskCtx, cancelWithCause, conversationID, assistantMessageID, sendEvent)
|
||||
taskCtx = mcp.WithMCPConversationID(taskCtx, conversationID)
|
||||
taskCtx = mcp.WithToolRunRegistry(taskCtx, h.tasks)
|
||||
taskCtx = mcp.WithEinoExecuteRunRegistry(taskCtx, h.tasks)
|
||||
|
||||
useBatchMulti := false
|
||||
batchOrch := "deep"
|
||||
am := strings.TrimSpace(strings.ToLower(queue.AgentMode))
|
||||
if am == "multi" {
|
||||
am = "deep"
|
||||
}
|
||||
if batchQueueWantsEino(queue.AgentMode) && h.config != nil && h.config.MultiAgent.Enabled {
|
||||
useBatchMulti = true
|
||||
batchOrch = config.NormalizeMultiAgentOrchestration(am)
|
||||
} else if queue.AgentMode == "" && h.config != nil && h.config.MultiAgent.Enabled && h.config.MultiAgent.BatchUseMultiAgent {
|
||||
useBatchMulti = true
|
||||
batchOrch = "deep"
|
||||
}
|
||||
|
||||
var resultMA *multiagent.RunResult
|
||||
var runErr error
|
||||
switch {
|
||||
case useBatchMulti:
|
||||
resultMA, runErr = multiagent.RunDeepAgent(taskCtx, h.config, &h.config.MultiAgent, h.agent, h.db, h.logger, conversationID, h.conversationProjectID(conversationID), finalMessage, []agent.ChatMessage{}, roleTools, progressCallback, h.agentsMarkdownDir, batchOrch, nil, h.projectBlackboardBlock(conversationID))
|
||||
default:
|
||||
if h.config == nil {
|
||||
runErr = fmt.Errorf("服务器配置未加载")
|
||||
} else {
|
||||
resultMA, runErr = multiagent.RunEinoSingleChatModelAgent(taskCtx, h.config, &h.config.MultiAgent, h.agent, h.db, h.logger, conversationID, h.conversationProjectID(conversationID), finalMessage, []agent.ChatMessage{}, roleTools, progressCallback, nil, h.projectBlackboardBlock(conversationID))
|
||||
}
|
||||
}
|
||||
|
||||
if runErr != nil {
|
||||
h.handleBatchSubTaskRunError(queueID, task, conversationID, assistantMessageID, baseCtx, taskCtx, resultMA, runErr, &finishStatus)
|
||||
return
|
||||
}
|
||||
|
||||
if resultMA == nil {
|
||||
h.logger.Error("批量任务执行成功但无结果对象",
|
||||
zap.String("queueId", queueID),
|
||||
zap.String("taskId", task.ID),
|
||||
zap.String("conversationId", conversationID))
|
||||
h.batchTaskManager.UpdateTaskStatus(queueID, task.ID, BatchTaskStatusFailed, "", "内部错误:无执行结果")
|
||||
return
|
||||
}
|
||||
|
||||
h.logger.Info("批量任务执行成功", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.String("conversationId", conversationID))
|
||||
|
||||
resText := resultMA.Response
|
||||
mcpIDs := resultMA.MCPExecutionIDs
|
||||
lastIn := resultMA.LastAgentTraceInput
|
||||
lastOut := resultMA.LastAgentTraceOutput
|
||||
|
||||
if assistantMessageID != "" {
|
||||
if updateErr := h.db.UpdateAssistantMessageFinalize(assistantMessageID, resText, mcpIDs, multiagent.AggregatedReasoningFromTraceJSON(lastIn)); updateErr != nil {
|
||||
h.logger.Warn("更新助手消息失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.Error(updateErr))
|
||||
if _, err = h.db.AddMessage(conversationID, "assistant", resText, mcpIDs); err != nil {
|
||||
h.logger.Error("保存助手消息失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.String("conversationId", conversationID), zap.Error(err))
|
||||
}
|
||||
}
|
||||
} else if _, err = h.db.AddMessage(conversationID, "assistant", resText, mcpIDs); err != nil {
|
||||
h.logger.Error("保存助手消息失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.String("conversationId", conversationID), zap.Error(err))
|
||||
}
|
||||
|
||||
if lastIn != "" || lastOut != "" {
|
||||
if err := h.db.SaveAgentTrace(conversationID, lastIn, lastOut); err != nil {
|
||||
h.logger.Warn("保存代理轨迹失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.Error(err))
|
||||
}
|
||||
}
|
||||
|
||||
h.batchTaskManager.UpdateTaskStatusWithConversationID(queueID, task.ID, BatchTaskStatusCompleted, resText, "", conversationID)
|
||||
}
|
||||
|
||||
func (h *AgentHandler) handleBatchSubTaskRunError(
|
||||
queueID string,
|
||||
task *BatchTask,
|
||||
conversationID, assistantMessageID string,
|
||||
baseCtx, taskCtx context.Context,
|
||||
resultMA *multiagent.RunResult,
|
||||
runErr error,
|
||||
finishStatus *string,
|
||||
) {
|
||||
if shouldPersistEinoAgentTraceAfterRunError(baseCtx) {
|
||||
h.persistEinoAgentTraceForResume(conversationID, resultMA)
|
||||
}
|
||||
errStr := runErr.Error()
|
||||
partialResp := ""
|
||||
if resultMA != nil {
|
||||
partialResp = resultMA.Response
|
||||
}
|
||||
isCancelled := errors.Is(context.Cause(baseCtx), ErrTaskCancelled) ||
|
||||
errors.Is(runErr, context.Canceled) ||
|
||||
strings.Contains(strings.ToLower(errStr), "context canceled") ||
|
||||
strings.Contains(strings.ToLower(errStr), "context cancelled") ||
|
||||
(partialResp != "" && (strings.Contains(partialResp, "任务已被取消") || strings.Contains(partialResp, "任务执行中断")))
|
||||
isTimeout := errors.Is(runErr, context.DeadlineExceeded) || errors.Is(context.Cause(taskCtx), context.DeadlineExceeded)
|
||||
|
||||
if isTimeout {
|
||||
*finishStatus = "timeout"
|
||||
} else if isCancelled {
|
||||
*finishStatus = "cancelled"
|
||||
} else {
|
||||
*finishStatus = "failed"
|
||||
}
|
||||
|
||||
if isCancelled {
|
||||
h.logger.Info("批量任务被取消", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.String("conversationId", conversationID))
|
||||
cancelMsg := "任务已被用户取消,后续操作已停止。"
|
||||
if partialResp != "" && (strings.Contains(partialResp, "任务已被取消") || strings.Contains(partialResp, "任务执行中断")) {
|
||||
cancelMsg = partialResp
|
||||
}
|
||||
if assistantMessageID != "" {
|
||||
if updateErr := h.appendAssistantMessageNotice(assistantMessageID, cancelMsg); updateErr != nil {
|
||||
h.logger.Warn("更新取消后的助手消息失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.Error(updateErr))
|
||||
}
|
||||
if err := h.db.AddProcessDetail(assistantMessageID, conversationID, "cancelled", cancelMsg, nil); err != nil {
|
||||
h.logger.Warn("保存取消详情失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.Error(err))
|
||||
}
|
||||
} else if _, errMsg := h.db.AddMessage(conversationID, "assistant", cancelMsg, nil); errMsg != nil {
|
||||
h.logger.Warn("保存取消消息失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.Error(errMsg))
|
||||
}
|
||||
h.batchTaskManager.UpdateTaskStatusWithConversationID(queueID, task.ID, BatchTaskStatusCancelled, cancelMsg, "", conversationID)
|
||||
return
|
||||
}
|
||||
|
||||
h.logger.Error("批量任务执行失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.String("conversationId", conversationID), zap.Error(runErr))
|
||||
errorMsg := "执行失败: " + runErr.Error()
|
||||
if assistantMessageID != "" {
|
||||
if _, updateErr := h.db.Exec(
|
||||
"UPDATE messages SET content = ?, updated_at = ? WHERE id = ?",
|
||||
errorMsg,
|
||||
time.Now(), assistantMessageID,
|
||||
); updateErr != nil {
|
||||
h.logger.Warn("更新失败后的助手消息失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.Error(updateErr))
|
||||
}
|
||||
if err := h.db.AddProcessDetail(assistantMessageID, conversationID, "error", errorMsg, nil); err != nil {
|
||||
h.logger.Warn("保存错误详情失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.Error(err))
|
||||
}
|
||||
}
|
||||
h.batchTaskManager.UpdateTaskStatus(queueID, task.ID, BatchTaskStatusFailed, "", runErr.Error())
|
||||
}
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"fmt"
|
||||
"sort"
|
||||
"strings"
|
||||
@@ -17,6 +18,15 @@ import (
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
var (
|
||||
// ErrBatchQueueNotFound 队列不存在或已从内存卸载。
|
||||
ErrBatchQueueNotFound = errors.New("batch queue not found")
|
||||
// ErrBatchQueueExecutorActive executeBatchQueue 协程仍在收尾,禁止删除。
|
||||
ErrBatchQueueExecutorActive = errors.New("batch queue executor is still active")
|
||||
// ErrBatchQueueStillRunning 队列状态仍为 running(无活跃执行器时的兜底保护)。
|
||||
ErrBatchQueueStillRunning = errors.New("batch queue is still running")
|
||||
)
|
||||
|
||||
// 批量任务状态常量
|
||||
const (
|
||||
BatchQueueStatusPending = "pending"
|
||||
@@ -39,6 +49,12 @@ const (
|
||||
|
||||
// MaxBatchQueueRoleLen 角色名最大长度
|
||||
MaxBatchQueueRoleLen = 100
|
||||
|
||||
// DefaultBatchQueueConcurrency 批量队列默认并发数(串行)
|
||||
DefaultBatchQueueConcurrency = 1
|
||||
|
||||
// MaxBatchQueueConcurrency 批量队列最大并发数
|
||||
MaxBatchQueueConcurrency = 8
|
||||
)
|
||||
|
||||
// BatchTask 批量任务项
|
||||
@@ -67,6 +83,7 @@ type BatchTaskQueue struct {
|
||||
LastScheduleError string `json:"lastScheduleError,omitempty"`
|
||||
LastRunError string `json:"lastRunError,omitempty"`
|
||||
ProjectID string `json:"projectId,omitempty"`
|
||||
Concurrency int `json:"concurrency"` // 同时执行的子任务数,默认 1
|
||||
Tasks []*BatchTask `json:"tasks"`
|
||||
Status string `json:"status"` // pending, running, paused, completed, cancelled
|
||||
CreatedAt time.Time `json:"createdAt"`
|
||||
@@ -80,8 +97,9 @@ type BatchTaskManager struct {
|
||||
db *database.DB
|
||||
logger *zap.Logger
|
||||
queues map[string]*BatchTaskQueue
|
||||
taskCancels map[string]context.CancelFunc // 存储每个队列当前任务的取消函数
|
||||
taskCancels map[string]map[string]context.CancelFunc // queueID -> taskID -> 取消函数
|
||||
singleRunTasks map[string]string // queueID -> taskID,单条执行完成后暂停队列
|
||||
queueExecutors map[string]struct{} // executeBatchQueue 协程活跃标记(与队列 status 解耦)
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
@@ -93,11 +111,56 @@ func NewBatchTaskManager(logger *zap.Logger) *BatchTaskManager {
|
||||
return &BatchTaskManager{
|
||||
logger: logger,
|
||||
queues: make(map[string]*BatchTaskQueue),
|
||||
taskCancels: make(map[string]context.CancelFunc),
|
||||
taskCancels: make(map[string]map[string]context.CancelFunc),
|
||||
singleRunTasks: make(map[string]string),
|
||||
queueExecutors: make(map[string]struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
// batchQueueExecutionShouldStop 判断 executeBatchQueue 主循环是否应退出。
|
||||
func batchQueueExecutionShouldStop(queue *BatchTaskQueue, exists bool) bool {
|
||||
if !exists || queue == nil {
|
||||
return true
|
||||
}
|
||||
switch queue.Status {
|
||||
case BatchQueueStatusCancelled, BatchQueueStatusCompleted, BatchQueueStatusPaused:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// TryMarkQueueExecutor 标记队列执行协程已启动;若已有执行协程则返回 false。
|
||||
func (m *BatchTaskManager) TryMarkQueueExecutor(queueID string) bool {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
if _, exists := m.queueExecutors[queueID]; exists {
|
||||
return false
|
||||
}
|
||||
m.queueExecutors[queueID] = struct{}{}
|
||||
return true
|
||||
}
|
||||
|
||||
// UnmarkQueueExecutor 清除队列执行协程标记(executeBatchQueue defer 调用)。
|
||||
func (m *BatchTaskManager) UnmarkQueueExecutor(queueID string) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
delete(m.queueExecutors, queueID)
|
||||
}
|
||||
|
||||
// ForceUnmarkQueueExecutor 强制清除执行协程标记(暂停态单条重跑等场景回收陈旧槽位)。
|
||||
func (m *BatchTaskManager) ForceUnmarkQueueExecutor(queueID string) {
|
||||
m.UnmarkQueueExecutor(queueID)
|
||||
}
|
||||
|
||||
// IsQueueExecutorActive 队列 executeBatchQueue 协程是否仍在运行。
|
||||
func (m *BatchTaskManager) IsQueueExecutorActive(queueID string) bool {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
_, ok := m.queueExecutors[queueID]
|
||||
return ok
|
||||
}
|
||||
|
||||
// SetDB 设置数据库连接
|
||||
func (m *BatchTaskManager) SetDB(db *database.DB) {
|
||||
m.mu.Lock()
|
||||
@@ -105,10 +168,22 @@ func (m *BatchTaskManager) SetDB(db *database.DB) {
|
||||
m.db = db
|
||||
}
|
||||
|
||||
// normalizeBatchQueueConcurrency 规范化队列并发数。
|
||||
func normalizeBatchQueueConcurrency(n int) int {
|
||||
if n < 1 {
|
||||
return DefaultBatchQueueConcurrency
|
||||
}
|
||||
if n > MaxBatchQueueConcurrency {
|
||||
return MaxBatchQueueConcurrency
|
||||
}
|
||||
return n
|
||||
}
|
||||
|
||||
// CreateBatchQueue 创建批量任务队列
|
||||
func (m *BatchTaskManager) CreateBatchQueue(
|
||||
title, role, agentMode, scheduleMode, cronExpr, projectID string,
|
||||
nextRunAt *time.Time,
|
||||
concurrency int,
|
||||
tasks []string,
|
||||
) (*BatchTaskQueue, error) {
|
||||
// 输入校验
|
||||
@@ -136,6 +211,7 @@ func (m *BatchTaskManager) CreateBatchQueue(
|
||||
CronExpr: strings.TrimSpace(cronExpr),
|
||||
NextRunAt: nextRunAt,
|
||||
ScheduleEnabled: true,
|
||||
Concurrency: normalizeBatchQueueConcurrency(concurrency),
|
||||
Tasks: make([]*BatchTask, 0, len(tasks)),
|
||||
Status: BatchQueueStatusPending,
|
||||
CreatedAt: time.Now(),
|
||||
@@ -177,6 +253,7 @@ func (m *BatchTaskManager) CreateBatchQueue(
|
||||
queue.CronExpr,
|
||||
queue.NextRunAt,
|
||||
queue.ProjectID,
|
||||
queue.Concurrency,
|
||||
dbTasks,
|
||||
); err != nil {
|
||||
m.logger.Warn("batch queue DB create failed", zap.String("queueId", queueID), zap.Error(err))
|
||||
@@ -272,6 +349,7 @@ func (m *BatchTaskManager) loadQueueFromDB(queueID string) *BatchTaskQueue {
|
||||
if queueRow.ProjectID.Valid {
|
||||
queue.ProjectID = strings.TrimSpace(queueRow.ProjectID.String)
|
||||
}
|
||||
queue.Concurrency = batchQueueConcurrencyFromRow(queueRow)
|
||||
if queueRow.StartedAt.Valid {
|
||||
queue.StartedAt = &queueRow.StartedAt.Time
|
||||
}
|
||||
@@ -511,6 +589,7 @@ func (m *BatchTaskManager) LoadFromDB() error {
|
||||
if queueRow.ProjectID.Valid {
|
||||
queue.ProjectID = strings.TrimSpace(queueRow.ProjectID.String)
|
||||
}
|
||||
queue.Concurrency = batchQueueConcurrencyFromRow(queueRow)
|
||||
if queueRow.StartedAt.Valid {
|
||||
queue.StartedAt = &queueRow.StartedAt.Time
|
||||
}
|
||||
@@ -651,8 +730,16 @@ func (m *BatchTaskManager) UpdateQueueSchedule(queueID, scheduleMode, cronExpr s
|
||||
}
|
||||
}
|
||||
|
||||
// UpdateQueueMetadata 更新队列标题、角色和代理模式(非 running 时可用)
|
||||
func (m *BatchTaskManager) UpdateQueueMetadata(queueID, title, role, agentMode string) error {
|
||||
// batchQueueConcurrencyFromRow 从数据库行读取并发数(缺省为 1)。
|
||||
func batchQueueConcurrencyFromRow(row *database.BatchTaskQueueRow) int {
|
||||
if row == nil || !row.Concurrency.Valid {
|
||||
return DefaultBatchQueueConcurrency
|
||||
}
|
||||
return normalizeBatchQueueConcurrency(int(row.Concurrency.Int64))
|
||||
}
|
||||
|
||||
// UpdateQueueMetadata 更新队列标题、角色、代理模式和并发数(非 running 时可用)
|
||||
func (m *BatchTaskManager) UpdateQueueMetadata(queueID, title, role, agentMode string, concurrency *int) error {
|
||||
if utf8.RuneCountInString(title) > MaxBatchQueueTitleLen {
|
||||
return fmt.Errorf("标题不能超过 %d 个字符", MaxBatchQueueTitleLen)
|
||||
}
|
||||
@@ -680,9 +767,12 @@ func (m *BatchTaskManager) UpdateQueueMetadata(queueID, title, role, agentMode s
|
||||
queue.Title = title
|
||||
queue.Role = role
|
||||
queue.AgentMode = agentMode
|
||||
if concurrency != nil {
|
||||
queue.Concurrency = normalizeBatchQueueConcurrency(*concurrency)
|
||||
}
|
||||
|
||||
if m.db != nil {
|
||||
if err := m.db.UpdateBatchQueueMetadata(queueID, title, role, agentMode); err != nil {
|
||||
if err := m.db.UpdateBatchQueueMetadata(queueID, title, role, agentMode, queue.Concurrency); err != nil {
|
||||
m.logger.Warn("batch queue DB metadata update failed", zap.String("queueId", queueID), zap.Error(err))
|
||||
}
|
||||
}
|
||||
@@ -868,7 +958,6 @@ func (m *BatchTaskManager) AddTaskToQueue(queueID, message string) (*BatchTask,
|
||||
|
||||
// PrepareSingleTaskRun 准备单条执行:重置目标任务(若已有结果)并定位队列索引
|
||||
func (m *BatchTaskManager) PrepareSingleTaskRun(queueID, taskID string) error {
|
||||
var cancelFunc context.CancelFunc
|
||||
var siblingRunningIDs []string
|
||||
|
||||
m.mu.Lock()
|
||||
@@ -898,11 +987,9 @@ func (m *BatchTaskManager) PrepareSingleTaskRun(queueID, taskID string) error {
|
||||
}
|
||||
|
||||
// 暂停态:中止在途子任务并收口仍标记 running 的其它子任务,以便单条执行非冲突项
|
||||
var cancelFuncs []context.CancelFunc
|
||||
if queue.Status == BatchQueueStatusPaused {
|
||||
if c, ok := m.taskCancels[queueID]; ok {
|
||||
cancelFunc = c
|
||||
delete(m.taskCancels, queueID)
|
||||
}
|
||||
cancelFuncs = m.drainTaskCancelsLocked(queueID)
|
||||
for _, t := range queue.Tasks {
|
||||
if t != nil && t.ID != taskID && t.Status == BatchTaskStatusRunning {
|
||||
siblingRunningIDs = append(siblingRunningIDs, t.ID)
|
||||
@@ -914,8 +1001,10 @@ func (m *BatchTaskManager) PrepareSingleTaskRun(queueID, taskID string) error {
|
||||
resumeQueue := queue.Status == BatchQueueStatusCompleted || queue.Status == BatchQueueStatusCancelled
|
||||
m.mu.Unlock()
|
||||
|
||||
if cancelFunc != nil {
|
||||
cancelFunc()
|
||||
for _, c := range cancelFuncs {
|
||||
if c != nil {
|
||||
c()
|
||||
}
|
||||
}
|
||||
const staleRunMsg = "为单条执行其它任务,已中止"
|
||||
for _, sid := range siblingRunningIDs {
|
||||
@@ -1089,7 +1178,90 @@ func queueAllowsSingleTaskRunLocked(queue *BatchTaskQueue, task *BatchTask) bool
|
||||
}
|
||||
}
|
||||
|
||||
// GetNextTask 获取下一个待执行的任务
|
||||
// ClaimNextPendingTask 原子领取下一个待执行子任务(并发 worker 安全)。
|
||||
func (m *BatchTaskManager) ClaimNextPendingTask(queueID string) (*BatchTask, bool) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
queue, exists := m.queues[queueID]
|
||||
if !exists || queue == nil {
|
||||
return nil, false
|
||||
}
|
||||
if queue.Status == BatchQueueStatusCancelled || queue.Status == BatchQueueStatusCompleted || queue.Status == BatchQueueStatusPaused {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
onlyTaskID := ""
|
||||
if m.singleRunTasks != nil {
|
||||
onlyTaskID = m.singleRunTasks[queueID]
|
||||
}
|
||||
|
||||
for i, task := range queue.Tasks {
|
||||
if task == nil || task.Status != BatchTaskStatusPending {
|
||||
continue
|
||||
}
|
||||
if onlyTaskID != "" && task.ID != onlyTaskID {
|
||||
continue
|
||||
}
|
||||
task.Status = BatchTaskStatusRunning
|
||||
queue.CurrentIndex = i
|
||||
return task, true
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
|
||||
// HasRunningTasks 队列是否仍有 running 状态的子任务。
|
||||
func (m *BatchTaskManager) HasRunningTasks(queueID string) bool {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
queue, exists := m.queues[queueID]
|
||||
if !exists || queue == nil {
|
||||
return false
|
||||
}
|
||||
for _, task := range queue.Tasks {
|
||||
if task != nil && task.Status == BatchTaskStatusRunning {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// HasPendingOrRunningTasks 队列是否仍有未完成的子任务。
|
||||
func (m *BatchTaskManager) HasPendingOrRunningTasks(queueID string) bool {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
queue, exists := m.queues[queueID]
|
||||
if !exists || queue == nil {
|
||||
return false
|
||||
}
|
||||
for _, task := range queue.Tasks {
|
||||
if task == nil {
|
||||
continue
|
||||
}
|
||||
if task.Status == BatchTaskStatusPending || task.Status == BatchTaskStatusRunning {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// drainTaskCancelsLocked 取出并清空队列下所有子任务取消函数(调用方须已持 m.mu)。
|
||||
func (m *BatchTaskManager) drainTaskCancelsLocked(queueID string) []context.CancelFunc {
|
||||
taskMap, ok := m.taskCancels[queueID]
|
||||
if !ok || len(taskMap) == 0 {
|
||||
return nil
|
||||
}
|
||||
cancels := make([]context.CancelFunc, 0, len(taskMap))
|
||||
for _, c := range taskMap {
|
||||
if c != nil {
|
||||
cancels = append(cancels, c)
|
||||
}
|
||||
}
|
||||
delete(m.taskCancels, queueID)
|
||||
return cancels
|
||||
}
|
||||
|
||||
// GetNextTask 获取下一个待执行的任务(串行兼容,优先使用 ClaimNextPendingTask)
|
||||
func (m *BatchTaskManager) GetNextTask(queueID string) (*BatchTask, bool) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
@@ -1130,20 +1302,28 @@ func (m *BatchTaskManager) MoveToNextTask(queueID string) {
|
||||
}
|
||||
}
|
||||
|
||||
// SetTaskCancel 设置当前任务的取消函数
|
||||
func (m *BatchTaskManager) SetTaskCancel(queueID string, cancel context.CancelFunc) {
|
||||
// SetTaskCancel 设置子任务的取消函数
|
||||
func (m *BatchTaskManager) SetTaskCancel(queueID, taskID string, cancel context.CancelFunc) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
if cancel != nil {
|
||||
m.taskCancels[queueID] = cancel
|
||||
} else {
|
||||
delete(m.taskCancels, queueID)
|
||||
if cancel == nil {
|
||||
if taskMap, ok := m.taskCancels[queueID]; ok {
|
||||
delete(taskMap, taskID)
|
||||
if len(taskMap) == 0 {
|
||||
delete(m.taskCancels, queueID)
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
if m.taskCancels[queueID] == nil {
|
||||
m.taskCancels[queueID] = make(map[string]context.CancelFunc)
|
||||
}
|
||||
m.taskCancels[queueID][taskID] = cancel
|
||||
}
|
||||
|
||||
// PauseQueue 暂停队列
|
||||
func (m *BatchTaskManager) PauseQueue(queueID string) bool {
|
||||
var cancelFunc context.CancelFunc
|
||||
var cancelFuncs []context.CancelFunc
|
||||
|
||||
m.mu.Lock()
|
||||
queue, exists := m.queues[queueID]
|
||||
@@ -1168,17 +1348,11 @@ func (m *BatchTaskManager) PauseQueue(queueID string) bool {
|
||||
}
|
||||
|
||||
queue.Status = BatchQueueStatusPaused
|
||||
|
||||
// 取消当前正在执行的任务(通过取消context)
|
||||
if cancel, ok := m.taskCancels[queueID]; ok {
|
||||
cancelFunc = cancel
|
||||
delete(m.taskCancels, queueID)
|
||||
}
|
||||
cancelFuncs = m.drainTaskCancelsLocked(queueID)
|
||||
m.mu.Unlock()
|
||||
|
||||
// 释放锁后执行取消回调(cancel 可能阻塞,不应持锁)
|
||||
if cancelFunc != nil {
|
||||
cancelFunc()
|
||||
for _, c := range cancelFuncs {
|
||||
c()
|
||||
}
|
||||
|
||||
return true
|
||||
@@ -1187,7 +1361,7 @@ func (m *BatchTaskManager) PauseQueue(queueID string) bool {
|
||||
// CancelQueue 取消队列(保留此方法以保持向后兼容,但建议使用PauseQueue)
|
||||
func (m *BatchTaskManager) CancelQueue(queueID string) bool {
|
||||
now := time.Now()
|
||||
var cancelFunc context.CancelFunc
|
||||
var cancelFuncs []context.CancelFunc
|
||||
|
||||
m.mu.Lock()
|
||||
queue, exists := m.queues[queueID]
|
||||
@@ -1228,34 +1402,33 @@ func (m *BatchTaskManager) CancelQueue(queueID string) bool {
|
||||
}
|
||||
}
|
||||
|
||||
// 取消当前正在执行的任务
|
||||
if cancel, ok := m.taskCancels[queueID]; ok {
|
||||
cancelFunc = cancel
|
||||
delete(m.taskCancels, queueID)
|
||||
}
|
||||
cancelFuncs = m.drainTaskCancelsLocked(queueID)
|
||||
m.mu.Unlock()
|
||||
|
||||
// 释放锁后执行取消回调(cancel 可能阻塞,不应持锁)
|
||||
if cancelFunc != nil {
|
||||
cancelFunc()
|
||||
for _, c := range cancelFuncs {
|
||||
c()
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// DeleteQueue 删除队列(运行中的队列不允许删除)
|
||||
func (m *BatchTaskManager) DeleteQueue(queueID string) bool {
|
||||
// DeleteQueue 删除队列。执行协程活跃或 status 为 running 时拒绝删除,避免 executeBatchQueue 空指针 panic。
|
||||
func (m *BatchTaskManager) DeleteQueue(queueID string) error {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
queue, exists := m.queues[queueID]
|
||||
if !exists {
|
||||
return false
|
||||
return ErrBatchQueueNotFound
|
||||
}
|
||||
|
||||
if _, exec := m.queueExecutors[queueID]; exec {
|
||||
return ErrBatchQueueExecutorActive
|
||||
}
|
||||
|
||||
// 运行中的队列不允许删除,防止孤儿协程和数据丢失
|
||||
if queue.Status == BatchQueueStatusRunning {
|
||||
return false
|
||||
return ErrBatchQueueStillRunning
|
||||
}
|
||||
|
||||
// 清理取消函数
|
||||
@@ -1269,7 +1442,7 @@ func (m *BatchTaskManager) DeleteQueue(queueID string) bool {
|
||||
}
|
||||
|
||||
delete(m.queues, queueID)
|
||||
return true
|
||||
return nil
|
||||
}
|
||||
|
||||
// generateShortID 生成短ID
|
||||
|
||||
@@ -0,0 +1,121 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
func TestNormalizeBatchQueueConcurrency(t *testing.T) {
|
||||
if got := normalizeBatchQueueConcurrency(0); got != DefaultBatchQueueConcurrency {
|
||||
t.Fatalf("expected default %d, got %d", DefaultBatchQueueConcurrency, got)
|
||||
}
|
||||
if got := normalizeBatchQueueConcurrency(99); got != MaxBatchQueueConcurrency {
|
||||
t.Fatalf("expected max %d, got %d", MaxBatchQueueConcurrency, got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestClaimNextPendingTaskParallel(t *testing.T) {
|
||||
m := NewBatchTaskManager(zap.NewNop())
|
||||
queue, err := m.CreateBatchQueue("test", "", "eino_single", "manual", "", "", nil, 3, []string{"a", "b", "c"})
|
||||
if err != nil {
|
||||
t.Fatalf("CreateBatchQueue: %v", err)
|
||||
}
|
||||
m.UpdateQueueStatus(queue.ID, BatchQueueStatusRunning)
|
||||
|
||||
t1, ok1 := m.ClaimNextPendingTask(queue.ID)
|
||||
t2, ok2 := m.ClaimNextPendingTask(queue.ID)
|
||||
if !ok1 || !ok2 || t1.ID == t2.ID {
|
||||
t.Fatalf("expected two distinct claims, got ok1=%v ok2=%v t1=%v t2=%v", ok1, ok2, t1, t2)
|
||||
}
|
||||
if t1.Status != BatchTaskStatusRunning || t2.Status != BatchTaskStatusRunning {
|
||||
t.Fatalf("claimed tasks should be running")
|
||||
}
|
||||
t3, ok3 := m.ClaimNextPendingTask(queue.ID)
|
||||
if !ok3 {
|
||||
t.Fatal("expected third claim")
|
||||
}
|
||||
_, ok4 := m.ClaimNextPendingTask(queue.ID)
|
||||
if ok4 {
|
||||
t.Fatal("expected no fourth pending task")
|
||||
}
|
||||
_ = t3
|
||||
}
|
||||
|
||||
func TestBatchQueueExecutionShouldStop(t *testing.T) {
|
||||
t.Parallel()
|
||||
if !batchQueueExecutionShouldStop(nil, false) {
|
||||
t.Fatal("expected stop when queue missing")
|
||||
}
|
||||
if !batchQueueExecutionShouldStop(nil, true) {
|
||||
t.Fatal("expected stop when queue is nil but exists=true")
|
||||
}
|
||||
q := &BatchTaskQueue{Status: BatchQueueStatusRunning}
|
||||
if batchQueueExecutionShouldStop(q, true) {
|
||||
t.Fatal("expected continue when running")
|
||||
}
|
||||
q.Status = BatchQueueStatusCancelled
|
||||
if !batchQueueExecutionShouldStop(q, true) {
|
||||
t.Fatal("expected stop when cancelled")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDeleteQueueBlockedWhileExecutorActive(t *testing.T) {
|
||||
t.Parallel()
|
||||
m := NewBatchTaskManager(zap.NewNop())
|
||||
queue, err := m.CreateBatchQueue("test", "", "eino_single", "manual", "", "", nil, 1, []string{"hello"})
|
||||
if err != nil {
|
||||
t.Fatalf("CreateBatchQueue: %v", err)
|
||||
}
|
||||
if !m.TryMarkQueueExecutor(queue.ID) {
|
||||
t.Fatal("expected to mark executor")
|
||||
}
|
||||
m.UpdateQueueStatus(queue.ID, BatchQueueStatusCancelled)
|
||||
|
||||
err = m.DeleteQueue(queue.ID)
|
||||
if !errors.Is(err, ErrBatchQueueExecutorActive) {
|
||||
t.Fatalf("expected ErrBatchQueueExecutorActive, got %v", err)
|
||||
}
|
||||
if _, ok := m.GetBatchQueue(queue.ID); !ok {
|
||||
t.Fatal("queue should still exist while executor active")
|
||||
}
|
||||
|
||||
m.UnmarkQueueExecutor(queue.ID)
|
||||
if err := m.DeleteQueue(queue.ID); err != nil {
|
||||
t.Fatalf("expected delete after executor unmarked, got %v", err)
|
||||
}
|
||||
if _, ok := m.GetBatchQueue(queue.ID); ok {
|
||||
t.Fatal("queue should be deleted")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDeleteQueueBlockedWhileRunning(t *testing.T) {
|
||||
t.Parallel()
|
||||
m := NewBatchTaskManager(zap.NewNop())
|
||||
queue, err := m.CreateBatchQueue("test", "", "eino_single", "manual", "", "", nil, 1, []string{"hello"})
|
||||
if err != nil {
|
||||
t.Fatalf("CreateBatchQueue: %v", err)
|
||||
}
|
||||
m.UpdateQueueStatus(queue.ID, BatchQueueStatusRunning)
|
||||
|
||||
err = m.DeleteQueue(queue.ID)
|
||||
if !errors.Is(err, ErrBatchQueueStillRunning) {
|
||||
t.Fatalf("expected ErrBatchQueueStillRunning, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTryMarkQueueExecutorDedupes(t *testing.T) {
|
||||
t.Parallel()
|
||||
m := NewBatchTaskManager(zap.NewNop())
|
||||
if !m.TryMarkQueueExecutor("q-1") {
|
||||
t.Fatal("first mark should succeed")
|
||||
}
|
||||
if m.TryMarkQueueExecutor("q-1") {
|
||||
t.Fatal("second mark should fail")
|
||||
}
|
||||
m.UnmarkQueueExecutor("q-1")
|
||||
if !m.TryMarkQueueExecutor("q-1") {
|
||||
t.Fatal("mark after unmark should succeed")
|
||||
}
|
||||
}
|
||||
@@ -3,6 +3,7 @@ package handler
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
@@ -181,6 +182,10 @@ func RegisterBatchTaskMCPTools(mcpServer *mcp.Server, h *AgentHandler, logger *z
|
||||
"type": "string",
|
||||
"description": "队列内子对话绑定的项目 ID(可选,未指定时使用 config.project.default_project_id)",
|
||||
},
|
||||
"concurrency": map[string]interface{}{
|
||||
"type": "integer",
|
||||
"description": "同时执行的子任务数,默认 1(串行),最大 8。含扫描类工具时建议 1-2。",
|
||||
},
|
||||
},
|
||||
},
|
||||
}, func(ctx context.Context, args map[string]interface{}) (*mcp.ToolResult, error) {
|
||||
@@ -210,7 +215,8 @@ func RegisterBatchTaskMCPTools(mcpServer *mcp.Server, h *AgentHandler, logger *z
|
||||
executeNow = false
|
||||
}
|
||||
projectID := strings.TrimSpace(mcpArgString(args, "project_id"))
|
||||
queue, createErr := h.batchTaskManager.CreateBatchQueue(title, role, agentMode, scheduleMode, cronExpr, projectID, nextRunAt, tasks)
|
||||
concurrency := int(mcpArgFloat(args, "concurrency"))
|
||||
queue, createErr := h.batchTaskManager.CreateBatchQueue(title, role, agentMode, scheduleMode, cronExpr, projectID, nextRunAt, concurrency, tasks)
|
||||
if createErr != nil {
|
||||
return batchMCPTextResult("创建队列失败: "+createErr.Error(), true), nil
|
||||
}
|
||||
@@ -365,8 +371,17 @@ func RegisterBatchTaskMCPTools(mcpServer *mcp.Server, h *AgentHandler, logger *z
|
||||
if qid == "" {
|
||||
return batchMCPTextResult("queue_id 不能为空", true), nil
|
||||
}
|
||||
if !h.batchTaskManager.DeleteQueue(qid) {
|
||||
return batchMCPTextResult("删除失败:队列不存在", true), nil
|
||||
if err := h.batchTaskManager.DeleteQueue(qid); err != nil {
|
||||
switch {
|
||||
case errors.Is(err, ErrBatchQueueNotFound):
|
||||
return batchMCPTextResult("删除失败:队列不存在", true), nil
|
||||
case errors.Is(err, ErrBatchQueueExecutorActive):
|
||||
return batchMCPTextResult("删除失败:队列执行器仍在运行,请稍后再试", true), nil
|
||||
case errors.Is(err, ErrBatchQueueStillRunning):
|
||||
return batchMCPTextResult("删除失败:队列正在运行中", true), nil
|
||||
default:
|
||||
return batchMCPTextResult("删除失败:"+err.Error(), true), nil
|
||||
}
|
||||
}
|
||||
logger.Info("MCP batch_task_delete", zap.String("queueId", qid))
|
||||
return batchMCPTextResult("队列已删除。", false), nil
|
||||
@@ -397,6 +412,10 @@ func RegisterBatchTaskMCPTools(mcpServer *mcp.Server, h *AgentHandler, logger *z
|
||||
"description": "代理模式:eino_single、deep、plan_execute、supervisor",
|
||||
"enum": []string{"eino_single", "deep", "plan_execute", "supervisor"},
|
||||
},
|
||||
"concurrency": map[string]interface{}{
|
||||
"type": "integer",
|
||||
"description": "同时执行的子任务数,默认 1,最大 8",
|
||||
},
|
||||
},
|
||||
"required": []string{"queue_id"},
|
||||
},
|
||||
@@ -408,7 +427,12 @@ func RegisterBatchTaskMCPTools(mcpServer *mcp.Server, h *AgentHandler, logger *z
|
||||
title := mcpArgString(args, "title")
|
||||
role := mcpArgString(args, "role")
|
||||
agentMode := mcpArgString(args, "agent_mode")
|
||||
if err := h.batchTaskManager.UpdateQueueMetadata(qid, title, role, agentMode); err != nil {
|
||||
var concurrency *int
|
||||
if raw, ok := args["concurrency"]; ok && raw != nil {
|
||||
v := int(mcpArgFloat(args, "concurrency"))
|
||||
concurrency = &v
|
||||
}
|
||||
if err := h.batchTaskManager.UpdateQueueMetadata(qid, title, role, agentMode, concurrency); err != nil {
|
||||
return batchMCPTextResult(err.Error(), true), nil
|
||||
}
|
||||
updated, _ := h.batchTaskManager.GetBatchQueue(qid)
|
||||
@@ -652,6 +676,7 @@ type batchTaskQueueMCPListItem struct {
|
||||
StartedAt *time.Time `json:"startedAt,omitempty"`
|
||||
CompletedAt *time.Time `json:"completedAt,omitempty"`
|
||||
CurrentIndex int `json:"currentIndex"`
|
||||
Concurrency int `json:"concurrency"`
|
||||
TaskTotal int `json:"task_total"`
|
||||
TaskCounts map[string]int `json:"task_counts"`
|
||||
Tasks []batchTaskMCPListSummary `json:"tasks"`
|
||||
@@ -715,6 +740,7 @@ func toBatchTaskQueueMCPListItem(q *BatchTaskQueue) batchTaskQueueMCPListItem {
|
||||
StartedAt: q.StartedAt,
|
||||
CompletedAt: q.CompletedAt,
|
||||
CurrentIndex: q.CurrentIndex,
|
||||
Concurrency: q.Concurrency,
|
||||
TaskTotal: len(tasks),
|
||||
TaskCounts: counts,
|
||||
Tasks: tasks,
|
||||
|
||||
@@ -12,11 +12,17 @@ import (
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// ConversationTaskStopper cancels in-flight agent work when a conversation is removed.
|
||||
type ConversationTaskStopper interface {
|
||||
CancelRunningTaskForConversation(conversationID string)
|
||||
}
|
||||
|
||||
// ConversationHandler 对话处理器
|
||||
type ConversationHandler struct {
|
||||
db *database.DB
|
||||
logger *zap.Logger
|
||||
audit *audit.Service
|
||||
db *database.DB
|
||||
logger *zap.Logger
|
||||
audit *audit.Service
|
||||
taskStopper ConversationTaskStopper
|
||||
}
|
||||
|
||||
// SetAudit wires platform audit logging.
|
||||
@@ -24,6 +30,11 @@ func (h *ConversationHandler) SetAudit(s *audit.Service) {
|
||||
h.audit = s
|
||||
}
|
||||
|
||||
// SetTaskStopper wires cancellation of in-flight agent tasks on conversation delete.
|
||||
func (h *ConversationHandler) SetTaskStopper(stopper ConversationTaskStopper) {
|
||||
h.taskStopper = stopper
|
||||
}
|
||||
|
||||
// NewConversationHandler 创建新的对话处理器
|
||||
func NewConversationHandler(db *database.DB, logger *zap.Logger) *ConversationHandler {
|
||||
return &ConversationHandler{
|
||||
@@ -165,6 +176,9 @@ func (h *ConversationHandler) GetConversation(c *gin.Context) {
|
||||
}
|
||||
|
||||
// GetMessageProcessDetails 获取指定消息的过程详情(按需加载)
|
||||
// 查询参数:
|
||||
// - summary=1:仅返回摘要(total / iterationCount / maxIteration)
|
||||
// - limit + offset:分页返回 processDetails(未指定 limit 时保持全量兼容)
|
||||
func (h *ConversationHandler) GetMessageProcessDetails(c *gin.Context) {
|
||||
messageID := c.Param("id")
|
||||
if messageID == "" {
|
||||
@@ -172,6 +186,51 @@ func (h *ConversationHandler) GetMessageProcessDetails(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
summaryStr := strings.TrimSpace(c.Query("summary"))
|
||||
if summaryStr == "1" || strings.EqualFold(summaryStr, "true") || strings.EqualFold(summaryStr, "yes") {
|
||||
summary, err := h.db.GetProcessDetailsSummary(messageID)
|
||||
if err != nil {
|
||||
h.logger.Error("获取过程详情摘要失败", zap.Error(err))
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{"summary": summary})
|
||||
return
|
||||
}
|
||||
|
||||
limitStr := strings.TrimSpace(c.Query("limit"))
|
||||
if limitStr != "" {
|
||||
limit, err := strconv.Atoi(limitStr)
|
||||
if err != nil || limit <= 0 {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid limit"})
|
||||
return
|
||||
}
|
||||
if limit > 500 {
|
||||
limit = 500
|
||||
}
|
||||
offset, _ := strconv.Atoi(strings.TrimSpace(c.Query("offset")))
|
||||
if offset < 0 {
|
||||
offset = 0
|
||||
}
|
||||
|
||||
details, total, err := h.db.GetProcessDetailsPage(messageID, limit, offset)
|
||||
if err != nil {
|
||||
h.logger.Error("分页获取过程详情失败", zap.Error(err))
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
details = database.DedupeConsecutiveProcessDetails(details)
|
||||
out := processDetailsToJSON(h.logger, details)
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"processDetails": out,
|
||||
"total": total,
|
||||
"offset": offset,
|
||||
"limit": limit,
|
||||
"hasMore": offset+len(out) < total,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
details, err := h.db.GetProcessDetails(messageID)
|
||||
if err != nil {
|
||||
h.logger.Error("获取过程详情失败", zap.Error(err))
|
||||
@@ -180,14 +239,17 @@ func (h *ConversationHandler) GetMessageProcessDetails(c *gin.Context) {
|
||||
}
|
||||
|
||||
details = database.DedupeConsecutiveProcessDetails(details)
|
||||
out := processDetailsToJSON(h.logger, details)
|
||||
c.JSON(http.StatusOK, gin.H{"processDetails": out, "total": len(out)})
|
||||
}
|
||||
|
||||
// 转换为前端期望的 JSON 结构(与 GetConversation 中 processDetails 结构一致)
|
||||
func processDetailsToJSON(logger *zap.Logger, details []database.ProcessDetail) []map[string]interface{} {
|
||||
out := make([]map[string]interface{}, 0, len(details))
|
||||
for _, d := range details {
|
||||
var data interface{}
|
||||
if d.Data != "" {
|
||||
if err := json.Unmarshal([]byte(d.Data), &data); err != nil {
|
||||
h.logger.Warn("解析过程详情数据失败", zap.Error(err))
|
||||
logger.Warn("解析过程详情数据失败", zap.Error(err))
|
||||
}
|
||||
}
|
||||
out = append(out, map[string]interface{}{
|
||||
@@ -200,8 +262,7 @@ func (h *ConversationHandler) GetMessageProcessDetails(c *gin.Context) {
|
||||
"createdAt": d.CreatedAt,
|
||||
})
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{"processDetails": out})
|
||||
return out
|
||||
}
|
||||
|
||||
// UpdateConversationRequest 更新对话请求
|
||||
@@ -245,6 +306,10 @@ func (h *ConversationHandler) UpdateConversation(c *gin.Context) {
|
||||
func (h *ConversationHandler) DeleteConversation(c *gin.Context) {
|
||||
id := c.Param("id")
|
||||
|
||||
if h.taskStopper != nil {
|
||||
h.taskStopper.CancelRunningTaskForConversation(id)
|
||||
}
|
||||
|
||||
if err := h.db.DeleteConversation(id); err != nil {
|
||||
h.logger.Error("删除对话失败", zap.Error(err))
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
|
||||
@@ -0,0 +1,30 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
func TestConversationHandlerDeleteConversationCancelsRunningTask(t *testing.T) {
|
||||
tm := NewAgentTaskManager()
|
||||
ctx, cancel := context.WithCancelCause(context.Background())
|
||||
_, err := tm.StartTask("conv-1", "hello", cancel)
|
||||
if err != nil {
|
||||
t.Fatalf("StartTask: %v", err)
|
||||
}
|
||||
|
||||
h := &AgentHandler{tasks: tm, logger: zap.NewNop()}
|
||||
h.CancelRunningTaskForConversation("conv-1")
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("task context was not cancelled")
|
||||
}
|
||||
if cause := context.Cause(ctx); cause != ErrTaskCancelled {
|
||||
t.Fatalf("expected ErrTaskCancelled, got %v", cause)
|
||||
}
|
||||
}
|
||||
@@ -2,31 +2,11 @@ package handler
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"cyberstrike-ai/internal/agent"
|
||||
"cyberstrike-ai/internal/multiagent"
|
||||
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
func (h *AgentHandler) einoRunRetryMaxAttempts() int {
|
||||
if h.config != nil {
|
||||
return multiagent.RunRetryMaxAttemptsFromConfig(&h.config.MultiAgent.EinoMiddleware)
|
||||
}
|
||||
return multiagent.RunRetryMaxAttemptsFromConfig(nil)
|
||||
}
|
||||
|
||||
func (h *AgentHandler) einoRunRetryMaxBackoffSec() int {
|
||||
if h.config != nil && h.config.MultiAgent.EinoMiddleware.RunRetryMaxBackoffSec > 0 {
|
||||
return h.config.MultiAgent.EinoMiddleware.RunRetryMaxBackoffSec
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
// applyEinoTraceResumeSegment 中断并继续:persist last_react_* → loadHistory,可选替换下一段 user 文案。
|
||||
func (h *AgentHandler) applyEinoTraceResumeSegment(
|
||||
conversationID string,
|
||||
@@ -45,136 +25,3 @@ func (h *AgentHandler) applyEinoTraceResumeSegment(
|
||||
*curFinalMessage = segmentUserMessage
|
||||
}
|
||||
}
|
||||
|
||||
// applyEinoTransientRetrySegment 临时错误重试:恢复轨迹并保留本请求原始 user 文案(不注入续跑说明)。
|
||||
// segmentUserMessage 为本轮 HTTP 请求开始时用户发送的内容,避免因清空 finalMessage 而丢失「你好」等短句。
|
||||
func (h *AgentHandler) applyEinoTransientRetrySegment(
|
||||
conversationID string,
|
||||
result *multiagent.RunResult,
|
||||
curHistory *[]agent.ChatMessage,
|
||||
curFinalMessage *string,
|
||||
segmentUserMessage string,
|
||||
) {
|
||||
if shouldPersistEinoAgentTraceAfterRunError(context.Background()) {
|
||||
h.persistEinoAgentTraceForResume(conversationID, result)
|
||||
}
|
||||
if hist, err := h.loadHistoryFromAgentTrace(conversationID); err == nil && len(hist) > 0 {
|
||||
*curHistory = hist
|
||||
}
|
||||
if s := strings.TrimSpace(segmentUserMessage); s != "" {
|
||||
*curFinalMessage = segmentUserMessage
|
||||
}
|
||||
}
|
||||
|
||||
// handleEinoTransientRetryContinue 在 SSE 任务循环内处理临时错误重试;返回 true 表示外层 for 应 continue。
|
||||
func (h *AgentHandler) handleEinoTransientRetryContinue(
|
||||
baseCtx context.Context,
|
||||
conversationID string,
|
||||
result *multiagent.RunResult,
|
||||
runErr error,
|
||||
transientAttempts *int,
|
||||
curHistory *[]agent.ChatMessage,
|
||||
curFinalMessage *string,
|
||||
segmentUserMessage string,
|
||||
progressCallback func(eventType, message string, data interface{}),
|
||||
sendProgress func(msg string, extra map[string]interface{}),
|
||||
) (handled bool, fatal error) {
|
||||
if !errors.Is(runErr, multiagent.ErrTransientRetryContinue) {
|
||||
return false, nil
|
||||
}
|
||||
maxAttempts := h.einoRunRetryMaxAttempts()
|
||||
*transientAttempts++
|
||||
if *transientAttempts > maxAttempts {
|
||||
if shouldPersistEinoAgentTraceAfterRunError(baseCtx) {
|
||||
h.persistEinoAgentTraceForResume(conversationID, result)
|
||||
}
|
||||
return false, errors.New("transient retry exhausted: " + runErr.Error())
|
||||
}
|
||||
attemptNo := *transientAttempts
|
||||
backoff := multiagent.TransientRetryBackoff(attemptNo-1, h.einoRunRetryMaxBackoffSec())
|
||||
if progressCallback != nil {
|
||||
progressCallback("eino_run_retry", fmt.Sprintf("遇到临时错误,%d 秒后第 %d/%d 次重试…", int(backoff.Seconds()), attemptNo, maxAttempts), map[string]interface{}{
|
||||
"conversationId": conversationID,
|
||||
"source": "eino",
|
||||
"attempt": attemptNo,
|
||||
"maxAttempts": maxAttempts,
|
||||
"backoffSec": int(backoff.Seconds()),
|
||||
})
|
||||
}
|
||||
select {
|
||||
case <-baseCtx.Done():
|
||||
return false, context.Cause(baseCtx)
|
||||
case <-time.After(backoff):
|
||||
}
|
||||
h.applyEinoTransientRetrySegment(conversationID, result, curHistory, curFinalMessage, segmentUserMessage)
|
||||
if progressCallback != nil {
|
||||
progressCallback("eino_run_retry", "已恢复上下文,正在重试…", map[string]interface{}{
|
||||
"conversationId": conversationID,
|
||||
"source": "eino",
|
||||
"attempt": attemptNo,
|
||||
})
|
||||
}
|
||||
if sendProgress != nil {
|
||||
sendProgress("正在重试…", map[string]interface{}{
|
||||
"conversationId": conversationID,
|
||||
"source": "transient_retry",
|
||||
})
|
||||
}
|
||||
return true, nil
|
||||
}
|
||||
|
||||
// handleEinoEmptyResponseContinue 在 SSE 任务循环内处理「正常结束但无助手正文」;返回 exhausted=true 时由外层按成功结束(保留占位文案)。
|
||||
// 与临时错误重试一致:仅恢复轨迹并保留本请求原始 user 文案,不向模型注入续跑说明。
|
||||
func (h *AgentHandler) handleEinoEmptyResponseContinue(
|
||||
baseCtx context.Context,
|
||||
conversationID string,
|
||||
result *multiagent.RunResult,
|
||||
runErr error,
|
||||
emptyResponseAttempts *int,
|
||||
curHistory *[]agent.ChatMessage,
|
||||
curFinalMessage *string,
|
||||
segmentUserMessage string,
|
||||
progressCallback func(eventType, message string, data interface{}),
|
||||
sendProgress func(msg string, extra map[string]interface{}),
|
||||
) (handled bool, exhausted bool) {
|
||||
if !errors.Is(runErr, multiagent.ErrEmptyResponseContinue) {
|
||||
return false, false
|
||||
}
|
||||
maxAttempts := h.einoRunRetryMaxAttempts()
|
||||
*emptyResponseAttempts++
|
||||
if *emptyResponseAttempts > maxAttempts {
|
||||
if h.logger != nil {
|
||||
h.logger.Warn("eino empty response auto resume exhausted",
|
||||
zap.String("conversationId", conversationID),
|
||||
zap.Int("maxAttempts", maxAttempts))
|
||||
}
|
||||
if shouldPersistEinoAgentTraceAfterRunError(baseCtx) {
|
||||
h.persistEinoAgentTraceForResume(conversationID, result)
|
||||
}
|
||||
return false, true
|
||||
}
|
||||
attemptNo := *emptyResponseAttempts
|
||||
if h.logger != nil {
|
||||
h.logger.Info("eino empty response, auto resume from trace",
|
||||
zap.String("conversationId", conversationID),
|
||||
zap.Int("attempt", attemptNo),
|
||||
zap.Int("maxAttempts", maxAttempts))
|
||||
}
|
||||
if progressCallback != nil {
|
||||
progressCallback("eino_empty_response_continue", fmt.Sprintf("未捕获到助手正文,正在基于轨迹自动续跑(%d/%d)…", attemptNo, maxAttempts), map[string]interface{}{
|
||||
"conversationId": conversationID,
|
||||
"source": "eino",
|
||||
"attempt": attemptNo,
|
||||
"maxAttempts": maxAttempts,
|
||||
"resumeKind": "trace_segment",
|
||||
})
|
||||
}
|
||||
h.applyEinoTransientRetrySegment(conversationID, result, curHistory, curFinalMessage, segmentUserMessage)
|
||||
if sendProgress != nil {
|
||||
sendProgress("已恢复上下文,正在继续推理…", map[string]interface{}{
|
||||
"conversationId": conversationID,
|
||||
"source": "empty_response_continue",
|
||||
})
|
||||
}
|
||||
return true, false
|
||||
}
|
||||
|
||||
@@ -119,7 +119,6 @@ func (h *AgentHandler) EinoSingleAgentLoopStream(c *gin.Context) {
|
||||
|
||||
var cancelWithCause context.CancelCauseFunc
|
||||
curFinalMessage := prep.FinalMessage
|
||||
segmentUserMessage := prep.FinalMessage // 本请求原始用户句,临时重试时不得丢失
|
||||
curHistory := prep.History
|
||||
roleTools := prep.RoleTools
|
||||
|
||||
@@ -177,8 +176,6 @@ func (h *AgentHandler) EinoSingleAgentLoopStream(c *gin.Context) {
|
||||
taskOwned = true
|
||||
|
||||
var cumulativeMCPExecutionIDs []string
|
||||
var transientRunAttempts int
|
||||
var emptyResponseAttempts int
|
||||
// 同一请求内分段续跑时,主代理 iteration 事件按偏移累计,避免 UI 出现「第3轮 → 第1轮」回跳。
|
||||
var mainIterationOffset int
|
||||
|
||||
@@ -215,6 +212,7 @@ func (h *AgentHandler) EinoSingleAgentLoopStream(c *gin.Context) {
|
||||
}
|
||||
taskCtxLoop := mcp.WithMCPConversationID(taskCtx, conversationID)
|
||||
taskCtxLoop = mcp.WithToolRunRegistry(taskCtxLoop, h.tasks)
|
||||
taskCtxLoop = mcp.WithEinoExecuteRunRegistry(taskCtxLoop, h.tasks)
|
||||
taskCtxLoop = multiagent.WithHITLToolInterceptor(taskCtxLoop, func(ctx context.Context, toolName, arguments string) (string, error) {
|
||||
return h.interceptHITLForEinoTool(ctx, cancelWithCause, conversationID, assistantMessageID, sendEvent, toolName, arguments)
|
||||
})
|
||||
@@ -240,54 +238,11 @@ func (h *AgentHandler) EinoSingleAgentLoopStream(c *gin.Context) {
|
||||
cumulativeMCPExecutionIDs = mergeMCPExecutionIDLists(cumulativeMCPExecutionIDs, result.MCPExecutionIDs)
|
||||
}
|
||||
|
||||
handledEmpty, exhaustedEmpty := h.handleEinoEmptyResponseContinue(
|
||||
baseCtx, conversationID, result, runErr, &emptyResponseAttempts,
|
||||
&curHistory, &curFinalMessage, segmentUserMessage, progressCallback,
|
||||
func(msg string, extra map[string]interface{}) { sendEvent("progress", msg, extra) },
|
||||
)
|
||||
if exhaustedEmpty {
|
||||
runErr = nil
|
||||
transientRunAttempts = 0
|
||||
timeoutCancel()
|
||||
break
|
||||
}
|
||||
if handledEmpty {
|
||||
mainIterationOffset += segmentMainIterationMax
|
||||
transientRunAttempts = 0
|
||||
timeoutCancel()
|
||||
baseCtx, cancelWithCause = context.WithCancelCause(context.Background())
|
||||
h.tasks.BindTaskCancel(conversationID, cancelWithCause)
|
||||
taskCtx, timeoutCancel = context.WithTimeout(baseCtx, 600*time.Minute)
|
||||
h.tasks.UpdateTaskStatus(conversationID, "running")
|
||||
continue
|
||||
}
|
||||
|
||||
if runErr == nil {
|
||||
// 任一段成功完成后,重置临时错误重试窗口(次数/退避从头开始)。
|
||||
transientRunAttempts = 0
|
||||
emptyResponseAttempts = 0
|
||||
timeoutCancel()
|
||||
break
|
||||
}
|
||||
|
||||
handled, fatalErr := h.handleEinoTransientRetryContinue(
|
||||
baseCtx, conversationID, result, runErr, &transientRunAttempts,
|
||||
&curHistory, &curFinalMessage, segmentUserMessage, progressCallback,
|
||||
func(msg string, extra map[string]interface{}) { sendEvent("progress", msg, extra) },
|
||||
)
|
||||
if handled {
|
||||
mainIterationOffset += segmentMainIterationMax
|
||||
timeoutCancel()
|
||||
baseCtx, cancelWithCause = context.WithCancelCause(context.Background())
|
||||
h.tasks.BindTaskCancel(conversationID, cancelWithCause)
|
||||
taskCtx, timeoutCancel = context.WithTimeout(baseCtx, 600*time.Minute)
|
||||
h.tasks.UpdateTaskStatus(conversationID, "running")
|
||||
continue
|
||||
}
|
||||
if fatalErr != nil {
|
||||
runErr = fatalErr
|
||||
}
|
||||
|
||||
cause := context.Cause(baseCtx)
|
||||
if errors.Is(cause, multiagent.ErrInterruptContinue) {
|
||||
if shouldPersistEinoAgentTraceAfterRunError(baseCtx) {
|
||||
@@ -312,8 +267,6 @@ func (h *AgentHandler) EinoSingleAgentLoopStream(c *gin.Context) {
|
||||
"source": "interrupt_continue",
|
||||
})
|
||||
mainIterationOffset += segmentMainIterationMax
|
||||
// 非临时错误分段续跑(用户中断并继续)时,清空 transient 计数,避免跨分段累加。
|
||||
transientRunAttempts = 0
|
||||
timeoutCancel()
|
||||
baseCtx, cancelWithCause = context.WithCancelCause(context.Background())
|
||||
h.tasks.BindTaskCancel(conversationID, cancelWithCause)
|
||||
@@ -448,8 +401,6 @@ func (h *AgentHandler) EinoSingleAgentLoop(c *gin.Context) {
|
||||
curMsg := prep.FinalMessage
|
||||
var result *multiagent.RunResult
|
||||
var runErr error
|
||||
var transientRunAttempts int
|
||||
var emptyResponseAttempts int
|
||||
for {
|
||||
result, runErr = multiagent.RunEinoSingleChatModelAgent(
|
||||
taskCtx,
|
||||
@@ -467,28 +418,9 @@ func (h *AgentHandler) EinoSingleAgentLoop(c *gin.Context) {
|
||||
chatReasoningToClientIntent(req.Reasoning),
|
||||
h.projectBlackboardBlock(prep.ConversationID),
|
||||
)
|
||||
handledEmpty, exhaustedEmpty := h.handleEinoEmptyResponseContinue(
|
||||
baseCtx, prep.ConversationID, result, runErr, &emptyResponseAttempts,
|
||||
&curHist, &curMsg, prep.FinalMessage, progressCallback, nil,
|
||||
)
|
||||
if exhaustedEmpty {
|
||||
runErr = nil
|
||||
break
|
||||
}
|
||||
if handledEmpty {
|
||||
continue
|
||||
}
|
||||
if runErr == nil {
|
||||
break
|
||||
}
|
||||
if handled, fatalErr := h.handleEinoTransientRetryContinue(
|
||||
baseCtx, prep.ConversationID, result, runErr, &transientRunAttempts,
|
||||
&curHist, &curMsg, prep.FinalMessage, progressCallback, nil,
|
||||
); handled {
|
||||
continue
|
||||
} else if fatalErr != nil {
|
||||
runErr = fatalErr
|
||||
}
|
||||
if shouldPersistEinoAgentTraceAfterRunError(baseCtx) {
|
||||
h.persistEinoAgentTraceForResume(prep.ConversationID, result)
|
||||
}
|
||||
|
||||
+37
-20
@@ -10,8 +10,10 @@ import (
|
||||
"time"
|
||||
|
||||
"cyberstrike-ai/internal/audit"
|
||||
"cyberstrike-ai/internal/config"
|
||||
"cyberstrike-ai/internal/database"
|
||||
"cyberstrike-ai/internal/mcp"
|
||||
"cyberstrike-ai/internal/monitor"
|
||||
"cyberstrike-ai/internal/security"
|
||||
"github.com/gin-gonic/gin"
|
||||
"go.uber.org/zap"
|
||||
@@ -19,12 +21,18 @@ import (
|
||||
|
||||
// MonitorHandler 监控处理器
|
||||
type MonitorHandler struct {
|
||||
mcpServer *mcp.Server
|
||||
externalMCPMgr *mcp.ExternalMCPManager
|
||||
executor *security.Executor
|
||||
db *database.DB
|
||||
logger *zap.Logger
|
||||
audit *audit.Service
|
||||
mcpServer *mcp.Server
|
||||
externalMCPMgr *mcp.ExternalMCPManager
|
||||
executor *security.Executor
|
||||
db *database.DB
|
||||
logger *zap.Logger
|
||||
audit *audit.Service
|
||||
monitorRetention *monitor.Service
|
||||
}
|
||||
|
||||
// SetMonitorRetention wires MCP execution retention settings.
|
||||
func (h *MonitorHandler) SetMonitorRetention(s *monitor.Service) {
|
||||
h.monitorRetention = s
|
||||
}
|
||||
|
||||
// SetAudit wires platform audit logging.
|
||||
@@ -50,13 +58,14 @@ func (h *MonitorHandler) SetExternalMCPManager(mgr *mcp.ExternalMCPManager) {
|
||||
|
||||
// MonitorResponse 监控响应
|
||||
type MonitorResponse struct {
|
||||
Executions []*mcp.ToolExecution `json:"executions"`
|
||||
Stats map[string]*mcp.ToolStats `json:"stats"`
|
||||
Timestamp time.Time `json:"timestamp"`
|
||||
Total int `json:"total,omitempty"`
|
||||
Page int `json:"page,omitempty"`
|
||||
PageSize int `json:"page_size,omitempty"`
|
||||
TotalPages int `json:"total_pages,omitempty"`
|
||||
Executions []*mcp.ToolExecution `json:"executions"`
|
||||
Stats map[string]*mcp.ToolStats `json:"stats"`
|
||||
Timestamp time.Time `json:"timestamp"`
|
||||
Total int `json:"total,omitempty"`
|
||||
Page int `json:"page,omitempty"`
|
||||
PageSize int `json:"page_size,omitempty"`
|
||||
TotalPages int `json:"total_pages,omitempty"`
|
||||
RetentionDays int `json:"retention_days,omitempty"`
|
||||
}
|
||||
|
||||
// Monitor 获取监控信息
|
||||
@@ -89,16 +98,24 @@ func (h *MonitorHandler) Monitor(c *gin.Context) {
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, MonitorResponse{
|
||||
Executions: executions,
|
||||
Stats: stats,
|
||||
Timestamp: time.Now(),
|
||||
Total: total,
|
||||
Page: page,
|
||||
PageSize: pageSize,
|
||||
TotalPages: totalPages,
|
||||
Executions: executions,
|
||||
Stats: stats,
|
||||
Timestamp: time.Now(),
|
||||
Total: total,
|
||||
Page: page,
|
||||
PageSize: pageSize,
|
||||
TotalPages: totalPages,
|
||||
RetentionDays: h.monitorRetentionDays(),
|
||||
})
|
||||
}
|
||||
|
||||
func (h *MonitorHandler) monitorRetentionDays() int {
|
||||
if h.monitorRetention != nil {
|
||||
return h.monitorRetention.RetentionDays()
|
||||
}
|
||||
return config.MonitorConfig{}.RetentionDaysEffective()
|
||||
}
|
||||
|
||||
func (h *MonitorHandler) loadExecutions() []*mcp.ToolExecution {
|
||||
executions, _ := h.loadExecutionsWithPagination(1, 1000, "", "")
|
||||
return executions
|
||||
|
||||
@@ -136,7 +136,6 @@ func (h *AgentHandler) MultiAgentLoopStream(c *gin.Context) {
|
||||
|
||||
var cancelWithCause context.CancelCauseFunc
|
||||
curFinalMessage := prep.FinalMessage
|
||||
segmentUserMessage := prep.FinalMessage // 本请求原始用户句,临时重试时不得丢失
|
||||
curHistory := prep.History
|
||||
roleTools := prep.RoleTools
|
||||
orch := strings.TrimSpace(req.Orchestration)
|
||||
@@ -187,8 +186,6 @@ func (h *AgentHandler) MultiAgentLoopStream(c *gin.Context) {
|
||||
|
||||
// 同一 HTTP 流内多段 Run(如中断并继续)合并 MCP execution id,供最终 response / 库表与工具芯片展示完整列表
|
||||
var cumulativeMCPExecutionIDs []string
|
||||
var transientRunAttempts int
|
||||
var emptyResponseAttempts int
|
||||
// 同一请求内分段续跑时,主代理 iteration 事件按偏移累计,避免 UI 出现「第3轮 → 第1轮」回跳。
|
||||
var mainIterationOffset int
|
||||
|
||||
@@ -225,6 +222,7 @@ func (h *AgentHandler) MultiAgentLoopStream(c *gin.Context) {
|
||||
}
|
||||
taskCtxLoop := mcp.WithMCPConversationID(taskCtx, conversationID)
|
||||
taskCtxLoop = mcp.WithToolRunRegistry(taskCtxLoop, h.tasks)
|
||||
taskCtxLoop = mcp.WithEinoExecuteRunRegistry(taskCtxLoop, h.tasks)
|
||||
taskCtxLoop = multiagent.WithHITLToolInterceptor(taskCtxLoop, func(ctx context.Context, toolName, arguments string) (string, error) {
|
||||
return h.interceptHITLForEinoTool(ctx, cancelWithCause, conversationID, assistantMessageID, sendEvent, toolName, arguments)
|
||||
})
|
||||
@@ -252,54 +250,11 @@ func (h *AgentHandler) MultiAgentLoopStream(c *gin.Context) {
|
||||
cumulativeMCPExecutionIDs = mergeMCPExecutionIDLists(cumulativeMCPExecutionIDs, result.MCPExecutionIDs)
|
||||
}
|
||||
|
||||
handledEmpty, exhaustedEmpty := h.handleEinoEmptyResponseContinue(
|
||||
baseCtx, conversationID, result, runErr, &emptyResponseAttempts,
|
||||
&curHistory, &curFinalMessage, segmentUserMessage, progressCallback,
|
||||
func(msg string, extra map[string]interface{}) { sendEvent("progress", msg, extra) },
|
||||
)
|
||||
if exhaustedEmpty {
|
||||
runErr = nil
|
||||
transientRunAttempts = 0
|
||||
timeoutCancel()
|
||||
break
|
||||
}
|
||||
if handledEmpty {
|
||||
mainIterationOffset += segmentMainIterationMax
|
||||
transientRunAttempts = 0
|
||||
timeoutCancel()
|
||||
baseCtx, cancelWithCause = context.WithCancelCause(context.Background())
|
||||
h.tasks.BindTaskCancel(conversationID, cancelWithCause)
|
||||
taskCtx, timeoutCancel = context.WithTimeout(baseCtx, 600*time.Minute)
|
||||
h.tasks.UpdateTaskStatus(conversationID, "running")
|
||||
continue
|
||||
}
|
||||
|
||||
if runErr == nil {
|
||||
// 任一段成功完成后,重置临时错误重试窗口(次数/退避从头开始)。
|
||||
transientRunAttempts = 0
|
||||
emptyResponseAttempts = 0
|
||||
timeoutCancel()
|
||||
break
|
||||
}
|
||||
|
||||
handled, fatalErr := h.handleEinoTransientRetryContinue(
|
||||
baseCtx, conversationID, result, runErr, &transientRunAttempts,
|
||||
&curHistory, &curFinalMessage, segmentUserMessage, progressCallback,
|
||||
func(msg string, extra map[string]interface{}) { sendEvent("progress", msg, extra) },
|
||||
)
|
||||
if handled {
|
||||
mainIterationOffset += segmentMainIterationMax
|
||||
timeoutCancel()
|
||||
baseCtx, cancelWithCause = context.WithCancelCause(context.Background())
|
||||
h.tasks.BindTaskCancel(conversationID, cancelWithCause)
|
||||
taskCtx, timeoutCancel = context.WithTimeout(baseCtx, 600*time.Minute)
|
||||
h.tasks.UpdateTaskStatus(conversationID, "running")
|
||||
continue
|
||||
}
|
||||
if fatalErr != nil {
|
||||
runErr = fatalErr
|
||||
}
|
||||
|
||||
cause := context.Cause(baseCtx)
|
||||
if errors.Is(cause, multiagent.ErrInterruptContinue) {
|
||||
if shouldPersistEinoAgentTraceAfterRunError(baseCtx) {
|
||||
@@ -324,8 +279,6 @@ func (h *AgentHandler) MultiAgentLoopStream(c *gin.Context) {
|
||||
"source": "interrupt_continue",
|
||||
})
|
||||
mainIterationOffset += segmentMainIterationMax
|
||||
// 非临时错误分段续跑(用户中断并继续)时,清空 transient 计数,避免跨分段累加。
|
||||
transientRunAttempts = 0
|
||||
timeoutCancel()
|
||||
baseCtx, cancelWithCause = context.WithCancelCause(context.Background())
|
||||
h.tasks.BindTaskCancel(conversationID, cancelWithCause)
|
||||
@@ -460,8 +413,6 @@ func (h *AgentHandler) MultiAgentLoop(c *gin.Context) {
|
||||
curMsg := prep.FinalMessage
|
||||
var result *multiagent.RunResult
|
||||
var runErr error
|
||||
var transientRunAttempts int
|
||||
var emptyResponseAttempts int
|
||||
for {
|
||||
result, runErr = multiagent.RunDeepAgent(
|
||||
taskCtx,
|
||||
@@ -481,28 +432,9 @@ func (h *AgentHandler) MultiAgentLoop(c *gin.Context) {
|
||||
chatReasoningToClientIntent(req.Reasoning),
|
||||
h.projectBlackboardBlock(prep.ConversationID),
|
||||
)
|
||||
handledEmpty, exhaustedEmpty := h.handleEinoEmptyResponseContinue(
|
||||
baseCtx, prep.ConversationID, result, runErr, &emptyResponseAttempts,
|
||||
&curHist, &curMsg, prep.FinalMessage, progressCallback, nil,
|
||||
)
|
||||
if exhaustedEmpty {
|
||||
runErr = nil
|
||||
break
|
||||
}
|
||||
if handledEmpty {
|
||||
continue
|
||||
}
|
||||
if runErr == nil {
|
||||
break
|
||||
}
|
||||
if handled, fatalErr := h.handleEinoTransientRetryContinue(
|
||||
baseCtx, prep.ConversationID, result, runErr, &transientRunAttempts,
|
||||
&curHist, &curMsg, prep.FinalMessage, progressCallback, nil,
|
||||
); handled {
|
||||
continue
|
||||
} else if fatalErr != nil {
|
||||
runErr = fatalErr
|
||||
}
|
||||
if shouldPersistEinoAgentTraceAfterRunError(baseCtx) {
|
||||
h.persistEinoAgentTraceForResume(prep.ConversationID, result)
|
||||
}
|
||||
|
||||
@@ -2464,17 +2464,108 @@ func (h *OpenAPIHandler) GetOpenAPISpec(c *gin.Context) {
|
||||
"parameters": []map[string]interface{}{
|
||||
{"name": "id", "in": "path", "required": true, "schema": map[string]interface{}{"type": "string"}},
|
||||
{"name": "fact_key", "in": "query", "schema": map[string]interface{}{"type": "string"}},
|
||||
{"name": "include_links", "in": "query", "schema": map[string]interface{}{"type": "boolean"}},
|
||||
{"name": "include_link_counts", "in": "query", "schema": map[string]interface{}{"type": "boolean"}},
|
||||
},
|
||||
"responses": map[string]interface{}{"200": map[string]interface{}{"description": "事实列表或单条"}},
|
||||
"responses": map[string]interface{}{"200": map[string]interface{}{"description": "事实列表或单条(可含 link_counts / outgoing_links)"}},
|
||||
},
|
||||
"post": map[string]interface{}{
|
||||
"tags": []string{"项目管理"}, "summary": "创建/更新事实", "operationId": "upsertProjectFactREST",
|
||||
"parameters": []map[string]interface{}{
|
||||
{"name": "id", "in": "path", "required": true, "schema": map[string]interface{}{"type": "string"}},
|
||||
},
|
||||
"requestBody": map[string]interface{}{
|
||||
"required": true,
|
||||
"content": map[string]interface{}{
|
||||
"application/json": map[string]interface{}{
|
||||
"schema": map[string]interface{}{
|
||||
"type": "object",
|
||||
"properties": map[string]interface{}{
|
||||
"fact_key": map[string]interface{}{"type": "string"},
|
||||
"summary": map[string]interface{}{"type": "string"},
|
||||
"links": map[string]interface{}{
|
||||
"type": "array",
|
||||
"items": map[string]interface{}{
|
||||
"type": "object",
|
||||
"properties": map[string]interface{}{
|
||||
"to": map[string]interface{}{"type": "string"},
|
||||
"type": map[string]interface{}{"type": "string"},
|
||||
},
|
||||
},
|
||||
},
|
||||
"links_text": map[string]interface{}{"type": "string", "description": "type: fact_key 每行一条"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
"responses": map[string]interface{}{"200": map[string]interface{}{"description": "成功"}},
|
||||
},
|
||||
},
|
||||
"/api/projects/{id}/fact-graph": map[string]interface{}{
|
||||
"get": map[string]interface{}{
|
||||
"tags": []string{"项目管理"}, "summary": "获取项目事实攻击路径图", "operationId": "getProjectFactGraph",
|
||||
"parameters": []map[string]interface{}{
|
||||
{"name": "id", "in": "path", "required": true, "schema": map[string]interface{}{"type": "string"}},
|
||||
{"name": "view", "in": "query", "schema": map[string]interface{}{"type": "string", "enum": []string{"path", "full"}, "default": "path"}},
|
||||
{"name": "exclude_deprecated", "in": "query", "schema": map[string]interface{}{"type": "boolean", "default": true}},
|
||||
},
|
||||
"responses": map[string]interface{}{"200": map[string]interface{}{"description": "nodes + edges"}},
|
||||
},
|
||||
},
|
||||
"/api/projects/{id}/fact-edges": map[string]interface{}{
|
||||
"get": map[string]interface{}{
|
||||
"tags": []string{"项目管理"}, "summary": "列出项目全部事实边", "operationId": "listProjectFactEdges",
|
||||
"parameters": []map[string]interface{}{
|
||||
{"name": "id", "in": "path", "required": true, "schema": map[string]interface{}{"type": "string"}},
|
||||
},
|
||||
"responses": map[string]interface{}{"200": map[string]interface{}{"description": "边列表"}},
|
||||
},
|
||||
"post": map[string]interface{}{
|
||||
"tags": []string{"项目管理"}, "summary": "添加事实边", "operationId": "createProjectFactEdge",
|
||||
"parameters": []map[string]interface{}{
|
||||
{"name": "id", "in": "path", "required": true, "schema": map[string]interface{}{"type": "string"}},
|
||||
},
|
||||
"requestBody": map[string]interface{}{
|
||||
"required": true,
|
||||
"content": map[string]interface{}{
|
||||
"application/json": map[string]interface{}{
|
||||
"schema": map[string]interface{}{
|
||||
"type": "object",
|
||||
"required": []string{"source_fact_key", "target_fact_key", "edge_type"},
|
||||
"properties": map[string]interface{}{
|
||||
"source_fact_key": map[string]interface{}{"type": "string"},
|
||||
"target_fact_key": map[string]interface{}{"type": "string"},
|
||||
"edge_type": map[string]interface{}{"type": "string"},
|
||||
"confidence": map[string]interface{}{"type": "string"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
"responses": map[string]interface{}{"200": map[string]interface{}{"description": "边已创建"}},
|
||||
},
|
||||
},
|
||||
"/api/projects/{id}/fact-edges/{edgeId}": map[string]interface{}{
|
||||
"delete": map[string]interface{}{
|
||||
"tags": []string{"项目管理"}, "summary": "删除事实边", "operationId": "deleteProjectFactEdge",
|
||||
"parameters": []map[string]interface{}{
|
||||
{"name": "id", "in": "path", "required": true, "schema": map[string]interface{}{"type": "string"}},
|
||||
{"name": "edgeId", "in": "path", "required": true, "schema": map[string]interface{}{"type": "string"}},
|
||||
},
|
||||
"responses": map[string]interface{}{"200": map[string]interface{}{"description": "删除成功"}},
|
||||
},
|
||||
},
|
||||
"/api/projects/{id}/promote-attack-chain/{conversationId}": map[string]interface{}{
|
||||
"post": map[string]interface{}{
|
||||
"tags": []string{"项目管理"}, "summary": "将对话攻击链沉淀到项目事实图", "operationId": "promoteAttackChainToProject",
|
||||
"parameters": []map[string]interface{}{
|
||||
{"name": "id", "in": "path", "required": true, "schema": map[string]interface{}{"type": "string"}},
|
||||
{"name": "conversationId", "in": "path", "required": true, "schema": map[string]interface{}{"type": "string"}},
|
||||
},
|
||||
"responses": map[string]interface{}{"200": map[string]interface{}{"description": "沉淀结果(facts/edges/graph)"}},
|
||||
},
|
||||
},
|
||||
"/api/vulnerabilities": map[string]interface{}{
|
||||
"get": map[string]interface{}{
|
||||
"tags": []string{"漏洞管理"},
|
||||
|
||||
+255
-21
@@ -1,10 +1,12 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"cyberstrike-ai/internal/attackchain"
|
||||
"cyberstrike-ai/internal/database"
|
||||
"cyberstrike-ai/internal/project"
|
||||
|
||||
@@ -223,26 +225,102 @@ func (h *ProjectHandler) DeleteProject(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{"success": true})
|
||||
}
|
||||
|
||||
type factLinkRequest struct {
|
||||
From string `json:"from"`
|
||||
Type string `json:"type"`
|
||||
Confidence string `json:"confidence,omitempty"`
|
||||
}
|
||||
|
||||
type upsertFactRequest struct {
|
||||
FactKey string `json:"fact_key" binding:"required"`
|
||||
Category string `json:"category"`
|
||||
Summary string `json:"summary" binding:"required"`
|
||||
Body string `json:"body"`
|
||||
Confidence string `json:"confidence"`
|
||||
Pinned bool `json:"pinned"`
|
||||
RelatedVulnerabilityID string `json:"related_vulnerability_id"`
|
||||
FactKey string `json:"fact_key" binding:"required"`
|
||||
Category string `json:"category"`
|
||||
Summary string `json:"summary" binding:"required"`
|
||||
Body string `json:"body"`
|
||||
Confidence string `json:"confidence"`
|
||||
Pinned bool `json:"pinned"`
|
||||
RelatedVulnerabilityID string `json:"related_vulnerability_id"`
|
||||
Links []factLinkRequest `json:"links"`
|
||||
LinksText *string `json:"links_text"`
|
||||
}
|
||||
|
||||
// updateFactRequest 部分更新事实;指针字段省略=不修改,body 传 "" 可清空(仍走 merge 逻辑见 Upsert)。
|
||||
type updateFactRequest struct {
|
||||
FactKey *string `json:"fact_key"`
|
||||
Category *string `json:"category"`
|
||||
Summary *string `json:"summary"`
|
||||
Body *string `json:"body"`
|
||||
Confidence *string `json:"confidence"`
|
||||
Pinned *bool `json:"pinned"`
|
||||
RelatedVulnerabilityID *string `json:"related_vulnerability_id"`
|
||||
ClearBody bool `json:"clear_body"`
|
||||
FactKey *string `json:"fact_key"`
|
||||
Category *string `json:"category"`
|
||||
Summary *string `json:"summary"`
|
||||
Body *string `json:"body"`
|
||||
Confidence *string `json:"confidence"`
|
||||
Pinned *bool `json:"pinned"`
|
||||
RelatedVulnerabilityID *string `json:"related_vulnerability_id"`
|
||||
ClearBody bool `json:"clear_body"`
|
||||
Links *[]factLinkRequest `json:"links"`
|
||||
LinksText *string `json:"links_text"`
|
||||
}
|
||||
|
||||
func factLinksFromRequest(links []factLinkRequest, linksText *string) (*project.ParsedFactLinks, error) {
|
||||
if len(links) > 0 {
|
||||
parsed := &project.ParsedFactLinks{}
|
||||
for i, l := range links {
|
||||
from := strings.TrimSpace(l.From)
|
||||
edgeType := strings.TrimSpace(l.Type)
|
||||
if from == "" {
|
||||
return nil, fmt.Errorf("links[%d] 须含 from", i)
|
||||
}
|
||||
if edgeType == "" {
|
||||
return nil, fmt.Errorf("links[%d] 须含 type", i)
|
||||
}
|
||||
parsed.Incoming = append(parsed.Incoming, database.ProjectFactEdgeFromInput{
|
||||
From: from, Type: edgeType, Confidence: strings.TrimSpace(l.Confidence),
|
||||
})
|
||||
}
|
||||
return parsed, nil
|
||||
}
|
||||
if linksText != nil {
|
||||
in, err := project.ParseFactLinksText(*linksText)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &project.ParsedFactLinks{Incoming: in}, nil
|
||||
}
|
||||
return &project.ParsedFactLinks{Incoming: []database.ProjectFactEdgeFromInput{}}, nil
|
||||
}
|
||||
|
||||
type factWithLinksResponse struct {
|
||||
*database.ProjectFact
|
||||
OutgoingLinks []*database.ProjectFactEdge `json:"outgoing_links,omitempty"`
|
||||
IncomingLinks []*database.ProjectFactEdge `json:"incoming_links,omitempty"`
|
||||
LinkCounts *project.LinkCounts `json:"link_counts,omitempty"`
|
||||
}
|
||||
|
||||
func (h *ProjectHandler) applyFactLinksAfterUpsert(projectID string, fact *database.ProjectFact, links []factLinkRequest, linksText *string, explicitLinks, parseBody bool) error {
|
||||
if explicitLinks {
|
||||
parsed, err := factLinksFromRequest(links, linksText)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return project.PersistFactLinksFromParsed(h.db, projectID, fact.FactKey, fact.SourceConversationID, parsed, true)
|
||||
}
|
||||
if parseBody {
|
||||
inputs := project.ParseLinksFromBody(fact.Body)
|
||||
if inputs == nil {
|
||||
return nil
|
||||
}
|
||||
return project.PersistFactIncomingLinks(h.db, projectID, fact.FactKey, inputs, true)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *ProjectHandler) factResponseWithLinks(projectID string, f *database.ProjectFact, includeLinks bool) interface{} {
|
||||
if !includeLinks || f == nil {
|
||||
return f
|
||||
}
|
||||
out, _ := h.db.ListOutgoingProjectFactEdges(projectID, f.FactKey)
|
||||
in, _ := h.db.ListIncomingProjectFactEdges(projectID, f.FactKey)
|
||||
return &factWithLinksResponse{
|
||||
ProjectFact: f,
|
||||
OutgoingLinks: out,
|
||||
IncomingLinks: in,
|
||||
}
|
||||
}
|
||||
|
||||
// ListFacts GET /api/projects/:id/facts (fact_key 查询参数可获取单条详情)
|
||||
@@ -254,7 +332,8 @@ func (h *ProjectHandler) ListFacts(c *gin.Context) {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, f)
|
||||
includeLinks := c.Query("include_links") == "1" || c.Query("include_links") == "true"
|
||||
c.JSON(http.StatusOK, h.factResponseWithLinks(projectID, f, includeLinks))
|
||||
return
|
||||
}
|
||||
limit, _ := strconv.Atoi(c.DefaultQuery("limit", "100"))
|
||||
@@ -285,7 +364,52 @@ func (h *ProjectHandler) ListFacts(c *gin.Context) {
|
||||
}
|
||||
list = filtered
|
||||
}
|
||||
c.JSON(http.StatusOK, list)
|
||||
includeLinkCounts := c.Query("include_link_counts") == "1" || c.Query("include_link_counts") == "true"
|
||||
if !includeLinkCounts {
|
||||
c.JSON(http.StatusOK, list)
|
||||
return
|
||||
}
|
||||
counts, err := project.LoadProjectFactLinkCounts(h.db, projectID)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
out := make([]factWithLinksResponse, 0, len(list))
|
||||
for _, f := range list {
|
||||
item := factWithLinksResponse{ProjectFact: f}
|
||||
if c, ok := counts[f.FactKey]; ok {
|
||||
cc := c
|
||||
item.LinkCounts = &cc
|
||||
}
|
||||
out = append(out, item)
|
||||
}
|
||||
c.JSON(http.StatusOK, out)
|
||||
}
|
||||
|
||||
// GetFactGraph GET /api/projects/:id/fact-graph?view=path|full
|
||||
func (h *ProjectHandler) GetFactGraph(c *gin.Context) {
|
||||
projectID := c.Param("id")
|
||||
if _, err := h.db.GetProject(projectID); err != nil {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "项目不存在"})
|
||||
return
|
||||
}
|
||||
view := c.DefaultQuery("view", "path")
|
||||
excludeDeprecated := true
|
||||
if v := c.Query("exclude_deprecated"); v == "0" || v == "false" {
|
||||
excludeDeprecated = false
|
||||
}
|
||||
graph, err := project.BuildProjectFactGraph(h.db, projectID, view, excludeDeprecated)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
if graph.Nodes == nil {
|
||||
graph.Nodes = []database.ProjectFactGraphNode{}
|
||||
}
|
||||
if graph.Edges == nil {
|
||||
graph.Edges = []database.ProjectFactGraphEdge{}
|
||||
}
|
||||
c.JSON(http.StatusOK, graph)
|
||||
}
|
||||
|
||||
// CreateFact POST /api/projects/:id/facts
|
||||
@@ -295,8 +419,9 @@ func (h *ProjectHandler) CreateFact(c *gin.Context) {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
projectID := c.Param("id")
|
||||
f := &database.ProjectFact{
|
||||
ProjectID: c.Param("id"),
|
||||
ProjectID: projectID,
|
||||
FactKey: req.FactKey,
|
||||
Category: req.Category,
|
||||
Summary: req.Summary,
|
||||
@@ -310,16 +435,24 @@ func (h *ProjectHandler) CreateFact(c *gin.Context) {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, created)
|
||||
explicitLinks := req.Links != nil || req.LinksText != nil
|
||||
if err := h.applyFactLinksAfterUpsert(projectID, created, req.Links, req.LinksText, explicitLinks, !explicitLinks); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
created, _ = h.db.GetProjectFactByKey(projectID, created.FactKey)
|
||||
c.JSON(http.StatusOK, h.factResponseWithLinks(projectID, created, true))
|
||||
}
|
||||
|
||||
// UpdateFact PUT /api/projects/:id/facts/:factId
|
||||
func (h *ProjectHandler) UpdateFact(c *gin.Context) {
|
||||
projectID := c.Param("id")
|
||||
existing, err := h.db.GetProjectFact(c.Param("factId"))
|
||||
if err != nil || existing.ProjectID != c.Param("id") {
|
||||
if err != nil || existing.ProjectID != projectID {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "事实不存在"})
|
||||
return
|
||||
}
|
||||
oldFactKey := existing.FactKey
|
||||
var req updateFactRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
@@ -355,7 +488,29 @@ func (h *ProjectHandler) UpdateFact(c *gin.Context) {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, updated)
|
||||
if oldFactKey != updated.FactKey {
|
||||
if err := h.db.RenameProjectFactKeyEdges(projectID, oldFactKey, updated.FactKey); err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
}
|
||||
if req.Links != nil || req.LinksText != nil {
|
||||
var links []factLinkRequest
|
||||
if req.Links != nil {
|
||||
links = *req.Links
|
||||
}
|
||||
if err := h.applyFactLinksAfterUpsert(projectID, updated, links, req.LinksText, true, false); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
} else if req.ClearBody || req.Body != nil {
|
||||
if err := h.applyFactLinksAfterUpsert(projectID, updated, nil, nil, false, true); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
}
|
||||
updated, _ = h.db.GetProjectFactByKey(projectID, updated.FactKey)
|
||||
c.JSON(http.StatusOK, h.factResponseWithLinks(projectID, updated, true))
|
||||
}
|
||||
|
||||
// DeleteFact DELETE /api/projects/:id/facts/:factId
|
||||
@@ -408,3 +563,82 @@ func (h *ProjectHandler) RestoreFact(c *gin.Context) {
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{"success": true})
|
||||
}
|
||||
|
||||
type createFactEdgeRequest struct {
|
||||
SourceFactKey string `json:"source_fact_key" binding:"required"`
|
||||
TargetFactKey string `json:"target_fact_key" binding:"required"`
|
||||
EdgeType string `json:"edge_type" binding:"required"`
|
||||
Confidence string `json:"confidence"`
|
||||
}
|
||||
|
||||
// ListFactEdges GET /api/projects/:id/fact-edges
|
||||
func (h *ProjectHandler) ListFactEdges(c *gin.Context) {
|
||||
projectID := c.Param("id")
|
||||
edges, err := h.db.ListProjectFactEdgesByProject(projectID)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
if edges == nil {
|
||||
edges = []*database.ProjectFactEdge{}
|
||||
}
|
||||
c.JSON(http.StatusOK, edges)
|
||||
}
|
||||
|
||||
// CreateFactEdge POST /api/projects/:id/fact-edges
|
||||
func (h *ProjectHandler) CreateFactEdge(c *gin.Context) {
|
||||
projectID := c.Param("id")
|
||||
var req createFactEdgeRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
edge, err := h.db.AddProjectFactEdge(projectID, database.ProjectFactEdgeInput{
|
||||
To: req.TargetFactKey,
|
||||
Type: req.EdgeType,
|
||||
Confidence: req.Confidence,
|
||||
}, req.SourceFactKey, "")
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
if f, err := h.db.GetProjectFactByKey(projectID, req.TargetFactKey); err == nil {
|
||||
in, _ := h.db.ListIncomingProjectFactEdges(projectID, req.TargetFactKey)
|
||||
f.Body = project.SyncBodyLinksSection(f.Body, in)
|
||||
_, _ = h.db.UpsertProjectFact(f)
|
||||
}
|
||||
c.JSON(http.StatusOK, edge)
|
||||
}
|
||||
|
||||
// DeleteFactEdge DELETE /api/projects/:id/fact-edges/:edgeId
|
||||
func (h *ProjectHandler) DeleteFactEdge(c *gin.Context) {
|
||||
projectID := c.Param("id")
|
||||
edgeID := c.Param("edgeId")
|
||||
edge, err := h.db.GetProjectFactEdge(edgeID)
|
||||
if err != nil || edge.ProjectID != projectID {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "边不存在"})
|
||||
return
|
||||
}
|
||||
if err := h.db.DeleteProjectFactEdge(edgeID); err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
if f, err := h.db.GetProjectFactByKey(projectID, edge.TargetFactKey); err == nil {
|
||||
in, _ := h.db.ListIncomingProjectFactEdges(projectID, edge.TargetFactKey)
|
||||
f.Body = project.SyncBodyLinksSection(f.Body, in)
|
||||
_, _ = h.db.UpsertProjectFact(f)
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{"success": true})
|
||||
}
|
||||
|
||||
// PromoteAttackChain POST /api/projects/:id/promote-attack-chain/:conversationId
|
||||
func (h *ProjectHandler) PromoteAttackChain(c *gin.Context) {
|
||||
projectID := c.Param("id")
|
||||
conversationID := c.Param("conversationId")
|
||||
result, err := attackchain.PromoteToProject(h.db, projectID, conversationID)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, result)
|
||||
}
|
||||
|
||||
@@ -594,6 +594,9 @@ func (h *RobotHandler) cmdDelete(platform, userID, convID string) string {
|
||||
h.mu.Unlock()
|
||||
h.deleteSessionBinding(sk)
|
||||
}
|
||||
if h.agentHandler != nil {
|
||||
h.agentHandler.CancelRunningTaskForConversation(convID)
|
||||
}
|
||||
if err := h.db.DeleteConversation(convID); err != nil {
|
||||
return "删除失败: " + err.Error()
|
||||
}
|
||||
|
||||
@@ -37,6 +37,11 @@ type AgentTask struct {
|
||||
// InterruptContinueNote 无 MCP 时「中断并继续」由用户在弹窗中填写的补充说明(Cancel 前写入,续跑轮次读取后清空)
|
||||
InterruptContinueNote string `json:"-"`
|
||||
|
||||
// activeEinoExecuteCancel 当前进行中的 Eino filesystem execute 取消函数(与 MCP 工具并行,供中断并继续)
|
||||
activeEinoExecuteCancel context.CancelFunc
|
||||
// activeEinoExecuteAbortNote AbortActiveEinoExecute 写入的用户说明,由 execute 收尾时合并进工具结果
|
||||
activeEinoExecuteAbortNote string
|
||||
|
||||
cancel func(error)
|
||||
}
|
||||
|
||||
@@ -70,6 +75,69 @@ func (m *AgentTaskManager) UnregisterRunningTool(conversationID, executionID str
|
||||
}
|
||||
}
|
||||
|
||||
// RegisterActiveEinoExecute 登记进行中的 Eino filesystem execute(每会话同时仅一条)。
|
||||
func (m *AgentTaskManager) RegisterActiveEinoExecute(conversationID string, cancel context.CancelFunc) {
|
||||
conversationID = strings.TrimSpace(conversationID)
|
||||
if conversationID == "" || cancel == nil {
|
||||
return
|
||||
}
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
if t, ok := m.tasks[conversationID]; ok && t != nil {
|
||||
t.activeEinoExecuteCancel = cancel
|
||||
t.activeEinoExecuteAbortNote = ""
|
||||
}
|
||||
}
|
||||
|
||||
// UnregisterActiveEinoExecute execute 正常结束或已取消后清除登记。
|
||||
func (m *AgentTaskManager) UnregisterActiveEinoExecute(conversationID string) {
|
||||
conversationID = strings.TrimSpace(conversationID)
|
||||
if conversationID == "" {
|
||||
return
|
||||
}
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
if t, ok := m.tasks[conversationID]; ok && t != nil {
|
||||
t.activeEinoExecuteCancel = nil
|
||||
t.activeEinoExecuteAbortNote = ""
|
||||
}
|
||||
}
|
||||
|
||||
// AbortActiveEinoExecute 终止当前 Eino execute 并暂存用户说明(与 MCP 工具终止一致)。
|
||||
func (m *AgentTaskManager) AbortActiveEinoExecute(conversationID, note string) bool {
|
||||
conversationID = strings.TrimSpace(conversationID)
|
||||
if conversationID == "" {
|
||||
return false
|
||||
}
|
||||
m.mu.Lock()
|
||||
t, ok := m.tasks[conversationID]
|
||||
if !ok || t == nil || t.activeEinoExecuteCancel == nil {
|
||||
m.mu.Unlock()
|
||||
return false
|
||||
}
|
||||
t.activeEinoExecuteAbortNote = strings.TrimSpace(note)
|
||||
cancel := t.activeEinoExecuteCancel
|
||||
m.mu.Unlock()
|
||||
cancel()
|
||||
return true
|
||||
}
|
||||
|
||||
// TakeEinoExecuteAbortNote 读取并清空 execute 终止说明(execute 收尾时调用一次)。
|
||||
func (m *AgentTaskManager) TakeEinoExecuteAbortNote(conversationID string) string {
|
||||
conversationID = strings.TrimSpace(conversationID)
|
||||
if conversationID == "" {
|
||||
return ""
|
||||
}
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
if t, ok := m.tasks[conversationID]; ok && t != nil {
|
||||
n := t.activeEinoExecuteAbortNote
|
||||
t.activeEinoExecuteAbortNote = ""
|
||||
return n
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// SetInterruptContinueNote 在发起 ErrInterruptContinue 取消前写入用户补充说明(仅内存)。
|
||||
func (m *AgentTaskManager) SetInterruptContinueNote(conversationID, note string) {
|
||||
conversationID = strings.TrimSpace(conversationID)
|
||||
|
||||
@@ -0,0 +1,40 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestAbortActiveEinoExecute(t *testing.T) {
|
||||
m := NewAgentTaskManager()
|
||||
conv := "conv-eino-exec-abort"
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
_, err := m.StartTask(conv, "test", func(error) {})
|
||||
if err != nil {
|
||||
t.Fatalf("StartTask: %v", err)
|
||||
}
|
||||
m.RegisterActiveEinoExecute(conv, cancel)
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
<-ctx.Done()
|
||||
close(done)
|
||||
}()
|
||||
|
||||
if !m.AbortActiveEinoExecute(conv, "跳过域名收集") {
|
||||
t.Fatal("expected abort to succeed")
|
||||
}
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("execute cancel did not propagate")
|
||||
}
|
||||
if got := m.TakeEinoExecuteAbortNote(conv); got != "跳过域名收集" {
|
||||
t.Fatalf("abort note = %q, want 跳过域名收集", got)
|
||||
}
|
||||
m.UnregisterActiveEinoExecute(conv)
|
||||
if m.AbortActiveEinoExecute(conv, "") {
|
||||
t.Fatal("second abort should fail when no active execute")
|
||||
}
|
||||
}
|
||||
@@ -11,7 +11,16 @@ type ToolRunRegistry interface {
|
||||
UnregisterRunningTool(conversationID, executionID string)
|
||||
}
|
||||
|
||||
// EinoExecuteRunRegistry 登记进行中的 Eino filesystem execute,供「中断并继续」终止 amass 等长命令。
|
||||
type EinoExecuteRunRegistry interface {
|
||||
RegisterActiveEinoExecute(conversationID string, cancel context.CancelFunc)
|
||||
UnregisterActiveEinoExecute(conversationID string)
|
||||
AbortActiveEinoExecute(conversationID, note string) bool
|
||||
TakeEinoExecuteAbortNote(conversationID string) string
|
||||
}
|
||||
|
||||
type toolRunRegistryCtxKey struct{}
|
||||
type einoExecuteRunRegistryCtxKey struct{}
|
||||
type mcpConversationIDCtxKey struct{}
|
||||
|
||||
// WithToolRunRegistry 将登记器注入 ctx(Eino / 原生 Agent 任务 ctx)。
|
||||
@@ -31,6 +40,23 @@ func ToolRunRegistryFromContext(ctx context.Context) ToolRunRegistry {
|
||||
return v
|
||||
}
|
||||
|
||||
// WithEinoExecuteRunRegistry 将 Eino execute 取消登记器注入 ctx。
|
||||
func WithEinoExecuteRunRegistry(ctx context.Context, reg EinoExecuteRunRegistry) context.Context {
|
||||
if ctx == nil || reg == nil {
|
||||
return ctx
|
||||
}
|
||||
return context.WithValue(ctx, einoExecuteRunRegistryCtxKey{}, reg)
|
||||
}
|
||||
|
||||
// EinoExecuteRunRegistryFromContext 取出 Eino execute 登记器(无则 nil)。
|
||||
func EinoExecuteRunRegistryFromContext(ctx context.Context) EinoExecuteRunRegistry {
|
||||
if ctx == nil {
|
||||
return nil
|
||||
}
|
||||
v, _ := ctx.Value(einoExecuteRunRegistryCtxKey{}).(EinoExecuteRunRegistry)
|
||||
return v
|
||||
}
|
||||
|
||||
// WithMCPConversationID 将对话 ID 注入 ctx,供 CallTool 内与 executionId 关联。
|
||||
func WithMCPConversationID(ctx context.Context, conversationID string) context.Context {
|
||||
if ctx == nil {
|
||||
|
||||
@@ -0,0 +1,71 @@
|
||||
package monitor
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"cyberstrike-ai/internal/config"
|
||||
"cyberstrike-ai/internal/database"
|
||||
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
const retentionPurgeInterval = time.Hour
|
||||
|
||||
// Service manages MCP tool execution monitor retention.
|
||||
type Service struct {
|
||||
db *database.DB
|
||||
cfg *config.Config
|
||||
logger *zap.Logger
|
||||
}
|
||||
|
||||
// NewService creates a monitor retention service.
|
||||
func NewService(db *database.DB, cfg *config.Config, logger *zap.Logger) *Service {
|
||||
return &Service{db: db, cfg: cfg, logger: logger}
|
||||
}
|
||||
|
||||
// RetentionDays returns configured retention; 0 means keep forever.
|
||||
func (s *Service) RetentionDays() int {
|
||||
if s == nil || s.cfg == nil {
|
||||
return config.MonitorConfig{}.RetentionDaysEffective()
|
||||
}
|
||||
return s.cfg.Monitor.RetentionDaysEffective()
|
||||
}
|
||||
|
||||
// PurgeExpired deletes tool execution rows older than retention_days when configured.
|
||||
func (s *Service) PurgeExpired() {
|
||||
if s == nil || s.db == nil || s.cfg == nil {
|
||||
return
|
||||
}
|
||||
days := s.cfg.Monitor.RetentionDaysEffective()
|
||||
if days <= 0 {
|
||||
return
|
||||
}
|
||||
cutoff := time.Now().AddDate(0, 0, -days)
|
||||
n, err := s.db.PurgeToolExecutionsBefore(cutoff)
|
||||
if err != nil {
|
||||
if s.logger != nil {
|
||||
s.logger.Warn("清理过期 MCP 执行记录失败", zap.Error(err))
|
||||
}
|
||||
return
|
||||
}
|
||||
if n > 0 && s.logger != nil {
|
||||
s.logger.Info("已清理过期 MCP 执行记录", zap.Int64("deleted", n), zap.Int("retention_days", days))
|
||||
}
|
||||
}
|
||||
|
||||
// StartRetentionLoop periodically purges expired tool execution rows.
|
||||
func StartRetentionLoop(s *Service, logger *zap.Logger) {
|
||||
if s == nil {
|
||||
return
|
||||
}
|
||||
go func() {
|
||||
ticker := time.NewTicker(retentionPurgeInterval)
|
||||
defer ticker.Stop()
|
||||
for range ticker.C {
|
||||
s.PurgeExpired()
|
||||
if logger != nil {
|
||||
logger.Debug("monitor retention tick completed")
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
@@ -0,0 +1,94 @@
|
||||
package monitor
|
||||
|
||||
import (
|
||||
"path/filepath"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"cyberstrike-ai/internal/config"
|
||||
"cyberstrike-ai/internal/database"
|
||||
"cyberstrike-ai/internal/mcp"
|
||||
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
func TestServicePurgeExpired_respectsZeroRetention(t *testing.T) {
|
||||
dbPath := filepath.Join(t.TempDir(), "monitor.db")
|
||||
db, err := database.NewDB(dbPath, zap.NewNop())
|
||||
if err != nil {
|
||||
t.Fatalf("NewDB: %v", err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
exec := &mcp.ToolExecution{
|
||||
ID: "ancient",
|
||||
ToolName: "curl::get",
|
||||
Arguments: map[string]interface{}{},
|
||||
Status: "completed",
|
||||
StartTime: mustParseTime(t, "2020-01-01T00:00:00Z"),
|
||||
}
|
||||
if err := db.SaveToolExecution(exec); err != nil {
|
||||
t.Fatalf("SaveToolExecution: %v", err)
|
||||
}
|
||||
|
||||
zero := 0
|
||||
svc := NewService(db, &config.Config{
|
||||
Monitor: config.MonitorConfig{RetentionDays: &zero},
|
||||
}, zap.NewNop())
|
||||
svc.PurgeExpired()
|
||||
|
||||
if _, err := db.GetToolExecution("ancient"); err != nil {
|
||||
t.Fatalf("record should remain when retention_days=0: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestServicePurgeExpired_deletesOldRows(t *testing.T) {
|
||||
dbPath := filepath.Join(t.TempDir(), "monitor.db")
|
||||
db, err := database.NewDB(dbPath, zap.NewNop())
|
||||
if err != nil {
|
||||
t.Fatalf("NewDB: %v", err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
exec := &mcp.ToolExecution{
|
||||
ID: "ancient",
|
||||
ToolName: "curl::get",
|
||||
Arguments: map[string]interface{}{},
|
||||
Status: "completed",
|
||||
StartTime: mustParseTime(t, "2020-01-01T00:00:00Z"),
|
||||
}
|
||||
if err := db.SaveToolExecution(exec); err != nil {
|
||||
t.Fatalf("SaveToolExecution: %v", err)
|
||||
}
|
||||
|
||||
days := 90
|
||||
svc := NewService(db, &config.Config{
|
||||
Monitor: config.MonitorConfig{RetentionDays: &days},
|
||||
}, zap.NewNop())
|
||||
svc.PurgeExpired()
|
||||
|
||||
if _, err := db.GetToolExecution("ancient"); err == nil {
|
||||
t.Fatal("record should be purged when older than retention_days")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRetentionDaysEffective_defaults(t *testing.T) {
|
||||
got := config.MonitorConfig{}.RetentionDaysEffective()
|
||||
if got != 90 {
|
||||
t.Fatalf("default = %d, want 90", got)
|
||||
}
|
||||
zero := 0
|
||||
cfg := config.MonitorConfig{RetentionDays: &zero}
|
||||
if cfg.RetentionDaysEffective() != 0 {
|
||||
t.Fatalf("zero = %d, want 0", cfg.RetentionDaysEffective())
|
||||
}
|
||||
}
|
||||
|
||||
func mustParseTime(t *testing.T, value string) time.Time {
|
||||
t.Helper()
|
||||
parsed, err := time.Parse(time.RFC3339, value)
|
||||
if err != nil {
|
||||
t.Fatalf("parse time: %v", err)
|
||||
}
|
||||
return parsed
|
||||
}
|
||||
@@ -0,0 +1,104 @@
|
||||
package multiagent
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
|
||||
"github.com/cloudwego/eino/adk"
|
||||
"github.com/cloudwego/eino/schema"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// continuationSessionMarker matches Cursor / IDE session-resume user injections.
|
||||
const continuationSessionMarker = "This session is being continued from a previous conversation"
|
||||
|
||||
// continuationUserDedupMiddleware keeps only the latest session-resume user message when
|
||||
// multiple continuation injections were stacked (e.g. after repeated out-of-context resumes).
|
||||
type continuationUserDedupMiddleware struct {
|
||||
adk.BaseChatModelAgentMiddleware
|
||||
logger *zap.Logger
|
||||
phase string
|
||||
}
|
||||
|
||||
func newContinuationUserDedupMiddleware(logger *zap.Logger, phase string) adk.ChatModelAgentMiddleware {
|
||||
return &continuationUserDedupMiddleware{logger: logger, phase: phase}
|
||||
}
|
||||
|
||||
func (m *continuationUserDedupMiddleware) BeforeModelRewriteState(
|
||||
ctx context.Context,
|
||||
state *adk.ChatModelAgentState,
|
||||
mc *adk.ModelContext,
|
||||
) (context.Context, *adk.ChatModelAgentState, error) {
|
||||
_ = mc
|
||||
if m == nil || state == nil || len(state.Messages) == 0 {
|
||||
return ctx, state, nil
|
||||
}
|
||||
deduped, dropped := dedupContinuationUserMessages(state.Messages)
|
||||
if dropped == 0 {
|
||||
return ctx, state, nil
|
||||
}
|
||||
if m.logger != nil {
|
||||
m.logger.Info("eino continuation user messages deduplicated",
|
||||
zap.String("phase", m.phase),
|
||||
zap.Int("dropped", dropped),
|
||||
zap.Int("messages_before", len(state.Messages)),
|
||||
zap.Int("messages_after", len(deduped)),
|
||||
)
|
||||
}
|
||||
out := *state
|
||||
out.Messages = deduped
|
||||
return ctx, &out, nil
|
||||
}
|
||||
|
||||
func adkUserMessageText(msg adk.Message) string {
|
||||
if msg == nil {
|
||||
return ""
|
||||
}
|
||||
var b strings.Builder
|
||||
if s := strings.TrimSpace(msg.Content); s != "" {
|
||||
b.WriteString(s)
|
||||
}
|
||||
for _, part := range msg.UserInputMultiContent {
|
||||
if part.Type == schema.ChatMessagePartTypeText {
|
||||
if s := strings.TrimSpace(part.Text); s != "" {
|
||||
if b.Len() > 0 {
|
||||
b.WriteByte('\n')
|
||||
}
|
||||
b.WriteString(s)
|
||||
}
|
||||
}
|
||||
}
|
||||
return b.String()
|
||||
}
|
||||
|
||||
func isContinuationUserMessage(msg adk.Message) bool {
|
||||
if msg == nil || msg.Role != schema.User {
|
||||
return false
|
||||
}
|
||||
return strings.Contains(adkUserMessageText(msg), continuationSessionMarker)
|
||||
}
|
||||
|
||||
func dedupContinuationUserMessages(msgs []adk.Message) ([]adk.Message, int) {
|
||||
lastIdx := -1
|
||||
contCount := 0
|
||||
for i, msg := range msgs {
|
||||
if !isContinuationUserMessage(msg) {
|
||||
continue
|
||||
}
|
||||
contCount++
|
||||
lastIdx = i
|
||||
}
|
||||
if contCount <= 1 {
|
||||
return msgs, 0
|
||||
}
|
||||
out := make([]adk.Message, 0, len(msgs)-(contCount-1))
|
||||
dropped := 0
|
||||
for i, msg := range msgs {
|
||||
if isContinuationUserMessage(msg) && i != lastIdx {
|
||||
dropped++
|
||||
continue
|
||||
}
|
||||
out = append(out, msg)
|
||||
}
|
||||
return out, dropped
|
||||
}
|
||||
@@ -0,0 +1,65 @@
|
||||
package multiagent
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/cloudwego/eino/adk"
|
||||
"github.com/cloudwego/eino/schema"
|
||||
)
|
||||
|
||||
func continuationUser(text string) adk.Message {
|
||||
return &schema.Message{
|
||||
Role: schema.User,
|
||||
UserInputMultiContent: []schema.MessageInputPart{
|
||||
{Type: schema.ChatMessagePartTypeText, Text: continuationSessionMarker + "\n" + text},
|
||||
{Type: schema.ChatMessagePartTypeText, Text: "Please continue the conversation from where we left it off."},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func TestDedupContinuationUserMessages_KeepsLatest(t *testing.T) {
|
||||
msgs := []adk.Message{
|
||||
continuationUser("summary old"),
|
||||
schema.UserMessage("real task"),
|
||||
continuationUser("summary new"),
|
||||
}
|
||||
out, dropped := dedupContinuationUserMessages(msgs)
|
||||
if dropped != 1 {
|
||||
t.Fatalf("dropped=%d want 1", dropped)
|
||||
}
|
||||
if len(out) != 2 {
|
||||
t.Fatalf("len=%d want 2", len(out))
|
||||
}
|
||||
if out[0].Role != schema.User || adkUserMessageText(out[0]) != "real task" {
|
||||
t.Fatalf("first should remain real task, got %q", adkUserMessageText(out[0]))
|
||||
}
|
||||
if !strings.Contains(adkUserMessageText(out[1]), "summary new") {
|
||||
t.Fatalf("latest continuation not kept: %q", adkUserMessageText(out[1]))
|
||||
}
|
||||
}
|
||||
|
||||
func TestDedupContinuationUserMessages_NoOpSingle(t *testing.T) {
|
||||
msgs := []adk.Message{continuationUser("only"), schema.UserMessage("task")}
|
||||
out, dropped := dedupContinuationUserMessages(msgs)
|
||||
if dropped != 0 || len(out) != 2 {
|
||||
t.Fatalf("unexpected change dropped=%d len=%d", dropped, len(out))
|
||||
}
|
||||
}
|
||||
|
||||
func TestContinuationUserDedupMiddleware(t *testing.T) {
|
||||
mw := newContinuationUserDedupMiddleware(nil, "test")
|
||||
state := &adk.ChatModelAgentState{Messages: []adk.Message{
|
||||
continuationUser("old"),
|
||||
continuationUser("new"),
|
||||
schema.UserMessage("task"),
|
||||
}}
|
||||
_, out, err := mw.(*continuationUserDedupMiddleware).BeforeModelRewriteState(context.Background(), state, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(out.Messages) != 2 {
|
||||
t.Fatalf("want 2 messages after dedup, got %d", len(out.Messages))
|
||||
}
|
||||
}
|
||||
@@ -90,7 +90,7 @@ type einoADKRunLoopArgs struct {
|
||||
FilesystemMonitorRecord einomcp.ExecutionRecorder
|
||||
MCPExecutionBinder *MCPExecutionBinder
|
||||
|
||||
// ToolInvokeNotify 与 einomcp.ToolsFromDefinitions 共享:run loop 在迭代前 Set,MCP 桥 Fire 以补全 tool_result。
|
||||
// ToolInvokeNotify 与 einomcp.ToolsFromDefinitions 共享:run loop 在迭代前 Set,execute/MCP 桥 Fire 时立即推送 tool_result(ADK 晚到经 toolResultSent 去重)。
|
||||
ToolInvokeNotify *einomcp.ToolInvokeNotifyHolder
|
||||
|
||||
DA adk.Agent
|
||||
@@ -341,8 +341,22 @@ func runEinoADKAgentLoop(ctx context.Context, args *einoADKRunLoopArgs, baseMsgs
|
||||
}
|
||||
if args.ToolInvokeNotify != nil {
|
||||
args.ToolInvokeNotify.Set(func(toolCallID, toolName, einoAgent string, success bool, content string, invokeErr error) {
|
||||
removePendingByID(strings.TrimSpace(toolCallID))
|
||||
// tool_result 仅由下方 ADK schema.Tool 事件推送,正文与送入模型的上下文一致(含 reduction 截断)。
|
||||
// Eino execute / MCP 桥在工具返回时 Fire;若 ADK schema.Tool 事件迟迟不到,此处立即推送
|
||||
// tool_result 解除 UI「执行中」。tryEmitToolResultProgress 经 toolResultSent 去重,ADK 晚到不重复。
|
||||
isErr := !success || invokeErr != nil
|
||||
body := content
|
||||
if strings.HasPrefix(body, einomcp.ToolErrorPrefix) {
|
||||
isErr = true
|
||||
body = strings.TrimPrefix(body, einomcp.ToolErrorPrefix)
|
||||
}
|
||||
if tail := friendlyEinoExecuteInvokeTail(invokeErr); tail != "" {
|
||||
if body == "" {
|
||||
body = tail
|
||||
} else if !strings.Contains(body, tail) {
|
||||
body = strings.TrimSpace(body) + "\n\n" + tail
|
||||
}
|
||||
}
|
||||
tryEmitToolResultProgress(toolName, body, toolCallID, isErr, einoAgent)
|
||||
})
|
||||
}
|
||||
|
||||
@@ -383,6 +397,12 @@ func runEinoADKAgentLoop(ctx context.Context, args *einoADKRunLoopArgs, baseMsgs
|
||||
}
|
||||
}
|
||||
runner := adk.NewRunner(ctx, runnerCfg)
|
||||
startRunnerIter := func(runMsgs []adk.Message) *adk.AsyncIterator[*adk.AgentEvent] {
|
||||
if checkPointID != "" {
|
||||
return runner.Run(ctx, runMsgs, adk.WithCheckPointID(checkPointID))
|
||||
}
|
||||
return runner.Run(ctx, runMsgs)
|
||||
}
|
||||
var iter *adk.AsyncIterator[*adk.AgentEvent]
|
||||
if cpStore != nil && checkPointID != "" {
|
||||
if _, existed, getErr := cpStore.Get(ctx, checkPointID); getErr != nil {
|
||||
@@ -422,12 +442,9 @@ func runEinoADKAgentLoop(ctx context.Context, args *einoADKRunLoopArgs, baseMsgs
|
||||
}
|
||||
}
|
||||
if iter == nil {
|
||||
if checkPointID != "" {
|
||||
iter = runner.Run(ctx, msgs, adk.WithCheckPointID(checkPointID))
|
||||
} else {
|
||||
iter = runner.Run(ctx, msgs)
|
||||
}
|
||||
iter = startRunnerIter(msgs)
|
||||
}
|
||||
transientRetrier := newEinoTransientRunRetrier(einoTransientRunRetryPolicyFromArgs(args))
|
||||
handleRunErr := func(runErr error) error {
|
||||
if runErr == nil {
|
||||
return nil
|
||||
@@ -480,26 +497,60 @@ func runEinoADKAgentLoop(ctx context.Context, args *einoADKRunLoopArgs, baseMsgs
|
||||
return runErr
|
||||
}
|
||||
|
||||
// maybeRetryTransientRun:不在此层 runner.Run/Resume;由 handler 落库 + loadHistoryFromAgentTrace 分段续跑(同中断并继续)。
|
||||
maybeRetryTransientRun := func(runErr error) (retry bool, fatal error) {
|
||||
if runErr == nil || !isEinoTransientRunError(runErr) {
|
||||
maybeRetryTransientRun := func(runErr error) (restarted bool, fatal error) {
|
||||
if runErr == nil {
|
||||
return false, nil
|
||||
}
|
||||
if !isEinoTransientRunError(runErr) {
|
||||
return false, handleRunErr(runErr)
|
||||
}
|
||||
restarted, restartMsgs, ctxSource, backoff, retErr := transientRetrier.tryRetry(
|
||||
ctx, runErr, args, baseMsgs, runAccumulatedMsgs, baseAccumulatedCount,
|
||||
)
|
||||
if retErr != nil {
|
||||
flushAllPendingAsFailed(runErr)
|
||||
if logger != nil {
|
||||
logger.Warn("eino transient retry exhausted",
|
||||
zap.Error(retErr),
|
||||
zap.String("orchestration", orchMode),
|
||||
zap.Int("maxAttempts", transientRetrier.maxAttempts()))
|
||||
}
|
||||
return false, retErr
|
||||
}
|
||||
if !restarted {
|
||||
return false, nil
|
||||
}
|
||||
attemptNo := transientRetrier.attempt()
|
||||
maxAttempts := transientRetrier.maxAttempts()
|
||||
if logger != nil {
|
||||
logger.Warn("eino transient error, ending run segment for handler resume",
|
||||
logger.Warn("eino transient error, retrying after backoff",
|
||||
zap.Error(runErr),
|
||||
zap.String("orchestration", orchMode))
|
||||
zap.String("orchestration", orchMode),
|
||||
zap.Int("attempt", attemptNo),
|
||||
zap.Int("maxAttempts", maxAttempts),
|
||||
zap.Duration("backoff", backoff))
|
||||
}
|
||||
if progress != nil {
|
||||
progress("eino_run_retry", "遇到临时错误(限流或网络波动),将保存上下文并重试…", map[string]interface{}{
|
||||
progress("eino_run_retry", fmt.Sprintf("遇到临时错误(限流或网络波动),%d 秒后第 %d/%d 次重试…", int(backoff.Seconds()), attemptNo, maxAttempts), map[string]interface{}{
|
||||
"conversationId": conversationID,
|
||||
"source": "eino",
|
||||
"orchestration": orchMode,
|
||||
"error": runErr.Error(),
|
||||
"resumeKind": "trace_segment",
|
||||
"attempt": attemptNo,
|
||||
"maxAttempts": maxAttempts,
|
||||
"backoffSec": int(backoff.Seconds()),
|
||||
})
|
||||
progress("eino_run_retry", "已恢复上下文,正在重试…", map[string]interface{}{
|
||||
"conversationId": conversationID,
|
||||
"source": "eino",
|
||||
"orchestration": orchMode,
|
||||
"attempt": attemptNo,
|
||||
"contextSource": string(ctxSource),
|
||||
})
|
||||
}
|
||||
return false, ErrTransientRetryContinue
|
||||
msgs = restartMsgs
|
||||
iter = startRunnerIter(msgs)
|
||||
return true, nil
|
||||
}
|
||||
|
||||
takePartial := func(runErr error) (*RunResult, error) {
|
||||
@@ -514,10 +565,10 @@ func runEinoADKAgentLoop(ctx context.Context, args *einoADKRunLoopArgs, baseMsgs
|
||||
}
|
||||
|
||||
for {
|
||||
// 检测 context 取消(用户关闭浏览器、请求超时等),flush pending 工具状态避免 UI 卡在 "执行中"。
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
flushAllPendingAsFailed(ctx.Err())
|
||||
// iter.Next 可能长时间阻塞(工具执行、模型推理);须与 ctx 联动,否则取消/超时无法及时 flush pending。
|
||||
ev, ok, iterCtxErr := nextAgentEventWithContext(ctx, iter)
|
||||
if iterCtxErr != nil {
|
||||
flushAllPendingAsFailed(iterCtxErr)
|
||||
if progress != nil {
|
||||
if isInterruptContinue(ctx) {
|
||||
progress("progress", "已暂停当前输出,正在合并用户补充并继续…", map[string]interface{}{
|
||||
@@ -526,17 +577,14 @@ func runEinoADKAgentLoop(ctx context.Context, args *einoADKRunLoopArgs, baseMsgs
|
||||
"kind": "interrupt_continue",
|
||||
})
|
||||
} else {
|
||||
progress("error", "Request cancelled / 请求已取消", map[string]interface{}{
|
||||
progress("error", iterCtxErr.Error(), map[string]interface{}{
|
||||
"conversationId": conversationID,
|
||||
"source": "eino",
|
||||
})
|
||||
}
|
||||
}
|
||||
return takePartial(ctx.Err())
|
||||
default:
|
||||
return takePartial(iterCtxErr)
|
||||
}
|
||||
|
||||
ev, ok := iter.Next()
|
||||
if !ok {
|
||||
// iter 结束并不总是“正常完成”:
|
||||
// 当取消/超时发生在 iter.Next() 阻塞期间时,可能直接返回 !ok。
|
||||
@@ -583,9 +631,15 @@ func runEinoADKAgentLoop(ctx context.Context, args *einoADKRunLoopArgs, baseMsgs
|
||||
continue
|
||||
}
|
||||
if ev.Err != nil {
|
||||
if _, retErr := maybeRetryTransientRun(ev.Err); retErr != nil {
|
||||
restarted, retErr := maybeRetryTransientRun(ev.Err)
|
||||
if retErr != nil {
|
||||
return takePartial(retErr)
|
||||
}
|
||||
if restarted {
|
||||
continue
|
||||
}
|
||||
} else {
|
||||
transientRetrier.reset()
|
||||
}
|
||||
if ev.AgentName != "" && progress != nil {
|
||||
iterEinoAgent := orchestratorName
|
||||
@@ -630,13 +684,16 @@ func runEinoADKAgentLoop(ctx context.Context, args *einoADKRunLoopArgs, baseMsgs
|
||||
}
|
||||
}
|
||||
}
|
||||
// 仅在代理切换时更新进度标题;同一代理的每个 ADK 事件不再重复刷 progress。
|
||||
if einoLastAgent != ev.AgentName {
|
||||
progress("progress", fmt.Sprintf("[Eino] %s", ev.AgentName), map[string]interface{}{
|
||||
"conversationId": conversationID,
|
||||
"einoAgent": ev.AgentName,
|
||||
"einoRole": einoRoleTag(ev.AgentName),
|
||||
"orchestration": orchMode,
|
||||
})
|
||||
}
|
||||
einoLastAgent = ev.AgentName
|
||||
progress("progress", fmt.Sprintf("[Eino] %s", ev.AgentName), map[string]interface{}{
|
||||
"conversationId": conversationID,
|
||||
"einoAgent": ev.AgentName,
|
||||
"einoRole": einoRoleTag(ev.AgentName),
|
||||
"orchestration": orchMode,
|
||||
})
|
||||
}
|
||||
if ev.Output == nil || ev.Output.MessageOutput == nil {
|
||||
continue
|
||||
@@ -645,29 +702,7 @@ func runEinoADKAgentLoop(ctx context.Context, args *einoADKRunLoopArgs, baseMsgs
|
||||
|
||||
if mv.IsStreaming && mv.MessageStream != nil && mv.Role == schema.Tool {
|
||||
toolName := strings.TrimSpace(mv.ToolName)
|
||||
var toolBuf strings.Builder
|
||||
streamToolCallID := ""
|
||||
var toolStreamRecvErr error
|
||||
for {
|
||||
chunk, rerr := mv.MessageStream.Recv()
|
||||
if errors.Is(rerr, io.EOF) {
|
||||
break
|
||||
}
|
||||
if rerr != nil {
|
||||
toolStreamRecvErr = rerr
|
||||
break
|
||||
}
|
||||
if chunk == nil {
|
||||
continue
|
||||
}
|
||||
if chunk.Content != "" {
|
||||
toolBuf.WriteString(chunk.Content)
|
||||
}
|
||||
if tid := strings.TrimSpace(chunk.ToolCallID); tid != "" {
|
||||
streamToolCallID = tid
|
||||
}
|
||||
}
|
||||
content := toolBuf.String()
|
||||
content, streamToolCallID, toolStreamRecvErr := recvSchemaMessageStream(ctx, mv.MessageStream)
|
||||
isErr := false
|
||||
if strings.HasPrefix(content, einomcp.ToolErrorPrefix) {
|
||||
isErr = true
|
||||
@@ -948,9 +983,13 @@ func runEinoADKAgentLoop(ctx context.Context, args *einoADKRunLoopArgs, baseMsgs
|
||||
"einoRole": einoRoleTag(ev.AgentName),
|
||||
})
|
||||
}
|
||||
if _, retErr := maybeRetryTransientRun(streamRecvErr); retErr != nil {
|
||||
restarted, retErr := maybeRetryTransientRun(streamRecvErr)
|
||||
if retErr != nil {
|
||||
return takePartial(retErr)
|
||||
}
|
||||
if restarted {
|
||||
continue
|
||||
}
|
||||
}
|
||||
continue
|
||||
}
|
||||
@@ -1054,32 +1093,9 @@ func runEinoADKAgentLoop(ctx context.Context, args *einoADKRunLoopArgs, baseMsgs
|
||||
orchMode, runAccumulatedMsgs, persistTraceSource(args, runAccumulatedMsgs),
|
||||
lastAssistant, lastPlanExecuteExecutor, emptyHint, ids, false,
|
||||
)
|
||||
if shouldEinoEmptyResponseContinue(out, emptyHint, len(runAccumulatedMsgs), baseAccumulatedCount) {
|
||||
if logger != nil {
|
||||
logger.Info("eino empty response, ending run segment for handler resume",
|
||||
zap.String("conversationId", conversationID),
|
||||
zap.String("orchestration", orchMode),
|
||||
zap.Int("traceMessages", len(runAccumulatedMsgs)))
|
||||
}
|
||||
if progress != nil {
|
||||
progress("eino_empty_response_continue", "会话已结束但未产生助手正文,正在基于轨迹自动续跑…", map[string]interface{}{
|
||||
"conversationId": conversationID,
|
||||
"source": "eino",
|
||||
"resumeKind": "trace_segment",
|
||||
})
|
||||
}
|
||||
return out, ErrEmptyResponseContinue
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func shouldEinoEmptyResponseContinue(out *RunResult, emptyHint string, accumulatedLen, baseCount int) bool {
|
||||
if out == nil || accumulatedLen <= baseCount {
|
||||
return false
|
||||
}
|
||||
return strings.TrimSpace(out.Response) == strings.TrimSpace(emptyHint)
|
||||
}
|
||||
|
||||
func persistTraceSource(args *einoADKRunLoopArgs, fallback []adk.Message) []adk.Message {
|
||||
if args != nil && args.ModelFacingTrace != nil {
|
||||
if snap := args.ModelFacingTrace.Snapshot(); len(snap) > 0 {
|
||||
@@ -1105,6 +1121,78 @@ func friendlyEinoExecuteInvokeTail(invokeErr error) string {
|
||||
return "[执行未正常结束] " + invokeErr.Error()
|
||||
}
|
||||
|
||||
// nextAgentEventWithContext 在 ctx 取消时不再无限阻塞于 iter.Next()(工具执行/模型推理期间常见)。
|
||||
func nextAgentEventWithContext(ctx context.Context, iter *adk.AsyncIterator[*adk.AgentEvent]) (ev *adk.AgentEvent, ok bool, ctxErr error) {
|
||||
if iter == nil {
|
||||
return nil, false, nil
|
||||
}
|
||||
type nextRes struct {
|
||||
ev *adk.AgentEvent
|
||||
ok bool
|
||||
}
|
||||
ch := make(chan nextRes, 1)
|
||||
go func() {
|
||||
e, o := iter.Next()
|
||||
ch <- nextRes{e, o}
|
||||
}()
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, false, ctx.Err()
|
||||
case res := <-ch:
|
||||
return res.ev, res.ok, nil
|
||||
}
|
||||
}
|
||||
|
||||
// recvSchemaMessageStream 消费 ADK Tool 流式结果;ctx 取消时立即返回,避免 amass 等无输出时永久阻塞。
|
||||
func recvSchemaMessageStream(ctx context.Context, stream *schema.StreamReader[*schema.Message]) (content, toolCallID string, recvErr error) {
|
||||
if stream == nil {
|
||||
return "", "", nil
|
||||
}
|
||||
type streamMsg struct {
|
||||
chunk *schema.Message
|
||||
err error
|
||||
}
|
||||
recvCh := make(chan streamMsg, 8)
|
||||
go func() {
|
||||
defer close(recvCh)
|
||||
for {
|
||||
ch, rerr := stream.Recv()
|
||||
recvCh <- streamMsg{chunk: ch, err: rerr}
|
||||
if rerr != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
var buf strings.Builder
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return buf.String(), toolCallID, ctx.Err()
|
||||
case sm, open := <-recvCh:
|
||||
if !open {
|
||||
return buf.String(), toolCallID, nil
|
||||
}
|
||||
rerr := sm.err
|
||||
if errors.Is(rerr, io.EOF) {
|
||||
return buf.String(), toolCallID, nil
|
||||
}
|
||||
if rerr != nil {
|
||||
return buf.String(), toolCallID, rerr
|
||||
}
|
||||
chunk := sm.chunk
|
||||
if chunk == nil {
|
||||
continue
|
||||
}
|
||||
if chunk.Content != "" {
|
||||
buf.WriteString(chunk.Content)
|
||||
}
|
||||
if tid := strings.TrimSpace(chunk.ToolCallID); tid != "" {
|
||||
toolCallID = tid
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func buildEinoRunResultFromAccumulated(
|
||||
orchMode string,
|
||||
runAccumulatedMsgs []adk.Message,
|
||||
|
||||
@@ -0,0 +1,74 @@
|
||||
package multiagent
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"io"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/cloudwego/eino/schema"
|
||||
)
|
||||
|
||||
func TestRecvSchemaMessageStream_EOF(t *testing.T) {
|
||||
sr, sw := schema.Pipe[*schema.Message](4)
|
||||
_ = sw.Send(schema.ToolMessage("hello", "tc-1"), nil)
|
||||
sw.Close()
|
||||
|
||||
content, tid, err := recvSchemaMessageStream(context.Background(), sr)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected err: %v", err)
|
||||
}
|
||||
if content != "hello" {
|
||||
t.Fatalf("content=%q want hello", content)
|
||||
}
|
||||
if tid != "tc-1" {
|
||||
t.Fatalf("toolCallID=%q want tc-1", tid)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRecvSchemaMessageStream_ContextCancel(t *testing.T) {
|
||||
sr, sw := schema.Pipe[*schema.Message](4)
|
||||
t.Cleanup(func() { sw.Close() })
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
go func() {
|
||||
time.Sleep(30 * time.Millisecond)
|
||||
cancel()
|
||||
}()
|
||||
|
||||
content, _, err := recvSchemaMessageStream(ctx, sr)
|
||||
if !errors.Is(err, context.Canceled) {
|
||||
t.Fatalf("want context.Canceled, got %v content=%q", err, content)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRecvSchemaMessageStream_RecvError(t *testing.T) {
|
||||
sr, sw := schema.Pipe[*schema.Message](4)
|
||||
want := errors.New("stream broken")
|
||||
_ = sw.Send(nil, want)
|
||||
sw.Close()
|
||||
|
||||
_, _, err := recvSchemaMessageStream(context.Background(), sr)
|
||||
if !errors.Is(err, want) {
|
||||
t.Fatalf("want %v, got %v", want, err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRecvSchemaMessageStream_NilStream(t *testing.T) {
|
||||
content, tid, err := recvSchemaMessageStream(context.Background(), nil)
|
||||
if err != nil || content != "" || tid != "" {
|
||||
t.Fatalf("nil stream: content=%q tid=%q err=%v", content, tid, err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRecvSchemaMessageStream_EOFViaEmptyRead(t *testing.T) {
|
||||
sr, sw := schema.Pipe[*schema.Message](4)
|
||||
_ = sw.Send(nil, io.EOF)
|
||||
sw.Close()
|
||||
|
||||
_, _, err := recvSchemaMessageStream(context.Background(), sr)
|
||||
if err != nil {
|
||||
t.Fatalf("EOF should not surface as error, got %v", err)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,50 @@
|
||||
package multiagent
|
||||
|
||||
import (
|
||||
"github.com/cloudwego/eino/adk"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// einoChatModelTailConfig configures middleware appended after reduction/skill/plantask
|
||||
// and immediately before each ChatModel invocation pipeline completes.
|
||||
//
|
||||
// Order (best practice):
|
||||
// 1. system merge — accurate token count for summarization
|
||||
// 2. continuation user dedup — drop stale session-resume injections
|
||||
// 3. summarization
|
||||
// 4. orphan tool prune
|
||||
// 5. telemetry
|
||||
// 6. model-facing trace snapshot
|
||||
type einoChatModelTailConfig struct {
|
||||
logger *zap.Logger
|
||||
phase string
|
||||
summarization adk.ChatModelAgentMiddleware
|
||||
modelName string
|
||||
conversationID string
|
||||
trace *modelFacingTraceHolder
|
||||
skipOrphanPruner bool
|
||||
skipTelemetry bool
|
||||
skipTrace bool
|
||||
}
|
||||
|
||||
func appendEinoChatModelTailMiddlewares(handlers []adk.ChatModelAgentMiddleware, cfg einoChatModelTailConfig) []adk.ChatModelAgentMiddleware {
|
||||
handlers = append(handlers, newSystemMessageNormalizerMiddleware(cfg.logger, cfg.phase))
|
||||
handlers = append(handlers, newContinuationUserDedupMiddleware(cfg.logger, cfg.phase))
|
||||
if cfg.summarization != nil {
|
||||
handlers = append(handlers, cfg.summarization)
|
||||
}
|
||||
if !cfg.skipOrphanPruner {
|
||||
handlers = append(handlers, newOrphanToolPrunerMiddleware(cfg.logger, cfg.phase))
|
||||
}
|
||||
if !cfg.skipTelemetry {
|
||||
if teleMw := newEinoModelInputTelemetryMiddleware(cfg.logger, cfg.modelName, cfg.conversationID, cfg.phase); teleMw != nil {
|
||||
handlers = append(handlers, teleMw)
|
||||
}
|
||||
}
|
||||
if !cfg.skipTrace && cfg.trace != nil {
|
||||
if capMw := newModelFacingTraceMiddleware(cfg.trace); capMw != nil {
|
||||
handlers = append(handlers, capMw)
|
||||
}
|
||||
}
|
||||
return handlers
|
||||
}
|
||||
@@ -1,21 +0,0 @@
|
||||
package multiagent
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestShouldEinoEmptyResponseContinue(t *testing.T) {
|
||||
t.Parallel()
|
||||
hint := "(empty hint)"
|
||||
out := &RunResult{Response: hint}
|
||||
if !shouldEinoEmptyResponseContinue(out, hint, 3, 1) {
|
||||
t.Fatal("expected continue when response is empty hint and trace grew")
|
||||
}
|
||||
if shouldEinoEmptyResponseContinue(out, hint, 1, 1) {
|
||||
t.Fatal("expected no continue when trace did not grow")
|
||||
}
|
||||
if shouldEinoEmptyResponseContinue(&RunResult{Response: "hello"}, hint, 3, 1) {
|
||||
t.Fatal("expected no continue when response has content")
|
||||
}
|
||||
if shouldEinoEmptyResponseContinue(nil, hint, 3, 1) {
|
||||
t.Fatal("expected no continue for nil result")
|
||||
}
|
||||
}
|
||||
@@ -6,9 +6,11 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"cyberstrike-ai/internal/einomcp"
|
||||
"cyberstrike-ai/internal/mcp"
|
||||
"cyberstrike-ai/internal/security"
|
||||
|
||||
"github.com/cloudwego/eino/adk/filesystem"
|
||||
@@ -34,13 +36,22 @@ func einoExecuteTimeoutUserHint() string {
|
||||
return "已超时终止 · Timed out"
|
||||
}
|
||||
|
||||
// einoExecuteRecvErrIsToolTimeout 判断 Recv 错误是否由 agent.tool_timeout_minutes 触发。
|
||||
// WithTimeout 到期后 local 侧常报 canceled / exit -1,但 execCtx.Err() 仍为 DeadlineExceeded。
|
||||
func einoExecuteRecvErrIsToolTimeout(rerr error, tctx context.Context) bool {
|
||||
if tctx != nil && errors.Is(tctx.Err(), context.DeadlineExceeded) {
|
||||
return true
|
||||
}
|
||||
return errors.Is(rerr, context.DeadlineExceeded)
|
||||
}
|
||||
|
||||
// einoStreamingShellWrap 包装 Eino filesystem 使用的 StreamingShell(cloudwego eino-ext local.Local)。
|
||||
// 官方 execute 工具默认走 ExecuteStreaming 且不设 RunInBackendGround;末尾带 & 时子进程仍与管道相连,
|
||||
// streamStdout 按行读取会在无换行输出时长时间阻塞(与 MCP 工具 exec 的独立实现不同)。
|
||||
// 对「完全后台」命令自动开启 RunInBackendGround,与 local.runCmdInBackground 行为对齐。
|
||||
//
|
||||
// 使用 Pipe 将内层流转发给调用方:在 inner EOF 后、关闭 Pipe 前同步调用 ToolInvokeNotify.Fire,
|
||||
// 保证 run loop 在模型开始下一轮输出前已记录 execute 结果(用于 UI 与「重复助手复述」去重)。
|
||||
// run loop 收到 Fire 后立即推送 tool_result(toolResultSent 去重),避免 ADK Tool 事件迟到时 UI 卡在「执行中」。
|
||||
//
|
||||
// 若 inner 在校验阶段直接返回 error(未建立 reader),不会进入下方 goroutine,也必须 Fire;
|
||||
// 否则 pending tool_call 要等整轮 run 结束才被 force-close,与已展示的助手/工具软错误文案不同步。
|
||||
@@ -71,18 +82,36 @@ func (w *einoStreamingShellWrap) ExecuteStreaming(ctx context.Context, input *fi
|
||||
req.Command = prependPythonUnbufferedEnv(req.Command)
|
||||
tid := strings.TrimSpace(compose.GetToolCallID(ctx))
|
||||
agentTag := strings.TrimSpace(w.einoAgentName)
|
||||
convID := mcp.MCPConversationIDFromContext(ctx)
|
||||
execReg := mcp.EinoExecuteRunRegistryFromContext(ctx)
|
||||
|
||||
execCtx := ctx
|
||||
var execCancel context.CancelFunc
|
||||
execCtx, execCancel := context.WithCancel(ctx)
|
||||
var timeoutCancel context.CancelFunc
|
||||
if w.toolTimeoutMinutes > 0 {
|
||||
execCtx, execCancel = context.WithTimeout(ctx, time.Duration(w.toolTimeoutMinutes)*time.Minute)
|
||||
execCtx, timeoutCancel = context.WithTimeout(execCtx, time.Duration(w.toolTimeoutMinutes)*time.Minute)
|
||||
}
|
||||
if execReg != nil && convID != "" {
|
||||
execReg.RegisterActiveEinoExecute(convID, execCancel)
|
||||
}
|
||||
|
||||
sr, err := w.inner.ExecuteStreaming(execCtx, &req)
|
||||
if err != nil {
|
||||
if timeoutCancel != nil {
|
||||
timeoutCancel()
|
||||
}
|
||||
if execCancel != nil {
|
||||
execCancel()
|
||||
}
|
||||
if einoExecuteRecvErrIsToolTimeout(err, execCtx) {
|
||||
hint := "\n\n" + einoExecuteTimeoutUserHint() + "\n"
|
||||
if w.recordMonitor != nil {
|
||||
w.recordMonitor(tid, userCmd, hint, false, context.DeadlineExceeded)
|
||||
}
|
||||
if w.invokeNotify != nil && tid != "" {
|
||||
w.invokeNotify.Fire(tid, "execute", agentTag, false, hint, context.DeadlineExceeded)
|
||||
}
|
||||
return schema.StreamReaderFromArray([]*filesystem.ExecuteResponse{{Output: hint}}), nil
|
||||
}
|
||||
if w.recordMonitor != nil {
|
||||
w.recordMonitor(tid, userCmd, "", false, err)
|
||||
}
|
||||
@@ -91,7 +120,10 @@ func (w *einoStreamingShellWrap) ExecuteStreaming(ctx context.Context, input *fi
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
if sr == nil || w.invokeNotify == nil || tid == "" {
|
||||
if sr == nil || w.invokeNotify == nil {
|
||||
if timeoutCancel != nil {
|
||||
timeoutCancel()
|
||||
}
|
||||
if execCancel != nil {
|
||||
execCancel()
|
||||
}
|
||||
@@ -100,11 +132,32 @@ func (w *einoStreamingShellWrap) ExecuteStreaming(ctx context.Context, input *fi
|
||||
|
||||
outR, outW := schema.Pipe[*filesystem.ExecuteResponse](32)
|
||||
|
||||
go func(inner *schema.StreamReader[*filesystem.ExecuteResponse], command string, cancel context.CancelFunc, tctx context.Context) {
|
||||
defer inner.Close()
|
||||
go func(inner *schema.StreamReader[*filesystem.ExecuteResponse], command string, cancel context.CancelFunc, timeoutCleanup context.CancelFunc, tctx context.Context, conversationID string, reg mcp.EinoExecuteRunRegistry) {
|
||||
var innerCloseOnce sync.Once
|
||||
closeInner := func() {
|
||||
innerCloseOnce.Do(func() { inner.Close() })
|
||||
}
|
||||
defer closeInner()
|
||||
if timeoutCleanup != nil {
|
||||
defer timeoutCleanup()
|
||||
}
|
||||
if cancel != nil {
|
||||
defer cancel()
|
||||
}
|
||||
if reg != nil && conversationID != "" {
|
||||
defer reg.UnregisterActiveEinoExecute(conversationID)
|
||||
}
|
||||
|
||||
// ctx 取消时关闭内层流,避免 amass 等长时间无换行输出时 Recv 永久阻塞。
|
||||
stopWatch := make(chan struct{})
|
||||
go func() {
|
||||
select {
|
||||
case <-tctx.Done():
|
||||
closeInner()
|
||||
case <-stopWatch:
|
||||
}
|
||||
}()
|
||||
defer close(stopWatch)
|
||||
|
||||
var sb strings.Builder
|
||||
success := true
|
||||
@@ -120,6 +173,15 @@ func (w *einoStreamingShellWrap) ExecuteStreaming(ctx context.Context, input *fi
|
||||
if rerr != nil {
|
||||
success = false
|
||||
invokeErr = rerr
|
||||
// 单次 execute 超时须与 MCP 工具一致:写入工具结果尾标、继续迭代,不得向 ADK 流注入硬错误。
|
||||
if einoExecuteRecvErrIsToolTimeout(rerr, tctx) {
|
||||
invokeErr = context.DeadlineExceeded
|
||||
break
|
||||
}
|
||||
if errors.Is(rerr, context.Canceled) || (tctx != nil && errors.Is(tctx.Err(), context.Canceled)) {
|
||||
invokeErr = context.Canceled
|
||||
break
|
||||
}
|
||||
_ = outW.Send(nil, rerr)
|
||||
break
|
||||
}
|
||||
@@ -154,6 +216,21 @@ func (w *einoStreamingShellWrap) ExecuteStreaming(ctx context.Context, input *fi
|
||||
success = false
|
||||
invokeErr = context.DeadlineExceeded
|
||||
}
|
||||
// 用户「中断并继续」终止 execute:合并说明进工具结果(与 MCP CancelToolExecutionWithNote 一致)。
|
||||
partialStreamed := sb.String()
|
||||
var abortNote string
|
||||
if reg != nil && conversationID != "" && (invokeErr != nil || errors.Is(tctx.Err(), context.Canceled)) {
|
||||
if note := reg.TakeEinoExecuteAbortNote(conversationID); note != "" {
|
||||
abortNote = note
|
||||
merged := mcp.MergePartialToolOutputAndAbortNote(partialStreamed, note)
|
||||
sb.Reset()
|
||||
sb.WriteString(merged)
|
||||
if invokeErr == nil {
|
||||
success = false
|
||||
invokeErr = context.Canceled
|
||||
}
|
||||
}
|
||||
}
|
||||
// ADK 从本 Pipe 拼出 tool 消息正文;仅 Notify 尾标不会进入模型上下文。超时句写入流,与 UI 一致。
|
||||
if invokeErr != nil && errors.Is(invokeErr, context.DeadlineExceeded) {
|
||||
hint := "\n\n" + einoExecuteTimeoutUserHint() + "\n"
|
||||
@@ -163,12 +240,20 @@ func (w *einoStreamingShellWrap) ExecuteStreaming(ctx context.Context, input *fi
|
||||
}
|
||||
sb.WriteString(hint)
|
||||
}
|
||||
// 中断时循环内已逐行写入 stdout;此处只追加 USER INTERRUPT NOTE,避免整段输出重复。
|
||||
if invokeErr != nil && errors.Is(invokeErr, context.Canceled) && abortNote != "" {
|
||||
if partialStreamed != "" {
|
||||
_ = outW.Send(&filesystem.ExecuteResponse{Output: "\n\n" + mcp.AbortNoteBannerForModel + "\n" + abortNote}, nil)
|
||||
} else if text := strings.TrimSpace(sb.String()); text != "" {
|
||||
_ = outW.Send(&filesystem.ExecuteResponse{Output: text + "\n"}, nil)
|
||||
}
|
||||
}
|
||||
if w.recordMonitor != nil {
|
||||
w.recordMonitor(tid, command, sb.String(), success, invokeErr)
|
||||
}
|
||||
w.invokeNotify.Fire(tid, "execute", agentTag, success, sb.String(), invokeErr)
|
||||
outW.Close()
|
||||
}(sr, userCmd, execCancel, execCtx)
|
||||
}(sr, userCmd, execCancel, timeoutCancel, execCtx, convID, execReg)
|
||||
|
||||
return outR, nil
|
||||
}
|
||||
|
||||
@@ -0,0 +1,227 @@
|
||||
package multiagent
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"io"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"cyberstrike-ai/internal/einomcp"
|
||||
"cyberstrike-ai/internal/mcp"
|
||||
|
||||
"github.com/cloudwego/eino/adk/filesystem"
|
||||
"github.com/cloudwego/eino/schema"
|
||||
)
|
||||
|
||||
type mockStreamingShell struct {
|
||||
immediateErr error
|
||||
recvErr error
|
||||
output string
|
||||
}
|
||||
|
||||
func (m *mockStreamingShell) ExecuteStreaming(ctx context.Context, input *filesystem.ExecuteRequest) (*schema.StreamReader[*filesystem.ExecuteResponse], error) {
|
||||
if m.immediateErr != nil {
|
||||
return nil, m.immediateErr
|
||||
}
|
||||
outR, outW := schema.Pipe[*filesystem.ExecuteResponse](4)
|
||||
go func() {
|
||||
defer outW.Close()
|
||||
if strings.TrimSpace(m.output) != "" {
|
||||
_ = outW.Send(&filesystem.ExecuteResponse{Output: m.output}, nil)
|
||||
}
|
||||
if m.recvErr != nil {
|
||||
_ = outW.Send(nil, m.recvErr)
|
||||
}
|
||||
}()
|
||||
return outR, nil
|
||||
}
|
||||
|
||||
func TestEinoExecuteRecvErrIsToolTimeout(t *testing.T) {
|
||||
tctx, cancel := context.WithTimeout(context.Background(), time.Millisecond)
|
||||
defer cancel()
|
||||
time.Sleep(2 * time.Millisecond)
|
||||
<-tctx.Done()
|
||||
|
||||
if !einoExecuteRecvErrIsToolTimeout(context.Canceled, tctx) {
|
||||
t.Fatal("expected canceled recv with deadline exec ctx to count as tool timeout")
|
||||
}
|
||||
if !einoExecuteRecvErrIsToolTimeout(context.DeadlineExceeded, nil) {
|
||||
t.Fatal("expected DeadlineExceeded recv without tctx")
|
||||
}
|
||||
if einoExecuteRecvErrIsToolTimeout(errors.New("exit status 1"), context.Background()) {
|
||||
t.Fatal("unexpected timeout for generic error")
|
||||
}
|
||||
}
|
||||
|
||||
func TestEinoStreamingShellWrap_ToolTimeoutImmediateErrIsSoft(t *testing.T) {
|
||||
inner := &mockStreamingShell{immediateErr: context.DeadlineExceeded}
|
||||
wrap := &einoStreamingShellWrap{
|
||||
inner: inner,
|
||||
toolTimeoutMinutes: 60,
|
||||
}
|
||||
sr, err := wrap.ExecuteStreaming(context.Background(), &filesystem.ExecuteRequest{Command: "true"})
|
||||
if err != nil {
|
||||
t.Fatalf("immediate tool timeout must return soft stream, got err: %v", err)
|
||||
}
|
||||
defer sr.Close()
|
||||
|
||||
var got strings.Builder
|
||||
for {
|
||||
resp, rerr := sr.Recv()
|
||||
if errors.Is(rerr, io.EOF) {
|
||||
break
|
||||
}
|
||||
if rerr != nil {
|
||||
t.Fatalf("outer stream must not hard-fail, got: %v", rerr)
|
||||
}
|
||||
if resp != nil && resp.Output != "" {
|
||||
got.WriteString(resp.Output)
|
||||
}
|
||||
}
|
||||
if !strings.Contains(got.String(), einoExecuteTimeoutUserHint()) {
|
||||
t.Fatalf("expected timeout hint, got: %q", got.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestEinoStreamingShellWrap_ToolTimeoutRecvErrIsSoft(t *testing.T) {
|
||||
inner := &mockStreamingShell{recvErr: context.DeadlineExceeded}
|
||||
notify := einomcp.NewToolInvokeNotifyHolder()
|
||||
wrap := &einoStreamingShellWrap{
|
||||
inner: inner,
|
||||
invokeNotify: notify,
|
||||
toolTimeoutMinutes: 60,
|
||||
}
|
||||
// 生产路径由 Eino compose 注入 toolCallID;单测通过已过期 execCtx 识别 tool_timeout 软错误。
|
||||
tctx, cancel := context.WithTimeout(context.Background(), time.Millisecond)
|
||||
defer cancel()
|
||||
time.Sleep(2 * time.Millisecond)
|
||||
<-tctx.Done()
|
||||
|
||||
sr, err := wrap.ExecuteStreaming(tctx, &filesystem.ExecuteRequest{Command: "sleep 999"})
|
||||
if err != nil {
|
||||
t.Fatalf("ExecuteStreaming: %v", err)
|
||||
}
|
||||
defer sr.Close()
|
||||
|
||||
var got strings.Builder
|
||||
for {
|
||||
resp, rerr := sr.Recv()
|
||||
if errors.Is(rerr, io.EOF) {
|
||||
break
|
||||
}
|
||||
if rerr != nil {
|
||||
t.Fatalf("outer stream must not hard-fail on tool timeout, got: %v", rerr)
|
||||
}
|
||||
if resp != nil && resp.Output != "" {
|
||||
got.WriteString(resp.Output)
|
||||
}
|
||||
}
|
||||
if !strings.Contains(got.String(), einoExecuteTimeoutUserHint()) {
|
||||
t.Fatalf("expected timeout hint in stream, got: %q", got.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestEinoStreamingShellWrap_CapturesOutputWithToolTimeout(t *testing.T) {
|
||||
inner := &mockStreamingShell{output: "100\n"}
|
||||
notify := einomcp.NewToolInvokeNotifyHolder()
|
||||
var firedContent string
|
||||
notify.Set(func(toolCallID, toolName, einoAgent string, success bool, content string, invokeErr error) {
|
||||
firedContent = content
|
||||
})
|
||||
wrap := &einoStreamingShellWrap{
|
||||
inner: inner,
|
||||
invokeNotify: notify,
|
||||
toolTimeoutMinutes: 60,
|
||||
}
|
||||
sr, err := wrap.ExecuteStreaming(context.Background(), &filesystem.ExecuteRequest{Command: "echo 100"})
|
||||
if err != nil {
|
||||
t.Fatalf("ExecuteStreaming: %v", err)
|
||||
}
|
||||
defer sr.Close()
|
||||
|
||||
var got strings.Builder
|
||||
for {
|
||||
resp, rerr := sr.Recv()
|
||||
if errors.Is(rerr, io.EOF) {
|
||||
break
|
||||
}
|
||||
if rerr != nil {
|
||||
t.Fatalf("unexpected stream error: %v", rerr)
|
||||
}
|
||||
if resp != nil && resp.Output != "" {
|
||||
got.WriteString(resp.Output)
|
||||
}
|
||||
}
|
||||
if !strings.Contains(got.String(), "100") {
|
||||
t.Fatalf("stream output = %q, want contains 100", got.String())
|
||||
}
|
||||
if !strings.Contains(firedContent, "100") {
|
||||
t.Fatalf("notify content = %q, want contains 100", firedContent)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEinoStreamingShellWrap_AbortNoteDoesNotDuplicateStreamedOutput(t *testing.T) {
|
||||
inner := &mockStreamingShell{output: "line1\nline2\n", recvErr: context.Canceled}
|
||||
notify := einomcp.NewToolInvokeNotifyHolder()
|
||||
wrap := &einoStreamingShellWrap{
|
||||
inner: inner,
|
||||
invokeNotify: notify,
|
||||
}
|
||||
reg := &abortNoteTestRegistry{note: "改成20次"}
|
||||
ctx := mcp.WithEinoExecuteRunRegistry(
|
||||
mcp.WithMCPConversationID(context.Background(), "conv-abort-dup"),
|
||||
reg,
|
||||
)
|
||||
sr, err := wrap.ExecuteStreaming(ctx, &filesystem.ExecuteRequest{Command: "ping -c 10 baidu.com"})
|
||||
if err != nil {
|
||||
t.Fatalf("ExecuteStreaming: %v", err)
|
||||
}
|
||||
defer sr.Close()
|
||||
|
||||
var got strings.Builder
|
||||
for {
|
||||
resp, rerr := sr.Recv()
|
||||
if errors.Is(rerr, io.EOF) {
|
||||
break
|
||||
}
|
||||
if rerr != nil {
|
||||
t.Fatalf("unexpected stream error: %v", rerr)
|
||||
}
|
||||
if resp != nil && resp.Output != "" {
|
||||
got.WriteString(resp.Output)
|
||||
}
|
||||
}
|
||||
out := got.String()
|
||||
if strings.Count(out, "line1") != 1 || strings.Count(out, "line2") != 1 {
|
||||
t.Fatalf("stream duplicated stdout: %q", out)
|
||||
}
|
||||
if !strings.Contains(out, "改成20次") {
|
||||
t.Fatalf("stream missing abort note: %q", out)
|
||||
}
|
||||
}
|
||||
|
||||
type abortNoteTestRegistry struct {
|
||||
note string
|
||||
}
|
||||
|
||||
func (r *abortNoteTestRegistry) RegisterActiveEinoExecute(string, context.CancelFunc) {}
|
||||
func (r *abortNoteTestRegistry) UnregisterActiveEinoExecute(string) {}
|
||||
func (r *abortNoteTestRegistry) AbortActiveEinoExecute(string, string) bool { return false }
|
||||
func (r *abortNoteTestRegistry) TakeEinoExecuteAbortNote(string) string { return r.note }
|
||||
|
||||
func TestEinoStreamingShellWrap_NonTimeoutRecvErrStillHard(t *testing.T) {
|
||||
inner := &mockStreamingShell{recvErr: errors.New("broken pipe")}
|
||||
wrap := &einoStreamingShellWrap{inner: inner}
|
||||
sr, err := wrap.ExecuteStreaming(context.Background(), &filesystem.ExecuteRequest{Command: "true"})
|
||||
if err != nil {
|
||||
t.Fatalf("ExecuteStreaming: %v", err)
|
||||
}
|
||||
defer sr.Close()
|
||||
|
||||
_, rerr := sr.Recv()
|
||||
if rerr == nil || errors.Is(rerr, io.EOF) {
|
||||
t.Fatal("expected hard stream error for non-timeout failure")
|
||||
}
|
||||
}
|
||||
@@ -243,17 +243,14 @@ func prependEinoMiddlewares(
|
||||
return outTools, extraHandlers, toolSearchActive, nil
|
||||
}
|
||||
|
||||
func deepExtrasFromConfig(ma *config.MultiAgentConfig) (outputKey string, retry *adk.ModelRetryConfig, taskDesc func(context.Context, []adk.Agent) (string, error)) {
|
||||
func deepExtrasFromConfig(ma *config.MultiAgentConfig) (outputKey string, taskDesc func(context.Context, []adk.Agent) (string, error)) {
|
||||
if ma == nil {
|
||||
return "", nil, nil
|
||||
return "", nil
|
||||
}
|
||||
mw := ma.EinoMiddleware
|
||||
if k := strings.TrimSpace(mw.DeepOutputKey); k != "" {
|
||||
outputKey = k
|
||||
}
|
||||
if mw.DeepModelRetryMaxRetries > 0 {
|
||||
retry = &adk.ModelRetryConfig{MaxRetries: mw.DeepModelRetryMaxRetries}
|
||||
}
|
||||
prefix := strings.TrimSpace(mw.TaskToolDescriptionPrefix)
|
||||
if prefix != "" {
|
||||
taskDesc = func(ctx context.Context, agents []adk.Agent) (string, error) {
|
||||
@@ -274,5 +271,5 @@ func deepExtrasFromConfig(ma *config.MultiAgentConfig) (outputKey string, retry
|
||||
return prefix + "\n可用子代理(按名称 transfer / task 调用):" + strings.Join(names, "、"), nil
|
||||
}
|
||||
}
|
||||
return outputKey, retry, taskDesc
|
||||
return outputKey, taskDesc
|
||||
}
|
||||
|
||||
@@ -94,24 +94,20 @@ func NewPlanExecuteRoot(ctx context.Context, a *PlanExecuteRootArgs) (adk.Resuma
|
||||
if a.SkillMiddleware != nil {
|
||||
execHandlers = append(execHandlers, a.SkillMiddleware)
|
||||
}
|
||||
// 4. summarization(最后,与 Deep/Supervisor 一致)
|
||||
// 4. pre-summarization normalize + continuation dedup, then summarization (与 Deep/Supervisor 一致)
|
||||
if a.AppCfg != nil {
|
||||
sumMw, sumErr := newEinoSummarizationMiddleware(ctx, a.ExecModel, a.AppCfg, a.MwCfg, a.ConversationID, a.DB, a.ProjectID, a.Logger)
|
||||
if sumErr != nil {
|
||||
return nil, fmt.Errorf("plan_execute executor summarization: %w", sumErr)
|
||||
}
|
||||
execHandlers = append(execHandlers, sumMw)
|
||||
}
|
||||
// 5. 孤儿 tool 消息兜底:必须挂在所有改写历史中间件(summarization/reduction/skill)之后、
|
||||
// telemetry 之前,保证送入 ChatModel 的消息序列 tool_call ↔ tool_result 配对完整。
|
||||
execHandlers = append(execHandlers, newOrphanToolPrunerMiddleware(a.Logger, "plan_execute_executor"))
|
||||
if teleMw := newEinoModelInputTelemetryMiddleware(a.Logger, a.ModelName, a.ConversationID, "plan_execute_executor"); teleMw != nil {
|
||||
execHandlers = append(execHandlers, teleMw)
|
||||
}
|
||||
if a.ModelFacingTrace != nil {
|
||||
if capMw := newModelFacingTraceMiddleware(a.ModelFacingTrace); capMw != nil {
|
||||
execHandlers = append(execHandlers, capMw)
|
||||
}
|
||||
execHandlers = appendEinoChatModelTailMiddlewares(execHandlers, einoChatModelTailConfig{
|
||||
logger: a.Logger,
|
||||
phase: "plan_execute_executor",
|
||||
summarization: sumMw,
|
||||
modelName: a.ModelName,
|
||||
conversationID: a.ConversationID,
|
||||
trace: a.ModelFacingTrace,
|
||||
})
|
||||
}
|
||||
executor, err := newPlanExecuteExecutor(ctx, &planexecute.ExecutorConfig{
|
||||
Model: a.ExecModel,
|
||||
|
||||
@@ -144,13 +144,14 @@ func RunEinoSingleChatModelAgent(
|
||||
}
|
||||
handlers = append(handlers, einoSkillMW)
|
||||
}
|
||||
handlers = append(handlers, mainSumMw)
|
||||
if teleMw := newEinoModelInputTelemetryMiddleware(logger, appCfg.OpenAI.Model, conversationID, "eino_single"); teleMw != nil {
|
||||
handlers = append(handlers, teleMw)
|
||||
}
|
||||
if capMw := newModelFacingTraceMiddleware(modelFacingTrace); capMw != nil {
|
||||
handlers = append(handlers, capMw)
|
||||
}
|
||||
handlers = appendEinoChatModelTailMiddlewares(handlers, einoChatModelTailConfig{
|
||||
logger: logger,
|
||||
phase: "eino_single",
|
||||
summarization: mainSumMw,
|
||||
modelName: appCfg.OpenAI.Model,
|
||||
conversationID: conversationID,
|
||||
trace: modelFacingTrace,
|
||||
})
|
||||
|
||||
maxIter := agentMaxIterations(appCfg)
|
||||
|
||||
@@ -188,13 +189,10 @@ func RunEinoSingleChatModelAgent(
|
||||
MaxIterations: maxIter,
|
||||
Handlers: handlers,
|
||||
}
|
||||
outKey, modelRetry, _ := deepExtrasFromConfig(ma)
|
||||
outKey, _ := deepExtrasFromConfig(ma)
|
||||
if outKey != "" {
|
||||
chatCfg.OutputKey = outKey
|
||||
}
|
||||
if modelRetry != nil {
|
||||
chatCfg.ModelRetryConfig = modelRetry
|
||||
}
|
||||
|
||||
chatAgent, err := adk.NewChatModelAgent(ctx, chatCfg)
|
||||
if err != nil {
|
||||
|
||||
@@ -22,8 +22,6 @@ import (
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
const defaultSummarizationRetryMax = 3
|
||||
|
||||
// einoSummarizeUserInstruction:压缩历史时保留渗透测试关键信息。
|
||||
const einoSummarizeUserInstruction = `在保持所有关键安全测试信息完整的前提下压缩对话历史。
|
||||
|
||||
@@ -97,10 +95,8 @@ func newEinoSummarizationMiddleware(
|
||||
}
|
||||
}
|
||||
|
||||
retryMax := defaultSummarizationRetryMax
|
||||
if mwCfg != nil && mwCfg.SummarizationRetryMaxAttempts > 0 {
|
||||
retryMax = mwCfg.SummarizationRetryMaxAttempts
|
||||
}
|
||||
retryPolicy := einoTransientRunRetryPolicyFromMW(mwCfg)
|
||||
retryMax := retryPolicy.maxAttempts
|
||||
|
||||
// ModelOptions apply only to summarization Generate (same ChatModel instance as the agent).
|
||||
// Strip thinking/reasoning on this call path; mark requests for empty-choices diagnostics.
|
||||
@@ -137,13 +133,14 @@ func newEinoSummarizationMiddleware(
|
||||
Retry: &summarization.RetryConfig{
|
||||
MaxRetries: &retryMax,
|
||||
ShouldRetry: func(_ context.Context, _ adk.Message, err error) bool {
|
||||
if err != nil && logger != nil {
|
||||
logger.Warn("eino summarization generate attempt failed, will retry if attempts remain",
|
||||
retry := isEinoTransientRunError(err)
|
||||
if retry && logger != nil {
|
||||
logger.Warn("eino summarization generate transient error, will retry if attempts remain",
|
||||
zap.Error(err),
|
||||
zap.Int("max_retries", retryMax),
|
||||
)
|
||||
}
|
||||
return err != nil
|
||||
return retry
|
||||
},
|
||||
},
|
||||
Finalize: func(ctx context.Context, originalMessages []adk.Message, summary adk.Message) ([]adk.Message, error) {
|
||||
@@ -260,17 +257,19 @@ func summarizeFinalizeWithRecentAssistantToolTrail(
|
||||
nonSystem = append(nonSystem, msg)
|
||||
}
|
||||
|
||||
mergedSystem := mergeCollectedSystemMessages(systemMsgs)
|
||||
|
||||
if recentTrailTokenBudget <= 0 || len(nonSystem) == 0 {
|
||||
out := make([]adk.Message, 0, len(systemMsgs)+1)
|
||||
out = append(out, systemMsgs...)
|
||||
out := make([]adk.Message, 0, len(mergedSystem)+1)
|
||||
out = append(out, mergedSystem...)
|
||||
out = append(out, summary)
|
||||
return out, nil
|
||||
}
|
||||
|
||||
rounds := splitMessagesIntoRounds(nonSystem)
|
||||
if len(rounds) == 0 {
|
||||
out := make([]adk.Message, 0, len(systemMsgs)+1)
|
||||
out = append(out, systemMsgs...)
|
||||
out := make([]adk.Message, 0, len(mergedSystem)+1)
|
||||
out = append(out, mergedSystem...)
|
||||
out = append(out, summary)
|
||||
return out, nil
|
||||
}
|
||||
@@ -322,8 +321,8 @@ func summarizeFinalizeWithRecentAssistantToolTrail(
|
||||
selectedMsgs = append(selectedMsgs, selectedRoundsReverse[i].messages...)
|
||||
}
|
||||
|
||||
out := make([]adk.Message, 0, len(systemMsgs)+1+len(selectedMsgs))
|
||||
out = append(out, systemMsgs...)
|
||||
out := make([]adk.Message, 0, len(mergedSystem)+1+len(selectedMsgs))
|
||||
out = append(out, mergedSystem...)
|
||||
out = append(out, summary)
|
||||
out = append(out, selectedMsgs...)
|
||||
return out, nil
|
||||
|
||||
@@ -192,8 +192,8 @@ func TestSummarizeFinalize_KeepsToolRoundIntact(t *testing.T) {
|
||||
if len(out) < 2 {
|
||||
t.Fatalf("output too short: %d", len(out))
|
||||
}
|
||||
if out[0] != sys {
|
||||
t.Fatalf("first message must be system")
|
||||
if out[0].Role != schema.System || out[0].Content != "sys" {
|
||||
t.Fatalf("first message must be system sys, got %s: %q", out[0].Role, out[0].Content)
|
||||
}
|
||||
if out[1] != summary {
|
||||
t.Fatalf("second message must be summary")
|
||||
@@ -293,12 +293,12 @@ func TestSummarizeFinalize_BudgetZeroFallsBackToSummaryOnly(t *testing.T) {
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if len(out) != 2 || out[0] != sys || out[1] != summary {
|
||||
if len(out) != 2 || out[0].Role != schema.System || out[0].Content != "sys" || out[1] != summary {
|
||||
t.Fatalf("budget=0 must yield [system, summary] only, got %+v", out)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSummarizeFinalize_PreservesAllSystemMessages(t *testing.T) {
|
||||
func TestSummarizeFinalize_MergesSystemMessages(t *testing.T) {
|
||||
sys1 := schema.SystemMessage("sys1")
|
||||
sys2 := schema.SystemMessage("sys2")
|
||||
summary := schema.AssistantMessage("s", nil)
|
||||
@@ -321,10 +321,13 @@ func TestSummarizeFinalize_PreservesAllSystemMessages(t *testing.T) {
|
||||
for _, m := range out {
|
||||
if m != nil && m.Role == schema.System {
|
||||
systemCount++
|
||||
if got := m.Content; got != "sys1\n\nsys2" {
|
||||
t.Fatalf("unexpected merged system content: %q", got)
|
||||
}
|
||||
}
|
||||
}
|
||||
if systemCount != 2 {
|
||||
t.Fatalf("want 2 system messages retained, got %d", systemCount)
|
||||
if systemCount != 1 {
|
||||
t.Fatalf("want 1 merged system message, got %d", systemCount)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -378,6 +381,12 @@ func TestWriteSummarizationTranscript(t *testing.T) {
|
||||
if !strings.Contains(text, "tool_calls:") || !strings.Contains(text, "nmap output") {
|
||||
t.Fatalf("missing tool round: %q", text)
|
||||
}
|
||||
if !strings.Contains(text, `"name":"stub_tool"`) || !strings.Contains(text, `"arguments":"{}"`) {
|
||||
t.Fatalf("missing tool name/arguments: %q", text)
|
||||
}
|
||||
if strings.Contains(text, "tool_call_id") || strings.Contains(text, `"id":"tc1"`) {
|
||||
t.Fatalf("transcript should omit tool_call_id: %q", text)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSanitizeSystemContentForTranscript_BestPractice(t *testing.T) {
|
||||
|
||||
@@ -23,6 +23,11 @@ const (
|
||||
transcriptSkillsSystemMarker = "# Skills System"
|
||||
)
|
||||
|
||||
type transcriptToolCall struct {
|
||||
Name string `json:"name"`
|
||||
Arguments string `json:"arguments"`
|
||||
}
|
||||
|
||||
// formatSummarizationTranscript renders pre-compaction messages for transcript.txt.
|
||||
// Best practice: keep full user/assistant/tool turns; slim system to dynamic blocks only.
|
||||
func formatSummarizationTranscript(msgs []adk.Message) string {
|
||||
@@ -138,15 +143,21 @@ func appendTranscriptMessage(sb *strings.Builder, msg adk.Message) {
|
||||
}
|
||||
}
|
||||
if len(msg.ToolCalls) > 0 {
|
||||
if b, err := sonic.Marshal(msg.ToolCalls); err == nil {
|
||||
if b, err := sonic.Marshal(formatTranscriptToolCalls(msg.ToolCalls)); err == nil {
|
||||
sb.WriteString("tool_calls: ")
|
||||
sb.Write(b)
|
||||
sb.WriteByte('\n')
|
||||
}
|
||||
}
|
||||
if msg.ToolCallID != "" {
|
||||
sb.WriteString("tool_call_id: ")
|
||||
sb.WriteString(msg.ToolCallID)
|
||||
sb.WriteByte('\n')
|
||||
}
|
||||
}
|
||||
|
||||
func formatTranscriptToolCalls(calls []schema.ToolCall) []transcriptToolCall {
|
||||
out := make([]transcriptToolCall, 0, len(calls))
|
||||
for _, tc := range calls {
|
||||
out = append(out, transcriptToolCall{
|
||||
Name: tc.Function.Name,
|
||||
Arguments: tc.Function.Arguments,
|
||||
})
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
@@ -3,6 +3,7 @@ package multiagent
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
@@ -17,8 +18,9 @@ const (
|
||||
defaultEinoRunRetryMaxBackoff = 30 * time.Second
|
||||
)
|
||||
|
||||
// isEinoTransientRunError 判断 ADK 运行期错误是否适合指数退避续跑(429、5xx、网络抖动等)。
|
||||
// 用户取消、超时、迭代上限等由 run loop 单独处理,不在此列。
|
||||
// isEinoTransientRunError 是 Eino 运行期「可退避重试 vs 直接失败」的唯一判据。
|
||||
// 429/5xx/网络抖动等返回 true;用户取消、超时、迭代上限、鉴权失败等返回 false。
|
||||
// 其它模块(run loop、summarization 等)只调用本函数,不在别处维护平行规则。
|
||||
func isEinoTransientRunError(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
@@ -60,6 +62,7 @@ func isEinoTransientRunError(err error) bool {
|
||||
"dial tcp",
|
||||
"tls handshake timeout",
|
||||
"stream error",
|
||||
"goaway", // http2: server sent GOAWAY and closed the connection
|
||||
"unexpected eof",
|
||||
`": eof`, // net/http: Post "url": EOF (often wraps io.EOF)
|
||||
"unexpected end of json",
|
||||
@@ -78,6 +81,71 @@ func isEinoTransientRunError(err error) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
type einoTransientRunRetryPolicy struct {
|
||||
maxAttempts int
|
||||
maxBackoff time.Duration
|
||||
}
|
||||
|
||||
func einoTransientRunRetryPolicyFromArgs(args *einoADKRunLoopArgs) einoTransientRunRetryPolicy {
|
||||
return einoTransientRunRetryPolicy{
|
||||
maxAttempts: einoRunRetryMaxAttempts(args),
|
||||
maxBackoff: einoRunRetryMaxBackoff(args),
|
||||
}
|
||||
}
|
||||
|
||||
func einoTransientRunRetryPolicyFromMW(mw *config.MultiAgentEinoMiddlewareConfig) einoTransientRunRetryPolicy {
|
||||
maxBackoff := defaultEinoRunRetryMaxBackoff
|
||||
if mw != nil && mw.RunRetryMaxBackoffSec > 0 {
|
||||
maxBackoff = time.Duration(mw.RunRetryMaxBackoffSec) * time.Second
|
||||
}
|
||||
return einoTransientRunRetryPolicy{
|
||||
maxAttempts: RunRetryMaxAttemptsFromConfig(mw),
|
||||
maxBackoff: maxBackoff,
|
||||
}
|
||||
}
|
||||
|
||||
// einoTransientRunRetrier 在 run loop 内对临时错误做指数退避并重启 Runner(唯一重试执行层)。
|
||||
type einoTransientRunRetrier struct {
|
||||
policy einoTransientRunRetryPolicy
|
||||
attempts int
|
||||
}
|
||||
|
||||
func newEinoTransientRunRetrier(policy einoTransientRunRetryPolicy) *einoTransientRunRetrier {
|
||||
return &einoTransientRunRetrier{policy: policy}
|
||||
}
|
||||
|
||||
// tryRetry 对临时错误退避后返回重启消息;次数用尽返回 exhausted 错误。
|
||||
func (r *einoTransientRunRetrier) tryRetry(
|
||||
ctx context.Context,
|
||||
runErr error,
|
||||
args *einoADKRunLoopArgs,
|
||||
baseMsgs, accumulated []adk.Message,
|
||||
baseCount int,
|
||||
) (restarted bool, restartMsgs []adk.Message, ctxSource einoRunRestartContextSource, backoff time.Duration, fatal error) {
|
||||
if runErr == nil || !isEinoTransientRunError(runErr) {
|
||||
return false, nil, "", 0, runErr
|
||||
}
|
||||
r.attempts++
|
||||
if r.attempts > r.policy.maxAttempts {
|
||||
return false, nil, "", 0, fmt.Errorf("transient retry exhausted after %d attempts: %w", r.policy.maxAttempts, runErr)
|
||||
}
|
||||
backoff = einoTransientRetryBackoff(r.attempts-1, r.policy.maxBackoff)
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return false, nil, "", 0, ctx.Err()
|
||||
case <-time.After(backoff):
|
||||
}
|
||||
restartMsgs, ctxSource = einoMessagesForRunRestart(args, baseMsgs, accumulated, baseCount)
|
||||
return true, restartMsgs, ctxSource, backoff, nil
|
||||
}
|
||||
|
||||
func (r *einoTransientRunRetrier) attempt() int { return r.attempts }
|
||||
|
||||
func (r *einoTransientRunRetrier) maxAttempts() int { return r.policy.maxAttempts }
|
||||
|
||||
// reset 在一次成功推进后清零重试计数,使后续临时错误从第 1 次退避重新开始。
|
||||
func (r *einoTransientRunRetrier) reset() { r.attempts = 0 }
|
||||
|
||||
func einoRunRetryMaxAttempts(args *einoADKRunLoopArgs) int {
|
||||
if args != nil && args.RunRetryMaxAttempts > 0 {
|
||||
return args.RunRetryMaxAttempts
|
||||
@@ -85,7 +153,7 @@ func einoRunRetryMaxAttempts(args *einoADKRunLoopArgs) int {
|
||||
return defaultEinoRunRetryMaxAttempts
|
||||
}
|
||||
|
||||
// RunRetryMaxAttemptsFromConfig 供 handler 分段续跑计数(与 eino_middleware.run_retry_max_attempts 一致)。
|
||||
// RunRetryMaxAttemptsFromConfig 与 eino_middleware.run_retry_max_attempts 一致。
|
||||
func RunRetryMaxAttemptsFromConfig(mw *config.MultiAgentEinoMiddlewareConfig) int {
|
||||
if mw != nil && mw.RunRetryMaxAttempts > 0 {
|
||||
return mw.RunRetryMaxAttempts
|
||||
@@ -93,15 +161,6 @@ func RunRetryMaxAttemptsFromConfig(mw *config.MultiAgentEinoMiddlewareConfig) in
|
||||
return defaultEinoRunRetryMaxAttempts
|
||||
}
|
||||
|
||||
// TransientRetryBackoff 供 handler 在分段续跑前退避。
|
||||
func TransientRetryBackoff(attempt int, maxBackoffSec int) time.Duration {
|
||||
max := defaultEinoRunRetryMaxBackoff
|
||||
if maxBackoffSec > 0 {
|
||||
max = time.Duration(maxBackoffSec) * time.Second
|
||||
}
|
||||
return einoTransientRetryBackoff(attempt, max)
|
||||
}
|
||||
|
||||
func einoRunRetryMaxBackoff(args *einoADKRunLoopArgs) time.Duration {
|
||||
if args != nil && args.RunRetryMaxBackoffSec > 0 {
|
||||
return time.Duration(args.RunRetryMaxBackoffSec) * time.Second
|
||||
@@ -122,10 +181,11 @@ const (
|
||||
// 1) ModelFacingTrace(与模型实际入参一致) 2) 事件流累积的 runAccumulatedMsgs 3) 初始 msgs。
|
||||
func einoMessagesForRunRestart(args *einoADKRunLoopArgs, baseMsgs, accumulated []adk.Message, baseCount int) ([]adk.Message, einoRunRestartContextSource) {
|
||||
if trace := persistTraceSource(args, nil); len(trace) > 0 {
|
||||
return append([]adk.Message(nil), trace...), einoRestartContextModelTrace
|
||||
// modelFacingTrace includes prior Instruction system message(s); genModelInput will prepend again.
|
||||
return stripADKSystemMessages(trace), einoRestartContextModelTrace
|
||||
}
|
||||
if len(accumulated) > baseCount {
|
||||
return append([]adk.Message(nil), accumulated...), einoRestartContextAccumulated
|
||||
return stripADKSystemMessages(accumulated), einoRestartContextAccumulated
|
||||
}
|
||||
return append([]adk.Message(nil), baseMsgs...), einoRestartContextInitial
|
||||
}
|
||||
|
||||
@@ -27,6 +27,7 @@ func TestIsEinoTransientRunError(t *testing.T) {
|
||||
{"429", errors.New("HTTP 429 Too Many Requests"), true},
|
||||
{"rate limit", errors.New(`{"error":"rate limit exceeded"}`), true},
|
||||
{"connection reset", errors.New("read tcp: connection reset by peer"), true},
|
||||
{"http2 goaway", errors.New("failed to receive stream chunk: error, http2: server sent GOAWAY and closed the connection; LastStreamID=791, ErrCode=NO_ERROR"), true},
|
||||
{"unexpected eof", errors.New("unexpected EOF"), true},
|
||||
{"503", errors.New("upstream returned 503"), true},
|
||||
{"iteration limit", errors.New("max iteration reached"), false},
|
||||
@@ -90,6 +91,20 @@ func TestEinoRunRetryMaxAttemptsFromArgs(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestEinoTransientRunRetrierReset(t *testing.T) {
|
||||
t.Parallel()
|
||||
r := newEinoTransientRunRetrier(einoTransientRunRetryPolicy{maxAttempts: 10, maxBackoff: 30 * time.Second})
|
||||
r.attempts = 3
|
||||
r.reset()
|
||||
if r.attempt() != 0 {
|
||||
t.Fatalf("after reset: attempt=%d, want 0", r.attempt())
|
||||
}
|
||||
// 重置后下一次退避应从 2s 起算(attempt index 0)。
|
||||
if got := einoTransientRetryBackoff(r.attempt(), r.policy.maxBackoff); got != 2*time.Second {
|
||||
t.Fatalf("backoff after reset: got %v, want 2s", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAppendUserMessageIfNeeded(t *testing.T) {
|
||||
t.Parallel()
|
||||
msgs := []adk.Message{schema.UserMessage("old task")}
|
||||
@@ -102,10 +117,3 @@ func TestAppendUserMessageIfNeeded(t *testing.T) {
|
||||
t.Fatalf("should not duplicate user message: len=%d", len(dup))
|
||||
}
|
||||
}
|
||||
|
||||
func TestErrTransientRetryContinue(t *testing.T) {
|
||||
t.Parallel()
|
||||
if !errors.Is(ErrTransientRetryContinue, ErrTransientRetryContinue) {
|
||||
t.Fatal("sentinel should match")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -5,11 +5,3 @@ import "errors"
|
||||
// ErrInterruptContinue 作为 context.CancelCause 使用:用户选择「中断并继续」且当前无进行中的 MCP 工具时,
|
||||
// 取消当前推理/流式输出,并在同一会话任务内携带用户补充说明自动续跑下一轮(类似 Hermes 式人机回合)。
|
||||
var ErrInterruptContinue = errors.New("agent interrupt: continue with user-supplied context")
|
||||
|
||||
// ErrTransientRetryContinue 表示 Run 因 429/网络等临时错误结束,应由 handler 落库轨迹后
|
||||
// loadHistoryFromAgentTrace 再开下一轮 Run(与 ErrInterruptContinue 同级的「分段续跑」语义)。
|
||||
var ErrTransientRetryContinue = errors.New("agent transient: retry after persisting trace")
|
||||
|
||||
// ErrEmptyResponseContinue 表示 Eino ADK 会话正常结束但未捕获到助手正文,应由 handler 落库轨迹后
|
||||
// loadHistoryFromAgentTrace 再开下一轮 Run(与 ErrInterruptContinue / ErrTransientRetryContinue 同级)。
|
||||
var ErrEmptyResponseContinue = errors.New("agent empty response: continue after persisting trace")
|
||||
|
||||
@@ -27,7 +27,7 @@ import (
|
||||
// 本中间件与之互补,专职兜底正向孤儿。
|
||||
// - 仅剔除消息,不向历史里注入虚构 assistant(tc):虚构 tool_calls 反而会误导模型后续推理。
|
||||
// 摘要已覆盖被裁剪段的语义,丢一条原始 tool 结果对对话连贯性影响最小。
|
||||
// - 位置建议:挂在所有可能改写历史的中间件(summarization / reduction / skill / plantask /
|
||||
// - 位置建议:挂在 summarization / reduction / skill / plantask / system 合并 / 续聊 dedup 之后,
|
||||
// tool_search)之后,靠近 ChatModel 调用的那一端。
|
||||
type orphanToolPrunerMiddleware struct {
|
||||
adk.BaseChatModelAgentMiddleware
|
||||
|
||||
@@ -231,13 +231,13 @@ func RunDeepAgent(
|
||||
}
|
||||
subHandlers = append(subHandlers, einoSkillMW)
|
||||
}
|
||||
subHandlers = append(subHandlers, subSumMw)
|
||||
// 孤儿 tool 消息兜底:放在 summarization 之后,telemetry 之前,
|
||||
// 以便 telemetry 记录的 token 数与 LLM 实际入参一致。
|
||||
subHandlers = append(subHandlers, newOrphanToolPrunerMiddleware(logger, "sub_agent:"+id))
|
||||
if teleMw := newEinoModelInputTelemetryMiddleware(logger, appCfg.OpenAI.Model, conversationID, "sub_agent"); teleMw != nil {
|
||||
subHandlers = append(subHandlers, teleMw)
|
||||
}
|
||||
subHandlers = appendEinoChatModelTailMiddlewares(subHandlers, einoChatModelTailConfig{
|
||||
logger: logger,
|
||||
phase: "sub_agent:" + id,
|
||||
summarization: subSumMw,
|
||||
modelName: appCfg.OpenAI.Model,
|
||||
conversationID: conversationID,
|
||||
})
|
||||
|
||||
subInstrFinal := project.AppendVisionImageAnalysisIfReady(instr, appCfg.Vision.Ready())
|
||||
subInstrFinal = injectToolNamesOnlyInstruction(ctx, subInstrFinal, subTools, subToolSearchActive)
|
||||
@@ -379,14 +379,14 @@ func RunDeepAgent(
|
||||
if einoSkillMW != nil {
|
||||
deepHandlers = append(deepHandlers, einoSkillMW)
|
||||
}
|
||||
deepHandlers = append(deepHandlers, mainSumMw)
|
||||
deepHandlers = append(deepHandlers, newOrphanToolPrunerMiddleware(logger, "deep_orchestrator"))
|
||||
if teleMw := newEinoModelInputTelemetryMiddleware(logger, appCfg.OpenAI.Model, conversationID, "deep_orchestrator"); teleMw != nil {
|
||||
deepHandlers = append(deepHandlers, teleMw)
|
||||
}
|
||||
if capMw := newModelFacingTraceMiddleware(modelFacingTrace); capMw != nil {
|
||||
deepHandlers = append(deepHandlers, capMw)
|
||||
}
|
||||
deepHandlers = appendEinoChatModelTailMiddlewares(deepHandlers, einoChatModelTailConfig{
|
||||
logger: logger,
|
||||
phase: "deep_orchestrator",
|
||||
summarization: mainSumMw,
|
||||
modelName: appCfg.OpenAI.Model,
|
||||
conversationID: conversationID,
|
||||
trace: modelFacingTrace,
|
||||
})
|
||||
|
||||
supHandlers := []adk.ChatModelAgentMiddleware{}
|
||||
if len(mainOrchestratorPre) > 0 {
|
||||
@@ -395,14 +395,14 @@ func RunDeepAgent(
|
||||
if einoSkillMW != nil {
|
||||
supHandlers = append(supHandlers, einoSkillMW)
|
||||
}
|
||||
supHandlers = append(supHandlers, mainSumMw)
|
||||
supHandlers = append(supHandlers, newOrphanToolPrunerMiddleware(logger, "supervisor_orchestrator"))
|
||||
if teleMw := newEinoModelInputTelemetryMiddleware(logger, appCfg.OpenAI.Model, conversationID, "supervisor_orchestrator"); teleMw != nil {
|
||||
supHandlers = append(supHandlers, teleMw)
|
||||
}
|
||||
if capMw := newModelFacingTraceMiddleware(modelFacingTrace); capMw != nil {
|
||||
supHandlers = append(supHandlers, capMw)
|
||||
}
|
||||
supHandlers = appendEinoChatModelTailMiddlewares(supHandlers, einoChatModelTailConfig{
|
||||
logger: logger,
|
||||
phase: "supervisor_orchestrator",
|
||||
summarization: mainSumMw,
|
||||
modelName: appCfg.OpenAI.Model,
|
||||
conversationID: conversationID,
|
||||
trace: modelFacingTrace,
|
||||
})
|
||||
|
||||
mainToolsCfg := adk.ToolsConfig{
|
||||
ToolsNodeConfig: compose.ToolsNodeConfig{
|
||||
@@ -416,7 +416,7 @@ func RunDeepAgent(
|
||||
EmitInternalEvents: true,
|
||||
}
|
||||
|
||||
deepOutKey, modelRetry, taskGen := deepExtrasFromConfig(ma)
|
||||
deepOutKey, taskGen := deepExtrasFromConfig(ma)
|
||||
|
||||
var da adk.Agent
|
||||
switch orchMode {
|
||||
@@ -451,12 +451,14 @@ func RunDeepAgent(
|
||||
SkillMiddleware: einoSkillMW,
|
||||
FilesystemMiddleware: peFsMw,
|
||||
ModelFacingTrace: modelFacingTrace,
|
||||
PlannerReplannerRewriteHandlers: []adk.ChatModelAgentMiddleware{
|
||||
mainSumMw,
|
||||
// 孤儿 tool 消息兜底:必须挂在 summarization 之后、telemetry 之前。
|
||||
newOrphanToolPrunerMiddleware(logger, "plan_execute_planner_replanner"),
|
||||
newEinoModelInputTelemetryMiddleware(logger, appCfg.OpenAI.Model, conversationID, "plan_execute_planner_replanner_rewrite"),
|
||||
},
|
||||
PlannerReplannerRewriteHandlers: appendEinoChatModelTailMiddlewares(nil, einoChatModelTailConfig{
|
||||
logger: logger,
|
||||
phase: "plan_execute_planner_replanner",
|
||||
summarization: mainSumMw,
|
||||
modelName: appCfg.OpenAI.Model,
|
||||
conversationID: conversationID,
|
||||
skipTrace: true,
|
||||
}),
|
||||
})
|
||||
if perr != nil {
|
||||
return nil, perr
|
||||
@@ -473,9 +475,6 @@ func RunDeepAgent(
|
||||
Handlers: supHandlers,
|
||||
Exit: &adk.ExitTool{},
|
||||
}
|
||||
if modelRetry != nil {
|
||||
supCfg.ModelRetryConfig = modelRetry
|
||||
}
|
||||
if deepOutKey != "" {
|
||||
supCfg.OutputKey = deepOutKey
|
||||
}
|
||||
@@ -509,9 +508,6 @@ func RunDeepAgent(
|
||||
if deepOutKey != "" {
|
||||
dcfg.OutputKey = deepOutKey
|
||||
}
|
||||
if modelRetry != nil {
|
||||
dcfg.ModelRetryConfig = modelRetry
|
||||
}
|
||||
if taskGen != nil {
|
||||
dcfg.TaskToolDescriptionGenerator = taskGen
|
||||
}
|
||||
|
||||
@@ -0,0 +1,86 @@
|
||||
package multiagent
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/cloudwego/eino/adk"
|
||||
"github.com/cloudwego/eino/schema"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// systemMessageNormalizerMiddleware merges duplicate role=system messages into a single
|
||||
// leading system message before summarization and each ChatModel call.
|
||||
type systemMessageNormalizerMiddleware struct {
|
||||
adk.BaseChatModelAgentMiddleware
|
||||
logger *zap.Logger
|
||||
phase string
|
||||
}
|
||||
|
||||
func newSystemMessageNormalizerMiddleware(logger *zap.Logger, phase string) adk.ChatModelAgentMiddleware {
|
||||
return &systemMessageNormalizerMiddleware{logger: logger, phase: phase}
|
||||
}
|
||||
|
||||
func (m *systemMessageNormalizerMiddleware) BeforeModelRewriteState(
|
||||
ctx context.Context,
|
||||
state *adk.ChatModelAgentState,
|
||||
mc *adk.ModelContext,
|
||||
) (context.Context, *adk.ChatModelAgentState, error) {
|
||||
_ = mc
|
||||
if m == nil || state == nil || len(state.Messages) == 0 {
|
||||
return ctx, state, nil
|
||||
}
|
||||
before := countADKSystemMessages(state.Messages)
|
||||
if before <= 1 {
|
||||
return ctx, state, nil
|
||||
}
|
||||
normalized := normalizeSingleLeadingSystemMessage(state.Messages, "")
|
||||
if len(normalized) == len(state.Messages) && countADKSystemMessages(normalized) >= before {
|
||||
return ctx, state, nil
|
||||
}
|
||||
if m.logger != nil {
|
||||
m.logger.Info("eino system messages merged",
|
||||
zap.String("phase", m.phase),
|
||||
zap.Int("system_before", before),
|
||||
zap.Int("system_after", countADKSystemMessages(normalized)),
|
||||
zap.Int("messages_before", len(state.Messages)),
|
||||
zap.Int("messages_after", len(normalized)),
|
||||
)
|
||||
}
|
||||
out := *state
|
||||
out.Messages = normalized
|
||||
return ctx, &out, nil
|
||||
}
|
||||
|
||||
func countADKSystemMessages(msgs []adk.Message) int {
|
||||
n := 0
|
||||
for _, msg := range msgs {
|
||||
if msg != nil && msg.Role == schema.System {
|
||||
n++
|
||||
}
|
||||
}
|
||||
return n
|
||||
}
|
||||
|
||||
// stripADKSystemMessages removes all system messages. Use before runner.Run restart when
|
||||
// genModelInput will prepend a fresh Instruction.
|
||||
func stripADKSystemMessages(msgs []adk.Message) []adk.Message {
|
||||
if len(msgs) == 0 {
|
||||
return msgs
|
||||
}
|
||||
out := make([]adk.Message, 0, len(msgs))
|
||||
for _, msg := range msgs {
|
||||
if msg == nil || msg.Role == schema.System {
|
||||
continue
|
||||
}
|
||||
out = append(out, msg)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// mergeCollectedSystemMessages collapses multiple system messages into one (or none).
|
||||
func mergeCollectedSystemMessages(systemMsgs []adk.Message) []adk.Message {
|
||||
if len(systemMsgs) == 0 {
|
||||
return nil
|
||||
}
|
||||
return normalizeSingleLeadingSystemMessage(systemMsgs, "")
|
||||
}
|
||||
@@ -0,0 +1,75 @@
|
||||
package multiagent
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/cloudwego/eino/adk"
|
||||
"github.com/cloudwego/eino/schema"
|
||||
)
|
||||
|
||||
func TestStripADKSystemMessages(t *testing.T) {
|
||||
in := []adk.Message{
|
||||
schema.SystemMessage("a"),
|
||||
schema.UserMessage("u"),
|
||||
schema.SystemMessage("b"),
|
||||
schema.AssistantMessage("x", nil),
|
||||
}
|
||||
out := stripADKSystemMessages(in)
|
||||
if len(out) != 2 {
|
||||
t.Fatalf("got %d messages, want 2", len(out))
|
||||
}
|
||||
if out[0].Role != schema.User || out[1].Role != schema.Assistant {
|
||||
t.Fatalf("unexpected roles: %s, %s", out[0].Role, out[1].Role)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEinoMessagesForRunRestart_StripsSystemFromTrace(t *testing.T) {
|
||||
holder := newModelFacingTraceHolder()
|
||||
holder.storeFromState(&adk.ChatModelAgentState{Messages: []adk.Message{
|
||||
schema.SystemMessage("sys-1"),
|
||||
schema.SystemMessage("sys-2"),
|
||||
schema.UserMessage("task"),
|
||||
}})
|
||||
msgs, src := einoMessagesForRunRestart(&einoADKRunLoopArgs{ModelFacingTrace: holder}, nil, nil, 0)
|
||||
if src != einoRestartContextModelTrace {
|
||||
t.Fatalf("source: got %q want model_trace", src)
|
||||
}
|
||||
if len(msgs) != 1 || msgs[0].Role != schema.User {
|
||||
t.Fatalf("expected user-only restart msgs, got %+v", msgs)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSystemMessageNormalizerMiddleware_MergesDuplicates(t *testing.T) {
|
||||
mw := newSystemMessageNormalizerMiddleware(nil, "test")
|
||||
state := &adk.ChatModelAgentState{Messages: []adk.Message{
|
||||
schema.SystemMessage("a"),
|
||||
schema.SystemMessage("b"),
|
||||
schema.UserMessage("u"),
|
||||
}}
|
||||
_, out, err := mw.(*systemMessageNormalizerMiddleware).BeforeModelRewriteState(context.Background(), state, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if countADKSystemMessages(out.Messages) != 1 {
|
||||
t.Fatalf("want 1 system, got %d", countADKSystemMessages(out.Messages))
|
||||
}
|
||||
if out.Messages[0].Content != "a\n\nb" {
|
||||
t.Fatalf("merged content: %q", out.Messages[0].Content)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSystemMessageNormalizerMiddleware_NoOpSingleSystem(t *testing.T) {
|
||||
mw := newSystemMessageNormalizerMiddleware(nil, "test")
|
||||
state := &adk.ChatModelAgentState{Messages: []adk.Message{
|
||||
schema.SystemMessage("only"),
|
||||
schema.UserMessage("u"),
|
||||
}}
|
||||
_, out, err := mw.(*systemMessageNormalizerMiddleware).BeforeModelRewriteState(context.Background(), state, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if out != state {
|
||||
t.Fatalf("expected same state pointer for no-op")
|
||||
}
|
||||
}
|
||||
@@ -2,7 +2,6 @@ package project
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
"cyberstrike-ai/internal/config"
|
||||
@@ -24,11 +23,11 @@ func AppendSystemPromptBlock(base, block string) string {
|
||||
|
||||
const (
|
||||
factIndexFooterGetDetail = "需要完整内容(攻击链、POC、请求响应等)时必须调用 get_project_fact(fact_key),禁止凭摘要臆造细节。"
|
||||
factIndexFooterWriteHint = "写入事实时:summary 写「什么+在哪+如何验证」;body 写可复现全流程(发现/利用类 fact_key 建议 finding|chain|exploit|poc/ 前缀)。"
|
||||
factIndexFooterWriteHint = "写入事实 links 时用 from(来源 fact_key → 当前 fact),如 finding 上 {from:target/*, type:discovered_on};body 写可复现全流程(发现/利用类 fact_key 建议 finding|chain|exploit|poc/ 前缀)。"
|
||||
factIndexFooterEmpty = "需要写入请使用 upsert_project_fact;需要详情请调用 get_project_fact(fact_key)。"
|
||||
)
|
||||
|
||||
// BuildFactIndexBlock 为 Agent 系统提示生成项目黑板索引(仅 key + summary,不含 body)。
|
||||
// BuildFactIndexBlock 为 Agent 系统提示生成项目黑板索引(key + summary + 关系边 + 攻击路径,不含 body)。
|
||||
func BuildFactIndexBlock(db *database.DB, projectID string, cfg config.ProjectConfig) (string, error) {
|
||||
if db == nil || !cfg.Enabled {
|
||||
return "", nil
|
||||
@@ -47,27 +46,38 @@ func BuildFactIndexBlock(db *database.DB, projectID string, cfg config.ProjectCo
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
allEdges, _ := db.ListProjectFactEdgesByProject(projectID)
|
||||
_, incomingByTarget := indexEdgeGroupMaps(allEdges)
|
||||
|
||||
if len(facts) == 0 {
|
||||
return wrapFactIndexBlock(fmt.Sprintf("## 项目黑板索引(project: %s, id: %s)\n(暂无事实)\n%s", proj.Name, proj.ID, factIndexFooterEmpty)), nil
|
||||
}
|
||||
|
||||
sort.SliceStable(facts, func(i, j int) bool {
|
||||
if facts[i].Pinned != facts[j].Pinned {
|
||||
return facts[i].Pinned
|
||||
}
|
||||
return facts[i].UpdatedAt.After(facts[j].UpdatedAt)
|
||||
})
|
||||
sortFactsForIndex(facts)
|
||||
|
||||
maxRunes := cfg.FactIndexMaxRunesEffective()
|
||||
pathMaxRunes := cfg.FactIndexPathMaxRunesEffective()
|
||||
footer := factIndexFooterGetDetail + "\n" + factIndexFooterWriteHint
|
||||
footerRunes := len([]rune(footer))
|
||||
factsBudget := maxRunes - pathMaxRunes - footerRunes
|
||||
if factsBudget < 800 {
|
||||
factsBudget = maxRunes - footerRunes
|
||||
pathMaxRunes = 0
|
||||
}
|
||||
|
||||
indexedKeys := make(map[string]struct{}, len(facts))
|
||||
var b strings.Builder
|
||||
b.WriteString(fmt.Sprintf("## 项目黑板索引(project: %s, id: %s)\n", proj.Name, proj.ID))
|
||||
used := len([]rune(b.String()))
|
||||
omitted := 0
|
||||
|
||||
for _, f := range facts {
|
||||
line := fmt.Sprintf("- [%s] %s — %s (%s)\n", f.FactKey, f.Category, strings.TrimSpace(f.Summary), f.Confidence)
|
||||
indexedKeys[f.FactKey] = struct{}{}
|
||||
line := fmt.Sprintf("- [%s] %s — %s (%s)", f.FactKey, f.Category, strings.TrimSpace(f.Summary), f.Confidence)
|
||||
line += FormatFactIndexLinksHint(f.FactKey, incomingByTarget[f.FactKey])
|
||||
line += "\n"
|
||||
lineRunes := len([]rune(line))
|
||||
if used+lineRunes > maxRunes {
|
||||
if used+lineRunes > factsBudget {
|
||||
omitted++
|
||||
continue
|
||||
}
|
||||
@@ -78,8 +88,12 @@ func BuildFactIndexBlock(db *database.DB, projectID string, cfg config.ProjectCo
|
||||
if omitted > 0 {
|
||||
b.WriteString(fmt.Sprintf("\n(另有 %d 条未列入索引,请使用 list_project_facts 或 search_project_facts 查询。)\n", omitted))
|
||||
}
|
||||
b.WriteString(factIndexFooterGetDetail)
|
||||
b.WriteByte('\n')
|
||||
b.WriteString(factIndexFooterWriteHint)
|
||||
|
||||
if pathSection := BuildFactPathOverviewSection(allEdges, indexedKeys, pathMaxRunes); pathSection != "" {
|
||||
b.WriteString("\n")
|
||||
b.WriteString(pathSection)
|
||||
}
|
||||
|
||||
b.WriteString(footer)
|
||||
return wrapFactIndexBlock(b.String()), nil
|
||||
}
|
||||
|
||||
@@ -0,0 +1,256 @@
|
||||
package project
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
"cyberstrike-ai/internal/database"
|
||||
)
|
||||
|
||||
var (
|
||||
bodyDepFactLine = regexp.MustCompile(`(?im)^[\s\-*]*依赖事实\s*[::]\s*([a-z0-9][a-z0-9._/-]*)`)
|
||||
bodyRelFactLine = regexp.MustCompile(`(?im)^[\s\-*]*相关\s*fact_key\s*[::]\s*([a-z0-9][a-z0-9._/-]*)`)
|
||||
bodyAssocSection = regexp.MustCompile(`(?im)^##\s*关联\s*$`)
|
||||
bodySyncLinksHead = "结构化关系边(自动同步)"
|
||||
)
|
||||
|
||||
// ParseLinksFromBody 从 body「关联」段落解析 from 语义的关系边(无显式 links 时的兜底)。
|
||||
func ParseLinksFromBody(body string) []database.ProjectFactEdgeFromInput {
|
||||
body = strings.TrimSpace(body)
|
||||
if body == "" {
|
||||
return nil
|
||||
}
|
||||
seen := map[string]struct{}{}
|
||||
var out []database.ProjectFactEdgeFromInput
|
||||
add := func(key, edgeType string) {
|
||||
key = strings.TrimSpace(key)
|
||||
if key == "" {
|
||||
return
|
||||
}
|
||||
if err := database.ValidateFactKey(key); err != nil {
|
||||
return
|
||||
}
|
||||
sig := edgeType + "\x00" + key
|
||||
if _, ok := seen[sig]; ok {
|
||||
return
|
||||
}
|
||||
seen[sig] = struct{}{}
|
||||
out = append(out, database.ProjectFactEdgeFromInput{From: key, Type: edgeType})
|
||||
}
|
||||
for _, m := range bodyDepFactLine.FindAllStringSubmatch(body, -1) {
|
||||
if len(m) > 1 {
|
||||
add(m[1], "depends_on")
|
||||
}
|
||||
}
|
||||
for _, m := range bodyRelFactLine.FindAllStringSubmatch(body, -1) {
|
||||
if len(m) > 1 {
|
||||
add(m[1], "supports")
|
||||
}
|
||||
}
|
||||
// 自动同步块:type: key
|
||||
syncBlock := extractBodySyncLinksBlock(body)
|
||||
for _, line := range strings.Split(syncBlock, "\n") {
|
||||
line = strings.TrimSpace(strings.TrimPrefix(strings.TrimSpace(line), "-"))
|
||||
if line == "" {
|
||||
continue
|
||||
}
|
||||
edgeType, source, ok := strings.Cut(line, ":")
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
edgeType = strings.TrimSpace(edgeType)
|
||||
source = strings.TrimSpace(source)
|
||||
if err := database.ValidateProjectFactEdgeType(edgeType); err != nil {
|
||||
continue
|
||||
}
|
||||
add(source, edgeType)
|
||||
}
|
||||
if len(out) == 0 {
|
||||
return nil
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func extractBodySyncLinksBlock(body string) string {
|
||||
lines := strings.Split(body, "\n")
|
||||
var b strings.Builder
|
||||
inAssoc := false
|
||||
inSync := false
|
||||
for _, line := range lines {
|
||||
trim := strings.TrimSpace(line)
|
||||
if bodyAssocSection.MatchString(trim) {
|
||||
inAssoc = true
|
||||
inSync = false
|
||||
continue
|
||||
}
|
||||
if inAssoc && strings.HasPrefix(trim, "## ") && !strings.HasPrefix(trim, "## 关联") {
|
||||
break
|
||||
}
|
||||
if inAssoc && strings.Contains(trim, bodySyncLinksHead) {
|
||||
inSync = true
|
||||
continue
|
||||
}
|
||||
if inSync {
|
||||
if trim == "" || strings.HasPrefix(trim, "-") || strings.Contains(trim, ":") {
|
||||
if strings.HasPrefix(trim, "-") || (strings.Contains(trim, ":") && !strings.Contains(trim, "related_vulnerability")) {
|
||||
b.WriteString(trim)
|
||||
b.WriteByte('\n')
|
||||
}
|
||||
} else if strings.HasPrefix(trim, "##") {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
return b.String()
|
||||
}
|
||||
|
||||
// SyncBodyLinksSection 将入边镜像写入 body 的「关联」段(人读用;结构化以 links 为准)。
|
||||
func SyncBodyLinksSection(body string, edges []*database.ProjectFactEdge) string {
|
||||
body = strings.TrimSpace(body)
|
||||
block := formatBodySyncLinksBlock(edges)
|
||||
if block == "" {
|
||||
return body
|
||||
}
|
||||
if body == "" {
|
||||
return "## 关联\n" + block
|
||||
}
|
||||
lines := strings.Split(body, "\n")
|
||||
var out []string
|
||||
inAssoc := false
|
||||
replaced := false
|
||||
for i := 0; i < len(lines); i++ {
|
||||
trim := strings.TrimSpace(lines[i])
|
||||
if bodyAssocSection.MatchString(trim) {
|
||||
inAssoc = true
|
||||
out = append(out, lines[i])
|
||||
// 跳过旧同步块
|
||||
j := i + 1
|
||||
for j < len(lines) {
|
||||
t := strings.TrimSpace(lines[j])
|
||||
if strings.HasPrefix(t, "## ") {
|
||||
break
|
||||
}
|
||||
if strings.Contains(t, bodySyncLinksHead) {
|
||||
for j < len(lines) {
|
||||
t2 := strings.TrimSpace(lines[j])
|
||||
if t2 != "" && !strings.HasPrefix(t2, "-") && !strings.Contains(t2, ":") && !strings.Contains(t2, bodySyncLinksHead) {
|
||||
if strings.HasPrefix(t2, "##") {
|
||||
break
|
||||
}
|
||||
}
|
||||
j++
|
||||
if j < len(lines) && strings.HasPrefix(strings.TrimSpace(lines[j]), "## ") {
|
||||
break
|
||||
}
|
||||
if j >= len(lines) {
|
||||
break
|
||||
}
|
||||
if j > i+1 && strings.TrimSpace(lines[j-1]) == "" && strings.HasPrefix(strings.TrimSpace(lines[j]), "## ") {
|
||||
break
|
||||
}
|
||||
}
|
||||
break
|
||||
}
|
||||
j++
|
||||
}
|
||||
out = append(out, block)
|
||||
i = j - 1
|
||||
replaced = true
|
||||
continue
|
||||
}
|
||||
out = append(out, lines[i])
|
||||
}
|
||||
if !replaced {
|
||||
if !inAssoc {
|
||||
out = append(out, "", "## 关联", block)
|
||||
} else {
|
||||
out = append(out, block)
|
||||
}
|
||||
}
|
||||
return strings.TrimSpace(strings.Join(out, "\n"))
|
||||
}
|
||||
|
||||
func formatBodySyncLinksBlock(edges []*database.ProjectFactEdge) string {
|
||||
if len(edges) == 0 {
|
||||
return fmt.Sprintf("- %s:\n (暂无)", bodySyncLinksHead)
|
||||
}
|
||||
var b strings.Builder
|
||||
b.WriteString("- ")
|
||||
b.WriteString(bodySyncLinksHead)
|
||||
b.WriteString(":\n")
|
||||
for _, e := range edges {
|
||||
b.WriteString(fmt.Sprintf(" - %s: %s\n", e.EdgeType, e.SourceFactKey))
|
||||
}
|
||||
return strings.TrimRight(b.String(), "\n")
|
||||
}
|
||||
|
||||
// ResolveFactLinksForUpsert 合并显式 links、links_text 与 body 解析结果。
|
||||
func ResolveFactLinksForUpsert(explicit []database.ProjectFactEdgeFromInput, linksText *string, body string, explicitSet bool) ([]database.ProjectFactEdgeFromInput, bool, error) {
|
||||
if explicitSet {
|
||||
if len(explicit) > 0 {
|
||||
return explicit, true, nil
|
||||
}
|
||||
if linksText != nil {
|
||||
parsed, err := ParseFactLinksText(*linksText)
|
||||
if err != nil {
|
||||
return nil, true, err
|
||||
}
|
||||
if parsed == nil {
|
||||
return []database.ProjectFactEdgeFromInput{}, true, nil
|
||||
}
|
||||
return parsed, true, nil
|
||||
}
|
||||
return []database.ProjectFactEdgeFromInput{}, true, nil
|
||||
}
|
||||
if parsed := ParseLinksFromBody(body); len(parsed) > 0 {
|
||||
return parsed, true, nil
|
||||
}
|
||||
return nil, false, nil
|
||||
}
|
||||
|
||||
// MergeLinkFromInputsUnique 合并多组 from 入边输入并去重。
|
||||
func MergeLinkFromInputsUnique(groups ...[]database.ProjectFactEdgeFromInput) []database.ProjectFactEdgeFromInput {
|
||||
seen := map[string]struct{}{}
|
||||
var out []database.ProjectFactEdgeFromInput
|
||||
for _, g := range groups {
|
||||
for _, in := range g {
|
||||
sig := in.Type + "\x00" + in.From
|
||||
if _, ok := seen[sig]; ok {
|
||||
continue
|
||||
}
|
||||
if err := database.ValidateProjectFactEdgeType(in.Type); err != nil {
|
||||
continue
|
||||
}
|
||||
if err := database.ValidateFactKey(in.From); err != nil {
|
||||
continue
|
||||
}
|
||||
seen[sig] = struct{}{}
|
||||
out = append(out, in)
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// MergeLinkInputsUnique 合并多组 link 输入并去重(内部出边写入用)。
|
||||
func MergeLinkInputsUnique(groups ...[]database.ProjectFactEdgeInput) []database.ProjectFactEdgeInput {
|
||||
seen := map[string]struct{}{}
|
||||
var out []database.ProjectFactEdgeInput
|
||||
for _, g := range groups {
|
||||
for _, in := range g {
|
||||
sig := in.Type + "\x00" + in.To
|
||||
if _, ok := seen[sig]; ok {
|
||||
continue
|
||||
}
|
||||
if err := database.ValidateProjectFactEdgeType(in.Type); err != nil {
|
||||
continue
|
||||
}
|
||||
if err := database.ValidateFactKey(in.To); err != nil {
|
||||
continue
|
||||
}
|
||||
seen[sig] = struct{}{}
|
||||
out = append(out, in)
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
@@ -0,0 +1,68 @@
|
||||
package project
|
||||
|
||||
import (
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"cyberstrike-ai/internal/database"
|
||||
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
func TestParseLinksFromBodyDependsOn(t *testing.T) {
|
||||
t.Parallel()
|
||||
body := "## 关联\n- 依赖事实: target/api\n- 相关 fact_key: auth/session"
|
||||
links := ParseLinksFromBody(body)
|
||||
if len(links) != 2 {
|
||||
t.Fatalf("want 2 links, got %d", len(links))
|
||||
}
|
||||
}
|
||||
|
||||
func TestSyncBodyLinksSection(t *testing.T) {
|
||||
t.Parallel()
|
||||
body := "## 结论\nx\n\n## 关联\n- 依赖事实: old/key"
|
||||
edges := []*database.ProjectFactEdge{{EdgeType: "discovered_on", SourceFactKey: "target/a"}}
|
||||
out := SyncBodyLinksSection(body, edges)
|
||||
if !strings.Contains(out, "discovered_on: target/a") {
|
||||
t.Fatalf("missing synced edge: %q", out)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFactGraphIntegration(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
dbPath := filepath.Join(dir, "test.db")
|
||||
db, err := database.NewDB(dbPath, zap.NewNop())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
p, err := db.CreateProject(&database.Project{Name: "g"})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
for _, spec := range []struct{ key, cat, summary string }{
|
||||
{"target/root", "target", "root"},
|
||||
{"finding/x", "finding", "finding x"},
|
||||
} {
|
||||
_, err := db.UpsertProjectFact(&database.ProjectFact{
|
||||
ProjectID: p.ID, FactKey: spec.key, Category: spec.cat, Summary: spec.summary, Confidence: "confirmed",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
if err := db.ReplaceIncomingProjectFactEdges(p.ID, "finding/x", []database.ProjectFactEdgeFromInput{
|
||||
{From: "target/root", Type: "discovered_on"},
|
||||
}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
graph, err := BuildProjectFactGraph(db, p.ID, "path", true)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(graph.Nodes) < 2 || len(graph.Edges) < 1 {
|
||||
t.Fatalf("expected graph nodes/edges, got %d/%d", len(graph.Nodes), len(graph.Edges))
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,407 @@
|
||||
package project
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"cyberstrike-ai/internal/database"
|
||||
"cyberstrike-ai/internal/projectprompt"
|
||||
)
|
||||
|
||||
// PathGraphCategories 攻击路径视图包含的事实分类。
|
||||
var PathGraphCategories = map[string]struct{}{
|
||||
FactCategoryTarget: {},
|
||||
FactCategoryFinding: {},
|
||||
FactCategoryChain: {},
|
||||
FactCategoryExploit: {},
|
||||
FactCategoryPOC: {},
|
||||
"vuln": {},
|
||||
}
|
||||
|
||||
// GraphNodeType 将 fact category 映射为图节点类型(供前端样式与 ELK 分层)。
|
||||
// 优先使用 category;仅 synthetic 节点(vuln:)或无 category 时才回退到 fact_key 前缀。
|
||||
func GraphNodeType(category, factKey string) string {
|
||||
key := strings.ToLower(strings.TrimSpace(factKey))
|
||||
if strings.HasPrefix(key, "vuln:") {
|
||||
return "vulnerability"
|
||||
}
|
||||
c := strings.ToLower(strings.TrimSpace(category))
|
||||
if c != "" {
|
||||
switch c {
|
||||
case FactCategoryTarget:
|
||||
return "target"
|
||||
case FactCategoryExploit:
|
||||
return "exploit"
|
||||
case FactCategoryPOC:
|
||||
return "poc"
|
||||
case FactCategoryChain:
|
||||
return "chain"
|
||||
case FactCategoryFinding:
|
||||
return "finding"
|
||||
case "vuln":
|
||||
return "vulnerability"
|
||||
case FactCategoryAuth:
|
||||
return "auth"
|
||||
case FactCategoryInfra, FactCategoryBusiness:
|
||||
return "infra"
|
||||
case FactCategoryNote:
|
||||
return "note"
|
||||
case "missing":
|
||||
return "missing"
|
||||
default:
|
||||
return c
|
||||
}
|
||||
}
|
||||
switch {
|
||||
case strings.HasPrefix(key, "target/"):
|
||||
return "target"
|
||||
case strings.HasPrefix(key, "exploit/"), strings.HasPrefix(key, "evidence/"):
|
||||
return "exploit"
|
||||
case strings.HasPrefix(key, "poc/"):
|
||||
return "poc"
|
||||
case strings.HasPrefix(key, "chain/"):
|
||||
return "chain"
|
||||
case strings.HasPrefix(key, "finding/"):
|
||||
return "finding"
|
||||
case strings.HasPrefix(key, "auth/"):
|
||||
return "auth"
|
||||
case strings.HasPrefix(key, "infra/"), strings.HasPrefix(key, "business/"):
|
||||
return "infra"
|
||||
default:
|
||||
return "note"
|
||||
}
|
||||
}
|
||||
|
||||
func truncateGraphLabel(summary string, maxRunes int) string {
|
||||
summary = strings.TrimSpace(summary)
|
||||
if summary == "" {
|
||||
return "—"
|
||||
}
|
||||
r := []rune(summary)
|
||||
if len(r) <= maxRunes {
|
||||
return summary
|
||||
}
|
||||
return string(r[:maxRunes]) + "…"
|
||||
}
|
||||
|
||||
// BuildProjectFactGraph 构建项目事实图(nodes + edges)。
|
||||
func BuildProjectFactGraph(db *database.DB, projectID string, view string, excludeDeprecated bool) (*database.ProjectFactGraph, error) {
|
||||
if db == nil {
|
||||
return nil, fmt.Errorf("database 未初始化")
|
||||
}
|
||||
projectID = strings.TrimSpace(projectID)
|
||||
if projectID == "" {
|
||||
return nil, fmt.Errorf("project_id 不能为空")
|
||||
}
|
||||
|
||||
view = strings.TrimSpace(strings.ToLower(view))
|
||||
if view == "" {
|
||||
view = "path"
|
||||
}
|
||||
|
||||
filter := database.ProjectFactListFilter{}
|
||||
if excludeDeprecated {
|
||||
filter.ExcludeDeprecated = true
|
||||
}
|
||||
facts, err := db.ListProjectFacts(projectID, filter, 1000, 0)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
edges, err := db.ListProjectFactEdgesByProject(projectID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if excludeDeprecated {
|
||||
edges = filterDeprecatedEdges(edges)
|
||||
}
|
||||
|
||||
factByKey := make(map[string]*database.ProjectFact, len(facts))
|
||||
for _, f := range facts {
|
||||
factByKey[f.FactKey] = f
|
||||
}
|
||||
|
||||
pathMode := view == "path"
|
||||
nodeKeys := make(map[string]struct{})
|
||||
|
||||
if pathMode {
|
||||
for _, f := range facts {
|
||||
if isPathGraphFact(f.Category, f.FactKey) {
|
||||
nodeKeys[f.FactKey] = struct{}{}
|
||||
}
|
||||
}
|
||||
// 路径视图中保留作为依赖目标的 auth/infra 节点
|
||||
for _, e := range edges {
|
||||
if _, ok := nodeKeys[e.SourceFactKey]; !ok {
|
||||
continue
|
||||
}
|
||||
if f, ok := factByKey[e.TargetFactKey]; ok && isDependencyGraphFact(f.Category, f.FactKey) {
|
||||
nodeKeys[e.TargetFactKey] = struct{}{}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for _, f := range facts {
|
||||
nodeKeys[f.FactKey] = struct{}{}
|
||||
}
|
||||
}
|
||||
|
||||
// 边上引用的 endpoint 纳入节点集
|
||||
for _, e := range edges {
|
||||
if pathMode {
|
||||
if _, ok := nodeKeys[e.SourceFactKey]; !ok {
|
||||
continue
|
||||
}
|
||||
if _, ok := nodeKeys[e.TargetFactKey]; ok {
|
||||
// already included
|
||||
} else if f, ok := factByKey[e.TargetFactKey]; !ok {
|
||||
nodeKeys[e.TargetFactKey] = struct{}{} // 占位节点
|
||||
} else if isPathGraphFact(f.Category, f.FactKey) || isDependencyGraphFact(f.Category, f.FactKey) {
|
||||
nodeKeys[e.TargetFactKey] = struct{}{}
|
||||
} else {
|
||||
continue
|
||||
}
|
||||
} else {
|
||||
nodeKeys[e.SourceFactKey] = struct{}{}
|
||||
nodeKeys[e.TargetFactKey] = struct{}{}
|
||||
}
|
||||
}
|
||||
|
||||
nodes := make([]database.ProjectFactGraphNode, 0, len(nodeKeys))
|
||||
for key := range nodeKeys {
|
||||
if f, ok := factByKey[key]; ok {
|
||||
nodes = append(nodes, database.ProjectFactGraphNode{
|
||||
ID: f.FactKey,
|
||||
FactKey: f.FactKey,
|
||||
Category: f.Category,
|
||||
Label: truncateGraphLabel(f.Summary, 48),
|
||||
Summary: strings.TrimSpace(f.Summary),
|
||||
Confidence: f.Confidence,
|
||||
Type: GraphNodeType(f.Category, f.FactKey),
|
||||
Pinned: f.Pinned,
|
||||
})
|
||||
continue
|
||||
}
|
||||
nodes = append(nodes, database.ProjectFactGraphNode{
|
||||
ID: key,
|
||||
FactKey: key,
|
||||
Category: "missing",
|
||||
Label: key,
|
||||
Confidence: "tentative",
|
||||
Type: "missing",
|
||||
Pinned: false,
|
||||
})
|
||||
}
|
||||
|
||||
graphEdges := make([]database.ProjectFactGraphEdge, 0, len(edges))
|
||||
for _, e := range edges {
|
||||
if pathMode {
|
||||
if _, ok := nodeKeys[e.SourceFactKey]; !ok {
|
||||
continue
|
||||
}
|
||||
if _, ok := nodeKeys[e.TargetFactKey]; !ok {
|
||||
continue
|
||||
}
|
||||
} else {
|
||||
if _, ok := nodeKeys[e.SourceFactKey]; !ok {
|
||||
continue
|
||||
}
|
||||
if _, ok := nodeKeys[e.TargetFactKey]; !ok {
|
||||
continue
|
||||
}
|
||||
}
|
||||
graphEdges = append(graphEdges, database.ProjectFactGraphEdge{
|
||||
ID: e.ID,
|
||||
Source: e.SourceFactKey,
|
||||
Target: e.TargetFactKey,
|
||||
Type: e.EdgeType,
|
||||
Confidence: e.Confidence,
|
||||
})
|
||||
}
|
||||
|
||||
// related_vulnerability_id 合成边(source=fact → target=vuln:<id>)
|
||||
for _, f := range facts {
|
||||
if _, ok := nodeKeys[f.FactKey]; !ok {
|
||||
continue
|
||||
}
|
||||
vid := strings.TrimSpace(f.RelatedVulnerabilityID)
|
||||
if vid == "" {
|
||||
continue
|
||||
}
|
||||
vulnNodeID := "vuln:" + vid
|
||||
if _, exists := nodeKeys[vulnNodeID]; !exists {
|
||||
nodeKeys[vulnNodeID] = struct{}{}
|
||||
label := "漏洞"
|
||||
if len(vid) >= 8 {
|
||||
label += " " + vid[:8] + "…"
|
||||
} else {
|
||||
label += " " + vid
|
||||
}
|
||||
nodes = append(nodes, database.ProjectFactGraphNode{
|
||||
ID: vulnNodeID,
|
||||
FactKey: vulnNodeID,
|
||||
Category: "vuln",
|
||||
Label: label,
|
||||
Confidence: f.Confidence,
|
||||
Type: "vulnerability",
|
||||
Pinned: false,
|
||||
})
|
||||
}
|
||||
graphEdges = append(graphEdges, database.ProjectFactGraphEdge{
|
||||
ID: "vuln-link:" + f.FactKey + ":" + vid,
|
||||
Source: f.FactKey,
|
||||
Target: vulnNodeID,
|
||||
Type: "links_vuln",
|
||||
Confidence: f.Confidence,
|
||||
})
|
||||
}
|
||||
|
||||
return &database.ProjectFactGraph{Nodes: nodes, Edges: graphEdges}, nil
|
||||
}
|
||||
|
||||
func min(a, b int) int {
|
||||
if a < b {
|
||||
return a
|
||||
}
|
||||
return b
|
||||
}
|
||||
|
||||
func isPathGraphFact(category, factKey string) bool {
|
||||
c := strings.ToLower(strings.TrimSpace(category))
|
||||
if _, ok := PathGraphCategories[c]; ok {
|
||||
return true
|
||||
}
|
||||
if c != "" {
|
||||
return false
|
||||
}
|
||||
key := strings.ToLower(strings.TrimSpace(factKey))
|
||||
for _, p := range []string{"target/", "finding/", "chain/", "exploit/", "poc/", "evidence/"} {
|
||||
if strings.HasPrefix(key, p) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func isDependencyGraphFact(category, factKey string) bool {
|
||||
c := strings.ToLower(strings.TrimSpace(category))
|
||||
if c == FactCategoryAuth || c == FactCategoryInfra || c == FactCategoryBusiness {
|
||||
return true
|
||||
}
|
||||
if c != "" {
|
||||
return false
|
||||
}
|
||||
key := strings.ToLower(strings.TrimSpace(factKey))
|
||||
return strings.HasPrefix(key, "auth/") || strings.HasPrefix(key, "infra/") || strings.HasPrefix(key, "business/")
|
||||
}
|
||||
|
||||
func filterDeprecatedEdges(edges []*database.ProjectFactEdge) []*database.ProjectFactEdge {
|
||||
out := make([]*database.ProjectFactEdge, 0, len(edges))
|
||||
for _, e := range edges {
|
||||
if strings.EqualFold(strings.TrimSpace(e.Confidence), "deprecated") {
|
||||
continue
|
||||
}
|
||||
out = append(out, e)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// ParsedFactLinks 解析 links 参数(from → 当前 fact)。
|
||||
type ParsedFactLinks struct {
|
||||
Incoming []database.ProjectFactEdgeFromInput
|
||||
}
|
||||
|
||||
// ParseFactLinkInputs 从 MCP links 参数解析;空数组表示清空全部入边。
|
||||
func ParseFactLinkInputs(raw interface{}) (*ParsedFactLinks, error) {
|
||||
if raw == nil {
|
||||
return nil, nil
|
||||
}
|
||||
items, ok := raw.([]interface{})
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("links 须为数组")
|
||||
}
|
||||
if len(items) == 0 {
|
||||
return &ParsedFactLinks{
|
||||
Incoming: []database.ProjectFactEdgeFromInput{},
|
||||
}, nil
|
||||
}
|
||||
parsed := &ParsedFactLinks{}
|
||||
for i, item := range items {
|
||||
m, ok := item.(map[string]interface{})
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("links[%d] 格式无效", i)
|
||||
}
|
||||
from, _ := m["from"].(string)
|
||||
edgeType, _ := m["type"].(string)
|
||||
from = strings.TrimSpace(from)
|
||||
edgeType = strings.TrimSpace(edgeType)
|
||||
if from == "" {
|
||||
return nil, fmt.Errorf("links[%d] 须含 from", i)
|
||||
}
|
||||
if edgeType == "" {
|
||||
return nil, fmt.Errorf("links[%d] 须含 type", i)
|
||||
}
|
||||
conf, _ := m["confidence"].(string)
|
||||
parsed.Incoming = append(parsed.Incoming, database.ProjectFactEdgeFromInput{
|
||||
From: from, Type: edgeType, Confidence: strings.TrimSpace(conf),
|
||||
})
|
||||
}
|
||||
return parsed, nil
|
||||
}
|
||||
|
||||
// ParseFactLinksText 解析 UI 文本:`type: source_fact_key` 每行一条(from 语义)。
|
||||
func ParseFactLinksText(text string) ([]database.ProjectFactEdgeFromInput, error) {
|
||||
return ParseFactIncomingLinksText(text)
|
||||
}
|
||||
|
||||
// FormatFactLinksText 将入边格式化为 UI 文本。
|
||||
func FormatFactLinksText(edges []*database.ProjectFactEdge) string {
|
||||
return FormatFactIncomingLinksText(edges)
|
||||
}
|
||||
|
||||
// ParseFactIncomingLinksText 解析 UI 入边文本:`type: source_fact_key` 每行一条。
|
||||
func ParseFactIncomingLinksText(text string) ([]database.ProjectFactEdgeFromInput, error) {
|
||||
text = strings.TrimSpace(text)
|
||||
if text == "" {
|
||||
return nil, nil
|
||||
}
|
||||
var out []database.ProjectFactEdgeFromInput
|
||||
for i, line := range strings.Split(text, "\n") {
|
||||
line = strings.TrimSpace(line)
|
||||
if line == "" || strings.HasPrefix(line, "#") {
|
||||
continue
|
||||
}
|
||||
edgeType, source, ok := strings.Cut(line, ":")
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("第 %d 行格式无效,应为 type: fact_key", i+1)
|
||||
}
|
||||
edgeType = strings.TrimSpace(edgeType)
|
||||
source = strings.TrimSpace(source)
|
||||
if edgeType == "" || source == "" {
|
||||
return nil, fmt.Errorf("第 %d 行 type 或 fact_key 为空", i+1)
|
||||
}
|
||||
out = append(out, database.ProjectFactEdgeFromInput{From: source, Type: edgeType})
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
// FormatFactIncomingLinksText 将入边格式化为 UI 文本。
|
||||
func FormatFactIncomingLinksText(edges []*database.ProjectFactEdge) string {
|
||||
if len(edges) == 0 {
|
||||
return ""
|
||||
}
|
||||
var b strings.Builder
|
||||
for i, e := range edges {
|
||||
if i > 0 {
|
||||
b.WriteByte('\n')
|
||||
}
|
||||
b.WriteString(e.EdgeType)
|
||||
b.WriteString(": ")
|
||||
b.WriteString(e.SourceFactKey)
|
||||
}
|
||||
return b.String()
|
||||
}
|
||||
|
||||
// FactEdgeRecordingGuidance 写入边时的 Agent 规范。
|
||||
func FactEdgeRecordingGuidance() string {
|
||||
return projectprompt.FactEdgeRecordingGuidance()
|
||||
}
|
||||
@@ -0,0 +1,96 @@
|
||||
package project
|
||||
|
||||
import (
|
||||
"cyberstrike-ai/internal/database"
|
||||
)
|
||||
|
||||
// ApplyFactOutgoingLinks 替换某事实的出边(links 为 nil 时不修改)。
|
||||
func ApplyFactOutgoingLinks(db *database.DB, projectID, sourceFactKey, sourceConversationID string, links []database.ProjectFactEdgeInput) error {
|
||||
if links == nil {
|
||||
return nil
|
||||
}
|
||||
return db.ReplaceOutgoingProjectFactEdges(projectID, sourceFactKey, sourceConversationID, links)
|
||||
}
|
||||
|
||||
// ResolveFactLinkInputs 合并 links 数组与 links_text 文本(数组优先)。
|
||||
func ResolveFactLinkInputs(links []database.ProjectFactEdgeFromInput, linksText string) ([]database.ProjectFactEdgeFromInput, error) {
|
||||
if len(links) > 0 {
|
||||
return links, nil
|
||||
}
|
||||
return ParseFactLinksText(linksText)
|
||||
}
|
||||
|
||||
// ApplyFactIncomingLinks 替换某事实的入边(links 为 nil 时不修改)。
|
||||
func ApplyFactIncomingLinks(db *database.DB, projectID, targetFactKey string, links []database.ProjectFactEdgeFromInput) error {
|
||||
if links == nil {
|
||||
return nil
|
||||
}
|
||||
return db.ReplaceIncomingProjectFactEdges(projectID, targetFactKey, links)
|
||||
}
|
||||
|
||||
// PersistFactIncomingLinks 写入入边并可选同步当前事实 body「关联」段。
|
||||
func PersistFactIncomingLinks(db *database.DB, projectID, targetFactKey string, links []database.ProjectFactEdgeFromInput, syncBody bool) error {
|
||||
if links == nil {
|
||||
return nil
|
||||
}
|
||||
if err := ApplyFactIncomingLinks(db, projectID, targetFactKey, links); err != nil {
|
||||
return err
|
||||
}
|
||||
if !syncBody {
|
||||
return nil
|
||||
}
|
||||
f, err := db.GetProjectFactByKey(projectID, targetFactKey)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
in, err := db.ListIncomingProjectFactEdges(projectID, targetFactKey)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
f.Body = SyncBodyLinksSection(f.Body, in)
|
||||
_, err = db.UpsertProjectFact(f)
|
||||
return err
|
||||
}
|
||||
|
||||
// PersistFactLinksFromParsed 写入解析后的 links(parsed 为 nil 表示不修改)。
|
||||
func PersistFactLinksFromParsed(db *database.DB, projectID, factKey, sourceConversationID string, parsed *ParsedFactLinks, syncBody bool) error {
|
||||
if parsed == nil || parsed.Incoming == nil {
|
||||
return nil
|
||||
}
|
||||
return PersistFactIncomingLinks(db, projectID, factKey, parsed.Incoming, syncBody)
|
||||
}
|
||||
|
||||
// PersistFactOutgoingLinks 写入出边(图连线等低层 API;body 同步请用 PersistFactIncomingLinks)。
|
||||
func PersistFactOutgoingLinks(db *database.DB, projectID, sourceFactKey, sourceConversationID string, links []database.ProjectFactEdgeInput, syncBody bool) error {
|
||||
if links == nil {
|
||||
return nil
|
||||
}
|
||||
return ApplyFactOutgoingLinks(db, projectID, sourceFactKey, sourceConversationID, links)
|
||||
}
|
||||
|
||||
// LinkCountMap 项目内各 fact 的入/出边计数。
|
||||
type LinkCountMap map[string]LinkCounts
|
||||
|
||||
// LinkCounts 单 fact 的入/出边数。
|
||||
type LinkCounts struct {
|
||||
Outgoing int `json:"outgoing"`
|
||||
Incoming int `json:"incoming"`
|
||||
}
|
||||
|
||||
// LoadProjectFactLinkCounts 批量加载边计数。
|
||||
func LoadProjectFactLinkCounts(db *database.DB, projectID string) (LinkCountMap, error) {
|
||||
edges, err := db.ListProjectFactEdgesByProject(projectID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
m := LinkCountMap{}
|
||||
for _, e := range edges {
|
||||
c := m[e.SourceFactKey]
|
||||
c.Outgoing++
|
||||
m[e.SourceFactKey] = c
|
||||
c = m[e.TargetFactKey]
|
||||
c.Incoming++
|
||||
m[e.TargetFactKey] = c
|
||||
}
|
||||
return m, nil
|
||||
}
|
||||
@@ -0,0 +1,296 @@
|
||||
package project
|
||||
|
||||
import (
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"cyberstrike-ai/internal/database"
|
||||
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
func TestParseFactLinksText(t *testing.T) {
|
||||
t.Parallel()
|
||||
inputs, err := ParseFactLinksText("discovered_on: target/api\nleads_to: finding/swagger")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(inputs) != 2 {
|
||||
t.Fatalf("want 2 links, got %d", len(inputs))
|
||||
}
|
||||
if inputs[0].Type != "discovered_on" || inputs[0].From != "target/api" {
|
||||
t.Fatalf("unexpected first link: %+v", inputs[0])
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseFactIncomingLinksText(t *testing.T) {
|
||||
t.Parallel()
|
||||
inputs, err := ParseFactIncomingLinksText("leads_to: finding/swagger\ndepends_on: target/api")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(inputs) != 2 {
|
||||
t.Fatalf("want 2 links, got %d", len(inputs))
|
||||
}
|
||||
if inputs[0].Type != "leads_to" || inputs[0].From != "finding/swagger" {
|
||||
t.Fatalf("unexpected first link: %+v", inputs[0])
|
||||
}
|
||||
}
|
||||
|
||||
func TestFormatFactIncomingLinksText(t *testing.T) {
|
||||
t.Parallel()
|
||||
text := FormatFactIncomingLinksText([]*database.ProjectFactEdge{
|
||||
{EdgeType: "leads_to", SourceFactKey: "finding/a"},
|
||||
{EdgeType: "depends_on", SourceFactKey: "target/b"},
|
||||
})
|
||||
want := "leads_to: finding/a\ndepends_on: target/b"
|
||||
if text != want {
|
||||
t.Fatalf("got %q want %q", text, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseFactLinkInputsEmptyClears(t *testing.T) {
|
||||
t.Parallel()
|
||||
parsed, err := ParseFactLinkInputs([]interface{}{})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if parsed == nil || parsed.Incoming == nil || len(parsed.Incoming) != 0 {
|
||||
t.Fatalf("empty array should clear incoming links, got %v", parsed)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseFactLinkInputsFrom(t *testing.T) {
|
||||
t.Parallel()
|
||||
raw := []interface{}{
|
||||
map[string]interface{}{
|
||||
"from": "target/primary_domain",
|
||||
"type": "discovered_on",
|
||||
},
|
||||
}
|
||||
parsed, err := ParseFactLinkInputs(raw)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(parsed.Incoming) != 1 || parsed.Incoming[0].From != "target/primary_domain" {
|
||||
t.Fatalf("unexpected incoming: %+v", parsed.Incoming)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseFactLinkInputsRequiresFrom(t *testing.T) {
|
||||
t.Parallel()
|
||||
raw := []interface{}{
|
||||
map[string]interface{}{
|
||||
"to": "target/primary_domain",
|
||||
"type": "discovered_on",
|
||||
},
|
||||
}
|
||||
_, err := ParseFactLinkInputs(raw)
|
||||
if err == nil {
|
||||
t.Fatal("expected error when from is missing")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGraphNodeType(t *testing.T) {
|
||||
t.Parallel()
|
||||
if GraphNodeType("chain", "chain/x") != "chain" {
|
||||
t.Fatal("chain category")
|
||||
}
|
||||
if GraphNodeType("finding", "finding/x") != "finding" {
|
||||
t.Fatal("finding category")
|
||||
}
|
||||
if GraphNodeType("exploit", "exploit/x") != "exploit" {
|
||||
t.Fatal("exploit category")
|
||||
}
|
||||
if GraphNodeType("finding", "evidence/x") != "finding" {
|
||||
t.Fatal("category should override evidence key prefix")
|
||||
}
|
||||
if GraphNodeType("note", "target/x") != "note" {
|
||||
t.Fatal("category should override target key prefix")
|
||||
}
|
||||
if GraphNodeType("vuln", "finding/x") != "vulnerability" {
|
||||
t.Fatal("vuln category maps to vulnerability node type")
|
||||
}
|
||||
if GraphNodeType("", "target/x") != "target" {
|
||||
t.Fatal("empty category falls back to target key prefix")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildProjectFactGraphPreservesStoredEdgeDirection(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
db, err := database.NewDB(filepath.Join(dir, "test.db"), zap.NewNop())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
p, err := db.CreateProject(&database.Project{Name: "path-edges"})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
for _, spec := range []struct{ key, cat string }{
|
||||
{"target/primary_domain", "target"},
|
||||
{"chain/full_attack_path", "chain"},
|
||||
{"finding/mysql_public", "finding"},
|
||||
{"exploit/mysql_creds_extract", "exploit"},
|
||||
} {
|
||||
if _, err := db.UpsertProjectFact(&database.ProjectFact{
|
||||
ProjectID: p.ID, FactKey: spec.key, Category: spec.cat, Summary: spec.key, Confidence: "confirmed",
|
||||
}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
if err := db.ReplaceIncomingProjectFactEdges(p.ID, "finding/mysql_public", []database.ProjectFactEdgeFromInput{
|
||||
{From: "target/primary_domain", Type: "discovered_on"},
|
||||
}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := db.ReplaceIncomingProjectFactEdges(p.ID, "finding/mysql_public", []database.ProjectFactEdgeFromInput{
|
||||
{From: "target/primary_domain", Type: "discovered_on"},
|
||||
{From: "exploit/mysql_creds_extract", Type: "exploits"},
|
||||
}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := db.ReplaceIncomingProjectFactEdges(p.ID, "chain/full_attack_path", []database.ProjectFactEdgeFromInput{
|
||||
{From: "target/primary_domain", Type: "discovered_on"},
|
||||
}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := db.ReplaceIncomingProjectFactEdges(p.ID, "exploit/mysql_creds_extract", []database.ProjectFactEdgeFromInput{
|
||||
{From: "chain/full_attack_path", Type: "leads_to"},
|
||||
}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
graph, err := BuildProjectFactGraph(db, p.ID, "path", true)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
want := map[string]struct{}{
|
||||
"target/primary_domain|discovered_on|finding/mysql_public": {},
|
||||
"exploit/mysql_creds_extract|exploits|finding/mysql_public": {},
|
||||
"target/primary_domain|discovered_on|chain/full_attack_path": {},
|
||||
"chain/full_attack_path|leads_to|exploit/mysql_creds_extract": {},
|
||||
}
|
||||
for _, e := range graph.Edges {
|
||||
key := e.Source + "|" + e.Type + "|" + e.Target
|
||||
delete(want, key)
|
||||
}
|
||||
if len(want) > 0 {
|
||||
t.Fatalf("missing expected stored-direction edges: %v", want)
|
||||
}
|
||||
countInOut := func(factKey string) (out, in int) {
|
||||
for _, e := range graph.Edges {
|
||||
if e.Source == factKey {
|
||||
out++
|
||||
}
|
||||
if e.Target == factKey {
|
||||
in++
|
||||
}
|
||||
}
|
||||
return out, in
|
||||
}
|
||||
if out, in := countInOut("chain/full_attack_path"); out != 1 || in != 1 {
|
||||
t.Fatalf("chain/full_attack_path want out=1 in=1 got out=%d in=%d", out, in)
|
||||
}
|
||||
if out, in := countInOut("exploit/mysql_creds_extract"); out != 1 || in != 1 {
|
||||
t.Fatalf("exploit/mysql_creds_extract want out=1 in=1 got out=%d in=%d", out, in)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPersistFactLinksFromUsesFromAsIncoming(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
db, err := database.NewDB(filepath.Join(dir, "test.db"), zap.NewNop())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
p, err := db.CreateProject(&database.Project{Name: "from-links"})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
for _, spec := range []struct{ key, cat string }{
|
||||
{"target/primary_domain", "target"},
|
||||
{"finding/sqli", "finding"},
|
||||
} {
|
||||
if _, err := db.UpsertProjectFact(&database.ProjectFact{
|
||||
ProjectID: p.ID, FactKey: spec.key, Category: spec.cat, Summary: spec.key, Confidence: "confirmed",
|
||||
}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
parsed := &ParsedFactLinks{
|
||||
Incoming: []database.ProjectFactEdgeFromInput{
|
||||
{From: "target/primary_domain", Type: "discovered_on"},
|
||||
},
|
||||
}
|
||||
if err := PersistFactLinksFromParsed(db, p.ID, "finding/sqli", "", parsed, false); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
graph, err := BuildProjectFactGraph(db, p.ID, "path", true)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
want := "target/primary_domain|discovered_on|finding/sqli"
|
||||
for _, e := range graph.Edges {
|
||||
key := e.Source + "|" + e.Type + "|" + e.Target
|
||||
if key == want {
|
||||
return
|
||||
}
|
||||
}
|
||||
t.Fatalf("expected edge %s, got %+v", want, graph.Edges)
|
||||
}
|
||||
|
||||
func TestFormatOutgoingLinksHint(t *testing.T) {
|
||||
t.Parallel()
|
||||
hint := FormatOutgoingLinksHint([]*database.ProjectFactEdge{
|
||||
{EdgeType: "discovered_on", TargetFactKey: "target/a"},
|
||||
})
|
||||
if hint == "" || hint[0] != ' ' {
|
||||
t.Fatalf("unexpected hint: %q", hint)
|
||||
}
|
||||
}
|
||||
|
||||
func TestReplaceIncomingAllowsNotYetCreatedSource(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
db, err := database.NewDB(filepath.Join(dir, "test.db"), zap.NewNop())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
p, err := db.CreateProject(&database.Project{Name: "parallel-links"})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if _, err := db.UpsertProjectFact(&database.ProjectFact{
|
||||
ProjectID: p.ID, FactKey: "exploit/sqli", Category: "exploit", Summary: "exploit", Confidence: "confirmed",
|
||||
}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := db.ReplaceIncomingProjectFactEdges(p.ID, "exploit/sqli", []database.ProjectFactEdgeFromInput{
|
||||
{From: "finding/sqli_endpoint", Type: "exploits"},
|
||||
}); err != nil {
|
||||
t.Fatalf("incoming edge should not require source fact to exist yet: %v", err)
|
||||
}
|
||||
if _, err := db.UpsertProjectFact(&database.ProjectFact{
|
||||
ProjectID: p.ID, FactKey: "finding/sqli_endpoint", Category: "finding", Summary: "finding", Confidence: "confirmed",
|
||||
}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
in, err := db.ListIncomingProjectFactEdges(p.ID, "exploit/sqli")
|
||||
if err != nil || len(in) != 1 || in[0].SourceFactKey != "finding/sqli_endpoint" {
|
||||
t.Fatalf("expected persisted edge from finding, got %+v err=%v", in, err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateProjectFactEdgeType(t *testing.T) {
|
||||
t.Parallel()
|
||||
if err := database.ValidateProjectFactEdgeType("leads_to"); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := database.ValidateProjectFactEdgeType("invalid"); err == nil {
|
||||
t.Fatal("expected error")
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,231 @@
|
||||
package project
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
"cyberstrike-ai/internal/database"
|
||||
)
|
||||
|
||||
var factIndexEdgeTypeOrder = []string{
|
||||
"discovered_on", "leads_to", "enables", "depends_on", "exploits", "contains", "part_of", "supports",
|
||||
}
|
||||
|
||||
func filterIndexEdges(edges []*database.ProjectFactEdge) []*database.ProjectFactEdge {
|
||||
if len(edges) == 0 {
|
||||
return nil
|
||||
}
|
||||
out := make([]*database.ProjectFactEdge, 0, len(edges))
|
||||
for _, e := range edges {
|
||||
if e == nil {
|
||||
continue
|
||||
}
|
||||
if strings.EqualFold(strings.TrimSpace(e.Confidence), "deprecated") {
|
||||
continue
|
||||
}
|
||||
edgeType := strings.ToLower(strings.TrimSpace(e.EdgeType))
|
||||
if _, ok := database.ValidProjectFactEdgeTypes[edgeType]; !ok {
|
||||
continue
|
||||
}
|
||||
out = append(out, e)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func edgeConfidenceSuffix(confidence string) string {
|
||||
c := strings.ToLower(strings.TrimSpace(confidence))
|
||||
if c == "" || c == "confirmed" {
|
||||
return ""
|
||||
}
|
||||
return " (" + c + ")"
|
||||
}
|
||||
|
||||
func formatRelationHintPart(e *database.ProjectFactEdge) string {
|
||||
return fmt.Sprintf("%s←%s%s", e.EdgeType, e.SourceFactKey, edgeConfidenceSuffix(e.Confidence))
|
||||
}
|
||||
|
||||
func formatOutgoingHintPart(e *database.ProjectFactEdge) string {
|
||||
return fmt.Sprintf("%s→%s%s", e.EdgeType, e.TargetFactKey, edgeConfidenceSuffix(e.Confidence))
|
||||
}
|
||||
|
||||
func formatIncomingHintPart(e *database.ProjectFactEdge) string {
|
||||
return formatRelationHintPart(e)
|
||||
}
|
||||
|
||||
func joinEdgeHintParts(edges []*database.ProjectFactEdge, formatter func(*database.ProjectFactEdge) string) string {
|
||||
parts := make([]string, 0, len(edges))
|
||||
for _, e := range edges {
|
||||
parts = append(parts, formatter(e))
|
||||
}
|
||||
return strings.Join(parts, ", ")
|
||||
}
|
||||
|
||||
// FormatOutgoingLinksHint 黑板索引用出边摘要(全部有效边类型,不截断)。
|
||||
func FormatOutgoingLinksHint(edges []*database.ProjectFactEdge) string {
|
||||
edges = filterIndexEdges(edges)
|
||||
if len(edges) == 0 {
|
||||
return ""
|
||||
}
|
||||
return " {出边: " + joinEdgeHintParts(edges, formatOutgoingHintPart) + "}"
|
||||
}
|
||||
|
||||
// FormatIncomingLinksHint 黑板索引用入边摘要(全部有效边类型,不截断)。
|
||||
func FormatIncomingLinksHint(edges []*database.ProjectFactEdge) string {
|
||||
edges = filterIndexEdges(edges)
|
||||
if len(edges) == 0 {
|
||||
return ""
|
||||
}
|
||||
return " {入边: " + joinEdgeHintParts(edges, formatIncomingHintPart) + "}"
|
||||
}
|
||||
|
||||
// FormatFactIndexLinksHint 黑板索引行内关系边(from → 当前 fact,与 upsert links 一致)。
|
||||
func FormatFactIndexLinksHint(_ string, incoming []*database.ProjectFactEdge) string {
|
||||
in := filterIndexEdges(incoming)
|
||||
if len(in) == 0 {
|
||||
return ""
|
||||
}
|
||||
return " {关系边: " + joinEdgeHintParts(in, formatRelationHintPart) + "}"
|
||||
}
|
||||
|
||||
func indexEdgeGroupMaps(edges []*database.ProjectFactEdge) (outgoing, incoming map[string][]*database.ProjectFactEdge) {
|
||||
outgoing = map[string][]*database.ProjectFactEdge{}
|
||||
incoming = map[string][]*database.ProjectFactEdge{}
|
||||
for _, e := range filterIndexEdges(edges) {
|
||||
outgoing[e.SourceFactKey] = append(outgoing[e.SourceFactKey], e)
|
||||
incoming[e.TargetFactKey] = append(incoming[e.TargetFactKey], e)
|
||||
}
|
||||
return outgoing, incoming
|
||||
}
|
||||
|
||||
func relationOverviewLine(e *database.ProjectFactEdge) string {
|
||||
return fmt.Sprintf("- %s → %s%s · %s", e.SourceFactKey, e.TargetFactKey, edgeConfidenceSuffix(e.Confidence), e.EdgeType)
|
||||
}
|
||||
|
||||
func indexEdgeSortKey(e *database.ProjectFactEdge) (int, int, string) {
|
||||
confRank := 0
|
||||
if strings.EqualFold(strings.TrimSpace(e.Confidence), "tentative") {
|
||||
confRank = 1
|
||||
}
|
||||
typeRank := len(factIndexEdgeTypeOrder) + 1
|
||||
for i, t := range factIndexEdgeTypeOrder {
|
||||
if strings.EqualFold(e.EdgeType, t) {
|
||||
typeRank = i
|
||||
break
|
||||
}
|
||||
}
|
||||
return confRank, typeRank, e.SourceFactKey + ">" + e.TargetFactKey + ">" + e.EdgeType
|
||||
}
|
||||
|
||||
func sortIndexOverviewEdges(edges []*database.ProjectFactEdge) {
|
||||
sort.SliceStable(edges, func(i, j int) bool {
|
||||
ci, ti, ki := indexEdgeSortKey(edges[i])
|
||||
cj, tj, kj := indexEdgeSortKey(edges[j])
|
||||
if ci != cj {
|
||||
return ci < cj
|
||||
}
|
||||
if ti != tj {
|
||||
return ti < tj
|
||||
}
|
||||
return ki < kj
|
||||
})
|
||||
}
|
||||
|
||||
// BuildFactPathOverviewSection 生成事实关系速览(全部有效边类型,不含 body)。
|
||||
func BuildFactPathOverviewSection(edges []*database.ProjectFactEdge, indexedKeys map[string]struct{}, maxRunes int) string {
|
||||
if maxRunes <= 0 {
|
||||
return ""
|
||||
}
|
||||
candidates := filterIndexEdges(edges)
|
||||
if len(candidates) == 0 {
|
||||
return ""
|
||||
}
|
||||
filtered := make([]*database.ProjectFactEdge, 0, len(candidates))
|
||||
for _, e := range candidates {
|
||||
if len(indexedKeys) > 0 {
|
||||
if _, ok := indexedKeys[e.SourceFactKey]; !ok {
|
||||
continue
|
||||
}
|
||||
if _, ok := indexedKeys[e.TargetFactKey]; !ok {
|
||||
continue
|
||||
}
|
||||
}
|
||||
filtered = append(filtered, e)
|
||||
}
|
||||
if len(filtered) == 0 {
|
||||
return ""
|
||||
}
|
||||
sortIndexOverviewEdges(filtered)
|
||||
|
||||
header := "### 攻击路径(事实关系)\n"
|
||||
header += "source → target · type(与攻击路径图/库中方向一致;写入时在目标 fact 的 links 用 from 声明来源)\n"
|
||||
var b strings.Builder
|
||||
b.WriteString(header)
|
||||
used := len([]rune(header))
|
||||
omitted := 0
|
||||
|
||||
for _, e := range filtered {
|
||||
line := relationOverviewLine(e) + "\n"
|
||||
lineRunes := len([]rune(line))
|
||||
if used+lineRunes > maxRunes {
|
||||
omitted++
|
||||
continue
|
||||
}
|
||||
b.WriteString(line)
|
||||
used += lineRunes
|
||||
}
|
||||
if omitted > 0 {
|
||||
extra := fmt.Sprintf("(另有 %d 条关系边未列入,请 get_project_fact 查看完整关系。)\n", omitted)
|
||||
if used+len([]rune(extra)) <= maxRunes {
|
||||
b.WriteString(extra)
|
||||
}
|
||||
}
|
||||
if used <= len([]rune(header)) {
|
||||
return ""
|
||||
}
|
||||
return b.String()
|
||||
}
|
||||
|
||||
func factIndexSortPriority(f *database.ProjectFact) int {
|
||||
if f == nil {
|
||||
return 0
|
||||
}
|
||||
score := 0
|
||||
if f.Pinned {
|
||||
score += 1000
|
||||
}
|
||||
c := strings.ToLower(strings.TrimSpace(f.Category))
|
||||
switch c {
|
||||
case FactCategoryTarget:
|
||||
score += 400
|
||||
case FactCategoryFinding, FactCategoryChain:
|
||||
score += 300
|
||||
case FactCategoryExploit, FactCategoryPOC:
|
||||
score += 250
|
||||
case "auth", "infra", "business":
|
||||
score += 200
|
||||
case "note":
|
||||
score += 50
|
||||
default:
|
||||
key := strings.ToLower(strings.TrimSpace(f.FactKey))
|
||||
if strings.HasPrefix(key, "target/") {
|
||||
score += 400
|
||||
} else if strings.HasPrefix(key, "finding/") || strings.HasPrefix(key, "chain/") {
|
||||
score += 300
|
||||
}
|
||||
}
|
||||
if strings.EqualFold(strings.TrimSpace(f.Confidence), "confirmed") {
|
||||
score += 80
|
||||
}
|
||||
return score
|
||||
}
|
||||
|
||||
func sortFactsForIndex(facts []*database.ProjectFact) {
|
||||
sort.SliceStable(facts, func(i, j int) bool {
|
||||
pi, pj := factIndexSortPriority(facts[i]), factIndexSortPriority(facts[j])
|
||||
if pi != pj {
|
||||
return pi > pj
|
||||
}
|
||||
return facts[i].UpdatedAt.After(facts[j].UpdatedAt)
|
||||
})
|
||||
}
|
||||
@@ -0,0 +1,161 @@
|
||||
package project
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"cyberstrike-ai/internal/config"
|
||||
"cyberstrike-ai/internal/database"
|
||||
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
func TestFormatIncomingLinksHint(t *testing.T) {
|
||||
t.Parallel()
|
||||
hint := FormatIncomingLinksHint([]*database.ProjectFactEdge{
|
||||
{EdgeType: "discovered_on", SourceFactKey: "finding/x", Confidence: "tentative"},
|
||||
})
|
||||
if !strings.Contains(hint, "入边:") {
|
||||
t.Fatalf("expected 入边 label: %q", hint)
|
||||
}
|
||||
if !strings.Contains(hint, "discovered_on←finding/x") {
|
||||
t.Fatalf("unexpected hint: %q", hint)
|
||||
}
|
||||
if !strings.Contains(hint, "tentative") {
|
||||
t.Fatalf("expected tentative in hint: %q", hint)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFormatIncomingLinksHint_allEdges(t *testing.T) {
|
||||
t.Parallel()
|
||||
edges := make([]*database.ProjectFactEdge, 0, 5)
|
||||
for i := 1; i <= 5; i++ {
|
||||
edges = append(edges, &database.ProjectFactEdge{
|
||||
EdgeType: "discovered_on",
|
||||
SourceFactKey: fmt.Sprintf("finding/f%d", i),
|
||||
Confidence: "tentative",
|
||||
})
|
||||
}
|
||||
hint := FormatIncomingLinksHint(edges)
|
||||
if strings.Contains(hint, "+") {
|
||||
t.Fatalf("should not truncate with +N: %q", hint)
|
||||
}
|
||||
for i := 1; i <= 5; i++ {
|
||||
if !strings.Contains(hint, fmt.Sprintf("finding/f%d", i)) {
|
||||
t.Fatalf("missing edge f%d in hint: %q", i, hint)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestFormatFactIndexLinksHint_incomingOnly(t *testing.T) {
|
||||
t.Parallel()
|
||||
in := []*database.ProjectFactEdge{
|
||||
{EdgeType: "discovered_on", SourceFactKey: "target/dev", Confidence: "tentative"},
|
||||
{EdgeType: "exploits", SourceFactKey: "exploit/rce", Confidence: "confirmed"},
|
||||
}
|
||||
hint := FormatFactIndexLinksHint("finding/sqli", in)
|
||||
if !strings.Contains(hint, "关系边:") {
|
||||
t.Fatalf("missing 关系边 label: %q", hint)
|
||||
}
|
||||
if !strings.Contains(hint, "discovered_on←target/dev") {
|
||||
t.Fatalf("missing discovered_on: %q", hint)
|
||||
}
|
||||
if !strings.Contains(hint, "exploits←exploit/rce") {
|
||||
t.Fatalf("missing exploits: %q", hint)
|
||||
}
|
||||
if strings.Contains(hint, "出边") || strings.Contains(hint, "入边") {
|
||||
t.Fatalf("should not use legacy 出边/入边 labels: %q", hint)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFormatFactIndexLinksHint_includesAuxiliaryEdgeTypes(t *testing.T) {
|
||||
t.Parallel()
|
||||
in := []*database.ProjectFactEdge{{EdgeType: "supports", SourceFactKey: "note/log"}}
|
||||
hint := FormatFactIndexLinksHint("finding/x", in)
|
||||
if !strings.Contains(hint, "supports←note/log") {
|
||||
t.Fatalf("supports edge should be included: %q", hint)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildFactPathOverviewSection(t *testing.T) {
|
||||
t.Parallel()
|
||||
edges := []*database.ProjectFactEdge{
|
||||
{EdgeType: "discovered_on", SourceFactKey: "target/dev", TargetFactKey: "finding/sqli", Confidence: "tentative"},
|
||||
{EdgeType: "exploits", SourceFactKey: "exploit/rce", TargetFactKey: "finding/sqli", Confidence: "confirmed"},
|
||||
{EdgeType: "supports", SourceFactKey: "note/log", TargetFactKey: "finding/sqli"},
|
||||
}
|
||||
keys := map[string]struct{}{
|
||||
"target/dev": {}, "finding/sqli": {}, "exploit/rce": {}, "note/log": {},
|
||||
}
|
||||
section := BuildFactPathOverviewSection(edges, keys, 800)
|
||||
if !strings.Contains(section, "### 攻击路径(事实关系)") {
|
||||
t.Fatalf("missing header: %q", section)
|
||||
}
|
||||
if !strings.Contains(section, "target/dev → finding/sqli") {
|
||||
t.Fatalf("missing discovered_on line: %q", section)
|
||||
}
|
||||
if !strings.Contains(section, "exploit/rce → finding/sqli") {
|
||||
t.Fatalf("missing exploits line: %q", section)
|
||||
}
|
||||
if !strings.Contains(section, "note/log → finding/sqli") {
|
||||
t.Fatalf("supports edge should be included: %q", section)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildFactIndexBlock_withLinksAndPathOverview(t *testing.T) {
|
||||
t.Parallel()
|
||||
dbPath := filepath.Join(t.TempDir(), "facts.db")
|
||||
db, err := database.NewDB(dbPath, zap.NewNop())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
proj, err := db.CreateProject(&database.Project{Name: "path-proj"})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
_, err = db.UpsertProjectFact(&database.ProjectFact{
|
||||
ProjectID: proj.ID,
|
||||
FactKey: "target/dev",
|
||||
Category: "target",
|
||||
Summary: "dev 子域",
|
||||
Confidence: "confirmed",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
_, err = db.UpsertProjectFact(&database.ProjectFact{
|
||||
ProjectID: proj.ID,
|
||||
FactKey: "finding/sqli",
|
||||
Category: "finding",
|
||||
Summary: "时间盲注",
|
||||
Confidence: "tentative",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
_, err = db.AddProjectFactEdge(proj.ID, database.ProjectFactEdgeInput{
|
||||
To: "finding/sqli",
|
||||
Type: "discovered_on",
|
||||
}, "target/dev", "")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
block, err := BuildFactIndexBlock(db, proj.ID, config.ProjectConfig{Enabled: true, FactIndexMaxRunes: 6500, FactIndexPathMaxRunes: 1000})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !strings.Contains(block, "关系边: discovered_on←target/dev") {
|
||||
t.Fatalf("finding line should include relation hint: %q", block)
|
||||
}
|
||||
if !strings.Contains(block, "### 攻击路径(事实关系)") {
|
||||
t.Fatalf("missing relation overview: %q", block)
|
||||
}
|
||||
if !strings.Contains(block, "target/dev → finding/sqli") {
|
||||
t.Fatalf("missing overview edge: %q", block)
|
||||
}
|
||||
}
|
||||
@@ -1,100 +1,23 @@
|
||||
package project
|
||||
|
||||
import (
|
||||
"strings"
|
||||
import "cyberstrike-ai/internal/projectprompt"
|
||||
|
||||
"cyberstrike-ai/internal/mcp/builtin"
|
||||
)
|
||||
|
||||
// 边渗透边记录:统一节奏文案(agents/*.md 须与 FactRecordingIncrementalRhythmMarkdown 保持一致)。
|
||||
const (
|
||||
factRhythmCore = "勿等会话结束或收尾再批量写入。每**确认**一条新认知(开放端口/服务版本、入口路径、认证态或凭据特征、可利用点或攻击面变化)后,**立即**调用 `upsert_project_fact`(同 fact_key 覆盖更新)。每**验证**出一条可复现漏洞(含 POC/影响)后,**立即**调用 `record_vulnerability`;与事实可各记一次。继续下一步工作前优先落库,避免上下文压缩后细节丢失。未绑项目时说明无法写黑板,仍在本轮保留证据摘要。"
|
||||
factRhythmCoordinatorSuffix = "委派/子任务返回新认知或漏洞时,由协调者及时写入,勿假定子代理已记。"
|
||||
factRhythmSubAgentSuffix = "若工具集中无上述工具,须在交付物末尾给出「待落库」结构化条目(fact_key 建议、summary、body/POC 要点),供协调者**立即**写入。"
|
||||
)
|
||||
|
||||
// FactRecordingIncrementalRhythmMarkdown 返回边渗透边记录节奏(Markdown,供 agents/*.md 与文档对齐)。
|
||||
// FactRecordingIncrementalRhythmMarkdown 见 projectprompt。
|
||||
func FactRecordingIncrementalRhythmMarkdown(coordinator, subAgent bool) string {
|
||||
var b strings.Builder
|
||||
b.WriteString("- **边渗透边记录(强制节奏)**:")
|
||||
b.WriteString(factRhythmCore)
|
||||
if coordinator {
|
||||
b.WriteString(factRhythmCoordinatorSuffix)
|
||||
}
|
||||
if subAgent {
|
||||
b.WriteString(factRhythmSubAgentSuffix)
|
||||
}
|
||||
return b.String()
|
||||
return projectprompt.FactRecordingIncrementalRhythmMarkdown(coordinator, subAgent)
|
||||
}
|
||||
|
||||
func factRecordingIncrementalRhythmBuiltin(coordinator, subAgent bool) string {
|
||||
var b strings.Builder
|
||||
b.WriteString("- **边渗透边记录(强制节奏)**:勿等会话结束或收尾再批量写入。每**确认**一条新认知(开放端口/服务版本、入口路径、认证态或凭据特征、可利用点或攻击面变化)后,**立即**调用 ")
|
||||
b.WriteString(builtin.ToolUpsertProjectFact)
|
||||
b.WriteString("(同 fact_key 覆盖更新)。每**验证**出一条可复现漏洞(含 POC/影响)后,**立即**调用 ")
|
||||
b.WriteString(builtin.ToolRecordVulnerability)
|
||||
b.WriteString(";与事实可各记一次。继续下一步工作前优先落库,避免上下文压缩后细节丢失。未绑项目时说明无法写黑板,仍在本轮保留证据摘要。")
|
||||
if coordinator {
|
||||
b.WriteString(factRhythmCoordinatorSuffix)
|
||||
}
|
||||
if subAgent {
|
||||
b.WriteString(factRhythmSubAgentSuffix)
|
||||
}
|
||||
return b.String()
|
||||
}
|
||||
|
||||
// FactRecordingBlackboardSection 项目黑板与漏洞记录的完整系统提示块(单/多 Agent 主代理共用)。
|
||||
// coordinatorDelegate 为 true 时追加「协调者代子代理落库」说明(Deep / plan_execute / supervisor)。
|
||||
// FactRecordingBlackboardSection 见 projectprompt。
|
||||
func FactRecordingBlackboardSection(coordinatorDelegate bool) string {
|
||||
var b strings.Builder
|
||||
b.WriteString("## 项目黑板(事实)与漏洞记录(分离)\n\n")
|
||||
b.WriteString("当前对话若已绑定项目,系统会自动注入「项目黑板索引」(仅 fact_key + 摘要)。**摘要不足时必须调用 ")
|
||||
b.WriteString(builtin.ToolGetProjectFact)
|
||||
b.WriteString("(fact_key) 获取 body,禁止凭摘要臆造细节。**\n\n")
|
||||
b.WriteString(factRecordingIncrementalRhythmBuiltin(coordinatorDelegate, false))
|
||||
b.WriteString("\n\n")
|
||||
b.WriteString("- **环境/目标/认证等认知**(非正式漏洞条目):使用 ")
|
||||
b.WriteString(builtin.ToolUpsertProjectFact)
|
||||
b.WriteString(",fact_key 建议 `category/slug`(如 target/primary_domain),同 key 覆盖更新;body 记端口/版本/凭据特征与证据来源。\n")
|
||||
b.WriteString("- **发现与利用上下文**(审计复现):fact_key 建议 finding/、chain/、exploit/、poc/ 前缀;**body 必填**完整攻击链(入口 → 步骤 → 原始请求/响应或命令 → 现象 → 关联 related_vulnerability_id),**禁止仅写结论**;summary 写「什么 + 在哪 + 如何验证」一行要点。\n")
|
||||
b.WriteString("- **可交付漏洞**:使用 ")
|
||||
b.WriteString(builtin.ToolRecordVulnerability)
|
||||
b.WriteString(",含标题、严重程度、类型、目标、证明(POC)、影响、修复建议。记前可先 ")
|
||||
b.WriteString(builtin.ToolListVulnerabilities)
|
||||
b.WriteString(" 查重,详情用 ")
|
||||
b.WriteString(builtin.ToolGetVulnerability)
|
||||
b.WriteString("(id)(默认仅当前项目/会话)。\n")
|
||||
b.WriteString("- 同一发现可能需**各记一次**(事实记**完整攻击链与 exploit 细节**供复现,漏洞记正式 findings)。误报用 ")
|
||||
b.WriteString(builtin.ToolDeprecateProjectFact)
|
||||
b.WriteString(" 或漏洞状态 false_positive。\n")
|
||||
b.WriteString("- 事实多时用 ")
|
||||
b.WriteString(builtin.ToolListProjectFacts)
|
||||
b.WriteString(" / ")
|
||||
b.WriteString(builtin.ToolSearchProjectFacts)
|
||||
b.WriteString(" 检索。\n\n")
|
||||
b.WriteString(FactRecordingGuidanceBlock())
|
||||
b.WriteString("\n\n严重程度:critical / high / medium / low / info。证明须含足够证据(请求响应、截图、命令输出等)。")
|
||||
return b.String()
|
||||
return projectprompt.FactRecordingBlackboardSection(coordinatorDelegate)
|
||||
}
|
||||
|
||||
// FactRecordingSubAgentSection 子代理边渗透边记录(无工具时输出待落库条目)。
|
||||
// FactRecordingSubAgentSection 见 projectprompt。
|
||||
func FactRecordingSubAgentSection() string {
|
||||
return "## 边渗透边记录\n\n" + factRecordingIncrementalRhythmBuiltin(false, true) + "\n"
|
||||
return projectprompt.FactRecordingSubAgentSection()
|
||||
}
|
||||
|
||||
// FactRecordingBlackboardSectionMarkdown 与 FactRecordingBlackboardSection 等价的 Markdown(工具名为字面量,供 agents/*.md)。
|
||||
// FactRecordingBlackboardSectionMarkdown 见 projectprompt。
|
||||
func FactRecordingBlackboardSectionMarkdown(coordinatorDelegate bool) string {
|
||||
var b strings.Builder
|
||||
b.WriteString("## 项目黑板(事实)与漏洞记录(分离)\n\n")
|
||||
b.WriteString("当前对话若已绑定项目,系统会自动注入「项目黑板索引」(仅 `fact_key` + 摘要)。**摘要不足时必须调用 `get_project_fact(fact_key)` 获取 body,禁止凭摘要臆造细节。**\n\n")
|
||||
b.WriteString(FactRecordingIncrementalRhythmMarkdown(coordinatorDelegate, false))
|
||||
b.WriteString("\n\n")
|
||||
b.WriteString("- **环境/目标/认证等认知**(非正式漏洞):使用 **`upsert_project_fact`**,`fact_key` 建议 `category/slug`(如 `target/primary_domain`),同 key 覆盖更新;body 记端口/版本/凭据特征与证据来源。\n")
|
||||
b.WriteString("- **发现与利用上下文**(审计复现):`fact_key` 建议 `finding/`、`chain/`、`exploit/`、`poc/` 前缀;**body 必填**完整攻击链(入口 → 步骤 → 原始请求/响应或命令 → 现象 → 关联 `related_vulnerability_id`),**禁止仅写结论**;summary 写「什么 + 在哪 + 如何验证」一行要点。\n")
|
||||
b.WriteString("- **可交付漏洞**:使用 **`record_vulnerability`**(标题、描述、严重程度、类型、目标、证明 POC、影响、修复建议)。严重程度 critical / high / medium / low / info。\n")
|
||||
b.WriteString("- 同一发现可能需**各记一次**(事实记可复现攻击链,漏洞记正式 findings)。误报用 **`deprecate_project_fact`** 或漏洞状态 false_positive。\n")
|
||||
b.WriteString("- 事实多时用 **`list_project_facts`** / **`search_project_facts`** 检索。\n\n")
|
||||
b.WriteString(FactRecordingGuidanceBlock())
|
||||
b.WriteString("\n\n严重程度:critical / high / medium / low / info。证明须含足够证据(请求响应、截图、命令输出等)。")
|
||||
return b.String()
|
||||
return projectprompt.FactRecordingBlackboardSectionMarkdown(coordinatorDelegate)
|
||||
}
|
||||
|
||||
@@ -3,6 +3,8 @@ package project
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"cyberstrike-ai/internal/projectprompt"
|
||||
)
|
||||
|
||||
// 事实 category 常量(写入 upsert_project_fact 的 category 字段)。
|
||||
@@ -90,7 +92,8 @@ const attackChainFactBodyTemplate = `## 结论(可验证,一句话)
|
||||
|
||||
## 关联
|
||||
- related_vulnerability_id: <可选,对应 record_vulnerability 的 id>
|
||||
- 依赖事实: <fact_key,如 auth/session_cookie>
|
||||
- links(upsert 参数): [{ "from": "<fact_key>", "type": "discovered_on|..." }](from → 当前 fact)
|
||||
- 依赖事实(body 可读镜像): <fact_key,如 auth/session_cookie>
|
||||
|
||||
## 备注与不确定性
|
||||
<待验证假设、环境差异、绕过尝试记录>`
|
||||
@@ -109,15 +112,7 @@ const envFactBodyTemplate = `## 摘要
|
||||
|
||||
// FactRecordingGuidanceBlock 写入系统提示:要求事实沉淀攻击链上下文而非仅结论。
|
||||
func FactRecordingGuidanceBlock() string {
|
||||
return `### 事实写入规范(审计复现 / 知识沉淀)
|
||||
|
||||
- **summary**:索引用一行,须含「什么 + 在哪 + 如何触发/验证」要点,禁止只写结论(如仅写「存在 SQLi」)。
|
||||
- **body**:完整可复现上下文,写入 ` + "`upsert_project_fact`" + ` 的 body 字段;索引不含 body,后续会话须靠 ` + "`get_project_fact`" + ` 取回。
|
||||
- **category / fact_key 建议**:
|
||||
- 环境认知:` + "`target/`" + `、` + "`auth/`" + `、` + "`infra/`" + `、` + "`business/`" + `(body 用环境模板即可)
|
||||
- 发现与利用:` + "`finding/`" + `、` + "`chain/`" + `、` + "`exploit/`" + `、` + "`poc/`" + `(**必须**用攻击链模板填满 body:入口、逐步攻击链、原始请求/响应或命令、证据、关联漏洞 ID)
|
||||
- **与漏洞记录分工**:` + "`record_vulnerability`" + ` 记可交付 findings;事实记**复现所需的全部上下文**(含失败尝试、绕过、依赖会话),二者可各记一次。
|
||||
- 更新同一发现时保持相同 ` + "`fact_key`" + ` 覆盖写入,勿散落多个 key 导致上下文丢失。`
|
||||
return projectprompt.FactRecordingGuidanceBlock()
|
||||
}
|
||||
|
||||
// SparseBodyWarning 攻击链类事实 body 不足时的工具返回提示(不阻断保存)。
|
||||
|
||||
@@ -0,0 +1,132 @@
|
||||
// Package projectprompt 提供项目黑板相关的系统提示文本(纯字符串,无 database 依赖)。
|
||||
// 供 agent / multiagent 等包引用,避免 agent → project 导入环导致 gopls 元数据失败。
|
||||
package projectprompt
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
"cyberstrike-ai/internal/mcp/builtin"
|
||||
)
|
||||
|
||||
const (
|
||||
factRhythmCore = "勿等会话结束或收尾再批量写入。每**确认**一条新认知(开放端口/服务版本、入口路径、认证态或凭据特征、可利用点或攻击面变化)后,**立即**调用 `upsert_project_fact`(同 fact_key 覆盖更新)。每**验证**出一条可复现漏洞(含 POC/影响)后,**立即**调用 `record_vulnerability`;与事实可各记一次。继续下一步工作前优先落库,避免上下文压缩后细节丢失。未绑项目时说明无法写黑板,仍在本轮保留证据摘要。"
|
||||
factRhythmCoordinatorSuffix = "委派/子任务返回新认知或漏洞时,由协调者及时写入,勿假定子代理已记。"
|
||||
factRhythmSubAgentSuffix = "若工具集中无上述工具,须在交付物末尾给出「待落库」结构化条目(fact_key 建议、summary、body/POC 要点),供协调者**立即**写入。"
|
||||
)
|
||||
|
||||
// FactRecordingIncrementalRhythmMarkdown 返回边渗透边记录节奏(Markdown,供 agents/*.md 与文档对齐)。
|
||||
func FactRecordingIncrementalRhythmMarkdown(coordinator, subAgent bool) string {
|
||||
var b strings.Builder
|
||||
b.WriteString("- **边渗透边记录(强制节奏)**:")
|
||||
b.WriteString(factRhythmCore)
|
||||
if coordinator {
|
||||
b.WriteString(factRhythmCoordinatorSuffix)
|
||||
}
|
||||
if subAgent {
|
||||
b.WriteString(factRhythmSubAgentSuffix)
|
||||
}
|
||||
return b.String()
|
||||
}
|
||||
|
||||
func factRecordingIncrementalRhythmBuiltin(coordinator, subAgent bool) string {
|
||||
var b strings.Builder
|
||||
b.WriteString("- **边渗透边记录(强制节奏)**:勿等会话结束或收尾再批量写入。每**确认**一条新认知(开放端口/服务版本、入口路径、认证态或凭据特征、可利用点或攻击面变化)后,**立即**调用 ")
|
||||
b.WriteString(builtin.ToolUpsertProjectFact)
|
||||
b.WriteString("(同 fact_key 覆盖更新)。每**验证**出一条可复现漏洞(含 POC/影响)后,**立即**调用 ")
|
||||
b.WriteString(builtin.ToolRecordVulnerability)
|
||||
b.WriteString(";与事实可各记一次。继续下一步工作前优先落库,避免上下文压缩后细节丢失。未绑项目时说明无法写黑板,仍在本轮保留证据摘要。")
|
||||
if coordinator {
|
||||
b.WriteString(factRhythmCoordinatorSuffix)
|
||||
}
|
||||
if subAgent {
|
||||
b.WriteString(factRhythmSubAgentSuffix)
|
||||
}
|
||||
return b.String()
|
||||
}
|
||||
|
||||
func factEdgeRecordingGuidance() string {
|
||||
return `### 事实关系边(links)
|
||||
|
||||
- 写入 **finding / chain / exploit / poc** 时,**必须**在 ` + "`upsert_project_fact`" + ` 中提供 ` + "`links`" + `(**推荐 ` + "`from`" + `**:来源 fact 指向当前 fact,即 ` + "`from`" + ` → 当前 ` + "`fact_key`" + `)。
|
||||
- **最少要求**:finding 类至少 1 条 from=target/* + type=discovered_on(即 target → finding);在 finding 上记录 exploit 用 from=exploit/* + type=exploits(即 exploit → finding)。
|
||||
- **常用 type**:` + "`discovered_on`" + `(发现在哪)、` + "`depends_on`" + `(复现前置)、` + "`leads_to`" + `(认知推进)、` + "`enables`" + `(扩大攻击面)、` + "`exploits`" + `(利用关系)、` + "`contains`" + `(资产包含)、` + "`part_of`" + `(属于链/组)、` + "`supports`" + `(证据支撑)。
|
||||
- 更新时:**省略 links 保留已有边**;传入 links 则**替换**全部关系边(from → 当前 fact)。
|
||||
- body 中「依赖事实」段落可与 links 并存(人读);结构化关系以 links 为准。`
|
||||
}
|
||||
|
||||
func factRecordingGuidanceBlock() string {
|
||||
return `### 事实写入规范(审计复现 / 知识沉淀)
|
||||
|
||||
- **summary**:索引用一行,须含「什么 + 在哪 + 如何触发/验证」要点,禁止只写结论(如仅写「存在 SQLi」)。
|
||||
- **body**:完整可复现上下文,写入 ` + "`upsert_project_fact`" + ` 的 body 字段;索引不含 body,后续会话须靠 ` + "`get_project_fact`" + ` 取回。
|
||||
- **category / fact_key 建议**:
|
||||
- 环境认知:` + "`target/`" + `、` + "`auth/`" + `、` + "`infra/`" + `、` + "`business/`" + `(body 用环境模板即可)
|
||||
- 发现与利用:` + "`finding/`" + `、` + "`chain/`" + `、` + "`exploit/`" + `、` + "`poc/`" + `(**必须**用攻击链模板填满 body:入口、逐步攻击链、原始请求/响应或命令、证据、关联漏洞 ID)
|
||||
- **与漏洞记录分工**:` + "`record_vulnerability`" + ` 记可交付 findings;事实记**复现所需的全部上下文**(含失败尝试、绕过、依赖会话),二者可各记一次。
|
||||
- 更新同一发现时保持相同 ` + "`fact_key`" + ` 覆盖写入,勿散落多个 key 导致上下文丢失。`
|
||||
}
|
||||
|
||||
// FactRecordingBlackboardSection 项目黑板与漏洞记录的完整系统提示块(单/多 Agent 主代理共用)。
|
||||
func FactRecordingBlackboardSection(coordinatorDelegate bool) string {
|
||||
var b strings.Builder
|
||||
b.WriteString("## 项目黑板(事实)与漏洞记录(分离)\n\n")
|
||||
b.WriteString("当前对话若已绑定项目,系统会自动注入「项目黑板索引」(仅 fact_key + 摘要)。**摘要不足时必须调用 ")
|
||||
b.WriteString(builtin.ToolGetProjectFact)
|
||||
b.WriteString("(fact_key) 获取 body,禁止凭摘要臆造细节。**\n\n")
|
||||
b.WriteString(factRecordingIncrementalRhythmBuiltin(coordinatorDelegate, false))
|
||||
b.WriteString("\n\n")
|
||||
b.WriteString("- **环境/目标/认证等认知**(非正式漏洞条目):使用 ")
|
||||
b.WriteString(builtin.ToolUpsertProjectFact)
|
||||
b.WriteString(",fact_key 建议 `category/slug`(如 target/primary_domain),同 key 覆盖更新;body 记端口/版本/凭据特征与证据来源。\n")
|
||||
b.WriteString("- **发现与利用上下文**(审计复现):fact_key 建议 finding/、chain/、exploit/、poc/ 前缀;**body 必填**完整攻击链(入口 → 步骤 → 原始请求/响应或命令 → 现象 → 关联 related_vulnerability_id),**禁止仅写结论**;summary 写「什么 + 在哪 + 如何验证」一行要点。\n")
|
||||
b.WriteString("- **可交付漏洞**:使用 ")
|
||||
b.WriteString(builtin.ToolRecordVulnerability)
|
||||
b.WriteString(",含标题、严重程度、类型、目标、证明(POC)、影响、修复建议。记前可先 ")
|
||||
b.WriteString(builtin.ToolListVulnerabilities)
|
||||
b.WriteString(" 查重,详情用 ")
|
||||
b.WriteString(builtin.ToolGetVulnerability)
|
||||
b.WriteString("(id)(默认仅当前项目/会话)。\n")
|
||||
b.WriteString("- 同一发现可能需**各记一次**(事实记**完整攻击链与 exploit 细节**供复现,漏洞记正式 findings)。误报用 ")
|
||||
b.WriteString(builtin.ToolDeprecateProjectFact)
|
||||
b.WriteString(" 或漏洞状态 false_positive。\n")
|
||||
b.WriteString("- 事实多时用 ")
|
||||
b.WriteString(builtin.ToolListProjectFacts)
|
||||
b.WriteString(" / ")
|
||||
b.WriteString(builtin.ToolSearchProjectFacts)
|
||||
b.WriteString(" 检索。\n\n")
|
||||
b.WriteString(factEdgeRecordingGuidance())
|
||||
b.WriteString("\n\n")
|
||||
b.WriteString(factRecordingGuidanceBlock())
|
||||
b.WriteString("\n\n严重程度:critical / high / medium / low / info。证明须含足够证据(请求响应、截图、命令输出等)。")
|
||||
return b.String()
|
||||
}
|
||||
|
||||
// FactRecordingSubAgentSection 子代理边渗透边记录(无工具时输出待落库条目)。
|
||||
func FactRecordingSubAgentSection() string {
|
||||
return "## 边渗透边记录\n\n" + factRecordingIncrementalRhythmBuiltin(false, true) + "\n"
|
||||
}
|
||||
|
||||
// FactRecordingBlackboardSectionMarkdown 与 FactRecordingBlackboardSection 等价的 Markdown(工具名为字面量,供 agents/*.md)。
|
||||
func FactRecordingBlackboardSectionMarkdown(coordinatorDelegate bool) string {
|
||||
var b strings.Builder
|
||||
b.WriteString("## 项目黑板(事实)与漏洞记录(分离)\n\n")
|
||||
b.WriteString("当前对话若已绑定项目,系统会自动注入「项目黑板索引」(仅 `fact_key` + 摘要)。**摘要不足时必须调用 `get_project_fact(fact_key)` 获取 body,禁止凭摘要臆造细节。**\n\n")
|
||||
b.WriteString(FactRecordingIncrementalRhythmMarkdown(coordinatorDelegate, false))
|
||||
b.WriteString("\n\n")
|
||||
b.WriteString("- **环境/目标/认证等认知**(非正式漏洞):使用 **`upsert_project_fact`**,`fact_key` 建议 `category/slug`(如 `target/primary_domain`),同 key 覆盖更新;body 记端口/版本/凭据特征与证据来源。\n")
|
||||
b.WriteString("- **发现与利用上下文**(审计复现):`fact_key` 建议 `finding/`、`chain/`、`exploit/`、`poc/` 前缀;**body 必填**完整攻击链(入口 → 步骤 → 原始请求/响应或命令 → 现象 → 关联 `related_vulnerability_id`),**禁止仅写结论**;summary 写「什么 + 在哪 + 如何验证」一行要点。\n")
|
||||
b.WriteString("- **可交付漏洞**:使用 **`record_vulnerability`**(标题、描述、严重程度、类型、目标、证明 POC、影响、修复建议)。严重程度 critical / high / medium / low / info。\n")
|
||||
b.WriteString("- 同一发现可能需**各记一次**(事实记可复现攻击链,漏洞记正式 findings)。误报用 **`deprecate_project_fact`** 或漏洞状态 false_positive。\n")
|
||||
b.WriteString("- 事实多时用 **`list_project_facts`** / **`search_project_facts`** 检索。\n\n")
|
||||
b.WriteString(factEdgeRecordingGuidance())
|
||||
b.WriteString("\n\n")
|
||||
b.WriteString(factRecordingGuidanceBlock())
|
||||
b.WriteString("\n\n严重程度:critical / high / medium / low / info。证明须含足够证据(请求响应、截图、命令输出等)。")
|
||||
return b.String()
|
||||
}
|
||||
|
||||
// FactEdgeRecordingGuidance 写入边时的 Agent 规范(供 project 包复用)。
|
||||
func FactEdgeRecordingGuidance() string { return factEdgeRecordingGuidance() }
|
||||
|
||||
// FactRecordingGuidanceBlock 事实写入规范块(供 project 包复用)。
|
||||
func FactRecordingGuidanceBlock() string { return factRecordingGuidanceBlock() }
|
||||
+5
-5
@@ -27,13 +27,13 @@ parameters:
|
||||
type: "string"
|
||||
description: "数据源(wayback,commoncrawl,otx,urlscan)"
|
||||
required: false
|
||||
flag: "-providers"
|
||||
flag: "--providers"
|
||||
format: "flag"
|
||||
- name: "include_subs"
|
||||
type: "bool"
|
||||
description: "包含子域名"
|
||||
required: false
|
||||
flag: "-subs"
|
||||
flag: "--subs"
|
||||
format: "flag"
|
||||
default: true
|
||||
- name: "additional_args"
|
||||
@@ -42,9 +42,9 @@ parameters:
|
||||
额外的Gau参数。用于传递未在参数列表中定义的Gau选项。
|
||||
|
||||
**示例值:**
|
||||
- "-o output.txt": 输出到文件
|
||||
- "-t": 线程数
|
||||
- "-b": 黑名单扩展
|
||||
- "--o output.txt": 输出到文件
|
||||
- "--threads 4": 线程数
|
||||
- "--blacklist ttf,woff,svg,png": 黑名单扩展
|
||||
|
||||
**注意事项:**
|
||||
- 多个参数用空格分隔
|
||||
|
||||
+717
-39
@@ -2456,16 +2456,68 @@ header {
|
||||
flex-shrink: 0;
|
||||
}
|
||||
|
||||
.mcp-call-buttons {
|
||||
.mcp-call-buttons,
|
||||
.mcp-call-toolbar {
|
||||
display: flex;
|
||||
flex-wrap: wrap;
|
||||
gap: 6px;
|
||||
align-items: center;
|
||||
}
|
||||
|
||||
.mcp-tool-list {
|
||||
display: none;
|
||||
flex-wrap: wrap;
|
||||
gap: 6px;
|
||||
margin-top: 8px;
|
||||
width: 100%;
|
||||
}
|
||||
|
||||
.mcp-tool-list.expanded {
|
||||
display: flex;
|
||||
}
|
||||
|
||||
.mcp-tools-toggle-btn {
|
||||
background: rgba(25, 118, 210, 0.1) !important;
|
||||
border-color: rgba(25, 118, 210, 0.35) !important;
|
||||
color: #1976d2 !important;
|
||||
}
|
||||
|
||||
.mcp-call-toolbar .process-detail-btn,
|
||||
.mcp-call-toolbar .mcp-tools-toggle-btn {
|
||||
display: inline-flex;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
gap: 6px;
|
||||
min-height: 32px;
|
||||
padding: 6px 12px;
|
||||
font-size: 0.8125rem;
|
||||
font-weight: 500;
|
||||
line-height: 1.25;
|
||||
box-sizing: border-box;
|
||||
white-space: nowrap;
|
||||
vertical-align: middle;
|
||||
}
|
||||
|
||||
.mcp-call-toolbar .process-detail-btn span,
|
||||
.mcp-call-toolbar .mcp-tools-toggle-btn span {
|
||||
display: inline-flex;
|
||||
align-items: center;
|
||||
line-height: 1.25;
|
||||
}
|
||||
|
||||
.mcp-tools-toggle-btn:hover {
|
||||
background: rgba(25, 118, 210, 0.18) !important;
|
||||
border-color: #1976d2 !important;
|
||||
color: #1565c0 !important;
|
||||
}
|
||||
|
||||
.process-detail-btn {
|
||||
background: rgba(156, 39, 176, 0.1) !important;
|
||||
border-color: rgba(156, 39, 176, 0.3) !important;
|
||||
color: #9c27b0 !important;
|
||||
}
|
||||
|
||||
.mcp-call-toolbar .process-detail-btn {
|
||||
display: inline-flex;
|
||||
align-items: center;
|
||||
gap: 6px;
|
||||
@@ -3964,9 +4016,10 @@ header {
|
||||
background: var(--bg-tertiary);
|
||||
}
|
||||
|
||||
/* 迭代轮次:暖琥珀色条 + 极浅底,与紫(推理)/蓝(工具)区分但不抢视觉 */
|
||||
.timeline-item-iteration {
|
||||
border-left-color: var(--accent-color);
|
||||
background: rgba(0, 102, 255, 0.06);
|
||||
border-left-color: #c4a574;
|
||||
background: rgba(180, 140, 90, 0.045);
|
||||
}
|
||||
|
||||
/*
|
||||
@@ -3974,13 +4027,18 @@ header {
|
||||
* 但不再在此处整卡铺色 + !important,否则会盖住工具调用/结果/思考的类型色。
|
||||
* 主编排 vs 子代理的区分由「迭代轮次」上的 timeline-eino-scope-* 负责。
|
||||
*/
|
||||
.timeline-item-iteration.timeline-eino-scope-main {
|
||||
border-left-color: #3949ab !important;
|
||||
background: rgba(57, 73, 171, 0.1) !important;
|
||||
.timeline-item.timeline-item-iteration.timeline-eino-scope-main {
|
||||
border-left-color: #b8956a;
|
||||
background: rgba(184, 149, 106, 0.05);
|
||||
}
|
||||
.timeline-item-iteration.timeline-eino-scope-sub {
|
||||
border-left-color: #00695c !important;
|
||||
background: rgba(0, 105, 92, 0.09) !important;
|
||||
.timeline-item.timeline-item-iteration.timeline-eino-scope-sub {
|
||||
border-left-color: #a6896c;
|
||||
background: rgba(166, 137, 108, 0.045);
|
||||
}
|
||||
|
||||
.timeline-item-iteration .timeline-item-title {
|
||||
color: var(--text-secondary);
|
||||
font-weight: 500;
|
||||
}
|
||||
|
||||
/* 模型内部思考:弱化灰紫,避免与「助手输出」抢视觉 */
|
||||
@@ -11791,34 +11849,44 @@ tr.mcp-stats-tool-row[data-tool-name]:focus-visible {
|
||||
background: transparent;
|
||||
color: var(--text-muted);
|
||||
cursor: pointer;
|
||||
border-radius: 4px;
|
||||
border-radius: 6px;
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
transition: all 0.2s ease;
|
||||
}
|
||||
|
||||
.batch-delete-btn svg {
|
||||
width: 16px;
|
||||
height: 16px;
|
||||
transition: transform 0.2s ease;
|
||||
}
|
||||
|
||||
.batch-delete-btn:hover {
|
||||
background: rgba(220, 53, 69, 0.1);
|
||||
color: var(--error-color);
|
||||
}
|
||||
|
||||
.batch-delete-btn:hover svg {
|
||||
transform: scale(1.08);
|
||||
}
|
||||
|
||||
.batch-delete-btn:active {
|
||||
background: rgba(220, 53, 69, 0.2);
|
||||
transform: scale(0.95);
|
||||
}
|
||||
|
||||
.batch-manage-footer {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: space-between;
|
||||
justify-content: flex-end;
|
||||
padding: 16px 24px;
|
||||
border-top: 1px solid var(--border-color);
|
||||
flex-shrink: 0;
|
||||
}
|
||||
|
||||
.select-all-checkbox {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 8px;
|
||||
.batch-table-col-checkbox input[type="checkbox"] {
|
||||
cursor: pointer;
|
||||
font-size: 0.875rem;
|
||||
color: var(--text-primary);
|
||||
}
|
||||
|
||||
.batch-footer-actions {
|
||||
@@ -15846,7 +15914,12 @@ tr.mcp-stats-tool-row[data-tool-name]:focus-visible {
|
||||
overflow-y: visible;
|
||||
}
|
||||
.webshell-ai-process-block .webshell-ai-timeline-iteration {
|
||||
border-left-color: var(--accent-color);
|
||||
border-left-color: #c4a574;
|
||||
background: rgba(180, 140, 90, 0.04);
|
||||
}
|
||||
.webshell-ai-process-block .webshell-ai-timeline-iteration .webshell-ai-timeline-title {
|
||||
color: var(--text-secondary);
|
||||
font-weight: 500;
|
||||
}
|
||||
.webshell-ai-process-block .webshell-ai-timeline-thinking {
|
||||
border-left-color: #9c27b0;
|
||||
@@ -23860,9 +23933,17 @@ button.chat-files-dropdown-item:hover:not(:disabled) {
|
||||
min-height: 420px;
|
||||
}
|
||||
.projects-placeholder-icon {
|
||||
font-size: 3rem;
|
||||
margin-bottom: 16px;
|
||||
opacity: 0.85;
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
width: 88px;
|
||||
height: 88px;
|
||||
margin-bottom: 20px;
|
||||
color: #3b82f6;
|
||||
background: linear-gradient(145deg, #eff6ff 0%, #dbeafe 100%);
|
||||
border: 1px solid #bfdbfe;
|
||||
border-radius: 22px;
|
||||
box-shadow: 0 8px 24px rgba(59, 130, 246, 0.12);
|
||||
}
|
||||
.projects-detail-placeholder h3 {
|
||||
margin: 0 0 8px;
|
||||
@@ -23883,7 +23964,7 @@ button.chat-files-dropdown-item:hover:not(:disabled) {
|
||||
background: #ffffff;
|
||||
border: 1px solid var(--border-color, #e2e8f0);
|
||||
border-radius: 14px;
|
||||
box-shadow: 0 1px 3px rgba(15, 23, 42, 0.06);
|
||||
box-shadow: 0 1px 3px rgba(15, 23, 42, 0.06), 0 8px 24px rgba(15, 23, 42, 0.04);
|
||||
overflow: hidden;
|
||||
min-height: 0;
|
||||
align-self: stretch;
|
||||
@@ -24066,6 +24147,7 @@ button.chat-files-dropdown-item:hover:not(:disabled) {
|
||||
color: #0066ff;
|
||||
background: #fff;
|
||||
box-shadow: 0 1px 3px rgba(15, 23, 42, 0.08);
|
||||
font-weight: 600;
|
||||
}
|
||||
.projects-panel {
|
||||
flex: 1;
|
||||
@@ -24309,11 +24391,17 @@ button.chat-files-dropdown-item:hover:not(:disabled) {
|
||||
#project-panel-vulns .projects-table-wrap {
|
||||
flex: 1 1 auto;
|
||||
min-height: 0;
|
||||
overflow-x: hidden;
|
||||
overflow-y: auto;
|
||||
overscroll-behavior: contain;
|
||||
-webkit-overflow-scrolling: touch;
|
||||
}
|
||||
#project-panel-conversations .projects-table-wrap,
|
||||
#project-panel-vulns .projects-table-wrap {
|
||||
overflow-x: hidden;
|
||||
}
|
||||
#project-panel-facts .projects-table-wrap {
|
||||
overflow-x: auto;
|
||||
}
|
||||
#project-panel-facts .projects-table-wrap .data-table--projects thead th,
|
||||
#project-panel-conversations .projects-table-wrap .data-table--projects thead th,
|
||||
#project-panel-vulns .projects-table-wrap .data-table--projects thead th {
|
||||
@@ -24332,12 +24420,6 @@ button.chat-files-dropdown-item:hover:not(:disabled) {
|
||||
.projects-panel-toolbar--hint .projects-fact-toolbar-hint {
|
||||
margin: 0;
|
||||
}
|
||||
#project-panel-facts .data-table--projects th:nth-child(1) { width: 20%; }
|
||||
#project-panel-facts .data-table--projects th:nth-child(2) { width: 9%; }
|
||||
#project-panel-facts .data-table--projects th:nth-child(3) { width: 30%; }
|
||||
#project-panel-facts .data-table--projects th:nth-child(4) { width: 9%; }
|
||||
#project-panel-facts .data-table--projects th:nth-child(5) { width: 10%; }
|
||||
#project-panel-facts .data-table--projects th:nth-child(6) { width: 10%; }
|
||||
#project-panel-facts .data-table--projects .cell-fact-key {
|
||||
overflow: hidden;
|
||||
max-width: 0;
|
||||
@@ -24345,6 +24427,16 @@ button.chat-files-dropdown-item:hover:not(:disabled) {
|
||||
#project-panel-facts .data-table--projects .cell-fact-category {
|
||||
white-space: nowrap;
|
||||
}
|
||||
#project-panel-facts .data-table--projects .cell-summary {
|
||||
max-width: 0;
|
||||
overflow: hidden;
|
||||
text-overflow: ellipsis;
|
||||
white-space: nowrap;
|
||||
}
|
||||
#project-panel-facts .data-table--projects .cell-fact-links {
|
||||
text-align: center;
|
||||
white-space: nowrap;
|
||||
}
|
||||
#project-panel-facts .projects-fact-key-chip {
|
||||
display: inline-block;
|
||||
max-width: 100%;
|
||||
@@ -24463,23 +24555,23 @@ button.chat-files-dropdown-item:hover:not(:disabled) {
|
||||
}
|
||||
#project-panel-facts .data-table--projects th:nth-child(1),
|
||||
#project-panel-facts .data-table--projects td:nth-child(1) {
|
||||
width: 19%;
|
||||
width: 13%;
|
||||
}
|
||||
#project-panel-facts .data-table--projects th:nth-child(2),
|
||||
#project-panel-facts .data-table--projects td:nth-child(2) {
|
||||
width: 9%;
|
||||
width: 7%;
|
||||
}
|
||||
#project-panel-facts .data-table--projects th:nth-child(3),
|
||||
#project-panel-facts .data-table--projects td:nth-child(3) {
|
||||
width: 28%;
|
||||
width: 22%;
|
||||
}
|
||||
#project-panel-facts .data-table--projects th:nth-child(4),
|
||||
#project-panel-facts .data-table--projects td:nth-child(4) {
|
||||
width: 8%;
|
||||
width: 5%;
|
||||
}
|
||||
#project-panel-facts .data-table--projects th:nth-child(5),
|
||||
#project-panel-facts .data-table--projects td:nth-child(5) {
|
||||
width: 9%;
|
||||
width: 8%;
|
||||
}
|
||||
#project-panel-facts .data-table--projects th:nth-child(6),
|
||||
#project-panel-facts .data-table--projects td:nth-child(6) {
|
||||
@@ -24487,8 +24579,593 @@ button.chat-files-dropdown-item:hover:not(:disabled) {
|
||||
}
|
||||
#project-panel-facts .data-table--projects th:nth-child(7),
|
||||
#project-panel-facts .data-table--projects td:nth-child(7) {
|
||||
width: 19%;
|
||||
width: 9%;
|
||||
}
|
||||
#project-panel-facts .data-table--projects th.col-actions,
|
||||
#project-panel-facts .data-table--projects td.col-actions {
|
||||
width: 28%;
|
||||
min-width: 196px;
|
||||
max-width: 240px;
|
||||
position: sticky;
|
||||
right: 0;
|
||||
z-index: 3;
|
||||
background: #fff;
|
||||
box-shadow: -6px 0 10px rgba(15, 23, 42, 0.05);
|
||||
}
|
||||
#project-panel-facts .data-table--projects thead th.col-actions {
|
||||
z-index: 6;
|
||||
background: #f8fafc;
|
||||
}
|
||||
#project-panel-facts .data-table--projects tbody tr:hover td.col-actions {
|
||||
background: #f8fafc;
|
||||
}
|
||||
#project-panel-facts .data-table--projects .col-actions .projects-table-actions {
|
||||
flex-wrap: nowrap;
|
||||
}
|
||||
|
||||
/* 项目事实攻击路径图 */
|
||||
#project-panel-graph.projects-panel--graph {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
overflow: hidden;
|
||||
min-height: 0;
|
||||
padding-bottom: 0;
|
||||
}
|
||||
#project-panel-graph .projects-graph-toolbar {
|
||||
flex: 0 0 auto;
|
||||
}
|
||||
#project-panel-graph .project-fact-graph-layout {
|
||||
flex: 1 1 0;
|
||||
min-height: 0;
|
||||
max-height: 100%;
|
||||
overflow: hidden;
|
||||
}
|
||||
#project-panel-graph .project-fact-graph-container {
|
||||
min-height: 0;
|
||||
height: 100%;
|
||||
}
|
||||
#project-panel-graph .project-fact-graph-footer {
|
||||
flex: 0 0 auto;
|
||||
flex-shrink: 0;
|
||||
position: relative;
|
||||
z-index: 20;
|
||||
margin: 0;
|
||||
padding: 10px 0 12px;
|
||||
background: #fff;
|
||||
border-top: 1px solid #eef2f7;
|
||||
}
|
||||
.projects-graph-toolbar-row {
|
||||
align-items: flex-end;
|
||||
}
|
||||
.projects-graph-search-field {
|
||||
flex: 1 1 180px;
|
||||
max-width: 280px;
|
||||
}
|
||||
.projects-graph-actions {
|
||||
display: inline-flex;
|
||||
align-items: center;
|
||||
gap: 6px;
|
||||
margin-left: auto;
|
||||
padding: 3px;
|
||||
background: #f8fafc;
|
||||
border: 1px solid #e2e8f0;
|
||||
border-radius: 10px;
|
||||
}
|
||||
.projects-graph-action-btn {
|
||||
display: inline-flex;
|
||||
align-items: center;
|
||||
gap: 5px;
|
||||
padding: 6px 11px;
|
||||
font-size: 0.8125rem;
|
||||
font-weight: 500;
|
||||
color: #475569;
|
||||
background: transparent;
|
||||
border: 1px solid transparent;
|
||||
border-radius: 7px;
|
||||
cursor: pointer;
|
||||
white-space: nowrap;
|
||||
transition: background 0.15s ease, color 0.15s ease, border-color 0.15s ease, box-shadow 0.15s ease;
|
||||
}
|
||||
.projects-graph-action-btn svg {
|
||||
flex-shrink: 0;
|
||||
opacity: 0.75;
|
||||
}
|
||||
.projects-graph-action-btn:hover {
|
||||
color: #0f172a;
|
||||
background: #fff;
|
||||
border-color: #e2e8f0;
|
||||
box-shadow: 0 1px 2px rgba(15, 23, 42, 0.06);
|
||||
}
|
||||
.projects-graph-action-btn--connect {
|
||||
color: #4338ca;
|
||||
background: #eef2ff;
|
||||
border-color: #c7d2fe;
|
||||
}
|
||||
.projects-graph-action-btn--connect:hover,
|
||||
.projects-graph-action-btn--connect-active {
|
||||
color: #fff;
|
||||
background: linear-gradient(135deg, #4f46e5 0%, #6366f1 100%);
|
||||
border-color: transparent;
|
||||
box-shadow: 0 2px 8px rgba(79, 70, 229, 0.35);
|
||||
}
|
||||
.projects-graph-legend {
|
||||
display: flex;
|
||||
flex-wrap: wrap;
|
||||
align-items: center;
|
||||
justify-content: flex-end;
|
||||
gap: 8px 14px;
|
||||
}
|
||||
.projects-graph-legend-group {
|
||||
display: inline-flex;
|
||||
flex-wrap: wrap;
|
||||
align-items: center;
|
||||
gap: 6px 10px;
|
||||
}
|
||||
.projects-graph-legend-heading {
|
||||
font-size: 0.6875rem;
|
||||
font-weight: 600;
|
||||
letter-spacing: 0.04em;
|
||||
text-transform: uppercase;
|
||||
color: #94a3b8;
|
||||
}
|
||||
.projects-graph-legend-divider {
|
||||
display: inline-block;
|
||||
width: 1px;
|
||||
height: 18px;
|
||||
background: #e2e8f0;
|
||||
flex: 0 0 auto;
|
||||
}
|
||||
.projects-graph-legend-item {
|
||||
display: inline-flex;
|
||||
align-items: center;
|
||||
gap: 6px;
|
||||
font-size: 0.75rem;
|
||||
color: #64748b;
|
||||
}
|
||||
.projects-graph-legend-item--edge i {
|
||||
display: inline-block;
|
||||
width: 22px;
|
||||
height: 0;
|
||||
border-top: 2.5px solid var(--legend-color, #cbd5e1);
|
||||
border-radius: 2px;
|
||||
}
|
||||
.projects-graph-legend-item--edge.projects-graph-legend-item--dashed i {
|
||||
border-top-style: dashed;
|
||||
opacity: 0.7;
|
||||
}
|
||||
.projects-graph-legend-item--node i {
|
||||
display: inline-block;
|
||||
width: 14px;
|
||||
height: 14px;
|
||||
border: 1.5px solid var(--legend-color, #cbd5e1);
|
||||
border-radius: 4px;
|
||||
background: linear-gradient(135deg, #ffffff 0%, var(--legend-bg, #f8fafc) 100%);
|
||||
box-shadow: inset 0 1px 0 rgba(255, 255, 255, 0.9);
|
||||
}
|
||||
.projects-graph-legend-item--node-dashed i {
|
||||
border-style: dashed;
|
||||
opacity: 0.85;
|
||||
}
|
||||
.project-fact-graph-layout {
|
||||
position: relative;
|
||||
display: flex;
|
||||
min-height: 0;
|
||||
align-items: stretch;
|
||||
}
|
||||
.project-fact-graph-container {
|
||||
flex: 1 1 auto;
|
||||
width: 100%;
|
||||
min-height: 240px;
|
||||
border: 1px solid var(--border-color, #e2e8f0);
|
||||
border-radius: 14px;
|
||||
background-color: #f8fafc;
|
||||
background-image:
|
||||
radial-gradient(circle at 1px 1px, rgba(148, 163, 184, 0.35) 1px, transparent 0);
|
||||
background-size: 20px 20px;
|
||||
position: relative;
|
||||
overflow: hidden;
|
||||
box-shadow: inset 0 1px 0 rgba(255, 255, 255, 0.8), 0 1px 3px rgba(15, 23, 42, 0.04);
|
||||
}
|
||||
.project-fact-graph-container::after {
|
||||
content: '';
|
||||
position: absolute;
|
||||
inset: 0;
|
||||
pointer-events: none;
|
||||
background: radial-gradient(ellipse at center, transparent 55%, rgba(241, 245, 249, 0.65) 100%);
|
||||
z-index: 1;
|
||||
}
|
||||
.project-fact-graph-container .loading-spinner,
|
||||
.project-fact-graph-container .project-fact-graph-empty,
|
||||
.project-fact-graph-container .error-message {
|
||||
position: relative;
|
||||
z-index: 2;
|
||||
}
|
||||
.project-fact-graph-empty {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
text-align: center;
|
||||
height: 100%;
|
||||
min-height: 420px;
|
||||
padding: 40px 32px;
|
||||
}
|
||||
.project-fact-graph-empty-icon {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
width: 80px;
|
||||
height: 80px;
|
||||
margin-bottom: 18px;
|
||||
background: rgba(255, 255, 255, 0.85);
|
||||
border: 1px solid #e2e8f0;
|
||||
border-radius: 20px;
|
||||
box-shadow: 0 4px 16px rgba(15, 23, 42, 0.06);
|
||||
}
|
||||
.project-fact-graph-empty-title {
|
||||
margin: 0 0 8px;
|
||||
font-size: 1.0625rem;
|
||||
font-weight: 600;
|
||||
color: #0f172a;
|
||||
letter-spacing: -0.01em;
|
||||
}
|
||||
.project-fact-graph-empty-hint {
|
||||
margin: 0 0 16px;
|
||||
max-width: 420px;
|
||||
font-size: 0.875rem;
|
||||
line-height: 1.6;
|
||||
color: #64748b;
|
||||
}
|
||||
.project-fact-graph-empty-steps {
|
||||
margin: 0 0 20px;
|
||||
padding-left: 1.2rem;
|
||||
max-width: 400px;
|
||||
text-align: left;
|
||||
font-size: 0.8125rem;
|
||||
line-height: 1.65;
|
||||
color: #475569;
|
||||
}
|
||||
.project-fact-graph-empty-steps li {
|
||||
margin-bottom: 4px;
|
||||
}
|
||||
.project-fact-graph-empty-steps li::marker {
|
||||
color: #6366f1;
|
||||
font-weight: 600;
|
||||
}
|
||||
.project-fact-graph-empty-cta {
|
||||
margin-top: 4px;
|
||||
}
|
||||
.project-fact-graph-sidebar {
|
||||
position: absolute;
|
||||
top: 12px;
|
||||
right: 12px;
|
||||
bottom: 12px;
|
||||
width: min(300px, calc(100% - 24px));
|
||||
z-index: 12;
|
||||
border: 1px solid rgba(226, 232, 240, 0.95);
|
||||
border-radius: 14px;
|
||||
padding: 16px;
|
||||
background: rgba(255, 255, 255, 0.96);
|
||||
backdrop-filter: blur(12px);
|
||||
-webkit-backdrop-filter: blur(12px);
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
gap: 12px;
|
||||
box-shadow: 0 8px 32px rgba(15, 23, 42, 0.12), 0 2px 8px rgba(15, 23, 42, 0.06);
|
||||
animation: projectGraphSidebarIn 0.2s ease;
|
||||
overflow-x: hidden;
|
||||
overflow-y: auto;
|
||||
}
|
||||
.project-fact-graph-sidebar[hidden] {
|
||||
display: none !important;
|
||||
}
|
||||
@keyframes projectGraphSidebarIn {
|
||||
from {
|
||||
opacity: 0;
|
||||
transform: translateX(12px);
|
||||
}
|
||||
to {
|
||||
opacity: 1;
|
||||
transform: translateX(0);
|
||||
}
|
||||
}
|
||||
.project-fact-graph-sidebar-header {
|
||||
display: flex;
|
||||
align-items: flex-start;
|
||||
justify-content: space-between;
|
||||
gap: 10px;
|
||||
}
|
||||
.project-fact-graph-sidebar-title-wrap {
|
||||
min-width: 0;
|
||||
flex: 1;
|
||||
}
|
||||
.project-fact-graph-sidebar-header h4 {
|
||||
margin: 4px 0 0;
|
||||
font-size: 0.9rem;
|
||||
font-weight: 600;
|
||||
word-break: break-all;
|
||||
color: #0f172a;
|
||||
line-height: 1.35;
|
||||
}
|
||||
.project-fact-graph-node-category {
|
||||
display: inline-block;
|
||||
font-size: 0.6875rem;
|
||||
font-weight: 600;
|
||||
text-transform: uppercase;
|
||||
letter-spacing: 0.04em;
|
||||
padding: 2px 8px;
|
||||
border-radius: 999px;
|
||||
background: #f1f5f9;
|
||||
color: #64748b;
|
||||
border: 1px solid #e2e8f0;
|
||||
}
|
||||
.project-fact-graph-node-category--target { color: #4338ca; background: #eef2ff; border-color: #c7d2fe; }
|
||||
.project-fact-graph-node-category--finding { color: #be123c; background: #fff1f2; border-color: #fecdd3; }
|
||||
.project-fact-graph-node-category--vulnerability { color: #7e22ce; background: #f5f3ff; border-color: #ddd6fe; }
|
||||
.project-fact-graph-node-category--exploit,
|
||||
.project-fact-graph-node-category--poc { color: #c2410c; background: #ffedd5; border-color: #fdba74; }
|
||||
.project-fact-graph-node-category--chain { color: #6d28d9; background: #f5f3ff; border-color: #ddd6fe; }
|
||||
.project-fact-graph-node-category--auth { color: #0f766e; background: #f0fdfa; border-color: #99f6e4; }
|
||||
.project-fact-graph-node-category--infra { color: #475569; background: #f1f5f9; border-color: #cbd5e1; }
|
||||
.project-fact-graph-node-category--business { color: #0369a1; background: #f0f9ff; border-color: #bae6fd; }
|
||||
.project-fact-graph-node-category--note { color: #64748b; background: #f8fafc; border-color: #e2e8f0; }
|
||||
.project-fact-graph-node-category--missing { color: #94a3b8; background: #f1f5f9; border-color: #e2e8f0; font-style: italic; }
|
||||
.project-fact-graph-sidebar-close {
|
||||
flex-shrink: 0;
|
||||
display: inline-flex;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
width: 28px;
|
||||
height: 28px;
|
||||
padding: 0;
|
||||
border: 1px solid #e2e8f0;
|
||||
border-radius: 8px;
|
||||
background: #fff;
|
||||
color: #64748b;
|
||||
cursor: pointer;
|
||||
transition: background 0.15s ease, color 0.15s ease, border-color 0.15s ease;
|
||||
}
|
||||
.project-fact-graph-sidebar-close:hover {
|
||||
color: #0f172a;
|
||||
border-color: #cbd5e1;
|
||||
background: #f8fafc;
|
||||
}
|
||||
.project-fact-graph-node-meta {
|
||||
margin: 0;
|
||||
font-size: 0.8125rem;
|
||||
line-height: 1.55;
|
||||
color: #64748b;
|
||||
flex: 0 0 auto;
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
align-items: flex-start;
|
||||
gap: 6px;
|
||||
min-width: 0;
|
||||
word-break: break-word;
|
||||
overflow-wrap: anywhere;
|
||||
}
|
||||
.project-fact-graph-node-summary {
|
||||
display: block;
|
||||
width: 100%;
|
||||
min-width: 0;
|
||||
color: #475569;
|
||||
}
|
||||
.project-fact-graph-node-vuln-hint {
|
||||
display: block;
|
||||
width: 100%;
|
||||
font-size: 0.75rem;
|
||||
line-height: 1.45;
|
||||
color: #64748b;
|
||||
}
|
||||
.project-fact-graph-edges-wrap {
|
||||
flex: 1 1 auto;
|
||||
min-height: 0;
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
gap: 6px;
|
||||
padding-top: 4px;
|
||||
border-top: 1px solid #f1f5f9;
|
||||
}
|
||||
.project-fact-graph-edges-title {
|
||||
margin: 0;
|
||||
font-size: 0.75rem;
|
||||
font-weight: 600;
|
||||
color: #475569;
|
||||
letter-spacing: 0.02em;
|
||||
}
|
||||
.project-fact-graph-edges-hint {
|
||||
margin: 0;
|
||||
font-size: 0.72rem;
|
||||
line-height: 1.45;
|
||||
color: #94a3b8;
|
||||
word-break: break-word;
|
||||
overflow-wrap: anywhere;
|
||||
}
|
||||
.project-fact-graph-edges-list {
|
||||
flex: 1 1 auto;
|
||||
min-height: 0;
|
||||
overflow-y: auto;
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
gap: 4px;
|
||||
}
|
||||
.project-fact-graph-edges-empty {
|
||||
margin: 0;
|
||||
font-size: 0.8125rem;
|
||||
color: #94a3b8;
|
||||
}
|
||||
.project-fact-graph-edge-item {
|
||||
display: grid;
|
||||
grid-template-columns: auto auto 1fr auto;
|
||||
align-items: center;
|
||||
gap: 4px 6px;
|
||||
padding: 6px 8px;
|
||||
font-size: 0.75rem;
|
||||
border: 1px solid #e2e8f0;
|
||||
border-radius: 8px;
|
||||
background: #f8fafc;
|
||||
cursor: pointer;
|
||||
transition: border-color 0.15s ease, background 0.15s ease, box-shadow 0.15s ease;
|
||||
}
|
||||
.project-fact-graph-edge-item:hover {
|
||||
border-color: #cbd5e1;
|
||||
background: #fff;
|
||||
}
|
||||
.project-fact-graph-edge-item.is-selected {
|
||||
border-color: #818cf8;
|
||||
background: #eef2ff;
|
||||
box-shadow: 0 0 0 1px rgba(99, 102, 241, 0.25);
|
||||
}
|
||||
.project-fact-graph-edge-dir {
|
||||
font-size: 0.6875rem;
|
||||
font-weight: 600;
|
||||
color: #64748b;
|
||||
white-space: nowrap;
|
||||
}
|
||||
.project-fact-graph-edge-type {
|
||||
font-family: ui-monospace, SFMono-Regular, Menlo, Monaco, Consolas, monospace;
|
||||
font-size: 0.6875rem;
|
||||
color: #4338ca;
|
||||
white-space: nowrap;
|
||||
}
|
||||
.project-fact-graph-edge-arrow {
|
||||
color: #94a3b8;
|
||||
}
|
||||
.project-fact-graph-edge-peer {
|
||||
overflow: hidden;
|
||||
text-overflow: ellipsis;
|
||||
white-space: nowrap;
|
||||
color: #334155;
|
||||
min-width: 0;
|
||||
}
|
||||
.project-fact-graph-edge-delete {
|
||||
display: inline-flex;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
width: 22px;
|
||||
height: 22px;
|
||||
padding: 0;
|
||||
border: 1px solid #fecaca;
|
||||
border-radius: 6px;
|
||||
background: #fff;
|
||||
color: #dc2626;
|
||||
font-size: 1rem;
|
||||
line-height: 1;
|
||||
cursor: pointer;
|
||||
flex-shrink: 0;
|
||||
}
|
||||
.project-fact-graph-edge-delete:hover {
|
||||
background: #fef2f2;
|
||||
border-color: #f87171;
|
||||
}
|
||||
.project-fact-graph-edge-synthetic {
|
||||
display: inline-flex;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
width: 22px;
|
||||
color: #cbd5e1;
|
||||
flex-shrink: 0;
|
||||
}
|
||||
.projects-incoming-links-readonly {
|
||||
margin-top: 4px;
|
||||
}
|
||||
.projects-incoming-links-list {
|
||||
margin: 0;
|
||||
padding: 0;
|
||||
list-style: none;
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
gap: 6px;
|
||||
}
|
||||
.projects-incoming-links-item {
|
||||
padding: 8px 10px;
|
||||
font-size: 0.8125rem;
|
||||
border: 1px solid #e2e8f0;
|
||||
border-radius: 8px;
|
||||
background: #f8fafc;
|
||||
color: #334155;
|
||||
word-break: break-all;
|
||||
}
|
||||
.projects-incoming-links-item code {
|
||||
font-size: 0.75rem;
|
||||
}
|
||||
.projects-incoming-links-empty {
|
||||
margin: 0;
|
||||
font-size: 0.8125rem;
|
||||
color: #94a3b8;
|
||||
}
|
||||
.projects-edge-type {
|
||||
font-family: ui-monospace, SFMono-Regular, Menlo, Monaco, Consolas, monospace;
|
||||
font-size: 0.75rem;
|
||||
color: #4338ca;
|
||||
}
|
||||
.projects-edge-arrow {
|
||||
color: #94a3b8;
|
||||
}
|
||||
.project-fact-graph-sidebar-actions {
|
||||
display: flex;
|
||||
gap: 8px;
|
||||
flex-wrap: wrap;
|
||||
padding-top: 4px;
|
||||
border-top: 1px solid #f1f5f9;
|
||||
}
|
||||
.project-fact-graph-footer {
|
||||
display: flex;
|
||||
flex-wrap: wrap;
|
||||
align-items: center;
|
||||
justify-content: space-between;
|
||||
gap: 8px 12px;
|
||||
margin: 10px 0 0;
|
||||
flex: 0 0 auto;
|
||||
flex-shrink: 0;
|
||||
}
|
||||
.project-fact-graph-stats {
|
||||
display: flex;
|
||||
flex-wrap: wrap;
|
||||
align-items: center;
|
||||
gap: 8px;
|
||||
margin: 0;
|
||||
flex: 0 1 auto;
|
||||
}
|
||||
.projects-graph-stat-badge {
|
||||
display: inline-flex;
|
||||
align-items: center;
|
||||
gap: 4px;
|
||||
font-size: 0.8125rem;
|
||||
color: #64748b;
|
||||
background: #f8fafc;
|
||||
border: 1px solid #e2e8f0;
|
||||
padding: 4px 12px;
|
||||
border-radius: 999px;
|
||||
}
|
||||
.projects-graph-stat-badge strong {
|
||||
font-size: 0.9375rem;
|
||||
font-weight: 700;
|
||||
color: #0f172a;
|
||||
font-variant-numeric: tabular-nums;
|
||||
}
|
||||
#project-panel-graph .projects-fact-toolbar-filters {
|
||||
flex-wrap: wrap;
|
||||
}
|
||||
.projects-fact-link-badge {
|
||||
font-size: 0.78rem;
|
||||
font-variant-numeric: tabular-nums;
|
||||
color: var(--text-secondary, #64748b);
|
||||
}
|
||||
.projects-fact-link-badge--empty {
|
||||
opacity: 0.45;
|
||||
}
|
||||
@media (max-width: 1100px) {
|
||||
.project-fact-graph-sidebar {
|
||||
width: min(280px, calc(100% - 24px));
|
||||
}
|
||||
.projects-graph-actions {
|
||||
margin-left: 0;
|
||||
width: 100%;
|
||||
justify-content: flex-end;
|
||||
}
|
||||
}
|
||||
|
||||
@media (max-width: 1400px) {
|
||||
.projects-detail-header {
|
||||
padding: 16px 18px 14px;
|
||||
@@ -24513,11 +25190,12 @@ button.chat-files-dropdown-item:hover:not(:disabled) {
|
||||
}
|
||||
#project-panel-facts .data-table--projects th:nth-child(3),
|
||||
#project-panel-facts .data-table--projects td:nth-child(3) {
|
||||
width: 24%;
|
||||
width: 22%;
|
||||
}
|
||||
#project-panel-facts .data-table--projects th:nth-child(7),
|
||||
#project-panel-facts .data-table--projects td:nth-child(7) {
|
||||
width: 23%;
|
||||
#project-panel-facts .data-table--projects th.col-actions,
|
||||
#project-panel-facts .data-table--projects td.col-actions {
|
||||
min-width: 188px;
|
||||
max-width: 220px;
|
||||
}
|
||||
}
|
||||
/* —— 项目设置:左右分栏 + 底部危险区,无内层滚动 —— */
|
||||
|
||||
@@ -258,10 +258,73 @@
|
||||
"vulnerabilityManagement": "Vulnerability management",
|
||||
"addFactCta": "+ Add fact",
|
||||
"tabFacts": "Fact board",
|
||||
"tabGraph": "Attack path",
|
||||
"tabConversations": "Bound conversations",
|
||||
"tabVulns": "Related vulnerabilities",
|
||||
"tabSettings": "Settings",
|
||||
"factToolbarHint": "Index includes key and summary only (must include what + where + how to verify); put attack chain / POC in body, and reproduce via get_project_fact.",
|
||||
"graphToolbarHint": "Graph arrows match stored fact links (source → target). Nodes are layered target→infra→finding→exploit. Dashed edges are tentative.",
|
||||
"graphView": "View",
|
||||
"graphViewPath": "Attack path",
|
||||
"graphViewFull": "Full graph",
|
||||
"graphSearchSr": "Search nodes",
|
||||
"graphSearchPlaceholder": "Search nodes…",
|
||||
"graphRefresh": "Refresh",
|
||||
"graphCenter": "Center",
|
||||
"graphEmpty": "No graph data yet. Add links on finding/exploit facts (discovered_on → target/*) to build the path.",
|
||||
"graphEmptyTitle": "Build your attack path",
|
||||
"graphEmptyStep1": "Add target facts (domains, endpoints, scope)",
|
||||
"graphEmptyStep2": "Record findings/exploits with links between facts",
|
||||
"graphEmptyStep3": "Use Connect mode or edit facts to add relationships",
|
||||
"graphEmptyCta": "Add first fact",
|
||||
"graphStats": "Nodes: {{nodes}} | Edges: {{edges}}",
|
||||
"graphStatsNodes": "Nodes",
|
||||
"graphStatsEdges": "Edges",
|
||||
"graphLegendNodes": "Nodes",
|
||||
"graphLegendEdges": "Edges",
|
||||
"graphLegendNodeTarget": "TARGET",
|
||||
"graphLegendNodeInfra": "INFRA",
|
||||
"graphLegendNodeFinding": "FINDING",
|
||||
"graphLegendNodeVuln": "VULN",
|
||||
"graphLegendNodeExploit": "EXPLOIT",
|
||||
"graphLegendNodeMissing": "MISSING",
|
||||
"graphLegendDiscovered": "discovered_on",
|
||||
"graphLegendLeads": "leads_to",
|
||||
"graphLegendExploits": "exploits",
|
||||
"graphLegendTentative": "Tentative (dashed)",
|
||||
"factLinksLabel": "Links (from → this fact)",
|
||||
"factLinksPlaceholder": "discovered_on: target/primary_domain\nexploits: exploit/upload-rce",
|
||||
"factLinksHint": "One per line: type: source_fact_key (source → this fact). Common types: discovered_on, depends_on, leads_to, enables, exploits. Saving replaces all links.",
|
||||
"factIncomingLinksLabel": "Incoming links (read-only)",
|
||||
"factIncomingLinksHint": "Derived from outgoing links on source facts. e.g. finding discovered_on → target/* appears as incoming on the target; edit the source fact's outgoing links.",
|
||||
"factIncomingLinksEmpty": "No incoming links",
|
||||
"graphEdgeFromSelf": "From this node",
|
||||
"graphEdgeToSelf": "To this node",
|
||||
"linksColumn": "Links",
|
||||
"linkCountsTitle": "Outgoing / incoming edge counts",
|
||||
"graphConnect": "Connect",
|
||||
"graphConnectActive": "Connecting…",
|
||||
"graphConnectPickTarget": "Source {{source}} selected — click target node",
|
||||
"graphEdgeTypePrompt": "Edge type (discovered_on / leads_to / depends_on / enables / exploits)",
|
||||
"graphConnectFailed": "Failed to create edge",
|
||||
"graphConnectSuccess": "Edge created",
|
||||
"graphEdgesTitle": "Links",
|
||||
"graphEdgesHint": "Arrow direction matches the database and edit modal (source → target). Click an edge to focus it.",
|
||||
"graphEdgesEmpty": "No links yet",
|
||||
"graphEdgeOutgoing": "Outgoing",
|
||||
"graphEdgeIncoming": "Incoming",
|
||||
"graphEdgeSynthetic": "Auto-generated from fact link; edit the fact to remove",
|
||||
"confirmDeleteGraphEdge": "Delete this link?",
|
||||
"graphEdgeDeleteFailed": "Failed to delete edge",
|
||||
"graphEdgeDeleteSuccess": "Edge deleted",
|
||||
"graphDeleteEdge": "Delete",
|
||||
"viewVulnerability": "View vulnerability",
|
||||
"graphVulnSidebarHint": "Linked vulnerability node. Use the button below to open it in Vulnerability Management.",
|
||||
"promoteAttackChain": "Promote chain",
|
||||
"promoteAttackChainTitle": "Promote conversation attack chain to project facts",
|
||||
"confirmPromoteAttackChain": "Promote this conversation's attack chain into the project? Facts and edges will be created or updated.",
|
||||
"promoteAttackChainFailed": "Promote failed",
|
||||
"promoteAttackChainSuccess": "Promoted: {{facts_created}} new / {{facts_updated}} updated / {{edges_created}} edges",
|
||||
"searchFactsSr": "Search facts",
|
||||
"searchFactsPlaceholder": "Search key, summary, body…",
|
||||
"category": "Category",
|
||||
@@ -467,6 +530,8 @@
|
||||
"noMatchTools": "No matching tools",
|
||||
"penetrationTestDetail": "Penetration test details",
|
||||
"expandDetail": "Expand details",
|
||||
"toolExecutionsCount": "{{n}} tool runs",
|
||||
"collapseToolExecutions": "Collapse tool runs",
|
||||
"noProcessDetail": "No process details (execution may be too fast or no detailed events)",
|
||||
"copyMessageTitle": "Copy message",
|
||||
"deleteTurnTitle": "Delete this turn",
|
||||
@@ -1593,6 +1658,7 @@
|
||||
"rateWarning": "Some failures detected",
|
||||
"rateCritical": "High failure rate",
|
||||
"statsSubtitle": "Refreshed {{time}} · {{count}} tools",
|
||||
"retentionHint": "Execution records are kept for {{days}} days, then purged automatically.",
|
||||
"timelineTitle": "Call trend",
|
||||
"timelineHint": "All tools combined (not split by tool)",
|
||||
"timelineRange24h": "24h",
|
||||
@@ -2514,6 +2580,8 @@
|
||||
"agentModeSingle": "Single-agent (Eino ADK)",
|
||||
"agentModeMulti": "Multi-agent (Eino)",
|
||||
"agentModeHint": "Same as chat: Eino single-agent (ADK), or Deep / Plan-Execute / Supervisor (last three require multi_agent.enabled).",
|
||||
"concurrency": "Concurrency",
|
||||
"concurrencyHint": "Number of subtasks to run in parallel (1-8). Default 1 is serial; use 1-2 for scan-heavy tasks.",
|
||||
"scheduleMode": "Schedule mode",
|
||||
"scheduleModeManual": "Manual",
|
||||
"scheduleModeCron": "Cron expression",
|
||||
@@ -2528,8 +2596,8 @@
|
||||
"tasksList": "Task list (one task per line)",
|
||||
"tasksListPlaceholder": "Enter task list, one per line",
|
||||
"tasksListPlaceholderExample": "Enter task list, one per line, for example:\nScan open ports of 192.168.1.1\nCheck if https://example.com has SQL injection\nEnumerate subdomains of example.com",
|
||||
"tasksListHint": "Enter one task command per line; the system will execute them in order. Empty lines are ignored.",
|
||||
"tasksListHintFull": "Hint: Enter one task command per line; the system will execute these tasks in order. Empty lines are ignored.",
|
||||
"tasksListHint": "Enter one task command per line; the system runs them via a concurrency pool. Empty lines are ignored.",
|
||||
"tasksListHintFull": "Hint: Enter one task command per line; the system runs them via a concurrency pool. Empty lines are ignored.",
|
||||
"createQueue": "Create queue"
|
||||
},
|
||||
"batchQueueDetailModal": {
|
||||
@@ -2563,6 +2631,8 @@
|
||||
"scheduleToggleFailed": "Failed to update schedule toggle",
|
||||
"completedAt": "Completed at",
|
||||
"taskTotal": "Total tasks",
|
||||
"concurrency": "Concurrency",
|
||||
"concurrencyEditHint": "Click to edit. Cannot change while the queue is running.",
|
||||
"taskList": "Task list",
|
||||
"startLabel": "Start",
|
||||
"completeLabel": "Complete",
|
||||
|
||||
@@ -246,10 +246,73 @@
|
||||
"vulnerabilityManagement": "漏洞管理",
|
||||
"addFactCta": "+ 添加事实",
|
||||
"tabFacts": "事实黑板",
|
||||
"tabGraph": "攻击路径",
|
||||
"tabConversations": "关联对话",
|
||||
"tabVulns": "关联漏洞",
|
||||
"tabSettings": "设置",
|
||||
"factToolbarHint": "索引仅含 key 与摘要(须含「什么 + 在哪 + 如何验证」);攻击链 / POC 写在 body,Agent 通过 get_project_fact 复现",
|
||||
"graphToolbarHint": "攻击路径图箭头与事实存储方向一致(source → target);节点按 target→infra→finding→exploit 分层排布。虚线边为待确认。",
|
||||
"graphView": "视图",
|
||||
"graphViewPath": "攻击路径",
|
||||
"graphViewFull": "完整关系",
|
||||
"graphSearchSr": "搜索节点",
|
||||
"graphSearchPlaceholder": "搜索节点…",
|
||||
"graphRefresh": "刷新",
|
||||
"graphCenter": "居中",
|
||||
"graphEmpty": "暂无路径图数据。为 finding/exploit 类事实添加关系边(discovered_on → target/*)后将在此展示。",
|
||||
"graphEmptyTitle": "构建攻击路径图",
|
||||
"graphEmptyStep1": "添加 target 类事实(目标、域名、入口)",
|
||||
"graphEmptyStep2": "记录 finding / exploit 并在 links 中连边",
|
||||
"graphEmptyStep3": "使用「连边」模式或编辑事实手动补关系",
|
||||
"graphEmptyCta": "添加第一条事实",
|
||||
"graphStats": "节点: {{nodes}} | 边: {{edges}}",
|
||||
"graphStatsNodes": "节点",
|
||||
"graphStatsEdges": "边",
|
||||
"graphLegendNodes": "节点",
|
||||
"graphLegendEdges": "连线",
|
||||
"graphLegendNodeTarget": "TARGET · 目标",
|
||||
"graphLegendNodeInfra": "INFRA · 基础设施",
|
||||
"graphLegendNodeFinding": "FINDING · 发现",
|
||||
"graphLegendNodeVuln": "VULN · 漏洞",
|
||||
"graphLegendNodeExploit": "EXPLOIT · 利用",
|
||||
"graphLegendNodeMissing": "MISSING · 缺失",
|
||||
"graphLegendDiscovered": "discovered_on",
|
||||
"graphLegendLeads": "leads_to",
|
||||
"graphLegendExploits": "exploits",
|
||||
"graphLegendTentative": "待确认(虚线)",
|
||||
"factLinksLabel": "关系边(from → 本事实)",
|
||||
"factLinksPlaceholder": "discovered_on: target/primary_domain\nexploits: exploit/upload-rce",
|
||||
"factLinksHint": "每行一条:type: source_fact_key(来源 → 当前事实)。常用 type:discovered_on、depends_on、leads_to、enables、exploits。保存时替换全部关系边。",
|
||||
"factIncomingLinksLabel": "入边(只读)",
|
||||
"factIncomingLinksHint": "由来源事实的出边产生。例如 finding 的 discovered_on → target/*,在目标上会显示为入边;请编辑来源事实的出边。",
|
||||
"factIncomingLinksEmpty": "暂无入边",
|
||||
"graphEdgeFromSelf": "本节点指出",
|
||||
"graphEdgeToSelf": "指向本节点",
|
||||
"linksColumn": "关系",
|
||||
"linkCountsTitle": "出边数 / 入边数",
|
||||
"graphConnect": "连边",
|
||||
"graphConnectActive": "连边中…",
|
||||
"graphConnectPickTarget": "已选 {{source}},请点击目标节点",
|
||||
"graphEdgeTypePrompt": "边类型(discovered_on / leads_to / depends_on / enables / exploits)",
|
||||
"graphConnectFailed": "创建边失败",
|
||||
"graphConnectSuccess": "边已创建",
|
||||
"graphEdgesTitle": "关系边",
|
||||
"graphEdgesHint": "箭头方向与数据库/编辑弹窗一致(source → target);点击连线可定位。",
|
||||
"graphEdgesEmpty": "暂无关系边",
|
||||
"graphEdgeOutgoing": "出边",
|
||||
"graphEdgeIncoming": "入边",
|
||||
"graphEdgeSynthetic": "由事实关联自动生成,请编辑事实解除",
|
||||
"confirmDeleteGraphEdge": "确定删除此关系边?",
|
||||
"graphEdgeDeleteFailed": "删除边失败",
|
||||
"graphEdgeDeleteSuccess": "边已删除",
|
||||
"graphDeleteEdge": "删边",
|
||||
"viewVulnerability": "查看漏洞",
|
||||
"graphVulnSidebarHint": "关联漏洞节点,点击下方按钮在漏洞管理中查看详情。",
|
||||
"promoteAttackChain": "沉淀攻击链",
|
||||
"promoteAttackChainTitle": "将对话攻击链沉淀为项目事实与边",
|
||||
"confirmPromoteAttackChain": "将该对话的攻击链沉淀到本项目?会创建/更新事实与关系边。",
|
||||
"promoteAttackChainFailed": "沉淀失败",
|
||||
"promoteAttackChainSuccess": "已沉淀:新建 {{facts_created}} / 更新 {{facts_updated}} / 边 {{edges_created}}",
|
||||
"searchFactsSr": "搜索事实",
|
||||
"searchFactsPlaceholder": "搜索 key、摘要、body…",
|
||||
"category": "分类",
|
||||
@@ -455,6 +518,8 @@
|
||||
"noMatchTools": "没有匹配的工具",
|
||||
"penetrationTestDetail": "渗透测试详情",
|
||||
"expandDetail": "展开详情",
|
||||
"toolExecutionsCount": "{{n}}次工具执行",
|
||||
"collapseToolExecutions": "收起工具执行",
|
||||
"noProcessDetail": "暂无过程详情(可能执行过快或未触发详细事件)",
|
||||
"copyMessageTitle": "复制消息内容",
|
||||
"deleteTurnTitle": "删除本轮对话",
|
||||
@@ -1581,6 +1646,7 @@
|
||||
"rateWarning": "存在失败调用",
|
||||
"rateCritical": "失败率偏高",
|
||||
"statsSubtitle": "最后刷新 {{time}} · 共 {{count}} 个工具",
|
||||
"retentionHint": "执行记录保留 {{days}} 天,超期自动清理",
|
||||
"timelineTitle": "调用趋势",
|
||||
"timelineHint": "全部工具合计,不按工具拆分",
|
||||
"timelineRange24h": "24 小时",
|
||||
@@ -2502,6 +2568,8 @@
|
||||
"agentModeSingle": "单代理(Eino ADK)",
|
||||
"agentModeMulti": "多代理(Eino)",
|
||||
"agentModeHint": "与对话页一致:Eino 单代理(ADK),或 Deep / Plan-Execute / Supervisor(后三种需已启用多代理)。",
|
||||
"concurrency": "并发数",
|
||||
"concurrencyHint": "同时执行的子任务数量(1-8)。默认 1 为串行;含扫描类工具时建议 1-2。",
|
||||
"scheduleMode": "调度方式",
|
||||
"scheduleModeManual": "手工执行",
|
||||
"scheduleModeCron": "调度表达式(Cron)",
|
||||
@@ -2516,8 +2584,8 @@
|
||||
"tasksList": "任务列表(每行一个任务)",
|
||||
"tasksListPlaceholder": "请输入任务列表,每行一个任务",
|
||||
"tasksListPlaceholderExample": "请输入任务列表,每行一个任务,例如:\n扫描 192.168.1.1 的开放端口\n检查 https://example.com 是否存在SQL注入\n枚举 example.com 的子域名",
|
||||
"tasksListHint": "每行输入一个任务指令,系统将依次执行这些任务。空行会被自动忽略。",
|
||||
"tasksListHintFull": "提示:每行输入一个任务指令,系统将依次执行这些任务。空行会被自动忽略。",
|
||||
"tasksListHint": "每行输入一个任务指令,系统将按并发池执行这些任务。空行会被自动忽略。",
|
||||
"tasksListHintFull": "提示:每行输入一个任务指令,系统将按并发池执行这些任务。空行会被自动忽略。",
|
||||
"createQueue": "创建队列"
|
||||
},
|
||||
"batchQueueDetailModal": {
|
||||
@@ -2551,6 +2619,8 @@
|
||||
"scheduleToggleFailed": "更新调度开关失败",
|
||||
"completedAt": "完成时间",
|
||||
"taskTotal": "任务总数",
|
||||
"concurrency": "并发数",
|
||||
"concurrencyEditHint": "点击可修改;队列运行中不可改。",
|
||||
"taskList": "任务列表",
|
||||
"startLabel": "开始",
|
||||
"completeLabel": "完成",
|
||||
|
||||
+450
-108
@@ -1,6 +1,39 @@
|
||||
let currentConversationId = null;
|
||||
let loadConversationRequestSeq = 0;
|
||||
|
||||
/** 轻量会话 LRU 缓存:来回切换已加载会话时避免重复网络 + 全量 DOM 重建 */
|
||||
const CONVERSATION_LITE_CACHE_MAX = 12;
|
||||
const conversationLiteCache = new Map();
|
||||
|
||||
function getConversationLiteFromCache(conversationId) {
|
||||
if (!conversationId) return null;
|
||||
const hit = conversationLiteCache.get(conversationId);
|
||||
if (!hit) return null;
|
||||
conversationLiteCache.delete(conversationId);
|
||||
conversationLiteCache.set(conversationId, hit);
|
||||
return hit;
|
||||
}
|
||||
|
||||
function putConversationLiteCache(conversationId, data) {
|
||||
if (!conversationId || !data) return;
|
||||
conversationLiteCache.delete(conversationId);
|
||||
conversationLiteCache.set(conversationId, data);
|
||||
while (conversationLiteCache.size > CONVERSATION_LITE_CACHE_MAX) {
|
||||
const oldest = conversationLiteCache.keys().next().value;
|
||||
conversationLiteCache.delete(oldest);
|
||||
}
|
||||
}
|
||||
|
||||
function invalidateConversationLiteCache(conversationId) {
|
||||
if (conversationId) {
|
||||
conversationLiteCache.delete(conversationId);
|
||||
} else {
|
||||
conversationLiteCache.clear();
|
||||
}
|
||||
}
|
||||
|
||||
window.invalidateConversationLiteCache = invalidateConversationLiteCache;
|
||||
|
||||
// @ 提及相关状态
|
||||
let mentionTools = [];
|
||||
let mentionToolsLoaded = false;
|
||||
@@ -886,6 +919,9 @@ async function sendMessage() {
|
||||
window.CyberStrikeChatScroll.onUserSendMessage();
|
||||
}
|
||||
addMessage('user', displayMessage, null, null, null, { scroll: 'none' });
|
||||
if (currentConversationId) {
|
||||
invalidateConversationLiteCache(currentConversationId);
|
||||
}
|
||||
|
||||
// 清除防抖定时器,防止在清空输入框后重新保存草稿
|
||||
if (draftSaveTimer) {
|
||||
@@ -2027,31 +2063,13 @@ function addMessage(role, content, mcpExecutionIds = null, progressId = null, cr
|
||||
|
||||
// 有 MCP 执行记录且非流式占位消息时展示调用按钮;带 progressId 的流式占位不挂此条(与进度卡片一致,结束时 integrate 再创建)
|
||||
if (role === 'assistant' && (mcpExecutionIds && Array.isArray(mcpExecutionIds) && mcpExecutionIds.length > 0) && !progressId) {
|
||||
const mcpSection = document.createElement('div');
|
||||
mcpSection.className = 'mcp-call-section';
|
||||
|
||||
const mcpLabel = document.createElement('div');
|
||||
mcpLabel.className = 'mcp-call-label';
|
||||
mcpLabel.textContent = '📋 ' + (typeof window.t === 'function' ? window.t('chat.penetrationTestDetail') : '渗透测试详情');
|
||||
mcpSection.appendChild(mcpLabel);
|
||||
|
||||
const buttonsContainer = document.createElement('div');
|
||||
buttonsContainer.className = 'mcp-call-buttons';
|
||||
|
||||
mcpExecutionIds.forEach((execId, index) => {
|
||||
const detailBtn = document.createElement('button');
|
||||
detailBtn.className = 'mcp-detail-btn';
|
||||
detailBtn.dataset.execId = execId;
|
||||
detailBtn.dataset.execIndex = String(index + 1);
|
||||
detailBtn.innerHTML = '<span>' + (typeof window.t === 'function' ? window.t('chat.callNumber', { n: index + 1 }) : '调用 #' + (index + 1)) + '</span>';
|
||||
detailBtn.onclick = () => showMCPDetail(execId);
|
||||
buttonsContainer.appendChild(detailBtn);
|
||||
});
|
||||
// 使用批量 API 一次性获取所有工具名称(消除 N 次单独请求)
|
||||
batchUpdateButtonToolNames(buttonsContainer, mcpExecutionIds);
|
||||
|
||||
mcpSection.appendChild(buttonsContainer);
|
||||
contentWrapper.appendChild(mcpSection);
|
||||
if (options && options.deferMcpButtons) {
|
||||
try {
|
||||
messageDiv.dataset.pendingMcpExecutionIds = JSON.stringify(mcpExecutionIds);
|
||||
} catch (e) { /* ignore */ }
|
||||
} else {
|
||||
appendMcpCallButtons(messageDiv, mcpExecutionIds);
|
||||
}
|
||||
}
|
||||
|
||||
messageDiv.appendChild(contentWrapper);
|
||||
@@ -2151,11 +2169,13 @@ function copyMessageToClipboard(messageDiv, button) {
|
||||
function showCopySuccess(button) {
|
||||
if (button) {
|
||||
const originalText = button.innerHTML;
|
||||
button.dataset.copySuccessActive = '1';
|
||||
button.innerHTML = '<svg width="16" height="16" viewBox="0 0 24 24" fill="none" xmlns="http://www.w3.org/2000/svg"><path d="M20 6L9 17l-5-5" stroke="currentColor" stroke-width="2.5" stroke-linecap="round" stroke-linejoin="round" fill="none"/></svg><span>' + (typeof window.t === 'function' ? window.t('common.copied') : '已复制') + '</span>';
|
||||
button.style.color = '#10b981';
|
||||
button.style.background = 'rgba(16, 185, 129, 0.1)';
|
||||
button.style.borderColor = 'rgba(16, 185, 129, 0.3)';
|
||||
setTimeout(() => {
|
||||
delete button.dataset.copySuccessActive;
|
||||
button.innerHTML = originalText;
|
||||
button.style.color = '';
|
||||
button.style.background = '';
|
||||
@@ -2252,10 +2272,22 @@ async function syncAssistantReasoningContentFromServer(backendMessageId, domAssi
|
||||
window.normalizeReasoningContentForDisplay = normalizeReasoningContentForDisplay;
|
||||
window.setMessageReasoningContent = setMessageReasoningContent;
|
||||
window.getMessageReasoningContent = getMessageReasoningContent;
|
||||
window.filterNoiseProcessDetails = filterNoiseProcessDetails;
|
||||
window.mergeMessageReasoningContentIntoProcessDetails = mergeMessageReasoningContentIntoProcessDetails;
|
||||
window.syncAssistantReasoningContentFromServer = syncAssistantReasoningContentFromServer;
|
||||
|
||||
/** 相邻且类型/正文/data 完全一致的过程详情只保留一条(与后端去重一致,避免时间线叠多条相同块) */
|
||||
function isEinoAgentHeartbeatProgress(detail) {
|
||||
if (!detail || detail.eventType !== 'progress') return false;
|
||||
const msg = String(detail.message != null ? detail.message : '').trim();
|
||||
return /^\[Eino\]\s+\S/.test(msg);
|
||||
}
|
||||
|
||||
function filterNoiseProcessDetails(details) {
|
||||
if (!Array.isArray(details)) return details;
|
||||
return details.filter(function (d) { return !isEinoAgentHeartbeatProgress(d); });
|
||||
}
|
||||
|
||||
function dedupeConsecutiveProcessDetailRows(details) {
|
||||
if (!Array.isArray(details) || details.length < 2) {
|
||||
return details;
|
||||
@@ -2289,47 +2321,20 @@ function processDetailRowFingerprint(d) {
|
||||
}
|
||||
|
||||
// 渲染过程详情
|
||||
function renderProcessDetails(messageId, processDetails) {
|
||||
// options.append=true 时分页追加;options.markLoaded=false 时保留 lazy 标记(分页加载中)
|
||||
function renderProcessDetails(messageId, processDetails, options) {
|
||||
const renderOpts = options || {};
|
||||
const appendMode = !!renderOpts.append;
|
||||
const markLoaded = renderOpts.markLoaded !== false;
|
||||
const messageElement = document.getElementById(messageId);
|
||||
if (!messageElement) {
|
||||
return;
|
||||
}
|
||||
|
||||
// 查找或创建MCP调用区域
|
||||
let mcpSection = messageElement.querySelector('.mcp-call-section');
|
||||
if (!mcpSection) {
|
||||
mcpSection = document.createElement('div');
|
||||
mcpSection.className = 'mcp-call-section';
|
||||
|
||||
const contentWrapper = messageElement.querySelector('.message-content');
|
||||
if (contentWrapper) {
|
||||
contentWrapper.appendChild(mcpSection);
|
||||
} else {
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
// 确保有标签和按钮容器(统一结构)
|
||||
let mcpLabel = mcpSection.querySelector('.mcp-call-label');
|
||||
let buttonsContainer = mcpSection.querySelector('.mcp-call-buttons');
|
||||
|
||||
// 如果没有标签,创建一个(当没有工具调用时)
|
||||
if (!mcpLabel && !buttonsContainer) {
|
||||
mcpLabel = document.createElement('div');
|
||||
mcpLabel.className = 'mcp-call-label';
|
||||
mcpLabel.textContent = '📋 ' + (typeof window.t === 'function' ? window.t('chat.penetrationTestDetail') : '渗透测试详情');
|
||||
mcpSection.appendChild(mcpLabel);
|
||||
} else if (mcpLabel && mcpLabel.textContent !== ('📋 ' + (typeof window.t === 'function' ? window.t('chat.penetrationTestDetail') : '渗透测试详情'))) {
|
||||
// 如果标签存在但不是统一格式,更新它
|
||||
mcpLabel.textContent = '📋 ' + (typeof window.t === 'function' ? window.t('chat.penetrationTestDetail') : '渗透测试详情');
|
||||
}
|
||||
|
||||
// 如果没有按钮容器,创建一个
|
||||
if (!buttonsContainer) {
|
||||
buttonsContainer = document.createElement('div');
|
||||
buttonsContainer.className = 'mcp-call-buttons';
|
||||
mcpSection.appendChild(buttonsContainer);
|
||||
}
|
||||
// 查找或创建 MCP 区域(工具栏 + 工具列表 + 迭代时间线 分区)
|
||||
const chrome = ensureMcpCallSectionChrome(messageElement, messageId);
|
||||
if (!chrome) return;
|
||||
const { mcpSection, toolbar: buttonsContainer } = chrome;
|
||||
|
||||
// 添加过程详情按钮(如果还没有)
|
||||
let processDetailBtn = buttonsContainer.querySelector('.process-detail-btn');
|
||||
@@ -2340,17 +2345,20 @@ function renderProcessDetails(messageId, processDetails) {
|
||||
processDetailBtn.onclick = () => toggleProcessDetails(null, messageId);
|
||||
buttonsContainer.appendChild(processDetailBtn);
|
||||
}
|
||||
syncMcpToolsToggleButton(messageElement);
|
||||
|
||||
// 创建过程详情容器(放在按钮容器之后)
|
||||
// 创建过程详情容器(放在工具列表之后)
|
||||
const detailsId = 'process-details-' + messageId;
|
||||
let detailsContainer = document.getElementById(detailsId);
|
||||
const toolListEl = chrome.toolList;
|
||||
|
||||
if (!detailsContainer) {
|
||||
detailsContainer = document.createElement('div');
|
||||
detailsContainer.id = detailsId;
|
||||
detailsContainer.className = 'process-details-container';
|
||||
// 确保容器在按钮容器之后
|
||||
if (buttonsContainer.nextSibling) {
|
||||
if (toolListEl) {
|
||||
toolListEl.after(detailsContainer);
|
||||
} else if (buttonsContainer.nextSibling) {
|
||||
mcpSection.insertBefore(detailsContainer, buttonsContainer.nextSibling);
|
||||
} else {
|
||||
mcpSection.appendChild(detailsContainer);
|
||||
@@ -2379,36 +2387,42 @@ function renderProcessDetails(messageId, processDetails) {
|
||||
if (isLazyNotLoaded && !reasoningFromMessage) {
|
||||
detailsContainer.dataset.lazyNotLoaded = '1';
|
||||
detailsContainer.dataset.loaded = '0';
|
||||
timeline.innerHTML = '<div class="progress-timeline-empty">' +
|
||||
(typeof window.t === 'function' ? window.t('chat.expandDetail') : '展开详情') +
|
||||
'(点击后加载)</div>';
|
||||
const expandLabel = typeof window.t === 'function' ? window.t('chat.expandDetail') : '展开详情';
|
||||
let lazyHint = expandLabel + '(点击后加载迭代详情)';
|
||||
timeline.innerHTML = '<div class="progress-timeline-empty">' + lazyHint + '</div>';
|
||||
timeline.classList.remove('expanded');
|
||||
prefetchProcessDetailsSummaryHint(messageId, messageElement);
|
||||
return;
|
||||
}
|
||||
if (isLazyNotLoaded) {
|
||||
detailsContainer.dataset.lazyNotLoaded = '1';
|
||||
detailsContainer.dataset.loaded = '0';
|
||||
processDetails = [];
|
||||
} else {
|
||||
if (!appendMode) {
|
||||
prefetchProcessDetailsSummaryHint(messageId, messageElement);
|
||||
}
|
||||
} else if (markLoaded) {
|
||||
detailsContainer.dataset.lazyNotLoaded = '0';
|
||||
detailsContainer.dataset.loaded = '1';
|
||||
}
|
||||
processDetails = mergeMessageReasoningContentIntoProcessDetails(processDetails, reasoningFromMessage);
|
||||
processDetails = filterNoiseProcessDetails(processDetails);
|
||||
processDetails = dedupeConsecutiveProcessDetailRows(processDetails);
|
||||
if (typeof window.coalesceProcessDetailsToolPairs === 'function') {
|
||||
processDetails = window.coalesceProcessDetailsToolPairs(processDetails);
|
||||
}
|
||||
// 如果没有processDetails或为空,显示空状态
|
||||
if (!processDetails || processDetails.length === 0) {
|
||||
// 显示空状态提示
|
||||
timeline.innerHTML = '<div class="progress-timeline-empty">' + (typeof window.t === 'function' ? window.t('chat.noProcessDetail') : '暂无过程详情(可能执行过快或未触发详细事件)') + '</div>';
|
||||
// 默认折叠
|
||||
timeline.classList.remove('expanded');
|
||||
if (!appendMode) {
|
||||
timeline.innerHTML = '<div class="progress-timeline-empty">' + (typeof window.t === 'function' ? window.t('chat.noProcessDetail') : '暂无过程详情(可能执行过快或未触发详细事件)') + '</div>';
|
||||
timeline.classList.remove('expanded');
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
// 清空时间线并重新渲染
|
||||
timeline.innerHTML = '';
|
||||
if (!appendMode) {
|
||||
timeline.innerHTML = '';
|
||||
}
|
||||
|
||||
|
||||
function processDetailAgentPrefix(d) {
|
||||
@@ -2417,14 +2431,12 @@ function renderProcessDetails(messageId, processDetails) {
|
||||
return s ? ('[' + s + '] ') : '';
|
||||
}
|
||||
|
||||
// 渲染每个过程详情事件
|
||||
processDetails.forEach(detail => {
|
||||
function renderOneProcessDetail(detail) {
|
||||
const eventType = detail.eventType || '';
|
||||
const title = detail.message || '';
|
||||
const data = detail.data || {};
|
||||
const agPx = processDetailAgentPrefix(data);
|
||||
|
||||
// 根据事件类型渲染不同的内容
|
||||
let itemTitle = title;
|
||||
if (eventType === 'iteration') {
|
||||
const n = data.iteration || 1;
|
||||
@@ -2517,15 +2529,38 @@ function renderProcessDetails(messageId, processDetails) {
|
||||
title: itemTitle,
|
||||
message: detail.message || '',
|
||||
data: data,
|
||||
createdAt: detail.createdAt // 传递实际的事件创建时间
|
||||
createdAt: detail.createdAt
|
||||
};
|
||||
if (eventType === 'tool_call' && data._mergedResult) {
|
||||
timelineOpts.mergedResult = data._mergedResult;
|
||||
}
|
||||
addTimelineItem(timeline, eventType, timelineOpts);
|
||||
});
|
||||
}
|
||||
|
||||
if (isLazyNotLoaded && reasoningFromMessage) {
|
||||
const TIMELINE_RENDER_BATCH = 40;
|
||||
const renderTimelineBatch = (startIdx) => {
|
||||
const endIdx = Math.min(startIdx + TIMELINE_RENDER_BATCH, processDetails.length);
|
||||
for (let i = startIdx; i < endIdx; i++) {
|
||||
renderOneProcessDetail(processDetails[i]);
|
||||
}
|
||||
if (endIdx < processDetails.length) {
|
||||
requestAnimationFrame(() => renderTimelineBatch(endIdx));
|
||||
} else if (markLoaded) {
|
||||
finishProcessDetailsRender(messageElement, processDetails, isLazyNotLoaded, timeline);
|
||||
}
|
||||
};
|
||||
if (processDetails.length > TIMELINE_RENDER_BATCH) {
|
||||
renderTimelineBatch(0);
|
||||
} else {
|
||||
processDetails.forEach(renderOneProcessDetail);
|
||||
if (markLoaded) {
|
||||
finishProcessDetailsRender(messageElement, processDetails, isLazyNotLoaded, timeline);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
function finishProcessDetailsRender(messageElement, processDetails, isLazyNotLoaded, timeline) {
|
||||
if (isLazyNotLoaded && getMessageReasoningContent(messageElement)) {
|
||||
const lazyHint = document.createElement('div');
|
||||
lazyHint.className = 'progress-timeline-empty progress-timeline-lazy-hint';
|
||||
lazyHint.textContent = (typeof window.t === 'function' ? window.t('chat.expandDetail') : '展开详情') +
|
||||
@@ -2533,15 +2568,12 @@ function renderProcessDetails(messageId, processDetails) {
|
||||
timeline.appendChild(lazyHint);
|
||||
}
|
||||
|
||||
// 检查是否有错误或取消事件,如果有,确保详情默认折叠(但仍有待审批 HITL 时保持展开,由 restoreHitlInlineForConversation 处理)
|
||||
const hasPendingHitlInDetails = processDetails.some(d => d && d.eventType === 'hitl_interrupt');
|
||||
const hasErrorOrCancelled = processDetails.some(d =>
|
||||
d.eventType === 'error' || d.eventType === 'cancelled'
|
||||
);
|
||||
if (hasErrorOrCancelled && !hasPendingHitlInDetails) {
|
||||
// 确保时间线是折叠的
|
||||
timeline.classList.remove('expanded');
|
||||
// 更新按钮文本为"展开详情"
|
||||
const processDetailBtn = messageElement.querySelector('.process-detail-btn');
|
||||
if (processDetailBtn) {
|
||||
processDetailBtn.innerHTML = '<span>' + (typeof window.t === 'function' ? window.t('chat.expandDetail') : '展开详情') + '</span>';
|
||||
@@ -2549,6 +2581,36 @@ function renderProcessDetails(messageId, processDetails) {
|
||||
}
|
||||
}
|
||||
|
||||
/** 懒加载折叠态:后台拉摘要,提示迭代规模而不加载全量详情 */
|
||||
function prefetchProcessDetailsSummaryHint(messageId, messageElement) {
|
||||
if (!messageElement || !messageElement.dataset || !messageElement.dataset.backendMessageId) return;
|
||||
const backendId = String(messageElement.dataset.backendMessageId).trim();
|
||||
if (!backendId || typeof apiFetch !== 'function') return;
|
||||
const detailsContainer = document.getElementById('process-details-' + messageId);
|
||||
if (!detailsContainer || detailsContainer.dataset.summaryFetched === '1') return;
|
||||
detailsContainer.dataset.summaryFetched = '1';
|
||||
apiFetch('/api/messages/' + encodeURIComponent(backendId) + '/process-details?summary=1')
|
||||
.then(async (res) => {
|
||||
const j = await res.json().catch(() => ({}));
|
||||
if (!res.ok || !j.summary) return;
|
||||
const s = j.summary;
|
||||
const timeline = detailsContainer.querySelector('.progress-timeline');
|
||||
if (!timeline || detailsContainer.dataset.loaded === '1') return;
|
||||
const expandLabel = typeof window.t === 'function' ? window.t('chat.expandDetail') : '展开详情';
|
||||
let hint = expandLabel + '(点击后加载迭代详情)';
|
||||
if (s.maxIteration > 0) {
|
||||
hint = expandLabel + '(共 ' + s.maxIteration + ' 轮迭代,' + (s.total || 0) + ' 条详情)';
|
||||
} else if (s.total > 0) {
|
||||
hint = expandLabel + '(共 ' + (s.total || 0) + ' 条详情)';
|
||||
}
|
||||
const empty = timeline.querySelector('.progress-timeline-empty');
|
||||
if (empty) {
|
||||
empty.textContent = hint;
|
||||
}
|
||||
})
|
||||
.catch(() => {});
|
||||
}
|
||||
|
||||
// 移除消息
|
||||
function removeMessage(id) {
|
||||
const messageDiv = document.getElementById(id);
|
||||
@@ -2610,6 +2672,201 @@ async function updateButtonWithToolName(button, executionId, index) {
|
||||
}
|
||||
}
|
||||
|
||||
function getPendingMcpExecutionCount(messageElement) {
|
||||
if (!messageElement || !messageElement.dataset || !messageElement.dataset.pendingMcpExecutionIds) {
|
||||
return 0;
|
||||
}
|
||||
try {
|
||||
const ids = JSON.parse(messageElement.dataset.pendingMcpExecutionIds);
|
||||
return Array.isArray(ids) ? ids.length : 0;
|
||||
} catch (e) {
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
|
||||
function getMcpExecutionCount(messageElement) {
|
||||
const pending = getPendingMcpExecutionCount(messageElement);
|
||||
if (pending > 0) return pending;
|
||||
const toolList = messageElement && messageElement.querySelector('.mcp-tool-list');
|
||||
if (toolList) {
|
||||
return toolList.querySelectorAll('.mcp-detail-btn[data-exec-id]').length;
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
function formatMcpToolsToggleLabel(count, expanded) {
|
||||
if (expanded) {
|
||||
if (typeof window.t === 'function') {
|
||||
const s = window.t('chat.collapseToolExecutions');
|
||||
if (s && s !== 'chat.collapseToolExecutions') return s;
|
||||
}
|
||||
return '收起工具执行';
|
||||
}
|
||||
if (typeof window.t === 'function') {
|
||||
const s = window.t('chat.toolExecutionsCount', { n: count });
|
||||
if (s && s !== 'chat.toolExecutionsCount') return s;
|
||||
}
|
||||
return count + '次工具执行';
|
||||
}
|
||||
|
||||
/** 渗透测试区:工具栏(展开详情 | N次工具执行)+ 独立工具列表 + 迭代时间线 */
|
||||
function ensureMcpCallSectionChrome(messageElement, messageId) {
|
||||
const contentWrapper = messageElement && messageElement.querySelector('.message-content');
|
||||
if (!contentWrapper) return null;
|
||||
|
||||
let mcpSection = messageElement.querySelector('.mcp-call-section');
|
||||
if (!mcpSection) {
|
||||
mcpSection = document.createElement('div');
|
||||
mcpSection.className = 'mcp-call-section';
|
||||
const mcpLabel = document.createElement('div');
|
||||
mcpLabel.className = 'mcp-call-label';
|
||||
mcpLabel.textContent = '📋 ' + (typeof window.t === 'function' ? window.t('chat.penetrationTestDetail') : '渗透测试详情');
|
||||
mcpSection.appendChild(mcpLabel);
|
||||
contentWrapper.appendChild(mcpSection);
|
||||
} else {
|
||||
const mcpLabel = mcpSection.querySelector('.mcp-call-label');
|
||||
const labelText = '📋 ' + (typeof window.t === 'function' ? window.t('chat.penetrationTestDetail') : '渗透测试详情');
|
||||
if (mcpLabel && mcpLabel.textContent !== labelText) {
|
||||
mcpLabel.textContent = labelText;
|
||||
}
|
||||
}
|
||||
|
||||
let toolbar = mcpSection.querySelector('.mcp-call-toolbar');
|
||||
const legacyButtons = mcpSection.querySelector('.mcp-call-buttons');
|
||||
if (!toolbar) {
|
||||
toolbar = document.createElement('div');
|
||||
toolbar.className = 'mcp-call-toolbar';
|
||||
if (legacyButtons) {
|
||||
const processBtn = legacyButtons.querySelector('.process-detail-btn');
|
||||
if (processBtn) toolbar.appendChild(processBtn);
|
||||
mcpSection.replaceChild(toolbar, legacyButtons);
|
||||
} else {
|
||||
mcpSection.appendChild(toolbar);
|
||||
}
|
||||
}
|
||||
|
||||
let toolList = mcpSection.querySelector('.mcp-tool-list');
|
||||
if (!toolList) {
|
||||
toolList = document.createElement('div');
|
||||
toolList.className = 'mcp-tool-list';
|
||||
const detailsContainer = mcpSection.querySelector('.process-details-container');
|
||||
if (detailsContainer) {
|
||||
mcpSection.insertBefore(toolList, detailsContainer);
|
||||
} else {
|
||||
toolbar.after(toolList);
|
||||
}
|
||||
}
|
||||
|
||||
if (legacyButtons && legacyButtons.parentNode === mcpSection) {
|
||||
legacyButtons.querySelectorAll('.mcp-detail-btn[data-exec-id]').forEach((btn) => toolList.appendChild(btn));
|
||||
legacyButtons.remove();
|
||||
}
|
||||
|
||||
const clientId = messageId || messageElement.id;
|
||||
if (clientId && !toolbar.querySelector('.process-detail-btn')) {
|
||||
const processDetailBtn = document.createElement('button');
|
||||
processDetailBtn.className = 'mcp-detail-btn process-detail-btn';
|
||||
processDetailBtn.innerHTML = '<span>' + (typeof window.t === 'function' ? window.t('chat.expandDetail') : '展开详情') + '</span>';
|
||||
processDetailBtn.onclick = () => toggleProcessDetails(null, clientId);
|
||||
toolbar.appendChild(processDetailBtn);
|
||||
}
|
||||
|
||||
return { mcpSection, toolbar, toolList };
|
||||
}
|
||||
|
||||
function syncMcpToolsToggleButton(messageElement) {
|
||||
if (!messageElement) return;
|
||||
const chrome = ensureMcpCallSectionChrome(messageElement, messageElement.id);
|
||||
if (!chrome) return;
|
||||
const { toolbar, toolList } = chrome;
|
||||
const count = getMcpExecutionCount(messageElement);
|
||||
let toolsToggle = toolbar.querySelector('.mcp-tools-toggle-btn');
|
||||
if (count <= 0) {
|
||||
if (toolsToggle) toolsToggle.remove();
|
||||
return;
|
||||
}
|
||||
if (!toolsToggle) {
|
||||
toolsToggle = document.createElement('button');
|
||||
toolsToggle.type = 'button';
|
||||
toolsToggle.className = 'mcp-detail-btn mcp-tools-toggle-btn';
|
||||
toolsToggle.onclick = function (e) {
|
||||
e.stopPropagation();
|
||||
toggleMcpToolList(messageElement.id);
|
||||
};
|
||||
toolbar.appendChild(toolsToggle);
|
||||
}
|
||||
const expanded = toolList.classList.contains('expanded');
|
||||
toolsToggle.innerHTML = '<span>' + formatMcpToolsToggleLabel(count, expanded) + '</span>';
|
||||
}
|
||||
|
||||
function toggleMcpToolList(assistantMessageId) {
|
||||
const messageEl = document.getElementById(assistantMessageId);
|
||||
if (!messageEl) return;
|
||||
const chrome = ensureMcpCallSectionChrome(messageEl, assistantMessageId);
|
||||
if (!chrome) return;
|
||||
const { toolList } = chrome;
|
||||
const willExpand = !toolList.classList.contains('expanded');
|
||||
if (willExpand) {
|
||||
ensureMcpCallButtons(messageEl);
|
||||
toolList.classList.add('expanded');
|
||||
} else {
|
||||
toolList.classList.remove('expanded');
|
||||
}
|
||||
syncMcpToolsToggleButton(messageEl);
|
||||
}
|
||||
|
||||
window.toggleMcpToolList = toggleMcpToolList;
|
||||
window.syncMcpToolsToggleButton = syncMcpToolsToggleButton;
|
||||
window.ensureMcpCallSectionChrome = ensureMcpCallSectionChrome;
|
||||
|
||||
/** 将 MCP 工具按钮挂到独立工具列表,并批量解析工具名 */
|
||||
function appendMcpCallButtons(messageElement, executionIds) {
|
||||
if (!messageElement || !Array.isArray(executionIds) || executionIds.length === 0) {
|
||||
return;
|
||||
}
|
||||
const chrome = ensureMcpCallSectionChrome(messageElement, messageElement.id);
|
||||
if (!chrome) return;
|
||||
const toolList = chrome.toolList;
|
||||
|
||||
executionIds.forEach((execId, index) => {
|
||||
if (toolList.querySelector('.mcp-detail-btn[data-exec-id="' + CSS.escape(String(execId)) + '"]')) {
|
||||
return;
|
||||
}
|
||||
const detailBtn = document.createElement('button');
|
||||
detailBtn.className = 'mcp-detail-btn';
|
||||
detailBtn.dataset.execId = execId;
|
||||
detailBtn.dataset.execIndex = String(index + 1);
|
||||
detailBtn.innerHTML = '<span>' + (typeof window.t === 'function' ? window.t('chat.callNumber', { n: index + 1 }) : '调用 #' + (index + 1)) + '</span>';
|
||||
detailBtn.onclick = () => showMCPDetail(execId);
|
||||
toolList.appendChild(detailBtn);
|
||||
});
|
||||
batchUpdateButtonToolNames(toolList, executionIds);
|
||||
syncMcpToolsToggleButton(messageElement);
|
||||
}
|
||||
|
||||
/** 历史会话懒加载:用户展开工具列表时再渲染工具按钮 */
|
||||
function ensureMcpCallButtons(messageElement) {
|
||||
if (!messageElement || !messageElement.dataset || !messageElement.dataset.pendingMcpExecutionIds) {
|
||||
return;
|
||||
}
|
||||
let executionIds;
|
||||
try {
|
||||
executionIds = JSON.parse(messageElement.dataset.pendingMcpExecutionIds);
|
||||
} catch (e) {
|
||||
delete messageElement.dataset.pendingMcpExecutionIds;
|
||||
return;
|
||||
}
|
||||
if (!Array.isArray(executionIds) || executionIds.length === 0) {
|
||||
delete messageElement.dataset.pendingMcpExecutionIds;
|
||||
return;
|
||||
}
|
||||
appendMcpCallButtons(messageElement, executionIds);
|
||||
delete messageElement.dataset.pendingMcpExecutionIds;
|
||||
}
|
||||
|
||||
window.ensureMcpCallButtons = ensureMcpCallButtons;
|
||||
window.appendMcpCallButtons = appendMcpCallButtons;
|
||||
|
||||
// 批量获取工具名称并更新按钮(消除 N 次单独 API 请求,合并为 1 次)
|
||||
async function batchUpdateButtonToolNames(buttonsContainer, executionIds) {
|
||||
if (!executionIds || executionIds.length === 0) return;
|
||||
@@ -3169,40 +3426,63 @@ function getConversationGroup(dateObj, todayStart, sevenDaysCutoff, yesterdaySta
|
||||
}
|
||||
|
||||
// 加载对话
|
||||
/** 轻量加载会话后,拉取最后一条助手消息的 process_details(机器人等无 SSE 场景) */
|
||||
/** 轻量加载会话后,仅对「处理中…」占位回复拉取过程详情(机器人等非 SSE 场景);已完成会话不预取全量 */
|
||||
async function prefetchLastAssistantProcessDetails() {
|
||||
const nodes = document.querySelectorAll('#chat-messages .message.assistant');
|
||||
if (!nodes.length) return;
|
||||
const last = nodes[nodes.length - 1];
|
||||
if (!last || !last.id) return;
|
||||
const bubble = last.querySelector('.message-bubble');
|
||||
const visibleText = bubble ? String(bubble.textContent || '').trim() : '';
|
||||
const isPlaceholder = visibleText === '处理中...' || visibleText === 'Processing...';
|
||||
if (!isPlaceholder) return;
|
||||
const container = document.getElementById('process-details-' + last.id);
|
||||
if (!container || container.dataset.lazyNotLoaded !== '1') return;
|
||||
const backendId = last.dataset && last.dataset.backendMessageId;
|
||||
if (!backendId || typeof apiFetch !== 'function') return;
|
||||
if (typeof window.loadProcessDetailsPaginated === 'function') {
|
||||
await window.loadProcessDetailsPaginated(last.id, backendId);
|
||||
return;
|
||||
}
|
||||
const res = await apiFetch('/api/messages/' + encodeURIComponent(String(backendId)) + '/process-details');
|
||||
const j = await res.json().catch(() => ({}));
|
||||
if (!res.ok || !Array.isArray(j.processDetails) || j.processDetails.length === 0) return;
|
||||
if (typeof renderProcessDetails === 'function') {
|
||||
renderProcessDetails(last.id, j.processDetails);
|
||||
}
|
||||
if (typeof window.expandProcessDetailsTimeline === 'function') {
|
||||
window.expandProcessDetailsTimeline(last.id);
|
||||
}
|
||||
}
|
||||
|
||||
async function loadConversation(conversationId) {
|
||||
const seq = ++loadConversationRequestSeq;
|
||||
try {
|
||||
// 轻量加载:不带 processDetails,避免历史会话切换卡顿;展开详情时再按需拉取
|
||||
const response = await apiFetch(`/api/conversations/${conversationId}?include_process_details=0`);
|
||||
if (seq !== loadConversationRequestSeq) {
|
||||
return;
|
||||
}
|
||||
const conversation = await response.json();
|
||||
|
||||
if (!response.ok) {
|
||||
showChatToast('加载对话失败: ' + (conversation.error || '未知错误'), 'error');
|
||||
return;
|
||||
const cachedConversation = getConversationLiteFromCache(conversationId);
|
||||
const fetchPromise = apiFetch(`/api/conversations/${conversationId}?include_process_details=0`)
|
||||
.then(async (response) => {
|
||||
const data = await response.json();
|
||||
return { response, data };
|
||||
});
|
||||
|
||||
let conversation;
|
||||
let response;
|
||||
if (cachedConversation) {
|
||||
conversation = cachedConversation;
|
||||
fetchPromise.then(({ response: freshResp, data }) => {
|
||||
if (freshResp.ok && data && seq === loadConversationRequestSeq && currentConversationId === conversationId) {
|
||||
putConversationLiteCache(conversationId, data);
|
||||
}
|
||||
}).catch(() => {});
|
||||
} else {
|
||||
const fetched = await fetchPromise;
|
||||
response = fetched.response;
|
||||
conversation = fetched.data;
|
||||
if (seq !== loadConversationRequestSeq) {
|
||||
return;
|
||||
}
|
||||
if (!response.ok) {
|
||||
showChatToast('加载对话失败: ' + (conversation.error || '未知错误'), 'error');
|
||||
return;
|
||||
}
|
||||
putConversationLiteCache(conversationId, conversation);
|
||||
}
|
||||
if (seq !== loadConversationRequestSeq) {
|
||||
return;
|
||||
@@ -3252,11 +3532,15 @@ async function loadConversation(conversationId) {
|
||||
if (typeof refreshChatProjectSelector === 'function') {
|
||||
refreshChatProjectSelector();
|
||||
}
|
||||
if (typeof window.syncHitlConfigFromServer === 'function') {
|
||||
await window.syncHitlConfigFromServer(conversationId);
|
||||
} else {
|
||||
refreshHitlConfigByCurrentConversation();
|
||||
}
|
||||
refreshHitlConfigByCurrentConversation();
|
||||
const hitlSyncPromise = (typeof window.syncHitlConfigFromServer === 'function')
|
||||
? window.syncHitlConfigFromServer(conversationId).then(() => {
|
||||
if (seq === loadConversationRequestSeq && currentConversationId === conversationId) {
|
||||
refreshHitlConfigByCurrentConversation();
|
||||
}
|
||||
}).catch(() => {})
|
||||
: Promise.resolve();
|
||||
void hitlSyncPromise;
|
||||
updateActiveConversation();
|
||||
|
||||
// 如果攻击链模态框打开且显示的不是当前对话,关闭它
|
||||
@@ -3323,7 +3607,9 @@ async function loadConversation(conversationId) {
|
||||
// - user: createdAt 即可(发送后不会再更新)
|
||||
// - assistant: 如果后端提供 updatedAt(任务完成时写回),优先用它,避免占位消息“任务开始时间”误导
|
||||
const msgTime = (msg && msg.role === 'assistant' && msg.updatedAt) ? msg.updatedAt : (msg ? msg.createdAt : null);
|
||||
const messageId = addMessage(msg.role, displayContent, msg.mcpExecutionIds || [], null, msgTime);
|
||||
const mcpIds = (msg.mcpExecutionIds && Array.isArray(msg.mcpExecutionIds)) ? msg.mcpExecutionIds : [];
|
||||
const addOpts = (msg.role === 'assistant' && mcpIds.length > 0) ? { deferMcpButtons: true } : null;
|
||||
const messageId = addMessage(msg.role, displayContent, mcpIds, null, msgTime, addOpts);
|
||||
const messageEl = document.getElementById(messageId);
|
||||
if (messageEl && msg && msg.id) {
|
||||
messageEl.dataset.backendMessageId = String(msg.id);
|
||||
@@ -3491,6 +3777,7 @@ async function deleteConversationTurnFromUI(anchorBackendMessageId) {
|
||||
if (!response.ok) {
|
||||
throw new Error(data.error || data.message || 'delete failed');
|
||||
}
|
||||
invalidateConversationLiteCache(currentConversationId);
|
||||
await loadConversation(currentConversationId);
|
||||
if (typeof loadConversationsWithGroups === 'function') {
|
||||
loadConversationsWithGroups();
|
||||
@@ -3537,6 +3824,7 @@ async function deleteConversation(conversationId, skipConfirm = false) {
|
||||
|
||||
// 更新缓存 - 立即删除,确保后续加载时能正确识别
|
||||
delete conversationGroupMappingCache[conversationId];
|
||||
invalidateConversationLiteCache(conversationId);
|
||||
// 同时从待保留映射中移除
|
||||
delete pendingGroupMappings[conversationId];
|
||||
|
||||
@@ -7437,14 +7725,14 @@ async function showBatchManageModal() {
|
||||
updateBatchManageTitle(allConversationsForBatch.length);
|
||||
|
||||
renderBatchConversations();
|
||||
openAppModal('batch-manage-modal');
|
||||
openAppModal('batch-manage-modal', { focus: false });
|
||||
} catch (error) {
|
||||
console.error('加载对话列表失败:', error);
|
||||
// 错误时使用空数组,不显示错误提示(更友好的用户体验)
|
||||
allConversationsForBatch = [];
|
||||
updateBatchManageTitle(0);
|
||||
renderBatchConversations();
|
||||
openAppModal('batch-manage-modal');
|
||||
openAppModal('batch-manage-modal', { focus: false });
|
||||
}
|
||||
}
|
||||
|
||||
@@ -7504,6 +7792,7 @@ function renderBatchConversations(filtered = null) {
|
||||
checkbox.type = 'checkbox';
|
||||
checkbox.className = 'batch-conversation-checkbox';
|
||||
checkbox.dataset.conversationId = conv.id;
|
||||
checkbox.addEventListener('change', syncSelectAllBatchCheckbox);
|
||||
|
||||
const name = document.createElement('div');
|
||||
name.className = 'batch-table-col-name';
|
||||
@@ -7529,9 +7818,21 @@ function renderBatchConversations(filtered = null) {
|
||||
const action = document.createElement('div');
|
||||
action.className = 'batch-table-col-action';
|
||||
const deleteBtn = document.createElement('button');
|
||||
deleteBtn.type = 'button';
|
||||
deleteBtn.className = 'batch-delete-btn';
|
||||
deleteBtn.innerHTML = '🗑️';
|
||||
deleteBtn.onclick = () => deleteConversation(conv.id);
|
||||
deleteBtn.innerHTML = `
|
||||
<svg width="16" height="16" viewBox="0 0 24 24" fill="none" xmlns="http://www.w3.org/2000/svg" aria-hidden="true">
|
||||
<path d="M3 6h18M8 6V4a2 2 0 0 1 2-2h4a2 2 0 0 1 2 2v2m3 0v14a2 2 0 0 1-2 2H7a2 2 0 0 1-2-2V6h14zM10 11v6M14 11v6"
|
||||
stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"/>
|
||||
</svg>
|
||||
`;
|
||||
const deleteLabel = typeof window.t === 'function' ? window.t('contextMenu.deleteConversation') : '删除此对话';
|
||||
deleteBtn.title = deleteLabel;
|
||||
deleteBtn.setAttribute('aria-label', deleteLabel);
|
||||
deleteBtn.onclick = (e) => {
|
||||
e.stopPropagation();
|
||||
deleteConversation(conv.id);
|
||||
};
|
||||
action.appendChild(deleteBtn);
|
||||
|
||||
row.appendChild(checkbox);
|
||||
@@ -7541,6 +7842,8 @@ function renderBatchConversations(filtered = null) {
|
||||
|
||||
list.appendChild(row);
|
||||
});
|
||||
|
||||
syncSelectAllBatchCheckbox();
|
||||
}
|
||||
|
||||
// 筛选批量管理对话
|
||||
@@ -7562,12 +7865,35 @@ function filterBatchConversations(query) {
|
||||
function toggleSelectAllBatch() {
|
||||
const selectAll = document.getElementById('batch-select-all');
|
||||
const checkboxes = document.querySelectorAll('.batch-conversation-checkbox');
|
||||
|
||||
|
||||
if (selectAll) {
|
||||
selectAll.indeterminate = false;
|
||||
}
|
||||
checkboxes.forEach(cb => {
|
||||
cb.checked = selectAll.checked;
|
||||
});
|
||||
}
|
||||
|
||||
function syncSelectAllBatchCheckbox() {
|
||||
const selectAll = document.getElementById('batch-select-all');
|
||||
if (!selectAll) return;
|
||||
|
||||
const checkboxes = document.querySelectorAll('.batch-conversation-checkbox');
|
||||
const total = checkboxes.length;
|
||||
const checked = document.querySelectorAll('.batch-conversation-checkbox:checked').length;
|
||||
|
||||
if (total === 0 || checked === 0) {
|
||||
selectAll.checked = false;
|
||||
selectAll.indeterminate = false;
|
||||
} else if (checked === total) {
|
||||
selectAll.checked = true;
|
||||
selectAll.indeterminate = false;
|
||||
} else {
|
||||
selectAll.checked = false;
|
||||
selectAll.indeterminate = true;
|
||||
}
|
||||
}
|
||||
|
||||
// 删除选中的对话
|
||||
async function deleteSelectedConversations() {
|
||||
const checkboxes = document.querySelectorAll('.batch-conversation-checkbox:checked');
|
||||
@@ -7591,6 +7917,7 @@ async function deleteSelectedConversations() {
|
||||
const selectAll = document.getElementById('batch-select-all');
|
||||
if (selectAll) {
|
||||
selectAll.checked = false;
|
||||
selectAll.indeterminate = false;
|
||||
}
|
||||
} catch (error) {
|
||||
console.error('删除失败:', error);
|
||||
@@ -7606,6 +7933,7 @@ function closeBatchManageModal() {
|
||||
const selectAll = document.getElementById('batch-select-all');
|
||||
if (selectAll) {
|
||||
selectAll.checked = false;
|
||||
selectAll.indeterminate = false;
|
||||
}
|
||||
allConversationsForBatch = [];
|
||||
}
|
||||
@@ -7640,6 +7968,20 @@ function refreshChatPanelI18n() {
|
||||
const expanded = timeline && timeline.classList.contains('expanded');
|
||||
span.textContent = expanded ? t('tasks.collapseDetail') : t('chat.expandDetail');
|
||||
});
|
||||
const copyLabel = t('common.copy');
|
||||
const copyTitle = t('chat.copyMessageTitle');
|
||||
messagesEl.querySelectorAll('.message-copy-btn').forEach(function (btn) {
|
||||
if (btn.dataset.copySuccessActive === '1') return;
|
||||
const span = btn.querySelector('span');
|
||||
if (span) span.textContent = copyLabel;
|
||||
btn.title = copyTitle;
|
||||
btn.setAttribute('aria-label', copyTitle);
|
||||
});
|
||||
messagesEl.querySelectorAll('.message.assistant').forEach(function (msgEl) {
|
||||
if (typeof window.syncMcpToolsToggleButton === 'function') {
|
||||
window.syncMcpToolsToggleButton(msgEl);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
if (isAppModalOpen('mcp-detail-modal')) {
|
||||
|
||||
@@ -0,0 +1,680 @@
|
||||
/**
|
||||
* 项目事实图渲染(Cytoscape + ELK),供项目管理页使用。
|
||||
* 节点采用 SVG 卡片背景(图标 + 多行文字),避免 Cytoscape 原生 label 定位问题。
|
||||
*/
|
||||
(function (global) {
|
||||
'use strict';
|
||||
|
||||
let _cy = null;
|
||||
let _graphData = null;
|
||||
let _onNodeSelect = null;
|
||||
let _onEdgeSelect = null;
|
||||
let _resizeObs = null;
|
||||
|
||||
const EDGE_COLORS = {
|
||||
discovered_on: '#4F46E5',
|
||||
leads_to: '#64748B',
|
||||
enables: '#E11D48',
|
||||
exploits: '#DC2626',
|
||||
depends_on: '#0D9488',
|
||||
contains: '#6366F1',
|
||||
part_of: '#6366F1',
|
||||
supports: '#94A3B8',
|
||||
links_vuln: '#BE123C',
|
||||
};
|
||||
|
||||
const CARD_PAD = 14;
|
||||
const CARD_TEXT_PAD_RIGHT = 12;
|
||||
const CARD_ICON = 36;
|
||||
const CARD_ICON_GAP = 12;
|
||||
const CARD_TEXT_X = CARD_PAD + CARD_ICON + CARD_ICON_GAP;
|
||||
const CARD_MIN_W = 300;
|
||||
const CARD_TARGET_W = 360;
|
||||
const CARD_MIN_H = 88;
|
||||
const CARD_MAX_H = 176;
|
||||
const CARD_HEADER_FS = 11;
|
||||
const CARD_HEADER_LH = 16;
|
||||
const CARD_KEY_FS = 10;
|
||||
const CARD_KEY_LH = 14;
|
||||
const CARD_SUMMARY_FS = 13;
|
||||
const CARD_SUMMARY_LH = 18;
|
||||
const CARD_SECTION_GAP = 6;
|
||||
const CARD_FONT =
|
||||
'-apple-system, BlinkMacSystemFont, "Segoe UI", Roboto, "Helvetica Neue", "PingFang SC", "Microsoft YaHei", sans-serif';
|
||||
const CARD_KEY_FONT =
|
||||
'ui-monospace, SFMono-Regular, Menlo, Monaco, Consolas, "Liberation Mono", monospace';
|
||||
|
||||
function nodeTheme(type) {
|
||||
switch (type) {
|
||||
case 'target':
|
||||
return { typeLabel: '目标', typeEn: 'TARGET', accent: '#4F46E5', bgEnd: '#F5F3FF', icon: 'target' };
|
||||
case 'finding':
|
||||
return { typeLabel: '发现', typeEn: 'FINDING', accent: '#E11D48', bgEnd: '#FFF1F2', icon: 'finding', cardStyle: 'default' };
|
||||
case 'exploit':
|
||||
return { typeLabel: '利用', typeEn: 'EXPLOIT', accent: '#B45309', bgEnd: '#FFFBEB', icon: 'vulnerability', cardStyle: 'default' };
|
||||
case 'vulnerability':
|
||||
return { typeLabel: '漏洞', typeEn: 'VULN', accent: '#9333EA', bgEnd: '#F5F3FF', icon: 'vuln', cardStyle: 'default' };
|
||||
case 'auth':
|
||||
return { typeLabel: '认证', typeEn: 'AUTH', accent: '#0D9488', bgEnd: '#F0FDFA', icon: 'default' };
|
||||
case 'infra':
|
||||
return { typeLabel: '基础设施', typeEn: 'INFRA', accent: '#64748B', bgEnd: '#F8FAFC', icon: 'default' };
|
||||
case 'chain':
|
||||
return { typeLabel: '攻击链', typeEn: 'CHAIN', accent: '#7C3AED', bgEnd: '#F5F3FF', icon: 'vulnerability' };
|
||||
case 'poc':
|
||||
return { typeLabel: 'POC', typeEn: 'POC', accent: '#C2410C', bgEnd: '#FFEDD5', icon: 'vulnerability' };
|
||||
case 'business':
|
||||
return { typeLabel: '业务', typeEn: 'BUSINESS', accent: '#0369A1', bgEnd: '#F0F9FF', icon: 'default' };
|
||||
case 'missing':
|
||||
return { typeLabel: '缺失', typeEn: 'MISSING', accent: '#CBD5E1', bgEnd: '#F1F5F9', icon: 'default' };
|
||||
default:
|
||||
return { typeLabel: '备注', typeEn: 'NOTE', accent: '#94A3B8', bgEnd: '#F8FAFC', icon: 'default' };
|
||||
}
|
||||
}
|
||||
|
||||
function escapeXml(str) {
|
||||
return String(str)
|
||||
.replace(/&/g, '&')
|
||||
.replace(/</g, '<')
|
||||
.replace(/>/g, '>')
|
||||
.replace(/"/g, '"')
|
||||
.replace(/'/g, ''');
|
||||
}
|
||||
|
||||
function escapeHtml(str) {
|
||||
return escapeXml(str);
|
||||
}
|
||||
|
||||
function buildStatusBadge(confidence) {
|
||||
const conf = (confidence || '').toLowerCase();
|
||||
if (conf === 'tentative') return '待确认';
|
||||
if (conf === 'deprecated') return '已废弃';
|
||||
return '';
|
||||
}
|
||||
|
||||
function buildHeaderText(theme, statusBadge) {
|
||||
const line = (theme.typeEn || '') + ' · ' + (theme.typeLabel || '');
|
||||
return statusBadge ? line + ' · ' + statusBadge : line;
|
||||
}
|
||||
|
||||
function isWideChar(ch) {
|
||||
const code = ch.codePointAt(0) || 0;
|
||||
if (code >= 0x4e00 && code <= 0x9fff) return true;
|
||||
if (code >= 0x3400 && code <= 0x4dbf) return true;
|
||||
if (code >= 0xf900 && code <= 0xfaff) return true;
|
||||
if (code >= 0xff00 && code <= 0xffef) return true;
|
||||
return /[·:,。;!?【】()《》、「」]/.test(ch);
|
||||
}
|
||||
|
||||
function charWidth(ch, fontSize, bold) {
|
||||
const scale = bold ? 1.05 : 1;
|
||||
if (ch === ' ') return fontSize * 0.3 * scale;
|
||||
if (isWideChar(ch)) return fontSize * scale;
|
||||
return fontSize * 0.58 * scale;
|
||||
}
|
||||
|
||||
function lineWidth(text, fontSize, bold) {
|
||||
let width = 0;
|
||||
for (const ch of text) width += charWidth(ch, fontSize, bold);
|
||||
return width;
|
||||
}
|
||||
|
||||
function wrapTextLines(text, maxWidth, fontSize, maxLines, bold) {
|
||||
const raw = String(text || '').replace(/\s+/g, ' ').trim();
|
||||
if (!raw) return ['—'];
|
||||
const safeWidth = Math.max(40, maxWidth - 4);
|
||||
const chars = [...raw];
|
||||
const lines = [];
|
||||
let index = 0;
|
||||
while (index < chars.length && lines.length < maxLines) {
|
||||
let line = '';
|
||||
let width = 0;
|
||||
while (index < chars.length) {
|
||||
const ch = chars[index];
|
||||
const nextWidth = charWidth(ch, fontSize, bold);
|
||||
if (line && width + nextWidth > safeWidth) break;
|
||||
line += ch;
|
||||
width += nextWidth;
|
||||
index += 1;
|
||||
if (width >= safeWidth) break;
|
||||
}
|
||||
if (line) lines.push(line);
|
||||
}
|
||||
if (index < chars.length && lines.length) {
|
||||
let last = lines[lines.length - 1];
|
||||
while (last.length > 1 && lineWidth(last + '…', fontSize, bold) > safeWidth) {
|
||||
last = last.slice(0, -1);
|
||||
}
|
||||
lines[lines.length - 1] = last + '…';
|
||||
}
|
||||
return lines.length ? lines : ['—'];
|
||||
}
|
||||
|
||||
function cardTextWidth(nodeWidth) {
|
||||
return nodeWidth - CARD_TEXT_X - CARD_PAD - CARD_TEXT_PAD_RIGHT;
|
||||
}
|
||||
|
||||
function computeNodeLayout(type, summary, statusBadge, theme, factKey) {
|
||||
const width = type === 'target' ? CARD_TARGET_W : CARD_MIN_W;
|
||||
const textW = cardTextWidth(width);
|
||||
const t = theme || nodeTheme(type);
|
||||
const headerLines = wrapTextLines(buildHeaderText(t, statusBadge), textW, CARD_HEADER_FS, 2, true);
|
||||
const keyText = String(factKey || '').trim();
|
||||
const keyLines = keyText ? wrapTextLines(keyText, textW, CARD_KEY_FS, 2, false) : [];
|
||||
const summaryLines = wrapTextLines(summary, textW, CARD_SUMMARY_FS, keyLines.length ? 3 : 4, true);
|
||||
const keyBlockHeight = keyLines.length
|
||||
? CARD_SECTION_GAP + keyLines.length * CARD_KEY_LH + CARD_SECTION_GAP
|
||||
: CARD_SECTION_GAP;
|
||||
const height = Math.min(
|
||||
CARD_MAX_H,
|
||||
Math.max(
|
||||
CARD_MIN_H,
|
||||
CARD_PAD +
|
||||
headerLines.length * CARD_HEADER_LH +
|
||||
keyBlockHeight +
|
||||
summaryLines.length * CARD_SUMMARY_LH +
|
||||
CARD_PAD,
|
||||
),
|
||||
);
|
||||
return {
|
||||
width,
|
||||
height,
|
||||
headerLines,
|
||||
keyLines,
|
||||
summaryLines,
|
||||
searchLabel: [headerLines.join(' '), keyLines.join(' '), summaryLines.join(' ')]
|
||||
.filter(Boolean)
|
||||
.join('\n'),
|
||||
};
|
||||
}
|
||||
|
||||
function svgIconGroup(kind, color, x, y) {
|
||||
const scale = (CARD_ICON / 24).toFixed(3);
|
||||
if (kind === 'target') {
|
||||
return (
|
||||
`<g transform="translate(${x}, ${y}) scale(${scale})">` +
|
||||
`<circle cx="12" cy="12" r="6" fill="none" stroke="${color}" stroke-width="2"/>` +
|
||||
`<circle cx="12" cy="12" r="2.5" fill="${color}"/></g>`
|
||||
);
|
||||
}
|
||||
if (kind === 'finding') {
|
||||
return (
|
||||
`<g transform="translate(${x}, ${y}) scale(${scale})">` +
|
||||
`<circle cx="10" cy="10" r="6" fill="none" stroke="${color}" stroke-width="2"/>` +
|
||||
`<line x1="14.5" y1="14.5" x2="19" y2="19" stroke="${color}" stroke-width="2" stroke-linecap="round"/></g>`
|
||||
);
|
||||
}
|
||||
if (kind === 'vuln') {
|
||||
return (
|
||||
`<g transform="translate(${x}, ${y}) scale(${scale})">` +
|
||||
`<path d="M12 2.5l7.5 3v6.2c0 4.6-3.1 8.1-7.5 9.3-4.4-1.2-7.5-4.7-7.5-9.3V5.5z" fill="${color}" fill-opacity="0.12" stroke="${color}" stroke-width="2"/>` +
|
||||
`<line x1="12" y1="8.5" x2="12" y2="12.5" stroke="${color}" stroke-width="2" stroke-linecap="round"/>` +
|
||||
`<circle cx="12" cy="15.5" r="1.1" fill="${color}"/></g>`
|
||||
);
|
||||
}
|
||||
if (kind === 'vulnerability') {
|
||||
return (
|
||||
`<g transform="translate(${x}, ${y}) scale(${scale})">` +
|
||||
`<path d="M12 3l9 16H3z" fill="none" stroke="${color}" stroke-width="2"/>` +
|
||||
`<line x1="12" y1="9" x2="12" y2="13" stroke="${color}" stroke-width="2"/>` +
|
||||
`<circle cx="12" cy="16" r="1" fill="${color}"/></g>`
|
||||
);
|
||||
}
|
||||
return (
|
||||
`<g transform="translate(${x}, ${y}) scale(${scale})">` +
|
||||
`<circle cx="12" cy="12" r="5" fill="${color}" opacity="0.85"/></g>`
|
||||
);
|
||||
}
|
||||
|
||||
function buildNodeCardSvgUrl(theme, layout, confidence) {
|
||||
const { width, height, headerLines, keyLines, summaryLines } = layout;
|
||||
const accent = theme.accent;
|
||||
const bgEnd = theme.bgEnd;
|
||||
const conf = (confidence || '').toLowerCase();
|
||||
const isTentative = conf === 'tentative';
|
||||
const isDeprecated = conf === 'deprecated';
|
||||
const iconX = CARD_PAD;
|
||||
const iconY = (height - CARD_ICON) / 2;
|
||||
const headerY = CARD_PAD + CARD_HEADER_FS;
|
||||
const keyY = CARD_PAD + headerLines.length * CARD_HEADER_LH + CARD_SECTION_GAP + CARD_KEY_FS;
|
||||
const summaryY =
|
||||
CARD_PAD +
|
||||
headerLines.length * CARD_HEADER_LH +
|
||||
(keyLines.length
|
||||
? CARD_SECTION_GAP + keyLines.length * CARD_KEY_LH + CARD_SECTION_GAP
|
||||
: CARD_SECTION_GAP) +
|
||||
CARD_SUMMARY_FS;
|
||||
|
||||
const stroke = isTentative
|
||||
? `stroke="${accent}" stroke-width="1.5" stroke-dasharray="8 5" stroke-opacity="0.9"`
|
||||
: `stroke="${accent}" stroke-width="1.5" stroke-opacity="0.72"`;
|
||||
|
||||
const headerSvg = headerLines
|
||||
.map(
|
||||
(line, i) =>
|
||||
`<text x="${CARD_TEXT_X}" y="${headerY + i * CARD_HEADER_LH}" font-size="${CARD_HEADER_FS}" font-weight="700" fill="${accent}" fill-opacity="0.88" font-family='${CARD_FONT}'>${escapeXml(line)}</text>`,
|
||||
)
|
||||
.join('');
|
||||
|
||||
const keySvg = keyLines
|
||||
.map(
|
||||
(line, i) =>
|
||||
`<text x="${CARD_TEXT_X}" y="${keyY + i * CARD_KEY_LH}" font-size="${CARD_KEY_FS}" font-weight="500" fill="#64748b" font-family='${CARD_KEY_FONT}'>${escapeXml(line)}</text>`,
|
||||
)
|
||||
.join('');
|
||||
|
||||
const summarySvg = summaryLines
|
||||
.map(
|
||||
(line, i) =>
|
||||
`<text x="${CARD_TEXT_X}" y="${summaryY + i * CARD_SUMMARY_LH}" font-size="${CARD_SUMMARY_FS}" font-weight="600" fill="#0f172a" font-family='${CARD_FONT}'>${escapeXml(line)}</text>`,
|
||||
)
|
||||
.join('');
|
||||
|
||||
const textClipW = width - CARD_TEXT_X - CARD_PAD - 2;
|
||||
const textClipH = height - CARD_PAD * 2 + 4;
|
||||
|
||||
const svg =
|
||||
`<svg xmlns="http://www.w3.org/2000/svg" width="${width}" height="${height}" viewBox="0 0 ${width} ${height}">` +
|
||||
`<defs><linearGradient id="bg" x1="0%" y1="0%" x2="100%" y2="100%">` +
|
||||
`<stop offset="0%" stop-color="#FFFFFF"/><stop offset="100%" stop-color="${bgEnd}"/></linearGradient>` +
|
||||
`<clipPath id="textClip"><rect x="${CARD_TEXT_X}" y="${CARD_PAD - 2}" width="${textClipW}" height="${textClipH}"/></clipPath></defs>` +
|
||||
`<g${isDeprecated ? ' opacity="0.55"' : ''}>` +
|
||||
`<rect x="0.75" y="0.75" width="${width - 1.5}" height="${height - 1.5}" rx="12" fill="url(#bg)" ${stroke}/>` +
|
||||
svgIconGroup(theme.icon, accent, iconX, iconY) +
|
||||
`<g clip-path="url(#textClip)">${headerSvg}${keySvg}${summarySvg}</g>` +
|
||||
`</g></svg>`;
|
||||
|
||||
try {
|
||||
return 'data:image/svg+xml;base64,' + btoa(unescape(encodeURIComponent(svg)));
|
||||
} catch (e) {
|
||||
return 'data:image/svg+xml;charset=utf-8,' + encodeURIComponent(svg);
|
||||
}
|
||||
}
|
||||
|
||||
function destroy() {
|
||||
if (_resizeObs) {
|
||||
_resizeObs.disconnect();
|
||||
_resizeObs = null;
|
||||
}
|
||||
if (_cy) {
|
||||
_cy.destroy();
|
||||
_cy = null;
|
||||
}
|
||||
_graphData = null;
|
||||
}
|
||||
|
||||
function observeContainerResize(container) {
|
||||
if (_resizeObs) {
|
||||
_resizeObs.disconnect();
|
||||
_resizeObs = null;
|
||||
}
|
||||
if (!container || typeof ResizeObserver === 'undefined') return;
|
||||
_resizeObs = new ResizeObserver(() => {
|
||||
if (_cy) {
|
||||
try {
|
||||
_cy.resize();
|
||||
} catch (e) {
|
||||
console.warn('graph resize', e);
|
||||
}
|
||||
}
|
||||
});
|
||||
_resizeObs.observe(container);
|
||||
}
|
||||
|
||||
function centerGraph() {
|
||||
if (!_cy) return;
|
||||
try {
|
||||
_cy.resize();
|
||||
_cy.fit(undefined, 56);
|
||||
if (_cy.zoom() < 0.65) {
|
||||
_cy.zoom(0.65);
|
||||
_cy.center();
|
||||
}
|
||||
} catch (e) {
|
||||
console.warn('centerGraph', e);
|
||||
}
|
||||
}
|
||||
|
||||
// ELK 分层(仅影响节点纵向位置,不修改边的 source/target)
|
||||
function pathGraphNodeLayer(type, factKey) {
|
||||
const key = (factKey || '').toLowerCase();
|
||||
if (key.startsWith('vuln:')) return '4';
|
||||
const t = (type || '').toLowerCase();
|
||||
if (t === 'target') return '0';
|
||||
if (t === 'infra' || t === 'auth' || t === 'business') return '1';
|
||||
if (t === 'exploit' || t === 'poc') return '3';
|
||||
if (t === 'vulnerability' || t === 'vuln') return '3';
|
||||
if (t === 'chain' || t === 'finding') return '2';
|
||||
if (t === 'note') return '2';
|
||||
return '2';
|
||||
}
|
||||
|
||||
function applyElkLayout(validEdges, isComplex) {
|
||||
const layoutOptions = {
|
||||
name: 'breadthfirst',
|
||||
directed: true,
|
||||
spacingFactor: isComplex ? 3.0 : 2.5,
|
||||
padding: 40,
|
||||
};
|
||||
const elkInstance = typeof ELK !== 'undefined' ? new ELK() : null;
|
||||
if (!elkInstance) {
|
||||
const layout = _cy.layout(layoutOptions);
|
||||
layout.one('layoutstop', () => setTimeout(centerGraph, 100));
|
||||
layout.run();
|
||||
return;
|
||||
}
|
||||
const nodeGap = isComplex ? 45 : 60;
|
||||
const layerGap = isComplex ? 70 : 95;
|
||||
const elkGraph = {
|
||||
id: 'root',
|
||||
layoutOptions: {
|
||||
'elk.algorithm': 'layered',
|
||||
'elk.direction': 'DOWN',
|
||||
'elk.spacing.nodeNode': String(nodeGap),
|
||||
'elk.layered.spacing.nodeNodeBetweenLayers': String(layerGap),
|
||||
'elk.layered.nodePlacement.strategy': 'BRANDES_KOEPF',
|
||||
},
|
||||
children: (_graphData.nodes || []).map((node) => {
|
||||
const n = _cy ? _cy.getElementById(node.id) : null;
|
||||
const w = n.length ? n.data('nodeWidth') : node.type === 'target' ? CARD_TARGET_W : CARD_MIN_W;
|
||||
const h = n.length ? n.data('nodeHeight') : CARD_MIN_H;
|
||||
const nodeKey = node.fact_key || node.id;
|
||||
return {
|
||||
id: node.id,
|
||||
width: w,
|
||||
height: h,
|
||||
layoutOptions: {
|
||||
'org.eclipse.elk.layered.layering.layerId': pathGraphNodeLayer(node.type, nodeKey),
|
||||
},
|
||||
};
|
||||
}),
|
||||
edges: validEdges.map((edge) => ({
|
||||
id: edge.id,
|
||||
sources: [edge.source],
|
||||
targets: [edge.target],
|
||||
})),
|
||||
};
|
||||
elkInstance
|
||||
.layout(elkGraph)
|
||||
.then((laidOut) => {
|
||||
(laidOut.children || []).forEach((elkNode) => {
|
||||
const cyNode = _cy.getElementById(elkNode.id);
|
||||
if (cyNode.length && elkNode.x != null) {
|
||||
cyNode.position({
|
||||
x: elkNode.x + (elkNode.width || 0) / 2,
|
||||
y: elkNode.y + (elkNode.height || 0) / 2,
|
||||
});
|
||||
}
|
||||
});
|
||||
setTimeout(centerGraph, 120);
|
||||
})
|
||||
.catch(() => {
|
||||
const layout = _cy.layout(layoutOptions);
|
||||
layout.one('layoutstop', () => setTimeout(centerGraph, 100));
|
||||
layout.run();
|
||||
});
|
||||
}
|
||||
|
||||
function render(container, graphData, options) {
|
||||
if (!container || typeof cytoscape === 'undefined') {
|
||||
if (container) {
|
||||
container.innerHTML = '<div class="error-message">Cytoscape 未加载</div>';
|
||||
}
|
||||
return null;
|
||||
}
|
||||
destroy();
|
||||
_graphData = graphData || { nodes: [], edges: [] };
|
||||
_onNodeSelect = options && options.onNodeSelect;
|
||||
_onEdgeSelect = options && options.onEdgeSelect;
|
||||
|
||||
const nodes = _graphData.nodes || [];
|
||||
const edges = _graphData.edges || [];
|
||||
if (!nodes.length) {
|
||||
const title = (options && options.emptyTitle) || '';
|
||||
const hint = (options && options.emptyText) || '暂无事实关系';
|
||||
const steps = (options && options.emptySteps) || [];
|
||||
const actionLabel = options && options.emptyActionLabel;
|
||||
const stepsHtml = steps.length
|
||||
? '<ol class="project-fact-graph-empty-steps">' +
|
||||
steps.map((s) => '<li>' + escapeHtml(String(s)) + '</li>').join('') +
|
||||
'</ol>'
|
||||
: '';
|
||||
const actionHtml =
|
||||
actionLabel && options.onEmptyAction
|
||||
? '<button type="button" class="btn-primary btn-small project-fact-graph-empty-cta">' +
|
||||
escapeHtml(actionLabel) +
|
||||
'</button>'
|
||||
: '';
|
||||
container.innerHTML =
|
||||
'<div class="project-fact-graph-empty">' +
|
||||
'<div class="project-fact-graph-empty-icon" aria-hidden="true">' +
|
||||
'<svg width="48" height="48" viewBox="0 0 24 24" fill="none"><circle cx="6" cy="6" r="2.5" fill="#4F46E5" opacity="0.9"/><circle cx="18" cy="6" r="2.5" fill="#E11D48" opacity="0.9"/><circle cx="12" cy="18" r="2.5" fill="#0D9488" opacity="0.9"/>' +
|
||||
'<path d="M8 7l4 9M16 7l-4 9M8 7h8" stroke="#CBD5E1" stroke-width="1.5" stroke-linecap="round"/></svg>' +
|
||||
'</div>' +
|
||||
(title ? '<h4 class="project-fact-graph-empty-title">' + escapeHtml(title) + '</h4>' : '') +
|
||||
'<p class="project-fact-graph-empty-hint">' + escapeHtml(hint) + '</p>' +
|
||||
stepsHtml +
|
||||
actionHtml +
|
||||
'</div>';
|
||||
const cta = container.querySelector('.project-fact-graph-empty-cta');
|
||||
if (cta && typeof options.onEmptyAction === 'function') {
|
||||
cta.addEventListener('click', options.onEmptyAction);
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
container.innerHTML = '';
|
||||
const isComplex = nodes.length > 15 || edges.length > 25;
|
||||
const elements = [];
|
||||
const nodeIds = new Set();
|
||||
|
||||
nodes.forEach((node) => {
|
||||
nodeIds.add(node.id);
|
||||
const visualType = resolveGraphNodeType(node);
|
||||
const theme = nodeTheme(visualType);
|
||||
const factKey = node.fact_key || node.id;
|
||||
const summary = (node.summary || node.label || '').trim() || '—';
|
||||
const statusBadge = buildStatusBadge(node.confidence);
|
||||
const layout = computeNodeLayout(visualType, summary, statusBadge, theme, factKey);
|
||||
elements.push({
|
||||
data: {
|
||||
id: node.id,
|
||||
label: layout.searchLabel,
|
||||
factKey: node.fact_key || node.id,
|
||||
category: node.category || '',
|
||||
type: visualType,
|
||||
typeLabel: theme.typeLabel,
|
||||
typeEn: theme.typeEn,
|
||||
accentColor: theme.accent,
|
||||
statusBadge: statusBadge,
|
||||
confidence: node.confidence || '',
|
||||
nodeWidth: layout.width,
|
||||
nodeHeight: layout.height,
|
||||
cardSvgUrl: buildNodeCardSvgUrl(theme, layout, node.confidence),
|
||||
},
|
||||
});
|
||||
});
|
||||
|
||||
const validEdges = [];
|
||||
edges.forEach((edge, idx) => {
|
||||
if (!nodeIds.has(edge.source) || !nodeIds.has(edge.target)) return;
|
||||
const id = edge.id || 'e-' + idx;
|
||||
validEdges.push({ ...edge, id });
|
||||
elements.push({
|
||||
data: {
|
||||
id,
|
||||
source: edge.source,
|
||||
target: edge.target,
|
||||
type: edge.type || 'leads_to',
|
||||
confidence: edge.confidence || 'confirmed',
|
||||
},
|
||||
});
|
||||
});
|
||||
|
||||
_cy = cytoscape({
|
||||
container,
|
||||
elements,
|
||||
style: [
|
||||
{
|
||||
selector: 'node',
|
||||
style: {
|
||||
label: '',
|
||||
width: (ele) => ele.data('nodeWidth') || CARD_MIN_W,
|
||||
height: (ele) => ele.data('nodeHeight') || CARD_MIN_H,
|
||||
shape: 'round-rectangle',
|
||||
'background-color': '#ffffff',
|
||||
'background-image': (ele) => ele.data('cardSvgUrl') || 'none',
|
||||
'background-width': (ele) => (ele.data('nodeWidth') || CARD_MIN_W) + 'px',
|
||||
'background-height': (ele) => (ele.data('nodeHeight') || CARD_MIN_H) + 'px',
|
||||
'background-position-x': '50%',
|
||||
'background-position-y': '50%',
|
||||
'background-fit': 'none',
|
||||
'border-width': 0,
|
||||
'background-opacity': 1,
|
||||
},
|
||||
},
|
||||
{
|
||||
selector: 'edge',
|
||||
style: {
|
||||
width: 2.2,
|
||||
'line-color': (ele) => EDGE_COLORS[ele.data('type')] || '#CBD5E1',
|
||||
'target-arrow-color': (ele) => EDGE_COLORS[ele.data('type')] || '#CBD5E1',
|
||||
'target-arrow-shape': 'triangle',
|
||||
'curve-style': 'bezier',
|
||||
opacity: (ele) => (ele.data('confidence') === 'tentative' ? 0.55 : 0.9),
|
||||
'line-style': (ele) => (ele.data('confidence') === 'tentative' ? 'dashed' : 'solid'),
|
||||
},
|
||||
},
|
||||
{
|
||||
selector: 'edge:selected',
|
||||
style: {
|
||||
width: 3.5,
|
||||
opacity: 1,
|
||||
'line-color': '#4F46E5',
|
||||
'target-arrow-color': '#4F46E5',
|
||||
},
|
||||
},
|
||||
{
|
||||
selector: 'node:selected',
|
||||
style: {
|
||||
'border-width': 3,
|
||||
'border-color': '#4F46E5',
|
||||
'border-opacity': 1,
|
||||
},
|
||||
},
|
||||
],
|
||||
minZoom: 0.35,
|
||||
maxZoom: 3,
|
||||
});
|
||||
|
||||
_cy.on('tap', 'node', (evt) => {
|
||||
const d = evt.target.data();
|
||||
const key = d.factKey || d.id;
|
||||
if (_connectMode && _connectPick) {
|
||||
_connectPick(key);
|
||||
return;
|
||||
}
|
||||
if (typeof _onNodeSelect === 'function') {
|
||||
_onNodeSelect(key, d);
|
||||
}
|
||||
});
|
||||
|
||||
_cy.on('tap', 'edge', (evt) => {
|
||||
if (_connectMode && _connectPick) return;
|
||||
const d = evt.target.data();
|
||||
if (typeof _onEdgeSelect === 'function') {
|
||||
_onEdgeSelect(d.id, d);
|
||||
}
|
||||
});
|
||||
|
||||
_cy.on('tap', (evt) => {
|
||||
if (evt.target === _cy) {
|
||||
clearEdgeSelection();
|
||||
}
|
||||
});
|
||||
|
||||
applyElkLayout(validEdges, isComplex);
|
||||
observeContainerResize(container);
|
||||
return _cy;
|
||||
}
|
||||
|
||||
function filterBySearch(query) {
|
||||
if (!_cy) return;
|
||||
const q = (query || '').trim().toLowerCase();
|
||||
_cy.nodes().forEach((n) => {
|
||||
if (!q) {
|
||||
n.style('opacity', 1);
|
||||
return;
|
||||
}
|
||||
const text = (
|
||||
(n.data('label') || '') +
|
||||
' ' +
|
||||
(n.data('factKey') || '') +
|
||||
' ' +
|
||||
(n.data('typeLabel') || '')
|
||||
).toLowerCase();
|
||||
n.style('opacity', text.includes(q) ? 1 : 0.15);
|
||||
});
|
||||
_cy.edges().forEach((e) => {
|
||||
e.style('opacity', q ? 0.12 : 0.9);
|
||||
});
|
||||
}
|
||||
|
||||
let _connectMode = false;
|
||||
let _connectPick = null;
|
||||
|
||||
function selectEdge(edgeId) {
|
||||
if (!_cy || !edgeId) return;
|
||||
_cy.elements().unselect();
|
||||
const edge = _cy.getElementById(edgeId);
|
||||
if (edge.length) edge.select();
|
||||
}
|
||||
|
||||
function clearEdgeSelection() {
|
||||
if (!_cy) return;
|
||||
_cy.elements().unselect();
|
||||
}
|
||||
|
||||
function setConnectMode(enabled, onPick) {
|
||||
_connectMode = !!enabled;
|
||||
_connectPick = typeof onPick === 'function' ? onPick : null;
|
||||
if (_cy) {
|
||||
_cy.userPanningEnabled(!_connectMode);
|
||||
}
|
||||
}
|
||||
|
||||
/** 与后端 GraphNodeType 一致:优先 category,vuln: 合成节点例外;无 category 时回退 type/key。 */
|
||||
function resolveGraphNodeType(node) {
|
||||
if (!node) return 'note';
|
||||
const key = String(node.fact_key || node.id || '').toLowerCase();
|
||||
if (key.startsWith('vuln:')) return 'vulnerability';
|
||||
const cat = String(node.category || '').toLowerCase();
|
||||
if (cat) {
|
||||
if (cat === 'vuln') return 'vulnerability';
|
||||
if (cat === 'missing') return 'missing';
|
||||
return cat;
|
||||
}
|
||||
const t = String(node.type || '').toLowerCase();
|
||||
if (t === 'vuln') return 'vulnerability';
|
||||
if (t) return t;
|
||||
if (key.startsWith('target/')) return 'target';
|
||||
if (key.startsWith('exploit/') || key.startsWith('evidence/')) return 'exploit';
|
||||
if (key.startsWith('poc/')) return 'poc';
|
||||
if (key.startsWith('chain/')) return 'chain';
|
||||
if (key.startsWith('finding/')) return 'finding';
|
||||
if (key.startsWith('auth/')) return 'auth';
|
||||
if (key.startsWith('infra/') || key.startsWith('business/')) return 'infra';
|
||||
return 'note';
|
||||
}
|
||||
|
||||
global.ProjectFactGraph = {
|
||||
render,
|
||||
destroy,
|
||||
center: centerGraph,
|
||||
filterBySearch,
|
||||
setConnectMode,
|
||||
selectEdge,
|
||||
clearEdgeSelection,
|
||||
nodeTheme,
|
||||
resolveGraphNodeType,
|
||||
};
|
||||
})(typeof window !== 'undefined' ? window : globalThis);
|
||||
+101
-85
@@ -1259,101 +1259,60 @@ function integrateProgressToMCPSection(progressId, assistantMessageId, mcpExecut
|
||||
return;
|
||||
}
|
||||
|
||||
// 查找或创建 MCP 区域
|
||||
let mcpSection = assistantElement.querySelector('.mcp-call-section');
|
||||
if (!mcpSection) {
|
||||
mcpSection = document.createElement('div');
|
||||
mcpSection.className = 'mcp-call-section';
|
||||
const mcpLabel = document.createElement('div');
|
||||
mcpLabel.className = 'mcp-call-label';
|
||||
mcpLabel.textContent = '📋 ' + (typeof window.t === 'function' ? window.t('chat.penetrationTestDetail') : '渗透测试详情');
|
||||
mcpSection.appendChild(mcpLabel);
|
||||
const buttonsContainerInit = document.createElement('div');
|
||||
buttonsContainerInit.className = 'mcp-call-buttons';
|
||||
mcpSection.appendChild(buttonsContainerInit);
|
||||
contentWrapper.appendChild(mcpSection);
|
||||
// 查找或创建 MCP 区域(工具栏 + 工具列表 + 迭代时间线)
|
||||
if (typeof window.ensureMcpCallSectionChrome === 'function') {
|
||||
window.ensureMcpCallSectionChrome(assistantElement, assistantMessageId);
|
||||
}
|
||||
|
||||
// 获取时间线内容
|
||||
const hasContent = timelineHTML.trim().length > 0;
|
||||
|
||||
// 检查时间线中是否有错误项
|
||||
const hasError = timeline && timeline.querySelector('.timeline-item-error');
|
||||
|
||||
// 确保按钮容器存在
|
||||
let buttonsContainer = mcpSection.querySelector('.mcp-call-buttons');
|
||||
if (!buttonsContainer) {
|
||||
buttonsContainer = document.createElement('div');
|
||||
buttonsContainer.className = 'mcp-call-buttons';
|
||||
mcpSection.appendChild(buttonsContainer);
|
||||
const mcpSection = assistantElement.querySelector('.mcp-call-section');
|
||||
if (!mcpSection) {
|
||||
removeMessage(progressId);
|
||||
return;
|
||||
}
|
||||
|
||||
let maxExecIndex = 0;
|
||||
const existingExecBtns = buttonsContainer.querySelectorAll('.mcp-detail-btn:not(.process-detail-btn)');
|
||||
existingExecBtns.forEach(function (btn) {
|
||||
const n = parseInt(btn.dataset.execIndex, 10);
|
||||
if (!isNaN(n) && n > maxExecIndex) maxExecIndex = n;
|
||||
});
|
||||
const seenExec = new Set();
|
||||
existingExecBtns.forEach(function (btn) {
|
||||
if (btn.dataset.execId) seenExec.add(String(btn.dataset.execId).trim());
|
||||
});
|
||||
let appendedAny = false;
|
||||
if (mcpIds.length > 0) {
|
||||
mcpIds.forEach(function (execId) {
|
||||
const id = execId != null ? String(execId).trim() : '';
|
||||
if (!id || seenExec.has(id)) return;
|
||||
seenExec.add(id);
|
||||
maxExecIndex += 1;
|
||||
appendedAny = true;
|
||||
const detailBtn = document.createElement('button');
|
||||
detailBtn.className = 'mcp-detail-btn';
|
||||
detailBtn.dataset.execId = id;
|
||||
detailBtn.dataset.execIndex = String(maxExecIndex);
|
||||
detailBtn.innerHTML = '<span>' + (typeof window.t === 'function' ? window.t('chat.callNumber', { n: maxExecIndex }) : '调用 #' + maxExecIndex) + '</span>';
|
||||
detailBtn.onclick = function () { showMCPDetail(id); };
|
||||
buttonsContainer.appendChild(detailBtn);
|
||||
});
|
||||
if (appendedAny && typeof batchUpdateButtonToolNames === 'function') {
|
||||
batchUpdateButtonToolNames(buttonsContainer, mcpIds);
|
||||
}
|
||||
const hasContent = timelineHTML.trim().length > 0;
|
||||
|
||||
if (mcpIds.length > 0 && typeof window.appendMcpCallButtons === 'function') {
|
||||
window.appendMcpCallButtons(assistantElement, mcpIds);
|
||||
const toolList = mcpSection.querySelector('.mcp-tool-list');
|
||||
if (toolList) toolList.classList.remove('expanded');
|
||||
}
|
||||
if (!buttonsContainer.querySelector('.process-detail-btn')) {
|
||||
if (typeof window.syncMcpToolsToggleButton === 'function') {
|
||||
window.syncMcpToolsToggleButton(assistantElement);
|
||||
}
|
||||
|
||||
const toolbar = mcpSection.querySelector('.mcp-call-toolbar');
|
||||
if (toolbar && !toolbar.querySelector('.process-detail-btn')) {
|
||||
const progressDetailBtn = document.createElement('button');
|
||||
progressDetailBtn.className = 'mcp-detail-btn process-detail-btn';
|
||||
progressDetailBtn.innerHTML = '<span>' + (typeof window.t === 'function' ? window.t('chat.expandDetail') : '展开详情') + '</span>';
|
||||
progressDetailBtn.onclick = () => toggleProcessDetails(null, assistantMessageId);
|
||||
buttonsContainer.appendChild(progressDetailBtn);
|
||||
toolbar.appendChild(progressDetailBtn);
|
||||
}
|
||||
|
||||
// 创建详情容器,放在MCP按钮区域下方(统一结构)
|
||||
|
||||
const detailsId = 'process-details-' + assistantMessageId;
|
||||
let detailsContainer = document.getElementById(detailsId);
|
||||
const toolListEl = mcpSection.querySelector('.mcp-tool-list');
|
||||
|
||||
if (!detailsContainer) {
|
||||
detailsContainer = document.createElement('div');
|
||||
detailsContainer.id = detailsId;
|
||||
detailsContainer.className = 'process-details-container';
|
||||
// 确保容器在按钮容器之后
|
||||
if (buttonsContainer.nextSibling) {
|
||||
mcpSection.insertBefore(detailsContainer, buttonsContainer.nextSibling);
|
||||
if (toolListEl) {
|
||||
toolListEl.after(detailsContainer);
|
||||
} else {
|
||||
mcpSection.appendChild(detailsContainer);
|
||||
}
|
||||
}
|
||||
|
||||
// 设置详情内容(如果有错误,默认折叠;否则默认折叠)
|
||||
detailsContainer.innerHTML = `
|
||||
<div class="process-details-content">
|
||||
${hasContent ? `<div class="progress-timeline" id="${detailsId}-timeline">${timelineHTML}</div>` : '<div class="progress-timeline-empty">' + (typeof window.t === 'function' ? window.t('chat.noProcessDetail') : '暂无过程详情(可能执行过快或未触发详细事件)') + '</div>'}
|
||||
</div>
|
||||
`;
|
||||
|
||||
// 确保初始状态是折叠的(默认折叠,特别是错误时)
|
||||
if (hasContent) {
|
||||
const timeline = document.getElementById(detailsId + '-timeline');
|
||||
if (timeline) {
|
||||
// 如果有错误,确保折叠;否则也默认折叠
|
||||
timeline.classList.remove('expanded');
|
||||
}
|
||||
|
||||
@@ -1363,10 +1322,47 @@ function integrateProgressToMCPSection(progressId, assistantMessageId, mcpExecut
|
||||
});
|
||||
}
|
||||
|
||||
// 移除原来的进度消息(详情已快照到助手消息下的 process-details)
|
||||
removeMessage(progressId);
|
||||
}
|
||||
|
||||
const PROCESS_DETAILS_PAGE_SIZE = 100;
|
||||
|
||||
/**
|
||||
* 分页加载过程详情并增量渲染,避免数百轮迭代一次性阻塞主线程。
|
||||
*/
|
||||
async function loadProcessDetailsPaginated(assistantMessageId, backendMessageId) {
|
||||
if (!assistantMessageId || !backendMessageId || typeof apiFetch !== 'function' || typeof renderProcessDetails !== 'function') {
|
||||
return;
|
||||
}
|
||||
const PAGE = PROCESS_DETAILS_PAGE_SIZE;
|
||||
let offset = 0;
|
||||
let isFirst = true;
|
||||
while (true) {
|
||||
const res = await apiFetch(
|
||||
'/api/messages/' + encodeURIComponent(String(backendMessageId)) +
|
||||
'/process-details?limit=' + PAGE + '&offset=' + offset
|
||||
);
|
||||
const j = await res.json().catch(() => ({}));
|
||||
if (!res.ok) {
|
||||
throw new Error((j && j.error) ? j.error : String(res.status));
|
||||
}
|
||||
const details = (j && Array.isArray(j.processDetails)) ? j.processDetails : [];
|
||||
const hasMore = !!(j && j.hasMore);
|
||||
renderProcessDetails(assistantMessageId, details, {
|
||||
append: !isFirst,
|
||||
markLoaded: !hasMore
|
||||
});
|
||||
if (!hasMore || details.length === 0) {
|
||||
break;
|
||||
}
|
||||
offset += details.length;
|
||||
isFirst = false;
|
||||
await new Promise((resolve) => requestAnimationFrame(resolve));
|
||||
}
|
||||
}
|
||||
|
||||
window.loadProcessDetailsPaginated = loadProcessDetailsPaginated;
|
||||
|
||||
// 切换过程详情显示
|
||||
function toggleProcessDetails(progressId, assistantMessageId) {
|
||||
const detailsId = 'process-details-' + assistantMessageId;
|
||||
@@ -1383,26 +1379,17 @@ function toggleProcessDetails(progressId, assistantMessageId) {
|
||||
// 正在加载中,避免重复请求
|
||||
} else {
|
||||
detailsContainer.dataset.loading = '1';
|
||||
// 先展开容器,显示加载态
|
||||
const timeline = detailsContainer.querySelector('.progress-timeline');
|
||||
if (timeline) {
|
||||
timeline.innerHTML = '<div class="progress-timeline-empty">' + ((typeof window.t === 'function') ? window.t('common.loading') : '加载中…') + '</div>';
|
||||
}
|
||||
apiFetch(`/api/messages/${encodeURIComponent(String(backendMessageId))}/process-details`)
|
||||
.then(async (res) => {
|
||||
const j = await res.json().catch(() => ({}));
|
||||
if (!res.ok) throw new Error((j && j.error) ? j.error : res.status);
|
||||
const details = (j && Array.isArray(j.processDetails)) ? j.processDetails : [];
|
||||
// 重新渲染详情(renderProcessDetails 会清掉 lazy 标记并写入 loaded)
|
||||
renderProcessDetails(assistantMessageId, details);
|
||||
})
|
||||
loadProcessDetailsPaginated(assistantMessageId, backendMessageId)
|
||||
.catch((e) => {
|
||||
console.error('加载过程详情失败:', e);
|
||||
const tl = detailsContainer.querySelector('.progress-timeline');
|
||||
if (tl) {
|
||||
tl.innerHTML = '<div class="progress-timeline-empty">' + ((typeof window.t === 'function') ? window.t('chat.noProcessDetail') : '暂无过程详情(加载失败)') + '</div>';
|
||||
}
|
||||
// 失败时保留 lazy 状态,允许用户重试
|
||||
detailsContainer.dataset.lazyNotLoaded = '1';
|
||||
detailsContainer.dataset.loaded = '0';
|
||||
})
|
||||
@@ -1944,6 +1931,7 @@ function handleStreamEvent(event, progressElement, progressId,
|
||||
message: event.message || '',
|
||||
data: d
|
||||
});
|
||||
finalizeOutstandingToolCallsForProgress(progressId, 'failed');
|
||||
break;
|
||||
}
|
||||
|
||||
@@ -2755,12 +2743,16 @@ async function restoreHitlInlineForConversation(conversationId) {
|
||||
if (detailsContainer.dataset.lazyNotLoaded === '1' && detailsContainer.dataset.loaded !== '1') {
|
||||
try {
|
||||
detailsContainer.dataset.loading = '1';
|
||||
const res = await apiFetch('/api/messages/' + encodeURIComponent(backendMsgId) + '/process-details');
|
||||
const j = await res.json().catch(function () { return {}; });
|
||||
if (!res.ok) throw new Error((j && j.error) ? j.error : String(res.status));
|
||||
const details = (j && Array.isArray(j.processDetails)) ? j.processDetails : [];
|
||||
if (typeof renderProcessDetails === 'function') {
|
||||
renderProcessDetails(clientMsgId, details);
|
||||
if (typeof loadProcessDetailsPaginated === 'function') {
|
||||
await loadProcessDetailsPaginated(clientMsgId, backendMsgId);
|
||||
} else {
|
||||
const res = await apiFetch('/api/messages/' + encodeURIComponent(backendMsgId) + '/process-details');
|
||||
const j = await res.json().catch(function () { return {}; });
|
||||
if (!res.ok) throw new Error((j && j.error) ? j.error : String(res.status));
|
||||
const details = (j && Array.isArray(j.processDetails)) ? j.processDetails : [];
|
||||
if (typeof renderProcessDetails === 'function') {
|
||||
renderProcessDetails(clientMsgId, details);
|
||||
}
|
||||
}
|
||||
} catch (e) {
|
||||
console.error('加载过程详情失败(HITL 恢复):', e);
|
||||
@@ -3531,6 +3523,7 @@ const monitorState = {
|
||||
timelineRange: null,
|
||||
timelineError: null,
|
||||
lastFetchedAt: null,
|
||||
retentionDays: 0,
|
||||
pagination: {
|
||||
page: 1,
|
||||
pageSize: (() => {
|
||||
@@ -3626,6 +3619,7 @@ async function refreshMonitorPanel(page = null) {
|
||||
monitorState.timeline = timeline;
|
||||
monitorState.timelineError = timelineError;
|
||||
monitorState.lastFetchedAt = new Date();
|
||||
monitorState.retentionDays = typeof result.retention_days === 'number' ? result.retention_days : 0;
|
||||
|
||||
// 更新分页信息
|
||||
if (result.total !== undefined) {
|
||||
@@ -3709,6 +3703,7 @@ async function refreshMonitorPanelWithFilter(statusFilter = 'all', toolFilter =
|
||||
monitorState.timeline = timeline;
|
||||
monitorState.timelineError = timelineError;
|
||||
monitorState.lastFetchedAt = new Date();
|
||||
monitorState.retentionDays = typeof result.retention_days === 'number' ? result.retention_days : 0;
|
||||
|
||||
// 更新分页信息
|
||||
if (result.total !== undefined) {
|
||||
@@ -4526,15 +4521,20 @@ function renderMcpStatsStackedBar(success, failed) {
|
||||
</div>`;
|
||||
}
|
||||
|
||||
function updateMonitorStatsSubtitle(lastFetchedAt, toolCount) {
|
||||
function updateMonitorStatsSubtitle(lastFetchedAt, toolCount, retentionDays) {
|
||||
const subtitle = document.getElementById('monitor-stats-subtitle');
|
||||
if (!subtitle) return;
|
||||
const locale = (typeof window.__locale === 'string' && window.__locale.startsWith('zh')) ? 'zh-CN' : 'en-US';
|
||||
const timeText = lastFetchedAt
|
||||
? (lastFetchedAt.toLocaleString ? lastFetchedAt.toLocaleString(locale) : String(lastFetchedAt))
|
||||
: '—';
|
||||
const text = mcpMonitorT('statsSubtitle', { time: timeText, count: toolCount })
|
||||
let text = mcpMonitorT('statsSubtitle', { time: timeText, count: toolCount })
|
||||
|| monitorFallback(`最后刷新 ${timeText} · 共 ${toolCount} 个工具`, `Refreshed ${timeText} · ${toolCount} tools`);
|
||||
if (typeof retentionDays === 'number' && retentionDays > 0) {
|
||||
const hint = mcpMonitorT('retentionHint', { days: retentionDays })
|
||||
|| monitorFallback(`执行记录保留 ${retentionDays} 天,超期自动清理`, `Execution records are kept for ${retentionDays} days, then purged automatically.`);
|
||||
text += ' · ' + hint;
|
||||
}
|
||||
subtitle.textContent = text;
|
||||
subtitle.hidden = false;
|
||||
}
|
||||
@@ -4959,7 +4959,7 @@ function renderMonitorStats(statsMap = {}, lastFetchedAt = null) {
|
||||
} else if (toolFilterEl) {
|
||||
toolFilterEl.classList.remove('is-filter-active');
|
||||
}
|
||||
updateMonitorStatsSubtitle(lastFetchedAt, entries.length);
|
||||
updateMonitorStatsSubtitle(lastFetchedAt, entries.length, monitorState.retentionDays);
|
||||
}
|
||||
|
||||
function renderMonitorExecutions(executions = [], statusFilter = 'all') {
|
||||
@@ -5459,6 +5459,22 @@ function refreshProgressAndTimelineI18n() {
|
||||
const expanded = timeline && timeline.classList.contains('expanded');
|
||||
span.textContent = expanded ? _t('tasks.collapseDetail') : _t('chat.expandDetail');
|
||||
});
|
||||
|
||||
document.querySelectorAll('#chat-messages .message.assistant').forEach(function (msgEl) {
|
||||
if (typeof window.syncMcpToolsToggleButton === 'function') {
|
||||
window.syncMcpToolsToggleButton(msgEl);
|
||||
}
|
||||
});
|
||||
|
||||
const copyLabel = _t('common.copy');
|
||||
const copyTitle = _t('chat.copyMessageTitle');
|
||||
document.querySelectorAll('#chat-messages .message-copy-btn').forEach(function (btn) {
|
||||
if (btn.dataset.copySuccessActive === '1') return;
|
||||
const span = btn.querySelector('span');
|
||||
if (span) span.textContent = copyLabel;
|
||||
btn.title = copyTitle;
|
||||
btn.setAttribute('aria-label', copyTitle);
|
||||
});
|
||||
}
|
||||
|
||||
document.addEventListener('languagechange', function () {
|
||||
|
||||
+355
-5
@@ -64,6 +64,8 @@ Host: ...
|
||||
## 关联
|
||||
- related_vulnerability_id: <可选>
|
||||
- 依赖事实: <fact_key,如 auth/session_cookie>
|
||||
- 结构化关系边(自动同步;links 文本格式 type: source_fact_key):
|
||||
- discovered_on: target/primary_domain
|
||||
|
||||
## 备注与不确定性
|
||||
<待验证假设、环境差异、绕过尝试记录>`;
|
||||
@@ -730,20 +732,316 @@ async function selectProject(id) {
|
||||
|
||||
function switchProjectTab(tab) {
|
||||
currentProjectTab = tab;
|
||||
['facts', 'conversations', 'vulns', 'settings'].forEach((t) => {
|
||||
['facts', 'graph', 'conversations', 'vulns', 'settings'].forEach((t) => {
|
||||
const btn = document.getElementById(`project-tab-${t}`);
|
||||
const panel = document.getElementById(`project-panel-${t}`);
|
||||
if (btn) btn.classList.toggle('is-active', t === tab);
|
||||
if (panel) panel.hidden = t !== tab;
|
||||
});
|
||||
if (tab === 'facts') loadProjectFacts();
|
||||
if (tab === 'graph') loadProjectFactGraph();
|
||||
if (tab === 'conversations') loadProjectConversations();
|
||||
if (tab === 'vulns') loadProjectVulnerabilities();
|
||||
}
|
||||
|
||||
let _selectedGraphFactKey = null;
|
||||
let _selectedGraphEdgeId = null;
|
||||
let _currentGraphData = null;
|
||||
let _graphConnectMode = false;
|
||||
let _graphConnectSource = null;
|
||||
|
||||
function toggleProjectFactGraphConnectMode() {
|
||||
_graphConnectMode = !_graphConnectMode;
|
||||
_graphConnectSource = null;
|
||||
const btn = document.getElementById('project-graph-connect-btn');
|
||||
if (btn) {
|
||||
btn.classList.toggle('is-active', _graphConnectMode);
|
||||
btn.textContent = _graphConnectMode ? tp('projects.graphConnectActive') : tp('projects.graphConnect');
|
||||
btn.classList.toggle('projects-graph-action-btn--connect-active', _graphConnectMode);
|
||||
}
|
||||
if (typeof ProjectFactGraph !== 'undefined') {
|
||||
ProjectFactGraph.setConnectMode(_graphConnectMode, handleGraphConnectNodePick);
|
||||
}
|
||||
}
|
||||
|
||||
async function handleGraphConnectNodePick(factKey) {
|
||||
if (!factKey || String(factKey).startsWith('vuln:')) return;
|
||||
if (!_graphConnectSource) {
|
||||
_graphConnectSource = factKey;
|
||||
if (typeof showNotification === 'function') {
|
||||
showNotification(tpFmt('projects.graphConnectPickTarget', `已选源节点 ${factKey},请点击目标节点`, { source: factKey }), 'info');
|
||||
}
|
||||
return;
|
||||
}
|
||||
if (_graphConnectSource === factKey) return;
|
||||
const edgeType = window.prompt(tp('projects.graphEdgeTypePrompt'), 'leads_to');
|
||||
if (!edgeType) {
|
||||
_graphConnectSource = null;
|
||||
return;
|
||||
}
|
||||
const res = await apiFetch(`/api/projects/${currentProjectId}/fact-edges`, {
|
||||
method: 'POST',
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
body: JSON.stringify({
|
||||
source_fact_key: _graphConnectSource,
|
||||
target_fact_key: factKey,
|
||||
edge_type: edgeType.trim(),
|
||||
}),
|
||||
});
|
||||
_graphConnectSource = null;
|
||||
if (!res.ok) {
|
||||
const err = await res.json().catch(() => ({}));
|
||||
return alert(err.error || tp('projects.graphConnectFailed'));
|
||||
}
|
||||
if (typeof showNotification === 'function') showNotification(tp('projects.graphConnectSuccess'), 'success');
|
||||
loadProjectFactGraph();
|
||||
loadProjectFacts();
|
||||
}
|
||||
|
||||
function formatIncomingLinksForModal(links) {
|
||||
if (!links || !links.length) return '';
|
||||
return links
|
||||
.map((e) => `${e.edge_type || e.type}: ${e.source_fact_key || e.from}`)
|
||||
.join('\n');
|
||||
}
|
||||
|
||||
|
||||
async function loadProjectFactGraph() {
|
||||
const container = document.getElementById('project-fact-graph-container');
|
||||
const statsEl = document.getElementById('project-fact-graph-stats');
|
||||
if (!container || !currentProjectId) return;
|
||||
container.innerHTML = `<div class="loading-spinner">${escapeHtml(tp('common.loading'))}</div>`;
|
||||
closeProjectFactGraphSidebar();
|
||||
const view = document.getElementById('project-graph-view')?.value || 'path';
|
||||
const hideDeprecated = document.getElementById('project-facts-filter-hide-deprecated')?.checked !== false;
|
||||
const params = new URLSearchParams({ view });
|
||||
if (!hideDeprecated) params.set('exclude_deprecated', '0');
|
||||
try {
|
||||
const res = await apiFetch(`/api/projects/${currentProjectId}/fact-graph?${params}`);
|
||||
if (!res.ok) throw new Error(tp('common.loadFailed'));
|
||||
const data = await res.json();
|
||||
_currentGraphData = data;
|
||||
if (typeof ProjectFactGraph !== 'undefined') {
|
||||
ProjectFactGraph.render(container, data, {
|
||||
emptyText: tp('projects.graphEmpty'),
|
||||
emptyTitle: tp('projects.graphEmptyTitle'),
|
||||
emptySteps: [
|
||||
tp('projects.graphEmptyStep1'),
|
||||
tp('projects.graphEmptyStep2'),
|
||||
tp('projects.graphEmptyStep3'),
|
||||
],
|
||||
emptyActionLabel: tp('projects.graphEmptyCta'),
|
||||
onEmptyAction: () => showAddFactModal(),
|
||||
onNodeSelect: (factKey) => showProjectFactGraphNode(factKey, _currentGraphData),
|
||||
onEdgeSelect: (edgeId) => showProjectFactGraphEdge(edgeId, _currentGraphData),
|
||||
});
|
||||
}
|
||||
const nodeCount = (data.nodes || []).length;
|
||||
const edgeCount = (data.edges || []).length;
|
||||
if (statsEl) {
|
||||
statsEl.innerHTML =
|
||||
`<span class="projects-graph-stat-badge"><strong>${nodeCount}</strong> ${escapeHtml(tp('projects.graphStatsNodes'))}</span>` +
|
||||
`<span class="projects-graph-stat-badge"><strong>${edgeCount}</strong> ${escapeHtml(tp('projects.graphStatsEdges'))}</span>`;
|
||||
}
|
||||
} catch (e) {
|
||||
container.innerHTML = `<div class="error-message">${escapeHtml(e.message || tp('common.loadFailed'))}</div>`;
|
||||
if (statsEl) statsEl.textContent = '';
|
||||
}
|
||||
}
|
||||
|
||||
function filterProjectFactGraph() {
|
||||
const q = document.getElementById('project-graph-search')?.value || '';
|
||||
if (typeof ProjectFactGraph !== 'undefined') {
|
||||
ProjectFactGraph.filterBySearch(q);
|
||||
}
|
||||
}
|
||||
|
||||
function centerProjectFactGraph() {
|
||||
if (typeof ProjectFactGraph !== 'undefined') ProjectFactGraph.center();
|
||||
}
|
||||
|
||||
function closeProjectFactGraphSidebar() {
|
||||
_selectedGraphFactKey = null;
|
||||
_selectedGraphEdgeId = null;
|
||||
if (typeof ProjectFactGraph !== 'undefined') ProjectFactGraph.clearEdgeSelection();
|
||||
const sidebar = document.getElementById('project-fact-graph-sidebar');
|
||||
if (sidebar) sidebar.hidden = true;
|
||||
}
|
||||
|
||||
function isSyntheticGraphEdge(edge) {
|
||||
if (!edge) return true;
|
||||
const id = String(edge.id || '');
|
||||
const type = String(edge.type || '');
|
||||
return id.startsWith('vuln-link:') || type === 'links_vuln';
|
||||
}
|
||||
|
||||
function getGraphEdgesForFact(factKey, graphData) {
|
||||
if (!factKey || !graphData?.edges) return [];
|
||||
return graphData.edges.filter((e) => e.source === factKey || e.target === factKey);
|
||||
}
|
||||
|
||||
function renderGraphEdgesListHtml(factKey, graphData, selectedEdgeId) {
|
||||
const edges = getGraphEdgesForFact(factKey, graphData);
|
||||
if (!edges.length) {
|
||||
return `<p class="project-fact-graph-edges-empty">${escapeHtml(tp('projects.graphEdgesEmpty'))}</p>`;
|
||||
}
|
||||
return edges
|
||||
.map((e) => {
|
||||
const isOut = e.source === factKey;
|
||||
const dirLabel = isOut ? tp('projects.graphEdgeFromSelf') : tp('projects.graphEdgeToSelf');
|
||||
const src = e.source || '';
|
||||
const tgt = e.target || '';
|
||||
const selected = e.id === selectedEdgeId ? ' is-selected' : '';
|
||||
const synthetic = isSyntheticGraphEdge(e);
|
||||
const deleteBtn = synthetic
|
||||
? `<span class="project-fact-graph-edge-synthetic" title="${escapeHtml(tp('projects.graphEdgeSynthetic'))}">—</span>`
|
||||
: `<button type="button" class="project-fact-graph-edge-delete" data-edge-id="${escapeHtml(e.id)}" onclick="event.stopPropagation(); deleteProjectFactEdge(this.dataset.edgeId)" title="${escapeHtml(tp('projects.graphDeleteEdge'))}">×</button>`;
|
||||
return `<div class="project-fact-graph-edge-item${selected}" data-edge-id="${escapeHtml(e.id)}" onclick="focusProjectFactGraphEdge(${JSON.stringify(e.id)})">
|
||||
<span class="project-fact-graph-edge-dir">${escapeHtml(dirLabel)}</span>
|
||||
<span class="project-fact-graph-edge-type">${escapeHtml(e.type || '')}</span>
|
||||
<span class="project-fact-graph-edge-peer" title="${escapeHtml(src + ' → ' + tgt)}">${escapeHtml(src)} → ${escapeHtml(tgt)}</span>
|
||||
${deleteBtn}
|
||||
</div>`;
|
||||
})
|
||||
.join('');
|
||||
}
|
||||
|
||||
function renderProjectFactGraphEdges(factKey, graphData, selectedEdgeId) {
|
||||
const wrap = document.getElementById('project-fact-graph-edges-wrap');
|
||||
const list = document.getElementById('project-fact-graph-edges-list');
|
||||
if (!wrap || !list) return;
|
||||
const edges = getGraphEdgesForFact(factKey, graphData);
|
||||
wrap.hidden = false;
|
||||
list.innerHTML = renderGraphEdgesListHtml(factKey, graphData, selectedEdgeId);
|
||||
if (selectedEdgeId) {
|
||||
const selectedEl = list.querySelector('[data-edge-id="' + String(selectedEdgeId).replace(/\\/g, '\\\\').replace(/"/g, '\\"') + '"]');
|
||||
if (selectedEl) selectedEl.scrollIntoView({ block: 'nearest' });
|
||||
}
|
||||
if (!edges.length) wrap.hidden = false;
|
||||
}
|
||||
|
||||
function graphVulnIdFromKey(factKey) {
|
||||
const key = String(factKey || '');
|
||||
if (!key.startsWith('vuln:')) return null;
|
||||
return key.slice(5);
|
||||
}
|
||||
|
||||
function showProjectFactGraphNode(factKey, graphData, selectedEdgeId) {
|
||||
if (!factKey) {
|
||||
closeProjectFactGraphSidebar();
|
||||
return;
|
||||
}
|
||||
_selectedGraphFactKey = factKey;
|
||||
_selectedGraphEdgeId = selectedEdgeId || null;
|
||||
const node = (graphData?.nodes || []).find((n) => n.fact_key === factKey || n.id === factKey);
|
||||
const vulnId = graphVulnIdFromKey(factKey);
|
||||
const isVulnNode = !!vulnId;
|
||||
const sidebar = document.getElementById('project-fact-graph-sidebar');
|
||||
const titleEl = document.getElementById('project-fact-graph-node-title');
|
||||
const metaEl = document.getElementById('project-fact-graph-node-meta');
|
||||
const categoryEl = document.getElementById('project-fact-graph-node-category');
|
||||
const detailBtn = document.getElementById('project-fact-graph-detail-btn');
|
||||
const editBtn = document.getElementById('project-fact-graph-edit-btn');
|
||||
if (!sidebar || !titleEl || !metaEl) return;
|
||||
titleEl.textContent = isVulnNode ? vulnId : factKey;
|
||||
titleEl.title = isVulnNode ? vulnId : factKey;
|
||||
if (categoryEl) {
|
||||
const visualType =
|
||||
typeof ProjectFactGraph !== 'undefined' && ProjectFactGraph.resolveGraphNodeType
|
||||
? ProjectFactGraph.resolveGraphNodeType(node)
|
||||
: node?.type || node?.category || 'note';
|
||||
const theme =
|
||||
typeof ProjectFactGraph !== 'undefined' && ProjectFactGraph.nodeTheme
|
||||
? ProjectFactGraph.nodeTheme(visualType)
|
||||
: { typeEn: String(visualType).toUpperCase(), typeLabel: visualType };
|
||||
categoryEl.textContent = theme.typeEn || String(visualType).toUpperCase();
|
||||
categoryEl.hidden = false;
|
||||
categoryEl.className = 'project-fact-graph-node-category project-fact-graph-node-category--' + visualType;
|
||||
categoryEl.title = theme.typeLabel || visualType;
|
||||
}
|
||||
const conf = node?.confidence || '';
|
||||
const summary = (node?.summary || node?.label || '').trim();
|
||||
if (summary || conf || isVulnNode) {
|
||||
const parts = [];
|
||||
if (summary) {
|
||||
parts.push(`<span class="project-fact-graph-node-summary">${escapeHtml(summary)}</span>`);
|
||||
}
|
||||
if (isVulnNode) {
|
||||
parts.push(
|
||||
`<span class="project-fact-graph-node-vuln-hint">${escapeHtml(tp('projects.graphVulnSidebarHint'))}</span>`,
|
||||
);
|
||||
}
|
||||
if (conf) {
|
||||
parts.push(formatConfidenceBadge(conf));
|
||||
}
|
||||
metaEl.innerHTML = parts.join('');
|
||||
} else {
|
||||
metaEl.textContent = '';
|
||||
}
|
||||
if (detailBtn) {
|
||||
detailBtn.textContent = isVulnNode ? tp('projects.viewVulnerability') : tp('projects.details');
|
||||
}
|
||||
if (editBtn) {
|
||||
editBtn.hidden = isVulnNode;
|
||||
}
|
||||
renderProjectFactGraphEdges(factKey, graphData, _selectedGraphEdgeId);
|
||||
if (_selectedGraphEdgeId && typeof ProjectFactGraph !== 'undefined') {
|
||||
ProjectFactGraph.selectEdge(_selectedGraphEdgeId);
|
||||
} else if (typeof ProjectFactGraph !== 'undefined') {
|
||||
ProjectFactGraph.clearEdgeSelection();
|
||||
}
|
||||
sidebar.hidden = false;
|
||||
}
|
||||
|
||||
function showProjectFactGraphEdge(edgeId, graphData) {
|
||||
const edge = (graphData?.edges || []).find((e) => e.id === edgeId);
|
||||
if (!edge) return;
|
||||
const anchorKey = edge.source && !String(edge.source).startsWith('vuln:') ? edge.source : edge.target;
|
||||
showProjectFactGraphNode(anchorKey, graphData, edgeId);
|
||||
}
|
||||
|
||||
function focusProjectFactGraphEdge(edgeId) {
|
||||
if (!edgeId || !_currentGraphData) return;
|
||||
showProjectFactGraphEdge(edgeId, _currentGraphData);
|
||||
}
|
||||
|
||||
async function deleteProjectFactEdge(edgeId) {
|
||||
if (!edgeId || !currentProjectId) return;
|
||||
const edge = (_currentGraphData?.edges || []).find((e) => e.id === edgeId);
|
||||
if (isSyntheticGraphEdge(edge)) return;
|
||||
if (!confirm(tp('projects.confirmDeleteGraphEdge'))) return;
|
||||
const res = await apiFetch(`/api/projects/${currentProjectId}/fact-edges/${encodeURIComponent(edgeId)}`, {
|
||||
method: 'DELETE',
|
||||
});
|
||||
if (!res.ok) {
|
||||
const err = await res.json().catch(() => ({}));
|
||||
return alert(err.error || tp('projects.graphEdgeDeleteFailed'));
|
||||
}
|
||||
if (typeof showNotification === 'function') showNotification(tp('projects.graphEdgeDeleteSuccess'), 'success');
|
||||
const keepKey = _selectedGraphFactKey;
|
||||
await loadProjectFactGraph();
|
||||
if (keepKey) showProjectFactGraphNode(keepKey, _currentGraphData);
|
||||
loadProjectFacts();
|
||||
}
|
||||
|
||||
function openSelectedGraphFactDetail() {
|
||||
if (!_selectedGraphFactKey) return;
|
||||
const vulnId = graphVulnIdFromKey(_selectedGraphFactKey);
|
||||
if (vulnId) {
|
||||
openVulnerabilityDetail(vulnId);
|
||||
return;
|
||||
}
|
||||
viewProjectFactBody(_selectedGraphFactKey);
|
||||
}
|
||||
|
||||
function editSelectedGraphFact() {
|
||||
if (_selectedGraphFactKey) showEditFactModal(_selectedGraphFactKey);
|
||||
}
|
||||
|
||||
function buildProjectFactsQueryParams() {
|
||||
const params = new URLSearchParams();
|
||||
params.set('limit', '200');
|
||||
params.set('include_link_counts', 'true');
|
||||
const search = document.getElementById('project-facts-search')?.value?.trim();
|
||||
const category = document.getElementById('project-facts-filter-category')?.value?.trim();
|
||||
const confidence = document.getElementById('project-facts-filter-confidence')?.value?.trim();
|
||||
@@ -768,11 +1066,11 @@ function debouncedLoadProjectFacts() {
|
||||
async function loadProjectFacts() {
|
||||
const tbody = document.getElementById('project-facts-tbody');
|
||||
if (!tbody || !currentProjectId) return;
|
||||
tbody.innerHTML = `<tr class="is-empty-row"><td colspan="7">${escapeHtml(tp('common.loading'))}</td></tr>`;
|
||||
tbody.innerHTML = `<tr class="is-empty-row"><td colspan="8">${escapeHtml(tp('common.loading'))}</td></tr>`;
|
||||
const qs = buildProjectFactsQueryParams().toString();
|
||||
const res = await apiFetch(`/api/projects/${currentProjectId}/facts?${qs}`);
|
||||
if (!res.ok) {
|
||||
tbody.innerHTML = `<tr class="is-empty-row"><td colspan="7">${escapeHtml(tp('common.loadFailed'))}</td></tr>`;
|
||||
tbody.innerHTML = `<tr class="is-empty-row"><td colspan="8">${escapeHtml(tp('common.loadFailed'))}</td></tr>`;
|
||||
return;
|
||||
}
|
||||
const facts = await res.json();
|
||||
@@ -782,7 +1080,7 @@ async function loadProjectFacts() {
|
||||
document.getElementById('project-facts-filter-category')?.value ||
|
||||
document.getElementById('project-facts-filter-confidence')?.value ||
|
||||
document.getElementById('project-facts-filter-sparse')?.checked;
|
||||
tbody.innerHTML = `<tr class="is-empty-row"><td colspan="7">${
|
||||
tbody.innerHTML = `<tr class="is-empty-row"><td colspan="8">${
|
||||
hasFilter ? tp('projects.noMatchingFacts') : tp('projects.noFacts')
|
||||
}</td></tr>`;
|
||||
refreshProjectHeaderStats();
|
||||
@@ -797,10 +1095,16 @@ async function loadProjectFacts() {
|
||||
const pinBadge = f.pinned
|
||||
? `<span class="projects-list-item-badge" title="${escapeHtml(tp('projects.pinned'))}">${escapeHtml(tp('projects.pinned'))}</span>`
|
||||
: '';
|
||||
const lc = f.link_counts || {};
|
||||
const linkBadge =
|
||||
lc.outgoing || lc.incoming
|
||||
? `<span class="projects-fact-link-badge" title="${escapeHtml(tp('projects.linkCountsTitle'))}">↑${lc.outgoing || 0} ↓${lc.incoming || 0}</span>`
|
||||
: '<span class="projects-fact-link-badge projects-fact-link-badge--empty">—</span>';
|
||||
return `<tr>
|
||||
<td class="cell-fact-key"><code class="projects-fact-key-chip" title="${keyEsc}">${keyEsc}</code>${pinBadge}${vulnLink}</td>
|
||||
<td class="cell-fact-category">${formatCategoryBadge(f.category)}</td>
|
||||
<td class="cell-summary" title="${escapeHtml(f.summary)}">${escapeHtml(f.summary)}</td>
|
||||
<td class="cell-fact-links">${linkBadge}</td>
|
||||
<td>${formatFactBodyBadge(f)}</td>
|
||||
<td>${formatConfidenceBadge(f.confidence)}</td>
|
||||
<td>${formatProjectTime(f.updated_at, f.created_at)}</td>
|
||||
@@ -849,6 +1153,7 @@ async function loadProjectConversations() {
|
||||
<td class="col-actions">
|
||||
<div class="projects-table-actions">
|
||||
<button type="button" class="projects-action-btn projects-action-btn--view" data-conv-id="${idEsc}" onclick="openProjectConversation(this.dataset.convId)">${escapeHtml(tp('projects.open'))}</button>
|
||||
<button type="button" class="projects-action-btn" data-conv-id="${idEsc}" onclick="promoteConversationAttackChain(this.dataset.convId)" title="${escapeHtml(tp('projects.promoteAttackChainTitle'))}">${escapeHtml(tp('projects.promoteAttackChain'))}</button>
|
||||
<button type="button" class="projects-action-btn projects-action-btn--mute" data-conv-id="${idEsc}" onclick="unbindConversationFromProject(this.dataset.convId)" title="${escapeHtml(tp('projects.unbindProjectTitle'))}">${escapeHtml(tp('projects.unbind'))}</button>
|
||||
</div>
|
||||
</td>
|
||||
@@ -869,6 +1174,32 @@ function openProjectConversation(conversationId) {
|
||||
}, 200);
|
||||
}
|
||||
|
||||
async function promoteConversationAttackChain(conversationId) {
|
||||
if (!currentProjectId || !conversationId) return;
|
||||
if (!confirm(tp('projects.confirmPromoteAttackChain'))) return;
|
||||
const res = await apiFetch(
|
||||
`/api/projects/${currentProjectId}/promote-attack-chain/${encodeURIComponent(conversationId)}`,
|
||||
{ method: 'POST' },
|
||||
);
|
||||
if (!res.ok) {
|
||||
const err = await res.json().catch(() => ({}));
|
||||
return alert(err.error || tp('projects.promoteAttackChainFailed'));
|
||||
}
|
||||
const data = await res.json();
|
||||
if (typeof showNotification === 'function') {
|
||||
showNotification(
|
||||
tpFmt(
|
||||
'projects.promoteAttackChainSuccess',
|
||||
`已沉淀 ${data.facts_created || 0} 新 / ${data.facts_updated || 0} 更新 / ${data.edges_created || 0} 边`,
|
||||
data,
|
||||
),
|
||||
'success',
|
||||
);
|
||||
}
|
||||
loadProjectFacts();
|
||||
if (currentProjectTab === 'graph') loadProjectFactGraph();
|
||||
}
|
||||
|
||||
async function unbindConversationFromProject(conversationId) {
|
||||
if (!conversationId || !confirm(tp('projects.confirmUnbindConversation'))) return;
|
||||
const res = await apiFetch(`/api/conversations/${encodeURIComponent(conversationId)}/project`, {
|
||||
@@ -1509,6 +1840,10 @@ function resetFactModalForm() {
|
||||
if (pinEl) pinEl.checked = false;
|
||||
const rel = document.getElementById('fact-modal-related-vuln');
|
||||
if (rel) rel.value = '';
|
||||
const linksEl = document.getElementById('fact-modal-links');
|
||||
if (linksEl) linksEl.value = '';
|
||||
const incomingWrap = document.getElementById('fact-modal-incoming-links-wrap');
|
||||
if (incomingWrap) incomingWrap.hidden = true;
|
||||
updateFactFormHints();
|
||||
}
|
||||
|
||||
@@ -1540,6 +1875,8 @@ function fillFactModalForm(f) {
|
||||
}
|
||||
const rel = document.getElementById('fact-modal-related-vuln');
|
||||
if (rel) rel.value = f.related_vulnerability_id || '';
|
||||
const linksEl = document.getElementById('fact-modal-links');
|
||||
if (linksEl) linksEl.value = formatIncomingLinksForModal(f.incoming_links);
|
||||
const pinEl = document.getElementById('fact-modal-pinned');
|
||||
if (pinEl) pinEl.checked = !!f.pinned;
|
||||
updateFactFormHints();
|
||||
@@ -1556,7 +1893,7 @@ async function showEditFactModal(factKey) {
|
||||
resetFactModalForm();
|
||||
openProjectsOverlay('fact-modal', { focus: false });
|
||||
const res = await apiFetch(
|
||||
`/api/projects/${currentProjectId}/facts?fact_key=${encodeURIComponent(factKey)}`,
|
||||
`/api/projects/${currentProjectId}/facts?fact_key=${encodeURIComponent(factKey)}&include_links=true`,
|
||||
);
|
||||
if (!res.ok) {
|
||||
closeFactModal();
|
||||
@@ -1594,6 +1931,7 @@ async function saveFactModal() {
|
||||
confidence: document.getElementById('fact-modal-confidence').value,
|
||||
pinned: !!document.getElementById('fact-modal-pinned')?.checked,
|
||||
related_vulnerability_id: document.getElementById('fact-modal-related-vuln')?.value?.trim() || '',
|
||||
links_text: document.getElementById('fact-modal-links')?.value || '',
|
||||
};
|
||||
const editId = window._factModalEditId;
|
||||
const res = editId
|
||||
@@ -1613,12 +1951,14 @@ async function saveFactModal() {
|
||||
}
|
||||
closeFactModal();
|
||||
loadProjectFacts();
|
||||
if (currentProjectTab === 'graph') loadProjectFactGraph();
|
||||
}
|
||||
|
||||
async function deleteProjectFact(id) {
|
||||
if (!confirm(tp('projects.confirmDeleteFact'))) return;
|
||||
await apiFetch(`/api/projects/${currentProjectId}/facts/${id}`, { method: 'DELETE' });
|
||||
loadProjectFacts();
|
||||
if (currentProjectTab === 'graph') loadProjectFactGraph();
|
||||
}
|
||||
|
||||
function parseProjectDate(t) {
|
||||
@@ -1974,5 +2314,15 @@ window.viewFactsForVulnerability = viewFactsForVulnerability;
|
||||
window.openProjectConversation = openProjectConversation;
|
||||
window.unbindConversationFromProject = unbindConversationFromProject;
|
||||
window.loadProjectConversations = loadProjectConversations;
|
||||
window.loadProjectFactGraph = loadProjectFactGraph;
|
||||
window.filterProjectFactGraph = filterProjectFactGraph;
|
||||
window.centerProjectFactGraph = centerProjectFactGraph;
|
||||
window.closeProjectFactGraphSidebar = closeProjectFactGraphSidebar;
|
||||
window.openSelectedGraphFactDetail = openSelectedGraphFactDetail;
|
||||
window.editSelectedGraphFact = editSelectedGraphFact;
|
||||
window.promoteConversationAttackChain = promoteConversationAttackChain;
|
||||
window.deleteProjectFactEdge = deleteProjectFactEdge;
|
||||
window.focusProjectFactGraphEdge = focusProjectFactGraphEdge;
|
||||
window.toggleProjectFactGraphConnectMode = toggleProjectFactGraphConnectMode;
|
||||
window.rebuildProjectNameMap = rebuildProjectNameMap;
|
||||
window.projectNameById = projectNameById;
|
||||
|
||||
@@ -990,6 +990,7 @@ async function createBatchQueue() {
|
||||
const roleSelect = document.getElementById('batch-queue-role');
|
||||
const projectSelect = document.getElementById('batch-queue-project-id');
|
||||
const agentModeSelect = document.getElementById('batch-queue-agent-mode');
|
||||
const concurrencyInput = document.getElementById('batch-queue-concurrency');
|
||||
const scheduleModeSelect = document.getElementById('batch-queue-schedule-mode');
|
||||
const cronExprInput = document.getElementById('batch-queue-cron-expr');
|
||||
const executeNowCheckbox = document.getElementById('batch-queue-execute-now');
|
||||
@@ -1019,6 +1020,9 @@ async function createBatchQueue() {
|
||||
const scheduleMode = scheduleModeSelect ? (scheduleModeSelect.value === 'cron' ? 'cron' : 'manual') : 'manual';
|
||||
const cronExpr = cronExprInput ? cronExprInput.value.trim() : '';
|
||||
const executeNow = executeNowCheckbox ? !!executeNowCheckbox.checked : false;
|
||||
let concurrency = concurrencyInput ? parseInt(concurrencyInput.value, 10) : 1;
|
||||
if (!Number.isFinite(concurrency) || concurrency < 1) concurrency = 1;
|
||||
if (concurrency > 8) concurrency = 8;
|
||||
if (scheduleMode === 'cron' && !cronExpr) {
|
||||
alert(_t('batchImportModal.cronExprRequired'));
|
||||
return;
|
||||
@@ -1043,6 +1047,7 @@ async function createBatchQueue() {
|
||||
cronExpr,
|
||||
executeNow,
|
||||
projectId,
|
||||
concurrency,
|
||||
}),
|
||||
});
|
||||
|
||||
@@ -1489,6 +1494,7 @@ async function showBatchQueueDetail(queueId) {
|
||||
<div class="bq-kv"><span class="bq-kv__k">${escapeHtml(_t('batchQueueDetailModal.role'))}</span><span class="bq-kv__v" id="bq-role-val">${allowSubtaskMutation ? `<span class="bq-inline-editable" onclick="startInlineEditRole()" title="${escapeHtml(_t('common.edit'))}">${roleLineVal}</span>` : roleLineVal}</span></div>
|
||||
<div class="bq-kv"><span class="bq-kv__k">${escapeHtml(_t('batchImportModal.agentMode'))}</span><span class="bq-kv__v" id="bq-agentmode-val">${allowSubtaskMutation ? `<span class="bq-inline-editable" onclick="startInlineEditAgentMode()" title="${escapeHtml(_t('common.edit'))}">${escapeHtml(agentModeText)}</span>` : escapeHtml(agentModeText)}</span></div>
|
||||
<div class="bq-kv"><span class="bq-kv__k">${escapeHtml(_t('batchImportModal.scheduleMode'))}</span><span class="bq-kv__v" id="bq-schedule-val">${allowSubtaskMutation ? `<span class="bq-inline-editable" onclick="startInlineEditSchedule()" title="${escapeHtml(_t('common.edit'))}">${scheduleDetail}</span>` : scheduleDetail}</span></div>
|
||||
<div class="bq-kv"><span class="bq-kv__k">${escapeHtml(_t('batchQueueDetailModal.concurrency'))}</span><span class="bq-kv__v" id="bq-concurrency-val">${allowSubtaskMutation ? `<span class="bq-inline-editable" onclick="startInlineEditConcurrency()" title="${escapeHtml(_t('common.edit'))}">${escapeHtml(String(queue.concurrency && queue.concurrency > 0 ? queue.concurrency : 1))}</span>` : escapeHtml(String(queue.concurrency && queue.concurrency > 0 ? queue.concurrency : 1))}</span></div>
|
||||
<div class="bq-kv"><span class="bq-kv__k">${escapeHtml(_t('batchQueueDetailModal.taskTotal'))}</span><span class="bq-kv__v">${queue.tasks.length}</span></div>
|
||||
${queue.scheduleMode === 'cron' ? `<div class="bq-kv bq-kv--block"><span class="bq-kv__k">${escapeHtml(_t('batchQueueDetailModal.scheduleCronAuto'))}</span><span class="bq-kv__v bq-kv__v--control"><label class="bq-cron-toggle"><input type="checkbox" ${queue.scheduleEnabled !== false ? 'checked' : ''} onchange="updateBatchQueueScheduleEnabled(this.checked)" /><span class="bq-cron-toggle__hint">${escapeHtml(_t('batchQueueDetailModal.scheduleCronAutoHint'))}</span></label></span></div>` : ''}
|
||||
</section>
|
||||
@@ -2287,6 +2293,75 @@ async function saveInlineAgentMode() {
|
||||
}
|
||||
}
|
||||
|
||||
function normalizeBatchQueueConcurrencyInput(raw) {
|
||||
let n = parseInt(raw, 10);
|
||||
if (!Number.isFinite(n) || n < 1) n = 1;
|
||||
if (n > 8) n = 8;
|
||||
return n;
|
||||
}
|
||||
|
||||
// --- 内联编辑:并发数 ---
|
||||
function startInlineEditConcurrency() {
|
||||
const container = document.getElementById('bq-concurrency-val');
|
||||
if (!container) return;
|
||||
const queueId = batchQueuesState.currentQueueId;
|
||||
if (!queueId) return;
|
||||
apiFetch(`/api/batch-tasks/${queueId}`).then(r => r.json()).then(detail => {
|
||||
const queue = detail.queue || {};
|
||||
const current = normalizeBatchQueueConcurrencyInput(queue.concurrency || 1);
|
||||
container.innerHTML = `<span class="bq-inline-edit-controls">
|
||||
<input type="number" id="bq-edit-concurrency" min="1" max="8" value="${current}" style="width:72px;" />
|
||||
</span>`;
|
||||
const inp = document.getElementById('bq-edit-concurrency');
|
||||
if (!inp) return;
|
||||
inp.focus();
|
||||
inp.select();
|
||||
let cancelled = false;
|
||||
inp.addEventListener('keydown', (e) => {
|
||||
if (e.key === 'Enter') { e.preventDefault(); inp.blur(); }
|
||||
if (e.key === 'Escape') { cancelled = true; cancelAllInlineEdits(); }
|
||||
});
|
||||
inp.addEventListener('blur', () => {
|
||||
if (!cancelled) saveInlineConcurrency();
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
async function saveInlineConcurrency() {
|
||||
if (_bqInlineSaving) return;
|
||||
_bqInlineSaving = true;
|
||||
const queueId = batchQueuesState.currentQueueId;
|
||||
if (!queueId) { _bqInlineSaving = false; return; }
|
||||
const inp = document.getElementById('bq-edit-concurrency');
|
||||
const concurrency = normalizeBatchQueueConcurrencyInput(inp ? inp.value : 1);
|
||||
try {
|
||||
const detailResp = await apiFetch(`/api/batch-tasks/${queueId}`);
|
||||
const detail = await detailResp.json();
|
||||
const q = detail.queue || {};
|
||||
const response = await apiFetch(`/api/batch-tasks/${queueId}/metadata`, {
|
||||
method: 'PUT',
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
body: JSON.stringify({
|
||||
title: q.title || '',
|
||||
role: q.role || '',
|
||||
agentMode: q.agentMode || 'eino_single',
|
||||
concurrency,
|
||||
}),
|
||||
});
|
||||
if (!response.ok) {
|
||||
const result = await response.json().catch(() => ({}));
|
||||
throw new Error(result.error || _t('tasks.updateTaskFailed'));
|
||||
}
|
||||
_bqInlineSaving = false;
|
||||
showBatchQueueDetail(queueId);
|
||||
refreshBatchQueues();
|
||||
} catch (e) {
|
||||
_bqInlineSaving = false;
|
||||
console.error(e);
|
||||
alert(e.message);
|
||||
}
|
||||
}
|
||||
|
||||
// --- 单条执行 ---
|
||||
async function runSingleBatchTask(queueId, taskId) {
|
||||
if (!queueId || !taskId) return;
|
||||
@@ -2441,6 +2516,8 @@ window.startInlineEditRole = startInlineEditRole;
|
||||
window.saveInlineRole = saveInlineRole;
|
||||
window.startInlineEditAgentMode = startInlineEditAgentMode;
|
||||
window.saveInlineAgentMode = saveInlineAgentMode;
|
||||
window.startInlineEditConcurrency = startInlineEditConcurrency;
|
||||
window.saveInlineConcurrency = saveInlineConcurrency;
|
||||
window.runSingleBatchTask = runSingleBatchTask;
|
||||
window.startInlineEditSchedule = startInlineEditSchedule;
|
||||
window.toggleInlineScheduleCron = toggleInlineScheduleCron;
|
||||
|
||||
@@ -1989,6 +1989,10 @@ function buildWebshellTimelineItemFromDetail(detail) {
|
||||
// 渲染「执行过程及调用工具」折叠块(默认折叠,刷新后加载历史时保留并可展开)
|
||||
function renderWebshellProcessDetailsBlock(processDetails, defaultCollapsed) {
|
||||
if (!processDetails || processDetails.length === 0) return null;
|
||||
if (typeof window.filterNoiseProcessDetails === 'function') {
|
||||
processDetails = window.filterNoiseProcessDetails(processDetails);
|
||||
}
|
||||
if (!processDetails.length) return null;
|
||||
if (typeof window.coalesceProcessDetailsToolPairs === 'function') {
|
||||
processDetails = window.coalesceProcessDetailsToolPairs(processDetails);
|
||||
}
|
||||
|
||||
+108
-6
@@ -1498,6 +1498,13 @@
|
||||
</aside>
|
||||
<main class="projects-detail" id="projects-detail-main">
|
||||
<div class="projects-detail-placeholder" id="projects-detail-placeholder">
|
||||
<div class="projects-placeholder-icon" aria-hidden="true">
|
||||
<svg width="56" height="56" viewBox="0 0 24 24" fill="none" xmlns="http://www.w3.org/2000/svg">
|
||||
<rect x="3" y="4" width="18" height="16" rx="3" stroke="currentColor" stroke-width="1.5"/>
|
||||
<path d="M3 9h18M8 4V9M16 4V9" stroke="currentColor" stroke-width="1.5" stroke-linecap="round"/>
|
||||
<path d="M8 14h8M8 17h5" stroke="currentColor" stroke-width="1.5" stroke-linecap="round"/>
|
||||
</svg>
|
||||
</div>
|
||||
<h3 data-i18n="projects.selectOrCreateTitle">选择或创建项目</h3>
|
||||
<p data-i18n="projects.selectOrCreateHint">项目用于跨对话共享「事实黑板」:目标、环境、认证等信息会在绑定项目的对话中自动注入。</p>
|
||||
<button class="btn-primary" type="button" onclick="showNewProjectModal()" data-i18n="projects.createFirstProject">创建第一个项目</button>
|
||||
@@ -1527,6 +1534,7 @@
|
||||
</header>
|
||||
<nav class="projects-tabs" role="tablist">
|
||||
<button type="button" id="project-tab-facts" class="projects-tab is-active" role="tab" onclick="switchProjectTab('facts')" data-i18n="projects.tabFacts">事实黑板</button>
|
||||
<button type="button" id="project-tab-graph" class="projects-tab" role="tab" onclick="switchProjectTab('graph')" data-i18n="projects.tabGraph">攻击路径</button>
|
||||
<button type="button" id="project-tab-conversations" class="projects-tab" role="tab" onclick="switchProjectTab('conversations')" data-i18n="projects.tabConversations">关联对话</button>
|
||||
<button type="button" id="project-tab-vulns" class="projects-tab" role="tab" onclick="switchProjectTab('vulns')" data-i18n="projects.tabVulns">关联漏洞</button>
|
||||
<button type="button" id="project-tab-settings" class="projects-tab" role="tab" onclick="switchProjectTab('settings')" data-i18n="projects.tabSettings">设置</button>
|
||||
@@ -1587,11 +1595,96 @@
|
||||
</div>
|
||||
<div class="projects-table-wrap">
|
||||
<table class="data-table data-table--projects">
|
||||
<thead><tr><th>Key</th><th data-i18n="projects.category">分类</th><th data-i18n="projects.summary">摘要</th><th>Body</th><th data-i18n="projects.confidence">置信度</th><th data-i18n="projects.updated">更新</th><th class="col-actions" data-i18n="common.actions">操作</th></tr></thead>
|
||||
<thead><tr><th>Key</th><th data-i18n="projects.category">分类</th><th data-i18n="projects.summary">摘要</th><th data-i18n="projects.linksColumn">关系</th><th>Body</th><th data-i18n="projects.confidence">置信度</th><th data-i18n="projects.updated">更新</th><th class="col-actions" data-i18n="common.actions">操作</th></tr></thead>
|
||||
<tbody id="project-facts-tbody"></tbody>
|
||||
</table>
|
||||
</div>
|
||||
</div>
|
||||
<div id="project-panel-graph" class="projects-panel projects-panel--graph" role="tabpanel" hidden>
|
||||
<div class="projects-fact-toolbar projects-graph-toolbar">
|
||||
<p class="projects-fact-toolbar-hint" role="note">
|
||||
<svg class="projects-fact-toolbar-hint-icon" width="16" height="16" viewBox="0 0 24 24" fill="none" aria-hidden="true" xmlns="http://www.w3.org/2000/svg">
|
||||
<circle cx="12" cy="12" r="9" stroke="currentColor" stroke-width="2"/>
|
||||
<path d="M12 10v6M12 8h.01" stroke="currentColor" stroke-width="2" stroke-linecap="round"/>
|
||||
</svg>
|
||||
<span data-i18n="projects.graphToolbarHint">攻击路径图箭头与事实存储方向一致(source → target);节点按 target→infra→finding→exploit 分层排布。虚线边为待确认。</span>
|
||||
</p>
|
||||
<div class="projects-fact-toolbar-filters projects-graph-toolbar-row">
|
||||
<label class="projects-fact-filter-field">
|
||||
<span class="projects-fact-filter-label" data-i18n="projects.graphView">视图</span>
|
||||
<select id="project-graph-view" onchange="loadProjectFactGraph()">
|
||||
<option value="path" data-i18n="projects.graphViewPath">攻击路径</option>
|
||||
<option value="full" data-i18n="projects.graphViewFull">完整关系</option>
|
||||
</select>
|
||||
</label>
|
||||
<label class="projects-fact-filter-field projects-fact-filter-field--search projects-graph-search-field">
|
||||
<span class="sr-only" data-i18n="projects.graphSearchSr">搜索节点</span>
|
||||
<svg class="projects-fact-search-icon" width="16" height="16" viewBox="0 0 24 24" fill="none" aria-hidden="true" xmlns="http://www.w3.org/2000/svg">
|
||||
<circle cx="11" cy="11" r="7" stroke="currentColor" stroke-width="2"/>
|
||||
<path d="M20 20L16 16" stroke="currentColor" stroke-width="2" stroke-linecap="round"/>
|
||||
</svg>
|
||||
<input type="search" id="project-graph-search" placeholder="搜索节点…" oninput="filterProjectFactGraph()" autocomplete="off" data-i18n="projects.graphSearchPlaceholder" data-i18n-attr="placeholder">
|
||||
</label>
|
||||
<div class="projects-graph-actions" role="group" aria-label="Graph actions">
|
||||
<button type="button" class="projects-graph-action-btn" onclick="loadProjectFactGraph()" title="刷新" data-i18n="projects.graphRefresh" data-i18n-attr="title">
|
||||
<svg width="15" height="15" viewBox="0 0 24 24" fill="none" aria-hidden="true"><path d="M21 12a9 9 0 1 1-2.64-6.36" stroke="currentColor" stroke-width="2" stroke-linecap="round"/><path d="M21 3v6h-6" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"/></svg>
|
||||
<span data-i18n="projects.graphRefresh">刷新</span>
|
||||
</button>
|
||||
<button type="button" class="projects-graph-action-btn" onclick="centerProjectFactGraph()" title="居中" data-i18n="projects.graphCenter" data-i18n-attr="title">
|
||||
<svg width="15" height="15" viewBox="0 0 24 24" fill="none" aria-hidden="true"><circle cx="12" cy="12" r="3" stroke="currentColor" stroke-width="2"/><path d="M12 2v4M12 18v4M2 12h4M18 12h4" stroke="currentColor" stroke-width="2" stroke-linecap="round"/></svg>
|
||||
<span data-i18n="projects.graphCenter">居中</span>
|
||||
</button>
|
||||
<button type="button" class="projects-graph-action-btn projects-graph-action-btn--connect" id="project-graph-connect-btn" onclick="toggleProjectFactGraphConnectMode()" data-i18n="projects.graphConnect">连边</button>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
<div class="project-fact-graph-layout">
|
||||
<div id="project-fact-graph-container" class="project-fact-graph-container"></div>
|
||||
<aside id="project-fact-graph-sidebar" class="project-fact-graph-sidebar" hidden>
|
||||
<div class="project-fact-graph-sidebar-header">
|
||||
<div class="project-fact-graph-sidebar-title-wrap">
|
||||
<span id="project-fact-graph-node-category" class="project-fact-graph-node-category"></span>
|
||||
<h4 id="project-fact-graph-node-title">—</h4>
|
||||
</div>
|
||||
<button type="button" class="project-fact-graph-sidebar-close" onclick="closeProjectFactGraphSidebar()" aria-label="关闭" data-i18n="common.close" data-i18n-attr="aria-label">
|
||||
<svg width="16" height="16" viewBox="0 0 24 24" fill="none" aria-hidden="true"><path d="M6 6l12 12M18 6L6 18" stroke="currentColor" stroke-width="2" stroke-linecap="round"/></svg>
|
||||
</button>
|
||||
</div>
|
||||
<p id="project-fact-graph-node-meta" class="project-fact-graph-node-meta"></p>
|
||||
<div id="project-fact-graph-edges-wrap" class="project-fact-graph-edges-wrap" hidden>
|
||||
<h5 class="project-fact-graph-edges-title" data-i18n="projects.graphEdgesTitle">关系边</h5>
|
||||
<p class="project-fact-graph-edges-hint" data-i18n="projects.graphEdgesHint">箭头方向与数据库/编辑弹窗一致(source → target);点击连线可定位。</p>
|
||||
<div id="project-fact-graph-edges-list" class="project-fact-graph-edges-list"></div>
|
||||
</div>
|
||||
<div class="project-fact-graph-sidebar-actions">
|
||||
<button type="button" class="btn-primary btn-small" id="project-fact-graph-detail-btn" onclick="openSelectedGraphFactDetail()" data-i18n="projects.details">详情</button>
|
||||
<button type="button" class="btn-secondary btn-small" id="project-fact-graph-edit-btn" onclick="editSelectedGraphFact()" data-i18n="common.edit">编辑</button>
|
||||
</div>
|
||||
</aside>
|
||||
</div>
|
||||
<div class="project-fact-graph-footer">
|
||||
<div id="project-fact-graph-stats" class="project-fact-graph-stats"></div>
|
||||
<div class="projects-graph-legend" role="group" aria-label="Graph legend">
|
||||
<div class="projects-graph-legend-group">
|
||||
<span class="projects-graph-legend-heading" data-i18n="projects.graphLegendNodes">节点</span>
|
||||
<span class="projects-graph-legend-item projects-graph-legend-item--node"><i style="--legend-color:#4F46E5;--legend-bg:#F5F3FF"></i><span data-i18n="projects.graphLegendNodeTarget">TARGET · 目标</span></span>
|
||||
<span class="projects-graph-legend-item projects-graph-legend-item--node"><i style="--legend-color:#64748B;--legend-bg:#F8FAFC"></i><span data-i18n="projects.graphLegendNodeInfra">INFRA · 基础设施</span></span>
|
||||
<span class="projects-graph-legend-item projects-graph-legend-item--node"><i style="--legend-color:#E11D48;--legend-bg:#FFF1F2"></i><span data-i18n="projects.graphLegendNodeFinding">FINDING · 发现</span></span>
|
||||
<span class="projects-graph-legend-item projects-graph-legend-item--node"><i style="--legend-color:#9333EA;--legend-bg:#F5F3FF"></i><span data-i18n="projects.graphLegendNodeVuln">VULN · 漏洞</span></span>
|
||||
<span class="projects-graph-legend-item projects-graph-legend-item--node"><i style="--legend-color:#B45309;--legend-bg:#FFFBEB"></i><span data-i18n="projects.graphLegendNodeExploit">EXPLOIT · 利用</span></span>
|
||||
<span class="projects-graph-legend-item projects-graph-legend-item--node projects-graph-legend-item--node-dashed"><i style="--legend-color:#CBD5E1;--legend-bg:#F1F5F9"></i><span data-i18n="projects.graphLegendNodeMissing">MISSING · 缺失</span></span>
|
||||
</div>
|
||||
<span class="projects-graph-legend-divider" aria-hidden="true"></span>
|
||||
<div class="projects-graph-legend-group">
|
||||
<span class="projects-graph-legend-heading" data-i18n="projects.graphLegendEdges">连线</span>
|
||||
<span class="projects-graph-legend-item projects-graph-legend-item--edge"><i style="--legend-color:#4F46E5"></i><span data-i18n="projects.graphLegendDiscovered">discovered_on</span></span>
|
||||
<span class="projects-graph-legend-item projects-graph-legend-item--edge"><i style="--legend-color:#64748B"></i><span data-i18n="projects.graphLegendLeads">leads_to</span></span>
|
||||
<span class="projects-graph-legend-item projects-graph-legend-item--edge"><i style="--legend-color:#DC2626"></i><span data-i18n="projects.graphLegendExploits">exploits</span></span>
|
||||
<span class="projects-graph-legend-item projects-graph-legend-item--edge projects-graph-legend-item--dashed"><i style="--legend-color:#94A3B8"></i><span data-i18n="projects.graphLegendTentative">待确认(虚线)</span></span>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
<div id="project-panel-conversations" class="projects-panel" role="tabpanel" hidden>
|
||||
<div class="projects-panel-toolbar projects-panel-toolbar--hint">
|
||||
<p class="projects-fact-toolbar-hint" role="note">
|
||||
@@ -3684,7 +3777,9 @@
|
||||
<div class="modal-body batch-manage-body">
|
||||
<div class="batch-conversations-table">
|
||||
<div class="batch-table-header">
|
||||
<div class="batch-table-col-checkbox"></div>
|
||||
<div class="batch-table-col-checkbox">
|
||||
<input type="checkbox" id="batch-select-all" onchange="toggleSelectAllBatch()" data-i18n="batchManageModal.selectAll" data-i18n-attr="title" title="全选" />
|
||||
</div>
|
||||
<div class="batch-table-col-name" data-i18n="batchManageModal.conversationName">对话名称</div>
|
||||
<div class="batch-table-col-time" data-i18n="batchManageModal.lastTime">最近一次对话时间</div>
|
||||
<div class="batch-table-col-action" data-i18n="batchManageModal.action">操作</div>
|
||||
@@ -3693,10 +3788,6 @@
|
||||
</div>
|
||||
</div>
|
||||
<div class="modal-footer batch-manage-footer">
|
||||
<label class="select-all-checkbox">
|
||||
<input type="checkbox" id="batch-select-all" onchange="toggleSelectAllBatch()" />
|
||||
<span data-i18n="batchManageModal.selectAll">全选</span>
|
||||
</label>
|
||||
<div class="batch-footer-actions">
|
||||
<button class="btn-secondary" onclick="closeBatchManageModal()" data-i18n="common.cancel">取消</button>
|
||||
<button class="btn-primary" onclick="deleteSelectedConversations()" data-i18n="batchManageModal.deleteSelected">删除所选</button>
|
||||
@@ -3919,6 +4010,11 @@
|
||||
</select>
|
||||
<div class="form-hint" style="margin-top: 4px;" data-i18n="batchImportModal.agentModeHint">与对话页一致:Eino 单代理(ADK),或 Deep / Plan-Execute / Supervisor(后三种需已启用多代理)。</div>
|
||||
</div>
|
||||
<div class="form-group">
|
||||
<label for="batch-queue-concurrency" data-i18n="batchImportModal.concurrency">并发数</label>
|
||||
<input type="number" id="batch-queue-concurrency" min="1" max="8" value="1" style="width: 100%; padding: 8px; border: 1px solid #ddd; border-radius: 4px; font-size: 0.875rem;" />
|
||||
<div class="form-hint" style="margin-top: 4px;" data-i18n="batchImportModal.concurrencyHint">同时执行的子任务数量(1-8)。默认 1 为串行;含扫描类工具时建议 1-2。</div>
|
||||
</div>
|
||||
<div class="form-group">
|
||||
<label for="batch-queue-schedule-mode" data-i18n="batchImportModal.scheduleMode">调度方式</label>
|
||||
<select id="batch-queue-schedule-mode" onchange="handleBatchScheduleModeChange()" style="width: 100%; padding: 8px; border: 1px solid #ddd; border-radius: 4px; font-size: 0.875rem;">
|
||||
@@ -4339,6 +4435,11 @@
|
||||
<label for="fact-modal-related-vuln" data-i18n="projects.relatedVulnIdLabel">关联漏洞 ID</label>
|
||||
<input type="text" id="fact-modal-related-vuln" class="form-input" placeholder="可选" data-i18n="projects.optional" data-i18n-attr="placeholder">
|
||||
</div>
|
||||
<div class="form-group">
|
||||
<label for="fact-modal-links" data-i18n="projects.factLinksLabel">关系边(from → 本事实)</label>
|
||||
<textarea id="fact-modal-links" class="form-input" rows="4" placeholder="discovered_on: target/primary_domain exploits: exploit/upload-rce" data-i18n="projects.factLinksPlaceholder" data-i18n-attr="placeholder"></textarea>
|
||||
<p class="projects-field-hint" data-i18n="projects.factLinksHint">每行一条:type: source_fact_key(来源 → 当前事实)。常用 type:discovered_on、depends_on、leads_to、enables、exploits。保存时替换全部关系边。</p>
|
||||
</div>
|
||||
</div>
|
||||
<div class="projects-modal-footer">
|
||||
<button class="btn-secondary" type="button" onclick="closeFactModal()" data-i18n="common.cancel">取消</button>
|
||||
@@ -4396,6 +4497,7 @@
|
||||
<script src="/static/js/terminal.js"></script>
|
||||
<script src="/static/js/knowledge.js"></script>
|
||||
<script src="/static/js/skills.js"></script>
|
||||
<script src="/static/js/fact-graph.js"></script>
|
||||
<script src="/static/js/projects.js"></script>
|
||||
<script src="/static/js/vulnerability.js?v=12"></script>
|
||||
<script src="/static/js/webshell.js"></script>
|
||||
|
||||
Reference in New Issue
Block a user