mirror of
https://github.com/Ed1s0nZ/CyberStrikeAI.git
synced 2026-05-17 13:43:31 +02:00
Compare commits
28 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 1d9fcfd87e | |||
| 91cb650234 | |||
| 44e7d3b340 | |||
| 531b05299a | |||
| 0de69a6345 | |||
| 6a2a445f32 | |||
| 6aaa21d3e0 | |||
| 5c57d358ef | |||
| 65a3475c02 | |||
| 516ebf7a65 | |||
| 2558be3d7d | |||
| f6bb455313 | |||
| fc64356282 | |||
| 3d4fce9b89 | |||
| 3e41a47abf | |||
| 5b942c7bc8 | |||
| bcfb7b8da1 | |||
| f420ae0265 | |||
| e3f59b29ab | |||
| 87cba37203 | |||
| 4773b9e963 | |||
| eda5f9bba1 | |||
| 1318607813 | |||
| 5100924abe | |||
| 44079674dd | |||
| d959390e27 | |||
| 62a0d8cb71 | |||
| b53cae3a02 |
+15
-5
@@ -10,7 +10,7 @@
|
||||
# ============================================
|
||||
|
||||
# 前端显示的版本号(可选,不填则显示默认版本)
|
||||
version: "v1.5.10"
|
||||
version: "v1.5.16"
|
||||
# 服务器配置
|
||||
server:
|
||||
host: 0.0.0.0 # 监听地址,0.0.0.0 表示监听所有网络接口
|
||||
@@ -70,7 +70,7 @@ multi_agent:
|
||||
robot_use_multi_agent: true # true 时企业微信/钉钉/飞书机器人也走 Eino 多代理(成本更高)
|
||||
batch_use_multi_agent: false # true 时「批量任务」队列中每个子任务也走 Eino 多代理(成本更高)
|
||||
max_iteration: 0 # 主代理 / plan_execute 执行器最大轮次,0 表示沿用 agent.max_iterations
|
||||
# plan_execute 专用:execute↔replan 外层循环上限,0 表示 Eino 默认 10。Executor 未暴露 Handlers:patch/reduction/plantask 不作用于 PE,但 tool_search 工具列表拆分仍通过共享 ToolsConfig 作用于执行器。
|
||||
# plan_execute 专用:execute↔replan 外层循环上限,0 表示 Eino 默认 10。当前实现下 Executor 会挂载 patch/reduction/tool_search 等前置中间件。
|
||||
plan_execute_loop_max_iterations: 0
|
||||
sub_agent_max_iterations: 120
|
||||
sub_agent_user_context_max_runes: 0 # 子代理 task 描述中自动注入用户原始请求的字符上限;0=默认2000,负数=禁用
|
||||
@@ -87,15 +87,25 @@ multi_agent:
|
||||
# Eino ADK 中间件与 Deep/Supervisor 调参(结构体见 internal/config/config.go → MultiAgentEinoMiddlewareConfig)
|
||||
eino_middleware:
|
||||
patch_tool_calls: true # true:修补历史中无 tool_result 的悬空 tool_call(流式中断/重试后更稳);false:关闭;字段省略时默认等同 true
|
||||
tool_search_enable: false # true:工具数 ≥ min 时启用 tool_search,仅前 N 个工具常驻,其余按正则按需解锁,省 token、减误选;false:全量工具进上下文
|
||||
tool_search_enable: true # true:工具数 ≥ min 时启用 tool_search,仅前 N 个工具常驻,其余按正则按需解锁,省 token、减误选;false:全量工具进上下文
|
||||
tool_search_min_tools: 20 # 达到该数量才启用 tool_search(避免工具很少时多此一举);与 always_visible 配合使用
|
||||
tool_search_always_visible: 12 # 始终直接暴露给模型的工具个数(顺序与角色工具列表一致);其余工具进入动态池,需 tool_search 解锁
|
||||
tool_search_always_visible_tools: [read_file, glob, grep, write_file, edit_file, execute, task, transfer_to_agent, exit, write_todos, skill, tool_search, TaskCreate, TaskGet, TaskUpdate, TaskList, record_vulnerability, list_knowledge_risk_types, search_knowledge_base, webshell_exec, webshell_file_list, webshell_file_read, webshell_file_write, manage_webshell_list, manage_webshell_add, manage_webshell_update, manage_webshell_delete, manage_webshell_test, batch_task_list, batch_task_get, batch_task_start, batch_task_rerun, batch_task_pause, batch_task_update_metadata, batch_task_update_schedule, batch_task_schedule_enabled, batch_task_update_task, batch_task_remove_task, batch_task_delete, batch_task_create, batch_task_add_task, http-framework-test] # 后端内置常驻工具白名单(优先于 always_visible 数量策略)
|
||||
plantask_enable: false # true:主代理(Deep / Supervisor 主)挂载 TaskCreate/Get/Update/List;需 eino_skills 可用且 skills_dir 存在,否则仅打日志并跳过
|
||||
plantask_rel_dir: .eino/plantask # 结构化任务文件相对 skills_dir 的子目录,其下再按会话 ID 分子目录存放
|
||||
reduction_enable: false # true:大工具输出截断/落盘以控上下文;依赖与 plantask 相同的 eino local 写盘后端,无后端时不挂载
|
||||
reduction_enable: true # true:大工具输出截断/落盘以控上下文;依赖与 plantask 相同的 eino local 写盘后端,无后端时不挂载
|
||||
reduction_max_length_for_trunc: 50000 # 单条工具结果超过该字符数(bytes)时截断并落盘(由 reduction 中间件处理)
|
||||
reduction_max_tokens_for_clear: 160000 # 历史工具结果清理阈值(tokens),超阈值时在模型调用前清理旧结果
|
||||
reduction_root_dir: "" # 非空:截断/清理内容落盘根路径;空:使用系统临时目录下按会话隔离的默认路径
|
||||
reduction_clear_exclude: [] # 不参与「清理阶段」的工具名额外列表(会与 task/transfer/exit 等内置排除项合并);需要时用 YAML 列表填写
|
||||
reduction_sub_agents: false # true:子代理也挂 reduction;false:仅编排主代理使用 reduction
|
||||
reduction_sub_agents: true # true:子代理也挂 reduction;false:仅编排主代理使用 reduction
|
||||
summarization_trigger_ratio: 0.8 # summarization 触发比例(max_total_tokens * ratio),建议 0.75~0.85
|
||||
summarization_emit_internal_events: true # true:发出 summarization 内部事件(便于诊断)
|
||||
history_input_budget_ratio: 0.35 # 历史入队预算比例(max_total_tokens * ratio)
|
||||
plan_execute_user_input_budget_ratio: 0.35 # plan_execute 中 userInput 预算比例(planner/replanner/executor 共用)
|
||||
plan_execute_executed_steps_budget_ratio: 0.2 # plan_execute 中 executed_steps 预算比例
|
||||
plan_execute_max_step_result_runes: 4000 # plan_execute 每步结果最大字符数(超出截断)
|
||||
plan_execute_keep_last_steps: 8 # plan_execute 仅保留最近 N 步正文,早期步骤折叠为标题
|
||||
checkpoint_dir: "" # 非空:为 adk.NewRunner 启用按会话子目录的文件型 CheckPointStore,便于中断恢复持久化;Resume 的 HTTP/前端流程需另行对接
|
||||
deep_output_key: "" # 非空:将最终助手输出写入 adk session 的键名(Deep 与 Supervisor 主代理);空表示不写入
|
||||
deep_model_retry_max_retries: 0 # >0:ChatModel 调用失败时的框架级最大重试次数(Deep 与 Supervisor 主);0:不重试
|
||||
|
||||
+51
-32
@@ -39,6 +39,7 @@ type Agent struct {
|
||||
toolNameMapping map[string]string // 工具名称映射:OpenAI格式 -> 原始格式(用于外部MCP工具)
|
||||
currentConversationID string // 当前对话ID(用于自动传递给工具)
|
||||
promptBaseDir string // 解析 system_prompt_path 时相对路径的基准目录(通常为 config.yaml 所在目录)
|
||||
toolDescriptionMode string // 工具描述模式: "short" | "full",默认 short
|
||||
}
|
||||
|
||||
// ResultStorage 结果存储接口(直接使用 storage 包的类型)
|
||||
@@ -162,6 +163,7 @@ func NewAgent(cfg *config.OpenAIConfig, agentCfg *config.AgentConfig, mcpServer
|
||||
resultStorage: resultStorage,
|
||||
largeResultThreshold: largeResultThreshold,
|
||||
toolNameMapping: make(map[string]string), // 初始化工具名称映射
|
||||
toolDescriptionMode: "short",
|
||||
}
|
||||
}
|
||||
|
||||
@@ -336,10 +338,10 @@ func (fc *FunctionCall) UnmarshalJSON(data []byte) error {
|
||||
|
||||
// AgentLoopResult Agent Loop执行结果
|
||||
type AgentLoopResult struct {
|
||||
Response string
|
||||
MCPExecutionIDs []string
|
||||
LastReActInput string // 最后一轮ReAct的输入(压缩后的messages,JSON格式)
|
||||
LastReActOutput string // 最终大模型的输出
|
||||
Response string
|
||||
MCPExecutionIDs []string
|
||||
LastAgentTraceInput string // 最后一轮代理消息轨迹(压缩后的 messages,JSON;与 multiagent.RunResult 字段对齐)
|
||||
LastAgentTraceOutput string // 最终助手输出文本
|
||||
}
|
||||
|
||||
// ProgressCallback 进度回调函数类型
|
||||
@@ -471,7 +473,7 @@ func (a *Agent) AgentLoopWithProgress(ctx context.Context, userInput string, his
|
||||
}
|
||||
|
||||
// 用于保存当前的messages,以便在异常情况下也能保存ReAct输入
|
||||
var currentReActInput string
|
||||
var currentAgentTraceInput string
|
||||
|
||||
maxIterations := a.maxIterations
|
||||
thinkingStreamSeq := 0
|
||||
@@ -490,9 +492,9 @@ func (a *Agent) AgentLoopWithProgress(ctx context.Context, userInput string, his
|
||||
if err != nil {
|
||||
a.logger.Warn("序列化ReAct输入失败", zap.Error(err))
|
||||
} else {
|
||||
currentReActInput = string(messagesJSON)
|
||||
currentAgentTraceInput = string(messagesJSON)
|
||||
// 更新result中的值,确保始终保存最新的ReAct输入(压缩后的)
|
||||
result.LastReActInput = currentReActInput
|
||||
result.LastAgentTraceInput = currentAgentTraceInput
|
||||
}
|
||||
|
||||
// 检查上下文是否已取消
|
||||
@@ -500,13 +502,13 @@ func (a *Agent) AgentLoopWithProgress(ctx context.Context, userInput string, his
|
||||
case <-ctx.Done():
|
||||
// 上下文被取消(可能是用户主动暂停或其他原因)
|
||||
a.logger.Info("检测到上下文取消,保存当前ReAct数据", zap.Error(ctx.Err()))
|
||||
result.LastReActInput = currentReActInput
|
||||
result.LastAgentTraceInput = currentAgentTraceInput
|
||||
if ctx.Err() == context.Canceled {
|
||||
result.Response = "任务已被取消。"
|
||||
} else {
|
||||
result.Response = fmt.Sprintf("任务执行中断: %v", ctx.Err())
|
||||
}
|
||||
result.LastReActOutput = result.Response
|
||||
result.LastAgentTraceOutput = result.Response
|
||||
return result, ctx.Err()
|
||||
default:
|
||||
}
|
||||
@@ -600,10 +602,10 @@ func (a *Agent) AgentLoopWithProgress(ctx context.Context, userInput string, his
|
||||
})
|
||||
if err != nil {
|
||||
// API调用失败,保存当前的ReAct输入和错误信息作为输出
|
||||
result.LastReActInput = currentReActInput
|
||||
result.LastAgentTraceInput = currentAgentTraceInput
|
||||
errorMsg := fmt.Sprintf("调用OpenAI失败: %v", err)
|
||||
result.Response = errorMsg
|
||||
result.LastReActOutput = errorMsg
|
||||
result.LastAgentTraceOutput = errorMsg
|
||||
a.logger.Warn("OpenAI调用失败,已保存ReAct数据", zap.Error(err))
|
||||
return result, fmt.Errorf("调用OpenAI失败: %w", err)
|
||||
}
|
||||
@@ -629,19 +631,19 @@ func (a *Agent) AgentLoopWithProgress(ctx context.Context, userInput string, his
|
||||
continue
|
||||
}
|
||||
// OpenAI返回错误,保存当前的ReAct输入和错误信息作为输出
|
||||
result.LastReActInput = currentReActInput
|
||||
result.LastAgentTraceInput = currentAgentTraceInput
|
||||
errorMsg := fmt.Sprintf("OpenAI错误: %s", response.Error.Message)
|
||||
result.Response = errorMsg
|
||||
result.LastReActOutput = errorMsg
|
||||
result.LastAgentTraceOutput = errorMsg
|
||||
return result, fmt.Errorf("OpenAI错误: %s", response.Error.Message)
|
||||
}
|
||||
|
||||
if len(response.Choices) == 0 {
|
||||
// 没有收到响应,保存当前的ReAct输入和错误信息作为输出
|
||||
result.LastReActInput = currentReActInput
|
||||
result.LastAgentTraceInput = currentAgentTraceInput
|
||||
errorMsg := "没有收到响应"
|
||||
result.Response = errorMsg
|
||||
result.LastReActOutput = errorMsg
|
||||
result.LastAgentTraceOutput = errorMsg
|
||||
return result, fmt.Errorf("没有收到响应")
|
||||
}
|
||||
|
||||
@@ -816,7 +818,7 @@ func (a *Agent) AgentLoopWithProgress(ctx context.Context, userInput string, his
|
||||
})
|
||||
if strings.TrimSpace(streamText) != "" {
|
||||
result.Response = streamText
|
||||
result.LastReActOutput = result.Response
|
||||
result.LastAgentTraceOutput = result.Response
|
||||
sendProgress("progress", "总结生成完成", nil)
|
||||
return result, nil
|
||||
}
|
||||
@@ -863,14 +865,14 @@ func (a *Agent) AgentLoopWithProgress(ctx context.Context, userInput string, his
|
||||
})
|
||||
if strings.TrimSpace(streamText) != "" {
|
||||
result.Response = streamText
|
||||
result.LastReActOutput = result.Response
|
||||
result.LastAgentTraceOutput = result.Response
|
||||
sendProgress("progress", "总结生成完成", nil)
|
||||
return result, nil
|
||||
}
|
||||
// 如果获取总结失败,使用当前回复作为结果
|
||||
if choice.Message.Content != "" {
|
||||
result.Response = choice.Message.Content
|
||||
result.LastReActOutput = result.Response
|
||||
result.LastAgentTraceOutput = result.Response
|
||||
return result, nil
|
||||
}
|
||||
// 如果都没有内容,跳出循环,让后续逻辑处理
|
||||
@@ -881,7 +883,7 @@ func (a *Agent) AgentLoopWithProgress(ctx context.Context, userInput string, his
|
||||
if choice.FinishReason == "stop" {
|
||||
sendProgress("progress", "正在生成最终回复...", nil)
|
||||
result.Response = choice.Message.Content
|
||||
result.LastReActOutput = result.Response
|
||||
result.LastAgentTraceOutput = result.Response
|
||||
return result, nil
|
||||
}
|
||||
}
|
||||
@@ -910,19 +912,19 @@ func (a *Agent) AgentLoopWithProgress(ctx context.Context, userInput string, his
|
||||
})
|
||||
if strings.TrimSpace(streamText) != "" {
|
||||
result.Response = streamText
|
||||
result.LastReActOutput = result.Response
|
||||
result.LastAgentTraceOutput = result.Response
|
||||
sendProgress("progress", "总结生成完成", nil)
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// 如果无法生成总结,返回友好的提示
|
||||
result.Response = fmt.Sprintf("已达到最大迭代次数(%d轮)。系统已执行了多轮测试,但由于达到迭代上限,无法继续自动执行。建议您查看已执行的工具结果,或提出新的测试请求以继续测试。", a.maxIterations)
|
||||
result.LastReActOutput = result.Response
|
||||
result.LastAgentTraceOutput = result.Response
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// getAvailableTools 获取可用工具
|
||||
// 从MCP服务器动态获取工具列表,使用简短描述以减少token消耗
|
||||
// 从MCP服务器动态获取工具列表,描述模式由 tool_description_mode 控制
|
||||
// roleTools: 角色配置的工具列表(toolKey格式),如果为空或nil,则使用所有工具(默认角色)
|
||||
func (a *Agent) getAvailableTools(roleTools []string) []Tool {
|
||||
// 构建角色工具集合(用于快速查找)
|
||||
@@ -946,11 +948,7 @@ func (a *Agent) getAvailableTools(roleTools []string) []Tool {
|
||||
continue // 不在角色工具列表中,跳过
|
||||
}
|
||||
}
|
||||
// 使用简短描述(如果存在),否则使用详细描述
|
||||
description := mcpTool.ShortDescription
|
||||
if description == "" {
|
||||
description = mcpTool.Description
|
||||
}
|
||||
description := a.pickToolDescription(mcpTool.ShortDescription, mcpTool.Description)
|
||||
|
||||
// 转换schema中的类型为OpenAI标准类型
|
||||
convertedSchema := a.convertSchemaTypes(mcpTool.InputSchema)
|
||||
@@ -1024,11 +1022,7 @@ func (a *Agent) getAvailableTools(roleTools []string) []Tool {
|
||||
continue
|
||||
}
|
||||
|
||||
// 使用简短描述(如果存在),否则使用详细描述
|
||||
description := externalTool.ShortDescription
|
||||
if description == "" {
|
||||
description = externalTool.Description
|
||||
}
|
||||
description := a.pickToolDescription(externalTool.ShortDescription, externalTool.Description)
|
||||
|
||||
// 转换schema中的类型为OpenAI标准类型
|
||||
convertedSchema := a.convertSchemaTypes(externalTool.InputSchema)
|
||||
@@ -1063,6 +1057,19 @@ func (a *Agent) getAvailableTools(roleTools []string) []Tool {
|
||||
return tools
|
||||
}
|
||||
|
||||
func (a *Agent) pickToolDescription(shortDesc, fullDesc string) string {
|
||||
a.mu.RLock()
|
||||
mode := strings.TrimSpace(strings.ToLower(a.toolDescriptionMode))
|
||||
a.mu.RUnlock()
|
||||
if mode == "full" {
|
||||
return fullDesc
|
||||
}
|
||||
if shortDesc != "" {
|
||||
return shortDesc
|
||||
}
|
||||
return fullDesc
|
||||
}
|
||||
|
||||
// convertSchemaTypes 递归转换schema中的类型为OpenAI标准类型
|
||||
func (a *Agent) convertSchemaTypes(schema map[string]interface{}) map[string]interface{} {
|
||||
if schema == nil {
|
||||
@@ -1665,6 +1672,18 @@ func (a *Agent) UpdateMaxIterations(maxIterations int) {
|
||||
}
|
||||
}
|
||||
|
||||
// UpdateToolDescriptionMode 更新工具描述模式(short/full)
|
||||
func (a *Agent) UpdateToolDescriptionMode(mode string) {
|
||||
a.mu.Lock()
|
||||
defer a.mu.Unlock()
|
||||
mode = strings.TrimSpace(strings.ToLower(mode))
|
||||
if mode != "full" {
|
||||
mode = "short"
|
||||
}
|
||||
a.toolDescriptionMode = mode
|
||||
a.logger.Info("Agent工具描述模式已更新", zap.String("tool_description_mode", mode))
|
||||
}
|
||||
|
||||
// formatToolError 格式化工具错误信息,提供更友好的错误描述
|
||||
func (a *Agent) formatToolError(toolName string, args map[string]interface{}, err error) string {
|
||||
errorMsg := fmt.Sprintf(`工具执行失败
|
||||
|
||||
@@ -18,62 +18,62 @@ import (
|
||||
func setupTestAgent(t *testing.T) (*Agent, *storage.FileResultStorage) {
|
||||
logger := zap.NewNop()
|
||||
mcpServer := mcp.NewServer(logger)
|
||||
|
||||
|
||||
openAICfg := &config.OpenAIConfig{
|
||||
APIKey: "test-key",
|
||||
BaseURL: "https://api.test.com/v1",
|
||||
Model: "test-model",
|
||||
}
|
||||
|
||||
|
||||
agentCfg := &config.AgentConfig{
|
||||
MaxIterations: 10,
|
||||
LargeResultThreshold: 100, // 设置较小的阈值便于测试
|
||||
ResultStorageDir: "",
|
||||
}
|
||||
|
||||
|
||||
agent := NewAgent(openAICfg, agentCfg, mcpServer, nil, logger, 10)
|
||||
|
||||
|
||||
// 创建测试存储
|
||||
tmpDir := filepath.Join(os.TempDir(), "test_agent_storage_"+time.Now().Format("20060102_150405"))
|
||||
testStorage, err := storage.NewFileResultStorage(tmpDir, logger)
|
||||
if err != nil {
|
||||
t.Fatalf("创建测试存储失败: %v", err)
|
||||
}
|
||||
|
||||
|
||||
agent.SetResultStorage(testStorage)
|
||||
|
||||
|
||||
return agent, testStorage
|
||||
}
|
||||
|
||||
func TestAgent_FormatMinimalNotification(t *testing.T) {
|
||||
agent, testStorage := setupTestAgent(t)
|
||||
_ = testStorage // 避免未使用变量警告
|
||||
|
||||
|
||||
executionID := "test_exec_001"
|
||||
toolName := "nmap_scan"
|
||||
size := 50000
|
||||
lineCount := 1000
|
||||
filePath := "tmp/test_exec_001.txt"
|
||||
|
||||
|
||||
notification := agent.formatMinimalNotification(executionID, toolName, size, lineCount, filePath)
|
||||
|
||||
|
||||
// 验证通知包含必要信息
|
||||
if !strings.Contains(notification, executionID) {
|
||||
t.Errorf("通知中应该包含执行ID: %s", executionID)
|
||||
}
|
||||
|
||||
|
||||
if !strings.Contains(notification, toolName) {
|
||||
t.Errorf("通知中应该包含工具名称: %s", toolName)
|
||||
}
|
||||
|
||||
|
||||
if !strings.Contains(notification, "50000") {
|
||||
t.Errorf("通知中应该包含大小信息")
|
||||
}
|
||||
|
||||
|
||||
if !strings.Contains(notification, "1000") {
|
||||
t.Errorf("通知中应该包含行数信息")
|
||||
}
|
||||
|
||||
|
||||
if !strings.Contains(notification, "query_execution_result") {
|
||||
t.Errorf("通知中应该包含查询工具的使用说明")
|
||||
}
|
||||
@@ -81,7 +81,7 @@ func TestAgent_FormatMinimalNotification(t *testing.T) {
|
||||
|
||||
func TestAgent_ExecuteToolViaMCP_LargeResult(t *testing.T) {
|
||||
agent, _ := setupTestAgent(t)
|
||||
|
||||
|
||||
// 创建模拟的MCP工具结果(大结果)
|
||||
largeResult := &mcp.ToolResult{
|
||||
Content: []mcp.Content{
|
||||
@@ -92,59 +92,59 @@ func TestAgent_ExecuteToolViaMCP_LargeResult(t *testing.T) {
|
||||
},
|
||||
IsError: false,
|
||||
}
|
||||
|
||||
|
||||
// 模拟MCP服务器返回大结果
|
||||
// 由于我们需要模拟CallTool的行为,这里需要创建一个mock或者使用实际的MCP服务器
|
||||
// 为了简化测试,我们直接测试结果处理逻辑
|
||||
|
||||
|
||||
// 设置阈值
|
||||
agent.mu.Lock()
|
||||
agent.largeResultThreshold = 1000 // 设置较小的阈值
|
||||
agent.mu.Unlock()
|
||||
|
||||
|
||||
// 创建执行ID
|
||||
executionID := "test_exec_large_001"
|
||||
toolName := "test_tool"
|
||||
|
||||
|
||||
// 格式化结果
|
||||
var resultText strings.Builder
|
||||
for _, content := range largeResult.Content {
|
||||
resultText.WriteString(content.Text)
|
||||
resultText.WriteString("\n")
|
||||
}
|
||||
|
||||
|
||||
resultStr := resultText.String()
|
||||
resultSize := len(resultStr)
|
||||
|
||||
|
||||
// 检测大结果并保存
|
||||
agent.mu.RLock()
|
||||
threshold := agent.largeResultThreshold
|
||||
storage := agent.resultStorage
|
||||
agent.mu.RUnlock()
|
||||
|
||||
|
||||
if resultSize > threshold && storage != nil {
|
||||
// 保存大结果
|
||||
err := storage.SaveResult(executionID, toolName, resultStr)
|
||||
if err != nil {
|
||||
t.Fatalf("保存大结果失败: %v", err)
|
||||
}
|
||||
|
||||
|
||||
// 生成通知
|
||||
lines := strings.Split(resultStr, "\n")
|
||||
filePath := storage.GetResultPath(executionID)
|
||||
notification := agent.formatMinimalNotification(executionID, toolName, resultSize, len(lines), filePath)
|
||||
|
||||
|
||||
// 验证通知格式
|
||||
if !strings.Contains(notification, executionID) {
|
||||
t.Errorf("通知中应该包含执行ID")
|
||||
}
|
||||
|
||||
|
||||
// 验证结果已保存
|
||||
savedResult, err := storage.GetResult(executionID)
|
||||
if err != nil {
|
||||
t.Fatalf("获取保存的结果失败: %v", err)
|
||||
}
|
||||
|
||||
|
||||
if savedResult != resultStr {
|
||||
t.Errorf("保存的结果与原始结果不匹配")
|
||||
}
|
||||
@@ -155,7 +155,7 @@ func TestAgent_ExecuteToolViaMCP_LargeResult(t *testing.T) {
|
||||
|
||||
func TestAgent_ExecuteToolViaMCP_SmallResult(t *testing.T) {
|
||||
agent, _ := setupTestAgent(t)
|
||||
|
||||
|
||||
// 创建小结果
|
||||
smallResult := &mcp.ToolResult{
|
||||
Content: []mcp.Content{
|
||||
@@ -166,32 +166,32 @@ func TestAgent_ExecuteToolViaMCP_SmallResult(t *testing.T) {
|
||||
},
|
||||
IsError: false,
|
||||
}
|
||||
|
||||
|
||||
// 设置较大的阈值
|
||||
agent.mu.Lock()
|
||||
agent.largeResultThreshold = 100000 // 100KB
|
||||
agent.mu.Unlock()
|
||||
|
||||
|
||||
// 格式化结果
|
||||
var resultText strings.Builder
|
||||
for _, content := range smallResult.Content {
|
||||
resultText.WriteString(content.Text)
|
||||
resultText.WriteString("\n")
|
||||
}
|
||||
|
||||
|
||||
resultStr := resultText.String()
|
||||
resultSize := len(resultStr)
|
||||
|
||||
|
||||
// 检测大结果
|
||||
agent.mu.RLock()
|
||||
threshold := agent.largeResultThreshold
|
||||
storage := agent.resultStorage
|
||||
agent.mu.RUnlock()
|
||||
|
||||
|
||||
if resultSize > threshold && storage != nil {
|
||||
t.Fatal("小结果不应该被保存")
|
||||
}
|
||||
|
||||
|
||||
// 小结果应该直接返回
|
||||
if resultSize <= threshold {
|
||||
// 这是预期的行为
|
||||
@@ -203,26 +203,26 @@ func TestAgent_ExecuteToolViaMCP_SmallResult(t *testing.T) {
|
||||
|
||||
func TestAgent_SetResultStorage(t *testing.T) {
|
||||
agent, _ := setupTestAgent(t)
|
||||
|
||||
|
||||
// 创建新的存储
|
||||
tmpDir := filepath.Join(os.TempDir(), "test_new_storage_"+time.Now().Format("20060102_150405"))
|
||||
newStorage, err := storage.NewFileResultStorage(tmpDir, zap.NewNop())
|
||||
if err != nil {
|
||||
t.Fatalf("创建新存储失败: %v", err)
|
||||
}
|
||||
|
||||
|
||||
// 设置新存储
|
||||
agent.SetResultStorage(newStorage)
|
||||
|
||||
|
||||
// 验证存储已更新
|
||||
agent.mu.RLock()
|
||||
currentStorage := agent.resultStorage
|
||||
agent.mu.RUnlock()
|
||||
|
||||
|
||||
if currentStorage != newStorage {
|
||||
t.Fatal("存储未正确更新")
|
||||
}
|
||||
|
||||
|
||||
// 清理
|
||||
os.RemoveAll(tmpDir)
|
||||
}
|
||||
@@ -230,24 +230,24 @@ func TestAgent_SetResultStorage(t *testing.T) {
|
||||
func TestAgent_NewAgent_DefaultValues(t *testing.T) {
|
||||
logger := zap.NewNop()
|
||||
mcpServer := mcp.NewServer(logger)
|
||||
|
||||
|
||||
openAICfg := &config.OpenAIConfig{
|
||||
APIKey: "test-key",
|
||||
BaseURL: "https://api.test.com/v1",
|
||||
Model: "test-model",
|
||||
}
|
||||
|
||||
|
||||
// 测试默认配置
|
||||
agent := NewAgent(openAICfg, nil, mcpServer, nil, logger, 0)
|
||||
|
||||
|
||||
if agent.maxIterations != 30 {
|
||||
t.Errorf("默认迭代次数不匹配。期望: 30, 实际: %d", agent.maxIterations)
|
||||
}
|
||||
|
||||
|
||||
agent.mu.RLock()
|
||||
threshold := agent.largeResultThreshold
|
||||
agent.mu.RUnlock()
|
||||
|
||||
|
||||
if threshold != 50*1024 {
|
||||
t.Errorf("默认阈值不匹配。期望: %d, 实际: %d", 50*1024, threshold)
|
||||
}
|
||||
@@ -256,31 +256,30 @@ func TestAgent_NewAgent_DefaultValues(t *testing.T) {
|
||||
func TestAgent_NewAgent_CustomConfig(t *testing.T) {
|
||||
logger := zap.NewNop()
|
||||
mcpServer := mcp.NewServer(logger)
|
||||
|
||||
|
||||
openAICfg := &config.OpenAIConfig{
|
||||
APIKey: "test-key",
|
||||
BaseURL: "https://api.test.com/v1",
|
||||
Model: "test-model",
|
||||
}
|
||||
|
||||
|
||||
agentCfg := &config.AgentConfig{
|
||||
MaxIterations: 20,
|
||||
LargeResultThreshold: 100 * 1024, // 100KB
|
||||
ResultStorageDir: "custom_tmp",
|
||||
}
|
||||
|
||||
|
||||
agent := NewAgent(openAICfg, agentCfg, mcpServer, nil, logger, 15)
|
||||
|
||||
|
||||
if agent.maxIterations != 15 {
|
||||
t.Errorf("迭代次数不匹配。期望: 15, 实际: %d", agent.maxIterations)
|
||||
}
|
||||
|
||||
|
||||
agent.mu.RLock()
|
||||
threshold := agent.largeResultThreshold
|
||||
agent.mu.RUnlock()
|
||||
|
||||
|
||||
if threshold != 100*1024 {
|
||||
t.Errorf("阈值不匹配。期望: %d, 实际: %d", 100*1024, threshold)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -256,11 +256,11 @@ func orchestratorConfigFromOrchestrator(o *OrchestratorMarkdown) config.MultiAge
|
||||
return config.MultiAgentSubConfig{}
|
||||
}
|
||||
return config.MultiAgentSubConfig{
|
||||
ID: o.EinoName,
|
||||
Name: o.DisplayName,
|
||||
Description: o.Description,
|
||||
Instruction: o.Instruction,
|
||||
Kind: "orchestrator",
|
||||
ID: o.EinoName,
|
||||
Name: o.DisplayName,
|
||||
Description: o.Description,
|
||||
Instruction: o.Instruction,
|
||||
Kind: "orchestrator",
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -133,6 +133,7 @@ func New(cfg *config.Config, log *logger.Logger) (*App, error) {
|
||||
maxIterations = 30 // 默认值
|
||||
}
|
||||
agent := agent.NewAgent(&cfg.OpenAI, &cfg.Agent, mcpServer, externalMCPMgr, log.Logger, maxIterations)
|
||||
agent.UpdateToolDescriptionMode(cfg.Security.ToolDescriptionMode)
|
||||
|
||||
// 设置结果存储到Agent
|
||||
agent.SetResultStorage(resultStorage)
|
||||
@@ -317,6 +318,7 @@ func New(cfg *config.Config, log *logger.Logger) (*App, error) {
|
||||
}
|
||||
monitorHandler := handler.NewMonitorHandler(mcpServer, executor, db, log.Logger)
|
||||
monitorHandler.SetExternalMCPManager(externalMCPMgr) // 设置外部MCP管理器,以便获取外部MCP执行记录
|
||||
notificationHandler := handler.NewNotificationHandler(db, agentHandler, log.Logger)
|
||||
groupHandler := handler.NewGroupHandler(db, log.Logger)
|
||||
authHandler := handler.NewAuthHandler(authManager, cfg, configPath, log.Logger)
|
||||
attackChainHandler := handler.NewAttackChainHandler(db, &cfg.OpenAI, log.Logger)
|
||||
@@ -433,6 +435,7 @@ func New(cfg *config.Config, log *logger.Logger) (*App, error) {
|
||||
authHandler,
|
||||
agentHandler,
|
||||
monitorHandler,
|
||||
notificationHandler,
|
||||
conversationHandler,
|
||||
robotHandler,
|
||||
groupHandler,
|
||||
@@ -599,6 +602,7 @@ func setupRoutes(
|
||||
authHandler *handler.AuthHandler,
|
||||
agentHandler *handler.AgentHandler,
|
||||
monitorHandler *handler.MonitorHandler,
|
||||
notificationHandler *handler.NotificationHandler,
|
||||
conversationHandler *handler.ConversationHandler,
|
||||
robotHandler *handler.RobotHandler,
|
||||
groupHandler *handler.GroupHandler,
|
||||
@@ -727,6 +731,8 @@ func setupRoutes(
|
||||
protected.DELETE("/monitor/execution/:id", monitorHandler.DeleteExecution)
|
||||
protected.DELETE("/monitor/executions", monitorHandler.DeleteExecutions)
|
||||
protected.GET("/monitor/stats", monitorHandler.GetStats)
|
||||
protected.GET("/notifications/summary", notificationHandler.GetSummary)
|
||||
protected.POST("/notifications/read", notificationHandler.MarkRead)
|
||||
|
||||
// 配置管理
|
||||
protected.GET("/config", configHandler.GetConfig)
|
||||
|
||||
@@ -145,7 +145,7 @@ func (b *Builder) BuildChainFromConversation(ctx context.Context, conversationID
|
||||
}
|
||||
|
||||
// 1. 优先尝试从数据库获取保存的最后一轮ReAct输入和输出
|
||||
reactInputJSON, modelOutput, err := b.db.GetReActData(conversationID)
|
||||
reactInputJSON, modelOutput, err := b.db.GetAgentTrace(conversationID)
|
||||
if err != nil {
|
||||
b.logger.Warn("获取保存的ReAct数据失败,将使用消息历史构建", zap.Error(err))
|
||||
// 继续使用原来的逻辑
|
||||
@@ -170,7 +170,7 @@ func (b *Builder) BuildChainFromConversation(ctx context.Context, conversationID
|
||||
messageCount = len(tempMessages)
|
||||
}
|
||||
|
||||
dataSource = "database_last_react_input"
|
||||
dataSource = "database_last_agent_trace"
|
||||
b.logger.Info("使用保存的ReAct数据构建攻击链",
|
||||
zap.String("conversationId", conversationID),
|
||||
zap.String("dataSource", dataSource),
|
||||
@@ -183,7 +183,7 @@ func (b *Builder) BuildChainFromConversation(ctx context.Context, conversationID
|
||||
// userInput = b.extractUserInputFromReActInput(reactInputJSON)
|
||||
|
||||
// 将JSON格式的messages转换为可读格式
|
||||
reactInputFinal = b.formatReActInputFromJSON(reactInputJSON)
|
||||
reactInputFinal = b.formatAgentTraceInputFromJSON(reactInputJSON)
|
||||
} else {
|
||||
// 2. 如果没有保存的ReAct数据,从对话消息构建
|
||||
dataSource = "messages_table"
|
||||
@@ -201,7 +201,7 @@ func (b *Builder) BuildChainFromConversation(ctx context.Context, conversationID
|
||||
}
|
||||
|
||||
// 提取最后一轮ReAct的输入(历史消息+当前用户输入)
|
||||
reactInputFinal = b.buildReActInput(messages)
|
||||
reactInputFinal = b.buildAgentTraceInput(messages)
|
||||
|
||||
// 提取大模型最后的输出(最后一条assistant消息)
|
||||
for i := len(messages) - 1; i >= 0; i-- {
|
||||
@@ -212,7 +212,7 @@ func (b *Builder) BuildChainFromConversation(ctx context.Context, conversationID
|
||||
}
|
||||
}
|
||||
|
||||
// 多代理:保存的 last_react_input 可能仅为首轮用户消息,不含工具轨迹;补充最后一轮助手的过程详情(与单代理「最后一轮 ReAct」对齐)
|
||||
// 多代理:保存的轨迹列可能仅为首轮用户消息,不含工具轨迹;补充最后一轮助手的过程详情(与单代理完整轨迹对齐)
|
||||
hasMCPOnAssistant := false
|
||||
var lastAssistantID string
|
||||
for i := len(messages) - 1; i >= 0; i-- {
|
||||
@@ -366,8 +366,8 @@ func (b *Builder) formatProcessDetailsForAttackChain(details []database.ProcessD
|
||||
return strings.TrimSpace(sb.String())
|
||||
}
|
||||
|
||||
// buildReActInput 构建最后一轮ReAct的输入(历史消息+当前用户输入)
|
||||
func (b *Builder) buildReActInput(messages []database.Message) string {
|
||||
// buildAgentTraceInput 构建最后一轮ReAct的输入(历史消息+当前用户输入)
|
||||
func (b *Builder) buildAgentTraceInput(messages []database.Message) string {
|
||||
var builder strings.Builder
|
||||
for _, msg := range messages {
|
||||
builder.WriteString(fmt.Sprintf("[%s]: %s\n\n", msg.Role, msg.Content))
|
||||
@@ -396,8 +396,8 @@ func (b *Builder) buildReActInput(messages []database.Message) string {
|
||||
// return ""
|
||||
// }
|
||||
|
||||
// formatReActInputFromJSON 将JSON格式的messages数组转换为可读的字符串格式
|
||||
func (b *Builder) formatReActInputFromJSON(reactInputJSON string) string {
|
||||
// formatAgentTraceInputFromJSON 将JSON格式的messages数组转换为可读的字符串格式
|
||||
func (b *Builder) formatAgentTraceInputFromJSON(reactInputJSON string) string {
|
||||
var messages []map[string]interface{}
|
||||
if err := json.Unmarshal([]byte(reactInputJSON), &messages); err != nil {
|
||||
b.logger.Warn("解析ReAct输入JSON失败", zap.Error(err))
|
||||
|
||||
@@ -72,6 +72,8 @@ type MultiAgentEinoMiddlewareConfig struct {
|
||||
ToolSearchEnable bool `yaml:"tool_search_enable,omitempty" json:"tool_search_enable,omitempty"`
|
||||
ToolSearchMinTools int `yaml:"tool_search_min_tools,omitempty" json:"tool_search_min_tools,omitempty"` // default 20; applies when len(tools) >= this
|
||||
ToolSearchAlwaysVisible int `yaml:"tool_search_always_visible,omitempty" json:"tool_search_always_visible,omitempty"` // default 12; first N tools stay always visible
|
||||
// ToolSearchAlwaysVisibleTools keeps specified tool names always visible (never hidden by tool_search).
|
||||
ToolSearchAlwaysVisibleTools []string `yaml:"tool_search_always_visible_tools,omitempty" json:"tool_search_always_visible_tools,omitempty"`
|
||||
// Plantask adds TaskCreate/Get/Update/List (file-backed under skills dir); requires eino_skills + local backend.
|
||||
PlantaskEnable bool `yaml:"plantask_enable,omitempty" json:"plantask_enable,omitempty"`
|
||||
// PlantaskRelDir relative to skills_dir for per-conversation task boards (default .eino/plantask).
|
||||
@@ -79,8 +81,24 @@ type MultiAgentEinoMiddlewareConfig struct {
|
||||
// Reduction truncates/offloads large tool outputs (requires eino local backend for Write).
|
||||
ReductionEnable bool `yaml:"reduction_enable,omitempty" json:"reduction_enable,omitempty"`
|
||||
ReductionRootDir string `yaml:"reduction_root_dir,omitempty" json:"reduction_root_dir,omitempty"` // default: os temp + conversation id
|
||||
ReductionMaxLengthForTrunc int `yaml:"reduction_max_length_for_trunc,omitempty" json:"reduction_max_length_for_trunc,omitempty"` // default 12000
|
||||
ReductionMaxTokensForClear int `yaml:"reduction_max_tokens_for_clear,omitempty" json:"reduction_max_tokens_for_clear,omitempty"` // default 50000
|
||||
ReductionClearExclude []string `yaml:"reduction_clear_exclude,omitempty" json:"reduction_clear_exclude,omitempty"`
|
||||
ReductionSubAgents bool `yaml:"reduction_sub_agents,omitempty" json:"reduction_sub_agents,omitempty"` // also attach to sub-agents
|
||||
// SummarizationTriggerRatio controls summarization trigger threshold as max_total_tokens * ratio (default 0.8).
|
||||
SummarizationTriggerRatio float64 `yaml:"summarization_trigger_ratio,omitempty" json:"summarization_trigger_ratio,omitempty"`
|
||||
// SummarizationEmitInternalEvents controls middleware internal event emission (default true).
|
||||
SummarizationEmitInternalEvents *bool `yaml:"summarization_emit_internal_events,omitempty" json:"summarization_emit_internal_events,omitempty"`
|
||||
// HistoryInputBudgetRatio caps pre-agent history tokens as max_total_tokens * ratio (default 0.35).
|
||||
HistoryInputBudgetRatio float64 `yaml:"history_input_budget_ratio,omitempty" json:"history_input_budget_ratio,omitempty"`
|
||||
// PlanExecuteUserInputBudgetRatio caps planner/replanner/executor userInput prompt budget ratio (default 0.35).
|
||||
PlanExecuteUserInputBudgetRatio float64 `yaml:"plan_execute_user_input_budget_ratio,omitempty" json:"plan_execute_user_input_budget_ratio,omitempty"`
|
||||
// PlanExecuteExecutedStepsBudgetRatio caps executed_steps prompt budget ratio (default 0.2).
|
||||
PlanExecuteExecutedStepsBudgetRatio float64 `yaml:"plan_execute_executed_steps_budget_ratio,omitempty" json:"plan_execute_executed_steps_budget_ratio,omitempty"`
|
||||
// PlanExecuteMaxStepResultRunes caps each executed step result length for prompt view (default 4000).
|
||||
PlanExecuteMaxStepResultRunes int `yaml:"plan_execute_max_step_result_runes,omitempty" json:"plan_execute_max_step_result_runes,omitempty"`
|
||||
// PlanExecuteKeepLastSteps keeps only the tail steps in prompt view (default 8).
|
||||
PlanExecuteKeepLastSteps int `yaml:"plan_execute_keep_last_steps,omitempty" json:"plan_execute_keep_last_steps,omitempty"`
|
||||
// CheckpointDir when non-empty enables adk.Runner CheckPointStore (file-backed) for interrupt/resume persistence.
|
||||
CheckpointDir string `yaml:"checkpoint_dir,omitempty" json:"checkpoint_dir,omitempty"`
|
||||
// DeepOutputKey passed to deep.Config OutputKey (session final text); empty = off.
|
||||
@@ -91,6 +109,97 @@ type MultiAgentEinoMiddlewareConfig struct {
|
||||
TaskToolDescriptionPrefix string `yaml:"task_tool_description_prefix,omitempty" json:"task_tool_description_prefix,omitempty"`
|
||||
}
|
||||
|
||||
func (c MultiAgentEinoMiddlewareConfig) SummarizationTriggerRatioEffective() float64 {
|
||||
v := c.SummarizationTriggerRatio
|
||||
if v <= 0 {
|
||||
return 0.8
|
||||
}
|
||||
if v < 0.5 {
|
||||
return 0.5
|
||||
}
|
||||
if v > 0.95 {
|
||||
return 0.95
|
||||
}
|
||||
return v
|
||||
}
|
||||
|
||||
func (c MultiAgentEinoMiddlewareConfig) SummarizationEmitInternalEventsEffective() bool {
|
||||
if c.SummarizationEmitInternalEvents != nil {
|
||||
return *c.SummarizationEmitInternalEvents
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func (c MultiAgentEinoMiddlewareConfig) HistoryInputBudgetRatioEffective() float64 {
|
||||
v := c.HistoryInputBudgetRatio
|
||||
if v <= 0 {
|
||||
return 0.35
|
||||
}
|
||||
if v < 0.15 {
|
||||
return 0.15
|
||||
}
|
||||
if v > 0.6 {
|
||||
return 0.6
|
||||
}
|
||||
return v
|
||||
}
|
||||
|
||||
func (c MultiAgentEinoMiddlewareConfig) PlanExecuteUserInputBudgetRatioEffective() float64 {
|
||||
v := c.PlanExecuteUserInputBudgetRatio
|
||||
if v <= 0 {
|
||||
return 0.35
|
||||
}
|
||||
if v < 0.1 {
|
||||
return 0.1
|
||||
}
|
||||
if v > 0.6 {
|
||||
return 0.6
|
||||
}
|
||||
return v
|
||||
}
|
||||
|
||||
func (c MultiAgentEinoMiddlewareConfig) PlanExecuteExecutedStepsBudgetRatioEffective() float64 {
|
||||
v := c.PlanExecuteExecutedStepsBudgetRatio
|
||||
if v <= 0 {
|
||||
return 0.2
|
||||
}
|
||||
if v < 0.08 {
|
||||
return 0.08
|
||||
}
|
||||
if v > 0.5 {
|
||||
return 0.5
|
||||
}
|
||||
return v
|
||||
}
|
||||
|
||||
func (c MultiAgentEinoMiddlewareConfig) PlanExecuteMaxStepResultRunesEffective() int {
|
||||
if c.PlanExecuteMaxStepResultRunes > 0 {
|
||||
return c.PlanExecuteMaxStepResultRunes
|
||||
}
|
||||
return 4000
|
||||
}
|
||||
|
||||
func (c MultiAgentEinoMiddlewareConfig) PlanExecuteKeepLastStepsEffective() int {
|
||||
if c.PlanExecuteKeepLastSteps > 0 {
|
||||
return c.PlanExecuteKeepLastSteps
|
||||
}
|
||||
return 8
|
||||
}
|
||||
|
||||
func (c MultiAgentEinoMiddlewareConfig) ReductionMaxLengthForTruncEffective() int {
|
||||
if c.ReductionMaxLengthForTrunc > 0 {
|
||||
return c.ReductionMaxLengthForTrunc
|
||||
}
|
||||
return 12000
|
||||
}
|
||||
|
||||
func (c MultiAgentEinoMiddlewareConfig) ReductionMaxTokensForClearEffective() int {
|
||||
if c.ReductionMaxTokensForClear > 0 {
|
||||
return c.ReductionMaxTokensForClear
|
||||
}
|
||||
return 50000
|
||||
}
|
||||
|
||||
// MultiAgentEinoSkillsConfig toggles Eino official skill progressive disclosure and host filesystem tools.
|
||||
type MultiAgentEinoSkillsConfig struct {
|
||||
// Disable skips skill middleware (and does not attach local FS tools for Deep).
|
||||
@@ -137,6 +246,8 @@ type MultiAgentPublic struct {
|
||||
SubAgentCount int `json:"sub_agent_count"`
|
||||
Orchestration string `json:"orchestration,omitempty"`
|
||||
PlanExecuteLoopMaxIterations int `json:"plan_execute_loop_max_iterations"`
|
||||
ToolSearchAlwaysVisibleTools []string `json:"tool_search_always_visible_tools,omitempty"`
|
||||
ToolSearchAlwaysVisibleEffectiveTools []string `json:"tool_search_always_visible_effective_tools,omitempty"`
|
||||
}
|
||||
|
||||
// NormalizeMultiAgentOrchestration 返回 deep、plan_execute 或 supervisor。
|
||||
@@ -158,6 +269,7 @@ type MultiAgentAPIUpdate struct {
|
||||
RobotUseMultiAgent bool `json:"robot_use_multi_agent"`
|
||||
BatchUseMultiAgent bool `json:"batch_use_multi_agent"`
|
||||
PlanExecuteLoopMaxIterations *int `json:"plan_execute_loop_max_iterations,omitempty"`
|
||||
ToolSearchAlwaysVisibleTools []string `json:"tool_search_always_visible_tools,omitempty"`
|
||||
}
|
||||
|
||||
// RobotsConfig 机器人配置(企业微信、钉钉、飞书等)
|
||||
|
||||
@@ -165,4 +165,3 @@ func (db *DB) DeleteAttackChain(conversationID string) error {
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
@@ -4,6 +4,8 @@ import (
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
@@ -308,7 +310,7 @@ func (db *DB) GetConversationLite(id string) (*Conversation, error) {
|
||||
func (db *DB) ListConversations(limit, offset int, search string) ([]*Conversation, error) {
|
||||
var rows *sql.Rows
|
||||
var err error
|
||||
|
||||
|
||||
if search != "" {
|
||||
// 使用 EXISTS 子查询代替 LEFT JOIN + DISTINCT,避免大表笛卡尔积
|
||||
searchPattern := "%" + search + "%"
|
||||
@@ -327,7 +329,7 @@ func (db *DB) ListConversations(limit, offset int, search string) ([]*Conversati
|
||||
limit, offset,
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("查询对话列表失败: %w", err)
|
||||
}
|
||||
@@ -416,25 +418,34 @@ func (db *DB) DeleteConversation(id string) error {
|
||||
if err != nil {
|
||||
return fmt.Errorf("删除对话失败: %w", err)
|
||||
}
|
||||
// Best-effort cleanup for conversation-scoped filesystem artifacts
|
||||
// (e.g., summarization transcript, reduction/checkpoint files under conversation_artifacts/<id>).
|
||||
if base := strings.TrimSpace(db.conversationArtifactsDir); base != "" {
|
||||
artDir := filepath.Join(base, id)
|
||||
if rmErr := os.RemoveAll(artDir); rmErr != nil {
|
||||
db.logger.Warn("删除会话 artifacts 目录失败", zap.String("conversationId", id), zap.String("dir", artDir), zap.Error(rmErr))
|
||||
}
|
||||
}
|
||||
|
||||
db.logger.Info("对话及其所有相关数据已删除", zap.String("conversationId", id))
|
||||
return nil
|
||||
}
|
||||
|
||||
// SaveReActData 保存最后一轮ReAct的输入和输出
|
||||
func (db *DB) SaveReActData(conversationID, reactInput, reactOutput string) error {
|
||||
// SaveAgentTrace 保存最后一轮代理消息轨迹与助手输出摘要。
|
||||
// SQLite 列名仍为 last_react_input / last_react_output,与历史库表兼容;语义上为「全模式代理轨迹」,非仅 ReAct。
|
||||
func (db *DB) SaveAgentTrace(conversationID, traceInputJSON, assistantOutput string) error {
|
||||
_, err := db.Exec(
|
||||
"UPDATE conversations SET last_react_input = ?, last_react_output = ?, updated_at = ? WHERE id = ?",
|
||||
reactInput, reactOutput, time.Now(), conversationID,
|
||||
traceInputJSON, assistantOutput, time.Now(), conversationID,
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("保存ReAct数据失败: %w", err)
|
||||
return fmt.Errorf("保存代理轨迹失败: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetReActData 获取最后一轮ReAct的输入和输出
|
||||
func (db *DB) GetReActData(conversationID string) (reactInput, reactOutput string, err error) {
|
||||
// GetAgentTrace 读取 conversations 中保存的代理轨迹(列名 last_react_*)。
|
||||
func (db *DB) GetAgentTrace(conversationID string) (traceInputJSON, assistantOutput string, err error) {
|
||||
var input, output sql.NullString
|
||||
err = db.QueryRow(
|
||||
"SELECT last_react_input, last_react_output FROM conversations WHERE id = ?",
|
||||
@@ -444,17 +455,17 @@ func (db *DB) GetReActData(conversationID string) (reactInput, reactOutput strin
|
||||
if err == sql.ErrNoRows {
|
||||
return "", "", fmt.Errorf("对话不存在")
|
||||
}
|
||||
return "", "", fmt.Errorf("获取ReAct数据失败: %w", err)
|
||||
return "", "", fmt.Errorf("获取代理轨迹失败: %w", err)
|
||||
}
|
||||
|
||||
if input.Valid {
|
||||
reactInput = input.String
|
||||
traceInputJSON = input.String
|
||||
}
|
||||
if output.Valid {
|
||||
reactOutput = output.String
|
||||
assistantOutput = output.String
|
||||
}
|
||||
|
||||
return reactInput, reactOutput, nil
|
||||
return traceInputJSON, assistantOutput, nil
|
||||
}
|
||||
|
||||
// ConversationHasToolProcessDetails 对话是否存在已落库的工具调用/结果(用于多代理等场景下 MCP execution id 未汇总时的攻击链判定)。
|
||||
|
||||
@@ -3,6 +3,8 @@ package database
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
@@ -21,7 +23,8 @@ func configureDBPool(db *sql.DB) {
|
||||
// DB 数据库连接
|
||||
type DB struct {
|
||||
*sql.DB
|
||||
logger *zap.Logger
|
||||
logger *zap.Logger
|
||||
conversationArtifactsDir string
|
||||
}
|
||||
|
||||
// NewDB 创建数据库连接
|
||||
@@ -41,6 +44,13 @@ func NewDB(dbPath string, logger *zap.Logger) (*DB, error) {
|
||||
DB: db,
|
||||
logger: logger,
|
||||
}
|
||||
// Keep conversation-scoped artifacts near database files, so cleanup can follow conversation lifecycle.
|
||||
baseDir := filepath.Join(filepath.Dir(dbPath), "conversation_artifacts")
|
||||
if mkErr := os.MkdirAll(baseDir, 0o755); mkErr == nil {
|
||||
database.conversationArtifactsDir = baseDir
|
||||
} else if logger != nil {
|
||||
logger.Warn("创建 conversation artifacts 目录失败", zap.String("dir", baseDir), zap.Error(mkErr))
|
||||
}
|
||||
|
||||
// 初始化表
|
||||
if err := database.initTables(); err != nil {
|
||||
@@ -52,7 +62,7 @@ func NewDB(dbPath string, logger *zap.Logger) (*DB, error) {
|
||||
|
||||
// initTables 初始化数据库表
|
||||
func (db *DB) initTables() error {
|
||||
// 创建对话表
|
||||
// 创建对话表(last_react_input / last_react_output 存「代理消息轨迹」JSON 与助手摘要,列名保留以兼容已有库)
|
||||
createConversationsTable := `
|
||||
CREATE TABLE IF NOT EXISTS conversations (
|
||||
id TEXT PRIMARY KEY,
|
||||
|
||||
@@ -12,7 +12,7 @@ import (
|
||||
// Vulnerability 漏洞
|
||||
type Vulnerability struct {
|
||||
ID string `json:"id"`
|
||||
ConversationID string `json:"conversation_id"`
|
||||
ConversationID string `json:"conversation_id"`
|
||||
ConversationTag string `json:"conversation_tag,omitempty"`
|
||||
TaskTag string `json:"task_tag,omitempty"`
|
||||
TaskID string `json:"task_id,omitempty"`
|
||||
@@ -367,4 +367,3 @@ func (db *DB) GetVulnerabilityFilterOptions() (map[string][]string, error) {
|
||||
"task_tags": taskTags,
|
||||
}, nil
|
||||
}
|
||||
|
||||
|
||||
+84
-92
@@ -497,10 +497,10 @@ func (h *AgentHandler) AgentLoop(c *gin.Context) {
|
||||
defer h.hitlManager.DeactivateConversation(conversationID)
|
||||
}
|
||||
|
||||
// 优先尝试从保存的ReAct数据恢复历史上下文
|
||||
agentHistoryMessages, err := h.loadHistoryFromReActData(conversationID)
|
||||
// 优先尝试从保存的代理轨迹恢复历史上下文
|
||||
agentHistoryMessages, err := h.loadHistoryFromAgentTrace(conversationID)
|
||||
if err != nil {
|
||||
h.logger.Warn("从ReAct数据加载历史消息失败,使用消息表", zap.Error(err))
|
||||
h.logger.Warn("从代理轨迹加载历史消息失败,使用消息表", zap.Error(err))
|
||||
// 回退到使用数据库消息表
|
||||
historyMessages, err := h.db.GetMessages(conversationID)
|
||||
if err != nil {
|
||||
@@ -518,7 +518,7 @@ func (h *AgentHandler) AgentLoop(c *gin.Context) {
|
||||
h.logger.Info("从消息表加载历史消息", zap.Int("count", len(agentHistoryMessages)))
|
||||
}
|
||||
} else {
|
||||
h.logger.Info("从ReAct数据恢复历史上下文", zap.Int("count", len(agentHistoryMessages)))
|
||||
h.logger.Info("从代理轨迹恢复历史上下文", zap.Int("count", len(agentHistoryMessages)))
|
||||
}
|
||||
|
||||
// 校验附件数量(非流式)
|
||||
@@ -613,12 +613,12 @@ func (h *AgentHandler) AgentLoop(c *gin.Context) {
|
||||
if err != nil {
|
||||
h.logger.Error("Agent Loop执行失败", zap.Error(err))
|
||||
|
||||
// 即使执行失败,也尝试保存ReAct数据(如果result中有)
|
||||
if result != nil && (result.LastReActInput != "" || result.LastReActOutput != "") {
|
||||
if saveErr := h.db.SaveReActData(conversationID, result.LastReActInput, result.LastReActOutput); saveErr != nil {
|
||||
h.logger.Warn("保存失败任务的ReAct数据失败", zap.Error(saveErr))
|
||||
// 即使执行失败,也尝试保存代理轨迹(如果 result 中有)
|
||||
if result != nil && (result.LastAgentTraceInput != "" || result.LastAgentTraceOutput != "") {
|
||||
if saveErr := h.db.SaveAgentTrace(conversationID, result.LastAgentTraceInput, result.LastAgentTraceOutput); saveErr != nil {
|
||||
h.logger.Warn("保存失败任务的代理轨迹失败", zap.Error(saveErr))
|
||||
} else {
|
||||
h.logger.Info("已保存失败任务的ReAct数据", zap.String("conversationId", conversationID))
|
||||
h.logger.Info("已保存失败任务的代理轨迹", zap.String("conversationId", conversationID))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -634,12 +634,12 @@ func (h *AgentHandler) AgentLoop(c *gin.Context) {
|
||||
// 因为AI已经生成了回复,用户应该能看到
|
||||
}
|
||||
|
||||
// 保存最后一轮ReAct的输入和输出
|
||||
if result.LastReActInput != "" || result.LastReActOutput != "" {
|
||||
if err := h.db.SaveReActData(conversationID, result.LastReActInput, result.LastReActOutput); err != nil {
|
||||
h.logger.Warn("保存ReAct数据失败", zap.Error(err))
|
||||
// 保存最后一轮代理轨迹与助手输出
|
||||
if result.LastAgentTraceInput != "" || result.LastAgentTraceOutput != "" {
|
||||
if err := h.db.SaveAgentTrace(conversationID, result.LastAgentTraceInput, result.LastAgentTraceOutput); err != nil {
|
||||
h.logger.Warn("保存代理轨迹失败", zap.Error(err))
|
||||
} else {
|
||||
h.logger.Info("已保存ReAct数据", zap.String("conversationId", conversationID))
|
||||
h.logger.Info("已保存代理轨迹", zap.String("conversationId", conversationID))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -666,7 +666,7 @@ func (h *AgentHandler) ProcessMessageForRobot(ctx context.Context, conversationI
|
||||
}
|
||||
}
|
||||
|
||||
agentHistoryMessages, err := h.loadHistoryFromReActData(conversationID)
|
||||
agentHistoryMessages, err := h.loadHistoryFromAgentTrace(conversationID)
|
||||
if err != nil {
|
||||
historyMessages, getErr := h.db.GetMessages(conversationID)
|
||||
if getErr != nil {
|
||||
@@ -722,6 +722,7 @@ func (h *AgentHandler) ProcessMessageForRobot(ctx context.Context, conversationI
|
||||
"deep",
|
||||
)
|
||||
if errMA != nil {
|
||||
h.persistEinoAgentTraceForResume(conversationID, resultMA)
|
||||
errMsg := "执行失败: " + errMA.Error()
|
||||
if assistantMessageID != "" {
|
||||
_, _ = h.db.Exec("UPDATE messages SET content = ? WHERE id = ?", errMsg, assistantMessageID)
|
||||
@@ -747,8 +748,8 @@ func (h *AgentHandler) ProcessMessageForRobot(ctx context.Context, conversationI
|
||||
h.logger.Warn("机器人:保存助手消息失败", zap.Error(err))
|
||||
}
|
||||
}
|
||||
if resultMA.LastReActInput != "" || resultMA.LastReActOutput != "" {
|
||||
_ = h.db.SaveReActData(conversationID, resultMA.LastReActInput, resultMA.LastReActOutput)
|
||||
if resultMA.LastAgentTraceInput != "" || resultMA.LastAgentTraceOutput != "" {
|
||||
_ = h.db.SaveAgentTrace(conversationID, resultMA.LastAgentTraceInput, resultMA.LastAgentTraceOutput)
|
||||
}
|
||||
return resultMA.Response, conversationID, nil
|
||||
}
|
||||
@@ -782,8 +783,8 @@ func (h *AgentHandler) ProcessMessageForRobot(ctx context.Context, conversationI
|
||||
h.logger.Warn("机器人:保存助手消息失败", zap.Error(err))
|
||||
}
|
||||
}
|
||||
if result.LastReActInput != "" || result.LastReActOutput != "" {
|
||||
_ = h.db.SaveReActData(conversationID, result.LastReActInput, result.LastReActOutput)
|
||||
if result.LastAgentTraceInput != "" || result.LastAgentTraceOutput != "" {
|
||||
_ = h.db.SaveAgentTrace(conversationID, result.LastAgentTraceInput, result.LastAgentTraceOutput)
|
||||
}
|
||||
return result.Response, conversationID, nil
|
||||
}
|
||||
@@ -1359,10 +1360,10 @@ func (h *AgentHandler) AgentLoopStream(c *gin.Context) {
|
||||
}
|
||||
ssePublishConversationID = conversationID
|
||||
|
||||
// 优先尝试从保存的ReAct数据恢复历史上下文
|
||||
agentHistoryMessages, err := h.loadHistoryFromReActData(conversationID)
|
||||
// 优先尝试从保存的代理轨迹恢复历史上下文
|
||||
agentHistoryMessages, err := h.loadHistoryFromAgentTrace(conversationID)
|
||||
if err != nil {
|
||||
h.logger.Warn("从ReAct数据加载历史消息失败,使用消息表", zap.Error(err))
|
||||
h.logger.Warn("从代理轨迹加载历史消息失败,使用消息表", zap.Error(err))
|
||||
// 回退到使用数据库消息表
|
||||
historyMessages, err := h.db.GetMessages(conversationID)
|
||||
if err != nil {
|
||||
@@ -1380,7 +1381,7 @@ func (h *AgentHandler) AgentLoopStream(c *gin.Context) {
|
||||
h.logger.Info("从消息表加载历史消息", zap.Int("count", len(agentHistoryMessages)))
|
||||
}
|
||||
} else {
|
||||
h.logger.Info("从ReAct数据恢复历史上下文", zap.Int("count", len(agentHistoryMessages)))
|
||||
h.logger.Info("从代理轨迹恢复历史上下文", zap.Int("count", len(agentHistoryMessages)))
|
||||
}
|
||||
|
||||
// 校验附件数量
|
||||
@@ -1579,12 +1580,12 @@ func (h *AgentHandler) AgentLoopStream(c *gin.Context) {
|
||||
h.db.AddProcessDetail(assistantMessageID, conversationID, "cancelled", cancelMsg, nil)
|
||||
}
|
||||
|
||||
// 即使任务被取消,也尝试保存ReAct数据(如果result中有)
|
||||
if result != nil && (result.LastReActInput != "" || result.LastReActOutput != "") {
|
||||
if err := h.db.SaveReActData(conversationID, result.LastReActInput, result.LastReActOutput); err != nil {
|
||||
h.logger.Warn("保存取消任务的ReAct数据失败", zap.Error(err))
|
||||
// 即使任务被取消,也尝试保存代理轨迹(如果 result 中有)
|
||||
if result != nil && (result.LastAgentTraceInput != "" || result.LastAgentTraceOutput != "") {
|
||||
if err := h.db.SaveAgentTrace(conversationID, result.LastAgentTraceInput, result.LastAgentTraceOutput); err != nil {
|
||||
h.logger.Warn("保存取消任务的代理轨迹失败", zap.Error(err))
|
||||
} else {
|
||||
h.logger.Info("已保存取消任务的ReAct数据", zap.String("conversationId", conversationID))
|
||||
h.logger.Info("已保存取消任务的代理轨迹", zap.String("conversationId", conversationID))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1614,12 +1615,12 @@ func (h *AgentHandler) AgentLoopStream(c *gin.Context) {
|
||||
h.db.AddProcessDetail(assistantMessageID, conversationID, "timeout", timeoutMsg, nil)
|
||||
}
|
||||
|
||||
// 即使任务超时,也尝试保存ReAct数据(如果result中有)
|
||||
if result != nil && (result.LastReActInput != "" || result.LastReActOutput != "") {
|
||||
if err := h.db.SaveReActData(conversationID, result.LastReActInput, result.LastReActOutput); err != nil {
|
||||
h.logger.Warn("保存超时任务的ReAct数据失败", zap.Error(err))
|
||||
// 即使任务超时,也尝试保存代理轨迹(如果 result 中有)
|
||||
if result != nil && (result.LastAgentTraceInput != "" || result.LastAgentTraceOutput != "") {
|
||||
if err := h.db.SaveAgentTrace(conversationID, result.LastAgentTraceInput, result.LastAgentTraceOutput); err != nil {
|
||||
h.logger.Warn("保存超时任务的代理轨迹失败", zap.Error(err))
|
||||
} else {
|
||||
h.logger.Info("已保存超时任务的ReAct数据", zap.String("conversationId", conversationID))
|
||||
h.logger.Info("已保存超时任务的代理轨迹", zap.String("conversationId", conversationID))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1649,12 +1650,12 @@ func (h *AgentHandler) AgentLoopStream(c *gin.Context) {
|
||||
h.db.AddProcessDetail(assistantMessageID, conversationID, "error", errorMsg, nil)
|
||||
}
|
||||
|
||||
// 即使任务失败,也尝试保存ReAct数据(如果result中有)
|
||||
if result != nil && (result.LastReActInput != "" || result.LastReActOutput != "") {
|
||||
if err := h.db.SaveReActData(conversationID, result.LastReActInput, result.LastReActOutput); err != nil {
|
||||
h.logger.Warn("保存失败任务的ReAct数据失败", zap.Error(err))
|
||||
// 即使任务失败,也尝试保存代理轨迹(如果 result 中有)
|
||||
if result != nil && (result.LastAgentTraceInput != "" || result.LastAgentTraceOutput != "") {
|
||||
if err := h.db.SaveAgentTrace(conversationID, result.LastAgentTraceInput, result.LastAgentTraceOutput); err != nil {
|
||||
h.logger.Warn("保存失败任务的代理轨迹失败", zap.Error(err))
|
||||
} else {
|
||||
h.logger.Info("已保存失败任务的ReAct数据", zap.String("conversationId", conversationID))
|
||||
h.logger.Info("已保存失败任务的代理轨迹", zap.String("conversationId", conversationID))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1694,12 +1695,12 @@ func (h *AgentHandler) AgentLoopStream(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
// 保存最后一轮ReAct的输入和输出
|
||||
if result.LastReActInput != "" || result.LastReActOutput != "" {
|
||||
if err := h.db.SaveReActData(conversationID, result.LastReActInput, result.LastReActOutput); err != nil {
|
||||
h.logger.Warn("保存ReAct数据失败", zap.Error(err))
|
||||
// 保存最后一轮代理轨迹与助手输出
|
||||
if result.LastAgentTraceInput != "" || result.LastAgentTraceOutput != "" {
|
||||
if err := h.db.SaveAgentTrace(conversationID, result.LastAgentTraceInput, result.LastAgentTraceOutput); err != nil {
|
||||
h.logger.Warn("保存代理轨迹失败", zap.Error(err))
|
||||
} else {
|
||||
h.logger.Info("已保存ReAct数据", zap.String("conversationId", conversationID))
|
||||
h.logger.Info("已保存代理轨迹", zap.String("conversationId", conversationID))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2499,6 +2500,9 @@ func (h *AgentHandler) executeBatchQueue(queueID string) {
|
||||
cancel()
|
||||
|
||||
if runErr != nil {
|
||||
if useRunResult {
|
||||
h.persistEinoAgentTraceForResume(conversationID, resultMA)
|
||||
}
|
||||
// 检查是否是取消错误
|
||||
// 1. 直接检查是否是 context.Canceled(包括包装后的错误)
|
||||
// 2. 检查错误消息中是否包含"context canceled"或"cancelled"关键字
|
||||
@@ -2542,14 +2546,14 @@ func (h *AgentHandler) executeBatchQueue(queueID string) {
|
||||
h.logger.Warn("保存取消消息失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.Error(errMsg))
|
||||
}
|
||||
}
|
||||
// 保存ReAct数据(如果存在)
|
||||
if result != nil && (result.LastReActInput != "" || result.LastReActOutput != "") {
|
||||
if err := h.db.SaveReActData(conversationID, result.LastReActInput, result.LastReActOutput); err != nil {
|
||||
h.logger.Warn("保存取消任务的ReAct数据失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.Error(err))
|
||||
// 保存代理轨迹(如果存在)
|
||||
if result != nil && (result.LastAgentTraceInput != "" || result.LastAgentTraceOutput != "") {
|
||||
if err := h.db.SaveAgentTrace(conversationID, result.LastAgentTraceInput, result.LastAgentTraceOutput); err != nil {
|
||||
h.logger.Warn("保存取消任务的代理轨迹失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.Error(err))
|
||||
}
|
||||
} else if useRunResult && resultMA != nil && (resultMA.LastReActInput != "" || resultMA.LastReActOutput != "") {
|
||||
if err := h.db.SaveReActData(conversationID, resultMA.LastReActInput, resultMA.LastReActOutput); err != nil {
|
||||
h.logger.Warn("保存取消任务的ReAct数据失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.Error(err))
|
||||
} else if useRunResult && resultMA != nil && (resultMA.LastAgentTraceInput != "" || resultMA.LastAgentTraceOutput != "") {
|
||||
if err := h.db.SaveAgentTrace(conversationID, resultMA.LastAgentTraceInput, resultMA.LastAgentTraceOutput); err != nil {
|
||||
h.logger.Warn("保存取消任务的代理轨迹失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.Error(err))
|
||||
}
|
||||
}
|
||||
h.batchTaskManager.UpdateTaskStatusWithConversationID(queueID, task.ID, "cancelled", cancelMsg, "", conversationID)
|
||||
@@ -2581,13 +2585,13 @@ func (h *AgentHandler) executeBatchQueue(queueID string) {
|
||||
if useRunResult {
|
||||
resText = resultMA.Response
|
||||
mcpIDs = resultMA.MCPExecutionIDs
|
||||
lastIn = resultMA.LastReActInput
|
||||
lastOut = resultMA.LastReActOutput
|
||||
lastIn = resultMA.LastAgentTraceInput
|
||||
lastOut = resultMA.LastAgentTraceOutput
|
||||
} else {
|
||||
resText = result.Response
|
||||
mcpIDs = result.MCPExecutionIDs
|
||||
lastIn = result.LastReActInput
|
||||
lastOut = result.LastReActOutput
|
||||
lastIn = result.LastAgentTraceInput
|
||||
lastOut = result.LastAgentTraceOutput
|
||||
}
|
||||
|
||||
// 更新助手消息内容
|
||||
@@ -2618,12 +2622,12 @@ func (h *AgentHandler) executeBatchQueue(queueID string) {
|
||||
}
|
||||
}
|
||||
|
||||
// 保存ReAct数据
|
||||
// 保存代理轨迹
|
||||
if lastIn != "" || lastOut != "" {
|
||||
if err := h.db.SaveReActData(conversationID, lastIn, lastOut); err != nil {
|
||||
h.logger.Warn("保存ReAct数据失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.Error(err))
|
||||
if err := h.db.SaveAgentTrace(conversationID, lastIn, lastOut); err != nil {
|
||||
h.logger.Warn("保存代理轨迹失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.Error(err))
|
||||
} else {
|
||||
h.logger.Info("已保存ReAct数据", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.String("conversationId", conversationID))
|
||||
h.logger.Info("已保存代理轨迹", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.String("conversationId", conversationID))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2642,36 +2646,33 @@ func (h *AgentHandler) executeBatchQueue(queueID string) {
|
||||
}
|
||||
}
|
||||
|
||||
// loadHistoryFromReActData 从保存的ReAct数据恢复历史消息上下文
|
||||
// 采用与攻击链生成类似的拼接逻辑:优先使用保存的last_react_input和last_react_output,若不存在则回退到消息表
|
||||
func (h *AgentHandler) loadHistoryFromReActData(conversationID string) ([]agent.ChatMessage, error) {
|
||||
// 获取保存的ReAct输入和输出
|
||||
reactInputJSON, reactOutput, err := h.db.GetReActData(conversationID)
|
||||
// loadHistoryFromAgentTrace 从库中保存的代理消息轨迹恢复历史(列 last_react_*;含单代理与 Eino)。
|
||||
// 逻辑与攻击链一致:优先用已保存的 JSON 消息带 + 最后一轮助手摘要,否则回退消息表。
|
||||
func (h *AgentHandler) loadHistoryFromAgentTrace(conversationID string) ([]agent.ChatMessage, error) {
|
||||
traceInputJSON, assistantOut, err := h.db.GetAgentTrace(conversationID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("获取ReAct数据失败: %w", err)
|
||||
return nil, fmt.Errorf("获取代理轨迹失败: %w", err)
|
||||
}
|
||||
|
||||
// 如果last_react_input为空,回退到使用消息表(与攻击链生成逻辑一致)
|
||||
if reactInputJSON == "" {
|
||||
return nil, fmt.Errorf("ReAct数据为空,将使用消息表")
|
||||
if traceInputJSON == "" {
|
||||
return nil, fmt.Errorf("代理轨迹为空,将使用消息表")
|
||||
}
|
||||
|
||||
dataSource := "database_last_react_input"
|
||||
dataSource := "database_last_agent_trace"
|
||||
|
||||
// 解析JSON格式的messages数组
|
||||
var messagesArray []map[string]interface{}
|
||||
if err := json.Unmarshal([]byte(reactInputJSON), &messagesArray); err != nil {
|
||||
return nil, fmt.Errorf("解析ReAct输入JSON失败: %w", err)
|
||||
if err := json.Unmarshal([]byte(traceInputJSON), &messagesArray); err != nil {
|
||||
return nil, fmt.Errorf("解析代理轨迹 JSON 失败: %w", err)
|
||||
}
|
||||
|
||||
messageCount := len(messagesArray)
|
||||
|
||||
h.logger.Info("使用保存的ReAct数据恢复历史上下文",
|
||||
h.logger.Info("使用保存的代理轨迹恢复历史上下文",
|
||||
zap.String("conversationId", conversationID),
|
||||
zap.String("dataSource", dataSource),
|
||||
zap.Int("reactInputSize", len(reactInputJSON)),
|
||||
zap.Int("traceInputSize", len(traceInputJSON)),
|
||||
zap.Int("messageCount", messageCount),
|
||||
zap.Int("reactOutputSize", len(reactOutput)),
|
||||
zap.Int("assistantOutSize", len(assistantOut)),
|
||||
)
|
||||
// fmt.Println("messagesArray:", messagesArray)//debug
|
||||
|
||||
@@ -2755,53 +2756,44 @@ func (h *AgentHandler) loadHistoryFromReActData(conversationID string) ([]agent.
|
||||
agentMessages = append(agentMessages, msg)
|
||||
}
|
||||
|
||||
// 如果存在last_react_output,需要将其作为最后一条assistant消息
|
||||
// 因为last_react_input是在迭代开始前保存的,不包含最后一轮的最终输出
|
||||
if reactOutput != "" {
|
||||
// 检查最后一条消息是否是assistant消息且没有tool_calls
|
||||
// 如果有tool_calls,说明后面应该还有tool消息和最终的assistant回复
|
||||
// 若存在 last_react_output(助手摘要),合并为最后一条 assistant(与保存格式一致)
|
||||
if assistantOut != "" {
|
||||
if len(agentMessages) > 0 {
|
||||
lastMsg := &agentMessages[len(agentMessages)-1]
|
||||
if strings.EqualFold(lastMsg.Role, "assistant") && len(lastMsg.ToolCalls) == 0 {
|
||||
// 最后一条是assistant消息且没有tool_calls,用最终输出更新其content
|
||||
lastMsg.Content = reactOutput
|
||||
lastMsg.Content = assistantOut
|
||||
} else {
|
||||
// 最后一条不是assistant消息,或者有tool_calls,添加最终输出作为新的assistant消息
|
||||
agentMessages = append(agentMessages, agent.ChatMessage{
|
||||
Role: "assistant",
|
||||
Content: reactOutput,
|
||||
Content: assistantOut,
|
||||
})
|
||||
}
|
||||
} else {
|
||||
// 如果没有消息,直接添加最终输出
|
||||
agentMessages = append(agentMessages, agent.ChatMessage{
|
||||
Role: "assistant",
|
||||
Content: reactOutput,
|
||||
Content: assistantOut,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
if len(agentMessages) == 0 {
|
||||
return nil, fmt.Errorf("从ReAct数据解析的消息为空")
|
||||
return nil, fmt.Errorf("从代理轨迹解析的消息为空")
|
||||
}
|
||||
|
||||
// 修复可能存在的失配tool消息,避免OpenAI报错
|
||||
// 这可以防止出现"messages with role 'tool' must be a response to a preceeding message with 'tool_calls'"错误
|
||||
if h.agent != nil {
|
||||
if fixed := h.agent.RepairOrphanToolMessages(&agentMessages); fixed {
|
||||
h.logger.Info("修复了从ReAct数据恢复的历史消息中的失配tool消息",
|
||||
h.logger.Info("修复了从代理轨迹恢复的历史消息中的失配 tool 消息",
|
||||
zap.String("conversationId", conversationID),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
h.logger.Info("从ReAct数据恢复历史消息完成",
|
||||
h.logger.Info("从代理轨迹恢复历史消息完成",
|
||||
zap.String("conversationId", conversationID),
|
||||
zap.String("dataSource", dataSource),
|
||||
zap.Int("originalMessageCount", messageCount),
|
||||
zap.Int("finalMessageCount", len(agentMessages)),
|
||||
zap.Bool("hasReactOutput", reactOutput != ""),
|
||||
zap.Bool("hasAssistantOut", assistantOut != ""),
|
||||
)
|
||||
fmt.Println("agentMessages:", agentMessages) //debug
|
||||
return agentMessages, nil
|
||||
}
|
||||
|
||||
@@ -83,7 +83,7 @@ func (h *AttackChainHandler) GetAttackChain(c *gin.Context) {
|
||||
// 使用锁机制防止同一对话的并发生成
|
||||
lockInterface, _ := h.generatingLocks.LoadOrStore(conversationID, &sync.Mutex{})
|
||||
lock := lockInterface.(*sync.Mutex)
|
||||
|
||||
|
||||
// 尝试获取锁,如果正在生成则返回错误
|
||||
acquired := lock.TryLock()
|
||||
if !acquired {
|
||||
@@ -144,7 +144,7 @@ func (h *AttackChainHandler) RegenerateAttackChain(c *gin.Context) {
|
||||
// 使用锁机制防止并发生成
|
||||
lockInterface, _ := h.generatingLocks.LoadOrStore(conversationID, &sync.Mutex{})
|
||||
lock := lockInterface.(*sync.Mutex)
|
||||
|
||||
|
||||
acquired := lock.TryLock()
|
||||
if !acquired {
|
||||
h.logger.Info("攻击链正在生成中,请稍后再试", zap.String("conversationId", conversationID))
|
||||
@@ -170,4 +170,3 @@ func (h *AttackChainHandler) RegenerateAttackChain(c *gin.Context) {
|
||||
|
||||
c.JSON(http.StatusOK, chain)
|
||||
}
|
||||
|
||||
|
||||
+39
-14
@@ -17,6 +17,7 @@ import (
|
||||
"cyberstrike-ai/internal/config"
|
||||
"cyberstrike-ai/internal/knowledge"
|
||||
"cyberstrike-ai/internal/mcp"
|
||||
"cyberstrike-ai/internal/mcp/builtin"
|
||||
"cyberstrike-ai/internal/openai"
|
||||
"cyberstrike-ai/internal/security"
|
||||
|
||||
@@ -90,6 +91,7 @@ type AttackChainUpdater interface {
|
||||
type AgentUpdater interface {
|
||||
UpdateConfig(cfg *config.OpenAIConfig)
|
||||
UpdateMaxIterations(maxIterations int)
|
||||
UpdateToolDescriptionMode(mode string)
|
||||
}
|
||||
|
||||
// NewConfigHandler 创建新的配置处理器
|
||||
@@ -232,13 +234,7 @@ func (h *ConfigHandler) GetConfig(c *gin.Context) {
|
||||
if configToolMap[mcpTool.Name] {
|
||||
continue
|
||||
}
|
||||
description := mcpTool.ShortDescription
|
||||
if description == "" {
|
||||
description = mcpTool.Description
|
||||
}
|
||||
if len(description) > 10000 {
|
||||
description = description[:10000] + "..."
|
||||
}
|
||||
description := h.pickToolDescription(mcpTool.ShortDescription, mcpTool.Description)
|
||||
tools = append(tools, ToolConfigInfo{
|
||||
Name: mcpTool.Name,
|
||||
Description: description,
|
||||
@@ -275,6 +271,11 @@ func (h *ConfigHandler) GetConfig(c *gin.Context) {
|
||||
SubAgentCount: subAgentCount,
|
||||
Orchestration: config.NormalizeMultiAgentOrchestration(h.config.MultiAgent.Orchestration),
|
||||
PlanExecuteLoopMaxIterations: h.config.MultiAgent.PlanExecuteLoopMaxIterations,
|
||||
ToolSearchAlwaysVisibleTools: append([]string(nil), h.config.MultiAgent.EinoMiddleware.ToolSearchAlwaysVisibleTools...),
|
||||
ToolSearchAlwaysVisibleEffectiveTools: mergeToolNameLists(
|
||||
h.config.MultiAgent.EinoMiddleware.ToolSearchAlwaysVisibleTools,
|
||||
builtin.GetAllBuiltinTools(),
|
||||
),
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, GetConfigResponse{
|
||||
@@ -430,13 +431,7 @@ func (h *ConfigHandler) GetTools(c *gin.Context) {
|
||||
continue
|
||||
}
|
||||
|
||||
description := mcpTool.ShortDescription
|
||||
if description == "" {
|
||||
description = mcpTool.Description
|
||||
}
|
||||
if len(description) > 10000 {
|
||||
description = description[:10000] + "..."
|
||||
}
|
||||
description := h.pickToolDescription(mcpTool.ShortDescription, mcpTool.Description)
|
||||
|
||||
toolInfo := ToolConfigInfo{
|
||||
Name: mcpTool.Name,
|
||||
@@ -689,11 +684,13 @@ func (h *ConfigHandler) UpdateConfig(c *gin.Context) {
|
||||
if req.MultiAgent.PlanExecuteLoopMaxIterations != nil {
|
||||
h.config.MultiAgent.PlanExecuteLoopMaxIterations = *req.MultiAgent.PlanExecuteLoopMaxIterations
|
||||
}
|
||||
h.config.MultiAgent.EinoMiddleware.ToolSearchAlwaysVisibleTools = dedupeToolNameList(req.MultiAgent.ToolSearchAlwaysVisibleTools)
|
||||
h.logger.Info("更新多代理配置",
|
||||
zap.Bool("enabled", h.config.MultiAgent.Enabled),
|
||||
zap.Bool("robot_use_multi_agent", h.config.MultiAgent.RobotUseMultiAgent),
|
||||
zap.Bool("batch_use_multi_agent", h.config.MultiAgent.BatchUseMultiAgent),
|
||||
zap.Int("plan_execute_loop_max_iterations", h.config.MultiAgent.PlanExecuteLoopMaxIterations),
|
||||
zap.Int("tool_search_always_visible_tools", len(h.config.MultiAgent.EinoMiddleware.ToolSearchAlwaysVisibleTools)),
|
||||
)
|
||||
}
|
||||
|
||||
@@ -1061,6 +1058,7 @@ func (h *ConfigHandler) ApplyConfig(c *gin.Context) {
|
||||
if h.agent != nil {
|
||||
h.agent.UpdateConfig(&h.config.OpenAI)
|
||||
h.agent.UpdateMaxIterations(h.config.Agent.MaxIterations)
|
||||
h.agent.UpdateToolDescriptionMode(h.config.Security.ToolDescriptionMode)
|
||||
h.logger.Info("Agent配置已更新")
|
||||
}
|
||||
|
||||
@@ -1383,6 +1381,33 @@ func updateMultiAgentConfig(doc *yaml.Node, cfg config.MultiAgentConfig) {
|
||||
setBoolInMap(maNode, "robot_use_multi_agent", cfg.RobotUseMultiAgent)
|
||||
setBoolInMap(maNode, "batch_use_multi_agent", cfg.BatchUseMultiAgent)
|
||||
setIntInMap(maNode, "plan_execute_loop_max_iterations", cfg.PlanExecuteLoopMaxIterations)
|
||||
mwNode := ensureMap(maNode, "eino_middleware")
|
||||
setFlowStringSliceInMap(mwNode, "tool_search_always_visible_tools", dedupeToolNameList(cfg.EinoMiddleware.ToolSearchAlwaysVisibleTools))
|
||||
}
|
||||
|
||||
func dedupeToolNameList(in []string) []string {
|
||||
if len(in) == 0 {
|
||||
return []string{}
|
||||
}
|
||||
seen := make(map[string]struct{}, len(in))
|
||||
out := make([]string, 0, len(in))
|
||||
for _, name := range in {
|
||||
n := strings.TrimSpace(name)
|
||||
if n == "" {
|
||||
continue
|
||||
}
|
||||
key := strings.ToLower(n)
|
||||
if _, ok := seen[key]; ok {
|
||||
continue
|
||||
}
|
||||
seen[key] = struct{}{}
|
||||
out = append(out, n)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func mergeToolNameLists(a, b []string) []string {
|
||||
return dedupeToolNameList(append(append([]string{}, a...), b...))
|
||||
}
|
||||
|
||||
func ensureMap(parent *yaml.Node, path ...string) *yaml.Node {
|
||||
|
||||
@@ -230,4 +230,3 @@ func (h *ConversationHandler) DeleteConversationTurn(c *gin.Context) {
|
||||
"message": "ok",
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
@@ -175,6 +175,7 @@ func (h *AgentHandler) EinoSingleAgentLoopStream(c *gin.Context) {
|
||||
)
|
||||
|
||||
if runErr != nil {
|
||||
h.persistEinoAgentTraceForResume(conversationID, result)
|
||||
cause := context.Cause(baseCtx)
|
||||
if errors.Is(cause, ErrTaskCancelled) {
|
||||
taskStatus = "cancelled"
|
||||
@@ -239,9 +240,9 @@ func (h *AgentHandler) EinoSingleAgentLoopStream(c *gin.Context) {
|
||||
)
|
||||
}
|
||||
|
||||
if result.LastReActInput != "" || result.LastReActOutput != "" {
|
||||
if err := h.db.SaveReActData(conversationID, result.LastReActInput, result.LastReActOutput); err != nil {
|
||||
h.logger.Warn("保存 ReAct 数据失败", zap.Error(err))
|
||||
if result.LastAgentTraceInput != "" || result.LastAgentTraceOutput != "" {
|
||||
if err := h.db.SaveAgentTrace(conversationID, result.LastAgentTraceInput, result.LastAgentTraceOutput); err != nil {
|
||||
h.logger.Warn("保存代理轨迹失败", zap.Error(err))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -306,6 +307,7 @@ func (h *AgentHandler) EinoSingleAgentLoop(c *gin.Context) {
|
||||
progressCallback,
|
||||
)
|
||||
if runErr != nil {
|
||||
h.persistEinoAgentTraceForResume(prep.ConversationID, result)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": runErr.Error()})
|
||||
return
|
||||
}
|
||||
@@ -323,8 +325,8 @@ func (h *AgentHandler) EinoSingleAgentLoop(c *gin.Context) {
|
||||
prep.AssistantMessageID,
|
||||
)
|
||||
}
|
||||
if result.LastReActInput != "" || result.LastReActOutput != "" {
|
||||
_ = h.db.SaveReActData(prep.ConversationID, result.LastReActInput, result.LastReActOutput)
|
||||
if result.LastAgentTraceInput != "" || result.LastAgentTraceOutput != "" {
|
||||
_ = h.db.SaveAgentTrace(prep.ConversationID, result.LastAgentTraceInput, result.LastAgentTraceOutput)
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
|
||||
@@ -247,7 +247,7 @@ func TestExternalMCPHandler_DeleteExternalMCP(t *testing.T) {
|
||||
|
||||
// 先添加一个配置
|
||||
configObj := config.ExternalMCPServerConfig{
|
||||
Command: "python3",
|
||||
Command: "python3",
|
||||
ExternalMCPEnable: true,
|
||||
}
|
||||
handler.manager.AddOrUpdateConfig("test-delete", configObj)
|
||||
@@ -276,11 +276,11 @@ func TestExternalMCPHandler_GetExternalMCPs(t *testing.T) {
|
||||
|
||||
// 添加多个配置
|
||||
handler.manager.AddOrUpdateConfig("test1", config.ExternalMCPServerConfig{
|
||||
Command: "python3",
|
||||
Command: "python3",
|
||||
ExternalMCPEnable: true,
|
||||
})
|
||||
handler.manager.AddOrUpdateConfig("test2", config.ExternalMCPServerConfig{
|
||||
URL: "http://127.0.0.1:8081/mcp",
|
||||
URL: "http://127.0.0.1:8081/mcp",
|
||||
ExternalMCPEnable: false,
|
||||
})
|
||||
|
||||
@@ -319,15 +319,15 @@ func TestExternalMCPHandler_GetExternalMCPStats(t *testing.T) {
|
||||
|
||||
// 添加配置
|
||||
handler.manager.AddOrUpdateConfig("enabled1", config.ExternalMCPServerConfig{
|
||||
Command: "python3",
|
||||
Command: "python3",
|
||||
ExternalMCPEnable: true,
|
||||
})
|
||||
handler.manager.AddOrUpdateConfig("enabled2", config.ExternalMCPServerConfig{
|
||||
URL: "http://127.0.0.1:8081/mcp",
|
||||
URL: "http://127.0.0.1:8081/mcp",
|
||||
ExternalMCPEnable: true,
|
||||
})
|
||||
handler.manager.AddOrUpdateConfig("disabled1", config.ExternalMCPServerConfig{
|
||||
Command: "python3",
|
||||
Command: "python3",
|
||||
})
|
||||
|
||||
req := httptest.NewRequest("GET", "/api/external-mcp/stats", nil)
|
||||
@@ -360,7 +360,7 @@ func TestExternalMCPHandler_StartStopExternalMCP(t *testing.T) {
|
||||
|
||||
// 添加一个禁用的配置
|
||||
handler.manager.AddOrUpdateConfig("test-start-stop", config.ExternalMCPServerConfig{
|
||||
Command: "python3",
|
||||
Command: "python3",
|
||||
})
|
||||
|
||||
// 测试启动(可能会失败,因为没有真实的服务器)
|
||||
@@ -416,7 +416,7 @@ func TestExternalMCPHandler_AddOrUpdateExternalMCP_EmptyName(t *testing.T) {
|
||||
router, _, _ := setupTestRouter()
|
||||
|
||||
configObj := config.ExternalMCPServerConfig{
|
||||
Command: "python3",
|
||||
Command: "python3",
|
||||
ExternalMCPEnable: true,
|
||||
}
|
||||
|
||||
@@ -459,14 +459,14 @@ func TestExternalMCPHandler_UpdateExistingConfig(t *testing.T) {
|
||||
|
||||
// 先添加配置
|
||||
config1 := config.ExternalMCPServerConfig{
|
||||
Command: "python3",
|
||||
Command: "python3",
|
||||
ExternalMCPEnable: true,
|
||||
}
|
||||
handler.manager.AddOrUpdateConfig("test-update", config1)
|
||||
|
||||
// 更新配置
|
||||
config2 := config.ExternalMCPServerConfig{
|
||||
URL: "http://127.0.0.1:8081/mcp",
|
||||
URL: "http://127.0.0.1:8081/mcp",
|
||||
ExternalMCPEnable: true,
|
||||
}
|
||||
|
||||
|
||||
+21
-11
@@ -85,7 +85,7 @@ CREATE TABLE IF NOT EXISTS hitl_conversation_configs (
|
||||
enabled INTEGER NOT NULL DEFAULT 0,
|
||||
mode TEXT NOT NULL DEFAULT 'off',
|
||||
sensitive_tools TEXT NOT NULL DEFAULT '[]',
|
||||
timeout_seconds INTEGER NOT NULL DEFAULT 300,
|
||||
timeout_seconds INTEGER NOT NULL DEFAULT 0,
|
||||
updated_at DATETIME NOT NULL
|
||||
);`)
|
||||
if err != nil {
|
||||
@@ -133,7 +133,8 @@ func (m *HITLManager) ActivateConversation(conversationID string, req *HITLReque
|
||||
tools[n] = struct{}{}
|
||||
}
|
||||
}
|
||||
timeout := 5 * time.Minute
|
||||
// timeout <= 0 means wait forever (no timeout).
|
||||
timeout := time.Duration(0)
|
||||
if req.TimeoutSeconds > 0 {
|
||||
timeout = time.Duration(req.TimeoutSeconds) * time.Second
|
||||
}
|
||||
@@ -275,8 +276,8 @@ func (m *HITLManager) ensureConversationHITLModePersisted(conversationID, interr
|
||||
}
|
||||
cfg.Enabled = true
|
||||
cfg.Mode = nm
|
||||
if cfg.TimeoutSeconds <= 0 {
|
||||
cfg.TimeoutSeconds = 300
|
||||
if cfg.TimeoutSeconds < 0 {
|
||||
cfg.TimeoutSeconds = 0
|
||||
}
|
||||
return m.SaveConversationConfig(conversationID, cfg)
|
||||
}
|
||||
@@ -341,7 +342,7 @@ func (m *HITLManager) SaveConversationConfig(conversationID string, req *HITLReq
|
||||
return errors.New("conversationId is required")
|
||||
}
|
||||
if req == nil {
|
||||
req = &HITLRequest{Enabled: false, Mode: "off", TimeoutSeconds: 300}
|
||||
req = &HITLRequest{Enabled: false, Mode: "off", TimeoutSeconds: 0}
|
||||
}
|
||||
mode := normalizeHitlMode(req.Mode)
|
||||
if !req.Enabled {
|
||||
@@ -349,8 +350,8 @@ func (m *HITLManager) SaveConversationConfig(conversationID string, req *HITLReq
|
||||
}
|
||||
tools, _ := json.Marshal(req.SensitiveTools)
|
||||
timeout := req.TimeoutSeconds
|
||||
if timeout <= 0 {
|
||||
timeout = 300
|
||||
if timeout < 0 {
|
||||
timeout = 0
|
||||
}
|
||||
_, err := m.db.Exec(`INSERT INTO hitl_conversation_configs
|
||||
(conversation_id, enabled, mode, sensitive_tools, timeout_seconds, updated_at)
|
||||
@@ -368,11 +369,14 @@ func (m *HITLManager) LoadConversationConfig(conversationID string) (*HITLReques
|
||||
err := m.db.QueryRow(`SELECT enabled, mode, sensitive_tools, timeout_seconds FROM hitl_conversation_configs WHERE conversation_id = ?`, conversationID).
|
||||
Scan(&enabledInt, &mode, &toolsJSON, &timeout)
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
return &HITLRequest{Enabled: false, Mode: "off", SensitiveTools: []string{}, TimeoutSeconds: 300}, nil
|
||||
return &HITLRequest{Enabled: false, Mode: "off", SensitiveTools: []string{}, TimeoutSeconds: 0}, nil
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if timeout < 0 {
|
||||
timeout = 0
|
||||
}
|
||||
tools := make([]string, 0)
|
||||
_ = json.Unmarshal([]byte(toolsJSON), &tools)
|
||||
return &HITLRequest{
|
||||
@@ -389,6 +393,12 @@ func (m *HITLManager) waitDecision(ctx context.Context, p *pendingInterrupt, tim
|
||||
delete(m.pending, p.InterruptID)
|
||||
m.mu.Unlock()
|
||||
}()
|
||||
var timeoutCh <-chan time.Time
|
||||
if timeout > 0 {
|
||||
timer := time.NewTimer(timeout)
|
||||
defer timer.Stop()
|
||||
timeoutCh = timer.C
|
||||
}
|
||||
select {
|
||||
case d := <-p.decideCh:
|
||||
// 只有 review_edit 模式允许改参;其他模式一律忽略 edited arguments
|
||||
@@ -398,7 +408,7 @@ func (m *HITLManager) waitDecision(ctx context.Context, p *pendingInterrupt, tim
|
||||
_, _ = m.db.Exec(`UPDATE hitl_interrupts SET status='decided', decision=?, decision_comment=?, decided_at=? WHERE id=?`,
|
||||
d.Decision, d.Comment, time.Now(), p.InterruptID)
|
||||
return d, nil
|
||||
case <-time.After(timeout):
|
||||
case <-timeoutCh:
|
||||
_, _ = m.db.Exec(`UPDATE hitl_interrupts SET status='timeout', decision='approve', decision_comment='timeout auto approve', decided_at=? WHERE id=?`,
|
||||
time.Now(), p.InterruptID)
|
||||
return hitlDecision{Decision: "approve", Comment: "timeout auto approve"}, nil
|
||||
@@ -718,8 +728,8 @@ func (h *AgentHandler) GetHITLConversationConfig(c *gin.Context) {
|
||||
cfg2 := *cfg
|
||||
cfg2.Enabled = true
|
||||
cfg2.Mode = normalizeHitlMode(pendMode)
|
||||
if cfg2.TimeoutSeconds <= 0 {
|
||||
cfg2.TimeoutSeconds = 300
|
||||
if cfg2.TimeoutSeconds < 0 {
|
||||
cfg2.TimeoutSeconds = 0
|
||||
}
|
||||
cfg = &cfg2
|
||||
}
|
||||
|
||||
@@ -131,16 +131,16 @@ func (h *MarkdownAgentsHandler) GetMarkdownAgent(c *gin.Context) {
|
||||
}
|
||||
|
||||
type markdownAgentBody struct {
|
||||
Filename string `json:"filename"`
|
||||
ID string `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description"`
|
||||
Tools []string `json:"tools"`
|
||||
Instruction string `json:"instruction"`
|
||||
BindRole string `json:"bind_role"`
|
||||
MaxIterations int `json:"max_iterations"`
|
||||
Kind string `json:"kind"`
|
||||
Raw string `json:"raw"`
|
||||
Filename string `json:"filename"`
|
||||
ID string `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description"`
|
||||
Tools []string `json:"tools"`
|
||||
Instruction string `json:"instruction"`
|
||||
BindRole string `json:"bind_role"`
|
||||
MaxIterations int `json:"max_iterations"`
|
||||
Kind string `json:"kind"`
|
||||
Raw string `json:"raw"`
|
||||
}
|
||||
|
||||
// CreateMarkdownAgent POST /api/multi-agent/markdown-agents
|
||||
|
||||
@@ -42,11 +42,11 @@ func (h *MonitorHandler) SetExternalMCPManager(mgr *mcp.ExternalMCPManager) {
|
||||
type MonitorResponse struct {
|
||||
Executions []*mcp.ToolExecution `json:"executions"`
|
||||
Stats map[string]*mcp.ToolStats `json:"stats"`
|
||||
Timestamp time.Time `json:"timestamp"`
|
||||
Total int `json:"total,omitempty"`
|
||||
Page int `json:"page,omitempty"`
|
||||
PageSize int `json:"page_size,omitempty"`
|
||||
TotalPages int `json:"total_pages,omitempty"`
|
||||
Timestamp time.Time `json:"timestamp"`
|
||||
Total int `json:"total,omitempty"`
|
||||
Page int `json:"page,omitempty"`
|
||||
PageSize int `json:"page_size,omitempty"`
|
||||
TotalPages int `json:"total_pages,omitempty"`
|
||||
}
|
||||
|
||||
// Monitor 获取监控信息
|
||||
@@ -213,7 +213,6 @@ func (h *MonitorHandler) loadStats() map[string]*mcp.ToolStats {
|
||||
return stats
|
||||
}
|
||||
|
||||
|
||||
// GetExecution 获取特定执行记录
|
||||
func (h *MonitorHandler) GetExecution(c *gin.Context) {
|
||||
id := c.Param("id")
|
||||
@@ -416,5 +415,3 @@ func (h *MonitorHandler) DeleteExecutions(c *gin.Context) {
|
||||
h.logger.Info("尝试批量删除内存中的执行记录", zap.Int("count", len(request.IDs)))
|
||||
c.JSON(http.StatusOK, gin.H{"message": "执行记录已删除(如果存在)"})
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -185,6 +185,7 @@ func (h *AgentHandler) MultiAgentLoopStream(c *gin.Context) {
|
||||
)
|
||||
|
||||
if runErr != nil {
|
||||
h.persistEinoAgentTraceForResume(conversationID, result)
|
||||
cause := context.Cause(baseCtx)
|
||||
if errors.Is(cause, ErrTaskCancelled) {
|
||||
taskStatus = "cancelled"
|
||||
@@ -249,9 +250,9 @@ func (h *AgentHandler) MultiAgentLoopStream(c *gin.Context) {
|
||||
)
|
||||
}
|
||||
|
||||
if result.LastReActInput != "" || result.LastReActOutput != "" {
|
||||
if err := h.db.SaveReActData(conversationID, result.LastReActInput, result.LastReActOutput); err != nil {
|
||||
h.logger.Warn("保存 ReAct 数据失败", zap.Error(err))
|
||||
if result.LastAgentTraceInput != "" || result.LastAgentTraceOutput != "" {
|
||||
if err := h.db.SaveAgentTrace(conversationID, result.LastAgentTraceInput, result.LastAgentTraceOutput); err != nil {
|
||||
h.logger.Warn("保存代理轨迹失败", zap.Error(err))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -318,6 +319,7 @@ func (h *AgentHandler) MultiAgentLoop(c *gin.Context) {
|
||||
strings.TrimSpace(req.Orchestration),
|
||||
)
|
||||
if runErr != nil {
|
||||
h.persistEinoAgentTraceForResume(prep.ConversationID, result)
|
||||
h.logger.Error("Eino DeepAgent 执行失败", zap.Error(runErr))
|
||||
errMsg := "执行失败: " + runErr.Error()
|
||||
if prep.AssistantMessageID != "" {
|
||||
@@ -341,9 +343,9 @@ func (h *AgentHandler) MultiAgentLoop(c *gin.Context) {
|
||||
)
|
||||
}
|
||||
|
||||
if result.LastReActInput != "" || result.LastReActOutput != "" {
|
||||
if err := h.db.SaveReActData(prep.ConversationID, result.LastReActInput, result.LastReActOutput); err != nil {
|
||||
h.logger.Warn("保存 ReAct 数据失败", zap.Error(err))
|
||||
if result.LastAgentTraceInput != "" || result.LastAgentTraceOutput != "" {
|
||||
if err := h.db.SaveAgentTrace(prep.ConversationID, result.LastAgentTraceInput, result.LastAgentTraceOutput); err != nil {
|
||||
h.logger.Warn("保存代理轨迹失败", zap.Error(err))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -355,6 +357,19 @@ func (h *AgentHandler) MultiAgentLoop(c *gin.Context) {
|
||||
})
|
||||
}
|
||||
|
||||
// persistEinoAgentTraceForResume 在 Eino 运行异常结束时写入代理轨迹(库列 last_react_*),供下一请求 loadHistoryFromAgentTrace 软续跑。
|
||||
func (h *AgentHandler) persistEinoAgentTraceForResume(conversationID string, result *multiagent.RunResult) {
|
||||
if h == nil || result == nil {
|
||||
return
|
||||
}
|
||||
if result.LastAgentTraceInput == "" && result.LastAgentTraceOutput == "" {
|
||||
return
|
||||
}
|
||||
if err := h.db.SaveAgentTrace(conversationID, result.LastAgentTraceInput, result.LastAgentTraceOutput); err != nil {
|
||||
h.logger.Warn("保存 Eino 续跑上下文失败", zap.String("conversationId", conversationID), zap.Error(err))
|
||||
}
|
||||
}
|
||||
|
||||
func multiAgentHTTPErrorStatus(err error) (int, string) {
|
||||
msg := err.Error()
|
||||
switch {
|
||||
|
||||
@@ -49,7 +49,7 @@ func (h *AgentHandler) prepareMultiAgentSession(req *ChatRequest) (*multiAgentPr
|
||||
}
|
||||
}
|
||||
|
||||
agentHistoryMessages, err := h.loadHistoryFromReActData(conversationID)
|
||||
agentHistoryMessages, err := h.loadHistoryFromAgentTrace(conversationID)
|
||||
if err != nil {
|
||||
historyMessages, getErr := h.db.GetMessages(conversationID)
|
||||
if getErr != nil {
|
||||
|
||||
@@ -0,0 +1,642 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"cyberstrike-ai/internal/database"
|
||||
"github.com/gin-gonic/gin"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// NotificationHandler 聚合通知(Phase 2:服务端统一计算)
|
||||
type NotificationHandler struct {
|
||||
db *database.DB
|
||||
agentHandler *AgentHandler
|
||||
logger *zap.Logger
|
||||
}
|
||||
|
||||
const notificationReadMaxRows = 150
|
||||
|
||||
// NotificationSummaryItem 通知项
|
||||
type NotificationSummaryItem struct {
|
||||
ID string `json:"id"`
|
||||
Level string `json:"level"` // p0/p1/p2
|
||||
Type string `json:"type"`
|
||||
Title string `json:"title"`
|
||||
Desc string `json:"desc"`
|
||||
Ts string `json:"ts"` // RFC3339
|
||||
Count int `json:"count,omitempty"`
|
||||
Actionable bool `json:"actionable"`
|
||||
Read bool `json:"read"`
|
||||
// 以下字段用于前端深链跳转(通知即入口)
|
||||
ConversationID string `json:"conversationId,omitempty"`
|
||||
VulnerabilityID string `json:"vulnerabilityId,omitempty"`
|
||||
ExecutionID string `json:"executionId,omitempty"`
|
||||
InterruptID string `json:"interruptId,omitempty"`
|
||||
}
|
||||
|
||||
// NotificationSummaryResponse 聚合响应
|
||||
type NotificationSummaryResponse struct {
|
||||
SinceMs int64 `json:"sinceMs"`
|
||||
GeneratedAt string `json:"generatedAt"`
|
||||
P0Count int `json:"p0Count"`
|
||||
UnreadCount int `json:"unreadCount"`
|
||||
Counts map[string]int `json:"counts"`
|
||||
Items []NotificationSummaryItem `json:"items"`
|
||||
}
|
||||
|
||||
func NewNotificationHandler(db *database.DB, agentHandler *AgentHandler, logger *zap.Logger) *NotificationHandler {
|
||||
return &NotificationHandler{
|
||||
db: db,
|
||||
agentHandler: agentHandler,
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
func parseSinceMs(raw string) int64 {
|
||||
v := strings.TrimSpace(raw)
|
||||
if v == "" {
|
||||
return 0
|
||||
}
|
||||
if ms, err := strconv.ParseInt(v, 10, 64); err == nil && ms > 0 {
|
||||
return ms
|
||||
}
|
||||
if t, err := time.Parse(time.RFC3339, v); err == nil {
|
||||
return t.UnixMilli()
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
func unixSecToRFC3339(sec int64) string {
|
||||
if sec <= 0 {
|
||||
return time.Now().UTC().Format(time.RFC3339)
|
||||
}
|
||||
return time.Unix(sec, 0).UTC().Format(time.RFC3339)
|
||||
}
|
||||
|
||||
func normalizedSinceSec(sinceMs int64) int64 {
|
||||
sec := sinceMs / 1000
|
||||
// SQLite 默认时间精度到秒;给 1s 回看窗口,避免“同秒内新增”被漏算。
|
||||
if sec > 0 {
|
||||
return sec - 1
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
func normalizeSinceMs(raw int64) int64 {
|
||||
if raw > 0 {
|
||||
return raw
|
||||
}
|
||||
// 默认仅看最近 24 小时,避免首次打开拉全量历史噪音。
|
||||
return time.Now().Add(-24 * time.Hour).UnixMilli()
|
||||
}
|
||||
|
||||
func levelBySeverity(sev string) string {
|
||||
switch strings.ToLower(strings.TrimSpace(sev)) {
|
||||
case "critical", "high":
|
||||
return "p0"
|
||||
case "medium":
|
||||
return "p1"
|
||||
default:
|
||||
return "p2"
|
||||
}
|
||||
}
|
||||
|
||||
func requestWantsEnglish(c *gin.Context) bool {
|
||||
if c == nil {
|
||||
return false
|
||||
}
|
||||
lang := strings.ToLower(strings.TrimSpace(c.Query("lang")))
|
||||
if lang == "" {
|
||||
lang = strings.ToLower(strings.TrimSpace(c.GetHeader("Accept-Language")))
|
||||
}
|
||||
return strings.HasPrefix(lang, "en")
|
||||
}
|
||||
|
||||
func i18nText(english bool, zh string, en string) string {
|
||||
if english {
|
||||
return en
|
||||
}
|
||||
return zh
|
||||
}
|
||||
|
||||
func (h *NotificationHandler) loadPendingHITLItems(limit int, english bool) ([]NotificationSummaryItem, error) {
|
||||
rows, err := h.db.Query(`
|
||||
SELECT
|
||||
id,
|
||||
conversation_id,
|
||||
tool_name,
|
||||
COALESCE(CAST(strftime('%s', created_at) AS INTEGER), 0)
|
||||
FROM hitl_interrupts
|
||||
WHERE status = 'pending'
|
||||
ORDER BY created_at DESC
|
||||
LIMIT ?
|
||||
`, limit)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
items := make([]NotificationSummaryItem, 0, limit)
|
||||
for rows.Next() {
|
||||
var id, conversationID, toolName string
|
||||
var createdSec int64
|
||||
if err := rows.Scan(&id, &conversationID, &toolName, &createdSec); err != nil {
|
||||
continue
|
||||
}
|
||||
desc := i18nText(english, "会话 "+conversationID+" 的审批中断待处理", "Conversation "+conversationID+" has pending HITL approval")
|
||||
if strings.TrimSpace(toolName) != "" {
|
||||
desc = i18nText(english, "工具 "+toolName+" 等待审批", "Tool "+toolName+" is waiting for approval")
|
||||
}
|
||||
items = append(items, NotificationSummaryItem{
|
||||
ID: "hitl:" + id,
|
||||
Level: "p0",
|
||||
Type: "hitl_pending",
|
||||
Title: i18nText(english, "HITL 待审批", "HITL Pending Approval"),
|
||||
Desc: desc,
|
||||
Ts: unixSecToRFC3339(createdSec),
|
||||
Count: 1,
|
||||
Actionable: true,
|
||||
Read: false,
|
||||
ConversationID: conversationID,
|
||||
InterruptID: id,
|
||||
})
|
||||
}
|
||||
return items, nil
|
||||
}
|
||||
|
||||
func (h *NotificationHandler) loadVulnerabilityItems(sinceMs int64, limit int, english bool) ([]NotificationSummaryItem, map[string]int, error) {
|
||||
sinceSec := normalizedSinceSec(sinceMs)
|
||||
rows, err := h.db.Query(`
|
||||
SELECT
|
||||
id,
|
||||
title,
|
||||
severity,
|
||||
conversation_id,
|
||||
COALESCE(CAST(strftime('%s', created_at) AS INTEGER), 0)
|
||||
FROM vulnerabilities
|
||||
WHERE CAST(strftime('%s', created_at) AS INTEGER) > ?
|
||||
ORDER BY created_at DESC
|
||||
LIMIT ?
|
||||
`, sinceSec, limit)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
items := make([]NotificationSummaryItem, 0, limit)
|
||||
counts := map[string]int{
|
||||
"newCriticalVulns": 0,
|
||||
"newHighVulns": 0,
|
||||
"newMediumVulns": 0,
|
||||
"newLowVulns": 0,
|
||||
"newInfoVulns": 0,
|
||||
}
|
||||
for rows.Next() {
|
||||
var id, title, severity, conversationID string
|
||||
var createdSec int64
|
||||
if err := rows.Scan(&id, &title, &severity, &conversationID, &createdSec); err != nil {
|
||||
continue
|
||||
}
|
||||
switch strings.ToLower(strings.TrimSpace(severity)) {
|
||||
case "critical":
|
||||
counts["newCriticalVulns"]++
|
||||
case "high":
|
||||
counts["newHighVulns"]++
|
||||
case "medium":
|
||||
counts["newMediumVulns"]++
|
||||
case "low":
|
||||
counts["newLowVulns"]++
|
||||
default:
|
||||
counts["newInfoVulns"]++
|
||||
}
|
||||
sevUpper := strings.ToUpper(strings.TrimSpace(severity))
|
||||
if sevUpper == "" {
|
||||
sevUpper = "INFO"
|
||||
}
|
||||
finalTitle := i18nText(english, "新漏洞("+sevUpper+")", "New Vulnerability ("+sevUpper+")")
|
||||
finalDesc := strings.TrimSpace(title)
|
||||
if finalDesc == "" {
|
||||
finalDesc = i18nText(english, "(无标题)", "(Untitled)")
|
||||
}
|
||||
items = append(items, NotificationSummaryItem{
|
||||
ID: "vuln:" + id,
|
||||
Level: levelBySeverity(severity),
|
||||
Type: "vulnerability_created",
|
||||
Title: finalTitle,
|
||||
Desc: finalDesc,
|
||||
Ts: unixSecToRFC3339(createdSec),
|
||||
Count: 1,
|
||||
Actionable: false,
|
||||
Read: false,
|
||||
ConversationID: conversationID,
|
||||
VulnerabilityID: id,
|
||||
})
|
||||
}
|
||||
return items, counts, nil
|
||||
}
|
||||
|
||||
func (h *NotificationHandler) loadFailedExecutionItems(sinceMs int64, limit int, english bool) ([]NotificationSummaryItem, int, error) {
|
||||
sinceSec := normalizedSinceSec(sinceMs)
|
||||
rows, err := h.db.Query(`
|
||||
SELECT
|
||||
id,
|
||||
tool_name,
|
||||
COALESCE(CAST(strftime('%s', start_time) AS INTEGER), 0)
|
||||
FROM tool_executions
|
||||
WHERE status = 'failed'
|
||||
AND CAST(strftime('%s', start_time) AS INTEGER) > ?
|
||||
ORDER BY start_time DESC
|
||||
LIMIT ?
|
||||
`, sinceSec, limit)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
defer rows.Close()
|
||||
items := make([]NotificationSummaryItem, 0, limit)
|
||||
count := 0
|
||||
for rows.Next() {
|
||||
var id, toolName string
|
||||
var startSec int64
|
||||
if err := rows.Scan(&id, &toolName, &startSec); err != nil {
|
||||
continue
|
||||
}
|
||||
count++
|
||||
if strings.TrimSpace(toolName) == "" {
|
||||
toolName = i18nText(english, "未知工具", "unknown")
|
||||
}
|
||||
items = append(items, NotificationSummaryItem{
|
||||
ID: "exec_failed:" + id,
|
||||
Level: "p0",
|
||||
Type: "task_failed",
|
||||
Title: i18nText(english, "任务执行失败", "Task Execution Failed"),
|
||||
Desc: i18nText(english, "工具 "+toolName+" 执行失败", "Tool "+toolName+" execution failed"),
|
||||
Ts: unixSecToRFC3339(startSec),
|
||||
Count: 1,
|
||||
Actionable: false,
|
||||
Read: false,
|
||||
ExecutionID: id,
|
||||
})
|
||||
}
|
||||
return items, count, nil
|
||||
}
|
||||
|
||||
func (h *NotificationHandler) summarizeLongRunningTasks(threshold time.Duration, english bool) ([]NotificationSummaryItem, int) {
|
||||
if h.agentHandler == nil || h.agentHandler.tasks == nil {
|
||||
return nil, 0
|
||||
}
|
||||
tasks := h.agentHandler.tasks.GetActiveTasks()
|
||||
now := time.Now()
|
||||
items := make([]NotificationSummaryItem, 0, len(tasks))
|
||||
for _, t := range tasks {
|
||||
if t == nil {
|
||||
continue
|
||||
}
|
||||
if now.Sub(t.StartedAt) >= threshold {
|
||||
items = append(items, NotificationSummaryItem{
|
||||
ID: "task_long:" + t.ConversationID,
|
||||
Level: "p1",
|
||||
Type: "long_running_tasks",
|
||||
Title: i18nText(english, "长时间运行任务", "Long Running Task"),
|
||||
Desc: i18nText(english, "会话 "+t.ConversationID+" 运行超过 15 分钟", "Conversation "+t.ConversationID+" has been running over 15 minutes"),
|
||||
Ts: t.StartedAt.UTC().Format(time.RFC3339),
|
||||
Count: 1,
|
||||
Actionable: true,
|
||||
Read: false,
|
||||
ConversationID: t.ConversationID,
|
||||
})
|
||||
}
|
||||
}
|
||||
return items, len(items)
|
||||
}
|
||||
|
||||
func (h *NotificationHandler) summarizeCompletedTasksSince(sinceMs int64, limit int, english bool) ([]NotificationSummaryItem, int) {
|
||||
if h.agentHandler == nil || h.agentHandler.tasks == nil {
|
||||
return nil, 0
|
||||
}
|
||||
since := time.UnixMilli(sinceMs)
|
||||
completed := h.agentHandler.tasks.GetCompletedTasks()
|
||||
items := make([]NotificationSummaryItem, 0, limit)
|
||||
for _, t := range completed {
|
||||
if t == nil {
|
||||
continue
|
||||
}
|
||||
if t.CompletedAt.After(since) {
|
||||
items = append(items, NotificationSummaryItem{
|
||||
ID: "task_completed:" + t.ConversationID + ":" + strconv.FormatInt(t.CompletedAt.Unix(), 10),
|
||||
Level: "p2",
|
||||
Type: "task_completed",
|
||||
Title: i18nText(english, "任务完成", "Task Completed"),
|
||||
Desc: i18nText(english, "会话 "+t.ConversationID+" 已完成", "Conversation "+t.ConversationID+" completed"),
|
||||
Ts: t.CompletedAt.UTC().Format(time.RFC3339),
|
||||
Count: 1,
|
||||
Actionable: false,
|
||||
Read: false,
|
||||
ConversationID: t.ConversationID,
|
||||
})
|
||||
if len(items) >= limit {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
return items, len(items)
|
||||
}
|
||||
|
||||
func buildPlaceholders(n int) string {
|
||||
if n <= 0 {
|
||||
return ""
|
||||
}
|
||||
out := make([]string, 0, n)
|
||||
for i := 0; i < n; i++ {
|
||||
out = append(out, "?")
|
||||
}
|
||||
return strings.Join(out, ",")
|
||||
}
|
||||
|
||||
func (h *NotificationHandler) readStatesByIDs(ids []string) (map[string]bool, error) {
|
||||
result := make(map[string]bool, len(ids))
|
||||
if len(ids) == 0 {
|
||||
return result, nil
|
||||
}
|
||||
holders := buildPlaceholders(len(ids))
|
||||
query := "SELECT event_id FROM notification_reads WHERE event_id IN (" + holders + ")"
|
||||
args := make([]interface{}, 0, len(ids))
|
||||
for _, id := range ids {
|
||||
args = append(args, id)
|
||||
}
|
||||
rows, err := h.db.Query(query, args...)
|
||||
if err != nil {
|
||||
return result, err
|
||||
}
|
||||
defer rows.Close()
|
||||
for rows.Next() {
|
||||
var id string
|
||||
if err := rows.Scan(&id); err != nil {
|
||||
continue
|
||||
}
|
||||
result[id] = true
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (h *NotificationHandler) applyReadStates(items []NotificationSummaryItem) ([]NotificationSummaryItem, error) {
|
||||
markableIDs := make([]string, 0, len(items))
|
||||
for _, item := range items {
|
||||
if item.Actionable {
|
||||
continue
|
||||
}
|
||||
markableIDs = append(markableIDs, item.ID)
|
||||
}
|
||||
readMap, err := h.readStatesByIDs(markableIDs)
|
||||
if err != nil {
|
||||
return items, err
|
||||
}
|
||||
for i := range items {
|
||||
if items[i].Actionable {
|
||||
items[i].Read = false
|
||||
continue
|
||||
}
|
||||
items[i].Read = readMap[items[i].ID]
|
||||
}
|
||||
return items, nil
|
||||
}
|
||||
|
||||
func filterVisibleItems(items []NotificationSummaryItem) []NotificationSummaryItem {
|
||||
out := make([]NotificationSummaryItem, 0, len(items))
|
||||
for _, item := range items {
|
||||
if item.Actionable || !item.Read {
|
||||
out = append(out, item)
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func countP0(items []NotificationSummaryItem) int {
|
||||
total := 0
|
||||
for _, item := range items {
|
||||
if item.Level == "p0" {
|
||||
if item.Count > 0 {
|
||||
total += item.Count
|
||||
} else {
|
||||
total++
|
||||
}
|
||||
}
|
||||
}
|
||||
return total
|
||||
}
|
||||
|
||||
func countUnread(items []NotificationSummaryItem) int {
|
||||
total := 0
|
||||
for _, item := range items {
|
||||
if item.Actionable || !item.Read {
|
||||
if item.Count > 0 {
|
||||
total += item.Count
|
||||
} else {
|
||||
total++
|
||||
}
|
||||
}
|
||||
}
|
||||
return total
|
||||
}
|
||||
|
||||
func createNotificationReadTableIfNeeded(db *database.DB) error {
|
||||
if db == nil {
|
||||
return fmt.Errorf("db is nil")
|
||||
}
|
||||
_, err := db.Exec(`
|
||||
CREATE TABLE IF NOT EXISTS notification_reads (
|
||||
event_id TEXT PRIMARY KEY,
|
||||
read_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP
|
||||
);
|
||||
`)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, idxErr := db.Exec(`CREATE INDEX IF NOT EXISTS idx_notification_reads_read_at ON notification_reads(read_at DESC);`)
|
||||
return idxErr
|
||||
}
|
||||
|
||||
func pruneNotificationReads(db *database.DB, maxRows int) error {
|
||||
if db == nil {
|
||||
return fmt.Errorf("db is nil")
|
||||
}
|
||||
if maxRows <= 0 {
|
||||
return nil
|
||||
}
|
||||
_, err := db.Exec(`
|
||||
DELETE FROM notification_reads
|
||||
WHERE event_id NOT IN (
|
||||
SELECT event_id
|
||||
FROM notification_reads
|
||||
ORDER BY read_at DESC, rowid DESC
|
||||
LIMIT ?
|
||||
)
|
||||
`, maxRows)
|
||||
return err
|
||||
}
|
||||
|
||||
type markReadRequest struct {
|
||||
EventIDs []string `json:"eventIds"`
|
||||
}
|
||||
|
||||
func normalizeMarkableEventID(id string) (string, bool) {
|
||||
v := strings.TrimSpace(id)
|
||||
if v == "" {
|
||||
return "", false
|
||||
}
|
||||
// 仅允许“可读后隐藏”的信息类事件;Actionable 事件不参与 read 标记。
|
||||
allowedPrefixes := []string{
|
||||
"vuln:",
|
||||
"exec_failed:",
|
||||
"task_completed:",
|
||||
}
|
||||
for _, prefix := range allowedPrefixes {
|
||||
if strings.HasPrefix(v, prefix) {
|
||||
return v, true
|
||||
}
|
||||
}
|
||||
return "", false
|
||||
}
|
||||
|
||||
// MarkRead 按事件 ID 标记已读
|
||||
func (h *NotificationHandler) MarkRead(c *gin.Context) {
|
||||
if err := createNotificationReadTableIfNeeded(h.db); err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to prepare notification read table"})
|
||||
return
|
||||
}
|
||||
var req markReadRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid request body"})
|
||||
return
|
||||
}
|
||||
if len(req.EventIDs) == 0 {
|
||||
c.JSON(http.StatusOK, gin.H{"ok": true, "marked": 0})
|
||||
return
|
||||
}
|
||||
tx, err := h.db.Begin()
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to begin transaction"})
|
||||
return
|
||||
}
|
||||
defer func() {
|
||||
_ = tx.Rollback()
|
||||
}()
|
||||
stmt, err := tx.Prepare(`
|
||||
INSERT INTO notification_reads(event_id, read_at)
|
||||
VALUES(?, CURRENT_TIMESTAMP)
|
||||
ON CONFLICT(event_id) DO UPDATE SET read_at = CURRENT_TIMESTAMP
|
||||
`)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to prepare statement"})
|
||||
return
|
||||
}
|
||||
defer stmt.Close()
|
||||
marked := 0
|
||||
for _, raw := range req.EventIDs {
|
||||
id, ok := normalizeMarkableEventID(raw)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
if _, err := stmt.Exec(id); err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to mark read"})
|
||||
return
|
||||
}
|
||||
marked++
|
||||
}
|
||||
if err := tx.Commit(); err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to commit read marks"})
|
||||
return
|
||||
}
|
||||
if err := pruneNotificationReads(h.db, notificationReadMaxRows); err != nil {
|
||||
h.logger.Warn("裁剪通知已读记录失败", zap.Error(err))
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{"ok": true, "marked": marked})
|
||||
}
|
||||
|
||||
// GetSummary 返回通知聚合视图(用于头部铃铛)
|
||||
func (h *NotificationHandler) GetSummary(c *gin.Context) {
|
||||
if h.db == nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "database unavailable"})
|
||||
return
|
||||
}
|
||||
|
||||
if err := createNotificationReadTableIfNeeded(h.db); err != nil {
|
||||
h.logger.Warn("初始化通知已读表失败", zap.Error(err))
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to initialize notification read table"})
|
||||
return
|
||||
}
|
||||
|
||||
english := requestWantsEnglish(c)
|
||||
sinceMs := normalizeSinceMs(parseSinceMs(c.Query("since")))
|
||||
limit, _ := strconv.Atoi(strings.TrimSpace(c.DefaultQuery("limit", "50")))
|
||||
if limit <= 0 {
|
||||
limit = 50
|
||||
}
|
||||
if limit > 200 {
|
||||
limit = 200
|
||||
}
|
||||
|
||||
hitlItems, err := h.loadPendingHITLItems(limit, english)
|
||||
if err != nil {
|
||||
h.logger.Warn("加载 HITL 通知失败", zap.Error(err))
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to summarize hitl notifications"})
|
||||
return
|
||||
}
|
||||
|
||||
vulnItems, vulnCounts, err := h.loadVulnerabilityItems(sinceMs, limit, english)
|
||||
if err != nil {
|
||||
h.logger.Warn("加载漏洞通知失败", zap.Error(err))
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to summarize vulnerabilities"})
|
||||
return
|
||||
}
|
||||
|
||||
longRunningItems, longRunningCount := h.summarizeLongRunningTasks(15*time.Minute, english)
|
||||
completedItems, completedCount := h.summarizeCompletedTasksSince(sinceMs, limit, english)
|
||||
|
||||
items := make([]NotificationSummaryItem, 0, len(hitlItems)+len(vulnItems)+len(longRunningItems)+len(completedItems))
|
||||
items = append(items, hitlItems...)
|
||||
items = append(items, vulnItems...)
|
||||
items = append(items, longRunningItems...)
|
||||
items = append(items, completedItems...)
|
||||
|
||||
items, err = h.applyReadStates(items)
|
||||
if err != nil {
|
||||
h.logger.Warn("加载通知已读状态失败", zap.Error(err))
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to load notification read states"})
|
||||
return
|
||||
}
|
||||
items = filterVisibleItems(items)
|
||||
|
||||
sort.Slice(items, func(i, j int) bool {
|
||||
ti, errI := time.Parse(time.RFC3339, items[i].Ts)
|
||||
tj, errJ := time.Parse(time.RFC3339, items[j].Ts)
|
||||
if errI != nil || errJ != nil {
|
||||
return i < j
|
||||
}
|
||||
return ti.After(tj)
|
||||
})
|
||||
|
||||
p0Count := countP0(items)
|
||||
unreadCount := countUnread(items)
|
||||
c.JSON(http.StatusOK, NotificationSummaryResponse{
|
||||
SinceMs: sinceMs,
|
||||
GeneratedAt: time.Now().UTC().Format(time.RFC3339),
|
||||
P0Count: p0Count,
|
||||
UnreadCount: unreadCount,
|
||||
Counts: map[string]int{
|
||||
"hitlPending": len(hitlItems),
|
||||
"newCriticalVulns": vulnCounts["newCriticalVulns"],
|
||||
"newHighVulns": vulnCounts["newHighVulns"],
|
||||
"newMediumVulns": vulnCounts["newMediumVulns"],
|
||||
"newLowVulns": vulnCounts["newLowVulns"],
|
||||
"newInfoVulns": vulnCounts["newInfoVulns"],
|
||||
"failedExecutions": 0,
|
||||
"longRunningTasks": longRunningCount,
|
||||
"completedTasks": completedCount,
|
||||
},
|
||||
Items: items,
|
||||
})
|
||||
}
|
||||
+26
-26
@@ -4445,7 +4445,7 @@ func (h *OpenAPIHandler) GetOpenAPISpec(c *gin.Context) {
|
||||
"content": map[string]interface{}{
|
||||
"application/json": map[string]interface{}{
|
||||
"schema": map[string]interface{}{
|
||||
"type": "object",
|
||||
"type": "object",
|
||||
"required": []string{"messageId"},
|
||||
"properties": map[string]interface{}{
|
||||
"messageId": map[string]interface{}{
|
||||
@@ -4689,7 +4689,7 @@ func (h *OpenAPIHandler) GetOpenAPISpec(c *gin.Context) {
|
||||
"content": map[string]interface{}{
|
||||
"application/json": map[string]interface{}{
|
||||
"schema": map[string]interface{}{
|
||||
"type": "object",
|
||||
"type": "object",
|
||||
"required": []string{"scheduleEnabled"},
|
||||
"properties": map[string]interface{}{
|
||||
"scheduleEnabled": map[string]interface{}{"type": "boolean", "description": "是否启用自动调度"},
|
||||
@@ -4761,7 +4761,7 @@ func (h *OpenAPIHandler) GetOpenAPISpec(c *gin.Context) {
|
||||
"content": map[string]interface{}{
|
||||
"application/json": map[string]interface{}{
|
||||
"schema": map[string]interface{}{
|
||||
"type": "object",
|
||||
"type": "object",
|
||||
"required": []string{"query"},
|
||||
"properties": map[string]interface{}{
|
||||
"query": map[string]interface{}{"type": "string", "description": "FOFA查询语法", "example": "domain=\"example.com\""},
|
||||
@@ -4810,7 +4810,7 @@ func (h *OpenAPIHandler) GetOpenAPISpec(c *gin.Context) {
|
||||
"content": map[string]interface{}{
|
||||
"application/json": map[string]interface{}{
|
||||
"schema": map[string]interface{}{
|
||||
"type": "object",
|
||||
"type": "object",
|
||||
"required": []string{"text"},
|
||||
"properties": map[string]interface{}{
|
||||
"text": map[string]interface{}{"type": "string", "description": "自然语言描述", "example": "查找使用WordPress的网站"},
|
||||
@@ -4853,7 +4853,7 @@ func (h *OpenAPIHandler) GetOpenAPISpec(c *gin.Context) {
|
||||
"content": map[string]interface{}{
|
||||
"application/json": map[string]interface{}{
|
||||
"schema": map[string]interface{}{
|
||||
"type": "object",
|
||||
"type": "object",
|
||||
"required": []string{"api_key", "model"},
|
||||
"properties": map[string]interface{}{
|
||||
"provider": map[string]interface{}{"type": "string", "description": "LLM提供商(openai/claude)", "example": "openai"},
|
||||
@@ -4900,7 +4900,7 @@ func (h *OpenAPIHandler) GetOpenAPISpec(c *gin.Context) {
|
||||
"content": map[string]interface{}{
|
||||
"application/json": map[string]interface{}{
|
||||
"schema": map[string]interface{}{
|
||||
"type": "object",
|
||||
"type": "object",
|
||||
"required": []string{"command"},
|
||||
"properties": map[string]interface{}{
|
||||
"command": map[string]interface{}{"type": "string", "description": "要执行的命令"},
|
||||
@@ -4943,7 +4943,7 @@ func (h *OpenAPIHandler) GetOpenAPISpec(c *gin.Context) {
|
||||
"content": map[string]interface{}{
|
||||
"application/json": map[string]interface{}{
|
||||
"schema": map[string]interface{}{
|
||||
"type": "object",
|
||||
"type": "object",
|
||||
"required": []string{"command"},
|
||||
"properties": map[string]interface{}{
|
||||
"command": map[string]interface{}{"type": "string", "description": "要执行的命令"},
|
||||
@@ -5027,7 +5027,7 @@ func (h *OpenAPIHandler) GetOpenAPISpec(c *gin.Context) {
|
||||
"content": map[string]interface{}{
|
||||
"application/json": map[string]interface{}{
|
||||
"schema": map[string]interface{}{
|
||||
"type": "object",
|
||||
"type": "object",
|
||||
"required": []string{"url"},
|
||||
"properties": map[string]interface{}{
|
||||
"url": map[string]interface{}{"type": "string", "description": "WebShell URL"},
|
||||
@@ -5231,7 +5231,7 @@ func (h *OpenAPIHandler) GetOpenAPISpec(c *gin.Context) {
|
||||
"content": map[string]interface{}{
|
||||
"application/json": map[string]interface{}{
|
||||
"schema": map[string]interface{}{
|
||||
"type": "object",
|
||||
"type": "object",
|
||||
"required": []string{"url", "command"},
|
||||
"properties": map[string]interface{}{
|
||||
"url": map[string]interface{}{"type": "string", "description": "WebShell URL"},
|
||||
@@ -5277,7 +5277,7 @@ func (h *OpenAPIHandler) GetOpenAPISpec(c *gin.Context) {
|
||||
"content": map[string]interface{}{
|
||||
"application/json": map[string]interface{}{
|
||||
"schema": map[string]interface{}{
|
||||
"type": "object",
|
||||
"type": "object",
|
||||
"required": []string{"url", "action", "path"},
|
||||
"properties": map[string]interface{}{
|
||||
"url": map[string]interface{}{"type": "string", "description": "WebShell URL"},
|
||||
@@ -5339,14 +5339,14 @@ func (h *OpenAPIHandler) GetOpenAPISpec(c *gin.Context) {
|
||||
"items": map[string]interface{}{
|
||||
"type": "object",
|
||||
"properties": map[string]interface{}{
|
||||
"relativePath": map[string]interface{}{"type": "string"},
|
||||
"absolutePath": map[string]interface{}{"type": "string"},
|
||||
"name": map[string]interface{}{"type": "string"},
|
||||
"size": map[string]interface{}{"type": "integer"},
|
||||
"modifiedUnix": map[string]interface{}{"type": "integer"},
|
||||
"date": map[string]interface{}{"type": "string"},
|
||||
"relativePath": map[string]interface{}{"type": "string"},
|
||||
"absolutePath": map[string]interface{}{"type": "string"},
|
||||
"name": map[string]interface{}{"type": "string"},
|
||||
"size": map[string]interface{}{"type": "integer"},
|
||||
"modifiedUnix": map[string]interface{}{"type": "integer"},
|
||||
"date": map[string]interface{}{"type": "string"},
|
||||
"conversationId": map[string]interface{}{"type": "string"},
|
||||
"subPath": map[string]interface{}{"type": "string"},
|
||||
"subPath": map[string]interface{}{"type": "string"},
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -5369,7 +5369,7 @@ func (h *OpenAPIHandler) GetOpenAPISpec(c *gin.Context) {
|
||||
"content": map[string]interface{}{
|
||||
"multipart/form-data": map[string]interface{}{
|
||||
"schema": map[string]interface{}{
|
||||
"type": "object",
|
||||
"type": "object",
|
||||
"required": []string{"file"},
|
||||
"properties": map[string]interface{}{
|
||||
"file": map[string]interface{}{"type": "string", "format": "binary", "description": "上传的文件"},
|
||||
@@ -5410,7 +5410,7 @@ func (h *OpenAPIHandler) GetOpenAPISpec(c *gin.Context) {
|
||||
"content": map[string]interface{}{
|
||||
"application/json": map[string]interface{}{
|
||||
"schema": map[string]interface{}{
|
||||
"type": "object",
|
||||
"type": "object",
|
||||
"required": []string{"path"},
|
||||
"properties": map[string]interface{}{
|
||||
"path": map[string]interface{}{"type": "string", "description": "文件相对路径"},
|
||||
@@ -5485,7 +5485,7 @@ func (h *OpenAPIHandler) GetOpenAPISpec(c *gin.Context) {
|
||||
"content": map[string]interface{}{
|
||||
"application/json": map[string]interface{}{
|
||||
"schema": map[string]interface{}{
|
||||
"type": "object",
|
||||
"type": "object",
|
||||
"required": []string{"path", "content"},
|
||||
"properties": map[string]interface{}{
|
||||
"path": map[string]interface{}{"type": "string", "description": "文件相对路径"},
|
||||
@@ -5512,7 +5512,7 @@ func (h *OpenAPIHandler) GetOpenAPISpec(c *gin.Context) {
|
||||
"content": map[string]interface{}{
|
||||
"application/json": map[string]interface{}{
|
||||
"schema": map[string]interface{}{
|
||||
"type": "object",
|
||||
"type": "object",
|
||||
"required": []string{"name"},
|
||||
"properties": map[string]interface{}{
|
||||
"parent": map[string]interface{}{"type": "string", "description": "父目录相对路径"},
|
||||
@@ -5552,7 +5552,7 @@ func (h *OpenAPIHandler) GetOpenAPISpec(c *gin.Context) {
|
||||
"content": map[string]interface{}{
|
||||
"application/json": map[string]interface{}{
|
||||
"schema": map[string]interface{}{
|
||||
"type": "object",
|
||||
"type": "object",
|
||||
"required": []string{"path", "newName"},
|
||||
"properties": map[string]interface{}{
|
||||
"path": map[string]interface{}{"type": "string", "description": "当前文件相对路径"},
|
||||
@@ -5646,7 +5646,7 @@ func (h *OpenAPIHandler) GetOpenAPISpec(c *gin.Context) {
|
||||
"content": map[string]interface{}{
|
||||
"application/json": map[string]interface{}{
|
||||
"schema": map[string]interface{}{
|
||||
"type": "object",
|
||||
"type": "object",
|
||||
"required": []string{"platform", "text"},
|
||||
"properties": map[string]interface{}{
|
||||
"platform": map[string]interface{}{"type": "string", "description": "平台类型", "enum": []string{"dingtalk", "lark", "wecom"}},
|
||||
@@ -5712,7 +5712,7 @@ func (h *OpenAPIHandler) GetOpenAPISpec(c *gin.Context) {
|
||||
"content": map[string]interface{}{
|
||||
"application/json": map[string]interface{}{
|
||||
"schema": map[string]interface{}{
|
||||
"type": "object",
|
||||
"type": "object",
|
||||
"required": []string{"name"},
|
||||
"properties": map[string]interface{}{
|
||||
"filename": map[string]interface{}{"type": "string", "description": "文件名(可选,自动生成)"},
|
||||
@@ -5932,7 +5932,7 @@ func (h *OpenAPIHandler) GetOpenAPISpec(c *gin.Context) {
|
||||
"content": map[string]interface{}{
|
||||
"application/json": map[string]interface{}{
|
||||
"schema": map[string]interface{}{
|
||||
"type": "object",
|
||||
"type": "object",
|
||||
"required": []string{"path"},
|
||||
"properties": map[string]interface{}{
|
||||
"path": map[string]interface{}{"type": "string", "description": "文件相对路径"},
|
||||
@@ -5974,7 +5974,7 @@ func (h *OpenAPIHandler) GetOpenAPISpec(c *gin.Context) {
|
||||
"content": map[string]interface{}{
|
||||
"application/json": map[string]interface{}{
|
||||
"schema": map[string]interface{}{
|
||||
"type": "object",
|
||||
"type": "object",
|
||||
"required": []string{"ids"},
|
||||
"properties": map[string]interface{}{
|
||||
"ids": map[string]interface{}{
|
||||
|
||||
@@ -26,7 +26,7 @@ var apiDocI18nSummaryToKey = map[string]string{
|
||||
"创建分组": "createGroup", "列出分组": "listGroups", "获取分组": "getGroup", "更新分组": "updateGroup",
|
||||
"删除分组": "deleteGroup", "获取分组中的对话": "getGroupConversations", "添加对话到分组": "addConversationToGroup",
|
||||
"从分组移除对话": "removeConversationFromGroup",
|
||||
"列出漏洞": "listVulnerabilities", "创建漏洞": "createVulnerability", "获取漏洞统计": "getVulnerabilityStats",
|
||||
"列出漏洞": "listVulnerabilities", "创建漏洞": "createVulnerability", "获取漏洞统计": "getVulnerabilityStats",
|
||||
"获取漏洞": "getVulnerability", "更新漏洞": "updateVulnerability", "删除漏洞": "deleteVulnerability",
|
||||
"列出角色": "listRoles", "创建角色": "createRole", "获取角色": "getRole", "更新角色": "updateRole", "删除角色": "deleteRole",
|
||||
"获取可用Skills列表": "getAvailableSkills", "列出Skills": "listSkills", "创建Skill": "createSkill",
|
||||
@@ -52,9 +52,9 @@ var apiDocI18nSummaryToKey = map[string]string{
|
||||
"重跑批量任务队列": "rerunBatchQueue", "修改队列元数据": "updateBatchQueueMetadata",
|
||||
"修改队列调度配置": "updateBatchQueueSchedule", "开关Cron自动调度": "setBatchQueueScheduleEnabled",
|
||||
"获取所有分组映射": "getAllGroupMappings",
|
||||
"FOFA搜索": "fofaSearch", "自然语言解析为FOFA语法": "fofaParse",
|
||||
"FOFA搜索": "fofaSearch", "自然语言解析为FOFA语法": "fofaParse",
|
||||
"测试OpenAI API连接": "testOpenAI",
|
||||
"执行终端命令": "terminalRun", "流式执行终端命令": "terminalRunStream", "WebSocket终端": "terminalWS",
|
||||
"执行终端命令": "terminalRun", "流式执行终端命令": "terminalRunStream", "WebSocket终端": "terminalWS",
|
||||
"列出WebShell连接": "listWebshellConnections", "创建WebShell连接": "createWebshellConnection",
|
||||
"更新WebShell连接": "updateWebshellConnection", "删除WebShell连接": "deleteWebshellConnection",
|
||||
"获取连接状态": "getWebshellConnectionState", "保存连接状态": "saveWebshellConnectionState",
|
||||
@@ -69,7 +69,7 @@ var apiDocI18nSummaryToKey = map[string]string{
|
||||
"获取Markdown代理详情": "getMarkdownAgent", "更新Markdown代理": "updateMarkdownAgent", "删除Markdown代理": "deleteMarkdownAgent",
|
||||
"列出技能包文件": "listSkillPackageFiles", "获取技能包文件内容": "getSkillPackageFile", "写入技能包文件": "putSkillPackageFile",
|
||||
"批量获取工具名称": "batchGetToolNames",
|
||||
"获取知识库统计": "getKnowledgeStats",
|
||||
"获取知识库统计": "getKnowledgeStats",
|
||||
}
|
||||
|
||||
var apiDocI18nResponseDescToKey = map[string]string{
|
||||
@@ -78,7 +78,7 @@ var apiDocI18nResponseDescToKey = map[string]string{
|
||||
"对话不存在或结果不存在": "conversationOrResultNotFound", "请求参数错误(如task为空)": "badRequestTaskEmpty",
|
||||
"请求参数错误或分组名称已存在": "badRequestGroupNameExists", "分组不存在": "groupNotFound",
|
||||
"请求参数错误(如配置格式不正确、缺少必需字段等)": "badRequestConfig",
|
||||
"请求参数错误(如query为空)": "badRequestQueryEmpty", "方法不允许(仅支持POST请求)": "methodNotAllowed",
|
||||
"请求参数错误(如query为空)": "badRequestQueryEmpty", "方法不允许(仅支持POST请求)": "methodNotAllowed",
|
||||
"登录成功": "loginSuccess", "密码错误": "invalidPassword", "登出成功": "logoutSuccess",
|
||||
"密码修改成功": "passwordChanged", "Token有效": "tokenValid", "Token无效或已过期": "tokenInvalid",
|
||||
"对话创建成功": "conversationCreated", "服务器内部错误": "internalError", "更新成功": "updateSuccess",
|
||||
@@ -89,7 +89,7 @@ var apiDocI18nResponseDescToKey = map[string]string{
|
||||
"消息发送成功,返回AI回复": "messageSent", "流式响应(Server-Sent Events)": "streamResponse",
|
||||
// 新增缺失端点响应
|
||||
"参数错误或删除失败": "badRequestOrDeleteFailed",
|
||||
"参数错误": "paramError", "仅已完成或已取消的队列可以重跑": "onlyCompletedOrCancelledCanRerun",
|
||||
"参数错误": "paramError", "仅已完成或已取消的队列可以重跑": "onlyCompletedOrCancelledCanRerun",
|
||||
"参数错误或队列正在运行中": "badRequestOrQueueRunning", "设置成功": "setSuccess",
|
||||
"搜索成功": "searchSuccess", "解析成功": "parseSuccess", "测试结果": "testResult",
|
||||
"执行完成": "executionDone", "SSE事件流": "sseEventStream", "WebSocket连接已建立": "wsEstablished",
|
||||
|
||||
+14
-14
@@ -28,20 +28,20 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
robotCmdHelp = "帮助"
|
||||
robotCmdList = "列表"
|
||||
robotCmdListAlt = "对话列表"
|
||||
robotCmdSwitch = "切换"
|
||||
robotCmdContinue = "继续"
|
||||
robotCmdNew = "新对话"
|
||||
robotCmdClear = "清空"
|
||||
robotCmdCurrent = "当前"
|
||||
robotCmdStop = "停止"
|
||||
robotCmdRoles = "角色"
|
||||
robotCmdRolesList = "角色列表"
|
||||
robotCmdSwitchRole = "切换角色"
|
||||
robotCmdDelete = "删除"
|
||||
robotCmdVersion = "版本"
|
||||
robotCmdHelp = "帮助"
|
||||
robotCmdList = "列表"
|
||||
robotCmdListAlt = "对话列表"
|
||||
robotCmdSwitch = "切换"
|
||||
robotCmdContinue = "继续"
|
||||
robotCmdNew = "新对话"
|
||||
robotCmdClear = "清空"
|
||||
robotCmdCurrent = "当前"
|
||||
robotCmdStop = "停止"
|
||||
robotCmdRoles = "角色"
|
||||
robotCmdRolesList = "角色列表"
|
||||
robotCmdSwitchRole = "切换角色"
|
||||
robotCmdDelete = "删除"
|
||||
robotCmdVersion = "版本"
|
||||
)
|
||||
|
||||
// RobotHandler 企业微信/钉钉/飞书等机器人回调处理
|
||||
|
||||
+13
-13
@@ -65,19 +65,19 @@ func (h *SkillsHandler) GetSkills(c *gin.Context) {
|
||||
allSkillsInfo := make([]map[string]interface{}, 0, len(allSummaries))
|
||||
for _, s := range allSummaries {
|
||||
skillInfo := map[string]interface{}{
|
||||
"id": s.ID,
|
||||
"name": s.Name,
|
||||
"dir_name": s.DirName,
|
||||
"description": s.Description,
|
||||
"version": s.Version,
|
||||
"path": s.Path,
|
||||
"tags": s.Tags,
|
||||
"triggers": s.Triggers,
|
||||
"script_count": s.ScriptCount,
|
||||
"file_count": s.FileCount,
|
||||
"progressive": s.Progressive,
|
||||
"file_size": s.FileSize,
|
||||
"mod_time": s.ModTime,
|
||||
"id": s.ID,
|
||||
"name": s.Name,
|
||||
"dir_name": s.DirName,
|
||||
"description": s.Description,
|
||||
"version": s.Version,
|
||||
"path": s.Path,
|
||||
"tags": s.Tags,
|
||||
"triggers": s.Triggers,
|
||||
"script_count": s.ScriptCount,
|
||||
"file_count": s.FileCount,
|
||||
"progressive": s.Progressive,
|
||||
"file_size": s.FileSize,
|
||||
"mod_time": s.ModTime,
|
||||
}
|
||||
allSkillsInfo = append(allSkillsInfo, skillInfo)
|
||||
}
|
||||
|
||||
@@ -109,4 +109,3 @@ func (h *TerminalHandler) RunCommandWS(c *gin.Context) {
|
||||
|
||||
<-doneChan
|
||||
}
|
||||
|
||||
|
||||
@@ -28,7 +28,7 @@ func NewVulnerabilityHandler(db *database.DB, logger *zap.Logger) *Vulnerability
|
||||
|
||||
// CreateVulnerabilityRequest 创建漏洞请求
|
||||
type CreateVulnerabilityRequest struct {
|
||||
ConversationID string `json:"conversation_id" binding:"required"`
|
||||
ConversationID string `json:"conversation_id" binding:"required"`
|
||||
ConversationTag string `json:"conversation_tag"`
|
||||
TaskTag string `json:"task_tag"`
|
||||
Title string `json:"title" binding:"required"`
|
||||
@@ -51,18 +51,18 @@ func (h *VulnerabilityHandler) CreateVulnerability(c *gin.Context) {
|
||||
}
|
||||
|
||||
vuln := &database.Vulnerability{
|
||||
ConversationID: req.ConversationID,
|
||||
ConversationID: req.ConversationID,
|
||||
ConversationTag: req.ConversationTag,
|
||||
TaskTag: req.TaskTag,
|
||||
Title: req.Title,
|
||||
Description: req.Description,
|
||||
Severity: req.Severity,
|
||||
Status: req.Status,
|
||||
Type: req.Type,
|
||||
Target: req.Target,
|
||||
Proof: req.Proof,
|
||||
Impact: req.Impact,
|
||||
Recommendation: req.Recommendation,
|
||||
Title: req.Title,
|
||||
Description: req.Description,
|
||||
Severity: req.Severity,
|
||||
Status: req.Status,
|
||||
Type: req.Type,
|
||||
Target: req.Target,
|
||||
Proof: req.Proof,
|
||||
Impact: req.Impact,
|
||||
Recommendation: req.Recommendation,
|
||||
}
|
||||
|
||||
created, err := h.db.CreateVulnerability(vuln)
|
||||
@@ -172,15 +172,15 @@ func (h *VulnerabilityHandler) ListVulnerabilities(c *gin.Context) {
|
||||
type UpdateVulnerabilityRequest struct {
|
||||
ConversationTag string `json:"conversation_tag"`
|
||||
TaskTag string `json:"task_tag"`
|
||||
Title string `json:"title"`
|
||||
Description string `json:"description"`
|
||||
Severity string `json:"severity"`
|
||||
Status string `json:"status"`
|
||||
Type string `json:"type"`
|
||||
Target string `json:"target"`
|
||||
Proof string `json:"proof"`
|
||||
Impact string `json:"impact"`
|
||||
Recommendation string `json:"recommendation"`
|
||||
Title string `json:"title"`
|
||||
Description string `json:"description"`
|
||||
Severity string `json:"severity"`
|
||||
Status string `json:"status"`
|
||||
Type string `json:"type"`
|
||||
Target string `json:"target"`
|
||||
Proof string `json:"proof"`
|
||||
Impact string `json:"impact"`
|
||||
Recommendation string `json:"recommendation"`
|
||||
}
|
||||
|
||||
// UpdateVulnerability 更新漏洞
|
||||
@@ -460,4 +460,3 @@ func sanitizeExportName(raw string) string {
|
||||
replacer := strings.NewReplacer("/", "-", "\\", "-", ":", "-", "*", "-", "?", "-", "\"", "-", "<", "-", ">", "-", "|", "-")
|
||||
return replacer.Replace(name)
|
||||
}
|
||||
|
||||
|
||||
@@ -16,9 +16,9 @@ const (
|
||||
|
||||
// DSL keys for [VectorEinoRetriever.Retrieve] via [retriever.WithDSLInfo].
|
||||
const (
|
||||
DSLRiskType = "risk_type"
|
||||
DSLSimilarityThreshold = "similarity_threshold"
|
||||
DSLSubIndexFilter = "sub_index_filter"
|
||||
DSLRiskType = "risk_type"
|
||||
DSLSimilarityThreshold = "similarity_threshold"
|
||||
DSLSubIndexFilter = "sub_index_filter"
|
||||
)
|
||||
|
||||
// FormatEmbeddingInput matches the historical indexing format so existing embeddings
|
||||
|
||||
@@ -8,8 +8,8 @@ import (
|
||||
|
||||
"cyberstrike-ai/internal/config"
|
||||
|
||||
"github.com/cloudwego/eino/compose"
|
||||
"github.com/cloudwego/eino/components/document"
|
||||
"github.com/cloudwego/eino/compose"
|
||||
"github.com/cloudwego/eino/schema"
|
||||
)
|
||||
|
||||
|
||||
@@ -11,9 +11,9 @@ import (
|
||||
"cyberstrike-ai/internal/config"
|
||||
|
||||
fileloader "github.com/cloudwego/eino-ext/components/document/loader/file"
|
||||
"github.com/cloudwego/eino/compose"
|
||||
"github.com/cloudwego/eino/components/document"
|
||||
"github.com/cloudwego/eino/components/indexer"
|
||||
"github.com/cloudwego/eino/compose"
|
||||
"github.com/cloudwego/eino/schema"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
@@ -35,14 +35,14 @@ type Indexer struct {
|
||||
lastErrorTime time.Time
|
||||
errorCount int
|
||||
|
||||
rebuildMu sync.RWMutex
|
||||
isRebuilding bool
|
||||
rebuildTotalItems int
|
||||
rebuildCurrent int
|
||||
rebuildFailed int
|
||||
rebuildStartTime time.Time
|
||||
rebuildLastItemID string
|
||||
rebuildLastChunks int
|
||||
rebuildMu sync.RWMutex
|
||||
isRebuilding bool
|
||||
rebuildTotalItems int
|
||||
rebuildCurrent int
|
||||
rebuildFailed int
|
||||
rebuildStartTime time.Time
|
||||
rebuildLastItemID string
|
||||
rebuildLastChunks int
|
||||
}
|
||||
|
||||
// NewIndexer 创建索引器并编译 Eino 索引链;kcfg 为完整知识库配置(含 indexing 与路径相关行为)。
|
||||
|
||||
@@ -108,9 +108,9 @@ func (r *RetrievalLog) MarshalJSON() ([]byte, error) {
|
||||
|
||||
// CategoryWithItems 分类及其下的知识项(用于按分类分页)
|
||||
type CategoryWithItems struct {
|
||||
Category string `json:"category"` // 分类名称
|
||||
ItemCount int `json:"itemCount"` // 该分类下的知识项总数
|
||||
Items []*KnowledgeItemSummary `json:"items"` // 该分类下的知识项列表
|
||||
Category string `json:"category"` // 分类名称
|
||||
ItemCount int `json:"itemCount"` // 该分类下的知识项总数
|
||||
Items []*KnowledgeItemSummary `json:"items"` // 该分类下的知识项列表
|
||||
}
|
||||
|
||||
// SearchRequest 搜索请求
|
||||
|
||||
+10
-10
@@ -55,14 +55,14 @@ func New(level, output string) *Logger {
|
||||
}
|
||||
|
||||
func (l *Logger) Fatal(msg string, fields ...interface{}) {
|
||||
zapFields := make([]zap.Field, 0, len(fields))
|
||||
for _, f := range fields {
|
||||
switch v := f.(type) {
|
||||
case error:
|
||||
zapFields = append(zapFields, zap.Error(v))
|
||||
default:
|
||||
zapFields = append(zapFields, zap.Any("field", v))
|
||||
}
|
||||
}
|
||||
l.Logger.Fatal(msg, zapFields...)
|
||||
zapFields := make([]zap.Field, 0, len(fields))
|
||||
for _, f := range fields {
|
||||
switch v := f.(type) {
|
||||
case error:
|
||||
zapFields = append(zapFields, zap.Error(v))
|
||||
default:
|
||||
zapFields = append(zapFields, zap.Any("field", v))
|
||||
}
|
||||
}
|
||||
l.Logger.Fatal(msg, zapFields...)
|
||||
}
|
||||
|
||||
@@ -62,7 +62,7 @@ func TestExternalMCPManager_RemoveConfig(t *testing.T) {
|
||||
manager := NewExternalMCPManager(logger)
|
||||
|
||||
cfg := config.ExternalMCPServerConfig{
|
||||
Command: "python3",
|
||||
Command: "python3",
|
||||
ExternalMCPEnable: false,
|
||||
}
|
||||
|
||||
@@ -86,17 +86,17 @@ func TestExternalMCPManager_GetStats(t *testing.T) {
|
||||
|
||||
// 添加多个配置
|
||||
manager.AddOrUpdateConfig("enabled1", config.ExternalMCPServerConfig{
|
||||
Command: "python3",
|
||||
Command: "python3",
|
||||
ExternalMCPEnable: true,
|
||||
})
|
||||
|
||||
manager.AddOrUpdateConfig("enabled2", config.ExternalMCPServerConfig{
|
||||
URL: "http://127.0.0.1:8081/mcp",
|
||||
URL: "http://127.0.0.1:8081/mcp",
|
||||
ExternalMCPEnable: true,
|
||||
})
|
||||
|
||||
manager.AddOrUpdateConfig("disabled1", config.ExternalMCPServerConfig{
|
||||
Command: "python3",
|
||||
Command: "python3",
|
||||
ExternalMCPEnable: false,
|
||||
})
|
||||
|
||||
@@ -122,11 +122,11 @@ func TestExternalMCPManager_LoadConfigs(t *testing.T) {
|
||||
externalMCPConfig := config.ExternalMCPConfig{
|
||||
Servers: map[string]config.ExternalMCPServerConfig{
|
||||
"loaded1": {
|
||||
Command: "python3",
|
||||
Command: "python3",
|
||||
ExternalMCPEnable: true,
|
||||
},
|
||||
"loaded2": {
|
||||
URL: "http://127.0.0.1:8081/mcp",
|
||||
URL: "http://127.0.0.1:8081/mcp",
|
||||
ExternalMCPEnable: false,
|
||||
},
|
||||
},
|
||||
@@ -153,9 +153,9 @@ func TestLazySDKClient_InitializeFails(t *testing.T) {
|
||||
logger := zap.NewNop()
|
||||
// 使用不存在的 HTTP 地址,Initialize 应失败
|
||||
cfg := config.ExternalMCPServerConfig{
|
||||
Type: "http",
|
||||
URL: "http://127.0.0.1:19999/nonexistent",
|
||||
Timeout: 2,
|
||||
Type: "http",
|
||||
URL: "http://127.0.0.1:19999/nonexistent",
|
||||
Timeout: 2,
|
||||
}
|
||||
c := newLazySDKClient(cfg, logger)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
@@ -176,7 +176,7 @@ func TestExternalMCPManager_StartStopClient(t *testing.T) {
|
||||
|
||||
// 添加一个禁用的配置
|
||||
cfg := config.ExternalMCPServerConfig{
|
||||
Command: "python3",
|
||||
Command: "python3",
|
||||
ExternalMCPEnable: false,
|
||||
}
|
||||
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
@@ -109,11 +110,11 @@ func runEinoADKAgentLoop(ctx context.Context, args *einoADKRunLoopArgs, baseMsgs
|
||||
}
|
||||
}()
|
||||
|
||||
var lastRunMsgs []adk.Message
|
||||
var lastAssistant string
|
||||
var lastPlanExecuteExecutor string
|
||||
msgs := append([]adk.Message(nil), baseMsgs...)
|
||||
runAccumulatedMsgs := append([]adk.Message(nil), msgs...)
|
||||
baseAccumulatedCount := len(runAccumulatedMsgs)
|
||||
|
||||
emptyHint := strings.TrimSpace(args.EmptyResponseMessage)
|
||||
if emptyHint == "" {
|
||||
@@ -132,125 +133,150 @@ func runEinoADKAgentLoop(ctx context.Context, args *einoADKRunLoopArgs, baseMsgs
|
||||
pendingByID := make(map[string]toolCallPendingInfo)
|
||||
pendingQueueByAgent := make(map[string][]string)
|
||||
markPending := func(tc toolCallPendingInfo) {
|
||||
if tc.ToolCallID == "" {
|
||||
return
|
||||
}
|
||||
pendingByID[tc.ToolCallID] = tc
|
||||
pendingQueueByAgent[tc.EinoAgent] = append(pendingQueueByAgent[tc.EinoAgent], tc.ToolCallID)
|
||||
if tc.ToolCallID == "" {
|
||||
return
|
||||
}
|
||||
popNextPendingForAgent := func(agentName string) (toolCallPendingInfo, bool) {
|
||||
q := pendingQueueByAgent[agentName]
|
||||
for len(q) > 0 {
|
||||
id := q[0]
|
||||
q = q[1:]
|
||||
pendingQueueByAgent[agentName] = q
|
||||
if tc, ok := pendingByID[id]; ok {
|
||||
delete(pendingByID, id)
|
||||
return tc, true
|
||||
}
|
||||
pendingByID[tc.ToolCallID] = tc
|
||||
pendingQueueByAgent[tc.EinoAgent] = append(pendingQueueByAgent[tc.EinoAgent], tc.ToolCallID)
|
||||
}
|
||||
popNextPendingForAgent := func(agentName string) (toolCallPendingInfo, bool) {
|
||||
q := pendingQueueByAgent[agentName]
|
||||
for len(q) > 0 {
|
||||
id := q[0]
|
||||
q = q[1:]
|
||||
pendingQueueByAgent[agentName] = q
|
||||
if tc, ok := pendingByID[id]; ok {
|
||||
delete(pendingByID, id)
|
||||
return tc, true
|
||||
}
|
||||
return toolCallPendingInfo{}, false
|
||||
}
|
||||
removePendingByID := func(toolCallID string) {
|
||||
if toolCallID == "" {
|
||||
return
|
||||
}
|
||||
delete(pendingByID, toolCallID)
|
||||
return toolCallPendingInfo{}, false
|
||||
}
|
||||
removePendingByID := func(toolCallID string) {
|
||||
if toolCallID == "" {
|
||||
return
|
||||
}
|
||||
flushAllPendingAsFailed := func(err error) {
|
||||
if progress == nil {
|
||||
pendingByID = make(map[string]toolCallPendingInfo)
|
||||
pendingQueueByAgent = make(map[string][]string)
|
||||
return
|
||||
}
|
||||
msg := ""
|
||||
if err != nil {
|
||||
msg = err.Error()
|
||||
}
|
||||
for _, tc := range pendingByID {
|
||||
toolName := tc.ToolName
|
||||
if strings.TrimSpace(toolName) == "" {
|
||||
toolName = "unknown"
|
||||
}
|
||||
progress("tool_result", fmt.Sprintf("工具结果 (%s)", toolName), map[string]interface{}{
|
||||
"toolName": toolName,
|
||||
"success": false,
|
||||
"isError": true,
|
||||
"result": msg,
|
||||
"resultPreview": msg,
|
||||
"toolCallId": tc.ToolCallID,
|
||||
"conversationId": conversationID,
|
||||
"einoAgent": tc.EinoAgent,
|
||||
"einoRole": tc.EinoRole,
|
||||
"source": "eino",
|
||||
})
|
||||
}
|
||||
delete(pendingByID, toolCallID)
|
||||
}
|
||||
flushAllPendingAsFailed := func(err error) {
|
||||
if progress == nil {
|
||||
pendingByID = make(map[string]toolCallPendingInfo)
|
||||
pendingQueueByAgent = make(map[string][]string)
|
||||
return
|
||||
}
|
||||
msg := ""
|
||||
if err != nil {
|
||||
msg = err.Error()
|
||||
}
|
||||
for _, tc := range pendingByID {
|
||||
toolName := tc.ToolName
|
||||
if strings.TrimSpace(toolName) == "" {
|
||||
toolName = "unknown"
|
||||
}
|
||||
progress("tool_result", fmt.Sprintf("工具结果 (%s)", toolName), map[string]interface{}{
|
||||
"toolName": toolName,
|
||||
"success": false,
|
||||
"isError": true,
|
||||
"result": msg,
|
||||
"resultPreview": msg,
|
||||
"toolCallId": tc.ToolCallID,
|
||||
"conversationId": conversationID,
|
||||
"einoAgent": tc.EinoAgent,
|
||||
"einoRole": tc.EinoRole,
|
||||
"source": "eino",
|
||||
})
|
||||
}
|
||||
pendingByID = make(map[string]toolCallPendingInfo)
|
||||
pendingQueueByAgent = make(map[string][]string)
|
||||
}
|
||||
|
||||
runnerCfg := adk.RunnerConfig{
|
||||
Agent: da,
|
||||
EnableStreaming: true,
|
||||
}
|
||||
if cp := strings.TrimSpace(args.CheckpointDir); cp != "" {
|
||||
cpDir := filepath.Join(cp, sanitizeEinoPathSegment(conversationID))
|
||||
st, stErr := newFileCheckPointStore(cpDir)
|
||||
if stErr != nil {
|
||||
if logger != nil {
|
||||
logger.Warn("eino checkpoint store disabled", zap.String("dir", cpDir), zap.Error(stErr))
|
||||
}
|
||||
} else {
|
||||
runnerCfg.CheckPointStore = st
|
||||
if logger != nil {
|
||||
logger.Info("eino runner: checkpoint store enabled", zap.String("dir", cpDir))
|
||||
}
|
||||
runnerCfg := adk.RunnerConfig{
|
||||
Agent: da,
|
||||
EnableStreaming: true,
|
||||
}
|
||||
var cpStore *fileCheckPointStore
|
||||
var checkPointID string
|
||||
if cp := strings.TrimSpace(args.CheckpointDir); cp != "" {
|
||||
cpDir := filepath.Join(cp, sanitizeEinoPathSegment(conversationID))
|
||||
st, stErr := newFileCheckPointStore(cpDir)
|
||||
if stErr != nil {
|
||||
if logger != nil {
|
||||
logger.Warn("eino checkpoint store disabled", zap.String("dir", cpDir), zap.Error(stErr))
|
||||
}
|
||||
} else {
|
||||
cpStore = st
|
||||
checkPointID = buildEinoCheckpointID(orchMode)
|
||||
runnerCfg.CheckPointStore = st
|
||||
if logger != nil {
|
||||
logger.Info("eino runner: checkpoint store enabled",
|
||||
zap.String("dir", cpDir),
|
||||
zap.String("checkPointID", checkPointID))
|
||||
}
|
||||
}
|
||||
}
|
||||
runner := adk.NewRunner(ctx, runnerCfg)
|
||||
iter := runner.Run(ctx, msgs)
|
||||
handleRunErr := func(runErr error) error {
|
||||
if runErr == nil {
|
||||
return nil
|
||||
var iter *adk.AsyncIterator[*adk.AgentEvent]
|
||||
if cpStore != nil && checkPointID != "" {
|
||||
if _, existed, getErr := cpStore.Get(ctx, checkPointID); getErr != nil {
|
||||
if logger != nil {
|
||||
logger.Warn("eino checkpoint preflight get failed", zap.String("checkPointID", checkPointID), zap.Error(getErr))
|
||||
}
|
||||
if errors.Is(runErr, context.DeadlineExceeded) {
|
||||
flushAllPendingAsFailed(runErr)
|
||||
if progress != nil {
|
||||
progress("error", runErr.Error(), map[string]interface{}{
|
||||
"conversationId": conversationID,
|
||||
"source": "eino",
|
||||
"errorKind": "timeout",
|
||||
})
|
||||
} else if existed {
|
||||
if progress != nil {
|
||||
progress("progress", "检测到断点,正在从中断节点恢复执行...", map[string]interface{}{
|
||||
"conversationId": conversationID,
|
||||
"source": "eino",
|
||||
"orchestration": orchMode,
|
||||
"checkPointID": checkPointID,
|
||||
})
|
||||
}
|
||||
if logger != nil {
|
||||
logger.Info("eino runner: resume from checkpoint", zap.String("checkPointID", checkPointID))
|
||||
}
|
||||
resumeIter, resumeErr := runner.Resume(ctx, checkPointID)
|
||||
if resumeErr == nil {
|
||||
iter = resumeIter
|
||||
} else {
|
||||
if logger != nil {
|
||||
logger.Warn("eino runner: resume failed, fallback to fresh run",
|
||||
zap.String("checkPointID", checkPointID),
|
||||
zap.Error(resumeErr))
|
||||
}
|
||||
return runErr
|
||||
}
|
||||
// context.Canceled 是唯一应当直接终止编排的错误(用户关闭页面、主动停止等)。
|
||||
if errors.Is(runErr, context.Canceled) {
|
||||
flushAllPendingAsFailed(runErr)
|
||||
if progress != nil {
|
||||
progress("error", runErr.Error(), map[string]interface{}{
|
||||
"conversationId": conversationID,
|
||||
"source": "eino",
|
||||
})
|
||||
}
|
||||
return runErr
|
||||
}
|
||||
if isEinoIterationLimitError(runErr) {
|
||||
flushAllPendingAsFailed(runErr)
|
||||
if progress != nil {
|
||||
progress("iteration_limit_reached", runErr.Error(), map[string]interface{}{
|
||||
progress("progress", "断点恢复失败,已回退为全新执行。", map[string]interface{}{
|
||||
"conversationId": conversationID,
|
||||
"source": "eino",
|
||||
"orchestration": orchMode,
|
||||
})
|
||||
progress("error", runErr.Error(), map[string]interface{}{
|
||||
"conversationId": conversationID,
|
||||
"source": "eino",
|
||||
"errorKind": "iteration_limit",
|
||||
"checkPointID": checkPointID,
|
||||
})
|
||||
}
|
||||
return runErr
|
||||
}
|
||||
}
|
||||
}
|
||||
if iter == nil {
|
||||
if checkPointID != "" {
|
||||
iter = runner.Run(ctx, msgs, adk.WithCheckPointID(checkPointID))
|
||||
} else {
|
||||
iter = runner.Run(ctx, msgs)
|
||||
}
|
||||
}
|
||||
handleRunErr := func(runErr error) error {
|
||||
if runErr == nil {
|
||||
return nil
|
||||
}
|
||||
if errors.Is(runErr, context.DeadlineExceeded) {
|
||||
flushAllPendingAsFailed(runErr)
|
||||
if progress != nil {
|
||||
progress("error", runErr.Error(), map[string]interface{}{
|
||||
"conversationId": conversationID,
|
||||
"source": "eino",
|
||||
"errorKind": "timeout",
|
||||
})
|
||||
}
|
||||
return runErr
|
||||
}
|
||||
// context.Canceled 是唯一应当直接终止编排的错误(用户关闭页面、主动停止等)。
|
||||
if errors.Is(runErr, context.Canceled) {
|
||||
flushAllPendingAsFailed(runErr)
|
||||
if progress != nil {
|
||||
progress("error", runErr.Error(), map[string]interface{}{
|
||||
@@ -260,249 +286,190 @@ func runEinoADKAgentLoop(ctx context.Context, args *einoADKRunLoopArgs, baseMsgs
|
||||
}
|
||||
return runErr
|
||||
}
|
||||
|
||||
for {
|
||||
// 检测 context 取消(用户关闭浏览器、请求超时等),flush pending 工具状态避免 UI 卡在 "执行中"。
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
flushAllPendingAsFailed(ctx.Err())
|
||||
if progress != nil {
|
||||
progress("error", "Request cancelled / 请求已取消", map[string]interface{}{
|
||||
"conversationId": conversationID,
|
||||
"source": "eino",
|
||||
})
|
||||
}
|
||||
return nil, ctx.Err()
|
||||
default:
|
||||
}
|
||||
|
||||
ev, ok := iter.Next()
|
||||
if !ok {
|
||||
if len(pendingByID) > 0 {
|
||||
orphanCount := len(pendingByID)
|
||||
flushAllPendingAsFailed(errors.New("pending tool call missing result before run completion"))
|
||||
if progress != nil {
|
||||
progress("eino_pending_orphaned", "pending tool calls were force-closed at run end", map[string]interface{}{
|
||||
"conversationId": conversationID,
|
||||
"source": "eino",
|
||||
"orchestration": orchMode,
|
||||
"pendingCount": orphanCount,
|
||||
})
|
||||
}
|
||||
}
|
||||
lastRunMsgs = runAccumulatedMsgs
|
||||
break
|
||||
}
|
||||
if ev == nil {
|
||||
continue
|
||||
}
|
||||
if ev.Err != nil {
|
||||
if retErr := handleRunErr(ev.Err); retErr != nil {
|
||||
return nil, retErr
|
||||
}
|
||||
}
|
||||
if ev.AgentName != "" && progress != nil {
|
||||
iterEinoAgent := orchestratorName
|
||||
if orchMode == "plan_execute" {
|
||||
if a := strings.TrimSpace(ev.AgentName); a != "" {
|
||||
iterEinoAgent = a
|
||||
}
|
||||
}
|
||||
if streamsMainAssistant(ev.AgentName) {
|
||||
if einoMainRound == 0 {
|
||||
einoMainRound = 1
|
||||
progress("iteration", "", map[string]interface{}{
|
||||
"iteration": 1,
|
||||
"einoScope": "main",
|
||||
"einoRole": "orchestrator",
|
||||
"einoAgent": iterEinoAgent,
|
||||
"orchestration": orchMode,
|
||||
"conversationId": conversationID,
|
||||
"source": "eino",
|
||||
})
|
||||
} else if einoLastAgent != "" && !streamsMainAssistant(einoLastAgent) {
|
||||
einoMainRound++
|
||||
progress("iteration", "", map[string]interface{}{
|
||||
"iteration": einoMainRound,
|
||||
"einoScope": "main",
|
||||
"einoRole": "orchestrator",
|
||||
"einoAgent": iterEinoAgent,
|
||||
"orchestration": orchMode,
|
||||
"conversationId": conversationID,
|
||||
"source": "eino",
|
||||
})
|
||||
}
|
||||
}
|
||||
einoLastAgent = ev.AgentName
|
||||
progress("progress", fmt.Sprintf("[Eino] %s", ev.AgentName), map[string]interface{}{
|
||||
if isEinoIterationLimitError(runErr) {
|
||||
flushAllPendingAsFailed(runErr)
|
||||
if progress != nil {
|
||||
progress("iteration_limit_reached", runErr.Error(), map[string]interface{}{
|
||||
"conversationId": conversationID,
|
||||
"einoAgent": ev.AgentName,
|
||||
"einoRole": einoRoleTag(ev.AgentName),
|
||||
"source": "eino",
|
||||
"orchestration": orchMode,
|
||||
})
|
||||
progress("error", runErr.Error(), map[string]interface{}{
|
||||
"conversationId": conversationID,
|
||||
"source": "eino",
|
||||
"errorKind": "iteration_limit",
|
||||
})
|
||||
}
|
||||
if ev.Output == nil || ev.Output.MessageOutput == nil {
|
||||
continue
|
||||
}
|
||||
mv := ev.Output.MessageOutput
|
||||
return runErr
|
||||
}
|
||||
flushAllPendingAsFailed(runErr)
|
||||
if progress != nil {
|
||||
progress("error", runErr.Error(), map[string]interface{}{
|
||||
"conversationId": conversationID,
|
||||
"source": "eino",
|
||||
})
|
||||
}
|
||||
return runErr
|
||||
}
|
||||
|
||||
if mv.IsStreaming && mv.MessageStream != nil {
|
||||
streamHeaderSent := false
|
||||
var reasoningStreamID string
|
||||
var toolStreamFragments []schema.ToolCall
|
||||
var subAssistantBuf strings.Builder
|
||||
var subReplyStreamID string
|
||||
var mainAssistantBuf strings.Builder
|
||||
var streamRecvErr error
|
||||
for {
|
||||
chunk, rerr := mv.MessageStream.Recv()
|
||||
if rerr != nil {
|
||||
if errors.Is(rerr, io.EOF) {
|
||||
break
|
||||
}
|
||||
if logger != nil {
|
||||
logger.Warn("eino stream recv error, flushing incomplete stream",
|
||||
zap.Error(rerr),
|
||||
zap.String("agent", ev.AgentName),
|
||||
zap.Int("toolFragments", len(toolStreamFragments)))
|
||||
}
|
||||
streamRecvErr = rerr
|
||||
break
|
||||
}
|
||||
if chunk == nil {
|
||||
continue
|
||||
}
|
||||
if progress != nil && strings.TrimSpace(chunk.ReasoningContent) != "" {
|
||||
if reasoningStreamID == "" {
|
||||
reasoningStreamID = fmt.Sprintf("eino-reasoning-%s-%d", conversationID, atomic.AddInt64(&reasoningStreamSeq, 1))
|
||||
progress("thinking_stream_start", " ", map[string]interface{}{
|
||||
"streamId": reasoningStreamID,
|
||||
"source": "eino",
|
||||
"einoAgent": ev.AgentName,
|
||||
"einoRole": einoRoleTag(ev.AgentName),
|
||||
"orchestration": orchMode,
|
||||
})
|
||||
}
|
||||
progress("thinking_stream_delta", chunk.ReasoningContent, map[string]interface{}{
|
||||
"streamId": reasoningStreamID,
|
||||
})
|
||||
}
|
||||
if chunk.Content != "" {
|
||||
if progress != nil && streamsMainAssistant(ev.AgentName) {
|
||||
if !streamHeaderSent {
|
||||
progress("response_start", "", map[string]interface{}{
|
||||
"conversationId": conversationID,
|
||||
"mcpExecutionIds": snapshotMCPIDs(),
|
||||
"messageGeneratedBy": "eino:" + ev.AgentName,
|
||||
"einoRole": "orchestrator",
|
||||
"einoAgent": ev.AgentName,
|
||||
"orchestration": orchMode,
|
||||
})
|
||||
streamHeaderSent = true
|
||||
}
|
||||
progress("response_delta", chunk.Content, map[string]interface{}{
|
||||
"conversationId": conversationID,
|
||||
"mcpExecutionIds": snapshotMCPIDs(),
|
||||
"einoRole": "orchestrator",
|
||||
"einoAgent": ev.AgentName,
|
||||
"orchestration": orchMode,
|
||||
})
|
||||
mainAssistantBuf.WriteString(chunk.Content)
|
||||
} else if !streamsMainAssistant(ev.AgentName) {
|
||||
if progress != nil {
|
||||
if subReplyStreamID == "" {
|
||||
subReplyStreamID = fmt.Sprintf("eino-sub-reply-%s-%d", conversationID, atomic.AddInt64(&einoSubReplyStreamSeq, 1))
|
||||
progress("eino_agent_reply_stream_start", "", map[string]interface{}{
|
||||
"streamId": subReplyStreamID,
|
||||
"einoAgent": ev.AgentName,
|
||||
"einoRole": "sub",
|
||||
"conversationId": conversationID,
|
||||
"source": "eino",
|
||||
})
|
||||
}
|
||||
progress("eino_agent_reply_stream_delta", chunk.Content, map[string]interface{}{
|
||||
"streamId": subReplyStreamID,
|
||||
"conversationId": conversationID,
|
||||
})
|
||||
}
|
||||
subAssistantBuf.WriteString(chunk.Content)
|
||||
}
|
||||
}
|
||||
if len(chunk.ToolCalls) > 0 {
|
||||
toolStreamFragments = append(toolStreamFragments, chunk.ToolCalls...)
|
||||
}
|
||||
}
|
||||
if streamsMainAssistant(ev.AgentName) {
|
||||
if s := strings.TrimSpace(mainAssistantBuf.String()); s != "" {
|
||||
lastAssistant = s
|
||||
runAccumulatedMsgs = append(runAccumulatedMsgs, schema.AssistantMessage(s, nil))
|
||||
if orchMode == "plan_execute" && strings.EqualFold(strings.TrimSpace(ev.AgentName), "executor") {
|
||||
lastPlanExecuteExecutor = UnwrapPlanExecuteUserText(s)
|
||||
}
|
||||
}
|
||||
}
|
||||
if subAssistantBuf.Len() > 0 && progress != nil {
|
||||
if s := strings.TrimSpace(subAssistantBuf.String()); s != "" {
|
||||
if subReplyStreamID != "" {
|
||||
progress("eino_agent_reply_stream_end", s, map[string]interface{}{
|
||||
"streamId": subReplyStreamID,
|
||||
"einoAgent": ev.AgentName,
|
||||
"einoRole": "sub",
|
||||
"conversationId": conversationID,
|
||||
"source": "eino",
|
||||
})
|
||||
} else {
|
||||
progress("eino_agent_reply", s, map[string]interface{}{
|
||||
"conversationId": conversationID,
|
||||
"einoAgent": ev.AgentName,
|
||||
"einoRole": "sub",
|
||||
"source": "eino",
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
var lastToolChunk *schema.Message
|
||||
if merged := mergeStreamingToolCallFragments(toolStreamFragments); len(merged) > 0 {
|
||||
lastToolChunk = &schema.Message{ToolCalls: merged}
|
||||
}
|
||||
tryEmitToolCallsOnce(lastToolChunk, ev.AgentName, orchestratorName, conversationID, progress, toolEmitSeen, subAgentToolStep, markPending)
|
||||
if streamRecvErr != nil {
|
||||
if progress != nil {
|
||||
progress("eino_stream_error", streamRecvErr.Error(), map[string]interface{}{
|
||||
"conversationId": conversationID,
|
||||
"source": "eino",
|
||||
"einoAgent": ev.AgentName,
|
||||
"einoRole": einoRoleTag(ev.AgentName),
|
||||
})
|
||||
}
|
||||
if retErr := handleRunErr(streamRecvErr); retErr != nil {
|
||||
return nil, retErr
|
||||
}
|
||||
}
|
||||
continue
|
||||
}
|
||||
takePartial := func(runErr error) (*RunResult, error) {
|
||||
if len(runAccumulatedMsgs) <= baseAccumulatedCount {
|
||||
return nil, runErr
|
||||
}
|
||||
ids := snapshotMCPIDs()
|
||||
return buildEinoRunResultFromAccumulated(
|
||||
orchMode, runAccumulatedMsgs, lastAssistant, lastPlanExecuteExecutor, emptyHint, ids, true,
|
||||
), runErr
|
||||
}
|
||||
|
||||
msg, gerr := mv.GetMessage()
|
||||
if gerr != nil || msg == nil {
|
||||
continue
|
||||
for {
|
||||
// 检测 context 取消(用户关闭浏览器、请求超时等),flush pending 工具状态避免 UI 卡在 "执行中"。
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
flushAllPendingAsFailed(ctx.Err())
|
||||
if progress != nil {
|
||||
progress("error", "Request cancelled / 请求已取消", map[string]interface{}{
|
||||
"conversationId": conversationID,
|
||||
"source": "eino",
|
||||
})
|
||||
}
|
||||
runAccumulatedMsgs = append(runAccumulatedMsgs, msg)
|
||||
tryEmitToolCallsOnce(mergeMessageToolCalls(msg), ev.AgentName, orchestratorName, conversationID, progress, toolEmitSeen, subAgentToolStep, markPending)
|
||||
return takePartial(ctx.Err())
|
||||
default:
|
||||
}
|
||||
|
||||
if mv.Role == schema.Assistant {
|
||||
if progress != nil && strings.TrimSpace(msg.ReasoningContent) != "" {
|
||||
progress("thinking", strings.TrimSpace(msg.ReasoningContent), map[string]interface{}{
|
||||
ev, ok := iter.Next()
|
||||
if !ok {
|
||||
// iter 结束并不总是“正常完成”:
|
||||
// 当取消/超时发生在 iter.Next() 阻塞期间时,可能直接返回 !ok。
|
||||
// 此时必须保留 checkpoint,避免后续恢复时被误判为“无断点”而全量重跑。
|
||||
if ctxErr := ctx.Err(); ctxErr != nil {
|
||||
flushAllPendingAsFailed(ctxErr)
|
||||
if progress != nil {
|
||||
progress("error", ctxErr.Error(), map[string]interface{}{
|
||||
"conversationId": conversationID,
|
||||
"source": "eino",
|
||||
"einoAgent": ev.AgentName,
|
||||
"einoRole": einoRoleTag(ev.AgentName),
|
||||
"orchestration": orchMode,
|
||||
})
|
||||
}
|
||||
body := strings.TrimSpace(msg.Content)
|
||||
if body != "" {
|
||||
if streamsMainAssistant(ev.AgentName) {
|
||||
if progress != nil {
|
||||
return takePartial(ctxErr)
|
||||
}
|
||||
if len(pendingByID) > 0 {
|
||||
orphanCount := len(pendingByID)
|
||||
flushAllPendingAsFailed(errors.New("pending tool call missing result before run completion"))
|
||||
if progress != nil {
|
||||
progress("eino_pending_orphaned", "pending tool calls were force-closed at run end", map[string]interface{}{
|
||||
"conversationId": conversationID,
|
||||
"source": "eino",
|
||||
"orchestration": orchMode,
|
||||
"pendingCount": orphanCount,
|
||||
})
|
||||
}
|
||||
}
|
||||
if cpStore != nil && checkPointID != "" {
|
||||
if p, pErr := cpStore.path(checkPointID); pErr == nil {
|
||||
if rmErr := os.Remove(p); rmErr != nil && !os.IsNotExist(rmErr) && logger != nil {
|
||||
logger.Warn("eino checkpoint cleanup failed", zap.String("path", p), zap.Error(rmErr))
|
||||
}
|
||||
}
|
||||
}
|
||||
break
|
||||
}
|
||||
if ev == nil {
|
||||
continue
|
||||
}
|
||||
if ev.Err != nil {
|
||||
if retErr := handleRunErr(ev.Err); retErr != nil {
|
||||
return takePartial(retErr)
|
||||
}
|
||||
}
|
||||
if ev.AgentName != "" && progress != nil {
|
||||
iterEinoAgent := orchestratorName
|
||||
if orchMode == "plan_execute" {
|
||||
if a := strings.TrimSpace(ev.AgentName); a != "" {
|
||||
iterEinoAgent = a
|
||||
}
|
||||
}
|
||||
if streamsMainAssistant(ev.AgentName) {
|
||||
if einoMainRound == 0 {
|
||||
einoMainRound = 1
|
||||
progress("iteration", "", map[string]interface{}{
|
||||
"iteration": 1,
|
||||
"einoScope": "main",
|
||||
"einoRole": "orchestrator",
|
||||
"einoAgent": iterEinoAgent,
|
||||
"orchestration": orchMode,
|
||||
"conversationId": conversationID,
|
||||
"source": "eino",
|
||||
})
|
||||
} else if einoLastAgent != "" && !streamsMainAssistant(einoLastAgent) {
|
||||
einoMainRound++
|
||||
progress("iteration", "", map[string]interface{}{
|
||||
"iteration": einoMainRound,
|
||||
"einoScope": "main",
|
||||
"einoRole": "orchestrator",
|
||||
"einoAgent": iterEinoAgent,
|
||||
"orchestration": orchMode,
|
||||
"conversationId": conversationID,
|
||||
"source": "eino",
|
||||
})
|
||||
}
|
||||
}
|
||||
einoLastAgent = ev.AgentName
|
||||
progress("progress", fmt.Sprintf("[Eino] %s", ev.AgentName), map[string]interface{}{
|
||||
"conversationId": conversationID,
|
||||
"einoAgent": ev.AgentName,
|
||||
"einoRole": einoRoleTag(ev.AgentName),
|
||||
"orchestration": orchMode,
|
||||
})
|
||||
}
|
||||
if ev.Output == nil || ev.Output.MessageOutput == nil {
|
||||
continue
|
||||
}
|
||||
mv := ev.Output.MessageOutput
|
||||
|
||||
if mv.IsStreaming && mv.MessageStream != nil {
|
||||
streamHeaderSent := false
|
||||
var reasoningStreamID string
|
||||
var toolStreamFragments []schema.ToolCall
|
||||
var subAssistantBuf strings.Builder
|
||||
var subReplyStreamID string
|
||||
var mainAssistantBuf strings.Builder
|
||||
var streamRecvErr error
|
||||
for {
|
||||
chunk, rerr := mv.MessageStream.Recv()
|
||||
if rerr != nil {
|
||||
if errors.Is(rerr, io.EOF) {
|
||||
break
|
||||
}
|
||||
if logger != nil {
|
||||
logger.Warn("eino stream recv error, flushing incomplete stream",
|
||||
zap.Error(rerr),
|
||||
zap.String("agent", ev.AgentName),
|
||||
zap.Int("toolFragments", len(toolStreamFragments)))
|
||||
}
|
||||
streamRecvErr = rerr
|
||||
break
|
||||
}
|
||||
if chunk == nil {
|
||||
continue
|
||||
}
|
||||
if progress != nil && strings.TrimSpace(chunk.ReasoningContent) != "" {
|
||||
if reasoningStreamID == "" {
|
||||
reasoningStreamID = fmt.Sprintf("eino-reasoning-%s-%d", conversationID, atomic.AddInt64(&reasoningStreamSeq, 1))
|
||||
progress("thinking_stream_start", " ", map[string]interface{}{
|
||||
"streamId": reasoningStreamID,
|
||||
"source": "eino",
|
||||
"einoAgent": ev.AgentName,
|
||||
"einoRole": einoRoleTag(ev.AgentName),
|
||||
"orchestration": orchMode,
|
||||
})
|
||||
}
|
||||
progress("thinking_stream_delta", chunk.ReasoningContent, map[string]interface{}{
|
||||
"streamId": reasoningStreamID,
|
||||
})
|
||||
}
|
||||
if chunk.Content != "" {
|
||||
if progress != nil && streamsMainAssistant(ev.AgentName) {
|
||||
if !streamHeaderSent {
|
||||
progress("response_start", "", map[string]interface{}{
|
||||
"conversationId": conversationID,
|
||||
"mcpExecutionIds": snapshotMCPIDs(),
|
||||
@@ -511,20 +478,61 @@ func runEinoADKAgentLoop(ctx context.Context, args *einoADKRunLoopArgs, baseMsgs
|
||||
"einoAgent": ev.AgentName,
|
||||
"orchestration": orchMode,
|
||||
})
|
||||
progress("response_delta", body, map[string]interface{}{
|
||||
"conversationId": conversationID,
|
||||
"mcpExecutionIds": snapshotMCPIDs(),
|
||||
"einoRole": "orchestrator",
|
||||
"einoAgent": ev.AgentName,
|
||||
"orchestration": orchMode,
|
||||
streamHeaderSent = true
|
||||
}
|
||||
progress("response_delta", chunk.Content, map[string]interface{}{
|
||||
"conversationId": conversationID,
|
||||
"mcpExecutionIds": snapshotMCPIDs(),
|
||||
"einoRole": "orchestrator",
|
||||
"einoAgent": ev.AgentName,
|
||||
"orchestration": orchMode,
|
||||
})
|
||||
mainAssistantBuf.WriteString(chunk.Content)
|
||||
} else if !streamsMainAssistant(ev.AgentName) {
|
||||
if progress != nil {
|
||||
if subReplyStreamID == "" {
|
||||
subReplyStreamID = fmt.Sprintf("eino-sub-reply-%s-%d", conversationID, atomic.AddInt64(&einoSubReplyStreamSeq, 1))
|
||||
progress("eino_agent_reply_stream_start", "", map[string]interface{}{
|
||||
"streamId": subReplyStreamID,
|
||||
"einoAgent": ev.AgentName,
|
||||
"einoRole": "sub",
|
||||
"conversationId": conversationID,
|
||||
"source": "eino",
|
||||
})
|
||||
}
|
||||
progress("eino_agent_reply_stream_delta", chunk.Content, map[string]interface{}{
|
||||
"streamId": subReplyStreamID,
|
||||
"conversationId": conversationID,
|
||||
})
|
||||
}
|
||||
lastAssistant = body
|
||||
if orchMode == "plan_execute" && strings.EqualFold(strings.TrimSpace(ev.AgentName), "executor") {
|
||||
lastPlanExecuteExecutor = UnwrapPlanExecuteUserText(body)
|
||||
}
|
||||
} else if progress != nil {
|
||||
progress("eino_agent_reply", body, map[string]interface{}{
|
||||
subAssistantBuf.WriteString(chunk.Content)
|
||||
}
|
||||
}
|
||||
if len(chunk.ToolCalls) > 0 {
|
||||
toolStreamFragments = append(toolStreamFragments, chunk.ToolCalls...)
|
||||
}
|
||||
}
|
||||
if streamsMainAssistant(ev.AgentName) {
|
||||
if s := strings.TrimSpace(mainAssistantBuf.String()); s != "" {
|
||||
lastAssistant = s
|
||||
runAccumulatedMsgs = append(runAccumulatedMsgs, schema.AssistantMessage(s, nil))
|
||||
if orchMode == "plan_execute" && strings.EqualFold(strings.TrimSpace(ev.AgentName), "executor") {
|
||||
lastPlanExecuteExecutor = UnwrapPlanExecuteUserText(s)
|
||||
}
|
||||
}
|
||||
}
|
||||
if subAssistantBuf.Len() > 0 && progress != nil {
|
||||
if s := strings.TrimSpace(subAssistantBuf.String()); s != "" {
|
||||
if subReplyStreamID != "" {
|
||||
progress("eino_agent_reply_stream_end", s, map[string]interface{}{
|
||||
"streamId": subReplyStreamID,
|
||||
"einoAgent": ev.AgentName,
|
||||
"einoRole": "sub",
|
||||
"conversationId": conversationID,
|
||||
"source": "eino",
|
||||
})
|
||||
} else {
|
||||
progress("eino_agent_reply", s, map[string]interface{}{
|
||||
"conversationId": conversationID,
|
||||
"einoAgent": ev.AgentName,
|
||||
"einoRole": "sub",
|
||||
@@ -533,65 +541,157 @@ func runEinoADKAgentLoop(ctx context.Context, args *einoADKRunLoopArgs, baseMsgs
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if mv.Role == schema.Tool && progress != nil {
|
||||
toolName := msg.ToolName
|
||||
if toolName == "" {
|
||||
toolName = mv.ToolName
|
||||
var lastToolChunk *schema.Message
|
||||
if merged := mergeStreamingToolCallFragments(toolStreamFragments); len(merged) > 0 {
|
||||
lastToolChunk = &schema.Message{ToolCalls: merged}
|
||||
}
|
||||
tryEmitToolCallsOnce(lastToolChunk, ev.AgentName, orchestratorName, conversationID, progress, toolEmitSeen, subAgentToolStep, markPending)
|
||||
if streamRecvErr != nil {
|
||||
if progress != nil {
|
||||
progress("eino_stream_error", streamRecvErr.Error(), map[string]interface{}{
|
||||
"conversationId": conversationID,
|
||||
"source": "eino",
|
||||
"einoAgent": ev.AgentName,
|
||||
"einoRole": einoRoleTag(ev.AgentName),
|
||||
})
|
||||
}
|
||||
|
||||
content := msg.Content
|
||||
isErr := false
|
||||
if strings.HasPrefix(content, einomcp.ToolErrorPrefix) {
|
||||
isErr = true
|
||||
content = strings.TrimPrefix(content, einomcp.ToolErrorPrefix)
|
||||
if retErr := handleRunErr(streamRecvErr); retErr != nil {
|
||||
return takePartial(retErr)
|
||||
}
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
preview := content
|
||||
if len(preview) > 200 {
|
||||
preview = preview[:200] + "..."
|
||||
}
|
||||
data := map[string]interface{}{
|
||||
"toolName": toolName,
|
||||
"success": !isErr,
|
||||
"isError": isErr,
|
||||
"result": content,
|
||||
"resultPreview": preview,
|
||||
msg, gerr := mv.GetMessage()
|
||||
if gerr != nil || msg == nil {
|
||||
continue
|
||||
}
|
||||
runAccumulatedMsgs = append(runAccumulatedMsgs, msg)
|
||||
tryEmitToolCallsOnce(mergeMessageToolCalls(msg), ev.AgentName, orchestratorName, conversationID, progress, toolEmitSeen, subAgentToolStep, markPending)
|
||||
|
||||
if mv.Role == schema.Assistant {
|
||||
if progress != nil && strings.TrimSpace(msg.ReasoningContent) != "" {
|
||||
progress("thinking", strings.TrimSpace(msg.ReasoningContent), map[string]interface{}{
|
||||
"conversationId": conversationID,
|
||||
"source": "eino",
|
||||
"einoAgent": ev.AgentName,
|
||||
"einoRole": einoRoleTag(ev.AgentName),
|
||||
"source": "eino",
|
||||
}
|
||||
toolCallID := strings.TrimSpace(msg.ToolCallID)
|
||||
if toolCallID == "" {
|
||||
if inferred, ok := popNextPendingForAgent(ev.AgentName); ok {
|
||||
toolCallID = inferred.ToolCallID
|
||||
} else if inferred, ok := popNextPendingForAgent(orchestratorName); ok {
|
||||
toolCallID = inferred.ToolCallID
|
||||
} else if inferred, ok := popNextPendingForAgent(""); ok {
|
||||
toolCallID = inferred.ToolCallID
|
||||
} else {
|
||||
for id := range pendingByID {
|
||||
toolCallID = id
|
||||
delete(pendingByID, id)
|
||||
break
|
||||
}
|
||||
"orchestration": orchMode,
|
||||
})
|
||||
}
|
||||
body := strings.TrimSpace(msg.Content)
|
||||
if body != "" {
|
||||
if streamsMainAssistant(ev.AgentName) {
|
||||
if progress != nil {
|
||||
progress("response_start", "", map[string]interface{}{
|
||||
"conversationId": conversationID,
|
||||
"mcpExecutionIds": snapshotMCPIDs(),
|
||||
"messageGeneratedBy": "eino:" + ev.AgentName,
|
||||
"einoRole": "orchestrator",
|
||||
"einoAgent": ev.AgentName,
|
||||
"orchestration": orchMode,
|
||||
})
|
||||
progress("response_delta", body, map[string]interface{}{
|
||||
"conversationId": conversationID,
|
||||
"mcpExecutionIds": snapshotMCPIDs(),
|
||||
"einoRole": "orchestrator",
|
||||
"einoAgent": ev.AgentName,
|
||||
"orchestration": orchMode,
|
||||
})
|
||||
}
|
||||
} else {
|
||||
removePendingByID(toolCallID)
|
||||
lastAssistant = body
|
||||
if orchMode == "plan_execute" && strings.EqualFold(strings.TrimSpace(ev.AgentName), "executor") {
|
||||
lastPlanExecuteExecutor = UnwrapPlanExecuteUserText(body)
|
||||
}
|
||||
} else if progress != nil {
|
||||
progress("eino_agent_reply", body, map[string]interface{}{
|
||||
"conversationId": conversationID,
|
||||
"einoAgent": ev.AgentName,
|
||||
"einoRole": "sub",
|
||||
"source": "eino",
|
||||
})
|
||||
}
|
||||
if toolCallID != "" {
|
||||
data["toolCallId"] = toolCallID
|
||||
}
|
||||
progress("tool_result", fmt.Sprintf("工具结果 (%s)", toolName), data)
|
||||
}
|
||||
}
|
||||
|
||||
if mv.Role == schema.Tool && progress != nil {
|
||||
toolName := msg.ToolName
|
||||
if toolName == "" {
|
||||
toolName = mv.ToolName
|
||||
}
|
||||
|
||||
content := msg.Content
|
||||
isErr := false
|
||||
if strings.HasPrefix(content, einomcp.ToolErrorPrefix) {
|
||||
isErr = true
|
||||
content = strings.TrimPrefix(content, einomcp.ToolErrorPrefix)
|
||||
}
|
||||
|
||||
preview := content
|
||||
if len(preview) > 200 {
|
||||
preview = preview[:200] + "..."
|
||||
}
|
||||
data := map[string]interface{}{
|
||||
"toolName": toolName,
|
||||
"success": !isErr,
|
||||
"isError": isErr,
|
||||
"result": content,
|
||||
"resultPreview": preview,
|
||||
"conversationId": conversationID,
|
||||
"einoAgent": ev.AgentName,
|
||||
"einoRole": einoRoleTag(ev.AgentName),
|
||||
"source": "eino",
|
||||
}
|
||||
toolCallID := strings.TrimSpace(msg.ToolCallID)
|
||||
if toolCallID == "" {
|
||||
if inferred, ok := popNextPendingForAgent(ev.AgentName); ok {
|
||||
toolCallID = inferred.ToolCallID
|
||||
} else if inferred, ok := popNextPendingForAgent(orchestratorName); ok {
|
||||
toolCallID = inferred.ToolCallID
|
||||
} else if inferred, ok := popNextPendingForAgent(""); ok {
|
||||
toolCallID = inferred.ToolCallID
|
||||
} else {
|
||||
for id := range pendingByID {
|
||||
toolCallID = id
|
||||
delete(pendingByID, id)
|
||||
break
|
||||
}
|
||||
}
|
||||
} else {
|
||||
removePendingByID(toolCallID)
|
||||
}
|
||||
if toolCallID != "" {
|
||||
data["toolCallId"] = toolCallID
|
||||
}
|
||||
progress("tool_result", fmt.Sprintf("工具结果 (%s)", toolName), data)
|
||||
}
|
||||
}
|
||||
|
||||
mcpIDsMu.Lock()
|
||||
ids := append([]string(nil), *mcpIDs...)
|
||||
mcpIDsMu.Unlock()
|
||||
|
||||
histJSON, _ := json.Marshal(lastRunMsgs)
|
||||
out := buildEinoRunResultFromAccumulated(
|
||||
orchMode, runAccumulatedMsgs, lastAssistant, lastPlanExecuteExecutor, emptyHint, ids, false,
|
||||
)
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func einoPartialRunLastOutputHint() string {
|
||||
return "[执行未正常结束(用户停止、超时或异常)。续跑时请基于上文已产生的工具与结果继续,勿重复已完成步骤。]\n" +
|
||||
"[Run ended abnormally; continue from the trace above without repeating completed steps.]"
|
||||
}
|
||||
|
||||
func buildEinoRunResultFromAccumulated(
|
||||
orchMode string,
|
||||
runAccumulatedMsgs []adk.Message,
|
||||
lastAssistant string,
|
||||
lastPlanExecuteExecutor string,
|
||||
emptyHint string,
|
||||
mcpIDs []string,
|
||||
partial bool,
|
||||
) *RunResult {
|
||||
histJSON, _ := json.Marshal(runAccumulatedMsgs)
|
||||
cleaned := strings.TrimSpace(lastAssistant)
|
||||
if orchMode == "plan_execute" {
|
||||
if e := strings.TrimSpace(lastPlanExecuteExecutor); e != "" {
|
||||
@@ -607,15 +707,29 @@ func runEinoADKAgentLoop(ctx context.Context, args *einoADKRunLoopArgs, baseMsgs
|
||||
if rs := []rune(cleaned); len(rs) > maxResponseRunes {
|
||||
cleaned = string(rs[:maxResponseRunes]) + "\n\n... (response truncated / 响应已截断)"
|
||||
}
|
||||
lastOut := cleaned
|
||||
resp := cleaned
|
||||
if partial && cleaned == "" {
|
||||
lastOut = einoPartialRunLastOutputHint()
|
||||
resp = emptyHint
|
||||
}
|
||||
out := &RunResult{
|
||||
Response: cleaned,
|
||||
MCPExecutionIDs: ids,
|
||||
LastReActInput: string(histJSON),
|
||||
LastReActOutput: cleaned,
|
||||
Response: resp,
|
||||
MCPExecutionIDs: mcpIDs,
|
||||
LastAgentTraceInput: string(histJSON),
|
||||
LastAgentTraceOutput: lastOut,
|
||||
}
|
||||
if out.Response == "" {
|
||||
if !partial && out.Response == "" {
|
||||
out.Response = emptyHint
|
||||
out.LastReActOutput = out.Response
|
||||
out.LastAgentTraceOutput = out.Response
|
||||
}
|
||||
return out, nil
|
||||
return out
|
||||
}
|
||||
|
||||
func buildEinoCheckpointID(orchMode string) string {
|
||||
mode := sanitizeEinoPathSegment(strings.TrimSpace(orchMode))
|
||||
if mode == "" {
|
||||
mode = "default"
|
||||
}
|
||||
return "runner-" + mode
|
||||
}
|
||||
|
||||
@@ -0,0 +1,133 @@
|
||||
package multiagent
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
|
||||
"cyberstrike-ai/internal/agent"
|
||||
|
||||
"github.com/bytedance/sonic"
|
||||
"github.com/cloudwego/eino/adk"
|
||||
"github.com/cloudwego/eino/schema"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
type einoModelInputTelemetryMiddleware struct {
|
||||
adk.BaseChatModelAgentMiddleware
|
||||
logger *zap.Logger
|
||||
modelName string
|
||||
conversationID string
|
||||
phase string
|
||||
}
|
||||
|
||||
func newEinoModelInputTelemetryMiddleware(
|
||||
logger *zap.Logger,
|
||||
modelName string,
|
||||
conversationID string,
|
||||
phase string,
|
||||
) adk.ChatModelAgentMiddleware {
|
||||
if logger == nil {
|
||||
return nil
|
||||
}
|
||||
return &einoModelInputTelemetryMiddleware{
|
||||
logger: logger,
|
||||
modelName: strings.TrimSpace(modelName),
|
||||
conversationID: strings.TrimSpace(conversationID),
|
||||
phase: strings.TrimSpace(phase),
|
||||
}
|
||||
}
|
||||
|
||||
func (m *einoModelInputTelemetryMiddleware) BeforeModelRewriteState(
|
||||
ctx context.Context,
|
||||
state *adk.ChatModelAgentState,
|
||||
mc *adk.ModelContext,
|
||||
) (context.Context, *adk.ChatModelAgentState, error) {
|
||||
if m == nil || m.logger == nil || state == nil {
|
||||
return ctx, state, nil
|
||||
}
|
||||
tokens := estimateTokensForMessagesAndTools(ctx, m.modelName, state.Messages, mcTools(mc))
|
||||
m.logger.Info("eino model input estimated",
|
||||
zap.String("phase", m.phase),
|
||||
zap.String("conversation_id", m.conversationID),
|
||||
zap.Int("messages", len(state.Messages)),
|
||||
zap.Int("tools", len(mcTools(mc))),
|
||||
zap.Int("input_tokens_estimated", tokens),
|
||||
)
|
||||
return ctx, state, nil
|
||||
}
|
||||
|
||||
func mcTools(mc *adk.ModelContext) []*schema.ToolInfo {
|
||||
if mc == nil || len(mc.Tools) == 0 {
|
||||
return nil
|
||||
}
|
||||
return mc.Tools
|
||||
}
|
||||
|
||||
func estimateTokensForMessagesAndTools(
|
||||
_ context.Context,
|
||||
modelName string,
|
||||
messages []adk.Message,
|
||||
tools []*schema.ToolInfo,
|
||||
) int {
|
||||
var sb strings.Builder
|
||||
for _, msg := range messages {
|
||||
if msg == nil {
|
||||
continue
|
||||
}
|
||||
sb.WriteString(string(msg.Role))
|
||||
sb.WriteByte('\n')
|
||||
sb.WriteString(msg.Content)
|
||||
sb.WriteByte('\n')
|
||||
if msg.ReasoningContent != "" {
|
||||
sb.WriteString(msg.ReasoningContent)
|
||||
sb.WriteByte('\n')
|
||||
}
|
||||
if len(msg.ToolCalls) > 0 {
|
||||
if b, err := sonic.Marshal(msg.ToolCalls); err == nil {
|
||||
sb.Write(b)
|
||||
sb.WriteByte('\n')
|
||||
}
|
||||
}
|
||||
}
|
||||
for _, tl := range tools {
|
||||
if tl == nil {
|
||||
continue
|
||||
}
|
||||
cp := *tl
|
||||
cp.Extra = nil
|
||||
if text, err := sonic.MarshalString(cp); err == nil {
|
||||
sb.WriteString(text)
|
||||
sb.WriteByte('\n')
|
||||
}
|
||||
}
|
||||
text := sb.String()
|
||||
if text == "" {
|
||||
return 0
|
||||
}
|
||||
tc := agent.NewTikTokenCounter()
|
||||
if n, err := tc.Count(modelName, text); err == nil {
|
||||
return n
|
||||
}
|
||||
return (len(text) + 3) / 4
|
||||
}
|
||||
|
||||
func logPlanExecuteModelInputEstimate(
|
||||
logger *zap.Logger,
|
||||
modelName string,
|
||||
conversationID string,
|
||||
phase string,
|
||||
msgs []adk.Message,
|
||||
) {
|
||||
if logger == nil {
|
||||
return
|
||||
}
|
||||
tokens := estimateTokensForMessagesAndTools(context.Background(), modelName, msgs, nil)
|
||||
logger.Info("eino model input estimated",
|
||||
zap.String("phase", phase),
|
||||
zap.String("conversation_id", strings.TrimSpace(conversationID)),
|
||||
zap.Int("messages", len(msgs)),
|
||||
zap.Int("tools", 0),
|
||||
zap.Int("input_tokens_estimated", tokens),
|
||||
)
|
||||
}
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"strings"
|
||||
|
||||
"cyberstrike-ai/internal/config"
|
||||
"cyberstrike-ai/internal/mcp/builtin"
|
||||
|
||||
localbk "github.com/cloudwego/eino-ext/adk/backend/local"
|
||||
"github.com/cloudwego/eino/adk"
|
||||
@@ -65,6 +66,66 @@ func splitToolsForToolSearch(all []tool.BaseTool, alwaysVisible int) (static []t
|
||||
return append([]tool.BaseTool(nil), all[:alwaysVisible]...), append([]tool.BaseTool(nil), all[alwaysVisible:]...), true
|
||||
}
|
||||
|
||||
func splitToolsForToolSearchByNames(all []tool.BaseTool, names []string, fallbackAlwaysVisible int) (static []tool.BaseTool, dynamic []tool.BaseTool, ok bool) {
|
||||
nameSet := make(map[string]struct{}, len(names))
|
||||
for _, n := range names {
|
||||
n = strings.TrimSpace(strings.ToLower(n))
|
||||
if n == "" {
|
||||
continue
|
||||
}
|
||||
nameSet[n] = struct{}{}
|
||||
}
|
||||
if len(nameSet) == 0 {
|
||||
return splitToolsForToolSearch(all, fallbackAlwaysVisible)
|
||||
}
|
||||
static = make([]tool.BaseTool, 0, len(all))
|
||||
dynamic = make([]tool.BaseTool, 0, len(all))
|
||||
for _, t := range all {
|
||||
if t == nil {
|
||||
continue
|
||||
}
|
||||
info, err := t.Info(context.Background())
|
||||
name := ""
|
||||
if err == nil && info != nil {
|
||||
name = strings.TrimSpace(strings.ToLower(info.Name))
|
||||
}
|
||||
if _, keep := nameSet[name]; keep {
|
||||
static = append(static, t)
|
||||
continue
|
||||
}
|
||||
dynamic = append(dynamic, t)
|
||||
}
|
||||
if len(static) == 0 || len(dynamic) == 0 {
|
||||
// fallback: preserve previous behavior when whitelist misses all or includes all.
|
||||
return splitToolsForToolSearch(all, fallbackAlwaysVisible)
|
||||
}
|
||||
return static, dynamic, true
|
||||
}
|
||||
|
||||
func mergeAlwaysVisibleToolNames(configured []string) []string {
|
||||
merged := make([]string, 0, len(configured)+32)
|
||||
seen := make(map[string]struct{}, len(configured)+32)
|
||||
add := func(name string) {
|
||||
n := strings.TrimSpace(strings.ToLower(name))
|
||||
if n == "" {
|
||||
return
|
||||
}
|
||||
if _, ok := seen[n]; ok {
|
||||
return
|
||||
}
|
||||
seen[n] = struct{}{}
|
||||
merged = append(merged, n)
|
||||
}
|
||||
for _, n := range configured {
|
||||
add(n)
|
||||
}
|
||||
// Always include hardcoded backend builtin MCP tools from constants.
|
||||
for _, n := range builtin.GetAllBuiltinTools() {
|
||||
add(n)
|
||||
}
|
||||
return merged
|
||||
}
|
||||
|
||||
func buildReductionMiddleware(ctx context.Context, mw config.MultiAgentEinoMiddlewareConfig, convID string, loc *localbk.Local, logger *zap.Logger) (adk.ChatModelAgentMiddleware, error) {
|
||||
if loc == nil {
|
||||
return nil, fmt.Errorf("reduction: local backend nil")
|
||||
@@ -87,6 +148,8 @@ func buildReductionMiddleware(ctx context.Context, mw config.MultiAgentEinoMiddl
|
||||
RootDir: root,
|
||||
ReadFileToolName: "read_file",
|
||||
ClearExcludeTools: excl,
|
||||
MaxLengthForTrunc: mw.ReductionMaxLengthForTruncEffective(),
|
||||
MaxTokensForClear: int64(mw.ReductionMaxTokensForClearEffective()),
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -142,7 +205,7 @@ func prependEinoMiddlewares(
|
||||
alwaysVis = 12
|
||||
}
|
||||
if mw.ToolSearchEnable && len(tools) >= minTools {
|
||||
static, dynamic, split := splitToolsForToolSearch(tools, alwaysVis)
|
||||
static, dynamic, split := splitToolsForToolSearchByNames(tools, mergeAlwaysVisibleToolNames(mw.ToolSearchAlwaysVisibleTools), alwaysVis)
|
||||
if split && len(dynamic) > 0 {
|
||||
ts, terr := toolsearch.New(ctx, &toolsearch.Config{DynamicTools: dynamic})
|
||||
if terr != nil {
|
||||
|
||||
@@ -0,0 +1,38 @@
|
||||
package multiagent
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/cloudwego/eino/adk"
|
||||
)
|
||||
|
||||
func applyBeforeModelRewriteHandlers(
|
||||
ctx context.Context,
|
||||
msgs []adk.Message,
|
||||
handlers []adk.ChatModelAgentMiddleware,
|
||||
) ([]adk.Message, error) {
|
||||
if len(msgs) == 0 || len(handlers) == 0 {
|
||||
return msgs, nil
|
||||
}
|
||||
state := &adk.ChatModelAgentState{Messages: msgs}
|
||||
modelCtx := &adk.ModelContext{}
|
||||
curCtx := ctx
|
||||
for _, h := range handlers {
|
||||
if h == nil {
|
||||
continue
|
||||
}
|
||||
nextCtx, nextState, err := h.BeforeModelRewriteState(curCtx, state, modelCtx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("before model rewrite: %w", err)
|
||||
}
|
||||
if nextCtx != nil {
|
||||
curCtx = nextCtx
|
||||
}
|
||||
if nextState != nil {
|
||||
state = nextState
|
||||
}
|
||||
}
|
||||
return state.Messages, nil
|
||||
}
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"cyberstrike-ai/internal/agent"
|
||||
"cyberstrike-ai/internal/config"
|
||||
|
||||
"github.com/cloudwego/eino-ext/components/model/openai"
|
||||
@@ -25,7 +26,12 @@ type PlanExecuteRootArgs struct {
|
||||
LoopMaxIter int
|
||||
// AppCfg / Logger 非空时为 Executor 挂载与 Deep/Supervisor 一致的 Eino summarization 中间件。
|
||||
AppCfg *config.Config
|
||||
MwCfg *config.MultiAgentEinoMiddlewareConfig
|
||||
// ConversationID is used for transcript/isolation paths in middleware.
|
||||
ConversationID string
|
||||
Logger *zap.Logger
|
||||
// ModelName is used for model input token estimation logs.
|
||||
ModelName string
|
||||
// ExecPreMiddlewares 是由 prependEinoMiddlewares 构建的前置中间件(patchtoolcalls, reduction, toolsearch, plantask),
|
||||
// 与 Deep/Supervisor 主代理的 mainOrchestratorPre 一致。
|
||||
ExecPreMiddlewares []adk.ChatModelAgentMiddleware
|
||||
@@ -33,6 +39,8 @@ type PlanExecuteRootArgs struct {
|
||||
SkillMiddleware adk.ChatModelAgentMiddleware
|
||||
// FilesystemMiddleware 是 Eino filesystem 中间件,当 eino_skills.filesystem_tools 启用时提供本机文件读写与 Shell 能力(可选)。
|
||||
FilesystemMiddleware adk.ChatModelAgentMiddleware
|
||||
// PlannerReplannerRewriteHandlers applies BeforeModelRewriteState pipeline for planner/replanner input.
|
||||
PlannerReplannerRewriteHandlers []adk.ChatModelAgentMiddleware
|
||||
}
|
||||
|
||||
// NewPlanExecuteRoot 返回 plan → execute → replan 预置编排根节点(与 Deep / Supervisor 并列)。
|
||||
@@ -50,7 +58,7 @@ func NewPlanExecuteRoot(ctx context.Context, a *PlanExecuteRootArgs) (adk.Resuma
|
||||
plannerCfg := &planexecute.PlannerConfig{
|
||||
ToolCallingChatModel: tcm,
|
||||
}
|
||||
if fn := planExecutePlannerGenInput(a.OrchInstruction); fn != nil {
|
||||
if fn := planExecutePlannerGenInput(a.OrchInstruction, a.AppCfg, a.MwCfg, a.Logger, a.ModelName, a.ConversationID, a.PlannerReplannerRewriteHandlers); fn != nil {
|
||||
plannerCfg.GenInputFn = fn
|
||||
}
|
||||
planner, err := planexecute.NewPlanner(ctx, plannerCfg)
|
||||
@@ -59,7 +67,7 @@ func NewPlanExecuteRoot(ctx context.Context, a *PlanExecuteRootArgs) (adk.Resuma
|
||||
}
|
||||
replanner, err := planexecute.NewReplanner(ctx, &planexecute.ReplannerConfig{
|
||||
ChatModel: tcm,
|
||||
GenInputFn: planExecuteReplannerGenInput(a.OrchInstruction),
|
||||
GenInputFn: planExecuteReplannerGenInput(a.OrchInstruction, a.AppCfg, a.MwCfg, a.Logger, a.ModelName, a.ConversationID, a.PlannerReplannerRewriteHandlers),
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("plan_execute replanner: %w", err)
|
||||
@@ -81,17 +89,20 @@ func NewPlanExecuteRoot(ctx context.Context, a *PlanExecuteRootArgs) (adk.Resuma
|
||||
}
|
||||
// 4. summarization(最后,与 Deep/Supervisor 一致)
|
||||
if a.AppCfg != nil {
|
||||
sumMw, sumErr := newEinoSummarizationMiddleware(ctx, a.ExecModel, a.AppCfg, a.Logger)
|
||||
sumMw, sumErr := newEinoSummarizationMiddleware(ctx, a.ExecModel, a.AppCfg, a.MwCfg, a.ConversationID, a.Logger)
|
||||
if sumErr != nil {
|
||||
return nil, fmt.Errorf("plan_execute executor summarization: %w", sumErr)
|
||||
}
|
||||
execHandlers = append(execHandlers, sumMw)
|
||||
}
|
||||
if teleMw := newEinoModelInputTelemetryMiddleware(a.Logger, a.ModelName, a.ConversationID, "plan_execute_executor"); teleMw != nil {
|
||||
execHandlers = append(execHandlers, teleMw)
|
||||
}
|
||||
executor, err := newPlanExecuteExecutor(ctx, &planexecute.ExecutorConfig{
|
||||
Model: a.ExecModel,
|
||||
ToolsConfig: a.ToolsCfg,
|
||||
MaxIterations: a.ExecMaxIter,
|
||||
GenInputFn: planExecuteExecutorGenInput(a.OrchInstruction),
|
||||
GenInputFn: planExecuteExecutorGenInput(a.OrchInstruction, a.AppCfg, a.MwCfg, a.Logger, a.ModelName, a.ConversationID),
|
||||
}, execHandlers)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("plan_execute executor: %w", err)
|
||||
@@ -110,20 +121,42 @@ func NewPlanExecuteRoot(ctx context.Context, a *PlanExecuteRootArgs) (adk.Resuma
|
||||
|
||||
// planExecutePlannerGenInput 将 orchestrator instruction 作为 SystemMessage 注入 planner 输入。
|
||||
// 返回 nil 时 Eino 使用内置默认 planner prompt。
|
||||
func planExecutePlannerGenInput(orchInstruction string) planexecute.GenPlannerModelInputFn {
|
||||
func planExecutePlannerGenInput(
|
||||
orchInstruction string,
|
||||
appCfg *config.Config,
|
||||
mwCfg *config.MultiAgentEinoMiddlewareConfig,
|
||||
logger *zap.Logger,
|
||||
modelName string,
|
||||
conversationID string,
|
||||
rewriteHandlers []adk.ChatModelAgentMiddleware,
|
||||
) planexecute.GenPlannerModelInputFn {
|
||||
oi := strings.TrimSpace(orchInstruction)
|
||||
if oi == "" {
|
||||
if oi == "" && appCfg == nil {
|
||||
return nil
|
||||
}
|
||||
return func(ctx context.Context, userInput []adk.Message) ([]adk.Message, error) {
|
||||
userInput = capPlanExecuteUserInputMessages(userInput, appCfg, mwCfg)
|
||||
msgs := make([]adk.Message, 0, 1+len(userInput))
|
||||
msgs = append(msgs, schema.SystemMessage(oi))
|
||||
if oi != "" {
|
||||
msgs = append(msgs, schema.SystemMessage(oi))
|
||||
}
|
||||
msgs = append(msgs, userInput...)
|
||||
if rewritten, rerr := applyBeforeModelRewriteHandlers(ctx, msgs, rewriteHandlers); rerr == nil && len(rewritten) > 0 {
|
||||
msgs = rewritten
|
||||
}
|
||||
logPlanExecuteModelInputEstimate(logger, modelName, conversationID, "plan_execute_planner", msgs)
|
||||
return msgs, nil
|
||||
}
|
||||
}
|
||||
|
||||
func planExecuteExecutorGenInput(orchInstruction string) planexecute.GenModelInputFn {
|
||||
func planExecuteExecutorGenInput(
|
||||
orchInstruction string,
|
||||
appCfg *config.Config,
|
||||
mwCfg *config.MultiAgentEinoMiddlewareConfig,
|
||||
logger *zap.Logger,
|
||||
modelName string,
|
||||
conversationID string,
|
||||
) planexecute.GenModelInputFn {
|
||||
oi := strings.TrimSpace(orchInstruction)
|
||||
return func(ctx context.Context, in *planexecute.ExecutionContext) ([]adk.Message, error) {
|
||||
planContent, err := in.Plan.MarshalJSON()
|
||||
@@ -131,9 +164,9 @@ func planExecuteExecutorGenInput(orchInstruction string) planexecute.GenModelInp
|
||||
return nil, err
|
||||
}
|
||||
userMsgs, err := planexecute.ExecutorPrompt.Format(ctx, map[string]any{
|
||||
"input": planExecuteFormatInput(in.UserInput),
|
||||
"input": planExecuteFormatInput(capPlanExecuteUserInputMessages(in.UserInput, appCfg, mwCfg)),
|
||||
"plan": string(planContent),
|
||||
"executed_steps": planExecuteFormatExecutedSteps(in.ExecutedSteps),
|
||||
"executed_steps": planExecuteFormatExecutedSteps(in.ExecutedSteps, appCfg, mwCfg),
|
||||
"step": in.Plan.FirstStep(),
|
||||
})
|
||||
if err != nil {
|
||||
@@ -142,6 +175,7 @@ func planExecuteExecutorGenInput(orchInstruction string) planexecute.GenModelInp
|
||||
if oi != "" {
|
||||
userMsgs = append([]adk.Message{schema.SystemMessage(oi)}, userMsgs...)
|
||||
}
|
||||
logPlanExecuteModelInputEstimate(logger, modelName, conversationID, "plan_execute_executor_gen_input", userMsgs)
|
||||
return userMsgs, nil
|
||||
}
|
||||
}
|
||||
@@ -155,18 +189,22 @@ func planExecuteFormatInput(input []adk.Message) string {
|
||||
return sb.String()
|
||||
}
|
||||
|
||||
func planExecuteFormatExecutedSteps(results []planexecute.ExecutedStep) string {
|
||||
capped := capPlanExecuteExecutedSteps(results)
|
||||
var sb strings.Builder
|
||||
for _, result := range capped {
|
||||
sb.WriteString(fmt.Sprintf("Step: %s\nResult: %s\n\n", result.Step, result.Result))
|
||||
}
|
||||
return sb.String()
|
||||
func planExecuteFormatExecutedSteps(results []planexecute.ExecutedStep, appCfg *config.Config, mwCfg *config.MultiAgentEinoMiddlewareConfig) string {
|
||||
capped := capPlanExecuteExecutedStepsWithConfig(results, mwCfg)
|
||||
return renderPlanExecuteStepsByBudget(capped, appCfg, mwCfg)
|
||||
}
|
||||
|
||||
// planExecuteReplannerGenInput 与 Eino 默认 Replanner 输入一致,但 executed_steps 经 cap 后再写入 prompt,
|
||||
// 且在 orchInstruction 非空时 prepend SystemMessage 使 replanner 也能接收全局指令。
|
||||
func planExecuteReplannerGenInput(orchInstruction string) planexecute.GenModelInputFn {
|
||||
func planExecuteReplannerGenInput(
|
||||
orchInstruction string,
|
||||
appCfg *config.Config,
|
||||
mwCfg *config.MultiAgentEinoMiddlewareConfig,
|
||||
logger *zap.Logger,
|
||||
modelName string,
|
||||
conversationID string,
|
||||
rewriteHandlers []adk.ChatModelAgentMiddleware,
|
||||
) planexecute.GenModelInputFn {
|
||||
oi := strings.TrimSpace(orchInstruction)
|
||||
return func(ctx context.Context, in *planexecute.ExecutionContext) ([]adk.Message, error) {
|
||||
planContent, err := in.Plan.MarshalJSON()
|
||||
@@ -175,8 +213,8 @@ func planExecuteReplannerGenInput(orchInstruction string) planexecute.GenModelIn
|
||||
}
|
||||
msgs, err := planexecute.ReplannerPrompt.Format(ctx, map[string]any{
|
||||
"plan": string(planContent),
|
||||
"input": planExecuteFormatInput(in.UserInput),
|
||||
"executed_steps": planExecuteFormatExecutedSteps(in.ExecutedSteps),
|
||||
"input": planExecuteFormatInput(capPlanExecuteUserInputMessages(in.UserInput, appCfg, mwCfg)),
|
||||
"executed_steps": planExecuteFormatExecutedSteps(in.ExecutedSteps, appCfg, mwCfg),
|
||||
"plan_tool": planexecute.PlanToolInfo.Name,
|
||||
"respond_tool": planexecute.RespondToolInfo.Name,
|
||||
})
|
||||
@@ -186,10 +224,120 @@ func planExecuteReplannerGenInput(orchInstruction string) planexecute.GenModelIn
|
||||
if oi != "" {
|
||||
msgs = append([]adk.Message{schema.SystemMessage(oi)}, msgs...)
|
||||
}
|
||||
if rewritten, rerr := applyBeforeModelRewriteHandlers(ctx, msgs, rewriteHandlers); rerr == nil && len(rewritten) > 0 {
|
||||
msgs = rewritten
|
||||
}
|
||||
logPlanExecuteModelInputEstimate(logger, modelName, conversationID, "plan_execute_replanner", msgs)
|
||||
return msgs, nil
|
||||
}
|
||||
}
|
||||
|
||||
func capPlanExecuteUserInputMessages(input []adk.Message, appCfg *config.Config, mwCfg *config.MultiAgentEinoMiddlewareConfig) []adk.Message {
|
||||
if len(input) == 0 {
|
||||
return input
|
||||
}
|
||||
maxTotal := 120000
|
||||
modelName := "gpt-4o"
|
||||
if appCfg != nil {
|
||||
if appCfg.OpenAI.MaxTotalTokens > 0 {
|
||||
maxTotal = appCfg.OpenAI.MaxTotalTokens
|
||||
}
|
||||
if m := strings.TrimSpace(appCfg.OpenAI.Model); m != "" {
|
||||
modelName = m
|
||||
}
|
||||
}
|
||||
// Reserve most tokens for planner/replanner prompt and tool schema.
|
||||
ratio := 0.35
|
||||
if mwCfg != nil {
|
||||
ratio = mwCfg.PlanExecuteUserInputBudgetRatioEffective()
|
||||
}
|
||||
budget := int(float64(maxTotal) * ratio)
|
||||
if budget < 4096 {
|
||||
budget = 4096
|
||||
}
|
||||
tc := agent.NewTikTokenCounter()
|
||||
out := make([]adk.Message, 0, len(input))
|
||||
used := 0
|
||||
for i := len(input) - 1; i >= 0; i-- {
|
||||
msg := input[i]
|
||||
if msg == nil {
|
||||
continue
|
||||
}
|
||||
n, err := tc.Count(modelName, string(msg.Role)+"\n"+msg.Content)
|
||||
if err != nil {
|
||||
n = (len(msg.Content) + 3) / 4
|
||||
}
|
||||
if n <= 0 {
|
||||
n = 1
|
||||
}
|
||||
if used+n > budget {
|
||||
break
|
||||
}
|
||||
used += n
|
||||
out = append(out, msg)
|
||||
}
|
||||
for i, j := 0, len(out)-1; i < j; i, j = i+1, j-1 {
|
||||
out[i], out[j] = out[j], out[i]
|
||||
}
|
||||
if len(out) == 0 {
|
||||
// Keep the latest user message at least.
|
||||
return []adk.Message{input[len(input)-1]}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func renderPlanExecuteStepsByBudget(steps []planexecute.ExecutedStep, appCfg *config.Config, mwCfg *config.MultiAgentEinoMiddlewareConfig) string {
|
||||
if len(steps) == 0 {
|
||||
return ""
|
||||
}
|
||||
maxTotal := 120000
|
||||
modelName := "gpt-4o"
|
||||
if appCfg != nil {
|
||||
if appCfg.OpenAI.MaxTotalTokens > 0 {
|
||||
maxTotal = appCfg.OpenAI.MaxTotalTokens
|
||||
}
|
||||
if m := strings.TrimSpace(appCfg.OpenAI.Model); m != "" {
|
||||
modelName = m
|
||||
}
|
||||
}
|
||||
ratio := 0.2
|
||||
if mwCfg != nil {
|
||||
ratio = mwCfg.PlanExecuteExecutedStepsBudgetRatioEffective()
|
||||
}
|
||||
budget := int(float64(maxTotal) * ratio)
|
||||
if budget < 3072 {
|
||||
budget = 3072
|
||||
}
|
||||
tc := agent.NewTikTokenCounter()
|
||||
var kept []string
|
||||
used := 0
|
||||
skipped := 0
|
||||
for i := len(steps) - 1; i >= 0; i-- {
|
||||
block := fmt.Sprintf("Step: %s\nResult: %s\n\n", steps[i].Step, steps[i].Result)
|
||||
n, err := tc.Count(modelName, block)
|
||||
if err != nil {
|
||||
n = (len(block) + 3) / 4
|
||||
}
|
||||
if n <= 0 {
|
||||
n = 1
|
||||
}
|
||||
if used+n > budget {
|
||||
skipped = i + 1
|
||||
break
|
||||
}
|
||||
used += n
|
||||
kept = append(kept, block)
|
||||
}
|
||||
var sb strings.Builder
|
||||
if skipped > 0 {
|
||||
sb.WriteString(fmt.Sprintf("Earlier executed steps omitted due to context budget: %d steps.\n\n", skipped))
|
||||
}
|
||||
for i := len(kept) - 1; i >= 0; i-- {
|
||||
sb.WriteString(kept[i])
|
||||
}
|
||||
return sb.String()
|
||||
}
|
||||
|
||||
// planExecuteStreamsMainAssistant 将规划/执行/重规划各阶段助手流式输出映射到主对话区。
|
||||
func planExecuteStreamsMainAssistant(agent string) bool {
|
||||
if agent == "" {
|
||||
|
||||
@@ -125,7 +125,7 @@ func RunEinoSingleChatModelAgent(
|
||||
return nil, fmt.Errorf("eino single 模型: %w", err)
|
||||
}
|
||||
|
||||
mainSumMw, err := newEinoSummarizationMiddleware(ctx, mainModel, appCfg, logger)
|
||||
mainSumMw, err := newEinoSummarizationMiddleware(ctx, mainModel, appCfg, &ma.EinoMiddleware, conversationID, logger)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("eino single summarization: %w", err)
|
||||
}
|
||||
@@ -145,6 +145,9 @@ func RunEinoSingleChatModelAgent(
|
||||
handlers = append(handlers, einoSkillMW)
|
||||
}
|
||||
handlers = append(handlers, mainSumMw)
|
||||
if teleMw := newEinoModelInputTelemetryMiddleware(logger, appCfg.OpenAI.Model, conversationID, "eino_single"); teleMw != nil {
|
||||
handlers = append(handlers, teleMw)
|
||||
}
|
||||
|
||||
maxIter := ma.MaxIteration
|
||||
if maxIter <= 0 {
|
||||
@@ -165,11 +168,29 @@ func RunEinoSingleChatModelAgent(
|
||||
},
|
||||
EmitInternalEvents: true,
|
||||
}
|
||||
ins := injectToolNamesOnlyInstruction(ctx, ag.EinoSingleAgentSystemInstruction(), mainTools)
|
||||
if logger != nil {
|
||||
names := collectToolNames(ctx, mainTools)
|
||||
mountedNames := collectToolNames(ctx, mainToolsForCfg)
|
||||
hasToolSearch := false
|
||||
for _, n := range names {
|
||||
if strings.EqualFold(strings.TrimSpace(n), "tool_search") {
|
||||
hasToolSearch = true
|
||||
break
|
||||
}
|
||||
}
|
||||
logger.Info("eino tool-name injection",
|
||||
zap.String("scope", "eino_single"),
|
||||
zap.Int("tool_names", len(names)),
|
||||
zap.Int("mounted_tool_names", len(mountedNames)),
|
||||
zap.Bool("has_tool_search", hasToolSearch),
|
||||
)
|
||||
}
|
||||
|
||||
chatCfg := &adk.ChatModelAgentConfig{
|
||||
Name: einoSingleAgentName,
|
||||
Description: "Eino ADK ChatModelAgent with MCP tools for authorized security testing.",
|
||||
Instruction: ag.EinoSingleAgentSystemInstruction(),
|
||||
Instruction: ins,
|
||||
Model: mainModel,
|
||||
ToolsConfig: mainToolsCfg,
|
||||
MaxIterations: maxIter,
|
||||
@@ -188,7 +209,7 @@ func RunEinoSingleChatModelAgent(
|
||||
return nil, fmt.Errorf("eino single NewChatModelAgent: %w", err)
|
||||
}
|
||||
|
||||
baseMsgs := historyToMessages(history)
|
||||
baseMsgs := historyToMessages(history, appCfg, &ma.EinoMiddleware)
|
||||
baseMsgs = append(baseMsgs, schema.UserMessage(userMessage))
|
||||
|
||||
streamsMainAssistant := func(agent string) bool {
|
||||
|
||||
@@ -3,6 +3,8 @@ package multiagent
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"cyberstrike-ai/internal/agent"
|
||||
@@ -32,6 +34,8 @@ func newEinoSummarizationMiddleware(
|
||||
ctx context.Context,
|
||||
summaryModel model.BaseChatModel,
|
||||
appCfg *config.Config,
|
||||
mwCfg *config.MultiAgentEinoMiddlewareConfig,
|
||||
conversationID string,
|
||||
logger *zap.Logger,
|
||||
) (adk.ChatModelAgentMiddleware, error) {
|
||||
if summaryModel == nil || appCfg == nil {
|
||||
@@ -41,7 +45,14 @@ func newEinoSummarizationMiddleware(
|
||||
if maxTotal <= 0 {
|
||||
maxTotal = 120000
|
||||
}
|
||||
trigger := int(float64(maxTotal) * 0.9)
|
||||
triggerRatio := 0.8
|
||||
emitInternalEvents := true
|
||||
if mwCfg != nil {
|
||||
triggerRatio = mwCfg.SummarizationTriggerRatioEffective()
|
||||
emitInternalEvents = mwCfg.SummarizationEmitInternalEventsEffective()
|
||||
}
|
||||
// Keep enough safety margin for tokenizer/model-side accounting mismatch.
|
||||
trigger := int(float64(maxTotal) * triggerRatio)
|
||||
if trigger < 4096 {
|
||||
trigger = maxTotal
|
||||
if trigger < 4096 {
|
||||
@@ -65,6 +76,18 @@ func newEinoSummarizationMiddleware(
|
||||
if recentTrailMax > trigger/2 {
|
||||
recentTrailMax = trigger / 2
|
||||
}
|
||||
transcriptPath := ""
|
||||
if conv := strings.TrimSpace(conversationID); conv != "" {
|
||||
baseRoot := filepath.Join(os.TempDir(), "cyberstrike-summarization")
|
||||
if dbPath := strings.TrimSpace(appCfg.Database.Path); dbPath != "" {
|
||||
// Persist with the same lifecycle as local conversation storage.
|
||||
baseRoot = filepath.Join(filepath.Dir(dbPath), "conversation_artifacts", sanitizeEinoPathSegment(conv), "summarization")
|
||||
}
|
||||
base := baseRoot
|
||||
if mkErr := os.MkdirAll(base, 0o755); mkErr == nil {
|
||||
transcriptPath = filepath.Join(base, "transcript.txt")
|
||||
}
|
||||
}
|
||||
|
||||
mw, err := summarization.New(ctx, &summarization.Config{
|
||||
Model: summaryModel,
|
||||
@@ -73,7 +96,8 @@ func newEinoSummarizationMiddleware(
|
||||
},
|
||||
TokenCounter: tokenCounter,
|
||||
UserInstruction: einoSummarizeUserInstruction,
|
||||
EmitInternalEvents: false,
|
||||
EmitInternalEvents: emitInternalEvents,
|
||||
TranscriptFilePath: transcriptPath,
|
||||
PreserveUserMessages: &summarization.PreserveUserMessages{
|
||||
Enabled: true,
|
||||
MaxTokens: preserveMax,
|
||||
@@ -85,11 +109,16 @@ func newEinoSummarizationMiddleware(
|
||||
if logger == nil {
|
||||
return nil
|
||||
}
|
||||
beforeTokens, _ := tokenCounter(ctx, &summarization.TokenCounterInput{Messages: before.Messages})
|
||||
afterTokens, _ := tokenCounter(ctx, &summarization.TokenCounterInput{Messages: after.Messages})
|
||||
logger.Info("eino summarization 已压缩上下文",
|
||||
zap.Int("messages_before", len(before.Messages)),
|
||||
zap.Int("messages_after", len(after.Messages)),
|
||||
zap.Int("tokens_before_estimated", beforeTokens),
|
||||
zap.Int("tokens_after_estimated", afterTokens),
|
||||
zap.Int("max_total_tokens", maxTotal),
|
||||
zap.Int("trigger_context_tokens", trigger),
|
||||
zap.String("transcript_file", transcriptPath),
|
||||
)
|
||||
return nil
|
||||
},
|
||||
|
||||
@@ -0,0 +1,73 @@
|
||||
package multiagent
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
|
||||
"github.com/cloudwego/eino/components/tool"
|
||||
)
|
||||
|
||||
// injectToolNamesOnlyInstruction prepends a compact tool-name-only section into
|
||||
// the system instruction so the model can reference current callable names.
|
||||
func injectToolNamesOnlyInstruction(ctx context.Context, instruction string, tools []tool.BaseTool) string {
|
||||
names := collectToolNames(ctx, tools)
|
||||
if len(names) == 0 {
|
||||
return strings.TrimSpace(instruction)
|
||||
}
|
||||
hasToolSearch := false
|
||||
for _, n := range names {
|
||||
if strings.EqualFold(strings.TrimSpace(n), "tool_search") {
|
||||
hasToolSearch = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
var sb strings.Builder
|
||||
sb.WriteString("以下是当前会话中可调用的工具名称列表(仅名称,无参数定义):\n")
|
||||
for _, name := range names {
|
||||
sb.WriteString("- ")
|
||||
sb.WriteString(name)
|
||||
sb.WriteByte('\n')
|
||||
}
|
||||
sb.WriteString("\n使用规则:\n")
|
||||
sb.WriteString("1) 上述仅为名称列表,不包含参数定义。\n")
|
||||
if hasToolSearch {
|
||||
sb.WriteString("2) 在调用具体工具前,应先使用 tool_search 查看工具详情与参数要求,再发起调用。\n")
|
||||
} else {
|
||||
sb.WriteString("2) 调用具体工具前,请先确认该工具的参数要求;不确定时先澄清再调用。\n")
|
||||
}
|
||||
sb.WriteString("3) 不要臆造不存在的工具名。\n\n")
|
||||
if s := strings.TrimSpace(instruction); s != "" {
|
||||
sb.WriteString(s)
|
||||
}
|
||||
return sb.String()
|
||||
}
|
||||
|
||||
func collectToolNames(ctx context.Context, tools []tool.BaseTool) []string {
|
||||
if len(tools) == 0 {
|
||||
return nil
|
||||
}
|
||||
seen := make(map[string]struct{}, len(tools))
|
||||
out := make([]string, 0, len(tools))
|
||||
for _, t := range tools {
|
||||
if t == nil {
|
||||
continue
|
||||
}
|
||||
info, err := t.Info(ctx)
|
||||
if err != nil || info == nil {
|
||||
continue
|
||||
}
|
||||
name := strings.TrimSpace(info.Name)
|
||||
if name == "" {
|
||||
continue
|
||||
}
|
||||
key := strings.ToLower(name)
|
||||
if _, ok := seen[key]; ok {
|
||||
continue
|
||||
}
|
||||
seen[key] = struct{}{}
|
||||
out = append(out, name)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
@@ -59,4 +59,3 @@ func (m *noNestedTaskMiddleware) WrapInvokableToolCall(
|
||||
return endpoint(ctx2, argumentsInJSON, opts...)
|
||||
}, nil
|
||||
}
|
||||
|
||||
|
||||
@@ -71,7 +71,7 @@ func planExecuteDefaultGenExecutorInput(ctx context.Context, in *planexecute.Exe
|
||||
return planexecute.ExecutorPrompt.Format(ctx, map[string]any{
|
||||
"input": planExecuteFormatInput(in.UserInput),
|
||||
"plan": string(planContent),
|
||||
"executed_steps": planExecuteFormatExecutedSteps(in.ExecutedSteps),
|
||||
"executed_steps": planExecuteFormatExecutedSteps(in.ExecutedSteps, nil, nil),
|
||||
"step": in.Plan.FirstStep(),
|
||||
})
|
||||
}
|
||||
|
||||
@@ -5,6 +5,8 @@ import (
|
||||
"strings"
|
||||
"unicode/utf8"
|
||||
|
||||
"cyberstrike-ai/internal/config"
|
||||
|
||||
"github.com/cloudwego/eino/adk/prebuilt/planexecute"
|
||||
)
|
||||
|
||||
@@ -12,8 +14,11 @@ import (
|
||||
// 此处仅约束「写入模型 prompt 的视图」,不修改 Eino session 中的原始 ExecutedSteps。
|
||||
|
||||
const (
|
||||
planExecuteMaxStepResultRunes = 12000
|
||||
planExecuteKeepLastSteps = 16
|
||||
defaultPlanExecuteMaxStepResultRunes = 4000
|
||||
defaultPlanExecuteKeepLastSteps = 8
|
||||
// Backward-compatible aliases for tests and existing references.
|
||||
planExecuteMaxStepResultRunes = defaultPlanExecuteMaxStepResultRunes
|
||||
planExecuteKeepLastSteps = defaultPlanExecuteKeepLastSteps
|
||||
)
|
||||
|
||||
func truncateRunesWithSuffix(s string, maxRunes int, suffix string) string {
|
||||
@@ -29,16 +34,26 @@ func truncateRunesWithSuffix(s string, maxRunes int, suffix string) string {
|
||||
|
||||
// capPlanExecuteExecutedSteps 折叠较早步骤、截断单步过长结果,供 prompt 使用。
|
||||
func capPlanExecuteExecutedSteps(steps []planexecute.ExecutedStep) []planexecute.ExecutedStep {
|
||||
return capPlanExecuteExecutedStepsWithConfig(steps, nil)
|
||||
}
|
||||
|
||||
func capPlanExecuteExecutedStepsWithConfig(steps []planexecute.ExecutedStep, mwCfg *config.MultiAgentEinoMiddlewareConfig) []planexecute.ExecutedStep {
|
||||
if len(steps) == 0 {
|
||||
return steps
|
||||
}
|
||||
maxStepResultRunes := defaultPlanExecuteMaxStepResultRunes
|
||||
keepLastSteps := defaultPlanExecuteKeepLastSteps
|
||||
if mwCfg != nil {
|
||||
maxStepResultRunes = mwCfg.PlanExecuteMaxStepResultRunesEffective()
|
||||
keepLastSteps = mwCfg.PlanExecuteKeepLastStepsEffective()
|
||||
}
|
||||
out := make([]planexecute.ExecutedStep, 0, len(steps)+1)
|
||||
start := 0
|
||||
if len(steps) > planExecuteKeepLastSteps {
|
||||
start = len(steps) - planExecuteKeepLastSteps
|
||||
if len(steps) > keepLastSteps {
|
||||
start = len(steps) - keepLastSteps
|
||||
var b strings.Builder
|
||||
b.WriteString(fmt.Sprintf("(上文已完成 %d 步;此处仅保留步骤标题以节省上下文,完整输出已省略。后续 %d 步仍保留正文。)\n",
|
||||
start, planExecuteKeepLastSteps))
|
||||
start, keepLastSteps))
|
||||
for i := 0; i < start; i++ {
|
||||
b.WriteString(fmt.Sprintf("- %s\n", steps[i].Step))
|
||||
}
|
||||
@@ -50,8 +65,8 @@ func capPlanExecuteExecutedSteps(steps []planexecute.ExecutedStep) []planexecute
|
||||
suffix := "\n…[step result truncated]"
|
||||
for i := start; i < len(steps); i++ {
|
||||
e := steps[i]
|
||||
if utf8.RuneCountInString(e.Result) > planExecuteMaxStepResultRunes {
|
||||
e.Result = truncateRunesWithSuffix(e.Result, planExecuteMaxStepResultRunes, suffix)
|
||||
if utf8.RuneCountInString(e.Result) > maxStepResultRunes {
|
||||
e.Result = truncateRunesWithSuffix(e.Result, maxStepResultRunes, suffix)
|
||||
}
|
||||
out = append(out, e)
|
||||
}
|
||||
|
||||
+111
-14
@@ -30,10 +30,10 @@ import (
|
||||
|
||||
// RunResult 与单 Agent 循环结果字段对齐,便于复用存储与 SSE 收尾逻辑。
|
||||
type RunResult struct {
|
||||
Response string
|
||||
MCPExecutionIDs []string
|
||||
LastReActInput string
|
||||
LastReActOutput string
|
||||
Response string
|
||||
MCPExecutionIDs []string
|
||||
LastAgentTraceInput string // 已序列化的消息带(JSON):原生循环或 Eino 均写入,供续跑/攻击链等恢复上下文
|
||||
LastAgentTraceOutput string // 本轮助手侧对外展示文本(摘要或最终回复)
|
||||
}
|
||||
|
||||
// toolCallPendingInfo tracks a tool_call emitted to the UI so we can later
|
||||
@@ -237,7 +237,7 @@ func RunDeepAgent(
|
||||
subMax = subDefaultIter
|
||||
}
|
||||
|
||||
subSumMw, err := newEinoSummarizationMiddleware(ctx, subModel, appCfg, logger)
|
||||
subSumMw, err := newEinoSummarizationMiddleware(ctx, subModel, appCfg, &ma.EinoMiddleware, conversationID, logger)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("子代理 %q summarization 中间件: %w", id, err)
|
||||
}
|
||||
@@ -257,11 +257,33 @@ func RunDeepAgent(
|
||||
subHandlers = append(subHandlers, einoSkillMW)
|
||||
}
|
||||
subHandlers = append(subHandlers, subSumMw)
|
||||
if teleMw := newEinoModelInputTelemetryMiddleware(logger, appCfg.OpenAI.Model, conversationID, "sub_agent"); teleMw != nil {
|
||||
subHandlers = append(subHandlers, teleMw)
|
||||
}
|
||||
|
||||
subInstrFinal := injectToolNamesOnlyInstruction(ctx, instr, subTools)
|
||||
if logger != nil {
|
||||
subNames := collectToolNames(ctx, subTools)
|
||||
mountedNames := collectToolNames(ctx, subToolsForCfg)
|
||||
hasToolSearch := false
|
||||
for _, n := range subNames {
|
||||
if strings.EqualFold(strings.TrimSpace(n), "tool_search") {
|
||||
hasToolSearch = true
|
||||
break
|
||||
}
|
||||
}
|
||||
logger.Info("eino tool-name injection",
|
||||
zap.String("scope", "sub_agent"),
|
||||
zap.String("agent", id),
|
||||
zap.Int("tool_names", len(subNames)),
|
||||
zap.Int("mounted_tool_names", len(mountedNames)),
|
||||
zap.Bool("has_tool_search", hasToolSearch),
|
||||
)
|
||||
}
|
||||
sa, err := adk.NewChatModelAgent(ctx, &adk.ChatModelAgentConfig{
|
||||
Name: id,
|
||||
Description: desc,
|
||||
Instruction: instr,
|
||||
Instruction: subInstrFinal,
|
||||
Model: subModel,
|
||||
ToolsConfig: adk.ToolsConfig{
|
||||
ToolsNodeConfig: compose.ToolsNodeConfig{
|
||||
@@ -289,7 +311,7 @@ func RunDeepAgent(
|
||||
return nil, fmt.Errorf("多代理主模型: %w", err)
|
||||
}
|
||||
|
||||
mainSumMw, err := newEinoSummarizationMiddleware(ctx, mainModel, appCfg, logger)
|
||||
mainSumMw, err := newEinoSummarizationMiddleware(ctx, mainModel, appCfg, &ma.EinoMiddleware, conversationID, logger)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("多代理主 summarization 中间件: %w", err)
|
||||
}
|
||||
@@ -313,6 +335,25 @@ func RunDeepAgent(
|
||||
orchDescription = d
|
||||
}
|
||||
}
|
||||
orchInstruction = injectToolNamesOnlyInstruction(ctx, orchInstruction, mainTools)
|
||||
if logger != nil {
|
||||
mainNames := collectToolNames(ctx, mainTools)
|
||||
mountedNames := collectToolNames(ctx, mainToolsForCfg)
|
||||
hasToolSearch := false
|
||||
for _, n := range mainNames {
|
||||
if strings.EqualFold(strings.TrimSpace(n), "tool_search") {
|
||||
hasToolSearch = true
|
||||
break
|
||||
}
|
||||
}
|
||||
logger.Info("eino tool-name injection",
|
||||
zap.String("scope", "orchestrator"),
|
||||
zap.String("orchestration", orchMode),
|
||||
zap.Int("tool_names", len(mainNames)),
|
||||
zap.Int("mounted_tool_names", len(mountedNames)),
|
||||
zap.Bool("has_tool_search", hasToolSearch),
|
||||
)
|
||||
}
|
||||
|
||||
supInstr := strings.TrimSpace(orchInstruction)
|
||||
if orchMode == "supervisor" {
|
||||
@@ -352,6 +393,9 @@ func RunDeepAgent(
|
||||
deepHandlers = append(deepHandlers, einoSkillMW)
|
||||
}
|
||||
deepHandlers = append(deepHandlers, mainSumMw)
|
||||
if teleMw := newEinoModelInputTelemetryMiddleware(logger, appCfg.OpenAI.Model, conversationID, "deep_orchestrator"); teleMw != nil {
|
||||
deepHandlers = append(deepHandlers, teleMw)
|
||||
}
|
||||
|
||||
supHandlers := []adk.ChatModelAgentMiddleware{}
|
||||
if len(mainOrchestratorPre) > 0 {
|
||||
@@ -361,6 +405,9 @@ func RunDeepAgent(
|
||||
supHandlers = append(supHandlers, einoSkillMW)
|
||||
}
|
||||
supHandlers = append(supHandlers, mainSumMw)
|
||||
if teleMw := newEinoModelInputTelemetryMiddleware(logger, appCfg.OpenAI.Model, conversationID, "supervisor_orchestrator"); teleMw != nil {
|
||||
supHandlers = append(supHandlers, teleMw)
|
||||
}
|
||||
|
||||
mainToolsCfg := adk.ToolsConfig{
|
||||
ToolsNodeConfig: compose.ToolsNodeConfig{
|
||||
@@ -399,10 +446,17 @@ func RunDeepAgent(
|
||||
ExecMaxIter: deepMaxIter,
|
||||
LoopMaxIter: ma.PlanExecuteLoopMaxIterations,
|
||||
AppCfg: appCfg,
|
||||
MwCfg: &ma.EinoMiddleware,
|
||||
ConversationID: conversationID,
|
||||
Logger: logger,
|
||||
ModelName: appCfg.OpenAI.Model,
|
||||
ExecPreMiddlewares: mainOrchestratorPre,
|
||||
SkillMiddleware: einoSkillMW,
|
||||
FilesystemMiddleware: peFsMw,
|
||||
PlannerReplannerRewriteHandlers: []adk.ChatModelAgentMiddleware{
|
||||
mainSumMw,
|
||||
newEinoModelInputTelemetryMiddleware(logger, appCfg.OpenAI.Model, conversationID, "plan_execute_planner_replanner_rewrite"),
|
||||
},
|
||||
})
|
||||
if perr != nil {
|
||||
return nil, perr
|
||||
@@ -468,7 +522,7 @@ func RunDeepAgent(
|
||||
da = dDeep
|
||||
}
|
||||
|
||||
baseMsgs := historyToMessages(history)
|
||||
baseMsgs := historyToMessages(history, appCfg, &ma.EinoMiddleware)
|
||||
baseMsgs = append(baseMsgs, schema.UserMessage(userMessage))
|
||||
|
||||
streamsMainAssistant := func(agent string) bool {
|
||||
@@ -505,34 +559,77 @@ func RunDeepAgent(
|
||||
}, baseMsgs)
|
||||
}
|
||||
|
||||
func historyToMessages(history []agent.ChatMessage) []adk.Message {
|
||||
func historyToMessages(history []agent.ChatMessage, appCfg *config.Config, mwCfg *config.MultiAgentEinoMiddlewareConfig) []adk.Message {
|
||||
if len(history) == 0 {
|
||||
return nil
|
||||
}
|
||||
// 放宽条数上限:跨轮历史交给 Eino Summarization(阈值对齐 openai.max_total_tokens)在调用模型前压缩,避免在入队前硬截断为 40 条。
|
||||
const maxHistoryMessages = 300
|
||||
// Keep a bounded tail first; then enforce a token budget.
|
||||
const maxHistoryMessages = 200
|
||||
start := 0
|
||||
if len(history) > maxHistoryMessages {
|
||||
start = len(history) - maxHistoryMessages
|
||||
}
|
||||
out := make([]adk.Message, 0, len(history[start:]))
|
||||
raw := make([]adk.Message, 0, len(history[start:]))
|
||||
for _, h := range history[start:] {
|
||||
switch h.Role {
|
||||
case "user":
|
||||
if strings.TrimSpace(h.Content) != "" {
|
||||
out = append(out, schema.UserMessage(h.Content))
|
||||
raw = append(raw, schema.UserMessage(h.Content))
|
||||
}
|
||||
case "assistant":
|
||||
if strings.TrimSpace(h.Content) == "" && len(h.ToolCalls) > 0 {
|
||||
continue
|
||||
}
|
||||
if strings.TrimSpace(h.Content) != "" {
|
||||
out = append(out, schema.AssistantMessage(h.Content, nil))
|
||||
raw = append(raw, schema.AssistantMessage(h.Content, nil))
|
||||
}
|
||||
default:
|
||||
continue
|
||||
}
|
||||
}
|
||||
if len(raw) == 0 {
|
||||
return raw
|
||||
}
|
||||
maxTotal := 120000
|
||||
modelName := "gpt-4o"
|
||||
if appCfg != nil {
|
||||
if appCfg.OpenAI.MaxTotalTokens > 0 {
|
||||
maxTotal = appCfg.OpenAI.MaxTotalTokens
|
||||
}
|
||||
if m := strings.TrimSpace(appCfg.OpenAI.Model); m != "" {
|
||||
modelName = m
|
||||
}
|
||||
}
|
||||
ratio := 0.35
|
||||
if mwCfg != nil {
|
||||
ratio = mwCfg.HistoryInputBudgetRatioEffective()
|
||||
}
|
||||
budget := int(float64(maxTotal) * ratio)
|
||||
if budget < 4096 {
|
||||
budget = 4096
|
||||
}
|
||||
tc := agent.NewTikTokenCounter()
|
||||
outRev := make([]adk.Message, 0, len(raw))
|
||||
used := 0
|
||||
for i := len(raw) - 1; i >= 0; i-- {
|
||||
msg := raw[i]
|
||||
n, err := tc.Count(modelName, string(msg.Role)+"\n"+msg.Content)
|
||||
if err != nil {
|
||||
n = (len(msg.Content) + 3) / 4
|
||||
}
|
||||
if n <= 0 {
|
||||
n = 1
|
||||
}
|
||||
if used+n > budget {
|
||||
break
|
||||
}
|
||||
used += n
|
||||
outRev = append(outRev, msg)
|
||||
}
|
||||
out := make([]adk.Message, 0, len(outRev))
|
||||
for i := len(outRev) - 1; i >= 0; i-- {
|
||||
out = append(out, outRev[i])
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
|
||||
@@ -192,13 +192,13 @@ func convertOpenAIToClaude(payload interface{}) (*claudeRequest, error) {
|
||||
fnName, _ := fn["name"].(string)
|
||||
fnArgs, _ := fn["arguments"]
|
||||
|
||||
// 防御:缺少 name 或 id 的 tool_call 会被 Claude 拒绝
|
||||
if strings.TrimSpace(fnName) == "" {
|
||||
fnName = "unknown_function"
|
||||
}
|
||||
if strings.TrimSpace(tcID) == "" {
|
||||
tcID = fmt.Sprintf("call_%d", time.Now().UnixNano())
|
||||
}
|
||||
// 防御:缺少 name 或 id 的 tool_call 会被 Claude 拒绝
|
||||
if strings.TrimSpace(fnName) == "" {
|
||||
fnName = "unknown_function"
|
||||
}
|
||||
if strings.TrimSpace(tcID) == "" {
|
||||
tcID = fmt.Sprintf("call_%d", time.Now().UnixNano())
|
||||
}
|
||||
|
||||
var inputRaw json.RawMessage
|
||||
switch v := fnArgs.(type) {
|
||||
@@ -752,25 +752,33 @@ func isClaudeProvider(cfg *config.OpenAIConfig) bool {
|
||||
// Eino HTTP Client Bridge
|
||||
// ============================================================
|
||||
|
||||
// NewEinoHTTPClient 为 einoopenai.ChatModelConfig 返回一个支持 Claude 自动桥接的 http.Client。
|
||||
// 当 cfg.Provider 为 claude 时,会拦截 /chat/completions 请求,透明转换为 Anthropic Messages API。
|
||||
// NewEinoHTTPClient 为 einoopenai.ChatModelConfig 返回一个 http.Client,包含两层 transport 包装:
|
||||
// 1. 当 cfg.Provider 为 claude 时,最内层套 claudeRoundTripper,把 OpenAI /chat/completions 透明
|
||||
// 桥接为 Anthropic /v1/messages(并把 Claude SSE 翻译回 OpenAI SSE 格式)。
|
||||
// 2. 最外层无条件套 einoSSESanitizingRoundTripper,吞掉中转站发的 SSE 心跳/注释/控制行
|
||||
// (": keepalive" / "event: ping" / "retry: 3000" 等),避免 Eino 用的 meguminnnnnnnnn/go-openai
|
||||
// SDK 在累计超过 300 个非 "data:" 行后抛 "stream has sent too many empty messages"。
|
||||
//
|
||||
// 两层都对调用方完全透明:普通 JSON 响应原样透传,仅当响应 Content-Type 为 text/event-stream 时
|
||||
// sanitizer 才会接管 body;data: payload (含 [DONE]、{"error":...}) 一字节不改。
|
||||
func NewEinoHTTPClient(cfg *config.OpenAIConfig, base *http.Client) *http.Client {
|
||||
if base == nil {
|
||||
base = http.DefaultClient
|
||||
}
|
||||
if !isClaudeProvider(cfg) {
|
||||
return base
|
||||
}
|
||||
|
||||
cloned := *base
|
||||
transport := base.Transport
|
||||
if transport == nil {
|
||||
transport = http.DefaultTransport
|
||||
}
|
||||
cloned.Transport = &claudeRoundTripper{
|
||||
base: transport,
|
||||
config: cfg,
|
||||
if isClaudeProvider(cfg) {
|
||||
transport = &claudeRoundTripper{
|
||||
base: transport,
|
||||
config: cfg,
|
||||
}
|
||||
}
|
||||
transport = &einoSSESanitizingRoundTripper{base: transport}
|
||||
cloned.Transport = transport
|
||||
return &cloned
|
||||
}
|
||||
|
||||
|
||||
@@ -0,0 +1,149 @@
|
||||
package openai
|
||||
|
||||
// eino_sse_sanitizer.go 解决 Eino 走 meguminnnnnnnnn/go-openai SDK 时,
|
||||
// 中转站心跳/SSE 控制行累计 > 300 行触发 ErrTooManyEmptyStreamMessages
|
||||
// (报错文案: "stream has sent too many empty messages")的问题。
|
||||
//
|
||||
// 触发链路:
|
||||
// einoopenai.NewChatModel
|
||||
// → eino-ext/libs/acl/openai → meguminnnnnnnnn/go-openai
|
||||
// → streamReader.processLines() 对所有非 "data:" 行计数, > 300 即抛错。
|
||||
//
|
||||
// 中转站常见的非 data: 行(合法 SSE 但 SDK 不接受):
|
||||
// ":" / ": keepalive" / ": ping" / "event: ping" / "retry: 3000"
|
||||
// 以及思考型模型 prefill 期间穿插的大量心跳。
|
||||
//
|
||||
// 兜底策略: 在 HTTP transport 层把响应 Body 包一层 reader, 只放行 "data:"
|
||||
// 开头的行, 把心跳/注释/事件类型行就地吞掉。下游 SDK 永远见不到非 data: 行,
|
||||
// 计数器始终为 0, 该错误不可能再发生。
|
||||
//
|
||||
// 该层对调用方完全透明:
|
||||
// - 仅当响应 Content-Type 是 text/event-stream 时介入;普通 JSON 响应原样透传
|
||||
// - data: payload (含 [DONE] 与 {"error":...}) 一字节不改
|
||||
// - 上游真断流 (EOF / connection reset / context cancel) 原样透传
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
)
|
||||
|
||||
const (
|
||||
// einoSSEReaderBufSize 给 bufio 一个较大的初始缓冲, 避免单行大 JSON chunk
|
||||
// (含工具调用 arguments / reasoning_content) 频繁触发缓冲区扩容。
|
||||
einoSSEReaderBufSize = 64 * 1024
|
||||
)
|
||||
|
||||
// einoSSESanitizingRoundTripper 包装下游 RoundTripper, 对 SSE 响应做行级清洗。
|
||||
type einoSSESanitizingRoundTripper struct {
|
||||
base http.RoundTripper
|
||||
}
|
||||
|
||||
func (rt *einoSSESanitizingRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
resp, err := rt.base.RoundTrip(req)
|
||||
if err != nil || resp == nil {
|
||||
return resp, err
|
||||
}
|
||||
if !isSSEResponse(resp) {
|
||||
return resp, nil
|
||||
}
|
||||
resp.Body = newEinoSSESanitizingBody(resp.Body)
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
// isSSEResponse 仅对 200 + text/event-stream 的响应做清洗;
|
||||
// 错误响应 (4xx/5xx 通常是 application/json) 不动, 由 SDK 走原错误路径。
|
||||
func isSSEResponse(resp *http.Response) bool {
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return false
|
||||
}
|
||||
ct := resp.Header.Get("Content-Type")
|
||||
if ct == "" {
|
||||
return false
|
||||
}
|
||||
ct = strings.ToLower(strings.TrimSpace(ct))
|
||||
// 兼容 "text/event-stream", "text/event-stream; charset=utf-8" 等。
|
||||
return strings.HasPrefix(ct, "text/event-stream")
|
||||
}
|
||||
|
||||
// einoSSESanitizingBody 是包装后的响应体: 只放行 data: 行, 其它行吞掉。
|
||||
type einoSSESanitizingBody struct {
|
||||
upstream io.ReadCloser
|
||||
reader *bufio.Reader
|
||||
pending []byte // 已清洗、待返回给下游的字节 (永远以 \n 结尾的完整 data: 行)
|
||||
err error // upstream 终态错误 (io.EOF 或网络错误)
|
||||
}
|
||||
|
||||
func newEinoSSESanitizingBody(body io.ReadCloser) *einoSSESanitizingBody {
|
||||
return &einoSSESanitizingBody{
|
||||
upstream: body,
|
||||
reader: bufio.NewReaderSize(body, einoSSEReaderBufSize),
|
||||
}
|
||||
}
|
||||
|
||||
func (b *einoSSESanitizingBody) Read(p []byte) (int, error) {
|
||||
if len(p) == 0 {
|
||||
return 0, nil
|
||||
}
|
||||
if len(b.pending) > 0 {
|
||||
n := copy(p, b.pending)
|
||||
b.pending = b.pending[n:]
|
||||
return n, nil
|
||||
}
|
||||
|
||||
// 从上游读, 直到攒出一行 data: 或拿到终态。
|
||||
// 单次循环可能丢弃任意多行心跳, 但只放行至多一行 data: 后退出,
|
||||
// 避免一次 Read 阻塞过久 / pending 缓冲过大。
|
||||
for b.err == nil {
|
||||
line, err := b.reader.ReadBytes('\n')
|
||||
if len(line) > 0 {
|
||||
if isPassThroughSSELine(line) {
|
||||
if line[len(line)-1] != '\n' {
|
||||
line = append(line, '\n')
|
||||
}
|
||||
b.pending = line
|
||||
if err != nil {
|
||||
b.err = err
|
||||
}
|
||||
break
|
||||
}
|
||||
// 非 data: 行 (空行 / ":" 注释 / event: / retry: / id: / 任何裸文本)
|
||||
// 全部吞掉, 不向下游透出, 继续循环读下一行。
|
||||
}
|
||||
if err != nil {
|
||||
b.err = err
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if len(b.pending) > 0 {
|
||||
n := copy(p, b.pending)
|
||||
b.pending = b.pending[n:]
|
||||
return n, nil
|
||||
}
|
||||
return 0, b.err
|
||||
}
|
||||
|
||||
func (b *einoSSESanitizingBody) Close() error {
|
||||
return b.upstream.Close()
|
||||
}
|
||||
|
||||
// isPassThroughSSELine 判定该行是否需要原样放行给下游 SDK。
|
||||
// 仅 "data:" (大小写不敏感, 可有任意前导空白) 开头的行需要保留。
|
||||
// 注意: 不能用 TrimSpace 去尾部换行后再判, 否则 " data: x" 会被误判;
|
||||
// 我们只 trim 前导空白, 与 SDK 内部 TrimSpace 后再正则 ^data:\s* 的语义一致。
|
||||
func isPassThroughSSELine(line []byte) bool {
|
||||
trimmed := bytes.TrimLeft(line, " \t")
|
||||
if len(trimmed) < 5 {
|
||||
return false
|
||||
}
|
||||
// 大小写不敏感比较前 5 字节是否为 "data:"。SSE 规范要求字段名小写,
|
||||
// 但宽松匹配可以兼容个别中转站的非规范实现。
|
||||
return (trimmed[0] == 'd' || trimmed[0] == 'D') &&
|
||||
(trimmed[1] == 'a' || trimmed[1] == 'A') &&
|
||||
(trimmed[2] == 't' || trimmed[2] == 'T') &&
|
||||
(trimmed[3] == 'a' || trimmed[3] == 'A') &&
|
||||
trimmed[4] == ':'
|
||||
}
|
||||
@@ -0,0 +1,303 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"errors"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"regexp"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// 复现 meguminnnnnnnnn/go-openai 的 SSE 行计数算法 (默认 limit=300):
|
||||
// - 逐行读
|
||||
// - 非 "data:" 行 (空行 / ":" 注释 / event: / retry:) 累计 emptyMessagesCount
|
||||
// - > 300 抛 ErrTooManyEmptyStreamMessages
|
||||
// - 遇到 data: 行 reset, 返回 payload
|
||||
//
|
||||
// 这一算法与上游 SDK 的 stream_reader.go processLines() 严格一致 (验证依据见
|
||||
// /Users/temp/go/pkg/mod/github.com/meguminnnnnnnnn/go-openai@v0.1.2/stream_reader.go)。
|
||||
// 测试中只复刻 "限制触发" 这一行为, 用来回归验证 sanitizer 的根因修复。
|
||||
var errTooManyEmptyStreamMessages = errors.New("stream has sent too many empty messages")
|
||||
|
||||
func sdkLikeRecvAll(body io.Reader, limit uint) ([]string, error) {
|
||||
headerData := regexp.MustCompile(`^data:\s*`)
|
||||
r := bufio.NewReader(body)
|
||||
var payloads []string
|
||||
for {
|
||||
var emptyMessagesCount uint
|
||||
var payload []byte
|
||||
for {
|
||||
line, err := r.ReadBytes('\n')
|
||||
if err != nil {
|
||||
if err == io.EOF {
|
||||
return payloads, nil
|
||||
}
|
||||
return payloads, err
|
||||
}
|
||||
noSpace := bytes.TrimSpace(line)
|
||||
if !headerData.Match(noSpace) {
|
||||
emptyMessagesCount++
|
||||
if emptyMessagesCount > limit {
|
||||
return payloads, errTooManyEmptyStreamMessages
|
||||
}
|
||||
continue
|
||||
}
|
||||
payload = headerData.ReplaceAll(noSpace, nil)
|
||||
break
|
||||
}
|
||||
if string(payload) == "[DONE]" {
|
||||
return payloads, nil
|
||||
}
|
||||
payloads = append(payloads, string(payload))
|
||||
}
|
||||
}
|
||||
|
||||
func newSSEServer(t *testing.T, body string, contentType string, status int) *httptest.Server {
|
||||
t.Helper()
|
||||
return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
if contentType != "" {
|
||||
w.Header().Set("Content-Type", contentType)
|
||||
}
|
||||
w.WriteHeader(status)
|
||||
_, _ = io.WriteString(w, body)
|
||||
}))
|
||||
}
|
||||
|
||||
func sanitizingClient(base *http.Client) *http.Client {
|
||||
if base == nil {
|
||||
base = &http.Client{}
|
||||
}
|
||||
cloned := *base
|
||||
transport := base.Transport
|
||||
if transport == nil {
|
||||
transport = http.DefaultTransport
|
||||
}
|
||||
cloned.Transport = &einoSSESanitizingRoundTripper{base: transport}
|
||||
return &cloned
|
||||
}
|
||||
|
||||
func readAll(t *testing.T, body io.ReadCloser) string {
|
||||
t.Helper()
|
||||
defer body.Close()
|
||||
out, err := io.ReadAll(body)
|
||||
if err != nil {
|
||||
t.Fatalf("read body: %v", err)
|
||||
}
|
||||
return string(out)
|
||||
}
|
||||
|
||||
// 1) 仅 data: 行 → 一字节不改地透传。
|
||||
func TestSSESanitizer_PassesDataLinesUnchanged(t *testing.T) {
|
||||
body := "data: {\"a\":1}\ndata: {\"b\":2}\ndata: [DONE]\n"
|
||||
srv := newSSEServer(t, body, "text/event-stream", 200)
|
||||
defer srv.Close()
|
||||
|
||||
resp, err := sanitizingClient(nil).Get(srv.URL)
|
||||
if err != nil {
|
||||
t.Fatalf("get: %v", err)
|
||||
}
|
||||
got := readAll(t, resp.Body)
|
||||
if got != body {
|
||||
t.Fatalf("body mismatch:\nwant %q\ngot %q", body, got)
|
||||
}
|
||||
}
|
||||
|
||||
// 2) 心跳/注释/事件类型行被吞掉, 仅保留 data: 行。
|
||||
func TestSSESanitizer_DropsHeartbeatsAndControlLines(t *testing.T) {
|
||||
body := strings.Join([]string{
|
||||
": keepalive",
|
||||
"",
|
||||
"event: ping",
|
||||
"retry: 3000",
|
||||
"id: 42",
|
||||
"data: {\"x\":1}",
|
||||
": ping",
|
||||
"",
|
||||
"data: {\"x\":2}",
|
||||
"data: [DONE]",
|
||||
"",
|
||||
}, "\n")
|
||||
srv := newSSEServer(t, body, "text/event-stream", 200)
|
||||
defer srv.Close()
|
||||
|
||||
resp, err := sanitizingClient(nil).Get(srv.URL)
|
||||
if err != nil {
|
||||
t.Fatalf("get: %v", err)
|
||||
}
|
||||
got := readAll(t, resp.Body)
|
||||
want := "data: {\"x\":1}\ndata: {\"x\":2}\ndata: [DONE]\n"
|
||||
if got != want {
|
||||
t.Fatalf("sanitized body mismatch:\nwant %q\ngot %q", want, got)
|
||||
}
|
||||
}
|
||||
|
||||
// 3) 根因回归: 上游堆 500 行心跳后才发 data:, 原始 SDK 算法会抛
|
||||
// ErrTooManyEmptyStreamMessages, sanitize 之后必须能正常拿到所有 data:。
|
||||
func TestSSESanitizer_ProtectsAgainstTooManyEmptyMessages(t *testing.T) {
|
||||
const heartbeats = 500
|
||||
var buf bytes.Buffer
|
||||
for i := 0; i < heartbeats; i++ {
|
||||
buf.WriteString(": keepalive\n")
|
||||
}
|
||||
buf.WriteString("data: {\"chunk\":1}\n")
|
||||
buf.WriteString("data: {\"chunk\":2}\n")
|
||||
buf.WriteString("data: [DONE]\n")
|
||||
|
||||
t.Run("baseline_without_sanitizer_must_fail", func(t *testing.T) {
|
||||
_, err := sdkLikeRecvAll(bytes.NewReader(buf.Bytes()), 300)
|
||||
if !errors.Is(err, errTooManyEmptyStreamMessages) {
|
||||
t.Fatalf("expected ErrTooManyEmptyStreamMessages, got %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("with_sanitizer_must_succeed", func(t *testing.T) {
|
||||
srv := newSSEServer(t, buf.String(), "text/event-stream", 200)
|
||||
defer srv.Close()
|
||||
|
||||
resp, err := sanitizingClient(nil).Get(srv.URL)
|
||||
if err != nil {
|
||||
t.Fatalf("get: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
payloads, err := sdkLikeRecvAll(resp.Body, 300)
|
||||
if err != nil {
|
||||
t.Fatalf("sdk-like recv after sanitize: %v", err)
|
||||
}
|
||||
want := []string{`{"chunk":1}`, `{"chunk":2}`}
|
||||
if len(payloads) != len(want) {
|
||||
t.Fatalf("payload count mismatch: want %d got %d (%v)", len(want), len(payloads), payloads)
|
||||
}
|
||||
for i, w := range want {
|
||||
if payloads[i] != w {
|
||||
t.Fatalf("payload[%d] mismatch: want %q got %q", i, w, payloads[i])
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// 4) 心跳穿插在 data: 之间也能正确清洗 (思考型模型 prefill 期间常见)。
|
||||
func TestSSESanitizer_HeartbeatsInterleavedWithData(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
buf.WriteString("data: {\"chunk\":1}\n")
|
||||
for i := 0; i < 400; i++ {
|
||||
buf.WriteString(": keepalive\n")
|
||||
}
|
||||
buf.WriteString("data: {\"chunk\":2}\n")
|
||||
buf.WriteString("data: [DONE]\n")
|
||||
|
||||
srv := newSSEServer(t, buf.String(), "text/event-stream", 200)
|
||||
defer srv.Close()
|
||||
|
||||
resp, err := sanitizingClient(nil).Get(srv.URL)
|
||||
if err != nil {
|
||||
t.Fatalf("get: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
payloads, err := sdkLikeRecvAll(resp.Body, 300)
|
||||
if err != nil {
|
||||
t.Fatalf("sdk-like recv: %v", err)
|
||||
}
|
||||
if got, want := len(payloads), 2; got != want {
|
||||
t.Fatalf("payload count: want %d got %d", want, got)
|
||||
}
|
||||
}
|
||||
|
||||
// 5) 非 SSE 响应 (例如非流式 JSON) 不应被 sanitizer 介入。
|
||||
func TestSSESanitizer_PassesNonSSEResponseUntouched(t *testing.T) {
|
||||
body := `{"id":"x","object":"chat.completion","choices":[]}`
|
||||
srv := newSSEServer(t, body, "application/json", 200)
|
||||
defer srv.Close()
|
||||
|
||||
resp, err := sanitizingClient(nil).Get(srv.URL)
|
||||
if err != nil {
|
||||
t.Fatalf("get: %v", err)
|
||||
}
|
||||
got := readAll(t, resp.Body)
|
||||
if got != body {
|
||||
t.Fatalf("non-SSE body must be untouched:\nwant %q\ngot %q", body, got)
|
||||
}
|
||||
}
|
||||
|
||||
// 6) 错误响应 (4xx/5xx) 不应被 sanitize, 即使 Content-Type 是 SSE 也不动,
|
||||
// 避免吞掉类似 "data: " 之外的错误正文。
|
||||
func TestSSESanitizer_PassesNon200Untouched(t *testing.T) {
|
||||
body := `{"error":{"message":"rate limit"}}`
|
||||
srv := newSSEServer(t, body, "text/event-stream", 429)
|
||||
defer srv.Close()
|
||||
|
||||
resp, err := sanitizingClient(nil).Get(srv.URL)
|
||||
if err != nil {
|
||||
t.Fatalf("get: %v", err)
|
||||
}
|
||||
got := readAll(t, resp.Body)
|
||||
if got != body {
|
||||
t.Fatalf("error body must be untouched:\nwant %q\ngot %q", body, got)
|
||||
}
|
||||
}
|
||||
|
||||
// 7) data: 行末尾若缺 \n (异常上游) sanitizer 也补齐, 保证下游按行解析。
|
||||
func TestSSESanitizer_AppendsTrailingNewlineIfMissing(t *testing.T) {
|
||||
body := "data: {\"a\":1}"
|
||||
srv := newSSEServer(t, body, "text/event-stream", 200)
|
||||
defer srv.Close()
|
||||
|
||||
resp, err := sanitizingClient(nil).Get(srv.URL)
|
||||
if err != nil {
|
||||
t.Fatalf("get: %v", err)
|
||||
}
|
||||
got := readAll(t, resp.Body)
|
||||
want := "data: {\"a\":1}\n"
|
||||
if got != want {
|
||||
t.Fatalf("trailing newline:\nwant %q\ngot %q", want, got)
|
||||
}
|
||||
}
|
||||
|
||||
// 8) 大 chunk (一行数十 KB) 也能完整透传, 不被切断。
|
||||
func TestSSESanitizer_LargeDataLinePassesIntact(t *testing.T) {
|
||||
huge := strings.Repeat("x", 80*1024)
|
||||
body := "data: {\"big\":\"" + huge + "\"}\ndata: [DONE]\n"
|
||||
srv := newSSEServer(t, body, "text/event-stream", 200)
|
||||
defer srv.Close()
|
||||
|
||||
resp, err := sanitizingClient(nil).Get(srv.URL)
|
||||
if err != nil {
|
||||
t.Fatalf("get: %v", err)
|
||||
}
|
||||
got := readAll(t, resp.Body)
|
||||
if got != body {
|
||||
t.Fatalf("large body length mismatch: want %d got %d", len(body), len(got))
|
||||
}
|
||||
}
|
||||
|
||||
// 9) isPassThroughSSELine 单元覆盖。
|
||||
func TestIsPassThroughSSELine(t *testing.T) {
|
||||
cases := []struct {
|
||||
line string
|
||||
want bool
|
||||
}{
|
||||
{"data: {\"a\":1}\n", true},
|
||||
{"DATA: x\n", true},
|
||||
{" data: x\n", true},
|
||||
{"data:\n", true},
|
||||
{"\n", false},
|
||||
{"\r\n", false},
|
||||
{": keepalive\n", false},
|
||||
{":\n", false},
|
||||
{"event: ping\n", false},
|
||||
{"retry: 3000\n", false},
|
||||
{"id: 42\n", false},
|
||||
{"datax: y\n", false},
|
||||
{"da", false},
|
||||
}
|
||||
for _, c := range cases {
|
||||
if got := isPassThroughSSELine([]byte(c.line)); got != c.want {
|
||||
t.Errorf("isPassThroughSSELine(%q) = %v, want %v", c.line, got, c.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
+14
-14
@@ -281,9 +281,9 @@ func (c *Client) ChatCompletionStream(ctx context.Context, payload interface{},
|
||||
|
||||
// StreamToolCall 流式工具调用的累积结果(arguments 以字符串形式拼接,留给上层再解析为 JSON)。
|
||||
type StreamToolCall struct {
|
||||
Index int
|
||||
ID string
|
||||
Type string
|
||||
Index int
|
||||
ID string
|
||||
Type string
|
||||
FunctionName string
|
||||
FunctionArgsStr string
|
||||
}
|
||||
@@ -348,10 +348,10 @@ func (c *Client) ChatCompletionStreamWithToolCalls(
|
||||
Arguments string `json:"arguments,omitempty"`
|
||||
}
|
||||
type toolCallDelta struct {
|
||||
Index int `json:"index,omitempty"`
|
||||
ID string `json:"id,omitempty"`
|
||||
Type string `json:"type,omitempty"`
|
||||
Function toolCallFunctionDelta `json:"function,omitempty"`
|
||||
Index int `json:"index,omitempty"`
|
||||
ID string `json:"id,omitempty"`
|
||||
Type string `json:"type,omitempty"`
|
||||
Function toolCallFunctionDelta `json:"function,omitempty"`
|
||||
}
|
||||
type streamDelta2 struct {
|
||||
Content string `json:"content,omitempty"`
|
||||
@@ -371,10 +371,10 @@ func (c *Client) ChatCompletionStreamWithToolCalls(
|
||||
}
|
||||
|
||||
type toolCallAccum struct {
|
||||
id string
|
||||
typ string
|
||||
name string
|
||||
args strings.Builder
|
||||
id string
|
||||
typ string
|
||||
name string
|
||||
args strings.Builder
|
||||
}
|
||||
toolCallAccums := make(map[int]*toolCallAccum)
|
||||
|
||||
@@ -475,9 +475,9 @@ func (c *Client) ChatCompletionStreamWithToolCalls(
|
||||
for _, idx := range indices {
|
||||
acc := toolCallAccums[idx]
|
||||
tc := StreamToolCall{
|
||||
Index: idx,
|
||||
ID: acc.id,
|
||||
Type: acc.typ,
|
||||
Index: idx,
|
||||
ID: acc.id,
|
||||
Type: acc.typ,
|
||||
FunctionName: acc.name,
|
||||
FunctionArgsStr: acc.args.String(),
|
||||
}
|
||||
|
||||
@@ -19,11 +19,11 @@ import (
|
||||
func setupTestExecutor(t *testing.T) (*Executor, *mcp.Server) {
|
||||
logger := zap.NewNop()
|
||||
mcpServer := mcp.NewServer(logger)
|
||||
|
||||
|
||||
cfg := &config.SecurityConfig{
|
||||
Tools: []config.ToolConfig{},
|
||||
}
|
||||
|
||||
|
||||
executor := NewExecutor(cfg, mcpServer, logger)
|
||||
return executor, mcpServer
|
||||
}
|
||||
@@ -32,12 +32,12 @@ func setupTestExecutor(t *testing.T) (*Executor, *mcp.Server) {
|
||||
func setupTestStorage(t *testing.T) *storage.FileResultStorage {
|
||||
tmpDir := filepath.Join(os.TempDir(), "test_executor_storage_"+time.Now().Format("20060102_150405"))
|
||||
logger := zap.NewNop()
|
||||
|
||||
|
||||
storage, err := storage.NewFileResultStorage(tmpDir, logger)
|
||||
if err != nil {
|
||||
t.Fatalf("创建测试存储失败: %v", err)
|
||||
}
|
||||
|
||||
|
||||
return storage
|
||||
}
|
||||
|
||||
@@ -45,46 +45,46 @@ func TestExecutor_ExecuteInternalTool_QueryExecutionResult(t *testing.T) {
|
||||
executor, _ := setupTestExecutor(t)
|
||||
testStorage := setupTestStorage(t)
|
||||
executor.SetResultStorage(testStorage)
|
||||
|
||||
|
||||
// 准备测试数据
|
||||
executionID := "test_exec_001"
|
||||
toolName := "nmap_scan"
|
||||
result := "Line 1: Port 22 open\nLine 2: Port 80 open\nLine 3: Port 443 open\nLine 4: error occurred"
|
||||
|
||||
|
||||
// 保存测试结果
|
||||
err := testStorage.SaveResult(executionID, toolName, result)
|
||||
if err != nil {
|
||||
t.Fatalf("保存测试结果失败: %v", err)
|
||||
}
|
||||
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
|
||||
// 测试1: 基本查询(第一页)
|
||||
args := map[string]interface{}{
|
||||
"execution_id": executionID,
|
||||
"page": float64(1),
|
||||
"limit": float64(2),
|
||||
}
|
||||
|
||||
|
||||
toolResult, err := executor.executeQueryExecutionResult(ctx, args)
|
||||
if err != nil {
|
||||
t.Fatalf("执行查询失败: %v", err)
|
||||
}
|
||||
|
||||
|
||||
if toolResult.IsError {
|
||||
t.Fatalf("查询应该成功,但返回了错误: %s", toolResult.Content[0].Text)
|
||||
}
|
||||
|
||||
|
||||
// 验证结果包含预期内容
|
||||
resultText := toolResult.Content[0].Text
|
||||
if !strings.Contains(resultText, executionID) {
|
||||
t.Errorf("结果中应该包含执行ID: %s", executionID)
|
||||
}
|
||||
|
||||
|
||||
if !strings.Contains(resultText, "第 1/") {
|
||||
t.Errorf("结果中应该包含分页信息")
|
||||
}
|
||||
|
||||
|
||||
// 测试2: 搜索功能
|
||||
args2 := map[string]interface{}{
|
||||
"execution_id": executionID,
|
||||
@@ -92,21 +92,21 @@ func TestExecutor_ExecuteInternalTool_QueryExecutionResult(t *testing.T) {
|
||||
"page": float64(1),
|
||||
"limit": float64(10),
|
||||
}
|
||||
|
||||
|
||||
toolResult2, err := executor.executeQueryExecutionResult(ctx, args2)
|
||||
if err != nil {
|
||||
t.Fatalf("执行搜索失败: %v", err)
|
||||
}
|
||||
|
||||
|
||||
if toolResult2.IsError {
|
||||
t.Fatalf("搜索应该成功,但返回了错误: %s", toolResult2.Content[0].Text)
|
||||
}
|
||||
|
||||
|
||||
resultText2 := toolResult2.Content[0].Text
|
||||
if !strings.Contains(resultText2, "error") {
|
||||
t.Errorf("搜索结果中应该包含关键词: error")
|
||||
}
|
||||
|
||||
|
||||
// 测试3: 过滤功能
|
||||
args3 := map[string]interface{}{
|
||||
"execution_id": executionID,
|
||||
@@ -114,46 +114,46 @@ func TestExecutor_ExecuteInternalTool_QueryExecutionResult(t *testing.T) {
|
||||
"page": float64(1),
|
||||
"limit": float64(10),
|
||||
}
|
||||
|
||||
|
||||
toolResult3, err := executor.executeQueryExecutionResult(ctx, args3)
|
||||
if err != nil {
|
||||
t.Fatalf("执行过滤失败: %v", err)
|
||||
}
|
||||
|
||||
|
||||
if toolResult3.IsError {
|
||||
t.Fatalf("过滤应该成功,但返回了错误: %s", toolResult3.Content[0].Text)
|
||||
}
|
||||
|
||||
|
||||
resultText3 := toolResult3.Content[0].Text
|
||||
if !strings.Contains(resultText3, "Port") {
|
||||
t.Errorf("过滤结果中应该包含关键词: Port")
|
||||
}
|
||||
|
||||
|
||||
// 测试4: 缺少必需参数
|
||||
args4 := map[string]interface{}{
|
||||
"page": float64(1),
|
||||
}
|
||||
|
||||
|
||||
toolResult4, err := executor.executeQueryExecutionResult(ctx, args4)
|
||||
if err != nil {
|
||||
t.Fatalf("执行查询失败: %v", err)
|
||||
}
|
||||
|
||||
|
||||
if !toolResult4.IsError {
|
||||
t.Fatal("缺少execution_id应该返回错误")
|
||||
}
|
||||
|
||||
|
||||
// 测试5: 不存在的执行ID
|
||||
args5 := map[string]interface{}{
|
||||
"execution_id": "nonexistent_id",
|
||||
"page": float64(1),
|
||||
}
|
||||
|
||||
|
||||
toolResult5, err := executor.executeQueryExecutionResult(ctx, args5)
|
||||
if err != nil {
|
||||
t.Fatalf("执行查询失败: %v", err)
|
||||
}
|
||||
|
||||
|
||||
if !toolResult5.IsError {
|
||||
t.Fatal("不存在的执行ID应该返回错误")
|
||||
}
|
||||
@@ -161,22 +161,22 @@ func TestExecutor_ExecuteInternalTool_QueryExecutionResult(t *testing.T) {
|
||||
|
||||
func TestExecutor_ExecuteInternalTool_UnknownTool(t *testing.T) {
|
||||
executor, _ := setupTestExecutor(t)
|
||||
|
||||
|
||||
ctx := context.Background()
|
||||
args := map[string]interface{}{
|
||||
"test": "value",
|
||||
}
|
||||
|
||||
|
||||
// 测试未知的内部工具类型
|
||||
toolResult, err := executor.executeInternalTool(ctx, "unknown_tool", "internal:unknown_tool", args)
|
||||
if err != nil {
|
||||
t.Fatalf("执行内部工具失败: %v", err)
|
||||
}
|
||||
|
||||
|
||||
if !toolResult.IsError {
|
||||
t.Fatal("未知的工具类型应该返回错误")
|
||||
}
|
||||
|
||||
|
||||
if !strings.Contains(toolResult.Content[0].Text, "未知的内部工具类型") {
|
||||
t.Errorf("错误消息应该包含'未知的内部工具类型'")
|
||||
}
|
||||
@@ -185,21 +185,21 @@ func TestExecutor_ExecuteInternalTool_UnknownTool(t *testing.T) {
|
||||
func TestExecutor_ExecuteInternalTool_NoStorage(t *testing.T) {
|
||||
executor, _ := setupTestExecutor(t)
|
||||
// 不设置存储,测试未初始化的情况
|
||||
|
||||
|
||||
ctx := context.Background()
|
||||
args := map[string]interface{}{
|
||||
"execution_id": "test_id",
|
||||
}
|
||||
|
||||
|
||||
toolResult, err := executor.executeQueryExecutionResult(ctx, args)
|
||||
if err != nil {
|
||||
t.Fatalf("执行查询失败: %v", err)
|
||||
}
|
||||
|
||||
|
||||
if !toolResult.IsError {
|
||||
t.Fatal("未初始化的存储应该返回错误")
|
||||
}
|
||||
|
||||
|
||||
if !strings.Contains(toolResult.Content[0].Text, "结果存储未初始化") {
|
||||
t.Errorf("错误消息应该包含'结果存储未初始化'")
|
||||
}
|
||||
@@ -207,7 +207,7 @@ func TestExecutor_ExecuteInternalTool_NoStorage(t *testing.T) {
|
||||
|
||||
func TestPaginateLines(t *testing.T) {
|
||||
lines := []string{"Line 1", "Line 2", "Line 3", "Line 4", "Line 5"}
|
||||
|
||||
|
||||
// 测试第一页
|
||||
page := paginateLines(lines, 1, 2)
|
||||
if page.Page != 1 {
|
||||
@@ -225,7 +225,7 @@ func TestPaginateLines(t *testing.T) {
|
||||
if len(page.Lines) != 2 {
|
||||
t.Errorf("第一页行数不匹配。期望: 2, 实际: %d", len(page.Lines))
|
||||
}
|
||||
|
||||
|
||||
// 测试第二页
|
||||
page2 := paginateLines(lines, 2, 2)
|
||||
if len(page2.Lines) != 2 {
|
||||
@@ -234,13 +234,13 @@ func TestPaginateLines(t *testing.T) {
|
||||
if page2.Lines[0] != "Line 3" {
|
||||
t.Errorf("第二页第一行不匹配。期望: Line 3, 实际: %s", page2.Lines[0])
|
||||
}
|
||||
|
||||
|
||||
// 测试最后一页
|
||||
page3 := paginateLines(lines, 3, 2)
|
||||
if len(page3.Lines) != 1 {
|
||||
t.Errorf("第三页行数不匹配。期望: 1, 实际: %d", len(page3.Lines))
|
||||
}
|
||||
|
||||
|
||||
// 测试超出范围的页码(应该返回最后一页)
|
||||
page4 := paginateLines(lines, 4, 2)
|
||||
if page4.Page != 3 {
|
||||
@@ -249,13 +249,13 @@ func TestPaginateLines(t *testing.T) {
|
||||
if len(page4.Lines) != 1 {
|
||||
t.Errorf("最后一页应该只有1行。实际: %d行", len(page4.Lines))
|
||||
}
|
||||
|
||||
|
||||
// 测试无效页码(小于1)
|
||||
page0 := paginateLines(lines, 0, 2)
|
||||
if page0.Page != 1 {
|
||||
t.Errorf("无效页码应该被修正为1。实际: %d", page0.Page)
|
||||
}
|
||||
|
||||
|
||||
// 测试空列表
|
||||
emptyPage := paginateLines([]string{}, 1, 10)
|
||||
if emptyPage.TotalLines != 0 {
|
||||
@@ -265,4 +265,3 @@ func TestPaginateLines(t *testing.T) {
|
||||
t.Errorf("空列表应该返回空结果。实际: %d行", len(emptyPage.Lines))
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -16,10 +16,10 @@ type rateLimitEntry struct {
|
||||
|
||||
// RateLimiter 基于 IP 的滑动窗口速率限制器
|
||||
type RateLimiter struct {
|
||||
mu sync.Mutex
|
||||
entries map[string]*rateLimitEntry
|
||||
limit int // 窗口内允许的最大请求数
|
||||
window time.Duration // 窗口时长
|
||||
mu sync.Mutex
|
||||
entries map[string]*rateLimitEntry
|
||||
limit int // 窗口内允许的最大请求数
|
||||
window time.Duration // 窗口时长
|
||||
}
|
||||
|
||||
// NewRateLimiter 创建速率限制器
|
||||
|
||||
@@ -162,4 +162,3 @@ func truncateRunes(s string, max int) string {
|
||||
}
|
||||
return string(r[:max]) + "…"
|
||||
}
|
||||
|
||||
|
||||
@@ -49,12 +49,12 @@ func ParseSkillMD(raw []byte) (*SkillManifest, string, error) {
|
||||
}
|
||||
|
||||
type skillFrontMatterExport struct {
|
||||
Name string `yaml:"name"`
|
||||
Description string `yaml:"description"`
|
||||
License string `yaml:"license,omitempty"`
|
||||
Compatibility string `yaml:"compatibility,omitempty"`
|
||||
Metadata map[string]any `yaml:"metadata,omitempty"`
|
||||
AllowedTools string `yaml:"allowed-tools,omitempty"`
|
||||
Name string `yaml:"name"`
|
||||
Description string `yaml:"description"`
|
||||
License string `yaml:"license,omitempty"`
|
||||
Compatibility string `yaml:"compatibility,omitempty"`
|
||||
Metadata map[string]any `yaml:"metadata,omitempty"`
|
||||
AllowedTools string `yaml:"allowed-tools,omitempty"`
|
||||
}
|
||||
|
||||
// BuildSkillMD serializes SKILL.md per agentskills.io.
|
||||
|
||||
@@ -9,10 +9,10 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
maxPackageFiles = 4000
|
||||
maxPackageDepth = 24
|
||||
maxScriptsDepth = 24
|
||||
defaultMaxRead = 10 << 20
|
||||
maxPackageFiles = 4000
|
||||
maxPackageDepth = 24
|
||||
maxScriptsDepth = 24
|
||||
defaultMaxRead = 10 << 20
|
||||
)
|
||||
|
||||
// SafeRelPath resolves rel inside root (no ..).
|
||||
|
||||
+1329
-71
File diff suppressed because it is too large
Load Diff
+107
-3
@@ -20,7 +20,13 @@
|
||||
"copied": "Copied",
|
||||
"copyFailed": "Copy failed",
|
||||
"view": "View",
|
||||
"actions": "Actions"
|
||||
"actions": "Actions",
|
||||
"loadFailed": "Load failed",
|
||||
"untitled": "Untitled",
|
||||
"justNow": "Just now",
|
||||
"minutesAgo": "{{n}} min ago",
|
||||
"hoursAgo": "{{n}} h ago",
|
||||
"daysAgo": "{{n}} d ago"
|
||||
},
|
||||
"header": {
|
||||
"title": "CyberStrikeAI",
|
||||
@@ -33,6 +39,13 @@
|
||||
"version": "Current version",
|
||||
"toggleSidebar": "Collapse/expand sidebar"
|
||||
},
|
||||
"notifications": {
|
||||
"title": "Notifications",
|
||||
"empty": "No new events",
|
||||
"markAllRead": "Mark all read",
|
||||
"markSingleRead": "Read",
|
||||
"itemDefaultTitle": "Notification"
|
||||
},
|
||||
"login": {
|
||||
"title": "Sign in to CyberStrikeAI",
|
||||
"subtitle": "Enter the access password from config",
|
||||
@@ -81,6 +94,7 @@
|
||||
"severityMedium": "Medium",
|
||||
"severityLow": "Low",
|
||||
"severityInfo": "Info",
|
||||
"totalVulns": "Total vulnerabilities",
|
||||
"runOverview": "Run overview",
|
||||
"batchQueues": "Batch task queues",
|
||||
"pending": "Pending",
|
||||
@@ -107,7 +121,80 @@
|
||||
"toUse": "To use",
|
||||
"active": "Active",
|
||||
"highFreq": "High frequency",
|
||||
"noCallData": "No call data"
|
||||
"noCallData": "No call data",
|
||||
"lastUpdated": "Last updated",
|
||||
"viewAll": "View all →",
|
||||
"recentVulns": "Recent vulnerabilities",
|
||||
"noVulnYet": "No recent vulnerabilities",
|
||||
"capabilities": "Capabilities",
|
||||
"mcpTools": "MCP tools",
|
||||
"rolesLabel": "Roles",
|
||||
"agentsLabel": "Agents",
|
||||
"webshellLabel": "WebShell",
|
||||
"pendingCountLabel": "{{count}} pending",
|
||||
"highCountLabel": "High {{count}}",
|
||||
"toolsCountLabel_one": "{{count}} tool",
|
||||
"toolsCountLabel_other": "{{count}} tools",
|
||||
"failedNCalls_one": "{{count}} failed",
|
||||
"failedNCalls_other": "{{count}} failed",
|
||||
"noCallYet": "No calls yet",
|
||||
"allClear": "No new risks",
|
||||
"allIdle": "System idle",
|
||||
"executingNow": "Running",
|
||||
"healthyStatus": "Healthy",
|
||||
"normalStatus": "Mostly OK",
|
||||
"degradedStatus": "Needs attention",
|
||||
"alertTitle": "Heads up",
|
||||
"alertWarningTitle": "Needs attention",
|
||||
"alertDangerTitle": "Action required",
|
||||
"alertCriticalReason_one": "{{count}} open critical vulnerability — please review immediately",
|
||||
"alertCriticalReason_other": "{{count}} open critical vulnerabilities — please review immediately",
|
||||
"alertFailedReason_one": "Tool success rate is low ({{count}} failed call) — check MCP monitor",
|
||||
"alertFailedReason_other": "Tool success rate is low ({{count}} failed calls) — check MCP monitor",
|
||||
"alertHitlReason_one": "{{count}} HITL request pending — Agent is waiting for your decision",
|
||||
"alertHitlReason_other": "{{count}} HITL requests pending — Agent is waiting for your decision",
|
||||
"alertMcpDownReason_one": "{{count}} External MCP server is down — related tools are unavailable",
|
||||
"alertMcpDownReason_other": "{{count}} External MCP servers are down — related tools are unavailable",
|
||||
"alertDismiss": "Dismiss (this session)",
|
||||
"openHighCountLabel": "Open high {{count}}",
|
||||
"allHandled": "All high severity handled",
|
||||
"viewVulns": "View vulnerabilities",
|
||||
"viewMonitor": "View monitor",
|
||||
"viewHitl": "Approve",
|
||||
"viewMcpManagement": "Manage MCP",
|
||||
"statusOpen": "Open",
|
||||
"statusConfirmed": "Confirmed",
|
||||
"statusFixed": "Fixed",
|
||||
"statusFalsePositive": "False positive",
|
||||
"fixRate": "Fix rate",
|
||||
"dataStale": "Data may be stale — please refresh",
|
||||
"recommendedActions": "Recommended Actions",
|
||||
"recommendedActionsHint": "Generated based on current state",
|
||||
"recoFixCritical_one": "Fix {{count}} open critical vulnerability",
|
||||
"recoFixCritical_other": "Fix {{count}} open critical vulnerabilities",
|
||||
"recoFixCriticalDesc": "Critical-level vulnerabilities should be addressed first",
|
||||
"recoApproveHitl_one": "Approve {{count}} HITL request",
|
||||
"recoApproveHitl_other": "Approve {{count}} HITL requests",
|
||||
"recoApproveHitlDesc": "Agent needs your decision to proceed",
|
||||
"recoRestartMcp_one": "Check {{count}} stopped External MCP",
|
||||
"recoRestartMcp_other": "Check {{count}} stopped External MCPs",
|
||||
"recoRestartMcpDesc": "Related tools are unavailable until MCP recovers",
|
||||
"recoCheckMonitor_one": "Investigate {{count}} failed tool call",
|
||||
"recoCheckMonitor_other": "Investigate {{count}} failed tool calls",
|
||||
"recoCheckMonitorDesc": "View failed request details in MCP monitor",
|
||||
"recoSetupMcp": "Configure your first MCP tool",
|
||||
"recoSetupMcpDesc": "Install MCP server before Agent can invoke specific capabilities",
|
||||
"recoStartScan": "Start a scan from chat",
|
||||
"recoStartScanDesc": "Describe your target in chat, AI will help execute",
|
||||
"recentEvents": "Recent Events",
|
||||
"eventUntitled": "Event",
|
||||
"externalMcpServers": "External MCP",
|
||||
"mcpAllRunning": "All running",
|
||||
"mcpPartialDown_one": "{{count}} stopped",
|
||||
"mcpPartialDown_other": "{{count}} stopped",
|
||||
"mcpAllDown": "All stopped",
|
||||
"noVulnDesc": "This list shows recent records; new results appear here when detection completes in chat",
|
||||
"startScanBtn": "Go to chat to scan"
|
||||
},
|
||||
"chat": {
|
||||
"newChat": "New chat",
|
||||
@@ -239,7 +326,20 @@
|
||||
},
|
||||
"hitl": {
|
||||
"pageTitle": "HITL approvals",
|
||||
"pendingTitle": "Pending approvals"
|
||||
"pendingTitle": "Pending approvals",
|
||||
"loading": "Loading...",
|
||||
"emptyState": "No pending approvals",
|
||||
"dismiss": "Dismiss",
|
||||
"conversationLabel": "Conversation:",
|
||||
"reviewEditHelp": "Review & edit mode: provide a JSON object to override tool arguments. Example: {\"command\":\"ls -la\"}",
|
||||
"approvalHelp": "Approval mode: only approve/reject, argument editing is disabled.",
|
||||
"commentHelp": "Comment (optional): briefly note the approval reason.",
|
||||
"commentPlaceholder": "e.g. allow read-only command",
|
||||
"reject": "Reject",
|
||||
"approve": "Approve",
|
||||
"loadFailed": "Failed to load",
|
||||
"invalidJson": "Invalid JSON arguments",
|
||||
"submitFailedPrefix": "Submit failed:"
|
||||
},
|
||||
"progress": {
|
||||
"callingAI": "Calling AI model...",
|
||||
@@ -576,6 +676,10 @@
|
||||
"addExternal": "Add external MCP",
|
||||
"toolConfig": "MCP tool config",
|
||||
"saveToolConfig": "Save tool config",
|
||||
"alwaysVisibleLabel": "Pinned",
|
||||
"alwaysVisibleHint": "Always keep visible in Tool Search results",
|
||||
"alwaysVisibleBuiltinLabel": "Builtin default",
|
||||
"alwaysVisibleBuiltinHint": "Backend builtin tool is pinned by default and cannot be disabled",
|
||||
"externalConfig": "External MCP config",
|
||||
"loadingTools": "Loading tools...",
|
||||
"loadToolsTimeout": "Tools load timeout. External MCP may be slow. Click Refresh to retry or check connection.",
|
||||
|
||||
@@ -20,7 +20,13 @@
|
||||
"copied": "已复制",
|
||||
"copyFailed": "复制失败",
|
||||
"view": "查看",
|
||||
"actions": "操作"
|
||||
"actions": "操作",
|
||||
"loadFailed": "加载失败",
|
||||
"untitled": "未命名",
|
||||
"justNow": "刚刚",
|
||||
"minutesAgo": "{{n}} 分钟前",
|
||||
"hoursAgo": "{{n}} 小时前",
|
||||
"daysAgo": "{{n}} 天前"
|
||||
},
|
||||
"header": {
|
||||
"title": "CyberStrikeAI",
|
||||
@@ -33,6 +39,13 @@
|
||||
"version": "当前版本",
|
||||
"toggleSidebar": "折叠/展开侧边栏"
|
||||
},
|
||||
"notifications": {
|
||||
"title": "事件通知",
|
||||
"empty": "暂无新事件",
|
||||
"markAllRead": "标记已读",
|
||||
"markSingleRead": "已读",
|
||||
"itemDefaultTitle": "通知"
|
||||
},
|
||||
"login": {
|
||||
"title": "登录 CyberStrikeAI",
|
||||
"subtitle": "请输入配置中的访问密码",
|
||||
@@ -81,6 +94,7 @@
|
||||
"severityMedium": "中危",
|
||||
"severityLow": "低危",
|
||||
"severityInfo": "信息",
|
||||
"totalVulns": "总漏洞数",
|
||||
"runOverview": "运行概览",
|
||||
"batchQueues": "批量任务队列",
|
||||
"pending": "待执行",
|
||||
@@ -107,7 +121,69 @@
|
||||
"toUse": "待使用",
|
||||
"active": "活跃",
|
||||
"highFreq": "高频",
|
||||
"noCallData": "暂无调用数据"
|
||||
"noCallData": "暂无调用数据",
|
||||
"lastUpdated": "上次更新",
|
||||
"viewAll": "查看全部 →",
|
||||
"recentVulns": "最近漏洞",
|
||||
"noVulnYet": "暂无最近漏洞",
|
||||
"capabilities": "能力总览",
|
||||
"mcpTools": "MCP 工具",
|
||||
"rolesLabel": "角色",
|
||||
"agentsLabel": "Agents",
|
||||
"webshellLabel": "WebShell",
|
||||
"pendingCountLabel": "{{count}} 待执行",
|
||||
"highCountLabel": "高危 {{count}}",
|
||||
"toolsCountLabel": "{{count}} 个工具",
|
||||
"failedNCalls": "{{count}} 次失败",
|
||||
"noCallYet": "暂无调用",
|
||||
"allClear": "暂无新增风险",
|
||||
"allIdle": "系统空闲",
|
||||
"executingNow": "正在执行",
|
||||
"healthyStatus": "运行平稳",
|
||||
"normalStatus": "基本正常",
|
||||
"degradedStatus": "需要关注",
|
||||
"alertTitle": "需要关注",
|
||||
"alertWarningTitle": "需要关注",
|
||||
"alertDangerTitle": "需要立即处理",
|
||||
"alertCriticalReason": "存在 {{count}} 个待处理的严重漏洞,建议立即处置",
|
||||
"alertFailedReason": "工具调用成功率偏低({{count}} 次失败),请检查 MCP 监控",
|
||||
"alertHitlReason": "有 {{count}} 个待审批的人机协同请求,Agent 正在等待你的决策",
|
||||
"alertMcpDownReason": "External MCP 服务器有 {{count}} 个未运行,相关工具不可用",
|
||||
"alertDismiss": "忽略此提醒(仅本次会话)",
|
||||
"openHighCountLabel": "待处理高危 {{count}}",
|
||||
"allHandled": "高严重度已全部处置",
|
||||
"viewVulns": "查看漏洞",
|
||||
"viewMonitor": "查看监控",
|
||||
"viewHitl": "前往审批",
|
||||
"viewMcpManagement": "管理 MCP",
|
||||
"statusOpen": "待处理",
|
||||
"statusConfirmed": "已确认",
|
||||
"statusFixed": "已修复",
|
||||
"statusFalsePositive": "误报",
|
||||
"fixRate": "修复率",
|
||||
"dataStale": "数据可能已过期,请手动刷新",
|
||||
"recommendedActions": "推荐操作",
|
||||
"recommendedActionsHint": "基于当前状态自动生成",
|
||||
"recoFixCritical": "修复 {{count}} 个待处理严重漏洞",
|
||||
"recoFixCriticalDesc": "严重等级的漏洞应优先处置",
|
||||
"recoApproveHitl": "审批 {{count}} 个 HITL 请求",
|
||||
"recoApproveHitlDesc": "Agent 正在等待你的决策才能继续",
|
||||
"recoRestartMcp": "检查 {{count}} 个未运行的 External MCP",
|
||||
"recoRestartMcpDesc": "相关工具在 MCP 服务恢复前不可用",
|
||||
"recoCheckMonitor": "排查 {{count}} 次工具调用失败",
|
||||
"recoCheckMonitorDesc": "在 MCP 监控中查看失败的请求详情",
|
||||
"recoSetupMcp": "配置首个 MCP 工具",
|
||||
"recoSetupMcpDesc": "安装 MCP 服务后 Agent 才能调用具体能力",
|
||||
"recoStartScan": "在对话中发起扫描",
|
||||
"recoStartScanDesc": "在对话中描述目标,让 AI 协助执行",
|
||||
"recentEvents": "最近事件",
|
||||
"eventUntitled": "事件",
|
||||
"externalMcpServers": "External MCP",
|
||||
"mcpAllRunning": "全部运行",
|
||||
"mcpPartialDown": "{{count}} 个未运行",
|
||||
"mcpAllDown": "全部未运行",
|
||||
"noVulnDesc": "此处展示近期漏洞记录;在对话中完成检测后,新结果会出现在这里",
|
||||
"startScanBtn": "前往对话发起扫描"
|
||||
},
|
||||
"chat": {
|
||||
"newChat": "新对话",
|
||||
@@ -239,7 +315,20 @@
|
||||
},
|
||||
"hitl": {
|
||||
"pageTitle": "人机协同审批",
|
||||
"pendingTitle": "待处理审批"
|
||||
"pendingTitle": "待处理审批",
|
||||
"loading": "加载中...",
|
||||
"emptyState": "暂无待审批项",
|
||||
"dismiss": "忽略",
|
||||
"conversationLabel": "会话:",
|
||||
"reviewEditHelp": "审查编辑模式:可填写 JSON 对象覆盖参数。示例:{\"command\":\"ls -la\"}",
|
||||
"approvalHelp": "审批模式:仅通过/拒绝,不支持改参。",
|
||||
"commentHelp": "备注(可选):建议写审批依据。",
|
||||
"commentPlaceholder": "例如:允许只读命令",
|
||||
"reject": "拒绝",
|
||||
"approve": "通过",
|
||||
"loadFailed": "加载失败",
|
||||
"invalidJson": "JSON 参数格式错误",
|
||||
"submitFailedPrefix": "提交失败:"
|
||||
},
|
||||
"progress": {
|
||||
"callingAI": "正在调用AI模型...",
|
||||
@@ -576,6 +665,10 @@
|
||||
"addExternal": "添加外部MCP",
|
||||
"toolConfig": "MCP 工具配置",
|
||||
"saveToolConfig": "保存工具配置",
|
||||
"alwaysVisibleLabel": "常驻",
|
||||
"alwaysVisibleHint": "始终常驻在 Tool Search 可见列表(不被 tool_search 隐藏)",
|
||||
"alwaysVisibleBuiltinLabel": "内置默认",
|
||||
"alwaysVisibleBuiltinHint": "后端内置工具默认常驻,不可关闭",
|
||||
"externalConfig": "外部 MCP 配置",
|
||||
"loadingTools": "正在加载工具列表...",
|
||||
"loadToolsTimeout": "加载工具列表超时,可能是外部MCP连接较慢。请点击\"刷新\"按钮重试,或检查外部MCP连接状态。",
|
||||
|
||||
+1157
-98
File diff suppressed because it is too large
Load Diff
+42
-16
@@ -7,6 +7,19 @@ function hitlModeNormalize(m) {
|
||||
return allowed.indexOf(v) >= 0 ? v : 'off';
|
||||
}
|
||||
|
||||
function hitlT(key, fallback, params) {
|
||||
const fullKey = 'hitl.' + key;
|
||||
try {
|
||||
if (typeof window.t === 'function') {
|
||||
const translated = window.t(fullKey, params || {});
|
||||
if (typeof translated === 'string' && translated && translated !== fullKey) {
|
||||
return translated;
|
||||
}
|
||||
}
|
||||
} catch (e) {}
|
||||
return fallback;
|
||||
}
|
||||
|
||||
function hitlEffectiveEnabled(cfg) {
|
||||
if (!cfg) return false;
|
||||
if (cfg.enabled === true) return true;
|
||||
@@ -36,6 +49,18 @@ function hitlSensitiveToolsToArray(config) {
|
||||
return [];
|
||||
}
|
||||
|
||||
function normalizeHitlTimeoutSeconds(v, fallback) {
|
||||
const n = Number(v);
|
||||
if (Number.isFinite(n)) {
|
||||
return n > 0 ? Math.floor(n) : 0;
|
||||
}
|
||||
const f = Number(fallback);
|
||||
if (Number.isFinite(f)) {
|
||||
return f > 0 ? Math.floor(f) : 0;
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
function getCurrentConversationIdForHitl() {
|
||||
if (typeof window.currentConversationId === 'string' && window.currentConversationId) {
|
||||
return window.currentConversationId;
|
||||
@@ -84,6 +109,7 @@ async function saveHitlConversationConfig(conversationId, config) {
|
||||
const mode = hitlModeNormalize(config.mode || 'off');
|
||||
const enabled = typeof config.enabled === 'boolean' ? config.enabled : (mode !== 'off');
|
||||
const sensitiveTools = hitlSensitiveToolsToArray(config);
|
||||
const timeoutSeconds = normalizeHitlTimeoutSeconds(config.timeoutSeconds, 0);
|
||||
const resp = await hitlApiFetch('/api/hitl/config', {
|
||||
method: 'PUT',
|
||||
credentials: 'same-origin',
|
||||
@@ -93,7 +119,7 @@ async function saveHitlConversationConfig(conversationId, config) {
|
||||
enabled: enabled,
|
||||
mode: mode,
|
||||
sensitiveTools: sensitiveTools,
|
||||
timeoutSeconds: config.timeoutSeconds || 300
|
||||
timeoutSeconds: timeoutSeconds
|
||||
})
|
||||
});
|
||||
if (!resp.ok) {
|
||||
@@ -126,7 +152,7 @@ async function syncHitlConfigFromServer(conversationId) {
|
||||
enabled: true,
|
||||
mode: localMode,
|
||||
sensitiveTools: localToolsStr.split(/[,\n\r]+/).map(function (s) { return s.trim(); }).filter(Boolean),
|
||||
timeoutSeconds: cfg.timeoutSeconds || 300
|
||||
timeoutSeconds: normalizeHitlTimeoutSeconds(cfg.timeoutSeconds, 0)
|
||||
};
|
||||
saveHitlConversationConfig(conversationId, {
|
||||
mode: localMode,
|
||||
@@ -146,7 +172,7 @@ async function syncHitlConfigFromServer(conversationId) {
|
||||
enabled: true,
|
||||
mode: glMode,
|
||||
sensitiveTools: glToolsStr.split(/[,\n\r]+/).map(function (s) { return s.trim(); }).filter(Boolean),
|
||||
timeoutSeconds: cfg.timeoutSeconds || 300
|
||||
timeoutSeconds: normalizeHitlTimeoutSeconds(cfg.timeoutSeconds, 0)
|
||||
};
|
||||
saveHitlConversationConfig(conversationId, {
|
||||
mode: glMode,
|
||||
@@ -265,7 +291,7 @@ async function followAgentRunAfterHitlDecision(conversationId) {
|
||||
async function refreshHitlPending() {
|
||||
const container = document.getElementById('hitl-pending-list');
|
||||
if (!container) return;
|
||||
container.innerHTML = '<div class="loading-spinner">Loading...</div>';
|
||||
container.innerHTML = '<div class="loading-spinner">' + escapeHtml(hitlT('loading', 'Loading...')) + '</div>';
|
||||
try {
|
||||
const resp = await hitlApiFetch('/api/hitl/pending', { credentials: 'same-origin' });
|
||||
if (!resp.ok) {
|
||||
@@ -274,7 +300,7 @@ async function refreshHitlPending() {
|
||||
const data = await resp.json();
|
||||
const items = Array.isArray(data.items) ? data.items : [];
|
||||
if (!items.length) {
|
||||
container.innerHTML = '<div class="empty-state">暂无待审批项</div>';
|
||||
container.innerHTML = '<div class="empty-state">' + escapeHtml(hitlT('emptyState', 'No pending approvals')) + '</div>';
|
||||
return;
|
||||
}
|
||||
container.innerHTML = items.map(function (item) {
|
||||
@@ -292,25 +318,25 @@ async function refreshHitlPending() {
|
||||
'<span class="hitl-tool-badge">' + escapeHtml(item.toolName || '-') + '</span>' +
|
||||
'<span class="hitl-mode-tag hitl-mode-tag--' + escapeHtml(mode) + '">' + escapeHtml(item.mode || '-') + '</span>' +
|
||||
'</div>' +
|
||||
'<button class="hitl-dismiss-btn" title="忽略" onclick="dismissHitlItem(' + qId + ')">×</button>' +
|
||||
'<button class="hitl-dismiss-btn" title="' + escapeHtml(hitlT('dismiss', 'Dismiss')) + '" onclick="dismissHitlItem(' + qId + ')">×</button>' +
|
||||
'</div>' +
|
||||
'<div class="hitl-pending-meta">会话:' + escapeHtml(item.conversationId || '-') + '</div>' +
|
||||
'<div class="hitl-pending-meta">' + escapeHtml(hitlT('conversationLabel', 'Conversation:')) + ' ' + escapeHtml(item.conversationId || '-') + '</div>' +
|
||||
'<pre class="hitl-pending-payload">' + escapeHtml(preview) + '</pre>' +
|
||||
(allowEdit
|
||||
? ('<div class="hitl-input-help">审查编辑模式:可填写 JSON 对象覆盖参数。示例:{"command":"ls -la"}</div>' +
|
||||
? ('<div class="hitl-input-help">' + escapeHtml(hitlT('reviewEditHelp', 'Review & edit mode: provide a JSON object to override tool arguments. Example: {"command":"ls -la"}')) + '</div>' +
|
||||
'<textarea id="hitl-edit-' + escId + '" class="hitl-edit-args" placeholder=\'{"command":"ls -la"}\'></textarea>')
|
||||
: '<div class="hitl-input-help">审批模式:仅通过/拒绝,不支持改参。</div>') +
|
||||
'<div class="hitl-input-help">备注(可选):建议写审批依据。</div>' +
|
||||
'<input id="hitl-comment-' + escId + '" class="hitl-config-input hitl-inline-comment" type="text" placeholder="例如:允许只读命令">' +
|
||||
: '<div class="hitl-input-help">' + escapeHtml(hitlT('approvalHelp', 'Approval mode: only approve/reject, argument editing is disabled.')) + '</div>') +
|
||||
'<div class="hitl-input-help">' + escapeHtml(hitlT('commentHelp', 'Comment (optional): briefly note the approval reason.')) + '</div>' +
|
||||
'<input id="hitl-comment-' + escId + '" class="hitl-config-input hitl-inline-comment" type="text" placeholder="' + escapeHtml(hitlT('commentPlaceholder', 'e.g. allow read-only command')) + '">' +
|
||||
'<div class="hitl-pending-actions">' +
|
||||
'<button class="btn-secondary" onclick="submitHitlDecision(' + qId + ',"reject",' + qConv + ')">拒绝</button>' +
|
||||
'<button class="btn-primary" onclick="submitHitlDecision(' + qId + ',"approve",' + qConv + ')">通过</button>' +
|
||||
'<button class="btn-secondary" onclick="submitHitlDecision(' + qId + ',"reject",' + qConv + ')">' + escapeHtml(hitlT('reject', 'Reject')) + '</button>' +
|
||||
'<button class="btn-primary" onclick="submitHitlDecision(' + qId + ',"approve",' + qConv + ')">' + escapeHtml(hitlT('approve', 'Approve')) + '</button>' +
|
||||
'</div>' +
|
||||
'</div>'
|
||||
);
|
||||
}).join('');
|
||||
} catch (e) {
|
||||
container.innerHTML = '<div class="empty-state">加载失败</div>';
|
||||
container.innerHTML = '<div class="empty-state">' + escapeHtml(hitlT('loadFailed', 'Failed to load')) + '</div>';
|
||||
}
|
||||
}
|
||||
|
||||
@@ -323,7 +349,7 @@ async function submitHitlDecision(interruptId, decision, conversationIdOpt) {
|
||||
try {
|
||||
editedArguments = JSON.parse(editBox.value.trim());
|
||||
} catch (e) {
|
||||
alert('JSON 参数格式错误');
|
||||
alert(hitlT('invalidJson', 'Invalid JSON arguments'));
|
||||
return;
|
||||
}
|
||||
}
|
||||
@@ -344,7 +370,7 @@ async function submitHitlDecisionWithPayload(interruptId, decision, comment, edi
|
||||
await dismissHitlItem(interruptId, true);
|
||||
return true;
|
||||
}
|
||||
alert('提交失败:' + errText);
|
||||
alert(hitlT('submitFailedPrefix', 'Submit failed:') + ' ' + errText);
|
||||
return false;
|
||||
}
|
||||
refreshHitlPending();
|
||||
|
||||
@@ -2348,9 +2348,28 @@ function renderActiveTasks(tasks) {
|
||||
bar.style.display = 'flex';
|
||||
bar.innerHTML = '';
|
||||
|
||||
function openActiveTaskConversation(conversationId) {
|
||||
if (!conversationId) return;
|
||||
if (typeof switchPage === 'function') {
|
||||
switchPage('chat');
|
||||
}
|
||||
if (typeof window.loadConversation === 'function') {
|
||||
setTimeout(function () {
|
||||
window.loadConversation(conversationId);
|
||||
}, 120);
|
||||
return;
|
||||
}
|
||||
window.location.hash = 'chat?conversation=' + encodeURIComponent(conversationId);
|
||||
}
|
||||
|
||||
normalizedTasks.forEach(task => {
|
||||
const item = document.createElement('div');
|
||||
item.className = 'active-task-item';
|
||||
item.className = 'active-task-item active-task-item-clickable';
|
||||
if (task && task.conversationId) {
|
||||
item.title = (typeof window.t === 'function' ? window.t('tasks.viewConversation') : '查看会话');
|
||||
item.setAttribute('role', 'button');
|
||||
item.onclick = () => openActiveTaskConversation(task.conversationId);
|
||||
}
|
||||
|
||||
const startedTime = task.startedAt ? new Date(task.startedAt) : null;
|
||||
const taskTimeLocale = getCurrentTimeLocale();
|
||||
@@ -2388,7 +2407,10 @@ function renderActiveTasks(tasks) {
|
||||
if (!isFinalStatus) {
|
||||
const cancelBtn = item.querySelector('.active-task-cancel');
|
||||
if (cancelBtn) {
|
||||
cancelBtn.onclick = () => cancelActiveTask(task.conversationId, cancelBtn);
|
||||
cancelBtn.onclick = (evt) => {
|
||||
evt.stopPropagation();
|
||||
cancelActiveTask(task.conversationId, cancelBtn);
|
||||
};
|
||||
if (task.status === 'cancelling') {
|
||||
cancelBtn.disabled = true;
|
||||
cancelBtn.textContent = typeof window.t === 'function' ? window.t('tasks.cancelling') : '取消中...';
|
||||
|
||||
@@ -0,0 +1,329 @@
|
||||
(function () {
|
||||
const STORAGE_LAST_SEEN_KEY = 'cyberstrike-notification-last-seen-at';
|
||||
const POLL_INTERVAL_ACTIVE_MS = 15000;
|
||||
const POLL_INTERVAL_HIDDEN_MS = 60000;
|
||||
const MAX_RENDER_ITEMS = 20;
|
||||
|
||||
const state = {
|
||||
inFlight: false,
|
||||
timerId: null,
|
||||
dropdownOpen: false,
|
||||
lastSeenAt: readLastSeenAt(),
|
||||
items: [],
|
||||
unreadCount: 0,
|
||||
};
|
||||
|
||||
function readLastSeenAt() {
|
||||
try {
|
||||
const raw = localStorage.getItem(STORAGE_LAST_SEEN_KEY);
|
||||
const n = Number(raw);
|
||||
if (Number.isFinite(n) && n > 0) return n;
|
||||
} catch (e) {
|
||||
console.warn('读取通知已读时间失败:', e);
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
function persistLastSeenAt(ts) {
|
||||
try {
|
||||
localStorage.setItem(STORAGE_LAST_SEEN_KEY, String(ts));
|
||||
} catch (e) {
|
||||
console.warn('保存通知已读时间失败:', e);
|
||||
}
|
||||
}
|
||||
|
||||
function getTimeMs(value) {
|
||||
if (!value) return 0;
|
||||
const d = new Date(value);
|
||||
const ms = d.getTime();
|
||||
return Number.isFinite(ms) ? ms : 0;
|
||||
}
|
||||
|
||||
function getLocale() {
|
||||
if (typeof window !== 'undefined') {
|
||||
if (typeof window.__locale === 'string' && window.__locale) {
|
||||
return window.__locale;
|
||||
}
|
||||
if (typeof window.currentLang === 'string' && window.currentLang) {
|
||||
return window.currentLang;
|
||||
}
|
||||
}
|
||||
return 'zh-CN';
|
||||
}
|
||||
|
||||
function formatTime(value) {
|
||||
const ms = getTimeMs(value);
|
||||
if (!ms) return '-';
|
||||
return new Date(ms).toLocaleString(getLocale());
|
||||
}
|
||||
|
||||
function htmlEscape(value) {
|
||||
if (typeof window.escapeHtml === 'function') {
|
||||
return window.escapeHtml(value == null ? '' : String(value));
|
||||
}
|
||||
const div = document.createElement('div');
|
||||
div.textContent = value == null ? '' : String(value);
|
||||
return div.innerHTML;
|
||||
}
|
||||
|
||||
function t(key, fallback, params) {
|
||||
if (typeof window !== 'undefined' && typeof window.t === 'function') {
|
||||
try {
|
||||
const translated = window.t(key, params || {});
|
||||
if (translated && translated !== key) return translated;
|
||||
} catch (_ignored) {}
|
||||
}
|
||||
return fallback;
|
||||
}
|
||||
|
||||
async function apiJson(url, options) {
|
||||
if (typeof window.apiFetch !== 'function') return null;
|
||||
const res = await window.apiFetch(url, options || {});
|
||||
if (!res.ok) return null;
|
||||
return res.json();
|
||||
}
|
||||
|
||||
async function fetchNotificationSummary() {
|
||||
const url = '/api/notifications/summary?since='
|
||||
+ encodeURIComponent(String(state.lastSeenAt || 0))
|
||||
+ '&limit=80&lang=' + encodeURIComponent(getLocale());
|
||||
try {
|
||||
const summary = await apiJson(url);
|
||||
if (summary && typeof summary === 'object') {
|
||||
return summary;
|
||||
}
|
||||
} catch (_ignored) {}
|
||||
return null;
|
||||
}
|
||||
|
||||
function renderBadge(count) {
|
||||
const badge = document.getElementById('notification-badge');
|
||||
const btn = document.getElementById('notification-bell-btn');
|
||||
if (!badge || !btn) return;
|
||||
if (count <= 0) {
|
||||
badge.style.display = 'none';
|
||||
btn.classList.remove('has-alert');
|
||||
return;
|
||||
}
|
||||
const text = count > 99 ? '99+' : String(count);
|
||||
badge.innerHTML = '<span class="notification-badge-text">' + htmlEscape(text) + '</span>';
|
||||
badge.style.display = 'inline-block';
|
||||
btn.classList.add('has-alert');
|
||||
}
|
||||
|
||||
function countP0(items) {
|
||||
return (Array.isArray(items) ? items : []).reduce((acc, item) => {
|
||||
if (!item || item.level !== 'p0') return acc;
|
||||
if (typeof item.count === 'number' && item.count > 0) return acc + item.count;
|
||||
return acc + 1;
|
||||
}, 0);
|
||||
}
|
||||
|
||||
function markableItems(items) {
|
||||
return (Array.isArray(items) ? items : []).filter(item => item && item.actionable !== true && item.id);
|
||||
}
|
||||
|
||||
function hasAction(item) {
|
||||
if (!item || !item.type) return false;
|
||||
if (item.type === 'vulnerability_created' && item.vulnerabilityId) return true;
|
||||
if ((item.type === 'task_completed' || item.type === 'long_running_tasks') && item.conversationId) return true;
|
||||
if (item.type === 'task_failed' && item.executionId) return true;
|
||||
if (item.type === 'hitl_pending') return true;
|
||||
return false;
|
||||
}
|
||||
|
||||
function openNotificationTarget(item) {
|
||||
if (!item || !item.type) return;
|
||||
if (item.type === 'vulnerability_created' && item.vulnerabilityId) {
|
||||
window.location.hash = 'vulnerabilities?id=' + encodeURIComponent(item.vulnerabilityId);
|
||||
return;
|
||||
}
|
||||
if ((item.type === 'task_completed' || item.type === 'long_running_tasks') && item.conversationId) {
|
||||
window.location.hash = 'chat?conversation=' + encodeURIComponent(item.conversationId);
|
||||
return;
|
||||
}
|
||||
if (item.type === 'task_failed' && item.executionId) {
|
||||
window.location.hash = 'mcp-monitor';
|
||||
setTimeout(function () {
|
||||
if (typeof showMCPDetail === 'function') {
|
||||
showMCPDetail(item.executionId);
|
||||
}
|
||||
}, 450);
|
||||
return;
|
||||
}
|
||||
if (item.type === 'hitl_pending') {
|
||||
window.location.hash = 'hitl';
|
||||
}
|
||||
}
|
||||
|
||||
async function markItemsRead(eventIds) {
|
||||
if (!Array.isArray(eventIds) || !eventIds.length) return true;
|
||||
const payload = { eventIds: eventIds };
|
||||
try {
|
||||
const result = await apiJson('/api/notifications/read', {
|
||||
method: 'POST',
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
body: JSON.stringify(payload),
|
||||
});
|
||||
return !!result;
|
||||
} catch (_ignored) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
function renderNotificationList(items) {
|
||||
const list = document.getElementById('notification-list');
|
||||
if (!list) return;
|
||||
const renderItems = Array.isArray(items) ? items.slice(0, MAX_RENDER_ITEMS) : [];
|
||||
if (!renderItems.length) {
|
||||
list.innerHTML = '<div class="notification-empty">' + htmlEscape(t('notifications.empty', '暂无新事件')) + '</div>';
|
||||
return;
|
||||
}
|
||||
const html = renderItems.map(item => {
|
||||
const canMarkRead = item.actionable !== true && !!item.id;
|
||||
const canView = hasAction(item);
|
||||
return `
|
||||
<div class="notification-item notification-level-${htmlEscape(item.level || 'p2')}">
|
||||
<div class="notification-item-header">
|
||||
<div class="notification-item-title">${htmlEscape(item.title || t('notifications.itemDefaultTitle', '通知'))}</div>
|
||||
<div class="notification-item-actions">
|
||||
${canView ? `<button class="notification-item-action-btn notification-item-view-btn" type="button" data-action-id="${htmlEscape(item.id || '')}">${htmlEscape(t('common.view', '查看'))}</button>` : ''}
|
||||
${canMarkRead ? `<button class="notification-item-action-btn notification-item-read-btn" type="button" data-notification-id="${htmlEscape(item.id)}">${htmlEscape(t('notifications.markSingleRead', '已读'))}</button>` : ''}
|
||||
</div>
|
||||
</div>
|
||||
<div class="notification-item-desc">${htmlEscape(item.desc || '')}</div>
|
||||
<div class="notification-item-time">${htmlEscape(formatTime(item.ts))}</div>
|
||||
</div>
|
||||
`;
|
||||
}).join('');
|
||||
list.innerHTML = html;
|
||||
const viewButtons = list.querySelectorAll('.notification-item-view-btn');
|
||||
viewButtons.forEach(btn => {
|
||||
btn.addEventListener('click', function (event) {
|
||||
event.preventDefault();
|
||||
event.stopPropagation();
|
||||
const eventID = btn.getAttribute('data-action-id') || '';
|
||||
if (!eventID) return;
|
||||
const item = state.items.find(it => it && it.id === eventID);
|
||||
if (!item) return;
|
||||
openNotificationTarget(item);
|
||||
closeDropdown();
|
||||
});
|
||||
});
|
||||
const readButtons = list.querySelectorAll('.notification-item-read-btn');
|
||||
readButtons.forEach(btn => {
|
||||
btn.addEventListener('click', async function (event) {
|
||||
event.preventDefault();
|
||||
event.stopPropagation();
|
||||
const eventID = btn.getAttribute('data-notification-id') || '';
|
||||
if (!eventID) return;
|
||||
const ok = await markItemsRead([eventID]);
|
||||
if (ok) {
|
||||
await refreshNotifications();
|
||||
}
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
function closeDropdown() {
|
||||
const dropdown = document.getElementById('notification-dropdown');
|
||||
const bellBtn = document.getElementById('notification-bell-btn');
|
||||
if (dropdown) dropdown.style.display = 'none';
|
||||
if (bellBtn) bellBtn.classList.remove('active');
|
||||
state.dropdownOpen = false;
|
||||
}
|
||||
|
||||
function markSeenNow() {
|
||||
state.lastSeenAt = Date.now();
|
||||
persistLastSeenAt(state.lastSeenAt);
|
||||
}
|
||||
|
||||
async function refreshNotifications() {
|
||||
if (state.inFlight) return;
|
||||
state.inFlight = true;
|
||||
try {
|
||||
const summary = await fetchNotificationSummary();
|
||||
const items = summary && Array.isArray(summary.items) ? summary.items : [];
|
||||
state.items = items;
|
||||
const unreadCount = summary && Number.isFinite(Number(summary.unreadCount))
|
||||
? Number(summary.unreadCount)
|
||||
: countP0(items);
|
||||
state.unreadCount = Math.max(0, unreadCount);
|
||||
renderBadge(state.unreadCount);
|
||||
renderNotificationList(items);
|
||||
} catch (e) {
|
||||
console.warn('刷新通知失败:', e);
|
||||
} finally {
|
||||
state.inFlight = false;
|
||||
}
|
||||
}
|
||||
|
||||
function scheduleNextPoll() {
|
||||
if (state.timerId) {
|
||||
window.clearTimeout(state.timerId);
|
||||
state.timerId = null;
|
||||
}
|
||||
const interval = document.hidden ? POLL_INTERVAL_HIDDEN_MS : POLL_INTERVAL_ACTIVE_MS;
|
||||
state.timerId = window.setTimeout(async function () {
|
||||
await refreshNotifications();
|
||||
scheduleNextPoll();
|
||||
}, interval);
|
||||
}
|
||||
|
||||
function handleDocumentClick(event) {
|
||||
const container = document.querySelector('.notification-menu-container');
|
||||
if (!container) return;
|
||||
if (!container.contains(event.target)) {
|
||||
closeDropdown();
|
||||
}
|
||||
}
|
||||
|
||||
async function toggleDropdown() {
|
||||
const dropdown = document.getElementById('notification-dropdown');
|
||||
const bellBtn = document.getElementById('notification-bell-btn');
|
||||
if (!dropdown || !bellBtn) return;
|
||||
const isOpen = dropdown.style.display !== 'none';
|
||||
if (isOpen) {
|
||||
closeDropdown();
|
||||
return;
|
||||
}
|
||||
// 从仪表盘「查看全部」等容器外入口打开时,同一 click 会冒泡到 document,
|
||||
// handleDocumentClick 会误判为「点在外面」并立刻关掉。推迟到宏任务再展开即可。
|
||||
const runOpen = async function () {
|
||||
if (dropdown.style.display !== 'none') return;
|
||||
dropdown.style.display = 'block';
|
||||
bellBtn.classList.add('active');
|
||||
state.dropdownOpen = true;
|
||||
await refreshNotifications();
|
||||
};
|
||||
window.setTimeout(function () {
|
||||
void runOpen();
|
||||
}, 0);
|
||||
}
|
||||
|
||||
async function markAllSeen() {
|
||||
const ids = markableItems(state.items).map(item => item.id);
|
||||
const ok = await markItemsRead(ids);
|
||||
if (ok) {
|
||||
markSeenNow();
|
||||
await refreshNotifications();
|
||||
}
|
||||
}
|
||||
|
||||
function initNotifications() {
|
||||
const bellBtn = document.getElementById('notification-bell-btn');
|
||||
if (!bellBtn) return;
|
||||
document.addEventListener('click', handleDocumentClick);
|
||||
document.addEventListener('visibilitychange', scheduleNextPoll);
|
||||
document.addEventListener('languagechange', function () {
|
||||
refreshNotifications();
|
||||
});
|
||||
refreshNotifications();
|
||||
scheduleNextPoll();
|
||||
}
|
||||
|
||||
window.toggleNotificationDropdown = toggleDropdown;
|
||||
window.markAllNotificationsSeen = markAllSeen;
|
||||
|
||||
document.addEventListener('DOMContentLoaded', initNotifications);
|
||||
})();
|
||||
+35
-5
@@ -1,6 +1,28 @@
|
||||
// 角色管理相关功能
|
||||
function _t(key, opts) {
|
||||
return typeof window.t === 'function' ? window.t(key, opts) : key;
|
||||
if (typeof window.t === 'function') {
|
||||
try {
|
||||
var translated = window.t(key, opts);
|
||||
if (typeof translated === 'string' && translated && translated !== key) {
|
||||
return translated;
|
||||
}
|
||||
} catch (e) { /* ignore */ }
|
||||
}
|
||||
// i18n 未就绪或词条缺失时避免把 key 暴露给用户(与 zh-CN 默认一致)
|
||||
if (key === 'roles.noDescription') return '暂无描述';
|
||||
if (key === 'roles.noDescriptionShort') return '无描述';
|
||||
if (key === 'roles.defaultRoleDescription') {
|
||||
return '默认角色,不额外携带用户提示词,使用默认MCP';
|
||||
}
|
||||
return key;
|
||||
}
|
||||
|
||||
/** 角色配置中的描述:trim,并把误存为 i18n key 的字面量视为空 */
|
||||
function rolePlainDescription(role) {
|
||||
const raw = typeof role.description === 'string' ? role.description.trim() : '';
|
||||
if (!raw) return '';
|
||||
if (raw === 'roles.noDescription' || raw === 'roles.noDescriptionShort') return '';
|
||||
return raw;
|
||||
}
|
||||
let currentRole = localStorage.getItem('currentRole') || '';
|
||||
let roles = [];
|
||||
@@ -56,6 +78,11 @@ function sortRoles(rolesArray) {
|
||||
|
||||
// 加载所有角色
|
||||
async function loadRoles() {
|
||||
if (window.i18nReady && typeof window.i18nReady.then === 'function') {
|
||||
try {
|
||||
await window.i18nReady;
|
||||
} catch (e) { /* ignore */ }
|
||||
}
|
||||
try {
|
||||
const response = await apiFetch('/api/roles');
|
||||
if (!response.ok) {
|
||||
@@ -189,8 +216,9 @@ function renderRoleSelectionSidebar() {
|
||||
const icon = getRoleIcon(role);
|
||||
|
||||
// 处理默认角色的描述
|
||||
let description = role.description || _t('roles.noDescription');
|
||||
if (isDefaultRole && !role.description) {
|
||||
const plainDesc = rolePlainDescription(role);
|
||||
let description = plainDesc || _t('roles.noDescription');
|
||||
if (isDefaultRole && !plainDesc) {
|
||||
description = _t('roles.defaultRoleDescription');
|
||||
}
|
||||
|
||||
@@ -316,6 +344,7 @@ function renderRolesList() {
|
||||
const sortedRoles = sortRoles(filteredRoles);
|
||||
|
||||
rolesList.innerHTML = sortedRoles.map(role => {
|
||||
const plainDesc = rolePlainDescription(role);
|
||||
// 获取角色图标,如果是Unicode转义格式则转换为emoji
|
||||
let roleIcon = role.icon || '👤';
|
||||
if (roleIcon && typeof roleIcon === 'string') {
|
||||
@@ -369,7 +398,7 @@ function renderRolesList() {
|
||||
${role.enabled !== false ? _t('roles.enabled') : _t('roles.disabled')}
|
||||
</span>
|
||||
</div>
|
||||
<div class="role-card-description">${escapeHtml(role.description || _t('roles.noDescriptionShort'))}</div>
|
||||
<div class="role-card-description">${escapeHtml(plainDesc || _t('roles.noDescriptionShort'))}</div>
|
||||
<div class="role-card-tools">
|
||||
<span class="role-card-tools-label">${_t('roleModal.toolsLabel')}</span>
|
||||
<span class="role-card-tools-value">${toolsDisplay}</span>
|
||||
@@ -1575,9 +1604,10 @@ document.addEventListener('DOMContentLoaded', () => {
|
||||
updateRoleSelectorDisplay();
|
||||
});
|
||||
|
||||
// 语言切换后刷新角色选择器显示(默认/自定义角色名)
|
||||
// 语言切换后刷新角色选择器与「选择角色」列表文案
|
||||
document.addEventListener('languagechange', () => {
|
||||
updateRoleSelectorDisplay();
|
||||
renderRoleSelectionSidebar();
|
||||
});
|
||||
|
||||
// 获取当前选中的角色(供chat.js使用)
|
||||
|
||||
+26
-14
@@ -301,26 +301,38 @@ async function initPage(pageId) {
|
||||
break;
|
||||
case 'mcp-management':
|
||||
// 初始化MCP管理
|
||||
const startLoadMcpTools = () => {
|
||||
// 加载工具列表(MCP工具配置已移到MCP管理页面)
|
||||
// 使用异步加载,避免阻塞页面渲染
|
||||
if (typeof loadToolsList === 'function') {
|
||||
// 确保工具分页设置已初始化
|
||||
if (typeof getToolsPageSize === 'function' && typeof toolsPagination !== 'undefined') {
|
||||
toolsPagination.pageSize = getToolsPageSize();
|
||||
}
|
||||
// 延迟加载,让页面先渲染
|
||||
setTimeout(() => {
|
||||
loadToolsList(1, '').catch(err => {
|
||||
console.error('加载工具列表失败:', err);
|
||||
});
|
||||
}, 100);
|
||||
}
|
||||
};
|
||||
// 先拉取全局配置,确保 tool_search 常驻状态按后端生效集合展示
|
||||
if (typeof loadConfig === 'function') {
|
||||
loadConfig(false)
|
||||
.catch(err => {
|
||||
console.warn('加载配置失败(将继续加载工具列表):', err);
|
||||
})
|
||||
.finally(startLoadMcpTools);
|
||||
} else {
|
||||
startLoadMcpTools();
|
||||
}
|
||||
// 先加载外部MCP列表(快速),然后加载工具列表
|
||||
if (typeof loadExternalMCPs === 'function') {
|
||||
loadExternalMCPs().catch(err => {
|
||||
console.warn('加载外部MCP列表失败:', err);
|
||||
});
|
||||
}
|
||||
// 加载工具列表(MCP工具配置已移到MCP管理页面)
|
||||
// 使用异步加载,避免阻塞页面渲染
|
||||
if (typeof loadToolsList === 'function') {
|
||||
// 确保工具分页设置已初始化
|
||||
if (typeof getToolsPageSize === 'function' && typeof toolsPagination !== 'undefined') {
|
||||
toolsPagination.pageSize = getToolsPageSize();
|
||||
}
|
||||
// 延迟加载,让页面先渲染
|
||||
setTimeout(() => {
|
||||
loadToolsList(1, '').catch(err => {
|
||||
console.error('加载工具列表失败:', err);
|
||||
});
|
||||
}, 100);
|
||||
}
|
||||
break;
|
||||
case 'vulnerabilities':
|
||||
// 初始化漏洞管理页面
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
// 设置相关功能
|
||||
let currentConfig = null;
|
||||
let allTools = [];
|
||||
let alwaysVisibleToolNames = new Set();
|
||||
let alwaysVisibleBuiltinToolNames = new Set();
|
||||
// 全局工具状态映射,用于保存用户在所有页面的修改
|
||||
// key: 唯一工具标识符(toolKey),value: { enabled: boolean, is_external: boolean, external_mcp: string }
|
||||
let toolStateMap = new Map();
|
||||
@@ -100,6 +102,14 @@ async function loadConfig(loadTools = true) {
|
||||
}
|
||||
|
||||
currentConfig = await response.json();
|
||||
const alwaysVisibleList = currentConfig?.multi_agent?.tool_search_always_visible_effective_tools;
|
||||
const alwaysVisibleConfigured = currentConfig?.multi_agent?.tool_search_always_visible_tools;
|
||||
alwaysVisibleToolNames = new Set(Array.isArray(alwaysVisibleList) ? alwaysVisibleList.filter(Boolean) : []);
|
||||
alwaysVisibleBuiltinToolNames = new Set(
|
||||
alwaysVisibleToolNames.size > 0 && Array.isArray(alwaysVisibleConfigured)
|
||||
? Array.from(alwaysVisibleToolNames).filter(name => !alwaysVisibleConfigured.includes(name))
|
||||
: []
|
||||
);
|
||||
|
||||
// 填充OpenAI配置
|
||||
const providerEl = document.getElementById('openai-provider');
|
||||
@@ -395,10 +405,13 @@ async function loadToolsList(page = 1, searchKeyword = '') {
|
||||
}
|
||||
}
|
||||
|
||||
// 每行有两类复选框:行首「启用工具」与名称旁「常驻」;统计/全选只应针对行首启用复选框
|
||||
const TOOL_ENABLE_CHECKBOX_SELECTOR = '#tools-list .tool-item > input[type="checkbox"]';
|
||||
|
||||
// 保存当前页的工具状态到全局映射
|
||||
function saveCurrentPageToolStates() {
|
||||
document.querySelectorAll('#tools-list .tool-item').forEach(item => {
|
||||
const checkbox = item.querySelector('input[type="checkbox"]');
|
||||
const checkbox = item.querySelector(':scope > input[type="checkbox"]');
|
||||
const toolKey = item.dataset.toolKey; // 使用唯一标识符
|
||||
const toolName = item.dataset.toolName;
|
||||
const isExternal = item.dataset.isExternal === 'true';
|
||||
@@ -498,6 +511,8 @@ function renderToolsList() {
|
||||
is_external: tool.is_external || false,
|
||||
external_mcp: tool.external_mcp || ''
|
||||
};
|
||||
const alwaysVisibleChecked = alwaysVisibleToolNames.has(tool.name);
|
||||
const alwaysVisibleLocked = alwaysVisibleBuiltinToolNames.has(tool.name);
|
||||
|
||||
// 外部工具标签,显示来源信息(可点击跳转到对应 MCP 卡片)
|
||||
let externalBadge = '';
|
||||
@@ -521,6 +536,11 @@ function renderToolsList() {
|
||||
<div class="tool-item-name">
|
||||
${escapeHtml(tool.name)}
|
||||
${externalBadge}
|
||||
<label class="tool-resident-toggle" title="${typeof window.t === 'function' ? window.t('mcp.alwaysVisibleHint') : '始终常驻在 Tool Search 可见列表'}" onclick="event.stopPropagation()">
|
||||
<input type="checkbox" ${alwaysVisibleChecked ? 'checked' : ''} ${alwaysVisibleLocked ? 'disabled' : ''} onchange="handleToolAlwaysVisibleChange('${escapeHtml(tool.name)}', this.checked)" />
|
||||
<span>${typeof window.t === 'function' ? window.t('mcp.alwaysVisibleLabel') : '常驻'}</span>
|
||||
</label>
|
||||
${alwaysVisibleLocked ? `<span class="external-tool-badge" title="${typeof window.t === 'function' ? window.t('mcp.alwaysVisibleBuiltinHint') : '后端内置工具默认常驻,不可关闭'}">${typeof window.t === 'function' ? window.t('mcp.alwaysVisibleBuiltinLabel') : '内置默认'}</span>` : ''}
|
||||
<span class="tool-expand-icon">▶</span>
|
||||
</div>
|
||||
<div class="tool-item-desc">${escapeHtml(tool.description || (typeof window.t === 'function' ? window.t('mcp.noDescription') : '无描述'))}</div>
|
||||
@@ -716,9 +736,19 @@ function handleToolCheckboxChange(toolKey, enabled) {
|
||||
updateToolsStats();
|
||||
}
|
||||
|
||||
function handleToolAlwaysVisibleChange(toolName, alwaysVisible) {
|
||||
const name = (toolName || '').trim();
|
||||
if (!name) return;
|
||||
if (alwaysVisible) {
|
||||
alwaysVisibleToolNames.add(name);
|
||||
} else {
|
||||
alwaysVisibleToolNames.delete(name);
|
||||
}
|
||||
}
|
||||
|
||||
// 全选工具
|
||||
function selectAllTools() {
|
||||
document.querySelectorAll('#tools-list input[type="checkbox"]').forEach(checkbox => {
|
||||
document.querySelectorAll(TOOL_ENABLE_CHECKBOX_SELECTOR).forEach(checkbox => {
|
||||
checkbox.checked = true;
|
||||
// 更新全局状态映射
|
||||
const toolItem = checkbox.closest('.tool-item');
|
||||
@@ -742,7 +772,7 @@ function selectAllTools() {
|
||||
|
||||
// 全不选工具
|
||||
function deselectAllTools() {
|
||||
document.querySelectorAll('#tools-list input[type="checkbox"]').forEach(checkbox => {
|
||||
document.querySelectorAll(TOOL_ENABLE_CHECKBOX_SELECTOR).forEach(checkbox => {
|
||||
checkbox.checked = false;
|
||||
// 更新全局状态映射
|
||||
const toolItem = checkbox.closest('.tool-item');
|
||||
@@ -799,9 +829,9 @@ async function updateToolsStats() {
|
||||
// 先保存当前页的状态到全局映射
|
||||
saveCurrentPageToolStates();
|
||||
|
||||
// 计算当前页的启用工具数
|
||||
const currentPageEnabled = Array.from(document.querySelectorAll('#tools-list input[type="checkbox"]:checked')).length;
|
||||
const currentPageTotal = document.querySelectorAll('#tools-list input[type="checkbox"]').length;
|
||||
// 计算当前页的启用工具数(仅行首「启用」复选框,不含「常驻」)
|
||||
const currentPageEnabled = Array.from(document.querySelectorAll(`${TOOL_ENABLE_CHECKBOX_SELECTOR}:checked`)).length;
|
||||
const currentPageTotal = document.querySelectorAll(TOOL_ENABLE_CHECKBOX_SELECTOR).length;
|
||||
|
||||
// 计算所有工具的启用数
|
||||
let totalEnabled = 0;
|
||||
@@ -886,9 +916,11 @@ async function updateToolsStats() {
|
||||
}
|
||||
|
||||
const tStats = typeof window.t === 'function' ? window.t : (k) => k;
|
||||
const pinnedCount = alwaysVisibleToolNames.size;
|
||||
statsEl.innerHTML = `
|
||||
<span title="${tStats('mcp.currentPageEnabled')}">✅ ${tStats('mcp.currentPageEnabled')}: <strong>${currentPageEnabled}</strong> / ${currentPageTotal}</span>
|
||||
<span title="${tStats('mcp.totalEnabled')}">📊 ${tStats('mcp.totalEnabled')}: <strong>${totalEnabled}</strong> / ${totalTools}</span>
|
||||
<span title="${tStats('mcp.alwaysVisibleHint')}">📌 ${tStats('mcp.alwaysVisibleLabel')}: <strong>${pinnedCount}</strong></span>
|
||||
`;
|
||||
}
|
||||
|
||||
@@ -1230,6 +1262,13 @@ async function saveToolsConfig() {
|
||||
const config = {
|
||||
openai: currentConfig.openai || {},
|
||||
agent: currentConfig.agent || {},
|
||||
multi_agent: {
|
||||
enabled: currentConfig?.multi_agent?.enabled === true,
|
||||
robot_use_multi_agent: currentConfig?.multi_agent?.robot_use_multi_agent === true,
|
||||
batch_use_multi_agent: currentConfig?.multi_agent?.batch_use_multi_agent === true,
|
||||
plan_execute_loop_max_iterations: Number(currentConfig?.multi_agent?.plan_execute_loop_max_iterations || 0),
|
||||
tool_search_always_visible_tools: Array.from(alwaysVisibleToolNames).filter(name => !alwaysVisibleBuiltinToolNames.has(name))
|
||||
},
|
||||
tools: []
|
||||
};
|
||||
|
||||
|
||||
@@ -61,7 +61,7 @@ let vulnerabilityPagination = {
|
||||
totalPages: 1
|
||||
};
|
||||
|
||||
// 从地址栏 #vulnerabilities?conversation_id= / ?task_id= 同步筛选(对话菜单、任务管理联动)
|
||||
// 从地址栏 #vulnerabilities?conversation_id= / ?task_id= / ?id= 同步筛选(通知/对话菜单/任务管理联动)
|
||||
function syncVulnerabilityFiltersFromLocationHash() {
|
||||
const hash = window.location.hash.slice(1);
|
||||
const hashParts = hash.split('?');
|
||||
@@ -69,19 +69,27 @@ function syncVulnerabilityFiltersFromLocationHash() {
|
||||
return;
|
||||
}
|
||||
const params = new URLSearchParams(hashParts.slice(1).join('?'));
|
||||
const vid = (params.get('id') || '').trim();
|
||||
const cid = (params.get('conversation_id') || '').trim();
|
||||
const tid = (params.get('task_id') || '').trim();
|
||||
if (!cid && !tid) {
|
||||
if (!vid && !cid && !tid) {
|
||||
return;
|
||||
}
|
||||
|
||||
vulnerabilityFilters.id = '';
|
||||
vulnerabilityFilters.conversation_id = '';
|
||||
vulnerabilityFilters.task_id = '';
|
||||
const idEl = document.getElementById('vulnerability-id-filter');
|
||||
const convEl = document.getElementById('vulnerability-conversation-filter');
|
||||
const taskEl = document.getElementById('vulnerability-task-filter');
|
||||
if (idEl) idEl.value = '';
|
||||
if (convEl) convEl.value = '';
|
||||
if (taskEl) taskEl.value = '';
|
||||
|
||||
if (vid) {
|
||||
vulnerabilityFilters.id = vid;
|
||||
if (idEl) idEl.value = vid;
|
||||
}
|
||||
if (cid) {
|
||||
vulnerabilityFilters.conversation_id = cid;
|
||||
if (convEl) convEl.value = cid;
|
||||
@@ -334,6 +342,13 @@ function renderVulnerabilities(vulnerabilities) {
|
||||
if (typeof window.applyTranslations === 'function') {
|
||||
window.applyTranslations(listContainer);
|
||||
}
|
||||
|
||||
// 如果通过漏洞ID筛选且只返回一条记录,自动展开详情(提升“点击查看”的用户体验)
|
||||
if (vulnerabilities.length === 1 && vulnerabilityFilters.id && vulnerabilityFilters.id === vulnerabilities[0].id) {
|
||||
setTimeout(() => {
|
||||
toggleVulnerabilityDetails(vulnerabilities[0].id);
|
||||
}, 300);
|
||||
}
|
||||
}
|
||||
|
||||
// 渲染分页控件
|
||||
|
||||
+285
-80
@@ -63,6 +63,24 @@
|
||||
<div class="lang-option" data-lang="en-US" onclick="onLanguageSelect('en-US')">English</div>
|
||||
</div>
|
||||
</div>
|
||||
<div class="notification-menu-container">
|
||||
<button class="notification-btn" id="notification-bell-btn" onclick="toggleNotificationDropdown()" data-i18n="notifications.title" data-i18n-attr="title" data-i18n-skip-text="true" title="事件通知">
|
||||
<svg width="18" height="18" viewBox="0 0 24 24" fill="none" xmlns="http://www.w3.org/2000/svg">
|
||||
<path d="M18 8a6 6 0 0 0-12 0c0 7-3 9-3 9h18s-3-2-3-9" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"/>
|
||||
<path d="M13.73 21a2 2 0 0 1-3.46 0" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"/>
|
||||
</svg>
|
||||
<span class="notification-badge" id="notification-badge" style="display: none;">0</span>
|
||||
</button>
|
||||
<div id="notification-dropdown" class="notification-dropdown" style="display: none;">
|
||||
<div class="notification-dropdown-header">
|
||||
<span id="notification-dropdown-title" data-i18n="notifications.title">事件通知</span>
|
||||
<button class="notification-mark-read-btn" id="notification-mark-all-read-btn" type="button" onclick="markAllNotificationsSeen()" data-i18n="notifications.markAllRead">标记已读</button>
|
||||
</div>
|
||||
<div id="notification-list" class="notification-list">
|
||||
<div class="notification-empty" data-i18n="notifications.empty">暂无新事件</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
<div class="user-menu-container">
|
||||
<button class="user-avatar-btn" onclick="toggleUserMenu()" data-i18n="header.userMenu" data-i18n-attr="title" data-i18n-skip-text="true" title="用户菜单">
|
||||
<svg width="20" height="20" viewBox="0 0 24 24" fill="none" xmlns="http://www.w3.org/2000/svg">
|
||||
@@ -287,41 +305,207 @@
|
||||
<div class="page-header">
|
||||
<h2 data-i18n="dashboard.title">仪表盘</h2>
|
||||
<div class="page-header-actions">
|
||||
<span class="dashboard-last-updated" id="dashboard-last-updated" aria-live="polite">
|
||||
<svg class="dashboard-last-updated-icon" width="14" height="14" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round" aria-hidden="true"><circle cx="12" cy="12" r="10"/><polyline points="12 6 12 12 16 14"/></svg>
|
||||
<span data-i18n="dashboard.lastUpdated">上次更新</span>
|
||||
<span class="dashboard-last-updated-time" id="dashboard-last-updated-time">-</span>
|
||||
<span class="dashboard-last-updated-stale" id="dashboard-last-updated-stale" hidden data-i18n="dashboard.dataStale" data-i18n-attr="title" title="数据可能已过期">
|
||||
<svg width="12" height="12" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2.2" stroke-linecap="round" stroke-linejoin="round" aria-hidden="true"><path d="M10.29 3.86L1.82 18a2 2 0 0 0 1.71 3h16.94a2 2 0 0 0 1.71-3L13.71 3.86a2 2 0 0 0-3.42 0z"/><line x1="12" y1="9" x2="12" y2="13"/><line x1="12" y1="17" x2="12.01" y2="17"/></svg>
|
||||
</span>
|
||||
</span>
|
||||
<button class="btn-secondary" onclick="refreshDashboard()" data-i18n="dashboard.refreshData" data-i18n-attr="title" title="刷新数据"><span data-i18n="common.refresh">刷新</span></button>
|
||||
</div>
|
||||
</div>
|
||||
<div class="dashboard-content">
|
||||
<!-- 第一行:核心 KPI(仪表盘最佳实践:关键指标置顶) -->
|
||||
<!-- 关键提醒条(仅当存在严重风险时渲染,默认 hidden);右侧 × 可在 session 内忽略 -->
|
||||
<div class="dashboard-alert-banner" id="dashboard-alert-banner" hidden>
|
||||
<span class="dashboard-alert-icon" aria-hidden="true">
|
||||
<svg width="20" height="20" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"><path d="M10.29 3.86L1.82 18a2 2 0 0 0 1.71 3h16.94a2 2 0 0 0 1.71-3L13.71 3.86a2 2 0 0 0-3.42 0z"/><line x1="12" y1="9" x2="12" y2="13"/><line x1="12" y1="17" x2="12.01" y2="17"/></svg>
|
||||
</span>
|
||||
<div class="dashboard-alert-content">
|
||||
<div class="dashboard-alert-title" id="dashboard-alert-title" data-i18n="dashboard.alertTitle">需要关注</div>
|
||||
<div class="dashboard-alert-desc" id="dashboard-alert-desc"></div>
|
||||
</div>
|
||||
<div class="dashboard-alert-actions" id="dashboard-alert-actions"></div>
|
||||
<button type="button" class="dashboard-alert-close" id="dashboard-alert-close" data-i18n="dashboard.alertDismiss" data-i18n-attr="title" data-i18n-skip-text="true" title="忽略此提醒(仅本次会话)" aria-label="dismiss">
|
||||
<svg width="14" height="14" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2.4" stroke-linecap="round" stroke-linejoin="round" aria-hidden="true"><line x1="18" y1="6" x2="6" y2="18"/><line x1="6" y1="6" x2="18" y2="18"/></svg>
|
||||
</button>
|
||||
</div>
|
||||
<!-- 第一行:核心 KPI(关键指标置顶 + 副标徽章承载次级信息) -->
|
||||
<div class="dashboard-kpi-row" id="dashboard-cards">
|
||||
<div class="dashboard-kpi-card" role="button" tabindex="0" onclick="switchPage('tasks')" onkeydown="if(event.key==='Enter'||event.key===' ') { event.preventDefault(); switchPage('tasks'); }" data-i18n="dashboard.clickToViewTasks" data-i18n-attr="title" title="点击查看任务管理"> <div class="dashboard-kpi-value" id="dashboard-running-tasks">-</div><div class="dashboard-kpi-label" data-i18n="dashboard.runningTasks">运行中任务</div></div>
|
||||
<div class="dashboard-kpi-card" role="button" tabindex="0" onclick="switchPage('vulnerabilities')" onkeydown="if(event.key==='Enter'||event.key===' ') { event.preventDefault(); switchPage('vulnerabilities'); }" data-i18n="dashboard.clickToViewVuln" data-i18n-attr="title" title="点击查看漏洞管理"><div class="dashboard-kpi-value" id="dashboard-vuln-total">-</div><div class="dashboard-kpi-label" data-i18n="dashboard.vulnTotal">漏洞总数</div></div>
|
||||
<div class="dashboard-kpi-card" role="button" tabindex="0" onclick="switchPage('mcp-monitor')" onkeydown="if(event.key==='Enter'||event.key===' ') { event.preventDefault(); switchPage('mcp-monitor'); }" data-i18n="dashboard.clickToViewMCP" data-i18n-attr="title" title="点击查看 MCP 监控"><div class="dashboard-kpi-value" id="dashboard-kpi-tools-calls">-</div><div class="dashboard-kpi-label" data-i18n="dashboard.toolCalls">工具调用次数</div></div>
|
||||
<div class="dashboard-kpi-card" role="button" tabindex="0" onclick="switchPage('mcp-monitor')" onkeydown="if(event.key==='Enter'||event.key===' ') { event.preventDefault(); switchPage('mcp-monitor'); }" data-i18n="dashboard.clickToViewMCP" data-i18n-attr="title" title="点击查看 MCP 监控"><div class="dashboard-kpi-value" id="dashboard-kpi-success-rate">-</div><div class="dashboard-kpi-label" data-i18n="dashboard.successRate">工具执行成功率</div></div>
|
||||
<div class="dashboard-kpi-card" role="button" tabindex="0" onclick="switchPage('tasks')" onkeydown="if(event.key==='Enter'||event.key===' ') { event.preventDefault(); switchPage('tasks'); }" data-i18n="dashboard.clickToViewTasks" data-i18n-attr="title" title="点击查看任务管理">
|
||||
<div class="dashboard-kpi-head">
|
||||
<div class="dashboard-kpi-label" data-i18n="dashboard.runningTasks">运行中任务</div>
|
||||
<span class="dashboard-kpi-icon dashboard-kpi-icon-tasks" aria-hidden="true"><svg width="18" height="18" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"><path d="M12 2v4"/><path d="M12 18v4"/><path d="M4.93 4.93l2.83 2.83"/><path d="M16.24 16.24l2.83 2.83"/><path d="M2 12h4"/><path d="M18 12h4"/><path d="M4.93 19.07l2.83-2.83"/><path d="M16.24 7.76l2.83-2.83"/></svg></span>
|
||||
</div>
|
||||
<div class="dashboard-kpi-value" id="dashboard-running-tasks">-</div>
|
||||
<div class="dashboard-kpi-sub" id="dashboard-kpi-tasks-sub">
|
||||
<span class="dashboard-kpi-sub-text" id="dashboard-kpi-tasks-sub-text">-</span>
|
||||
</div>
|
||||
</div>
|
||||
<div class="dashboard-kpi-card" role="button" tabindex="0" onclick="switchPage('vulnerabilities')" onkeydown="if(event.key==='Enter'||event.key===' ') { event.preventDefault(); switchPage('vulnerabilities'); }" data-i18n="dashboard.clickToViewVuln" data-i18n-attr="title" title="点击查看漏洞管理">
|
||||
<div class="dashboard-kpi-head">
|
||||
<div class="dashboard-kpi-label" data-i18n="dashboard.vulnTotal">漏洞总数</div>
|
||||
<span class="dashboard-kpi-icon dashboard-kpi-icon-vuln" aria-hidden="true"><svg width="18" height="18" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"><path d="M12 22s8-4 8-10V5l-8-3-8 3v7c0 6 8 10 8 10z"/></svg></span>
|
||||
</div>
|
||||
<div class="dashboard-kpi-value" id="dashboard-vuln-total">-</div>
|
||||
<div class="dashboard-kpi-sub" id="dashboard-kpi-vuln-sub">
|
||||
<span class="dashboard-kpi-sub-badge dashboard-kpi-sub-badge-critical" id="dashboard-kpi-vuln-critical-badge" hidden>
|
||||
<span class="dashboard-kpi-sub-badge-dot"></span>
|
||||
<span data-i18n="dashboard.severityCritical">严重</span>
|
||||
<span id="dashboard-kpi-vuln-critical-count">0</span>
|
||||
</span>
|
||||
<span class="dashboard-kpi-sub-text" id="dashboard-kpi-vuln-sub-text" data-i18n="dashboard.allClear">暂无新增风险</span>
|
||||
</div>
|
||||
</div>
|
||||
<div class="dashboard-kpi-card" role="button" tabindex="0" onclick="switchPage('mcp-monitor')" onkeydown="if(event.key==='Enter'||event.key===' ') { event.preventDefault(); switchPage('mcp-monitor'); }" data-i18n="dashboard.clickToViewMCP" data-i18n-attr="title" title="点击查看 MCP 监控">
|
||||
<div class="dashboard-kpi-head">
|
||||
<div class="dashboard-kpi-label" data-i18n="dashboard.toolCalls">工具调用次数</div>
|
||||
<span class="dashboard-kpi-icon dashboard-kpi-icon-calls" aria-hidden="true"><svg width="18" height="18" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"><polyline points="22 12 18 12 15 21 9 3 6 12 2 12"/></svg></span>
|
||||
</div>
|
||||
<div class="dashboard-kpi-value" id="dashboard-kpi-tools-calls">-</div>
|
||||
<div class="dashboard-kpi-sub">
|
||||
<span class="dashboard-kpi-sub-text" id="dashboard-kpi-tools-sub-text">-</span>
|
||||
</div>
|
||||
</div>
|
||||
<div class="dashboard-kpi-card" role="button" tabindex="0" onclick="switchPage('mcp-monitor')" onkeydown="if(event.key==='Enter'||event.key===' ') { event.preventDefault(); switchPage('mcp-monitor'); }" data-i18n="dashboard.clickToViewMCP" data-i18n-attr="title" title="点击查看 MCP 监控">
|
||||
<div class="dashboard-kpi-head">
|
||||
<div class="dashboard-kpi-label" data-i18n="dashboard.successRate">工具执行成功率</div>
|
||||
<span class="dashboard-kpi-icon dashboard-kpi-icon-rate" aria-hidden="true"><svg width="18" height="18" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"><polyline points="20 6 9 17 4 12"/></svg></span>
|
||||
</div>
|
||||
<div class="dashboard-kpi-value" id="dashboard-kpi-success-rate">-</div>
|
||||
<div class="dashboard-kpi-sub">
|
||||
<span class="dashboard-kpi-sub-text" id="dashboard-kpi-rate-sub-text" data-i18n="dashboard.healthyStatus">运行平稳</span>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
<!-- 两列主内容区 -->
|
||||
<div class="dashboard-grid">
|
||||
<div class="dashboard-main">
|
||||
<section class="dashboard-section dashboard-section-chart">
|
||||
<h3 class="dashboard-section-title" data-i18n="dashboard.severityDistribution">漏洞严重程度分布</h3>
|
||||
<div class="dashboard-chart-wrap">
|
||||
<div class="dashboard-stacked-bar" id="dashboard-stacked-bar">
|
||||
<span class="dashboard-bar-seg seg-critical" id="dashboard-bar-critical" style="width: 0%"></span>
|
||||
<span class="dashboard-bar-seg seg-high" id="dashboard-bar-high" style="width: 0%"></span>
|
||||
<span class="dashboard-bar-seg seg-medium" id="dashboard-bar-medium" style="width: 0%"></span>
|
||||
<span class="dashboard-bar-seg seg-low" id="dashboard-bar-low" style="width: 0%"></span>
|
||||
<span class="dashboard-bar-seg seg-info" id="dashboard-bar-info" style="width: 0%"></span>
|
||||
<div class="dashboard-section-header">
|
||||
<h3 class="dashboard-section-title" data-i18n="dashboard.severityDistribution">漏洞严重程度分布</h3>
|
||||
<a class="dashboard-section-link" onclick="switchPage('vulnerabilities')" data-i18n="dashboard.viewAll">查看全部 →</a>
|
||||
</div>
|
||||
<div class="dashboard-severity-wrap">
|
||||
<div class="dashboard-severity-chart">
|
||||
<svg class="dashboard-severity-donut" id="dashboard-severity-donut" viewBox="0 0 480 260" preserveAspectRatio="xMidYMid meet" aria-hidden="true">
|
||||
<g id="dashboard-severity-donut-track"></g>
|
||||
<g id="dashboard-severity-donut-segments"></g>
|
||||
<g id="dashboard-severity-donut-labels"></g>
|
||||
</svg>
|
||||
<div class="dashboard-severity-center">
|
||||
<div class="dashboard-severity-center-value" id="dashboard-severity-total">0</div>
|
||||
<div class="dashboard-severity-center-label" data-i18n="dashboard.totalVulns">总漏洞数</div>
|
||||
</div>
|
||||
</div>
|
||||
<div class="dashboard-legend" id="dashboard-vuln-bars">
|
||||
<div class="dashboard-legend-item"><span class="dashboard-legend-dot critical"></span><span class="dashboard-legend-label" data-i18n="dashboard.severityCritical">严重</span><span class="dashboard-legend-value" id="dashboard-severity-critical">0</span></div>
|
||||
<div class="dashboard-legend-item"><span class="dashboard-legend-dot high"></span><span class="dashboard-legend-label" data-i18n="dashboard.severityHigh">高危</span><span class="dashboard-legend-value" id="dashboard-severity-high">0</span></div>
|
||||
<div class="dashboard-legend-item"><span class="dashboard-legend-dot medium"></span><span class="dashboard-legend-label" data-i18n="dashboard.severityMedium">中危</span><span class="dashboard-legend-value" id="dashboard-severity-medium">0</span></div>
|
||||
<div class="dashboard-legend-item"><span class="dashboard-legend-dot low"></span><span class="dashboard-legend-label" data-i18n="dashboard.severityLow">低危</span><span class="dashboard-legend-value" id="dashboard-severity-low">0</span></div>
|
||||
<div class="dashboard-legend-item"><span class="dashboard-legend-dot info"></span><span class="dashboard-legend-label" data-i18n="dashboard.severityInfo">信息</span><span class="dashboard-legend-value" id="dashboard-severity-info">0</span></div>
|
||||
<div class="dashboard-severity-legend" id="dashboard-vuln-bars">
|
||||
<div class="dashboard-severity-legend-item">
|
||||
<span class="dashboard-severity-legend-dot critical"></span>
|
||||
<span class="dashboard-severity-legend-label" data-i18n="dashboard.severityCritical">严重</span>
|
||||
<span class="dashboard-severity-legend-value" id="dashboard-severity-critical">0</span>
|
||||
<span class="dashboard-severity-legend-pct" id="dashboard-severity-critical-pct">0%</span>
|
||||
</div>
|
||||
<div class="dashboard-severity-legend-item">
|
||||
<span class="dashboard-severity-legend-dot high"></span>
|
||||
<span class="dashboard-severity-legend-label" data-i18n="dashboard.severityHigh">高危</span>
|
||||
<span class="dashboard-severity-legend-value" id="dashboard-severity-high">0</span>
|
||||
<span class="dashboard-severity-legend-pct" id="dashboard-severity-high-pct">0%</span>
|
||||
</div>
|
||||
<div class="dashboard-severity-legend-item">
|
||||
<span class="dashboard-severity-legend-dot medium"></span>
|
||||
<span class="dashboard-severity-legend-label" data-i18n="dashboard.severityMedium">中危</span>
|
||||
<span class="dashboard-severity-legend-value" id="dashboard-severity-medium">0</span>
|
||||
<span class="dashboard-severity-legend-pct" id="dashboard-severity-medium-pct">0%</span>
|
||||
</div>
|
||||
<div class="dashboard-severity-legend-item">
|
||||
<span class="dashboard-severity-legend-dot low"></span>
|
||||
<span class="dashboard-severity-legend-label" data-i18n="dashboard.severityLow">低危</span>
|
||||
<span class="dashboard-severity-legend-value" id="dashboard-severity-low">0</span>
|
||||
<span class="dashboard-severity-legend-pct" id="dashboard-severity-low-pct">0%</span>
|
||||
</div>
|
||||
<div class="dashboard-severity-legend-item">
|
||||
<span class="dashboard-severity-legend-dot info"></span>
|
||||
<span class="dashboard-severity-legend-label" data-i18n="dashboard.severityInfo">信息</span>
|
||||
<span class="dashboard-severity-legend-value" id="dashboard-severity-info">0</span>
|
||||
<span class="dashboard-severity-legend-pct" id="dashboard-severity-info-pct">0%</span>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
<!-- 处置状态 + 修复进度(利用 by_status 数据,避免下半部分留白) -->
|
||||
<div class="dashboard-severity-status">
|
||||
<div class="dashboard-severity-status-grid">
|
||||
<div class="dashboard-severity-status-cell s-open" role="button" tabindex="0" onclick="switchPage('vulnerabilities')" onkeydown="if(event.key==='Enter'||event.key===' ') { event.preventDefault(); switchPage('vulnerabilities'); }">
|
||||
<span class="dashboard-severity-status-icon" aria-hidden="true">
|
||||
<svg width="16" height="16" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"><circle cx="12" cy="12" r="9"/><line x1="12" y1="8" x2="12" y2="12"/><line x1="12" y1="16" x2="12.01" y2="16"/></svg>
|
||||
</span>
|
||||
<div class="dashboard-severity-status-text">
|
||||
<span class="dashboard-severity-status-value" id="dashboard-status-open">0</span>
|
||||
<span class="dashboard-severity-status-label" data-i18n="dashboard.statusOpen">待处理</span>
|
||||
</div>
|
||||
</div>
|
||||
<div class="dashboard-severity-status-cell s-confirmed" role="button" tabindex="0" onclick="switchPage('vulnerabilities')" onkeydown="if(event.key==='Enter'||event.key===' ') { event.preventDefault(); switchPage('vulnerabilities'); }">
|
||||
<span class="dashboard-severity-status-icon" aria-hidden="true">
|
||||
<svg width="16" height="16" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"><path d="M22 11.08V12a10 10 0 1 1-5.93-9.14"/><polyline points="22 4 12 14.01 9 11.01"/></svg>
|
||||
</span>
|
||||
<div class="dashboard-severity-status-text">
|
||||
<span class="dashboard-severity-status-value" id="dashboard-status-confirmed">0</span>
|
||||
<span class="dashboard-severity-status-label" data-i18n="dashboard.statusConfirmed">已确认</span>
|
||||
</div>
|
||||
</div>
|
||||
<div class="dashboard-severity-status-cell s-fixed" role="button" tabindex="0" onclick="switchPage('vulnerabilities')" onkeydown="if(event.key==='Enter'||event.key===' ') { event.preventDefault(); switchPage('vulnerabilities'); }">
|
||||
<span class="dashboard-severity-status-icon" aria-hidden="true">
|
||||
<svg width="16" height="16" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"><path d="M12 22s8-4 8-10V5l-8-3-8 3v7c0 6 8 10 8 10z"/><polyline points="9 12 11 14 15 10"/></svg>
|
||||
</span>
|
||||
<div class="dashboard-severity-status-text">
|
||||
<span class="dashboard-severity-status-value" id="dashboard-status-fixed">0</span>
|
||||
<span class="dashboard-severity-status-label" data-i18n="dashboard.statusFixed">已修复</span>
|
||||
</div>
|
||||
</div>
|
||||
<div class="dashboard-severity-status-cell s-fp" role="button" tabindex="0" onclick="switchPage('vulnerabilities')" onkeydown="if(event.key==='Enter'||event.key===' ') { event.preventDefault(); switchPage('vulnerabilities'); }">
|
||||
<span class="dashboard-severity-status-icon" aria-hidden="true">
|
||||
<svg width="16" height="16" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"><circle cx="12" cy="12" r="10"/><line x1="4.93" y1="4.93" x2="19.07" y2="19.07"/></svg>
|
||||
</span>
|
||||
<div class="dashboard-severity-status-text">
|
||||
<span class="dashboard-severity-status-value" id="dashboard-status-fp">0</span>
|
||||
<span class="dashboard-severity-status-label" data-i18n="dashboard.statusFalsePositive">误报</span>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
<div class="dashboard-severity-progress">
|
||||
<div class="dashboard-severity-progress-meta">
|
||||
<span class="dashboard-severity-progress-title" data-i18n="dashboard.fixRate">修复率</span>
|
||||
<span class="dashboard-severity-progress-value">
|
||||
<span id="dashboard-fix-rate">0%</span>
|
||||
<span class="dashboard-severity-progress-detail" id="dashboard-fix-detail">(0 / 0)</span>
|
||||
</span>
|
||||
</div>
|
||||
<div class="dashboard-severity-progress-track" aria-hidden="true">
|
||||
<div class="dashboard-severity-progress-fixed" id="dashboard-fix-progress-fixed" style="width: 0%"></div>
|
||||
<div class="dashboard-severity-progress-confirmed" id="dashboard-fix-progress-confirmed" style="width: 0%"></div>
|
||||
</div>
|
||||
<div class="dashboard-severity-progress-legend">
|
||||
<span class="dashboard-severity-progress-legend-item"><span class="dashboard-severity-progress-legend-dot legend-fixed"></span><span data-i18n="dashboard.statusFixed">已修复</span></span>
|
||||
<span class="dashboard-severity-progress-legend-item"><span class="dashboard-severity-progress-legend-dot legend-confirmed"></span><span data-i18n="dashboard.statusConfirmed">已确认</span></span>
|
||||
<span class="dashboard-severity-progress-legend-item"><span class="dashboard-severity-progress-legend-dot legend-open"></span><span data-i18n="dashboard.statusOpen">待处理</span></span>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</section>
|
||||
<section class="dashboard-section dashboard-section-recent-vulns">
|
||||
<div class="dashboard-section-header">
|
||||
<h3 class="dashboard-section-title" data-i18n="dashboard.recentVulns">最近漏洞</h3>
|
||||
<a class="dashboard-section-link" onclick="switchPage('vulnerabilities')" data-i18n="dashboard.viewAll">查看全部 →</a>
|
||||
</div>
|
||||
<div class="dashboard-recent-vulns" id="dashboard-recent-vulns">
|
||||
<div class="dashboard-recent-vulns-empty" id="dashboard-recent-vulns-empty" data-i18n="dashboard.noVulnYet">暂无最近漏洞</div>
|
||||
</div>
|
||||
</section>
|
||||
<section class="dashboard-section dashboard-section-overview">
|
||||
<h3 class="dashboard-section-title" data-i18n="dashboard.runOverview">运行概览</h3>
|
||||
<div class="dashboard-section-header">
|
||||
<h3 class="dashboard-section-title" data-i18n="dashboard.batchQueues">批量任务队列</h3>
|
||||
<a class="dashboard-section-link" onclick="switchPage('tasks')" data-i18n="dashboard.viewAll">查看全部 →</a>
|
||||
</div>
|
||||
<div class="dashboard-overview-list">
|
||||
<div class="dashboard-overview-item dashboard-overview-item-batch" role="button" tabindex="0" onclick="switchPage('tasks')" onkeydown="if(event.key==='Enter'||event.key===' ') { event.preventDefault(); switchPage('tasks'); }">
|
||||
<span class="dashboard-overview-icon dashboard-overview-icon-batch" aria-hidden="true"><svg width="16" height="16" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2"><rect x="3" y="3" width="7" height="7"/><rect x="14" y="3" width="7" height="7"/><rect x="14" y="14" width="7" height="7"/><rect x="3" y="14" width="7" height="7"/></svg></span>
|
||||
@@ -356,80 +540,100 @@
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
<div class="dashboard-overview-item dashboard-overview-item-tools" role="button" tabindex="0" onclick="switchPage('mcp-monitor')" onkeydown="if(event.key==='Enter'||event.key===' ') { event.preventDefault(); switchPage('mcp-monitor'); }">
|
||||
<span class="dashboard-overview-icon dashboard-overview-icon-tools" aria-hidden="true"><svg width="16" height="16" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2"><path d="M14.7 6.3a1 1 0 0 0 0 1.4l1.6 1.6a1 1 0 0 0 1.4 0l3.77-3.77a6 6 0 0 1-7.94 7.94l-6.91 6.91a2.12 2.12 0 0 1-3-3l6.91-6.91a6 6 0 0 1 7.94-7.94l-3.76 3.76z"/></svg></span>
|
||||
<div class="dashboard-overview-content">
|
||||
<div class="dashboard-overview-header">
|
||||
<span class="dashboard-overview-label" data-i18n="dashboard.toolInvocations">工具调用</span>
|
||||
<span class="dashboard-overview-success-rate" id="dashboard-tools-success-rate">-</span>
|
||||
</div>
|
||||
<div class="dashboard-overview-value-group">
|
||||
<span class="dashboard-overview-value-large" id="dashboard-tools-calls">-</span>
|
||||
<span class="dashboard-overview-value-unit" data-i18n="dashboard.callsUnit">次调用</span>
|
||||
<span class="dashboard-overview-value-separator">·</span>
|
||||
<span class="dashboard-overview-value-normal" id="dashboard-tools-count">-</span>
|
||||
<span class="dashboard-overview-value-unit" data-i18n="dashboard.toolsUnit">个工具</span>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
<div class="dashboard-overview-item dashboard-overview-item-knowledge" role="button" tabindex="0" onclick="switchPage('knowledge-management')" onkeydown="if(event.key==='Enter'||event.key===' ') { event.preventDefault(); switchPage('knowledge-management'); }">
|
||||
<span class="dashboard-overview-icon dashboard-overview-icon-knowledge" aria-hidden="true"><svg width="16" height="16" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2"><path d="M4 19.5A2.5 2.5 0 0 1 6.5 17H20"></path><path d="M6.5 2H20v20H6.5A2.5 2.5 0 0 1 4 19.5v-15A2.5 2.5 0 0 1 6.5 2z"></path></svg></span>
|
||||
<div class="dashboard-overview-content">
|
||||
<div class="dashboard-overview-header">
|
||||
<span class="dashboard-overview-label" data-i18n="dashboard.knowledgeLabel">知识</span>
|
||||
<span class="dashboard-overview-status" id="dashboard-knowledge-status">-</span>
|
||||
</div>
|
||||
<div class="dashboard-overview-value-group">
|
||||
<span class="dashboard-overview-value-large" id="dashboard-knowledge-items">-</span>
|
||||
<span class="dashboard-overview-value-unit" data-i18n="dashboard.knowledgeItems">项知识</span>
|
||||
<span class="dashboard-overview-value-separator">·</span>
|
||||
<span class="dashboard-overview-value-normal" id="dashboard-knowledge-categories">-</span>
|
||||
<span class="dashboard-overview-value-unit" data-i18n="dashboard.categoriesUnit">个分类</span>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
<div class="dashboard-overview-item dashboard-overview-item-skills" role="button" tabindex="0" onclick="switchPage('skills-monitor')" onkeydown="if(event.key==='Enter'||event.key===' ') { event.preventDefault(); switchPage('skills-monitor'); }">
|
||||
<span class="dashboard-overview-icon dashboard-overview-icon-skills" aria-hidden="true"><svg width="16" height="16" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2"><path d="M14 2H6a2 2 0 0 0-2 2v16a2 2 0 0 0 2 2h12a2 2 0 0 0 2-2V8z"/><polyline points="14 2 14 8 20 8"/><line x1="16" y1="13" x2="8" y2="13"/><line x1="16" y1="17" x2="8" y2="17"/></svg></span>
|
||||
<div class="dashboard-overview-content">
|
||||
<div class="dashboard-overview-header">
|
||||
<span class="dashboard-overview-label" data-i18n="dashboard.skillsLabel">Skills</span>
|
||||
<span class="dashboard-overview-status" id="dashboard-skills-status">-</span>
|
||||
</div>
|
||||
<div class="dashboard-overview-value-group">
|
||||
<span class="dashboard-overview-value-large" id="dashboard-skills-calls">-</span>
|
||||
<span class="dashboard-overview-value-unit" data-i18n="dashboard.callsUnit">次调用</span>
|
||||
<span class="dashboard-overview-value-separator">·</span>
|
||||
<span class="dashboard-overview-value-normal" id="dashboard-skills-count">-</span>
|
||||
<span class="dashboard-overview-value-unit" data-i18n="dashboard.skillUnit">个 Skill</span>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</section>
|
||||
<section class="dashboard-section dashboard-section-quick dashboard-quick-inline">
|
||||
<h3 class="dashboard-section-title" data-i18n="dashboard.quickLinks">快捷入口</h3>
|
||||
<div class="dashboard-quick-links dashboard-quick-links-row">
|
||||
<a class="dashboard-quick-link" onclick="switchPage('chat')"><span class="dashboard-quick-icon"><svg width="16" height="16" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2"><path d="M21 15a2 2 0 0 1-2 2H7l-4 4V5a2 2 0 0 1 2-2h14a2 2 0 0 1 2 2z"></path></svg></span><span data-i18n="nav.chat">对话</span></a>
|
||||
<a class="dashboard-quick-link" onclick="switchPage('tasks')"><span class="dashboard-quick-icon"><svg width="16" height="16" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2"><path d="M9 11l3 3L22 4"></path><path d="M21 12v7a2 2 0 0 1-2 2H5a2 2 0 0 1-2-2V5a2 2 0 0 1 2-2h11"></path></svg></span><span data-i18n="nav.tasks">任务管理</span></a>
|
||||
<a class="dashboard-quick-link" onclick="switchPage('vulnerabilities')"><span class="dashboard-quick-icon"><svg width="16" height="16" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2"><path d="M12 22s8-4 8-10V5l-8-3-8 3v7c0 6 8 10 8 10z"></path></svg></span><span data-i18n="nav.vulnerabilities">漏洞管理</span></a>
|
||||
<a class="dashboard-quick-link" onclick="switchPage('mcp-management')"><span class="dashboard-quick-icon"><svg width="16" height="16" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2"><path d="M13 2L3 14h9l-1 8 10-12h-9l1-8z"></path></svg></span><span data-i18n="nav.mcpManagement">MCP 管理</span></a>
|
||||
<a class="dashboard-quick-link" onclick="switchPage('knowledge-management')"><span class="dashboard-quick-icon"><svg width="16" height="16" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2"><path d="M2 3h6a4 4 0 0 1 4 4v14a3 3 0 0 0-3-3H2z"></path><path d="M22 3h-6a4 4 0 0 0-4 4v14a3 3 0 0 1 3-3h7z"></path></svg></span><span data-i18n="nav.knowledgeManagement">知识管理</span></a>
|
||||
<a class="dashboard-quick-link" onclick="switchPage('skills-management')"><span class="dashboard-quick-icon"><svg width="16" height="16" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2"><path d="M14.5 2H6a2 2 0 0 0-2 2v16a2 2 0 0 0 2 2h12a2 2 0 0 0 2-2V7.5L14.5 2z"></path><polyline points="14 2 14 8 20 8"></polyline></svg></span><span data-i18n="nav.skillsManagement">Skills 管理</span></a>
|
||||
<a class="dashboard-quick-link" onclick="switchPage('roles-management')"><span class="dashboard-quick-icon"><svg width="16" height="16" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2"><path d="M17 21v-2a4 4 0 0 0-4-4H5a4 4 0 0 0-4 4v2"></path><circle cx="9" cy="7" r="4"></circle><path d="M23 21v-2a4 4 0 0 0-3-3.87"></path><path d="M16 3.13a4 4 0 0 1 0 7.75"></path></svg></span><span data-i18n="nav.rolesManagement">角色管理</span></a>
|
||||
<!-- 推荐操作:基于当前数据状态智能生成(如「修复 4 个待处理严重漏洞」「审批 2 个 HITL」),
|
||||
比纯静态导航更有意义;当没有任何推荐时整个 section 隐藏 -->
|
||||
<section class="dashboard-section dashboard-section-recommend" id="dashboard-section-recommend" hidden>
|
||||
<div class="dashboard-section-header">
|
||||
<h3 class="dashboard-section-title" data-i18n="dashboard.recommendedActions">推荐操作</h3>
|
||||
<span class="dashboard-section-hint" data-i18n="dashboard.recommendedActionsHint">基于当前状态自动生成</span>
|
||||
</div>
|
||||
<div class="dashboard-recommend-list" id="dashboard-recommend-list"></div>
|
||||
</section>
|
||||
</div>
|
||||
<div class="dashboard-side">
|
||||
<section class="dashboard-section dashboard-section-tools">
|
||||
<h3 class="dashboard-section-title" data-i18n="dashboard.toolsExecCount">工具执行次数</h3>
|
||||
<div class="dashboard-section-header">
|
||||
<h3 class="dashboard-section-title" data-i18n="dashboard.toolsExecCount">工具执行次数</h3>
|
||||
<a class="dashboard-section-link" onclick="switchPage('mcp-monitor')" data-i18n="dashboard.viewAll">查看全部 →</a>
|
||||
</div>
|
||||
<div class="dashboard-tools-chart-wrap">
|
||||
<div class="dashboard-tools-chart-placeholder" id="dashboard-tools-pie-placeholder" data-i18n="common.noData">暂无数据</div>
|
||||
<div class="dashboard-tools-bar-chart" id="dashboard-tools-bar-chart"></div>
|
||||
</div>
|
||||
</section>
|
||||
<!-- 最近事件:拉 /api/notifications/summary 取最新 3 条;空时整个隐藏 -->
|
||||
<section class="dashboard-section dashboard-section-events" id="dashboard-section-events" hidden>
|
||||
<div class="dashboard-section-header">
|
||||
<h3 class="dashboard-section-title" data-i18n="dashboard.recentEvents">最近事件</h3>
|
||||
<a class="dashboard-section-link" onclick="if(typeof toggleNotificationDropdown==='function') toggleNotificationDropdown()" data-i18n="dashboard.viewAll">查看全部 →</a>
|
||||
</div>
|
||||
<div class="dashboard-events-list" id="dashboard-events-list"></div>
|
||||
</section>
|
||||
<section class="dashboard-section dashboard-section-resources">
|
||||
<h3 class="dashboard-section-title" data-i18n="dashboard.capabilities">能力总览</h3>
|
||||
<div class="dashboard-resource-list" id="dashboard-resource-list">
|
||||
<a class="dashboard-resource-item" onclick="switchPage('mcp-management')" role="button" tabindex="0">
|
||||
<span class="dashboard-resource-icon dashboard-resource-icon-mcp" aria-hidden="true">
|
||||
<svg width="16" height="16" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"><path d="M13 2L3 14h9l-1 8 10-12h-9l1-8z"/></svg>
|
||||
</span>
|
||||
<span class="dashboard-resource-label" data-i18n="dashboard.mcpTools">MCP 工具</span>
|
||||
<span class="dashboard-resource-value" id="dashboard-resource-tools">-</span>
|
||||
</a>
|
||||
<!-- External MCP 服务器健康度:N 运行 / N 异常;只有配置过 External MCP 才显示 -->
|
||||
<a class="dashboard-resource-item" id="dashboard-resource-external-mcp-row" onclick="switchPage('mcp-management')" role="button" tabindex="0" hidden>
|
||||
<span class="dashboard-resource-icon dashboard-resource-icon-external" aria-hidden="true">
|
||||
<svg width="16" height="16" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"><circle cx="12" cy="12" r="10"/><line x1="2" y1="12" x2="22" y2="12"/><path d="M12 2a15.3 15.3 0 0 1 4 10 15.3 15.3 0 0 1-4 10 15.3 15.3 0 0 1-4-10 15.3 15.3 0 0 1 4-10z"/></svg>
|
||||
</span>
|
||||
<span class="dashboard-resource-label" data-i18n="dashboard.externalMcpServers">External MCP</span>
|
||||
<span class="dashboard-resource-value" id="dashboard-resource-external-mcp">
|
||||
<span id="dashboard-resource-external-mcp-text">-</span>
|
||||
<span class="dashboard-resource-health" id="dashboard-resource-external-mcp-health" hidden></span>
|
||||
</span>
|
||||
</a>
|
||||
<a class="dashboard-resource-item" onclick="switchPage('skills-management')" role="button" tabindex="0">
|
||||
<span class="dashboard-resource-icon dashboard-resource-icon-skills" aria-hidden="true">
|
||||
<svg width="16" height="16" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"><path d="M14.5 2H6a2 2 0 0 0-2 2v16a2 2 0 0 0 2 2h12a2 2 0 0 0 2-2V7.5L14.5 2z"/><polyline points="14 2 14 8 20 8"/></svg>
|
||||
</span>
|
||||
<span class="dashboard-resource-label" data-i18n="dashboard.skillsLabel">Skills</span>
|
||||
<span class="dashboard-resource-value" id="dashboard-resource-skills">-</span>
|
||||
</a>
|
||||
<a class="dashboard-resource-item" onclick="switchPage('knowledge-management')" role="button" tabindex="0">
|
||||
<span class="dashboard-resource-icon dashboard-resource-icon-knowledge" aria-hidden="true">
|
||||
<svg width="16" height="16" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"><path d="M4 19.5A2.5 2.5 0 0 1 6.5 17H20"/><path d="M6.5 2H20v20H6.5A2.5 2.5 0 0 1 4 19.5v-15A2.5 2.5 0 0 1 6.5 2z"/></svg>
|
||||
</span>
|
||||
<span class="dashboard-resource-label" data-i18n="dashboard.knowledgeLabel">知识</span>
|
||||
<span class="dashboard-resource-value" id="dashboard-resource-knowledge">-</span>
|
||||
</a>
|
||||
<a class="dashboard-resource-item" onclick="switchPage('roles-management')" role="button" tabindex="0">
|
||||
<span class="dashboard-resource-icon dashboard-resource-icon-roles" aria-hidden="true">
|
||||
<svg width="16" height="16" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"><path d="M17 21v-2a4 4 0 0 0-4-4H5a4 4 0 0 0-4 4v2"/><circle cx="9" cy="7" r="4"/><path d="M23 21v-2a4 4 0 0 0-3-3.87"/><path d="M16 3.13a4 4 0 0 1 0 7.75"/></svg>
|
||||
</span>
|
||||
<span class="dashboard-resource-label" data-i18n="dashboard.rolesLabel">角色</span>
|
||||
<span class="dashboard-resource-value" id="dashboard-resource-roles">-</span>
|
||||
</a>
|
||||
<a class="dashboard-resource-item" onclick="switchPage('agents-management')" role="button" tabindex="0">
|
||||
<span class="dashboard-resource-icon dashboard-resource-icon-agents" aria-hidden="true">
|
||||
<svg width="16" height="16" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"><polygon points="12 2 2 7 12 12 22 7 12 2"/><polyline points="2 17 12 22 22 17"/><polyline points="2 12 12 17 22 12"/></svg>
|
||||
</span>
|
||||
<span class="dashboard-resource-label" data-i18n="dashboard.agentsLabel">Agents</span>
|
||||
<span class="dashboard-resource-value" id="dashboard-resource-agents">-</span>
|
||||
</a>
|
||||
<!-- WebShell 连接:渗透落地后建立的 foothold,对安全运维场景非常关键 -->
|
||||
<a class="dashboard-resource-item" onclick="switchPage('webshell')" role="button" tabindex="0">
|
||||
<span class="dashboard-resource-icon dashboard-resource-icon-webshell" aria-hidden="true">
|
||||
<svg width="16" height="16" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"><polyline points="4 17 10 11 4 5"/><line x1="12" y1="19" x2="20" y2="19"/></svg>
|
||||
</span>
|
||||
<span class="dashboard-resource-label" data-i18n="dashboard.webshellLabel">WebShell</span>
|
||||
<span class="dashboard-resource-value" id="dashboard-resource-webshell">-</span>
|
||||
</a>
|
||||
</div>
|
||||
</section>
|
||||
</div>
|
||||
</div>
|
||||
<div class="dashboard-cta-block">
|
||||
<!-- "开始你的安全之旅" CTA:默认显示;当用户已经有数据(任务/漏洞/调用)后,由 JS 隐藏避免冗余 -->
|
||||
<div class="dashboard-cta-block" id="dashboard-cta-block">
|
||||
<div class="dashboard-cta-content">
|
||||
<div class="dashboard-cta-icon" aria-hidden="true">
|
||||
<svg width="28" height="28" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="1.75" stroke-linecap="round" stroke-linejoin="round"><path d="M21 15a2 2 0 0 1-2 2H7l-4 4V5a2 2 0 0 1 2-2h14a2 2 0 0 1 2 2z"></path></svg>
|
||||
@@ -2832,6 +3036,7 @@
|
||||
<script src="/static/js/i18n.js"></script>
|
||||
<script src="/static/js/builtin-tools.js"></script>
|
||||
<script src="/static/js/auth.js"></script>
|
||||
<script src="/static/js/notifications.js"></script>
|
||||
<script src="/static/js/info-collect.js"></script>
|
||||
<script src="/static/js/router.js"></script>
|
||||
<script src="/static/js/agents.js"></script>
|
||||
|
||||
Reference in New Issue
Block a user