Compare commits

...

26 Commits

Author SHA1 Message Date
公明 bcfb7b8da1 Update config.yaml 2026-04-29 04:11:31 +08:00
公明 f420ae0265 Add files via upload 2026-04-29 03:28:32 +08:00
公明 e3f59b29ab Add files via upload 2026-04-29 03:26:27 +08:00
公明 87cba37203 Add files via upload 2026-04-29 03:24:48 +08:00
公明 4773b9e963 Update config.yaml 2026-04-29 03:01:21 +08:00
公明 eda5f9bba1 Add files via upload 2026-04-29 02:59:34 +08:00
公明 1318607813 Add files via upload 2026-04-29 02:57:22 +08:00
公明 5100924abe Add files via upload 2026-04-29 02:54:43 +08:00
公明 44079674dd Add files via upload 2026-04-28 14:07:01 +08:00
公明 d959390e27 Update config.yaml 2026-04-28 11:45:27 +08:00
公明 62a0d8cb71 Add files via upload 2026-04-28 11:40:09 +08:00
公明 b53cae3a02 Add files via upload 2026-04-28 11:37:52 +08:00
公明 3b3d094dc4 Add files via upload 2026-04-28 10:26:09 +08:00
公明 47922c2083 Add files via upload 2026-04-28 10:23:24 +08:00
公明 dfaf0bc77f Update config.yaml 2026-04-28 01:23:57 +08:00
公明 3eb7edb1b8 Add files via upload 2026-04-28 01:23:33 +08:00
公明 f82f6b861e Add files via upload 2026-04-28 01:22:21 +08:00
公明 2acf43c454 Add files via upload 2026-04-28 01:19:01 +08:00
公明 fad6b3c808 Add files via upload 2026-04-28 01:05:58 +08:00
公明 0597838217 Add files via upload 2026-04-28 01:04:58 +08:00
公明 1532426b4f Add files via upload 2026-04-28 01:02:30 +08:00
公明 3aeb8c3474 Add files via upload 2026-04-28 00:37:46 +08:00
公明 b2b166972a Add files via upload 2026-04-28 00:33:29 +08:00
公明 36b669771c Delete internal/multiagent directory 2026-04-28 00:30:34 +08:00
公明 96564d4d89 Update default_single_system_prompt.go 2026-04-27 14:58:49 +08:00
公明 d85afa2d39 Add files via upload 2026-04-27 11:29:16 +08:00
68 changed files with 2994 additions and 1299 deletions
+15 -5
View File
@@ -10,7 +10,7 @@
# ============================================ # ============================================
# 前端显示的版本号(可选,不填则显示默认版本) # 前端显示的版本号(可选,不填则显示默认版本)
version: "v1.5.8" version: "v1.5.13"
# 服务器配置 # 服务器配置
server: server:
host: 0.0.0.0 # 监听地址,0.0.0.0 表示监听所有网络接口 host: 0.0.0.0 # 监听地址,0.0.0.0 表示监听所有网络接口
@@ -70,7 +70,7 @@ multi_agent:
robot_use_multi_agent: true # true 时企业微信/钉钉/飞书机器人也走 Eino 多代理(成本更高) robot_use_multi_agent: true # true 时企业微信/钉钉/飞书机器人也走 Eino 多代理(成本更高)
batch_use_multi_agent: false # true 时「批量任务」队列中每个子任务也走 Eino 多代理(成本更高) batch_use_multi_agent: false # true 时「批量任务」队列中每个子任务也走 Eino 多代理(成本更高)
max_iteration: 0 # 主代理 / plan_execute 执行器最大轮次,0 表示沿用 agent.max_iterations max_iteration: 0 # 主代理 / plan_execute 执行器最大轮次,0 表示沿用 agent.max_iterations
# plan_execute 专用:execute↔replan 外层循环上限,0 表示 Eino 默认 10。Executor 未暴露 Handlerspatch/reduction/plantask 不作用于 PE,但 tool_search 工具列表拆分仍通过共享 ToolsConfig 作用于执行器 # plan_execute 专用:execute↔replan 外层循环上限,0 表示 Eino 默认 10。当前实现下 Executor 会挂载 patch/reduction/tool_search 等前置中间件
plan_execute_loop_max_iterations: 0 plan_execute_loop_max_iterations: 0
sub_agent_max_iterations: 120 sub_agent_max_iterations: 120
sub_agent_user_context_max_runes: 0 # 子代理 task 描述中自动注入用户原始请求的字符上限;0=默认2000,负数=禁用 sub_agent_user_context_max_runes: 0 # 子代理 task 描述中自动注入用户原始请求的字符上限;0=默认2000,负数=禁用
@@ -87,15 +87,25 @@ multi_agent:
# Eino ADK 中间件与 Deep/Supervisor 调参(结构体见 internal/config/config.go → MultiAgentEinoMiddlewareConfig # Eino ADK 中间件与 Deep/Supervisor 调参(结构体见 internal/config/config.go → MultiAgentEinoMiddlewareConfig
eino_middleware: eino_middleware:
patch_tool_calls: true # true:修补历史中无 tool_result 的悬空 tool_call(流式中断/重试后更稳);false:关闭;字段省略时默认等同 true patch_tool_calls: true # true:修补历史中无 tool_result 的悬空 tool_call(流式中断/重试后更稳);false:关闭;字段省略时默认等同 true
tool_search_enable: false # true:工具数 ≥ min 时启用 tool_search,仅前 N 个工具常驻,其余按正则按需解锁,省 token、减误选;false:全量工具进上下文 tool_search_enable: true # true:工具数 ≥ min 时启用 tool_search,仅前 N 个工具常驻,其余按正则按需解锁,省 token、减误选;false:全量工具进上下文
tool_search_min_tools: 20 # 达到该数量才启用 tool_search(避免工具很少时多此一举);与 always_visible 配合使用 tool_search_min_tools: 20 # 达到该数量才启用 tool_search(避免工具很少时多此一举);与 always_visible 配合使用
tool_search_always_visible: 12 # 始终直接暴露给模型的工具个数(顺序与角色工具列表一致);其余工具进入动态池,需 tool_search 解锁 tool_search_always_visible: 12 # 始终直接暴露给模型的工具个数(顺序与角色工具列表一致);其余工具进入动态池,需 tool_search 解锁
tool_search_always_visible_tools: [read_file, glob, grep, write_file, edit_file, execute, task, transfer_to_agent, exit, write_todos, skill, tool_search, TaskCreate, TaskGet, TaskUpdate, TaskList, record_vulnerability, list_knowledge_risk_types, search_knowledge_base, webshell_exec, webshell_file_list, webshell_file_read, webshell_file_write, manage_webshell_list, manage_webshell_add, manage_webshell_update, manage_webshell_delete, manage_webshell_test, batch_task_list, batch_task_get, batch_task_start, batch_task_rerun, batch_task_pause, batch_task_update_metadata, batch_task_update_schedule, batch_task_schedule_enabled, batch_task_update_task, batch_task_remove_task, batch_task_delete, batch_task_create, batch_task_add_task, http-framework-test] # 后端内置常驻工具白名单(优先于 always_visible 数量策略)
plantask_enable: false # true:主代理(Deep / Supervisor 主)挂载 TaskCreate/Get/Update/List;需 eino_skills 可用且 skills_dir 存在,否则仅打日志并跳过 plantask_enable: false # true:主代理(Deep / Supervisor 主)挂载 TaskCreate/Get/Update/List;需 eino_skills 可用且 skills_dir 存在,否则仅打日志并跳过
plantask_rel_dir: .eino/plantask # 结构化任务文件相对 skills_dir 的子目录,其下再按会话 ID 分子目录存放 plantask_rel_dir: .eino/plantask # 结构化任务文件相对 skills_dir 的子目录,其下再按会话 ID 分子目录存放
reduction_enable: false # true:大工具输出截断/落盘以控上下文;依赖与 plantask 相同的 eino local 写盘后端,无后端时不挂载 reduction_enable: true # true:大工具输出截断/落盘以控上下文;依赖与 plantask 相同的 eino local 写盘后端,无后端时不挂载
reduction_max_length_for_trunc: 50000 # 单条工具结果超过该字符数(bytes)时截断并落盘(由 reduction 中间件处理)
reduction_max_tokens_for_clear: 160000 # 历史工具结果清理阈值(tokens),超阈值时在模型调用前清理旧结果
reduction_root_dir: "" # 非空:截断/清理内容落盘根路径;空:使用系统临时目录下按会话隔离的默认路径 reduction_root_dir: "" # 非空:截断/清理内容落盘根路径;空:使用系统临时目录下按会话隔离的默认路径
reduction_clear_exclude: [] # 不参与「清理阶段」的工具名额外列表(会与 task/transfer/exit 等内置排除项合并);需要时用 YAML 列表填写 reduction_clear_exclude: [] # 不参与「清理阶段」的工具名额外列表(会与 task/transfer/exit 等内置排除项合并);需要时用 YAML 列表填写
reduction_sub_agents: false # true:子代理也挂 reductionfalse:仅编排主代理使用 reduction reduction_sub_agents: true # true:子代理也挂 reductionfalse:仅编排主代理使用 reduction
summarization_trigger_ratio: 0.8 # summarization 触发比例(max_total_tokens * ratio),建议 0.75~0.85
summarization_emit_internal_events: true # true:发出 summarization 内部事件(便于诊断)
history_input_budget_ratio: 0.35 # 历史入队预算比例(max_total_tokens * ratio
plan_execute_user_input_budget_ratio: 0.35 # plan_execute 中 userInput 预算比例(planner/replanner/executor 共用)
plan_execute_executed_steps_budget_ratio: 0.2 # plan_execute 中 executed_steps 预算比例
plan_execute_max_step_result_runes: 4000 # plan_execute 每步结果最大字符数(超出截断)
plan_execute_keep_last_steps: 8 # plan_execute 仅保留最近 N 步正文,早期步骤折叠为标题
checkpoint_dir: "" # 非空:为 adk.NewRunner 启用按会话子目录的文件型 CheckPointStore,便于中断恢复持久化;Resume 的 HTTP/前端流程需另行对接 checkpoint_dir: "" # 非空:为 adk.NewRunner 启用按会话子目录的文件型 CheckPointStore,便于中断恢复持久化;Resume 的 HTTP/前端流程需另行对接
deep_output_key: "" # 非空:将最终助手输出写入 adk session 的键名(Deep 与 Supervisor 主代理);空表示不写入 deep_output_key: "" # 非空:将最终助手输出写入 adk session 的键名(Deep 与 Supervisor 主代理);空表示不写入
deep_model_retry_max_retries: 0 # >0ChatModel 调用失败时的框架级最大重试次数(Deep 与 Supervisor 主);0:不重试 deep_model_retry_max_retries: 0 # >0ChatModel 调用失败时的框架级最大重试次数(Deep 与 Supervisor 主);0:不重试
+49 -30
View File
@@ -39,6 +39,7 @@ type Agent struct {
toolNameMapping map[string]string // 工具名称映射:OpenAI格式 -> 原始格式(用于外部MCP工具) toolNameMapping map[string]string // 工具名称映射:OpenAI格式 -> 原始格式(用于外部MCP工具)
currentConversationID string // 当前对话ID(用于自动传递给工具) currentConversationID string // 当前对话ID(用于自动传递给工具)
promptBaseDir string // 解析 system_prompt_path 时相对路径的基准目录(通常为 config.yaml 所在目录) promptBaseDir string // 解析 system_prompt_path 时相对路径的基准目录(通常为 config.yaml 所在目录)
toolDescriptionMode string // 工具描述模式: "short" | "full",默认 short
} }
// ResultStorage 结果存储接口(直接使用 storage 包的类型) // ResultStorage 结果存储接口(直接使用 storage 包的类型)
@@ -162,6 +163,7 @@ func NewAgent(cfg *config.OpenAIConfig, agentCfg *config.AgentConfig, mcpServer
resultStorage: resultStorage, resultStorage: resultStorage,
largeResultThreshold: largeResultThreshold, largeResultThreshold: largeResultThreshold,
toolNameMapping: make(map[string]string), // 初始化工具名称映射 toolNameMapping: make(map[string]string), // 初始化工具名称映射
toolDescriptionMode: "short",
} }
} }
@@ -338,8 +340,8 @@ func (fc *FunctionCall) UnmarshalJSON(data []byte) error {
type AgentLoopResult struct { type AgentLoopResult struct {
Response string Response string
MCPExecutionIDs []string MCPExecutionIDs []string
LastReActInput string // 最后一轮ReAct的输入(压缩后的messagesJSON格式 LastAgentTraceInput string // 最后一轮代理消息轨迹(压缩后的 messagesJSON;与 multiagent.RunResult 字段对齐
LastReActOutput string // 最终大模型的输出 LastAgentTraceOutput string // 最终助手输出文本
} }
// ProgressCallback 进度回调函数类型 // ProgressCallback 进度回调函数类型
@@ -471,7 +473,7 @@ func (a *Agent) AgentLoopWithProgress(ctx context.Context, userInput string, his
} }
// 用于保存当前的messages,以便在异常情况下也能保存ReAct输入 // 用于保存当前的messages,以便在异常情况下也能保存ReAct输入
var currentReActInput string var currentAgentTraceInput string
maxIterations := a.maxIterations maxIterations := a.maxIterations
thinkingStreamSeq := 0 thinkingStreamSeq := 0
@@ -490,9 +492,9 @@ func (a *Agent) AgentLoopWithProgress(ctx context.Context, userInput string, his
if err != nil { if err != nil {
a.logger.Warn("序列化ReAct输入失败", zap.Error(err)) a.logger.Warn("序列化ReAct输入失败", zap.Error(err))
} else { } else {
currentReActInput = string(messagesJSON) currentAgentTraceInput = string(messagesJSON)
// 更新result中的值,确保始终保存最新的ReAct输入(压缩后的) // 更新result中的值,确保始终保存最新的ReAct输入(压缩后的)
result.LastReActInput = currentReActInput result.LastAgentTraceInput = currentAgentTraceInput
} }
// 检查上下文是否已取消 // 检查上下文是否已取消
@@ -500,13 +502,13 @@ func (a *Agent) AgentLoopWithProgress(ctx context.Context, userInput string, his
case <-ctx.Done(): case <-ctx.Done():
// 上下文被取消(可能是用户主动暂停或其他原因) // 上下文被取消(可能是用户主动暂停或其他原因)
a.logger.Info("检测到上下文取消,保存当前ReAct数据", zap.Error(ctx.Err())) a.logger.Info("检测到上下文取消,保存当前ReAct数据", zap.Error(ctx.Err()))
result.LastReActInput = currentReActInput result.LastAgentTraceInput = currentAgentTraceInput
if ctx.Err() == context.Canceled { if ctx.Err() == context.Canceled {
result.Response = "任务已被取消。" result.Response = "任务已被取消。"
} else { } else {
result.Response = fmt.Sprintf("任务执行中断: %v", ctx.Err()) result.Response = fmt.Sprintf("任务执行中断: %v", ctx.Err())
} }
result.LastReActOutput = result.Response result.LastAgentTraceOutput = result.Response
return result, ctx.Err() return result, ctx.Err()
default: default:
} }
@@ -600,10 +602,10 @@ func (a *Agent) AgentLoopWithProgress(ctx context.Context, userInput string, his
}) })
if err != nil { if err != nil {
// API调用失败,保存当前的ReAct输入和错误信息作为输出 // API调用失败,保存当前的ReAct输入和错误信息作为输出
result.LastReActInput = currentReActInput result.LastAgentTraceInput = currentAgentTraceInput
errorMsg := fmt.Sprintf("调用OpenAI失败: %v", err) errorMsg := fmt.Sprintf("调用OpenAI失败: %v", err)
result.Response = errorMsg result.Response = errorMsg
result.LastReActOutput = errorMsg result.LastAgentTraceOutput = errorMsg
a.logger.Warn("OpenAI调用失败,已保存ReAct数据", zap.Error(err)) a.logger.Warn("OpenAI调用失败,已保存ReAct数据", zap.Error(err))
return result, fmt.Errorf("调用OpenAI失败: %w", err) return result, fmt.Errorf("调用OpenAI失败: %w", err)
} }
@@ -629,19 +631,19 @@ func (a *Agent) AgentLoopWithProgress(ctx context.Context, userInput string, his
continue continue
} }
// OpenAI返回错误,保存当前的ReAct输入和错误信息作为输出 // OpenAI返回错误,保存当前的ReAct输入和错误信息作为输出
result.LastReActInput = currentReActInput result.LastAgentTraceInput = currentAgentTraceInput
errorMsg := fmt.Sprintf("OpenAI错误: %s", response.Error.Message) errorMsg := fmt.Sprintf("OpenAI错误: %s", response.Error.Message)
result.Response = errorMsg result.Response = errorMsg
result.LastReActOutput = errorMsg result.LastAgentTraceOutput = errorMsg
return result, fmt.Errorf("OpenAI错误: %s", response.Error.Message) return result, fmt.Errorf("OpenAI错误: %s", response.Error.Message)
} }
if len(response.Choices) == 0 { if len(response.Choices) == 0 {
// 没有收到响应,保存当前的ReAct输入和错误信息作为输出 // 没有收到响应,保存当前的ReAct输入和错误信息作为输出
result.LastReActInput = currentReActInput result.LastAgentTraceInput = currentAgentTraceInput
errorMsg := "没有收到响应" errorMsg := "没有收到响应"
result.Response = errorMsg result.Response = errorMsg
result.LastReActOutput = errorMsg result.LastAgentTraceOutput = errorMsg
return result, fmt.Errorf("没有收到响应") return result, fmt.Errorf("没有收到响应")
} }
@@ -816,7 +818,7 @@ func (a *Agent) AgentLoopWithProgress(ctx context.Context, userInput string, his
}) })
if strings.TrimSpace(streamText) != "" { if strings.TrimSpace(streamText) != "" {
result.Response = streamText result.Response = streamText
result.LastReActOutput = result.Response result.LastAgentTraceOutput = result.Response
sendProgress("progress", "总结生成完成", nil) sendProgress("progress", "总结生成完成", nil)
return result, nil return result, nil
} }
@@ -863,14 +865,14 @@ func (a *Agent) AgentLoopWithProgress(ctx context.Context, userInput string, his
}) })
if strings.TrimSpace(streamText) != "" { if strings.TrimSpace(streamText) != "" {
result.Response = streamText result.Response = streamText
result.LastReActOutput = result.Response result.LastAgentTraceOutput = result.Response
sendProgress("progress", "总结生成完成", nil) sendProgress("progress", "总结生成完成", nil)
return result, nil return result, nil
} }
// 如果获取总结失败,使用当前回复作为结果 // 如果获取总结失败,使用当前回复作为结果
if choice.Message.Content != "" { if choice.Message.Content != "" {
result.Response = choice.Message.Content result.Response = choice.Message.Content
result.LastReActOutput = result.Response result.LastAgentTraceOutput = result.Response
return result, nil return result, nil
} }
// 如果都没有内容,跳出循环,让后续逻辑处理 // 如果都没有内容,跳出循环,让后续逻辑处理
@@ -881,7 +883,7 @@ func (a *Agent) AgentLoopWithProgress(ctx context.Context, userInput string, his
if choice.FinishReason == "stop" { if choice.FinishReason == "stop" {
sendProgress("progress", "正在生成最终回复...", nil) sendProgress("progress", "正在生成最终回复...", nil)
result.Response = choice.Message.Content result.Response = choice.Message.Content
result.LastReActOutput = result.Response result.LastAgentTraceOutput = result.Response
return result, nil return result, nil
} }
} }
@@ -910,19 +912,19 @@ func (a *Agent) AgentLoopWithProgress(ctx context.Context, userInput string, his
}) })
if strings.TrimSpace(streamText) != "" { if strings.TrimSpace(streamText) != "" {
result.Response = streamText result.Response = streamText
result.LastReActOutput = result.Response result.LastAgentTraceOutput = result.Response
sendProgress("progress", "总结生成完成", nil) sendProgress("progress", "总结生成完成", nil)
return result, nil return result, nil
} }
// 如果无法生成总结,返回友好的提示 // 如果无法生成总结,返回友好的提示
result.Response = fmt.Sprintf("已达到最大迭代次数(%d轮)。系统已执行了多轮测试,但由于达到迭代上限,无法继续自动执行。建议您查看已执行的工具结果,或提出新的测试请求以继续测试。", a.maxIterations) result.Response = fmt.Sprintf("已达到最大迭代次数(%d轮)。系统已执行了多轮测试,但由于达到迭代上限,无法继续自动执行。建议您查看已执行的工具结果,或提出新的测试请求以继续测试。", a.maxIterations)
result.LastReActOutput = result.Response result.LastAgentTraceOutput = result.Response
return result, nil return result, nil
} }
// getAvailableTools 获取可用工具 // getAvailableTools 获取可用工具
// 从MCP服务器动态获取工具列表,使用简短描述以减少token消耗 // 从MCP服务器动态获取工具列表,描述模式由 tool_description_mode 控制
// roleTools: 角色配置的工具列表(toolKey格式),如果为空或nil,则使用所有工具(默认角色) // roleTools: 角色配置的工具列表(toolKey格式),如果为空或nil,则使用所有工具(默认角色)
func (a *Agent) getAvailableTools(roleTools []string) []Tool { func (a *Agent) getAvailableTools(roleTools []string) []Tool {
// 构建角色工具集合(用于快速查找) // 构建角色工具集合(用于快速查找)
@@ -946,11 +948,7 @@ func (a *Agent) getAvailableTools(roleTools []string) []Tool {
continue // 不在角色工具列表中,跳过 continue // 不在角色工具列表中,跳过
} }
} }
// 使用简短描述(如果存在),否则使用详细描述 description := a.pickToolDescription(mcpTool.ShortDescription, mcpTool.Description)
description := mcpTool.ShortDescription
if description == "" {
description = mcpTool.Description
}
// 转换schema中的类型为OpenAI标准类型 // 转换schema中的类型为OpenAI标准类型
convertedSchema := a.convertSchemaTypes(mcpTool.InputSchema) convertedSchema := a.convertSchemaTypes(mcpTool.InputSchema)
@@ -1024,11 +1022,7 @@ func (a *Agent) getAvailableTools(roleTools []string) []Tool {
continue continue
} }
// 使用简短描述(如果存在),否则使用详细描述 description := a.pickToolDescription(externalTool.ShortDescription, externalTool.Description)
description := externalTool.ShortDescription
if description == "" {
description = externalTool.Description
}
// 转换schema中的类型为OpenAI标准类型 // 转换schema中的类型为OpenAI标准类型
convertedSchema := a.convertSchemaTypes(externalTool.InputSchema) convertedSchema := a.convertSchemaTypes(externalTool.InputSchema)
@@ -1063,6 +1057,19 @@ func (a *Agent) getAvailableTools(roleTools []string) []Tool {
return tools return tools
} }
func (a *Agent) pickToolDescription(shortDesc, fullDesc string) string {
a.mu.RLock()
mode := strings.TrimSpace(strings.ToLower(a.toolDescriptionMode))
a.mu.RUnlock()
if mode == "full" {
return fullDesc
}
if shortDesc != "" {
return shortDesc
}
return fullDesc
}
// convertSchemaTypes 递归转换schema中的类型为OpenAI标准类型 // convertSchemaTypes 递归转换schema中的类型为OpenAI标准类型
func (a *Agent) convertSchemaTypes(schema map[string]interface{}) map[string]interface{} { func (a *Agent) convertSchemaTypes(schema map[string]interface{}) map[string]interface{} {
if schema == nil { if schema == nil {
@@ -1665,6 +1672,18 @@ func (a *Agent) UpdateMaxIterations(maxIterations int) {
} }
} }
// UpdateToolDescriptionMode 更新工具描述模式(short/full)
func (a *Agent) UpdateToolDescriptionMode(mode string) {
a.mu.Lock()
defer a.mu.Unlock()
mode = strings.TrimSpace(strings.ToLower(mode))
if mode != "full" {
mode = "short"
}
a.toolDescriptionMode = mode
a.logger.Info("Agent工具描述模式已更新", zap.String("tool_description_mode", mode))
}
// formatToolError 格式化工具错误信息,提供更友好的错误描述 // formatToolError 格式化工具错误信息,提供更友好的错误描述
func (a *Agent) formatToolError(toolName string, args map[string]interface{}, err error) string { func (a *Agent) formatToolError(toolName string, args map[string]interface{}, err error) string {
errorMsg := fmt.Sprintf(`工具执行失败 errorMsg := fmt.Sprintf(`工具执行失败
-1
View File
@@ -283,4 +283,3 @@ func TestAgent_NewAgent_CustomConfig(t *testing.T) {
t.Errorf("阈值不匹配。期望: %d, 实际: %d", 100*1024, threshold) t.Errorf("阈值不匹配。期望: %d, 实际: %d", 100*1024, threshold)
} }
} }
@@ -91,6 +91,20 @@ func DefaultSingleAgentSystemPrompt() string {
当工具返回错误时,错误信息会包含在工具响应中,请仔细阅读并做出合理的决策。 当工具返回错误时,错误信息会包含在工具响应中,请仔细阅读并做出合理的决策。
## 结束条件与停止约束
- 在「未完成用户目标」前,不得输出纯计划/纯建议式结论并结束本轮;必须继续给出可执行下一步,并优先通过工具验证。
- 若你准备结束回答,先执行一次自检:
1) 是否已有可验证证据支撑“任务完成/无法继续”的结论;
2) 是否至少尝试过当前路径的合理替代(参数、路径、方法、入口);
3) 是否仍存在可执行且低成本的下一步验证动作。
- 仅当满足以下任一条件时,才允许输出最终收尾:
1) 已达到用户目标并给出证据;
2) 达到明确边界(超时、权限、目标不可达、工具不可用且无替代),并清楚说明阻断点与已尝试项;
3) 用户明确要求停止。
- 若最近一步得到 404/空结果/无效响应,不得直接结束;至少再进行一次“同目标不同策略”的验证(如变更路径、参数、请求方法、上下文来源)。
- 避免无效空转:同一工具+同类参数连续失败 3 次后,必须切换策略(改工具、改入口、改假设)并说明切换原因。
## 漏洞记录 ## 漏洞记录
发现有效漏洞时,必须使用 ` + builtin.ToolRecordVulnerability + ` 记录:标题、描述、严重程度、类型、目标、证明(POC)、影响、修复建议。 发现有效漏洞时,必须使用 ` + builtin.ToolRecordVulnerability + ` 记录:标题、描述、严重程度、类型、目标、证明(POC)、影响、修复建议。
+3
View File
@@ -133,6 +133,7 @@ func New(cfg *config.Config, log *logger.Logger) (*App, error) {
maxIterations = 30 // 默认值 maxIterations = 30 // 默认值
} }
agent := agent.NewAgent(&cfg.OpenAI, &cfg.Agent, mcpServer, externalMCPMgr, log.Logger, maxIterations) agent := agent.NewAgent(&cfg.OpenAI, &cfg.Agent, mcpServer, externalMCPMgr, log.Logger, maxIterations)
agent.UpdateToolDescriptionMode(cfg.Security.ToolDescriptionMode)
// 设置结果存储到Agent // 设置结果存储到Agent
agent.SetResultStorage(resultStorage) agent.SetResultStorage(resultStorage)
@@ -901,6 +902,8 @@ func setupRoutes(
// 漏洞管理 // 漏洞管理
protected.GET("/vulnerabilities", vulnerabilityHandler.ListVulnerabilities) protected.GET("/vulnerabilities", vulnerabilityHandler.ListVulnerabilities)
protected.GET("/vulnerabilities/export", vulnerabilityHandler.ExportVulnerabilities)
protected.GET("/vulnerabilities/filter-options", vulnerabilityHandler.GetVulnerabilityFilterOptions)
protected.GET("/vulnerabilities/stats", vulnerabilityHandler.GetVulnerabilityStats) protected.GET("/vulnerabilities/stats", vulnerabilityHandler.GetVulnerabilityStats)
protected.GET("/vulnerabilities/:id", vulnerabilityHandler.GetVulnerability) protected.GET("/vulnerabilities/:id", vulnerabilityHandler.GetVulnerability)
protected.POST("/vulnerabilities", vulnerabilityHandler.CreateVulnerability) protected.POST("/vulnerabilities", vulnerabilityHandler.CreateVulnerability)
+10 -10
View File
@@ -145,7 +145,7 @@ func (b *Builder) BuildChainFromConversation(ctx context.Context, conversationID
} }
// 1. 优先尝试从数据库获取保存的最后一轮ReAct输入和输出 // 1. 优先尝试从数据库获取保存的最后一轮ReAct输入和输出
reactInputJSON, modelOutput, err := b.db.GetReActData(conversationID) reactInputJSON, modelOutput, err := b.db.GetAgentTrace(conversationID)
if err != nil { if err != nil {
b.logger.Warn("获取保存的ReAct数据失败,将使用消息历史构建", zap.Error(err)) b.logger.Warn("获取保存的ReAct数据失败,将使用消息历史构建", zap.Error(err))
// 继续使用原来的逻辑 // 继续使用原来的逻辑
@@ -170,7 +170,7 @@ func (b *Builder) BuildChainFromConversation(ctx context.Context, conversationID
messageCount = len(tempMessages) messageCount = len(tempMessages)
} }
dataSource = "database_last_react_input" dataSource = "database_last_agent_trace"
b.logger.Info("使用保存的ReAct数据构建攻击链", b.logger.Info("使用保存的ReAct数据构建攻击链",
zap.String("conversationId", conversationID), zap.String("conversationId", conversationID),
zap.String("dataSource", dataSource), zap.String("dataSource", dataSource),
@@ -183,7 +183,7 @@ func (b *Builder) BuildChainFromConversation(ctx context.Context, conversationID
// userInput = b.extractUserInputFromReActInput(reactInputJSON) // userInput = b.extractUserInputFromReActInput(reactInputJSON)
// 将JSON格式的messages转换为可读格式 // 将JSON格式的messages转换为可读格式
reactInputFinal = b.formatReActInputFromJSON(reactInputJSON) reactInputFinal = b.formatAgentTraceInputFromJSON(reactInputJSON)
} else { } else {
// 2. 如果没有保存的ReAct数据,从对话消息构建 // 2. 如果没有保存的ReAct数据,从对话消息构建
dataSource = "messages_table" dataSource = "messages_table"
@@ -201,7 +201,7 @@ func (b *Builder) BuildChainFromConversation(ctx context.Context, conversationID
} }
// 提取最后一轮ReAct的输入(历史消息+当前用户输入) // 提取最后一轮ReAct的输入(历史消息+当前用户输入)
reactInputFinal = b.buildReActInput(messages) reactInputFinal = b.buildAgentTraceInput(messages)
// 提取大模型最后的输出(最后一条assistant消息) // 提取大模型最后的输出(最后一条assistant消息)
for i := len(messages) - 1; i >= 0; i-- { for i := len(messages) - 1; i >= 0; i-- {
@@ -212,7 +212,7 @@ func (b *Builder) BuildChainFromConversation(ctx context.Context, conversationID
} }
} }
// 多代理:保存的 last_react_input 可能仅为首轮用户消息,不含工具轨迹;补充最后一轮助手的过程详情(与单代理「最后一轮 ReAct」对齐) // 多代理:保存的轨迹列可能仅为首轮用户消息,不含工具轨迹;补充最后一轮助手的过程详情(与单代理完整轨迹对齐)
hasMCPOnAssistant := false hasMCPOnAssistant := false
var lastAssistantID string var lastAssistantID string
for i := len(messages) - 1; i >= 0; i-- { for i := len(messages) - 1; i >= 0; i-- {
@@ -320,7 +320,7 @@ func (b *Builder) formatProcessDetailsForAttackChain(details []database.ProcessD
} }
// 1) 编排器的工具调用/结果:保留(这是“主 agent 调了什么工具”) // 1) 编排器的工具调用/结果:保留(这是“主 agent 调了什么工具”)
if (d.EventType == "tool_call" || d.EventType == "tool_result" || d.EventType == "tool_calls_detected" || d.EventType == "iteration" || d.EventType == "eino_recovery") && einoRole == "orchestrator" { if (d.EventType == "tool_call" || d.EventType == "tool_result" || d.EventType == "tool_calls_detected" || d.EventType == "iteration") && einoRole == "orchestrator" {
sb.WriteString("[") sb.WriteString("[")
sb.WriteString(d.EventType) sb.WriteString(d.EventType)
sb.WriteString("] ") sb.WriteString("] ")
@@ -366,8 +366,8 @@ func (b *Builder) formatProcessDetailsForAttackChain(details []database.ProcessD
return strings.TrimSpace(sb.String()) return strings.TrimSpace(sb.String())
} }
// buildReActInput 构建最后一轮ReAct的输入(历史消息+当前用户输入) // buildAgentTraceInput 构建最后一轮ReAct的输入(历史消息+当前用户输入)
func (b *Builder) buildReActInput(messages []database.Message) string { func (b *Builder) buildAgentTraceInput(messages []database.Message) string {
var builder strings.Builder var builder strings.Builder
for _, msg := range messages { for _, msg := range messages {
builder.WriteString(fmt.Sprintf("[%s]: %s\n\n", msg.Role, msg.Content)) builder.WriteString(fmt.Sprintf("[%s]: %s\n\n", msg.Role, msg.Content))
@@ -396,8 +396,8 @@ func (b *Builder) buildReActInput(messages []database.Message) string {
// return "" // return ""
// } // }
// formatReActInputFromJSON 将JSON格式的messages数组转换为可读的字符串格式 // formatAgentTraceInputFromJSON 将JSON格式的messages数组转换为可读的字符串格式
func (b *Builder) formatReActInputFromJSON(reactInputJSON string) string { func (b *Builder) formatAgentTraceInputFromJSON(reactInputJSON string) string {
var messages []map[string]interface{} var messages []map[string]interface{}
if err := json.Unmarshal([]byte(reactInputJSON), &messages); err != nil { if err := json.Unmarshal([]byte(reactInputJSON), &messages); err != nil {
b.logger.Warn("解析ReAct输入JSON失败", zap.Error(err)) b.logger.Warn("解析ReAct输入JSON失败", zap.Error(err))
+112
View File
@@ -72,6 +72,8 @@ type MultiAgentEinoMiddlewareConfig struct {
ToolSearchEnable bool `yaml:"tool_search_enable,omitempty" json:"tool_search_enable,omitempty"` ToolSearchEnable bool `yaml:"tool_search_enable,omitempty" json:"tool_search_enable,omitempty"`
ToolSearchMinTools int `yaml:"tool_search_min_tools,omitempty" json:"tool_search_min_tools,omitempty"` // default 20; applies when len(tools) >= this ToolSearchMinTools int `yaml:"tool_search_min_tools,omitempty" json:"tool_search_min_tools,omitempty"` // default 20; applies when len(tools) >= this
ToolSearchAlwaysVisible int `yaml:"tool_search_always_visible,omitempty" json:"tool_search_always_visible,omitempty"` // default 12; first N tools stay always visible ToolSearchAlwaysVisible int `yaml:"tool_search_always_visible,omitempty" json:"tool_search_always_visible,omitempty"` // default 12; first N tools stay always visible
// ToolSearchAlwaysVisibleTools keeps specified tool names always visible (never hidden by tool_search).
ToolSearchAlwaysVisibleTools []string `yaml:"tool_search_always_visible_tools,omitempty" json:"tool_search_always_visible_tools,omitempty"`
// Plantask adds TaskCreate/Get/Update/List (file-backed under skills dir); requires eino_skills + local backend. // Plantask adds TaskCreate/Get/Update/List (file-backed under skills dir); requires eino_skills + local backend.
PlantaskEnable bool `yaml:"plantask_enable,omitempty" json:"plantask_enable,omitempty"` PlantaskEnable bool `yaml:"plantask_enable,omitempty" json:"plantask_enable,omitempty"`
// PlantaskRelDir relative to skills_dir for per-conversation task boards (default .eino/plantask). // PlantaskRelDir relative to skills_dir for per-conversation task boards (default .eino/plantask).
@@ -79,8 +81,24 @@ type MultiAgentEinoMiddlewareConfig struct {
// Reduction truncates/offloads large tool outputs (requires eino local backend for Write). // Reduction truncates/offloads large tool outputs (requires eino local backend for Write).
ReductionEnable bool `yaml:"reduction_enable,omitempty" json:"reduction_enable,omitempty"` ReductionEnable bool `yaml:"reduction_enable,omitempty" json:"reduction_enable,omitempty"`
ReductionRootDir string `yaml:"reduction_root_dir,omitempty" json:"reduction_root_dir,omitempty"` // default: os temp + conversation id ReductionRootDir string `yaml:"reduction_root_dir,omitempty" json:"reduction_root_dir,omitempty"` // default: os temp + conversation id
ReductionMaxLengthForTrunc int `yaml:"reduction_max_length_for_trunc,omitempty" json:"reduction_max_length_for_trunc,omitempty"` // default 12000
ReductionMaxTokensForClear int `yaml:"reduction_max_tokens_for_clear,omitempty" json:"reduction_max_tokens_for_clear,omitempty"` // default 50000
ReductionClearExclude []string `yaml:"reduction_clear_exclude,omitempty" json:"reduction_clear_exclude,omitempty"` ReductionClearExclude []string `yaml:"reduction_clear_exclude,omitempty" json:"reduction_clear_exclude,omitempty"`
ReductionSubAgents bool `yaml:"reduction_sub_agents,omitempty" json:"reduction_sub_agents,omitempty"` // also attach to sub-agents ReductionSubAgents bool `yaml:"reduction_sub_agents,omitempty" json:"reduction_sub_agents,omitempty"` // also attach to sub-agents
// SummarizationTriggerRatio controls summarization trigger threshold as max_total_tokens * ratio (default 0.8).
SummarizationTriggerRatio float64 `yaml:"summarization_trigger_ratio,omitempty" json:"summarization_trigger_ratio,omitempty"`
// SummarizationEmitInternalEvents controls middleware internal event emission (default true).
SummarizationEmitInternalEvents *bool `yaml:"summarization_emit_internal_events,omitempty" json:"summarization_emit_internal_events,omitempty"`
// HistoryInputBudgetRatio caps pre-agent history tokens as max_total_tokens * ratio (default 0.35).
HistoryInputBudgetRatio float64 `yaml:"history_input_budget_ratio,omitempty" json:"history_input_budget_ratio,omitempty"`
// PlanExecuteUserInputBudgetRatio caps planner/replanner/executor userInput prompt budget ratio (default 0.35).
PlanExecuteUserInputBudgetRatio float64 `yaml:"plan_execute_user_input_budget_ratio,omitempty" json:"plan_execute_user_input_budget_ratio,omitempty"`
// PlanExecuteExecutedStepsBudgetRatio caps executed_steps prompt budget ratio (default 0.2).
PlanExecuteExecutedStepsBudgetRatio float64 `yaml:"plan_execute_executed_steps_budget_ratio,omitempty" json:"plan_execute_executed_steps_budget_ratio,omitempty"`
// PlanExecuteMaxStepResultRunes caps each executed step result length for prompt view (default 4000).
PlanExecuteMaxStepResultRunes int `yaml:"plan_execute_max_step_result_runes,omitempty" json:"plan_execute_max_step_result_runes,omitempty"`
// PlanExecuteKeepLastSteps keeps only the tail steps in prompt view (default 8).
PlanExecuteKeepLastSteps int `yaml:"plan_execute_keep_last_steps,omitempty" json:"plan_execute_keep_last_steps,omitempty"`
// CheckpointDir when non-empty enables adk.Runner CheckPointStore (file-backed) for interrupt/resume persistence. // CheckpointDir when non-empty enables adk.Runner CheckPointStore (file-backed) for interrupt/resume persistence.
CheckpointDir string `yaml:"checkpoint_dir,omitempty" json:"checkpoint_dir,omitempty"` CheckpointDir string `yaml:"checkpoint_dir,omitempty" json:"checkpoint_dir,omitempty"`
// DeepOutputKey passed to deep.Config OutputKey (session final text); empty = off. // DeepOutputKey passed to deep.Config OutputKey (session final text); empty = off.
@@ -91,6 +109,97 @@ type MultiAgentEinoMiddlewareConfig struct {
TaskToolDescriptionPrefix string `yaml:"task_tool_description_prefix,omitempty" json:"task_tool_description_prefix,omitempty"` TaskToolDescriptionPrefix string `yaml:"task_tool_description_prefix,omitempty" json:"task_tool_description_prefix,omitempty"`
} }
func (c MultiAgentEinoMiddlewareConfig) SummarizationTriggerRatioEffective() float64 {
v := c.SummarizationTriggerRatio
if v <= 0 {
return 0.8
}
if v < 0.5 {
return 0.5
}
if v > 0.95 {
return 0.95
}
return v
}
func (c MultiAgentEinoMiddlewareConfig) SummarizationEmitInternalEventsEffective() bool {
if c.SummarizationEmitInternalEvents != nil {
return *c.SummarizationEmitInternalEvents
}
return true
}
func (c MultiAgentEinoMiddlewareConfig) HistoryInputBudgetRatioEffective() float64 {
v := c.HistoryInputBudgetRatio
if v <= 0 {
return 0.35
}
if v < 0.15 {
return 0.15
}
if v > 0.6 {
return 0.6
}
return v
}
func (c MultiAgentEinoMiddlewareConfig) PlanExecuteUserInputBudgetRatioEffective() float64 {
v := c.PlanExecuteUserInputBudgetRatio
if v <= 0 {
return 0.35
}
if v < 0.1 {
return 0.1
}
if v > 0.6 {
return 0.6
}
return v
}
func (c MultiAgentEinoMiddlewareConfig) PlanExecuteExecutedStepsBudgetRatioEffective() float64 {
v := c.PlanExecuteExecutedStepsBudgetRatio
if v <= 0 {
return 0.2
}
if v < 0.08 {
return 0.08
}
if v > 0.5 {
return 0.5
}
return v
}
func (c MultiAgentEinoMiddlewareConfig) PlanExecuteMaxStepResultRunesEffective() int {
if c.PlanExecuteMaxStepResultRunes > 0 {
return c.PlanExecuteMaxStepResultRunes
}
return 4000
}
func (c MultiAgentEinoMiddlewareConfig) PlanExecuteKeepLastStepsEffective() int {
if c.PlanExecuteKeepLastSteps > 0 {
return c.PlanExecuteKeepLastSteps
}
return 8
}
func (c MultiAgentEinoMiddlewareConfig) ReductionMaxLengthForTruncEffective() int {
if c.ReductionMaxLengthForTrunc > 0 {
return c.ReductionMaxLengthForTrunc
}
return 12000
}
func (c MultiAgentEinoMiddlewareConfig) ReductionMaxTokensForClearEffective() int {
if c.ReductionMaxTokensForClear > 0 {
return c.ReductionMaxTokensForClear
}
return 50000
}
// MultiAgentEinoSkillsConfig toggles Eino official skill progressive disclosure and host filesystem tools. // MultiAgentEinoSkillsConfig toggles Eino official skill progressive disclosure and host filesystem tools.
type MultiAgentEinoSkillsConfig struct { type MultiAgentEinoSkillsConfig struct {
// Disable skips skill middleware (and does not attach local FS tools for Deep). // Disable skips skill middleware (and does not attach local FS tools for Deep).
@@ -137,6 +246,8 @@ type MultiAgentPublic struct {
SubAgentCount int `json:"sub_agent_count"` SubAgentCount int `json:"sub_agent_count"`
Orchestration string `json:"orchestration,omitempty"` Orchestration string `json:"orchestration,omitempty"`
PlanExecuteLoopMaxIterations int `json:"plan_execute_loop_max_iterations"` PlanExecuteLoopMaxIterations int `json:"plan_execute_loop_max_iterations"`
ToolSearchAlwaysVisibleTools []string `json:"tool_search_always_visible_tools,omitempty"`
ToolSearchAlwaysVisibleEffectiveTools []string `json:"tool_search_always_visible_effective_tools,omitempty"`
} }
// NormalizeMultiAgentOrchestration 返回 deep、plan_execute 或 supervisor。 // NormalizeMultiAgentOrchestration 返回 deep、plan_execute 或 supervisor。
@@ -158,6 +269,7 @@ type MultiAgentAPIUpdate struct {
RobotUseMultiAgent bool `json:"robot_use_multi_agent"` RobotUseMultiAgent bool `json:"robot_use_multi_agent"`
BatchUseMultiAgent bool `json:"batch_use_multi_agent"` BatchUseMultiAgent bool `json:"batch_use_multi_agent"`
PlanExecuteLoopMaxIterations *int `json:"plan_execute_loop_max_iterations,omitempty"` PlanExecuteLoopMaxIterations *int `json:"plan_execute_loop_max_iterations,omitempty"`
ToolSearchAlwaysVisibleTools []string `json:"tool_search_always_visible_tools,omitempty"`
} }
// RobotsConfig 机器人配置(企业微信、钉钉、飞书等) // RobotsConfig 机器人配置(企业微信、钉钉、飞书等)
-1
View File
@@ -165,4 +165,3 @@ func (db *DB) DeleteAttackChain(conversationID string) error {
return nil return nil
} }
+21 -10
View File
@@ -4,6 +4,8 @@ import (
"database/sql" "database/sql"
"encoding/json" "encoding/json"
"fmt" "fmt"
"os"
"path/filepath"
"strings" "strings"
"time" "time"
@@ -416,25 +418,34 @@ func (db *DB) DeleteConversation(id string) error {
if err != nil { if err != nil {
return fmt.Errorf("删除对话失败: %w", err) return fmt.Errorf("删除对话失败: %w", err)
} }
// Best-effort cleanup for conversation-scoped filesystem artifacts
// (e.g., summarization transcript, reduction/checkpoint files under conversation_artifacts/<id>).
if base := strings.TrimSpace(db.conversationArtifactsDir); base != "" {
artDir := filepath.Join(base, id)
if rmErr := os.RemoveAll(artDir); rmErr != nil {
db.logger.Warn("删除会话 artifacts 目录失败", zap.String("conversationId", id), zap.String("dir", artDir), zap.Error(rmErr))
}
}
db.logger.Info("对话及其所有相关数据已删除", zap.String("conversationId", id)) db.logger.Info("对话及其所有相关数据已删除", zap.String("conversationId", id))
return nil return nil
} }
// SaveReActData 保存最后一轮ReAct的输入和输出 // SaveAgentTrace 保存最后一轮代理消息轨迹与助手输出摘要。
func (db *DB) SaveReActData(conversationID, reactInput, reactOutput string) error { // SQLite 列名仍为 last_react_input / last_react_output,与历史库表兼容;语义上为「全模式代理轨迹」,非仅 ReAct。
func (db *DB) SaveAgentTrace(conversationID, traceInputJSON, assistantOutput string) error {
_, err := db.Exec( _, err := db.Exec(
"UPDATE conversations SET last_react_input = ?, last_react_output = ?, updated_at = ? WHERE id = ?", "UPDATE conversations SET last_react_input = ?, last_react_output = ?, updated_at = ? WHERE id = ?",
reactInput, reactOutput, time.Now(), conversationID, traceInputJSON, assistantOutput, time.Now(), conversationID,
) )
if err != nil { if err != nil {
return fmt.Errorf("保存ReAct数据失败: %w", err) return fmt.Errorf("保存代理轨迹失败: %w", err)
} }
return nil return nil
} }
// GetReActData 获取最后一轮ReAct的输入和输出 // GetAgentTrace 读取 conversations 中保存的代理轨迹(列名 last_react_*)。
func (db *DB) GetReActData(conversationID string) (reactInput, reactOutput string, err error) { func (db *DB) GetAgentTrace(conversationID string) (traceInputJSON, assistantOutput string, err error) {
var input, output sql.NullString var input, output sql.NullString
err = db.QueryRow( err = db.QueryRow(
"SELECT last_react_input, last_react_output FROM conversations WHERE id = ?", "SELECT last_react_input, last_react_output FROM conversations WHERE id = ?",
@@ -444,17 +455,17 @@ func (db *DB) GetReActData(conversationID string) (reactInput, reactOutput strin
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
return "", "", fmt.Errorf("对话不存在") return "", "", fmt.Errorf("对话不存在")
} }
return "", "", fmt.Errorf("获取ReAct数据失败: %w", err) return "", "", fmt.Errorf("获取代理轨迹失败: %w", err)
} }
if input.Valid { if input.Valid {
reactInput = input.String traceInputJSON = input.String
} }
if output.Valid { if output.Valid {
reactOutput = output.String assistantOutput = output.String
} }
return reactInput, reactOutput, nil return traceInputJSON, assistantOutput, nil
} }
// ConversationHasToolProcessDetails 对话是否存在已落库的工具调用/结果(用于多代理等场景下 MCP execution id 未汇总时的攻击链判定)。 // ConversationHasToolProcessDetails 对话是否存在已落库的工具调用/结果(用于多代理等场景下 MCP execution id 未汇总时的攻击链判定)。
+50 -1
View File
@@ -3,6 +3,8 @@ package database
import ( import (
"database/sql" "database/sql"
"fmt" "fmt"
"os"
"path/filepath"
"strings" "strings"
"time" "time"
@@ -22,6 +24,7 @@ func configureDBPool(db *sql.DB) {
type DB struct { type DB struct {
*sql.DB *sql.DB
logger *zap.Logger logger *zap.Logger
conversationArtifactsDir string
} }
// NewDB 创建数据库连接 // NewDB 创建数据库连接
@@ -41,6 +44,13 @@ func NewDB(dbPath string, logger *zap.Logger) (*DB, error) {
DB: db, DB: db,
logger: logger, logger: logger,
} }
// Keep conversation-scoped artifacts near database files, so cleanup can follow conversation lifecycle.
baseDir := filepath.Join(filepath.Dir(dbPath), "conversation_artifacts")
if mkErr := os.MkdirAll(baseDir, 0o755); mkErr == nil {
database.conversationArtifactsDir = baseDir
} else if logger != nil {
logger.Warn("创建 conversation artifacts 目录失败", zap.String("dir", baseDir), zap.Error(mkErr))
}
// 初始化表 // 初始化表
if err := database.initTables(); err != nil { if err := database.initTables(); err != nil {
@@ -52,7 +62,7 @@ func NewDB(dbPath string, logger *zap.Logger) (*DB, error) {
// initTables 初始化数据库表 // initTables 初始化数据库表
func (db *DB) initTables() error { func (db *DB) initTables() error {
// 创建对话表 // 创建对话表last_react_input / last_react_output 存「代理消息轨迹」JSON 与助手摘要,列名保留以兼容已有库)
createConversationsTable := ` createConversationsTable := `
CREATE TABLE IF NOT EXISTS conversations ( CREATE TABLE IF NOT EXISTS conversations (
id TEXT PRIMARY KEY, id TEXT PRIMARY KEY,
@@ -197,6 +207,8 @@ func (db *DB) initTables() error {
CREATE TABLE IF NOT EXISTS vulnerabilities ( CREATE TABLE IF NOT EXISTS vulnerabilities (
id TEXT PRIMARY KEY, id TEXT PRIMARY KEY,
conversation_id TEXT NOT NULL, conversation_id TEXT NOT NULL,
conversation_tag TEXT,
task_tag TEXT,
title TEXT NOT NULL, title TEXT NOT NULL,
description TEXT, description TEXT,
severity TEXT NOT NULL, severity TEXT NOT NULL,
@@ -289,6 +301,8 @@ func (db *DB) initTables() error {
CREATE INDEX IF NOT EXISTS idx_conversation_group_mappings_group ON conversation_group_mappings(group_id); CREATE INDEX IF NOT EXISTS idx_conversation_group_mappings_group ON conversation_group_mappings(group_id);
CREATE INDEX IF NOT EXISTS idx_conversations_pinned ON conversations(pinned); CREATE INDEX IF NOT EXISTS idx_conversations_pinned ON conversations(pinned);
CREATE INDEX IF NOT EXISTS idx_vulnerabilities_conversation_id ON vulnerabilities(conversation_id); CREATE INDEX IF NOT EXISTS idx_vulnerabilities_conversation_id ON vulnerabilities(conversation_id);
CREATE INDEX IF NOT EXISTS idx_vulnerabilities_conversation_tag ON vulnerabilities(conversation_tag);
CREATE INDEX IF NOT EXISTS idx_vulnerabilities_task_tag ON vulnerabilities(task_tag);
CREATE INDEX IF NOT EXISTS idx_vulnerabilities_severity ON vulnerabilities(severity); CREATE INDEX IF NOT EXISTS idx_vulnerabilities_severity ON vulnerabilities(severity);
CREATE INDEX IF NOT EXISTS idx_vulnerabilities_status ON vulnerabilities(status); CREATE INDEX IF NOT EXISTS idx_vulnerabilities_status ON vulnerabilities(status);
CREATE INDEX IF NOT EXISTS idx_vulnerabilities_created_at ON vulnerabilities(created_at); CREATE INDEX IF NOT EXISTS idx_vulnerabilities_created_at ON vulnerabilities(created_at);
@@ -383,6 +397,10 @@ func (db *DB) initTables() error {
db.logger.Warn("迁移batch_task_queues表失败", zap.Error(err)) db.logger.Warn("迁移batch_task_queues表失败", zap.Error(err))
// 不返回错误,允许继续运行 // 不返回错误,允许继续运行
} }
if err := db.migrateVulnerabilitiesTable(); err != nil {
db.logger.Warn("迁移vulnerabilities表失败", zap.Error(err))
// 不返回错误,允许继续运行
}
if _, err := db.Exec(createIndexes); err != nil { if _, err := db.Exec(createIndexes); err != nil {
return fmt.Errorf("创建索引失败: %w", err) return fmt.Errorf("创建索引失败: %w", err)
@@ -683,6 +701,37 @@ func (db *DB) migrateBatchTaskQueuesTable() error {
return nil return nil
} }
// migrateVulnerabilitiesTable 迁移 vulnerabilities 表,补充标签字段
func (db *DB) migrateVulnerabilitiesTable() error {
columns := []struct {
name string
stmt string
}{
{name: "conversation_tag", stmt: "ALTER TABLE vulnerabilities ADD COLUMN conversation_tag TEXT"},
{name: "task_tag", stmt: "ALTER TABLE vulnerabilities ADD COLUMN task_tag TEXT"},
}
for _, col := range columns {
var count int
err := db.QueryRow("SELECT COUNT(*) FROM pragma_table_info('vulnerabilities') WHERE name=?", col.name).Scan(&count)
if err != nil {
if _, addErr := db.Exec(col.stmt); addErr != nil {
errMsg := strings.ToLower(addErr.Error())
if !strings.Contains(errMsg, "duplicate column") && !strings.Contains(errMsg, "already exists") {
db.logger.Warn("添加vulnerabilities字段失败", zap.String("field", col.name), zap.Error(addErr))
}
}
continue
}
if count == 0 {
if _, addErr := db.Exec(col.stmt); addErr != nil {
db.logger.Warn("添加vulnerabilities字段失败", zap.String("field", col.name), zap.Error(addErr))
}
}
}
return nil
}
// NewKnowledgeDB 创建知识库数据库连接(只包含知识库相关的表) // NewKnowledgeDB 创建知识库数据库连接(只包含知识库相关的表)
func NewKnowledgeDB(dbPath string, logger *zap.Logger) (*DB, error) { func NewKnowledgeDB(dbPath string, logger *zap.Logger) (*DB, error) {
sqlDB, err := sql.Open("sqlite3", dbPath+"?_journal_mode=WAL&_foreign_keys=1&_busy_timeout=5000&_synchronous=NORMAL") sqlDB, err := sql.Open("sqlite3", dbPath+"?_journal_mode=WAL&_foreign_keys=1&_busy_timeout=5000&_synchronous=NORMAL")
+117 -29
View File
@@ -13,6 +13,10 @@ import (
type Vulnerability struct { type Vulnerability struct {
ID string `json:"id"` ID string `json:"id"`
ConversationID string `json:"conversation_id"` ConversationID string `json:"conversation_id"`
ConversationTag string `json:"conversation_tag,omitempty"`
TaskTag string `json:"task_tag,omitempty"`
TaskID string `json:"task_id,omitempty"`
TaskQueueID string `json:"task_queue_id,omitempty"`
Title string `json:"title"` Title string `json:"title"`
Description string `json:"description"` Description string `json:"description"`
Severity string `json:"severity"` // critical, high, medium, low, info Severity string `json:"severity"` // critical, high, medium, low, info
@@ -42,15 +46,15 @@ func (db *DB) CreateVulnerability(vuln *Vulnerability) (*Vulnerability, error) {
query := ` query := `
INSERT INTO vulnerabilities ( INSERT INTO vulnerabilities (
id, conversation_id, title, description, severity, status, id, conversation_id, conversation_tag, task_tag, title, description, severity, status,
vulnerability_type, target, proof, impact, recommendation, vulnerability_type, target, proof, impact, recommendation,
created_at, updated_at created_at, updated_at
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
` `
_, err := db.Exec( _, err := db.Exec(
query, query,
vuln.ID, vuln.ConversationID, vuln.Title, vuln.Description, vuln.ID, vuln.ConversationID, vuln.ConversationTag, vuln.TaskTag, vuln.Title, vuln.Description,
vuln.Severity, vuln.Status, vuln.Type, vuln.Target, vuln.Severity, vuln.Status, vuln.Type, vuln.Target,
vuln.Proof, vuln.Impact, vuln.Recommendation, vuln.Proof, vuln.Impact, vuln.Recommendation,
vuln.CreatedAt, vuln.UpdatedAt, vuln.CreatedAt, vuln.UpdatedAt,
@@ -67,7 +71,9 @@ func (db *DB) GetVulnerability(id string) (*Vulnerability, error) {
var vuln Vulnerability var vuln Vulnerability
query := ` query := `
SELECT id, conversation_id, title, description, severity, status, SELECT id, conversation_id, title, description, severity, status,
vulnerability_type, target, proof, impact, recommendation, conversation_tag, task_tag, vulnerability_type, target, proof, impact, recommendation,
COALESCE((SELECT bt.id FROM batch_tasks bt WHERE bt.conversation_id = vulnerabilities.conversation_id LIMIT 1), '') AS task_id,
COALESCE((SELECT bt.queue_id FROM batch_tasks bt WHERE bt.conversation_id = vulnerabilities.conversation_id LIMIT 1), '') AS task_queue_id,
created_at, updated_at created_at, updated_at
FROM vulnerabilities FROM vulnerabilities
WHERE id = ? WHERE id = ?
@@ -75,8 +81,9 @@ func (db *DB) GetVulnerability(id string) (*Vulnerability, error) {
err := db.QueryRow(query, id).Scan( err := db.QueryRow(query, id).Scan(
&vuln.ID, &vuln.ConversationID, &vuln.Title, &vuln.Description, &vuln.ID, &vuln.ConversationID, &vuln.Title, &vuln.Description,
&vuln.Severity, &vuln.Status, &vuln.Type, &vuln.Target, &vuln.Severity, &vuln.Status, &vuln.ConversationTag, &vuln.TaskTag, &vuln.Type, &vuln.Target,
&vuln.Proof, &vuln.Impact, &vuln.Recommendation, &vuln.Proof, &vuln.Impact, &vuln.Recommendation,
&vuln.TaskID, &vuln.TaskQueueID,
&vuln.CreatedAt, &vuln.UpdatedAt, &vuln.CreatedAt, &vuln.UpdatedAt,
) )
if err != nil { if err != nil {
@@ -90,10 +97,12 @@ func (db *DB) GetVulnerability(id string) (*Vulnerability, error) {
} }
// ListVulnerabilities 列出漏洞 // ListVulnerabilities 列出漏洞
func (db *DB) ListVulnerabilities(limit, offset int, id, conversationID, severity, status string) ([]*Vulnerability, error) { func (db *DB) ListVulnerabilities(limit, offset int, id, conversationID, severity, status, taskID, conversationTag, taskTag string) ([]*Vulnerability, error) {
query := ` query := `
SELECT id, conversation_id, title, description, severity, status, SELECT id, conversation_id, title, description, severity, status, conversation_tag, task_tag,
vulnerability_type, target, proof, impact, recommendation, vulnerability_type, target, proof, impact, recommendation,
COALESCE((SELECT bt.id FROM batch_tasks bt WHERE bt.conversation_id = vulnerabilities.conversation_id LIMIT 1), '') AS task_id,
COALESCE((SELECT bt.queue_id FROM batch_tasks bt WHERE bt.conversation_id = vulnerabilities.conversation_id LIMIT 1), '') AS task_queue_id,
created_at, updated_at created_at, updated_at
FROM vulnerabilities FROM vulnerabilities
WHERE 1=1 WHERE 1=1
@@ -108,6 +117,18 @@ func (db *DB) ListVulnerabilities(limit, offset int, id, conversationID, severit
query += " AND conversation_id = ?" query += " AND conversation_id = ?"
args = append(args, conversationID) args = append(args, conversationID)
} }
if taskID != "" {
query += " AND EXISTS (SELECT 1 FROM batch_tasks bt WHERE bt.conversation_id = vulnerabilities.conversation_id AND (bt.id = ? OR bt.queue_id = ?))"
args = append(args, taskID, taskID)
}
if conversationTag != "" {
query += " AND conversation_tag = ?"
args = append(args, conversationTag)
}
if taskTag != "" {
query += " AND task_tag = ?"
args = append(args, taskTag)
}
if severity != "" { if severity != "" {
query += " AND severity = ?" query += " AND severity = ?"
args = append(args, severity) args = append(args, severity)
@@ -131,8 +152,9 @@ func (db *DB) ListVulnerabilities(limit, offset int, id, conversationID, severit
var vuln Vulnerability var vuln Vulnerability
err := rows.Scan( err := rows.Scan(
&vuln.ID, &vuln.ConversationID, &vuln.Title, &vuln.Description, &vuln.ID, &vuln.ConversationID, &vuln.Title, &vuln.Description,
&vuln.Severity, &vuln.Status, &vuln.Type, &vuln.Target, &vuln.Severity, &vuln.Status, &vuln.ConversationTag, &vuln.TaskTag, &vuln.Type, &vuln.Target,
&vuln.Proof, &vuln.Impact, &vuln.Recommendation, &vuln.Proof, &vuln.Impact, &vuln.Recommendation,
&vuln.TaskID, &vuln.TaskQueueID,
&vuln.CreatedAt, &vuln.UpdatedAt, &vuln.CreatedAt, &vuln.UpdatedAt,
) )
if err != nil { if err != nil {
@@ -146,7 +168,7 @@ func (db *DB) ListVulnerabilities(limit, offset int, id, conversationID, severit
} }
// CountVulnerabilities 统计漏洞总数(支持筛选条件) // CountVulnerabilities 统计漏洞总数(支持筛选条件)
func (db *DB) CountVulnerabilities(id, conversationID, severity, status string) (int, error) { func (db *DB) CountVulnerabilities(id, conversationID, severity, status, taskID, conversationTag, taskTag string) (int, error) {
query := "SELECT COUNT(*) FROM vulnerabilities WHERE 1=1" query := "SELECT COUNT(*) FROM vulnerabilities WHERE 1=1"
args := []interface{}{} args := []interface{}{}
@@ -158,6 +180,18 @@ func (db *DB) CountVulnerabilities(id, conversationID, severity, status string)
query += " AND conversation_id = ?" query += " AND conversation_id = ?"
args = append(args, conversationID) args = append(args, conversationID)
} }
if taskID != "" {
query += " AND EXISTS (SELECT 1 FROM batch_tasks bt WHERE bt.conversation_id = vulnerabilities.conversation_id AND (bt.id = ? OR bt.queue_id = ?))"
args = append(args, taskID, taskID)
}
if conversationTag != "" {
query += " AND conversation_tag = ?"
args = append(args, conversationTag)
}
if taskTag != "" {
query += " AND task_tag = ?"
args = append(args, taskTag)
}
if severity != "" { if severity != "" {
query += " AND severity = ?" query += " AND severity = ?"
args = append(args, severity) args = append(args, severity)
@@ -182,7 +216,7 @@ func (db *DB) UpdateVulnerability(id string, vuln *Vulnerability) error {
query := ` query := `
UPDATE vulnerabilities UPDATE vulnerabilities
SET title = ?, description = ?, severity = ?, status = ?, SET conversation_tag = ?, task_tag = ?, title = ?, description = ?, severity = ?, status = ?,
vulnerability_type = ?, target = ?, proof = ?, impact = ?, vulnerability_type = ?, target = ?, proof = ?, impact = ?,
recommendation = ?, updated_at = ? recommendation = ?, updated_at = ?
WHERE id = ? WHERE id = ?
@@ -190,7 +224,7 @@ func (db *DB) UpdateVulnerability(id string, vuln *Vulnerability) error {
_, err := db.Exec( _, err := db.Exec(
query, query,
vuln.Title, vuln.Description, vuln.Severity, vuln.Status, vuln.ConversationTag, vuln.TaskTag, vuln.Title, vuln.Description, vuln.Severity, vuln.Status,
vuln.Type, vuln.Target, vuln.Proof, vuln.Impact, vuln.Type, vuln.Target, vuln.Proof, vuln.Impact,
vuln.Recommendation, vuln.UpdatedAt, id, vuln.Recommendation, vuln.UpdatedAt, id,
) )
@@ -210,18 +244,24 @@ func (db *DB) DeleteVulnerability(id string) error {
return nil return nil
} }
// GetVulnerabilityStats 获取漏洞统计 // GetVulnerabilityStats 获取漏洞统计(筛选条件与 ListVulnerabilities / CountVulnerabilities 一致)
func (db *DB) GetVulnerabilityStats(conversationID string) (map[string]interface{}, error) { func (db *DB) GetVulnerabilityStats(conversationID, taskID string) (map[string]interface{}, error) {
stats := make(map[string]interface{}) stats := make(map[string]interface{})
where := "WHERE 1=1"
args := []interface{}{}
if conversationID != "" {
where += " AND conversation_id = ?"
args = append(args, conversationID)
}
if taskID != "" {
where += " AND EXISTS (SELECT 1 FROM batch_tasks bt WHERE bt.conversation_id = vulnerabilities.conversation_id AND (bt.id = ? OR bt.queue_id = ?))"
args = append(args, taskID, taskID)
}
// 总漏洞数 // 总漏洞数
var totalCount int var totalCount int
query := "SELECT COUNT(*) FROM vulnerabilities" query := "SELECT COUNT(*) FROM vulnerabilities " + where
args := []interface{}{}
if conversationID != "" {
query += " WHERE conversation_id = ?"
args = append(args, conversationID)
}
err := db.QueryRow(query, args...).Scan(&totalCount) err := db.QueryRow(query, args...).Scan(&totalCount)
if err != nil { if err != nil {
return nil, fmt.Errorf("获取总漏洞数失败: %w", err) return nil, fmt.Errorf("获取总漏洞数失败: %w", err)
@@ -229,11 +269,7 @@ func (db *DB) GetVulnerabilityStats(conversationID string) (map[string]interface
stats["total"] = totalCount stats["total"] = totalCount
// 按严重程度统计 // 按严重程度统计
severityQuery := "SELECT severity, COUNT(*) FROM vulnerabilities" severityQuery := "SELECT severity, COUNT(*) FROM vulnerabilities " + where + " GROUP BY severity"
if conversationID != "" {
severityQuery += " WHERE conversation_id = ?"
}
severityQuery += " GROUP BY severity"
rows, err := db.Query(severityQuery, args...) rows, err := db.Query(severityQuery, args...)
if err != nil { if err != nil {
@@ -253,11 +289,7 @@ func (db *DB) GetVulnerabilityStats(conversationID string) (map[string]interface
stats["by_severity"] = severityStats stats["by_severity"] = severityStats
// 按状态统计 // 按状态统计
statusQuery := "SELECT status, COUNT(*) FROM vulnerabilities" statusQuery := "SELECT status, COUNT(*) FROM vulnerabilities " + where + " GROUP BY status"
if conversationID != "" {
statusQuery += " WHERE conversation_id = ?"
}
statusQuery += " GROUP BY status"
rows, err = db.Query(statusQuery, args...) rows, err = db.Query(statusQuery, args...)
if err != nil { if err != nil {
@@ -279,3 +311,59 @@ func (db *DB) GetVulnerabilityStats(conversationID string) (map[string]interface
return stats, nil return stats, nil
} }
// GetVulnerabilityFilterOptions 获取漏洞筛选建议项
func (db *DB) GetVulnerabilityFilterOptions() (map[string][]string, error) {
collect := func(query string, args ...interface{}) ([]string, error) {
rows, err := db.Query(query, args...)
if err != nil {
return nil, err
}
defer rows.Close()
items := make([]string, 0)
for rows.Next() {
var val string
if err := rows.Scan(&val); err != nil {
continue
}
if val == "" {
continue
}
items = append(items, val)
}
return items, nil
}
vulnIDs, err := collect(`SELECT DISTINCT id FROM vulnerabilities ORDER BY created_at DESC LIMIT 500`)
if err != nil {
return nil, fmt.Errorf("查询漏洞ID建议失败: %w", err)
}
conversationIDs, err := collect(`SELECT DISTINCT conversation_id FROM vulnerabilities WHERE conversation_id <> '' ORDER BY created_at DESC LIMIT 500`)
if err != nil {
return nil, fmt.Errorf("查询会话ID建议失败: %w", err)
}
taskIDs, err := collect(`SELECT DISTINCT id FROM batch_tasks WHERE id <> '' ORDER BY rowid DESC LIMIT 500`)
if err != nil {
return nil, fmt.Errorf("查询任务ID建议失败: %w", err)
}
queueIDs, err := collect(`SELECT DISTINCT queue_id FROM batch_tasks WHERE queue_id <> '' ORDER BY rowid DESC LIMIT 500`)
if err != nil {
return nil, fmt.Errorf("查询队列ID建议失败: %w", err)
}
conversationTags, err := collect(`SELECT DISTINCT conversation_tag FROM vulnerabilities WHERE conversation_tag IS NOT NULL AND conversation_tag <> '' ORDER BY conversation_tag LIMIT 500`)
if err != nil {
return nil, fmt.Errorf("查询对话标签建议失败: %w", err)
}
taskTags, err := collect(`SELECT DISTINCT task_tag FROM vulnerabilities WHERE task_tag IS NOT NULL AND task_tag <> '' ORDER BY task_tag LIMIT 500`)
if err != nil {
return nil, fmt.Errorf("查询任务标签建议失败: %w", err)
}
return map[string][]string{
"vulnerability_ids": vulnIDs,
"conversation_ids": conversationIDs,
"task_ids": taskIDs,
"queue_ids": queueIDs,
"conversation_tags": conversationTags,
"task_tags": taskTags,
}, nil
}
+5 -5
View File
@@ -160,17 +160,17 @@ func runMCPToolInvocation(
} }
// UnknownToolReminderHandler 供 compose.ToolsNodeConfig.UnknownToolsHandler 使用: // UnknownToolReminderHandler 供 compose.ToolsNodeConfig.UnknownToolsHandler 使用:
// 模型请求了未注册的工具名时,返回一个「可恢复」的错误,让上层 runner 触发重试与纠错提示 // 模型请求了未注册的工具名时,返回一个「软错误」工具结果(nil error
// 同时避免 UI 永远停留在“执行中”(runner 会在 recoverable 分支 flush 掉 pending 的 tool_call // 让模型在同一轮继续自我修正,避免触发 run-loop 级别的 full rerun
// 不进行名称猜测或映射,避免误执行。 // 不进行名称猜测或映射,避免误执行。
func UnknownToolReminderHandler() func(ctx context.Context, name, input string) (string, error) { func UnknownToolReminderHandler() func(ctx context.Context, name, input string) (string, error) {
return func(ctx context.Context, name, input string) (string, error) { return func(ctx context.Context, name, input string) (string, error) {
_ = ctx _ = ctx
_ = input _ = input
requested := strings.TrimSpace(name) requested := strings.TrimSpace(name)
// Return a recoverable error that still carries a friendly, bilingual hint. // Return a soft tool-result error so the graph keeps running and the LLM
// This will be caught by multiagent runner as "tool not found" and trigger a retry. // can correct tool name/arguments within the same run.
return "", fmt.Errorf("tool %q not found: %s", requested, unknownToolReminderText(requested)) return ToolErrorPrefix + unknownToolReminderText(requested), nil
} }
} }
+84 -92
View File
@@ -497,10 +497,10 @@ func (h *AgentHandler) AgentLoop(c *gin.Context) {
defer h.hitlManager.DeactivateConversation(conversationID) defer h.hitlManager.DeactivateConversation(conversationID)
} }
// 优先尝试从保存的ReAct数据恢复历史上下文 // 优先尝试从保存的代理轨迹恢复历史上下文
agentHistoryMessages, err := h.loadHistoryFromReActData(conversationID) agentHistoryMessages, err := h.loadHistoryFromAgentTrace(conversationID)
if err != nil { if err != nil {
h.logger.Warn("从ReAct数据加载历史消息失败,使用消息表", zap.Error(err)) h.logger.Warn("从代理轨迹加载历史消息失败,使用消息表", zap.Error(err))
// 回退到使用数据库消息表 // 回退到使用数据库消息表
historyMessages, err := h.db.GetMessages(conversationID) historyMessages, err := h.db.GetMessages(conversationID)
if err != nil { if err != nil {
@@ -518,7 +518,7 @@ func (h *AgentHandler) AgentLoop(c *gin.Context) {
h.logger.Info("从消息表加载历史消息", zap.Int("count", len(agentHistoryMessages))) h.logger.Info("从消息表加载历史消息", zap.Int("count", len(agentHistoryMessages)))
} }
} else { } else {
h.logger.Info("从ReAct数据恢复历史上下文", zap.Int("count", len(agentHistoryMessages))) h.logger.Info("从代理轨迹恢复历史上下文", zap.Int("count", len(agentHistoryMessages)))
} }
// 校验附件数量(非流式) // 校验附件数量(非流式)
@@ -613,12 +613,12 @@ func (h *AgentHandler) AgentLoop(c *gin.Context) {
if err != nil { if err != nil {
h.logger.Error("Agent Loop执行失败", zap.Error(err)) h.logger.Error("Agent Loop执行失败", zap.Error(err))
// 即使执行失败,也尝试保存ReAct数据(如果result中有) // 即使执行失败,也尝试保存代理轨迹(如果 result 中有)
if result != nil && (result.LastReActInput != "" || result.LastReActOutput != "") { if result != nil && (result.LastAgentTraceInput != "" || result.LastAgentTraceOutput != "") {
if saveErr := h.db.SaveReActData(conversationID, result.LastReActInput, result.LastReActOutput); saveErr != nil { if saveErr := h.db.SaveAgentTrace(conversationID, result.LastAgentTraceInput, result.LastAgentTraceOutput); saveErr != nil {
h.logger.Warn("保存失败任务的ReAct数据失败", zap.Error(saveErr)) h.logger.Warn("保存失败任务的代理轨迹失败", zap.Error(saveErr))
} else { } else {
h.logger.Info("已保存失败任务的ReAct数据", zap.String("conversationId", conversationID)) h.logger.Info("已保存失败任务的代理轨迹", zap.String("conversationId", conversationID))
} }
} }
@@ -634,12 +634,12 @@ func (h *AgentHandler) AgentLoop(c *gin.Context) {
// 因为AI已经生成了回复,用户应该能看到 // 因为AI已经生成了回复,用户应该能看到
} }
// 保存最后一轮ReAct的输入和输出 // 保存最后一轮代理轨迹与助手输出
if result.LastReActInput != "" || result.LastReActOutput != "" { if result.LastAgentTraceInput != "" || result.LastAgentTraceOutput != "" {
if err := h.db.SaveReActData(conversationID, result.LastReActInput, result.LastReActOutput); err != nil { if err := h.db.SaveAgentTrace(conversationID, result.LastAgentTraceInput, result.LastAgentTraceOutput); err != nil {
h.logger.Warn("保存ReAct数据失败", zap.Error(err)) h.logger.Warn("保存代理轨迹失败", zap.Error(err))
} else { } else {
h.logger.Info("已保存ReAct数据", zap.String("conversationId", conversationID)) h.logger.Info("已保存代理轨迹", zap.String("conversationId", conversationID))
} }
} }
@@ -666,7 +666,7 @@ func (h *AgentHandler) ProcessMessageForRobot(ctx context.Context, conversationI
} }
} }
agentHistoryMessages, err := h.loadHistoryFromReActData(conversationID) agentHistoryMessages, err := h.loadHistoryFromAgentTrace(conversationID)
if err != nil { if err != nil {
historyMessages, getErr := h.db.GetMessages(conversationID) historyMessages, getErr := h.db.GetMessages(conversationID)
if getErr != nil { if getErr != nil {
@@ -722,6 +722,7 @@ func (h *AgentHandler) ProcessMessageForRobot(ctx context.Context, conversationI
"deep", "deep",
) )
if errMA != nil { if errMA != nil {
h.persistEinoAgentTraceForResume(conversationID, resultMA)
errMsg := "执行失败: " + errMA.Error() errMsg := "执行失败: " + errMA.Error()
if assistantMessageID != "" { if assistantMessageID != "" {
_, _ = h.db.Exec("UPDATE messages SET content = ? WHERE id = ?", errMsg, assistantMessageID) _, _ = h.db.Exec("UPDATE messages SET content = ? WHERE id = ?", errMsg, assistantMessageID)
@@ -747,8 +748,8 @@ func (h *AgentHandler) ProcessMessageForRobot(ctx context.Context, conversationI
h.logger.Warn("机器人:保存助手消息失败", zap.Error(err)) h.logger.Warn("机器人:保存助手消息失败", zap.Error(err))
} }
} }
if resultMA.LastReActInput != "" || resultMA.LastReActOutput != "" { if resultMA.LastAgentTraceInput != "" || resultMA.LastAgentTraceOutput != "" {
_ = h.db.SaveReActData(conversationID, resultMA.LastReActInput, resultMA.LastReActOutput) _ = h.db.SaveAgentTrace(conversationID, resultMA.LastAgentTraceInput, resultMA.LastAgentTraceOutput)
} }
return resultMA.Response, conversationID, nil return resultMA.Response, conversationID, nil
} }
@@ -782,8 +783,8 @@ func (h *AgentHandler) ProcessMessageForRobot(ctx context.Context, conversationI
h.logger.Warn("机器人:保存助手消息失败", zap.Error(err)) h.logger.Warn("机器人:保存助手消息失败", zap.Error(err))
} }
} }
if result.LastReActInput != "" || result.LastReActOutput != "" { if result.LastAgentTraceInput != "" || result.LastAgentTraceOutput != "" {
_ = h.db.SaveReActData(conversationID, result.LastReActInput, result.LastReActOutput) _ = h.db.SaveAgentTrace(conversationID, result.LastAgentTraceInput, result.LastAgentTraceOutput)
} }
return result.Response, conversationID, nil return result.Response, conversationID, nil
} }
@@ -1359,10 +1360,10 @@ func (h *AgentHandler) AgentLoopStream(c *gin.Context) {
} }
ssePublishConversationID = conversationID ssePublishConversationID = conversationID
// 优先尝试从保存的ReAct数据恢复历史上下文 // 优先尝试从保存的代理轨迹恢复历史上下文
agentHistoryMessages, err := h.loadHistoryFromReActData(conversationID) agentHistoryMessages, err := h.loadHistoryFromAgentTrace(conversationID)
if err != nil { if err != nil {
h.logger.Warn("从ReAct数据加载历史消息失败,使用消息表", zap.Error(err)) h.logger.Warn("从代理轨迹加载历史消息失败,使用消息表", zap.Error(err))
// 回退到使用数据库消息表 // 回退到使用数据库消息表
historyMessages, err := h.db.GetMessages(conversationID) historyMessages, err := h.db.GetMessages(conversationID)
if err != nil { if err != nil {
@@ -1380,7 +1381,7 @@ func (h *AgentHandler) AgentLoopStream(c *gin.Context) {
h.logger.Info("从消息表加载历史消息", zap.Int("count", len(agentHistoryMessages))) h.logger.Info("从消息表加载历史消息", zap.Int("count", len(agentHistoryMessages)))
} }
} else { } else {
h.logger.Info("从ReAct数据恢复历史上下文", zap.Int("count", len(agentHistoryMessages))) h.logger.Info("从代理轨迹恢复历史上下文", zap.Int("count", len(agentHistoryMessages)))
} }
// 校验附件数量 // 校验附件数量
@@ -1579,12 +1580,12 @@ func (h *AgentHandler) AgentLoopStream(c *gin.Context) {
h.db.AddProcessDetail(assistantMessageID, conversationID, "cancelled", cancelMsg, nil) h.db.AddProcessDetail(assistantMessageID, conversationID, "cancelled", cancelMsg, nil)
} }
// 即使任务被取消,也尝试保存ReAct数据(如果result中有) // 即使任务被取消,也尝试保存代理轨迹(如果 result 中有)
if result != nil && (result.LastReActInput != "" || result.LastReActOutput != "") { if result != nil && (result.LastAgentTraceInput != "" || result.LastAgentTraceOutput != "") {
if err := h.db.SaveReActData(conversationID, result.LastReActInput, result.LastReActOutput); err != nil { if err := h.db.SaveAgentTrace(conversationID, result.LastAgentTraceInput, result.LastAgentTraceOutput); err != nil {
h.logger.Warn("保存取消任务的ReAct数据失败", zap.Error(err)) h.logger.Warn("保存取消任务的代理轨迹失败", zap.Error(err))
} else { } else {
h.logger.Info("已保存取消任务的ReAct数据", zap.String("conversationId", conversationID)) h.logger.Info("已保存取消任务的代理轨迹", zap.String("conversationId", conversationID))
} }
} }
@@ -1614,12 +1615,12 @@ func (h *AgentHandler) AgentLoopStream(c *gin.Context) {
h.db.AddProcessDetail(assistantMessageID, conversationID, "timeout", timeoutMsg, nil) h.db.AddProcessDetail(assistantMessageID, conversationID, "timeout", timeoutMsg, nil)
} }
// 即使任务超时,也尝试保存ReAct数据(如果result中有) // 即使任务超时,也尝试保存代理轨迹(如果 result 中有)
if result != nil && (result.LastReActInput != "" || result.LastReActOutput != "") { if result != nil && (result.LastAgentTraceInput != "" || result.LastAgentTraceOutput != "") {
if err := h.db.SaveReActData(conversationID, result.LastReActInput, result.LastReActOutput); err != nil { if err := h.db.SaveAgentTrace(conversationID, result.LastAgentTraceInput, result.LastAgentTraceOutput); err != nil {
h.logger.Warn("保存超时任务的ReAct数据失败", zap.Error(err)) h.logger.Warn("保存超时任务的代理轨迹失败", zap.Error(err))
} else { } else {
h.logger.Info("已保存超时任务的ReAct数据", zap.String("conversationId", conversationID)) h.logger.Info("已保存超时任务的代理轨迹", zap.String("conversationId", conversationID))
} }
} }
@@ -1649,12 +1650,12 @@ func (h *AgentHandler) AgentLoopStream(c *gin.Context) {
h.db.AddProcessDetail(assistantMessageID, conversationID, "error", errorMsg, nil) h.db.AddProcessDetail(assistantMessageID, conversationID, "error", errorMsg, nil)
} }
// 即使任务失败,也尝试保存ReAct数据(如果result中有) // 即使任务失败,也尝试保存代理轨迹(如果 result 中有)
if result != nil && (result.LastReActInput != "" || result.LastReActOutput != "") { if result != nil && (result.LastAgentTraceInput != "" || result.LastAgentTraceOutput != "") {
if err := h.db.SaveReActData(conversationID, result.LastReActInput, result.LastReActOutput); err != nil { if err := h.db.SaveAgentTrace(conversationID, result.LastAgentTraceInput, result.LastAgentTraceOutput); err != nil {
h.logger.Warn("保存失败任务的ReAct数据失败", zap.Error(err)) h.logger.Warn("保存失败任务的代理轨迹失败", zap.Error(err))
} else { } else {
h.logger.Info("已保存失败任务的ReAct数据", zap.String("conversationId", conversationID)) h.logger.Info("已保存失败任务的代理轨迹", zap.String("conversationId", conversationID))
} }
} }
@@ -1694,12 +1695,12 @@ func (h *AgentHandler) AgentLoopStream(c *gin.Context) {
} }
} }
// 保存最后一轮ReAct的输入和输出 // 保存最后一轮代理轨迹与助手输出
if result.LastReActInput != "" || result.LastReActOutput != "" { if result.LastAgentTraceInput != "" || result.LastAgentTraceOutput != "" {
if err := h.db.SaveReActData(conversationID, result.LastReActInput, result.LastReActOutput); err != nil { if err := h.db.SaveAgentTrace(conversationID, result.LastAgentTraceInput, result.LastAgentTraceOutput); err != nil {
h.logger.Warn("保存ReAct数据失败", zap.Error(err)) h.logger.Warn("保存代理轨迹失败", zap.Error(err))
} else { } else {
h.logger.Info("已保存ReAct数据", zap.String("conversationId", conversationID)) h.logger.Info("已保存代理轨迹", zap.String("conversationId", conversationID))
} }
} }
@@ -2499,6 +2500,9 @@ func (h *AgentHandler) executeBatchQueue(queueID string) {
cancel() cancel()
if runErr != nil { if runErr != nil {
if useRunResult {
h.persistEinoAgentTraceForResume(conversationID, resultMA)
}
// 检查是否是取消错误 // 检查是否是取消错误
// 1. 直接检查是否是 context.Canceled(包括包装后的错误) // 1. 直接检查是否是 context.Canceled(包括包装后的错误)
// 2. 检查错误消息中是否包含"context canceled"或"cancelled"关键字 // 2. 检查错误消息中是否包含"context canceled"或"cancelled"关键字
@@ -2542,14 +2546,14 @@ func (h *AgentHandler) executeBatchQueue(queueID string) {
h.logger.Warn("保存取消消息失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.Error(errMsg)) h.logger.Warn("保存取消消息失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.Error(errMsg))
} }
} }
// 保存ReAct数据(如果存在) // 保存代理轨迹(如果存在)
if result != nil && (result.LastReActInput != "" || result.LastReActOutput != "") { if result != nil && (result.LastAgentTraceInput != "" || result.LastAgentTraceOutput != "") {
if err := h.db.SaveReActData(conversationID, result.LastReActInput, result.LastReActOutput); err != nil { if err := h.db.SaveAgentTrace(conversationID, result.LastAgentTraceInput, result.LastAgentTraceOutput); err != nil {
h.logger.Warn("保存取消任务的ReAct数据失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.Error(err)) h.logger.Warn("保存取消任务的代理轨迹失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.Error(err))
} }
} else if useRunResult && resultMA != nil && (resultMA.LastReActInput != "" || resultMA.LastReActOutput != "") { } else if useRunResult && resultMA != nil && (resultMA.LastAgentTraceInput != "" || resultMA.LastAgentTraceOutput != "") {
if err := h.db.SaveReActData(conversationID, resultMA.LastReActInput, resultMA.LastReActOutput); err != nil { if err := h.db.SaveAgentTrace(conversationID, resultMA.LastAgentTraceInput, resultMA.LastAgentTraceOutput); err != nil {
h.logger.Warn("保存取消任务的ReAct数据失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.Error(err)) h.logger.Warn("保存取消任务的代理轨迹失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.Error(err))
} }
} }
h.batchTaskManager.UpdateTaskStatusWithConversationID(queueID, task.ID, "cancelled", cancelMsg, "", conversationID) h.batchTaskManager.UpdateTaskStatusWithConversationID(queueID, task.ID, "cancelled", cancelMsg, "", conversationID)
@@ -2581,13 +2585,13 @@ func (h *AgentHandler) executeBatchQueue(queueID string) {
if useRunResult { if useRunResult {
resText = resultMA.Response resText = resultMA.Response
mcpIDs = resultMA.MCPExecutionIDs mcpIDs = resultMA.MCPExecutionIDs
lastIn = resultMA.LastReActInput lastIn = resultMA.LastAgentTraceInput
lastOut = resultMA.LastReActOutput lastOut = resultMA.LastAgentTraceOutput
} else { } else {
resText = result.Response resText = result.Response
mcpIDs = result.MCPExecutionIDs mcpIDs = result.MCPExecutionIDs
lastIn = result.LastReActInput lastIn = result.LastAgentTraceInput
lastOut = result.LastReActOutput lastOut = result.LastAgentTraceOutput
} }
// 更新助手消息内容 // 更新助手消息内容
@@ -2618,12 +2622,12 @@ func (h *AgentHandler) executeBatchQueue(queueID string) {
} }
} }
// 保存ReAct数据 // 保存代理轨迹
if lastIn != "" || lastOut != "" { if lastIn != "" || lastOut != "" {
if err := h.db.SaveReActData(conversationID, lastIn, lastOut); err != nil { if err := h.db.SaveAgentTrace(conversationID, lastIn, lastOut); err != nil {
h.logger.Warn("保存ReAct数据失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.Error(err)) h.logger.Warn("保存代理轨迹失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.Error(err))
} else { } else {
h.logger.Info("已保存ReAct数据", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.String("conversationId", conversationID)) h.logger.Info("已保存代理轨迹", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.String("conversationId", conversationID))
} }
} }
@@ -2642,36 +2646,33 @@ func (h *AgentHandler) executeBatchQueue(queueID string) {
} }
} }
// loadHistoryFromReActData 从保存的ReAct数据恢复历史消息上下文 // loadHistoryFromAgentTrace 从库中保存的代理消息轨迹恢复历史(列 last_react_*;含单代理与 Eino)。
// 采用与攻击链生成类似的拼接逻辑:优先使用保存的last_react_input和last_react_output,若不存在则回退消息表 // 逻辑与攻击链一致:优先用保存的 JSON 消息带 + 最后一轮助手摘要,否则回退消息表
func (h *AgentHandler) loadHistoryFromReActData(conversationID string) ([]agent.ChatMessage, error) { func (h *AgentHandler) loadHistoryFromAgentTrace(conversationID string) ([]agent.ChatMessage, error) {
// 获取保存的ReAct输入和输出 traceInputJSON, assistantOut, err := h.db.GetAgentTrace(conversationID)
reactInputJSON, reactOutput, err := h.db.GetReActData(conversationID)
if err != nil { if err != nil {
return nil, fmt.Errorf("获取ReAct数据失败: %w", err) return nil, fmt.Errorf("获取代理轨迹失败: %w", err)
} }
// 如果last_react_input为空,回退到使用消息表(与攻击链生成逻辑一致) if traceInputJSON == "" {
if reactInputJSON == "" { return nil, fmt.Errorf("代理轨迹为空,将使用消息表")
return nil, fmt.Errorf("ReAct数据为空,将使用消息表")
} }
dataSource := "database_last_react_input" dataSource := "database_last_agent_trace"
// 解析JSON格式的messages数组
var messagesArray []map[string]interface{} var messagesArray []map[string]interface{}
if err := json.Unmarshal([]byte(reactInputJSON), &messagesArray); err != nil { if err := json.Unmarshal([]byte(traceInputJSON), &messagesArray); err != nil {
return nil, fmt.Errorf("解析ReAct输入JSON失败: %w", err) return nil, fmt.Errorf("解析代理轨迹 JSON 失败: %w", err)
} }
messageCount := len(messagesArray) messageCount := len(messagesArray)
h.logger.Info("使用保存的ReAct数据恢复历史上下文", h.logger.Info("使用保存的代理轨迹恢复历史上下文",
zap.String("conversationId", conversationID), zap.String("conversationId", conversationID),
zap.String("dataSource", dataSource), zap.String("dataSource", dataSource),
zap.Int("reactInputSize", len(reactInputJSON)), zap.Int("traceInputSize", len(traceInputJSON)),
zap.Int("messageCount", messageCount), zap.Int("messageCount", messageCount),
zap.Int("reactOutputSize", len(reactOutput)), zap.Int("assistantOutSize", len(assistantOut)),
) )
// fmt.Println("messagesArray:", messagesArray)//debug // fmt.Println("messagesArray:", messagesArray)//debug
@@ -2755,53 +2756,44 @@ func (h *AgentHandler) loadHistoryFromReActData(conversationID string) ([]agent.
agentMessages = append(agentMessages, msg) agentMessages = append(agentMessages, msg)
} }
// 如果存在last_react_output,需要将其作为最后一条assistant消息 // 存在 last_react_output(助手摘要),合并为最后一条 assistant(与保存格式一致)
// 因为last_react_input是在迭代开始前保存的,不包含最后一轮的最终输出 if assistantOut != "" {
if reactOutput != "" {
// 检查最后一条消息是否是assistant消息且没有tool_calls
// 如果有tool_calls,说明后面应该还有tool消息和最终的assistant回复
if len(agentMessages) > 0 { if len(agentMessages) > 0 {
lastMsg := &agentMessages[len(agentMessages)-1] lastMsg := &agentMessages[len(agentMessages)-1]
if strings.EqualFold(lastMsg.Role, "assistant") && len(lastMsg.ToolCalls) == 0 { if strings.EqualFold(lastMsg.Role, "assistant") && len(lastMsg.ToolCalls) == 0 {
// 最后一条是assistant消息且没有tool_calls,用最终输出更新其content lastMsg.Content = assistantOut
lastMsg.Content = reactOutput
} else { } else {
// 最后一条不是assistant消息,或者有tool_calls,添加最终输出作为新的assistant消息
agentMessages = append(agentMessages, agent.ChatMessage{ agentMessages = append(agentMessages, agent.ChatMessage{
Role: "assistant", Role: "assistant",
Content: reactOutput, Content: assistantOut,
}) })
} }
} else { } else {
// 如果没有消息,直接添加最终输出
agentMessages = append(agentMessages, agent.ChatMessage{ agentMessages = append(agentMessages, agent.ChatMessage{
Role: "assistant", Role: "assistant",
Content: reactOutput, Content: assistantOut,
}) })
} }
} }
if len(agentMessages) == 0 { if len(agentMessages) == 0 {
return nil, fmt.Errorf("从ReAct数据解析的消息为空") return nil, fmt.Errorf("从代理轨迹解析的消息为空")
} }
// 修复可能存在的失配tool消息,避免OpenAI报错
// 这可以防止出现"messages with role 'tool' must be a response to a preceeding message with 'tool_calls'"错误
if h.agent != nil { if h.agent != nil {
if fixed := h.agent.RepairOrphanToolMessages(&agentMessages); fixed { if fixed := h.agent.RepairOrphanToolMessages(&agentMessages); fixed {
h.logger.Info("修复了从ReAct数据恢复的历史消息中的失配tool消息", h.logger.Info("修复了从代理轨迹恢复的历史消息中的失配 tool 消息",
zap.String("conversationId", conversationID), zap.String("conversationId", conversationID),
) )
} }
} }
h.logger.Info("从ReAct数据恢复历史消息完成", h.logger.Info("从代理轨迹恢复历史消息完成",
zap.String("conversationId", conversationID), zap.String("conversationId", conversationID),
zap.String("dataSource", dataSource), zap.String("dataSource", dataSource),
zap.Int("originalMessageCount", messageCount), zap.Int("originalMessageCount", messageCount),
zap.Int("finalMessageCount", len(agentMessages)), zap.Int("finalMessageCount", len(agentMessages)),
zap.Bool("hasReactOutput", reactOutput != ""), zap.Bool("hasAssistantOut", assistantOut != ""),
) )
fmt.Println("agentMessages:", agentMessages) //debug
return agentMessages, nil return agentMessages, nil
} }
-1
View File
@@ -170,4 +170,3 @@ func (h *AttackChainHandler) RegenerateAttackChain(c *gin.Context) {
c.JSON(http.StatusOK, chain) c.JSON(http.StatusOK, chain)
} }
+39 -14
View File
@@ -17,6 +17,7 @@ import (
"cyberstrike-ai/internal/config" "cyberstrike-ai/internal/config"
"cyberstrike-ai/internal/knowledge" "cyberstrike-ai/internal/knowledge"
"cyberstrike-ai/internal/mcp" "cyberstrike-ai/internal/mcp"
"cyberstrike-ai/internal/mcp/builtin"
"cyberstrike-ai/internal/openai" "cyberstrike-ai/internal/openai"
"cyberstrike-ai/internal/security" "cyberstrike-ai/internal/security"
@@ -90,6 +91,7 @@ type AttackChainUpdater interface {
type AgentUpdater interface { type AgentUpdater interface {
UpdateConfig(cfg *config.OpenAIConfig) UpdateConfig(cfg *config.OpenAIConfig)
UpdateMaxIterations(maxIterations int) UpdateMaxIterations(maxIterations int)
UpdateToolDescriptionMode(mode string)
} }
// NewConfigHandler 创建新的配置处理器 // NewConfigHandler 创建新的配置处理器
@@ -232,13 +234,7 @@ func (h *ConfigHandler) GetConfig(c *gin.Context) {
if configToolMap[mcpTool.Name] { if configToolMap[mcpTool.Name] {
continue continue
} }
description := mcpTool.ShortDescription description := h.pickToolDescription(mcpTool.ShortDescription, mcpTool.Description)
if description == "" {
description = mcpTool.Description
}
if len(description) > 10000 {
description = description[:10000] + "..."
}
tools = append(tools, ToolConfigInfo{ tools = append(tools, ToolConfigInfo{
Name: mcpTool.Name, Name: mcpTool.Name,
Description: description, Description: description,
@@ -275,6 +271,11 @@ func (h *ConfigHandler) GetConfig(c *gin.Context) {
SubAgentCount: subAgentCount, SubAgentCount: subAgentCount,
Orchestration: config.NormalizeMultiAgentOrchestration(h.config.MultiAgent.Orchestration), Orchestration: config.NormalizeMultiAgentOrchestration(h.config.MultiAgent.Orchestration),
PlanExecuteLoopMaxIterations: h.config.MultiAgent.PlanExecuteLoopMaxIterations, PlanExecuteLoopMaxIterations: h.config.MultiAgent.PlanExecuteLoopMaxIterations,
ToolSearchAlwaysVisibleTools: append([]string(nil), h.config.MultiAgent.EinoMiddleware.ToolSearchAlwaysVisibleTools...),
ToolSearchAlwaysVisibleEffectiveTools: mergeToolNameLists(
h.config.MultiAgent.EinoMiddleware.ToolSearchAlwaysVisibleTools,
builtin.GetAllBuiltinTools(),
),
} }
c.JSON(http.StatusOK, GetConfigResponse{ c.JSON(http.StatusOK, GetConfigResponse{
@@ -430,13 +431,7 @@ func (h *ConfigHandler) GetTools(c *gin.Context) {
continue continue
} }
description := mcpTool.ShortDescription description := h.pickToolDescription(mcpTool.ShortDescription, mcpTool.Description)
if description == "" {
description = mcpTool.Description
}
if len(description) > 10000 {
description = description[:10000] + "..."
}
toolInfo := ToolConfigInfo{ toolInfo := ToolConfigInfo{
Name: mcpTool.Name, Name: mcpTool.Name,
@@ -689,11 +684,13 @@ func (h *ConfigHandler) UpdateConfig(c *gin.Context) {
if req.MultiAgent.PlanExecuteLoopMaxIterations != nil { if req.MultiAgent.PlanExecuteLoopMaxIterations != nil {
h.config.MultiAgent.PlanExecuteLoopMaxIterations = *req.MultiAgent.PlanExecuteLoopMaxIterations h.config.MultiAgent.PlanExecuteLoopMaxIterations = *req.MultiAgent.PlanExecuteLoopMaxIterations
} }
h.config.MultiAgent.EinoMiddleware.ToolSearchAlwaysVisibleTools = dedupeToolNameList(req.MultiAgent.ToolSearchAlwaysVisibleTools)
h.logger.Info("更新多代理配置", h.logger.Info("更新多代理配置",
zap.Bool("enabled", h.config.MultiAgent.Enabled), zap.Bool("enabled", h.config.MultiAgent.Enabled),
zap.Bool("robot_use_multi_agent", h.config.MultiAgent.RobotUseMultiAgent), zap.Bool("robot_use_multi_agent", h.config.MultiAgent.RobotUseMultiAgent),
zap.Bool("batch_use_multi_agent", h.config.MultiAgent.BatchUseMultiAgent), zap.Bool("batch_use_multi_agent", h.config.MultiAgent.BatchUseMultiAgent),
zap.Int("plan_execute_loop_max_iterations", h.config.MultiAgent.PlanExecuteLoopMaxIterations), zap.Int("plan_execute_loop_max_iterations", h.config.MultiAgent.PlanExecuteLoopMaxIterations),
zap.Int("tool_search_always_visible_tools", len(h.config.MultiAgent.EinoMiddleware.ToolSearchAlwaysVisibleTools)),
) )
} }
@@ -1061,6 +1058,7 @@ func (h *ConfigHandler) ApplyConfig(c *gin.Context) {
if h.agent != nil { if h.agent != nil {
h.agent.UpdateConfig(&h.config.OpenAI) h.agent.UpdateConfig(&h.config.OpenAI)
h.agent.UpdateMaxIterations(h.config.Agent.MaxIterations) h.agent.UpdateMaxIterations(h.config.Agent.MaxIterations)
h.agent.UpdateToolDescriptionMode(h.config.Security.ToolDescriptionMode)
h.logger.Info("Agent配置已更新") h.logger.Info("Agent配置已更新")
} }
@@ -1383,6 +1381,33 @@ func updateMultiAgentConfig(doc *yaml.Node, cfg config.MultiAgentConfig) {
setBoolInMap(maNode, "robot_use_multi_agent", cfg.RobotUseMultiAgent) setBoolInMap(maNode, "robot_use_multi_agent", cfg.RobotUseMultiAgent)
setBoolInMap(maNode, "batch_use_multi_agent", cfg.BatchUseMultiAgent) setBoolInMap(maNode, "batch_use_multi_agent", cfg.BatchUseMultiAgent)
setIntInMap(maNode, "plan_execute_loop_max_iterations", cfg.PlanExecuteLoopMaxIterations) setIntInMap(maNode, "plan_execute_loop_max_iterations", cfg.PlanExecuteLoopMaxIterations)
mwNode := ensureMap(maNode, "eino_middleware")
setFlowStringSliceInMap(mwNode, "tool_search_always_visible_tools", dedupeToolNameList(cfg.EinoMiddleware.ToolSearchAlwaysVisibleTools))
}
func dedupeToolNameList(in []string) []string {
if len(in) == 0 {
return []string{}
}
seen := make(map[string]struct{}, len(in))
out := make([]string, 0, len(in))
for _, name := range in {
n := strings.TrimSpace(name)
if n == "" {
continue
}
key := strings.ToLower(n)
if _, ok := seen[key]; ok {
continue
}
seen[key] = struct{}{}
out = append(out, n)
}
return out
}
func mergeToolNameLists(a, b []string) []string {
return dedupeToolNameList(append(append([]string{}, a...), b...))
} }
func ensureMap(parent *yaml.Node, path ...string) *yaml.Node { func ensureMap(parent *yaml.Node, path ...string) *yaml.Node {
-1
View File
@@ -230,4 +230,3 @@ func (h *ConversationHandler) DeleteConversationTurn(c *gin.Context) {
"message": "ok", "message": "ok",
}) })
} }
+7 -5
View File
@@ -175,6 +175,7 @@ func (h *AgentHandler) EinoSingleAgentLoopStream(c *gin.Context) {
) )
if runErr != nil { if runErr != nil {
h.persistEinoAgentTraceForResume(conversationID, result)
cause := context.Cause(baseCtx) cause := context.Cause(baseCtx)
if errors.Is(cause, ErrTaskCancelled) { if errors.Is(cause, ErrTaskCancelled) {
taskStatus = "cancelled" taskStatus = "cancelled"
@@ -239,9 +240,9 @@ func (h *AgentHandler) EinoSingleAgentLoopStream(c *gin.Context) {
) )
} }
if result.LastReActInput != "" || result.LastReActOutput != "" { if result.LastAgentTraceInput != "" || result.LastAgentTraceOutput != "" {
if err := h.db.SaveReActData(conversationID, result.LastReActInput, result.LastReActOutput); err != nil { if err := h.db.SaveAgentTrace(conversationID, result.LastAgentTraceInput, result.LastAgentTraceOutput); err != nil {
h.logger.Warn("保存 ReAct 数据失败", zap.Error(err)) h.logger.Warn("保存代理轨迹失败", zap.Error(err))
} }
} }
@@ -306,6 +307,7 @@ func (h *AgentHandler) EinoSingleAgentLoop(c *gin.Context) {
progressCallback, progressCallback,
) )
if runErr != nil { if runErr != nil {
h.persistEinoAgentTraceForResume(prep.ConversationID, result)
c.JSON(http.StatusInternalServerError, gin.H{"error": runErr.Error()}) c.JSON(http.StatusInternalServerError, gin.H{"error": runErr.Error()})
return return
} }
@@ -323,8 +325,8 @@ func (h *AgentHandler) EinoSingleAgentLoop(c *gin.Context) {
prep.AssistantMessageID, prep.AssistantMessageID,
) )
} }
if result.LastReActInput != "" || result.LastReActOutput != "" { if result.LastAgentTraceInput != "" || result.LastAgentTraceOutput != "" {
_ = h.db.SaveReActData(prep.ConversationID, result.LastReActInput, result.LastReActOutput) _ = h.db.SaveAgentTrace(prep.ConversationID, result.LastAgentTraceInput, result.LastAgentTraceOutput)
} }
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
-3
View File
@@ -213,7 +213,6 @@ func (h *MonitorHandler) loadStats() map[string]*mcp.ToolStats {
return stats return stats
} }
// GetExecution 获取特定执行记录 // GetExecution 获取特定执行记录
func (h *MonitorHandler) GetExecution(c *gin.Context) { func (h *MonitorHandler) GetExecution(c *gin.Context) {
id := c.Param("id") id := c.Param("id")
@@ -416,5 +415,3 @@ func (h *MonitorHandler) DeleteExecutions(c *gin.Context) {
h.logger.Info("尝试批量删除内存中的执行记录", zap.Int("count", len(request.IDs))) h.logger.Info("尝试批量删除内存中的执行记录", zap.Int("count", len(request.IDs)))
c.JSON(http.StatusOK, gin.H{"message": "执行记录已删除(如果存在)"}) c.JSON(http.StatusOK, gin.H{"message": "执行记录已删除(如果存在)"})
} }
+21 -6
View File
@@ -185,6 +185,7 @@ func (h *AgentHandler) MultiAgentLoopStream(c *gin.Context) {
) )
if runErr != nil { if runErr != nil {
h.persistEinoAgentTraceForResume(conversationID, result)
cause := context.Cause(baseCtx) cause := context.Cause(baseCtx)
if errors.Is(cause, ErrTaskCancelled) { if errors.Is(cause, ErrTaskCancelled) {
taskStatus = "cancelled" taskStatus = "cancelled"
@@ -249,9 +250,9 @@ func (h *AgentHandler) MultiAgentLoopStream(c *gin.Context) {
) )
} }
if result.LastReActInput != "" || result.LastReActOutput != "" { if result.LastAgentTraceInput != "" || result.LastAgentTraceOutput != "" {
if err := h.db.SaveReActData(conversationID, result.LastReActInput, result.LastReActOutput); err != nil { if err := h.db.SaveAgentTrace(conversationID, result.LastAgentTraceInput, result.LastAgentTraceOutput); err != nil {
h.logger.Warn("保存 ReAct 数据失败", zap.Error(err)) h.logger.Warn("保存代理轨迹失败", zap.Error(err))
} }
} }
@@ -318,6 +319,7 @@ func (h *AgentHandler) MultiAgentLoop(c *gin.Context) {
strings.TrimSpace(req.Orchestration), strings.TrimSpace(req.Orchestration),
) )
if runErr != nil { if runErr != nil {
h.persistEinoAgentTraceForResume(prep.ConversationID, result)
h.logger.Error("Eino DeepAgent 执行失败", zap.Error(runErr)) h.logger.Error("Eino DeepAgent 执行失败", zap.Error(runErr))
errMsg := "执行失败: " + runErr.Error() errMsg := "执行失败: " + runErr.Error()
if prep.AssistantMessageID != "" { if prep.AssistantMessageID != "" {
@@ -341,9 +343,9 @@ func (h *AgentHandler) MultiAgentLoop(c *gin.Context) {
) )
} }
if result.LastReActInput != "" || result.LastReActOutput != "" { if result.LastAgentTraceInput != "" || result.LastAgentTraceOutput != "" {
if err := h.db.SaveReActData(prep.ConversationID, result.LastReActInput, result.LastReActOutput); err != nil { if err := h.db.SaveAgentTrace(prep.ConversationID, result.LastAgentTraceInput, result.LastAgentTraceOutput); err != nil {
h.logger.Warn("保存 ReAct 数据失败", zap.Error(err)) h.logger.Warn("保存代理轨迹失败", zap.Error(err))
} }
} }
@@ -355,6 +357,19 @@ func (h *AgentHandler) MultiAgentLoop(c *gin.Context) {
}) })
} }
// persistEinoAgentTraceForResume 在 Eino 运行异常结束时写入代理轨迹(库列 last_react_*),供下一请求 loadHistoryFromAgentTrace 软续跑。
func (h *AgentHandler) persistEinoAgentTraceForResume(conversationID string, result *multiagent.RunResult) {
if h == nil || result == nil {
return
}
if result.LastAgentTraceInput == "" && result.LastAgentTraceOutput == "" {
return
}
if err := h.db.SaveAgentTrace(conversationID, result.LastAgentTraceInput, result.LastAgentTraceOutput); err != nil {
h.logger.Warn("保存 Eino 续跑上下文失败", zap.String("conversationId", conversationID), zap.Error(err))
}
}
func multiAgentHTTPErrorStatus(err error) (int, string) { func multiAgentHTTPErrorStatus(err error) (int, string) {
msg := err.Error() msg := err.Error()
switch { switch {
+1 -1
View File
@@ -49,7 +49,7 @@ func (h *AgentHandler) prepareMultiAgentSession(req *ChatRequest) (*multiAgentPr
} }
} }
agentHistoryMessages, err := h.loadHistoryFromReActData(conversationID) agentHistoryMessages, err := h.loadHistoryFromAgentTrace(conversationID)
if err != nil { if err != nil {
historyMessages, getErr := h.db.GetMessages(conversationID) historyMessages, getErr := h.db.GetMessages(conversationID)
if getErr != nil { if getErr != nil {
+1 -1
View File
@@ -6197,7 +6197,7 @@ func (h *OpenAPIHandler) GetConversationResults(c *gin.Context) {
} }
// 获取漏洞列表 // 获取漏洞列表
vulnList, err := h.db.ListVulnerabilities(1000, 0, "", conversationID, "", "") vulnList, err := h.db.ListVulnerabilities(1000, 0, "", conversationID, "", "", "", "", "")
if err != nil { if err != nil {
h.logger.Warn("获取漏洞列表失败", zap.Error(err)) h.logger.Warn("获取漏洞列表失败", zap.Error(err))
vulnList = []*database.Vulnerability{} vulnList = []*database.Vulnerability{}
-1
View File
@@ -109,4 +109,3 @@ func (h *TerminalHandler) RunCommandWS(c *gin.Context) {
<-doneChan <-doneChan
} }
+202 -3
View File
@@ -1,8 +1,11 @@
package handler package handler
import ( import (
"fmt"
"net/http" "net/http"
"strconv" "strconv"
"strings"
"time"
"cyberstrike-ai/internal/database" "cyberstrike-ai/internal/database"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
@@ -26,6 +29,8 @@ func NewVulnerabilityHandler(db *database.DB, logger *zap.Logger) *Vulnerability
// CreateVulnerabilityRequest 创建漏洞请求 // CreateVulnerabilityRequest 创建漏洞请求
type CreateVulnerabilityRequest struct { type CreateVulnerabilityRequest struct {
ConversationID string `json:"conversation_id" binding:"required"` ConversationID string `json:"conversation_id" binding:"required"`
ConversationTag string `json:"conversation_tag"`
TaskTag string `json:"task_tag"`
Title string `json:"title" binding:"required"` Title string `json:"title" binding:"required"`
Description string `json:"description"` Description string `json:"description"`
Severity string `json:"severity" binding:"required"` Severity string `json:"severity" binding:"required"`
@@ -47,6 +52,8 @@ func (h *VulnerabilityHandler) CreateVulnerability(c *gin.Context) {
vuln := &database.Vulnerability{ vuln := &database.Vulnerability{
ConversationID: req.ConversationID, ConversationID: req.ConversationID,
ConversationTag: req.ConversationTag,
TaskTag: req.TaskTag,
Title: req.Title, Title: req.Title,
Description: req.Description, Description: req.Description,
Severity: req.Severity, Severity: req.Severity,
@@ -100,6 +107,9 @@ func (h *VulnerabilityHandler) ListVulnerabilities(c *gin.Context) {
conversationID := c.Query("conversation_id") conversationID := c.Query("conversation_id")
severity := c.Query("severity") severity := c.Query("severity")
status := c.Query("status") status := c.Query("status")
taskID := c.Query("task_id")
conversationTag := c.Query("conversation_tag")
taskTag := c.Query("task_tag")
limit, _ := strconv.Atoi(limitStr) limit, _ := strconv.Atoi(limitStr)
offset, _ := strconv.Atoi(offsetStr) offset, _ := strconv.Atoi(offsetStr)
@@ -121,7 +131,7 @@ func (h *VulnerabilityHandler) ListVulnerabilities(c *gin.Context) {
} }
// 获取总数 // 获取总数
total, err := h.db.CountVulnerabilities(id, conversationID, severity, status) total, err := h.db.CountVulnerabilities(id, conversationID, severity, status, taskID, conversationTag, taskTag)
if err != nil { if err != nil {
h.logger.Error("获取漏洞总数失败", zap.Error(err)) h.logger.Error("获取漏洞总数失败", zap.Error(err))
// 继续执行,使用0作为总数 // 继续执行,使用0作为总数
@@ -129,7 +139,7 @@ func (h *VulnerabilityHandler) ListVulnerabilities(c *gin.Context) {
} }
// 获取漏洞列表 // 获取漏洞列表
vulnerabilities, err := h.db.ListVulnerabilities(limit, offset, id, conversationID, severity, status) vulnerabilities, err := h.db.ListVulnerabilities(limit, offset, id, conversationID, severity, status, taskID, conversationTag, taskTag)
if err != nil { if err != nil {
h.logger.Error("获取漏洞列表失败", zap.Error(err)) h.logger.Error("获取漏洞列表失败", zap.Error(err))
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
@@ -160,6 +170,8 @@ func (h *VulnerabilityHandler) ListVulnerabilities(c *gin.Context) {
// UpdateVulnerabilityRequest 更新漏洞请求 // UpdateVulnerabilityRequest 更新漏洞请求
type UpdateVulnerabilityRequest struct { type UpdateVulnerabilityRequest struct {
ConversationTag string `json:"conversation_tag"`
TaskTag string `json:"task_tag"`
Title string `json:"title"` Title string `json:"title"`
Description string `json:"description"` Description string `json:"description"`
Severity string `json:"severity"` Severity string `json:"severity"`
@@ -189,6 +201,12 @@ func (h *VulnerabilityHandler) UpdateVulnerability(c *gin.Context) {
} }
// 更新字段 // 更新字段
if req.ConversationTag != "" {
existing.ConversationTag = req.ConversationTag
}
if req.TaskTag != "" {
existing.TaskTag = req.TaskTag
}
if req.Title != "" { if req.Title != "" {
existing.Title = req.Title existing.Title = req.Title
} }
@@ -250,8 +268,9 @@ func (h *VulnerabilityHandler) DeleteVulnerability(c *gin.Context) {
// GetVulnerabilityStats 获取漏洞统计 // GetVulnerabilityStats 获取漏洞统计
func (h *VulnerabilityHandler) GetVulnerabilityStats(c *gin.Context) { func (h *VulnerabilityHandler) GetVulnerabilityStats(c *gin.Context) {
conversationID := c.Query("conversation_id") conversationID := c.Query("conversation_id")
taskID := c.Query("task_id")
stats, err := h.db.GetVulnerabilityStats(conversationID) stats, err := h.db.GetVulnerabilityStats(conversationID, taskID)
if err != nil { if err != nil {
h.logger.Error("获取漏洞统计失败", zap.Error(err)) h.logger.Error("获取漏洞统计失败", zap.Error(err))
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
@@ -261,3 +280,183 @@ func (h *VulnerabilityHandler) GetVulnerabilityStats(c *gin.Context) {
c.JSON(http.StatusOK, stats) c.JSON(http.StatusOK, stats)
} }
// GetVulnerabilityFilterOptions 获取漏洞筛选建议项
func (h *VulnerabilityHandler) GetVulnerabilityFilterOptions(c *gin.Context) {
options, err := h.db.GetVulnerabilityFilterOptions()
if err != nil {
h.logger.Error("获取漏洞筛选建议失败", zap.Error(err))
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, options)
}
// ExportVulnerabilities 导出漏洞(支持按对话/任务分组,汇总或拆分)
func (h *VulnerabilityHandler) ExportVulnerabilities(c *gin.Context) {
groupBy := c.DefaultQuery("group_by", "conversation")
mode := c.DefaultQuery("mode", "summary")
if groupBy != "conversation" && groupBy != "task" {
c.JSON(http.StatusBadRequest, gin.H{"error": "group_by 仅支持 conversation 或 task"})
return
}
if mode != "summary" && mode != "split" {
c.JSON(http.StatusBadRequest, gin.H{"error": "mode 仅支持 summary 或 split"})
return
}
id := c.Query("id")
conversationID := c.Query("conversation_id")
severity := c.Query("severity")
status := c.Query("status")
taskID := c.Query("task_id")
conversationTag := c.Query("conversation_tag")
taskTag := c.Query("task_tag")
total, err := h.db.CountVulnerabilities(id, conversationID, severity, status, taskID, conversationTag, taskTag)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
if total == 0 {
c.JSON(http.StatusOK, gin.H{"mode": mode, "group_by": groupBy, "total": 0, "files": []any{}})
return
}
items, err := h.db.ListVulnerabilities(total, 0, id, conversationID, severity, status, taskID, conversationTag, taskTag)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
type exportFile struct {
FileName string `json:"filename"`
Content string `json:"content"`
}
grouped := map[string][]*database.Vulnerability{}
for _, v := range items {
key := v.ConversationID
if groupBy == "conversation" {
if strings.TrimSpace(v.ConversationTag) != "" {
key = strings.TrimSpace(v.ConversationTag)
}
} else {
key = firstNonEmpty(v.TaskTag, v.TaskID, v.TaskQueueID, "unassigned-task")
}
grouped[key] = append(grouped[key], v)
}
files := make([]exportFile, 0)
nowStr := time.Now().Format("20060102-150405")
if mode == "summary" {
var b strings.Builder
b.WriteString("# 漏洞批量导出报告\n\n")
b.WriteString(fmt.Sprintf("- 导出时间: %s\n", time.Now().Format("2006-01-02 15:04:05")))
b.WriteString(fmt.Sprintf("- 分组维度: %s\n", groupBy))
b.WriteString(fmt.Sprintf("- 漏洞总数: %d\n", len(items)))
b.WriteString(fmt.Sprintf("- 分组数: %d\n\n", len(grouped)))
for group, list := range grouped {
b.WriteString(fmt.Sprintf("## %s (%d)\n\n", group, len(list)))
for _, v := range list {
appendVulnerabilityMarkdown(&b, v, "###")
}
}
files = append(files, exportFile{
FileName: fmt.Sprintf("vulnerability-report-%s-%s.md", groupBy, nowStr),
Content: b.String(),
})
} else {
for group, list := range grouped {
var b strings.Builder
b.WriteString(fmt.Sprintf("# 漏洞报告 - %s\n\n", group))
b.WriteString(fmt.Sprintf("- 导出时间: %s\n", time.Now().Format("2006-01-02 15:04:05")))
b.WriteString(fmt.Sprintf("- 漏洞数量: %d\n\n", len(list)))
for _, v := range list {
appendVulnerabilityMarkdown(&b, v, "##")
}
files = append(files, exportFile{
FileName: fmt.Sprintf("vulnerability-%s-%s.md", sanitizeExportName(group), nowStr),
Content: b.String(),
})
}
}
c.JSON(http.StatusOK, gin.H{
"mode": mode,
"group_by": groupBy,
"total": len(items),
"files": files,
})
}
// appendVulnerabilityMarkdown 单条漏洞的 Markdown 片段(与单文件下载字段对齐,缺省字段不写)
func appendVulnerabilityMarkdown(b *strings.Builder, v *database.Vulnerability, titleHeading string) {
b.WriteString(fmt.Sprintf("%s %s\n\n", titleHeading, v.Title))
b.WriteString(fmt.Sprintf("- 漏洞ID: `%s`\n", v.ID))
b.WriteString(fmt.Sprintf("- 严重程度: %s\n", v.Severity))
b.WriteString(fmt.Sprintf("- 状态: %s\n", v.Status))
if v.Type != "" {
b.WriteString(fmt.Sprintf("- 类型: %s\n", v.Type))
}
if v.Target != "" {
b.WriteString(fmt.Sprintf("- 目标: %s\n", v.Target))
}
b.WriteString(fmt.Sprintf("- 对话ID: `%s`\n", v.ConversationID))
if v.ConversationTag != "" {
b.WriteString(fmt.Sprintf("- 对话标签: %s\n", v.ConversationTag))
}
if v.TaskTag != "" {
b.WriteString(fmt.Sprintf("- 任务标签: %s\n", v.TaskTag))
}
if v.TaskID != "" {
b.WriteString(fmt.Sprintf("- 任务ID: `%s`\n", v.TaskID))
}
if v.TaskQueueID != "" {
b.WriteString(fmt.Sprintf("- 任务队列ID: `%s`\n", v.TaskQueueID))
}
if !v.CreatedAt.IsZero() {
b.WriteString(fmt.Sprintf("- 创建时间: %s\n", v.CreatedAt.Format("2006-01-02 15:04:05")))
}
if !v.UpdatedAt.IsZero() {
b.WriteString(fmt.Sprintf("- 更新时间: %s\n", v.UpdatedAt.Format("2006-01-02 15:04:05")))
}
if v.Description != "" {
b.WriteString("\n#### 描述\n\n")
b.WriteString(v.Description)
b.WriteString("\n")
}
if v.Proof != "" {
b.WriteString("\n#### 证明(POC\n\n```\n")
b.WriteString(v.Proof)
b.WriteString("\n```\n")
}
if v.Impact != "" {
b.WriteString("\n#### 影响\n\n")
b.WriteString(v.Impact)
b.WriteString("\n")
}
if v.Recommendation != "" {
b.WriteString("\n#### 修复建议\n\n")
b.WriteString(v.Recommendation)
b.WriteString("\n")
}
b.WriteString("\n")
}
func firstNonEmpty(values ...string) string {
for _, v := range values {
trimmed := strings.TrimSpace(v)
if trimmed != "" {
return trimmed
}
}
return ""
}
func sanitizeExportName(raw string) string {
name := strings.TrimSpace(raw)
if name == "" {
return "unknown"
}
replacer := strings.NewReplacer("/", "-", "\\", "-", ":", "-", "*", "-", "?", "-", "\"", "-", "<", "-", ">", "-", "|", "-")
return replacer.Replace(name)
}
+1 -1
View File
@@ -8,8 +8,8 @@ import (
"cyberstrike-ai/internal/config" "cyberstrike-ai/internal/config"
"github.com/cloudwego/eino/compose"
"github.com/cloudwego/eino/components/document" "github.com/cloudwego/eino/components/document"
"github.com/cloudwego/eino/compose"
"github.com/cloudwego/eino/schema" "github.com/cloudwego/eino/schema"
) )
+1 -1
View File
@@ -11,9 +11,9 @@ import (
"cyberstrike-ai/internal/config" "cyberstrike-ai/internal/config"
fileloader "github.com/cloudwego/eino-ext/components/document/loader/file" fileloader "github.com/cloudwego/eino-ext/components/document/loader/file"
"github.com/cloudwego/eino/compose"
"github.com/cloudwego/eino/components/document" "github.com/cloudwego/eino/components/document"
"github.com/cloudwego/eino/components/indexer" "github.com/cloudwego/eino/components/indexer"
"github.com/cloudwego/eino/compose"
"github.com/cloudwego/eino/schema" "github.com/cloudwego/eino/schema"
"go.uber.org/zap" "go.uber.org/zap"
) )
+139 -80
View File
@@ -6,6 +6,7 @@ import (
"errors" "errors"
"fmt" "fmt"
"io" "io"
"os"
"path/filepath" "path/filepath"
"strings" "strings"
"sync" "sync"
@@ -109,10 +110,11 @@ func runEinoADKAgentLoop(ctx context.Context, args *einoADKRunLoopArgs, baseMsgs
} }
}() }()
var lastRunMsgs []adk.Message
var lastAssistant string var lastAssistant string
var lastPlanExecuteExecutor string var lastPlanExecuteExecutor string
var retryHints []adk.Message msgs := append([]adk.Message(nil), baseMsgs...)
runAccumulatedMsgs := append([]adk.Message(nil), msgs...)
baseAccumulatedCount := len(runAccumulatedMsgs)
emptyHint := strings.TrimSpace(args.EmptyResponseMessage) emptyHint := strings.TrimSpace(args.EmptyResponseMessage)
if emptyHint == "" { if emptyHint == "" {
@@ -120,18 +122,6 @@ func runEinoADKAgentLoop(ctx context.Context, args *einoADKRunLoopArgs, baseMsgs
"(Eino 会话已完成,但未捕获到助手文本输出。请查看过程详情或日志。)" "(Eino 会话已完成,但未捕获到助手文本输出。请查看过程详情或日志。)"
} }
attemptLoop:
for attempt := 0; attempt < maxToolCallRecoveryAttempts; attempt++ {
msgs := make([]adk.Message, 0, len(baseMsgs)+len(retryHints))
msgs = append(msgs, baseMsgs...)
msgs = append(msgs, retryHints...)
if attempt > 0 {
mcpIDsMu.Lock()
*mcpIDs = (*mcpIDs)[:0]
mcpIDsMu.Unlock()
}
lastAssistant = "" lastAssistant = ""
lastPlanExecuteExecutor = "" lastPlanExecuteExecutor = ""
var reasoningStreamSeq int64 var reasoningStreamSeq int64
@@ -204,6 +194,8 @@ attemptLoop:
Agent: da, Agent: da,
EnableStreaming: true, EnableStreaming: true,
} }
var cpStore *fileCheckPointStore
var checkPointID string
if cp := strings.TrimSpace(args.CheckpointDir); cp != "" { if cp := strings.TrimSpace(args.CheckpointDir); cp != "" {
cpDir := filepath.Join(cp, sanitizeEinoPathSegment(conversationID)) cpDir := filepath.Join(cp, sanitizeEinoPathSegment(conversationID))
st, stErr := newFileCheckPointStore(cpDir) st, stErr := newFileCheckPointStore(cpDir)
@@ -212,17 +204,65 @@ attemptLoop:
logger.Warn("eino checkpoint store disabled", zap.String("dir", cpDir), zap.Error(stErr)) logger.Warn("eino checkpoint store disabled", zap.String("dir", cpDir), zap.Error(stErr))
} }
} else { } else {
cpStore = st
checkPointID = buildEinoCheckpointID(orchMode)
runnerCfg.CheckPointStore = st runnerCfg.CheckPointStore = st
if logger != nil { if logger != nil {
logger.Info("eino runner: checkpoint store enabled", zap.String("dir", cpDir)) logger.Info("eino runner: checkpoint store enabled",
zap.String("dir", cpDir),
zap.String("checkPointID", checkPointID))
} }
} }
} }
runner := adk.NewRunner(ctx, runnerCfg) runner := adk.NewRunner(ctx, runnerCfg)
iter := runner.Run(ctx, msgs) var iter *adk.AsyncIterator[*adk.AgentEvent]
handleRunErr := func(runErr error, attempt int, reasonOverride string) (retry bool, retErr error) { if cpStore != nil && checkPointID != "" {
if _, existed, getErr := cpStore.Get(ctx, checkPointID); getErr != nil {
if logger != nil {
logger.Warn("eino checkpoint preflight get failed", zap.String("checkPointID", checkPointID), zap.Error(getErr))
}
} else if existed {
if progress != nil {
progress("progress", "检测到断点,正在从中断节点恢复执行...", map[string]interface{}{
"conversationId": conversationID,
"source": "eino",
"orchestration": orchMode,
"checkPointID": checkPointID,
})
}
if logger != nil {
logger.Info("eino runner: resume from checkpoint", zap.String("checkPointID", checkPointID))
}
resumeIter, resumeErr := runner.Resume(ctx, checkPointID)
if resumeErr == nil {
iter = resumeIter
} else {
if logger != nil {
logger.Warn("eino runner: resume failed, fallback to fresh run",
zap.String("checkPointID", checkPointID),
zap.Error(resumeErr))
}
if progress != nil {
progress("progress", "断点恢复失败,已回退为全新执行。", map[string]interface{}{
"conversationId": conversationID,
"source": "eino",
"orchestration": orchMode,
"checkPointID": checkPointID,
})
}
}
}
}
if iter == nil {
if checkPointID != "" {
iter = runner.Run(ctx, msgs, adk.WithCheckPointID(checkPointID))
} else {
iter = runner.Run(ctx, msgs)
}
}
handleRunErr := func(runErr error) error {
if runErr == nil { if runErr == nil {
return false, nil return nil
} }
if errors.Is(runErr, context.DeadlineExceeded) { if errors.Is(runErr, context.DeadlineExceeded) {
flushAllPendingAsFailed(runErr) flushAllPendingAsFailed(runErr)
@@ -233,7 +273,7 @@ attemptLoop:
"errorKind": "timeout", "errorKind": "timeout",
}) })
} }
return false, runErr return runErr
} }
// context.Canceled 是唯一应当直接终止编排的错误(用户关闭页面、主动停止等)。 // context.Canceled 是唯一应当直接终止编排的错误(用户关闭页面、主动停止等)。
if errors.Is(runErr, context.Canceled) { if errors.Is(runErr, context.Canceled) {
@@ -244,7 +284,7 @@ attemptLoop:
"source": "eino", "source": "eino",
}) })
} }
return false, runErr return runErr
} }
if isEinoIterationLimitError(runErr) { if isEinoIterationLimitError(runErr) {
flushAllPendingAsFailed(runErr) flushAllPendingAsFailed(runErr)
@@ -260,12 +300,8 @@ attemptLoop:
"errorKind": "iteration_limit", "errorKind": "iteration_limit",
}) })
} }
return false, runErr return runErr
} }
canRetry := attempt+1 < maxToolCallRecoveryAttempts
if !canRetry {
// 重试次数已耗尽,终止。
flushAllPendingAsFailed(runErr) flushAllPendingAsFailed(runErr)
if progress != nil { if progress != nil {
progress("error", runErr.Error(), map[string]interface{}{ progress("error", runErr.Error(), map[string]interface{}{
@@ -273,44 +309,17 @@ attemptLoop:
"source": "eino", "source": "eino",
}) })
} }
return false, runErr return runErr
} }
// 区分错误类型以选择最合适的纠错提示,但无论哪种都执行重试(default-soft)。 takePartial := func(runErr error) (*RunResult, error) {
var hint *schema.Message if len(runAccumulatedMsgs) <= baseAccumulatedCount {
var reason, timelineMsg string return nil, runErr
switch {
case strings.TrimSpace(reasonOverride) != "":
hint = toolExecutionRetryHint()
reason = strings.TrimSpace(reasonOverride)
timelineMsg = toolExecutionRecoveryTimelineMessage(attempt)
case isRecoverableToolCallArgumentsJSONError(runErr):
hint = toolCallArgumentsJSONRetryHint()
reason = "invalid_tool_arguments_json"
timelineMsg = toolCallArgumentsJSONRecoveryTimelineMessage(attempt)
default:
hint = toolExecutionRetryHint()
reason = "tool_execution_error"
timelineMsg = toolExecutionRecoveryTimelineMessage(attempt)
} }
ids := snapshotMCPIDs()
if logger != nil { return buildEinoRunResultFromAccumulated(
logger.Warn("eino: recoverable error, will retry with corrective hint", orchMode, runAccumulatedMsgs, lastAssistant, lastPlanExecuteExecutor, emptyHint, ids, true,
zap.Error(runErr), zap.Int("attempt", attempt), zap.String("reason", reason)) ), runErr
}
flushAllPendingAsFailed(runErr)
retryHints = append(retryHints, hint)
if progress != nil {
progress("eino_recovery", timelineMsg, map[string]interface{}{
"conversationId": conversationID,
"source": "eino",
"einoRetry": attempt,
"runIndex": attempt + 1,
"maxRuns": maxToolCallRecoveryAttempts,
"reason": reason,
})
}
return true, nil
} }
for { for {
@@ -324,12 +333,25 @@ attemptLoop:
"source": "eino", "source": "eino",
}) })
} }
return nil, ctx.Err() return takePartial(ctx.Err())
default: default:
} }
ev, ok := iter.Next() ev, ok := iter.Next()
if !ok { if !ok {
// iter 结束并不总是“正常完成”:
// 当取消/超时发生在 iter.Next() 阻塞期间时,可能直接返回 !ok。
// 此时必须保留 checkpoint,避免后续恢复时被误判为“无断点”而全量重跑。
if ctxErr := ctx.Err(); ctxErr != nil {
flushAllPendingAsFailed(ctxErr)
if progress != nil {
progress("error", ctxErr.Error(), map[string]interface{}{
"conversationId": conversationID,
"source": "eino",
})
}
return takePartial(ctxErr)
}
if len(pendingByID) > 0 { if len(pendingByID) > 0 {
orphanCount := len(pendingByID) orphanCount := len(pendingByID)
flushAllPendingAsFailed(errors.New("pending tool call missing result before run completion")) flushAllPendingAsFailed(errors.New("pending tool call missing result before run completion"))
@@ -342,17 +364,21 @@ attemptLoop:
}) })
} }
} }
lastRunMsgs = msgs if cpStore != nil && checkPointID != "" {
break attemptLoop if p, pErr := cpStore.path(checkPointID); pErr == nil {
if rmErr := os.Remove(p); rmErr != nil && !os.IsNotExist(rmErr) && logger != nil {
logger.Warn("eino checkpoint cleanup failed", zap.String("path", p), zap.Error(rmErr))
}
}
}
break
} }
if ev == nil { if ev == nil {
continue continue
} }
if ev.Err != nil { if ev.Err != nil {
if retry, retErr := handleRunErr(ev.Err, attempt, ""); retErr != nil { if retErr := handleRunErr(ev.Err); retErr != nil {
return nil, retErr return takePartial(retErr)
} else if retry {
continue attemptLoop
} }
} }
if ev.AgentName != "" && progress != nil { if ev.AgentName != "" && progress != nil {
@@ -489,6 +515,7 @@ attemptLoop:
if streamsMainAssistant(ev.AgentName) { if streamsMainAssistant(ev.AgentName) {
if s := strings.TrimSpace(mainAssistantBuf.String()); s != "" { if s := strings.TrimSpace(mainAssistantBuf.String()); s != "" {
lastAssistant = s lastAssistant = s
runAccumulatedMsgs = append(runAccumulatedMsgs, schema.AssistantMessage(s, nil))
if orchMode == "plan_execute" && strings.EqualFold(strings.TrimSpace(ev.AgentName), "executor") { if orchMode == "plan_execute" && strings.EqualFold(strings.TrimSpace(ev.AgentName), "executor") {
lastPlanExecuteExecutor = UnwrapPlanExecuteUserText(s) lastPlanExecuteExecutor = UnwrapPlanExecuteUserText(s)
} }
@@ -528,10 +555,8 @@ attemptLoop:
"einoRole": einoRoleTag(ev.AgentName), "einoRole": einoRoleTag(ev.AgentName),
}) })
} }
if retry, retErr := handleRunErr(streamRecvErr, attempt, "stream_recv_error"); retErr != nil { if retErr := handleRunErr(streamRecvErr); retErr != nil {
return nil, retErr return takePartial(retErr)
} else if retry {
continue attemptLoop
} }
} }
continue continue
@@ -541,6 +566,7 @@ attemptLoop:
if gerr != nil || msg == nil { if gerr != nil || msg == nil {
continue continue
} }
runAccumulatedMsgs = append(runAccumulatedMsgs, msg)
tryEmitToolCallsOnce(mergeMessageToolCalls(msg), ev.AgentName, orchestratorName, conversationID, progress, toolEmitSeen, subAgentToolStep, markPending) tryEmitToolCallsOnce(mergeMessageToolCalls(msg), ev.AgentName, orchestratorName, conversationID, progress, toolEmitSeen, subAgentToolStep, markPending)
if mv.Role == schema.Assistant { if mv.Role == schema.Assistant {
@@ -640,13 +666,32 @@ attemptLoop:
progress("tool_result", fmt.Sprintf("工具结果 (%s)", toolName), data) progress("tool_result", fmt.Sprintf("工具结果 (%s)", toolName), data)
} }
} }
}
mcpIDsMu.Lock() mcpIDsMu.Lock()
ids := append([]string(nil), *mcpIDs...) ids := append([]string(nil), *mcpIDs...)
mcpIDsMu.Unlock() mcpIDsMu.Unlock()
histJSON, _ := json.Marshal(lastRunMsgs) out := buildEinoRunResultFromAccumulated(
orchMode, runAccumulatedMsgs, lastAssistant, lastPlanExecuteExecutor, emptyHint, ids, false,
)
return out, nil
}
func einoPartialRunLastOutputHint() string {
return "[执行未正常结束(用户停止、超时或异常)。续跑时请基于上文已产生的工具与结果继续,勿重复已完成步骤。]\n" +
"[Run ended abnormally; continue from the trace above without repeating completed steps.]"
}
func buildEinoRunResultFromAccumulated(
orchMode string,
runAccumulatedMsgs []adk.Message,
lastAssistant string,
lastPlanExecuteExecutor string,
emptyHint string,
mcpIDs []string,
partial bool,
) *RunResult {
histJSON, _ := json.Marshal(runAccumulatedMsgs)
cleaned := strings.TrimSpace(lastAssistant) cleaned := strings.TrimSpace(lastAssistant)
if orchMode == "plan_execute" { if orchMode == "plan_execute" {
if e := strings.TrimSpace(lastPlanExecuteExecutor); e != "" { if e := strings.TrimSpace(lastPlanExecuteExecutor); e != "" {
@@ -662,15 +707,29 @@ attemptLoop:
if rs := []rune(cleaned); len(rs) > maxResponseRunes { if rs := []rune(cleaned); len(rs) > maxResponseRunes {
cleaned = string(rs[:maxResponseRunes]) + "\n\n... (response truncated / 响应已截断)" cleaned = string(rs[:maxResponseRunes]) + "\n\n... (response truncated / 响应已截断)"
} }
lastOut := cleaned
resp := cleaned
if partial && cleaned == "" {
lastOut = einoPartialRunLastOutputHint()
resp = emptyHint
}
out := &RunResult{ out := &RunResult{
Response: cleaned, Response: resp,
MCPExecutionIDs: ids, MCPExecutionIDs: mcpIDs,
LastReActInput: string(histJSON), LastAgentTraceInput: string(histJSON),
LastReActOutput: cleaned, LastAgentTraceOutput: lastOut,
} }
if out.Response == "" { if !partial && out.Response == "" {
out.Response = emptyHint out.Response = emptyHint
out.LastReActOutput = out.Response out.LastAgentTraceOutput = out.Response
} }
return out, nil return out
}
func buildEinoCheckpointID(orchMode string) string {
mode := sanitizeEinoPathSegment(strings.TrimSpace(orchMode))
if mode == "" {
mode = "default"
}
return "runner-" + mode
} }
+133
View File
@@ -0,0 +1,133 @@
package multiagent
import (
"context"
"strings"
"cyberstrike-ai/internal/agent"
"github.com/bytedance/sonic"
"github.com/cloudwego/eino/adk"
"github.com/cloudwego/eino/schema"
"go.uber.org/zap"
)
type einoModelInputTelemetryMiddleware struct {
adk.BaseChatModelAgentMiddleware
logger *zap.Logger
modelName string
conversationID string
phase string
}
func newEinoModelInputTelemetryMiddleware(
logger *zap.Logger,
modelName string,
conversationID string,
phase string,
) adk.ChatModelAgentMiddleware {
if logger == nil {
return nil
}
return &einoModelInputTelemetryMiddleware{
logger: logger,
modelName: strings.TrimSpace(modelName),
conversationID: strings.TrimSpace(conversationID),
phase: strings.TrimSpace(phase),
}
}
func (m *einoModelInputTelemetryMiddleware) BeforeModelRewriteState(
ctx context.Context,
state *adk.ChatModelAgentState,
mc *adk.ModelContext,
) (context.Context, *adk.ChatModelAgentState, error) {
if m == nil || m.logger == nil || state == nil {
return ctx, state, nil
}
tokens := estimateTokensForMessagesAndTools(ctx, m.modelName, state.Messages, mcTools(mc))
m.logger.Info("eino model input estimated",
zap.String("phase", m.phase),
zap.String("conversation_id", m.conversationID),
zap.Int("messages", len(state.Messages)),
zap.Int("tools", len(mcTools(mc))),
zap.Int("input_tokens_estimated", tokens),
)
return ctx, state, nil
}
func mcTools(mc *adk.ModelContext) []*schema.ToolInfo {
if mc == nil || len(mc.Tools) == 0 {
return nil
}
return mc.Tools
}
func estimateTokensForMessagesAndTools(
_ context.Context,
modelName string,
messages []adk.Message,
tools []*schema.ToolInfo,
) int {
var sb strings.Builder
for _, msg := range messages {
if msg == nil {
continue
}
sb.WriteString(string(msg.Role))
sb.WriteByte('\n')
sb.WriteString(msg.Content)
sb.WriteByte('\n')
if msg.ReasoningContent != "" {
sb.WriteString(msg.ReasoningContent)
sb.WriteByte('\n')
}
if len(msg.ToolCalls) > 0 {
if b, err := sonic.Marshal(msg.ToolCalls); err == nil {
sb.Write(b)
sb.WriteByte('\n')
}
}
}
for _, tl := range tools {
if tl == nil {
continue
}
cp := *tl
cp.Extra = nil
if text, err := sonic.MarshalString(cp); err == nil {
sb.WriteString(text)
sb.WriteByte('\n')
}
}
text := sb.String()
if text == "" {
return 0
}
tc := agent.NewTikTokenCounter()
if n, err := tc.Count(modelName, text); err == nil {
return n
}
return (len(text) + 3) / 4
}
func logPlanExecuteModelInputEstimate(
logger *zap.Logger,
modelName string,
conversationID string,
phase string,
msgs []adk.Message,
) {
if logger == nil {
return
}
tokens := estimateTokensForMessagesAndTools(context.Background(), modelName, msgs, nil)
logger.Info("eino model input estimated",
zap.String("phase", phase),
zap.String("conversation_id", strings.TrimSpace(conversationID)),
zap.Int("messages", len(msgs)),
zap.Int("tools", 0),
zap.Int("input_tokens_estimated", tokens),
)
}
+64 -1
View File
@@ -8,6 +8,7 @@ import (
"strings" "strings"
"cyberstrike-ai/internal/config" "cyberstrike-ai/internal/config"
"cyberstrike-ai/internal/mcp/builtin"
localbk "github.com/cloudwego/eino-ext/adk/backend/local" localbk "github.com/cloudwego/eino-ext/adk/backend/local"
"github.com/cloudwego/eino/adk" "github.com/cloudwego/eino/adk"
@@ -65,6 +66,66 @@ func splitToolsForToolSearch(all []tool.BaseTool, alwaysVisible int) (static []t
return append([]tool.BaseTool(nil), all[:alwaysVisible]...), append([]tool.BaseTool(nil), all[alwaysVisible:]...), true return append([]tool.BaseTool(nil), all[:alwaysVisible]...), append([]tool.BaseTool(nil), all[alwaysVisible:]...), true
} }
func splitToolsForToolSearchByNames(all []tool.BaseTool, names []string, fallbackAlwaysVisible int) (static []tool.BaseTool, dynamic []tool.BaseTool, ok bool) {
nameSet := make(map[string]struct{}, len(names))
for _, n := range names {
n = strings.TrimSpace(strings.ToLower(n))
if n == "" {
continue
}
nameSet[n] = struct{}{}
}
if len(nameSet) == 0 {
return splitToolsForToolSearch(all, fallbackAlwaysVisible)
}
static = make([]tool.BaseTool, 0, len(all))
dynamic = make([]tool.BaseTool, 0, len(all))
for _, t := range all {
if t == nil {
continue
}
info, err := t.Info(context.Background())
name := ""
if err == nil && info != nil {
name = strings.TrimSpace(strings.ToLower(info.Name))
}
if _, keep := nameSet[name]; keep {
static = append(static, t)
continue
}
dynamic = append(dynamic, t)
}
if len(static) == 0 || len(dynamic) == 0 {
// fallback: preserve previous behavior when whitelist misses all or includes all.
return splitToolsForToolSearch(all, fallbackAlwaysVisible)
}
return static, dynamic, true
}
func mergeAlwaysVisibleToolNames(configured []string) []string {
merged := make([]string, 0, len(configured)+32)
seen := make(map[string]struct{}, len(configured)+32)
add := func(name string) {
n := strings.TrimSpace(strings.ToLower(name))
if n == "" {
return
}
if _, ok := seen[n]; ok {
return
}
seen[n] = struct{}{}
merged = append(merged, n)
}
for _, n := range configured {
add(n)
}
// Always include hardcoded backend builtin MCP tools from constants.
for _, n := range builtin.GetAllBuiltinTools() {
add(n)
}
return merged
}
func buildReductionMiddleware(ctx context.Context, mw config.MultiAgentEinoMiddlewareConfig, convID string, loc *localbk.Local, logger *zap.Logger) (adk.ChatModelAgentMiddleware, error) { func buildReductionMiddleware(ctx context.Context, mw config.MultiAgentEinoMiddlewareConfig, convID string, loc *localbk.Local, logger *zap.Logger) (adk.ChatModelAgentMiddleware, error) {
if loc == nil { if loc == nil {
return nil, fmt.Errorf("reduction: local backend nil") return nil, fmt.Errorf("reduction: local backend nil")
@@ -87,6 +148,8 @@ func buildReductionMiddleware(ctx context.Context, mw config.MultiAgentEinoMiddl
RootDir: root, RootDir: root,
ReadFileToolName: "read_file", ReadFileToolName: "read_file",
ClearExcludeTools: excl, ClearExcludeTools: excl,
MaxLengthForTrunc: mw.ReductionMaxLengthForTruncEffective(),
MaxTokensForClear: int64(mw.ReductionMaxTokensForClearEffective()),
}) })
if err != nil { if err != nil {
return nil, err return nil, err
@@ -142,7 +205,7 @@ func prependEinoMiddlewares(
alwaysVis = 12 alwaysVis = 12
} }
if mw.ToolSearchEnable && len(tools) >= minTools { if mw.ToolSearchEnable && len(tools) >= minTools {
static, dynamic, split := splitToolsForToolSearch(tools, alwaysVis) static, dynamic, split := splitToolsForToolSearchByNames(tools, mergeAlwaysVisibleToolNames(mw.ToolSearchAlwaysVisibleTools), alwaysVis)
if split && len(dynamic) > 0 { if split && len(dynamic) > 0 {
ts, terr := toolsearch.New(ctx, &toolsearch.Config{DynamicTools: dynamic}) ts, terr := toolsearch.New(ctx, &toolsearch.Config{DynamicTools: dynamic})
if terr != nil { if terr != nil {
@@ -0,0 +1,38 @@
package multiagent
import (
"context"
"fmt"
"github.com/cloudwego/eino/adk"
)
func applyBeforeModelRewriteHandlers(
ctx context.Context,
msgs []adk.Message,
handlers []adk.ChatModelAgentMiddleware,
) ([]adk.Message, error) {
if len(msgs) == 0 || len(handlers) == 0 {
return msgs, nil
}
state := &adk.ChatModelAgentState{Messages: msgs}
modelCtx := &adk.ModelContext{}
curCtx := ctx
for _, h := range handlers {
if h == nil {
continue
}
nextCtx, nextState, err := h.BeforeModelRewriteState(curCtx, state, modelCtx)
if err != nil {
return nil, fmt.Errorf("before model rewrite: %w", err)
}
if nextCtx != nil {
curCtx = nextCtx
}
if nextState != nil {
state = nextState
}
}
return state.Messages, nil
}
+167 -19
View File
@@ -5,6 +5,7 @@ import (
"fmt" "fmt"
"strings" "strings"
"cyberstrike-ai/internal/agent"
"cyberstrike-ai/internal/config" "cyberstrike-ai/internal/config"
"github.com/cloudwego/eino-ext/components/model/openai" "github.com/cloudwego/eino-ext/components/model/openai"
@@ -25,7 +26,12 @@ type PlanExecuteRootArgs struct {
LoopMaxIter int LoopMaxIter int
// AppCfg / Logger 非空时为 Executor 挂载与 Deep/Supervisor 一致的 Eino summarization 中间件。 // AppCfg / Logger 非空时为 Executor 挂载与 Deep/Supervisor 一致的 Eino summarization 中间件。
AppCfg *config.Config AppCfg *config.Config
MwCfg *config.MultiAgentEinoMiddlewareConfig
// ConversationID is used for transcript/isolation paths in middleware.
ConversationID string
Logger *zap.Logger Logger *zap.Logger
// ModelName is used for model input token estimation logs.
ModelName string
// ExecPreMiddlewares 是由 prependEinoMiddlewares 构建的前置中间件(patchtoolcalls, reduction, toolsearch, plantask), // ExecPreMiddlewares 是由 prependEinoMiddlewares 构建的前置中间件(patchtoolcalls, reduction, toolsearch, plantask),
// 与 Deep/Supervisor 主代理的 mainOrchestratorPre 一致。 // 与 Deep/Supervisor 主代理的 mainOrchestratorPre 一致。
ExecPreMiddlewares []adk.ChatModelAgentMiddleware ExecPreMiddlewares []adk.ChatModelAgentMiddleware
@@ -33,6 +39,8 @@ type PlanExecuteRootArgs struct {
SkillMiddleware adk.ChatModelAgentMiddleware SkillMiddleware adk.ChatModelAgentMiddleware
// FilesystemMiddleware 是 Eino filesystem 中间件,当 eino_skills.filesystem_tools 启用时提供本机文件读写与 Shell 能力(可选)。 // FilesystemMiddleware 是 Eino filesystem 中间件,当 eino_skills.filesystem_tools 启用时提供本机文件读写与 Shell 能力(可选)。
FilesystemMiddleware adk.ChatModelAgentMiddleware FilesystemMiddleware adk.ChatModelAgentMiddleware
// PlannerReplannerRewriteHandlers applies BeforeModelRewriteState pipeline for planner/replanner input.
PlannerReplannerRewriteHandlers []adk.ChatModelAgentMiddleware
} }
// NewPlanExecuteRoot 返回 plan → execute → replan 预置编排根节点(与 Deep / Supervisor 并列)。 // NewPlanExecuteRoot 返回 plan → execute → replan 预置编排根节点(与 Deep / Supervisor 并列)。
@@ -50,7 +58,7 @@ func NewPlanExecuteRoot(ctx context.Context, a *PlanExecuteRootArgs) (adk.Resuma
plannerCfg := &planexecute.PlannerConfig{ plannerCfg := &planexecute.PlannerConfig{
ToolCallingChatModel: tcm, ToolCallingChatModel: tcm,
} }
if fn := planExecutePlannerGenInput(a.OrchInstruction); fn != nil { if fn := planExecutePlannerGenInput(a.OrchInstruction, a.AppCfg, a.MwCfg, a.Logger, a.ModelName, a.ConversationID, a.PlannerReplannerRewriteHandlers); fn != nil {
plannerCfg.GenInputFn = fn plannerCfg.GenInputFn = fn
} }
planner, err := planexecute.NewPlanner(ctx, plannerCfg) planner, err := planexecute.NewPlanner(ctx, plannerCfg)
@@ -59,7 +67,7 @@ func NewPlanExecuteRoot(ctx context.Context, a *PlanExecuteRootArgs) (adk.Resuma
} }
replanner, err := planexecute.NewReplanner(ctx, &planexecute.ReplannerConfig{ replanner, err := planexecute.NewReplanner(ctx, &planexecute.ReplannerConfig{
ChatModel: tcm, ChatModel: tcm,
GenInputFn: planExecuteReplannerGenInput(a.OrchInstruction), GenInputFn: planExecuteReplannerGenInput(a.OrchInstruction, a.AppCfg, a.MwCfg, a.Logger, a.ModelName, a.ConversationID, a.PlannerReplannerRewriteHandlers),
}) })
if err != nil { if err != nil {
return nil, fmt.Errorf("plan_execute replanner: %w", err) return nil, fmt.Errorf("plan_execute replanner: %w", err)
@@ -81,17 +89,20 @@ func NewPlanExecuteRoot(ctx context.Context, a *PlanExecuteRootArgs) (adk.Resuma
} }
// 4. summarization(最后,与 Deep/Supervisor 一致) // 4. summarization(最后,与 Deep/Supervisor 一致)
if a.AppCfg != nil { if a.AppCfg != nil {
sumMw, sumErr := newEinoSummarizationMiddleware(ctx, a.ExecModel, a.AppCfg, a.Logger) sumMw, sumErr := newEinoSummarizationMiddleware(ctx, a.ExecModel, a.AppCfg, a.MwCfg, a.ConversationID, a.Logger)
if sumErr != nil { if sumErr != nil {
return nil, fmt.Errorf("plan_execute executor summarization: %w", sumErr) return nil, fmt.Errorf("plan_execute executor summarization: %w", sumErr)
} }
execHandlers = append(execHandlers, sumMw) execHandlers = append(execHandlers, sumMw)
} }
if teleMw := newEinoModelInputTelemetryMiddleware(a.Logger, a.ModelName, a.ConversationID, "plan_execute_executor"); teleMw != nil {
execHandlers = append(execHandlers, teleMw)
}
executor, err := newPlanExecuteExecutor(ctx, &planexecute.ExecutorConfig{ executor, err := newPlanExecuteExecutor(ctx, &planexecute.ExecutorConfig{
Model: a.ExecModel, Model: a.ExecModel,
ToolsConfig: a.ToolsCfg, ToolsConfig: a.ToolsCfg,
MaxIterations: a.ExecMaxIter, MaxIterations: a.ExecMaxIter,
GenInputFn: planExecuteExecutorGenInput(a.OrchInstruction), GenInputFn: planExecuteExecutorGenInput(a.OrchInstruction, a.AppCfg, a.MwCfg, a.Logger, a.ModelName, a.ConversationID),
}, execHandlers) }, execHandlers)
if err != nil { if err != nil {
return nil, fmt.Errorf("plan_execute executor: %w", err) return nil, fmt.Errorf("plan_execute executor: %w", err)
@@ -110,20 +121,42 @@ func NewPlanExecuteRoot(ctx context.Context, a *PlanExecuteRootArgs) (adk.Resuma
// planExecutePlannerGenInput 将 orchestrator instruction 作为 SystemMessage 注入 planner 输入。 // planExecutePlannerGenInput 将 orchestrator instruction 作为 SystemMessage 注入 planner 输入。
// 返回 nil 时 Eino 使用内置默认 planner prompt。 // 返回 nil 时 Eino 使用内置默认 planner prompt。
func planExecutePlannerGenInput(orchInstruction string) planexecute.GenPlannerModelInputFn { func planExecutePlannerGenInput(
orchInstruction string,
appCfg *config.Config,
mwCfg *config.MultiAgentEinoMiddlewareConfig,
logger *zap.Logger,
modelName string,
conversationID string,
rewriteHandlers []adk.ChatModelAgentMiddleware,
) planexecute.GenPlannerModelInputFn {
oi := strings.TrimSpace(orchInstruction) oi := strings.TrimSpace(orchInstruction)
if oi == "" { if oi == "" && appCfg == nil {
return nil return nil
} }
return func(ctx context.Context, userInput []adk.Message) ([]adk.Message, error) { return func(ctx context.Context, userInput []adk.Message) ([]adk.Message, error) {
userInput = capPlanExecuteUserInputMessages(userInput, appCfg, mwCfg)
msgs := make([]adk.Message, 0, 1+len(userInput)) msgs := make([]adk.Message, 0, 1+len(userInput))
if oi != "" {
msgs = append(msgs, schema.SystemMessage(oi)) msgs = append(msgs, schema.SystemMessage(oi))
}
msgs = append(msgs, userInput...) msgs = append(msgs, userInput...)
if rewritten, rerr := applyBeforeModelRewriteHandlers(ctx, msgs, rewriteHandlers); rerr == nil && len(rewritten) > 0 {
msgs = rewritten
}
logPlanExecuteModelInputEstimate(logger, modelName, conversationID, "plan_execute_planner", msgs)
return msgs, nil return msgs, nil
} }
} }
func planExecuteExecutorGenInput(orchInstruction string) planexecute.GenModelInputFn { func planExecuteExecutorGenInput(
orchInstruction string,
appCfg *config.Config,
mwCfg *config.MultiAgentEinoMiddlewareConfig,
logger *zap.Logger,
modelName string,
conversationID string,
) planexecute.GenModelInputFn {
oi := strings.TrimSpace(orchInstruction) oi := strings.TrimSpace(orchInstruction)
return func(ctx context.Context, in *planexecute.ExecutionContext) ([]adk.Message, error) { return func(ctx context.Context, in *planexecute.ExecutionContext) ([]adk.Message, error) {
planContent, err := in.Plan.MarshalJSON() planContent, err := in.Plan.MarshalJSON()
@@ -131,9 +164,9 @@ func planExecuteExecutorGenInput(orchInstruction string) planexecute.GenModelInp
return nil, err return nil, err
} }
userMsgs, err := planexecute.ExecutorPrompt.Format(ctx, map[string]any{ userMsgs, err := planexecute.ExecutorPrompt.Format(ctx, map[string]any{
"input": planExecuteFormatInput(in.UserInput), "input": planExecuteFormatInput(capPlanExecuteUserInputMessages(in.UserInput, appCfg, mwCfg)),
"plan": string(planContent), "plan": string(planContent),
"executed_steps": planExecuteFormatExecutedSteps(in.ExecutedSteps), "executed_steps": planExecuteFormatExecutedSteps(in.ExecutedSteps, appCfg, mwCfg),
"step": in.Plan.FirstStep(), "step": in.Plan.FirstStep(),
}) })
if err != nil { if err != nil {
@@ -142,6 +175,7 @@ func planExecuteExecutorGenInput(orchInstruction string) planexecute.GenModelInp
if oi != "" { if oi != "" {
userMsgs = append([]adk.Message{schema.SystemMessage(oi)}, userMsgs...) userMsgs = append([]adk.Message{schema.SystemMessage(oi)}, userMsgs...)
} }
logPlanExecuteModelInputEstimate(logger, modelName, conversationID, "plan_execute_executor_gen_input", userMsgs)
return userMsgs, nil return userMsgs, nil
} }
} }
@@ -155,18 +189,22 @@ func planExecuteFormatInput(input []adk.Message) string {
return sb.String() return sb.String()
} }
func planExecuteFormatExecutedSteps(results []planexecute.ExecutedStep) string { func planExecuteFormatExecutedSteps(results []planexecute.ExecutedStep, appCfg *config.Config, mwCfg *config.MultiAgentEinoMiddlewareConfig) string {
capped := capPlanExecuteExecutedSteps(results) capped := capPlanExecuteExecutedStepsWithConfig(results, mwCfg)
var sb strings.Builder return renderPlanExecuteStepsByBudget(capped, appCfg, mwCfg)
for _, result := range capped {
sb.WriteString(fmt.Sprintf("Step: %s\nResult: %s\n\n", result.Step, result.Result))
}
return sb.String()
} }
// planExecuteReplannerGenInput 与 Eino 默认 Replanner 输入一致,但 executed_steps 经 cap 后再写入 prompt // planExecuteReplannerGenInput 与 Eino 默认 Replanner 输入一致,但 executed_steps 经 cap 后再写入 prompt
// 且在 orchInstruction 非空时 prepend SystemMessage 使 replanner 也能接收全局指令。 // 且在 orchInstruction 非空时 prepend SystemMessage 使 replanner 也能接收全局指令。
func planExecuteReplannerGenInput(orchInstruction string) planexecute.GenModelInputFn { func planExecuteReplannerGenInput(
orchInstruction string,
appCfg *config.Config,
mwCfg *config.MultiAgentEinoMiddlewareConfig,
logger *zap.Logger,
modelName string,
conversationID string,
rewriteHandlers []adk.ChatModelAgentMiddleware,
) planexecute.GenModelInputFn {
oi := strings.TrimSpace(orchInstruction) oi := strings.TrimSpace(orchInstruction)
return func(ctx context.Context, in *planexecute.ExecutionContext) ([]adk.Message, error) { return func(ctx context.Context, in *planexecute.ExecutionContext) ([]adk.Message, error) {
planContent, err := in.Plan.MarshalJSON() planContent, err := in.Plan.MarshalJSON()
@@ -175,8 +213,8 @@ func planExecuteReplannerGenInput(orchInstruction string) planexecute.GenModelIn
} }
msgs, err := planexecute.ReplannerPrompt.Format(ctx, map[string]any{ msgs, err := planexecute.ReplannerPrompt.Format(ctx, map[string]any{
"plan": string(planContent), "plan": string(planContent),
"input": planExecuteFormatInput(in.UserInput), "input": planExecuteFormatInput(capPlanExecuteUserInputMessages(in.UserInput, appCfg, mwCfg)),
"executed_steps": planExecuteFormatExecutedSteps(in.ExecutedSteps), "executed_steps": planExecuteFormatExecutedSteps(in.ExecutedSteps, appCfg, mwCfg),
"plan_tool": planexecute.PlanToolInfo.Name, "plan_tool": planexecute.PlanToolInfo.Name,
"respond_tool": planexecute.RespondToolInfo.Name, "respond_tool": planexecute.RespondToolInfo.Name,
}) })
@@ -186,10 +224,120 @@ func planExecuteReplannerGenInput(orchInstruction string) planexecute.GenModelIn
if oi != "" { if oi != "" {
msgs = append([]adk.Message{schema.SystemMessage(oi)}, msgs...) msgs = append([]adk.Message{schema.SystemMessage(oi)}, msgs...)
} }
if rewritten, rerr := applyBeforeModelRewriteHandlers(ctx, msgs, rewriteHandlers); rerr == nil && len(rewritten) > 0 {
msgs = rewritten
}
logPlanExecuteModelInputEstimate(logger, modelName, conversationID, "plan_execute_replanner", msgs)
return msgs, nil return msgs, nil
} }
} }
func capPlanExecuteUserInputMessages(input []adk.Message, appCfg *config.Config, mwCfg *config.MultiAgentEinoMiddlewareConfig) []adk.Message {
if len(input) == 0 {
return input
}
maxTotal := 120000
modelName := "gpt-4o"
if appCfg != nil {
if appCfg.OpenAI.MaxTotalTokens > 0 {
maxTotal = appCfg.OpenAI.MaxTotalTokens
}
if m := strings.TrimSpace(appCfg.OpenAI.Model); m != "" {
modelName = m
}
}
// Reserve most tokens for planner/replanner prompt and tool schema.
ratio := 0.35
if mwCfg != nil {
ratio = mwCfg.PlanExecuteUserInputBudgetRatioEffective()
}
budget := int(float64(maxTotal) * ratio)
if budget < 4096 {
budget = 4096
}
tc := agent.NewTikTokenCounter()
out := make([]adk.Message, 0, len(input))
used := 0
for i := len(input) - 1; i >= 0; i-- {
msg := input[i]
if msg == nil {
continue
}
n, err := tc.Count(modelName, string(msg.Role)+"\n"+msg.Content)
if err != nil {
n = (len(msg.Content) + 3) / 4
}
if n <= 0 {
n = 1
}
if used+n > budget {
break
}
used += n
out = append(out, msg)
}
for i, j := 0, len(out)-1; i < j; i, j = i+1, j-1 {
out[i], out[j] = out[j], out[i]
}
if len(out) == 0 {
// Keep the latest user message at least.
return []adk.Message{input[len(input)-1]}
}
return out
}
func renderPlanExecuteStepsByBudget(steps []planexecute.ExecutedStep, appCfg *config.Config, mwCfg *config.MultiAgentEinoMiddlewareConfig) string {
if len(steps) == 0 {
return ""
}
maxTotal := 120000
modelName := "gpt-4o"
if appCfg != nil {
if appCfg.OpenAI.MaxTotalTokens > 0 {
maxTotal = appCfg.OpenAI.MaxTotalTokens
}
if m := strings.TrimSpace(appCfg.OpenAI.Model); m != "" {
modelName = m
}
}
ratio := 0.2
if mwCfg != nil {
ratio = mwCfg.PlanExecuteExecutedStepsBudgetRatioEffective()
}
budget := int(float64(maxTotal) * ratio)
if budget < 3072 {
budget = 3072
}
tc := agent.NewTikTokenCounter()
var kept []string
used := 0
skipped := 0
for i := len(steps) - 1; i >= 0; i-- {
block := fmt.Sprintf("Step: %s\nResult: %s\n\n", steps[i].Step, steps[i].Result)
n, err := tc.Count(modelName, block)
if err != nil {
n = (len(block) + 3) / 4
}
if n <= 0 {
n = 1
}
if used+n > budget {
skipped = i + 1
break
}
used += n
kept = append(kept, block)
}
var sb strings.Builder
if skipped > 0 {
sb.WriteString(fmt.Sprintf("Earlier executed steps omitted due to context budget: %d steps.\n\n", skipped))
}
for i := len(kept) - 1; i >= 0; i-- {
sb.WriteString(kept[i])
}
return sb.String()
}
// planExecuteStreamsMainAssistant 将规划/执行/重规划各阶段助手流式输出映射到主对话区。 // planExecuteStreamsMainAssistant 将规划/执行/重规划各阶段助手流式输出映射到主对话区。
func planExecuteStreamsMainAssistant(agent string) bool { func planExecuteStreamsMainAssistant(agent string) bool {
if agent == "" { if agent == "" {
+24 -3
View File
@@ -125,7 +125,7 @@ func RunEinoSingleChatModelAgent(
return nil, fmt.Errorf("eino single 模型: %w", err) return nil, fmt.Errorf("eino single 模型: %w", err)
} }
mainSumMw, err := newEinoSummarizationMiddleware(ctx, mainModel, appCfg, logger) mainSumMw, err := newEinoSummarizationMiddleware(ctx, mainModel, appCfg, &ma.EinoMiddleware, conversationID, logger)
if err != nil { if err != nil {
return nil, fmt.Errorf("eino single summarization: %w", err) return nil, fmt.Errorf("eino single summarization: %w", err)
} }
@@ -145,6 +145,9 @@ func RunEinoSingleChatModelAgent(
handlers = append(handlers, einoSkillMW) handlers = append(handlers, einoSkillMW)
} }
handlers = append(handlers, mainSumMw) handlers = append(handlers, mainSumMw)
if teleMw := newEinoModelInputTelemetryMiddleware(logger, appCfg.OpenAI.Model, conversationID, "eino_single"); teleMw != nil {
handlers = append(handlers, teleMw)
}
maxIter := ma.MaxIteration maxIter := ma.MaxIteration
if maxIter <= 0 { if maxIter <= 0 {
@@ -165,11 +168,29 @@ func RunEinoSingleChatModelAgent(
}, },
EmitInternalEvents: true, EmitInternalEvents: true,
} }
ins := injectToolNamesOnlyInstruction(ctx, ag.EinoSingleAgentSystemInstruction(), mainTools)
if logger != nil {
names := collectToolNames(ctx, mainTools)
mountedNames := collectToolNames(ctx, mainToolsForCfg)
hasToolSearch := false
for _, n := range names {
if strings.EqualFold(strings.TrimSpace(n), "tool_search") {
hasToolSearch = true
break
}
}
logger.Info("eino tool-name injection",
zap.String("scope", "eino_single"),
zap.Int("tool_names", len(names)),
zap.Int("mounted_tool_names", len(mountedNames)),
zap.Bool("has_tool_search", hasToolSearch),
)
}
chatCfg := &adk.ChatModelAgentConfig{ chatCfg := &adk.ChatModelAgentConfig{
Name: einoSingleAgentName, Name: einoSingleAgentName,
Description: "Eino ADK ChatModelAgent with MCP tools for authorized security testing.", Description: "Eino ADK ChatModelAgent with MCP tools for authorized security testing.",
Instruction: ag.EinoSingleAgentSystemInstruction(), Instruction: ins,
Model: mainModel, Model: mainModel,
ToolsConfig: mainToolsCfg, ToolsConfig: mainToolsCfg,
MaxIterations: maxIter, MaxIterations: maxIter,
@@ -188,7 +209,7 @@ func RunEinoSingleChatModelAgent(
return nil, fmt.Errorf("eino single NewChatModelAgent: %w", err) return nil, fmt.Errorf("eino single NewChatModelAgent: %w", err)
} }
baseMsgs := historyToMessages(history) baseMsgs := historyToMessages(history, appCfg, &ma.EinoMiddleware)
baseMsgs = append(baseMsgs, schema.UserMessage(userMessage)) baseMsgs = append(baseMsgs, schema.UserMessage(userMessage))
streamsMainAssistant := func(agent string) bool { streamsMainAssistant := func(agent string) bool {
+145 -3
View File
@@ -3,6 +3,8 @@ package multiagent
import ( import (
"context" "context"
"fmt" "fmt"
"os"
"path/filepath"
"strings" "strings"
"cyberstrike-ai/internal/agent" "cyberstrike-ai/internal/agent"
@@ -32,6 +34,8 @@ func newEinoSummarizationMiddleware(
ctx context.Context, ctx context.Context,
summaryModel model.BaseChatModel, summaryModel model.BaseChatModel,
appCfg *config.Config, appCfg *config.Config,
mwCfg *config.MultiAgentEinoMiddlewareConfig,
conversationID string,
logger *zap.Logger, logger *zap.Logger,
) (adk.ChatModelAgentMiddleware, error) { ) (adk.ChatModelAgentMiddleware, error) {
if summaryModel == nil || appCfg == nil { if summaryModel == nil || appCfg == nil {
@@ -41,7 +45,14 @@ func newEinoSummarizationMiddleware(
if maxTotal <= 0 { if maxTotal <= 0 {
maxTotal = 120000 maxTotal = 120000
} }
trigger := int(float64(maxTotal) * 0.9) triggerRatio := 0.8
emitInternalEvents := true
if mwCfg != nil {
triggerRatio = mwCfg.SummarizationTriggerRatioEffective()
emitInternalEvents = mwCfg.SummarizationEmitInternalEventsEffective()
}
// Keep enough safety margin for tokenizer/model-side accounting mismatch.
trigger := int(float64(maxTotal) * triggerRatio)
if trigger < 4096 { if trigger < 4096 {
trigger = maxTotal trigger = maxTotal
if trigger < 4096 { if trigger < 4096 {
@@ -57,28 +68,57 @@ func newEinoSummarizationMiddleware(
if modelName == "" { if modelName == "" {
modelName = "gpt-4o" modelName = "gpt-4o"
} }
tokenCounter := einoSummarizationTokenCounter(modelName)
recentTrailMax := trigger / 4
if recentTrailMax < 2048 {
recentTrailMax = 2048
}
if recentTrailMax > trigger/2 {
recentTrailMax = trigger / 2
}
transcriptPath := ""
if conv := strings.TrimSpace(conversationID); conv != "" {
baseRoot := filepath.Join(os.TempDir(), "cyberstrike-summarization")
if dbPath := strings.TrimSpace(appCfg.Database.Path); dbPath != "" {
// Persist with the same lifecycle as local conversation storage.
baseRoot = filepath.Join(filepath.Dir(dbPath), "conversation_artifacts", sanitizeEinoPathSegment(conv), "summarization")
}
base := baseRoot
if mkErr := os.MkdirAll(base, 0o755); mkErr == nil {
transcriptPath = filepath.Join(base, "transcript.txt")
}
}
mw, err := summarization.New(ctx, &summarization.Config{ mw, err := summarization.New(ctx, &summarization.Config{
Model: summaryModel, Model: summaryModel,
Trigger: &summarization.TriggerCondition{ Trigger: &summarization.TriggerCondition{
ContextTokens: trigger, ContextTokens: trigger,
}, },
TokenCounter: einoSummarizationTokenCounter(modelName), TokenCounter: tokenCounter,
UserInstruction: einoSummarizeUserInstruction, UserInstruction: einoSummarizeUserInstruction,
EmitInternalEvents: false, EmitInternalEvents: emitInternalEvents,
TranscriptFilePath: transcriptPath,
PreserveUserMessages: &summarization.PreserveUserMessages{ PreserveUserMessages: &summarization.PreserveUserMessages{
Enabled: true, Enabled: true,
MaxTokens: preserveMax, MaxTokens: preserveMax,
}, },
Finalize: func(ctx context.Context, originalMessages []adk.Message, summary adk.Message) ([]adk.Message, error) {
return summarizeFinalizeWithRecentAssistantToolTrail(ctx, originalMessages, summary, tokenCounter, recentTrailMax)
},
Callback: func(ctx context.Context, before, after adk.ChatModelAgentState) error { Callback: func(ctx context.Context, before, after adk.ChatModelAgentState) error {
if logger == nil { if logger == nil {
return nil return nil
} }
beforeTokens, _ := tokenCounter(ctx, &summarization.TokenCounterInput{Messages: before.Messages})
afterTokens, _ := tokenCounter(ctx, &summarization.TokenCounterInput{Messages: after.Messages})
logger.Info("eino summarization 已压缩上下文", logger.Info("eino summarization 已压缩上下文",
zap.Int("messages_before", len(before.Messages)), zap.Int("messages_before", len(before.Messages)),
zap.Int("messages_after", len(after.Messages)), zap.Int("messages_after", len(after.Messages)),
zap.Int("tokens_before_estimated", beforeTokens),
zap.Int("tokens_after_estimated", afterTokens),
zap.Int("max_total_tokens", maxTotal), zap.Int("max_total_tokens", maxTotal),
zap.Int("trigger_context_tokens", trigger), zap.Int("trigger_context_tokens", trigger),
zap.String("transcript_file", transcriptPath),
) )
return nil return nil
}, },
@@ -89,6 +129,108 @@ func newEinoSummarizationMiddleware(
return mw, nil return mw, nil
} }
// summarizeFinalizeWithRecentAssistantToolTrail 在摘要消息后保留最近 assistant/tool 轨迹,避免压缩后执行链断裂。
func summarizeFinalizeWithRecentAssistantToolTrail(
ctx context.Context,
originalMessages []adk.Message,
summary adk.Message,
tokenCounter summarization.TokenCounterFunc,
recentTrailTokenBudget int,
) ([]adk.Message, error) {
systemMsgs := make([]adk.Message, 0, len(originalMessages))
nonSystem := make([]adk.Message, 0, len(originalMessages))
for _, msg := range originalMessages {
if msg == nil {
continue
}
if msg.Role == schema.System {
systemMsgs = append(systemMsgs, msg)
continue
}
nonSystem = append(nonSystem, msg)
}
if recentTrailTokenBudget <= 0 || len(nonSystem) == 0 {
out := make([]adk.Message, 0, len(systemMsgs)+1)
out = append(out, systemMsgs...)
out = append(out, summary)
return out, nil
}
selectedReverse := make([]adk.Message, 0, 8)
seen := make(map[adk.Message]struct{})
totalTokens := 0
assistantToolKept := 0
const minAssistantToolTrail = 4
tryKeep := func(msg adk.Message) (bool, error) {
if msg == nil {
return false, nil
}
if _, ok := seen[msg]; ok {
return false, nil
}
n, err := tokenCounter(ctx, &summarization.TokenCounterInput{Messages: []adk.Message{msg}})
if err != nil {
return false, err
}
if n <= 0 {
n = 1
}
if totalTokens+n > recentTrailTokenBudget {
return false, nil
}
totalTokens += n
selectedReverse = append(selectedReverse, msg)
seen[msg] = struct{}{}
return true, nil
}
// 优先保留最近 assistant/tool,确保执行轨迹可续跑。
for i := len(nonSystem) - 1; i >= 0; i-- {
msg := nonSystem[i]
if msg.Role != schema.Assistant && msg.Role != schema.Tool {
continue
}
ok, err := tryKeep(msg)
if err != nil {
return nil, err
}
if ok {
assistantToolKept++
}
if assistantToolKept >= minAssistantToolTrail {
break
}
}
// 在预算内回填更多最近消息,保持短链路上下文。
for i := len(nonSystem) - 1; i >= 0; i-- {
_, exists := seen[nonSystem[i]]
if exists {
continue
}
ok, err := tryKeep(nonSystem[i])
if err != nil {
return nil, err
}
if !ok {
break
}
}
selected := make([]adk.Message, 0, len(selectedReverse))
for i := len(selectedReverse) - 1; i >= 0; i-- {
selected = append(selected, selectedReverse[i])
}
out := make([]adk.Message, 0, len(systemMsgs)+1+len(selected))
out = append(out, systemMsgs...)
out = append(out, summary)
out = append(out, selected...)
return out, nil
}
func einoSummarizationTokenCounter(openAIModel string) summarization.TokenCounterFunc { func einoSummarizationTokenCounter(openAIModel string) summarization.TokenCounterFunc {
tc := agent.NewTikTokenCounter() tc := agent.NewTikTokenCounter()
return func(ctx context.Context, input *summarization.TokenCounterInput) (int, error) { return func(ctx context.Context, input *summarization.TokenCounterInput) (int, error) {
@@ -0,0 +1,73 @@
package multiagent
import (
"context"
"strings"
"github.com/cloudwego/eino/components/tool"
)
// injectToolNamesOnlyInstruction prepends a compact tool-name-only section into
// the system instruction so the model can reference current callable names.
func injectToolNamesOnlyInstruction(ctx context.Context, instruction string, tools []tool.BaseTool) string {
names := collectToolNames(ctx, tools)
if len(names) == 0 {
return strings.TrimSpace(instruction)
}
hasToolSearch := false
for _, n := range names {
if strings.EqualFold(strings.TrimSpace(n), "tool_search") {
hasToolSearch = true
break
}
}
var sb strings.Builder
sb.WriteString("以下是当前会话中可调用的工具名称列表(仅名称,无参数定义):\n")
for _, name := range names {
sb.WriteString("- ")
sb.WriteString(name)
sb.WriteByte('\n')
}
sb.WriteString("\n使用规则:\n")
sb.WriteString("1) 上述仅为名称列表,不包含参数定义。\n")
if hasToolSearch {
sb.WriteString("2) 在调用具体工具前,应先使用 tool_search 查看工具详情与参数要求,再发起调用。\n")
} else {
sb.WriteString("2) 调用具体工具前,请先确认该工具的参数要求;不确定时先澄清再调用。\n")
}
sb.WriteString("3) 不要臆造不存在的工具名。\n\n")
if s := strings.TrimSpace(instruction); s != "" {
sb.WriteString(s)
}
return sb.String()
}
func collectToolNames(ctx context.Context, tools []tool.BaseTool) []string {
if len(tools) == 0 {
return nil
}
seen := make(map[string]struct{}, len(tools))
out := make([]string, 0, len(tools))
for _, t := range tools {
if t == nil {
continue
}
info, err := t.Info(ctx)
if err != nil || info == nil {
continue
}
name := strings.TrimSpace(info.Name)
if name == "" {
continue
}
key := strings.ToLower(name)
if _, ok := seen[key]; ok {
continue
}
seen[key] = struct{}{}
out = append(out, name)
}
return out
}
-1
View File
@@ -59,4 +59,3 @@ func (m *noNestedTaskMiddleware) WrapInvokableToolCall(
return endpoint(ctx2, argumentsInJSON, opts...) return endpoint(ctx2, argumentsInJSON, opts...)
}, nil }, nil
} }
+1 -1
View File
@@ -71,7 +71,7 @@ func planExecuteDefaultGenExecutorInput(ctx context.Context, in *planexecute.Exe
return planexecute.ExecutorPrompt.Format(ctx, map[string]any{ return planexecute.ExecutorPrompt.Format(ctx, map[string]any{
"input": planExecuteFormatInput(in.UserInput), "input": planExecuteFormatInput(in.UserInput),
"plan": string(planContent), "plan": string(planContent),
"executed_steps": planExecuteFormatExecutedSteps(in.ExecutedSteps), "executed_steps": planExecuteFormatExecutedSteps(in.ExecutedSteps, nil, nil),
"step": in.Plan.FirstStep(), "step": in.Plan.FirstStep(),
}) })
} }
+22 -7
View File
@@ -5,6 +5,8 @@ import (
"strings" "strings"
"unicode/utf8" "unicode/utf8"
"cyberstrike-ai/internal/config"
"github.com/cloudwego/eino/adk/prebuilt/planexecute" "github.com/cloudwego/eino/adk/prebuilt/planexecute"
) )
@@ -12,8 +14,11 @@ import (
// 此处仅约束「写入模型 prompt 的视图」,不修改 Eino session 中的原始 ExecutedSteps。 // 此处仅约束「写入模型 prompt 的视图」,不修改 Eino session 中的原始 ExecutedSteps。
const ( const (
planExecuteMaxStepResultRunes = 12000 defaultPlanExecuteMaxStepResultRunes = 4000
planExecuteKeepLastSteps = 16 defaultPlanExecuteKeepLastSteps = 8
// Backward-compatible aliases for tests and existing references.
planExecuteMaxStepResultRunes = defaultPlanExecuteMaxStepResultRunes
planExecuteKeepLastSteps = defaultPlanExecuteKeepLastSteps
) )
func truncateRunesWithSuffix(s string, maxRunes int, suffix string) string { func truncateRunesWithSuffix(s string, maxRunes int, suffix string) string {
@@ -29,16 +34,26 @@ func truncateRunesWithSuffix(s string, maxRunes int, suffix string) string {
// capPlanExecuteExecutedSteps 折叠较早步骤、截断单步过长结果,供 prompt 使用。 // capPlanExecuteExecutedSteps 折叠较早步骤、截断单步过长结果,供 prompt 使用。
func capPlanExecuteExecutedSteps(steps []planexecute.ExecutedStep) []planexecute.ExecutedStep { func capPlanExecuteExecutedSteps(steps []planexecute.ExecutedStep) []planexecute.ExecutedStep {
return capPlanExecuteExecutedStepsWithConfig(steps, nil)
}
func capPlanExecuteExecutedStepsWithConfig(steps []planexecute.ExecutedStep, mwCfg *config.MultiAgentEinoMiddlewareConfig) []planexecute.ExecutedStep {
if len(steps) == 0 { if len(steps) == 0 {
return steps return steps
} }
maxStepResultRunes := defaultPlanExecuteMaxStepResultRunes
keepLastSteps := defaultPlanExecuteKeepLastSteps
if mwCfg != nil {
maxStepResultRunes = mwCfg.PlanExecuteMaxStepResultRunesEffective()
keepLastSteps = mwCfg.PlanExecuteKeepLastStepsEffective()
}
out := make([]planexecute.ExecutedStep, 0, len(steps)+1) out := make([]planexecute.ExecutedStep, 0, len(steps)+1)
start := 0 start := 0
if len(steps) > planExecuteKeepLastSteps { if len(steps) > keepLastSteps {
start = len(steps) - planExecuteKeepLastSteps start = len(steps) - keepLastSteps
var b strings.Builder var b strings.Builder
b.WriteString(fmt.Sprintf("(上文已完成 %d 步;此处仅保留步骤标题以节省上下文,完整输出已省略。后续 %d 步仍保留正文。)\n", b.WriteString(fmt.Sprintf("(上文已完成 %d 步;此处仅保留步骤标题以节省上下文,完整输出已省略。后续 %d 步仍保留正文。)\n",
start, planExecuteKeepLastSteps)) start, keepLastSteps))
for i := 0; i < start; i++ { for i := 0; i < start; i++ {
b.WriteString(fmt.Sprintf("- %s\n", steps[i].Step)) b.WriteString(fmt.Sprintf("- %s\n", steps[i].Step))
} }
@@ -50,8 +65,8 @@ func capPlanExecuteExecutedSteps(steps []planexecute.ExecutedStep) []planexecute
suffix := "\n…[step result truncated]" suffix := "\n…[step result truncated]"
for i := start; i < len(steps); i++ { for i := start; i < len(steps); i++ {
e := steps[i] e := steps[i]
if utf8.RuneCountInString(e.Result) > planExecuteMaxStepResultRunes { if utf8.RuneCountInString(e.Result) > maxStepResultRunes {
e.Result = truncateRunesWithSuffix(e.Result, planExecuteMaxStepResultRunes, suffix) e.Result = truncateRunesWithSuffix(e.Result, maxStepResultRunes, suffix)
} }
out = append(out, e) out = append(out, e)
} }
+109 -12
View File
@@ -32,8 +32,8 @@ import (
type RunResult struct { type RunResult struct {
Response string Response string
MCPExecutionIDs []string MCPExecutionIDs []string
LastReActInput string LastAgentTraceInput string // 已序列化的消息带(JSON):原生循环或 Eino 均写入,供续跑/攻击链等恢复上下文
LastReActOutput string LastAgentTraceOutput string // 本轮助手侧对外展示文本(摘要或最终回复)
} }
// toolCallPendingInfo tracks a tool_call emitted to the UI so we can later // toolCallPendingInfo tracks a tool_call emitted to the UI so we can later
@@ -237,7 +237,7 @@ func RunDeepAgent(
subMax = subDefaultIter subMax = subDefaultIter
} }
subSumMw, err := newEinoSummarizationMiddleware(ctx, subModel, appCfg, logger) subSumMw, err := newEinoSummarizationMiddleware(ctx, subModel, appCfg, &ma.EinoMiddleware, conversationID, logger)
if err != nil { if err != nil {
return nil, fmt.Errorf("子代理 %q summarization 中间件: %w", id, err) return nil, fmt.Errorf("子代理 %q summarization 中间件: %w", id, err)
} }
@@ -257,11 +257,33 @@ func RunDeepAgent(
subHandlers = append(subHandlers, einoSkillMW) subHandlers = append(subHandlers, einoSkillMW)
} }
subHandlers = append(subHandlers, subSumMw) subHandlers = append(subHandlers, subSumMw)
if teleMw := newEinoModelInputTelemetryMiddleware(logger, appCfg.OpenAI.Model, conversationID, "sub_agent"); teleMw != nil {
subHandlers = append(subHandlers, teleMw)
}
subInstrFinal := injectToolNamesOnlyInstruction(ctx, instr, subTools)
if logger != nil {
subNames := collectToolNames(ctx, subTools)
mountedNames := collectToolNames(ctx, subToolsForCfg)
hasToolSearch := false
for _, n := range subNames {
if strings.EqualFold(strings.TrimSpace(n), "tool_search") {
hasToolSearch = true
break
}
}
logger.Info("eino tool-name injection",
zap.String("scope", "sub_agent"),
zap.String("agent", id),
zap.Int("tool_names", len(subNames)),
zap.Int("mounted_tool_names", len(mountedNames)),
zap.Bool("has_tool_search", hasToolSearch),
)
}
sa, err := adk.NewChatModelAgent(ctx, &adk.ChatModelAgentConfig{ sa, err := adk.NewChatModelAgent(ctx, &adk.ChatModelAgentConfig{
Name: id, Name: id,
Description: desc, Description: desc,
Instruction: instr, Instruction: subInstrFinal,
Model: subModel, Model: subModel,
ToolsConfig: adk.ToolsConfig{ ToolsConfig: adk.ToolsConfig{
ToolsNodeConfig: compose.ToolsNodeConfig{ ToolsNodeConfig: compose.ToolsNodeConfig{
@@ -289,7 +311,7 @@ func RunDeepAgent(
return nil, fmt.Errorf("多代理主模型: %w", err) return nil, fmt.Errorf("多代理主模型: %w", err)
} }
mainSumMw, err := newEinoSummarizationMiddleware(ctx, mainModel, appCfg, logger) mainSumMw, err := newEinoSummarizationMiddleware(ctx, mainModel, appCfg, &ma.EinoMiddleware, conversationID, logger)
if err != nil { if err != nil {
return nil, fmt.Errorf("多代理主 summarization 中间件: %w", err) return nil, fmt.Errorf("多代理主 summarization 中间件: %w", err)
} }
@@ -313,6 +335,25 @@ func RunDeepAgent(
orchDescription = d orchDescription = d
} }
} }
orchInstruction = injectToolNamesOnlyInstruction(ctx, orchInstruction, mainTools)
if logger != nil {
mainNames := collectToolNames(ctx, mainTools)
mountedNames := collectToolNames(ctx, mainToolsForCfg)
hasToolSearch := false
for _, n := range mainNames {
if strings.EqualFold(strings.TrimSpace(n), "tool_search") {
hasToolSearch = true
break
}
}
logger.Info("eino tool-name injection",
zap.String("scope", "orchestrator"),
zap.String("orchestration", orchMode),
zap.Int("tool_names", len(mainNames)),
zap.Int("mounted_tool_names", len(mountedNames)),
zap.Bool("has_tool_search", hasToolSearch),
)
}
supInstr := strings.TrimSpace(orchInstruction) supInstr := strings.TrimSpace(orchInstruction)
if orchMode == "supervisor" { if orchMode == "supervisor" {
@@ -352,6 +393,9 @@ func RunDeepAgent(
deepHandlers = append(deepHandlers, einoSkillMW) deepHandlers = append(deepHandlers, einoSkillMW)
} }
deepHandlers = append(deepHandlers, mainSumMw) deepHandlers = append(deepHandlers, mainSumMw)
if teleMw := newEinoModelInputTelemetryMiddleware(logger, appCfg.OpenAI.Model, conversationID, "deep_orchestrator"); teleMw != nil {
deepHandlers = append(deepHandlers, teleMw)
}
supHandlers := []adk.ChatModelAgentMiddleware{} supHandlers := []adk.ChatModelAgentMiddleware{}
if len(mainOrchestratorPre) > 0 { if len(mainOrchestratorPre) > 0 {
@@ -361,6 +405,9 @@ func RunDeepAgent(
supHandlers = append(supHandlers, einoSkillMW) supHandlers = append(supHandlers, einoSkillMW)
} }
supHandlers = append(supHandlers, mainSumMw) supHandlers = append(supHandlers, mainSumMw)
if teleMw := newEinoModelInputTelemetryMiddleware(logger, appCfg.OpenAI.Model, conversationID, "supervisor_orchestrator"); teleMw != nil {
supHandlers = append(supHandlers, teleMw)
}
mainToolsCfg := adk.ToolsConfig{ mainToolsCfg := adk.ToolsConfig{
ToolsNodeConfig: compose.ToolsNodeConfig{ ToolsNodeConfig: compose.ToolsNodeConfig{
@@ -399,10 +446,17 @@ func RunDeepAgent(
ExecMaxIter: deepMaxIter, ExecMaxIter: deepMaxIter,
LoopMaxIter: ma.PlanExecuteLoopMaxIterations, LoopMaxIter: ma.PlanExecuteLoopMaxIterations,
AppCfg: appCfg, AppCfg: appCfg,
MwCfg: &ma.EinoMiddleware,
ConversationID: conversationID,
Logger: logger, Logger: logger,
ModelName: appCfg.OpenAI.Model,
ExecPreMiddlewares: mainOrchestratorPre, ExecPreMiddlewares: mainOrchestratorPre,
SkillMiddleware: einoSkillMW, SkillMiddleware: einoSkillMW,
FilesystemMiddleware: peFsMw, FilesystemMiddleware: peFsMw,
PlannerReplannerRewriteHandlers: []adk.ChatModelAgentMiddleware{
mainSumMw,
newEinoModelInputTelemetryMiddleware(logger, appCfg.OpenAI.Model, conversationID, "plan_execute_planner_replanner_rewrite"),
},
}) })
if perr != nil { if perr != nil {
return nil, perr return nil, perr
@@ -468,7 +522,7 @@ func RunDeepAgent(
da = dDeep da = dDeep
} }
baseMsgs := historyToMessages(history) baseMsgs := historyToMessages(history, appCfg, &ma.EinoMiddleware)
baseMsgs = append(baseMsgs, schema.UserMessage(userMessage)) baseMsgs = append(baseMsgs, schema.UserMessage(userMessage))
streamsMainAssistant := func(agent string) bool { streamsMainAssistant := func(agent string) bool {
@@ -505,34 +559,77 @@ func RunDeepAgent(
}, baseMsgs) }, baseMsgs)
} }
func historyToMessages(history []agent.ChatMessage) []adk.Message { func historyToMessages(history []agent.ChatMessage, appCfg *config.Config, mwCfg *config.MultiAgentEinoMiddlewareConfig) []adk.Message {
if len(history) == 0 { if len(history) == 0 {
return nil return nil
} }
// 放宽条数上限:跨轮历史交给 Eino Summarization(阈值对齐 openai.max_total_tokens)在调用模型前压缩,避免在入队前硬截断为 40 条。 // Keep a bounded tail first; then enforce a token budget.
const maxHistoryMessages = 300 const maxHistoryMessages = 200
start := 0 start := 0
if len(history) > maxHistoryMessages { if len(history) > maxHistoryMessages {
start = len(history) - maxHistoryMessages start = len(history) - maxHistoryMessages
} }
out := make([]adk.Message, 0, len(history[start:])) raw := make([]adk.Message, 0, len(history[start:]))
for _, h := range history[start:] { for _, h := range history[start:] {
switch h.Role { switch h.Role {
case "user": case "user":
if strings.TrimSpace(h.Content) != "" { if strings.TrimSpace(h.Content) != "" {
out = append(out, schema.UserMessage(h.Content)) raw = append(raw, schema.UserMessage(h.Content))
} }
case "assistant": case "assistant":
if strings.TrimSpace(h.Content) == "" && len(h.ToolCalls) > 0 { if strings.TrimSpace(h.Content) == "" && len(h.ToolCalls) > 0 {
continue continue
} }
if strings.TrimSpace(h.Content) != "" { if strings.TrimSpace(h.Content) != "" {
out = append(out, schema.AssistantMessage(h.Content, nil)) raw = append(raw, schema.AssistantMessage(h.Content, nil))
} }
default: default:
continue continue
} }
} }
if len(raw) == 0 {
return raw
}
maxTotal := 120000
modelName := "gpt-4o"
if appCfg != nil {
if appCfg.OpenAI.MaxTotalTokens > 0 {
maxTotal = appCfg.OpenAI.MaxTotalTokens
}
if m := strings.TrimSpace(appCfg.OpenAI.Model); m != "" {
modelName = m
}
}
ratio := 0.35
if mwCfg != nil {
ratio = mwCfg.HistoryInputBudgetRatioEffective()
}
budget := int(float64(maxTotal) * ratio)
if budget < 4096 {
budget = 4096
}
tc := agent.NewTikTokenCounter()
outRev := make([]adk.Message, 0, len(raw))
used := 0
for i := len(raw) - 1; i >= 0; i-- {
msg := raw[i]
n, err := tc.Count(modelName, string(msg.Role)+"\n"+msg.Content)
if err != nil {
n = (len(msg.Content) + 3) / 4
}
if n <= 0 {
n = 1
}
if used+n > budget {
break
}
used += n
outRev = append(outRev, msg)
}
out := make([]adk.Message, 0, len(outRev))
for i := len(outRev) - 1; i >= 0; i-- {
out = append(out, outRev[i])
}
return out return out
} }
@@ -1,51 +0,0 @@
package multiagent
import (
"fmt"
"strings"
"github.com/cloudwego/eino/schema"
)
// maxToolCallRecoveryAttempts 含首次运行:首次 + 自动重试次数。
// 例如为 3 表示最多共 3 次完整 DeepAgent 运行(2 次失败后各追加一条纠错提示)。
// 该常量同时用于 JSON 参数错误和工具执行错误(如子代理名称不存在)的恢复重试。
const maxToolCallRecoveryAttempts = 5
// toolCallArgumentsJSONRetryHint 追加在用户消息后,提示模型输出合法 JSON 工具参数(部分云厂商会在流式阶段校验 arguments)。
func toolCallArgumentsJSONRetryHint() *schema.Message {
return schema.UserMessage(`[系统提示] 上一次输出中工具调用的 function.arguments 不是合法 JSON接口已拒绝请重新生成每个 tool call arguments 必须是完整可解析的 JSON 对象字符串键名用双引号无多余逗号括号配对不要输出截断或不完整的 JSON
[System] Your previous tool call used invalid JSON in function.arguments and was rejected by the API. Regenerate with strictly valid JSON objects only (double-quoted keys, matched braces, no trailing commas).`)
}
// toolCallArgumentsJSONRecoveryTimelineMessage 供 eino_recovery 事件落库与前端时间线展示。
func toolCallArgumentsJSONRecoveryTimelineMessage(attempt int) string {
return fmt.Sprintf(
"接口拒绝了无效的工具参数 JSON。已向对话追加系统提示并要求模型重新生成合法的 function.arguments。"+
"当前为第 %d/%d 轮完整运行。\n\n"+
"The API rejected invalid JSON in tool arguments. A system hint was appended. This is full run %d of %d.",
attempt+1, maxToolCallRecoveryAttempts, attempt+1, maxToolCallRecoveryAttempts,
)
}
// isRecoverableToolCallArgumentsJSONError 判断是否为「工具参数非合法 JSON」类流式错误,可通过追加提示后重跑一轮。
func isRecoverableToolCallArgumentsJSONError(err error) bool {
if err == nil {
return false
}
s := strings.ToLower(err.Error())
if !strings.Contains(s, "json") {
return false
}
if strings.Contains(s, "function.arguments") || strings.Contains(s, "function arguments") {
return true
}
if strings.Contains(s, "invalidparameter") && strings.Contains(s, "json") {
return true
}
if strings.Contains(s, "must be in json format") {
return true
}
return false
}
@@ -1,17 +0,0 @@
package multiagent
import (
"errors"
"testing"
)
func TestIsRecoverableToolCallArgumentsJSONError(t *testing.T) {
yes := errors.New(`failed to receive stream chunk: error, <400> InternalError.Algo.InvalidParameter: The "function.arguments" parameter of the code model must be in JSON format.`)
if !isRecoverableToolCallArgumentsJSONError(yes) {
t.Fatal("expected recoverable for function.arguments + JSON")
}
no := errors.New("unrelated network failure")
if isRecoverableToolCallArgumentsJSONError(no) {
t.Fatal("expected not recoverable")
}
}
@@ -1,44 +0,0 @@
package multiagent
import (
"fmt"
"github.com/cloudwego/eino/schema"
)
// toolExecutionRetryHint returns a user message appended to the conversation to prompt
// the LLM to adjust after a tool execution error (tool not found, binary missing,
// runtime failure, network error, etc.).
func toolExecutionRetryHint() *schema.Message {
return schema.UserMessage(`[System] Your previous tool call failed. Possible causes:
- The tool or sub-agent name does not exist (typo or unregistered name).
- The tool call arguments were not valid JSON.
- The tool's underlying binary is not installed or not in PATH.
- The tool encountered a runtime error (timeout, network failure, permission denied, etc.).
Please review the error message above, check available tools, and either:
1. Retry with corrected arguments or a different tool, OR
2. Inform the user about the limitation and proceed with an alternative approach.
[系统提示] 上一次工具调用失败可能原因
- 工具名或子代理名称不存在拼写错误或未注册
- 工具调用参数不是合法 JSON
- 工具依赖的底层二进制程序未安装或不在 PATH
- 工具运行时遇到错误超时网络故障权限不足等
请根据上述错误信息检查可用工具然后
1. 修正参数或改用其他工具重试或者
2. 告知用户当前限制并采用替代方案继续`)
}
// toolExecutionRecoveryTimelineMessage returns a message for the eino_recovery event
// displayed in the UI timeline when a tool execution error triggers a retry.
func toolExecutionRecoveryTimelineMessage(attempt int) string {
return fmt.Sprintf(
"工具调用执行失败。已向对话追加纠错提示并要求模型调整策略。"+
"当前为第 %d/%d 轮完整运行。\n\n"+
"Tool call execution failed. "+
"A corrective hint was appended. This is full run %d of %d.",
attempt+1, maxToolCallRecoveryAttempts, attempt+1, maxToolCallRecoveryAttempts,
)
}
-1
View File
@@ -265,4 +265,3 @@ func TestPaginateLines(t *testing.T) {
t.Errorf("空列表应该返回空结果。实际: %d行", len(emptyPage.Lines)) t.Errorf("空列表应该返回空结果。实际: %d行", len(emptyPage.Lines))
} }
} }
-1
View File
@@ -162,4 +162,3 @@ func truncateRunes(s string, max int) string {
} }
return string(r[:max]) + "…" return string(r[:max]) + "…"
} }
+100 -11
View File
@@ -4472,6 +4472,25 @@ header {
flex-wrap: wrap; flex-wrap: wrap;
} }
.tool-resident-toggle {
display: inline-flex;
align-items: center;
gap: 4px;
font-size: 0.75rem;
color: var(--text-secondary);
border: 1px solid var(--border-color);
border-radius: 12px;
padding: 2px 8px;
background: var(--bg-secondary);
cursor: pointer;
}
.tool-resident-toggle input[type="checkbox"] {
width: 14px;
height: 14px;
margin: 0;
}
.external-tool-badge { .external-tool-badge {
display: inline-flex; display: inline-flex;
align-items: center; align-items: center;
@@ -8970,7 +8989,7 @@ header {
/* 任务管理 · 队列卡片:单行主网格 + 进度列内统计,降低高度 */ /* 任务管理 · 队列卡片:单行主网格 + 进度列内统计,降低高度 */
.batch-queue-item__inner--grid { .batch-queue-item__inner--grid {
display: grid; display: grid;
grid-template-columns: minmax(0, 1fr) minmax(128px, auto) minmax(88px, 14%) 44px; grid-template-columns: minmax(0, 1fr) minmax(128px, auto) minmax(88px, 14%) minmax(40px, max-content);
grid-template-rows: auto; grid-template-rows: auto;
grid-template-areas: "lead cluster progress actions"; grid-template-areas: "lead cluster progress actions";
column-gap: 22px; column-gap: 22px;
@@ -9051,6 +9070,12 @@ header {
justify-self: end; justify-self: end;
align-self: center; align-self: center;
padding-left: 6px; padding-left: 6px;
display: flex;
flex-direction: row;
flex-wrap: nowrap;
align-items: center;
justify-content: flex-end;
gap: 6px;
} }
.batch-queue-item__idline--lead { .batch-queue-item__idline--lead {
@@ -9137,6 +9162,12 @@ header {
} }
.batch-queue-icon-btn:hover { .batch-queue-icon-btn:hover {
color: var(--accent-color, #0066ff);
border-color: rgba(0, 102, 255, 0.35);
background: rgba(0, 102, 255, 0.08);
}
.batch-queue-icon-btn--danger:hover {
color: var(--error-color, #dc3545); color: var(--error-color, #dc3545);
border-color: rgba(220, 53, 69, 0.35); border-color: rgba(220, 53, 69, 0.35);
background: rgba(220, 53, 69, 0.06); background: rgba(220, 53, 69, 0.06);
@@ -13575,29 +13606,87 @@ header {
.vulnerability-details { .vulnerability-details {
display: grid; display: grid;
grid-template-columns: repeat(auto-fit, minmax(200px, 1fr)); grid-template-columns: repeat(2, minmax(0, 1fr));
gap: 12px; gap: 14px 16px;
margin-bottom: 16px; margin-bottom: 16px;
padding: 12px; padding: 12px;
background: var(--bg-secondary); background: var(--bg-secondary);
border-radius: 6px; border-radius: 6px;
align-items: start;
} }
.detail-item { /* 元数据条数为奇数时,最后一项占满一行,长 URL/队列 ID 更易读 */
.vulnerability-details .vuln-detail-field:last-child:nth-child(odd) {
grid-column: 1 / -1;
}
@media (max-width: 768px) {
.vulnerability-details {
grid-template-columns: 1fr;
}
.vulnerability-details .vuln-detail-field:last-child:nth-child(odd) {
grid-column: auto;
}
}
/* 漏洞详情字段:标签与值分行,长 ID/URL 可换行、可选中复制 */
.vuln-detail-field {
min-width: 0;
font-size: 0.875rem; font-size: 0.875rem;
} }
.detail-item strong { .vuln-detail-field__label {
color: var(--text-secondary); color: var(--text-secondary);
margin-right: 4px; font-weight: 600;
font-size: 0.75rem;
margin-bottom: 6px;
text-transform: none;
letter-spacing: normal;
} }
.detail-item code { .vuln-detail-field__row {
display: flex;
align-items: flex-start;
gap: 8px;
min-width: 0;
}
.vuln-detail-field-value {
flex: 1;
min-width: 0;
margin: 0;
padding: 8px 10px;
border-radius: 6px;
background: var(--bg-tertiary); background: var(--bg-tertiary);
padding: 2px 6px; border: 1px solid var(--border-color);
border-radius: 4px; font-size: 0.8125rem;
font-size: 0.8rem; line-height: 1.45;
font-family: 'Monaco', 'Menlo', 'Ubuntu Mono', monospace; word-break: break-word;
overflow-wrap: anywhere;
white-space: pre-wrap;
user-select: text;
-webkit-user-select: text;
color: var(--text-primary);
font-family: 'Monaco', 'Menlo', 'Ubuntu Mono', 'Consolas', monospace;
}
.vuln-detail-field__copy {
flex-shrink: 0;
margin-top: 2px;
padding: 6px;
line-height: 0;
border-radius: 6px;
color: var(--text-secondary);
border: 1px solid transparent;
background: transparent;
cursor: pointer;
}
.vuln-detail-field__copy:hover {
color: var(--accent-color);
background: var(--bg-primary);
border-color: var(--border-color);
} }
.vulnerability-proof, .vulnerability-proof,
+61 -2
View File
@@ -178,7 +178,6 @@
"taskCancelled": "Task cancelled", "taskCancelled": "Task cancelled",
"unknownTool": "Unknown tool", "unknownTool": "Unknown tool",
"einoAgentReplyTitle": "Sub-agent reply", "einoAgentReplyTitle": "Sub-agent reply",
"einoRecoveryTitle": "🔄 Invalid tool JSON · run {{n}}/{{max}} (hint appended)",
"einoStreamErrorTitle": "⚠️ Eino stream interrupted ({{agent}})", "einoStreamErrorTitle": "⚠️ Eino stream interrupted ({{agent}})",
"einoStreamErrorMessage": "Streaming read failed; the system will retry or terminate according to policy.", "einoStreamErrorMessage": "Streaming read failed; the system will retry or terminate according to policy.",
"iterationLimitReachedTitle": "⛔ Iteration limit reached", "iterationLimitReachedTitle": "⛔ Iteration limit reached",
@@ -304,6 +303,8 @@
"clearHistory": "Clear history", "clearHistory": "Clear history",
"cancelTask": "Cancel task", "cancelTask": "Cancel task",
"viewConversation": "View conversation", "viewConversation": "View conversation",
"viewVulnerabilities": "View vulnerabilities",
"viewVulnerabilitiesQueueTitle": "View vulnerabilities: open management filtered to this queue",
"retryTask": "Retry", "retryTask": "Retry",
"conversationIdLabel": "Conversation ID", "conversationIdLabel": "Conversation ID",
"statusPending": "Pending", "statusPending": "Pending",
@@ -575,6 +576,10 @@
"addExternal": "Add external MCP", "addExternal": "Add external MCP",
"toolConfig": "MCP tool config", "toolConfig": "MCP tool config",
"saveToolConfig": "Save tool config", "saveToolConfig": "Save tool config",
"alwaysVisibleLabel": "Pinned",
"alwaysVisibleHint": "Always keep visible in Tool Search results",
"alwaysVisibleBuiltinLabel": "Builtin default",
"alwaysVisibleBuiltinHint": "Backend builtin tool is pinned by default and cannot be disabled",
"externalConfig": "External MCP config", "externalConfig": "External MCP config",
"loadingTools": "Loading tools...", "loadingTools": "Loading tools...",
"loadToolsTimeout": "Tools load timeout. External MCP may be slow. Click Refresh to retry or check connection.", "loadToolsTimeout": "Tools load timeout. External MCP may be slow. Click Refresh to retry or check connection.",
@@ -1313,6 +1318,12 @@
"clear": "Clear", "clear": "Clear",
"vulnId": "Vuln ID", "vulnId": "Vuln ID",
"conversationId": "Conversation ID", "conversationId": "Conversation ID",
"taskOrQueueId": "Task / queue ID",
"filterTaskOrQueue": "Filter by task or queue ID",
"conversationTag": "Conversation tag",
"filterConversationTag": "Filter by conversation tag",
"taskTag": "Task tag",
"filterTaskTag": "Filter by task tag",
"severity": "Severity", "severity": "Severity",
"status": "Status", "status": "Status",
"statusOpen": "Open", "statusOpen": "Open",
@@ -1322,7 +1333,31 @@
"searchVulnId": "Search vuln ID", "searchVulnId": "Search vuln ID",
"filterConversation": "Filter by conversation", "filterConversation": "Filter by conversation",
"loading": "Loading...", "loading": "Loading...",
"noRecords": "No vulnerability records" "loadListFailed": "Failed to load",
"noRecords": "No vulnerability records",
"batchExport": "Batch export",
"downloadMarkdownTitle": "Download Markdown",
"exportNoResults": "No vulnerabilities match the current filters",
"exportStarted": "Started downloading {{count}} file(s)",
"exportFailed": "Export failed",
"saveRequiredFields": "Please fill in conversation ID, title, and severity",
"saveFailed": "Save failed",
"fetchFailed": "Failed to fetch vulnerability",
"deleteFailed": "Delete failed",
"detailVulnId": "Vuln ID",
"detailType": "Type",
"detailTarget": "Target",
"detailConversationId": "Conversation ID",
"detailTaskId": "Task ID",
"detailTaskQueueId": "Task queue ID",
"detailConversationTag": "Conversation tag",
"detailTaskTag": "Task tag",
"detailProof": "Proof",
"detailImpact": "Impact",
"detailRecommendation": "Remediation",
"downloadOkTitle": "Downloaded",
"exportFailedMessage": "Export failed",
"downloadFailed": "Download failed"
}, },
"tasksPage": { "tasksPage": {
"statusFilter": "Status filter", "statusFilter": "Status filter",
@@ -1673,6 +1708,7 @@
}, },
"contextMenu": { "contextMenu": {
"viewAttackChain": "View attack chain", "viewAttackChain": "View attack chain",
"viewVulnerabilities": "View vulnerabilities",
"downloadMarkdown": "Download Markdown", "downloadMarkdown": "Download Markdown",
"downloadMarkdownSummary": "Summary", "downloadMarkdownSummary": "Summary",
"downloadMarkdownFull": "Full", "downloadMarkdownFull": "Full",
@@ -1768,6 +1804,10 @@
"vulnerabilityModal": { "vulnerabilityModal": {
"conversationId": "Conversation ID", "conversationId": "Conversation ID",
"conversationIdPlaceholder": "Enter conversation ID", "conversationIdPlaceholder": "Enter conversation ID",
"conversationTag": "Conversation tag",
"conversationTagPlaceholder": "e.g. engagement A, weekly report",
"taskTag": "Task tag",
"taskTagPlaceholder": "e.g. batch scan Q2, retest",
"title": "Title", "title": "Title",
"titlePlaceholder": "Vulnerability title", "titlePlaceholder": "Vulnerability title",
"description": "Description", "description": "Description",
@@ -1795,6 +1835,25 @@
"recommendation": "Recommendation", "recommendation": "Recommendation",
"recommendationPlaceholder": "Remediation" "recommendationPlaceholder": "Remediation"
}, },
"vulnerabilityMd": {
"headingBasic": "Basic information",
"labelId": "Vulnerability ID",
"labelSeverity": "Severity",
"labelStatus": "Status",
"labelType": "Type",
"labelTarget": "Target",
"labelConversationId": "Conversation ID",
"labelTaskId": "Task ID",
"labelTaskQueueId": "Task queue ID",
"labelConversationTag": "Conversation tag",
"labelTaskTag": "Task tag",
"labelCreated": "Created at",
"labelUpdated": "Updated at",
"headingDescription": "Description",
"headingProof": "Proof (POC)",
"headingImpact": "Impact",
"headingRecommendation": "Remediation"
},
"roleModal": { "roleModal": {
"addRole": "Add role", "addRole": "Add role",
"editRole": "Edit role", "editRole": "Edit role",
+61 -2
View File
@@ -178,7 +178,6 @@
"taskCancelled": "任务已取消", "taskCancelled": "任务已取消",
"unknownTool": "未知工具", "unknownTool": "未知工具",
"einoAgentReplyTitle": "子代理回复", "einoAgentReplyTitle": "子代理回复",
"einoRecoveryTitle": "🔄 工具参数无效 · 第 {{n}}/{{max}} 轮(已追加提示)",
"einoStreamErrorTitle": "⚠️ Eino 流式中断({{agent}}", "einoStreamErrorTitle": "⚠️ Eino 流式中断({{agent}}",
"einoStreamErrorMessage": "流式读取异常,系统将按策略重试或结束。", "einoStreamErrorMessage": "流式读取异常,系统将按策略重试或结束。",
"iterationLimitReachedTitle": "⛔ 达到迭代上限", "iterationLimitReachedTitle": "⛔ 达到迭代上限",
@@ -304,6 +303,8 @@
"clearHistory": "清空历史", "clearHistory": "清空历史",
"cancelTask": "取消任务", "cancelTask": "取消任务",
"viewConversation": "查看对话", "viewConversation": "查看对话",
"viewVulnerabilities": "查看漏洞",
"viewVulnerabilitiesQueueTitle": "查看漏洞:打开漏洞管理并筛选本队列",
"retryTask": "重试", "retryTask": "重试",
"conversationIdLabel": "对话ID", "conversationIdLabel": "对话ID",
"statusPending": "待执行", "statusPending": "待执行",
@@ -575,6 +576,10 @@
"addExternal": "添加外部MCP", "addExternal": "添加外部MCP",
"toolConfig": "MCP 工具配置", "toolConfig": "MCP 工具配置",
"saveToolConfig": "保存工具配置", "saveToolConfig": "保存工具配置",
"alwaysVisibleLabel": "常驻",
"alwaysVisibleHint": "始终常驻在 Tool Search 可见列表(不被 tool_search 隐藏)",
"alwaysVisibleBuiltinLabel": "内置默认",
"alwaysVisibleBuiltinHint": "后端内置工具默认常驻,不可关闭",
"externalConfig": "外部 MCP 配置", "externalConfig": "外部 MCP 配置",
"loadingTools": "正在加载工具列表...", "loadingTools": "正在加载工具列表...",
"loadToolsTimeout": "加载工具列表超时,可能是外部MCP连接较慢。请点击\"刷新\"按钮重试,或检查外部MCP连接状态。", "loadToolsTimeout": "加载工具列表超时,可能是外部MCP连接较慢。请点击\"刷新\"按钮重试,或检查外部MCP连接状态。",
@@ -1313,6 +1318,12 @@
"clear": "清除", "clear": "清除",
"vulnId": "漏洞ID", "vulnId": "漏洞ID",
"conversationId": "会话ID", "conversationId": "会话ID",
"taskOrQueueId": "任务ID/队列ID",
"filterTaskOrQueue": "筛选任务ID或队列ID",
"conversationTag": "对话标签",
"filterConversationTag": "筛选对话标签",
"taskTag": "任务标签",
"filterTaskTag": "筛选任务标签",
"severity": "严重程度", "severity": "严重程度",
"status": "状态", "status": "状态",
"statusOpen": "待处理", "statusOpen": "待处理",
@@ -1322,7 +1333,31 @@
"searchVulnId": "搜索漏洞ID", "searchVulnId": "搜索漏洞ID",
"filterConversation": "筛选特定会话", "filterConversation": "筛选特定会话",
"loading": "加载中...", "loading": "加载中...",
"noRecords": "暂无漏洞记录" "loadListFailed": "加载失败",
"noRecords": "暂无漏洞记录",
"batchExport": "批量导出",
"downloadMarkdownTitle": "下载 Markdown",
"exportNoResults": "当前筛选条件下无可导出漏洞",
"exportStarted": "已开始下载 {{count}} 份报告",
"exportFailed": "导出失败",
"saveRequiredFields": "请填写必填字段:会话ID、标题和严重程度",
"saveFailed": "保存失败",
"fetchFailed": "获取漏洞失败",
"deleteFailed": "删除失败",
"detailVulnId": "漏洞ID",
"detailType": "类型",
"detailTarget": "目标",
"detailConversationId": "会话ID",
"detailTaskId": "任务ID",
"detailTaskQueueId": "任务队列ID",
"detailConversationTag": "对话标签",
"detailTaskTag": "任务标签",
"detailProof": "证明",
"detailImpact": "影响",
"detailRecommendation": "修复建议",
"downloadOkTitle": "下载成功",
"exportFailedMessage": "导出失败",
"downloadFailed": "下载失败"
}, },
"tasksPage": { "tasksPage": {
"statusFilter": "状态筛选", "statusFilter": "状态筛选",
@@ -1673,6 +1708,7 @@
}, },
"contextMenu": { "contextMenu": {
"viewAttackChain": "查看攻击链", "viewAttackChain": "查看攻击链",
"viewVulnerabilities": "查看漏洞",
"downloadMarkdown": "下载 Markdown", "downloadMarkdown": "下载 Markdown",
"downloadMarkdownSummary": "简版", "downloadMarkdownSummary": "简版",
"downloadMarkdownFull": "完整版", "downloadMarkdownFull": "完整版",
@@ -1768,6 +1804,10 @@
"vulnerabilityModal": { "vulnerabilityModal": {
"conversationId": "会话ID", "conversationId": "会话ID",
"conversationIdPlaceholder": "输入会话ID", "conversationIdPlaceholder": "输入会话ID",
"conversationTag": "对话标签",
"conversationTagPlaceholder": "如:红队演练A、客户A周报",
"taskTag": "任务标签",
"taskTagPlaceholder": "如:批量扫描Q2、专项复测",
"title": "标题", "title": "标题",
"titlePlaceholder": "漏洞标题", "titlePlaceholder": "漏洞标题",
"description": "描述", "description": "描述",
@@ -1795,6 +1835,25 @@
"recommendation": "修复建议", "recommendation": "修复建议",
"recommendationPlaceholder": "修复建议" "recommendationPlaceholder": "修复建议"
}, },
"vulnerabilityMd": {
"headingBasic": "基本信息",
"labelId": "漏洞ID",
"labelSeverity": "严重程度",
"labelStatus": "状态",
"labelType": "类型",
"labelTarget": "目标",
"labelConversationId": "会话ID",
"labelTaskId": "任务ID",
"labelTaskQueueId": "任务队列ID",
"labelConversationTag": "对话标签",
"labelTaskTag": "任务标签",
"labelCreated": "创建时间",
"labelUpdated": "更新时间",
"headingDescription": "描述",
"headingProof": "证明(POC",
"headingImpact": "影响",
"headingRecommendation": "修复建议"
},
"roleModal": { "roleModal": {
"addRole": "添加角色", "addRole": "添加角色",
"editRole": "编辑角色", "editRole": "编辑角色",
+11 -4
View File
@@ -2226,10 +2226,6 @@ function renderProcessDetails(messageId, processDetails) {
itemTitle = agPx + execLine; itemTitle = agPx + execLine;
} else if (eventType === 'eino_agent_reply') { } else if (eventType === 'eino_agent_reply') {
itemTitle = agPx + '💬 ' + (typeof window.t === 'function' ? window.t('chat.einoAgentReplyTitle') : '子代理回复'); itemTitle = agPx + '💬 ' + (typeof window.t === 'function' ? window.t('chat.einoAgentReplyTitle') : '子代理回复');
} else if (eventType === 'eino_recovery') {
const ri = data.runIndex != null ? data.runIndex : (data.einoRetry != null ? data.einoRetry + 1 : 1);
const mx = data.maxRuns != null ? data.maxRuns : 3;
itemTitle = (typeof window.t === 'function' ? window.t('chat.einoRecoveryTitle', { n: ri, max: mx }) : ('🔄 第 ' + ri + '/' + mx + ' 轮(已追加提示)'));
} else if (eventType === 'knowledge_retrieval') { } else if (eventType === 'knowledge_retrieval') {
itemTitle = '📚 ' + (typeof window.t === 'function' ? window.t('chat.knowledgeRetrieval') : '知识检索'); itemTitle = '📚 ' + (typeof window.t === 'function' ? window.t('chat.knowledgeRetrieval') : '知识检索');
} else if (eventType === 'error') { } else if (eventType === 'error') {
@@ -6125,6 +6121,17 @@ async function downloadConversationMarkdownFromContext(includeToolDetails = fals
closeContextMenu(); closeContextMenu();
} }
// 从上下文菜单跳转到漏洞管理,并按当前对话 ID 筛选
function navigateToVulnerabilitiesForContextConversation() {
const convId = contextMenuConversationId;
if (!convId) {
closeContextMenu();
return;
}
closeContextMenu();
window.location.hash = 'vulnerabilities?conversation_id=' + encodeURIComponent(convId);
}
// 从上下文菜单删除对话 // 从上下文菜单删除对话
function deleteConversationFromContext() { function deleteConversationFromContext() {
const convId = contextMenuConversationId; const convId = contextMenuConversationId;
-37
View File
@@ -1133,24 +1133,6 @@ function handleStreamEvent(event, progressElement, progressId,
}); });
break; break;
case 'eino_recovery': {
const d = event.data || {};
const runIdx = d.runIndex != null ? d.runIndex : (d.einoRetry != null ? d.einoRetry + 1 : 1);
const maxRuns = d.maxRuns != null ? d.maxRuns : 3;
const title = typeof window.t === 'function'
? window.t('chat.einoRecoveryTitle', { n: runIdx, max: maxRuns })
: ('🔄 工具参数无效 · 第 ' + runIdx + '/' + maxRuns + ' 轮(已追加提示)');
addTimelineItem(timeline, 'eino_recovery', {
title: title,
message: event.message || '',
data: event.data
});
// If the backend triggers a recovery run, any "running" tool_call items in this progress
// should be closed to avoid being stuck forever.
finalizeOutstandingToolCallsForProgress(progressId, 'failed');
break;
}
case 'eino_stream_error': { case 'eino_stream_error': {
const d = event.data || {}; const d = event.data || {};
const agent = d.einoAgent ? String(d.einoAgent) : ''; const agent = d.einoAgent ? String(d.einoAgent) : '';
@@ -2190,15 +2172,6 @@ function addTimelineItem(timeline, type, options) {
if (type === 'progress' && options.message) { if (type === 'progress' && options.message) {
item.dataset.progressMessage = options.message; item.dataset.progressMessage = options.message;
} }
if (type === 'eino_recovery' && options.data) {
const d = options.data;
if (d.runIndex != null) {
item.dataset.recoveryRunIndex = String(d.runIndex);
}
if (d.maxRuns != null) {
item.dataset.recoveryMaxRuns = String(d.maxRuns);
}
}
if (type === 'tool_calls_detected' && options.data && options.data.count != null) { if (type === 'tool_calls_detected' && options.data && options.data.count != null) {
item.dataset.toolCallsCount = String(options.data.count); item.dataset.toolCallsCount = String(options.data.count);
} }
@@ -2309,12 +2282,6 @@ function addTimelineItem(timeline, type, options) {
</div> </div>
</div> </div>
`; `;
} else if (type === 'eino_recovery' && options.message) {
content += `
<div class="timeline-item-content timeline-eino-recovery">
${escapeHtml(options.message).replace(/\n/g, '<br>')}
</div>
`;
} else if (type === 'cancelled') { } else if (type === 'cancelled') {
const taskCancelledLabel = typeof window.t === 'function' ? window.t('chat.taskCancelled') : '任务已取消'; const taskCancelledLabel = typeof window.t === 'function' ? window.t('chat.taskCancelled') : '任务已取消';
content += ` content += `
@@ -3197,10 +3164,6 @@ function refreshProgressAndTimelineI18n() {
titleSpan.textContent = ap + icon + (success ? _t('chat.toolExecComplete', { name: name }) : _t('chat.toolExecFailed', { name: name })); titleSpan.textContent = ap + icon + (success ? _t('chat.toolExecComplete', { name: name }) : _t('chat.toolExecFailed', { name: name }));
} else if (type === 'eino_agent_reply') { } else if (type === 'eino_agent_reply') {
titleSpan.textContent = ap + '\uD83D\uDCAC ' + _t('chat.einoAgentReplyTitle'); titleSpan.textContent = ap + '\uD83D\uDCAC ' + _t('chat.einoAgentReplyTitle');
} else if (type === 'eino_recovery' && item.dataset.recoveryRunIndex) {
const n = parseInt(item.dataset.recoveryRunIndex, 10) || 1;
const mx = parseInt(item.dataset.recoveryMaxRuns, 10) || 3;
titleSpan.textContent = _t('chat.einoRecoveryTitle', { n: n, max: mx });
} else if (type === 'cancelled') { } else if (type === 'cancelled') {
titleSpan.textContent = '\u26D4 ' + _t('chat.taskCancelled'); titleSpan.textContent = '\u26D4 ' + _t('chat.taskCancelled');
} else if (type === 'progress' && item.dataset.progressMessage !== undefined) { } else if (type === 'progress' && item.dataset.progressMessage !== undefined) {
+23 -11
View File
@@ -1,19 +1,19 @@
// 页面路由管理 // 页面路由管理
let currentPage = 'dashboard'; let currentPage = 'dashboard';
/** 仅当停留在 chat 时保留 ?conversation= 等查询串,其它页面只使用 pageId */ /** chat、漏洞管理页在切换时保留当前 hash 上的查询串(如 ?conversation= / ?conversation_id= */
function buildHashForPage(pageId) { function buildHashForPage(pageId) {
if (pageId !== 'chat') { if (pageId !== 'chat' && pageId !== 'vulnerabilities') {
return pageId; return pageId;
} }
const full = window.location.hash.slice(1); const full = window.location.hash.slice(1);
const parts = full.split('?'); const parts = full.split('?');
const curPage = parts[0]; const curPage = parts[0];
const q = parts.length > 1 ? parts.slice(1).join('?') : ''; const q = parts.length > 1 ? parts.slice(1).join('?') : '';
if (curPage === 'chat' && q) { if (curPage === pageId && q) {
return 'chat?' + q; return pageId + '?' + q;
} }
return 'chat'; return pageId;
} }
let chatConversationFromHashSeq = 0; let chatConversationFromHashSeq = 0;
@@ -301,12 +301,7 @@ async function initPage(pageId) {
break; break;
case 'mcp-management': case 'mcp-management':
// 初始化MCP管理 // 初始化MCP管理
// 先加载外部MCP列表(快速),然后加载工具列表 const startLoadMcpTools = () => {
if (typeof loadExternalMCPs === 'function') {
loadExternalMCPs().catch(err => {
console.warn('加载外部MCP列表失败:', err);
});
}
// 加载工具列表(MCP工具配置已移到MCP管理页面) // 加载工具列表(MCP工具配置已移到MCP管理页面)
// 使用异步加载,避免阻塞页面渲染 // 使用异步加载,避免阻塞页面渲染
if (typeof loadToolsList === 'function') { if (typeof loadToolsList === 'function') {
@@ -321,6 +316,23 @@ async function initPage(pageId) {
}); });
}, 100); }, 100);
} }
};
// 先拉取全局配置,确保 tool_search 常驻状态按后端生效集合展示
if (typeof loadConfig === 'function') {
loadConfig(false)
.catch(err => {
console.warn('加载配置失败(将继续加载工具列表):', err);
})
.finally(startLoadMcpTools);
} else {
startLoadMcpTools();
}
// 先加载外部MCP列表(快速),然后加载工具列表
if (typeof loadExternalMCPs === 'function') {
loadExternalMCPs().catch(err => {
console.warn('加载外部MCP列表失败:', err);
});
}
break; break;
case 'vulnerabilities': case 'vulnerabilities':
// 初始化漏洞管理页面 // 初始化漏洞管理页面
+36
View File
@@ -1,6 +1,8 @@
// 设置相关功能 // 设置相关功能
let currentConfig = null; let currentConfig = null;
let allTools = []; let allTools = [];
let alwaysVisibleToolNames = new Set();
let alwaysVisibleBuiltinToolNames = new Set();
// 全局工具状态映射,用于保存用户在所有页面的修改 // 全局工具状态映射,用于保存用户在所有页面的修改
// key: 唯一工具标识符(toolKey),value: { enabled: boolean, is_external: boolean, external_mcp: string } // key: 唯一工具标识符(toolKey),value: { enabled: boolean, is_external: boolean, external_mcp: string }
let toolStateMap = new Map(); let toolStateMap = new Map();
@@ -100,6 +102,14 @@ async function loadConfig(loadTools = true) {
} }
currentConfig = await response.json(); currentConfig = await response.json();
const alwaysVisibleList = currentConfig?.multi_agent?.tool_search_always_visible_effective_tools;
const alwaysVisibleConfigured = currentConfig?.multi_agent?.tool_search_always_visible_tools;
alwaysVisibleToolNames = new Set(Array.isArray(alwaysVisibleList) ? alwaysVisibleList.filter(Boolean) : []);
alwaysVisibleBuiltinToolNames = new Set(
alwaysVisibleToolNames.size > 0 && Array.isArray(alwaysVisibleConfigured)
? Array.from(alwaysVisibleToolNames).filter(name => !alwaysVisibleConfigured.includes(name))
: []
);
// 填充OpenAI配置 // 填充OpenAI配置
const providerEl = document.getElementById('openai-provider'); const providerEl = document.getElementById('openai-provider');
@@ -498,6 +508,8 @@ function renderToolsList() {
is_external: tool.is_external || false, is_external: tool.is_external || false,
external_mcp: tool.external_mcp || '' external_mcp: tool.external_mcp || ''
}; };
const alwaysVisibleChecked = alwaysVisibleToolNames.has(tool.name);
const alwaysVisibleLocked = alwaysVisibleBuiltinToolNames.has(tool.name);
// 外部工具标签,显示来源信息(可点击跳转到对应 MCP 卡片) // 外部工具标签,显示来源信息(可点击跳转到对应 MCP 卡片)
let externalBadge = ''; let externalBadge = '';
@@ -521,6 +533,11 @@ function renderToolsList() {
<div class="tool-item-name"> <div class="tool-item-name">
${escapeHtml(tool.name)} ${escapeHtml(tool.name)}
${externalBadge} ${externalBadge}
<label class="tool-resident-toggle" title="${typeof window.t === 'function' ? window.t('mcp.alwaysVisibleHint') : '始终常驻在 Tool Search 可见列表'}" onclick="event.stopPropagation()">
<input type="checkbox" ${alwaysVisibleChecked ? 'checked' : ''} ${alwaysVisibleLocked ? 'disabled' : ''} onchange="handleToolAlwaysVisibleChange('${escapeHtml(tool.name)}', this.checked)" />
<span>${typeof window.t === 'function' ? window.t('mcp.alwaysVisibleLabel') : '常驻'}</span>
</label>
${alwaysVisibleLocked ? `<span class="external-tool-badge" title="${typeof window.t === 'function' ? window.t('mcp.alwaysVisibleBuiltinHint') : '后端内置工具默认常驻,不可关闭'}">${typeof window.t === 'function' ? window.t('mcp.alwaysVisibleBuiltinLabel') : '内置默认'}</span>` : ''}
<span class="tool-expand-icon"></span> <span class="tool-expand-icon"></span>
</div> </div>
<div class="tool-item-desc">${escapeHtml(tool.description || (typeof window.t === 'function' ? window.t('mcp.noDescription') : '无描述'))}</div> <div class="tool-item-desc">${escapeHtml(tool.description || (typeof window.t === 'function' ? window.t('mcp.noDescription') : '无描述'))}</div>
@@ -716,6 +733,16 @@ function handleToolCheckboxChange(toolKey, enabled) {
updateToolsStats(); updateToolsStats();
} }
function handleToolAlwaysVisibleChange(toolName, alwaysVisible) {
const name = (toolName || '').trim();
if (!name) return;
if (alwaysVisible) {
alwaysVisibleToolNames.add(name);
} else {
alwaysVisibleToolNames.delete(name);
}
}
// 全选工具 // 全选工具
function selectAllTools() { function selectAllTools() {
document.querySelectorAll('#tools-list input[type="checkbox"]').forEach(checkbox => { document.querySelectorAll('#tools-list input[type="checkbox"]').forEach(checkbox => {
@@ -886,9 +913,11 @@ async function updateToolsStats() {
} }
const tStats = typeof window.t === 'function' ? window.t : (k) => k; const tStats = typeof window.t === 'function' ? window.t : (k) => k;
const pinnedCount = alwaysVisibleToolNames.size;
statsEl.innerHTML = ` statsEl.innerHTML = `
<span title="${tStats('mcp.currentPageEnabled')}"> ${tStats('mcp.currentPageEnabled')}: <strong>${currentPageEnabled}</strong> / ${currentPageTotal}</span> <span title="${tStats('mcp.currentPageEnabled')}"> ${tStats('mcp.currentPageEnabled')}: <strong>${currentPageEnabled}</strong> / ${currentPageTotal}</span>
<span title="${tStats('mcp.totalEnabled')}">📊 ${tStats('mcp.totalEnabled')}: <strong>${totalEnabled}</strong> / ${totalTools}</span> <span title="${tStats('mcp.totalEnabled')}">📊 ${tStats('mcp.totalEnabled')}: <strong>${totalEnabled}</strong> / ${totalTools}</span>
<span title="${tStats('mcp.alwaysVisibleHint')}">📌 ${tStats('mcp.alwaysVisibleLabel')}: <strong>${pinnedCount}</strong></span>
`; `;
} }
@@ -1230,6 +1259,13 @@ async function saveToolsConfig() {
const config = { const config = {
openai: currentConfig.openai || {}, openai: currentConfig.openai || {},
agent: currentConfig.agent || {}, agent: currentConfig.agent || {},
multi_agent: {
enabled: currentConfig?.multi_agent?.enabled === true,
robot_use_multi_agent: currentConfig?.multi_agent?.robot_use_multi_agent === true,
batch_use_multi_agent: currentConfig?.multi_agent?.batch_use_multi_agent === true,
plan_execute_loop_max_iterations: Number(currentConfig?.multi_agent?.plan_execute_loop_max_iterations || 0),
tool_search_always_visible_tools: Array.from(alwaysVisibleToolNames).filter(name => !alwaysVisibleBuiltinToolNames.has(name))
},
tools: [] tools: []
}; };
+16 -2
View File
@@ -531,6 +531,7 @@ function renderTaskItem(task, statusMap, isHistory = false) {
${isHistory && completedText ? completedText : timeText} ${isHistory && completedText ? completedText : timeText}
</span> </span>
${canCancel ? `<button class="btn-secondary btn-small" onclick="cancelTask('${task.conversationId}', this)">` + _t('tasks.cancelTask') + `</button>` : ''} ${canCancel ? `<button class="btn-secondary btn-small" onclick="cancelTask('${task.conversationId}', this)">` + _t('tasks.cancelTask') + `</button>` : ''}
${task.conversationId ? `<button class="btn-secondary btn-small" onclick="navigateToVulnerabilitiesFromTasksPage('conversation', '${task.conversationId}')">` + _t('tasks.viewVulnerabilities') + `</button>` : ''}
${task.conversationId ? `<button class="btn-secondary btn-small" onclick="viewConversation('${task.conversationId}')">` + _t('tasks.viewConversation') + `</button>` : ''} ${task.conversationId ? `<button class="btn-secondary btn-small" onclick="viewConversation('${task.conversationId}')">` + _t('tasks.viewConversation') + `</button>` : ''}
</div> </div>
</div> </div>
@@ -708,6 +709,17 @@ function viewConversation(conversationId) {
} }
} }
// 跳转漏洞管理并按对话 ID 或批量队列 ID 筛选(队列 ID 走 task_id,与列表筛选项一致)
function navigateToVulnerabilitiesFromTasksPage(kind, id) {
if (!id) return;
const enc = encodeURIComponent(id);
if (kind === 'queue') {
window.location.hash = 'vulnerabilities?task_id=' + enc;
} else if (kind === 'conversation') {
window.location.hash = 'vulnerabilities?conversation_id=' + enc;
}
}
// 刷新任务列表 // 刷新任务列表
async function refreshTasks() { async function refreshTasks() {
await loadTasks(); await loadTasks();
@@ -1134,6 +1146,8 @@ function renderBatchQueues() {
const progress = stats.total > 0 ? Math.round((stats.completed + stats.failed + stats.cancelled) / stats.total * 100) : 0; const progress = stats.total > 0 ? Math.round((stats.completed + stats.failed + stats.cancelled) / stats.total * 100) : 0;
// 允许删除待执行、已完成或已取消状态的队列 // 允许删除待执行、已完成或已取消状态的队列
const canDelete = queue.status === 'pending' || queue.status === 'completed' || queue.status === 'cancelled'; const canDelete = queue.status === 'pending' || queue.status === 'completed' || queue.status === 'cancelled';
// 操作列常驻「查看漏洞」,不再使用 --no-actions 隐藏整列(否则无法从运行中队列跳转漏洞页)
const noActionsClass = '';
const loadedRoles = batchQueuesState.loadedRoles || []; const loadedRoles = batchQueuesState.loadedRoles || [];
const roleIcon = getRoleIconForDisplay(queue.role, loadedRoles); const roleIcon = getRoleIconForDisplay(queue.role, loadedRoles);
@@ -1157,7 +1171,6 @@ function renderBatchQueues() {
: `<h4 class="batch-queue-card-title batch-queue-card-title--muted">${escapeHtml(_t('tasks.batchQueueUntitled'))}</h4>`; : `<h4 class="batch-queue-card-title batch-queue-card-title--muted">${escapeHtml(_t('tasks.batchQueueUntitled'))}</h4>`;
const doneCount = stats.completed + stats.failed + stats.cancelled; const doneCount = stats.completed + stats.failed + stats.cancelled;
const noActionsClass = canDelete ? '' : ' batch-queue-item--no-actions';
return ` return `
<div class="batch-queue-item batch-queue-item--compact${cardMod}${noActionsClass}" data-queue-id="${queue.id}" onclick="showBatchQueueDetail('${queue.id}')"> <div class="batch-queue-item batch-queue-item--compact${cardMod}${noActionsClass}" data-queue-id="${queue.id}" onclick="showBatchQueueDetail('${queue.id}')">
<div class="batch-queue-item__inner batch-queue-item__inner--grid"> <div class="batch-queue-item__inner batch-queue-item__inner--grid">
@@ -1182,7 +1195,8 @@ function renderBatchQueues() {
</div> </div>
</div> </div>
<div class="batch-queue-item__actions-col" onclick="event.stopPropagation();"> <div class="batch-queue-item__actions-col" onclick="event.stopPropagation();">
${canDelete ? `<button type="button" class="batch-queue-icon-btn" onclick="deleteBatchQueueFromList('${queue.id}')" title="${escapeHtml(_t('tasks.deleteQueue'))}" aria-label="${escapeHtml(_t('tasks.deleteQueue'))}"><svg class="batch-queue-icon-btn__svg" width="18" height="18" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round" aria-hidden="true"><path d="M3 6h18"/><path d="M19 6v14a2 2 0 0 1-2 2H7a2 2 0 0 1-2-2V6"/><path d="M8 6V4a2 2 0 0 1 2-2h4a2 2 0 0 1 2 2v2"/><path d="M10 11v6"/><path d="M14 11v6"/></svg></button>` : ''} <button type="button" class="batch-queue-icon-btn" onclick="navigateToVulnerabilitiesFromTasksPage('queue', '${queue.id}')" title="${escapeHtml(_t('tasks.viewVulnerabilitiesQueueTitle'))}" aria-label="${escapeHtml(_t('tasks.viewVulnerabilitiesQueueTitle'))}"><svg class="batch-queue-icon-btn__svg" width="18" height="18" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round" aria-hidden="true"><path d="M12 22s8-4 8-10V5l-8-3-8 3v7c0 6 8 10 8 10z"/><path d="M9 12l2 2 4-4"/></svg></button>
${canDelete ? `<button type="button" class="batch-queue-icon-btn batch-queue-icon-btn--danger" onclick="deleteBatchQueueFromList('${queue.id}')" title="${escapeHtml(_t('tasks.deleteQueue'))}" aria-label="${escapeHtml(_t('tasks.deleteQueue'))}"><svg class="batch-queue-icon-btn__svg" width="18" height="18" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round" aria-hidden="true"><path d="M3 6h18"/><path d="M19 6v14a2 2 0 0 1-2 2H7a2 2 0 0 1-2-2V6"/><path d="M8 6V4a2 2 0 0 1 2-2h4a2 2 0 0 1 2 2v2"/><path d="M10 11v6"/><path d="M14 11v6"/></svg></button>` : ''}
</div> </div>
</div> </div>
</div> </div>
+315 -81
View File
@@ -1,5 +1,43 @@
// 漏洞管理相关功能 // 漏洞管理相关功能
function vulnT(key, opts) {
if (typeof window.t === 'function') {
return window.t(key, opts);
}
return key;
}
function vulnDateLocale() {
try {
const lang = (window.__locale || '').toLowerCase();
if (lang.indexOf('zh') === 0) {
return 'zh-CN';
}
} catch (e) { /* ignore */ }
return 'en-US';
}
function vulnSeverityLabel(code) {
const m = {
critical: 'dashboard.severityCritical',
high: 'dashboard.severityHigh',
medium: 'dashboard.severityMedium',
low: 'dashboard.severityLow',
info: 'dashboard.severityInfo'
};
return m[code] ? vulnT(m[code]) : code;
}
function vulnStatusLabel(code) {
const m = {
open: 'vulnerabilityPage.statusOpen',
confirmed: 'vulnerabilityPage.statusConfirmed',
fixed: 'vulnerabilityPage.statusFixed',
false_positive: 'vulnerabilityPage.statusFalsePositive'
};
return m[code] ? vulnT(m[code]) : code;
}
// 从localStorage读取每页显示数量,默认为20 // 从localStorage读取每页显示数量,默认为20
const getVulnerabilityPageSize = () => { const getVulnerabilityPageSize = () => {
const saved = localStorage.getItem('vulnerabilityPageSize'); const saved = localStorage.getItem('vulnerabilityPageSize');
@@ -10,6 +48,9 @@ let currentVulnerabilityId = null;
let vulnerabilityFilters = { let vulnerabilityFilters = {
id: '', id: '',
conversation_id: '', conversation_id: '',
task_id: '',
conversation_tag: '',
task_tag: '',
severity: '', severity: '',
status: '' status: ''
}; };
@@ -20,10 +61,43 @@ let vulnerabilityPagination = {
totalPages: 1 totalPages: 1
}; };
// 从地址栏 #vulnerabilities?conversation_id= / ?task_id= 同步筛选(对话菜单、任务管理联动)
function syncVulnerabilityFiltersFromLocationHash() {
const hash = window.location.hash.slice(1);
const hashParts = hash.split('?');
if (hashParts[0] !== 'vulnerabilities' || hashParts.length < 2) {
return;
}
const params = new URLSearchParams(hashParts.slice(1).join('?'));
const cid = (params.get('conversation_id') || '').trim();
const tid = (params.get('task_id') || '').trim();
if (!cid && !tid) {
return;
}
vulnerabilityFilters.conversation_id = '';
vulnerabilityFilters.task_id = '';
const convEl = document.getElementById('vulnerability-conversation-filter');
const taskEl = document.getElementById('vulnerability-task-filter');
if (convEl) convEl.value = '';
if (taskEl) taskEl.value = '';
if (cid) {
vulnerabilityFilters.conversation_id = cid;
if (convEl) convEl.value = cid;
}
if (tid) {
vulnerabilityFilters.task_id = tid;
if (taskEl) taskEl.value = tid;
}
vulnerabilityPagination.currentPage = 1;
}
// 初始化漏洞管理页面 // 初始化漏洞管理页面
function initVulnerabilityPage() { function initVulnerabilityPage() {
// 从localStorage加载每页条数设置 // 从localStorage加载每页条数设置
vulnerabilityPagination.pageSize = getVulnerabilityPageSize(); vulnerabilityPagination.pageSize = getVulnerabilityPageSize();
syncVulnerabilityFiltersFromLocationHash();
loadVulnerabilityStats(); loadVulnerabilityStats();
loadVulnerabilities(); loadVulnerabilities();
} }
@@ -41,6 +115,9 @@ async function loadVulnerabilityStats() {
if (vulnerabilityFilters.conversation_id) { if (vulnerabilityFilters.conversation_id) {
params.append('conversation_id', vulnerabilityFilters.conversation_id); params.append('conversation_id', vulnerabilityFilters.conversation_id);
} }
if (vulnerabilityFilters.task_id) {
params.append('task_id', vulnerabilityFilters.task_id);
}
const response = await apiFetch(`/api/vulnerabilities/stats?${params.toString()}`); const response = await apiFetch(`/api/vulnerabilities/stats?${params.toString()}`);
if (!response.ok) { if (!response.ok) {
@@ -82,7 +159,7 @@ function updateVulnerabilityStats(stats) {
// 加载漏洞列表 // 加载漏洞列表
async function loadVulnerabilities(page = null) { async function loadVulnerabilities(page = null) {
const listContainer = document.getElementById('vulnerabilities-list'); const listContainer = document.getElementById('vulnerabilities-list');
listContainer.innerHTML = '<div class="loading-spinner">加载中...</div>'; listContainer.innerHTML = `<div class="loading-spinner">${escapeHtml(vulnT('vulnerabilityPage.loading'))}</div>`;
try { try {
// 检查apiFetch是否可用 // 检查apiFetch是否可用
@@ -106,6 +183,15 @@ async function loadVulnerabilities(page = null) {
if (vulnerabilityFilters.conversation_id) { if (vulnerabilityFilters.conversation_id) {
params.append('conversation_id', vulnerabilityFilters.conversation_id); params.append('conversation_id', vulnerabilityFilters.conversation_id);
} }
if (vulnerabilityFilters.task_id) {
params.append('task_id', vulnerabilityFilters.task_id);
}
if (vulnerabilityFilters.conversation_tag) {
params.append('conversation_tag', vulnerabilityFilters.conversation_tag);
}
if (vulnerabilityFilters.task_tag) {
params.append('task_tag', vulnerabilityFilters.task_tag);
}
if (vulnerabilityFilters.severity) { if (vulnerabilityFilters.severity) {
params.append('severity', vulnerabilityFilters.severity); params.append('severity', vulnerabilityFilters.severity);
} }
@@ -148,7 +234,7 @@ async function loadVulnerabilities(page = null) {
renderVulnerabilityPagination(); renderVulnerabilityPagination();
} catch (error) { } catch (error) {
console.error('加载漏洞列表失败:', error); console.error('加载漏洞列表失败:', error);
listContainer.innerHTML = `<div class="error-message">加载失败: ${error.message}</div>`; listContainer.innerHTML = `<div class="error-message">${escapeHtml(vulnT('vulnerabilityPage.loadListFailed'))}: ${escapeHtml(error.message)}</div>`;
} }
} }
@@ -180,22 +266,12 @@ function renderVulnerabilities(vulnerabilities) {
const html = vulnerabilities.map(vuln => { const html = vulnerabilities.map(vuln => {
const severityClass = `severity-${vuln.severity}`; const severityClass = `severity-${vuln.severity}`;
const severityText = { const severityText = vulnSeverityLabel(vuln.severity);
'critical': '严重', const statusText = vulnStatusLabel(vuln.status);
'high': '高危', const createdDate = new Date(vuln.created_at).toLocaleString(vulnDateLocale());
'medium': '中危', const dlTitle = escapeHtml(vulnT('vulnerabilityPage.downloadMarkdownTitle'));
'low': '低危', const editTitle = escapeHtml(vulnT('common.edit'));
'info': '信息' const deleteTitle = escapeHtml(vulnT('common.delete'));
}[vuln.severity] || vuln.severity;
const statusText = {
'open': '待处理',
'confirmed': '已确认',
'fixed': '已修复',
'false_positive': '误报'
}[vuln.status] || vuln.status;
const createdDate = new Date(vuln.created_at).toLocaleString('zh-CN');
return ` return `
<div class="vulnerability-card ${severityClass}"> <div class="vulnerability-card ${severityClass}">
@@ -214,20 +290,20 @@ function renderVulnerabilities(vulnerabilities) {
</div> </div>
</div> </div>
<div class="vulnerability-actions" onclick="event.stopPropagation();"> <div class="vulnerability-actions" onclick="event.stopPropagation();">
<button class="btn-ghost" onclick="downloadVulnerabilityAsMarkdown('${vuln.id}', event)" title="下载Markdown"> <button class="btn-ghost" onclick="downloadVulnerabilityAsMarkdown('${vuln.id}', event)" title="${dlTitle}">
<svg width="16" height="16" viewBox="0 0 24 24" fill="none" xmlns="http://www.w3.org/2000/svg"> <svg width="16" height="16" viewBox="0 0 24 24" fill="none" xmlns="http://www.w3.org/2000/svg">
<path d="M21 15v4a2 2 0 0 1-2 2H5a2 2 0 0 1-2-2v-4" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"/> <path d="M21 15v4a2 2 0 0 1-2 2H5a2 2 0 0 1-2-2v-4" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"/>
<polyline points="7 10 12 15 17 10" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"/> <polyline points="7 10 12 15 17 10" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"/>
<line x1="12" y1="15" x2="12" y2="3" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"/> <line x1="12" y1="15" x2="12" y2="3" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"/>
</svg> </svg>
</button> </button>
<button class="btn-ghost" onclick="editVulnerability('${vuln.id}')" title="编辑"> <button class="btn-ghost" onclick="editVulnerability('${vuln.id}')" title="${editTitle}">
<svg width="16" height="16" viewBox="0 0 24 24" fill="none" xmlns="http://www.w3.org/2000/svg"> <svg width="16" height="16" viewBox="0 0 24 24" fill="none" xmlns="http://www.w3.org/2000/svg">
<path d="M11 4H4a2 2 0 0 0-2 2v14a2 2 0 0 0 2 2h14a2 2 0 0 0 2-2v-7" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"/> <path d="M11 4H4a2 2 0 0 0-2 2v14a2 2 0 0 0 2 2h14a2 2 0 0 0 2-2v-7" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"/>
<path d="M18.5 2.5a2.121 2.121 0 0 1 3 3L12 15l-4 1 1-4 9.5-9.5z" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"/> <path d="M18.5 2.5a2.121 2.121 0 0 1 3 3L12 15l-4 1 1-4 9.5-9.5z" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"/>
</svg> </svg>
</button> </button>
<button class="btn-ghost" onclick="deleteVulnerability('${vuln.id}')" title="删除"> <button class="btn-ghost" onclick="deleteVulnerability('${vuln.id}')" title="${deleteTitle}">
<svg width="16" height="16" viewBox="0 0 24 24" fill="none" xmlns="http://www.w3.org/2000/svg"> <svg width="16" height="16" viewBox="0 0 24 24" fill="none" xmlns="http://www.w3.org/2000/svg">
<path d="M3 6h18M8 6V4a2 2 0 0 1 2-2h4a2 2 0 0 1 2 2v2m3 0v14a2 2 0 0 1-2 2H7a2 2 0 0 1-2-2V6h14z" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"/> <path d="M3 6h18M8 6V4a2 2 0 0 1 2-2h4a2 2 0 0 1 2 2v2m3 0v14a2 2 0 0 1-2 2H7a2 2 0 0 1-2-2V6h14z" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"/>
</svg> </svg>
@@ -237,20 +313,27 @@ function renderVulnerabilities(vulnerabilities) {
<div class="vulnerability-content" id="content-${vuln.id}" style="display: none;"> <div class="vulnerability-content" id="content-${vuln.id}" style="display: none;">
${vuln.description ? `<div class="vulnerability-description">${escapeHtml(vuln.description)}</div>` : ''} ${vuln.description ? `<div class="vulnerability-description">${escapeHtml(vuln.description)}</div>` : ''}
<div class="vulnerability-details"> <div class="vulnerability-details">
<div class="detail-item"><strong>漏洞ID:</strong> <code>${escapeHtml(vuln.id)}</code></div> ${vulnDetailField(vulnT('vulnerabilityPage.detailVulnId'), vuln.id, true)}
${vuln.type ? `<div class="detail-item"><strong>类型:</strong> ${escapeHtml(vuln.type)}</div>` : ''} ${vuln.type ? vulnDetailField(vulnT('vulnerabilityPage.detailType'), vuln.type, false) : ''}
${vuln.target ? `<div class="detail-item"><strong>目标:</strong> ${escapeHtml(vuln.target)}</div>` : ''} ${vuln.target ? vulnDetailField(vulnT('vulnerabilityPage.detailTarget'), vuln.target, false) : ''}
<div class="detail-item"><strong>会话ID:</strong> <code>${escapeHtml(vuln.conversation_id)}</code></div> ${vulnDetailField(vulnT('vulnerabilityPage.detailConversationId'), vuln.conversation_id, true)}
${vuln.task_id ? vulnDetailField(vulnT('vulnerabilityPage.detailTaskId'), vuln.task_id, true) : ''}
${vuln.task_queue_id ? vulnDetailField(vulnT('vulnerabilityPage.detailTaskQueueId'), vuln.task_queue_id, true) : ''}
${vuln.conversation_tag ? vulnDetailField(vulnT('vulnerabilityPage.detailConversationTag'), vuln.conversation_tag, false) : ''}
${vuln.task_tag ? vulnDetailField(vulnT('vulnerabilityPage.detailTaskTag'), vuln.task_tag, false) : ''}
</div> </div>
${vuln.proof ? `<div class="vulnerability-proof"><strong>证明:</strong><pre>${escapeHtml(vuln.proof)}</pre></div>` : ''} ${vuln.proof ? `<div class="vulnerability-proof"><strong>${escapeHtml(vulnT('vulnerabilityPage.detailProof'))}:</strong><pre>${escapeHtml(vuln.proof)}</pre></div>` : ''}
${vuln.impact ? `<div class="vulnerability-impact"><strong>影响:</strong> ${escapeHtml(vuln.impact)}</div>` : ''} ${vuln.impact ? `<div class="vulnerability-impact"><strong>${escapeHtml(vulnT('vulnerabilityPage.detailImpact'))}:</strong> ${escapeHtml(vuln.impact)}</div>` : ''}
${vuln.recommendation ? `<div class="vulnerability-recommendation"><strong>修复建议:</strong> ${escapeHtml(vuln.recommendation)}</div>` : ''} ${vuln.recommendation ? `<div class="vulnerability-recommendation"><strong>${escapeHtml(vulnT('vulnerabilityPage.detailRecommendation'))}:</strong> ${escapeHtml(vuln.recommendation)}</div>` : ''}
</div> </div>
</div> </div>
`; `;
}).join(''); }).join('');
listContainer.innerHTML = html; listContainer.innerHTML = html;
if (typeof window.applyTranslations === 'function') {
window.applyTranslations(listContainer);
}
} }
// 渲染分页控件 // 渲染分页控件
@@ -277,9 +360,9 @@ function renderVulnerabilityPagination() {
// 左侧:显示范围信息和每页数量选择器(参考Skills样式) // 左侧:显示范围信息和每页数量选择器(参考Skills样式)
paginationHTML += ` paginationHTML += `
<div class="pagination-info"> <div class="pagination-info">
<span>显示 ${start}-${end} / ${total} </span> <span>${escapeHtml(vulnT('skillsPage.paginationShow', { start, end, total }))}</span>
<label class="pagination-page-size"> <label class="pagination-page-size">
每页显示 ${escapeHtml(vulnT('skillsPage.perPageLabel'))}
<select id="vulnerability-page-size-pagination" onchange="changeVulnerabilityPageSize()"> <select id="vulnerability-page-size-pagination" onchange="changeVulnerabilityPageSize()">
<option value="10" ${pageSize === 10 ? 'selected' : ''}>10</option> <option value="10" ${pageSize === 10 ? 'selected' : ''}>10</option>
<option value="20" ${pageSize === 20 ? 'selected' : ''}>20</option> <option value="20" ${pageSize === 20 ? 'selected' : ''}>20</option>
@@ -293,17 +376,20 @@ function renderVulnerabilityPagination() {
// 右侧:分页按钮(参考Skills样式:首页、上一页、第X/Y页、下一页、末页) // 右侧:分页按钮(参考Skills样式:首页、上一页、第X/Y页、下一页、末页)
paginationHTML += ` paginationHTML += `
<div class="pagination-controls"> <div class="pagination-controls">
<button class="btn-secondary" onclick="loadVulnerabilities(1)" ${currentPage === 1 || total === 0 ? 'disabled' : ''}>首页</button> <button class="btn-secondary" onclick="loadVulnerabilities(1)" ${currentPage === 1 || total === 0 ? 'disabled' : ''}>${escapeHtml(vulnT('skillsPage.firstPage'))}</button>
<button class="btn-secondary" onclick="loadVulnerabilities(${currentPage - 1})" ${currentPage === 1 || total === 0 ? 'disabled' : ''}>上一页</button> <button class="btn-secondary" onclick="loadVulnerabilities(${currentPage - 1})" ${currentPage === 1 || total === 0 ? 'disabled' : ''}>${escapeHtml(vulnT('skillsPage.prevPage'))}</button>
<span class="pagination-page"> ${currentPage} / ${totalPages || 1} </span> <span class="pagination-page">${escapeHtml(vulnT('skillsPage.pageOf', { current: currentPage, total: totalPages || 1 }))}</span>
<button class="btn-secondary" onclick="loadVulnerabilities(${currentPage + 1})" ${currentPage >= totalPages || total === 0 ? 'disabled' : ''}>下一页</button> <button class="btn-secondary" onclick="loadVulnerabilities(${currentPage + 1})" ${currentPage >= totalPages || total === 0 ? 'disabled' : ''}>${escapeHtml(vulnT('skillsPage.nextPage'))}</button>
<button class="btn-secondary" onclick="loadVulnerabilities(${totalPages || 1})" ${currentPage >= totalPages || total === 0 ? 'disabled' : ''}>末页</button> <button class="btn-secondary" onclick="loadVulnerabilities(${totalPages || 1})" ${currentPage >= totalPages || total === 0 ? 'disabled' : ''}>${escapeHtml(vulnT('skillsPage.lastPage'))}</button>
</div> </div>
`; `;
paginationHTML += '</div>'; paginationHTML += '</div>';
paginationContainer.innerHTML = paginationHTML; paginationContainer.innerHTML = paginationHTML;
if (typeof window.applyTranslations === 'function') {
window.applyTranslations(paginationContainer);
}
} }
// 改变每页显示数量 // 改变每页显示数量
@@ -334,10 +420,12 @@ async function changeVulnerabilityPageSize() {
// 显示添加漏洞模态框 // 显示添加漏洞模态框
function showAddVulnerabilityModal() { function showAddVulnerabilityModal() {
currentVulnerabilityId = null; currentVulnerabilityId = null;
document.getElementById('vulnerability-modal-title').textContent = (typeof window.t === 'function' ? window.t('vulnerability.addVuln') : '添加漏洞'); document.getElementById('vulnerability-modal-title').textContent = vulnT('vulnerability.addVuln');
// 清空表单 // 清空表单
document.getElementById('vulnerability-conversation-id').value = ''; document.getElementById('vulnerability-conversation-id').value = '';
document.getElementById('vulnerability-conversation-tag').value = '';
document.getElementById('vulnerability-task-tag').value = '';
document.getElementById('vulnerability-title').value = ''; document.getElementById('vulnerability-title').value = '';
document.getElementById('vulnerability-description').value = ''; document.getElementById('vulnerability-description').value = '';
document.getElementById('vulnerability-severity').value = ''; document.getElementById('vulnerability-severity').value = '';
@@ -355,14 +443,16 @@ function showAddVulnerabilityModal() {
async function editVulnerability(id) { async function editVulnerability(id) {
try { try {
const response = await apiFetch(`/api/vulnerabilities/${id}`); const response = await apiFetch(`/api/vulnerabilities/${id}`);
if (!response.ok) throw new Error('获取漏洞失败'); if (!response.ok) throw new Error(vulnT('vulnerabilityPage.fetchFailed'));
const vuln = await response.json(); const vuln = await response.json();
currentVulnerabilityId = id; currentVulnerabilityId = id;
document.getElementById('vulnerability-modal-title').textContent = (typeof window.t === 'function' ? window.t('vulnerability.editVuln') : '编辑漏洞'); document.getElementById('vulnerability-modal-title').textContent = vulnT('vulnerability.editVuln');
// 填充表单 // 填充表单
document.getElementById('vulnerability-conversation-id').value = vuln.conversation_id || ''; document.getElementById('vulnerability-conversation-id').value = vuln.conversation_id || '';
document.getElementById('vulnerability-conversation-tag').value = vuln.conversation_tag || '';
document.getElementById('vulnerability-task-tag').value = vuln.task_tag || '';
document.getElementById('vulnerability-title').value = vuln.title || ''; document.getElementById('vulnerability-title').value = vuln.title || '';
document.getElementById('vulnerability-description').value = vuln.description || ''; document.getElementById('vulnerability-description').value = vuln.description || '';
document.getElementById('vulnerability-severity').value = vuln.severity || ''; document.getElementById('vulnerability-severity').value = vuln.severity || '';
@@ -376,7 +466,7 @@ async function editVulnerability(id) {
document.getElementById('vulnerability-modal').style.display = 'block'; document.getElementById('vulnerability-modal').style.display = 'block';
} catch (error) { } catch (error) {
console.error('加载漏洞失败:', error); console.error('加载漏洞失败:', error);
alert('加载漏洞失败: ' + error.message); alert(vulnT('vulnerability.loadFailed') + ': ' + error.message);
} }
} }
@@ -387,12 +477,14 @@ async function saveVulnerability() {
const severity = document.getElementById('vulnerability-severity').value; const severity = document.getElementById('vulnerability-severity').value;
if (!conversationId || !title || !severity) { if (!conversationId || !title || !severity) {
alert('请填写必填字段:会话ID、标题和严重程度'); alert(vulnT('vulnerabilityPage.saveRequiredFields'));
return; return;
} }
const data = { const data = {
conversation_id: conversationId, conversation_id: conversationId,
conversation_tag: document.getElementById('vulnerability-conversation-tag').value.trim(),
task_tag: document.getElementById('vulnerability-task-tag').value.trim(),
title: title, title: title,
description: document.getElementById('vulnerability-description').value.trim(), description: document.getElementById('vulnerability-description').value.trim(),
severity: severity, severity: severity,
@@ -420,7 +512,7 @@ async function saveVulnerability() {
if (!response.ok) { if (!response.ok) {
const error = await response.json(); const error = await response.json();
throw new Error(error.error || '保存失败'); throw new Error(error.error || vulnT('vulnerabilityPage.saveFailed'));
} }
closeVulnerabilityModal(); closeVulnerabilityModal();
@@ -430,13 +522,13 @@ async function saveVulnerability() {
loadVulnerabilities(); loadVulnerabilities();
} catch (error) { } catch (error) {
console.error('保存漏洞失败:', error); console.error('保存漏洞失败:', error);
alert('保存漏洞失败: ' + error.message); alert(vulnT('vulnerabilityPage.saveFailed') + ': ' + error.message);
} }
} }
// 删除漏洞 // 删除漏洞
async function deleteVulnerability(id) { async function deleteVulnerability(id) {
if (!confirm('确定要删除此漏洞吗?')) { if (!confirm(vulnT('vulnerability.deleteConfirm'))) {
return; return;
} }
@@ -445,7 +537,7 @@ async function deleteVulnerability(id) {
method: 'DELETE' method: 'DELETE'
}); });
if (!response.ok) throw new Error('删除失败'); if (!response.ok) throw new Error(vulnT('vulnerabilityPage.deleteFailed'));
loadVulnerabilityStats(); loadVulnerabilityStats();
// 删除后,如果当前页没有数据了,回到上一页 // 删除后,如果当前页没有数据了,回到上一页
@@ -458,7 +550,7 @@ async function deleteVulnerability(id) {
loadVulnerabilities(); loadVulnerabilities();
} catch (error) { } catch (error) {
console.error('删除漏洞失败:', error); console.error('删除漏洞失败:', error);
alert('删除漏洞失败: ' + error.message); alert(vulnT('vulnerabilityPage.deleteFailed') + ': ' + error.message);
} }
} }
@@ -472,6 +564,9 @@ function closeVulnerabilityModal() {
function filterVulnerabilities() { function filterVulnerabilities() {
vulnerabilityFilters.id = document.getElementById('vulnerability-id-filter').value.trim(); vulnerabilityFilters.id = document.getElementById('vulnerability-id-filter').value.trim();
vulnerabilityFilters.conversation_id = document.getElementById('vulnerability-conversation-filter').value.trim(); vulnerabilityFilters.conversation_id = document.getElementById('vulnerability-conversation-filter').value.trim();
vulnerabilityFilters.task_id = document.getElementById('vulnerability-task-filter').value.trim();
vulnerabilityFilters.conversation_tag = document.getElementById('vulnerability-conversation-tag-filter').value.trim();
vulnerabilityFilters.task_tag = document.getElementById('vulnerability-task-tag-filter').value.trim();
vulnerabilityFilters.severity = document.getElementById('vulnerability-severity-filter').value; vulnerabilityFilters.severity = document.getElementById('vulnerability-severity-filter').value;
vulnerabilityFilters.status = document.getElementById('vulnerability-status-filter').value; vulnerabilityFilters.status = document.getElementById('vulnerability-status-filter').value;
@@ -486,12 +581,18 @@ function filterVulnerabilities() {
function clearVulnerabilityFilters() { function clearVulnerabilityFilters() {
document.getElementById('vulnerability-id-filter').value = ''; document.getElementById('vulnerability-id-filter').value = '';
document.getElementById('vulnerability-conversation-filter').value = ''; document.getElementById('vulnerability-conversation-filter').value = '';
document.getElementById('vulnerability-task-filter').value = '';
document.getElementById('vulnerability-conversation-tag-filter').value = '';
document.getElementById('vulnerability-task-tag-filter').value = '';
document.getElementById('vulnerability-severity-filter').value = ''; document.getElementById('vulnerability-severity-filter').value = '';
document.getElementById('vulnerability-status-filter').value = ''; document.getElementById('vulnerability-status-filter').value = '';
vulnerabilityFilters = { vulnerabilityFilters = {
id: '', id: '',
conversation_id: '', conversation_id: '',
task_id: '',
conversation_tag: '',
task_tag: '',
severity: '', severity: '',
status: '' status: ''
}; };
@@ -532,67 +633,193 @@ function escapeHtml(text) {
return div.innerHTML; return div.innerHTML;
} }
// 将漏洞格式化为Markdown /** 复制详情字段(编码由 encodeURIComponent 传入,避免引号截断) */
function vulnerabilityCopyEncoded(evt, encoded) {
if (evt && evt.stopPropagation) {
evt.stopPropagation();
}
let text = '';
try {
text = decodeURIComponent(encoded);
} catch (e) {
return;
}
const done = () => {
if (evt && evt.target && evt.target.closest) {
const btn = evt.target.closest('.vuln-detail-field__copy');
if (btn) {
const t0 = btn.getAttribute('title') || '';
btn.setAttribute('title', vulnT('common.copied'));
setTimeout(() => btn.setAttribute('title', t0), 1600);
}
}
};
if (navigator.clipboard && typeof navigator.clipboard.writeText === 'function') {
navigator.clipboard.writeText(text).then(done).catch(() => {
try {
const ta = document.createElement('textarea');
ta.value = text;
ta.style.position = 'fixed';
ta.style.left = '-9999px';
document.body.appendChild(ta);
ta.select();
document.execCommand('copy');
document.body.removeChild(ta);
done();
} catch (err) {
console.error('copy failed', err);
}
});
} else {
try {
const ta = document.createElement('textarea');
ta.value = text;
ta.style.position = 'fixed';
ta.style.left = '-9999px';
document.body.appendChild(ta);
ta.select();
document.execCommand('copy');
document.body.removeChild(ta);
done();
} catch (err) {
console.error('copy failed', err);
}
}
}
function vulnDetailField(label, value, asCode) {
if (value === undefined || value === null || String(value) === '') {
return '';
}
const s = String(value);
const enc = encodeURIComponent(s);
const copyTitle = escapeHtml(vulnT('common.copy'));
const valueEl = asCode
? `<code class="vuln-detail-field-value">${escapeHtml(s)}</code>`
: `<span class="vuln-detail-field-value">${escapeHtml(s)}</span>`;
const copyBtn = `<button type="button" class="vuln-detail-field__copy" onclick="vulnerabilityCopyEncoded(event, '${enc}')" title="${copyTitle}" aria-label="${copyTitle}">
<svg width="14" height="14" viewBox="0 0 24 24" fill="none" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"><rect x="9" y="9" width="13" height="13" rx="2"/><path d="M5 15H4a2 2 0 0 1-2-2V4a2 2 0 0 1 2-2h9a2 2 0 0 1 2 2v1"/></svg>
</button>`;
return `<div class="vuln-detail-field">
<div class="vuln-detail-field__label">${escapeHtml(label)}</div>
<div class="vuln-detail-field__row">${valueEl}${copyBtn}</div>
</div>`;
}
// 将漏洞格式化为Markdown(章节标题随界面语言)
function formatVulnerabilityAsMarkdown(vuln) { function formatVulnerabilityAsMarkdown(vuln) {
const severityText = { const severityText = vulnSeverityLabel(vuln.severity);
'critical': '严重', const statusText = vulnStatusLabel(vuln.status);
'high': '高危', const loc = vulnDateLocale();
'medium': '中危', const createdDate = new Date(vuln.created_at).toLocaleString(loc);
'low': '低危', const updatedDate = new Date(vuln.updated_at).toLocaleString(loc);
'info': '信息' const L = (k) => vulnT('vulnerabilityMd.' + k);
}[vuln.severity] || vuln.severity;
const statusText = {
'open': '待处理',
'confirmed': '已确认',
'fixed': '已修复',
'false_positive': '误报'
}[vuln.status] || vuln.status;
const createdDate = new Date(vuln.created_at).toLocaleString('zh-CN');
const updatedDate = new Date(vuln.updated_at).toLocaleString('zh-CN');
let markdown = `# ${vuln.title}\n\n`; let markdown = `# ${vuln.title}\n\n`;
markdown += `## 基本信息\n\n`; markdown += `## ${L('headingBasic')}\n\n`;
markdown += `- **漏洞ID**: \`${vuln.id}\`\n`; markdown += `- **${L('labelId')}**: \`${vuln.id}\`\n`;
markdown += `- **严重程度**: ${severityText}\n`; markdown += `- **${L('labelSeverity')}**: ${severityText}\n`;
markdown += `- **状态**: ${statusText}\n`; markdown += `- **${L('labelStatus')}**: ${statusText}\n`;
if (vuln.type) { if (vuln.type) {
markdown += `- **类型**: ${vuln.type}\n`; markdown += `- **${L('labelType')}**: ${vuln.type}\n`;
} }
if (vuln.target) { if (vuln.target) {
markdown += `- **目标**: ${vuln.target}\n`; markdown += `- **${L('labelTarget')}**: ${vuln.target}\n`;
} }
markdown += `- **会话ID**: \`${vuln.conversation_id}\`\n`; markdown += `- **${L('labelConversationId')}**: \`${vuln.conversation_id}\`\n`;
markdown += `- **创建时间**: ${createdDate}\n`; if (vuln.task_id) {
markdown += `- **更新时间**: ${updatedDate}\n\n`; markdown += `- **${L('labelTaskId')}**: \`${vuln.task_id}\`\n`;
}
if (vuln.task_queue_id) {
markdown += `- **${L('labelTaskQueueId')}**: \`${vuln.task_queue_id}\`\n`;
}
if (vuln.conversation_tag) {
markdown += `- **${L('labelConversationTag')}**: ${vuln.conversation_tag}\n`;
}
if (vuln.task_tag) {
markdown += `- **${L('labelTaskTag')}**: ${vuln.task_tag}\n`;
}
markdown += `- **${L('labelCreated')}**: ${createdDate}\n`;
markdown += `- **${L('labelUpdated')}**: ${updatedDate}\n\n`;
if (vuln.description) { if (vuln.description) {
markdown += `## 描述\n\n${vuln.description}\n\n`; markdown += `## ${L('headingDescription')}\n\n${vuln.description}\n\n`;
} }
if (vuln.proof) { if (vuln.proof) {
markdown += `## 证明(POC\n\n\`\`\`\n${vuln.proof}\n\`\`\`\n\n`; markdown += `## ${L('headingProof')}\n\n\`\`\`\n${vuln.proof}\n\`\`\`\n\n`;
} }
if (vuln.impact) { if (vuln.impact) {
markdown += `## 影响\n\n${vuln.impact}\n\n`; markdown += `## ${L('headingImpact')}\n\n${vuln.impact}\n\n`;
} }
if (vuln.recommendation) { if (vuln.recommendation) {
markdown += `## 修复建议\n\n${vuln.recommendation}\n\n`; markdown += `## ${L('headingRecommendation')}\n\n${vuln.recommendation}\n\n`;
} }
return markdown; return markdown;
} }
function buildVulnerabilityFilterParams() {
const params = new URLSearchParams();
const keys = ['id', 'conversation_id', 'task_id', 'conversation_tag', 'task_tag', 'severity', 'status'];
keys.forEach((k) => {
if (vulnerabilityFilters[k]) {
params.append(k, vulnerabilityFilters[k]);
}
});
return params;
}
function triggerTextDownload(fileName, content) {
const blob = new Blob([content], { type: 'text/markdown;charset=utf-8' });
const url = URL.createObjectURL(blob);
const link = document.createElement('a');
link.href = url;
link.download = fileName;
document.body.appendChild(link);
link.click();
document.body.removeChild(link);
URL.revokeObjectURL(url);
}
async function exportVulnerabilityReports() {
try {
const params = buildVulnerabilityFilterParams();
params.set('mode', 'summary');
params.set('group_by', 'conversation');
const response = await apiFetch(`/api/vulnerabilities/export?${params.toString()}`);
if (!response.ok) {
const error = await response.json().catch(() => ({ error: vulnT('vulnerabilityPage.exportFailedMessage') }));
throw new Error(error.error || vulnT('vulnerabilityPage.exportFailedMessage'));
}
const data = await response.json();
const files = Array.isArray(data.files) ? data.files : [];
if (!files.length) {
alert(vulnT('vulnerabilityPage.exportNoResults'));
return;
}
files.forEach((file, idx) => {
setTimeout(() => triggerTextDownload(file.filename || `vulnerability-export-${idx + 1}.md`, file.content || ''), idx * 120);
});
if (files.length > 1) {
alert(vulnT('vulnerabilityPage.exportStarted', { count: files.length }));
}
} catch (error) {
console.error('导出漏洞报告失败:', error);
alert(vulnT('vulnerabilityPage.exportFailed') + ': ' + error.message);
}
}
// 下载漏洞为Markdown格式 // 下载漏洞为Markdown格式
async function downloadVulnerabilityAsMarkdown(id, event) { async function downloadVulnerabilityAsMarkdown(id, event) {
try { try {
const response = await apiFetch(`/api/vulnerabilities/${id}`); const response = await apiFetch(`/api/vulnerabilities/${id}`);
if (!response.ok) { if (!response.ok) {
throw new Error('获取漏洞失败'); throw new Error(vulnT('vulnerabilityPage.fetchFailed'));
} }
const vuln = await response.json(); const vuln = await response.json();
@@ -626,8 +853,8 @@ async function downloadVulnerabilityAsMarkdown(id, event) {
if (event && event.target) { if (event && event.target) {
const button = event.target.closest('button'); const button = event.target.closest('button');
if (button) { if (button) {
const originalTitle = button.title || '下载Markdown'; const originalTitle = button.title || vulnT('vulnerabilityPage.downloadMarkdownTitle');
button.title = '下载成功!'; button.title = vulnT('vulnerabilityPage.downloadOkTitle');
setTimeout(() => { setTimeout(() => {
button.title = originalTitle; button.title = originalTitle;
}, 2000); }, 2000);
@@ -635,7 +862,7 @@ async function downloadVulnerabilityAsMarkdown(id, event) {
} }
} catch (error) { } catch (error) {
console.error('下载失败:', error); console.error('下载失败:', error);
alert('下载失败: ' + error.message); alert(vulnT('vulnerabilityPage.downloadFailed') + ': ' + error.message);
} }
} }
@@ -645,5 +872,12 @@ window.onclick = function(event) {
if (event.target === modal) { if (event.target === modal) {
closeVulnerabilityModal(); closeVulnerabilityModal();
} }
} };
document.addEventListener('languagechange', function () {
const page = document.getElementById('page-vulnerabilities');
if (page && page.classList.contains('active')) {
loadVulnerabilities();
}
});
-11
View File
@@ -2881,17 +2881,6 @@ function runWebshellAiSend(conn, inputEl, sendBtn, messagesContainer) {
} else if (_et === 'warning') { } else if (_et === 'warning') {
appendTimelineItem('warning', '⚠️ ' + (_em || ''), '', _ed); appendTimelineItem('warning', '⚠️ ' + (_em || ''), '', _ed);
// ─── Eino recovery ───
} else if (_et === 'eino_recovery') {
var runIdx = _ed.runIndex != null ? _ed.runIndex : (_ed.einoRetry != null ? _ed.einoRetry + 1 : 1);
var maxRuns = _ed.maxRuns != null ? _ed.maxRuns : 3;
var recTitle = wsTOr('chat.einoRecoveryTitle', '') ||
('🔄 工具参数无效 · 第 ' + runIdx + '/' + maxRuns + ' 轮(已追加提示)');
if (typeof window.t === 'function') {
try { recTitle = window.t('chat.einoRecoveryTitle', { n: runIdx, max: maxRuns }); } catch (e) { /* */ }
}
appendTimelineItem('eino_recovery', recTitle, _em, _ed);
// ─── Tool calls ─── // ─── Tool calls ───
} else if (_et === 'tool_calls_detected' && _ed) { } else if (_et === 'tool_calls_detected' && _ed) {
var count = _ed.count || 0; var count = _ed.count || 0;
+29 -1
View File
@@ -1097,6 +1097,18 @@
<span data-i18n="vulnerabilityPage.conversationId">会话ID</span> <span data-i18n="vulnerabilityPage.conversationId">会话ID</span>
<input type="text" id="vulnerability-conversation-filter" data-i18n="vulnerabilityPage.filterConversation" data-i18n-attr="placeholder" placeholder="筛选特定会话" /> <input type="text" id="vulnerability-conversation-filter" data-i18n="vulnerabilityPage.filterConversation" data-i18n-attr="placeholder" placeholder="筛选特定会话" />
</label> </label>
<label>
<span data-i18n="vulnerabilityPage.taskOrQueueId">任务ID/队列ID</span>
<input type="text" id="vulnerability-task-filter" data-i18n="vulnerabilityPage.filterTaskOrQueue" data-i18n-attr="placeholder" placeholder="筛选任务ID或队列ID" />
</label>
<label>
<span data-i18n="vulnerabilityPage.conversationTag">对话标签</span>
<input type="text" id="vulnerability-conversation-tag-filter" data-i18n="vulnerabilityPage.filterConversationTag" data-i18n-attr="placeholder" placeholder="筛选对话标签" />
</label>
<label>
<span data-i18n="vulnerabilityPage.taskTag">任务标签</span>
<input type="text" id="vulnerability-task-tag-filter" data-i18n="vulnerabilityPage.filterTaskTag" data-i18n-attr="placeholder" placeholder="筛选任务标签" />
</label>
<label> <label>
<span data-i18n="vulnerabilityPage.severity">严重程度</span> <span data-i18n="vulnerabilityPage.severity">严重程度</span>
<select id="vulnerability-severity-filter"> <select id="vulnerability-severity-filter">
@@ -1120,6 +1132,7 @@
</label> </label>
<button class="btn-secondary" onclick="filterVulnerabilities()" data-i18n="vulnerabilityPage.filter">筛选</button> <button class="btn-secondary" onclick="filterVulnerabilities()" data-i18n="vulnerabilityPage.filter">筛选</button>
<button class="btn-secondary" onclick="clearVulnerabilityFilters()" data-i18n="vulnerabilityPage.clear">清除</button> <button class="btn-secondary" onclick="clearVulnerabilityFilters()" data-i18n="vulnerabilityPage.clear">清除</button>
<button class="btn-primary" onclick="exportVulnerabilityReports()" data-i18n="vulnerabilityPage.batchExport">批量导出</button>
</div> </div>
</div> </div>
@@ -2411,6 +2424,13 @@
</div> </div>
</div> </div>
</div> </div>
<div class="context-menu-item" onclick="navigateToVulnerabilitiesForContextConversation()">
<svg width="16" height="16" viewBox="0 0 24 24" fill="none" xmlns="http://www.w3.org/2000/svg">
<path d="M12 22s8-4 8-10V5l-8-3-8 3v7c0 6 8 10 8 10z" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"/>
<path d="M9 12l2 2 4-4" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"/>
</svg>
<span data-i18n="contextMenu.viewVulnerabilities">查看漏洞</span>
</div>
<div class="context-menu-divider"></div> <div class="context-menu-divider"></div>
<div class="context-menu-item" onclick="renameConversation()"> <div class="context-menu-item" onclick="renameConversation()">
<svg width="16" height="16" viewBox="0 0 24 24" fill="none" xmlns="http://www.w3.org/2000/svg"> <svg width="16" height="16" viewBox="0 0 24 24" fill="none" xmlns="http://www.w3.org/2000/svg">
@@ -2599,6 +2619,14 @@
<label for="vulnerability-conversation-id"><span data-i18n="vulnerabilityModal.conversationId">会话ID</span> <span style="color: red;">*</span></label> <label for="vulnerability-conversation-id"><span data-i18n="vulnerabilityModal.conversationId">会话ID</span> <span style="color: red;">*</span></label>
<input type="text" id="vulnerability-conversation-id" data-i18n="vulnerabilityModal.conversationIdPlaceholder" data-i18n-attr="placeholder" placeholder="输入会话ID" required /> <input type="text" id="vulnerability-conversation-id" data-i18n="vulnerabilityModal.conversationIdPlaceholder" data-i18n-attr="placeholder" placeholder="输入会话ID" required />
</div> </div>
<div class="form-group">
<label for="vulnerability-conversation-tag" data-i18n="vulnerabilityModal.conversationTag">对话标签</label>
<input type="text" id="vulnerability-conversation-tag" data-i18n="vulnerabilityModal.conversationTagPlaceholder" data-i18n-attr="placeholder" placeholder="如:红队演练A、客户A周报" />
</div>
<div class="form-group">
<label for="vulnerability-task-tag" data-i18n="vulnerabilityModal.taskTag">任务标签</label>
<input type="text" id="vulnerability-task-tag" data-i18n="vulnerabilityModal.taskTagPlaceholder" data-i18n-attr="placeholder" placeholder="如:批量扫描Q2、专项复测" />
</div>
<div class="form-group"> <div class="form-group">
<label for="vulnerability-title"><span data-i18n="vulnerabilityModal.title">标题</span> <span style="color: red;">*</span></label> <label for="vulnerability-title"><span data-i18n="vulnerabilityModal.title">标题</span> <span style="color: red;">*</span></label>
<input type="text" id="vulnerability-title" data-i18n="vulnerabilityModal.titlePlaceholder" data-i18n-attr="placeholder" placeholder="漏洞标题" required /> <input type="text" id="vulnerability-title" data-i18n="vulnerabilityModal.titlePlaceholder" data-i18n-attr="placeholder" placeholder="漏洞标题" required />
@@ -2817,7 +2845,7 @@
<script src="/static/js/terminal.js"></script> <script src="/static/js/terminal.js"></script>
<script src="/static/js/knowledge.js"></script> <script src="/static/js/knowledge.js"></script>
<script src="/static/js/skills.js"></script> <script src="/static/js/skills.js"></script>
<script src="/static/js/vulnerability.js?v=4"></script> <script src="/static/js/vulnerability.js?v=7"></script>
<script src="/static/js/webshell.js"></script> <script src="/static/js/webshell.js"></script>
<script src="/static/js/chat-files.js"></script> <script src="/static/js/chat-files.js"></script>
<script src="/static/js/tasks.js"></script> <script src="/static/js/tasks.js"></script>