Compare commits

...

51 Commits

Author SHA1 Message Date
公明 dec69a1993 Update config.yaml 2026-05-01 01:33:17 +08:00
公明 15aab2584a Add files via upload 2026-05-01 01:32:54 +08:00
公明 399b697d75 Add files via upload 2026-05-01 01:31:19 +08:00
公明 e0753fd03e Add files via upload 2026-05-01 01:28:19 +08:00
公明 9b1e493023 Add files via upload 2026-05-01 01:05:48 +08:00
公明 77d212098d Add files via upload 2026-05-01 01:03:28 +08:00
公明 39926007fe Add files via upload 2026-05-01 01:01:30 +08:00
公明 0e35506ae1 Add files via upload 2026-05-01 01:00:23 +08:00
公明 9ff8bfa44b Add files via upload 2026-04-30 20:31:17 +08:00
公明 1d9fcfd87e Update version number to v1.5.16 2026-04-30 20:28:21 +08:00
公明 91cb650234 Add files via upload 2026-04-30 15:20:13 +08:00
公明 44e7d3b340 Add files via upload 2026-04-30 15:01:35 +08:00
公明 531b05299a Add files via upload 2026-04-30 10:49:19 +08:00
公明 0de69a6345 Add files via upload 2026-04-30 10:43:23 +08:00
公明 6a2a445f32 Update config.yaml 2026-04-30 01:56:47 +08:00
公明 6aaa21d3e0 Add files via upload 2026-04-30 01:55:23 +08:00
公明 5c57d358ef Add files via upload 2026-04-30 01:53:46 +08:00
公明 65a3475c02 Add files via upload 2026-04-30 01:52:11 +08:00
公明 516ebf7a65 Add files via upload 2026-04-29 22:40:17 +08:00
公明 2558be3d7d Add files via upload 2026-04-29 22:38:14 +08:00
公明 f6bb455313 Update config.yaml 2026-04-29 17:14:19 +08:00
公明 fc64356282 Add files via upload 2026-04-29 17:10:53 +08:00
公明 3d4fce9b89 Add files via upload 2026-04-29 17:09:37 +08:00
公明 3e41a47abf Add files via upload 2026-04-29 17:05:02 +08:00
公明 5b942c7bc8 Add files via upload 2026-04-29 17:03:51 +08:00
公明 bcfb7b8da1 Update config.yaml 2026-04-29 04:11:31 +08:00
公明 f420ae0265 Add files via upload 2026-04-29 03:28:32 +08:00
公明 e3f59b29ab Add files via upload 2026-04-29 03:26:27 +08:00
公明 87cba37203 Add files via upload 2026-04-29 03:24:48 +08:00
公明 4773b9e963 Update config.yaml 2026-04-29 03:01:21 +08:00
公明 eda5f9bba1 Add files via upload 2026-04-29 02:59:34 +08:00
公明 1318607813 Add files via upload 2026-04-29 02:57:22 +08:00
公明 5100924abe Add files via upload 2026-04-29 02:54:43 +08:00
公明 44079674dd Add files via upload 2026-04-28 14:07:01 +08:00
公明 d959390e27 Update config.yaml 2026-04-28 11:45:27 +08:00
公明 62a0d8cb71 Add files via upload 2026-04-28 11:40:09 +08:00
公明 b53cae3a02 Add files via upload 2026-04-28 11:37:52 +08:00
公明 3b3d094dc4 Add files via upload 2026-04-28 10:26:09 +08:00
公明 47922c2083 Add files via upload 2026-04-28 10:23:24 +08:00
公明 dfaf0bc77f Update config.yaml 2026-04-28 01:23:57 +08:00
公明 3eb7edb1b8 Add files via upload 2026-04-28 01:23:33 +08:00
公明 f82f6b861e Add files via upload 2026-04-28 01:22:21 +08:00
公明 2acf43c454 Add files via upload 2026-04-28 01:19:01 +08:00
公明 fad6b3c808 Add files via upload 2026-04-28 01:05:58 +08:00
公明 0597838217 Add files via upload 2026-04-28 01:04:58 +08:00
公明 1532426b4f Add files via upload 2026-04-28 01:02:30 +08:00
公明 3aeb8c3474 Add files via upload 2026-04-28 00:37:46 +08:00
公明 b2b166972a Add files via upload 2026-04-28 00:33:29 +08:00
公明 36b669771c Delete internal/multiagent directory 2026-04-28 00:30:34 +08:00
公明 96564d4d89 Update default_single_system_prompt.go 2026-04-27 14:58:49 +08:00
公明 d85afa2d39 Add files via upload 2026-04-27 11:29:16 +08:00
87 changed files with 11624 additions and 2619 deletions
+15 -5
View File
@@ -10,7 +10,7 @@
# ============================================
# 前端显示的版本号(可选,不填则显示默认版本)
version: "v1.5.8"
version: "v1.5.17"
# 服务器配置
server:
host: 0.0.0.0 # 监听地址,0.0.0.0 表示监听所有网络接口
@@ -70,7 +70,7 @@ multi_agent:
robot_use_multi_agent: true # true 时企业微信/钉钉/飞书机器人也走 Eino 多代理(成本更高)
batch_use_multi_agent: false # true 时「批量任务」队列中每个子任务也走 Eino 多代理(成本更高)
max_iteration: 0 # 主代理 / plan_execute 执行器最大轮次,0 表示沿用 agent.max_iterations
# plan_execute 专用:execute↔replan 外层循环上限,0 表示 Eino 默认 10。Executor 未暴露 Handlerspatch/reduction/plantask 不作用于 PE,但 tool_search 工具列表拆分仍通过共享 ToolsConfig 作用于执行器
# plan_execute 专用:execute↔replan 外层循环上限,0 表示 Eino 默认 10。当前实现下 Executor 会挂载 patch/reduction/tool_search 等前置中间件
plan_execute_loop_max_iterations: 0
sub_agent_max_iterations: 120
sub_agent_user_context_max_runes: 0 # 子代理 task 描述中自动注入用户原始请求的字符上限;0=默认2000,负数=禁用
@@ -87,15 +87,25 @@ multi_agent:
# Eino ADK 中间件与 Deep/Supervisor 调参(结构体见 internal/config/config.go → MultiAgentEinoMiddlewareConfig
eino_middleware:
patch_tool_calls: true # true:修补历史中无 tool_result 的悬空 tool_call(流式中断/重试后更稳);false:关闭;字段省略时默认等同 true
tool_search_enable: false # true:工具数 ≥ min 时启用 tool_search,仅前 N 个工具常驻,其余按正则按需解锁,省 token、减误选;false:全量工具进上下文
tool_search_enable: true # true:工具数 ≥ min 时启用 tool_search,仅前 N 个工具常驻,其余按正则按需解锁,省 token、减误选;false:全量工具进上下文
tool_search_min_tools: 20 # 达到该数量才启用 tool_search(避免工具很少时多此一举);与 always_visible 配合使用
tool_search_always_visible: 12 # 始终直接暴露给模型的工具个数(顺序与角色工具列表一致);其余工具进入动态池,需 tool_search 解锁
tool_search_always_visible_tools: [read_file, glob, grep, write_file, edit_file, execute, task, transfer_to_agent, exit, write_todos, skill, tool_search, TaskCreate, TaskGet, TaskUpdate, TaskList, record_vulnerability, list_knowledge_risk_types, search_knowledge_base, webshell_exec, webshell_file_list, webshell_file_read, webshell_file_write, manage_webshell_list, manage_webshell_add, manage_webshell_update, manage_webshell_delete, manage_webshell_test, batch_task_list, batch_task_get, batch_task_start, batch_task_rerun, batch_task_pause, batch_task_update_metadata, batch_task_update_schedule, batch_task_schedule_enabled, batch_task_update_task, batch_task_remove_task, batch_task_delete, batch_task_create, batch_task_add_task, http-framework-test] # 后端内置常驻工具白名单(优先于 always_visible 数量策略)
plantask_enable: false # true:主代理(Deep / Supervisor 主)挂载 TaskCreate/Get/Update/List;需 eino_skills 可用且 skills_dir 存在,否则仅打日志并跳过
plantask_rel_dir: .eino/plantask # 结构化任务文件相对 skills_dir 的子目录,其下再按会话 ID 分子目录存放
reduction_enable: false # true:大工具输出截断/落盘以控上下文;依赖与 plantask 相同的 eino local 写盘后端,无后端时不挂载
reduction_enable: true # true:大工具输出截断/落盘以控上下文;依赖与 plantask 相同的 eino local 写盘后端,无后端时不挂载
reduction_max_length_for_trunc: 50000 # 单条工具结果超过该字符数(bytes)时截断并落盘(由 reduction 中间件处理)
reduction_max_tokens_for_clear: 160000 # 历史工具结果清理阈值(tokens),超阈值时在模型调用前清理旧结果
reduction_root_dir: "" # 非空:截断/清理内容落盘根路径;空:使用系统临时目录下按会话隔离的默认路径
reduction_clear_exclude: [] # 不参与「清理阶段」的工具名额外列表(会与 task/transfer/exit 等内置排除项合并);需要时用 YAML 列表填写
reduction_sub_agents: false # true:子代理也挂 reductionfalse:仅编排主代理使用 reduction
reduction_sub_agents: true # true:子代理也挂 reductionfalse:仅编排主代理使用 reduction
summarization_trigger_ratio: 0.8 # summarization 触发比例(max_total_tokens * ratio),建议 0.75~0.85
summarization_emit_internal_events: true # true:发出 summarization 内部事件(便于诊断)
history_input_budget_ratio: 0.35 # 历史入队预算比例(max_total_tokens * ratio
plan_execute_user_input_budget_ratio: 0.35 # plan_execute 中 userInput 预算比例(planner/replanner/executor 共用)
plan_execute_executed_steps_budget_ratio: 0.2 # plan_execute 中 executed_steps 预算比例
plan_execute_max_step_result_runes: 4000 # plan_execute 每步结果最大字符数(超出截断)
plan_execute_keep_last_steps: 8 # plan_execute 仅保留最近 N 步正文,早期步骤折叠为标题
checkpoint_dir: "" # 非空:为 adk.NewRunner 启用按会话子目录的文件型 CheckPointStore,便于中断恢复持久化;Resume 的 HTTP/前端流程需另行对接
deep_output_key: "" # 非空:将最终助手输出写入 adk session 的键名(Deep 与 Supervisor 主代理);空表示不写入
deep_model_retry_max_retries: 0 # >0ChatModel 调用失败时的框架级最大重试次数(Deep 与 Supervisor 主);0:不重试
+51 -32
View File
@@ -39,6 +39,7 @@ type Agent struct {
toolNameMapping map[string]string // 工具名称映射:OpenAI格式 -> 原始格式(用于外部MCP工具)
currentConversationID string // 当前对话ID(用于自动传递给工具)
promptBaseDir string // 解析 system_prompt_path 时相对路径的基准目录(通常为 config.yaml 所在目录)
toolDescriptionMode string // 工具描述模式: "short" | "full",默认 short
}
// ResultStorage 结果存储接口(直接使用 storage 包的类型)
@@ -162,6 +163,7 @@ func NewAgent(cfg *config.OpenAIConfig, agentCfg *config.AgentConfig, mcpServer
resultStorage: resultStorage,
largeResultThreshold: largeResultThreshold,
toolNameMapping: make(map[string]string), // 初始化工具名称映射
toolDescriptionMode: "short",
}
}
@@ -336,10 +338,10 @@ func (fc *FunctionCall) UnmarshalJSON(data []byte) error {
// AgentLoopResult Agent Loop执行结果
type AgentLoopResult struct {
Response string
MCPExecutionIDs []string
LastReActInput string // 最后一轮ReAct的输入(压缩后的messagesJSON格式
LastReActOutput string // 最终大模型的输出
Response string
MCPExecutionIDs []string
LastAgentTraceInput string // 最后一轮代理消息轨迹(压缩后的 messagesJSON;与 multiagent.RunResult 字段对齐
LastAgentTraceOutput string // 最终助手输出文本
}
// ProgressCallback 进度回调函数类型
@@ -471,7 +473,7 @@ func (a *Agent) AgentLoopWithProgress(ctx context.Context, userInput string, his
}
// 用于保存当前的messages,以便在异常情况下也能保存ReAct输入
var currentReActInput string
var currentAgentTraceInput string
maxIterations := a.maxIterations
thinkingStreamSeq := 0
@@ -490,9 +492,9 @@ func (a *Agent) AgentLoopWithProgress(ctx context.Context, userInput string, his
if err != nil {
a.logger.Warn("序列化ReAct输入失败", zap.Error(err))
} else {
currentReActInput = string(messagesJSON)
currentAgentTraceInput = string(messagesJSON)
// 更新result中的值,确保始终保存最新的ReAct输入(压缩后的)
result.LastReActInput = currentReActInput
result.LastAgentTraceInput = currentAgentTraceInput
}
// 检查上下文是否已取消
@@ -500,13 +502,13 @@ func (a *Agent) AgentLoopWithProgress(ctx context.Context, userInput string, his
case <-ctx.Done():
// 上下文被取消(可能是用户主动暂停或其他原因)
a.logger.Info("检测到上下文取消,保存当前ReAct数据", zap.Error(ctx.Err()))
result.LastReActInput = currentReActInput
result.LastAgentTraceInput = currentAgentTraceInput
if ctx.Err() == context.Canceled {
result.Response = "任务已被取消。"
} else {
result.Response = fmt.Sprintf("任务执行中断: %v", ctx.Err())
}
result.LastReActOutput = result.Response
result.LastAgentTraceOutput = result.Response
return result, ctx.Err()
default:
}
@@ -600,10 +602,10 @@ func (a *Agent) AgentLoopWithProgress(ctx context.Context, userInput string, his
})
if err != nil {
// API调用失败,保存当前的ReAct输入和错误信息作为输出
result.LastReActInput = currentReActInput
result.LastAgentTraceInput = currentAgentTraceInput
errorMsg := fmt.Sprintf("调用OpenAI失败: %v", err)
result.Response = errorMsg
result.LastReActOutput = errorMsg
result.LastAgentTraceOutput = errorMsg
a.logger.Warn("OpenAI调用失败,已保存ReAct数据", zap.Error(err))
return result, fmt.Errorf("调用OpenAI失败: %w", err)
}
@@ -629,19 +631,19 @@ func (a *Agent) AgentLoopWithProgress(ctx context.Context, userInput string, his
continue
}
// OpenAI返回错误,保存当前的ReAct输入和错误信息作为输出
result.LastReActInput = currentReActInput
result.LastAgentTraceInput = currentAgentTraceInput
errorMsg := fmt.Sprintf("OpenAI错误: %s", response.Error.Message)
result.Response = errorMsg
result.LastReActOutput = errorMsg
result.LastAgentTraceOutput = errorMsg
return result, fmt.Errorf("OpenAI错误: %s", response.Error.Message)
}
if len(response.Choices) == 0 {
// 没有收到响应,保存当前的ReAct输入和错误信息作为输出
result.LastReActInput = currentReActInput
result.LastAgentTraceInput = currentAgentTraceInput
errorMsg := "没有收到响应"
result.Response = errorMsg
result.LastReActOutput = errorMsg
result.LastAgentTraceOutput = errorMsg
return result, fmt.Errorf("没有收到响应")
}
@@ -816,7 +818,7 @@ func (a *Agent) AgentLoopWithProgress(ctx context.Context, userInput string, his
})
if strings.TrimSpace(streamText) != "" {
result.Response = streamText
result.LastReActOutput = result.Response
result.LastAgentTraceOutput = result.Response
sendProgress("progress", "总结生成完成", nil)
return result, nil
}
@@ -863,14 +865,14 @@ func (a *Agent) AgentLoopWithProgress(ctx context.Context, userInput string, his
})
if strings.TrimSpace(streamText) != "" {
result.Response = streamText
result.LastReActOutput = result.Response
result.LastAgentTraceOutput = result.Response
sendProgress("progress", "总结生成完成", nil)
return result, nil
}
// 如果获取总结失败,使用当前回复作为结果
if choice.Message.Content != "" {
result.Response = choice.Message.Content
result.LastReActOutput = result.Response
result.LastAgentTraceOutput = result.Response
return result, nil
}
// 如果都没有内容,跳出循环,让后续逻辑处理
@@ -881,7 +883,7 @@ func (a *Agent) AgentLoopWithProgress(ctx context.Context, userInput string, his
if choice.FinishReason == "stop" {
sendProgress("progress", "正在生成最终回复...", nil)
result.Response = choice.Message.Content
result.LastReActOutput = result.Response
result.LastAgentTraceOutput = result.Response
return result, nil
}
}
@@ -910,19 +912,19 @@ func (a *Agent) AgentLoopWithProgress(ctx context.Context, userInput string, his
})
if strings.TrimSpace(streamText) != "" {
result.Response = streamText
result.LastReActOutput = result.Response
result.LastAgentTraceOutput = result.Response
sendProgress("progress", "总结生成完成", nil)
return result, nil
}
// 如果无法生成总结,返回友好的提示
result.Response = fmt.Sprintf("已达到最大迭代次数(%d轮)。系统已执行了多轮测试,但由于达到迭代上限,无法继续自动执行。建议您查看已执行的工具结果,或提出新的测试请求以继续测试。", a.maxIterations)
result.LastReActOutput = result.Response
result.LastAgentTraceOutput = result.Response
return result, nil
}
// getAvailableTools 获取可用工具
// 从MCP服务器动态获取工具列表,使用简短描述以减少token消耗
// 从MCP服务器动态获取工具列表,描述模式由 tool_description_mode 控制
// roleTools: 角色配置的工具列表(toolKey格式),如果为空或nil,则使用所有工具(默认角色)
func (a *Agent) getAvailableTools(roleTools []string) []Tool {
// 构建角色工具集合(用于快速查找)
@@ -946,11 +948,7 @@ func (a *Agent) getAvailableTools(roleTools []string) []Tool {
continue // 不在角色工具列表中,跳过
}
}
// 使用简短描述(如果存在),否则使用详细描述
description := mcpTool.ShortDescription
if description == "" {
description = mcpTool.Description
}
description := a.pickToolDescription(mcpTool.ShortDescription, mcpTool.Description)
// 转换schema中的类型为OpenAI标准类型
convertedSchema := a.convertSchemaTypes(mcpTool.InputSchema)
@@ -1024,11 +1022,7 @@ func (a *Agent) getAvailableTools(roleTools []string) []Tool {
continue
}
// 使用简短描述(如果存在),否则使用详细描述
description := externalTool.ShortDescription
if description == "" {
description = externalTool.Description
}
description := a.pickToolDescription(externalTool.ShortDescription, externalTool.Description)
// 转换schema中的类型为OpenAI标准类型
convertedSchema := a.convertSchemaTypes(externalTool.InputSchema)
@@ -1063,6 +1057,19 @@ func (a *Agent) getAvailableTools(roleTools []string) []Tool {
return tools
}
func (a *Agent) pickToolDescription(shortDesc, fullDesc string) string {
a.mu.RLock()
mode := strings.TrimSpace(strings.ToLower(a.toolDescriptionMode))
a.mu.RUnlock()
if mode == "full" {
return fullDesc
}
if shortDesc != "" {
return shortDesc
}
return fullDesc
}
// convertSchemaTypes 递归转换schema中的类型为OpenAI标准类型
func (a *Agent) convertSchemaTypes(schema map[string]interface{}) map[string]interface{} {
if schema == nil {
@@ -1665,6 +1672,18 @@ func (a *Agent) UpdateMaxIterations(maxIterations int) {
}
}
// UpdateToolDescriptionMode 更新工具描述模式(short/full)
func (a *Agent) UpdateToolDescriptionMode(mode string) {
a.mu.Lock()
defer a.mu.Unlock()
mode = strings.TrimSpace(strings.ToLower(mode))
if mode != "full" {
mode = "short"
}
a.toolDescriptionMode = mode
a.logger.Info("Agent工具描述模式已更新", zap.String("tool_description_mode", mode))
}
// formatToolError 格式化工具错误信息,提供更友好的错误描述
func (a *Agent) formatToolError(toolName string, args map[string]interface{}, err error) string {
errorMsg := fmt.Sprintf(`工具执行失败
+48 -49
View File
@@ -18,62 +18,62 @@ import (
func setupTestAgent(t *testing.T) (*Agent, *storage.FileResultStorage) {
logger := zap.NewNop()
mcpServer := mcp.NewServer(logger)
openAICfg := &config.OpenAIConfig{
APIKey: "test-key",
BaseURL: "https://api.test.com/v1",
Model: "test-model",
}
agentCfg := &config.AgentConfig{
MaxIterations: 10,
LargeResultThreshold: 100, // 设置较小的阈值便于测试
ResultStorageDir: "",
}
agent := NewAgent(openAICfg, agentCfg, mcpServer, nil, logger, 10)
// 创建测试存储
tmpDir := filepath.Join(os.TempDir(), "test_agent_storage_"+time.Now().Format("20060102_150405"))
testStorage, err := storage.NewFileResultStorage(tmpDir, logger)
if err != nil {
t.Fatalf("创建测试存储失败: %v", err)
}
agent.SetResultStorage(testStorage)
return agent, testStorage
}
func TestAgent_FormatMinimalNotification(t *testing.T) {
agent, testStorage := setupTestAgent(t)
_ = testStorage // 避免未使用变量警告
executionID := "test_exec_001"
toolName := "nmap_scan"
size := 50000
lineCount := 1000
filePath := "tmp/test_exec_001.txt"
notification := agent.formatMinimalNotification(executionID, toolName, size, lineCount, filePath)
// 验证通知包含必要信息
if !strings.Contains(notification, executionID) {
t.Errorf("通知中应该包含执行ID: %s", executionID)
}
if !strings.Contains(notification, toolName) {
t.Errorf("通知中应该包含工具名称: %s", toolName)
}
if !strings.Contains(notification, "50000") {
t.Errorf("通知中应该包含大小信息")
}
if !strings.Contains(notification, "1000") {
t.Errorf("通知中应该包含行数信息")
}
if !strings.Contains(notification, "query_execution_result") {
t.Errorf("通知中应该包含查询工具的使用说明")
}
@@ -81,7 +81,7 @@ func TestAgent_FormatMinimalNotification(t *testing.T) {
func TestAgent_ExecuteToolViaMCP_LargeResult(t *testing.T) {
agent, _ := setupTestAgent(t)
// 创建模拟的MCP工具结果(大结果)
largeResult := &mcp.ToolResult{
Content: []mcp.Content{
@@ -92,59 +92,59 @@ func TestAgent_ExecuteToolViaMCP_LargeResult(t *testing.T) {
},
IsError: false,
}
// 模拟MCP服务器返回大结果
// 由于我们需要模拟CallTool的行为,这里需要创建一个mock或者使用实际的MCP服务器
// 为了简化测试,我们直接测试结果处理逻辑
// 设置阈值
agent.mu.Lock()
agent.largeResultThreshold = 1000 // 设置较小的阈值
agent.mu.Unlock()
// 创建执行ID
executionID := "test_exec_large_001"
toolName := "test_tool"
// 格式化结果
var resultText strings.Builder
for _, content := range largeResult.Content {
resultText.WriteString(content.Text)
resultText.WriteString("\n")
}
resultStr := resultText.String()
resultSize := len(resultStr)
// 检测大结果并保存
agent.mu.RLock()
threshold := agent.largeResultThreshold
storage := agent.resultStorage
agent.mu.RUnlock()
if resultSize > threshold && storage != nil {
// 保存大结果
err := storage.SaveResult(executionID, toolName, resultStr)
if err != nil {
t.Fatalf("保存大结果失败: %v", err)
}
// 生成通知
lines := strings.Split(resultStr, "\n")
filePath := storage.GetResultPath(executionID)
notification := agent.formatMinimalNotification(executionID, toolName, resultSize, len(lines), filePath)
// 验证通知格式
if !strings.Contains(notification, executionID) {
t.Errorf("通知中应该包含执行ID")
}
// 验证结果已保存
savedResult, err := storage.GetResult(executionID)
if err != nil {
t.Fatalf("获取保存的结果失败: %v", err)
}
if savedResult != resultStr {
t.Errorf("保存的结果与原始结果不匹配")
}
@@ -155,7 +155,7 @@ func TestAgent_ExecuteToolViaMCP_LargeResult(t *testing.T) {
func TestAgent_ExecuteToolViaMCP_SmallResult(t *testing.T) {
agent, _ := setupTestAgent(t)
// 创建小结果
smallResult := &mcp.ToolResult{
Content: []mcp.Content{
@@ -166,32 +166,32 @@ func TestAgent_ExecuteToolViaMCP_SmallResult(t *testing.T) {
},
IsError: false,
}
// 设置较大的阈值
agent.mu.Lock()
agent.largeResultThreshold = 100000 // 100KB
agent.mu.Unlock()
// 格式化结果
var resultText strings.Builder
for _, content := range smallResult.Content {
resultText.WriteString(content.Text)
resultText.WriteString("\n")
}
resultStr := resultText.String()
resultSize := len(resultStr)
// 检测大结果
agent.mu.RLock()
threshold := agent.largeResultThreshold
storage := agent.resultStorage
agent.mu.RUnlock()
if resultSize > threshold && storage != nil {
t.Fatal("小结果不应该被保存")
}
// 小结果应该直接返回
if resultSize <= threshold {
// 这是预期的行为
@@ -203,26 +203,26 @@ func TestAgent_ExecuteToolViaMCP_SmallResult(t *testing.T) {
func TestAgent_SetResultStorage(t *testing.T) {
agent, _ := setupTestAgent(t)
// 创建新的存储
tmpDir := filepath.Join(os.TempDir(), "test_new_storage_"+time.Now().Format("20060102_150405"))
newStorage, err := storage.NewFileResultStorage(tmpDir, zap.NewNop())
if err != nil {
t.Fatalf("创建新存储失败: %v", err)
}
// 设置新存储
agent.SetResultStorage(newStorage)
// 验证存储已更新
agent.mu.RLock()
currentStorage := agent.resultStorage
agent.mu.RUnlock()
if currentStorage != newStorage {
t.Fatal("存储未正确更新")
}
// 清理
os.RemoveAll(tmpDir)
}
@@ -230,24 +230,24 @@ func TestAgent_SetResultStorage(t *testing.T) {
func TestAgent_NewAgent_DefaultValues(t *testing.T) {
logger := zap.NewNop()
mcpServer := mcp.NewServer(logger)
openAICfg := &config.OpenAIConfig{
APIKey: "test-key",
BaseURL: "https://api.test.com/v1",
Model: "test-model",
}
// 测试默认配置
agent := NewAgent(openAICfg, nil, mcpServer, nil, logger, 0)
if agent.maxIterations != 30 {
t.Errorf("默认迭代次数不匹配。期望: 30, 实际: %d", agent.maxIterations)
}
agent.mu.RLock()
threshold := agent.largeResultThreshold
agent.mu.RUnlock()
if threshold != 50*1024 {
t.Errorf("默认阈值不匹配。期望: %d, 实际: %d", 50*1024, threshold)
}
@@ -256,31 +256,30 @@ func TestAgent_NewAgent_DefaultValues(t *testing.T) {
func TestAgent_NewAgent_CustomConfig(t *testing.T) {
logger := zap.NewNop()
mcpServer := mcp.NewServer(logger)
openAICfg := &config.OpenAIConfig{
APIKey: "test-key",
BaseURL: "https://api.test.com/v1",
Model: "test-model",
}
agentCfg := &config.AgentConfig{
MaxIterations: 20,
LargeResultThreshold: 100 * 1024, // 100KB
ResultStorageDir: "custom_tmp",
}
agent := NewAgent(openAICfg, agentCfg, mcpServer, nil, logger, 15)
if agent.maxIterations != 15 {
t.Errorf("迭代次数不匹配。期望: 15, 实际: %d", agent.maxIterations)
}
agent.mu.RLock()
threshold := agent.largeResultThreshold
agent.mu.RUnlock()
if threshold != 100*1024 {
t.Errorf("阈值不匹配。期望: %d, 实际: %d", 100*1024, threshold)
}
}
@@ -91,6 +91,20 @@ func DefaultSingleAgentSystemPrompt() string {
当工具返回错误时,错误信息会包含在工具响应中,请仔细阅读并做出合理的决策。
## 结束条件与停止约束
- 在「未完成用户目标」前,不得输出纯计划/纯建议式结论并结束本轮;必须继续给出可执行下一步,并优先通过工具验证。
- 若你准备结束回答,先执行一次自检:
1) 是否已有可验证证据支撑“任务完成/无法继续”的结论;
2) 是否至少尝试过当前路径的合理替代(参数、路径、方法、入口);
3) 是否仍存在可执行且低成本的下一步验证动作。
- 仅当满足以下任一条件时,才允许输出最终收尾:
1) 已达到用户目标并给出证据;
2) 达到明确边界(超时、权限、目标不可达、工具不可用且无替代),并清楚说明阻断点与已尝试项;
3) 用户明确要求停止。
- 若最近一步得到 404/空结果/无效响应,不得直接结束;至少再进行一次“同目标不同策略”的验证(如变更路径、参数、请求方法、上下文来源)。
- 避免无效空转:同一工具+同类参数连续失败 3 次后,必须切换策略(改工具、改入口、改假设)并说明切换原因。
## 漏洞记录
发现有效漏洞时,必须使用 ` + builtin.ToolRecordVulnerability + ` 记录:标题、描述、严重程度、类型、目标、证明(POC)、影响、修复建议。
+5 -5
View File
@@ -256,11 +256,11 @@ func orchestratorConfigFromOrchestrator(o *OrchestratorMarkdown) config.MultiAge
return config.MultiAgentSubConfig{}
}
return config.MultiAgentSubConfig{
ID: o.EinoName,
Name: o.DisplayName,
Description: o.Description,
Instruction: o.Instruction,
Kind: "orchestrator",
ID: o.EinoName,
Name: o.DisplayName,
Description: o.Description,
Instruction: o.Instruction,
Kind: "orchestrator",
}
}
+8
View File
@@ -133,6 +133,7 @@ func New(cfg *config.Config, log *logger.Logger) (*App, error) {
maxIterations = 30 // 默认值
}
agent := agent.NewAgent(&cfg.OpenAI, &cfg.Agent, mcpServer, externalMCPMgr, log.Logger, maxIterations)
agent.UpdateToolDescriptionMode(cfg.Security.ToolDescriptionMode)
// 设置结果存储到Agent
agent.SetResultStorage(resultStorage)
@@ -317,6 +318,7 @@ func New(cfg *config.Config, log *logger.Logger) (*App, error) {
}
monitorHandler := handler.NewMonitorHandler(mcpServer, executor, db, log.Logger)
monitorHandler.SetExternalMCPManager(externalMCPMgr) // 设置外部MCP管理器,以便获取外部MCP执行记录
notificationHandler := handler.NewNotificationHandler(db, agentHandler, log.Logger)
groupHandler := handler.NewGroupHandler(db, log.Logger)
authHandler := handler.NewAuthHandler(authManager, cfg, configPath, log.Logger)
attackChainHandler := handler.NewAttackChainHandler(db, &cfg.OpenAI, log.Logger)
@@ -433,6 +435,7 @@ func New(cfg *config.Config, log *logger.Logger) (*App, error) {
authHandler,
agentHandler,
monitorHandler,
notificationHandler,
conversationHandler,
robotHandler,
groupHandler,
@@ -599,6 +602,7 @@ func setupRoutes(
authHandler *handler.AuthHandler,
agentHandler *handler.AgentHandler,
monitorHandler *handler.MonitorHandler,
notificationHandler *handler.NotificationHandler,
conversationHandler *handler.ConversationHandler,
robotHandler *handler.RobotHandler,
groupHandler *handler.GroupHandler,
@@ -727,6 +731,8 @@ func setupRoutes(
protected.DELETE("/monitor/execution/:id", monitorHandler.DeleteExecution)
protected.DELETE("/monitor/executions", monitorHandler.DeleteExecutions)
protected.GET("/monitor/stats", monitorHandler.GetStats)
protected.GET("/notifications/summary", notificationHandler.GetSummary)
protected.POST("/notifications/read", notificationHandler.MarkRead)
// 配置管理
protected.GET("/config", configHandler.GetConfig)
@@ -901,6 +907,8 @@ func setupRoutes(
// 漏洞管理
protected.GET("/vulnerabilities", vulnerabilityHandler.ListVulnerabilities)
protected.GET("/vulnerabilities/export", vulnerabilityHandler.ExportVulnerabilities)
protected.GET("/vulnerabilities/filter-options", vulnerabilityHandler.GetVulnerabilityFilterOptions)
protected.GET("/vulnerabilities/stats", vulnerabilityHandler.GetVulnerabilityStats)
protected.GET("/vulnerabilities/:id", vulnerabilityHandler.GetVulnerability)
protected.POST("/vulnerabilities", vulnerabilityHandler.CreateVulnerability)
+10 -10
View File
@@ -145,7 +145,7 @@ func (b *Builder) BuildChainFromConversation(ctx context.Context, conversationID
}
// 1. 优先尝试从数据库获取保存的最后一轮ReAct输入和输出
reactInputJSON, modelOutput, err := b.db.GetReActData(conversationID)
reactInputJSON, modelOutput, err := b.db.GetAgentTrace(conversationID)
if err != nil {
b.logger.Warn("获取保存的ReAct数据失败,将使用消息历史构建", zap.Error(err))
// 继续使用原来的逻辑
@@ -170,7 +170,7 @@ func (b *Builder) BuildChainFromConversation(ctx context.Context, conversationID
messageCount = len(tempMessages)
}
dataSource = "database_last_react_input"
dataSource = "database_last_agent_trace"
b.logger.Info("使用保存的ReAct数据构建攻击链",
zap.String("conversationId", conversationID),
zap.String("dataSource", dataSource),
@@ -183,7 +183,7 @@ func (b *Builder) BuildChainFromConversation(ctx context.Context, conversationID
// userInput = b.extractUserInputFromReActInput(reactInputJSON)
// 将JSON格式的messages转换为可读格式
reactInputFinal = b.formatReActInputFromJSON(reactInputJSON)
reactInputFinal = b.formatAgentTraceInputFromJSON(reactInputJSON)
} else {
// 2. 如果没有保存的ReAct数据,从对话消息构建
dataSource = "messages_table"
@@ -201,7 +201,7 @@ func (b *Builder) BuildChainFromConversation(ctx context.Context, conversationID
}
// 提取最后一轮ReAct的输入(历史消息+当前用户输入)
reactInputFinal = b.buildReActInput(messages)
reactInputFinal = b.buildAgentTraceInput(messages)
// 提取大模型最后的输出(最后一条assistant消息)
for i := len(messages) - 1; i >= 0; i-- {
@@ -212,7 +212,7 @@ func (b *Builder) BuildChainFromConversation(ctx context.Context, conversationID
}
}
// 多代理:保存的 last_react_input 可能仅为首轮用户消息,不含工具轨迹;补充最后一轮助手的过程详情(与单代理「最后一轮 ReAct」对齐)
// 多代理:保存的轨迹列可能仅为首轮用户消息,不含工具轨迹;补充最后一轮助手的过程详情(与单代理完整轨迹对齐)
hasMCPOnAssistant := false
var lastAssistantID string
for i := len(messages) - 1; i >= 0; i-- {
@@ -320,7 +320,7 @@ func (b *Builder) formatProcessDetailsForAttackChain(details []database.ProcessD
}
// 1) 编排器的工具调用/结果:保留(这是“主 agent 调了什么工具”)
if (d.EventType == "tool_call" || d.EventType == "tool_result" || d.EventType == "tool_calls_detected" || d.EventType == "iteration" || d.EventType == "eino_recovery") && einoRole == "orchestrator" {
if (d.EventType == "tool_call" || d.EventType == "tool_result" || d.EventType == "tool_calls_detected" || d.EventType == "iteration") && einoRole == "orchestrator" {
sb.WriteString("[")
sb.WriteString(d.EventType)
sb.WriteString("] ")
@@ -366,8 +366,8 @@ func (b *Builder) formatProcessDetailsForAttackChain(details []database.ProcessD
return strings.TrimSpace(sb.String())
}
// buildReActInput 构建最后一轮ReAct的输入(历史消息+当前用户输入)
func (b *Builder) buildReActInput(messages []database.Message) string {
// buildAgentTraceInput 构建最后一轮ReAct的输入(历史消息+当前用户输入)
func (b *Builder) buildAgentTraceInput(messages []database.Message) string {
var builder strings.Builder
for _, msg := range messages {
builder.WriteString(fmt.Sprintf("[%s]: %s\n\n", msg.Role, msg.Content))
@@ -396,8 +396,8 @@ func (b *Builder) buildReActInput(messages []database.Message) string {
// return ""
// }
// formatReActInputFromJSON 将JSON格式的messages数组转换为可读的字符串格式
func (b *Builder) formatReActInputFromJSON(reactInputJSON string) string {
// formatAgentTraceInputFromJSON 将JSON格式的messages数组转换为可读的字符串格式
func (b *Builder) formatAgentTraceInputFromJSON(reactInputJSON string) string {
var messages []map[string]interface{}
if err := json.Unmarshal([]byte(reactInputJSON), &messages); err != nil {
b.logger.Warn("解析ReAct输入JSON失败", zap.Error(err))
+112
View File
@@ -72,6 +72,8 @@ type MultiAgentEinoMiddlewareConfig struct {
ToolSearchEnable bool `yaml:"tool_search_enable,omitempty" json:"tool_search_enable,omitempty"`
ToolSearchMinTools int `yaml:"tool_search_min_tools,omitempty" json:"tool_search_min_tools,omitempty"` // default 20; applies when len(tools) >= this
ToolSearchAlwaysVisible int `yaml:"tool_search_always_visible,omitempty" json:"tool_search_always_visible,omitempty"` // default 12; first N tools stay always visible
// ToolSearchAlwaysVisibleTools keeps specified tool names always visible (never hidden by tool_search).
ToolSearchAlwaysVisibleTools []string `yaml:"tool_search_always_visible_tools,omitempty" json:"tool_search_always_visible_tools,omitempty"`
// Plantask adds TaskCreate/Get/Update/List (file-backed under skills dir); requires eino_skills + local backend.
PlantaskEnable bool `yaml:"plantask_enable,omitempty" json:"plantask_enable,omitempty"`
// PlantaskRelDir relative to skills_dir for per-conversation task boards (default .eino/plantask).
@@ -79,8 +81,24 @@ type MultiAgentEinoMiddlewareConfig struct {
// Reduction truncates/offloads large tool outputs (requires eino local backend for Write).
ReductionEnable bool `yaml:"reduction_enable,omitempty" json:"reduction_enable,omitempty"`
ReductionRootDir string `yaml:"reduction_root_dir,omitempty" json:"reduction_root_dir,omitempty"` // default: os temp + conversation id
ReductionMaxLengthForTrunc int `yaml:"reduction_max_length_for_trunc,omitempty" json:"reduction_max_length_for_trunc,omitempty"` // default 12000
ReductionMaxTokensForClear int `yaml:"reduction_max_tokens_for_clear,omitempty" json:"reduction_max_tokens_for_clear,omitempty"` // default 50000
ReductionClearExclude []string `yaml:"reduction_clear_exclude,omitempty" json:"reduction_clear_exclude,omitempty"`
ReductionSubAgents bool `yaml:"reduction_sub_agents,omitempty" json:"reduction_sub_agents,omitempty"` // also attach to sub-agents
// SummarizationTriggerRatio controls summarization trigger threshold as max_total_tokens * ratio (default 0.8).
SummarizationTriggerRatio float64 `yaml:"summarization_trigger_ratio,omitempty" json:"summarization_trigger_ratio,omitempty"`
// SummarizationEmitInternalEvents controls middleware internal event emission (default true).
SummarizationEmitInternalEvents *bool `yaml:"summarization_emit_internal_events,omitempty" json:"summarization_emit_internal_events,omitempty"`
// HistoryInputBudgetRatio caps pre-agent history tokens as max_total_tokens * ratio (default 0.35).
HistoryInputBudgetRatio float64 `yaml:"history_input_budget_ratio,omitempty" json:"history_input_budget_ratio,omitempty"`
// PlanExecuteUserInputBudgetRatio caps planner/replanner/executor userInput prompt budget ratio (default 0.35).
PlanExecuteUserInputBudgetRatio float64 `yaml:"plan_execute_user_input_budget_ratio,omitempty" json:"plan_execute_user_input_budget_ratio,omitempty"`
// PlanExecuteExecutedStepsBudgetRatio caps executed_steps prompt budget ratio (default 0.2).
PlanExecuteExecutedStepsBudgetRatio float64 `yaml:"plan_execute_executed_steps_budget_ratio,omitempty" json:"plan_execute_executed_steps_budget_ratio,omitempty"`
// PlanExecuteMaxStepResultRunes caps each executed step result length for prompt view (default 4000).
PlanExecuteMaxStepResultRunes int `yaml:"plan_execute_max_step_result_runes,omitempty" json:"plan_execute_max_step_result_runes,omitempty"`
// PlanExecuteKeepLastSteps keeps only the tail steps in prompt view (default 8).
PlanExecuteKeepLastSteps int `yaml:"plan_execute_keep_last_steps,omitempty" json:"plan_execute_keep_last_steps,omitempty"`
// CheckpointDir when non-empty enables adk.Runner CheckPointStore (file-backed) for interrupt/resume persistence.
CheckpointDir string `yaml:"checkpoint_dir,omitempty" json:"checkpoint_dir,omitempty"`
// DeepOutputKey passed to deep.Config OutputKey (session final text); empty = off.
@@ -91,6 +109,97 @@ type MultiAgentEinoMiddlewareConfig struct {
TaskToolDescriptionPrefix string `yaml:"task_tool_description_prefix,omitempty" json:"task_tool_description_prefix,omitempty"`
}
func (c MultiAgentEinoMiddlewareConfig) SummarizationTriggerRatioEffective() float64 {
v := c.SummarizationTriggerRatio
if v <= 0 {
return 0.8
}
if v < 0.5 {
return 0.5
}
if v > 0.95 {
return 0.95
}
return v
}
func (c MultiAgentEinoMiddlewareConfig) SummarizationEmitInternalEventsEffective() bool {
if c.SummarizationEmitInternalEvents != nil {
return *c.SummarizationEmitInternalEvents
}
return true
}
func (c MultiAgentEinoMiddlewareConfig) HistoryInputBudgetRatioEffective() float64 {
v := c.HistoryInputBudgetRatio
if v <= 0 {
return 0.35
}
if v < 0.15 {
return 0.15
}
if v > 0.6 {
return 0.6
}
return v
}
func (c MultiAgentEinoMiddlewareConfig) PlanExecuteUserInputBudgetRatioEffective() float64 {
v := c.PlanExecuteUserInputBudgetRatio
if v <= 0 {
return 0.35
}
if v < 0.1 {
return 0.1
}
if v > 0.6 {
return 0.6
}
return v
}
func (c MultiAgentEinoMiddlewareConfig) PlanExecuteExecutedStepsBudgetRatioEffective() float64 {
v := c.PlanExecuteExecutedStepsBudgetRatio
if v <= 0 {
return 0.2
}
if v < 0.08 {
return 0.08
}
if v > 0.5 {
return 0.5
}
return v
}
func (c MultiAgentEinoMiddlewareConfig) PlanExecuteMaxStepResultRunesEffective() int {
if c.PlanExecuteMaxStepResultRunes > 0 {
return c.PlanExecuteMaxStepResultRunes
}
return 4000
}
func (c MultiAgentEinoMiddlewareConfig) PlanExecuteKeepLastStepsEffective() int {
if c.PlanExecuteKeepLastSteps > 0 {
return c.PlanExecuteKeepLastSteps
}
return 8
}
func (c MultiAgentEinoMiddlewareConfig) ReductionMaxLengthForTruncEffective() int {
if c.ReductionMaxLengthForTrunc > 0 {
return c.ReductionMaxLengthForTrunc
}
return 12000
}
func (c MultiAgentEinoMiddlewareConfig) ReductionMaxTokensForClearEffective() int {
if c.ReductionMaxTokensForClear > 0 {
return c.ReductionMaxTokensForClear
}
return 50000
}
// MultiAgentEinoSkillsConfig toggles Eino official skill progressive disclosure and host filesystem tools.
type MultiAgentEinoSkillsConfig struct {
// Disable skips skill middleware (and does not attach local FS tools for Deep).
@@ -137,6 +246,8 @@ type MultiAgentPublic struct {
SubAgentCount int `json:"sub_agent_count"`
Orchestration string `json:"orchestration,omitempty"`
PlanExecuteLoopMaxIterations int `json:"plan_execute_loop_max_iterations"`
ToolSearchAlwaysVisibleTools []string `json:"tool_search_always_visible_tools,omitempty"`
ToolSearchAlwaysVisibleEffectiveTools []string `json:"tool_search_always_visible_effective_tools,omitempty"`
}
// NormalizeMultiAgentOrchestration 返回 deep、plan_execute 或 supervisor。
@@ -158,6 +269,7 @@ type MultiAgentAPIUpdate struct {
RobotUseMultiAgent bool `json:"robot_use_multi_agent"`
BatchUseMultiAgent bool `json:"batch_use_multi_agent"`
PlanExecuteLoopMaxIterations *int `json:"plan_execute_loop_max_iterations,omitempty"`
ToolSearchAlwaysVisibleTools []string `json:"tool_search_always_visible_tools,omitempty"`
}
// RobotsConfig 机器人配置(企业微信、钉钉、飞书等)
-1
View File
@@ -165,4 +165,3 @@ func (db *DB) DeleteAttackChain(conversationID string) error {
return nil
}
+23 -12
View File
@@ -4,6 +4,8 @@ import (
"database/sql"
"encoding/json"
"fmt"
"os"
"path/filepath"
"strings"
"time"
@@ -308,7 +310,7 @@ func (db *DB) GetConversationLite(id string) (*Conversation, error) {
func (db *DB) ListConversations(limit, offset int, search string) ([]*Conversation, error) {
var rows *sql.Rows
var err error
if search != "" {
// 使用 EXISTS 子查询代替 LEFT JOIN + DISTINCT,避免大表笛卡尔积
searchPattern := "%" + search + "%"
@@ -327,7 +329,7 @@ func (db *DB) ListConversations(limit, offset int, search string) ([]*Conversati
limit, offset,
)
}
if err != nil {
return nil, fmt.Errorf("查询对话列表失败: %w", err)
}
@@ -416,25 +418,34 @@ func (db *DB) DeleteConversation(id string) error {
if err != nil {
return fmt.Errorf("删除对话失败: %w", err)
}
// Best-effort cleanup for conversation-scoped filesystem artifacts
// (e.g., summarization transcript, reduction/checkpoint files under conversation_artifacts/<id>).
if base := strings.TrimSpace(db.conversationArtifactsDir); base != "" {
artDir := filepath.Join(base, id)
if rmErr := os.RemoveAll(artDir); rmErr != nil {
db.logger.Warn("删除会话 artifacts 目录失败", zap.String("conversationId", id), zap.String("dir", artDir), zap.Error(rmErr))
}
}
db.logger.Info("对话及其所有相关数据已删除", zap.String("conversationId", id))
return nil
}
// SaveReActData 保存最后一轮ReAct的输入和输出
func (db *DB) SaveReActData(conversationID, reactInput, reactOutput string) error {
// SaveAgentTrace 保存最后一轮代理消息轨迹与助手输出摘要。
// SQLite 列名仍为 last_react_input / last_react_output,与历史库表兼容;语义上为「全模式代理轨迹」,非仅 ReAct。
func (db *DB) SaveAgentTrace(conversationID, traceInputJSON, assistantOutput string) error {
_, err := db.Exec(
"UPDATE conversations SET last_react_input = ?, last_react_output = ?, updated_at = ? WHERE id = ?",
reactInput, reactOutput, time.Now(), conversationID,
traceInputJSON, assistantOutput, time.Now(), conversationID,
)
if err != nil {
return fmt.Errorf("保存ReAct数据失败: %w", err)
return fmt.Errorf("保存代理轨迹失败: %w", err)
}
return nil
}
// GetReActData 获取最后一轮ReAct的输入和输出
func (db *DB) GetReActData(conversationID string) (reactInput, reactOutput string, err error) {
// GetAgentTrace 读取 conversations 中保存的代理轨迹(列名 last_react_*)。
func (db *DB) GetAgentTrace(conversationID string) (traceInputJSON, assistantOutput string, err error) {
var input, output sql.NullString
err = db.QueryRow(
"SELECT last_react_input, last_react_output FROM conversations WHERE id = ?",
@@ -444,17 +455,17 @@ func (db *DB) GetReActData(conversationID string) (reactInput, reactOutput strin
if err == sql.ErrNoRows {
return "", "", fmt.Errorf("对话不存在")
}
return "", "", fmt.Errorf("获取ReAct数据失败: %w", err)
return "", "", fmt.Errorf("获取代理轨迹失败: %w", err)
}
if input.Valid {
reactInput = input.String
traceInputJSON = input.String
}
if output.Valid {
reactOutput = output.String
assistantOutput = output.String
}
return reactInput, reactOutput, nil
return traceInputJSON, assistantOutput, nil
}
// ConversationHasToolProcessDetails 对话是否存在已落库的工具调用/结果(用于多代理等场景下 MCP execution id 未汇总时的攻击链判定)。
+89 -2
View File
@@ -3,6 +3,8 @@ package database
import (
"database/sql"
"fmt"
"os"
"path/filepath"
"strings"
"time"
@@ -21,7 +23,8 @@ func configureDBPool(db *sql.DB) {
// DB 数据库连接
type DB struct {
*sql.DB
logger *zap.Logger
logger *zap.Logger
conversationArtifactsDir string
}
// NewDB 创建数据库连接
@@ -41,6 +44,13 @@ func NewDB(dbPath string, logger *zap.Logger) (*DB, error) {
DB: db,
logger: logger,
}
// Keep conversation-scoped artifacts near database files, so cleanup can follow conversation lifecycle.
baseDir := filepath.Join(filepath.Dir(dbPath), "conversation_artifacts")
if mkErr := os.MkdirAll(baseDir, 0o755); mkErr == nil {
database.conversationArtifactsDir = baseDir
} else if logger != nil {
logger.Warn("创建 conversation artifacts 目录失败", zap.String("dir", baseDir), zap.Error(mkErr))
}
// 初始化表
if err := database.initTables(); err != nil {
@@ -52,7 +62,7 @@ func NewDB(dbPath string, logger *zap.Logger) (*DB, error) {
// initTables 初始化数据库表
func (db *DB) initTables() error {
// 创建对话表
// 创建对话表last_react_input / last_react_output 存「代理消息轨迹」JSON 与助手摘要,列名保留以兼容已有库)
createConversationsTable := `
CREATE TABLE IF NOT EXISTS conversations (
id TEXT PRIMARY KEY,
@@ -197,6 +207,8 @@ func (db *DB) initTables() error {
CREATE TABLE IF NOT EXISTS vulnerabilities (
id TEXT PRIMARY KEY,
conversation_id TEXT NOT NULL,
conversation_tag TEXT,
task_tag TEXT,
title TEXT NOT NULL,
description TEXT,
severity TEXT NOT NULL,
@@ -257,6 +269,8 @@ func (db *DB) initTables() error {
method TEXT NOT NULL DEFAULT 'post',
cmd_param TEXT NOT NULL DEFAULT '',
remark TEXT NOT NULL DEFAULT '',
encoding TEXT NOT NULL DEFAULT '',
os TEXT NOT NULL DEFAULT '',
created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP
);`
@@ -289,6 +303,8 @@ func (db *DB) initTables() error {
CREATE INDEX IF NOT EXISTS idx_conversation_group_mappings_group ON conversation_group_mappings(group_id);
CREATE INDEX IF NOT EXISTS idx_conversations_pinned ON conversations(pinned);
CREATE INDEX IF NOT EXISTS idx_vulnerabilities_conversation_id ON vulnerabilities(conversation_id);
CREATE INDEX IF NOT EXISTS idx_vulnerabilities_conversation_tag ON vulnerabilities(conversation_tag);
CREATE INDEX IF NOT EXISTS idx_vulnerabilities_task_tag ON vulnerabilities(task_tag);
CREATE INDEX IF NOT EXISTS idx_vulnerabilities_severity ON vulnerabilities(severity);
CREATE INDEX IF NOT EXISTS idx_vulnerabilities_status ON vulnerabilities(status);
CREATE INDEX IF NOT EXISTS idx_vulnerabilities_created_at ON vulnerabilities(created_at);
@@ -383,6 +399,15 @@ func (db *DB) initTables() error {
db.logger.Warn("迁移batch_task_queues表失败", zap.Error(err))
// 不返回错误,允许继续运行
}
if err := db.migrateVulnerabilitiesTable(); err != nil {
db.logger.Warn("迁移vulnerabilities表失败", zap.Error(err))
// 不返回错误,允许继续运行
}
if err := db.migrateWebshellConnectionsTable(); err != nil {
db.logger.Warn("迁移webshell_connections表失败", zap.Error(err))
// 不返回错误,允许继续运行
}
if _, err := db.Exec(createIndexes); err != nil {
return fmt.Errorf("创建索引失败: %w", err)
@@ -683,6 +708,68 @@ func (db *DB) migrateBatchTaskQueuesTable() error {
return nil
}
// migrateVulnerabilitiesTable 迁移 vulnerabilities 表,补充标签字段
func (db *DB) migrateVulnerabilitiesTable() error {
columns := []struct {
name string
stmt string
}{
{name: "conversation_tag", stmt: "ALTER TABLE vulnerabilities ADD COLUMN conversation_tag TEXT"},
{name: "task_tag", stmt: "ALTER TABLE vulnerabilities ADD COLUMN task_tag TEXT"},
}
for _, col := range columns {
var count int
err := db.QueryRow("SELECT COUNT(*) FROM pragma_table_info('vulnerabilities') WHERE name=?", col.name).Scan(&count)
if err != nil {
if _, addErr := db.Exec(col.stmt); addErr != nil {
errMsg := strings.ToLower(addErr.Error())
if !strings.Contains(errMsg, "duplicate column") && !strings.Contains(errMsg, "already exists") {
db.logger.Warn("添加vulnerabilities字段失败", zap.String("field", col.name), zap.Error(addErr))
}
}
continue
}
if count == 0 {
if _, addErr := db.Exec(col.stmt); addErr != nil {
db.logger.Warn("添加vulnerabilities字段失败", zap.String("field", col.name), zap.Error(addErr))
}
}
}
return nil
}
// migrateWebshellConnectionsTable 迁移 webshell_connections 表,补充新字段
func (db *DB) migrateWebshellConnectionsTable() error {
columns := []struct {
name string
stmt string
}{
{name: "encoding", stmt: "ALTER TABLE webshell_connections ADD COLUMN encoding TEXT NOT NULL DEFAULT ''"},
{name: "os", stmt: "ALTER TABLE webshell_connections ADD COLUMN os TEXT NOT NULL DEFAULT ''"},
}
for _, col := range columns {
var count int
err := db.QueryRow("SELECT COUNT(*) FROM pragma_table_info('webshell_connections') WHERE name=?", col.name).Scan(&count)
if err != nil {
if _, addErr := db.Exec(col.stmt); addErr != nil {
errMsg := strings.ToLower(addErr.Error())
if !strings.Contains(errMsg, "duplicate column") && !strings.Contains(errMsg, "already exists") {
db.logger.Warn("添加webshell_connections字段失败", zap.String("field", col.name), zap.Error(addErr))
}
}
continue
}
if count == 0 {
if _, addErr := db.Exec(col.stmt); addErr != nil {
db.logger.Warn("添加webshell_connections字段失败", zap.String("field", col.name), zap.Error(addErr))
}
}
}
return nil
}
// NewKnowledgeDB 创建知识库数据库连接(只包含知识库相关的表)
func NewKnowledgeDB(dbPath string, logger *zap.Logger) (*DB, error) {
sqlDB, err := sql.Open("sqlite3", dbPath+"?_journal_mode=WAL&_foreign_keys=1&_busy_timeout=5000&_synchronous=NORMAL")
+118 -30
View File
@@ -12,7 +12,11 @@ import (
// Vulnerability 漏洞
type Vulnerability struct {
ID string `json:"id"`
ConversationID string `json:"conversation_id"`
ConversationID string `json:"conversation_id"`
ConversationTag string `json:"conversation_tag,omitempty"`
TaskTag string `json:"task_tag,omitempty"`
TaskID string `json:"task_id,omitempty"`
TaskQueueID string `json:"task_queue_id,omitempty"`
Title string `json:"title"`
Description string `json:"description"`
Severity string `json:"severity"` // critical, high, medium, low, info
@@ -42,15 +46,15 @@ func (db *DB) CreateVulnerability(vuln *Vulnerability) (*Vulnerability, error) {
query := `
INSERT INTO vulnerabilities (
id, conversation_id, title, description, severity, status,
id, conversation_id, conversation_tag, task_tag, title, description, severity, status,
vulnerability_type, target, proof, impact, recommendation,
created_at, updated_at
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
`
_, err := db.Exec(
query,
vuln.ID, vuln.ConversationID, vuln.Title, vuln.Description,
vuln.ID, vuln.ConversationID, vuln.ConversationTag, vuln.TaskTag, vuln.Title, vuln.Description,
vuln.Severity, vuln.Status, vuln.Type, vuln.Target,
vuln.Proof, vuln.Impact, vuln.Recommendation,
vuln.CreatedAt, vuln.UpdatedAt,
@@ -67,7 +71,9 @@ func (db *DB) GetVulnerability(id string) (*Vulnerability, error) {
var vuln Vulnerability
query := `
SELECT id, conversation_id, title, description, severity, status,
vulnerability_type, target, proof, impact, recommendation,
conversation_tag, task_tag, vulnerability_type, target, proof, impact, recommendation,
COALESCE((SELECT bt.id FROM batch_tasks bt WHERE bt.conversation_id = vulnerabilities.conversation_id LIMIT 1), '') AS task_id,
COALESCE((SELECT bt.queue_id FROM batch_tasks bt WHERE bt.conversation_id = vulnerabilities.conversation_id LIMIT 1), '') AS task_queue_id,
created_at, updated_at
FROM vulnerabilities
WHERE id = ?
@@ -75,8 +81,9 @@ func (db *DB) GetVulnerability(id string) (*Vulnerability, error) {
err := db.QueryRow(query, id).Scan(
&vuln.ID, &vuln.ConversationID, &vuln.Title, &vuln.Description,
&vuln.Severity, &vuln.Status, &vuln.Type, &vuln.Target,
&vuln.Severity, &vuln.Status, &vuln.ConversationTag, &vuln.TaskTag, &vuln.Type, &vuln.Target,
&vuln.Proof, &vuln.Impact, &vuln.Recommendation,
&vuln.TaskID, &vuln.TaskQueueID,
&vuln.CreatedAt, &vuln.UpdatedAt,
)
if err != nil {
@@ -90,10 +97,12 @@ func (db *DB) GetVulnerability(id string) (*Vulnerability, error) {
}
// ListVulnerabilities 列出漏洞
func (db *DB) ListVulnerabilities(limit, offset int, id, conversationID, severity, status string) ([]*Vulnerability, error) {
func (db *DB) ListVulnerabilities(limit, offset int, id, conversationID, severity, status, taskID, conversationTag, taskTag string) ([]*Vulnerability, error) {
query := `
SELECT id, conversation_id, title, description, severity, status,
SELECT id, conversation_id, title, description, severity, status, conversation_tag, task_tag,
vulnerability_type, target, proof, impact, recommendation,
COALESCE((SELECT bt.id FROM batch_tasks bt WHERE bt.conversation_id = vulnerabilities.conversation_id LIMIT 1), '') AS task_id,
COALESCE((SELECT bt.queue_id FROM batch_tasks bt WHERE bt.conversation_id = vulnerabilities.conversation_id LIMIT 1), '') AS task_queue_id,
created_at, updated_at
FROM vulnerabilities
WHERE 1=1
@@ -108,6 +117,18 @@ func (db *DB) ListVulnerabilities(limit, offset int, id, conversationID, severit
query += " AND conversation_id = ?"
args = append(args, conversationID)
}
if taskID != "" {
query += " AND EXISTS (SELECT 1 FROM batch_tasks bt WHERE bt.conversation_id = vulnerabilities.conversation_id AND (bt.id = ? OR bt.queue_id = ?))"
args = append(args, taskID, taskID)
}
if conversationTag != "" {
query += " AND conversation_tag = ?"
args = append(args, conversationTag)
}
if taskTag != "" {
query += " AND task_tag = ?"
args = append(args, taskTag)
}
if severity != "" {
query += " AND severity = ?"
args = append(args, severity)
@@ -131,8 +152,9 @@ func (db *DB) ListVulnerabilities(limit, offset int, id, conversationID, severit
var vuln Vulnerability
err := rows.Scan(
&vuln.ID, &vuln.ConversationID, &vuln.Title, &vuln.Description,
&vuln.Severity, &vuln.Status, &vuln.Type, &vuln.Target,
&vuln.Severity, &vuln.Status, &vuln.ConversationTag, &vuln.TaskTag, &vuln.Type, &vuln.Target,
&vuln.Proof, &vuln.Impact, &vuln.Recommendation,
&vuln.TaskID, &vuln.TaskQueueID,
&vuln.CreatedAt, &vuln.UpdatedAt,
)
if err != nil {
@@ -146,7 +168,7 @@ func (db *DB) ListVulnerabilities(limit, offset int, id, conversationID, severit
}
// CountVulnerabilities 统计漏洞总数(支持筛选条件)
func (db *DB) CountVulnerabilities(id, conversationID, severity, status string) (int, error) {
func (db *DB) CountVulnerabilities(id, conversationID, severity, status, taskID, conversationTag, taskTag string) (int, error) {
query := "SELECT COUNT(*) FROM vulnerabilities WHERE 1=1"
args := []interface{}{}
@@ -158,6 +180,18 @@ func (db *DB) CountVulnerabilities(id, conversationID, severity, status string)
query += " AND conversation_id = ?"
args = append(args, conversationID)
}
if taskID != "" {
query += " AND EXISTS (SELECT 1 FROM batch_tasks bt WHERE bt.conversation_id = vulnerabilities.conversation_id AND (bt.id = ? OR bt.queue_id = ?))"
args = append(args, taskID, taskID)
}
if conversationTag != "" {
query += " AND conversation_tag = ?"
args = append(args, conversationTag)
}
if taskTag != "" {
query += " AND task_tag = ?"
args = append(args, taskTag)
}
if severity != "" {
query += " AND severity = ?"
args = append(args, severity)
@@ -182,7 +216,7 @@ func (db *DB) UpdateVulnerability(id string, vuln *Vulnerability) error {
query := `
UPDATE vulnerabilities
SET title = ?, description = ?, severity = ?, status = ?,
SET conversation_tag = ?, task_tag = ?, title = ?, description = ?, severity = ?, status = ?,
vulnerability_type = ?, target = ?, proof = ?, impact = ?,
recommendation = ?, updated_at = ?
WHERE id = ?
@@ -190,7 +224,7 @@ func (db *DB) UpdateVulnerability(id string, vuln *Vulnerability) error {
_, err := db.Exec(
query,
vuln.Title, vuln.Description, vuln.Severity, vuln.Status,
vuln.ConversationTag, vuln.TaskTag, vuln.Title, vuln.Description, vuln.Severity, vuln.Status,
vuln.Type, vuln.Target, vuln.Proof, vuln.Impact,
vuln.Recommendation, vuln.UpdatedAt, id,
)
@@ -210,18 +244,24 @@ func (db *DB) DeleteVulnerability(id string) error {
return nil
}
// GetVulnerabilityStats 获取漏洞统计
func (db *DB) GetVulnerabilityStats(conversationID string) (map[string]interface{}, error) {
// GetVulnerabilityStats 获取漏洞统计(筛选条件与 ListVulnerabilities / CountVulnerabilities 一致)
func (db *DB) GetVulnerabilityStats(conversationID, taskID string) (map[string]interface{}, error) {
stats := make(map[string]interface{})
where := "WHERE 1=1"
args := []interface{}{}
if conversationID != "" {
where += " AND conversation_id = ?"
args = append(args, conversationID)
}
if taskID != "" {
where += " AND EXISTS (SELECT 1 FROM batch_tasks bt WHERE bt.conversation_id = vulnerabilities.conversation_id AND (bt.id = ? OR bt.queue_id = ?))"
args = append(args, taskID, taskID)
}
// 总漏洞数
var totalCount int
query := "SELECT COUNT(*) FROM vulnerabilities"
args := []interface{}{}
if conversationID != "" {
query += " WHERE conversation_id = ?"
args = append(args, conversationID)
}
query := "SELECT COUNT(*) FROM vulnerabilities " + where
err := db.QueryRow(query, args...).Scan(&totalCount)
if err != nil {
return nil, fmt.Errorf("获取总漏洞数失败: %w", err)
@@ -229,11 +269,7 @@ func (db *DB) GetVulnerabilityStats(conversationID string) (map[string]interface
stats["total"] = totalCount
// 按严重程度统计
severityQuery := "SELECT severity, COUNT(*) FROM vulnerabilities"
if conversationID != "" {
severityQuery += " WHERE conversation_id = ?"
}
severityQuery += " GROUP BY severity"
severityQuery := "SELECT severity, COUNT(*) FROM vulnerabilities " + where + " GROUP BY severity"
rows, err := db.Query(severityQuery, args...)
if err != nil {
@@ -253,11 +289,7 @@ func (db *DB) GetVulnerabilityStats(conversationID string) (map[string]interface
stats["by_severity"] = severityStats
// 按状态统计
statusQuery := "SELECT status, COUNT(*) FROM vulnerabilities"
if conversationID != "" {
statusQuery += " WHERE conversation_id = ?"
}
statusQuery += " GROUP BY status"
statusQuery := "SELECT status, COUNT(*) FROM vulnerabilities " + where + " GROUP BY status"
rows, err = db.Query(statusQuery, args...)
if err != nil {
@@ -279,3 +311,59 @@ func (db *DB) GetVulnerabilityStats(conversationID string) (map[string]interface
return stats, nil
}
// GetVulnerabilityFilterOptions 获取漏洞筛选建议项
func (db *DB) GetVulnerabilityFilterOptions() (map[string][]string, error) {
collect := func(query string, args ...interface{}) ([]string, error) {
rows, err := db.Query(query, args...)
if err != nil {
return nil, err
}
defer rows.Close()
items := make([]string, 0)
for rows.Next() {
var val string
if err := rows.Scan(&val); err != nil {
continue
}
if val == "" {
continue
}
items = append(items, val)
}
return items, nil
}
vulnIDs, err := collect(`SELECT DISTINCT id FROM vulnerabilities ORDER BY created_at DESC LIMIT 500`)
if err != nil {
return nil, fmt.Errorf("查询漏洞ID建议失败: %w", err)
}
conversationIDs, err := collect(`SELECT DISTINCT conversation_id FROM vulnerabilities WHERE conversation_id <> '' ORDER BY created_at DESC LIMIT 500`)
if err != nil {
return nil, fmt.Errorf("查询会话ID建议失败: %w", err)
}
taskIDs, err := collect(`SELECT DISTINCT id FROM batch_tasks WHERE id <> '' ORDER BY rowid DESC LIMIT 500`)
if err != nil {
return nil, fmt.Errorf("查询任务ID建议失败: %w", err)
}
queueIDs, err := collect(`SELECT DISTINCT queue_id FROM batch_tasks WHERE queue_id <> '' ORDER BY rowid DESC LIMIT 500`)
if err != nil {
return nil, fmt.Errorf("查询队列ID建议失败: %w", err)
}
conversationTags, err := collect(`SELECT DISTINCT conversation_tag FROM vulnerabilities WHERE conversation_tag IS NOT NULL AND conversation_tag <> '' ORDER BY conversation_tag LIMIT 500`)
if err != nil {
return nil, fmt.Errorf("查询对话标签建议失败: %w", err)
}
taskTags, err := collect(`SELECT DISTINCT task_tag FROM vulnerabilities WHERE task_tag IS NOT NULL AND task_tag <> '' ORDER BY task_tag LIMIT 500`)
if err != nil {
return nil, fmt.Errorf("查询任务标签建议失败: %w", err)
}
return map[string][]string{
"vulnerability_ids": vulnIDs,
"conversation_ids": conversationIDs,
"task_ids": taskIDs,
"queue_ids": queueIDs,
"conversation_tags": conversationTags,
"task_tags": taskTags,
}, nil
}
+13 -9
View File
@@ -16,6 +16,8 @@ type WebShellConnection struct {
Method string `json:"method"`
CmdParam string `json:"cmdParam"`
Remark string `json:"remark"`
Encoding string `json:"encoding"` // 目标响应编码:auto / utf-8 / gbk / gb18030,空值视为 auto
OS string `json:"os"` // 目标操作系统:auto / linux / windows,空值/未知视为 auto
CreatedAt time.Time `json:"createdAt"`
}
@@ -58,7 +60,8 @@ func (db *DB) UpsertWebshellConnectionState(connectionID, stateJSON string) erro
// ListWebshellConnections 列出所有 WebShell 连接,按创建时间倒序
func (db *DB) ListWebshellConnections() ([]WebShellConnection, error) {
query := `
SELECT id, url, password, type, method, cmd_param, remark, created_at
SELECT id, url, password, type, method, cmd_param, remark,
COALESCE(encoding, '') AS encoding, COALESCE(os, '') AS os, created_at
FROM webshell_connections
ORDER BY created_at DESC
`
@@ -72,7 +75,7 @@ func (db *DB) ListWebshellConnections() ([]WebShellConnection, error) {
var list []WebShellConnection
for rows.Next() {
var c WebShellConnection
err := rows.Scan(&c.ID, &c.URL, &c.Password, &c.Type, &c.Method, &c.CmdParam, &c.Remark, &c.CreatedAt)
err := rows.Scan(&c.ID, &c.URL, &c.Password, &c.Type, &c.Method, &c.CmdParam, &c.Remark, &c.Encoding, &c.OS, &c.CreatedAt)
if err != nil {
db.logger.Warn("扫描 WebShell 连接行失败", zap.Error(err))
continue
@@ -85,11 +88,12 @@ func (db *DB) ListWebshellConnections() ([]WebShellConnection, error) {
// GetWebshellConnection 根据 ID 获取一条连接
func (db *DB) GetWebshellConnection(id string) (*WebShellConnection, error) {
query := `
SELECT id, url, password, type, method, cmd_param, remark, created_at
SELECT id, url, password, type, method, cmd_param, remark,
COALESCE(encoding, '') AS encoding, COALESCE(os, '') AS os, created_at
FROM webshell_connections WHERE id = ?
`
var c WebShellConnection
err := db.QueryRow(query, id).Scan(&c.ID, &c.URL, &c.Password, &c.Type, &c.Method, &c.CmdParam, &c.Remark, &c.CreatedAt)
err := db.QueryRow(query, id).Scan(&c.ID, &c.URL, &c.Password, &c.Type, &c.Method, &c.CmdParam, &c.Remark, &c.Encoding, &c.OS, &c.CreatedAt)
if err == sql.ErrNoRows {
return nil, nil
}
@@ -103,10 +107,10 @@ func (db *DB) GetWebshellConnection(id string) (*WebShellConnection, error) {
// CreateWebshellConnection 创建 WebShell 连接
func (db *DB) CreateWebshellConnection(c *WebShellConnection) error {
query := `
INSERT INTO webshell_connections (id, url, password, type, method, cmd_param, remark, created_at)
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
INSERT INTO webshell_connections (id, url, password, type, method, cmd_param, remark, encoding, os, created_at)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
`
_, err := db.Exec(query, c.ID, c.URL, c.Password, c.Type, c.Method, c.CmdParam, c.Remark, c.CreatedAt)
_, err := db.Exec(query, c.ID, c.URL, c.Password, c.Type, c.Method, c.CmdParam, c.Remark, c.Encoding, c.OS, c.CreatedAt)
if err != nil {
db.logger.Error("创建 WebShell 连接失败", zap.Error(err), zap.String("id", c.ID))
return err
@@ -118,10 +122,10 @@ func (db *DB) CreateWebshellConnection(c *WebShellConnection) error {
func (db *DB) UpdateWebshellConnection(c *WebShellConnection) error {
query := `
UPDATE webshell_connections
SET url = ?, password = ?, type = ?, method = ?, cmd_param = ?, remark = ?
SET url = ?, password = ?, type = ?, method = ?, cmd_param = ?, remark = ?, encoding = ?, os = ?
WHERE id = ?
`
result, err := db.Exec(query, c.URL, c.Password, c.Type, c.Method, c.CmdParam, c.Remark, c.ID)
result, err := db.Exec(query, c.URL, c.Password, c.Type, c.Method, c.CmdParam, c.Remark, c.Encoding, c.OS, c.ID)
if err != nil {
db.logger.Error("更新 WebShell 连接失败", zap.Error(err), zap.String("id", c.ID))
return err
+5 -5
View File
@@ -160,17 +160,17 @@ func runMCPToolInvocation(
}
// UnknownToolReminderHandler 供 compose.ToolsNodeConfig.UnknownToolsHandler 使用:
// 模型请求了未注册的工具名时,返回一个「可恢复」的错误,让上层 runner 触发重试与纠错提示
// 同时避免 UI 永远停留在“执行中”(runner 会在 recoverable 分支 flush 掉 pending 的 tool_call
// 模型请求了未注册的工具名时,返回一个「软错误」工具结果(nil error
// 让模型在同一轮继续自我修正,避免触发 run-loop 级别的 full rerun
// 不进行名称猜测或映射,避免误执行。
func UnknownToolReminderHandler() func(ctx context.Context, name, input string) (string, error) {
return func(ctx context.Context, name, input string) (string, error) {
_ = ctx
_ = input
requested := strings.TrimSpace(name)
// Return a recoverable error that still carries a friendly, bilingual hint.
// This will be caught by multiagent runner as "tool not found" and trigger a retry.
return "", fmt.Errorf("tool %q not found: %s", requested, unknownToolReminderText(requested))
// Return a soft tool-result error so the graph keeps running and the LLM
// can correct tool name/arguments within the same run.
return ToolErrorPrefix + unknownToolReminderText(requested), nil
}
}
+86 -104
View File
@@ -497,10 +497,10 @@ func (h *AgentHandler) AgentLoop(c *gin.Context) {
defer h.hitlManager.DeactivateConversation(conversationID)
}
// 优先尝试从保存的ReAct数据恢复历史上下文
agentHistoryMessages, err := h.loadHistoryFromReActData(conversationID)
// 优先尝试从保存的代理轨迹恢复历史上下文
agentHistoryMessages, err := h.loadHistoryFromAgentTrace(conversationID)
if err != nil {
h.logger.Warn("从ReAct数据加载历史消息失败,使用消息表", zap.Error(err))
h.logger.Warn("从代理轨迹加载历史消息失败,使用消息表", zap.Error(err))
// 回退到使用数据库消息表
historyMessages, err := h.db.GetMessages(conversationID)
if err != nil {
@@ -518,7 +518,7 @@ func (h *AgentHandler) AgentLoop(c *gin.Context) {
h.logger.Info("从消息表加载历史消息", zap.Int("count", len(agentHistoryMessages)))
}
} else {
h.logger.Info("从ReAct数据恢复历史上下文", zap.Int("count", len(agentHistoryMessages)))
h.logger.Info("从代理轨迹恢复历史上下文", zap.Int("count", len(agentHistoryMessages)))
}
// 校验附件数量(非流式)
@@ -539,12 +539,7 @@ func (h *AgentHandler) AgentLoop(c *gin.Context) {
c.JSON(http.StatusBadRequest, gin.H{"error": "未找到该 WebShell 连接"})
return
}
remark := conn.Remark
if remark == "" {
remark = conn.URL
}
webshellContext := fmt.Sprintf("[WebShell 助手上下文] 当前连接 ID:%s,备注:%s。可用工具(仅在该连接上操作时使用,connection_id 填 \"%s\"):webshell_exec、webshell_file_list、webshell_file_read、webshell_file_write、record_vulnerability、list_knowledge_risk_types、search_knowledge_base。Skills 包请使用「多代理 / Eino DeepAgent」会话中的内置 `skill` 工具渐进加载。\n\n用户请求:%s",
conn.ID, remark, conn.ID, req.Message)
webshellContext := BuildWebshellAssistantContext(conn, WebshellSkillHintDefault, req.Message)
// WebShell 模式下如果同时指定了角色,追加角色 user_prompt(工具集仍仅限 webshell 专用工具)
if req.Role != "" && req.Role != "默认" && h.config.Roles != nil {
if role, exists := h.config.Roles[req.Role]; exists && role.Enabled && role.UserPrompt != "" {
@@ -613,12 +608,12 @@ func (h *AgentHandler) AgentLoop(c *gin.Context) {
if err != nil {
h.logger.Error("Agent Loop执行失败", zap.Error(err))
// 即使执行失败,也尝试保存ReAct数据(如果result中有)
if result != nil && (result.LastReActInput != "" || result.LastReActOutput != "") {
if saveErr := h.db.SaveReActData(conversationID, result.LastReActInput, result.LastReActOutput); saveErr != nil {
h.logger.Warn("保存失败任务的ReAct数据失败", zap.Error(saveErr))
// 即使执行失败,也尝试保存代理轨迹(如果 result 中有)
if result != nil && (result.LastAgentTraceInput != "" || result.LastAgentTraceOutput != "") {
if saveErr := h.db.SaveAgentTrace(conversationID, result.LastAgentTraceInput, result.LastAgentTraceOutput); saveErr != nil {
h.logger.Warn("保存失败任务的代理轨迹失败", zap.Error(saveErr))
} else {
h.logger.Info("已保存失败任务的ReAct数据", zap.String("conversationId", conversationID))
h.logger.Info("已保存失败任务的代理轨迹", zap.String("conversationId", conversationID))
}
}
@@ -634,12 +629,12 @@ func (h *AgentHandler) AgentLoop(c *gin.Context) {
// 因为AI已经生成了回复,用户应该能看到
}
// 保存最后一轮ReAct的输入和输出
if result.LastReActInput != "" || result.LastReActOutput != "" {
if err := h.db.SaveReActData(conversationID, result.LastReActInput, result.LastReActOutput); err != nil {
h.logger.Warn("保存ReAct数据失败", zap.Error(err))
// 保存最后一轮代理轨迹与助手输出
if result.LastAgentTraceInput != "" || result.LastAgentTraceOutput != "" {
if err := h.db.SaveAgentTrace(conversationID, result.LastAgentTraceInput, result.LastAgentTraceOutput); err != nil {
h.logger.Warn("保存代理轨迹失败", zap.Error(err))
} else {
h.logger.Info("已保存ReAct数据", zap.String("conversationId", conversationID))
h.logger.Info("已保存代理轨迹", zap.String("conversationId", conversationID))
}
}
@@ -666,7 +661,7 @@ func (h *AgentHandler) ProcessMessageForRobot(ctx context.Context, conversationI
}
}
agentHistoryMessages, err := h.loadHistoryFromReActData(conversationID)
agentHistoryMessages, err := h.loadHistoryFromAgentTrace(conversationID)
if err != nil {
historyMessages, getErr := h.db.GetMessages(conversationID)
if getErr != nil {
@@ -722,6 +717,7 @@ func (h *AgentHandler) ProcessMessageForRobot(ctx context.Context, conversationI
"deep",
)
if errMA != nil {
h.persistEinoAgentTraceForResume(conversationID, resultMA)
errMsg := "执行失败: " + errMA.Error()
if assistantMessageID != "" {
_, _ = h.db.Exec("UPDATE messages SET content = ? WHERE id = ?", errMsg, assistantMessageID)
@@ -747,8 +743,8 @@ func (h *AgentHandler) ProcessMessageForRobot(ctx context.Context, conversationI
h.logger.Warn("机器人:保存助手消息失败", zap.Error(err))
}
}
if resultMA.LastReActInput != "" || resultMA.LastReActOutput != "" {
_ = h.db.SaveReActData(conversationID, resultMA.LastReActInput, resultMA.LastReActOutput)
if resultMA.LastAgentTraceInput != "" || resultMA.LastAgentTraceOutput != "" {
_ = h.db.SaveAgentTrace(conversationID, resultMA.LastAgentTraceInput, resultMA.LastAgentTraceOutput)
}
return resultMA.Response, conversationID, nil
}
@@ -782,8 +778,8 @@ func (h *AgentHandler) ProcessMessageForRobot(ctx context.Context, conversationI
h.logger.Warn("机器人:保存助手消息失败", zap.Error(err))
}
}
if result.LastReActInput != "" || result.LastReActOutput != "" {
_ = h.db.SaveReActData(conversationID, result.LastReActInput, result.LastReActOutput)
if result.LastAgentTraceInput != "" || result.LastAgentTraceOutput != "" {
_ = h.db.SaveAgentTrace(conversationID, result.LastAgentTraceInput, result.LastAgentTraceOutput)
}
return result.Response, conversationID, nil
}
@@ -1359,10 +1355,10 @@ func (h *AgentHandler) AgentLoopStream(c *gin.Context) {
}
ssePublishConversationID = conversationID
// 优先尝试从保存的ReAct数据恢复历史上下文
agentHistoryMessages, err := h.loadHistoryFromReActData(conversationID)
// 优先尝试从保存的代理轨迹恢复历史上下文
agentHistoryMessages, err := h.loadHistoryFromAgentTrace(conversationID)
if err != nil {
h.logger.Warn("从ReAct数据加载历史消息失败,使用消息表", zap.Error(err))
h.logger.Warn("从代理轨迹加载历史消息失败,使用消息表", zap.Error(err))
// 回退到使用数据库消息表
historyMessages, err := h.db.GetMessages(conversationID)
if err != nil {
@@ -1380,7 +1376,7 @@ func (h *AgentHandler) AgentLoopStream(c *gin.Context) {
h.logger.Info("从消息表加载历史消息", zap.Int("count", len(agentHistoryMessages)))
}
} else {
h.logger.Info("从ReAct数据恢复历史上下文", zap.Int("count", len(agentHistoryMessages)))
h.logger.Info("从代理轨迹恢复历史上下文", zap.Int("count", len(agentHistoryMessages)))
}
// 校验附件数量
@@ -1399,12 +1395,7 @@ func (h *AgentHandler) AgentLoopStream(c *gin.Context) {
sendEvent("error", "未找到该 WebShell 连接", nil)
return
}
remark := conn.Remark
if remark == "" {
remark = conn.URL
}
webshellContext := fmt.Sprintf("[WebShell 助手上下文] 当前连接 ID:%s,备注:%s。可用工具(仅在该连接上操作时使用,connection_id 填 \"%s\"):webshell_exec、webshell_file_list、webshell_file_read、webshell_file_write、record_vulnerability、list_knowledge_risk_types、search_knowledge_base。Skills 包请使用「多代理 / Eino DeepAgent」会话中的内置 `skill` 工具渐进加载。\n\n用户请求:%s",
conn.ID, remark, conn.ID, req.Message)
webshellContext := BuildWebshellAssistantContext(conn, WebshellSkillHintDefault, req.Message)
// WebShell 模式下如果同时指定了角色,追加角色 user_prompt(工具集仍仅限 webshell 专用工具)
if req.Role != "" && req.Role != "默认" && h.config.Roles != nil {
if role, exists := h.config.Roles[req.Role]; exists && role.Enabled && role.UserPrompt != "" {
@@ -1579,12 +1570,12 @@ func (h *AgentHandler) AgentLoopStream(c *gin.Context) {
h.db.AddProcessDetail(assistantMessageID, conversationID, "cancelled", cancelMsg, nil)
}
// 即使任务被取消,也尝试保存ReAct数据(如果result中有)
if result != nil && (result.LastReActInput != "" || result.LastReActOutput != "") {
if err := h.db.SaveReActData(conversationID, result.LastReActInput, result.LastReActOutput); err != nil {
h.logger.Warn("保存取消任务的ReAct数据失败", zap.Error(err))
// 即使任务被取消,也尝试保存代理轨迹(如果 result 中有)
if result != nil && (result.LastAgentTraceInput != "" || result.LastAgentTraceOutput != "") {
if err := h.db.SaveAgentTrace(conversationID, result.LastAgentTraceInput, result.LastAgentTraceOutput); err != nil {
h.logger.Warn("保存取消任务的代理轨迹失败", zap.Error(err))
} else {
h.logger.Info("已保存取消任务的ReAct数据", zap.String("conversationId", conversationID))
h.logger.Info("已保存取消任务的代理轨迹", zap.String("conversationId", conversationID))
}
}
@@ -1614,12 +1605,12 @@ func (h *AgentHandler) AgentLoopStream(c *gin.Context) {
h.db.AddProcessDetail(assistantMessageID, conversationID, "timeout", timeoutMsg, nil)
}
// 即使任务超时,也尝试保存ReAct数据(如果result中有)
if result != nil && (result.LastReActInput != "" || result.LastReActOutput != "") {
if err := h.db.SaveReActData(conversationID, result.LastReActInput, result.LastReActOutput); err != nil {
h.logger.Warn("保存超时任务的ReAct数据失败", zap.Error(err))
// 即使任务超时,也尝试保存代理轨迹(如果 result 中有)
if result != nil && (result.LastAgentTraceInput != "" || result.LastAgentTraceOutput != "") {
if err := h.db.SaveAgentTrace(conversationID, result.LastAgentTraceInput, result.LastAgentTraceOutput); err != nil {
h.logger.Warn("保存超时任务的代理轨迹失败", zap.Error(err))
} else {
h.logger.Info("已保存超时任务的ReAct数据", zap.String("conversationId", conversationID))
h.logger.Info("已保存超时任务的代理轨迹", zap.String("conversationId", conversationID))
}
}
@@ -1649,12 +1640,12 @@ func (h *AgentHandler) AgentLoopStream(c *gin.Context) {
h.db.AddProcessDetail(assistantMessageID, conversationID, "error", errorMsg, nil)
}
// 即使任务失败,也尝试保存ReAct数据(如果result中有)
if result != nil && (result.LastReActInput != "" || result.LastReActOutput != "") {
if err := h.db.SaveReActData(conversationID, result.LastReActInput, result.LastReActOutput); err != nil {
h.logger.Warn("保存失败任务的ReAct数据失败", zap.Error(err))
// 即使任务失败,也尝试保存代理轨迹(如果 result 中有)
if result != nil && (result.LastAgentTraceInput != "" || result.LastAgentTraceOutput != "") {
if err := h.db.SaveAgentTrace(conversationID, result.LastAgentTraceInput, result.LastAgentTraceOutput); err != nil {
h.logger.Warn("保存失败任务的代理轨迹失败", zap.Error(err))
} else {
h.logger.Info("已保存失败任务的ReAct数据", zap.String("conversationId", conversationID))
h.logger.Info("已保存失败任务的代理轨迹", zap.String("conversationId", conversationID))
}
}
@@ -1694,12 +1685,12 @@ func (h *AgentHandler) AgentLoopStream(c *gin.Context) {
}
}
// 保存最后一轮ReAct的输入和输出
if result.LastReActInput != "" || result.LastReActOutput != "" {
if err := h.db.SaveReActData(conversationID, result.LastReActInput, result.LastReActOutput); err != nil {
h.logger.Warn("保存ReAct数据失败", zap.Error(err))
// 保存最后一轮代理轨迹与助手输出
if result.LastAgentTraceInput != "" || result.LastAgentTraceOutput != "" {
if err := h.db.SaveAgentTrace(conversationID, result.LastAgentTraceInput, result.LastAgentTraceOutput); err != nil {
h.logger.Warn("保存代理轨迹失败", zap.Error(err))
} else {
h.logger.Info("已保存ReAct数据", zap.String("conversationId", conversationID))
h.logger.Info("已保存代理轨迹", zap.String("conversationId", conversationID))
}
}
@@ -2499,6 +2490,9 @@ func (h *AgentHandler) executeBatchQueue(queueID string) {
cancel()
if runErr != nil {
if useRunResult {
h.persistEinoAgentTraceForResume(conversationID, resultMA)
}
// 检查是否是取消错误
// 1. 直接检查是否是 context.Canceled(包括包装后的错误)
// 2. 检查错误消息中是否包含"context canceled"或"cancelled"关键字
@@ -2542,14 +2536,14 @@ func (h *AgentHandler) executeBatchQueue(queueID string) {
h.logger.Warn("保存取消消息失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.Error(errMsg))
}
}
// 保存ReAct数据(如果存在)
if result != nil && (result.LastReActInput != "" || result.LastReActOutput != "") {
if err := h.db.SaveReActData(conversationID, result.LastReActInput, result.LastReActOutput); err != nil {
h.logger.Warn("保存取消任务的ReAct数据失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.Error(err))
// 保存代理轨迹(如果存在)
if result != nil && (result.LastAgentTraceInput != "" || result.LastAgentTraceOutput != "") {
if err := h.db.SaveAgentTrace(conversationID, result.LastAgentTraceInput, result.LastAgentTraceOutput); err != nil {
h.logger.Warn("保存取消任务的代理轨迹失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.Error(err))
}
} else if useRunResult && resultMA != nil && (resultMA.LastReActInput != "" || resultMA.LastReActOutput != "") {
if err := h.db.SaveReActData(conversationID, resultMA.LastReActInput, resultMA.LastReActOutput); err != nil {
h.logger.Warn("保存取消任务的ReAct数据失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.Error(err))
} else if useRunResult && resultMA != nil && (resultMA.LastAgentTraceInput != "" || resultMA.LastAgentTraceOutput != "") {
if err := h.db.SaveAgentTrace(conversationID, resultMA.LastAgentTraceInput, resultMA.LastAgentTraceOutput); err != nil {
h.logger.Warn("保存取消任务的代理轨迹失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.Error(err))
}
}
h.batchTaskManager.UpdateTaskStatusWithConversationID(queueID, task.ID, "cancelled", cancelMsg, "", conversationID)
@@ -2581,13 +2575,13 @@ func (h *AgentHandler) executeBatchQueue(queueID string) {
if useRunResult {
resText = resultMA.Response
mcpIDs = resultMA.MCPExecutionIDs
lastIn = resultMA.LastReActInput
lastOut = resultMA.LastReActOutput
lastIn = resultMA.LastAgentTraceInput
lastOut = resultMA.LastAgentTraceOutput
} else {
resText = result.Response
mcpIDs = result.MCPExecutionIDs
lastIn = result.LastReActInput
lastOut = result.LastReActOutput
lastIn = result.LastAgentTraceInput
lastOut = result.LastAgentTraceOutput
}
// 更新助手消息内容
@@ -2618,12 +2612,12 @@ func (h *AgentHandler) executeBatchQueue(queueID string) {
}
}
// 保存ReAct数据
// 保存代理轨迹
if lastIn != "" || lastOut != "" {
if err := h.db.SaveReActData(conversationID, lastIn, lastOut); err != nil {
h.logger.Warn("保存ReAct数据失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.Error(err))
if err := h.db.SaveAgentTrace(conversationID, lastIn, lastOut); err != nil {
h.logger.Warn("保存代理轨迹失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.Error(err))
} else {
h.logger.Info("已保存ReAct数据", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.String("conversationId", conversationID))
h.logger.Info("已保存代理轨迹", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.String("conversationId", conversationID))
}
}
@@ -2642,36 +2636,33 @@ func (h *AgentHandler) executeBatchQueue(queueID string) {
}
}
// loadHistoryFromReActData 从保存的ReAct数据恢复历史消息上下文
// 采用与攻击链生成类似的拼接逻辑:优先使用保存的last_react_input和last_react_output,若不存在则回退消息表
func (h *AgentHandler) loadHistoryFromReActData(conversationID string) ([]agent.ChatMessage, error) {
// 获取保存的ReAct输入和输出
reactInputJSON, reactOutput, err := h.db.GetReActData(conversationID)
// loadHistoryFromAgentTrace 从库中保存的代理消息轨迹恢复历史(列 last_react_*;含单代理与 Eino)。
// 逻辑与攻击链一致:优先用保存的 JSON 消息带 + 最后一轮助手摘要,否则回退消息表
func (h *AgentHandler) loadHistoryFromAgentTrace(conversationID string) ([]agent.ChatMessage, error) {
traceInputJSON, assistantOut, err := h.db.GetAgentTrace(conversationID)
if err != nil {
return nil, fmt.Errorf("获取ReAct数据失败: %w", err)
return nil, fmt.Errorf("获取代理轨迹失败: %w", err)
}
// 如果last_react_input为空,回退到使用消息表(与攻击链生成逻辑一致)
if reactInputJSON == "" {
return nil, fmt.Errorf("ReAct数据为空,将使用消息表")
if traceInputJSON == "" {
return nil, fmt.Errorf("代理轨迹为空,将使用消息表")
}
dataSource := "database_last_react_input"
dataSource := "database_last_agent_trace"
// 解析JSON格式的messages数组
var messagesArray []map[string]interface{}
if err := json.Unmarshal([]byte(reactInputJSON), &messagesArray); err != nil {
return nil, fmt.Errorf("解析ReAct输入JSON失败: %w", err)
if err := json.Unmarshal([]byte(traceInputJSON), &messagesArray); err != nil {
return nil, fmt.Errorf("解析代理轨迹 JSON 失败: %w", err)
}
messageCount := len(messagesArray)
h.logger.Info("使用保存的ReAct数据恢复历史上下文",
h.logger.Info("使用保存的代理轨迹恢复历史上下文",
zap.String("conversationId", conversationID),
zap.String("dataSource", dataSource),
zap.Int("reactInputSize", len(reactInputJSON)),
zap.Int("traceInputSize", len(traceInputJSON)),
zap.Int("messageCount", messageCount),
zap.Int("reactOutputSize", len(reactOutput)),
zap.Int("assistantOutSize", len(assistantOut)),
)
// fmt.Println("messagesArray:", messagesArray)//debug
@@ -2755,53 +2746,44 @@ func (h *AgentHandler) loadHistoryFromReActData(conversationID string) ([]agent.
agentMessages = append(agentMessages, msg)
}
// 如果存在last_react_output,需要将其作为最后一条assistant消息
// 因为last_react_input是在迭代开始前保存的,不包含最后一轮的最终输出
if reactOutput != "" {
// 检查最后一条消息是否是assistant消息且没有tool_calls
// 如果有tool_calls,说明后面应该还有tool消息和最终的assistant回复
// 存在 last_react_output(助手摘要),合并为最后一条 assistant(与保存格式一致)
if assistantOut != "" {
if len(agentMessages) > 0 {
lastMsg := &agentMessages[len(agentMessages)-1]
if strings.EqualFold(lastMsg.Role, "assistant") && len(lastMsg.ToolCalls) == 0 {
// 最后一条是assistant消息且没有tool_calls,用最终输出更新其content
lastMsg.Content = reactOutput
lastMsg.Content = assistantOut
} else {
// 最后一条不是assistant消息,或者有tool_calls,添加最终输出作为新的assistant消息
agentMessages = append(agentMessages, agent.ChatMessage{
Role: "assistant",
Content: reactOutput,
Content: assistantOut,
})
}
} else {
// 如果没有消息,直接添加最终输出
agentMessages = append(agentMessages, agent.ChatMessage{
Role: "assistant",
Content: reactOutput,
Content: assistantOut,
})
}
}
if len(agentMessages) == 0 {
return nil, fmt.Errorf("从ReAct数据解析的消息为空")
return nil, fmt.Errorf("从代理轨迹解析的消息为空")
}
// 修复可能存在的失配tool消息,避免OpenAI报错
// 这可以防止出现"messages with role 'tool' must be a response to a preceeding message with 'tool_calls'"错误
if h.agent != nil {
if fixed := h.agent.RepairOrphanToolMessages(&agentMessages); fixed {
h.logger.Info("修复了从ReAct数据恢复的历史消息中的失配tool消息",
h.logger.Info("修复了从代理轨迹恢复的历史消息中的失配 tool 消息",
zap.String("conversationId", conversationID),
)
}
}
h.logger.Info("从ReAct数据恢复历史消息完成",
h.logger.Info("从代理轨迹恢复历史消息完成",
zap.String("conversationId", conversationID),
zap.String("dataSource", dataSource),
zap.Int("originalMessageCount", messageCount),
zap.Int("finalMessageCount", len(agentMessages)),
zap.Bool("hasReactOutput", reactOutput != ""),
zap.Bool("hasAssistantOut", assistantOut != ""),
)
fmt.Println("agentMessages:", agentMessages) //debug
return agentMessages, nil
}
+2 -3
View File
@@ -83,7 +83,7 @@ func (h *AttackChainHandler) GetAttackChain(c *gin.Context) {
// 使用锁机制防止同一对话的并发生成
lockInterface, _ := h.generatingLocks.LoadOrStore(conversationID, &sync.Mutex{})
lock := lockInterface.(*sync.Mutex)
// 尝试获取锁,如果正在生成则返回错误
acquired := lock.TryLock()
if !acquired {
@@ -144,7 +144,7 @@ func (h *AttackChainHandler) RegenerateAttackChain(c *gin.Context) {
// 使用锁机制防止并发生成
lockInterface, _ := h.generatingLocks.LoadOrStore(conversationID, &sync.Mutex{})
lock := lockInterface.(*sync.Mutex)
acquired := lock.TryLock()
if !acquired {
h.logger.Info("攻击链正在生成中,请稍后再试", zap.String("conversationId", conversationID))
@@ -170,4 +170,3 @@ func (h *AttackChainHandler) RegenerateAttackChain(c *gin.Context) {
c.JSON(http.StatusOK, chain)
}
+39 -14
View File
@@ -17,6 +17,7 @@ import (
"cyberstrike-ai/internal/config"
"cyberstrike-ai/internal/knowledge"
"cyberstrike-ai/internal/mcp"
"cyberstrike-ai/internal/mcp/builtin"
"cyberstrike-ai/internal/openai"
"cyberstrike-ai/internal/security"
@@ -90,6 +91,7 @@ type AttackChainUpdater interface {
type AgentUpdater interface {
UpdateConfig(cfg *config.OpenAIConfig)
UpdateMaxIterations(maxIterations int)
UpdateToolDescriptionMode(mode string)
}
// NewConfigHandler 创建新的配置处理器
@@ -232,13 +234,7 @@ func (h *ConfigHandler) GetConfig(c *gin.Context) {
if configToolMap[mcpTool.Name] {
continue
}
description := mcpTool.ShortDescription
if description == "" {
description = mcpTool.Description
}
if len(description) > 10000 {
description = description[:10000] + "..."
}
description := h.pickToolDescription(mcpTool.ShortDescription, mcpTool.Description)
tools = append(tools, ToolConfigInfo{
Name: mcpTool.Name,
Description: description,
@@ -275,6 +271,11 @@ func (h *ConfigHandler) GetConfig(c *gin.Context) {
SubAgentCount: subAgentCount,
Orchestration: config.NormalizeMultiAgentOrchestration(h.config.MultiAgent.Orchestration),
PlanExecuteLoopMaxIterations: h.config.MultiAgent.PlanExecuteLoopMaxIterations,
ToolSearchAlwaysVisibleTools: append([]string(nil), h.config.MultiAgent.EinoMiddleware.ToolSearchAlwaysVisibleTools...),
ToolSearchAlwaysVisibleEffectiveTools: mergeToolNameLists(
h.config.MultiAgent.EinoMiddleware.ToolSearchAlwaysVisibleTools,
builtin.GetAllBuiltinTools(),
),
}
c.JSON(http.StatusOK, GetConfigResponse{
@@ -430,13 +431,7 @@ func (h *ConfigHandler) GetTools(c *gin.Context) {
continue
}
description := mcpTool.ShortDescription
if description == "" {
description = mcpTool.Description
}
if len(description) > 10000 {
description = description[:10000] + "..."
}
description := h.pickToolDescription(mcpTool.ShortDescription, mcpTool.Description)
toolInfo := ToolConfigInfo{
Name: mcpTool.Name,
@@ -689,11 +684,13 @@ func (h *ConfigHandler) UpdateConfig(c *gin.Context) {
if req.MultiAgent.PlanExecuteLoopMaxIterations != nil {
h.config.MultiAgent.PlanExecuteLoopMaxIterations = *req.MultiAgent.PlanExecuteLoopMaxIterations
}
h.config.MultiAgent.EinoMiddleware.ToolSearchAlwaysVisibleTools = dedupeToolNameList(req.MultiAgent.ToolSearchAlwaysVisibleTools)
h.logger.Info("更新多代理配置",
zap.Bool("enabled", h.config.MultiAgent.Enabled),
zap.Bool("robot_use_multi_agent", h.config.MultiAgent.RobotUseMultiAgent),
zap.Bool("batch_use_multi_agent", h.config.MultiAgent.BatchUseMultiAgent),
zap.Int("plan_execute_loop_max_iterations", h.config.MultiAgent.PlanExecuteLoopMaxIterations),
zap.Int("tool_search_always_visible_tools", len(h.config.MultiAgent.EinoMiddleware.ToolSearchAlwaysVisibleTools)),
)
}
@@ -1061,6 +1058,7 @@ func (h *ConfigHandler) ApplyConfig(c *gin.Context) {
if h.agent != nil {
h.agent.UpdateConfig(&h.config.OpenAI)
h.agent.UpdateMaxIterations(h.config.Agent.MaxIterations)
h.agent.UpdateToolDescriptionMode(h.config.Security.ToolDescriptionMode)
h.logger.Info("Agent配置已更新")
}
@@ -1383,6 +1381,33 @@ func updateMultiAgentConfig(doc *yaml.Node, cfg config.MultiAgentConfig) {
setBoolInMap(maNode, "robot_use_multi_agent", cfg.RobotUseMultiAgent)
setBoolInMap(maNode, "batch_use_multi_agent", cfg.BatchUseMultiAgent)
setIntInMap(maNode, "plan_execute_loop_max_iterations", cfg.PlanExecuteLoopMaxIterations)
mwNode := ensureMap(maNode, "eino_middleware")
setFlowStringSliceInMap(mwNode, "tool_search_always_visible_tools", dedupeToolNameList(cfg.EinoMiddleware.ToolSearchAlwaysVisibleTools))
}
func dedupeToolNameList(in []string) []string {
if len(in) == 0 {
return []string{}
}
seen := make(map[string]struct{}, len(in))
out := make([]string, 0, len(in))
for _, name := range in {
n := strings.TrimSpace(name)
if n == "" {
continue
}
key := strings.ToLower(n)
if _, ok := seen[key]; ok {
continue
}
seen[key] = struct{}{}
out = append(out, n)
}
return out
}
func mergeToolNameLists(a, b []string) []string {
return dedupeToolNameList(append(append([]string{}, a...), b...))
}
func ensureMap(parent *yaml.Node, path ...string) *yaml.Node {
-1
View File
@@ -230,4 +230,3 @@ func (h *ConversationHandler) DeleteConversationTurn(c *gin.Context) {
"message": "ok",
})
}
+7 -5
View File
@@ -175,6 +175,7 @@ func (h *AgentHandler) EinoSingleAgentLoopStream(c *gin.Context) {
)
if runErr != nil {
h.persistEinoAgentTraceForResume(conversationID, result)
cause := context.Cause(baseCtx)
if errors.Is(cause, ErrTaskCancelled) {
taskStatus = "cancelled"
@@ -239,9 +240,9 @@ func (h *AgentHandler) EinoSingleAgentLoopStream(c *gin.Context) {
)
}
if result.LastReActInput != "" || result.LastReActOutput != "" {
if err := h.db.SaveReActData(conversationID, result.LastReActInput, result.LastReActOutput); err != nil {
h.logger.Warn("保存 ReAct 数据失败", zap.Error(err))
if result.LastAgentTraceInput != "" || result.LastAgentTraceOutput != "" {
if err := h.db.SaveAgentTrace(conversationID, result.LastAgentTraceInput, result.LastAgentTraceOutput); err != nil {
h.logger.Warn("保存代理轨迹失败", zap.Error(err))
}
}
@@ -306,6 +307,7 @@ func (h *AgentHandler) EinoSingleAgentLoop(c *gin.Context) {
progressCallback,
)
if runErr != nil {
h.persistEinoAgentTraceForResume(prep.ConversationID, result)
c.JSON(http.StatusInternalServerError, gin.H{"error": runErr.Error()})
return
}
@@ -323,8 +325,8 @@ func (h *AgentHandler) EinoSingleAgentLoop(c *gin.Context) {
prep.AssistantMessageID,
)
}
if result.LastReActInput != "" || result.LastReActOutput != "" {
_ = h.db.SaveReActData(prep.ConversationID, result.LastReActInput, result.LastReActOutput)
if result.LastAgentTraceInput != "" || result.LastAgentTraceOutput != "" {
_ = h.db.SaveAgentTrace(prep.ConversationID, result.LastAgentTraceInput, result.LastAgentTraceOutput)
}
c.JSON(http.StatusOK, gin.H{
+10 -10
View File
@@ -247,7 +247,7 @@ func TestExternalMCPHandler_DeleteExternalMCP(t *testing.T) {
// 先添加一个配置
configObj := config.ExternalMCPServerConfig{
Command: "python3",
Command: "python3",
ExternalMCPEnable: true,
}
handler.manager.AddOrUpdateConfig("test-delete", configObj)
@@ -276,11 +276,11 @@ func TestExternalMCPHandler_GetExternalMCPs(t *testing.T) {
// 添加多个配置
handler.manager.AddOrUpdateConfig("test1", config.ExternalMCPServerConfig{
Command: "python3",
Command: "python3",
ExternalMCPEnable: true,
})
handler.manager.AddOrUpdateConfig("test2", config.ExternalMCPServerConfig{
URL: "http://127.0.0.1:8081/mcp",
URL: "http://127.0.0.1:8081/mcp",
ExternalMCPEnable: false,
})
@@ -319,15 +319,15 @@ func TestExternalMCPHandler_GetExternalMCPStats(t *testing.T) {
// 添加配置
handler.manager.AddOrUpdateConfig("enabled1", config.ExternalMCPServerConfig{
Command: "python3",
Command: "python3",
ExternalMCPEnable: true,
})
handler.manager.AddOrUpdateConfig("enabled2", config.ExternalMCPServerConfig{
URL: "http://127.0.0.1:8081/mcp",
URL: "http://127.0.0.1:8081/mcp",
ExternalMCPEnable: true,
})
handler.manager.AddOrUpdateConfig("disabled1", config.ExternalMCPServerConfig{
Command: "python3",
Command: "python3",
})
req := httptest.NewRequest("GET", "/api/external-mcp/stats", nil)
@@ -360,7 +360,7 @@ func TestExternalMCPHandler_StartStopExternalMCP(t *testing.T) {
// 添加一个禁用的配置
handler.manager.AddOrUpdateConfig("test-start-stop", config.ExternalMCPServerConfig{
Command: "python3",
Command: "python3",
})
// 测试启动(可能会失败,因为没有真实的服务器)
@@ -416,7 +416,7 @@ func TestExternalMCPHandler_AddOrUpdateExternalMCP_EmptyName(t *testing.T) {
router, _, _ := setupTestRouter()
configObj := config.ExternalMCPServerConfig{
Command: "python3",
Command: "python3",
ExternalMCPEnable: true,
}
@@ -459,14 +459,14 @@ func TestExternalMCPHandler_UpdateExistingConfig(t *testing.T) {
// 先添加配置
config1 := config.ExternalMCPServerConfig{
Command: "python3",
Command: "python3",
ExternalMCPEnable: true,
}
handler.manager.AddOrUpdateConfig("test-update", config1)
// 更新配置
config2 := config.ExternalMCPServerConfig{
URL: "http://127.0.0.1:8081/mcp",
URL: "http://127.0.0.1:8081/mcp",
ExternalMCPEnable: true,
}
+21 -11
View File
@@ -85,7 +85,7 @@ CREATE TABLE IF NOT EXISTS hitl_conversation_configs (
enabled INTEGER NOT NULL DEFAULT 0,
mode TEXT NOT NULL DEFAULT 'off',
sensitive_tools TEXT NOT NULL DEFAULT '[]',
timeout_seconds INTEGER NOT NULL DEFAULT 300,
timeout_seconds INTEGER NOT NULL DEFAULT 0,
updated_at DATETIME NOT NULL
);`)
if err != nil {
@@ -133,7 +133,8 @@ func (m *HITLManager) ActivateConversation(conversationID string, req *HITLReque
tools[n] = struct{}{}
}
}
timeout := 5 * time.Minute
// timeout <= 0 means wait forever (no timeout).
timeout := time.Duration(0)
if req.TimeoutSeconds > 0 {
timeout = time.Duration(req.TimeoutSeconds) * time.Second
}
@@ -275,8 +276,8 @@ func (m *HITLManager) ensureConversationHITLModePersisted(conversationID, interr
}
cfg.Enabled = true
cfg.Mode = nm
if cfg.TimeoutSeconds <= 0 {
cfg.TimeoutSeconds = 300
if cfg.TimeoutSeconds < 0 {
cfg.TimeoutSeconds = 0
}
return m.SaveConversationConfig(conversationID, cfg)
}
@@ -341,7 +342,7 @@ func (m *HITLManager) SaveConversationConfig(conversationID string, req *HITLReq
return errors.New("conversationId is required")
}
if req == nil {
req = &HITLRequest{Enabled: false, Mode: "off", TimeoutSeconds: 300}
req = &HITLRequest{Enabled: false, Mode: "off", TimeoutSeconds: 0}
}
mode := normalizeHitlMode(req.Mode)
if !req.Enabled {
@@ -349,8 +350,8 @@ func (m *HITLManager) SaveConversationConfig(conversationID string, req *HITLReq
}
tools, _ := json.Marshal(req.SensitiveTools)
timeout := req.TimeoutSeconds
if timeout <= 0 {
timeout = 300
if timeout < 0 {
timeout = 0
}
_, err := m.db.Exec(`INSERT INTO hitl_conversation_configs
(conversation_id, enabled, mode, sensitive_tools, timeout_seconds, updated_at)
@@ -368,11 +369,14 @@ func (m *HITLManager) LoadConversationConfig(conversationID string) (*HITLReques
err := m.db.QueryRow(`SELECT enabled, mode, sensitive_tools, timeout_seconds FROM hitl_conversation_configs WHERE conversation_id = ?`, conversationID).
Scan(&enabledInt, &mode, &toolsJSON, &timeout)
if errors.Is(err, sql.ErrNoRows) {
return &HITLRequest{Enabled: false, Mode: "off", SensitiveTools: []string{}, TimeoutSeconds: 300}, nil
return &HITLRequest{Enabled: false, Mode: "off", SensitiveTools: []string{}, TimeoutSeconds: 0}, nil
}
if err != nil {
return nil, err
}
if timeout < 0 {
timeout = 0
}
tools := make([]string, 0)
_ = json.Unmarshal([]byte(toolsJSON), &tools)
return &HITLRequest{
@@ -389,6 +393,12 @@ func (m *HITLManager) waitDecision(ctx context.Context, p *pendingInterrupt, tim
delete(m.pending, p.InterruptID)
m.mu.Unlock()
}()
var timeoutCh <-chan time.Time
if timeout > 0 {
timer := time.NewTimer(timeout)
defer timer.Stop()
timeoutCh = timer.C
}
select {
case d := <-p.decideCh:
// 只有 review_edit 模式允许改参;其他模式一律忽略 edited arguments
@@ -398,7 +408,7 @@ func (m *HITLManager) waitDecision(ctx context.Context, p *pendingInterrupt, tim
_, _ = m.db.Exec(`UPDATE hitl_interrupts SET status='decided', decision=?, decision_comment=?, decided_at=? WHERE id=?`,
d.Decision, d.Comment, time.Now(), p.InterruptID)
return d, nil
case <-time.After(timeout):
case <-timeoutCh:
_, _ = m.db.Exec(`UPDATE hitl_interrupts SET status='timeout', decision='approve', decision_comment='timeout auto approve', decided_at=? WHERE id=?`,
time.Now(), p.InterruptID)
return hitlDecision{Decision: "approve", Comment: "timeout auto approve"}, nil
@@ -718,8 +728,8 @@ func (h *AgentHandler) GetHITLConversationConfig(c *gin.Context) {
cfg2 := *cfg
cfg2.Enabled = true
cfg2.Mode = normalizeHitlMode(pendMode)
if cfg2.TimeoutSeconds <= 0 {
cfg2.TimeoutSeconds = 300
if cfg2.TimeoutSeconds < 0 {
cfg2.TimeoutSeconds = 0
}
cfg = &cfg2
}
+10 -10
View File
@@ -131,16 +131,16 @@ func (h *MarkdownAgentsHandler) GetMarkdownAgent(c *gin.Context) {
}
type markdownAgentBody struct {
Filename string `json:"filename"`
ID string `json:"id"`
Name string `json:"name"`
Description string `json:"description"`
Tools []string `json:"tools"`
Instruction string `json:"instruction"`
BindRole string `json:"bind_role"`
MaxIterations int `json:"max_iterations"`
Kind string `json:"kind"`
Raw string `json:"raw"`
Filename string `json:"filename"`
ID string `json:"id"`
Name string `json:"name"`
Description string `json:"description"`
Tools []string `json:"tools"`
Instruction string `json:"instruction"`
BindRole string `json:"bind_role"`
MaxIterations int `json:"max_iterations"`
Kind string `json:"kind"`
Raw string `json:"raw"`
}
// CreateMarkdownAgent POST /api/multi-agent/markdown-agents
+5 -8
View File
@@ -42,11 +42,11 @@ func (h *MonitorHandler) SetExternalMCPManager(mgr *mcp.ExternalMCPManager) {
type MonitorResponse struct {
Executions []*mcp.ToolExecution `json:"executions"`
Stats map[string]*mcp.ToolStats `json:"stats"`
Timestamp time.Time `json:"timestamp"`
Total int `json:"total,omitempty"`
Page int `json:"page,omitempty"`
PageSize int `json:"page_size,omitempty"`
TotalPages int `json:"total_pages,omitempty"`
Timestamp time.Time `json:"timestamp"`
Total int `json:"total,omitempty"`
Page int `json:"page,omitempty"`
PageSize int `json:"page_size,omitempty"`
TotalPages int `json:"total_pages,omitempty"`
}
// Monitor 获取监控信息
@@ -213,7 +213,6 @@ func (h *MonitorHandler) loadStats() map[string]*mcp.ToolStats {
return stats
}
// GetExecution 获取特定执行记录
func (h *MonitorHandler) GetExecution(c *gin.Context) {
id := c.Param("id")
@@ -416,5 +415,3 @@ func (h *MonitorHandler) DeleteExecutions(c *gin.Context) {
h.logger.Info("尝试批量删除内存中的执行记录", zap.Int("count", len(request.IDs)))
c.JSON(http.StatusOK, gin.H{"message": "执行记录已删除(如果存在)"})
}
+21 -6
View File
@@ -185,6 +185,7 @@ func (h *AgentHandler) MultiAgentLoopStream(c *gin.Context) {
)
if runErr != nil {
h.persistEinoAgentTraceForResume(conversationID, result)
cause := context.Cause(baseCtx)
if errors.Is(cause, ErrTaskCancelled) {
taskStatus = "cancelled"
@@ -249,9 +250,9 @@ func (h *AgentHandler) MultiAgentLoopStream(c *gin.Context) {
)
}
if result.LastReActInput != "" || result.LastReActOutput != "" {
if err := h.db.SaveReActData(conversationID, result.LastReActInput, result.LastReActOutput); err != nil {
h.logger.Warn("保存 ReAct 数据失败", zap.Error(err))
if result.LastAgentTraceInput != "" || result.LastAgentTraceOutput != "" {
if err := h.db.SaveAgentTrace(conversationID, result.LastAgentTraceInput, result.LastAgentTraceOutput); err != nil {
h.logger.Warn("保存代理轨迹失败", zap.Error(err))
}
}
@@ -318,6 +319,7 @@ func (h *AgentHandler) MultiAgentLoop(c *gin.Context) {
strings.TrimSpace(req.Orchestration),
)
if runErr != nil {
h.persistEinoAgentTraceForResume(prep.ConversationID, result)
h.logger.Error("Eino DeepAgent 执行失败", zap.Error(runErr))
errMsg := "执行失败: " + runErr.Error()
if prep.AssistantMessageID != "" {
@@ -341,9 +343,9 @@ func (h *AgentHandler) MultiAgentLoop(c *gin.Context) {
)
}
if result.LastReActInput != "" || result.LastReActOutput != "" {
if err := h.db.SaveReActData(prep.ConversationID, result.LastReActInput, result.LastReActOutput); err != nil {
h.logger.Warn("保存 ReAct 数据失败", zap.Error(err))
if result.LastAgentTraceInput != "" || result.LastAgentTraceOutput != "" {
if err := h.db.SaveAgentTrace(prep.ConversationID, result.LastAgentTraceInput, result.LastAgentTraceOutput); err != nil {
h.logger.Warn("保存代理轨迹失败", zap.Error(err))
}
}
@@ -355,6 +357,19 @@ func (h *AgentHandler) MultiAgentLoop(c *gin.Context) {
})
}
// persistEinoAgentTraceForResume 在 Eino 运行异常结束时写入代理轨迹(库列 last_react_*),供下一请求 loadHistoryFromAgentTrace 软续跑。
func (h *AgentHandler) persistEinoAgentTraceForResume(conversationID string, result *multiagent.RunResult) {
if h == nil || result == nil {
return
}
if result.LastAgentTraceInput == "" && result.LastAgentTraceOutput == "" {
return
}
if err := h.db.SaveAgentTrace(conversationID, result.LastAgentTraceInput, result.LastAgentTraceOutput); err != nil {
h.logger.Warn("保存 Eino 续跑上下文失败", zap.String("conversationId", conversationID), zap.Error(err))
}
}
func multiAgentHTTPErrorStatus(err error) (int, string) {
msg := err.Error()
switch {
+2 -7
View File
@@ -49,7 +49,7 @@ func (h *AgentHandler) prepareMultiAgentSession(req *ChatRequest) (*multiAgentPr
}
}
agentHistoryMessages, err := h.loadHistoryFromReActData(conversationID)
agentHistoryMessages, err := h.loadHistoryFromAgentTrace(conversationID)
if err != nil {
historyMessages, getErr := h.db.GetMessages(conversationID)
if getErr != nil {
@@ -73,12 +73,7 @@ func (h *AgentHandler) prepareMultiAgentSession(req *ChatRequest) (*multiAgentPr
h.logger.Warn("WebShell AI 助手:未找到连接", zap.String("id", req.WebShellConnectionID), zap.Error(errConn))
return nil, fmt.Errorf("未找到该 WebShell 连接")
}
remark := conn.Remark
if remark == "" {
remark = conn.URL
}
webshellContext := fmt.Sprintf("[WebShell 助手上下文] 当前连接 ID:%s,备注:%s。可用工具(仅在该连接上操作时使用,connection_id 填 \"%s\"):webshell_exec、webshell_file_list、webshell_file_read、webshell_file_write、record_vulnerability、list_knowledge_risk_types、search_knowledge_base。Skills 包请使用 Eino 多代理内置 `skill` 工具。\n\n用户请求:%s",
conn.ID, remark, conn.ID, req.Message)
webshellContext := BuildWebshellAssistantContext(conn, WebshellSkillHintMultiAgent, req.Message)
// WebShell 模式下如果同时指定了角色,追加角色 user_prompt(工具集仍仅限 webshell 专用工具)
if req.Role != "" && req.Role != "默认" && h.config != nil && h.config.Roles != nil {
if role, exists := h.config.Roles[req.Role]; exists && role.Enabled && role.UserPrompt != "" {
+642
View File
@@ -0,0 +1,642 @@
package handler
import (
"fmt"
"net/http"
"sort"
"strconv"
"strings"
"time"
"cyberstrike-ai/internal/database"
"github.com/gin-gonic/gin"
"go.uber.org/zap"
)
// NotificationHandler 聚合通知(Phase 2:服务端统一计算)
type NotificationHandler struct {
db *database.DB
agentHandler *AgentHandler
logger *zap.Logger
}
const notificationReadMaxRows = 150
// NotificationSummaryItem 通知项
type NotificationSummaryItem struct {
ID string `json:"id"`
Level string `json:"level"` // p0/p1/p2
Type string `json:"type"`
Title string `json:"title"`
Desc string `json:"desc"`
Ts string `json:"ts"` // RFC3339
Count int `json:"count,omitempty"`
Actionable bool `json:"actionable"`
Read bool `json:"read"`
// 以下字段用于前端深链跳转(通知即入口)
ConversationID string `json:"conversationId,omitempty"`
VulnerabilityID string `json:"vulnerabilityId,omitempty"`
ExecutionID string `json:"executionId,omitempty"`
InterruptID string `json:"interruptId,omitempty"`
}
// NotificationSummaryResponse 聚合响应
type NotificationSummaryResponse struct {
SinceMs int64 `json:"sinceMs"`
GeneratedAt string `json:"generatedAt"`
P0Count int `json:"p0Count"`
UnreadCount int `json:"unreadCount"`
Counts map[string]int `json:"counts"`
Items []NotificationSummaryItem `json:"items"`
}
func NewNotificationHandler(db *database.DB, agentHandler *AgentHandler, logger *zap.Logger) *NotificationHandler {
return &NotificationHandler{
db: db,
agentHandler: agentHandler,
logger: logger,
}
}
func parseSinceMs(raw string) int64 {
v := strings.TrimSpace(raw)
if v == "" {
return 0
}
if ms, err := strconv.ParseInt(v, 10, 64); err == nil && ms > 0 {
return ms
}
if t, err := time.Parse(time.RFC3339, v); err == nil {
return t.UnixMilli()
}
return 0
}
func unixSecToRFC3339(sec int64) string {
if sec <= 0 {
return time.Now().UTC().Format(time.RFC3339)
}
return time.Unix(sec, 0).UTC().Format(time.RFC3339)
}
func normalizedSinceSec(sinceMs int64) int64 {
sec := sinceMs / 1000
// SQLite 默认时间精度到秒;给 1s 回看窗口,避免“同秒内新增”被漏算。
if sec > 0 {
return sec - 1
}
return 0
}
func normalizeSinceMs(raw int64) int64 {
if raw > 0 {
return raw
}
// 默认仅看最近 24 小时,避免首次打开拉全量历史噪音。
return time.Now().Add(-24 * time.Hour).UnixMilli()
}
func levelBySeverity(sev string) string {
switch strings.ToLower(strings.TrimSpace(sev)) {
case "critical", "high":
return "p0"
case "medium":
return "p1"
default:
return "p2"
}
}
func requestWantsEnglish(c *gin.Context) bool {
if c == nil {
return false
}
lang := strings.ToLower(strings.TrimSpace(c.Query("lang")))
if lang == "" {
lang = strings.ToLower(strings.TrimSpace(c.GetHeader("Accept-Language")))
}
return strings.HasPrefix(lang, "en")
}
func i18nText(english bool, zh string, en string) string {
if english {
return en
}
return zh
}
func (h *NotificationHandler) loadPendingHITLItems(limit int, english bool) ([]NotificationSummaryItem, error) {
rows, err := h.db.Query(`
SELECT
id,
conversation_id,
tool_name,
COALESCE(CAST(strftime('%s', created_at) AS INTEGER), 0)
FROM hitl_interrupts
WHERE status = 'pending'
ORDER BY created_at DESC
LIMIT ?
`, limit)
if err != nil {
return nil, err
}
defer rows.Close()
items := make([]NotificationSummaryItem, 0, limit)
for rows.Next() {
var id, conversationID, toolName string
var createdSec int64
if err := rows.Scan(&id, &conversationID, &toolName, &createdSec); err != nil {
continue
}
desc := i18nText(english, "会话 "+conversationID+" 的审批中断待处理", "Conversation "+conversationID+" has pending HITL approval")
if strings.TrimSpace(toolName) != "" {
desc = i18nText(english, "工具 "+toolName+" 等待审批", "Tool "+toolName+" is waiting for approval")
}
items = append(items, NotificationSummaryItem{
ID: "hitl:" + id,
Level: "p0",
Type: "hitl_pending",
Title: i18nText(english, "HITL 待审批", "HITL Pending Approval"),
Desc: desc,
Ts: unixSecToRFC3339(createdSec),
Count: 1,
Actionable: true,
Read: false,
ConversationID: conversationID,
InterruptID: id,
})
}
return items, nil
}
func (h *NotificationHandler) loadVulnerabilityItems(sinceMs int64, limit int, english bool) ([]NotificationSummaryItem, map[string]int, error) {
sinceSec := normalizedSinceSec(sinceMs)
rows, err := h.db.Query(`
SELECT
id,
title,
severity,
conversation_id,
COALESCE(CAST(strftime('%s', created_at) AS INTEGER), 0)
FROM vulnerabilities
WHERE CAST(strftime('%s', created_at) AS INTEGER) > ?
ORDER BY created_at DESC
LIMIT ?
`, sinceSec, limit)
if err != nil {
return nil, nil, err
}
defer rows.Close()
items := make([]NotificationSummaryItem, 0, limit)
counts := map[string]int{
"newCriticalVulns": 0,
"newHighVulns": 0,
"newMediumVulns": 0,
"newLowVulns": 0,
"newInfoVulns": 0,
}
for rows.Next() {
var id, title, severity, conversationID string
var createdSec int64
if err := rows.Scan(&id, &title, &severity, &conversationID, &createdSec); err != nil {
continue
}
switch strings.ToLower(strings.TrimSpace(severity)) {
case "critical":
counts["newCriticalVulns"]++
case "high":
counts["newHighVulns"]++
case "medium":
counts["newMediumVulns"]++
case "low":
counts["newLowVulns"]++
default:
counts["newInfoVulns"]++
}
sevUpper := strings.ToUpper(strings.TrimSpace(severity))
if sevUpper == "" {
sevUpper = "INFO"
}
finalTitle := i18nText(english, "新漏洞("+sevUpper+"", "New Vulnerability ("+sevUpper+")")
finalDesc := strings.TrimSpace(title)
if finalDesc == "" {
finalDesc = i18nText(english, "(无标题)", "(Untitled)")
}
items = append(items, NotificationSummaryItem{
ID: "vuln:" + id,
Level: levelBySeverity(severity),
Type: "vulnerability_created",
Title: finalTitle,
Desc: finalDesc,
Ts: unixSecToRFC3339(createdSec),
Count: 1,
Actionable: false,
Read: false,
ConversationID: conversationID,
VulnerabilityID: id,
})
}
return items, counts, nil
}
func (h *NotificationHandler) loadFailedExecutionItems(sinceMs int64, limit int, english bool) ([]NotificationSummaryItem, int, error) {
sinceSec := normalizedSinceSec(sinceMs)
rows, err := h.db.Query(`
SELECT
id,
tool_name,
COALESCE(CAST(strftime('%s', start_time) AS INTEGER), 0)
FROM tool_executions
WHERE status = 'failed'
AND CAST(strftime('%s', start_time) AS INTEGER) > ?
ORDER BY start_time DESC
LIMIT ?
`, sinceSec, limit)
if err != nil {
return nil, 0, err
}
defer rows.Close()
items := make([]NotificationSummaryItem, 0, limit)
count := 0
for rows.Next() {
var id, toolName string
var startSec int64
if err := rows.Scan(&id, &toolName, &startSec); err != nil {
continue
}
count++
if strings.TrimSpace(toolName) == "" {
toolName = i18nText(english, "未知工具", "unknown")
}
items = append(items, NotificationSummaryItem{
ID: "exec_failed:" + id,
Level: "p0",
Type: "task_failed",
Title: i18nText(english, "任务执行失败", "Task Execution Failed"),
Desc: i18nText(english, "工具 "+toolName+" 执行失败", "Tool "+toolName+" execution failed"),
Ts: unixSecToRFC3339(startSec),
Count: 1,
Actionable: false,
Read: false,
ExecutionID: id,
})
}
return items, count, nil
}
func (h *NotificationHandler) summarizeLongRunningTasks(threshold time.Duration, english bool) ([]NotificationSummaryItem, int) {
if h.agentHandler == nil || h.agentHandler.tasks == nil {
return nil, 0
}
tasks := h.agentHandler.tasks.GetActiveTasks()
now := time.Now()
items := make([]NotificationSummaryItem, 0, len(tasks))
for _, t := range tasks {
if t == nil {
continue
}
if now.Sub(t.StartedAt) >= threshold {
items = append(items, NotificationSummaryItem{
ID: "task_long:" + t.ConversationID,
Level: "p1",
Type: "long_running_tasks",
Title: i18nText(english, "长时间运行任务", "Long Running Task"),
Desc: i18nText(english, "会话 "+t.ConversationID+" 运行超过 15 分钟", "Conversation "+t.ConversationID+" has been running over 15 minutes"),
Ts: t.StartedAt.UTC().Format(time.RFC3339),
Count: 1,
Actionable: true,
Read: false,
ConversationID: t.ConversationID,
})
}
}
return items, len(items)
}
func (h *NotificationHandler) summarizeCompletedTasksSince(sinceMs int64, limit int, english bool) ([]NotificationSummaryItem, int) {
if h.agentHandler == nil || h.agentHandler.tasks == nil {
return nil, 0
}
since := time.UnixMilli(sinceMs)
completed := h.agentHandler.tasks.GetCompletedTasks()
items := make([]NotificationSummaryItem, 0, limit)
for _, t := range completed {
if t == nil {
continue
}
if t.CompletedAt.After(since) {
items = append(items, NotificationSummaryItem{
ID: "task_completed:" + t.ConversationID + ":" + strconv.FormatInt(t.CompletedAt.Unix(), 10),
Level: "p2",
Type: "task_completed",
Title: i18nText(english, "任务完成", "Task Completed"),
Desc: i18nText(english, "会话 "+t.ConversationID+" 已完成", "Conversation "+t.ConversationID+" completed"),
Ts: t.CompletedAt.UTC().Format(time.RFC3339),
Count: 1,
Actionable: false,
Read: false,
ConversationID: t.ConversationID,
})
if len(items) >= limit {
break
}
}
}
return items, len(items)
}
func buildPlaceholders(n int) string {
if n <= 0 {
return ""
}
out := make([]string, 0, n)
for i := 0; i < n; i++ {
out = append(out, "?")
}
return strings.Join(out, ",")
}
func (h *NotificationHandler) readStatesByIDs(ids []string) (map[string]bool, error) {
result := make(map[string]bool, len(ids))
if len(ids) == 0 {
return result, nil
}
holders := buildPlaceholders(len(ids))
query := "SELECT event_id FROM notification_reads WHERE event_id IN (" + holders + ")"
args := make([]interface{}, 0, len(ids))
for _, id := range ids {
args = append(args, id)
}
rows, err := h.db.Query(query, args...)
if err != nil {
return result, err
}
defer rows.Close()
for rows.Next() {
var id string
if err := rows.Scan(&id); err != nil {
continue
}
result[id] = true
}
return result, nil
}
func (h *NotificationHandler) applyReadStates(items []NotificationSummaryItem) ([]NotificationSummaryItem, error) {
markableIDs := make([]string, 0, len(items))
for _, item := range items {
if item.Actionable {
continue
}
markableIDs = append(markableIDs, item.ID)
}
readMap, err := h.readStatesByIDs(markableIDs)
if err != nil {
return items, err
}
for i := range items {
if items[i].Actionable {
items[i].Read = false
continue
}
items[i].Read = readMap[items[i].ID]
}
return items, nil
}
func filterVisibleItems(items []NotificationSummaryItem) []NotificationSummaryItem {
out := make([]NotificationSummaryItem, 0, len(items))
for _, item := range items {
if item.Actionable || !item.Read {
out = append(out, item)
}
}
return out
}
func countP0(items []NotificationSummaryItem) int {
total := 0
for _, item := range items {
if item.Level == "p0" {
if item.Count > 0 {
total += item.Count
} else {
total++
}
}
}
return total
}
func countUnread(items []NotificationSummaryItem) int {
total := 0
for _, item := range items {
if item.Actionable || !item.Read {
if item.Count > 0 {
total += item.Count
} else {
total++
}
}
}
return total
}
func createNotificationReadTableIfNeeded(db *database.DB) error {
if db == nil {
return fmt.Errorf("db is nil")
}
_, err := db.Exec(`
CREATE TABLE IF NOT EXISTS notification_reads (
event_id TEXT PRIMARY KEY,
read_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP
);
`)
if err != nil {
return err
}
_, idxErr := db.Exec(`CREATE INDEX IF NOT EXISTS idx_notification_reads_read_at ON notification_reads(read_at DESC);`)
return idxErr
}
func pruneNotificationReads(db *database.DB, maxRows int) error {
if db == nil {
return fmt.Errorf("db is nil")
}
if maxRows <= 0 {
return nil
}
_, err := db.Exec(`
DELETE FROM notification_reads
WHERE event_id NOT IN (
SELECT event_id
FROM notification_reads
ORDER BY read_at DESC, rowid DESC
LIMIT ?
)
`, maxRows)
return err
}
type markReadRequest struct {
EventIDs []string `json:"eventIds"`
}
func normalizeMarkableEventID(id string) (string, bool) {
v := strings.TrimSpace(id)
if v == "" {
return "", false
}
// 仅允许“可读后隐藏”的信息类事件;Actionable 事件不参与 read 标记。
allowedPrefixes := []string{
"vuln:",
"exec_failed:",
"task_completed:",
}
for _, prefix := range allowedPrefixes {
if strings.HasPrefix(v, prefix) {
return v, true
}
}
return "", false
}
// MarkRead 按事件 ID 标记已读
func (h *NotificationHandler) MarkRead(c *gin.Context) {
if err := createNotificationReadTableIfNeeded(h.db); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to prepare notification read table"})
return
}
var req markReadRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid request body"})
return
}
if len(req.EventIDs) == 0 {
c.JSON(http.StatusOK, gin.H{"ok": true, "marked": 0})
return
}
tx, err := h.db.Begin()
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to begin transaction"})
return
}
defer func() {
_ = tx.Rollback()
}()
stmt, err := tx.Prepare(`
INSERT INTO notification_reads(event_id, read_at)
VALUES(?, CURRENT_TIMESTAMP)
ON CONFLICT(event_id) DO UPDATE SET read_at = CURRENT_TIMESTAMP
`)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to prepare statement"})
return
}
defer stmt.Close()
marked := 0
for _, raw := range req.EventIDs {
id, ok := normalizeMarkableEventID(raw)
if !ok {
continue
}
if _, err := stmt.Exec(id); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to mark read"})
return
}
marked++
}
if err := tx.Commit(); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to commit read marks"})
return
}
if err := pruneNotificationReads(h.db, notificationReadMaxRows); err != nil {
h.logger.Warn("裁剪通知已读记录失败", zap.Error(err))
}
c.JSON(http.StatusOK, gin.H{"ok": true, "marked": marked})
}
// GetSummary 返回通知聚合视图(用于头部铃铛)
func (h *NotificationHandler) GetSummary(c *gin.Context) {
if h.db == nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "database unavailable"})
return
}
if err := createNotificationReadTableIfNeeded(h.db); err != nil {
h.logger.Warn("初始化通知已读表失败", zap.Error(err))
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to initialize notification read table"})
return
}
english := requestWantsEnglish(c)
sinceMs := normalizeSinceMs(parseSinceMs(c.Query("since")))
limit, _ := strconv.Atoi(strings.TrimSpace(c.DefaultQuery("limit", "50")))
if limit <= 0 {
limit = 50
}
if limit > 200 {
limit = 200
}
hitlItems, err := h.loadPendingHITLItems(limit, english)
if err != nil {
h.logger.Warn("加载 HITL 通知失败", zap.Error(err))
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to summarize hitl notifications"})
return
}
vulnItems, vulnCounts, err := h.loadVulnerabilityItems(sinceMs, limit, english)
if err != nil {
h.logger.Warn("加载漏洞通知失败", zap.Error(err))
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to summarize vulnerabilities"})
return
}
longRunningItems, longRunningCount := h.summarizeLongRunningTasks(15*time.Minute, english)
completedItems, completedCount := h.summarizeCompletedTasksSince(sinceMs, limit, english)
items := make([]NotificationSummaryItem, 0, len(hitlItems)+len(vulnItems)+len(longRunningItems)+len(completedItems))
items = append(items, hitlItems...)
items = append(items, vulnItems...)
items = append(items, longRunningItems...)
items = append(items, completedItems...)
items, err = h.applyReadStates(items)
if err != nil {
h.logger.Warn("加载通知已读状态失败", zap.Error(err))
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to load notification read states"})
return
}
items = filterVisibleItems(items)
sort.Slice(items, func(i, j int) bool {
ti, errI := time.Parse(time.RFC3339, items[i].Ts)
tj, errJ := time.Parse(time.RFC3339, items[j].Ts)
if errI != nil || errJ != nil {
return i < j
}
return ti.After(tj)
})
p0Count := countP0(items)
unreadCount := countUnread(items)
c.JSON(http.StatusOK, NotificationSummaryResponse{
SinceMs: sinceMs,
GeneratedAt: time.Now().UTC().Format(time.RFC3339),
P0Count: p0Count,
UnreadCount: unreadCount,
Counts: map[string]int{
"hitlPending": len(hitlItems),
"newCriticalVulns": vulnCounts["newCriticalVulns"],
"newHighVulns": vulnCounts["newHighVulns"],
"newMediumVulns": vulnCounts["newMediumVulns"],
"newLowVulns": vulnCounts["newLowVulns"],
"newInfoVulns": vulnCounts["newInfoVulns"],
"failedExecutions": 0,
"longRunningTasks": longRunningCount,
"completedTasks": completedCount,
},
Items: items,
})
}
+27 -27
View File
@@ -4445,7 +4445,7 @@ func (h *OpenAPIHandler) GetOpenAPISpec(c *gin.Context) {
"content": map[string]interface{}{
"application/json": map[string]interface{}{
"schema": map[string]interface{}{
"type": "object",
"type": "object",
"required": []string{"messageId"},
"properties": map[string]interface{}{
"messageId": map[string]interface{}{
@@ -4689,7 +4689,7 @@ func (h *OpenAPIHandler) GetOpenAPISpec(c *gin.Context) {
"content": map[string]interface{}{
"application/json": map[string]interface{}{
"schema": map[string]interface{}{
"type": "object",
"type": "object",
"required": []string{"scheduleEnabled"},
"properties": map[string]interface{}{
"scheduleEnabled": map[string]interface{}{"type": "boolean", "description": "是否启用自动调度"},
@@ -4761,7 +4761,7 @@ func (h *OpenAPIHandler) GetOpenAPISpec(c *gin.Context) {
"content": map[string]interface{}{
"application/json": map[string]interface{}{
"schema": map[string]interface{}{
"type": "object",
"type": "object",
"required": []string{"query"},
"properties": map[string]interface{}{
"query": map[string]interface{}{"type": "string", "description": "FOFA查询语法", "example": "domain=\"example.com\""},
@@ -4810,7 +4810,7 @@ func (h *OpenAPIHandler) GetOpenAPISpec(c *gin.Context) {
"content": map[string]interface{}{
"application/json": map[string]interface{}{
"schema": map[string]interface{}{
"type": "object",
"type": "object",
"required": []string{"text"},
"properties": map[string]interface{}{
"text": map[string]interface{}{"type": "string", "description": "自然语言描述", "example": "查找使用WordPress的网站"},
@@ -4853,7 +4853,7 @@ func (h *OpenAPIHandler) GetOpenAPISpec(c *gin.Context) {
"content": map[string]interface{}{
"application/json": map[string]interface{}{
"schema": map[string]interface{}{
"type": "object",
"type": "object",
"required": []string{"api_key", "model"},
"properties": map[string]interface{}{
"provider": map[string]interface{}{"type": "string", "description": "LLM提供商(openai/claude", "example": "openai"},
@@ -4900,7 +4900,7 @@ func (h *OpenAPIHandler) GetOpenAPISpec(c *gin.Context) {
"content": map[string]interface{}{
"application/json": map[string]interface{}{
"schema": map[string]interface{}{
"type": "object",
"type": "object",
"required": []string{"command"},
"properties": map[string]interface{}{
"command": map[string]interface{}{"type": "string", "description": "要执行的命令"},
@@ -4943,7 +4943,7 @@ func (h *OpenAPIHandler) GetOpenAPISpec(c *gin.Context) {
"content": map[string]interface{}{
"application/json": map[string]interface{}{
"schema": map[string]interface{}{
"type": "object",
"type": "object",
"required": []string{"command"},
"properties": map[string]interface{}{
"command": map[string]interface{}{"type": "string", "description": "要执行的命令"},
@@ -5027,7 +5027,7 @@ func (h *OpenAPIHandler) GetOpenAPISpec(c *gin.Context) {
"content": map[string]interface{}{
"application/json": map[string]interface{}{
"schema": map[string]interface{}{
"type": "object",
"type": "object",
"required": []string{"url"},
"properties": map[string]interface{}{
"url": map[string]interface{}{"type": "string", "description": "WebShell URL"},
@@ -5231,7 +5231,7 @@ func (h *OpenAPIHandler) GetOpenAPISpec(c *gin.Context) {
"content": map[string]interface{}{
"application/json": map[string]interface{}{
"schema": map[string]interface{}{
"type": "object",
"type": "object",
"required": []string{"url", "command"},
"properties": map[string]interface{}{
"url": map[string]interface{}{"type": "string", "description": "WebShell URL"},
@@ -5277,7 +5277,7 @@ func (h *OpenAPIHandler) GetOpenAPISpec(c *gin.Context) {
"content": map[string]interface{}{
"application/json": map[string]interface{}{
"schema": map[string]interface{}{
"type": "object",
"type": "object",
"required": []string{"url", "action", "path"},
"properties": map[string]interface{}{
"url": map[string]interface{}{"type": "string", "description": "WebShell URL"},
@@ -5339,14 +5339,14 @@ func (h *OpenAPIHandler) GetOpenAPISpec(c *gin.Context) {
"items": map[string]interface{}{
"type": "object",
"properties": map[string]interface{}{
"relativePath": map[string]interface{}{"type": "string"},
"absolutePath": map[string]interface{}{"type": "string"},
"name": map[string]interface{}{"type": "string"},
"size": map[string]interface{}{"type": "integer"},
"modifiedUnix": map[string]interface{}{"type": "integer"},
"date": map[string]interface{}{"type": "string"},
"relativePath": map[string]interface{}{"type": "string"},
"absolutePath": map[string]interface{}{"type": "string"},
"name": map[string]interface{}{"type": "string"},
"size": map[string]interface{}{"type": "integer"},
"modifiedUnix": map[string]interface{}{"type": "integer"},
"date": map[string]interface{}{"type": "string"},
"conversationId": map[string]interface{}{"type": "string"},
"subPath": map[string]interface{}{"type": "string"},
"subPath": map[string]interface{}{"type": "string"},
},
},
},
@@ -5369,7 +5369,7 @@ func (h *OpenAPIHandler) GetOpenAPISpec(c *gin.Context) {
"content": map[string]interface{}{
"multipart/form-data": map[string]interface{}{
"schema": map[string]interface{}{
"type": "object",
"type": "object",
"required": []string{"file"},
"properties": map[string]interface{}{
"file": map[string]interface{}{"type": "string", "format": "binary", "description": "上传的文件"},
@@ -5410,7 +5410,7 @@ func (h *OpenAPIHandler) GetOpenAPISpec(c *gin.Context) {
"content": map[string]interface{}{
"application/json": map[string]interface{}{
"schema": map[string]interface{}{
"type": "object",
"type": "object",
"required": []string{"path"},
"properties": map[string]interface{}{
"path": map[string]interface{}{"type": "string", "description": "文件相对路径"},
@@ -5485,7 +5485,7 @@ func (h *OpenAPIHandler) GetOpenAPISpec(c *gin.Context) {
"content": map[string]interface{}{
"application/json": map[string]interface{}{
"schema": map[string]interface{}{
"type": "object",
"type": "object",
"required": []string{"path", "content"},
"properties": map[string]interface{}{
"path": map[string]interface{}{"type": "string", "description": "文件相对路径"},
@@ -5512,7 +5512,7 @@ func (h *OpenAPIHandler) GetOpenAPISpec(c *gin.Context) {
"content": map[string]interface{}{
"application/json": map[string]interface{}{
"schema": map[string]interface{}{
"type": "object",
"type": "object",
"required": []string{"name"},
"properties": map[string]interface{}{
"parent": map[string]interface{}{"type": "string", "description": "父目录相对路径"},
@@ -5552,7 +5552,7 @@ func (h *OpenAPIHandler) GetOpenAPISpec(c *gin.Context) {
"content": map[string]interface{}{
"application/json": map[string]interface{}{
"schema": map[string]interface{}{
"type": "object",
"type": "object",
"required": []string{"path", "newName"},
"properties": map[string]interface{}{
"path": map[string]interface{}{"type": "string", "description": "当前文件相对路径"},
@@ -5646,7 +5646,7 @@ func (h *OpenAPIHandler) GetOpenAPISpec(c *gin.Context) {
"content": map[string]interface{}{
"application/json": map[string]interface{}{
"schema": map[string]interface{}{
"type": "object",
"type": "object",
"required": []string{"platform", "text"},
"properties": map[string]interface{}{
"platform": map[string]interface{}{"type": "string", "description": "平台类型", "enum": []string{"dingtalk", "lark", "wecom"}},
@@ -5712,7 +5712,7 @@ func (h *OpenAPIHandler) GetOpenAPISpec(c *gin.Context) {
"content": map[string]interface{}{
"application/json": map[string]interface{}{
"schema": map[string]interface{}{
"type": "object",
"type": "object",
"required": []string{"name"},
"properties": map[string]interface{}{
"filename": map[string]interface{}{"type": "string", "description": "文件名(可选,自动生成)"},
@@ -5932,7 +5932,7 @@ func (h *OpenAPIHandler) GetOpenAPISpec(c *gin.Context) {
"content": map[string]interface{}{
"application/json": map[string]interface{}{
"schema": map[string]interface{}{
"type": "object",
"type": "object",
"required": []string{"path"},
"properties": map[string]interface{}{
"path": map[string]interface{}{"type": "string", "description": "文件相对路径"},
@@ -5974,7 +5974,7 @@ func (h *OpenAPIHandler) GetOpenAPISpec(c *gin.Context) {
"content": map[string]interface{}{
"application/json": map[string]interface{}{
"schema": map[string]interface{}{
"type": "object",
"type": "object",
"required": []string{"ids"},
"properties": map[string]interface{}{
"ids": map[string]interface{}{
@@ -6197,7 +6197,7 @@ func (h *OpenAPIHandler) GetConversationResults(c *gin.Context) {
}
// 获取漏洞列表
vulnList, err := h.db.ListVulnerabilities(1000, 0, "", conversationID, "", "")
vulnList, err := h.db.ListVulnerabilities(1000, 0, "", conversationID, "", "", "", "", "")
if err != nil {
h.logger.Warn("获取漏洞列表失败", zap.Error(err))
vulnList = []*database.Vulnerability{}
+6 -6
View File
@@ -26,7 +26,7 @@ var apiDocI18nSummaryToKey = map[string]string{
"创建分组": "createGroup", "列出分组": "listGroups", "获取分组": "getGroup", "更新分组": "updateGroup",
"删除分组": "deleteGroup", "获取分组中的对话": "getGroupConversations", "添加对话到分组": "addConversationToGroup",
"从分组移除对话": "removeConversationFromGroup",
"列出漏洞": "listVulnerabilities", "创建漏洞": "createVulnerability", "获取漏洞统计": "getVulnerabilityStats",
"列出漏洞": "listVulnerabilities", "创建漏洞": "createVulnerability", "获取漏洞统计": "getVulnerabilityStats",
"获取漏洞": "getVulnerability", "更新漏洞": "updateVulnerability", "删除漏洞": "deleteVulnerability",
"列出角色": "listRoles", "创建角色": "createRole", "获取角色": "getRole", "更新角色": "updateRole", "删除角色": "deleteRole",
"获取可用Skills列表": "getAvailableSkills", "列出Skills": "listSkills", "创建Skill": "createSkill",
@@ -52,9 +52,9 @@ var apiDocI18nSummaryToKey = map[string]string{
"重跑批量任务队列": "rerunBatchQueue", "修改队列元数据": "updateBatchQueueMetadata",
"修改队列调度配置": "updateBatchQueueSchedule", "开关Cron自动调度": "setBatchQueueScheduleEnabled",
"获取所有分组映射": "getAllGroupMappings",
"FOFA搜索": "fofaSearch", "自然语言解析为FOFA语法": "fofaParse",
"FOFA搜索": "fofaSearch", "自然语言解析为FOFA语法": "fofaParse",
"测试OpenAI API连接": "testOpenAI",
"执行终端命令": "terminalRun", "流式执行终端命令": "terminalRunStream", "WebSocket终端": "terminalWS",
"执行终端命令": "terminalRun", "流式执行终端命令": "terminalRunStream", "WebSocket终端": "terminalWS",
"列出WebShell连接": "listWebshellConnections", "创建WebShell连接": "createWebshellConnection",
"更新WebShell连接": "updateWebshellConnection", "删除WebShell连接": "deleteWebshellConnection",
"获取连接状态": "getWebshellConnectionState", "保存连接状态": "saveWebshellConnectionState",
@@ -69,7 +69,7 @@ var apiDocI18nSummaryToKey = map[string]string{
"获取Markdown代理详情": "getMarkdownAgent", "更新Markdown代理": "updateMarkdownAgent", "删除Markdown代理": "deleteMarkdownAgent",
"列出技能包文件": "listSkillPackageFiles", "获取技能包文件内容": "getSkillPackageFile", "写入技能包文件": "putSkillPackageFile",
"批量获取工具名称": "batchGetToolNames",
"获取知识库统计": "getKnowledgeStats",
"获取知识库统计": "getKnowledgeStats",
}
var apiDocI18nResponseDescToKey = map[string]string{
@@ -78,7 +78,7 @@ var apiDocI18nResponseDescToKey = map[string]string{
"对话不存在或结果不存在": "conversationOrResultNotFound", "请求参数错误(如task为空)": "badRequestTaskEmpty",
"请求参数错误或分组名称已存在": "badRequestGroupNameExists", "分组不存在": "groupNotFound",
"请求参数错误(如配置格式不正确、缺少必需字段等)": "badRequestConfig",
"请求参数错误(如query为空)": "badRequestQueryEmpty", "方法不允许(仅支持POST请求)": "methodNotAllowed",
"请求参数错误(如query为空)": "badRequestQueryEmpty", "方法不允许(仅支持POST请求)": "methodNotAllowed",
"登录成功": "loginSuccess", "密码错误": "invalidPassword", "登出成功": "logoutSuccess",
"密码修改成功": "passwordChanged", "Token有效": "tokenValid", "Token无效或已过期": "tokenInvalid",
"对话创建成功": "conversationCreated", "服务器内部错误": "internalError", "更新成功": "updateSuccess",
@@ -89,7 +89,7 @@ var apiDocI18nResponseDescToKey = map[string]string{
"消息发送成功,返回AI回复": "messageSent", "流式响应(Server-Sent Events": "streamResponse",
// 新增缺失端点响应
"参数错误或删除失败": "badRequestOrDeleteFailed",
"参数错误": "paramError", "仅已完成或已取消的队列可以重跑": "onlyCompletedOrCancelledCanRerun",
"参数错误": "paramError", "仅已完成或已取消的队列可以重跑": "onlyCompletedOrCancelledCanRerun",
"参数错误或队列正在运行中": "badRequestOrQueueRunning", "设置成功": "setSuccess",
"搜索成功": "searchSuccess", "解析成功": "parseSuccess", "测试结果": "testResult",
"执行完成": "executionDone", "SSE事件流": "sseEventStream", "WebSocket连接已建立": "wsEstablished",
+14 -14
View File
@@ -28,20 +28,20 @@ import (
)
const (
robotCmdHelp = "帮助"
robotCmdList = "列表"
robotCmdListAlt = "对话列表"
robotCmdSwitch = "切换"
robotCmdContinue = "继续"
robotCmdNew = "新对话"
robotCmdClear = "清空"
robotCmdCurrent = "当前"
robotCmdStop = "停止"
robotCmdRoles = "角色"
robotCmdRolesList = "角色列表"
robotCmdSwitchRole = "切换角色"
robotCmdDelete = "删除"
robotCmdVersion = "版本"
robotCmdHelp = "帮助"
robotCmdList = "列表"
robotCmdListAlt = "对话列表"
robotCmdSwitch = "切换"
robotCmdContinue = "继续"
robotCmdNew = "新对话"
robotCmdClear = "清空"
robotCmdCurrent = "当前"
robotCmdStop = "停止"
robotCmdRoles = "角色"
robotCmdRolesList = "角色列表"
robotCmdSwitchRole = "切换角色"
robotCmdDelete = "删除"
robotCmdVersion = "版本"
)
// RobotHandler 企业微信/钉钉/飞书等机器人回调处理
+13 -13
View File
@@ -65,19 +65,19 @@ func (h *SkillsHandler) GetSkills(c *gin.Context) {
allSkillsInfo := make([]map[string]interface{}, 0, len(allSummaries))
for _, s := range allSummaries {
skillInfo := map[string]interface{}{
"id": s.ID,
"name": s.Name,
"dir_name": s.DirName,
"description": s.Description,
"version": s.Version,
"path": s.Path,
"tags": s.Tags,
"triggers": s.Triggers,
"script_count": s.ScriptCount,
"file_count": s.FileCount,
"progressive": s.Progressive,
"file_size": s.FileSize,
"mod_time": s.ModTime,
"id": s.ID,
"name": s.Name,
"dir_name": s.DirName,
"description": s.Description,
"version": s.Version,
"path": s.Path,
"tags": s.Tags,
"triggers": s.Triggers,
"script_count": s.ScriptCount,
"file_count": s.FileCount,
"progressive": s.Progressive,
"file_size": s.FileSize,
"mod_time": s.ModTime,
}
allSkillsInfo = append(allSkillsInfo, skillInfo)
}
-1
View File
@@ -109,4 +109,3 @@ func (h *TerminalHandler) RunCommandWS(c *gin.Context) {
<-doneChan
}
+222 -23
View File
@@ -1,8 +1,11 @@
package handler
import (
"fmt"
"net/http"
"strconv"
"strings"
"time"
"cyberstrike-ai/internal/database"
"github.com/gin-gonic/gin"
@@ -25,7 +28,9 @@ func NewVulnerabilityHandler(db *database.DB, logger *zap.Logger) *Vulnerability
// CreateVulnerabilityRequest 创建漏洞请求
type CreateVulnerabilityRequest struct {
ConversationID string `json:"conversation_id" binding:"required"`
ConversationID string `json:"conversation_id" binding:"required"`
ConversationTag string `json:"conversation_tag"`
TaskTag string `json:"task_tag"`
Title string `json:"title" binding:"required"`
Description string `json:"description"`
Severity string `json:"severity" binding:"required"`
@@ -46,16 +51,18 @@ func (h *VulnerabilityHandler) CreateVulnerability(c *gin.Context) {
}
vuln := &database.Vulnerability{
ConversationID: req.ConversationID,
Title: req.Title,
Description: req.Description,
Severity: req.Severity,
Status: req.Status,
Type: req.Type,
Target: req.Target,
Proof: req.Proof,
Impact: req.Impact,
Recommendation: req.Recommendation,
ConversationID: req.ConversationID,
ConversationTag: req.ConversationTag,
TaskTag: req.TaskTag,
Title: req.Title,
Description: req.Description,
Severity: req.Severity,
Status: req.Status,
Type: req.Type,
Target: req.Target,
Proof: req.Proof,
Impact: req.Impact,
Recommendation: req.Recommendation,
}
created, err := h.db.CreateVulnerability(vuln)
@@ -100,6 +107,9 @@ func (h *VulnerabilityHandler) ListVulnerabilities(c *gin.Context) {
conversationID := c.Query("conversation_id")
severity := c.Query("severity")
status := c.Query("status")
taskID := c.Query("task_id")
conversationTag := c.Query("conversation_tag")
taskTag := c.Query("task_tag")
limit, _ := strconv.Atoi(limitStr)
offset, _ := strconv.Atoi(offsetStr)
@@ -121,7 +131,7 @@ func (h *VulnerabilityHandler) ListVulnerabilities(c *gin.Context) {
}
// 获取总数
total, err := h.db.CountVulnerabilities(id, conversationID, severity, status)
total, err := h.db.CountVulnerabilities(id, conversationID, severity, status, taskID, conversationTag, taskTag)
if err != nil {
h.logger.Error("获取漏洞总数失败", zap.Error(err))
// 继续执行,使用0作为总数
@@ -129,7 +139,7 @@ func (h *VulnerabilityHandler) ListVulnerabilities(c *gin.Context) {
}
// 获取漏洞列表
vulnerabilities, err := h.db.ListVulnerabilities(limit, offset, id, conversationID, severity, status)
vulnerabilities, err := h.db.ListVulnerabilities(limit, offset, id, conversationID, severity, status, taskID, conversationTag, taskTag)
if err != nil {
h.logger.Error("获取漏洞列表失败", zap.Error(err))
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
@@ -160,15 +170,17 @@ func (h *VulnerabilityHandler) ListVulnerabilities(c *gin.Context) {
// UpdateVulnerabilityRequest 更新漏洞请求
type UpdateVulnerabilityRequest struct {
Title string `json:"title"`
Description string `json:"description"`
Severity string `json:"severity"`
Status string `json:"status"`
Type string `json:"type"`
Target string `json:"target"`
Proof string `json:"proof"`
Impact string `json:"impact"`
Recommendation string `json:"recommendation"`
ConversationTag string `json:"conversation_tag"`
TaskTag string `json:"task_tag"`
Title string `json:"title"`
Description string `json:"description"`
Severity string `json:"severity"`
Status string `json:"status"`
Type string `json:"type"`
Target string `json:"target"`
Proof string `json:"proof"`
Impact string `json:"impact"`
Recommendation string `json:"recommendation"`
}
// UpdateVulnerability 更新漏洞
@@ -189,6 +201,12 @@ func (h *VulnerabilityHandler) UpdateVulnerability(c *gin.Context) {
}
// 更新字段
if req.ConversationTag != "" {
existing.ConversationTag = req.ConversationTag
}
if req.TaskTag != "" {
existing.TaskTag = req.TaskTag
}
if req.Title != "" {
existing.Title = req.Title
}
@@ -250,8 +268,9 @@ func (h *VulnerabilityHandler) DeleteVulnerability(c *gin.Context) {
// GetVulnerabilityStats 获取漏洞统计
func (h *VulnerabilityHandler) GetVulnerabilityStats(c *gin.Context) {
conversationID := c.Query("conversation_id")
taskID := c.Query("task_id")
stats, err := h.db.GetVulnerabilityStats(conversationID)
stats, err := h.db.GetVulnerabilityStats(conversationID, taskID)
if err != nil {
h.logger.Error("获取漏洞统计失败", zap.Error(err))
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
@@ -261,3 +280,183 @@ func (h *VulnerabilityHandler) GetVulnerabilityStats(c *gin.Context) {
c.JSON(http.StatusOK, stats)
}
// GetVulnerabilityFilterOptions 获取漏洞筛选建议项
func (h *VulnerabilityHandler) GetVulnerabilityFilterOptions(c *gin.Context) {
options, err := h.db.GetVulnerabilityFilterOptions()
if err != nil {
h.logger.Error("获取漏洞筛选建议失败", zap.Error(err))
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, options)
}
// ExportVulnerabilities 导出漏洞(支持按对话/任务分组,汇总或拆分)
func (h *VulnerabilityHandler) ExportVulnerabilities(c *gin.Context) {
groupBy := c.DefaultQuery("group_by", "conversation")
mode := c.DefaultQuery("mode", "summary")
if groupBy != "conversation" && groupBy != "task" {
c.JSON(http.StatusBadRequest, gin.H{"error": "group_by 仅支持 conversation 或 task"})
return
}
if mode != "summary" && mode != "split" {
c.JSON(http.StatusBadRequest, gin.H{"error": "mode 仅支持 summary 或 split"})
return
}
id := c.Query("id")
conversationID := c.Query("conversation_id")
severity := c.Query("severity")
status := c.Query("status")
taskID := c.Query("task_id")
conversationTag := c.Query("conversation_tag")
taskTag := c.Query("task_tag")
total, err := h.db.CountVulnerabilities(id, conversationID, severity, status, taskID, conversationTag, taskTag)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
if total == 0 {
c.JSON(http.StatusOK, gin.H{"mode": mode, "group_by": groupBy, "total": 0, "files": []any{}})
return
}
items, err := h.db.ListVulnerabilities(total, 0, id, conversationID, severity, status, taskID, conversationTag, taskTag)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
type exportFile struct {
FileName string `json:"filename"`
Content string `json:"content"`
}
grouped := map[string][]*database.Vulnerability{}
for _, v := range items {
key := v.ConversationID
if groupBy == "conversation" {
if strings.TrimSpace(v.ConversationTag) != "" {
key = strings.TrimSpace(v.ConversationTag)
}
} else {
key = firstNonEmpty(v.TaskTag, v.TaskID, v.TaskQueueID, "unassigned-task")
}
grouped[key] = append(grouped[key], v)
}
files := make([]exportFile, 0)
nowStr := time.Now().Format("20060102-150405")
if mode == "summary" {
var b strings.Builder
b.WriteString("# 漏洞批量导出报告\n\n")
b.WriteString(fmt.Sprintf("- 导出时间: %s\n", time.Now().Format("2006-01-02 15:04:05")))
b.WriteString(fmt.Sprintf("- 分组维度: %s\n", groupBy))
b.WriteString(fmt.Sprintf("- 漏洞总数: %d\n", len(items)))
b.WriteString(fmt.Sprintf("- 分组数: %d\n\n", len(grouped)))
for group, list := range grouped {
b.WriteString(fmt.Sprintf("## %s (%d)\n\n", group, len(list)))
for _, v := range list {
appendVulnerabilityMarkdown(&b, v, "###")
}
}
files = append(files, exportFile{
FileName: fmt.Sprintf("vulnerability-report-%s-%s.md", groupBy, nowStr),
Content: b.String(),
})
} else {
for group, list := range grouped {
var b strings.Builder
b.WriteString(fmt.Sprintf("# 漏洞报告 - %s\n\n", group))
b.WriteString(fmt.Sprintf("- 导出时间: %s\n", time.Now().Format("2006-01-02 15:04:05")))
b.WriteString(fmt.Sprintf("- 漏洞数量: %d\n\n", len(list)))
for _, v := range list {
appendVulnerabilityMarkdown(&b, v, "##")
}
files = append(files, exportFile{
FileName: fmt.Sprintf("vulnerability-%s-%s.md", sanitizeExportName(group), nowStr),
Content: b.String(),
})
}
}
c.JSON(http.StatusOK, gin.H{
"mode": mode,
"group_by": groupBy,
"total": len(items),
"files": files,
})
}
// appendVulnerabilityMarkdown 单条漏洞的 Markdown 片段(与单文件下载字段对齐,缺省字段不写)
func appendVulnerabilityMarkdown(b *strings.Builder, v *database.Vulnerability, titleHeading string) {
b.WriteString(fmt.Sprintf("%s %s\n\n", titleHeading, v.Title))
b.WriteString(fmt.Sprintf("- 漏洞ID: `%s`\n", v.ID))
b.WriteString(fmt.Sprintf("- 严重程度: %s\n", v.Severity))
b.WriteString(fmt.Sprintf("- 状态: %s\n", v.Status))
if v.Type != "" {
b.WriteString(fmt.Sprintf("- 类型: %s\n", v.Type))
}
if v.Target != "" {
b.WriteString(fmt.Sprintf("- 目标: %s\n", v.Target))
}
b.WriteString(fmt.Sprintf("- 对话ID: `%s`\n", v.ConversationID))
if v.ConversationTag != "" {
b.WriteString(fmt.Sprintf("- 对话标签: %s\n", v.ConversationTag))
}
if v.TaskTag != "" {
b.WriteString(fmt.Sprintf("- 任务标签: %s\n", v.TaskTag))
}
if v.TaskID != "" {
b.WriteString(fmt.Sprintf("- 任务ID: `%s`\n", v.TaskID))
}
if v.TaskQueueID != "" {
b.WriteString(fmt.Sprintf("- 任务队列ID: `%s`\n", v.TaskQueueID))
}
if !v.CreatedAt.IsZero() {
b.WriteString(fmt.Sprintf("- 创建时间: %s\n", v.CreatedAt.Format("2006-01-02 15:04:05")))
}
if !v.UpdatedAt.IsZero() {
b.WriteString(fmt.Sprintf("- 更新时间: %s\n", v.UpdatedAt.Format("2006-01-02 15:04:05")))
}
if v.Description != "" {
b.WriteString("\n#### 描述\n\n")
b.WriteString(v.Description)
b.WriteString("\n")
}
if v.Proof != "" {
b.WriteString("\n#### 证明(POC\n\n```\n")
b.WriteString(v.Proof)
b.WriteString("\n```\n")
}
if v.Impact != "" {
b.WriteString("\n#### 影响\n\n")
b.WriteString(v.Impact)
b.WriteString("\n")
}
if v.Recommendation != "" {
b.WriteString("\n#### 修复建议\n\n")
b.WriteString(v.Recommendation)
b.WriteString("\n")
}
b.WriteString("\n")
}
func firstNonEmpty(values ...string) string {
for _, v := range values {
trimmed := strings.TrimSpace(v)
if trimmed != "" {
return trimmed
}
}
return ""
}
func sanitizeExportName(raw string) string {
name := strings.TrimSpace(raw)
if name == "" {
return "unknown"
}
replacer := strings.NewReplacer("/", "-", "\\", "-", ":", "-", "*", "-", "?", "-", "\"", "-", "<", "-", ">", "-", "|", "-")
return replacer.Replace(name)
}
+369 -138
View File
@@ -3,20 +3,302 @@ package handler
import (
"bytes"
"database/sql"
"encoding/base64"
"encoding/json"
"io"
"net/http"
"net/url"
"strings"
"time"
"unicode/utf8"
"cyberstrike-ai/internal/database"
"github.com/gin-gonic/gin"
"github.com/google/uuid"
"go.uber.org/zap"
"golang.org/x/text/encoding/simplifiedchinese"
"golang.org/x/text/transform"
)
// webshellSupportedEncodings 允许的 WebShell 响应编码取值(小写,含空串代表 auto)
// 仅暴露目前最常见的几种,其他需求可后续扩展(如 Big5、Shift_JIS 等)。
var webshellSupportedEncodings = map[string]struct{}{
"": {}, // 未配置,按 auto 处理
"auto": {},
"utf-8": {},
"utf8": {},
"gbk": {},
"gb18030": {},
}
// normalizeWebshellEncoding 归一化编码标识:统一为小写,未知值回退为 auto,供持久化使用
func normalizeWebshellEncoding(enc string) string {
enc = strings.ToLower(strings.TrimSpace(enc))
if _, ok := webshellSupportedEncodings[enc]; !ok {
return "auto"
}
if enc == "" {
return "auto"
}
if enc == "utf8" {
return "utf-8"
}
return enc
}
// decodeWebshellOutput 把 WebShell 返回的字节按指定编码转换为合法 UTF-8 字符串。
// 约定:
// - "" / "auto":若已是合法 UTF-8 原样返回,否则依次尝试 GB18030(GBK 超集)解码。
// - "utf-8" / "utf8":原样返回,非法字节交由 JSON 层按 U+FFFD 处理(保持原有行为)。
// - "gbk" / "gb18030":强制按对应编码解码;失败则回退原始字节。
//
// 该函数对空输入直接返回空串,避免不必要的转换。
func decodeWebshellOutput(raw []byte, encoding string) string {
if len(raw) == 0 {
return ""
}
enc := normalizeWebshellEncoding(encoding)
switch enc {
case "utf-8":
return string(raw)
case "gbk":
if out, _, err := transform.Bytes(simplifiedchinese.GBK.NewDecoder(), raw); err == nil {
return string(out)
}
return string(raw)
case "gb18030":
if out, _, err := transform.Bytes(simplifiedchinese.GB18030.NewDecoder(), raw); err == nil {
return string(out)
}
return string(raw)
default: // auto
if utf8.Valid(raw) {
return string(raw)
}
// GB18030 是 GBK 的超集,覆盖范围最广,auto 模式统一用它兜底
if out, _, err := transform.Bytes(simplifiedchinese.GB18030.NewDecoder(), raw); err == nil {
return string(out)
}
return string(raw)
}
}
// webshellSupportedOS 允许的 WebShell 目标操作系统(小写,空串代表 auto)
var webshellSupportedOS = map[string]struct{}{
"": {},
"auto": {},
"linux": {},
"windows": {},
}
// normalizeWebshellOS 归一化 OS 标识,未知值回退为 auto,供持久化使用
func normalizeWebshellOS(osTag string) string {
osTag = strings.ToLower(strings.TrimSpace(osTag))
if _, ok := webshellSupportedOS[osTag]; !ok {
return "auto"
}
if osTag == "" {
return "auto"
}
return osTag
}
// resolveWebshellOS 根据连接的 os 与 shellType 推断最终目标 OS(仅返回 "linux" 或 "windows")。
// 规则:
// - 显式 linux / windows:按用户选择。
// - auto 或未知:asp/aspx → windows,其他 → linux。保持历史行为,平滑向后兼容。
func resolveWebshellOS(osTag, shellType string) string {
osTag = strings.ToLower(strings.TrimSpace(osTag))
switch osTag {
case "linux":
return "linux"
case "windows":
return "windows"
}
t := strings.ToLower(strings.TrimSpace(shellType))
if t == "asp" || t == "aspx" {
return "windows"
}
return "linux"
}
// quoteCmdPath 把路径按 Windows cmd.exe 规则转义。
// 使用双引号包裹,内部双引号转义为 ""(cmd 接受的写法)。
func quoteCmdPath(p string) string {
if p == "" {
return "\".\""
}
return "\"" + strings.ReplaceAll(p, "\"", "\"\"") + "\""
}
// quotePsSingle 把字符串按 PowerShell 单引号字符串规则转义(内部 ' → '')。
// 供 PowerShell 脚本参数使用,全脚本只用单引号,外层 cmd 再用双引号包裹即可安全传递。
func quotePsSingle(s string) string {
return "'" + strings.ReplaceAll(s, "'", "''") + "'"
}
// quoteShellSinglePosix 把路径按 POSIX sh 单引号规则转义(内部 ' → '\''
func quoteShellSinglePosix(p string) string {
if p == "" {
return "."
}
return "'" + strings.ReplaceAll(p, "'", "'\\''") + "'"
}
// quoteWebshellPath 按目标 OS 选择转义方案:Linux 用 POSIX 单引号,Windows 用 cmd 双引号
func quoteWebshellPath(path, osTag string) string {
if resolveWebshellOS(osTag, "") == "windows" {
return quoteCmdPath(path)
}
return quoteShellSinglePosix(path)
}
// buildWindowsPowerShellWrite 构造 Windows 端把 base64 内容一次性写入目标路径的 cmd 命令。
// 外层走 cmd.exe 的 powershell 调用,PowerShell 脚本里只用单引号字符串,避免嵌套引号陷阱。
func buildWindowsPowerShellWrite(path, b64 string) string {
script := "$b=[Convert]::FromBase64String(" + quotePsSingle(b64) + ");" +
"[IO.File]::WriteAllBytes(" + quotePsSingle(path) + ",$b)"
return "powershell -NoProfile -NonInteractive -Command \"" + script + "\""
}
// buildWindowsPowerShellAppend 构造 Windows 端把 base64 内容追加写入目标路径的 cmd 命令(用于分块上传)
func buildWindowsPowerShellAppend(path, b64 string) string {
script := "$b=[Convert]::FromBase64String(" + quotePsSingle(b64) + ");" +
"$f=[IO.File]::Open(" + quotePsSingle(path) + ",[IO.FileMode]::Append,[IO.FileAccess]::Write,[IO.FileShare]::None);" +
"try{$f.Write($b,0,$b.Length)}finally{$f.Close()}"
return "powershell -NoProfile -NonInteractive -Command \"" + script + "\""
}
// fileCommandInput 封装 buildFileCommand 的输入,避免长参数列表
type fileCommandInput struct {
Action string
Path string
TargetPath string
Content string
ChunkIndex int
OS string
ShellType string
}
// buildFileCommand 根据目标 OS 与文件操作类型生成具体的远端命令字符串。
// 同一份实现供 HTTP 入口(FileOp)与 MCP 入口(FileOpWithConnection)共用,避免双份维护。
// 返回值第二位是用户可见的业务错误(如 "path is required")。
func (h *WebShellHandler) buildFileCommand(in fileCommandInput) (string, error) {
targetOS := resolveWebshellOS(in.OS, in.ShellType)
action := strings.ToLower(strings.TrimSpace(in.Action))
path := strings.TrimSpace(in.Path)
switch action {
case "list":
p := path
if p == "" {
p = "."
}
if targetOS == "windows" {
return "dir /a " + quoteCmdPath(p), nil
}
return "ls -la " + quoteShellSinglePosix(p), nil
case "read":
if path == "" {
return "", errFileOpPathRequired
}
if targetOS == "windows" {
return "type " + quoteCmdPath(path), nil
}
return "cat " + quoteShellSinglePosix(path), nil
case "delete":
if path == "" {
return "", errFileOpPathRequired
}
if targetOS == "windows" {
return "del /q /f " + quoteCmdPath(path), nil
}
return "rm -f " + quoteShellSinglePosix(path), nil
case "mkdir":
if path == "" {
return "", errFileOpPathRequired
}
if targetOS == "windows" {
// cmd 的 md 默认会自动创建中间目录(等价于 Linux 的 mkdir -p
return "md " + quoteCmdPath(path), nil
}
return "mkdir -p " + quoteShellSinglePosix(path), nil
case "rename":
oldPath := path
newPath := strings.TrimSpace(in.TargetPath)
if oldPath == "" || newPath == "" {
return "", errFileOpRenameNeedsBothPaths
}
if targetOS == "windows" {
return "move /y " + quoteCmdPath(oldPath) + " " + quoteCmdPath(newPath), nil
}
return "mv -f " + quoteShellSinglePosix(oldPath) + " " + quoteShellSinglePosix(newPath), nil
case "write":
if path == "" {
return "", errFileOpPathRequired
}
// 统一策略:先把内容 base64 编码,再用目标平台对应方式解码写回,
// 这样既能写入任意二进制/含引号的文本,又避免各家 shell 的转义地狱。
b64 := base64.StdEncoding.EncodeToString([]byte(in.Content))
if targetOS == "windows" {
return buildWindowsPowerShellWrite(path, b64), nil
}
return "echo '" + b64 + "' | base64 -d > " + quoteShellSinglePosix(path), nil
case "upload":
if path == "" {
return "", errFileOpPathRequired
}
if len(in.Content) > 512*1024 {
return "", errFileOpUploadTooLarge
}
if targetOS == "windows" {
return buildWindowsPowerShellWrite(path, in.Content), nil
}
return "echo '" + in.Content + "' | base64 -d > " + quoteShellSinglePosix(path), nil
case "upload_chunk":
if path == "" {
return "", errFileOpPathRequired
}
if targetOS == "windows" {
if in.ChunkIndex == 0 {
return buildWindowsPowerShellWrite(path, in.Content), nil
}
return buildWindowsPowerShellAppend(path, in.Content), nil
}
redir := ">>"
if in.ChunkIndex == 0 {
redir = ">"
}
return "echo '" + in.Content + "' | base64 -d " + redir + " " + quoteShellSinglePosix(path), nil
}
return "", errFileOpUnsupportedAction(action)
}
// 业务错误常量,便于上层统一返回用户可见提示
var (
errFileOpPathRequired = simpleError("path is required")
errFileOpRenameNeedsBothPaths = simpleError("path and target_path are required for rename")
errFileOpUploadTooLarge = simpleError("upload content too large (max 512KB base64)")
)
func errFileOpUnsupportedAction(action string) error {
return simpleError("unsupported action: " + action)
}
// simpleError 是不带堆栈的轻量错误类型,供 buildFileCommand 报可预期的参数校验错误
type simpleError string
func (e simpleError) Error() string { return string(e) }
// WebShellHandler 代理执行 WebShell 命令(类似冰蝎/蚁剑),避免前端跨域并统一构建请求
type WebShellHandler struct {
logger *zap.Logger
@@ -44,6 +326,8 @@ type CreateConnectionRequest struct {
Method string `json:"method"`
CmdParam string `json:"cmd_param"`
Remark string `json:"remark"`
Encoding string `json:"encoding"`
OS string `json:"os"`
}
// UpdateConnectionRequest 更新连接请求
@@ -54,6 +338,8 @@ type UpdateConnectionRequest struct {
Method string `json:"method"`
CmdParam string `json:"cmd_param"`
Remark string `json:"remark"`
Encoding string `json:"encoding"`
OS string `json:"os"`
}
// ListConnections 列出所有 WebShell 连接(GET /api/webshell/connections
@@ -109,6 +395,8 @@ func (h *WebShellHandler) CreateConnection(c *gin.Context) {
Method: method,
CmdParam: strings.TrimSpace(req.CmdParam),
Remark: strings.TrimSpace(req.Remark),
Encoding: normalizeWebshellEncoding(req.Encoding),
OS: normalizeWebshellOS(req.OS),
CreatedAt: time.Now(),
}
if err := h.db.CreateWebshellConnection(conn); err != nil {
@@ -159,6 +447,8 @@ func (h *WebShellHandler) UpdateConnection(c *gin.Context) {
Method: method,
CmdParam: strings.TrimSpace(req.CmdParam),
Remark: strings.TrimSpace(req.Remark),
Encoding: normalizeWebshellEncoding(req.Encoding),
OS: normalizeWebshellOS(req.OS),
}
if err := h.db.UpdateWebshellConnection(conn); err != nil {
if err == sql.ErrNoRows {
@@ -331,6 +621,8 @@ type ExecRequest struct {
Type string `json:"type"` // php, asp, aspx, jsp, custom
Method string `json:"method"` // GET 或 POST,空则默认 POST
CmdParam string `json:"cmd_param"` // 命令参数名,如 cmd/xxx,空则默认 cmd
Encoding string `json:"encoding"` // 响应编码:auto / utf-8 / gbk / gb18030,空则 auto
OS string `json:"os"` // 目标操作系统:auto / linux / windows,当前 exec 不用它,保留字段便于未来扩展
Command string `json:"command" binding:"required"`
}
@@ -344,23 +636,27 @@ type ExecResponse struct {
// FileOpRequest 文件操作请求
type FileOpRequest struct {
URL string `json:"url" binding:"required"`
Password string `json:"password"`
Type string `json:"type"`
Method string `json:"method"` // GET 或 POST,空则默认 POST
CmdParam string `json:"cmd_param"` // 命令参数名,如 cmd/xxx,空则默认 cmd
Action string `json:"action" binding:"required"` // list, read, delete, write, mkdir, rename, upload, upload_chunk
Path string `json:"path"`
TargetPath string `json:"target_path"` // rename 时目标路径
Content string `json:"content"` // write/upload 时使用
ChunkIndex int `json:"chunk_index"` // upload_chunk 时,0 表示首块
URL string `json:"url" binding:"required"`
Password string `json:"password"`
Type string `json:"type"`
Method string `json:"method"` // GET 或 POST,空则默认 POST
CmdParam string `json:"cmd_param"` // 命令参数名,如 cmd/xxx,空则默认 cmd
Encoding string `json:"encoding"` // 响应编码:auto / utf-8 / gbk / gb18030,空则 auto
OS string `json:"os"` // 目标操作系统:auto / linux / windows,空则按 shellType 推断
ConnectionID string `json:"connection_id,omitempty"` // 可选:连接 ID;服务端探活出 OS 后会回写到此连接
Action string `json:"action" binding:"required"` // list, read, delete, write, mkdir, rename, upload, upload_chunk
Path string `json:"path"`
TargetPath string `json:"target_path"` // rename 时目标路径
Content string `json:"content"` // write/upload 时使用
ChunkIndex int `json:"chunk_index"` // upload_chunk 时,0 表示首块
}
// FileOpResponse 文件操作响应
type FileOpResponse struct {
OK bool `json:"ok"`
Output string `json:"output"`
Error string `json:"error,omitempty"`
OK bool `json:"ok"`
Output string `json:"output"`
Error string `json:"error,omitempty"`
DetectedOS string `json:"detected_os,omitempty"` // 仅在 auto 模式且探活成功时返回,前端应更新本地缓存
}
func (h *WebShellHandler) Exec(c *gin.Context) {
@@ -415,7 +711,7 @@ func (h *WebShellHandler) Exec(c *gin.Context) {
if readErr != nil {
h.logger.Warn("webshell exec read body", zap.Error(readErr))
}
output := string(out)
output := decodeWebshellOutput(out, req.Encoding)
httpCode := resp.StatusCode
c.JSON(http.StatusOK, ExecResponse{
@@ -474,83 +770,32 @@ func (h *WebShellHandler) FileOp(c *gin.Context) {
return
}
// 通过执行系统命令实现文件操作(与通用一句话兼容)
var command string
shellType := strings.ToLower(strings.TrimSpace(req.Type))
switch req.Action {
case "list":
path := strings.TrimSpace(req.Path)
if path == "" {
path = "."
// 若 OS 未显式配置,先发一次探活命令,识别出真实 OS 再构造文件操作命令。
// 这解决了 "Windows + PHP + OS=auto" 场景下旧 fallback 错发 `ls -la` 导致目录列不出来的问题。
osTag := req.OS
detectedOS := ""
if normalizeWebshellOS(osTag) == "auto" {
if probed := probeWebshellOSViaExec(h.newHTTPExecFn(req.URL, req.Password, req.Type, req.Method, req.CmdParam, req.Encoding)); probed != "" {
osTag = probed
detectedOS = probed
// 若前端带了 connection_id,顺带把探活结果持久化到该连接,后续刷新零成本
if cid := strings.TrimSpace(req.ConnectionID); cid != "" {
h.persistDetectedOS(cid, probed)
}
}
if shellType == "asp" || shellType == "aspx" {
command = "dir " + h.escapePath(path)
} else {
command = "ls -la " + h.escapePath(path)
}
case "read":
if shellType == "asp" || shellType == "aspx" {
command = "type " + h.escapePath(strings.TrimSpace(req.Path))
} else {
command = "cat " + h.escapePath(strings.TrimSpace(req.Path))
}
case "delete":
if shellType == "asp" || shellType == "aspx" {
command = "del " + h.escapePath(strings.TrimSpace(req.Path))
} else {
command = "rm -f " + h.escapePath(strings.TrimSpace(req.Path))
}
case "write":
path := h.escapePath(strings.TrimSpace(req.Path))
command = "echo " + h.escapeForEcho(req.Content) + " > " + path
case "mkdir":
path := strings.TrimSpace(req.Path)
if path == "" {
c.JSON(http.StatusBadRequest, FileOpResponse{OK: false, Error: "path is required for mkdir"})
return
}
if shellType == "asp" || shellType == "aspx" {
command = "md " + h.escapePath(path)
} else {
command = "mkdir -p " + h.escapePath(path)
}
case "rename":
oldPath := strings.TrimSpace(req.Path)
newPath := strings.TrimSpace(req.TargetPath)
if oldPath == "" || newPath == "" {
c.JSON(http.StatusBadRequest, FileOpResponse{OK: false, Error: "path and target_path are required for rename"})
return
}
if shellType == "asp" || shellType == "aspx" {
command = "move /y " + h.escapePath(oldPath) + " " + h.escapePath(newPath)
} else {
command = "mv " + h.escapePath(oldPath) + " " + h.escapePath(newPath)
}
case "upload":
path := strings.TrimSpace(req.Path)
if path == "" {
c.JSON(http.StatusBadRequest, FileOpResponse{OK: false, Error: "path is required for upload"})
return
}
if len(req.Content) > 512*1024 {
c.JSON(http.StatusBadRequest, FileOpResponse{OK: false, Error: "upload content too large (max 512KB base64)"})
return
}
// base64 仅含 A-Za-z0-9+/=,用单引号包裹安全
command = "echo " + "'" + req.Content + "'" + " | base64 -d > " + h.escapePath(path)
case "upload_chunk":
path := strings.TrimSpace(req.Path)
if path == "" {
c.JSON(http.StatusBadRequest, FileOpResponse{OK: false, Error: "path is required for upload_chunk"})
return
}
redir := ">>"
if req.ChunkIndex == 0 {
redir = ">"
}
command = "echo " + "'" + req.Content + "'" + " | base64 -d " + redir + " " + h.escapePath(path)
default:
c.JSON(http.StatusBadRequest, FileOpResponse{OK: false, Error: "unsupported action: " + req.Action})
}
command, cmdErr := h.buildFileCommand(fileCommandInput{
Action: req.Action,
Path: req.Path,
TargetPath: req.TargetPath,
Content: req.Content,
ChunkIndex: req.ChunkIndex,
OS: osTag,
ShellType: req.Type,
})
if cmdErr != nil {
c.JSON(http.StatusBadRequest, FileOpResponse{OK: false, Error: cmdErr.Error()})
return
}
@@ -585,27 +830,15 @@ func (h *WebShellHandler) FileOp(c *gin.Context) {
if readErr != nil {
h.logger.Warn("webshell fileop read body", zap.Error(readErr))
}
output := string(out)
output := decodeWebshellOutput(out, req.Encoding)
c.JSON(http.StatusOK, FileOpResponse{
OK: resp.StatusCode == http.StatusOK,
Output: output,
OK: resp.StatusCode == http.StatusOK,
Output: output,
DetectedOS: detectedOS,
})
}
func (h *WebShellHandler) escapePath(p string) string {
if p == "" {
return "."
}
// 简单转义空格与敏感字符,避免命令注入
return "'" + strings.ReplaceAll(p, "'", "'\\''") + "'"
}
func (h *WebShellHandler) escapeForEcho(s string) string {
// 仅用于 write:base64 写入更安全,这里简单用单引号包裹
return "'" + strings.ReplaceAll(s, "'", "'\"'\"'") + "'"
}
// ExecWithConnection 在指定 WebShell 连接上执行命令(供 MCP/Agent 等非 HTTP 调用)
func (h *WebShellHandler) ExecWithConnection(conn *database.WebShellConnection, command string) (output string, ok bool, errMsg string) {
if conn == nil {
@@ -643,7 +876,7 @@ func (h *WebShellHandler) ExecWithConnection(conn *database.WebShellConnection,
if readErr != nil {
h.logger.Warn("webshell ExecWithConnection read body", zap.Error(readErr))
}
return string(out), resp.StatusCode == http.StatusOK, ""
return decodeWebshellOutput(out, conn.Encoding), resp.StatusCode == http.StatusOK, ""
}
// FileOpWithConnection 在指定 WebShell 连接上执行文件操作(供 MCP/Agent 调用),支持 list / read / write
@@ -652,40 +885,38 @@ func (h *WebShellHandler) FileOpWithConnection(conn *database.WebShellConnection
return "", false, "connection is nil"
}
action = strings.ToLower(strings.TrimSpace(action))
shellType := strings.ToLower(strings.TrimSpace(conn.Type))
if shellType == "" {
shellType = "php"
}
var command string
// MCP 入口仅开放 list / read / write 三种动作,与工具文档的承诺保持一致
switch action {
case "list":
if path == "" {
path = "."
}
if shellType == "asp" || shellType == "aspx" {
command = "dir " + h.escapePath(strings.TrimSpace(path))
} else {
command = "ls -la " + h.escapePath(strings.TrimSpace(path))
}
case "read":
path = strings.TrimSpace(path)
if path == "" {
return "", false, "path is required for read"
}
if shellType == "asp" || shellType == "aspx" {
command = "type " + h.escapePath(path)
} else {
command = "cat " + h.escapePath(path)
}
case "write":
path = strings.TrimSpace(path)
if path == "" {
return "", false, "path is required for write"
}
command = "echo " + h.escapeForEcho(content) + " > " + h.escapePath(path)
case "list", "read", "write":
// 支持的动作
default:
return "", false, "unsupported action: " + action + " (supported: list, read, write)"
}
// 若连接的 OS 为 auto,先探活并持久化,避免 AI/MCP 每次都对 Windows 发 `ls -la`
osTag := conn.OS
if normalizeWebshellOS(osTag) == "auto" {
if probed := probeWebshellOSViaExec(func(cmd string) (string, bool) {
out, exOk, _ := h.ExecWithConnection(conn, cmd)
return out, exOk
}); probed != "" {
osTag = probed
conn.OS = probed // 本次请求内使用探活结果
h.persistDetectedOS(conn.ID, probed)
}
}
command, cmdErr := h.buildFileCommand(fileCommandInput{
Action: action,
Path: path,
TargetPath: targetPath,
Content: content,
OS: osTag,
ShellType: conn.Type,
})
if cmdErr != nil {
return "", false, cmdErr.Error()
}
useGET := strings.ToUpper(strings.TrimSpace(conn.Method)) == "GET"
cmdParam := strings.TrimSpace(conn.CmdParam)
if cmdParam == "" {
@@ -714,5 +945,5 @@ func (h *WebShellHandler) FileOpWithConnection(conn *database.WebShellConnection
if readErr != nil {
h.logger.Warn("webshell FileOpWithConnection read body", zap.Error(readErr))
}
return string(out), resp.StatusCode == http.StatusOK, ""
return decodeWebshellOutput(out, conn.Encoding), resp.StatusCode == http.StatusOK, ""
}
+106
View File
@@ -0,0 +1,106 @@
package handler
import (
"strings"
"cyberstrike-ai/internal/database"
)
// WebshellSkillHintDefault 对话页 / Eino 单代理共用的 Skills 说明,放在 webshell 上下文末尾,
// 供 AI 选择 skill 加载入口时参考。
const WebshellSkillHintDefault = "Skills 包请使用「多代理 / Eino DeepAgent」会话中的内置 `skill` 工具渐进加载。"
// WebshellSkillHintMultiAgent 多代理 / Eino 多代理准备阶段使用的 Skills 说明
const WebshellSkillHintMultiAgent = "Skills 包请使用 Eino 多代理内置 `skill` 工具。"
// webshellAssistantToolList AI 助手在 WebShell 上下文下允许使用的工具清单(展示给模型用)。
// 注意:此处只是展示字符串,真正的权限限制是在调用方设置的 roleTools 切片里。
const webshellAssistantToolList = "webshell_exec、webshell_file_list、webshell_file_read、webshell_file_write、record_vulnerability、list_knowledge_risk_types、search_knowledge_base"
// BuildWebshellAssistantContext 根据连接信息与用户原始消息组装 AI 助手的上下文提示词。
// 上下文包含:连接 ID、备注、目标系统(及对应命令集建议)、响应编码、可用工具清单、Skills 加载入口、
// 以及最终的用户请求。调用方只需要决定 skillHint 的文案(默认使用 WebshellSkillHintDefault)。
//
// 之所以把这段逻辑抽到共享函数里,是为了避免 agent.go / multi_agent_prepare.go 等多处复制粘贴,
// 并确保当我们升级 OS / Encoding 文案时只需要改一处、测一处、同步生效。
func BuildWebshellAssistantContext(conn *database.WebShellConnection, skillHint, userMsg string) string {
if conn == nil {
// 兜底:调用方已保证 conn 非 nil,这里只是防御性返回原消息
return userMsg
}
remark := conn.Remark
if remark == "" {
remark = conn.URL
}
targetOS := resolveWebshellOS(conn.OS, conn.Type) // 归一为 "linux" / "windows"
encoding := normalizeWebshellEncoding(conn.Encoding)
if skillHint == "" {
skillHint = WebshellSkillHintDefault
}
var b strings.Builder
b.Grow(512 + len(userMsg))
b.WriteString("[WebShell 助手上下文] 连接 ID")
b.WriteString(conn.ID)
b.WriteString(",备注:")
b.WriteString(remark)
b.WriteByte('\n')
// 目标系统:明确告诉 AI 能用/不能用的命令集,避免它对着 Windows 发 ls/cat/rm
b.WriteString("- 目标系统:")
b.WriteString(describeTargetOSForPrompt(targetOS))
b.WriteByte('\n')
// 响应编码:仅在非 auto 时显式告知,auto 模式由后端自适应,不打扰模型
if encHint := describeEncodingForPrompt(encoding); encHint != "" {
b.WriteString("- 响应编码:")
b.WriteString(encHint)
b.WriteByte('\n')
}
// 工具清单 & connection_id 约束:保持旧有表达,AI 已熟悉
b.WriteString("可用工具(仅在该连接上操作时使用,connection_id 填 \"")
b.WriteString(conn.ID)
b.WriteString("\"):")
b.WriteString(webshellAssistantToolList)
b.WriteString("。")
b.WriteString(skillHint)
b.WriteString("\n\n用户请求:")
b.WriteString(userMsg)
return b.String()
}
// describeTargetOSForPrompt 返回某个 OS 对应的中文描述 + 推荐命令集 + 反例,
// 命令列表覆盖文件管理最常用的 6 类动作(查看/读/删/改名/建目录/查找),让 AI 能直接照抄。
func describeTargetOSForPrompt(targetOS string) string {
switch targetOS {
case "windows":
return "Windows(推荐 cmd/PowerShelldir /a、type、del /q /f、move /y、md、ren" +
"查找文件用 `dir /s /b 过滤词` 或 PowerShell `Get-ChildItem -Recurse`" +
"避免 ls / cat / rm / mv / find 等 Unix 命令,否则将返回 `不是内部或外部命令`)"
case "linux":
return "Linux/Unix(推荐 sh/bashls -la、cat、rm -f、mv、mkdir -p" +
"查找文件用 `find /path -name '*pattern*'`" +
"避免 dir、type、del、move 等 Windows 命令)"
default:
// 理论上不会走到这里,resolveWebshellOS 已经兜底
return "未知(请先执行 `uname || ver` 探测再决定命令集)"
}
}
// describeEncodingForPrompt 返回响应编码的人类可读描述;auto 返回空串以减少 token。
func describeEncodingForPrompt(encoding string) string {
switch encoding {
case "utf-8":
return "UTF-8(目标原生 UTF-8,无需额外解码)"
case "gbk":
return "GBK(中文 Windows;后端已自动转码为 UTF-8 返回,若仍出现大量 \\uFFFD 替换字符说明命令失败或编码识别错误)"
case "gb18030":
return "GB18030(后端已自动转码为 UTF-8 返回)"
default:
return ""
}
}
+170
View File
@@ -0,0 +1,170 @@
package handler
import (
"strings"
"testing"
"cyberstrike-ai/internal/database"
)
func TestBuildWebshellAssistantContext_WindowsExplicit(t *testing.T) {
conn := &database.WebShellConnection{
ID: "ws_win01",
Remark: "IIS Windows 靶机",
URL: "http://example.com/shell.php",
Type: "php",
OS: "windows",
Encoding: "gbk",
}
got := BuildWebshellAssistantContext(conn, WebshellSkillHintDefault, "列出当前目录并告诉我 flag 在哪")
mustContain(t, got,
"[WebShell 助手上下文]",
"ws_win01",
"IIS Windows 靶机",
"目标系统:Windows",
"dir /a",
"move /y",
"避免 ls / cat / rm",
"响应编码:GBK",
"后端已自动转码为 UTF-8",
"connection_id 填 \"ws_win01\"",
"webshell_exec、webshell_file_list",
WebshellSkillHintDefault,
"用户请求:列出当前目录并告诉我 flag 在哪",
)
// Windows 场景下不应出现 Linux 命令推荐
mustNotContain(t, got, "推荐 sh/bash")
}
func TestBuildWebshellAssistantContext_LinuxAutoFromPHP(t *testing.T) {
conn := &database.WebShellConnection{
ID: "ws_lnx01",
Remark: "", // 测试备注为空时 fallback URL
URL: "http://example.com/a.php",
Type: "php",
OS: "auto", // auto + php → linux
Encoding: "", // auto 编码不显式提示
}
got := BuildWebshellAssistantContext(conn, WebshellSkillHintDefault, "看看 /etc/passwd")
mustContain(t, got,
"连接 IDws_lnx01",
"备注:http://example.com/a.php", // 备注空时 fallback URL
"目标系统:Linux/Unix",
"ls -la",
"mkdir -p",
"避免 dir、type、del、move",
"用户请求:看看 /etc/passwd",
)
// encoding=auto 不应出现"响应编码:"这一行
mustNotContain(t, got, "响应编码:")
// Linux 场景不应出现 Windows 命令
mustNotContain(t, got, "推荐 cmd/PowerShell")
}
func TestBuildWebshellAssistantContext_AutoFromASPDefaultsToWindows(t *testing.T) {
// 保留向后兼容:旧连接没配 os,shellType=asp 时应视为 Windows
conn := &database.WebShellConnection{
ID: "ws_asp01",
Remark: "老 ASP 靶机",
Type: "asp",
OS: "", // 空串等同 auto
Encoding: "gb18030",
}
got := BuildWebshellAssistantContext(conn, WebshellSkillHintMultiAgent, "查当前用户")
mustContain(t, got,
"目标系统:Windows",
"响应编码:GB18030",
"后端已自动转码为 UTF-8 返回",
WebshellSkillHintMultiAgent,
)
// 多代理 skill 文案里没有 DeepAgent,不应混入 default 文案
mustNotContain(t, got, "DeepAgent")
}
func TestBuildWebshellAssistantContext_MultiAgentSkillHint(t *testing.T) {
conn := &database.WebShellConnection{ID: "ws_m1", Remark: "x", Type: "php", OS: "linux"}
got := BuildWebshellAssistantContext(conn, WebshellSkillHintMultiAgent, "hi")
mustContain(t, got, WebshellSkillHintMultiAgent)
mustNotContain(t, got, "DeepAgent")
}
func TestBuildWebshellAssistantContext_DefaultSkillHintFallback(t *testing.T) {
conn := &database.WebShellConnection{ID: "ws_d1", Remark: "x", Type: "php", OS: "linux"}
// skillHint 传空字符串时应回退到 default
got := BuildWebshellAssistantContext(conn, "", "hi")
mustContain(t, got, WebshellSkillHintDefault)
}
func TestBuildWebshellAssistantContext_UTF8EncodingIsAnnotated(t *testing.T) {
conn := &database.WebShellConnection{
ID: "ws_u1", Remark: "u", Type: "jsp", OS: "linux", Encoding: "utf-8",
}
got := BuildWebshellAssistantContext(conn, WebshellSkillHintDefault, "hi")
mustContain(t, got, "响应编码:UTF-8", "目标原生 UTF-8")
}
func TestBuildWebshellAssistantContext_NilConnReturnsUserMsg(t *testing.T) {
// 防御性:conn == nil 时不 panic,直接返回原消息
got := BuildWebshellAssistantContext(nil, WebshellSkillHintDefault, "just the message")
if got != "just the message" {
t.Errorf("nil conn should return userMsg as-is, got %q", got)
}
}
func TestDescribeTargetOSForPrompt(t *testing.T) {
cases := map[string][]string{
"windows": {"Windows", "dir /a", "move /y", "PowerShell"},
"linux": {"Linux/Unix", "ls -la", "mkdir -p"},
"": {"未知", "uname"}, // 防御性分支
}
for in, wants := range cases {
got := describeTargetOSForPrompt(in)
for _, w := range wants {
if !strings.Contains(got, w) {
t.Errorf("describeTargetOSForPrompt(%q) should contain %q, got: %s", in, w, got)
}
}
}
}
func TestDescribeEncodingForPrompt(t *testing.T) {
cases := map[string]string{
"utf-8": "UTF-8",
"gbk": "GBK",
"gb18030": "GB18030",
"auto": "",
"": "",
}
for in, want := range cases {
got := describeEncodingForPrompt(in)
if want == "" && got != "" {
t.Errorf("describeEncodingForPrompt(%q) should return empty string, got: %s", in, got)
}
if want != "" && !strings.Contains(got, want) {
t.Errorf("describeEncodingForPrompt(%q) should contain %q, got: %s", in, want, got)
}
}
}
// ---- 小工具 ----
func mustContain(t *testing.T, text string, substrings ...string) {
t.Helper()
for _, s := range substrings {
if !strings.Contains(text, s) {
t.Errorf("expected text to contain %q\n--- text ---\n%s", s, text)
}
}
}
func mustNotContain(t *testing.T, text string, substrings ...string) {
t.Helper()
for _, s := range substrings {
if strings.Contains(text, s) {
t.Errorf("text should not contain %q\n--- text ---\n%s", s, text)
}
}
}
+103
View File
@@ -0,0 +1,103 @@
package handler
import (
"testing"
"golang.org/x/text/encoding/simplifiedchinese"
"golang.org/x/text/transform"
)
// mustEncode 使用指定编码对 UTF-8 字符串做编码,得到原始字节,用于构造测试输入
func mustEncode(t *testing.T, s string, enc string) []byte {
t.Helper()
var tr transform.Transformer
switch enc {
case "gbk":
tr = simplifiedchinese.GBK.NewEncoder()
case "gb18030":
tr = simplifiedchinese.GB18030.NewEncoder()
default:
t.Fatalf("unsupported test encoding: %s", enc)
}
out, _, err := transform.Bytes(tr, []byte(s))
if err != nil {
t.Fatalf("mustEncode(%s) failed: %v", enc, err)
}
return out
}
func TestNormalizeWebshellEncoding(t *testing.T) {
cases := map[string]string{
"": "auto",
" ": "auto",
"auto": "auto",
"AUTO": "auto",
"utf-8": "utf-8",
"UTF-8": "utf-8",
"utf8": "utf-8",
"gbk": "gbk",
"GBK": "gbk",
"gb18030": "gb18030",
"big5": "auto", // 未支持的回退到 auto
"anything": "auto",
}
for in, want := range cases {
if got := normalizeWebshellEncoding(in); got != want {
t.Errorf("normalizeWebshellEncoding(%q) = %q, want %q", in, got, want)
}
}
}
func TestDecodeWebshellOutput_AutoDetectsGBK(t *testing.T) {
// 模拟 Windows 中文 cmd 输出的 GBK 字节流
want := "用户名 SID 类型"
raw := mustEncode(t, want, "gbk")
// auto 模式:UTF-8 校验失败后应当回退 GB18030 解码,得到原始中文
got := decodeWebshellOutput(raw, "auto")
if got != want {
t.Errorf("decodeWebshellOutput(auto) = %q, want %q", got, want)
}
// 显式 GBK 模式:同样应当正确解码
got = decodeWebshellOutput(raw, "gbk")
if got != want {
t.Errorf("decodeWebshellOutput(gbk) = %q, want %q", got, want)
}
// 显式 GB18030 模式:GBK 是 GB18030 子集,也应正确解码
got = decodeWebshellOutput(raw, "gb18030")
if got != want {
t.Errorf("decodeWebshellOutput(gb18030) = %q, want %q", got, want)
}
}
func TestDecodeWebshellOutput_PassthroughUTF8(t *testing.T) {
// 已经是 UTF-8 的中文字符串,各模式都应返回原串(不破坏)
want := "hello 世界"
for _, enc := range []string{"", "auto", "utf-8"} {
if got := decodeWebshellOutput([]byte(want), enc); got != want {
t.Errorf("decodeWebshellOutput(%q) passthrough = %q, want %q", enc, got, want)
}
}
}
func TestDecodeWebshellOutput_ASCIIStable(t *testing.T) {
// 纯 ASCII 在任何模式下都必须保持原样
want := "whoami\nAdministrator\n"
for _, enc := range []string{"", "auto", "utf-8", "gbk", "gb18030"} {
if got := decodeWebshellOutput([]byte(want), enc); got != want {
t.Errorf("decodeWebshellOutput(%q) ASCII = %q, want %q", enc, got, want)
}
}
}
func TestDecodeWebshellOutput_EmptyInput(t *testing.T) {
// 空输入直接返回空串,不做额外分配
if got := decodeWebshellOutput(nil, "gbk"); got != "" {
t.Errorf("decodeWebshellOutput(nil) = %q, want empty", got)
}
if got := decodeWebshellOutput([]byte{}, "auto"); got != "" {
t.Errorf("decodeWebshellOutput([]) = %q, want empty", got)
}
}
+348
View File
@@ -0,0 +1,348 @@
package handler
import (
"encoding/base64"
"strings"
"testing"
"go.uber.org/zap"
)
func newTestWebShellHandler() *WebShellHandler {
return NewWebShellHandler(zap.NewNop(), nil)
}
func TestNormalizeWebshellOS(t *testing.T) {
cases := map[string]string{
"": "auto",
" ": "auto",
"auto": "auto",
"AUTO": "auto",
"linux": "linux",
"Linux": "linux",
"windows": "windows",
"WINDOWS": "windows",
"macos": "auto", // 未支持的回退 auto
"solaris": "auto",
}
for in, want := range cases {
if got := normalizeWebshellOS(in); got != want {
t.Errorf("normalizeWebshellOS(%q) = %q, want %q", in, got, want)
}
}
}
func TestResolveWebshellOS(t *testing.T) {
type testCase struct {
osTag string
shellType string
want string
}
cases := []testCase{
// 显式 OS:按用户选择,忽略 shellType
{"linux", "asp", "linux"},
{"windows", "php", "windows"},
{"LINUX", "jsp", "linux"},
// auto + 各种 shellTypeasp/aspx → windows,其他 → linux
{"auto", "asp", "windows"},
{"auto", "aspx", "windows"},
{"auto", "ASP", "windows"},
{"auto", "php", "linux"},
{"auto", "jsp", "linux"},
{"auto", "custom", "linux"},
{"auto", "", "linux"},
// 空/未知 OS 等价 auto
{"", "asp", "windows"},
{"", "php", "linux"},
{"unknown", "aspx", "windows"},
}
for _, c := range cases {
got := resolveWebshellOS(c.osTag, c.shellType)
if got != c.want {
t.Errorf("resolveWebshellOS(%q,%q) = %q, want %q", c.osTag, c.shellType, got, c.want)
}
}
}
func TestQuoteCmdPath(t *testing.T) {
cases := map[string]string{
"": `"."`,
`C:\Windows\Temp`: `"C:\Windows\Temp"`,
`C:\Program Files\a`: `"C:\Program Files\a"`,
`C:\weird"name\f.txt`: `"C:\weird""name\f.txt"`,
`.`: `"."`,
}
for in, want := range cases {
if got := quoteCmdPath(in); got != want {
t.Errorf("quoteCmdPath(%q) = %q, want %q", in, got, want)
}
}
}
func TestQuoteShellSinglePosix(t *testing.T) {
cases := map[string]string{
"": ".",
"/tmp/a b": "'/tmp/a b'",
"/tmp/it's.txt": `'/tmp/it'\''s.txt'`,
}
for in, want := range cases {
if got := quoteShellSinglePosix(in); got != want {
t.Errorf("quoteShellSinglePosix(%q) = %q, want %q", in, got, want)
}
}
}
// TestBuildFileCommand_LinuxBranch 覆盖 Linux 目标下每个 action 产出的命令
func TestBuildFileCommand_LinuxBranch(t *testing.T) {
h := newTestWebShellHandler()
base := fileCommandInput{OS: "linux", ShellType: "php"}
mustContain := func(t *testing.T, cmd string, substrings ...string) {
t.Helper()
for _, s := range substrings {
if !strings.Contains(cmd, s) {
t.Errorf("expected command to contain %q, got: %s", s, cmd)
}
}
}
mustNotContain := func(t *testing.T, cmd string, substrings ...string) {
t.Helper()
for _, s := range substrings {
if strings.Contains(cmd, s) {
t.Errorf("command should not contain %q, got: %s", s, cmd)
}
}
}
// list with empty path defaults to '.'
in := base
in.Action = "list"
cmd, err := h.buildFileCommand(in)
if err != nil {
t.Fatalf("list linux: unexpected err: %v", err)
}
mustContain(t, cmd, "ls -la", "'.'")
// list with path containing spaces
in.Path = "/tmp/my files"
cmd, _ = h.buildFileCommand(in)
mustContain(t, cmd, "ls -la ", "'/tmp/my files'")
// read with path
in = base
in.Action = "read"
in.Path = "/etc/passwd"
cmd, _ = h.buildFileCommand(in)
mustContain(t, cmd, "cat ", "'/etc/passwd'")
// read without path → error
in.Path = ""
if _, err := h.buildFileCommand(in); err != errFileOpPathRequired {
t.Errorf("read empty path: want errFileOpPathRequired, got %v", err)
}
// delete
in = base
in.Action = "delete"
in.Path = "/tmp/a.txt"
cmd, _ = h.buildFileCommand(in)
mustContain(t, cmd, "rm -f ", "'/tmp/a.txt'")
mustNotContain(t, cmd, "del")
// mkdir
in.Action = "mkdir"
in.Path = "/tmp/new/sub"
cmd, _ = h.buildFileCommand(in)
mustContain(t, cmd, "mkdir -p ", "'/tmp/new/sub'")
// rename
in = base
in.Action = "rename"
in.Path = "/tmp/a"
in.TargetPath = "/tmp/b"
cmd, _ = h.buildFileCommand(in)
mustContain(t, cmd, "mv -f ", "'/tmp/a'", "'/tmp/b'")
// rename missing target → error
in.TargetPath = ""
if _, err := h.buildFileCommand(in); err != errFileOpRenameNeedsBothPaths {
t.Errorf("rename empty target: want errFileOpRenameNeedsBothPaths, got %v", err)
}
// write
in = base
in.Action = "write"
in.Path = "/tmp/w.txt"
in.Content = "hello 世界"
cmd, _ = h.buildFileCommand(in)
b64 := base64.StdEncoding.EncodeToString([]byte("hello 世界"))
mustContain(t, cmd, "echo '"+b64+"'", "| base64 -d", "> '/tmp/w.txt'")
// upload
in = base
in.Action = "upload"
in.Path = "/tmp/bin"
in.Content = "YWJjZA==" // base64 of "abcd"
cmd, _ = h.buildFileCommand(in)
mustContain(t, cmd, "echo 'YWJjZA=='", "| base64 -d", "> '/tmp/bin'")
// upload oversized content → error
in.Content = strings.Repeat("A", 513*1024)
if _, err := h.buildFileCommand(in); err != errFileOpUploadTooLarge {
t.Errorf("upload too large: want errFileOpUploadTooLarge, got %v", err)
}
// upload_chunk with chunk_index=0 uses single redirect
in = base
in.Action = "upload_chunk"
in.Path = "/tmp/bin"
in.Content = "YWJj"
in.ChunkIndex = 0
cmd, _ = h.buildFileCommand(in)
mustContain(t, cmd, "base64 -d > '/tmp/bin'")
mustNotContain(t, cmd, ">>")
// upload_chunk with chunk_index>0 uses append redirect
in.ChunkIndex = 1
cmd, _ = h.buildFileCommand(in)
mustContain(t, cmd, "base64 -d >> '/tmp/bin'")
// unsupported action
in = base
in.Action = "nope"
if _, err := h.buildFileCommand(in); err == nil || !strings.Contains(err.Error(), "unsupported action") {
t.Errorf("unknown action: want unsupported action error, got %v", err)
}
}
// TestBuildFileCommand_WindowsBranch 覆盖 Windows 目标下每个 action 产出的命令
func TestBuildFileCommand_WindowsBranch(t *testing.T) {
h := newTestWebShellHandler()
base := fileCommandInput{OS: "windows", ShellType: "php"}
mustContain := func(t *testing.T, cmd string, substrings ...string) {
t.Helper()
for _, s := range substrings {
if !strings.Contains(cmd, s) {
t.Errorf("expected command to contain %q, got: %s", s, cmd)
}
}
}
mustNotContain := func(t *testing.T, cmd string, substrings ...string) {
t.Helper()
for _, s := range substrings {
if strings.Contains(cmd, s) {
t.Errorf("command should not contain %q, got: %s", s, cmd)
}
}
}
// list
in := base
in.Action = "list"
cmd, _ := h.buildFileCommand(in)
mustContain(t, cmd, "dir /a ", `"."`)
mustNotContain(t, cmd, "ls -la")
in.Path = `C:\Users\Public Docs`
cmd, _ = h.buildFileCommand(in)
mustContain(t, cmd, "dir /a ", `"C:\Users\Public Docs"`)
// read
in = base
in.Action = "read"
in.Path = `C:\flag.txt`
cmd, _ = h.buildFileCommand(in)
mustContain(t, cmd, "type ", `"C:\flag.txt"`)
// delete
in.Action = "delete"
cmd, _ = h.buildFileCommand(in)
mustContain(t, cmd, "del /q /f ", `"C:\flag.txt"`)
mustNotContain(t, cmd, "rm -f")
// mkdir
in.Action = "mkdir"
in.Path = `C:\a\b\c`
cmd, _ = h.buildFileCommand(in)
mustContain(t, cmd, "md ", `"C:\a\b\c"`)
// rename
in = base
in.Action = "rename"
in.Path = `C:\a.txt`
in.TargetPath = `C:\b.txt`
cmd, _ = h.buildFileCommand(in)
mustContain(t, cmd, "move /y ", `"C:\a.txt"`, `"C:\b.txt"`)
// write → PowerShell base64 one-liner
in = base
in.Action = "write"
in.Path = `C:\out.txt`
in.Content = "hello 世界"
cmd, _ = h.buildFileCommand(in)
wantB64 := base64.StdEncoding.EncodeToString([]byte("hello 世界"))
mustContain(t, cmd,
"powershell -NoProfile -NonInteractive -Command",
"[Convert]::FromBase64String('"+wantB64+"')",
"[IO.File]::WriteAllBytes('C:\\out.txt'",
)
mustNotContain(t, cmd, "echo ", "base64 -d")
// upload (chunk_index=0 equivalent) uses WriteAllBytes
in = base
in.Action = "upload"
in.Path = `C:\bin\f`
in.Content = "YWJjZA=="
cmd, _ = h.buildFileCommand(in)
mustContain(t, cmd, "WriteAllBytes('C:\\bin\\f'", "FromBase64String('YWJjZA==')")
// upload_chunk index=0 → WriteAllBytes
in.Action = "upload_chunk"
in.ChunkIndex = 0
cmd, _ = h.buildFileCommand(in)
mustContain(t, cmd, "WriteAllBytes(")
mustNotContain(t, cmd, "FileMode]::Append")
// upload_chunk index>0 → append (Open with Append mode)
in.ChunkIndex = 1
cmd, _ = h.buildFileCommand(in)
mustContain(t, cmd, "[IO.FileMode]::Append", "FromBase64String('YWJjZA==')")
}
// TestBuildFileCommand_AutoFallbackMatchesLegacyBehavior 确保 os=auto 时与旧版 shellType 判定行为完全一致
// asp/aspx 视为 Windows(旧行为),其他视为 Linux。
func TestBuildFileCommand_AutoFallbackMatchesLegacyBehavior(t *testing.T) {
h := newTestWebShellHandler()
// asp + auto → windows 命令
cmd, _ := h.buildFileCommand(fileCommandInput{Action: "list", OS: "auto", ShellType: "asp"})
if !strings.Contains(cmd, "dir /a") {
t.Errorf("auto + asp should use Windows cmd, got: %s", cmd)
}
cmd, _ = h.buildFileCommand(fileCommandInput{Action: "list", OS: "auto", ShellType: "aspx"})
if !strings.Contains(cmd, "dir /a") {
t.Errorf("auto + aspx should use Windows cmd, got: %s", cmd)
}
// php/jsp/custom + auto → linux 命令(与历史行为一致)
for _, st := range []string{"php", "jsp", "custom", ""} {
cmd, _ = h.buildFileCommand(fileCommandInput{Action: "list", OS: "auto", ShellType: st})
if !strings.Contains(cmd, "ls -la") {
t.Errorf("auto + %q should use Linux cmd, got: %s", st, cmd)
}
}
// 显式 OS 覆盖 shellType
cmd, _ = h.buildFileCommand(fileCommandInput{Action: "list", OS: "windows", ShellType: "php"})
if !strings.Contains(cmd, "dir /a") {
t.Errorf("explicit windows should override php shellType, got: %s", cmd)
}
cmd, _ = h.buildFileCommand(fileCommandInput{Action: "list", OS: "linux", ShellType: "asp"})
if !strings.Contains(cmd, "ls -la") {
t.Errorf("explicit linux should override asp shellType, got: %s", cmd)
}
}
+127
View File
@@ -0,0 +1,127 @@
package handler
import (
"bytes"
"io"
"net/http"
"strings"
"go.uber.org/zap"
)
// webshellOSProbeCommand 探活命令:利用 Windows cmd 与 POSIX shell 对 `%OS%` 展开差异进行判定。
// - Windows cmd`%OS%` 被展开为 `Windows_NT`,回显 `:OSPROBE_Windows_NT:END`
// - POSIX sh/bash`%OS%` 不是变量语法,作为字面量原样保留,回显 `:OSPROBE_%OS%:END`
//
// 一条命令即可得到明确的、互斥的信号,避免探活成本(相比发两次命令)。
// 冒号包裹是为了避免部分 shell 输出多余空白/BOM 时字符串匹配失效。
const webshellOSProbeCommand = "echo :OSPROBE_%OS%:END"
// probeWebshellOSViaExec 通过一次命令执行的回显推断目标操作系统。
//
// 返回值:
// - "windows" / "linux":识别成功
// - "":无法判定(调用方应保留既有 fallback 逻辑)
//
// 入参 execFn 是一个"发命令并拿到回显"的闭包;让 HTTP 入口和 MCP 入口可以共用同一套探活逻辑
// 而不必关心底层是如何发包的。
func probeWebshellOSViaExec(execFn func(cmd string) (output string, ok bool)) string {
if execFn == nil {
return ""
}
out, ok := execFn(webshellOSProbeCommand)
if !ok {
return ""
}
return classifyWebshellOSProbeOutput(out)
}
// classifyWebshellOSProbeOutput 纯函数:根据探活命令的回显判定 OS。
// 抽出来是为了单测可直接覆盖所有分支,无需真实 HTTP 调用。
func classifyWebshellOSProbeOutput(out string) string {
if out == "" {
return ""
}
lower := strings.ToLower(out)
// Windows 强信号:cmd.exe 成功展开了 %OS% 变量
if strings.Contains(out, "Windows_NT") {
return "windows"
}
// 容错:部分老版本 Windows 可能 `%OS%` 展开为其他字样(极少见),再看 PATH/OS 等次级线索
if strings.Contains(lower, "microsoft windows") {
return "windows"
}
// Linux/Unix 强信号:`%OS%` 字面量被原样回显,说明 shell 不是 cmd.exe
if strings.Contains(out, "%OS%") {
return "linux"
}
// 次级线索:部分 webshell 在 Linux 上可能走了其他外壳(如 zsh/ash),
// 但它们对 `%OS%` 同样不展开;若命中 OSPROBE 头部却没拿到 %OS% 字面量,
// 说明回显被中途截断或过滤,保守返回空让上层 fallback。
return ""
}
// newHTTPExecFn 为 HTTP FileOp 路径构造"发命令取回显"的闭包,供探活复用。
// 参数来自 HTTP 请求,复用 buildExecURL / buildExecBody 两个已有的命令编排器,
// 确保探活包与实际文件操作包走完全一致的 webshell 协议(GET/POST、参数名、编码)。
func (h *WebShellHandler) newHTTPExecFn(targetURL, password, shellType, method, cmdParam, encoding string) func(string) (string, bool) {
useGET := strings.ToUpper(strings.TrimSpace(method)) == "GET"
if strings.TrimSpace(cmdParam) == "" {
cmdParam = "cmd"
}
return func(cmd string) (string, bool) {
var (
httpReq *http.Request
err error
)
if useGET {
u := h.buildExecURL(targetURL, shellType, password, cmdParam, cmd)
httpReq, err = http.NewRequest(http.MethodGet, u, nil)
} else {
body := h.buildExecBody(shellType, password, cmdParam, cmd)
httpReq, err = http.NewRequest(http.MethodPost, targetURL, bytes.NewReader(body))
if err == nil {
httpReq.Header.Set("Content-Type", "application/x-www-form-urlencoded")
}
}
if err != nil {
return "", false
}
httpReq.Header.Set("User-Agent", "Mozilla/5.0 (compatible; CyberStrikeAI-WebShell/1.0)")
resp, err := h.client.Do(httpReq)
if err != nil {
return "", false
}
defer resp.Body.Close()
raw, _ := io.ReadAll(resp.Body)
return decodeWebshellOutput(raw, encoding), resp.StatusCode == http.StatusOK
}
}
// persistDetectedOS 把探活结果回写到连接表;失败只记日志不阻断主流程。
// 设计上故意只触发 UPDATE,不会新建记录,因此即便 connectionID 不存在也只是悄悄放弃。
func (h *WebShellHandler) persistDetectedOS(connectionID, detected string) {
connectionID = strings.TrimSpace(connectionID)
detected = normalizeWebshellOS(detected)
if connectionID == "" || detected == "" || detected == "auto" {
return
}
conn, err := h.db.GetWebshellConnection(connectionID)
if err != nil || conn == nil {
// 不是所有调用方都能提供有效 ID(比如临时测试),这里静默返回
return
}
if normalizeWebshellOS(conn.OS) != "auto" {
// 用户已经显式选过 OS,尊重用户选择,不自动覆盖
return
}
conn.OS = detected
if err := h.db.UpdateWebshellConnection(conn); err != nil {
h.logger.Warn("webshell 探活结果持久化失败", zap.String("id", connectionID), zap.String("os", detected), zap.Error(err))
return
}
h.logger.Info("webshell auto OS 探活成功并持久化", zap.String("id", connectionID), zap.String("os", detected))
}
+68
View File
@@ -0,0 +1,68 @@
package handler
import "testing"
func TestClassifyWebshellOSProbeOutput(t *testing.T) {
cases := []struct {
name string
in string
want string
}{
{"Windows cmd 回显完整", ":OSPROBE_Windows_NT:END\r\n", "windows"},
{"Windows cmd 回显带额外空行", "\r\n:OSPROBE_Windows_NT:END\r\n", "windows"},
{"Windows 次级线索 - ver banner", "Microsoft Windows [版本 10.0.19045]\r\n", "windows"},
{"Linux sh 字面量回显", ":OSPROBE_%OS%:END\n", "linux"},
{"Linux 紧凑输出(无换行)", ":OSPROBE_%OS%:END", "linux"},
{"空输出 - 无法判定", "", ""},
{"被过滤的输出 - 无法判定", "something weird", ""},
{"仅有 OSPROBE 前缀但被截断 - 保守返回空", ":OSPROBE_:END", ""},
}
for _, c := range cases {
if got := classifyWebshellOSProbeOutput(c.in); got != c.want {
t.Errorf("case %q: got %q, want %q", c.name, got, c.want)
}
}
}
func TestProbeWebshellOSViaExec_SendsOneCommandOnly(t *testing.T) {
var calls []string
fn := func(cmd string) (string, bool) {
calls = append(calls, cmd)
return ":OSPROBE_Windows_NT:END", true
}
got := probeWebshellOSViaExec(fn)
if got != "windows" {
t.Fatalf("want windows, got %q", got)
}
if len(calls) != 1 {
t.Fatalf("probe should issue exactly one exec call, got %d: %v", len(calls), calls)
}
if calls[0] != webshellOSProbeCommand {
t.Errorf("probe command mismatch: got %q", calls[0])
}
}
func TestProbeWebshellOSViaExec_NotOkReturnsEmpty(t *testing.T) {
// HTTP 非 200 的场景:execFn 返回 ok=false,探活应放弃
fn := func(cmd string) (string, bool) { return "whatever", false }
if got := probeWebshellOSViaExec(fn); got != "" {
t.Errorf("want empty when exec not ok, got %q", got)
}
}
func TestProbeWebshellOSViaExec_NilSafeguard(t *testing.T) {
if got := probeWebshellOSViaExec(nil); got != "" {
t.Errorf("nil execFn should return empty, got %q", got)
}
}
func TestProbeWebshellOSViaExec_LinuxUname(t *testing.T) {
// 某些 webshell 对 `%OS%` 字面量也会过滤(例如安全规则),
// 但主要路径是"%OS% 字面量被原样回显"。这里覆盖标准 Linux 场景。
fn := func(cmd string) (string, bool) {
return ":OSPROBE_%OS%:END\n", true
}
if got := probeWebshellOSViaExec(fn); got != "linux" {
t.Errorf("Linux case: want linux, got %q", got)
}
}
+3 -3
View File
@@ -16,9 +16,9 @@ const (
// DSL keys for [VectorEinoRetriever.Retrieve] via [retriever.WithDSLInfo].
const (
DSLRiskType = "risk_type"
DSLSimilarityThreshold = "similarity_threshold"
DSLSubIndexFilter = "sub_index_filter"
DSLRiskType = "risk_type"
DSLSimilarityThreshold = "similarity_threshold"
DSLSubIndexFilter = "sub_index_filter"
)
// FormatEmbeddingInput matches the historical indexing format so existing embeddings
+1 -1
View File
@@ -8,8 +8,8 @@ import (
"cyberstrike-ai/internal/config"
"github.com/cloudwego/eino/compose"
"github.com/cloudwego/eino/components/document"
"github.com/cloudwego/eino/compose"
"github.com/cloudwego/eino/schema"
)
+9 -9
View File
@@ -11,9 +11,9 @@ import (
"cyberstrike-ai/internal/config"
fileloader "github.com/cloudwego/eino-ext/components/document/loader/file"
"github.com/cloudwego/eino/compose"
"github.com/cloudwego/eino/components/document"
"github.com/cloudwego/eino/components/indexer"
"github.com/cloudwego/eino/compose"
"github.com/cloudwego/eino/schema"
"go.uber.org/zap"
)
@@ -35,14 +35,14 @@ type Indexer struct {
lastErrorTime time.Time
errorCount int
rebuildMu sync.RWMutex
isRebuilding bool
rebuildTotalItems int
rebuildCurrent int
rebuildFailed int
rebuildStartTime time.Time
rebuildLastItemID string
rebuildLastChunks int
rebuildMu sync.RWMutex
isRebuilding bool
rebuildTotalItems int
rebuildCurrent int
rebuildFailed int
rebuildStartTime time.Time
rebuildLastItemID string
rebuildLastChunks int
}
// NewIndexer 创建索引器并编译 Eino 索引链;kcfg 为完整知识库配置(含 indexing 与路径相关行为)。
+3 -3
View File
@@ -108,9 +108,9 @@ func (r *RetrievalLog) MarshalJSON() ([]byte, error) {
// CategoryWithItems 分类及其下的知识项(用于按分类分页)
type CategoryWithItems struct {
Category string `json:"category"` // 分类名称
ItemCount int `json:"itemCount"` // 该分类下的知识项总数
Items []*KnowledgeItemSummary `json:"items"` // 该分类下的知识项列表
Category string `json:"category"` // 分类名称
ItemCount int `json:"itemCount"` // 该分类下的知识项总数
Items []*KnowledgeItemSummary `json:"items"` // 该分类下的知识项列表
}
// SearchRequest 搜索请求
+10 -10
View File
@@ -55,14 +55,14 @@ func New(level, output string) *Logger {
}
func (l *Logger) Fatal(msg string, fields ...interface{}) {
zapFields := make([]zap.Field, 0, len(fields))
for _, f := range fields {
switch v := f.(type) {
case error:
zapFields = append(zapFields, zap.Error(v))
default:
zapFields = append(zapFields, zap.Any("field", v))
}
}
l.Logger.Fatal(msg, zapFields...)
zapFields := make([]zap.Field, 0, len(fields))
for _, f := range fields {
switch v := f.(type) {
case error:
zapFields = append(zapFields, zap.Error(v))
default:
zapFields = append(zapFields, zap.Any("field", v))
}
}
l.Logger.Fatal(msg, zapFields...)
}
+10 -10
View File
@@ -62,7 +62,7 @@ func TestExternalMCPManager_RemoveConfig(t *testing.T) {
manager := NewExternalMCPManager(logger)
cfg := config.ExternalMCPServerConfig{
Command: "python3",
Command: "python3",
ExternalMCPEnable: false,
}
@@ -86,17 +86,17 @@ func TestExternalMCPManager_GetStats(t *testing.T) {
// 添加多个配置
manager.AddOrUpdateConfig("enabled1", config.ExternalMCPServerConfig{
Command: "python3",
Command: "python3",
ExternalMCPEnable: true,
})
manager.AddOrUpdateConfig("enabled2", config.ExternalMCPServerConfig{
URL: "http://127.0.0.1:8081/mcp",
URL: "http://127.0.0.1:8081/mcp",
ExternalMCPEnable: true,
})
manager.AddOrUpdateConfig("disabled1", config.ExternalMCPServerConfig{
Command: "python3",
Command: "python3",
ExternalMCPEnable: false,
})
@@ -122,11 +122,11 @@ func TestExternalMCPManager_LoadConfigs(t *testing.T) {
externalMCPConfig := config.ExternalMCPConfig{
Servers: map[string]config.ExternalMCPServerConfig{
"loaded1": {
Command: "python3",
Command: "python3",
ExternalMCPEnable: true,
},
"loaded2": {
URL: "http://127.0.0.1:8081/mcp",
URL: "http://127.0.0.1:8081/mcp",
ExternalMCPEnable: false,
},
},
@@ -153,9 +153,9 @@ func TestLazySDKClient_InitializeFails(t *testing.T) {
logger := zap.NewNop()
// 使用不存在的 HTTP 地址,Initialize 应失败
cfg := config.ExternalMCPServerConfig{
Type: "http",
URL: "http://127.0.0.1:19999/nonexistent",
Timeout: 2,
Type: "http",
URL: "http://127.0.0.1:19999/nonexistent",
Timeout: 2,
}
c := newLazySDKClient(cfg, logger)
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
@@ -176,7 +176,7 @@ func TestExternalMCPManager_StartStopClient(t *testing.T) {
// 添加一个禁用的配置
cfg := config.ExternalMCPServerConfig{
Command: "python3",
Command: "python3",
ExternalMCPEnable: false,
}
File diff suppressed because it is too large Load Diff
+133
View File
@@ -0,0 +1,133 @@
package multiagent
import (
"context"
"strings"
"cyberstrike-ai/internal/agent"
"github.com/bytedance/sonic"
"github.com/cloudwego/eino/adk"
"github.com/cloudwego/eino/schema"
"go.uber.org/zap"
)
type einoModelInputTelemetryMiddleware struct {
adk.BaseChatModelAgentMiddleware
logger *zap.Logger
modelName string
conversationID string
phase string
}
func newEinoModelInputTelemetryMiddleware(
logger *zap.Logger,
modelName string,
conversationID string,
phase string,
) adk.ChatModelAgentMiddleware {
if logger == nil {
return nil
}
return &einoModelInputTelemetryMiddleware{
logger: logger,
modelName: strings.TrimSpace(modelName),
conversationID: strings.TrimSpace(conversationID),
phase: strings.TrimSpace(phase),
}
}
func (m *einoModelInputTelemetryMiddleware) BeforeModelRewriteState(
ctx context.Context,
state *adk.ChatModelAgentState,
mc *adk.ModelContext,
) (context.Context, *adk.ChatModelAgentState, error) {
if m == nil || m.logger == nil || state == nil {
return ctx, state, nil
}
tokens := estimateTokensForMessagesAndTools(ctx, m.modelName, state.Messages, mcTools(mc))
m.logger.Info("eino model input estimated",
zap.String("phase", m.phase),
zap.String("conversation_id", m.conversationID),
zap.Int("messages", len(state.Messages)),
zap.Int("tools", len(mcTools(mc))),
zap.Int("input_tokens_estimated", tokens),
)
return ctx, state, nil
}
func mcTools(mc *adk.ModelContext) []*schema.ToolInfo {
if mc == nil || len(mc.Tools) == 0 {
return nil
}
return mc.Tools
}
func estimateTokensForMessagesAndTools(
_ context.Context,
modelName string,
messages []adk.Message,
tools []*schema.ToolInfo,
) int {
var sb strings.Builder
for _, msg := range messages {
if msg == nil {
continue
}
sb.WriteString(string(msg.Role))
sb.WriteByte('\n')
sb.WriteString(msg.Content)
sb.WriteByte('\n')
if msg.ReasoningContent != "" {
sb.WriteString(msg.ReasoningContent)
sb.WriteByte('\n')
}
if len(msg.ToolCalls) > 0 {
if b, err := sonic.Marshal(msg.ToolCalls); err == nil {
sb.Write(b)
sb.WriteByte('\n')
}
}
}
for _, tl := range tools {
if tl == nil {
continue
}
cp := *tl
cp.Extra = nil
if text, err := sonic.MarshalString(cp); err == nil {
sb.WriteString(text)
sb.WriteByte('\n')
}
}
text := sb.String()
if text == "" {
return 0
}
tc := agent.NewTikTokenCounter()
if n, err := tc.Count(modelName, text); err == nil {
return n
}
return (len(text) + 3) / 4
}
func logPlanExecuteModelInputEstimate(
logger *zap.Logger,
modelName string,
conversationID string,
phase string,
msgs []adk.Message,
) {
if logger == nil {
return
}
tokens := estimateTokensForMessagesAndTools(context.Background(), modelName, msgs, nil)
logger.Info("eino model input estimated",
zap.String("phase", phase),
zap.String("conversation_id", strings.TrimSpace(conversationID)),
zap.Int("messages", len(msgs)),
zap.Int("tools", 0),
zap.Int("input_tokens_estimated", tokens),
)
}
+64 -1
View File
@@ -8,6 +8,7 @@ import (
"strings"
"cyberstrike-ai/internal/config"
"cyberstrike-ai/internal/mcp/builtin"
localbk "github.com/cloudwego/eino-ext/adk/backend/local"
"github.com/cloudwego/eino/adk"
@@ -65,6 +66,66 @@ func splitToolsForToolSearch(all []tool.BaseTool, alwaysVisible int) (static []t
return append([]tool.BaseTool(nil), all[:alwaysVisible]...), append([]tool.BaseTool(nil), all[alwaysVisible:]...), true
}
func splitToolsForToolSearchByNames(all []tool.BaseTool, names []string, fallbackAlwaysVisible int) (static []tool.BaseTool, dynamic []tool.BaseTool, ok bool) {
nameSet := make(map[string]struct{}, len(names))
for _, n := range names {
n = strings.TrimSpace(strings.ToLower(n))
if n == "" {
continue
}
nameSet[n] = struct{}{}
}
if len(nameSet) == 0 {
return splitToolsForToolSearch(all, fallbackAlwaysVisible)
}
static = make([]tool.BaseTool, 0, len(all))
dynamic = make([]tool.BaseTool, 0, len(all))
for _, t := range all {
if t == nil {
continue
}
info, err := t.Info(context.Background())
name := ""
if err == nil && info != nil {
name = strings.TrimSpace(strings.ToLower(info.Name))
}
if _, keep := nameSet[name]; keep {
static = append(static, t)
continue
}
dynamic = append(dynamic, t)
}
if len(static) == 0 || len(dynamic) == 0 {
// fallback: preserve previous behavior when whitelist misses all or includes all.
return splitToolsForToolSearch(all, fallbackAlwaysVisible)
}
return static, dynamic, true
}
func mergeAlwaysVisibleToolNames(configured []string) []string {
merged := make([]string, 0, len(configured)+32)
seen := make(map[string]struct{}, len(configured)+32)
add := func(name string) {
n := strings.TrimSpace(strings.ToLower(name))
if n == "" {
return
}
if _, ok := seen[n]; ok {
return
}
seen[n] = struct{}{}
merged = append(merged, n)
}
for _, n := range configured {
add(n)
}
// Always include hardcoded backend builtin MCP tools from constants.
for _, n := range builtin.GetAllBuiltinTools() {
add(n)
}
return merged
}
func buildReductionMiddleware(ctx context.Context, mw config.MultiAgentEinoMiddlewareConfig, convID string, loc *localbk.Local, logger *zap.Logger) (adk.ChatModelAgentMiddleware, error) {
if loc == nil {
return nil, fmt.Errorf("reduction: local backend nil")
@@ -87,6 +148,8 @@ func buildReductionMiddleware(ctx context.Context, mw config.MultiAgentEinoMiddl
RootDir: root,
ReadFileToolName: "read_file",
ClearExcludeTools: excl,
MaxLengthForTrunc: mw.ReductionMaxLengthForTruncEffective(),
MaxTokensForClear: int64(mw.ReductionMaxTokensForClearEffective()),
})
if err != nil {
return nil, err
@@ -142,7 +205,7 @@ func prependEinoMiddlewares(
alwaysVis = 12
}
if mw.ToolSearchEnable && len(tools) >= minTools {
static, dynamic, split := splitToolsForToolSearch(tools, alwaysVis)
static, dynamic, split := splitToolsForToolSearchByNames(tools, mergeAlwaysVisibleToolNames(mw.ToolSearchAlwaysVisibleTools), alwaysVis)
if split && len(dynamic) > 0 {
ts, terr := toolsearch.New(ctx, &toolsearch.Config{DynamicTools: dynamic})
if terr != nil {
@@ -0,0 +1,38 @@
package multiagent
import (
"context"
"fmt"
"github.com/cloudwego/eino/adk"
)
func applyBeforeModelRewriteHandlers(
ctx context.Context,
msgs []adk.Message,
handlers []adk.ChatModelAgentMiddleware,
) ([]adk.Message, error) {
if len(msgs) == 0 || len(handlers) == 0 {
return msgs, nil
}
state := &adk.ChatModelAgentState{Messages: msgs}
modelCtx := &adk.ModelContext{}
curCtx := ctx
for _, h := range handlers {
if h == nil {
continue
}
nextCtx, nextState, err := h.BeforeModelRewriteState(curCtx, state, modelCtx)
if err != nil {
return nil, fmt.Errorf("before model rewrite: %w", err)
}
if nextCtx != nil {
curCtx = nextCtx
}
if nextState != nil {
state = nextState
}
}
return state.Messages, nil
}
+171 -20
View File
@@ -5,6 +5,7 @@ import (
"fmt"
"strings"
"cyberstrike-ai/internal/agent"
"cyberstrike-ai/internal/config"
"github.com/cloudwego/eino-ext/components/model/openai"
@@ -25,7 +26,12 @@ type PlanExecuteRootArgs struct {
LoopMaxIter int
// AppCfg / Logger 非空时为 Executor 挂载与 Deep/Supervisor 一致的 Eino summarization 中间件。
AppCfg *config.Config
MwCfg *config.MultiAgentEinoMiddlewareConfig
// ConversationID is used for transcript/isolation paths in middleware.
ConversationID string
Logger *zap.Logger
// ModelName is used for model input token estimation logs.
ModelName string
// ExecPreMiddlewares 是由 prependEinoMiddlewares 构建的前置中间件(patchtoolcalls, reduction, toolsearch, plantask),
// 与 Deep/Supervisor 主代理的 mainOrchestratorPre 一致。
ExecPreMiddlewares []adk.ChatModelAgentMiddleware
@@ -33,6 +39,8 @@ type PlanExecuteRootArgs struct {
SkillMiddleware adk.ChatModelAgentMiddleware
// FilesystemMiddleware 是 Eino filesystem 中间件,当 eino_skills.filesystem_tools 启用时提供本机文件读写与 Shell 能力(可选)。
FilesystemMiddleware adk.ChatModelAgentMiddleware
// PlannerReplannerRewriteHandlers applies BeforeModelRewriteState pipeline for planner/replanner input.
PlannerReplannerRewriteHandlers []adk.ChatModelAgentMiddleware
}
// NewPlanExecuteRoot 返回 plan → execute → replan 预置编排根节点(与 Deep / Supervisor 并列)。
@@ -50,7 +58,7 @@ func NewPlanExecuteRoot(ctx context.Context, a *PlanExecuteRootArgs) (adk.Resuma
plannerCfg := &planexecute.PlannerConfig{
ToolCallingChatModel: tcm,
}
if fn := planExecutePlannerGenInput(a.OrchInstruction); fn != nil {
if fn := planExecutePlannerGenInput(a.OrchInstruction, a.AppCfg, a.MwCfg, a.Logger, a.ModelName, a.ConversationID, a.PlannerReplannerRewriteHandlers); fn != nil {
plannerCfg.GenInputFn = fn
}
planner, err := planexecute.NewPlanner(ctx, plannerCfg)
@@ -59,7 +67,7 @@ func NewPlanExecuteRoot(ctx context.Context, a *PlanExecuteRootArgs) (adk.Resuma
}
replanner, err := planexecute.NewReplanner(ctx, &planexecute.ReplannerConfig{
ChatModel: tcm,
GenInputFn: planExecuteReplannerGenInput(a.OrchInstruction),
GenInputFn: planExecuteReplannerGenInput(a.OrchInstruction, a.AppCfg, a.MwCfg, a.Logger, a.ModelName, a.ConversationID, a.PlannerReplannerRewriteHandlers),
})
if err != nil {
return nil, fmt.Errorf("plan_execute replanner: %w", err)
@@ -81,17 +89,23 @@ func NewPlanExecuteRoot(ctx context.Context, a *PlanExecuteRootArgs) (adk.Resuma
}
// 4. summarization(最后,与 Deep/Supervisor 一致)
if a.AppCfg != nil {
sumMw, sumErr := newEinoSummarizationMiddleware(ctx, a.ExecModel, a.AppCfg, a.Logger)
sumMw, sumErr := newEinoSummarizationMiddleware(ctx, a.ExecModel, a.AppCfg, a.MwCfg, a.ConversationID, a.Logger)
if sumErr != nil {
return nil, fmt.Errorf("plan_execute executor summarization: %w", sumErr)
}
execHandlers = append(execHandlers, sumMw)
}
// 5. 孤儿 tool 消息兜底:必须挂在所有改写历史中间件(summarization/reduction/skill)之后、
// telemetry 之前,保证送入 ChatModel 的消息序列 tool_call ↔ tool_result 配对完整。
execHandlers = append(execHandlers, newOrphanToolPrunerMiddleware(a.Logger, "plan_execute_executor"))
if teleMw := newEinoModelInputTelemetryMiddleware(a.Logger, a.ModelName, a.ConversationID, "plan_execute_executor"); teleMw != nil {
execHandlers = append(execHandlers, teleMw)
}
executor, err := newPlanExecuteExecutor(ctx, &planexecute.ExecutorConfig{
Model: a.ExecModel,
ToolsConfig: a.ToolsCfg,
MaxIterations: a.ExecMaxIter,
GenInputFn: planExecuteExecutorGenInput(a.OrchInstruction),
GenInputFn: planExecuteExecutorGenInput(a.OrchInstruction, a.AppCfg, a.MwCfg, a.Logger, a.ModelName, a.ConversationID),
}, execHandlers)
if err != nil {
return nil, fmt.Errorf("plan_execute executor: %w", err)
@@ -110,20 +124,42 @@ func NewPlanExecuteRoot(ctx context.Context, a *PlanExecuteRootArgs) (adk.Resuma
// planExecutePlannerGenInput 将 orchestrator instruction 作为 SystemMessage 注入 planner 输入。
// 返回 nil 时 Eino 使用内置默认 planner prompt。
func planExecutePlannerGenInput(orchInstruction string) planexecute.GenPlannerModelInputFn {
func planExecutePlannerGenInput(
orchInstruction string,
appCfg *config.Config,
mwCfg *config.MultiAgentEinoMiddlewareConfig,
logger *zap.Logger,
modelName string,
conversationID string,
rewriteHandlers []adk.ChatModelAgentMiddleware,
) planexecute.GenPlannerModelInputFn {
oi := strings.TrimSpace(orchInstruction)
if oi == "" {
if oi == "" && appCfg == nil {
return nil
}
return func(ctx context.Context, userInput []adk.Message) ([]adk.Message, error) {
userInput = capPlanExecuteUserInputMessages(userInput, appCfg, mwCfg)
msgs := make([]adk.Message, 0, 1+len(userInput))
msgs = append(msgs, schema.SystemMessage(oi))
if oi != "" {
msgs = append(msgs, schema.SystemMessage(oi))
}
msgs = append(msgs, userInput...)
if rewritten, rerr := applyBeforeModelRewriteHandlers(ctx, msgs, rewriteHandlers); rerr == nil && len(rewritten) > 0 {
msgs = rewritten
}
logPlanExecuteModelInputEstimate(logger, modelName, conversationID, "plan_execute_planner", msgs)
return msgs, nil
}
}
func planExecuteExecutorGenInput(orchInstruction string) planexecute.GenModelInputFn {
func planExecuteExecutorGenInput(
orchInstruction string,
appCfg *config.Config,
mwCfg *config.MultiAgentEinoMiddlewareConfig,
logger *zap.Logger,
modelName string,
conversationID string,
) planexecute.GenModelInputFn {
oi := strings.TrimSpace(orchInstruction)
return func(ctx context.Context, in *planexecute.ExecutionContext) ([]adk.Message, error) {
planContent, err := in.Plan.MarshalJSON()
@@ -131,9 +167,9 @@ func planExecuteExecutorGenInput(orchInstruction string) planexecute.GenModelInp
return nil, err
}
userMsgs, err := planexecute.ExecutorPrompt.Format(ctx, map[string]any{
"input": planExecuteFormatInput(in.UserInput),
"input": planExecuteFormatInput(capPlanExecuteUserInputMessages(in.UserInput, appCfg, mwCfg)),
"plan": string(planContent),
"executed_steps": planExecuteFormatExecutedSteps(in.ExecutedSteps),
"executed_steps": planExecuteFormatExecutedSteps(in.ExecutedSteps, appCfg, mwCfg),
"step": in.Plan.FirstStep(),
})
if err != nil {
@@ -142,6 +178,7 @@ func planExecuteExecutorGenInput(orchInstruction string) planexecute.GenModelInp
if oi != "" {
userMsgs = append([]adk.Message{schema.SystemMessage(oi)}, userMsgs...)
}
logPlanExecuteModelInputEstimate(logger, modelName, conversationID, "plan_execute_executor_gen_input", userMsgs)
return userMsgs, nil
}
}
@@ -155,18 +192,22 @@ func planExecuteFormatInput(input []adk.Message) string {
return sb.String()
}
func planExecuteFormatExecutedSteps(results []planexecute.ExecutedStep) string {
capped := capPlanExecuteExecutedSteps(results)
var sb strings.Builder
for _, result := range capped {
sb.WriteString(fmt.Sprintf("Step: %s\nResult: %s\n\n", result.Step, result.Result))
}
return sb.String()
func planExecuteFormatExecutedSteps(results []planexecute.ExecutedStep, appCfg *config.Config, mwCfg *config.MultiAgentEinoMiddlewareConfig) string {
capped := capPlanExecuteExecutedStepsWithConfig(results, mwCfg)
return renderPlanExecuteStepsByBudget(capped, appCfg, mwCfg)
}
// planExecuteReplannerGenInput 与 Eino 默认 Replanner 输入一致,但 executed_steps 经 cap 后再写入 prompt
// 且在 orchInstruction 非空时 prepend SystemMessage 使 replanner 也能接收全局指令。
func planExecuteReplannerGenInput(orchInstruction string) planexecute.GenModelInputFn {
func planExecuteReplannerGenInput(
orchInstruction string,
appCfg *config.Config,
mwCfg *config.MultiAgentEinoMiddlewareConfig,
logger *zap.Logger,
modelName string,
conversationID string,
rewriteHandlers []adk.ChatModelAgentMiddleware,
) planexecute.GenModelInputFn {
oi := strings.TrimSpace(orchInstruction)
return func(ctx context.Context, in *planexecute.ExecutionContext) ([]adk.Message, error) {
planContent, err := in.Plan.MarshalJSON()
@@ -175,8 +216,8 @@ func planExecuteReplannerGenInput(orchInstruction string) planexecute.GenModelIn
}
msgs, err := planexecute.ReplannerPrompt.Format(ctx, map[string]any{
"plan": string(planContent),
"input": planExecuteFormatInput(in.UserInput),
"executed_steps": planExecuteFormatExecutedSteps(in.ExecutedSteps),
"input": planExecuteFormatInput(capPlanExecuteUserInputMessages(in.UserInput, appCfg, mwCfg)),
"executed_steps": planExecuteFormatExecutedSteps(in.ExecutedSteps, appCfg, mwCfg),
"plan_tool": planexecute.PlanToolInfo.Name,
"respond_tool": planexecute.RespondToolInfo.Name,
})
@@ -186,10 +227,120 @@ func planExecuteReplannerGenInput(orchInstruction string) planexecute.GenModelIn
if oi != "" {
msgs = append([]adk.Message{schema.SystemMessage(oi)}, msgs...)
}
if rewritten, rerr := applyBeforeModelRewriteHandlers(ctx, msgs, rewriteHandlers); rerr == nil && len(rewritten) > 0 {
msgs = rewritten
}
logPlanExecuteModelInputEstimate(logger, modelName, conversationID, "plan_execute_replanner", msgs)
return msgs, nil
}
}
func capPlanExecuteUserInputMessages(input []adk.Message, appCfg *config.Config, mwCfg *config.MultiAgentEinoMiddlewareConfig) []adk.Message {
if len(input) == 0 {
return input
}
maxTotal := 120000
modelName := "gpt-4o"
if appCfg != nil {
if appCfg.OpenAI.MaxTotalTokens > 0 {
maxTotal = appCfg.OpenAI.MaxTotalTokens
}
if m := strings.TrimSpace(appCfg.OpenAI.Model); m != "" {
modelName = m
}
}
// Reserve most tokens for planner/replanner prompt and tool schema.
ratio := 0.35
if mwCfg != nil {
ratio = mwCfg.PlanExecuteUserInputBudgetRatioEffective()
}
budget := int(float64(maxTotal) * ratio)
if budget < 4096 {
budget = 4096
}
tc := agent.NewTikTokenCounter()
out := make([]adk.Message, 0, len(input))
used := 0
for i := len(input) - 1; i >= 0; i-- {
msg := input[i]
if msg == nil {
continue
}
n, err := tc.Count(modelName, string(msg.Role)+"\n"+msg.Content)
if err != nil {
n = (len(msg.Content) + 3) / 4
}
if n <= 0 {
n = 1
}
if used+n > budget {
break
}
used += n
out = append(out, msg)
}
for i, j := 0, len(out)-1; i < j; i, j = i+1, j-1 {
out[i], out[j] = out[j], out[i]
}
if len(out) == 0 {
// Keep the latest user message at least.
return []adk.Message{input[len(input)-1]}
}
return out
}
func renderPlanExecuteStepsByBudget(steps []planexecute.ExecutedStep, appCfg *config.Config, mwCfg *config.MultiAgentEinoMiddlewareConfig) string {
if len(steps) == 0 {
return ""
}
maxTotal := 120000
modelName := "gpt-4o"
if appCfg != nil {
if appCfg.OpenAI.MaxTotalTokens > 0 {
maxTotal = appCfg.OpenAI.MaxTotalTokens
}
if m := strings.TrimSpace(appCfg.OpenAI.Model); m != "" {
modelName = m
}
}
ratio := 0.2
if mwCfg != nil {
ratio = mwCfg.PlanExecuteExecutedStepsBudgetRatioEffective()
}
budget := int(float64(maxTotal) * ratio)
if budget < 3072 {
budget = 3072
}
tc := agent.NewTikTokenCounter()
var kept []string
used := 0
skipped := 0
for i := len(steps) - 1; i >= 0; i-- {
block := fmt.Sprintf("Step: %s\nResult: %s\n\n", steps[i].Step, steps[i].Result)
n, err := tc.Count(modelName, block)
if err != nil {
n = (len(block) + 3) / 4
}
if n <= 0 {
n = 1
}
if used+n > budget {
skipped = i + 1
break
}
used += n
kept = append(kept, block)
}
var sb strings.Builder
if skipped > 0 {
sb.WriteString(fmt.Sprintf("Earlier executed steps omitted due to context budget: %d steps.\n\n", skipped))
}
for i := len(kept) - 1; i >= 0; i-- {
sb.WriteString(kept[i])
}
return sb.String()
}
// planExecuteStreamsMainAssistant 将规划/执行/重规划各阶段助手流式输出映射到主对话区。
func planExecuteStreamsMainAssistant(agent string) bool {
if agent == "" {
+24 -3
View File
@@ -125,7 +125,7 @@ func RunEinoSingleChatModelAgent(
return nil, fmt.Errorf("eino single 模型: %w", err)
}
mainSumMw, err := newEinoSummarizationMiddleware(ctx, mainModel, appCfg, logger)
mainSumMw, err := newEinoSummarizationMiddleware(ctx, mainModel, appCfg, &ma.EinoMiddleware, conversationID, logger)
if err != nil {
return nil, fmt.Errorf("eino single summarization: %w", err)
}
@@ -145,6 +145,9 @@ func RunEinoSingleChatModelAgent(
handlers = append(handlers, einoSkillMW)
}
handlers = append(handlers, mainSumMw)
if teleMw := newEinoModelInputTelemetryMiddleware(logger, appCfg.OpenAI.Model, conversationID, "eino_single"); teleMw != nil {
handlers = append(handlers, teleMw)
}
maxIter := ma.MaxIteration
if maxIter <= 0 {
@@ -165,11 +168,29 @@ func RunEinoSingleChatModelAgent(
},
EmitInternalEvents: true,
}
ins := injectToolNamesOnlyInstruction(ctx, ag.EinoSingleAgentSystemInstruction(), mainTools)
if logger != nil {
names := collectToolNames(ctx, mainTools)
mountedNames := collectToolNames(ctx, mainToolsForCfg)
hasToolSearch := false
for _, n := range names {
if strings.EqualFold(strings.TrimSpace(n), "tool_search") {
hasToolSearch = true
break
}
}
logger.Info("eino tool-name injection",
zap.String("scope", "eino_single"),
zap.Int("tool_names", len(names)),
zap.Int("mounted_tool_names", len(mountedNames)),
zap.Bool("has_tool_search", hasToolSearch),
)
}
chatCfg := &adk.ChatModelAgentConfig{
Name: einoSingleAgentName,
Description: "Eino ADK ChatModelAgent with MCP tools for authorized security testing.",
Instruction: ag.EinoSingleAgentSystemInstruction(),
Instruction: ins,
Model: mainModel,
ToolsConfig: mainToolsCfg,
MaxIterations: maxIter,
@@ -188,7 +209,7 @@ func RunEinoSingleChatModelAgent(
return nil, fmt.Errorf("eino single NewChatModelAgent: %w", err)
}
baseMsgs := historyToMessages(history)
baseMsgs := historyToMessages(history, appCfg, &ma.EinoMiddleware)
baseMsgs = append(baseMsgs, schema.UserMessage(userMessage))
streamsMainAssistant := func(agent string) bool {
+209 -3
View File
@@ -3,6 +3,8 @@ package multiagent
import (
"context"
"fmt"
"os"
"path/filepath"
"strings"
"cyberstrike-ai/internal/agent"
@@ -32,6 +34,8 @@ func newEinoSummarizationMiddleware(
ctx context.Context,
summaryModel model.BaseChatModel,
appCfg *config.Config,
mwCfg *config.MultiAgentEinoMiddlewareConfig,
conversationID string,
logger *zap.Logger,
) (adk.ChatModelAgentMiddleware, error) {
if summaryModel == nil || appCfg == nil {
@@ -41,7 +45,14 @@ func newEinoSummarizationMiddleware(
if maxTotal <= 0 {
maxTotal = 120000
}
trigger := int(float64(maxTotal) * 0.9)
triggerRatio := 0.8
emitInternalEvents := true
if mwCfg != nil {
triggerRatio = mwCfg.SummarizationTriggerRatioEffective()
emitInternalEvents = mwCfg.SummarizationEmitInternalEventsEffective()
}
// Keep enough safety margin for tokenizer/model-side accounting mismatch.
trigger := int(float64(maxTotal) * triggerRatio)
if trigger < 4096 {
trigger = maxTotal
if trigger < 4096 {
@@ -57,28 +68,57 @@ func newEinoSummarizationMiddleware(
if modelName == "" {
modelName = "gpt-4o"
}
tokenCounter := einoSummarizationTokenCounter(modelName)
recentTrailMax := trigger / 4
if recentTrailMax < 2048 {
recentTrailMax = 2048
}
if recentTrailMax > trigger/2 {
recentTrailMax = trigger / 2
}
transcriptPath := ""
if conv := strings.TrimSpace(conversationID); conv != "" {
baseRoot := filepath.Join(os.TempDir(), "cyberstrike-summarization")
if dbPath := strings.TrimSpace(appCfg.Database.Path); dbPath != "" {
// Persist with the same lifecycle as local conversation storage.
baseRoot = filepath.Join(filepath.Dir(dbPath), "conversation_artifacts", sanitizeEinoPathSegment(conv), "summarization")
}
base := baseRoot
if mkErr := os.MkdirAll(base, 0o755); mkErr == nil {
transcriptPath = filepath.Join(base, "transcript.txt")
}
}
mw, err := summarization.New(ctx, &summarization.Config{
Model: summaryModel,
Trigger: &summarization.TriggerCondition{
ContextTokens: trigger,
},
TokenCounter: einoSummarizationTokenCounter(modelName),
TokenCounter: tokenCounter,
UserInstruction: einoSummarizeUserInstruction,
EmitInternalEvents: false,
EmitInternalEvents: emitInternalEvents,
TranscriptFilePath: transcriptPath,
PreserveUserMessages: &summarization.PreserveUserMessages{
Enabled: true,
MaxTokens: preserveMax,
},
Finalize: func(ctx context.Context, originalMessages []adk.Message, summary adk.Message) ([]adk.Message, error) {
return summarizeFinalizeWithRecentAssistantToolTrail(ctx, originalMessages, summary, tokenCounter, recentTrailMax)
},
Callback: func(ctx context.Context, before, after adk.ChatModelAgentState) error {
if logger == nil {
return nil
}
beforeTokens, _ := tokenCounter(ctx, &summarization.TokenCounterInput{Messages: before.Messages})
afterTokens, _ := tokenCounter(ctx, &summarization.TokenCounterInput{Messages: after.Messages})
logger.Info("eino summarization 已压缩上下文",
zap.Int("messages_before", len(before.Messages)),
zap.Int("messages_after", len(after.Messages)),
zap.Int("tokens_before_estimated", beforeTokens),
zap.Int("tokens_after_estimated", afterTokens),
zap.Int("max_total_tokens", maxTotal),
zap.Int("trigger_context_tokens", trigger),
zap.String("transcript_file", transcriptPath),
)
return nil
},
@@ -89,6 +129,172 @@ func newEinoSummarizationMiddleware(
return mw, nil
}
// summarizeFinalizeWithRecentAssistantToolTrail 在摘要消息后保留最近 assistant/tool 轨迹,避免压缩后执行链断裂。
//
// 关键不变量:tool_call ↔ tool_result 的 pair 必须整体保留或整体丢弃。
// 把消息切成 round(回合)为原子单位:
// - user(...) 单条为一个 round
// - assistant(tool_calls=[...]) 及其后连续的 role=tool 消息合成一个 round
// - 其它 assistant(reply, 无 tool_calls) 单条为一个 round。
//
// 倒序挑 round(预算不够即放弃该 round),保证 tool 消息不会跨 round 被孤立。
func summarizeFinalizeWithRecentAssistantToolTrail(
ctx context.Context,
originalMessages []adk.Message,
summary adk.Message,
tokenCounter summarization.TokenCounterFunc,
recentTrailTokenBudget int,
) ([]adk.Message, error) {
systemMsgs := make([]adk.Message, 0, len(originalMessages))
nonSystem := make([]adk.Message, 0, len(originalMessages))
for _, msg := range originalMessages {
if msg == nil {
continue
}
if msg.Role == schema.System {
systemMsgs = append(systemMsgs, msg)
continue
}
nonSystem = append(nonSystem, msg)
}
if recentTrailTokenBudget <= 0 || len(nonSystem) == 0 {
out := make([]adk.Message, 0, len(systemMsgs)+1)
out = append(out, systemMsgs...)
out = append(out, summary)
return out, nil
}
rounds := splitMessagesIntoRounds(nonSystem)
if len(rounds) == 0 {
out := make([]adk.Message, 0, len(systemMsgs)+1)
out = append(out, systemMsgs...)
out = append(out, summary)
return out, nil
}
// 目标:至少保留 minRounds 个 round 的执行轨迹;在预算允许时尽量多保留。
// 优先确保最后一个 round(通常是最新的 tool 往返或 assistant 回复)存在。
const minRounds = 2
selectedRoundsReverse := make([]messageRound, 0, 8)
selectedCount := 0
totalTokens := 0
tokensOfRound := func(r messageRound) (int, error) {
if len(r.messages) == 0 {
return 0, nil
}
n, err := tokenCounter(ctx, &summarization.TokenCounterInput{Messages: r.messages})
if err != nil {
return 0, err
}
if n <= 0 {
n = len(r.messages)
}
return n, nil
}
for i := len(rounds) - 1; i >= 0; i-- {
r := rounds[i]
n, err := tokensOfRound(r)
if err != nil {
return nil, err
}
// 预算不够:已经保留了足够 round 则停,否则跳过该 round 继续往前找
// (避免一个超大 round 挤占全部预算,至少保证有轨迹)。
if totalTokens+n > recentTrailTokenBudget {
if selectedCount >= minRounds {
break
}
continue
}
totalTokens += n
selectedRoundsReverse = append(selectedRoundsReverse, r)
selectedCount++
}
// 还原时间顺序
selectedMsgs := make([]adk.Message, 0, 8)
for i := len(selectedRoundsReverse) - 1; i >= 0; i-- {
selectedMsgs = append(selectedMsgs, selectedRoundsReverse[i].messages...)
}
out := make([]adk.Message, 0, len(systemMsgs)+1+len(selectedMsgs))
out = append(out, systemMsgs...)
out = append(out, summary)
out = append(out, selectedMsgs...)
return out, nil
}
// messageRound 表示一个"不可分割"的消息回合。
// - 对 assistant(tool_calls) + 随后若干 tool 消息的组合,round 内全部 call_id 成对完整;
// - 对独立的 user / assistant(reply) 消息,round 仅包含该条消息。
type messageRound struct {
messages []adk.Message
}
// splitMessagesIntoRounds 将非 system 消息切分为若干 round,保证:
// - 每个 assistant(tool_calls) 与其对应的 role=tool 响应消息在同一个 round
// - 孤立(无对应 assistant(tool_calls))的 role=tool 消息不会单独成为 round
// 而是被丢弃(这些消息在 pair 完整性层面已属孤儿,保留反而会触发 LLM 400)。
func splitMessagesIntoRounds(msgs []adk.Message) []messageRound {
if len(msgs) == 0 {
return nil
}
rounds := make([]messageRound, 0, len(msgs))
i := 0
for i < len(msgs) {
msg := msgs[i]
if msg == nil {
i++
continue
}
switch {
case msg.Role == schema.Assistant && len(msg.ToolCalls) > 0:
// 收集该 assistant 提供的 call_id 集合。
provided := make(map[string]struct{}, len(msg.ToolCalls))
for _, tc := range msg.ToolCalls {
if tc.ID != "" {
provided[tc.ID] = struct{}{}
}
}
round := messageRound{messages: []adk.Message{msg}}
j := i + 1
for j < len(msgs) {
next := msgs[j]
if next == nil {
j++
continue
}
if next.Role != schema.Tool {
break
}
if next.ToolCallID != "" {
if _, ok := provided[next.ToolCallID]; !ok {
// 下一条 tool 不属于当前 assistant,认为当前 round 结束。
break
}
}
round.messages = append(round.messages, next)
j++
}
rounds = append(rounds, round)
i = j
case msg.Role == schema.Tool:
// 孤儿 tool 消息:既不跟随在一个 assistant(tool_calls) 后,
// 说明它对应的 assistant 已被上游裁剪;直接丢弃,下一步到 orphan pruner
// 兜底也不会出错,但在 round 切分这里就剔除更干净。
i++
default:
// user / assistant(reply) / 其它:单条成 round。
rounds = append(rounds, messageRound{messages: []adk.Message{msg}})
i++
}
}
return rounds
}
func einoSummarizationTokenCounter(openAIModel string) summarization.TokenCounterFunc {
tc := agent.NewTikTokenCounter()
return func(ctx context.Context, input *summarization.TokenCounterInput) (int, error) {
+345
View File
@@ -0,0 +1,345 @@
package multiagent
import (
"context"
"testing"
"github.com/cloudwego/eino/adk"
"github.com/cloudwego/eino/adk/middlewares/summarization"
"github.com/cloudwego/eino/schema"
)
// fixedTokenCounter 让 tool 消息按 tokensPerToolMessage 计,其它消息按 1 计。
// 用于验证 tool-round 超预算时整体被跳过的分支。
func fixedTokenCounter(tokensPerToolMessage int) summarization.TokenCounterFunc {
return func(_ context.Context, in *summarization.TokenCounterInput) (int, error) {
total := 0
for _, msg := range in.Messages {
if msg == nil {
continue
}
switch msg.Role {
case schema.Tool:
total += tokensPerToolMessage
default:
total++
}
}
return total, nil
}
}
// variableTokenCounter 让 tool 消息按 len(Content) 计(可区分不同大小的 tool 结果),
// 其它消息按 1 计;assistant 附加 len(ToolCalls) token 近似 tool_calls schema 开销。
func variableTokenCounter() summarization.TokenCounterFunc {
return func(_ context.Context, in *summarization.TokenCounterInput) (int, error) {
total := 0
for _, msg := range in.Messages {
if msg == nil {
continue
}
if msg.Role == schema.Tool {
total += len(msg.Content)
continue
}
total++
total += len(msg.ToolCalls)
}
return total, nil
}
}
func TestSplitMessagesIntoRounds_Complex(t *testing.T) {
msgs := []adk.Message{
schema.UserMessage("q1"),
assistantToolCallsMsg("", "c1", "c2"),
schema.ToolMessage("r1", "c1"),
schema.ToolMessage("r2", "c2"),
schema.AssistantMessage("reply1", nil),
schema.UserMessage("q2"),
assistantToolCallsMsg("", "c3"),
schema.ToolMessage("r3", "c3"),
}
rounds := splitMessagesIntoRounds(msgs)
// 5 rounds: user(q1) | assistant(tc:c1,c2)+tool*2 | assistant(reply1) | user(q2) | assistant(tc:c3)+tool(c3)
if len(rounds) != 5 {
t.Fatalf("want 5 rounds, got %d", len(rounds))
}
// round 1 应为 tool-round,必须成对
r1 := rounds[1]
if len(r1.messages) != 3 {
t.Fatalf("rounds[1] size: want 3, got %d", len(r1.messages))
}
if r1.messages[0].Role != schema.Assistant || len(r1.messages[0].ToolCalls) != 2 {
t.Fatalf("rounds[1][0] must be assistant(tc=2)")
}
for i := 1; i < 3; i++ {
if r1.messages[i].Role != schema.Tool {
t.Fatalf("rounds[1][%d] must be tool, got %s", i, r1.messages[i].Role)
}
}
// 最后一个 round 成对
rLast := rounds[len(rounds)-1]
if len(rLast.messages) != 2 {
t.Fatalf("rounds[last] size: want 2, got %d", len(rLast.messages))
}
if rLast.messages[0].Role != schema.Assistant || rLast.messages[1].Role != schema.Tool {
t.Fatalf("last round must be assistant(tc)+tool(c3)")
}
}
func TestSplitMessagesIntoRounds_DropsOrphanTool(t *testing.T) {
// 起点直接是 tool 消息(孤儿)—— 应被丢弃,不独立成 round。
msgs := []adk.Message{
schema.ToolMessage("orphan", "c_old"),
schema.UserMessage("continue"),
assistantToolCallsMsg("", "c_new"),
schema.ToolMessage("r_new", "c_new"),
}
rounds := splitMessagesIntoRounds(msgs)
// user(continue) | assistant(tc:c_new)+tool(c_new) → 2 rounds
if len(rounds) != 2 {
t.Fatalf("want 2 rounds after dropping orphan, got %d", len(rounds))
}
for _, r := range rounds {
for _, m := range r.messages {
if m.Role == schema.Tool && m.ToolCallID == "c_old" {
t.Fatalf("orphan tool c_old must not appear in any round")
}
}
}
}
func TestSplitMessagesIntoRounds_ToolBelongsToCurrentAssistantOnly(t *testing.T) {
// 两个相邻 assistant(tc),第二个的 tool 不应被归到第一个 assistant。
msgs := []adk.Message{
assistantToolCallsMsg("", "c1"),
schema.ToolMessage("r1", "c1"),
assistantToolCallsMsg("", "c2"),
schema.ToolMessage("r2", "c2"),
}
rounds := splitMessagesIntoRounds(msgs)
if len(rounds) != 2 {
t.Fatalf("want 2 rounds, got %d", len(rounds))
}
if len(rounds[0].messages) != 2 || rounds[0].messages[0].ToolCalls[0].ID != "c1" {
t.Fatalf("round[0] wrong: %+v", rounds[0].messages)
}
if len(rounds[1].messages) != 2 || rounds[1].messages[0].ToolCalls[0].ID != "c2" {
t.Fatalf("round[1] wrong: %+v", rounds[1].messages)
}
}
func TestSplitMessagesIntoRounds_ToolBelongsToWrongAssistant(t *testing.T) {
// assistant(tc:c1) 后面跟一个 tool_call_id=c999 的 tool 消息(本不属它)。
// 切分规则:该 tool 不应拼入第一个 round(配对不完整),round 在此结束。
// 而 c999 又没有对应 assistant,应被当孤儿丢弃。
msgs := []adk.Message{
assistantToolCallsMsg("", "c1"),
schema.ToolMessage("wrong", "c999"),
schema.UserMessage("hi"),
}
rounds := splitMessagesIntoRounds(msgs)
// assistant(tc:c1) 没有对应 tool(c1),但不是孤儿(patchtoolcalls 会兜底补);
// 它独立成 round 允许上游后处理。user(hi) 独立成 round。共 2 rounds。
if len(rounds) != 2 {
t.Fatalf("want 2 rounds, got %d: %+v", len(rounds), rounds)
}
for _, r := range rounds {
for _, m := range r.messages {
if m.Role == schema.Tool && m.ToolCallID == "c999" {
t.Fatalf("wrong-owner tool must be dropped as orphan")
}
}
}
}
func TestSummarizeFinalize_KeepsToolRoundIntact(t *testing.T) {
// 关键回归测试:一个 tool-round 整体被保留,而不是只保留 tool 消息。
sys := schema.SystemMessage("sys")
summary := schema.AssistantMessage("summary_content", nil)
msgs := []adk.Message{
sys,
schema.UserMessage("q1"),
schema.AssistantMessage("reply_before_tc", nil), // 填料,占预算
assistantToolCallsMsg("", "c1"),
schema.ToolMessage("r1", "c1"),
}
// token 预算:2 条消息(1 assistant + 1 tool)恰好够用。
// 若按条数保留,可能先吃 tool(c1) 再吃 assistant(reply) 落入 budgetassistant(tc:c1) 被挤掉,导致孤儿。
// 按 round 保留时,整个 tool-round 为原子,要么保留 2 条都在,要么都不在。
out, err := summarizeFinalizeWithRecentAssistantToolTrail(
context.Background(),
msgs,
summary,
fixedTokenCounter(1),
2, // 预算:2 tokens
)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
// 必须包含 system + summary
if len(out) < 2 {
t.Fatalf("output too short: %d", len(out))
}
if out[0] != sys {
t.Fatalf("first message must be system")
}
if out[1] != summary {
t.Fatalf("second message must be summary")
}
// 关键不变量:每个被保留的 tool 消息,必须能在输出中找到提供其 ToolCallID 的 assistant(tc)。
assertNoOrphanTool(t, out)
}
func TestSummarizeFinalize_SkipsOversizedToolRoundButKeepsSmallerRound(t *testing.T) {
// 构造两个大小差异显著的 tool-round:
// c_big round 的 tool 结果 content="aaaaaaaaaa"10 bytes),round token ≈ 2 (assistant+tc) + 10 = 12
// c_ok round 的 tool 结果 content="ok"2 bytes),round token ≈ 2 + 2 = 4
// 配上 budget=8,使得:
// - 最新的 c_ok round4)能放下;
// - 进一步的中间 roundassistant reply + user)也能放下;
// - 更早的 c_big round12)放不下会被跳过(continue),而非 break。
sys := schema.SystemMessage("sys")
summary := schema.AssistantMessage("summary_content", nil)
msgs := []adk.Message{
sys,
schema.UserMessage("q1"),
assistantToolCallsMsg("", "c_big"),
schema.ToolMessage("aaaaaaaaaa", "c_big"),
schema.AssistantMessage("s", nil),
schema.UserMessage("q2"),
assistantToolCallsMsg("", "c_ok"),
schema.ToolMessage("ok", "c_ok"),
}
out, err := summarizeFinalizeWithRecentAssistantToolTrail(
context.Background(),
msgs,
summary,
variableTokenCounter(),
8,
)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
assertNoOrphanTool(t, out)
// c_big 整个 round 必须被丢弃(tool 和 assistant 都不能出现)
for _, m := range out {
if m == nil {
continue
}
if m.Role == schema.Tool && m.ToolCallID == "c_big" {
t.Fatal("oversized tool round must be skipped: tool(c_big) leaked")
}
if m.Role == schema.Assistant {
for _, tc := range m.ToolCalls {
if tc.ID == "c_big" {
t.Fatal("oversized tool round must be skipped: assistant(tc:c_big) leaked")
}
}
}
}
// 最近 round (c_ok) 作为一个原子单位必须整体保留。
foundOKTool, foundOKAsst := false, false
for _, m := range out {
if m == nil {
continue
}
if m.Role == schema.Tool && m.ToolCallID == "c_ok" {
foundOKTool = true
}
if m.Role == schema.Assistant {
for _, tc := range m.ToolCalls {
if tc.ID == "c_ok" {
foundOKAsst = true
}
}
}
}
if !foundOKTool || !foundOKAsst {
t.Fatalf("recent tool-round (c_ok) must be retained as an atomic pair: assistantKept=%v toolKept=%v", foundOKAsst, foundOKTool)
}
}
func TestSummarizeFinalize_BudgetZeroFallsBackToSummaryOnly(t *testing.T) {
sys := schema.SystemMessage("sys")
summary := schema.AssistantMessage("summary", nil)
msgs := []adk.Message{
sys,
assistantToolCallsMsg("", "c1"),
schema.ToolMessage("r1", "c1"),
}
out, err := summarizeFinalizeWithRecentAssistantToolTrail(
context.Background(),
msgs,
summary,
fixedTokenCounter(1),
0,
)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(out) != 2 || out[0] != sys || out[1] != summary {
t.Fatalf("budget=0 must yield [system, summary] only, got %+v", out)
}
}
func TestSummarizeFinalize_PreservesAllSystemMessages(t *testing.T) {
sys1 := schema.SystemMessage("sys1")
sys2 := schema.SystemMessage("sys2")
summary := schema.AssistantMessage("s", nil)
msgs := []adk.Message{
sys1,
schema.UserMessage("q"),
sys2, // 非典型位置,但应当被 system group 捕获
}
out, err := summarizeFinalizeWithRecentAssistantToolTrail(
context.Background(),
msgs,
summary,
fixedTokenCounter(1),
100,
)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
systemCount := 0
for _, m := range out {
if m != nil && m.Role == schema.System {
systemCount++
}
}
if systemCount != 2 {
t.Fatalf("want 2 system messages retained, got %d", systemCount)
}
}
// assertNoOrphanTool 断言消息列表里的每个 role=tool 消息都能在更前面找到一个
// assistant(tool_calls) 提供相同 ID,否则说明产生了孤儿(触发 LLM 400 的根因)。
func assertNoOrphanTool(t *testing.T, msgs []adk.Message) {
t.Helper()
provided := make(map[string]struct{})
for _, m := range msgs {
if m == nil {
continue
}
if m.Role == schema.Assistant {
for _, tc := range m.ToolCalls {
if tc.ID != "" {
provided[tc.ID] = struct{}{}
}
}
}
if m.Role == schema.Tool && m.ToolCallID != "" {
if _, ok := provided[m.ToolCallID]; !ok {
t.Fatalf("orphan tool message found: ToolCallID=%q has no preceding assistant(tool_calls)", m.ToolCallID)
}
}
}
}
@@ -0,0 +1,73 @@
package multiagent
import (
"context"
"strings"
"github.com/cloudwego/eino/components/tool"
)
// injectToolNamesOnlyInstruction prepends a compact tool-name-only section into
// the system instruction so the model can reference current callable names.
func injectToolNamesOnlyInstruction(ctx context.Context, instruction string, tools []tool.BaseTool) string {
names := collectToolNames(ctx, tools)
if len(names) == 0 {
return strings.TrimSpace(instruction)
}
hasToolSearch := false
for _, n := range names {
if strings.EqualFold(strings.TrimSpace(n), "tool_search") {
hasToolSearch = true
break
}
}
var sb strings.Builder
sb.WriteString("以下是当前会话中可调用的工具名称列表(仅名称,无参数定义):\n")
for _, name := range names {
sb.WriteString("- ")
sb.WriteString(name)
sb.WriteByte('\n')
}
sb.WriteString("\n使用规则:\n")
sb.WriteString("1) 上述仅为名称列表,不包含参数定义。\n")
if hasToolSearch {
sb.WriteString("2) 在调用具体工具前,应先使用 tool_search 查看工具详情与参数要求,再发起调用。\n")
} else {
sb.WriteString("2) 调用具体工具前,请先确认该工具的参数要求;不确定时先澄清再调用。\n")
}
sb.WriteString("3) 不要臆造不存在的工具名。\n\n")
if s := strings.TrimSpace(instruction); s != "" {
sb.WriteString(s)
}
return sb.String()
}
func collectToolNames(ctx context.Context, tools []tool.BaseTool) []string {
if len(tools) == 0 {
return nil
}
seen := make(map[string]struct{}, len(tools))
out := make([]string, 0, len(tools))
for _, t := range tools {
if t == nil {
continue
}
info, err := t.Info(ctx)
if err != nil || info == nil {
continue
}
name := strings.TrimSpace(info.Name)
if name == "" {
continue
}
key := strings.ToLower(name)
if _, ok := seen[key]; ok {
continue
}
seen[key] = struct{}{}
out = append(out, name)
}
return out
}
-1
View File
@@ -59,4 +59,3 @@ func (m *noNestedTaskMiddleware) WrapInvokableToolCall(
return endpoint(ctx2, argumentsInJSON, opts...)
}, nil
}
@@ -0,0 +1,124 @@
package multiagent
import (
"context"
"github.com/cloudwego/eino/adk"
"github.com/cloudwego/eino/schema"
"go.uber.org/zap"
)
// orphanToolPrunerMiddleware 在每次 ChatModel 调用前剪掉没有对应 assistant(tool_calls) 的孤儿 tool 消息。
//
// 背景:
// - eino 的 summarization 中间件在触发摘要后,默认把所有非 system 消息替换为 1 条 summary 消息;
// 本项目通过自定义 FinalizesummarizeFinalizeWithRecentAssistantToolTrail)在 summary 后回填
// 最近的 assistant/tool 轨迹。若 Finalize 的保留策略按"条数"截断而未按 round 对齐,可能保留
// 了 tool 结果却把对应的 assistant(tool_calls) 落在了 summary 前面,形成孤儿 tool 消息。
// - 同样,reduction / tool_search / 自定义断点恢复等任一改写历史的逻辑,都可能破坏
// tool_call ↔ tool_result 配对。
//
// 一旦孤儿 tool 消息进入 ChatModelOpenAI 兼容 API(含 DashScope / 各类中转)会返回
// 400 "No tool call found for function call output with call_id ...",并被 Eino 包装成
// [NodeRunError] 抛出,终止整轮编排。
//
// 设计取舍:
// - 官方 patchtoolcalls 中间件只补反向(assistant(tc) 缺 tool_result),不处理孤儿 tool。
// 本中间件与之互补,专职兜底正向孤儿。
// - 仅剔除消息,不向历史里注入虚构 assistant(tc):虚构 tool_calls 反而会误导模型后续推理。
// 摘要已覆盖被裁剪段的语义,丢一条原始 tool 结果对对话连贯性影响最小。
// - 位置建议:挂在所有可能改写历史的中间件(summarization / reduction / skill / plantask /
// tool_search)之后,靠近 ChatModel 调用的那一端。
type orphanToolPrunerMiddleware struct {
adk.BaseChatModelAgentMiddleware
logger *zap.Logger
phase string
}
// newOrphanToolPrunerMiddleware 构造中间件。phase 仅用于日志区分 deep / supervisor /
// plan_execute_executor / sub_agent,不影响运行时行为。
func newOrphanToolPrunerMiddleware(logger *zap.Logger, phase string) adk.ChatModelAgentMiddleware {
return &orphanToolPrunerMiddleware{
logger: logger,
phase: phase,
}
}
// BeforeModelRewriteState 扫描消息列表,收集 assistant.tool_calls 提供的 call_id 集合,
// 再剔除掉 ToolCallID 不在该集合中的 role=tool 消息。
//
// 复杂度:O(N)。当未发现孤儿时不产生任何分配,state 原样返回以便上游快路径。
func (m *orphanToolPrunerMiddleware) BeforeModelRewriteState(
ctx context.Context,
state *adk.ChatModelAgentState,
mc *adk.ModelContext,
) (context.Context, *adk.ChatModelAgentState, error) {
_ = mc
if m == nil || state == nil || len(state.Messages) == 0 {
return ctx, state, nil
}
// 第一遍:收集所有已提供的 tool_call_id;同时快路径判定是否真的存在孤儿。
provided := make(map[string]struct{}, 8)
for _, msg := range state.Messages {
if msg == nil {
continue
}
if msg.Role == schema.Assistant {
for _, tc := range msg.ToolCalls {
if tc.ID != "" {
provided[tc.ID] = struct{}{}
}
}
}
}
hasOrphan := false
for _, msg := range state.Messages {
if msg == nil {
continue
}
if msg.Role == schema.Tool && msg.ToolCallID != "" {
if _, ok := provided[msg.ToolCallID]; !ok {
hasOrphan = true
break
}
}
}
if !hasOrphan {
return ctx, state, nil
}
// 第二遍:生成剪除孤儿后的新消息列表。
pruned := make([]adk.Message, 0, len(state.Messages))
droppedIDs := make([]string, 0, 2)
droppedNames := make([]string, 0, 2)
for _, msg := range state.Messages {
if msg == nil {
continue
}
if msg.Role == schema.Tool && msg.ToolCallID != "" {
if _, ok := provided[msg.ToolCallID]; !ok {
droppedIDs = append(droppedIDs, msg.ToolCallID)
droppedNames = append(droppedNames, msg.ToolName)
continue
}
}
pruned = append(pruned, msg)
}
if m.logger != nil {
m.logger.Warn("eino orphan tool messages pruned before model call",
zap.String("phase", m.phase),
zap.Int("dropped_count", len(droppedIDs)),
zap.Strings("dropped_tool_call_ids", droppedIDs),
zap.Strings("dropped_tool_names", droppedNames),
zap.Int("messages_before", len(state.Messages)),
zap.Int("messages_after", len(pruned)),
)
}
ns := *state
ns.Messages = pruned
return ctx, &ns, nil
}
@@ -0,0 +1,131 @@
package multiagent
import (
"context"
"testing"
"github.com/cloudwego/eino/adk"
"github.com/cloudwego/eino/schema"
)
func assistantToolCallsMsg(content string, callIDs ...string) *schema.Message {
tcs := make([]schema.ToolCall, 0, len(callIDs))
for _, id := range callIDs {
tcs = append(tcs, schema.ToolCall{
ID: id,
Type: "function",
Function: schema.FunctionCall{
Name: "stub_tool",
Arguments: `{}`,
},
})
}
return schema.AssistantMessage(content, tcs)
}
func TestOrphanToolPruner_NoOpWhenPaired(t *testing.T) {
mw := newOrphanToolPrunerMiddleware(nil, "test").(*orphanToolPrunerMiddleware)
msgs := []adk.Message{
schema.SystemMessage("sys"),
schema.UserMessage("hi"),
assistantToolCallsMsg("", "c1", "c2"),
schema.ToolMessage("r1", "c1"),
schema.ToolMessage("r2", "c2"),
schema.AssistantMessage("done", nil),
}
in := &adk.ChatModelAgentState{Messages: msgs}
_, out, err := mw.BeforeModelRewriteState(context.Background(), in, &adk.ModelContext{})
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if out == nil {
t.Fatal("expected non-nil state")
}
if len(out.Messages) != len(msgs) {
t.Fatalf("expected %d messages kept, got %d", len(msgs), len(out.Messages))
}
// 快路径:未发现孤儿时必须原地返回 state,不分配新切片。
if &out.Messages[0] != &msgs[0] {
t.Fatalf("expected state to be returned as-is (same backing slice) when no orphan present")
}
}
func TestOrphanToolPruner_DropsOrphanToolMessages(t *testing.T) {
mw := newOrphanToolPrunerMiddleware(nil, "test").(*orphanToolPrunerMiddleware)
msgs := []adk.Message{
schema.SystemMessage("sys"),
// 摘要前的 assistant(tc: c_old) 已被裁剪,但对应的 tool 结果漏保留了。
schema.ToolMessage("orphan result", "c_old"),
schema.UserMessage("continue"),
assistantToolCallsMsg("", "c_new"),
schema.ToolMessage("r_new", "c_new"),
}
in := &adk.ChatModelAgentState{Messages: msgs}
_, out, err := mw.BeforeModelRewriteState(context.Background(), in, &adk.ModelContext{})
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if out == nil {
t.Fatal("expected non-nil state")
}
if len(out.Messages) != len(msgs)-1 {
t.Fatalf("expected %d messages after pruning, got %d", len(msgs)-1, len(out.Messages))
}
for _, m := range out.Messages {
if m != nil && m.Role == schema.Tool && m.ToolCallID == "c_old" {
t.Fatalf("orphan tool message with ToolCallID=c_old should have been dropped")
}
}
// 合法的 tool(c_new) 必须保留。
foundNew := false
for _, m := range out.Messages {
if m != nil && m.Role == schema.Tool && m.ToolCallID == "c_new" {
foundNew = true
break
}
}
if !foundNew {
t.Fatal("paired tool message (c_new) must be retained")
}
}
func TestOrphanToolPruner_EmptyToolCallIDIsIgnored(t *testing.T) {
// 空 ToolCallID 的 tool 消息在真实场景中极罕见,但不应当被误判为孤儿。
// 语义上把它当作"无法校验,保留",避免误删。
mw := newOrphanToolPrunerMiddleware(nil, "test").(*orphanToolPrunerMiddleware)
odd := schema.ToolMessage("no_id", "")
msgs := []adk.Message{
schema.UserMessage("hi"),
odd,
schema.AssistantMessage("ok", nil),
}
in := &adk.ChatModelAgentState{Messages: msgs}
_, out, err := mw.BeforeModelRewriteState(context.Background(), in, &adk.ModelContext{})
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(out.Messages) != len(msgs) {
t.Fatalf("empty ToolCallID tool message should be kept, got %d messages", len(out.Messages))
}
}
func TestOrphanToolPruner_NilAndEmpty(t *testing.T) {
mw := newOrphanToolPrunerMiddleware(nil, "test").(*orphanToolPrunerMiddleware)
ctx := context.Background()
// nil state
if _, out, err := mw.BeforeModelRewriteState(ctx, nil, &adk.ModelContext{}); err != nil || out != nil {
t.Fatalf("nil state: expected (nil,nil), got (%v,%v)", out, err)
}
// empty messages
empty := &adk.ChatModelAgentState{}
if _, out, err := mw.BeforeModelRewriteState(ctx, empty, &adk.ModelContext{}); err != nil || out != empty {
t.Fatalf("empty messages: expected same state, got (%v,%v)", out, err)
}
}
+1 -1
View File
@@ -71,7 +71,7 @@ func planExecuteDefaultGenExecutorInput(ctx context.Context, in *planexecute.Exe
return planexecute.ExecutorPrompt.Format(ctx, map[string]any{
"input": planExecuteFormatInput(in.UserInput),
"plan": string(planContent),
"executed_steps": planExecuteFormatExecutedSteps(in.ExecutedSteps),
"executed_steps": planExecuteFormatExecutedSteps(in.ExecutedSteps, nil, nil),
"step": in.Plan.FirstStep(),
})
}
+22 -7
View File
@@ -5,6 +5,8 @@ import (
"strings"
"unicode/utf8"
"cyberstrike-ai/internal/config"
"github.com/cloudwego/eino/adk/prebuilt/planexecute"
)
@@ -12,8 +14,11 @@ import (
// 此处仅约束「写入模型 prompt 的视图」,不修改 Eino session 中的原始 ExecutedSteps。
const (
planExecuteMaxStepResultRunes = 12000
planExecuteKeepLastSteps = 16
defaultPlanExecuteMaxStepResultRunes = 4000
defaultPlanExecuteKeepLastSteps = 8
// Backward-compatible aliases for tests and existing references.
planExecuteMaxStepResultRunes = defaultPlanExecuteMaxStepResultRunes
planExecuteKeepLastSteps = defaultPlanExecuteKeepLastSteps
)
func truncateRunesWithSuffix(s string, maxRunes int, suffix string) string {
@@ -29,16 +34,26 @@ func truncateRunesWithSuffix(s string, maxRunes int, suffix string) string {
// capPlanExecuteExecutedSteps 折叠较早步骤、截断单步过长结果,供 prompt 使用。
func capPlanExecuteExecutedSteps(steps []planexecute.ExecutedStep) []planexecute.ExecutedStep {
return capPlanExecuteExecutedStepsWithConfig(steps, nil)
}
func capPlanExecuteExecutedStepsWithConfig(steps []planexecute.ExecutedStep, mwCfg *config.MultiAgentEinoMiddlewareConfig) []planexecute.ExecutedStep {
if len(steps) == 0 {
return steps
}
maxStepResultRunes := defaultPlanExecuteMaxStepResultRunes
keepLastSteps := defaultPlanExecuteKeepLastSteps
if mwCfg != nil {
maxStepResultRunes = mwCfg.PlanExecuteMaxStepResultRunesEffective()
keepLastSteps = mwCfg.PlanExecuteKeepLastStepsEffective()
}
out := make([]planexecute.ExecutedStep, 0, len(steps)+1)
start := 0
if len(steps) > planExecuteKeepLastSteps {
start = len(steps) - planExecuteKeepLastSteps
if len(steps) > keepLastSteps {
start = len(steps) - keepLastSteps
var b strings.Builder
b.WriteString(fmt.Sprintf("(上文已完成 %d 步;此处仅保留步骤标题以节省上下文,完整输出已省略。后续 %d 步仍保留正文。)\n",
start, planExecuteKeepLastSteps))
start, keepLastSteps))
for i := 0; i < start; i++ {
b.WriteString(fmt.Sprintf("- %s\n", steps[i].Step))
}
@@ -50,8 +65,8 @@ func capPlanExecuteExecutedSteps(steps []planexecute.ExecutedStep) []planexecute
suffix := "\n…[step result truncated]"
for i := start; i < len(steps); i++ {
e := steps[i]
if utf8.RuneCountInString(e.Result) > planExecuteMaxStepResultRunes {
e.Result = truncateRunesWithSuffix(e.Result, planExecuteMaxStepResultRunes, suffix)
if utf8.RuneCountInString(e.Result) > maxStepResultRunes {
e.Result = truncateRunesWithSuffix(e.Result, maxStepResultRunes, suffix)
}
out = append(out, e)
}
+118 -14
View File
@@ -30,10 +30,10 @@ import (
// RunResult 与单 Agent 循环结果字段对齐,便于复用存储与 SSE 收尾逻辑。
type RunResult struct {
Response string
MCPExecutionIDs []string
LastReActInput string
LastReActOutput string
Response string
MCPExecutionIDs []string
LastAgentTraceInput string // 已序列化的消息带(JSON):原生循环或 Eino 均写入,供续跑/攻击链等恢复上下文
LastAgentTraceOutput string // 本轮助手侧对外展示文本(摘要或最终回复)
}
// toolCallPendingInfo tracks a tool_call emitted to the UI so we can later
@@ -237,7 +237,7 @@ func RunDeepAgent(
subMax = subDefaultIter
}
subSumMw, err := newEinoSummarizationMiddleware(ctx, subModel, appCfg, logger)
subSumMw, err := newEinoSummarizationMiddleware(ctx, subModel, appCfg, &ma.EinoMiddleware, conversationID, logger)
if err != nil {
return nil, fmt.Errorf("子代理 %q summarization 中间件: %w", id, err)
}
@@ -257,11 +257,36 @@ func RunDeepAgent(
subHandlers = append(subHandlers, einoSkillMW)
}
subHandlers = append(subHandlers, subSumMw)
// 孤儿 tool 消息兜底:放在 summarization 之后,telemetry 之前,
// 以便 telemetry 记录的 token 数与 LLM 实际入参一致。
subHandlers = append(subHandlers, newOrphanToolPrunerMiddleware(logger, "sub_agent:"+id))
if teleMw := newEinoModelInputTelemetryMiddleware(logger, appCfg.OpenAI.Model, conversationID, "sub_agent"); teleMw != nil {
subHandlers = append(subHandlers, teleMw)
}
subInstrFinal := injectToolNamesOnlyInstruction(ctx, instr, subTools)
if logger != nil {
subNames := collectToolNames(ctx, subTools)
mountedNames := collectToolNames(ctx, subToolsForCfg)
hasToolSearch := false
for _, n := range subNames {
if strings.EqualFold(strings.TrimSpace(n), "tool_search") {
hasToolSearch = true
break
}
}
logger.Info("eino tool-name injection",
zap.String("scope", "sub_agent"),
zap.String("agent", id),
zap.Int("tool_names", len(subNames)),
zap.Int("mounted_tool_names", len(mountedNames)),
zap.Bool("has_tool_search", hasToolSearch),
)
}
sa, err := adk.NewChatModelAgent(ctx, &adk.ChatModelAgentConfig{
Name: id,
Description: desc,
Instruction: instr,
Instruction: subInstrFinal,
Model: subModel,
ToolsConfig: adk.ToolsConfig{
ToolsNodeConfig: compose.ToolsNodeConfig{
@@ -289,7 +314,7 @@ func RunDeepAgent(
return nil, fmt.Errorf("多代理主模型: %w", err)
}
mainSumMw, err := newEinoSummarizationMiddleware(ctx, mainModel, appCfg, logger)
mainSumMw, err := newEinoSummarizationMiddleware(ctx, mainModel, appCfg, &ma.EinoMiddleware, conversationID, logger)
if err != nil {
return nil, fmt.Errorf("多代理主 summarization 中间件: %w", err)
}
@@ -313,6 +338,25 @@ func RunDeepAgent(
orchDescription = d
}
}
orchInstruction = injectToolNamesOnlyInstruction(ctx, orchInstruction, mainTools)
if logger != nil {
mainNames := collectToolNames(ctx, mainTools)
mountedNames := collectToolNames(ctx, mainToolsForCfg)
hasToolSearch := false
for _, n := range mainNames {
if strings.EqualFold(strings.TrimSpace(n), "tool_search") {
hasToolSearch = true
break
}
}
logger.Info("eino tool-name injection",
zap.String("scope", "orchestrator"),
zap.String("orchestration", orchMode),
zap.Int("tool_names", len(mainNames)),
zap.Int("mounted_tool_names", len(mountedNames)),
zap.Bool("has_tool_search", hasToolSearch),
)
}
supInstr := strings.TrimSpace(orchInstruction)
if orchMode == "supervisor" {
@@ -352,6 +396,10 @@ func RunDeepAgent(
deepHandlers = append(deepHandlers, einoSkillMW)
}
deepHandlers = append(deepHandlers, mainSumMw)
deepHandlers = append(deepHandlers, newOrphanToolPrunerMiddleware(logger, "deep_orchestrator"))
if teleMw := newEinoModelInputTelemetryMiddleware(logger, appCfg.OpenAI.Model, conversationID, "deep_orchestrator"); teleMw != nil {
deepHandlers = append(deepHandlers, teleMw)
}
supHandlers := []adk.ChatModelAgentMiddleware{}
if len(mainOrchestratorPre) > 0 {
@@ -361,6 +409,10 @@ func RunDeepAgent(
supHandlers = append(supHandlers, einoSkillMW)
}
supHandlers = append(supHandlers, mainSumMw)
supHandlers = append(supHandlers, newOrphanToolPrunerMiddleware(logger, "supervisor_orchestrator"))
if teleMw := newEinoModelInputTelemetryMiddleware(logger, appCfg.OpenAI.Model, conversationID, "supervisor_orchestrator"); teleMw != nil {
supHandlers = append(supHandlers, teleMw)
}
mainToolsCfg := adk.ToolsConfig{
ToolsNodeConfig: compose.ToolsNodeConfig{
@@ -399,10 +451,19 @@ func RunDeepAgent(
ExecMaxIter: deepMaxIter,
LoopMaxIter: ma.PlanExecuteLoopMaxIterations,
AppCfg: appCfg,
MwCfg: &ma.EinoMiddleware,
ConversationID: conversationID,
Logger: logger,
ModelName: appCfg.OpenAI.Model,
ExecPreMiddlewares: mainOrchestratorPre,
SkillMiddleware: einoSkillMW,
FilesystemMiddleware: peFsMw,
PlannerReplannerRewriteHandlers: []adk.ChatModelAgentMiddleware{
mainSumMw,
// 孤儿 tool 消息兜底:必须挂在 summarization 之后、telemetry 之前。
newOrphanToolPrunerMiddleware(logger, "plan_execute_planner_replanner"),
newEinoModelInputTelemetryMiddleware(logger, appCfg.OpenAI.Model, conversationID, "plan_execute_planner_replanner_rewrite"),
},
})
if perr != nil {
return nil, perr
@@ -468,7 +529,7 @@ func RunDeepAgent(
da = dDeep
}
baseMsgs := historyToMessages(history)
baseMsgs := historyToMessages(history, appCfg, &ma.EinoMiddleware)
baseMsgs = append(baseMsgs, schema.UserMessage(userMessage))
streamsMainAssistant := func(agent string) bool {
@@ -505,34 +566,77 @@ func RunDeepAgent(
}, baseMsgs)
}
func historyToMessages(history []agent.ChatMessage) []adk.Message {
func historyToMessages(history []agent.ChatMessage, appCfg *config.Config, mwCfg *config.MultiAgentEinoMiddlewareConfig) []adk.Message {
if len(history) == 0 {
return nil
}
// 放宽条数上限:跨轮历史交给 Eino Summarization(阈值对齐 openai.max_total_tokens)在调用模型前压缩,避免在入队前硬截断为 40 条。
const maxHistoryMessages = 300
// Keep a bounded tail first; then enforce a token budget.
const maxHistoryMessages = 200
start := 0
if len(history) > maxHistoryMessages {
start = len(history) - maxHistoryMessages
}
out := make([]adk.Message, 0, len(history[start:]))
raw := make([]adk.Message, 0, len(history[start:]))
for _, h := range history[start:] {
switch h.Role {
case "user":
if strings.TrimSpace(h.Content) != "" {
out = append(out, schema.UserMessage(h.Content))
raw = append(raw, schema.UserMessage(h.Content))
}
case "assistant":
if strings.TrimSpace(h.Content) == "" && len(h.ToolCalls) > 0 {
continue
}
if strings.TrimSpace(h.Content) != "" {
out = append(out, schema.AssistantMessage(h.Content, nil))
raw = append(raw, schema.AssistantMessage(h.Content, nil))
}
default:
continue
}
}
if len(raw) == 0 {
return raw
}
maxTotal := 120000
modelName := "gpt-4o"
if appCfg != nil {
if appCfg.OpenAI.MaxTotalTokens > 0 {
maxTotal = appCfg.OpenAI.MaxTotalTokens
}
if m := strings.TrimSpace(appCfg.OpenAI.Model); m != "" {
modelName = m
}
}
ratio := 0.35
if mwCfg != nil {
ratio = mwCfg.HistoryInputBudgetRatioEffective()
}
budget := int(float64(maxTotal) * ratio)
if budget < 4096 {
budget = 4096
}
tc := agent.NewTikTokenCounter()
outRev := make([]adk.Message, 0, len(raw))
used := 0
for i := len(raw) - 1; i >= 0; i-- {
msg := raw[i]
n, err := tc.Count(modelName, string(msg.Role)+"\n"+msg.Content)
if err != nil {
n = (len(msg.Content) + 3) / 4
}
if n <= 0 {
n = 1
}
if used+n > budget {
break
}
used += n
outRev = append(outRev, msg)
}
out := make([]adk.Message, 0, len(outRev))
for i := len(outRev) - 1; i >= 0; i-- {
out = append(out, outRev[i])
}
return out
}
@@ -1,51 +0,0 @@
package multiagent
import (
"fmt"
"strings"
"github.com/cloudwego/eino/schema"
)
// maxToolCallRecoveryAttempts 含首次运行:首次 + 自动重试次数。
// 例如为 3 表示最多共 3 次完整 DeepAgent 运行(2 次失败后各追加一条纠错提示)。
// 该常量同时用于 JSON 参数错误和工具执行错误(如子代理名称不存在)的恢复重试。
const maxToolCallRecoveryAttempts = 5
// toolCallArgumentsJSONRetryHint 追加在用户消息后,提示模型输出合法 JSON 工具参数(部分云厂商会在流式阶段校验 arguments)。
func toolCallArgumentsJSONRetryHint() *schema.Message {
return schema.UserMessage(`[系统提示] 上一次输出中工具调用的 function.arguments 不是合法 JSON接口已拒绝请重新生成每个 tool call arguments 必须是完整可解析的 JSON 对象字符串键名用双引号无多余逗号括号配对不要输出截断或不完整的 JSON
[System] Your previous tool call used invalid JSON in function.arguments and was rejected by the API. Regenerate with strictly valid JSON objects only (double-quoted keys, matched braces, no trailing commas).`)
}
// toolCallArgumentsJSONRecoveryTimelineMessage 供 eino_recovery 事件落库与前端时间线展示。
func toolCallArgumentsJSONRecoveryTimelineMessage(attempt int) string {
return fmt.Sprintf(
"接口拒绝了无效的工具参数 JSON。已向对话追加系统提示并要求模型重新生成合法的 function.arguments。"+
"当前为第 %d/%d 轮完整运行。\n\n"+
"The API rejected invalid JSON in tool arguments. A system hint was appended. This is full run %d of %d.",
attempt+1, maxToolCallRecoveryAttempts, attempt+1, maxToolCallRecoveryAttempts,
)
}
// isRecoverableToolCallArgumentsJSONError 判断是否为「工具参数非合法 JSON」类流式错误,可通过追加提示后重跑一轮。
func isRecoverableToolCallArgumentsJSONError(err error) bool {
if err == nil {
return false
}
s := strings.ToLower(err.Error())
if !strings.Contains(s, "json") {
return false
}
if strings.Contains(s, "function.arguments") || strings.Contains(s, "function arguments") {
return true
}
if strings.Contains(s, "invalidparameter") && strings.Contains(s, "json") {
return true
}
if strings.Contains(s, "must be in json format") {
return true
}
return false
}
@@ -1,17 +0,0 @@
package multiagent
import (
"errors"
"testing"
)
func TestIsRecoverableToolCallArgumentsJSONError(t *testing.T) {
yes := errors.New(`failed to receive stream chunk: error, <400> InternalError.Algo.InvalidParameter: The "function.arguments" parameter of the code model must be in JSON format.`)
if !isRecoverableToolCallArgumentsJSONError(yes) {
t.Fatal("expected recoverable for function.arguments + JSON")
}
no := errors.New("unrelated network failure")
if isRecoverableToolCallArgumentsJSONError(no) {
t.Fatal("expected not recoverable")
}
}
@@ -1,44 +0,0 @@
package multiagent
import (
"fmt"
"github.com/cloudwego/eino/schema"
)
// toolExecutionRetryHint returns a user message appended to the conversation to prompt
// the LLM to adjust after a tool execution error (tool not found, binary missing,
// runtime failure, network error, etc.).
func toolExecutionRetryHint() *schema.Message {
return schema.UserMessage(`[System] Your previous tool call failed. Possible causes:
- The tool or sub-agent name does not exist (typo or unregistered name).
- The tool call arguments were not valid JSON.
- The tool's underlying binary is not installed or not in PATH.
- The tool encountered a runtime error (timeout, network failure, permission denied, etc.).
Please review the error message above, check available tools, and either:
1. Retry with corrected arguments or a different tool, OR
2. Inform the user about the limitation and proceed with an alternative approach.
[系统提示] 上一次工具调用失败可能原因
- 工具名或子代理名称不存在拼写错误或未注册
- 工具调用参数不是合法 JSON
- 工具依赖的底层二进制程序未安装或不在 PATH
- 工具运行时遇到错误超时网络故障权限不足等
请根据上述错误信息检查可用工具然后
1. 修正参数或改用其他工具重试或者
2. 告知用户当前限制并采用替代方案继续`)
}
// toolExecutionRecoveryTimelineMessage returns a message for the eino_recovery event
// displayed in the UI timeline when a tool execution error triggers a retry.
func toolExecutionRecoveryTimelineMessage(attempt int) string {
return fmt.Sprintf(
"工具调用执行失败。已向对话追加纠错提示并要求模型调整策略。"+
"当前为第 %d/%d 轮完整运行。\n\n"+
"Tool call execution failed. "+
"A corrective hint was appended. This is full run %d of %d.",
attempt+1, maxToolCallRecoveryAttempts, attempt+1, maxToolCallRecoveryAttempts,
)
}
+23 -15
View File
@@ -192,13 +192,13 @@ func convertOpenAIToClaude(payload interface{}) (*claudeRequest, error) {
fnName, _ := fn["name"].(string)
fnArgs, _ := fn["arguments"]
// 防御:缺少 name 或 id 的 tool_call 会被 Claude 拒绝
if strings.TrimSpace(fnName) == "" {
fnName = "unknown_function"
}
if strings.TrimSpace(tcID) == "" {
tcID = fmt.Sprintf("call_%d", time.Now().UnixNano())
}
// 防御:缺少 name 或 id 的 tool_call 会被 Claude 拒绝
if strings.TrimSpace(fnName) == "" {
fnName = "unknown_function"
}
if strings.TrimSpace(tcID) == "" {
tcID = fmt.Sprintf("call_%d", time.Now().UnixNano())
}
var inputRaw json.RawMessage
switch v := fnArgs.(type) {
@@ -752,25 +752,33 @@ func isClaudeProvider(cfg *config.OpenAIConfig) bool {
// Eino HTTP Client Bridge
// ============================================================
// NewEinoHTTPClient 为 einoopenai.ChatModelConfig 返回一个支持 Claude 自动桥接的 http.Client。
// 当 cfg.Provider 为 claude 时,会拦截 /chat/completions 请求,透明转换为 Anthropic Messages API。
// NewEinoHTTPClient 为 einoopenai.ChatModelConfig 返回一个 http.Client,包含两层 transport 包装:
// 1. 当 cfg.Provider 为 claude 时,最内层套 claudeRoundTripper,把 OpenAI /chat/completions 透明
// 桥接为 Anthropic /v1/messages(并把 Claude SSE 翻译回 OpenAI SSE 格式)。
// 2. 最外层无条件套 einoSSESanitizingRoundTripper,吞掉中转站发的 SSE 心跳/注释/控制行
// (": keepalive" / "event: ping" / "retry: 3000" 等),避免 Eino 用的 meguminnnnnnnnn/go-openai
// SDK 在累计超过 300 个非 "data:" 行后抛 "stream has sent too many empty messages"。
//
// 两层都对调用方完全透明:普通 JSON 响应原样透传,仅当响应 Content-Type 为 text/event-stream 时
// sanitizer 才会接管 bodydata: payload (含 [DONE]、{"error":...}) 一字节不改。
func NewEinoHTTPClient(cfg *config.OpenAIConfig, base *http.Client) *http.Client {
if base == nil {
base = http.DefaultClient
}
if !isClaudeProvider(cfg) {
return base
}
cloned := *base
transport := base.Transport
if transport == nil {
transport = http.DefaultTransport
}
cloned.Transport = &claudeRoundTripper{
base: transport,
config: cfg,
if isClaudeProvider(cfg) {
transport = &claudeRoundTripper{
base: transport,
config: cfg,
}
}
transport = &einoSSESanitizingRoundTripper{base: transport}
cloned.Transport = transport
return &cloned
}
+149
View File
@@ -0,0 +1,149 @@
package openai
// eino_sse_sanitizer.go 解决 Eino 走 meguminnnnnnnnn/go-openai SDK 时,
// 中转站心跳/SSE 控制行累计 > 300 行触发 ErrTooManyEmptyStreamMessages
// (报错文案: "stream has sent too many empty messages")的问题。
//
// 触发链路:
// einoopenai.NewChatModel
// → eino-ext/libs/acl/openai → meguminnnnnnnnn/go-openai
// → streamReader.processLines() 对所有非 "data:" 行计数, > 300 即抛错。
//
// 中转站常见的非 data: 行(合法 SSE 但 SDK 不接受):
// ":" / ": keepalive" / ": ping" / "event: ping" / "retry: 3000"
// 以及思考型模型 prefill 期间穿插的大量心跳。
//
// 兜底策略: 在 HTTP transport 层把响应 Body 包一层 reader, 只放行 "data:"
// 开头的行, 把心跳/注释/事件类型行就地吞掉。下游 SDK 永远见不到非 data: 行,
// 计数器始终为 0, 该错误不可能再发生。
//
// 该层对调用方完全透明:
// - 仅当响应 Content-Type 是 text/event-stream 时介入;普通 JSON 响应原样透传
// - data: payload (含 [DONE] 与 {"error":...}) 一字节不改
// - 上游真断流 (EOF / connection reset / context cancel) 原样透传
import (
"bufio"
"bytes"
"io"
"net/http"
"strings"
)
const (
// einoSSEReaderBufSize 给 bufio 一个较大的初始缓冲, 避免单行大 JSON chunk
// (含工具调用 arguments / reasoning_content) 频繁触发缓冲区扩容。
einoSSEReaderBufSize = 64 * 1024
)
// einoSSESanitizingRoundTripper 包装下游 RoundTripper, 对 SSE 响应做行级清洗。
type einoSSESanitizingRoundTripper struct {
base http.RoundTripper
}
func (rt *einoSSESanitizingRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
resp, err := rt.base.RoundTrip(req)
if err != nil || resp == nil {
return resp, err
}
if !isSSEResponse(resp) {
return resp, nil
}
resp.Body = newEinoSSESanitizingBody(resp.Body)
return resp, nil
}
// isSSEResponse 仅对 200 + text/event-stream 的响应做清洗;
// 错误响应 (4xx/5xx 通常是 application/json) 不动, 由 SDK 走原错误路径。
func isSSEResponse(resp *http.Response) bool {
if resp.StatusCode != http.StatusOK {
return false
}
ct := resp.Header.Get("Content-Type")
if ct == "" {
return false
}
ct = strings.ToLower(strings.TrimSpace(ct))
// 兼容 "text/event-stream", "text/event-stream; charset=utf-8" 等。
return strings.HasPrefix(ct, "text/event-stream")
}
// einoSSESanitizingBody 是包装后的响应体: 只放行 data: 行, 其它行吞掉。
type einoSSESanitizingBody struct {
upstream io.ReadCloser
reader *bufio.Reader
pending []byte // 已清洗、待返回给下游的字节 (永远以 \n 结尾的完整 data: 行)
err error // upstream 终态错误 (io.EOF 或网络错误)
}
func newEinoSSESanitizingBody(body io.ReadCloser) *einoSSESanitizingBody {
return &einoSSESanitizingBody{
upstream: body,
reader: bufio.NewReaderSize(body, einoSSEReaderBufSize),
}
}
func (b *einoSSESanitizingBody) Read(p []byte) (int, error) {
if len(p) == 0 {
return 0, nil
}
if len(b.pending) > 0 {
n := copy(p, b.pending)
b.pending = b.pending[n:]
return n, nil
}
// 从上游读, 直到攒出一行 data: 或拿到终态。
// 单次循环可能丢弃任意多行心跳, 但只放行至多一行 data: 后退出,
// 避免一次 Read 阻塞过久 / pending 缓冲过大。
for b.err == nil {
line, err := b.reader.ReadBytes('\n')
if len(line) > 0 {
if isPassThroughSSELine(line) {
if line[len(line)-1] != '\n' {
line = append(line, '\n')
}
b.pending = line
if err != nil {
b.err = err
}
break
}
// 非 data: 行 (空行 / ":" 注释 / event: / retry: / id: / 任何裸文本)
// 全部吞掉, 不向下游透出, 继续循环读下一行。
}
if err != nil {
b.err = err
break
}
}
if len(b.pending) > 0 {
n := copy(p, b.pending)
b.pending = b.pending[n:]
return n, nil
}
return 0, b.err
}
func (b *einoSSESanitizingBody) Close() error {
return b.upstream.Close()
}
// isPassThroughSSELine 判定该行是否需要原样放行给下游 SDK。
// 仅 "data:" (大小写不敏感, 可有任意前导空白) 开头的行需要保留。
// 注意: 不能用 TrimSpace 去尾部换行后再判, 否则 " data: x" 会被误判;
// 我们只 trim 前导空白, 与 SDK 内部 TrimSpace 后再正则 ^data:\s* 的语义一致。
func isPassThroughSSELine(line []byte) bool {
trimmed := bytes.TrimLeft(line, " \t")
if len(trimmed) < 5 {
return false
}
// 大小写不敏感比较前 5 字节是否为 "data:"。SSE 规范要求字段名小写,
// 但宽松匹配可以兼容个别中转站的非规范实现。
return (trimmed[0] == 'd' || trimmed[0] == 'D') &&
(trimmed[1] == 'a' || trimmed[1] == 'A') &&
(trimmed[2] == 't' || trimmed[2] == 'T') &&
(trimmed[3] == 'a' || trimmed[3] == 'A') &&
trimmed[4] == ':'
}
+303
View File
@@ -0,0 +1,303 @@
package openai
import (
"bufio"
"bytes"
"errors"
"io"
"net/http"
"net/http/httptest"
"regexp"
"strings"
"testing"
)
// 复现 meguminnnnnnnnn/go-openai 的 SSE 行计数算法 (默认 limit=300):
// - 逐行读
// - 非 "data:" 行 (空行 / ":" 注释 / event: / retry:) 累计 emptyMessagesCount
// - > 300 抛 ErrTooManyEmptyStreamMessages
// - 遇到 data: 行 reset, 返回 payload
//
// 这一算法与上游 SDK 的 stream_reader.go processLines() 严格一致 (验证依据见
// /Users/temp/go/pkg/mod/github.com/meguminnnnnnnnn/go-openai@v0.1.2/stream_reader.go)。
// 测试中只复刻 "限制触发" 这一行为, 用来回归验证 sanitizer 的根因修复。
var errTooManyEmptyStreamMessages = errors.New("stream has sent too many empty messages")
func sdkLikeRecvAll(body io.Reader, limit uint) ([]string, error) {
headerData := regexp.MustCompile(`^data:\s*`)
r := bufio.NewReader(body)
var payloads []string
for {
var emptyMessagesCount uint
var payload []byte
for {
line, err := r.ReadBytes('\n')
if err != nil {
if err == io.EOF {
return payloads, nil
}
return payloads, err
}
noSpace := bytes.TrimSpace(line)
if !headerData.Match(noSpace) {
emptyMessagesCount++
if emptyMessagesCount > limit {
return payloads, errTooManyEmptyStreamMessages
}
continue
}
payload = headerData.ReplaceAll(noSpace, nil)
break
}
if string(payload) == "[DONE]" {
return payloads, nil
}
payloads = append(payloads, string(payload))
}
}
func newSSEServer(t *testing.T, body string, contentType string, status int) *httptest.Server {
t.Helper()
return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
if contentType != "" {
w.Header().Set("Content-Type", contentType)
}
w.WriteHeader(status)
_, _ = io.WriteString(w, body)
}))
}
func sanitizingClient(base *http.Client) *http.Client {
if base == nil {
base = &http.Client{}
}
cloned := *base
transport := base.Transport
if transport == nil {
transport = http.DefaultTransport
}
cloned.Transport = &einoSSESanitizingRoundTripper{base: transport}
return &cloned
}
func readAll(t *testing.T, body io.ReadCloser) string {
t.Helper()
defer body.Close()
out, err := io.ReadAll(body)
if err != nil {
t.Fatalf("read body: %v", err)
}
return string(out)
}
// 1) 仅 data: 行 → 一字节不改地透传。
func TestSSESanitizer_PassesDataLinesUnchanged(t *testing.T) {
body := "data: {\"a\":1}\ndata: {\"b\":2}\ndata: [DONE]\n"
srv := newSSEServer(t, body, "text/event-stream", 200)
defer srv.Close()
resp, err := sanitizingClient(nil).Get(srv.URL)
if err != nil {
t.Fatalf("get: %v", err)
}
got := readAll(t, resp.Body)
if got != body {
t.Fatalf("body mismatch:\nwant %q\ngot %q", body, got)
}
}
// 2) 心跳/注释/事件类型行被吞掉, 仅保留 data: 行。
func TestSSESanitizer_DropsHeartbeatsAndControlLines(t *testing.T) {
body := strings.Join([]string{
": keepalive",
"",
"event: ping",
"retry: 3000",
"id: 42",
"data: {\"x\":1}",
": ping",
"",
"data: {\"x\":2}",
"data: [DONE]",
"",
}, "\n")
srv := newSSEServer(t, body, "text/event-stream", 200)
defer srv.Close()
resp, err := sanitizingClient(nil).Get(srv.URL)
if err != nil {
t.Fatalf("get: %v", err)
}
got := readAll(t, resp.Body)
want := "data: {\"x\":1}\ndata: {\"x\":2}\ndata: [DONE]\n"
if got != want {
t.Fatalf("sanitized body mismatch:\nwant %q\ngot %q", want, got)
}
}
// 3) 根因回归: 上游堆 500 行心跳后才发 data:, 原始 SDK 算法会抛
// ErrTooManyEmptyStreamMessages, sanitize 之后必须能正常拿到所有 data:。
func TestSSESanitizer_ProtectsAgainstTooManyEmptyMessages(t *testing.T) {
const heartbeats = 500
var buf bytes.Buffer
for i := 0; i < heartbeats; i++ {
buf.WriteString(": keepalive\n")
}
buf.WriteString("data: {\"chunk\":1}\n")
buf.WriteString("data: {\"chunk\":2}\n")
buf.WriteString("data: [DONE]\n")
t.Run("baseline_without_sanitizer_must_fail", func(t *testing.T) {
_, err := sdkLikeRecvAll(bytes.NewReader(buf.Bytes()), 300)
if !errors.Is(err, errTooManyEmptyStreamMessages) {
t.Fatalf("expected ErrTooManyEmptyStreamMessages, got %v", err)
}
})
t.Run("with_sanitizer_must_succeed", func(t *testing.T) {
srv := newSSEServer(t, buf.String(), "text/event-stream", 200)
defer srv.Close()
resp, err := sanitizingClient(nil).Get(srv.URL)
if err != nil {
t.Fatalf("get: %v", err)
}
defer resp.Body.Close()
payloads, err := sdkLikeRecvAll(resp.Body, 300)
if err != nil {
t.Fatalf("sdk-like recv after sanitize: %v", err)
}
want := []string{`{"chunk":1}`, `{"chunk":2}`}
if len(payloads) != len(want) {
t.Fatalf("payload count mismatch: want %d got %d (%v)", len(want), len(payloads), payloads)
}
for i, w := range want {
if payloads[i] != w {
t.Fatalf("payload[%d] mismatch: want %q got %q", i, w, payloads[i])
}
}
})
}
// 4) 心跳穿插在 data: 之间也能正确清洗 (思考型模型 prefill 期间常见)。
func TestSSESanitizer_HeartbeatsInterleavedWithData(t *testing.T) {
var buf bytes.Buffer
buf.WriteString("data: {\"chunk\":1}\n")
for i := 0; i < 400; i++ {
buf.WriteString(": keepalive\n")
}
buf.WriteString("data: {\"chunk\":2}\n")
buf.WriteString("data: [DONE]\n")
srv := newSSEServer(t, buf.String(), "text/event-stream", 200)
defer srv.Close()
resp, err := sanitizingClient(nil).Get(srv.URL)
if err != nil {
t.Fatalf("get: %v", err)
}
defer resp.Body.Close()
payloads, err := sdkLikeRecvAll(resp.Body, 300)
if err != nil {
t.Fatalf("sdk-like recv: %v", err)
}
if got, want := len(payloads), 2; got != want {
t.Fatalf("payload count: want %d got %d", want, got)
}
}
// 5) 非 SSE 响应 (例如非流式 JSON) 不应被 sanitizer 介入。
func TestSSESanitizer_PassesNonSSEResponseUntouched(t *testing.T) {
body := `{"id":"x","object":"chat.completion","choices":[]}`
srv := newSSEServer(t, body, "application/json", 200)
defer srv.Close()
resp, err := sanitizingClient(nil).Get(srv.URL)
if err != nil {
t.Fatalf("get: %v", err)
}
got := readAll(t, resp.Body)
if got != body {
t.Fatalf("non-SSE body must be untouched:\nwant %q\ngot %q", body, got)
}
}
// 6) 错误响应 (4xx/5xx) 不应被 sanitize, 即使 Content-Type 是 SSE 也不动,
// 避免吞掉类似 "data: " 之外的错误正文。
func TestSSESanitizer_PassesNon200Untouched(t *testing.T) {
body := `{"error":{"message":"rate limit"}}`
srv := newSSEServer(t, body, "text/event-stream", 429)
defer srv.Close()
resp, err := sanitizingClient(nil).Get(srv.URL)
if err != nil {
t.Fatalf("get: %v", err)
}
got := readAll(t, resp.Body)
if got != body {
t.Fatalf("error body must be untouched:\nwant %q\ngot %q", body, got)
}
}
// 7) data: 行末尾若缺 \n (异常上游) sanitizer 也补齐, 保证下游按行解析。
func TestSSESanitizer_AppendsTrailingNewlineIfMissing(t *testing.T) {
body := "data: {\"a\":1}"
srv := newSSEServer(t, body, "text/event-stream", 200)
defer srv.Close()
resp, err := sanitizingClient(nil).Get(srv.URL)
if err != nil {
t.Fatalf("get: %v", err)
}
got := readAll(t, resp.Body)
want := "data: {\"a\":1}\n"
if got != want {
t.Fatalf("trailing newline:\nwant %q\ngot %q", want, got)
}
}
// 8) 大 chunk (一行数十 KB) 也能完整透传, 不被切断。
func TestSSESanitizer_LargeDataLinePassesIntact(t *testing.T) {
huge := strings.Repeat("x", 80*1024)
body := "data: {\"big\":\"" + huge + "\"}\ndata: [DONE]\n"
srv := newSSEServer(t, body, "text/event-stream", 200)
defer srv.Close()
resp, err := sanitizingClient(nil).Get(srv.URL)
if err != nil {
t.Fatalf("get: %v", err)
}
got := readAll(t, resp.Body)
if got != body {
t.Fatalf("large body length mismatch: want %d got %d", len(body), len(got))
}
}
// 9) isPassThroughSSELine 单元覆盖。
func TestIsPassThroughSSELine(t *testing.T) {
cases := []struct {
line string
want bool
}{
{"data: {\"a\":1}\n", true},
{"DATA: x\n", true},
{" data: x\n", true},
{"data:\n", true},
{"\n", false},
{"\r\n", false},
{": keepalive\n", false},
{":\n", false},
{"event: ping\n", false},
{"retry: 3000\n", false},
{"id: 42\n", false},
{"datax: y\n", false},
{"da", false},
}
for _, c := range cases {
if got := isPassThroughSSELine([]byte(c.line)); got != c.want {
t.Errorf("isPassThroughSSELine(%q) = %v, want %v", c.line, got, c.want)
}
}
}
+14 -14
View File
@@ -281,9 +281,9 @@ func (c *Client) ChatCompletionStream(ctx context.Context, payload interface{},
// StreamToolCall 流式工具调用的累积结果(arguments 以字符串形式拼接,留给上层再解析为 JSON)。
type StreamToolCall struct {
Index int
ID string
Type string
Index int
ID string
Type string
FunctionName string
FunctionArgsStr string
}
@@ -348,10 +348,10 @@ func (c *Client) ChatCompletionStreamWithToolCalls(
Arguments string `json:"arguments,omitempty"`
}
type toolCallDelta struct {
Index int `json:"index,omitempty"`
ID string `json:"id,omitempty"`
Type string `json:"type,omitempty"`
Function toolCallFunctionDelta `json:"function,omitempty"`
Index int `json:"index,omitempty"`
ID string `json:"id,omitempty"`
Type string `json:"type,omitempty"`
Function toolCallFunctionDelta `json:"function,omitempty"`
}
type streamDelta2 struct {
Content string `json:"content,omitempty"`
@@ -371,10 +371,10 @@ func (c *Client) ChatCompletionStreamWithToolCalls(
}
type toolCallAccum struct {
id string
typ string
name string
args strings.Builder
id string
typ string
name string
args strings.Builder
}
toolCallAccums := make(map[int]*toolCallAccum)
@@ -475,9 +475,9 @@ func (c *Client) ChatCompletionStreamWithToolCalls(
for _, idx := range indices {
acc := toolCallAccums[idx]
tc := StreamToolCall{
Index: idx,
ID: acc.id,
Type: acc.typ,
Index: idx,
ID: acc.id,
Type: acc.typ,
FunctionName: acc.name,
FunctionArgsStr: acc.args.String(),
}
+40 -41
View File
@@ -19,11 +19,11 @@ import (
func setupTestExecutor(t *testing.T) (*Executor, *mcp.Server) {
logger := zap.NewNop()
mcpServer := mcp.NewServer(logger)
cfg := &config.SecurityConfig{
Tools: []config.ToolConfig{},
}
executor := NewExecutor(cfg, mcpServer, logger)
return executor, mcpServer
}
@@ -32,12 +32,12 @@ func setupTestExecutor(t *testing.T) (*Executor, *mcp.Server) {
func setupTestStorage(t *testing.T) *storage.FileResultStorage {
tmpDir := filepath.Join(os.TempDir(), "test_executor_storage_"+time.Now().Format("20060102_150405"))
logger := zap.NewNop()
storage, err := storage.NewFileResultStorage(tmpDir, logger)
if err != nil {
t.Fatalf("创建测试存储失败: %v", err)
}
return storage
}
@@ -45,46 +45,46 @@ func TestExecutor_ExecuteInternalTool_QueryExecutionResult(t *testing.T) {
executor, _ := setupTestExecutor(t)
testStorage := setupTestStorage(t)
executor.SetResultStorage(testStorage)
// 准备测试数据
executionID := "test_exec_001"
toolName := "nmap_scan"
result := "Line 1: Port 22 open\nLine 2: Port 80 open\nLine 3: Port 443 open\nLine 4: error occurred"
// 保存测试结果
err := testStorage.SaveResult(executionID, toolName, result)
if err != nil {
t.Fatalf("保存测试结果失败: %v", err)
}
ctx := context.Background()
// 测试1: 基本查询(第一页)
args := map[string]interface{}{
"execution_id": executionID,
"page": float64(1),
"limit": float64(2),
}
toolResult, err := executor.executeQueryExecutionResult(ctx, args)
if err != nil {
t.Fatalf("执行查询失败: %v", err)
}
if toolResult.IsError {
t.Fatalf("查询应该成功,但返回了错误: %s", toolResult.Content[0].Text)
}
// 验证结果包含预期内容
resultText := toolResult.Content[0].Text
if !strings.Contains(resultText, executionID) {
t.Errorf("结果中应该包含执行ID: %s", executionID)
}
if !strings.Contains(resultText, "第 1/") {
t.Errorf("结果中应该包含分页信息")
}
// 测试2: 搜索功能
args2 := map[string]interface{}{
"execution_id": executionID,
@@ -92,21 +92,21 @@ func TestExecutor_ExecuteInternalTool_QueryExecutionResult(t *testing.T) {
"page": float64(1),
"limit": float64(10),
}
toolResult2, err := executor.executeQueryExecutionResult(ctx, args2)
if err != nil {
t.Fatalf("执行搜索失败: %v", err)
}
if toolResult2.IsError {
t.Fatalf("搜索应该成功,但返回了错误: %s", toolResult2.Content[0].Text)
}
resultText2 := toolResult2.Content[0].Text
if !strings.Contains(resultText2, "error") {
t.Errorf("搜索结果中应该包含关键词: error")
}
// 测试3: 过滤功能
args3 := map[string]interface{}{
"execution_id": executionID,
@@ -114,46 +114,46 @@ func TestExecutor_ExecuteInternalTool_QueryExecutionResult(t *testing.T) {
"page": float64(1),
"limit": float64(10),
}
toolResult3, err := executor.executeQueryExecutionResult(ctx, args3)
if err != nil {
t.Fatalf("执行过滤失败: %v", err)
}
if toolResult3.IsError {
t.Fatalf("过滤应该成功,但返回了错误: %s", toolResult3.Content[0].Text)
}
resultText3 := toolResult3.Content[0].Text
if !strings.Contains(resultText3, "Port") {
t.Errorf("过滤结果中应该包含关键词: Port")
}
// 测试4: 缺少必需参数
args4 := map[string]interface{}{
"page": float64(1),
}
toolResult4, err := executor.executeQueryExecutionResult(ctx, args4)
if err != nil {
t.Fatalf("执行查询失败: %v", err)
}
if !toolResult4.IsError {
t.Fatal("缺少execution_id应该返回错误")
}
// 测试5: 不存在的执行ID
args5 := map[string]interface{}{
"execution_id": "nonexistent_id",
"page": float64(1),
}
toolResult5, err := executor.executeQueryExecutionResult(ctx, args5)
if err != nil {
t.Fatalf("执行查询失败: %v", err)
}
if !toolResult5.IsError {
t.Fatal("不存在的执行ID应该返回错误")
}
@@ -161,22 +161,22 @@ func TestExecutor_ExecuteInternalTool_QueryExecutionResult(t *testing.T) {
func TestExecutor_ExecuteInternalTool_UnknownTool(t *testing.T) {
executor, _ := setupTestExecutor(t)
ctx := context.Background()
args := map[string]interface{}{
"test": "value",
}
// 测试未知的内部工具类型
toolResult, err := executor.executeInternalTool(ctx, "unknown_tool", "internal:unknown_tool", args)
if err != nil {
t.Fatalf("执行内部工具失败: %v", err)
}
if !toolResult.IsError {
t.Fatal("未知的工具类型应该返回错误")
}
if !strings.Contains(toolResult.Content[0].Text, "未知的内部工具类型") {
t.Errorf("错误消息应该包含'未知的内部工具类型'")
}
@@ -185,21 +185,21 @@ func TestExecutor_ExecuteInternalTool_UnknownTool(t *testing.T) {
func TestExecutor_ExecuteInternalTool_NoStorage(t *testing.T) {
executor, _ := setupTestExecutor(t)
// 不设置存储,测试未初始化的情况
ctx := context.Background()
args := map[string]interface{}{
"execution_id": "test_id",
}
toolResult, err := executor.executeQueryExecutionResult(ctx, args)
if err != nil {
t.Fatalf("执行查询失败: %v", err)
}
if !toolResult.IsError {
t.Fatal("未初始化的存储应该返回错误")
}
if !strings.Contains(toolResult.Content[0].Text, "结果存储未初始化") {
t.Errorf("错误消息应该包含'结果存储未初始化'")
}
@@ -207,7 +207,7 @@ func TestExecutor_ExecuteInternalTool_NoStorage(t *testing.T) {
func TestPaginateLines(t *testing.T) {
lines := []string{"Line 1", "Line 2", "Line 3", "Line 4", "Line 5"}
// 测试第一页
page := paginateLines(lines, 1, 2)
if page.Page != 1 {
@@ -225,7 +225,7 @@ func TestPaginateLines(t *testing.T) {
if len(page.Lines) != 2 {
t.Errorf("第一页行数不匹配。期望: 2, 实际: %d", len(page.Lines))
}
// 测试第二页
page2 := paginateLines(lines, 2, 2)
if len(page2.Lines) != 2 {
@@ -234,13 +234,13 @@ func TestPaginateLines(t *testing.T) {
if page2.Lines[0] != "Line 3" {
t.Errorf("第二页第一行不匹配。期望: Line 3, 实际: %s", page2.Lines[0])
}
// 测试最后一页
page3 := paginateLines(lines, 3, 2)
if len(page3.Lines) != 1 {
t.Errorf("第三页行数不匹配。期望: 1, 实际: %d", len(page3.Lines))
}
// 测试超出范围的页码(应该返回最后一页)
page4 := paginateLines(lines, 4, 2)
if page4.Page != 3 {
@@ -249,13 +249,13 @@ func TestPaginateLines(t *testing.T) {
if len(page4.Lines) != 1 {
t.Errorf("最后一页应该只有1行。实际: %d行", len(page4.Lines))
}
// 测试无效页码(小于1
page0 := paginateLines(lines, 0, 2)
if page0.Page != 1 {
t.Errorf("无效页码应该被修正为1。实际: %d", page0.Page)
}
// 测试空列表
emptyPage := paginateLines([]string{}, 1, 10)
if emptyPage.TotalLines != 0 {
@@ -265,4 +265,3 @@ func TestPaginateLines(t *testing.T) {
t.Errorf("空列表应该返回空结果。实际: %d行", len(emptyPage.Lines))
}
}
+4 -4
View File
@@ -16,10 +16,10 @@ type rateLimitEntry struct {
// RateLimiter 基于 IP 的滑动窗口速率限制器
type RateLimiter struct {
mu sync.Mutex
entries map[string]*rateLimitEntry
limit int // 窗口内允许的最大请求数
window time.Duration // 窗口时长
mu sync.Mutex
entries map[string]*rateLimitEntry
limit int // 窗口内允许的最大请求数
window time.Duration // 窗口时长
}
// NewRateLimiter 创建速率限制器
-1
View File
@@ -162,4 +162,3 @@ func truncateRunes(s string, max int) string {
}
return string(r[:max]) + "…"
}
+6 -6
View File
@@ -49,12 +49,12 @@ func ParseSkillMD(raw []byte) (*SkillManifest, string, error) {
}
type skillFrontMatterExport struct {
Name string `yaml:"name"`
Description string `yaml:"description"`
License string `yaml:"license,omitempty"`
Compatibility string `yaml:"compatibility,omitempty"`
Metadata map[string]any `yaml:"metadata,omitempty"`
AllowedTools string `yaml:"allowed-tools,omitempty"`
Name string `yaml:"name"`
Description string `yaml:"description"`
License string `yaml:"license,omitempty"`
Compatibility string `yaml:"compatibility,omitempty"`
Metadata map[string]any `yaml:"metadata,omitempty"`
AllowedTools string `yaml:"allowed-tools,omitempty"`
}
// BuildSkillMD serializes SKILL.md per agentskills.io.
+4 -4
View File
@@ -9,10 +9,10 @@ import (
)
const (
maxPackageFiles = 4000
maxPackageDepth = 24
maxScriptsDepth = 24
defaultMaxRead = 10 << 20
maxPackageFiles = 4000
maxPackageDepth = 24
maxScriptsDepth = 24
defaultMaxRead = 10 << 20
)
// SafeRelPath resolves rel inside root (no ..).
+2025 -255
View File
File diff suppressed because it is too large Load Diff
+184 -5
View File
@@ -20,7 +20,13 @@
"copied": "Copied",
"copyFailed": "Copy failed",
"view": "View",
"actions": "Actions"
"actions": "Actions",
"loadFailed": "Load failed",
"untitled": "Untitled",
"justNow": "Just now",
"minutesAgo": "{{n}} min ago",
"hoursAgo": "{{n}} h ago",
"daysAgo": "{{n}} d ago"
},
"header": {
"title": "CyberStrikeAI",
@@ -33,6 +39,13 @@
"version": "Current version",
"toggleSidebar": "Collapse/expand sidebar"
},
"notifications": {
"title": "Notifications",
"empty": "No new events",
"markAllRead": "Mark all read",
"markSingleRead": "Read",
"itemDefaultTitle": "Notification"
},
"login": {
"title": "Sign in to CyberStrikeAI",
"subtitle": "Enter the access password from config",
@@ -81,6 +94,16 @@
"severityMedium": "Medium",
"severityLow": "Low",
"severityInfo": "Info",
"totalVulns": "Total vulnerabilities",
"riskLevel": "Risk level",
"riskScore": "Weighted risk score",
"riskSafe": "Safe",
"riskLow": "Low",
"riskMedium": "Medium",
"riskHigh": "High",
"riskSevere": "Severe",
"latestFound": "Latest found",
"noneYet": "None yet",
"runOverview": "Run overview",
"batchQueues": "Batch task queues",
"pending": "Pending",
@@ -107,7 +130,80 @@
"toUse": "To use",
"active": "Active",
"highFreq": "High frequency",
"noCallData": "No call data"
"noCallData": "No call data",
"lastUpdated": "Last updated",
"viewAll": "View all →",
"recentVulns": "Recent vulnerabilities",
"noVulnYet": "No recent vulnerabilities",
"capabilities": "Capabilities",
"mcpTools": "MCP tools",
"rolesLabel": "Roles",
"agentsLabel": "Agents",
"webshellLabel": "WebShell",
"pendingCountLabel": "{{count}} pending",
"highCountLabel": "High {{count}}",
"toolsCountLabel_one": "{{count}} tool",
"toolsCountLabel_other": "{{count}} tools",
"failedNCalls_one": "{{count}} failed",
"failedNCalls_other": "{{count}} failed",
"noCallYet": "No calls yet",
"allClear": "No new risks",
"allIdle": "System idle",
"executingNow": "Running",
"healthyStatus": "Healthy",
"normalStatus": "Mostly OK",
"degradedStatus": "Needs attention",
"alertTitle": "Heads up",
"alertWarningTitle": "Needs attention",
"alertDangerTitle": "Action required",
"alertCriticalReason_one": "{{count}} open critical vulnerability — please review immediately",
"alertCriticalReason_other": "{{count}} open critical vulnerabilities — please review immediately",
"alertFailedReason_one": "Tool success rate is low ({{count}} failed call) — check MCP monitor",
"alertFailedReason_other": "Tool success rate is low ({{count}} failed calls) — check MCP monitor",
"alertHitlReason_one": "{{count}} HITL request pending — Agent is waiting for your decision",
"alertHitlReason_other": "{{count}} HITL requests pending — Agent is waiting for your decision",
"alertMcpDownReason_one": "{{count}} External MCP server is down — related tools are unavailable",
"alertMcpDownReason_other": "{{count}} External MCP servers are down — related tools are unavailable",
"alertDismiss": "Dismiss (this session)",
"openHighCountLabel": "Open high {{count}}",
"allHandled": "All high severity handled",
"viewVulns": "View vulnerabilities",
"viewMonitor": "View monitor",
"viewHitl": "Approve",
"viewMcpManagement": "Manage MCP",
"statusOpen": "Open",
"statusConfirmed": "Confirmed",
"statusFixed": "Fixed",
"statusFalsePositive": "False positive",
"fixRate": "Fix rate",
"dataStale": "Data may be stale — please refresh",
"recommendedActions": "Recommended Actions",
"recommendedActionsHint": "Generated based on current state",
"recoFixCritical_one": "Fix {{count}} open critical vulnerability",
"recoFixCritical_other": "Fix {{count}} open critical vulnerabilities",
"recoFixCriticalDesc": "Critical-level vulnerabilities should be addressed first",
"recoApproveHitl_one": "Approve {{count}} HITL request",
"recoApproveHitl_other": "Approve {{count}} HITL requests",
"recoApproveHitlDesc": "Agent needs your decision to proceed",
"recoRestartMcp_one": "Check {{count}} stopped External MCP",
"recoRestartMcp_other": "Check {{count}} stopped External MCPs",
"recoRestartMcpDesc": "Related tools are unavailable until MCP recovers",
"recoCheckMonitor_one": "Investigate {{count}} failed tool call",
"recoCheckMonitor_other": "Investigate {{count}} failed tool calls",
"recoCheckMonitorDesc": "View failed request details in MCP monitor",
"recoSetupMcp": "Configure your first MCP tool",
"recoSetupMcpDesc": "Install MCP server before Agent can invoke specific capabilities",
"recoStartScan": "Start a scan from chat",
"recoStartScanDesc": "Describe your target in chat, AI will help execute",
"recentEvents": "Recent Events",
"eventUntitled": "Event",
"externalMcpServers": "External MCP",
"mcpAllRunning": "All running",
"mcpPartialDown_one": "{{count}} stopped",
"mcpPartialDown_other": "{{count}} stopped",
"mcpAllDown": "All stopped",
"noVulnDesc": "This list shows recent records; new results appear here when detection completes in chat",
"startScanBtn": "Go to chat to scan"
},
"chat": {
"newChat": "New chat",
@@ -178,7 +274,6 @@
"taskCancelled": "Task cancelled",
"unknownTool": "Unknown tool",
"einoAgentReplyTitle": "Sub-agent reply",
"einoRecoveryTitle": "🔄 Invalid tool JSON · run {{n}}/{{max}} (hint appended)",
"einoStreamErrorTitle": "⚠️ Eino stream interrupted ({{agent}})",
"einoStreamErrorMessage": "Streaming read failed; the system will retry or terminate according to policy.",
"iterationLimitReachedTitle": "⛔ Iteration limit reached",
@@ -240,7 +335,20 @@
},
"hitl": {
"pageTitle": "HITL approvals",
"pendingTitle": "Pending approvals"
"pendingTitle": "Pending approvals",
"loading": "Loading...",
"emptyState": "No pending approvals",
"dismiss": "Dismiss",
"conversationLabel": "Conversation:",
"reviewEditHelp": "Review & edit mode: provide a JSON object to override tool arguments. Example: {\"command\":\"ls -la\"}",
"approvalHelp": "Approval mode: only approve/reject, argument editing is disabled.",
"commentHelp": "Comment (optional): briefly note the approval reason.",
"commentPlaceholder": "e.g. allow read-only command",
"reject": "Reject",
"approve": "Approve",
"loadFailed": "Failed to load",
"invalidJson": "Invalid JSON arguments",
"submitFailedPrefix": "Submit failed:"
},
"progress": {
"callingAI": "Calling AI model...",
@@ -304,6 +412,8 @@
"clearHistory": "Clear history",
"cancelTask": "Cancel task",
"viewConversation": "View conversation",
"viewVulnerabilities": "View vulnerabilities",
"viewVulnerabilitiesQueueTitle": "View vulnerabilities: open management filtered to this queue",
"retryTask": "Retry",
"conversationIdLabel": "Conversation ID",
"statusPending": "Pending",
@@ -445,6 +555,17 @@
"typeCustom": "Custom",
"cmdParam": "Command parameter name",
"cmdParamPlaceholder": "Leave empty for cmd; e.g. xxx for xxx=command",
"encoding": "Response encoding",
"encodingAuto": "Auto detect",
"encodingUtf8": "UTF-8",
"encodingGbk": "GBK (Simplified Chinese Windows)",
"encodingGb18030": "GB18030",
"encodingHint": "Switch to GBK or GB18030 if the Simplified Chinese Windows target shows garbled output.",
"os": "Target OS",
"osAuto": "Auto (infer from Shell type)",
"osLinux": "Linux / Unix",
"osWindows": "Windows",
"osHint": "Determines whether file manager / uploads use Linux or Windows commands. Choose Windows for PHP/JSP hosted on Windows.",
"remark": "Remark",
"remarkPlaceholder": "Friendly name for this connection",
"deleteConfirm": "Delete this connection?",
@@ -575,6 +696,10 @@
"addExternal": "Add external MCP",
"toolConfig": "MCP tool config",
"saveToolConfig": "Save tool config",
"alwaysVisibleLabel": "Pinned",
"alwaysVisibleHint": "Always keep visible in Tool Search results",
"alwaysVisibleBuiltinLabel": "Builtin default",
"alwaysVisibleBuiltinHint": "Backend builtin tool is pinned by default and cannot be disabled",
"externalConfig": "External MCP config",
"loadingTools": "Loading tools...",
"loadToolsTimeout": "Tools load timeout. External MCP may be slow. Click Refresh to retry or check connection.",
@@ -1313,6 +1438,12 @@
"clear": "Clear",
"vulnId": "Vuln ID",
"conversationId": "Conversation ID",
"taskOrQueueId": "Task / queue ID",
"filterTaskOrQueue": "Filter by task or queue ID",
"conversationTag": "Conversation tag",
"filterConversationTag": "Filter by conversation tag",
"taskTag": "Task tag",
"filterTaskTag": "Filter by task tag",
"severity": "Severity",
"status": "Status",
"statusOpen": "Open",
@@ -1322,7 +1453,31 @@
"searchVulnId": "Search vuln ID",
"filterConversation": "Filter by conversation",
"loading": "Loading...",
"noRecords": "No vulnerability records"
"loadListFailed": "Failed to load",
"noRecords": "No vulnerability records",
"batchExport": "Batch export",
"downloadMarkdownTitle": "Download Markdown",
"exportNoResults": "No vulnerabilities match the current filters",
"exportStarted": "Started downloading {{count}} file(s)",
"exportFailed": "Export failed",
"saveRequiredFields": "Please fill in conversation ID, title, and severity",
"saveFailed": "Save failed",
"fetchFailed": "Failed to fetch vulnerability",
"deleteFailed": "Delete failed",
"detailVulnId": "Vuln ID",
"detailType": "Type",
"detailTarget": "Target",
"detailConversationId": "Conversation ID",
"detailTaskId": "Task ID",
"detailTaskQueueId": "Task queue ID",
"detailConversationTag": "Conversation tag",
"detailTaskTag": "Task tag",
"detailProof": "Proof",
"detailImpact": "Impact",
"detailRecommendation": "Remediation",
"downloadOkTitle": "Downloaded",
"exportFailedMessage": "Export failed",
"downloadFailed": "Download failed"
},
"tasksPage": {
"statusFilter": "Status filter",
@@ -1673,6 +1828,7 @@
},
"contextMenu": {
"viewAttackChain": "View attack chain",
"viewVulnerabilities": "View vulnerabilities",
"downloadMarkdown": "Download Markdown",
"downloadMarkdownSummary": "Summary",
"downloadMarkdownFull": "Full",
@@ -1768,6 +1924,10 @@
"vulnerabilityModal": {
"conversationId": "Conversation ID",
"conversationIdPlaceholder": "Enter conversation ID",
"conversationTag": "Conversation tag",
"conversationTagPlaceholder": "e.g. engagement A, weekly report",
"taskTag": "Task tag",
"taskTagPlaceholder": "e.g. batch scan Q2, retest",
"title": "Title",
"titlePlaceholder": "Vulnerability title",
"description": "Description",
@@ -1795,6 +1955,25 @@
"recommendation": "Recommendation",
"recommendationPlaceholder": "Remediation"
},
"vulnerabilityMd": {
"headingBasic": "Basic information",
"labelId": "Vulnerability ID",
"labelSeverity": "Severity",
"labelStatus": "Status",
"labelType": "Type",
"labelTarget": "Target",
"labelConversationId": "Conversation ID",
"labelTaskId": "Task ID",
"labelTaskQueueId": "Task queue ID",
"labelConversationTag": "Conversation tag",
"labelTaskTag": "Task tag",
"labelCreated": "Created at",
"labelUpdated": "Updated at",
"headingDescription": "Description",
"headingProof": "Proof (POC)",
"headingImpact": "Impact",
"headingRecommendation": "Remediation"
},
"roleModal": {
"addRole": "Add role",
"editRole": "Edit role",
+173 -5
View File
@@ -20,7 +20,13 @@
"copied": "已复制",
"copyFailed": "复制失败",
"view": "查看",
"actions": "操作"
"actions": "操作",
"loadFailed": "加载失败",
"untitled": "未命名",
"justNow": "刚刚",
"minutesAgo": "{{n}} 分钟前",
"hoursAgo": "{{n}} 小时前",
"daysAgo": "{{n}} 天前"
},
"header": {
"title": "CyberStrikeAI",
@@ -33,6 +39,13 @@
"version": "当前版本",
"toggleSidebar": "折叠/展开侧边栏"
},
"notifications": {
"title": "事件通知",
"empty": "暂无新事件",
"markAllRead": "标记已读",
"markSingleRead": "已读",
"itemDefaultTitle": "通知"
},
"login": {
"title": "登录 CyberStrikeAI",
"subtitle": "请输入配置中的访问密码",
@@ -81,6 +94,16 @@
"severityMedium": "中危",
"severityLow": "低危",
"severityInfo": "信息",
"totalVulns": "总漏洞数",
"riskLevel": "风险等级",
"riskScore": "加权风险分",
"riskSafe": "安全",
"riskLow": "低",
"riskMedium": "中",
"riskHigh": "高",
"riskSevere": "极高",
"latestFound": "最近发现",
"noneYet": "暂无",
"runOverview": "运行概览",
"batchQueues": "批量任务队列",
"pending": "待执行",
@@ -107,7 +130,69 @@
"toUse": "待使用",
"active": "活跃",
"highFreq": "高频",
"noCallData": "暂无调用数据"
"noCallData": "暂无调用数据",
"lastUpdated": "上次更新",
"viewAll": "查看全部 →",
"recentVulns": "最近漏洞",
"noVulnYet": "暂无最近漏洞",
"capabilities": "能力总览",
"mcpTools": "MCP 工具",
"rolesLabel": "角色",
"agentsLabel": "Agents",
"webshellLabel": "WebShell",
"pendingCountLabel": "{{count}} 待执行",
"highCountLabel": "高危 {{count}}",
"toolsCountLabel": "{{count}} 个工具",
"failedNCalls": "{{count}} 次失败",
"noCallYet": "暂无调用",
"allClear": "暂无新增风险",
"allIdle": "系统空闲",
"executingNow": "正在执行",
"healthyStatus": "运行平稳",
"normalStatus": "基本正常",
"degradedStatus": "需要关注",
"alertTitle": "需要关注",
"alertWarningTitle": "需要关注",
"alertDangerTitle": "需要立即处理",
"alertCriticalReason": "存在 {{count}} 个待处理的严重漏洞,建议立即处置",
"alertFailedReason": "工具调用成功率偏低({{count}} 次失败),请检查 MCP 监控",
"alertHitlReason": "有 {{count}} 个待审批的人机协同请求,Agent 正在等待你的决策",
"alertMcpDownReason": "External MCP 服务器有 {{count}} 个未运行,相关工具不可用",
"alertDismiss": "忽略此提醒(仅本次会话)",
"openHighCountLabel": "待处理高危 {{count}}",
"allHandled": "高严重度已全部处置",
"viewVulns": "查看漏洞",
"viewMonitor": "查看监控",
"viewHitl": "前往审批",
"viewMcpManagement": "管理 MCP",
"statusOpen": "待处理",
"statusConfirmed": "已确认",
"statusFixed": "已修复",
"statusFalsePositive": "误报",
"fixRate": "修复率",
"dataStale": "数据可能已过期,请手动刷新",
"recommendedActions": "推荐操作",
"recommendedActionsHint": "基于当前状态自动生成",
"recoFixCritical": "修复 {{count}} 个待处理严重漏洞",
"recoFixCriticalDesc": "严重等级的漏洞应优先处置",
"recoApproveHitl": "审批 {{count}} 个 HITL 请求",
"recoApproveHitlDesc": "Agent 正在等待你的决策才能继续",
"recoRestartMcp": "检查 {{count}} 个未运行的 External MCP",
"recoRestartMcpDesc": "相关工具在 MCP 服务恢复前不可用",
"recoCheckMonitor": "排查 {{count}} 次工具调用失败",
"recoCheckMonitorDesc": "在 MCP 监控中查看失败的请求详情",
"recoSetupMcp": "配置首个 MCP 工具",
"recoSetupMcpDesc": "安装 MCP 服务后 Agent 才能调用具体能力",
"recoStartScan": "在对话中发起扫描",
"recoStartScanDesc": "在对话中描述目标,让 AI 协助执行",
"recentEvents": "最近事件",
"eventUntitled": "事件",
"externalMcpServers": "External MCP",
"mcpAllRunning": "全部运行",
"mcpPartialDown": "{{count}} 个未运行",
"mcpAllDown": "全部未运行",
"noVulnDesc": "此处展示近期漏洞记录;在对话中完成检测后,新结果会出现在这里",
"startScanBtn": "前往对话发起扫描"
},
"chat": {
"newChat": "新对话",
@@ -178,7 +263,6 @@
"taskCancelled": "任务已取消",
"unknownTool": "未知工具",
"einoAgentReplyTitle": "子代理回复",
"einoRecoveryTitle": "🔄 工具参数无效 · 第 {{n}}/{{max}} 轮(已追加提示)",
"einoStreamErrorTitle": "⚠️ Eino 流式中断({{agent}}",
"einoStreamErrorMessage": "流式读取异常,系统将按策略重试或结束。",
"iterationLimitReachedTitle": "⛔ 达到迭代上限",
@@ -240,7 +324,20 @@
},
"hitl": {
"pageTitle": "人机协同审批",
"pendingTitle": "待处理审批"
"pendingTitle": "待处理审批",
"loading": "加载中...",
"emptyState": "暂无待审批项",
"dismiss": "忽略",
"conversationLabel": "会话:",
"reviewEditHelp": "审查编辑模式:可填写 JSON 对象覆盖参数。示例:{\"command\":\"ls -la\"}",
"approvalHelp": "审批模式:仅通过/拒绝,不支持改参。",
"commentHelp": "备注(可选):建议写审批依据。",
"commentPlaceholder": "例如:允许只读命令",
"reject": "拒绝",
"approve": "通过",
"loadFailed": "加载失败",
"invalidJson": "JSON 参数格式错误",
"submitFailedPrefix": "提交失败:"
},
"progress": {
"callingAI": "正在调用AI模型...",
@@ -304,6 +401,8 @@
"clearHistory": "清空历史",
"cancelTask": "取消任务",
"viewConversation": "查看对话",
"viewVulnerabilities": "查看漏洞",
"viewVulnerabilitiesQueueTitle": "查看漏洞:打开漏洞管理并筛选本队列",
"retryTask": "重试",
"conversationIdLabel": "对话ID",
"statusPending": "待执行",
@@ -445,6 +544,17 @@
"typeCustom": "自定义",
"cmdParam": "命令参数名",
"cmdParamPlaceholder": "不填默认为 cmd,如填 xxx 则请求为 xxx=命令",
"encoding": "响应编码",
"encodingAuto": "自动检测",
"encodingUtf8": "UTF-8",
"encodingGbk": "GBK(中文 Windows",
"encodingGb18030": "GB18030",
"encodingHint": "中文 Windows 目标若出现乱码,请切换为 GBK 或 GB18030",
"os": "目标系统",
"osAuto": "自动(按 Shell 类型推断)",
"osLinux": "Linux / Unix",
"osWindows": "Windows",
"osHint": "决定文件管理/上传使用 Linux 还是 Windows 命令;PHP/JSP 跑在 Windows 上请选 Windows",
"remark": "备注",
"remarkPlaceholder": "便于识别的备注名",
"deleteConfirm": "确定要删除该连接吗?",
@@ -575,6 +685,10 @@
"addExternal": "添加外部MCP",
"toolConfig": "MCP 工具配置",
"saveToolConfig": "保存工具配置",
"alwaysVisibleLabel": "常驻",
"alwaysVisibleHint": "始终常驻在 Tool Search 可见列表(不被 tool_search 隐藏)",
"alwaysVisibleBuiltinLabel": "内置默认",
"alwaysVisibleBuiltinHint": "后端内置工具默认常驻,不可关闭",
"externalConfig": "外部 MCP 配置",
"loadingTools": "正在加载工具列表...",
"loadToolsTimeout": "加载工具列表超时,可能是外部MCP连接较慢。请点击\"刷新\"按钮重试,或检查外部MCP连接状态。",
@@ -1313,6 +1427,12 @@
"clear": "清除",
"vulnId": "漏洞ID",
"conversationId": "会话ID",
"taskOrQueueId": "任务ID/队列ID",
"filterTaskOrQueue": "筛选任务ID或队列ID",
"conversationTag": "对话标签",
"filterConversationTag": "筛选对话标签",
"taskTag": "任务标签",
"filterTaskTag": "筛选任务标签",
"severity": "严重程度",
"status": "状态",
"statusOpen": "待处理",
@@ -1322,7 +1442,31 @@
"searchVulnId": "搜索漏洞ID",
"filterConversation": "筛选特定会话",
"loading": "加载中...",
"noRecords": "暂无漏洞记录"
"loadListFailed": "加载失败",
"noRecords": "暂无漏洞记录",
"batchExport": "批量导出",
"downloadMarkdownTitle": "下载 Markdown",
"exportNoResults": "当前筛选条件下无可导出漏洞",
"exportStarted": "已开始下载 {{count}} 份报告",
"exportFailed": "导出失败",
"saveRequiredFields": "请填写必填字段:会话ID、标题和严重程度",
"saveFailed": "保存失败",
"fetchFailed": "获取漏洞失败",
"deleteFailed": "删除失败",
"detailVulnId": "漏洞ID",
"detailType": "类型",
"detailTarget": "目标",
"detailConversationId": "会话ID",
"detailTaskId": "任务ID",
"detailTaskQueueId": "任务队列ID",
"detailConversationTag": "对话标签",
"detailTaskTag": "任务标签",
"detailProof": "证明",
"detailImpact": "影响",
"detailRecommendation": "修复建议",
"downloadOkTitle": "下载成功",
"exportFailedMessage": "导出失败",
"downloadFailed": "下载失败"
},
"tasksPage": {
"statusFilter": "状态筛选",
@@ -1673,6 +1817,7 @@
},
"contextMenu": {
"viewAttackChain": "查看攻击链",
"viewVulnerabilities": "查看漏洞",
"downloadMarkdown": "下载 Markdown",
"downloadMarkdownSummary": "简版",
"downloadMarkdownFull": "完整版",
@@ -1768,6 +1913,10 @@
"vulnerabilityModal": {
"conversationId": "会话ID",
"conversationIdPlaceholder": "输入会话ID",
"conversationTag": "对话标签",
"conversationTagPlaceholder": "如:红队演练A、客户A周报",
"taskTag": "任务标签",
"taskTagPlaceholder": "如:批量扫描Q2、专项复测",
"title": "标题",
"titlePlaceholder": "漏洞标题",
"description": "描述",
@@ -1795,6 +1944,25 @@
"recommendation": "修复建议",
"recommendationPlaceholder": "修复建议"
},
"vulnerabilityMd": {
"headingBasic": "基本信息",
"labelId": "漏洞ID",
"labelSeverity": "严重程度",
"labelStatus": "状态",
"labelType": "类型",
"labelTarget": "目标",
"labelConversationId": "会话ID",
"labelTaskId": "任务ID",
"labelTaskQueueId": "任务队列ID",
"labelConversationTag": "对话标签",
"labelTaskTag": "任务标签",
"labelCreated": "创建时间",
"labelUpdated": "更新时间",
"headingDescription": "描述",
"headingProof": "证明(POC",
"headingImpact": "影响",
"headingRecommendation": "修复建议"
},
"roleModal": {
"addRole": "添加角色",
"editRole": "编辑角色",
+1134 -655
View File
File diff suppressed because it is too large Load Diff
+1251 -98
View File
File diff suppressed because it is too large Load Diff
+42 -16
View File
@@ -7,6 +7,19 @@ function hitlModeNormalize(m) {
return allowed.indexOf(v) >= 0 ? v : 'off';
}
function hitlT(key, fallback, params) {
const fullKey = 'hitl.' + key;
try {
if (typeof window.t === 'function') {
const translated = window.t(fullKey, params || {});
if (typeof translated === 'string' && translated && translated !== fullKey) {
return translated;
}
}
} catch (e) {}
return fallback;
}
function hitlEffectiveEnabled(cfg) {
if (!cfg) return false;
if (cfg.enabled === true) return true;
@@ -36,6 +49,18 @@ function hitlSensitiveToolsToArray(config) {
return [];
}
function normalizeHitlTimeoutSeconds(v, fallback) {
const n = Number(v);
if (Number.isFinite(n)) {
return n > 0 ? Math.floor(n) : 0;
}
const f = Number(fallback);
if (Number.isFinite(f)) {
return f > 0 ? Math.floor(f) : 0;
}
return 0;
}
function getCurrentConversationIdForHitl() {
if (typeof window.currentConversationId === 'string' && window.currentConversationId) {
return window.currentConversationId;
@@ -84,6 +109,7 @@ async function saveHitlConversationConfig(conversationId, config) {
const mode = hitlModeNormalize(config.mode || 'off');
const enabled = typeof config.enabled === 'boolean' ? config.enabled : (mode !== 'off');
const sensitiveTools = hitlSensitiveToolsToArray(config);
const timeoutSeconds = normalizeHitlTimeoutSeconds(config.timeoutSeconds, 0);
const resp = await hitlApiFetch('/api/hitl/config', {
method: 'PUT',
credentials: 'same-origin',
@@ -93,7 +119,7 @@ async function saveHitlConversationConfig(conversationId, config) {
enabled: enabled,
mode: mode,
sensitiveTools: sensitiveTools,
timeoutSeconds: config.timeoutSeconds || 300
timeoutSeconds: timeoutSeconds
})
});
if (!resp.ok) {
@@ -126,7 +152,7 @@ async function syncHitlConfigFromServer(conversationId) {
enabled: true,
mode: localMode,
sensitiveTools: localToolsStr.split(/[,\n\r]+/).map(function (s) { return s.trim(); }).filter(Boolean),
timeoutSeconds: cfg.timeoutSeconds || 300
timeoutSeconds: normalizeHitlTimeoutSeconds(cfg.timeoutSeconds, 0)
};
saveHitlConversationConfig(conversationId, {
mode: localMode,
@@ -146,7 +172,7 @@ async function syncHitlConfigFromServer(conversationId) {
enabled: true,
mode: glMode,
sensitiveTools: glToolsStr.split(/[,\n\r]+/).map(function (s) { return s.trim(); }).filter(Boolean),
timeoutSeconds: cfg.timeoutSeconds || 300
timeoutSeconds: normalizeHitlTimeoutSeconds(cfg.timeoutSeconds, 0)
};
saveHitlConversationConfig(conversationId, {
mode: glMode,
@@ -265,7 +291,7 @@ async function followAgentRunAfterHitlDecision(conversationId) {
async function refreshHitlPending() {
const container = document.getElementById('hitl-pending-list');
if (!container) return;
container.innerHTML = '<div class="loading-spinner">Loading...</div>';
container.innerHTML = '<div class="loading-spinner">' + escapeHtml(hitlT('loading', 'Loading...')) + '</div>';
try {
const resp = await hitlApiFetch('/api/hitl/pending', { credentials: 'same-origin' });
if (!resp.ok) {
@@ -274,7 +300,7 @@ async function refreshHitlPending() {
const data = await resp.json();
const items = Array.isArray(data.items) ? data.items : [];
if (!items.length) {
container.innerHTML = '<div class="empty-state">暂无待审批项</div>';
container.innerHTML = '<div class="empty-state">' + escapeHtml(hitlT('emptyState', 'No pending approvals')) + '</div>';
return;
}
container.innerHTML = items.map(function (item) {
@@ -292,25 +318,25 @@ async function refreshHitlPending() {
'<span class="hitl-tool-badge">' + escapeHtml(item.toolName || '-') + '</span>' +
'<span class="hitl-mode-tag hitl-mode-tag--' + escapeHtml(mode) + '">' + escapeHtml(item.mode || '-') + '</span>' +
'</div>' +
'<button class="hitl-dismiss-btn" title="忽略" onclick="dismissHitlItem(' + qId + ')">&times;</button>' +
'<button class="hitl-dismiss-btn" title="' + escapeHtml(hitlT('dismiss', 'Dismiss')) + '" onclick="dismissHitlItem(' + qId + ')">&times;</button>' +
'</div>' +
'<div class="hitl-pending-meta">会话:' + escapeHtml(item.conversationId || '-') + '</div>' +
'<div class="hitl-pending-meta">' + escapeHtml(hitlT('conversationLabel', 'Conversation:')) + ' ' + escapeHtml(item.conversationId || '-') + '</div>' +
'<pre class="hitl-pending-payload">' + escapeHtml(preview) + '</pre>' +
(allowEdit
? ('<div class="hitl-input-help">审查编辑模式:可填写 JSON 对象覆盖参数。示例:{"command":"ls -la"}</div>' +
? ('<div class="hitl-input-help">' + escapeHtml(hitlT('reviewEditHelp', 'Review & edit mode: provide a JSON object to override tool arguments. Example: {"command":"ls -la"}')) + '</div>' +
'<textarea id="hitl-edit-' + escId + '" class="hitl-edit-args" placeholder=\'{"command":"ls -la"}\'></textarea>')
: '<div class="hitl-input-help">审批模式:仅通过/拒绝,不支持改参。</div>') +
'<div class="hitl-input-help">备注(可选):建议写审批依据。</div>' +
'<input id="hitl-comment-' + escId + '" class="hitl-config-input hitl-inline-comment" type="text" placeholder="例如:允许只读命令">' +
: '<div class="hitl-input-help">' + escapeHtml(hitlT('approvalHelp', 'Approval mode: only approve/reject, argument editing is disabled.')) + '</div>') +
'<div class="hitl-input-help">' + escapeHtml(hitlT('commentHelp', 'Comment (optional): briefly note the approval reason.')) + '</div>' +
'<input id="hitl-comment-' + escId + '" class="hitl-config-input hitl-inline-comment" type="text" placeholder="' + escapeHtml(hitlT('commentPlaceholder', 'e.g. allow read-only command')) + '">' +
'<div class="hitl-pending-actions">' +
'<button class="btn-secondary" onclick="submitHitlDecision(' + qId + ',&quot;reject&quot;,' + qConv + ')">拒绝</button>' +
'<button class="btn-primary" onclick="submitHitlDecision(' + qId + ',&quot;approve&quot;,' + qConv + ')">通过</button>' +
'<button class="btn-secondary" onclick="submitHitlDecision(' + qId + ',&quot;reject&quot;,' + qConv + ')">' + escapeHtml(hitlT('reject', 'Reject')) + '</button>' +
'<button class="btn-primary" onclick="submitHitlDecision(' + qId + ',&quot;approve&quot;,' + qConv + ')">' + escapeHtml(hitlT('approve', 'Approve')) + '</button>' +
'</div>' +
'</div>'
);
}).join('');
} catch (e) {
container.innerHTML = '<div class="empty-state">加载失败</div>';
container.innerHTML = '<div class="empty-state">' + escapeHtml(hitlT('loadFailed', 'Failed to load')) + '</div>';
}
}
@@ -323,7 +349,7 @@ async function submitHitlDecision(interruptId, decision, conversationIdOpt) {
try {
editedArguments = JSON.parse(editBox.value.trim());
} catch (e) {
alert('JSON 参数格式错误');
alert(hitlT('invalidJson', 'Invalid JSON arguments'));
return;
}
}
@@ -344,7 +370,7 @@ async function submitHitlDecisionWithPayload(interruptId, decision, comment, edi
await dismissHitlItem(interruptId, true);
return true;
}
alert('提交失败:' + errText);
alert(hitlT('submitFailedPrefix', 'Submit failed:') + ' ' + errText);
return false;
}
refreshHitlPending();
+24 -39
View File
@@ -1133,24 +1133,6 @@ function handleStreamEvent(event, progressElement, progressId,
});
break;
case 'eino_recovery': {
const d = event.data || {};
const runIdx = d.runIndex != null ? d.runIndex : (d.einoRetry != null ? d.einoRetry + 1 : 1);
const maxRuns = d.maxRuns != null ? d.maxRuns : 3;
const title = typeof window.t === 'function'
? window.t('chat.einoRecoveryTitle', { n: runIdx, max: maxRuns })
: ('🔄 工具参数无效 · 第 ' + runIdx + '/' + maxRuns + ' 轮(已追加提示)');
addTimelineItem(timeline, 'eino_recovery', {
title: title,
message: event.message || '',
data: event.data
});
// If the backend triggers a recovery run, any "running" tool_call items in this progress
// should be closed to avoid being stuck forever.
finalizeOutstandingToolCallsForProgress(progressId, 'failed');
break;
}
case 'eino_stream_error': {
const d = event.data || {};
const agent = d.einoAgent ? String(d.einoAgent) : '';
@@ -2190,15 +2172,6 @@ function addTimelineItem(timeline, type, options) {
if (type === 'progress' && options.message) {
item.dataset.progressMessage = options.message;
}
if (type === 'eino_recovery' && options.data) {
const d = options.data;
if (d.runIndex != null) {
item.dataset.recoveryRunIndex = String(d.runIndex);
}
if (d.maxRuns != null) {
item.dataset.recoveryMaxRuns = String(d.maxRuns);
}
}
if (type === 'tool_calls_detected' && options.data && options.data.count != null) {
item.dataset.toolCallsCount = String(options.data.count);
}
@@ -2309,12 +2282,6 @@ function addTimelineItem(timeline, type, options) {
</div>
</div>
`;
} else if (type === 'eino_recovery' && options.message) {
content += `
<div class="timeline-item-content timeline-eino-recovery">
${escapeHtml(options.message).replace(/\n/g, '<br>')}
</div>
`;
} else if (type === 'cancelled') {
const taskCancelledLabel = typeof window.t === 'function' ? window.t('chat.taskCancelled') : '任务已取消';
content += `
@@ -2381,9 +2348,28 @@ function renderActiveTasks(tasks) {
bar.style.display = 'flex';
bar.innerHTML = '';
function openActiveTaskConversation(conversationId) {
if (!conversationId) return;
if (typeof switchPage === 'function') {
switchPage('chat');
}
if (typeof window.loadConversation === 'function') {
setTimeout(function () {
window.loadConversation(conversationId);
}, 120);
return;
}
window.location.hash = 'chat?conversation=' + encodeURIComponent(conversationId);
}
normalizedTasks.forEach(task => {
const item = document.createElement('div');
item.className = 'active-task-item';
item.className = 'active-task-item active-task-item-clickable';
if (task && task.conversationId) {
item.title = (typeof window.t === 'function' ? window.t('tasks.viewConversation') : '查看会话');
item.setAttribute('role', 'button');
item.onclick = () => openActiveTaskConversation(task.conversationId);
}
const startedTime = task.startedAt ? new Date(task.startedAt) : null;
const taskTimeLocale = getCurrentTimeLocale();
@@ -2421,7 +2407,10 @@ function renderActiveTasks(tasks) {
if (!isFinalStatus) {
const cancelBtn = item.querySelector('.active-task-cancel');
if (cancelBtn) {
cancelBtn.onclick = () => cancelActiveTask(task.conversationId, cancelBtn);
cancelBtn.onclick = (evt) => {
evt.stopPropagation();
cancelActiveTask(task.conversationId, cancelBtn);
};
if (task.status === 'cancelling') {
cancelBtn.disabled = true;
cancelBtn.textContent = typeof window.t === 'function' ? window.t('tasks.cancelling') : '取消中...';
@@ -3197,10 +3186,6 @@ function refreshProgressAndTimelineI18n() {
titleSpan.textContent = ap + icon + (success ? _t('chat.toolExecComplete', { name: name }) : _t('chat.toolExecFailed', { name: name }));
} else if (type === 'eino_agent_reply') {
titleSpan.textContent = ap + '\uD83D\uDCAC ' + _t('chat.einoAgentReplyTitle');
} else if (type === 'eino_recovery' && item.dataset.recoveryRunIndex) {
const n = parseInt(item.dataset.recoveryRunIndex, 10) || 1;
const mx = parseInt(item.dataset.recoveryMaxRuns, 10) || 3;
titleSpan.textContent = _t('chat.einoRecoveryTitle', { n: n, max: mx });
} else if (type === 'cancelled') {
titleSpan.textContent = '\u26D4 ' + _t('chat.taskCancelled');
} else if (type === 'progress' && item.dataset.progressMessage !== undefined) {
+329
View File
@@ -0,0 +1,329 @@
(function () {
const STORAGE_LAST_SEEN_KEY = 'cyberstrike-notification-last-seen-at';
const POLL_INTERVAL_ACTIVE_MS = 15000;
const POLL_INTERVAL_HIDDEN_MS = 60000;
const MAX_RENDER_ITEMS = 20;
const state = {
inFlight: false,
timerId: null,
dropdownOpen: false,
lastSeenAt: readLastSeenAt(),
items: [],
unreadCount: 0,
};
function readLastSeenAt() {
try {
const raw = localStorage.getItem(STORAGE_LAST_SEEN_KEY);
const n = Number(raw);
if (Number.isFinite(n) && n > 0) return n;
} catch (e) {
console.warn('读取通知已读时间失败:', e);
}
return 0;
}
function persistLastSeenAt(ts) {
try {
localStorage.setItem(STORAGE_LAST_SEEN_KEY, String(ts));
} catch (e) {
console.warn('保存通知已读时间失败:', e);
}
}
function getTimeMs(value) {
if (!value) return 0;
const d = new Date(value);
const ms = d.getTime();
return Number.isFinite(ms) ? ms : 0;
}
function getLocale() {
if (typeof window !== 'undefined') {
if (typeof window.__locale === 'string' && window.__locale) {
return window.__locale;
}
if (typeof window.currentLang === 'string' && window.currentLang) {
return window.currentLang;
}
}
return 'zh-CN';
}
function formatTime(value) {
const ms = getTimeMs(value);
if (!ms) return '-';
return new Date(ms).toLocaleString(getLocale());
}
function htmlEscape(value) {
if (typeof window.escapeHtml === 'function') {
return window.escapeHtml(value == null ? '' : String(value));
}
const div = document.createElement('div');
div.textContent = value == null ? '' : String(value);
return div.innerHTML;
}
function t(key, fallback, params) {
if (typeof window !== 'undefined' && typeof window.t === 'function') {
try {
const translated = window.t(key, params || {});
if (translated && translated !== key) return translated;
} catch (_ignored) {}
}
return fallback;
}
async function apiJson(url, options) {
if (typeof window.apiFetch !== 'function') return null;
const res = await window.apiFetch(url, options || {});
if (!res.ok) return null;
return res.json();
}
async function fetchNotificationSummary() {
const url = '/api/notifications/summary?since='
+ encodeURIComponent(String(state.lastSeenAt || 0))
+ '&limit=80&lang=' + encodeURIComponent(getLocale());
try {
const summary = await apiJson(url);
if (summary && typeof summary === 'object') {
return summary;
}
} catch (_ignored) {}
return null;
}
function renderBadge(count) {
const badge = document.getElementById('notification-badge');
const btn = document.getElementById('notification-bell-btn');
if (!badge || !btn) return;
if (count <= 0) {
badge.style.display = 'none';
btn.classList.remove('has-alert');
return;
}
const text = count > 99 ? '99+' : String(count);
badge.innerHTML = '<span class="notification-badge-text">' + htmlEscape(text) + '</span>';
badge.style.display = 'inline-block';
btn.classList.add('has-alert');
}
function countP0(items) {
return (Array.isArray(items) ? items : []).reduce((acc, item) => {
if (!item || item.level !== 'p0') return acc;
if (typeof item.count === 'number' && item.count > 0) return acc + item.count;
return acc + 1;
}, 0);
}
function markableItems(items) {
return (Array.isArray(items) ? items : []).filter(item => item && item.actionable !== true && item.id);
}
function hasAction(item) {
if (!item || !item.type) return false;
if (item.type === 'vulnerability_created' && item.vulnerabilityId) return true;
if ((item.type === 'task_completed' || item.type === 'long_running_tasks') && item.conversationId) return true;
if (item.type === 'task_failed' && item.executionId) return true;
if (item.type === 'hitl_pending') return true;
return false;
}
function openNotificationTarget(item) {
if (!item || !item.type) return;
if (item.type === 'vulnerability_created' && item.vulnerabilityId) {
window.location.hash = 'vulnerabilities?id=' + encodeURIComponent(item.vulnerabilityId);
return;
}
if ((item.type === 'task_completed' || item.type === 'long_running_tasks') && item.conversationId) {
window.location.hash = 'chat?conversation=' + encodeURIComponent(item.conversationId);
return;
}
if (item.type === 'task_failed' && item.executionId) {
window.location.hash = 'mcp-monitor';
setTimeout(function () {
if (typeof showMCPDetail === 'function') {
showMCPDetail(item.executionId);
}
}, 450);
return;
}
if (item.type === 'hitl_pending') {
window.location.hash = 'hitl';
}
}
async function markItemsRead(eventIds) {
if (!Array.isArray(eventIds) || !eventIds.length) return true;
const payload = { eventIds: eventIds };
try {
const result = await apiJson('/api/notifications/read', {
method: 'POST',
headers: { 'Content-Type': 'application/json' },
body: JSON.stringify(payload),
});
return !!result;
} catch (_ignored) {
return false;
}
}
function renderNotificationList(items) {
const list = document.getElementById('notification-list');
if (!list) return;
const renderItems = Array.isArray(items) ? items.slice(0, MAX_RENDER_ITEMS) : [];
if (!renderItems.length) {
list.innerHTML = '<div class="notification-empty">' + htmlEscape(t('notifications.empty', '暂无新事件')) + '</div>';
return;
}
const html = renderItems.map(item => {
const canMarkRead = item.actionable !== true && !!item.id;
const canView = hasAction(item);
return `
<div class="notification-item notification-level-${htmlEscape(item.level || 'p2')}">
<div class="notification-item-header">
<div class="notification-item-title">${htmlEscape(item.title || t('notifications.itemDefaultTitle', '通知'))}</div>
<div class="notification-item-actions">
${canView ? `<button class="notification-item-action-btn notification-item-view-btn" type="button" data-action-id="${htmlEscape(item.id || '')}">${htmlEscape(t('common.view', '查看'))}</button>` : ''}
${canMarkRead ? `<button class="notification-item-action-btn notification-item-read-btn" type="button" data-notification-id="${htmlEscape(item.id)}">${htmlEscape(t('notifications.markSingleRead', '已读'))}</button>` : ''}
</div>
</div>
<div class="notification-item-desc">${htmlEscape(item.desc || '')}</div>
<div class="notification-item-time">${htmlEscape(formatTime(item.ts))}</div>
</div>
`;
}).join('');
list.innerHTML = html;
const viewButtons = list.querySelectorAll('.notification-item-view-btn');
viewButtons.forEach(btn => {
btn.addEventListener('click', function (event) {
event.preventDefault();
event.stopPropagation();
const eventID = btn.getAttribute('data-action-id') || '';
if (!eventID) return;
const item = state.items.find(it => it && it.id === eventID);
if (!item) return;
openNotificationTarget(item);
closeDropdown();
});
});
const readButtons = list.querySelectorAll('.notification-item-read-btn');
readButtons.forEach(btn => {
btn.addEventListener('click', async function (event) {
event.preventDefault();
event.stopPropagation();
const eventID = btn.getAttribute('data-notification-id') || '';
if (!eventID) return;
const ok = await markItemsRead([eventID]);
if (ok) {
await refreshNotifications();
}
});
});
}
function closeDropdown() {
const dropdown = document.getElementById('notification-dropdown');
const bellBtn = document.getElementById('notification-bell-btn');
if (dropdown) dropdown.style.display = 'none';
if (bellBtn) bellBtn.classList.remove('active');
state.dropdownOpen = false;
}
function markSeenNow() {
state.lastSeenAt = Date.now();
persistLastSeenAt(state.lastSeenAt);
}
async function refreshNotifications() {
if (state.inFlight) return;
state.inFlight = true;
try {
const summary = await fetchNotificationSummary();
const items = summary && Array.isArray(summary.items) ? summary.items : [];
state.items = items;
const unreadCount = summary && Number.isFinite(Number(summary.unreadCount))
? Number(summary.unreadCount)
: countP0(items);
state.unreadCount = Math.max(0, unreadCount);
renderBadge(state.unreadCount);
renderNotificationList(items);
} catch (e) {
console.warn('刷新通知失败:', e);
} finally {
state.inFlight = false;
}
}
function scheduleNextPoll() {
if (state.timerId) {
window.clearTimeout(state.timerId);
state.timerId = null;
}
const interval = document.hidden ? POLL_INTERVAL_HIDDEN_MS : POLL_INTERVAL_ACTIVE_MS;
state.timerId = window.setTimeout(async function () {
await refreshNotifications();
scheduleNextPoll();
}, interval);
}
function handleDocumentClick(event) {
const container = document.querySelector('.notification-menu-container');
if (!container) return;
if (!container.contains(event.target)) {
closeDropdown();
}
}
async function toggleDropdown() {
const dropdown = document.getElementById('notification-dropdown');
const bellBtn = document.getElementById('notification-bell-btn');
if (!dropdown || !bellBtn) return;
const isOpen = dropdown.style.display !== 'none';
if (isOpen) {
closeDropdown();
return;
}
// 从仪表盘「查看全部」等容器外入口打开时,同一 click 会冒泡到 document
// handleDocumentClick 会误判为「点在外面」并立刻关掉。推迟到宏任务再展开即可。
const runOpen = async function () {
if (dropdown.style.display !== 'none') return;
dropdown.style.display = 'block';
bellBtn.classList.add('active');
state.dropdownOpen = true;
await refreshNotifications();
};
window.setTimeout(function () {
void runOpen();
}, 0);
}
async function markAllSeen() {
const ids = markableItems(state.items).map(item => item.id);
const ok = await markItemsRead(ids);
if (ok) {
markSeenNow();
await refreshNotifications();
}
}
function initNotifications() {
const bellBtn = document.getElementById('notification-bell-btn');
if (!bellBtn) return;
document.addEventListener('click', handleDocumentClick);
document.addEventListener('visibilitychange', scheduleNextPoll);
document.addEventListener('languagechange', function () {
refreshNotifications();
});
refreshNotifications();
scheduleNextPoll();
}
window.toggleNotificationDropdown = toggleDropdown;
window.markAllNotificationsSeen = markAllSeen;
document.addEventListener('DOMContentLoaded', initNotifications);
})();
+35 -5
View File
@@ -1,6 +1,28 @@
// 角色管理相关功能
function _t(key, opts) {
return typeof window.t === 'function' ? window.t(key, opts) : key;
if (typeof window.t === 'function') {
try {
var translated = window.t(key, opts);
if (typeof translated === 'string' && translated && translated !== key) {
return translated;
}
} catch (e) { /* ignore */ }
}
// i18n 未就绪或词条缺失时避免把 key 暴露给用户(与 zh-CN 默认一致)
if (key === 'roles.noDescription') return '暂无描述';
if (key === 'roles.noDescriptionShort') return '无描述';
if (key === 'roles.defaultRoleDescription') {
return '默认角色,不额外携带用户提示词,使用默认MCP';
}
return key;
}
/** 角色配置中的描述:trim,并把误存为 i18n key 的字面量视为空 */
function rolePlainDescription(role) {
const raw = typeof role.description === 'string' ? role.description.trim() : '';
if (!raw) return '';
if (raw === 'roles.noDescription' || raw === 'roles.noDescriptionShort') return '';
return raw;
}
let currentRole = localStorage.getItem('currentRole') || '';
let roles = [];
@@ -56,6 +78,11 @@ function sortRoles(rolesArray) {
// 加载所有角色
async function loadRoles() {
if (window.i18nReady && typeof window.i18nReady.then === 'function') {
try {
await window.i18nReady;
} catch (e) { /* ignore */ }
}
try {
const response = await apiFetch('/api/roles');
if (!response.ok) {
@@ -189,8 +216,9 @@ function renderRoleSelectionSidebar() {
const icon = getRoleIcon(role);
// 处理默认角色的描述
let description = role.description || _t('roles.noDescription');
if (isDefaultRole && !role.description) {
const plainDesc = rolePlainDescription(role);
let description = plainDesc || _t('roles.noDescription');
if (isDefaultRole && !plainDesc) {
description = _t('roles.defaultRoleDescription');
}
@@ -316,6 +344,7 @@ function renderRolesList() {
const sortedRoles = sortRoles(filteredRoles);
rolesList.innerHTML = sortedRoles.map(role => {
const plainDesc = rolePlainDescription(role);
// 获取角色图标,如果是Unicode转义格式则转换为emoji
let roleIcon = role.icon || '👤';
if (roleIcon && typeof roleIcon === 'string') {
@@ -369,7 +398,7 @@ function renderRolesList() {
${role.enabled !== false ? _t('roles.enabled') : _t('roles.disabled')}
</span>
</div>
<div class="role-card-description">${escapeHtml(role.description || _t('roles.noDescriptionShort'))}</div>
<div class="role-card-description">${escapeHtml(plainDesc || _t('roles.noDescriptionShort'))}</div>
<div class="role-card-tools">
<span class="role-card-tools-label">${_t('roleModal.toolsLabel')}</span>
<span class="role-card-tools-value">${toolsDisplay}</span>
@@ -1575,9 +1604,10 @@ document.addEventListener('DOMContentLoaded', () => {
updateRoleSelectorDisplay();
});
// 语言切换后刷新角色选择器显示(默认/自定义角色名)
// 语言切换后刷新角色选择器与「选择角色」列表文案
document.addEventListener('languagechange', () => {
updateRoleSelectorDisplay();
renderRoleSelectionSidebar();
});
// 获取当前选中的角色(供chat.js使用)
+31 -19
View File
@@ -1,19 +1,19 @@
// 页面路由管理
let currentPage = 'dashboard';
/** 仅当停留在 chat 时保留 ?conversation= 等查询串,其它页面只使用 pageId */
/** chat、漏洞管理页在切换时保留当前 hash 上的查询串(如 ?conversation= / ?conversation_id= */
function buildHashForPage(pageId) {
if (pageId !== 'chat') {
if (pageId !== 'chat' && pageId !== 'vulnerabilities') {
return pageId;
}
const full = window.location.hash.slice(1);
const parts = full.split('?');
const curPage = parts[0];
const q = parts.length > 1 ? parts.slice(1).join('?') : '';
if (curPage === 'chat' && q) {
return 'chat?' + q;
if (curPage === pageId && q) {
return pageId + '?' + q;
}
return 'chat';
return pageId;
}
let chatConversationFromHashSeq = 0;
@@ -301,26 +301,38 @@ async function initPage(pageId) {
break;
case 'mcp-management':
// 初始化MCP管理
const startLoadMcpTools = () => {
// 加载工具列表(MCP工具配置已移到MCP管理页面)
// 使用异步加载,避免阻塞页面渲染
if (typeof loadToolsList === 'function') {
// 确保工具分页设置已初始化
if (typeof getToolsPageSize === 'function' && typeof toolsPagination !== 'undefined') {
toolsPagination.pageSize = getToolsPageSize();
}
// 延迟加载,让页面先渲染
setTimeout(() => {
loadToolsList(1, '').catch(err => {
console.error('加载工具列表失败:', err);
});
}, 100);
}
};
// 先拉取全局配置,确保 tool_search 常驻状态按后端生效集合展示
if (typeof loadConfig === 'function') {
loadConfig(false)
.catch(err => {
console.warn('加载配置失败(将继续加载工具列表):', err);
})
.finally(startLoadMcpTools);
} else {
startLoadMcpTools();
}
// 先加载外部MCP列表(快速),然后加载工具列表
if (typeof loadExternalMCPs === 'function') {
loadExternalMCPs().catch(err => {
console.warn('加载外部MCP列表失败:', err);
});
}
// 加载工具列表(MCP工具配置已移到MCP管理页面)
// 使用异步加载,避免阻塞页面渲染
if (typeof loadToolsList === 'function') {
// 确保工具分页设置已初始化
if (typeof getToolsPageSize === 'function' && typeof toolsPagination !== 'undefined') {
toolsPagination.pageSize = getToolsPageSize();
}
// 延迟加载,让页面先渲染
setTimeout(() => {
loadToolsList(1, '').catch(err => {
console.error('加载工具列表失败:', err);
});
}, 100);
}
break;
case 'vulnerabilities':
// 初始化漏洞管理页面
+45 -6
View File
@@ -1,6 +1,8 @@
// 设置相关功能
let currentConfig = null;
let allTools = [];
let alwaysVisibleToolNames = new Set();
let alwaysVisibleBuiltinToolNames = new Set();
// 全局工具状态映射,用于保存用户在所有页面的修改
// key: 唯一工具标识符(toolKey),value: { enabled: boolean, is_external: boolean, external_mcp: string }
let toolStateMap = new Map();
@@ -100,6 +102,14 @@ async function loadConfig(loadTools = true) {
}
currentConfig = await response.json();
const alwaysVisibleList = currentConfig?.multi_agent?.tool_search_always_visible_effective_tools;
const alwaysVisibleConfigured = currentConfig?.multi_agent?.tool_search_always_visible_tools;
alwaysVisibleToolNames = new Set(Array.isArray(alwaysVisibleList) ? alwaysVisibleList.filter(Boolean) : []);
alwaysVisibleBuiltinToolNames = new Set(
alwaysVisibleToolNames.size > 0 && Array.isArray(alwaysVisibleConfigured)
? Array.from(alwaysVisibleToolNames).filter(name => !alwaysVisibleConfigured.includes(name))
: []
);
// 填充OpenAI配置
const providerEl = document.getElementById('openai-provider');
@@ -395,10 +405,13 @@ async function loadToolsList(page = 1, searchKeyword = '') {
}
}
// 每行有两类复选框:行首「启用工具」与名称旁「常驻」;统计/全选只应针对行首启用复选框
const TOOL_ENABLE_CHECKBOX_SELECTOR = '#tools-list .tool-item > input[type="checkbox"]';
// 保存当前页的工具状态到全局映射
function saveCurrentPageToolStates() {
document.querySelectorAll('#tools-list .tool-item').forEach(item => {
const checkbox = item.querySelector('input[type="checkbox"]');
const checkbox = item.querySelector(':scope > input[type="checkbox"]');
const toolKey = item.dataset.toolKey; // 使用唯一标识符
const toolName = item.dataset.toolName;
const isExternal = item.dataset.isExternal === 'true';
@@ -498,6 +511,8 @@ function renderToolsList() {
is_external: tool.is_external || false,
external_mcp: tool.external_mcp || ''
};
const alwaysVisibleChecked = alwaysVisibleToolNames.has(tool.name);
const alwaysVisibleLocked = alwaysVisibleBuiltinToolNames.has(tool.name);
// 外部工具标签,显示来源信息(可点击跳转到对应 MCP 卡片)
let externalBadge = '';
@@ -521,6 +536,11 @@ function renderToolsList() {
<div class="tool-item-name">
${escapeHtml(tool.name)}
${externalBadge}
<label class="tool-resident-toggle" title="${typeof window.t === 'function' ? window.t('mcp.alwaysVisibleHint') : '始终常驻在 Tool Search 可见列表'}" onclick="event.stopPropagation()">
<input type="checkbox" ${alwaysVisibleChecked ? 'checked' : ''} ${alwaysVisibleLocked ? 'disabled' : ''} onchange="handleToolAlwaysVisibleChange('${escapeHtml(tool.name)}', this.checked)" />
<span>${typeof window.t === 'function' ? window.t('mcp.alwaysVisibleLabel') : '常驻'}</span>
</label>
${alwaysVisibleLocked ? `<span class="external-tool-badge" title="${typeof window.t === 'function' ? window.t('mcp.alwaysVisibleBuiltinHint') : '后端内置工具默认常驻,不可关闭'}">${typeof window.t === 'function' ? window.t('mcp.alwaysVisibleBuiltinLabel') : '内置默认'}</span>` : ''}
<span class="tool-expand-icon"></span>
</div>
<div class="tool-item-desc">${escapeHtml(tool.description || (typeof window.t === 'function' ? window.t('mcp.noDescription') : '无描述'))}</div>
@@ -716,9 +736,19 @@ function handleToolCheckboxChange(toolKey, enabled) {
updateToolsStats();
}
function handleToolAlwaysVisibleChange(toolName, alwaysVisible) {
const name = (toolName || '').trim();
if (!name) return;
if (alwaysVisible) {
alwaysVisibleToolNames.add(name);
} else {
alwaysVisibleToolNames.delete(name);
}
}
// 全选工具
function selectAllTools() {
document.querySelectorAll('#tools-list input[type="checkbox"]').forEach(checkbox => {
document.querySelectorAll(TOOL_ENABLE_CHECKBOX_SELECTOR).forEach(checkbox => {
checkbox.checked = true;
// 更新全局状态映射
const toolItem = checkbox.closest('.tool-item');
@@ -742,7 +772,7 @@ function selectAllTools() {
// 全不选工具
function deselectAllTools() {
document.querySelectorAll('#tools-list input[type="checkbox"]').forEach(checkbox => {
document.querySelectorAll(TOOL_ENABLE_CHECKBOX_SELECTOR).forEach(checkbox => {
checkbox.checked = false;
// 更新全局状态映射
const toolItem = checkbox.closest('.tool-item');
@@ -799,9 +829,9 @@ async function updateToolsStats() {
// 先保存当前页的状态到全局映射
saveCurrentPageToolStates();
// 计算当前页的启用工具数
const currentPageEnabled = Array.from(document.querySelectorAll('#tools-list input[type="checkbox"]:checked')).length;
const currentPageTotal = document.querySelectorAll('#tools-list input[type="checkbox"]').length;
// 计算当前页的启用工具数(仅行首「启用」复选框,不含「常驻」)
const currentPageEnabled = Array.from(document.querySelectorAll(`${TOOL_ENABLE_CHECKBOX_SELECTOR}:checked`)).length;
const currentPageTotal = document.querySelectorAll(TOOL_ENABLE_CHECKBOX_SELECTOR).length;
// 计算所有工具的启用数
let totalEnabled = 0;
@@ -886,9 +916,11 @@ async function updateToolsStats() {
}
const tStats = typeof window.t === 'function' ? window.t : (k) => k;
const pinnedCount = alwaysVisibleToolNames.size;
statsEl.innerHTML = `
<span title="${tStats('mcp.currentPageEnabled')}"> ${tStats('mcp.currentPageEnabled')}: <strong>${currentPageEnabled}</strong> / ${currentPageTotal}</span>
<span title="${tStats('mcp.totalEnabled')}">📊 ${tStats('mcp.totalEnabled')}: <strong>${totalEnabled}</strong> / ${totalTools}</span>
<span title="${tStats('mcp.alwaysVisibleHint')}">📌 ${tStats('mcp.alwaysVisibleLabel')}: <strong>${pinnedCount}</strong></span>
`;
}
@@ -1230,6 +1262,13 @@ async function saveToolsConfig() {
const config = {
openai: currentConfig.openai || {},
agent: currentConfig.agent || {},
multi_agent: {
enabled: currentConfig?.multi_agent?.enabled === true,
robot_use_multi_agent: currentConfig?.multi_agent?.robot_use_multi_agent === true,
batch_use_multi_agent: currentConfig?.multi_agent?.batch_use_multi_agent === true,
plan_execute_loop_max_iterations: Number(currentConfig?.multi_agent?.plan_execute_loop_max_iterations || 0),
tool_search_always_visible_tools: Array.from(alwaysVisibleToolNames).filter(name => !alwaysVisibleBuiltinToolNames.has(name))
},
tools: []
};
+16 -2
View File
@@ -531,6 +531,7 @@ function renderTaskItem(task, statusMap, isHistory = false) {
${isHistory && completedText ? completedText : timeText}
</span>
${canCancel ? `<button class="btn-secondary btn-small" onclick="cancelTask('${task.conversationId}', this)">` + _t('tasks.cancelTask') + `</button>` : ''}
${task.conversationId ? `<button class="btn-secondary btn-small" onclick="navigateToVulnerabilitiesFromTasksPage('conversation', '${task.conversationId}')">` + _t('tasks.viewVulnerabilities') + `</button>` : ''}
${task.conversationId ? `<button class="btn-secondary btn-small" onclick="viewConversation('${task.conversationId}')">` + _t('tasks.viewConversation') + `</button>` : ''}
</div>
</div>
@@ -708,6 +709,17 @@ function viewConversation(conversationId) {
}
}
// 跳转漏洞管理并按对话 ID 或批量队列 ID 筛选(队列 ID 走 task_id,与列表筛选项一致)
function navigateToVulnerabilitiesFromTasksPage(kind, id) {
if (!id) return;
const enc = encodeURIComponent(id);
if (kind === 'queue') {
window.location.hash = 'vulnerabilities?task_id=' + enc;
} else if (kind === 'conversation') {
window.location.hash = 'vulnerabilities?conversation_id=' + enc;
}
}
// 刷新任务列表
async function refreshTasks() {
await loadTasks();
@@ -1134,6 +1146,8 @@ function renderBatchQueues() {
const progress = stats.total > 0 ? Math.round((stats.completed + stats.failed + stats.cancelled) / stats.total * 100) : 0;
// 允许删除待执行、已完成或已取消状态的队列
const canDelete = queue.status === 'pending' || queue.status === 'completed' || queue.status === 'cancelled';
// 操作列常驻「查看漏洞」,不再使用 --no-actions 隐藏整列(否则无法从运行中队列跳转漏洞页)
const noActionsClass = '';
const loadedRoles = batchQueuesState.loadedRoles || [];
const roleIcon = getRoleIconForDisplay(queue.role, loadedRoles);
@@ -1157,7 +1171,6 @@ function renderBatchQueues() {
: `<h4 class="batch-queue-card-title batch-queue-card-title--muted">${escapeHtml(_t('tasks.batchQueueUntitled'))}</h4>`;
const doneCount = stats.completed + stats.failed + stats.cancelled;
const noActionsClass = canDelete ? '' : ' batch-queue-item--no-actions';
return `
<div class="batch-queue-item batch-queue-item--compact${cardMod}${noActionsClass}" data-queue-id="${queue.id}" onclick="showBatchQueueDetail('${queue.id}')">
<div class="batch-queue-item__inner batch-queue-item__inner--grid">
@@ -1182,7 +1195,8 @@ function renderBatchQueues() {
</div>
</div>
<div class="batch-queue-item__actions-col" onclick="event.stopPropagation();">
${canDelete ? `<button type="button" class="batch-queue-icon-btn" onclick="deleteBatchQueueFromList('${queue.id}')" title="${escapeHtml(_t('tasks.deleteQueue'))}" aria-label="${escapeHtml(_t('tasks.deleteQueue'))}"><svg class="batch-queue-icon-btn__svg" width="18" height="18" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round" aria-hidden="true"><path d="M3 6h18"/><path d="M19 6v14a2 2 0 0 1-2 2H7a2 2 0 0 1-2-2V6"/><path d="M8 6V4a2 2 0 0 1 2-2h4a2 2 0 0 1 2 2v2"/><path d="M10 11v6"/><path d="M14 11v6"/></svg></button>` : ''}
<button type="button" class="batch-queue-icon-btn" onclick="navigateToVulnerabilitiesFromTasksPage('queue', '${queue.id}')" title="${escapeHtml(_t('tasks.viewVulnerabilitiesQueueTitle'))}" aria-label="${escapeHtml(_t('tasks.viewVulnerabilitiesQueueTitle'))}"><svg class="batch-queue-icon-btn__svg" width="18" height="18" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round" aria-hidden="true"><path d="M12 22s8-4 8-10V5l-8-3-8 3v7c0 6 8 10 8 10z"/><path d="M9 12l2 2 4-4"/></svg></button>
${canDelete ? `<button type="button" class="batch-queue-icon-btn batch-queue-icon-btn--danger" onclick="deleteBatchQueueFromList('${queue.id}')" title="${escapeHtml(_t('tasks.deleteQueue'))}" aria-label="${escapeHtml(_t('tasks.deleteQueue'))}"><svg class="batch-queue-icon-btn__svg" width="18" height="18" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round" aria-hidden="true"><path d="M3 6h18"/><path d="M19 6v14a2 2 0 0 1-2 2H7a2 2 0 0 1-2-2V6"/><path d="M8 6V4a2 2 0 0 1 2-2h4a2 2 0 0 1 2 2v2"/><path d="M10 11v6"/><path d="M14 11v6"/></svg></button>` : ''}
</div>
</div>
</div>
+331 -82
View File
@@ -1,5 +1,43 @@
// 漏洞管理相关功能
function vulnT(key, opts) {
if (typeof window.t === 'function') {
return window.t(key, opts);
}
return key;
}
function vulnDateLocale() {
try {
const lang = (window.__locale || '').toLowerCase();
if (lang.indexOf('zh') === 0) {
return 'zh-CN';
}
} catch (e) { /* ignore */ }
return 'en-US';
}
function vulnSeverityLabel(code) {
const m = {
critical: 'dashboard.severityCritical',
high: 'dashboard.severityHigh',
medium: 'dashboard.severityMedium',
low: 'dashboard.severityLow',
info: 'dashboard.severityInfo'
};
return m[code] ? vulnT(m[code]) : code;
}
function vulnStatusLabel(code) {
const m = {
open: 'vulnerabilityPage.statusOpen',
confirmed: 'vulnerabilityPage.statusConfirmed',
fixed: 'vulnerabilityPage.statusFixed',
false_positive: 'vulnerabilityPage.statusFalsePositive'
};
return m[code] ? vulnT(m[code]) : code;
}
// 从localStorage读取每页显示数量,默认为20
const getVulnerabilityPageSize = () => {
const saved = localStorage.getItem('vulnerabilityPageSize');
@@ -10,6 +48,9 @@ let currentVulnerabilityId = null;
let vulnerabilityFilters = {
id: '',
conversation_id: '',
task_id: '',
conversation_tag: '',
task_tag: '',
severity: '',
status: ''
};
@@ -20,10 +61,51 @@ let vulnerabilityPagination = {
totalPages: 1
};
// 从地址栏 #vulnerabilities?conversation_id= / ?task_id= / ?id= 同步筛选(通知/对话菜单/任务管理联动)
function syncVulnerabilityFiltersFromLocationHash() {
const hash = window.location.hash.slice(1);
const hashParts = hash.split('?');
if (hashParts[0] !== 'vulnerabilities' || hashParts.length < 2) {
return;
}
const params = new URLSearchParams(hashParts.slice(1).join('?'));
const vid = (params.get('id') || '').trim();
const cid = (params.get('conversation_id') || '').trim();
const tid = (params.get('task_id') || '').trim();
if (!vid && !cid && !tid) {
return;
}
vulnerabilityFilters.id = '';
vulnerabilityFilters.conversation_id = '';
vulnerabilityFilters.task_id = '';
const idEl = document.getElementById('vulnerability-id-filter');
const convEl = document.getElementById('vulnerability-conversation-filter');
const taskEl = document.getElementById('vulnerability-task-filter');
if (idEl) idEl.value = '';
if (convEl) convEl.value = '';
if (taskEl) taskEl.value = '';
if (vid) {
vulnerabilityFilters.id = vid;
if (idEl) idEl.value = vid;
}
if (cid) {
vulnerabilityFilters.conversation_id = cid;
if (convEl) convEl.value = cid;
}
if (tid) {
vulnerabilityFilters.task_id = tid;
if (taskEl) taskEl.value = tid;
}
vulnerabilityPagination.currentPage = 1;
}
// 初始化漏洞管理页面
function initVulnerabilityPage() {
// 从localStorage加载每页条数设置
vulnerabilityPagination.pageSize = getVulnerabilityPageSize();
syncVulnerabilityFiltersFromLocationHash();
loadVulnerabilityStats();
loadVulnerabilities();
}
@@ -41,6 +123,9 @@ async function loadVulnerabilityStats() {
if (vulnerabilityFilters.conversation_id) {
params.append('conversation_id', vulnerabilityFilters.conversation_id);
}
if (vulnerabilityFilters.task_id) {
params.append('task_id', vulnerabilityFilters.task_id);
}
const response = await apiFetch(`/api/vulnerabilities/stats?${params.toString()}`);
if (!response.ok) {
@@ -82,7 +167,7 @@ function updateVulnerabilityStats(stats) {
// 加载漏洞列表
async function loadVulnerabilities(page = null) {
const listContainer = document.getElementById('vulnerabilities-list');
listContainer.innerHTML = '<div class="loading-spinner">加载中...</div>';
listContainer.innerHTML = `<div class="loading-spinner">${escapeHtml(vulnT('vulnerabilityPage.loading'))}</div>`;
try {
// 检查apiFetch是否可用
@@ -106,6 +191,15 @@ async function loadVulnerabilities(page = null) {
if (vulnerabilityFilters.conversation_id) {
params.append('conversation_id', vulnerabilityFilters.conversation_id);
}
if (vulnerabilityFilters.task_id) {
params.append('task_id', vulnerabilityFilters.task_id);
}
if (vulnerabilityFilters.conversation_tag) {
params.append('conversation_tag', vulnerabilityFilters.conversation_tag);
}
if (vulnerabilityFilters.task_tag) {
params.append('task_tag', vulnerabilityFilters.task_tag);
}
if (vulnerabilityFilters.severity) {
params.append('severity', vulnerabilityFilters.severity);
}
@@ -148,7 +242,7 @@ async function loadVulnerabilities(page = null) {
renderVulnerabilityPagination();
} catch (error) {
console.error('加载漏洞列表失败:', error);
listContainer.innerHTML = `<div class="error-message">加载失败: ${error.message}</div>`;
listContainer.innerHTML = `<div class="error-message">${escapeHtml(vulnT('vulnerabilityPage.loadListFailed'))}: ${escapeHtml(error.message)}</div>`;
}
}
@@ -180,22 +274,12 @@ function renderVulnerabilities(vulnerabilities) {
const html = vulnerabilities.map(vuln => {
const severityClass = `severity-${vuln.severity}`;
const severityText = {
'critical': '严重',
'high': '高危',
'medium': '中危',
'low': '低危',
'info': '信息'
}[vuln.severity] || vuln.severity;
const statusText = {
'open': '待处理',
'confirmed': '已确认',
'fixed': '已修复',
'false_positive': '误报'
}[vuln.status] || vuln.status;
const createdDate = new Date(vuln.created_at).toLocaleString('zh-CN');
const severityText = vulnSeverityLabel(vuln.severity);
const statusText = vulnStatusLabel(vuln.status);
const createdDate = new Date(vuln.created_at).toLocaleString(vulnDateLocale());
const dlTitle = escapeHtml(vulnT('vulnerabilityPage.downloadMarkdownTitle'));
const editTitle = escapeHtml(vulnT('common.edit'));
const deleteTitle = escapeHtml(vulnT('common.delete'));
return `
<div class="vulnerability-card ${severityClass}">
@@ -214,20 +298,20 @@ function renderVulnerabilities(vulnerabilities) {
</div>
</div>
<div class="vulnerability-actions" onclick="event.stopPropagation();">
<button class="btn-ghost" onclick="downloadVulnerabilityAsMarkdown('${vuln.id}', event)" title="下载Markdown">
<button class="btn-ghost" onclick="downloadVulnerabilityAsMarkdown('${vuln.id}', event)" title="${dlTitle}">
<svg width="16" height="16" viewBox="0 0 24 24" fill="none" xmlns="http://www.w3.org/2000/svg">
<path d="M21 15v4a2 2 0 0 1-2 2H5a2 2 0 0 1-2-2v-4" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"/>
<polyline points="7 10 12 15 17 10" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"/>
<line x1="12" y1="15" x2="12" y2="3" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"/>
</svg>
</button>
<button class="btn-ghost" onclick="editVulnerability('${vuln.id}')" title="编辑">
<button class="btn-ghost" onclick="editVulnerability('${vuln.id}')" title="${editTitle}">
<svg width="16" height="16" viewBox="0 0 24 24" fill="none" xmlns="http://www.w3.org/2000/svg">
<path d="M11 4H4a2 2 0 0 0-2 2v14a2 2 0 0 0 2 2h14a2 2 0 0 0 2-2v-7" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"/>
<path d="M18.5 2.5a2.121 2.121 0 0 1 3 3L12 15l-4 1 1-4 9.5-9.5z" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"/>
</svg>
</button>
<button class="btn-ghost" onclick="deleteVulnerability('${vuln.id}')" title="删除">
<button class="btn-ghost" onclick="deleteVulnerability('${vuln.id}')" title="${deleteTitle}">
<svg width="16" height="16" viewBox="0 0 24 24" fill="none" xmlns="http://www.w3.org/2000/svg">
<path d="M3 6h18M8 6V4a2 2 0 0 1 2-2h4a2 2 0 0 1 2 2v2m3 0v14a2 2 0 0 1-2 2H7a2 2 0 0 1-2-2V6h14z" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"/>
</svg>
@@ -237,20 +321,34 @@ function renderVulnerabilities(vulnerabilities) {
<div class="vulnerability-content" id="content-${vuln.id}" style="display: none;">
${vuln.description ? `<div class="vulnerability-description">${escapeHtml(vuln.description)}</div>` : ''}
<div class="vulnerability-details">
<div class="detail-item"><strong>漏洞ID:</strong> <code>${escapeHtml(vuln.id)}</code></div>
${vuln.type ? `<div class="detail-item"><strong>类型:</strong> ${escapeHtml(vuln.type)}</div>` : ''}
${vuln.target ? `<div class="detail-item"><strong>目标:</strong> ${escapeHtml(vuln.target)}</div>` : ''}
<div class="detail-item"><strong>会话ID:</strong> <code>${escapeHtml(vuln.conversation_id)}</code></div>
${vulnDetailField(vulnT('vulnerabilityPage.detailVulnId'), vuln.id, true)}
${vuln.type ? vulnDetailField(vulnT('vulnerabilityPage.detailType'), vuln.type, false) : ''}
${vuln.target ? vulnDetailField(vulnT('vulnerabilityPage.detailTarget'), vuln.target, false) : ''}
${vulnDetailField(vulnT('vulnerabilityPage.detailConversationId'), vuln.conversation_id, true)}
${vuln.task_id ? vulnDetailField(vulnT('vulnerabilityPage.detailTaskId'), vuln.task_id, true) : ''}
${vuln.task_queue_id ? vulnDetailField(vulnT('vulnerabilityPage.detailTaskQueueId'), vuln.task_queue_id, true) : ''}
${vuln.conversation_tag ? vulnDetailField(vulnT('vulnerabilityPage.detailConversationTag'), vuln.conversation_tag, false) : ''}
${vuln.task_tag ? vulnDetailField(vulnT('vulnerabilityPage.detailTaskTag'), vuln.task_tag, false) : ''}
</div>
${vuln.proof ? `<div class="vulnerability-proof"><strong>证明:</strong><pre>${escapeHtml(vuln.proof)}</pre></div>` : ''}
${vuln.impact ? `<div class="vulnerability-impact"><strong>影响:</strong> ${escapeHtml(vuln.impact)}</div>` : ''}
${vuln.recommendation ? `<div class="vulnerability-recommendation"><strong>修复建议:</strong> ${escapeHtml(vuln.recommendation)}</div>` : ''}
${vuln.proof ? `<div class="vulnerability-proof"><strong>${escapeHtml(vulnT('vulnerabilityPage.detailProof'))}:</strong><pre>${escapeHtml(vuln.proof)}</pre></div>` : ''}
${vuln.impact ? `<div class="vulnerability-impact"><strong>${escapeHtml(vulnT('vulnerabilityPage.detailImpact'))}:</strong> ${escapeHtml(vuln.impact)}</div>` : ''}
${vuln.recommendation ? `<div class="vulnerability-recommendation"><strong>${escapeHtml(vulnT('vulnerabilityPage.detailRecommendation'))}:</strong> ${escapeHtml(vuln.recommendation)}</div>` : ''}
</div>
</div>
`;
}).join('');
listContainer.innerHTML = html;
if (typeof window.applyTranslations === 'function') {
window.applyTranslations(listContainer);
}
// 如果通过漏洞ID筛选且只返回一条记录,自动展开详情(提升“点击查看”的用户体验)
if (vulnerabilities.length === 1 && vulnerabilityFilters.id && vulnerabilityFilters.id === vulnerabilities[0].id) {
setTimeout(() => {
toggleVulnerabilityDetails(vulnerabilities[0].id);
}, 300);
}
}
// 渲染分页控件
@@ -277,9 +375,9 @@ function renderVulnerabilityPagination() {
// 左侧:显示范围信息和每页数量选择器(参考Skills样式)
paginationHTML += `
<div class="pagination-info">
<span>显示 ${start}-${end} / ${total} </span>
<span>${escapeHtml(vulnT('skillsPage.paginationShow', { start, end, total }))}</span>
<label class="pagination-page-size">
每页显示
${escapeHtml(vulnT('skillsPage.perPageLabel'))}
<select id="vulnerability-page-size-pagination" onchange="changeVulnerabilityPageSize()">
<option value="10" ${pageSize === 10 ? 'selected' : ''}>10</option>
<option value="20" ${pageSize === 20 ? 'selected' : ''}>20</option>
@@ -293,17 +391,20 @@ function renderVulnerabilityPagination() {
// 右侧:分页按钮(参考Skills样式:首页、上一页、第X/Y页、下一页、末页)
paginationHTML += `
<div class="pagination-controls">
<button class="btn-secondary" onclick="loadVulnerabilities(1)" ${currentPage === 1 || total === 0 ? 'disabled' : ''}>首页</button>
<button class="btn-secondary" onclick="loadVulnerabilities(${currentPage - 1})" ${currentPage === 1 || total === 0 ? 'disabled' : ''}>上一页</button>
<span class="pagination-page"> ${currentPage} / ${totalPages || 1} </span>
<button class="btn-secondary" onclick="loadVulnerabilities(${currentPage + 1})" ${currentPage >= totalPages || total === 0 ? 'disabled' : ''}>下一页</button>
<button class="btn-secondary" onclick="loadVulnerabilities(${totalPages || 1})" ${currentPage >= totalPages || total === 0 ? 'disabled' : ''}>末页</button>
<button class="btn-secondary" onclick="loadVulnerabilities(1)" ${currentPage === 1 || total === 0 ? 'disabled' : ''}>${escapeHtml(vulnT('skillsPage.firstPage'))}</button>
<button class="btn-secondary" onclick="loadVulnerabilities(${currentPage - 1})" ${currentPage === 1 || total === 0 ? 'disabled' : ''}>${escapeHtml(vulnT('skillsPage.prevPage'))}</button>
<span class="pagination-page">${escapeHtml(vulnT('skillsPage.pageOf', { current: currentPage, total: totalPages || 1 }))}</span>
<button class="btn-secondary" onclick="loadVulnerabilities(${currentPage + 1})" ${currentPage >= totalPages || total === 0 ? 'disabled' : ''}>${escapeHtml(vulnT('skillsPage.nextPage'))}</button>
<button class="btn-secondary" onclick="loadVulnerabilities(${totalPages || 1})" ${currentPage >= totalPages || total === 0 ? 'disabled' : ''}>${escapeHtml(vulnT('skillsPage.lastPage'))}</button>
</div>
`;
paginationHTML += '</div>';
paginationContainer.innerHTML = paginationHTML;
if (typeof window.applyTranslations === 'function') {
window.applyTranslations(paginationContainer);
}
}
// 改变每页显示数量
@@ -334,10 +435,12 @@ async function changeVulnerabilityPageSize() {
// 显示添加漏洞模态框
function showAddVulnerabilityModal() {
currentVulnerabilityId = null;
document.getElementById('vulnerability-modal-title').textContent = (typeof window.t === 'function' ? window.t('vulnerability.addVuln') : '添加漏洞');
document.getElementById('vulnerability-modal-title').textContent = vulnT('vulnerability.addVuln');
// 清空表单
document.getElementById('vulnerability-conversation-id').value = '';
document.getElementById('vulnerability-conversation-tag').value = '';
document.getElementById('vulnerability-task-tag').value = '';
document.getElementById('vulnerability-title').value = '';
document.getElementById('vulnerability-description').value = '';
document.getElementById('vulnerability-severity').value = '';
@@ -355,14 +458,16 @@ function showAddVulnerabilityModal() {
async function editVulnerability(id) {
try {
const response = await apiFetch(`/api/vulnerabilities/${id}`);
if (!response.ok) throw new Error('获取漏洞失败');
if (!response.ok) throw new Error(vulnT('vulnerabilityPage.fetchFailed'));
const vuln = await response.json();
currentVulnerabilityId = id;
document.getElementById('vulnerability-modal-title').textContent = (typeof window.t === 'function' ? window.t('vulnerability.editVuln') : '编辑漏洞');
document.getElementById('vulnerability-modal-title').textContent = vulnT('vulnerability.editVuln');
// 填充表单
document.getElementById('vulnerability-conversation-id').value = vuln.conversation_id || '';
document.getElementById('vulnerability-conversation-tag').value = vuln.conversation_tag || '';
document.getElementById('vulnerability-task-tag').value = vuln.task_tag || '';
document.getElementById('vulnerability-title').value = vuln.title || '';
document.getElementById('vulnerability-description').value = vuln.description || '';
document.getElementById('vulnerability-severity').value = vuln.severity || '';
@@ -376,7 +481,7 @@ async function editVulnerability(id) {
document.getElementById('vulnerability-modal').style.display = 'block';
} catch (error) {
console.error('加载漏洞失败:', error);
alert('加载漏洞失败: ' + error.message);
alert(vulnT('vulnerability.loadFailed') + ': ' + error.message);
}
}
@@ -387,12 +492,14 @@ async function saveVulnerability() {
const severity = document.getElementById('vulnerability-severity').value;
if (!conversationId || !title || !severity) {
alert('请填写必填字段:会话ID、标题和严重程度');
alert(vulnT('vulnerabilityPage.saveRequiredFields'));
return;
}
const data = {
conversation_id: conversationId,
conversation_tag: document.getElementById('vulnerability-conversation-tag').value.trim(),
task_tag: document.getElementById('vulnerability-task-tag').value.trim(),
title: title,
description: document.getElementById('vulnerability-description').value.trim(),
severity: severity,
@@ -420,7 +527,7 @@ async function saveVulnerability() {
if (!response.ok) {
const error = await response.json();
throw new Error(error.error || '保存失败');
throw new Error(error.error || vulnT('vulnerabilityPage.saveFailed'));
}
closeVulnerabilityModal();
@@ -430,13 +537,13 @@ async function saveVulnerability() {
loadVulnerabilities();
} catch (error) {
console.error('保存漏洞失败:', error);
alert('保存漏洞失败: ' + error.message);
alert(vulnT('vulnerabilityPage.saveFailed') + ': ' + error.message);
}
}
// 删除漏洞
async function deleteVulnerability(id) {
if (!confirm('确定要删除此漏洞吗?')) {
if (!confirm(vulnT('vulnerability.deleteConfirm'))) {
return;
}
@@ -445,7 +552,7 @@ async function deleteVulnerability(id) {
method: 'DELETE'
});
if (!response.ok) throw new Error('删除失败');
if (!response.ok) throw new Error(vulnT('vulnerabilityPage.deleteFailed'));
loadVulnerabilityStats();
// 删除后,如果当前页没有数据了,回到上一页
@@ -458,7 +565,7 @@ async function deleteVulnerability(id) {
loadVulnerabilities();
} catch (error) {
console.error('删除漏洞失败:', error);
alert('删除漏洞失败: ' + error.message);
alert(vulnT('vulnerabilityPage.deleteFailed') + ': ' + error.message);
}
}
@@ -472,6 +579,9 @@ function closeVulnerabilityModal() {
function filterVulnerabilities() {
vulnerabilityFilters.id = document.getElementById('vulnerability-id-filter').value.trim();
vulnerabilityFilters.conversation_id = document.getElementById('vulnerability-conversation-filter').value.trim();
vulnerabilityFilters.task_id = document.getElementById('vulnerability-task-filter').value.trim();
vulnerabilityFilters.conversation_tag = document.getElementById('vulnerability-conversation-tag-filter').value.trim();
vulnerabilityFilters.task_tag = document.getElementById('vulnerability-task-tag-filter').value.trim();
vulnerabilityFilters.severity = document.getElementById('vulnerability-severity-filter').value;
vulnerabilityFilters.status = document.getElementById('vulnerability-status-filter').value;
@@ -486,12 +596,18 @@ function filterVulnerabilities() {
function clearVulnerabilityFilters() {
document.getElementById('vulnerability-id-filter').value = '';
document.getElementById('vulnerability-conversation-filter').value = '';
document.getElementById('vulnerability-task-filter').value = '';
document.getElementById('vulnerability-conversation-tag-filter').value = '';
document.getElementById('vulnerability-task-tag-filter').value = '';
document.getElementById('vulnerability-severity-filter').value = '';
document.getElementById('vulnerability-status-filter').value = '';
vulnerabilityFilters = {
id: '',
conversation_id: '',
task_id: '',
conversation_tag: '',
task_tag: '',
severity: '',
status: ''
};
@@ -532,67 +648,193 @@ function escapeHtml(text) {
return div.innerHTML;
}
// 将漏洞格式化为Markdown
/** 复制详情字段(编码由 encodeURIComponent 传入,避免引号截断) */
function vulnerabilityCopyEncoded(evt, encoded) {
if (evt && evt.stopPropagation) {
evt.stopPropagation();
}
let text = '';
try {
text = decodeURIComponent(encoded);
} catch (e) {
return;
}
const done = () => {
if (evt && evt.target && evt.target.closest) {
const btn = evt.target.closest('.vuln-detail-field__copy');
if (btn) {
const t0 = btn.getAttribute('title') || '';
btn.setAttribute('title', vulnT('common.copied'));
setTimeout(() => btn.setAttribute('title', t0), 1600);
}
}
};
if (navigator.clipboard && typeof navigator.clipboard.writeText === 'function') {
navigator.clipboard.writeText(text).then(done).catch(() => {
try {
const ta = document.createElement('textarea');
ta.value = text;
ta.style.position = 'fixed';
ta.style.left = '-9999px';
document.body.appendChild(ta);
ta.select();
document.execCommand('copy');
document.body.removeChild(ta);
done();
} catch (err) {
console.error('copy failed', err);
}
});
} else {
try {
const ta = document.createElement('textarea');
ta.value = text;
ta.style.position = 'fixed';
ta.style.left = '-9999px';
document.body.appendChild(ta);
ta.select();
document.execCommand('copy');
document.body.removeChild(ta);
done();
} catch (err) {
console.error('copy failed', err);
}
}
}
function vulnDetailField(label, value, asCode) {
if (value === undefined || value === null || String(value) === '') {
return '';
}
const s = String(value);
const enc = encodeURIComponent(s);
const copyTitle = escapeHtml(vulnT('common.copy'));
const valueEl = asCode
? `<code class="vuln-detail-field-value">${escapeHtml(s)}</code>`
: `<span class="vuln-detail-field-value">${escapeHtml(s)}</span>`;
const copyBtn = `<button type="button" class="vuln-detail-field__copy" onclick="vulnerabilityCopyEncoded(event, '${enc}')" title="${copyTitle}" aria-label="${copyTitle}">
<svg width="14" height="14" viewBox="0 0 24 24" fill="none" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"><rect x="9" y="9" width="13" height="13" rx="2"/><path d="M5 15H4a2 2 0 0 1-2-2V4a2 2 0 0 1 2-2h9a2 2 0 0 1 2 2v1"/></svg>
</button>`;
return `<div class="vuln-detail-field">
<div class="vuln-detail-field__label">${escapeHtml(label)}</div>
<div class="vuln-detail-field__row">${valueEl}${copyBtn}</div>
</div>`;
}
// 将漏洞格式化为Markdown(章节标题随界面语言)
function formatVulnerabilityAsMarkdown(vuln) {
const severityText = {
'critical': '严重',
'high': '高危',
'medium': '中危',
'low': '低危',
'info': '信息'
}[vuln.severity] || vuln.severity;
const statusText = {
'open': '待处理',
'confirmed': '已确认',
'fixed': '已修复',
'false_positive': '误报'
}[vuln.status] || vuln.status;
const createdDate = new Date(vuln.created_at).toLocaleString('zh-CN');
const updatedDate = new Date(vuln.updated_at).toLocaleString('zh-CN');
const severityText = vulnSeverityLabel(vuln.severity);
const statusText = vulnStatusLabel(vuln.status);
const loc = vulnDateLocale();
const createdDate = new Date(vuln.created_at).toLocaleString(loc);
const updatedDate = new Date(vuln.updated_at).toLocaleString(loc);
const L = (k) => vulnT('vulnerabilityMd.' + k);
let markdown = `# ${vuln.title}\n\n`;
markdown += `## 基本信息\n\n`;
markdown += `- **漏洞ID**: \`${vuln.id}\`\n`;
markdown += `- **严重程度**: ${severityText}\n`;
markdown += `- **状态**: ${statusText}\n`;
markdown += `## ${L('headingBasic')}\n\n`;
markdown += `- **${L('labelId')}**: \`${vuln.id}\`\n`;
markdown += `- **${L('labelSeverity')}**: ${severityText}\n`;
markdown += `- **${L('labelStatus')}**: ${statusText}\n`;
if (vuln.type) {
markdown += `- **类型**: ${vuln.type}\n`;
markdown += `- **${L('labelType')}**: ${vuln.type}\n`;
}
if (vuln.target) {
markdown += `- **目标**: ${vuln.target}\n`;
markdown += `- **${L('labelTarget')}**: ${vuln.target}\n`;
}
markdown += `- **会话ID**: \`${vuln.conversation_id}\`\n`;
markdown += `- **创建时间**: ${createdDate}\n`;
markdown += `- **更新时间**: ${updatedDate}\n\n`;
markdown += `- **${L('labelConversationId')}**: \`${vuln.conversation_id}\`\n`;
if (vuln.task_id) {
markdown += `- **${L('labelTaskId')}**: \`${vuln.task_id}\`\n`;
}
if (vuln.task_queue_id) {
markdown += `- **${L('labelTaskQueueId')}**: \`${vuln.task_queue_id}\`\n`;
}
if (vuln.conversation_tag) {
markdown += `- **${L('labelConversationTag')}**: ${vuln.conversation_tag}\n`;
}
if (vuln.task_tag) {
markdown += `- **${L('labelTaskTag')}**: ${vuln.task_tag}\n`;
}
markdown += `- **${L('labelCreated')}**: ${createdDate}\n`;
markdown += `- **${L('labelUpdated')}**: ${updatedDate}\n\n`;
if (vuln.description) {
markdown += `## 描述\n\n${vuln.description}\n\n`;
markdown += `## ${L('headingDescription')}\n\n${vuln.description}\n\n`;
}
if (vuln.proof) {
markdown += `## 证明(POC\n\n\`\`\`\n${vuln.proof}\n\`\`\`\n\n`;
markdown += `## ${L('headingProof')}\n\n\`\`\`\n${vuln.proof}\n\`\`\`\n\n`;
}
if (vuln.impact) {
markdown += `## 影响\n\n${vuln.impact}\n\n`;
markdown += `## ${L('headingImpact')}\n\n${vuln.impact}\n\n`;
}
if (vuln.recommendation) {
markdown += `## 修复建议\n\n${vuln.recommendation}\n\n`;
markdown += `## ${L('headingRecommendation')}\n\n${vuln.recommendation}\n\n`;
}
return markdown;
}
function buildVulnerabilityFilterParams() {
const params = new URLSearchParams();
const keys = ['id', 'conversation_id', 'task_id', 'conversation_tag', 'task_tag', 'severity', 'status'];
keys.forEach((k) => {
if (vulnerabilityFilters[k]) {
params.append(k, vulnerabilityFilters[k]);
}
});
return params;
}
function triggerTextDownload(fileName, content) {
const blob = new Blob([content], { type: 'text/markdown;charset=utf-8' });
const url = URL.createObjectURL(blob);
const link = document.createElement('a');
link.href = url;
link.download = fileName;
document.body.appendChild(link);
link.click();
document.body.removeChild(link);
URL.revokeObjectURL(url);
}
async function exportVulnerabilityReports() {
try {
const params = buildVulnerabilityFilterParams();
params.set('mode', 'summary');
params.set('group_by', 'conversation');
const response = await apiFetch(`/api/vulnerabilities/export?${params.toString()}`);
if (!response.ok) {
const error = await response.json().catch(() => ({ error: vulnT('vulnerabilityPage.exportFailedMessage') }));
throw new Error(error.error || vulnT('vulnerabilityPage.exportFailedMessage'));
}
const data = await response.json();
const files = Array.isArray(data.files) ? data.files : [];
if (!files.length) {
alert(vulnT('vulnerabilityPage.exportNoResults'));
return;
}
files.forEach((file, idx) => {
setTimeout(() => triggerTextDownload(file.filename || `vulnerability-export-${idx + 1}.md`, file.content || ''), idx * 120);
});
if (files.length > 1) {
alert(vulnT('vulnerabilityPage.exportStarted', { count: files.length }));
}
} catch (error) {
console.error('导出漏洞报告失败:', error);
alert(vulnT('vulnerabilityPage.exportFailed') + ': ' + error.message);
}
}
// 下载漏洞为Markdown格式
async function downloadVulnerabilityAsMarkdown(id, event) {
try {
const response = await apiFetch(`/api/vulnerabilities/${id}`);
if (!response.ok) {
throw new Error('获取漏洞失败');
throw new Error(vulnT('vulnerabilityPage.fetchFailed'));
}
const vuln = await response.json();
@@ -626,8 +868,8 @@ async function downloadVulnerabilityAsMarkdown(id, event) {
if (event && event.target) {
const button = event.target.closest('button');
if (button) {
const originalTitle = button.title || '下载Markdown';
button.title = '下载成功!';
const originalTitle = button.title || vulnT('vulnerabilityPage.downloadMarkdownTitle');
button.title = vulnT('vulnerabilityPage.downloadOkTitle');
setTimeout(() => {
button.title = originalTitle;
}, 2000);
@@ -635,7 +877,7 @@ async function downloadVulnerabilityAsMarkdown(id, event) {
}
} catch (error) {
console.error('下载失败:', error);
alert('下载失败: ' + error.message);
alert(vulnT('vulnerabilityPage.downloadFailed') + ': ' + error.message);
}
}
@@ -645,5 +887,12 @@ window.onclick = function(event) {
if (event.target === modal) {
closeVulnerabilityModal();
}
}
};
document.addEventListener('languagechange', function () {
const page = document.getElementById('page-vulnerabilities');
if (page && page.classList.contains('active')) {
loadVulnerabilities();
}
});
+149 -39
View File
@@ -39,6 +39,100 @@ let webshellStreamingTypingId = 0;
let webshellProbeStatusById = {};
let webshellBatchProbeRunning = false;
/** 允许的响应编码,与后端 normalizeWebshellEncoding 对齐 */
const WEBSHELL_ALLOWED_ENCODINGS = ['auto', 'utf-8', 'gbk', 'gb18030'];
/** 归一化连接的 encoding 字段,返回 'auto' | 'utf-8' | 'gbk' | 'gb18030'(空/未知 → auto */
function normalizeWebshellEncoding(v) {
var s = (v == null ? '' : String(v)).trim().toLowerCase();
if (s === 'utf8') s = 'utf-8';
if (!s) return 'auto';
return WEBSHELL_ALLOWED_ENCODINGS.indexOf(s) >= 0 ? s : 'auto';
}
/** 从连接对象取编码,便于透传到 /api/webshell/exec 与 /api/webshell/file */
function webshellConnEncoding(conn) {
return normalizeWebshellEncoding(conn && conn.encoding);
}
/** 允许的目标 OS,与后端 normalizeWebshellOS 对齐 */
const WEBSHELL_ALLOWED_OS = ['auto', 'linux', 'windows'];
/** 归一化连接的 os 字段,返回 'auto' | 'linux' | 'windows'(空/未知 → auto */
function normalizeWebshellOS(v) {
var s = (v == null ? '' : String(v)).trim().toLowerCase();
if (!s) return 'auto';
return WEBSHELL_ALLOWED_OS.indexOf(s) >= 0 ? s : 'auto';
}
/** 从连接对象取目标 OS,便于透传到 /api/webshell/exec 与 /api/webshell/file */
function webshellConnOS(conn) {
return normalizeWebshellOS(conn && conn.os);
}
/**
* 组装 /api/webshell/file 的公共请求体
* 所有文件管理调用点都应走此函数避免遗漏字段 connection_id
* @param {Object} conn 连接对象
* @param {Object} extra 额外字段action / path / content / target_path / chunk_index ...
* @returns {string} JSON 字符串
*/
function webshellFileRequestBody(conn, extra) {
const base = {
url: conn.url,
password: conn.password || '',
type: conn.type || 'php',
method: (conn.method || 'post').toLowerCase(),
cmd_param: conn.cmdParam || '',
encoding: webshellConnEncoding(conn),
os: webshellConnOS(conn),
connection_id: conn.id || ''
};
const merged = Object.assign(base, extra || {});
return JSON.stringify(merged);
}
/**
* 当服务端探活命中目标系统 auto 连接首次列目录时出现
* 把结果同步到本地 webshellConnections 缓存 + 持久化到数据库
* 后续刷新不再探活AI 也能直接看到正确的 OS 上下文
*/
function applyWebshellDetectedOS(conn, data) {
if (!conn || !data || !data.detected_os) return;
const detected = normalizeWebshellOS(data.detected_os);
if (detected !== 'linux' && detected !== 'windows') return;
if (webshellConnOS(conn) !== 'auto') return; // 用户已显式配置,尊重之
conn.os = detected;
if (Array.isArray(webshellConnections)) {
for (var i = 0; i < webshellConnections.length; i++) {
if (webshellConnections[i] && webshellConnections[i].id === conn.id) {
webshellConnections[i].os = detected;
break;
}
}
}
if (typeof renderWebshellList === 'function') {
try { renderWebshellList(); } catch (e) {}
}
// 服务端已经回写了 DB;但极少数情况下调用方未带 connection_id,这里再兜底 PUT 一次
if (conn.id && typeof apiFetch === 'function') {
apiFetch('/api/webshell/connections/' + encodeURIComponent(conn.id), {
method: 'PUT',
headers: { 'Content-Type': 'application/json' },
body: JSON.stringify({
url: conn.url,
password: conn.password || '',
type: conn.type || 'php',
method: conn.method || 'post',
cmd_param: conn.cmdParam || '',
remark: conn.remark || '',
encoding: conn.encoding || 'auto',
os: detected
})
}).catch(function () {});
}
}
/** 与主对话页一致:Eino 模式走 /api/multi-agent/streambody 带 orchestration */
function resolveWebshellAiStreamRequest() {
if (typeof apiFetch === 'undefined') {
@@ -335,6 +429,17 @@ function wsT(key) {
'webshell.addConnection': '添加连接',
'webshell.cmdParam': '命令参数名',
'webshell.cmdParamPlaceholder': '不填默认为 cmd,如填 xxx 则请求为 xxx=命令',
'webshell.encoding': '响应编码',
'webshell.encodingAuto': '自动检测',
'webshell.encodingUtf8': 'UTF-8',
'webshell.encodingGbk': 'GBK(中文 Windows',
'webshell.encodingGb18030': 'GB18030',
'webshell.encodingHint': '中文 Windows 目标若出现乱码,请切换为 GBK 或 GB18030',
'webshell.os': '目标系统',
'webshell.osAuto': '自动(按 Shell 类型推断)',
'webshell.osLinux': 'Linux / Unix',
'webshell.osWindows': 'Windows',
'webshell.osHint': '决定文件管理/上传使用 Linux 还是 Windows 命令;PHP/JSP 跑在 Windows 上请选 Windows',
'webshell.connections': '连接列表',
'webshell.noConnections': '暂无连接,请点击「添加连接」',
'webshell.selectOrAdd': '请从左侧选择连接,或添加新的 WebShell 连接',
@@ -661,9 +766,20 @@ function renderWebshellList() {
} else if (probe && probe.state === 'fail') {
probeHtml = '<span class="webshell-probe-badge fail" title="' + escapeHtml(probe.message || '') + '">' + (wsT('webshell.probeOffline') || '离线') + '</span>';
}
var encNorm = normalizeWebshellEncoding(conn.encoding);
var encHtml = '';
if (encNorm && encNorm !== 'auto') {
encHtml = '<span class="webshell-probe-badge" title="' + escapeHtml(wsT('webshell.encoding') || '响应编码') + '">' + escapeHtml(encNorm.toUpperCase()) + '</span>';
}
var osNorm = normalizeWebshellOS(conn.os);
var osHtml = '';
if (osNorm && osNorm !== 'auto') {
var osLabel = osNorm === 'windows' ? 'WIN' : 'LINUX';
osHtml = '<span class="webshell-probe-badge" title="' + escapeHtml(wsT('webshell.os') || '目标系统') + '">' + osLabel + '</span>';
}
return (
'<div class="webshell-item' + active + '" data-id="' + safeId + '">' +
'<div class="webshell-item-remark-row"><div class="webshell-item-remark" title="' + urlTitle + '">' + remark + '</div>' + probeHtml + '</div>' +
'<div class="webshell-item-remark-row"><div class="webshell-item-remark" title="' + urlTitle + '">' + remark + '</div>' + probeHtml + osHtml + encHtml + '</div>' +
'<div class="webshell-item-url" title="' + urlTitle + '">' + url + '</div>' +
'<div class="webshell-item-actions">' +
'<details class="webshell-conn-actions"><summary class="btn-ghost btn-sm webshell-conn-actions-btn" title="' + actionsLabel + '">' + actionsLabel + '</summary>' +
@@ -709,6 +825,8 @@ function probeWebshellConnection(conn) {
type: conn.type || 'php',
method: ((conn.method || 'post').toLowerCase() === 'get') ? 'get' : 'post',
cmd_param: conn.cmdParam || '',
encoding: webshellConnEncoding(conn),
os: webshellConnOS(conn),
command: 'echo 1'
})
})
@@ -2881,17 +2999,6 @@ function runWebshellAiSend(conn, inputEl, sendBtn, messagesContainer) {
} else if (_et === 'warning') {
appendTimelineItem('warning', '⚠️ ' + (_em || ''), '', _ed);
// ─── Eino recovery ───
} else if (_et === 'eino_recovery') {
var runIdx = _ed.runIndex != null ? _ed.runIndex : (_ed.einoRetry != null ? _ed.einoRetry + 1 : 1);
var maxRuns = _ed.maxRuns != null ? _ed.maxRuns : 3;
var recTitle = wsTOr('chat.einoRecoveryTitle', '') ||
('🔄 工具参数无效 · 第 ' + runIdx + '/' + maxRuns + ' 轮(已追加提示)');
if (typeof window.t === 'function') {
try { recTitle = window.t('chat.einoRecoveryTitle', { n: runIdx, max: maxRuns }); } catch (e) { /* */ }
}
appendTimelineItem('eino_recovery', recTitle, _em, _ed);
// ─── Tool calls ───
} else if (_et === 'tool_calls_detected' && _ed) {
var count = _ed.count || 0;
@@ -3376,6 +3483,8 @@ function execWebshellCommand(conn, command) {
type: conn.type || 'php',
method: (conn.method || 'post').toLowerCase(),
cmd_param: conn.cmdParam || '',
encoding: webshellConnEncoding(conn),
os: webshellConnOS(conn),
command: command
})
}).then(function (r) { return r.json(); })
@@ -3402,17 +3511,10 @@ function webshellFileListDir(conn, path) {
apiFetch('/api/webshell/file', {
method: 'POST',
headers: { 'Content-Type': 'application/json' },
body: JSON.stringify({
url: conn.url,
password: conn.password || '',
type: conn.type || 'php',
method: (conn.method || 'post').toLowerCase(),
cmd_param: conn.cmdParam || '',
action: 'list',
path: path
})
body: webshellFileRequestBody(conn, { action: 'list', path: path })
}).then(function (r) { return r.json(); })
.then(function (data) {
applyWebshellDetectedOS(conn, data);
if (!data.ok && data.error) {
listEl.innerHTML = '<div class="webshell-file-error">' + escapeHtml(data.error) + '</div><pre class="webshell-file-raw">' + escapeHtml(data.output || '') + '</pre>';
return;
@@ -3508,16 +3610,9 @@ function fetchWebshellDirectoryItems(conn, path) {
return apiFetch('/api/webshell/file', {
method: 'POST',
headers: { 'Content-Type': 'application/json' },
body: JSON.stringify({
url: conn.url,
password: conn.password || '',
type: conn.type || 'php',
method: (conn.method || 'post').toLowerCase(),
cmd_param: conn.cmdParam || '',
action: 'list',
path: path
})
body: webshellFileRequestBody(conn, { action: 'list', path: path })
}).then(function (r) { return r.json(); }).then(function (data) {
applyWebshellDetectedOS(conn, data);
if (!data || data.error || !data.ok) return [];
return parseWebshellListItems(data.output || '');
}).catch(function () {
@@ -3812,7 +3907,7 @@ function webshellFileMkdir(conn, pathInput) {
var name = prompt(wsT('webshell.newDir') || '新建目录', 'newdir');
if (name == null || !name.trim()) return;
var path = base === '.' ? name.trim() : base + '/' + name.trim();
apiFetch('/api/webshell/file', { method: 'POST', headers: { 'Content-Type': 'application/json' }, body: JSON.stringify({ url: conn.url, password: conn.password || '', type: conn.type || 'php', method: (conn.method || 'post').toLowerCase(), cmd_param: conn.cmdParam || '', action: 'mkdir', path: path }) })
apiFetch('/api/webshell/file', { method: 'POST', headers: { 'Content-Type': 'application/json' }, body: webshellFileRequestBody(conn, { action: 'mkdir', path: path }) })
.then(function (r) { return r.json(); })
.then(function () { webshellFileListDir(conn, base); })
.catch(function () { webshellFileListDir(conn, base); });
@@ -3859,7 +3954,7 @@ function webshellFileUpload(conn, pathInput) {
webshellFileListDir(conn, base);
return;
}
apiFetch('/api/webshell/file', { method: 'POST', headers: { 'Content-Type': 'application/json' }, body: JSON.stringify({ url: conn.url, password: conn.password || '', type: conn.type || 'php', method: (conn.method || 'post').toLowerCase(), cmd_param: conn.cmdParam || '', action: 'upload_chunk', path: path, content: base64Chunks[idx], chunk_index: idx }) })
apiFetch('/api/webshell/file', { method: 'POST', headers: { 'Content-Type': 'application/json' }, body: webshellFileRequestBody(conn, { action: 'upload_chunk', path: path, content: base64Chunks[idx], chunk_index: idx }) })
.then(function (r) { return r.json(); })
.then(function () { idx++; sendNext(); })
.catch(function () { idx++; sendNext(); });
@@ -3878,7 +3973,7 @@ function webshellFileRename(conn, oldPath, oldName, listEl) {
var parts = oldPath.split('/');
var dir = parts.length > 1 ? parts.slice(0, -1).join('/') + '/' : '';
var newPath = dir + newName.trim();
apiFetch('/api/webshell/file', { method: 'POST', headers: { 'Content-Type': 'application/json' }, body: JSON.stringify({ url: conn.url, password: conn.password || '', type: conn.type || 'php', method: (conn.method || 'post').toLowerCase(), cmd_param: conn.cmdParam || '', action: 'rename', path: oldPath, target_path: newPath }) })
apiFetch('/api/webshell/file', { method: 'POST', headers: { 'Content-Type': 'application/json' }, body: webshellFileRequestBody(conn, { action: 'rename', path: oldPath, target_path: newPath }) })
.then(function (r) { return r.json(); })
.then(function () { webshellFileListDir(conn, document.getElementById('webshell-file-path').value.trim() || '.'); })
.catch(function () { webshellFileListDir(conn, document.getElementById('webshell-file-path').value.trim() || '.'); });
@@ -3917,7 +4012,7 @@ function webshellFileDownload(conn, path) {
apiFetch('/api/webshell/file', {
method: 'POST',
headers: { 'Content-Type': 'application/json' },
body: JSON.stringify({ url: conn.url, password: conn.password || '', type: conn.type || 'php', method: (conn.method || 'post').toLowerCase(), cmd_param: conn.cmdParam || '', action: 'read', path: path })
body: webshellFileRequestBody(conn, { action: 'read', path: path })
}).then(function (r) { return r.json(); })
.then(function (data) {
var content = (data && data.output) != null ? data.output : (data.error || '');
@@ -3938,7 +4033,7 @@ function webshellFileRead(conn, path, listEl, browsePath) {
apiFetch('/api/webshell/file', {
method: 'POST',
headers: { 'Content-Type': 'application/json' },
body: JSON.stringify({ url: conn.url, password: conn.password || '', type: conn.type || 'php', method: (conn.method || 'post').toLowerCase(), cmd_param: conn.cmdParam || '', action: 'read', path: path })
body: webshellFileRequestBody(conn, { action: 'read', path: path })
}).then(function (r) { return r.json(); })
.then(function (data) {
const out = (data && data.output) ? data.output : (data.error || '');
@@ -3967,7 +4062,7 @@ function webshellFileEdit(conn, path, listEl) {
apiFetch('/api/webshell/file', {
method: 'POST',
headers: { 'Content-Type': 'application/json' },
body: JSON.stringify({ url: conn.url, password: conn.password || '', type: conn.type || 'php', method: (conn.method || 'post').toLowerCase(), cmd_param: conn.cmdParam || '', action: 'read', path: path })
body: webshellFileRequestBody(conn, { action: 'read', path: path })
}).then(function (r) { return r.json(); })
.then(function (data) {
const content = (data && data.output) ? data.output : (data.error || '');
@@ -4003,7 +4098,7 @@ function webshellFileWrite(conn, path, content, onDone, listEl) {
apiFetch('/api/webshell/file', {
method: 'POST',
headers: { 'Content-Type': 'application/json' },
body: JSON.stringify({ url: conn.url, password: conn.password || '', type: conn.type || 'php', method: (conn.method || 'post').toLowerCase(), cmd_param: conn.cmdParam || '', action: 'write', path: path, content: content })
body: webshellFileRequestBody(conn, { action: 'write', path: path, content: content })
}).then(function (r) { return r.json(); })
.then(function (data) {
if (data && !data.ok && data.error && listEl) {
@@ -4022,7 +4117,7 @@ function webshellFileDelete(conn, path, onDone) {
apiFetch('/api/webshell/file', {
method: 'POST',
headers: { 'Content-Type': 'application/json' },
body: JSON.stringify({ url: conn.url, password: conn.password || '', type: conn.type || 'php', method: (conn.method || 'post').toLowerCase(), cmd_param: conn.cmdParam || '', action: 'delete', path: path })
body: webshellFileRequestBody(conn, { action: 'delete', path: path })
}).then(function (r) { return r.json(); })
.then(function () { if (onDone) onDone(); })
.catch(function () { if (onDone) onDone(); });
@@ -4074,6 +4169,10 @@ function showAddWebshellModal() {
document.getElementById('webshell-type').value = 'php';
document.getElementById('webshell-method').value = 'post';
document.getElementById('webshell-cmd-param').value = '';
var osSelEl = document.getElementById('webshell-os');
if (osSelEl) osSelEl.value = 'auto';
var encSelEl = document.getElementById('webshell-encoding');
if (encSelEl) encSelEl.value = 'auto';
document.getElementById('webshell-remark').value = '';
var titleEl = document.getElementById('webshell-modal-title');
if (titleEl) titleEl.textContent = wsT('webshell.addConnection');
@@ -4092,6 +4191,10 @@ function showEditWebshellModal(connId) {
document.getElementById('webshell-type').value = conn.type || 'php';
document.getElementById('webshell-method').value = (conn.method || 'post').toLowerCase();
document.getElementById('webshell-cmd-param').value = conn.cmdParam || '';
var osEditEl = document.getElementById('webshell-os');
if (osEditEl) osEditEl.value = normalizeWebshellOS(conn.os);
var encEditEl = document.getElementById('webshell-encoding');
if (encEditEl) encEditEl.value = normalizeWebshellEncoding(conn.encoding);
document.getElementById('webshell-remark').value = conn.remark || '';
var titleEl = document.getElementById('webshell-modal-title');
if (titleEl) titleEl.textContent = wsT('webshell.editConnectionTitle');
@@ -4319,6 +4422,8 @@ function testWebshellConnection() {
var method = ((document.getElementById('webshell-method') || {}).value || 'post').toLowerCase();
var cmdParam = (document.getElementById('webshell-cmd-param') || {}).value;
if (cmdParam && typeof cmdParam.trim === 'function') cmdParam = cmdParam.trim(); else cmdParam = '';
var osTag = normalizeWebshellOS((document.getElementById('webshell-os') || {}).value);
var encoding = normalizeWebshellEncoding((document.getElementById('webshell-encoding') || {}).value);
var btn = document.getElementById('webshell-test-btn');
if (btn) { btn.disabled = true; btn.textContent = (typeof wsT === 'function' ? wsT('common.refresh') : '刷新') + '...'; }
if (typeof apiFetch === 'undefined') {
@@ -4326,6 +4431,7 @@ function testWebshellConnection() {
alert(wsT('webshell.testFailed') || '连通性测试失败');
return;
}
// 连通性使用 Windows/Linux 都识别的最小内建命令作为探测(echo 1 在 cmd 和 sh 下行为等价)
apiFetch('/api/webshell/exec', {
method: 'POST',
headers: { 'Content-Type': 'application/json' },
@@ -4335,6 +4441,8 @@ function testWebshellConnection() {
type: type,
method: method === 'get' ? 'get' : 'post',
cmd_param: cmdParam || '',
encoding: encoding,
os: osTag,
command: 'echo 1'
})
})
@@ -4380,12 +4488,14 @@ function saveWebshellConnection() {
var method = ((document.getElementById('webshell-method') || {}).value || 'post').toLowerCase();
var cmdParam = (document.getElementById('webshell-cmd-param') || {}).value;
if (cmdParam && typeof cmdParam.trim === 'function') cmdParam = cmdParam.trim(); else cmdParam = '';
var osTag = normalizeWebshellOS((document.getElementById('webshell-os') || {}).value);
var encoding = normalizeWebshellEncoding((document.getElementById('webshell-encoding') || {}).value);
var remark = (document.getElementById('webshell-remark') || {}).value;
if (remark && typeof remark.trim === 'function') remark = remark.trim(); else remark = '';
var editIdEl = document.getElementById('webshell-edit-id');
var editId = editIdEl ? editIdEl.value.trim() : '';
var body = { url: url, password: password, type: type, method: method === 'get' ? 'get' : 'post', cmd_param: cmdParam, remark: remark || url };
var body = { url: url, password: password, type: type, method: method === 'get' ? 'get' : 'post', cmd_param: cmdParam, encoding: encoding, os: osTag, remark: remark || url };
if (typeof apiFetch === 'undefined') return;
var reqUrl = editId ? ('/api/webshell/connections/' + encodeURIComponent(editId)) : '/api/webshell/connections';
+367 -81
View File
@@ -63,6 +63,24 @@
<div class="lang-option" data-lang="en-US" onclick="onLanguageSelect('en-US')">English</div>
</div>
</div>
<div class="notification-menu-container">
<button class="notification-btn" id="notification-bell-btn" onclick="toggleNotificationDropdown()" data-i18n="notifications.title" data-i18n-attr="title" data-i18n-skip-text="true" title="事件通知">
<svg width="18" height="18" viewBox="0 0 24 24" fill="none" xmlns="http://www.w3.org/2000/svg">
<path d="M18 8a6 6 0 0 0-12 0c0 7-3 9-3 9h18s-3-2-3-9" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"/>
<path d="M13.73 21a2 2 0 0 1-3.46 0" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"/>
</svg>
<span class="notification-badge" id="notification-badge" style="display: none;">0</span>
</button>
<div id="notification-dropdown" class="notification-dropdown" style="display: none;">
<div class="notification-dropdown-header">
<span id="notification-dropdown-title" data-i18n="notifications.title">事件通知</span>
<button class="notification-mark-read-btn" id="notification-mark-all-read-btn" type="button" onclick="markAllNotificationsSeen()" data-i18n="notifications.markAllRead">标记已读</button>
</div>
<div id="notification-list" class="notification-list">
<div class="notification-empty" data-i18n="notifications.empty">暂无新事件</div>
</div>
</div>
</div>
<div class="user-menu-container">
<button class="user-avatar-btn" onclick="toggleUserMenu()" data-i18n="header.userMenu" data-i18n-attr="title" data-i18n-skip-text="true" title="用户菜单">
<svg width="20" height="20" viewBox="0 0 24 24" fill="none" xmlns="http://www.w3.org/2000/svg">
@@ -287,41 +305,241 @@
<div class="page-header">
<h2 data-i18n="dashboard.title">仪表盘</h2>
<div class="page-header-actions">
<span class="dashboard-last-updated" id="dashboard-last-updated" aria-live="polite">
<svg class="dashboard-last-updated-icon" width="14" height="14" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round" aria-hidden="true"><circle cx="12" cy="12" r="10"/><polyline points="12 6 12 12 16 14"/></svg>
<span data-i18n="dashboard.lastUpdated">上次更新</span>
<span class="dashboard-last-updated-time" id="dashboard-last-updated-time">-</span>
<span class="dashboard-last-updated-stale" id="dashboard-last-updated-stale" hidden data-i18n="dashboard.dataStale" data-i18n-attr="title" title="数据可能已过期">
<svg width="12" height="12" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2.2" stroke-linecap="round" stroke-linejoin="round" aria-hidden="true"><path d="M10.29 3.86L1.82 18a2 2 0 0 0 1.71 3h16.94a2 2 0 0 0 1.71-3L13.71 3.86a2 2 0 0 0-3.42 0z"/><line x1="12" y1="9" x2="12" y2="13"/><line x1="12" y1="17" x2="12.01" y2="17"/></svg>
</span>
</span>
<button class="btn-secondary" onclick="refreshDashboard()" data-i18n="dashboard.refreshData" data-i18n-attr="title" title="刷新数据"><span data-i18n="common.refresh">刷新</span></button>
</div>
</div>
<div class="dashboard-content">
<!-- 第一行:核心 KPI(仪表盘最佳实践:关键指标置顶) -->
<!-- 关键提醒条(仅当存在严重风险时渲染,默认 hidden);右侧 × 可在 session 内忽略 -->
<div class="dashboard-alert-banner" id="dashboard-alert-banner" hidden>
<span class="dashboard-alert-icon" aria-hidden="true">
<svg width="20" height="20" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"><path d="M10.29 3.86L1.82 18a2 2 0 0 0 1.71 3h16.94a2 2 0 0 0 1.71-3L13.71 3.86a2 2 0 0 0-3.42 0z"/><line x1="12" y1="9" x2="12" y2="13"/><line x1="12" y1="17" x2="12.01" y2="17"/></svg>
</span>
<div class="dashboard-alert-content">
<div class="dashboard-alert-title" id="dashboard-alert-title" data-i18n="dashboard.alertTitle">需要关注</div>
<div class="dashboard-alert-desc" id="dashboard-alert-desc"></div>
</div>
<div class="dashboard-alert-actions" id="dashboard-alert-actions"></div>
<button type="button" class="dashboard-alert-close" id="dashboard-alert-close" data-i18n="dashboard.alertDismiss" data-i18n-attr="title" data-i18n-skip-text="true" title="忽略此提醒(仅本次会话)" aria-label="dismiss">
<svg width="14" height="14" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2.4" stroke-linecap="round" stroke-linejoin="round" aria-hidden="true"><line x1="18" y1="6" x2="6" y2="18"/><line x1="6" y1="6" x2="18" y2="18"/></svg>
</button>
</div>
<!-- 第一行:核心 KPI(关键指标置顶 + 副标徽章承载次级信息) -->
<div class="dashboard-kpi-row" id="dashboard-cards">
<div class="dashboard-kpi-card" role="button" tabindex="0" onclick="switchPage('tasks')" onkeydown="if(event.key==='Enter'||event.key===' ') { event.preventDefault(); switchPage('tasks'); }" data-i18n="dashboard.clickToViewTasks" data-i18n-attr="title" title="点击查看任务管理"> <div class="dashboard-kpi-value" id="dashboard-running-tasks">-</div><div class="dashboard-kpi-label" data-i18n="dashboard.runningTasks">运行中任务</div></div>
<div class="dashboard-kpi-card" role="button" tabindex="0" onclick="switchPage('vulnerabilities')" onkeydown="if(event.key==='Enter'||event.key===' ') { event.preventDefault(); switchPage('vulnerabilities'); }" data-i18n="dashboard.clickToViewVuln" data-i18n-attr="title" title="点击查看漏洞管理"><div class="dashboard-kpi-value" id="dashboard-vuln-total">-</div><div class="dashboard-kpi-label" data-i18n="dashboard.vulnTotal">漏洞总数</div></div>
<div class="dashboard-kpi-card" role="button" tabindex="0" onclick="switchPage('mcp-monitor')" onkeydown="if(event.key==='Enter'||event.key===' ') { event.preventDefault(); switchPage('mcp-monitor'); }" data-i18n="dashboard.clickToViewMCP" data-i18n-attr="title" title="点击查看 MCP 监控"><div class="dashboard-kpi-value" id="dashboard-kpi-tools-calls">-</div><div class="dashboard-kpi-label" data-i18n="dashboard.toolCalls">工具调用次数</div></div>
<div class="dashboard-kpi-card" role="button" tabindex="0" onclick="switchPage('mcp-monitor')" onkeydown="if(event.key==='Enter'||event.key===' ') { event.preventDefault(); switchPage('mcp-monitor'); }" data-i18n="dashboard.clickToViewMCP" data-i18n-attr="title" title="点击查看 MCP 监控"><div class="dashboard-kpi-value" id="dashboard-kpi-success-rate">-</div><div class="dashboard-kpi-label" data-i18n="dashboard.successRate">工具执行成功率</div></div>
<div class="dashboard-kpi-card" role="button" tabindex="0" onclick="switchPage('tasks')" onkeydown="if(event.key==='Enter'||event.key===' ') { event.preventDefault(); switchPage('tasks'); }" data-i18n="dashboard.clickToViewTasks" data-i18n-attr="title" title="点击查看任务管理">
<div class="dashboard-kpi-head">
<div class="dashboard-kpi-label" data-i18n="dashboard.runningTasks">运行中任务</div>
<span class="dashboard-kpi-icon dashboard-kpi-icon-tasks" aria-hidden="true"><svg width="18" height="18" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"><path d="M12 2v4"/><path d="M12 18v4"/><path d="M4.93 4.93l2.83 2.83"/><path d="M16.24 16.24l2.83 2.83"/><path d="M2 12h4"/><path d="M18 12h4"/><path d="M4.93 19.07l2.83-2.83"/><path d="M16.24 7.76l2.83-2.83"/></svg></span>
</div>
<div class="dashboard-kpi-value" id="dashboard-running-tasks">-</div>
<div class="dashboard-kpi-sub" id="dashboard-kpi-tasks-sub">
<span class="dashboard-kpi-sub-text" id="dashboard-kpi-tasks-sub-text">-</span>
</div>
</div>
<div class="dashboard-kpi-card" role="button" tabindex="0" onclick="switchPage('vulnerabilities')" onkeydown="if(event.key==='Enter'||event.key===' ') { event.preventDefault(); switchPage('vulnerabilities'); }" data-i18n="dashboard.clickToViewVuln" data-i18n-attr="title" title="点击查看漏洞管理">
<div class="dashboard-kpi-head">
<div class="dashboard-kpi-label" data-i18n="dashboard.vulnTotal">漏洞总数</div>
<span class="dashboard-kpi-icon dashboard-kpi-icon-vuln" aria-hidden="true"><svg width="18" height="18" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"><path d="M12 22s8-4 8-10V5l-8-3-8 3v7c0 6 8 10 8 10z"/></svg></span>
</div>
<div class="dashboard-kpi-value" id="dashboard-vuln-total">-</div>
<div class="dashboard-kpi-sub" id="dashboard-kpi-vuln-sub">
<span class="dashboard-kpi-sub-badge dashboard-kpi-sub-badge-critical" id="dashboard-kpi-vuln-critical-badge" hidden>
<span class="dashboard-kpi-sub-badge-dot"></span>
<span data-i18n="dashboard.severityCritical">严重</span>
<span id="dashboard-kpi-vuln-critical-count">0</span>
</span>
<span class="dashboard-kpi-sub-text" id="dashboard-kpi-vuln-sub-text" data-i18n="dashboard.allClear">暂无新增风险</span>
</div>
</div>
<div class="dashboard-kpi-card" role="button" tabindex="0" onclick="switchPage('mcp-monitor')" onkeydown="if(event.key==='Enter'||event.key===' ') { event.preventDefault(); switchPage('mcp-monitor'); }" data-i18n="dashboard.clickToViewMCP" data-i18n-attr="title" title="点击查看 MCP 监控">
<div class="dashboard-kpi-head">
<div class="dashboard-kpi-label" data-i18n="dashboard.toolCalls">工具调用次数</div>
<span class="dashboard-kpi-icon dashboard-kpi-icon-calls" aria-hidden="true"><svg width="18" height="18" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"><polyline points="22 12 18 12 15 21 9 3 6 12 2 12"/></svg></span>
</div>
<div class="dashboard-kpi-value" id="dashboard-kpi-tools-calls">-</div>
<div class="dashboard-kpi-sub">
<span class="dashboard-kpi-sub-text" id="dashboard-kpi-tools-sub-text">-</span>
</div>
</div>
<div class="dashboard-kpi-card" role="button" tabindex="0" onclick="switchPage('mcp-monitor')" onkeydown="if(event.key==='Enter'||event.key===' ') { event.preventDefault(); switchPage('mcp-monitor'); }" data-i18n="dashboard.clickToViewMCP" data-i18n-attr="title" title="点击查看 MCP 监控">
<div class="dashboard-kpi-head">
<div class="dashboard-kpi-label" data-i18n="dashboard.successRate">工具执行成功率</div>
<span class="dashboard-kpi-icon dashboard-kpi-icon-rate" aria-hidden="true"><svg width="18" height="18" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"><polyline points="20 6 9 17 4 12"/></svg></span>
</div>
<div class="dashboard-kpi-value" id="dashboard-kpi-success-rate">-</div>
<div class="dashboard-kpi-sub">
<span class="dashboard-kpi-sub-text" id="dashboard-kpi-rate-sub-text" data-i18n="dashboard.healthyStatus">运行平稳</span>
</div>
</div>
</div>
<!-- 两列主内容区 -->
<div class="dashboard-grid">
<div class="dashboard-main">
<section class="dashboard-section dashboard-section-chart">
<h3 class="dashboard-section-title" data-i18n="dashboard.severityDistribution">漏洞严重程度分布</h3>
<div class="dashboard-chart-wrap">
<div class="dashboard-stacked-bar" id="dashboard-stacked-bar">
<span class="dashboard-bar-seg seg-critical" id="dashboard-bar-critical" style="width: 0%"></span>
<span class="dashboard-bar-seg seg-high" id="dashboard-bar-high" style="width: 0%"></span>
<span class="dashboard-bar-seg seg-medium" id="dashboard-bar-medium" style="width: 0%"></span>
<span class="dashboard-bar-seg seg-low" id="dashboard-bar-low" style="width: 0%"></span>
<span class="dashboard-bar-seg seg-info" id="dashboard-bar-info" style="width: 0%"></span>
<div class="dashboard-section-header">
<h3 class="dashboard-section-title" data-i18n="dashboard.severityDistribution">漏洞严重程度分布</h3>
<a class="dashboard-section-link" onclick="switchPage('vulnerabilities')" data-i18n="dashboard.viewAll">查看全部 →</a>
</div>
<div class="dashboard-severity-wrap">
<!-- 风险概览卡:填充 donut 左侧留白;提供「结论性」洞察(风险等级/加权分/待处理计数/最新时间),
与右侧 legend 的「明细」形成互补,避免和下方「最近漏洞」列表重复 -->
<aside class="dashboard-severity-insights" aria-label="风险概览">
<div class="dashboard-severity-insight-risk" data-level="safe">
<div class="dashboard-severity-insight-head">
<span class="dashboard-severity-insight-label" data-i18n="dashboard.riskLevel">风险等级</span>
<span class="dashboard-severity-insight-risk-badge" id="dashboard-severity-risk-level" data-i18n="dashboard.riskSafe">安全</span>
</div>
<div class="dashboard-severity-insight-score-track" aria-hidden="true">
<div class="dashboard-severity-insight-score-fill" id="dashboard-severity-risk-fill" style="width: 0%"></div>
</div>
<div class="dashboard-severity-insight-score-meta">
<span class="dashboard-severity-insight-score-label" data-i18n="dashboard.riskScore">加权风险分</span>
<span class="dashboard-severity-insight-score-value" id="dashboard-severity-risk-score">0</span>
</div>
</div>
<div class="dashboard-severity-insight-urgent-group">
<span class="dashboard-severity-insight-label" data-i18n="dashboard.statusOpen">待处理</span>
<div class="dashboard-severity-insight-urgent">
<div class="dashboard-severity-insight-urgent-item u-critical" role="button" tabindex="0" onclick="switchPage('vulnerabilities')" onkeydown="if(event.key==='Enter'||event.key===' ') { event.preventDefault(); switchPage('vulnerabilities'); }" title="查看待处理严重漏洞">
<span class="dashboard-severity-insight-urgent-value" id="dashboard-severity-urgent-critical">0</span>
<span class="dashboard-severity-insight-urgent-label" data-i18n="dashboard.severityCritical">严重</span>
</div>
<div class="dashboard-severity-insight-urgent-item u-high" role="button" tabindex="0" onclick="switchPage('vulnerabilities')" onkeydown="if(event.key==='Enter'||event.key===' ') { event.preventDefault(); switchPage('vulnerabilities'); }" title="查看待处理高危漏洞">
<span class="dashboard-severity-insight-urgent-value" id="dashboard-severity-urgent-high">0</span>
<span class="dashboard-severity-insight-urgent-label" data-i18n="dashboard.severityHigh">高危</span>
</div>
</div>
</div>
<div class="dashboard-severity-insight-latest">
<span class="dashboard-severity-insight-label" data-i18n="dashboard.latestFound">最近发现</span>
<span class="dashboard-severity-insight-time" id="dashboard-severity-latest-time" data-i18n="dashboard.noneYet">暂无</span>
</div>
</aside>
<div class="dashboard-severity-chart">
<svg class="dashboard-severity-donut" id="dashboard-severity-donut" viewBox="0 0 480 260" preserveAspectRatio="xMidYMid meet" aria-hidden="true">
<g id="dashboard-severity-donut-track"></g>
<g id="dashboard-severity-donut-segments"></g>
<g id="dashboard-severity-donut-labels"></g>
</svg>
<div class="dashboard-severity-center">
<div class="dashboard-severity-center-value" id="dashboard-severity-total">0</div>
<div class="dashboard-severity-center-label" data-i18n="dashboard.totalVulns">总漏洞数</div>
</div>
</div>
<div class="dashboard-legend" id="dashboard-vuln-bars">
<div class="dashboard-legend-item"><span class="dashboard-legend-dot critical"></span><span class="dashboard-legend-label" data-i18n="dashboard.severityCritical">严重</span><span class="dashboard-legend-value" id="dashboard-severity-critical">0</span></div>
<div class="dashboard-legend-item"><span class="dashboard-legend-dot high"></span><span class="dashboard-legend-label" data-i18n="dashboard.severityHigh">高危</span><span class="dashboard-legend-value" id="dashboard-severity-high">0</span></div>
<div class="dashboard-legend-item"><span class="dashboard-legend-dot medium"></span><span class="dashboard-legend-label" data-i18n="dashboard.severityMedium">中危</span><span class="dashboard-legend-value" id="dashboard-severity-medium">0</span></div>
<div class="dashboard-legend-item"><span class="dashboard-legend-dot low"></span><span class="dashboard-legend-label" data-i18n="dashboard.severityLow">低危</span><span class="dashboard-legend-value" id="dashboard-severity-low">0</span></div>
<div class="dashboard-legend-item"><span class="dashboard-legend-dot info"></span><span class="dashboard-legend-label" data-i18n="dashboard.severityInfo">信息</span><span class="dashboard-legend-value" id="dashboard-severity-info">0</span></div>
<div class="dashboard-severity-legend" id="dashboard-vuln-bars">
<div class="dashboard-severity-legend-item">
<span class="dashboard-severity-legend-dot critical"></span>
<span class="dashboard-severity-legend-label" data-i18n="dashboard.severityCritical">严重</span>
<span class="dashboard-severity-legend-value" id="dashboard-severity-critical">0</span>
<span class="dashboard-severity-legend-pct" id="dashboard-severity-critical-pct">0%</span>
</div>
<div class="dashboard-severity-legend-item">
<span class="dashboard-severity-legend-dot high"></span>
<span class="dashboard-severity-legend-label" data-i18n="dashboard.severityHigh">高危</span>
<span class="dashboard-severity-legend-value" id="dashboard-severity-high">0</span>
<span class="dashboard-severity-legend-pct" id="dashboard-severity-high-pct">0%</span>
</div>
<div class="dashboard-severity-legend-item">
<span class="dashboard-severity-legend-dot medium"></span>
<span class="dashboard-severity-legend-label" data-i18n="dashboard.severityMedium">中危</span>
<span class="dashboard-severity-legend-value" id="dashboard-severity-medium">0</span>
<span class="dashboard-severity-legend-pct" id="dashboard-severity-medium-pct">0%</span>
</div>
<div class="dashboard-severity-legend-item">
<span class="dashboard-severity-legend-dot low"></span>
<span class="dashboard-severity-legend-label" data-i18n="dashboard.severityLow">低危</span>
<span class="dashboard-severity-legend-value" id="dashboard-severity-low">0</span>
<span class="dashboard-severity-legend-pct" id="dashboard-severity-low-pct">0%</span>
</div>
<div class="dashboard-severity-legend-item">
<span class="dashboard-severity-legend-dot info"></span>
<span class="dashboard-severity-legend-label" data-i18n="dashboard.severityInfo">信息</span>
<span class="dashboard-severity-legend-value" id="dashboard-severity-info">0</span>
<span class="dashboard-severity-legend-pct" id="dashboard-severity-info-pct">0%</span>
</div>
</div>
</div>
<!-- 处置状态 + 修复进度(利用 by_status 数据,避免下半部分留白) -->
<div class="dashboard-severity-status">
<div class="dashboard-severity-status-grid">
<div class="dashboard-severity-status-cell s-open" role="button" tabindex="0" onclick="switchPage('vulnerabilities')" onkeydown="if(event.key==='Enter'||event.key===' ') { event.preventDefault(); switchPage('vulnerabilities'); }">
<span class="dashboard-severity-status-icon" aria-hidden="true">
<svg width="16" height="16" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"><circle cx="12" cy="12" r="9"/><line x1="12" y1="8" x2="12" y2="12"/><line x1="12" y1="16" x2="12.01" y2="16"/></svg>
</span>
<div class="dashboard-severity-status-text">
<span class="dashboard-severity-status-value" id="dashboard-status-open">0</span>
<span class="dashboard-severity-status-label" data-i18n="dashboard.statusOpen">待处理</span>
</div>
</div>
<div class="dashboard-severity-status-cell s-confirmed" role="button" tabindex="0" onclick="switchPage('vulnerabilities')" onkeydown="if(event.key==='Enter'||event.key===' ') { event.preventDefault(); switchPage('vulnerabilities'); }">
<span class="dashboard-severity-status-icon" aria-hidden="true">
<svg width="16" height="16" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"><path d="M22 11.08V12a10 10 0 1 1-5.93-9.14"/><polyline points="22 4 12 14.01 9 11.01"/></svg>
</span>
<div class="dashboard-severity-status-text">
<span class="dashboard-severity-status-value" id="dashboard-status-confirmed">0</span>
<span class="dashboard-severity-status-label" data-i18n="dashboard.statusConfirmed">已确认</span>
</div>
</div>
<div class="dashboard-severity-status-cell s-fixed" role="button" tabindex="0" onclick="switchPage('vulnerabilities')" onkeydown="if(event.key==='Enter'||event.key===' ') { event.preventDefault(); switchPage('vulnerabilities'); }">
<span class="dashboard-severity-status-icon" aria-hidden="true">
<svg width="16" height="16" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"><path d="M12 22s8-4 8-10V5l-8-3-8 3v7c0 6 8 10 8 10z"/><polyline points="9 12 11 14 15 10"/></svg>
</span>
<div class="dashboard-severity-status-text">
<span class="dashboard-severity-status-value" id="dashboard-status-fixed">0</span>
<span class="dashboard-severity-status-label" data-i18n="dashboard.statusFixed">已修复</span>
</div>
</div>
<div class="dashboard-severity-status-cell s-fp" role="button" tabindex="0" onclick="switchPage('vulnerabilities')" onkeydown="if(event.key==='Enter'||event.key===' ') { event.preventDefault(); switchPage('vulnerabilities'); }">
<span class="dashboard-severity-status-icon" aria-hidden="true">
<svg width="16" height="16" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"><circle cx="12" cy="12" r="10"/><line x1="4.93" y1="4.93" x2="19.07" y2="19.07"/></svg>
</span>
<div class="dashboard-severity-status-text">
<span class="dashboard-severity-status-value" id="dashboard-status-fp">0</span>
<span class="dashboard-severity-status-label" data-i18n="dashboard.statusFalsePositive">误报</span>
</div>
</div>
</div>
<div class="dashboard-severity-progress">
<div class="dashboard-severity-progress-meta">
<span class="dashboard-severity-progress-title" data-i18n="dashboard.fixRate">修复率</span>
<span class="dashboard-severity-progress-value">
<span id="dashboard-fix-rate">0%</span>
<span class="dashboard-severity-progress-detail" id="dashboard-fix-detail">(0 / 0)</span>
</span>
</div>
<div class="dashboard-severity-progress-track" aria-hidden="true">
<div class="dashboard-severity-progress-fixed" id="dashboard-fix-progress-fixed" style="width: 0%"></div>
<div class="dashboard-severity-progress-confirmed" id="dashboard-fix-progress-confirmed" style="width: 0%"></div>
</div>
<div class="dashboard-severity-progress-legend">
<span class="dashboard-severity-progress-legend-item"><span class="dashboard-severity-progress-legend-dot legend-fixed"></span><span data-i18n="dashboard.statusFixed">已修复</span></span>
<span class="dashboard-severity-progress-legend-item"><span class="dashboard-severity-progress-legend-dot legend-confirmed"></span><span data-i18n="dashboard.statusConfirmed">已确认</span></span>
<span class="dashboard-severity-progress-legend-item"><span class="dashboard-severity-progress-legend-dot legend-open"></span><span data-i18n="dashboard.statusOpen">待处理</span></span>
</div>
</div>
</div>
</section>
<section class="dashboard-section dashboard-section-recent-vulns">
<div class="dashboard-section-header">
<h3 class="dashboard-section-title" data-i18n="dashboard.recentVulns">最近漏洞</h3>
<a class="dashboard-section-link" onclick="switchPage('vulnerabilities')" data-i18n="dashboard.viewAll">查看全部 →</a>
</div>
<div class="dashboard-recent-vulns" id="dashboard-recent-vulns">
<div class="dashboard-recent-vulns-empty" id="dashboard-recent-vulns-empty" data-i18n="dashboard.noVulnYet">暂无最近漏洞</div>
</div>
</section>
<section class="dashboard-section dashboard-section-overview">
<h3 class="dashboard-section-title" data-i18n="dashboard.runOverview">运行概览</h3>
<div class="dashboard-section-header">
<h3 class="dashboard-section-title" data-i18n="dashboard.batchQueues">批量任务队列</h3>
<a class="dashboard-section-link" onclick="switchPage('tasks')" data-i18n="dashboard.viewAll">查看全部 →</a>
</div>
<div class="dashboard-overview-list">
<div class="dashboard-overview-item dashboard-overview-item-batch" role="button" tabindex="0" onclick="switchPage('tasks')" onkeydown="if(event.key==='Enter'||event.key===' ') { event.preventDefault(); switchPage('tasks'); }">
<span class="dashboard-overview-icon dashboard-overview-icon-batch" aria-hidden="true"><svg width="16" height="16" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2"><rect x="3" y="3" width="7" height="7"/><rect x="14" y="3" width="7" height="7"/><rect x="14" y="14" width="7" height="7"/><rect x="3" y="14" width="7" height="7"/></svg></span>
@@ -356,80 +574,100 @@
</div>
</div>
</div>
<div class="dashboard-overview-item dashboard-overview-item-tools" role="button" tabindex="0" onclick="switchPage('mcp-monitor')" onkeydown="if(event.key==='Enter'||event.key===' ') { event.preventDefault(); switchPage('mcp-monitor'); }">
<span class="dashboard-overview-icon dashboard-overview-icon-tools" aria-hidden="true"><svg width="16" height="16" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2"><path d="M14.7 6.3a1 1 0 0 0 0 1.4l1.6 1.6a1 1 0 0 0 1.4 0l3.77-3.77a6 6 0 0 1-7.94 7.94l-6.91 6.91a2.12 2.12 0 0 1-3-3l6.91-6.91a6 6 0 0 1 7.94-7.94l-3.76 3.76z"/></svg></span>
<div class="dashboard-overview-content">
<div class="dashboard-overview-header">
<span class="dashboard-overview-label" data-i18n="dashboard.toolInvocations">工具调用</span>
<span class="dashboard-overview-success-rate" id="dashboard-tools-success-rate">-</span>
</div>
<div class="dashboard-overview-value-group">
<span class="dashboard-overview-value-large" id="dashboard-tools-calls">-</span>
<span class="dashboard-overview-value-unit" data-i18n="dashboard.callsUnit">次调用</span>
<span class="dashboard-overview-value-separator">·</span>
<span class="dashboard-overview-value-normal" id="dashboard-tools-count">-</span>
<span class="dashboard-overview-value-unit" data-i18n="dashboard.toolsUnit">个工具</span>
</div>
</div>
</div>
<div class="dashboard-overview-item dashboard-overview-item-knowledge" role="button" tabindex="0" onclick="switchPage('knowledge-management')" onkeydown="if(event.key==='Enter'||event.key===' ') { event.preventDefault(); switchPage('knowledge-management'); }">
<span class="dashboard-overview-icon dashboard-overview-icon-knowledge" aria-hidden="true"><svg width="16" height="16" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2"><path d="M4 19.5A2.5 2.5 0 0 1 6.5 17H20"></path><path d="M6.5 2H20v20H6.5A2.5 2.5 0 0 1 4 19.5v-15A2.5 2.5 0 0 1 6.5 2z"></path></svg></span>
<div class="dashboard-overview-content">
<div class="dashboard-overview-header">
<span class="dashboard-overview-label" data-i18n="dashboard.knowledgeLabel">知识</span>
<span class="dashboard-overview-status" id="dashboard-knowledge-status">-</span>
</div>
<div class="dashboard-overview-value-group">
<span class="dashboard-overview-value-large" id="dashboard-knowledge-items">-</span>
<span class="dashboard-overview-value-unit" data-i18n="dashboard.knowledgeItems">项知识</span>
<span class="dashboard-overview-value-separator">·</span>
<span class="dashboard-overview-value-normal" id="dashboard-knowledge-categories">-</span>
<span class="dashboard-overview-value-unit" data-i18n="dashboard.categoriesUnit">个分类</span>
</div>
</div>
</div>
<div class="dashboard-overview-item dashboard-overview-item-skills" role="button" tabindex="0" onclick="switchPage('skills-monitor')" onkeydown="if(event.key==='Enter'||event.key===' ') { event.preventDefault(); switchPage('skills-monitor'); }">
<span class="dashboard-overview-icon dashboard-overview-icon-skills" aria-hidden="true"><svg width="16" height="16" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2"><path d="M14 2H6a2 2 0 0 0-2 2v16a2 2 0 0 0 2 2h12a2 2 0 0 0 2-2V8z"/><polyline points="14 2 14 8 20 8"/><line x1="16" y1="13" x2="8" y2="13"/><line x1="16" y1="17" x2="8" y2="17"/></svg></span>
<div class="dashboard-overview-content">
<div class="dashboard-overview-header">
<span class="dashboard-overview-label" data-i18n="dashboard.skillsLabel">Skills</span>
<span class="dashboard-overview-status" id="dashboard-skills-status">-</span>
</div>
<div class="dashboard-overview-value-group">
<span class="dashboard-overview-value-large" id="dashboard-skills-calls">-</span>
<span class="dashboard-overview-value-unit" data-i18n="dashboard.callsUnit">次调用</span>
<span class="dashboard-overview-value-separator">·</span>
<span class="dashboard-overview-value-normal" id="dashboard-skills-count">-</span>
<span class="dashboard-overview-value-unit" data-i18n="dashboard.skillUnit">个 Skill</span>
</div>
</div>
</div>
</div>
</section>
<section class="dashboard-section dashboard-section-quick dashboard-quick-inline">
<h3 class="dashboard-section-title" data-i18n="dashboard.quickLinks">快捷入口</h3>
<div class="dashboard-quick-links dashboard-quick-links-row">
<a class="dashboard-quick-link" onclick="switchPage('chat')"><span class="dashboard-quick-icon"><svg width="16" height="16" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2"><path d="M21 15a2 2 0 0 1-2 2H7l-4 4V5a2 2 0 0 1 2-2h14a2 2 0 0 1 2 2z"></path></svg></span><span data-i18n="nav.chat">对话</span></a>
<a class="dashboard-quick-link" onclick="switchPage('tasks')"><span class="dashboard-quick-icon"><svg width="16" height="16" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2"><path d="M9 11l3 3L22 4"></path><path d="M21 12v7a2 2 0 0 1-2 2H5a2 2 0 0 1-2-2V5a2 2 0 0 1 2-2h11"></path></svg></span><span data-i18n="nav.tasks">任务管理</span></a>
<a class="dashboard-quick-link" onclick="switchPage('vulnerabilities')"><span class="dashboard-quick-icon"><svg width="16" height="16" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2"><path d="M12 22s8-4 8-10V5l-8-3-8 3v7c0 6 8 10 8 10z"></path></svg></span><span data-i18n="nav.vulnerabilities">漏洞管理</span></a>
<a class="dashboard-quick-link" onclick="switchPage('mcp-management')"><span class="dashboard-quick-icon"><svg width="16" height="16" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2"><path d="M13 2L3 14h9l-1 8 10-12h-9l1-8z"></path></svg></span><span data-i18n="nav.mcpManagement">MCP 管理</span></a>
<a class="dashboard-quick-link" onclick="switchPage('knowledge-management')"><span class="dashboard-quick-icon"><svg width="16" height="16" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2"><path d="M2 3h6a4 4 0 0 1 4 4v14a3 3 0 0 0-3-3H2z"></path><path d="M22 3h-6a4 4 0 0 0-4 4v14a3 3 0 0 1 3-3h7z"></path></svg></span><span data-i18n="nav.knowledgeManagement">知识管理</span></a>
<a class="dashboard-quick-link" onclick="switchPage('skills-management')"><span class="dashboard-quick-icon"><svg width="16" height="16" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2"><path d="M14.5 2H6a2 2 0 0 0-2 2v16a2 2 0 0 0 2 2h12a2 2 0 0 0 2-2V7.5L14.5 2z"></path><polyline points="14 2 14 8 20 8"></polyline></svg></span><span data-i18n="nav.skillsManagement">Skills 管理</span></a>
<a class="dashboard-quick-link" onclick="switchPage('roles-management')"><span class="dashboard-quick-icon"><svg width="16" height="16" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2"><path d="M17 21v-2a4 4 0 0 0-4-4H5a4 4 0 0 0-4 4v2"></path><circle cx="9" cy="7" r="4"></circle><path d="M23 21v-2a4 4 0 0 0-3-3.87"></path><path d="M16 3.13a4 4 0 0 1 0 7.75"></path></svg></span><span data-i18n="nav.rolesManagement">角色管理</span></a>
<!-- 推荐操作:基于当前数据状态智能生成(如「修复 4 个待处理严重漏洞」「审批 2 个 HITL」),
比纯静态导航更有意义;当没有任何推荐时整个 section 隐藏 -->
<section class="dashboard-section dashboard-section-recommend" id="dashboard-section-recommend" hidden>
<div class="dashboard-section-header">
<h3 class="dashboard-section-title" data-i18n="dashboard.recommendedActions">推荐操作</h3>
<span class="dashboard-section-hint" data-i18n="dashboard.recommendedActionsHint">基于当前状态自动生成</span>
</div>
<div class="dashboard-recommend-list" id="dashboard-recommend-list"></div>
</section>
</div>
<div class="dashboard-side">
<section class="dashboard-section dashboard-section-tools">
<h3 class="dashboard-section-title" data-i18n="dashboard.toolsExecCount">工具执行次数</h3>
<div class="dashboard-section-header">
<h3 class="dashboard-section-title" data-i18n="dashboard.toolsExecCount">工具执行次数</h3>
<a class="dashboard-section-link" onclick="switchPage('mcp-monitor')" data-i18n="dashboard.viewAll">查看全部 →</a>
</div>
<div class="dashboard-tools-chart-wrap">
<div class="dashboard-tools-chart-placeholder" id="dashboard-tools-pie-placeholder" data-i18n="common.noData">暂无数据</div>
<div class="dashboard-tools-bar-chart" id="dashboard-tools-bar-chart"></div>
</div>
</section>
<!-- 最近事件:拉 /api/notifications/summary 取最新 3 条;空时整个隐藏 -->
<section class="dashboard-section dashboard-section-events" id="dashboard-section-events" hidden>
<div class="dashboard-section-header">
<h3 class="dashboard-section-title" data-i18n="dashboard.recentEvents">最近事件</h3>
<a class="dashboard-section-link" onclick="if(typeof toggleNotificationDropdown==='function') toggleNotificationDropdown()" data-i18n="dashboard.viewAll">查看全部 →</a>
</div>
<div class="dashboard-events-list" id="dashboard-events-list"></div>
</section>
<section class="dashboard-section dashboard-section-resources">
<h3 class="dashboard-section-title" data-i18n="dashboard.capabilities">能力总览</h3>
<div class="dashboard-resource-list" id="dashboard-resource-list">
<a class="dashboard-resource-item" onclick="switchPage('mcp-management')" role="button" tabindex="0">
<span class="dashboard-resource-icon dashboard-resource-icon-mcp" aria-hidden="true">
<svg width="16" height="16" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"><path d="M13 2L3 14h9l-1 8 10-12h-9l1-8z"/></svg>
</span>
<span class="dashboard-resource-label" data-i18n="dashboard.mcpTools">MCP 工具</span>
<span class="dashboard-resource-value" id="dashboard-resource-tools">-</span>
</a>
<!-- External MCP 服务器健康度:N 运行 / N 异常;只有配置过 External MCP 才显示 -->
<a class="dashboard-resource-item" id="dashboard-resource-external-mcp-row" onclick="switchPage('mcp-management')" role="button" tabindex="0" hidden>
<span class="dashboard-resource-icon dashboard-resource-icon-external" aria-hidden="true">
<svg width="16" height="16" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"><circle cx="12" cy="12" r="10"/><line x1="2" y1="12" x2="22" y2="12"/><path d="M12 2a15.3 15.3 0 0 1 4 10 15.3 15.3 0 0 1-4 10 15.3 15.3 0 0 1-4-10 15.3 15.3 0 0 1 4-10z"/></svg>
</span>
<span class="dashboard-resource-label" data-i18n="dashboard.externalMcpServers">External MCP</span>
<span class="dashboard-resource-value" id="dashboard-resource-external-mcp">
<span id="dashboard-resource-external-mcp-text">-</span>
<span class="dashboard-resource-health" id="dashboard-resource-external-mcp-health" hidden></span>
</span>
</a>
<a class="dashboard-resource-item" onclick="switchPage('skills-management')" role="button" tabindex="0">
<span class="dashboard-resource-icon dashboard-resource-icon-skills" aria-hidden="true">
<svg width="16" height="16" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"><path d="M14.5 2H6a2 2 0 0 0-2 2v16a2 2 0 0 0 2 2h12a2 2 0 0 0 2-2V7.5L14.5 2z"/><polyline points="14 2 14 8 20 8"/></svg>
</span>
<span class="dashboard-resource-label" data-i18n="dashboard.skillsLabel">Skills</span>
<span class="dashboard-resource-value" id="dashboard-resource-skills">-</span>
</a>
<a class="dashboard-resource-item" onclick="switchPage('knowledge-management')" role="button" tabindex="0">
<span class="dashboard-resource-icon dashboard-resource-icon-knowledge" aria-hidden="true">
<svg width="16" height="16" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"><path d="M4 19.5A2.5 2.5 0 0 1 6.5 17H20"/><path d="M6.5 2H20v20H6.5A2.5 2.5 0 0 1 4 19.5v-15A2.5 2.5 0 0 1 6.5 2z"/></svg>
</span>
<span class="dashboard-resource-label" data-i18n="dashboard.knowledgeLabel">知识</span>
<span class="dashboard-resource-value" id="dashboard-resource-knowledge">-</span>
</a>
<a class="dashboard-resource-item" onclick="switchPage('roles-management')" role="button" tabindex="0">
<span class="dashboard-resource-icon dashboard-resource-icon-roles" aria-hidden="true">
<svg width="16" height="16" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"><path d="M17 21v-2a4 4 0 0 0-4-4H5a4 4 0 0 0-4 4v2"/><circle cx="9" cy="7" r="4"/><path d="M23 21v-2a4 4 0 0 0-3-3.87"/><path d="M16 3.13a4 4 0 0 1 0 7.75"/></svg>
</span>
<span class="dashboard-resource-label" data-i18n="dashboard.rolesLabel">角色</span>
<span class="dashboard-resource-value" id="dashboard-resource-roles">-</span>
</a>
<a class="dashboard-resource-item" onclick="switchPage('agents-management')" role="button" tabindex="0">
<span class="dashboard-resource-icon dashboard-resource-icon-agents" aria-hidden="true">
<svg width="16" height="16" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"><polygon points="12 2 2 7 12 12 22 7 12 2"/><polyline points="2 17 12 22 22 17"/><polyline points="2 12 12 17 22 12"/></svg>
</span>
<span class="dashboard-resource-label" data-i18n="dashboard.agentsLabel">Agents</span>
<span class="dashboard-resource-value" id="dashboard-resource-agents">-</span>
</a>
<!-- WebShell 连接:渗透落地后建立的 foothold,对安全运维场景非常关键 -->
<a class="dashboard-resource-item" onclick="switchPage('webshell')" role="button" tabindex="0">
<span class="dashboard-resource-icon dashboard-resource-icon-webshell" aria-hidden="true">
<svg width="16" height="16" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"><polyline points="4 17 10 11 4 5"/><line x1="12" y1="19" x2="20" y2="19"/></svg>
</span>
<span class="dashboard-resource-label" data-i18n="dashboard.webshellLabel">WebShell</span>
<span class="dashboard-resource-value" id="dashboard-resource-webshell">-</span>
</a>
</div>
</section>
</div>
</div>
<div class="dashboard-cta-block">
<!-- "开始你的安全之旅" CTA:默认显示;当用户已经有数据(任务/漏洞/调用)后,由 JS 隐藏避免冗余 -->
<div class="dashboard-cta-block" id="dashboard-cta-block">
<div class="dashboard-cta-content">
<div class="dashboard-cta-icon" aria-hidden="true">
<svg width="28" height="28" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="1.75" stroke-linecap="round" stroke-linejoin="round"><path d="M21 15a2 2 0 0 1-2 2H7l-4 4V5a2 2 0 0 1 2-2h14a2 2 0 0 1 2 2z"></path></svg>
@@ -1097,6 +1335,18 @@
<span data-i18n="vulnerabilityPage.conversationId">会话ID</span>
<input type="text" id="vulnerability-conversation-filter" data-i18n="vulnerabilityPage.filterConversation" data-i18n-attr="placeholder" placeholder="筛选特定会话" />
</label>
<label>
<span data-i18n="vulnerabilityPage.taskOrQueueId">任务ID/队列ID</span>
<input type="text" id="vulnerability-task-filter" data-i18n="vulnerabilityPage.filterTaskOrQueue" data-i18n-attr="placeholder" placeholder="筛选任务ID或队列ID" />
</label>
<label>
<span data-i18n="vulnerabilityPage.conversationTag">对话标签</span>
<input type="text" id="vulnerability-conversation-tag-filter" data-i18n="vulnerabilityPage.filterConversationTag" data-i18n-attr="placeholder" placeholder="筛选对话标签" />
</label>
<label>
<span data-i18n="vulnerabilityPage.taskTag">任务标签</span>
<input type="text" id="vulnerability-task-tag-filter" data-i18n="vulnerabilityPage.filterTaskTag" data-i18n-attr="placeholder" placeholder="筛选任务标签" />
</label>
<label>
<span data-i18n="vulnerabilityPage.severity">严重程度</span>
<select id="vulnerability-severity-filter">
@@ -1120,6 +1370,7 @@
</label>
<button class="btn-secondary" onclick="filterVulnerabilities()" data-i18n="vulnerabilityPage.filter">筛选</button>
<button class="btn-secondary" onclick="clearVulnerabilityFilters()" data-i18n="vulnerabilityPage.clear">清除</button>
<button class="btn-primary" onclick="exportVulnerabilityReports()" data-i18n="vulnerabilityPage.batchExport">批量导出</button>
</div>
</div>
@@ -2411,6 +2662,13 @@
</div>
</div>
</div>
<div class="context-menu-item" onclick="navigateToVulnerabilitiesForContextConversation()">
<svg width="16" height="16" viewBox="0 0 24 24" fill="none" xmlns="http://www.w3.org/2000/svg">
<path d="M12 22s8-4 8-10V5l-8-3-8 3v7c0 6 8 10 8 10z" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"/>
<path d="M9 12l2 2 4-4" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"/>
</svg>
<span data-i18n="contextMenu.viewVulnerabilities">查看漏洞</span>
</div>
<div class="context-menu-divider"></div>
<div class="context-menu-item" onclick="renameConversation()">
<svg width="16" height="16" viewBox="0 0 24 24" fill="none" xmlns="http://www.w3.org/2000/svg">
@@ -2599,6 +2857,14 @@
<label for="vulnerability-conversation-id"><span data-i18n="vulnerabilityModal.conversationId">会话ID</span> <span style="color: red;">*</span></label>
<input type="text" id="vulnerability-conversation-id" data-i18n="vulnerabilityModal.conversationIdPlaceholder" data-i18n-attr="placeholder" placeholder="输入会话ID" required />
</div>
<div class="form-group">
<label for="vulnerability-conversation-tag" data-i18n="vulnerabilityModal.conversationTag">对话标签</label>
<input type="text" id="vulnerability-conversation-tag" data-i18n="vulnerabilityModal.conversationTagPlaceholder" data-i18n-attr="placeholder" placeholder="如:红队演练A、客户A周报" />
</div>
<div class="form-group">
<label for="vulnerability-task-tag" data-i18n="vulnerabilityModal.taskTag">任务标签</label>
<input type="text" id="vulnerability-task-tag" data-i18n="vulnerabilityModal.taskTagPlaceholder" data-i18n-attr="placeholder" placeholder="如:批量扫描Q2、专项复测" />
</div>
<div class="form-group">
<label for="vulnerability-title"><span data-i18n="vulnerabilityModal.title">标题</span> <span style="color: red;">*</span></label>
<input type="text" id="vulnerability-title" data-i18n="vulnerabilityModal.titlePlaceholder" data-i18n-attr="placeholder" placeholder="漏洞标题" required />
@@ -2693,6 +2959,25 @@
<label for="webshell-cmd-param" data-i18n="webshell.cmdParam">命令参数名</label>
<input type="text" id="webshell-cmd-param" data-i18n="webshell.cmdParamPlaceholder" data-i18n-attr="placeholder" placeholder="不填默认为 cmd,如 xxx 则请求为 xxx=命令" />
</div>
<div class="form-group">
<label for="webshell-os" data-i18n="webshell.os">目标系统</label>
<select id="webshell-os">
<option value="auto" data-i18n="webshell.osAuto">自动(按 Shell 类型推断)</option>
<option value="linux" data-i18n="webshell.osLinux">Linux / Unix</option>
<option value="windows" data-i18n="webshell.osWindows">Windows</option>
</select>
<small class="form-hint" data-i18n="webshell.osHint">决定文件管理/上传使用 Linux 还是 Windows 命令;PHP/JSP 跑在 Windows 上请选 Windows</small>
</div>
<div class="form-group">
<label for="webshell-encoding" data-i18n="webshell.encoding">响应编码</label>
<select id="webshell-encoding">
<option value="auto" data-i18n="webshell.encodingAuto">自动检测</option>
<option value="utf-8" data-i18n="webshell.encodingUtf8">UTF-8</option>
<option value="gbk" data-i18n="webshell.encodingGbk">GBK(中文 Windows</option>
<option value="gb18030" data-i18n="webshell.encodingGb18030">GB18030</option>
</select>
<small class="form-hint" data-i18n="webshell.encodingHint">中文 Windows 目标若出现乱码,请切换为 GBK 或 GB18030</small>
</div>
<div class="form-group">
<label for="webshell-remark" data-i18n="webshell.remark">备注</label>
<input type="text" id="webshell-remark" data-i18n="webshell.remarkPlaceholder" data-i18n-attr="placeholder" placeholder="便于识别的备注名" />
@@ -2804,6 +3089,7 @@
<script src="/static/js/i18n.js"></script>
<script src="/static/js/builtin-tools.js"></script>
<script src="/static/js/auth.js"></script>
<script src="/static/js/notifications.js"></script>
<script src="/static/js/info-collect.js"></script>
<script src="/static/js/router.js"></script>
<script src="/static/js/agents.js"></script>
@@ -2817,7 +3103,7 @@
<script src="/static/js/terminal.js"></script>
<script src="/static/js/knowledge.js"></script>
<script src="/static/js/skills.js"></script>
<script src="/static/js/vulnerability.js?v=4"></script>
<script src="/static/js/vulnerability.js?v=7"></script>
<script src="/static/js/webshell.js"></script>
<script src="/static/js/chat-files.js"></script>
<script src="/static/js/tasks.js"></script>