Compare commits

...

52 Commits

Author SHA1 Message Date
公明 e28ae39b9a Update config.yaml 2026-06-24 02:04:49 +08:00
公明 df34ceda68 Add files via upload 2026-06-24 01:50:13 +08:00
公明 3e69a50f87 Add files via upload 2026-06-24 01:49:43 +08:00
公明 53325ce07d Add files via upload 2026-06-24 01:49:09 +08:00
公明 d85de3461b Add files via upload 2026-06-24 01:47:33 +08:00
公明 9306303d99 Add files via upload 2026-06-24 01:46:30 +08:00
公明 1e8f72ed74 Add files via upload 2026-06-24 01:44:47 +08:00
公明 0198f50314 Add files via upload 2026-06-24 01:43:37 +08:00
公明 560d0dca43 Add files via upload 2026-06-24 01:42:15 +08:00
公明 47486a49c2 Update version number to v1.6.44 2026-06-23 21:17:08 +08:00
公明 476727933d Update config.yaml 2026-06-23 21:16:41 +08:00
公明 8bb50e8323 Add files via upload 2026-06-23 21:15:45 +08:00
公明 e74f2a2292 Add files via upload 2026-06-23 21:14:08 +08:00
公明 4799d0dba7 Add files via upload 2026-06-23 21:12:26 +08:00
公明 1db917061d Add files via upload 2026-06-23 21:10:47 +08:00
公明 41cd7db30f Add files via upload 2026-06-23 21:08:59 +08:00
公明 68b3265f3f Add files via upload 2026-06-23 21:07:01 +08:00
公明 05dc4395a1 Add files via upload 2026-06-23 21:06:14 +08:00
公明 637a35748b Add files via upload 2026-06-23 21:03:59 +08:00
公明 5d77a99236 Add files via upload 2026-06-23 21:01:35 +08:00
公明 e84d936f85 Add files via upload 2026-06-23 20:59:20 +08:00
公明 e748201ae8 Add files via upload 2026-06-23 20:57:47 +08:00
公明 7a3c67458c Add files via upload 2026-06-23 16:53:32 +08:00
公明 6e9e43eec8 Add files via upload 2026-06-23 15:43:15 +08:00
公明 bca86e48ae Add files via upload 2026-06-23 15:40:04 +08:00
公明 3f3b8b4db4 Add files via upload 2026-06-23 15:37:23 +08:00
公明 b366dc0287 Add files via upload 2026-06-23 15:35:12 +08:00
公明 a52452ceea Add files via upload 2026-06-23 15:32:41 +08:00
公明 5b87667782 Update config.yaml 2026-06-23 15:32:18 +08:00
公明 4f0e812d37 Add files via upload 2026-06-23 15:31:23 +08:00
公明 79691c021f Add files via upload 2026-06-23 15:09:53 +08:00
公明 5a8309a015 Add files via upload 2026-06-23 15:07:41 +08:00
公明 6244197339 Add files via upload 2026-06-23 15:06:02 +08:00
公明 eb14aca05a Add files via upload 2026-06-23 15:03:23 +08:00
公明 091e8a4da8 Add files via upload 2026-06-23 15:00:44 +08:00
公明 48ce0c519e Add files via upload 2026-06-23 12:34:50 +08:00
公明 afc37051c0 Add files via upload 2026-06-23 12:33:35 +08:00
公明 2964247361 Add files via upload 2026-06-23 12:31:05 +08:00
公明 02919df476 Add files via upload 2026-06-23 12:28:37 +08:00
公明 c3294d96a2 Add files via upload 2026-06-23 12:28:07 +08:00
公明 c8b8b41bda Add files via upload 2026-06-23 12:26:40 +08:00
公明 9a4c333b90 Add files via upload 2026-06-23 12:25:20 +08:00
公明 8e21ae290a Add files via upload 2026-06-23 12:22:50 +08:00
公明 b9d102d046 Add files via upload 2026-06-23 11:54:28 +08:00
公明 8c85494a05 Add files via upload 2026-06-23 11:52:15 +08:00
公明 c3d2a41301 Add files via upload 2026-06-23 01:54:29 +08:00
公明 1a2e282d46 Add files via upload 2026-06-23 01:39:55 +08:00
公明 8129f2147f Delete internal/multiagent/eino_empty_response_test.go 2026-06-23 01:37:34 +08:00
公明 4a9889f0af Add files via upload 2026-06-23 01:36:48 +08:00
公明 732d47a965 Add files via upload 2026-06-22 23:31:42 +08:00
公明 e22382aab0 Add files via upload 2026-06-22 23:29:57 +08:00
公明 b6ff80adf2 Add files via upload 2026-06-22 23:27:30 +08:00
55 changed files with 3392 additions and 1252 deletions
+4 -1
View File
@@ -10,7 +10,7 @@
# ============================================ # ============================================
# 前端显示的版本号(可选,不填则显示默认版本) # 前端显示的版本号(可选,不填则显示默认版本)
version: "v1.6.42" version: "v1.6.45"
# 服务器配置 # 服务器配置
server: server:
host: 0.0.0.0 # 监听地址,0.0.0.0 表示监听所有网络接口 host: 0.0.0.0 # 监听地址,0.0.0.0 表示监听所有网络接口
@@ -40,6 +40,9 @@ audit:
retention_days: 15 # 0 表示不自动清理 retention_days: 15 # 0 表示不自动清理
max_detail_bytes: 8192 max_detail_bytes: 8192
auth_failure_cooldown_seconds: 60 # 同一 IP 登录/改密失败审计最短间隔(秒);未配置时默认 60;-1 关闭节流 auth_failure_cooldown_seconds: 60 # 同一 IP 登录/改密失败审计最短间隔(秒);未配置时默认 60;-1 关闭节流
# MCP 状态监控执行记录保留(tool_executions 表)
monitor:
retention_days: 90 # 省略时默认 90;0 表示不自动清理
# ============================================ # ============================================
# 对话相关配置 # 对话相关配置
# ============================================ # ============================================
+9 -1
View File
@@ -25,6 +25,7 @@ import (
"cyberstrike-ai/internal/logger" "cyberstrike-ai/internal/logger"
"cyberstrike-ai/internal/mcp" "cyberstrike-ai/internal/mcp"
"cyberstrike-ai/internal/mcp/builtin" "cyberstrike-ai/internal/mcp/builtin"
"cyberstrike-ai/internal/monitor"
"cyberstrike-ai/internal/robot" "cyberstrike-ai/internal/robot"
"cyberstrike-ai/internal/security" "cyberstrike-ai/internal/security"
"cyberstrike-ai/internal/skillpackage" "cyberstrike-ai/internal/skillpackage"
@@ -99,6 +100,10 @@ func New(cfg *config.Config, log *logger.Logger, configPath string) (*App, error
auditSvc.PurgeExpired() auditSvc.PurgeExpired()
audit.StartRetentionLoop(auditSvc, log.Logger) audit.StartRetentionLoop(auditSvc, log.Logger)
monitorRetention := monitor.NewService(db, cfg, log.Logger)
monitorRetention.PurgeExpired()
monitor.StartRetentionLoop(monitorRetention, log.Logger)
// 创建MCP服务器(带数据库持久化) // 创建MCP服务器(带数据库持久化)
mcpServer := mcp.NewServerWithStorage(log.Logger, db) mcpServer := mcp.NewServerWithStorage(log.Logger, db)
mcpServer.ConfigureHTTPToolCallTimeoutFromAgentMinutes(cfg.Agent.ToolTimeoutMinutes) mcpServer.ConfigureHTTPToolCallTimeoutFromAgentMinutes(cfg.Agent.ToolTimeoutMinutes)
@@ -298,7 +303,8 @@ func New(cfg *config.Config, log *logger.Logger, configPath string) (*App, error
plantaskBase := filepath.Join(skillsDir, plantaskRel) plantaskBase := filepath.Join(skillsDir, plantaskRel)
// Match eino_adk_run_loop: checkpoint_dir is used as configured (relative to process CWD when not absolute). // Match eino_adk_run_loop: checkpoint_dir is used as configured (relative to process CWD when not absolute).
checkpointBase := strings.TrimSpace(cfg.MultiAgent.EinoMiddleware.CheckpointDir) checkpointBase := strings.TrimSpace(cfg.MultiAgent.EinoMiddleware.CheckpointDir)
db.SetEinoConversationDirs(plantaskBase, checkpointBase) reductionRoot := strings.TrimSpace(cfg.MultiAgent.EinoMiddleware.ReductionRootDir)
db.SetEinoConversationDirs(plantaskBase, checkpointBase, reductionRoot)
agent.SetPromptBaseDir(configDir) agent.SetPromptBaseDir(configDir)
agentsDir := cfg.AgentsDir agentsDir := cfg.AgentsDir
@@ -325,6 +331,7 @@ func New(cfg *config.Config, log *logger.Logger, configPath string) (*App, error
} }
monitorHandler := handler.NewMonitorHandler(mcpServer, executor, db, log.Logger) monitorHandler := handler.NewMonitorHandler(mcpServer, executor, db, log.Logger)
monitorHandler.SetAudit(auditSvc) monitorHandler.SetAudit(auditSvc)
monitorHandler.SetMonitorRetention(monitorRetention)
monitorHandler.SetExternalMCPManager(externalMCPMgr) // 设置外部MCP管理器,以便获取外部MCP执行记录 monitorHandler.SetExternalMCPManager(externalMCPMgr) // 设置外部MCP管理器,以便获取外部MCP执行记录
notificationHandler := handler.NewNotificationHandler(db, agentHandler, log.Logger) notificationHandler := handler.NewNotificationHandler(db, agentHandler, log.Logger)
groupHandler := handler.NewGroupHandler(db, log.Logger) groupHandler := handler.NewGroupHandler(db, log.Logger)
@@ -368,6 +375,7 @@ func New(cfg *config.Config, log *logger.Logger, configPath string) (*App, error
// 创建OpenAPI处理器 // 创建OpenAPI处理器
conversationHandler := handler.NewConversationHandler(db, log.Logger) conversationHandler := handler.NewConversationHandler(db, log.Logger)
conversationHandler.SetAudit(auditSvc) conversationHandler.SetAudit(auditSvc)
conversationHandler.SetTaskStopper(agentHandler)
auditHandler := handler.NewAuditHandler(db, auditSvc, log.Logger) auditHandler := handler.NewAuditHandler(db, auditSvc, log.Logger)
robotHandler := handler.NewRobotHandler(cfg, db, agentHandler, log.Logger) robotHandler := handler.NewRobotHandler(cfg, db, agentHandler, log.Logger)
openAPIHandler := handler.NewOpenAPIHandler(db, log.Logger, conversationHandler, agentHandler) openAPIHandler := handler.NewOpenAPIHandler(db, log.Logger, conversationHandler, agentHandler)
+25 -3
View File
@@ -27,6 +27,7 @@ type Config struct {
Database DatabaseConfig `yaml:"database"` Database DatabaseConfig `yaml:"database"`
Auth AuthConfig `yaml:"auth"` Auth AuthConfig `yaml:"auth"`
Audit AuditConfig `yaml:"audit,omitempty" json:"audit,omitempty"` Audit AuditConfig `yaml:"audit,omitempty" json:"audit,omitempty"`
Monitor MonitorConfig `yaml:"monitor,omitempty" json:"monitor,omitempty"`
ExternalMCP ExternalMCPConfig `yaml:"external_mcp,omitempty"` ExternalMCP ExternalMCPConfig `yaml:"external_mcp,omitempty"`
Knowledge KnowledgeConfig `yaml:"knowledge,omitempty"` Knowledge KnowledgeConfig `yaml:"knowledge,omitempty"`
C2 C2Config `yaml:"c2,omitempty" json:"c2,omitempty"` // 内置 C2 总开关;未配置时默认启用 C2 C2Config `yaml:"c2,omitempty" json:"c2,omitempty"` // 内置 C2 总开关;未配置时默认启用
@@ -249,7 +250,7 @@ type MultiAgentEinoMiddlewareConfig struct {
SummarizationTriggerRatio float64 `yaml:"summarization_trigger_ratio,omitempty" json:"summarization_trigger_ratio,omitempty"` SummarizationTriggerRatio float64 `yaml:"summarization_trigger_ratio,omitempty" json:"summarization_trigger_ratio,omitempty"`
// SummarizationEmitInternalEvents controls middleware internal event emission (default true). // SummarizationEmitInternalEvents controls middleware internal event emission (default true).
SummarizationEmitInternalEvents *bool `yaml:"summarization_emit_internal_events,omitempty" json:"summarization_emit_internal_events,omitempty"` SummarizationEmitInternalEvents *bool `yaml:"summarization_emit_internal_events,omitempty" json:"summarization_emit_internal_events,omitempty"`
// SummarizationRetryMaxAttempts is extra retries after the first summarization Generate attempt; 0 = default 3. // SummarizationRetryMaxAttempts 已废弃:summarization 与 run loop 共用 run_retry_max_attempts 及 isEinoTransientRunError。
SummarizationRetryMaxAttempts int `yaml:"summarization_retry_max_attempts,omitempty" json:"summarization_retry_max_attempts,omitempty"` SummarizationRetryMaxAttempts int `yaml:"summarization_retry_max_attempts,omitempty" json:"summarization_retry_max_attempts,omitempty"`
// PlanExecuteUserInputBudgetRatio caps planner/replanner/executor userInput prompt budget ratio (default 0.35). // PlanExecuteUserInputBudgetRatio caps planner/replanner/executor userInput prompt budget ratio (default 0.35).
PlanExecuteUserInputBudgetRatio float64 `yaml:"plan_execute_user_input_budget_ratio,omitempty" json:"plan_execute_user_input_budget_ratio,omitempty"` PlanExecuteUserInputBudgetRatio float64 `yaml:"plan_execute_user_input_budget_ratio,omitempty" json:"plan_execute_user_input_budget_ratio,omitempty"`
@@ -263,9 +264,9 @@ type MultiAgentEinoMiddlewareConfig struct {
CheckpointDir string `yaml:"checkpoint_dir,omitempty" json:"checkpoint_dir,omitempty"` CheckpointDir string `yaml:"checkpoint_dir,omitempty" json:"checkpoint_dir,omitempty"`
// DeepOutputKey passed to deep.Config OutputKey (session final text); empty = off. // DeepOutputKey passed to deep.Config OutputKey (session final text); empty = off.
DeepOutputKey string `yaml:"deep_output_key,omitempty" json:"deep_output_key,omitempty"` DeepOutputKey string `yaml:"deep_output_key,omitempty" json:"deep_output_key,omitempty"`
// DeepModelRetryMaxRetries > 0 enables deep.Config ModelRetryConfig (framework-level chat model retries). // DeepModelRetryMaxRetries 已废弃:临时错误统一由 run loop 内 isEinoTransientRunError + run_retry_max_attempts 处理。
DeepModelRetryMaxRetries int `yaml:"deep_model_retry_max_retries,omitempty" json:"deep_model_retry_max_retries,omitempty"` DeepModelRetryMaxRetries int `yaml:"deep_model_retry_max_retries,omitempty" json:"deep_model_retry_max_retries,omitempty"`
// RunRetryMaxAttempts > 0429/5xx/网络抖动时 handler 分段续跑次数0=默认 10。 // RunRetryMaxAttempts > 0429/5xx/网络抖动时可退避重试次数(run loop 与 summarization 共用)0=默认 10。
RunRetryMaxAttempts int `yaml:"run_retry_max_attempts,omitempty" json:"run_retry_max_attempts,omitempty"` RunRetryMaxAttempts int `yaml:"run_retry_max_attempts,omitempty" json:"run_retry_max_attempts,omitempty"`
// RunRetryMaxBackoffSec 单次退避上限秒数;0=默认 30。 // RunRetryMaxBackoffSec 单次退避上限秒数;0=默认 30。
RunRetryMaxBackoffSec int `yaml:"run_retry_max_backoff_sec,omitempty" json:"run_retry_max_backoff_sec,omitempty"` RunRetryMaxBackoffSec int `yaml:"run_retry_max_backoff_sec,omitempty" json:"run_retry_max_backoff_sec,omitempty"`
@@ -623,6 +624,23 @@ type AuthConfig struct {
GeneratedPasswordPersistErr string `yaml:"-" json:"-"` GeneratedPasswordPersistErr string `yaml:"-" json:"-"`
} }
// MonitorConfig MCP 状态监控(tool_executions)保留策略。
type MonitorConfig struct {
// RetentionDays 执行记录保留天数;省略时默认 90;0 表示不自动清理。
RetentionDays *int `yaml:"retention_days,omitempty" json:"retention_days,omitempty"`
}
// RetentionDaysEffective returns retention; 0 means keep forever; omitted defaults to 90.
func (m MonitorConfig) RetentionDaysEffective() int {
if m.RetentionDays == nil {
return 90
}
if *m.RetentionDays < 0 {
return 0
}
return *m.RetentionDays
}
// AuditConfig platform operation audit log settings (not chat/tool execution bodies). // AuditConfig platform operation audit log settings (not chat/tool execution bodies).
type AuditConfig struct { type AuditConfig struct {
// Enabled nil or true enables persistence; explicit false disables. // Enabled nil or true enables persistence; explicit false disables.
@@ -1274,6 +1292,10 @@ func Default() *Config {
Enabled: &on, Enabled: &on,
} }
}(), }(),
Monitor: func() MonitorConfig {
days := 90
return MonitorConfig{RetentionDays: &days}
}(),
Robots: RobotsConfig{ Robots: RobotsConfig{
Session: RobotSessionConfig{ Session: RobotSessionConfig{
StrictUserIdentity: &strictRobotIdentity, StrictUserIdentity: &strictRobotIdentity,
+16 -12
View File
@@ -23,6 +23,7 @@ type BatchTaskQueueRow struct {
LastScheduleError sql.NullString LastScheduleError sql.NullString
LastRunError sql.NullString LastRunError sql.NullString
ProjectID sql.NullString ProjectID sql.NullString
Concurrency sql.NullInt64
Status string Status string
CreatedAt time.Time CreatedAt time.Time
StartedAt sql.NullTime StartedAt sql.NullTime
@@ -53,6 +54,7 @@ func (db *DB) CreateBatchQueue(
cronExpr string, cronExpr string,
nextRunAt *time.Time, nextRunAt *time.Time,
projectID string, projectID string,
concurrency int,
tasks []map[string]interface{}, tasks []map[string]interface{},
) error { ) error {
tx, err := db.Begin() tx, err := db.Begin()
@@ -72,8 +74,8 @@ func (db *DB) CreateBatchQueue(
projectIDVal = strings.TrimSpace(projectID) projectIDVal = strings.TrimSpace(projectID)
} }
_, err = tx.Exec( _, err = tx.Exec(
"INSERT INTO batch_task_queues (id, title, role, agent_mode, schedule_mode, cron_expr, next_run_at, schedule_enabled, project_id, status, created_at, current_index) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)", "INSERT INTO batch_task_queues (id, title, role, agent_mode, schedule_mode, cron_expr, next_run_at, schedule_enabled, project_id, concurrency, status, created_at, current_index) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
queueID, title, role, agentMode, scheduleMode, cronExpr, nextRunAtValue, 1, projectIDVal, "pending", now, 0, queueID, title, role, agentMode, scheduleMode, cronExpr, nextRunAtValue, 1, projectIDVal, concurrency, "pending", now, 0,
) )
if err != nil { if err != nil {
return fmt.Errorf("创建批量任务队列失败: %w", err) return fmt.Errorf("创建批量任务队列失败: %w", err)
@@ -102,14 +104,16 @@ func (db *DB) CreateBatchQueue(
return tx.Commit() return tx.Commit()
} }
const batchQueueSelectColumns = `id, title, role, agent_mode, schedule_mode, cron_expr, next_run_at, schedule_enabled, last_schedule_trigger_at, last_schedule_error, last_run_error, project_id, concurrency, status, created_at, started_at, completed_at, current_index`
// GetBatchQueue 获取批量任务队列 // GetBatchQueue 获取批量任务队列
func (db *DB) GetBatchQueue(queueID string) (*BatchTaskQueueRow, error) { func (db *DB) GetBatchQueue(queueID string) (*BatchTaskQueueRow, error) {
var row BatchTaskQueueRow var row BatchTaskQueueRow
var createdAt string var createdAt string
err := db.QueryRow( err := db.QueryRow(
"SELECT id, title, role, agent_mode, schedule_mode, cron_expr, next_run_at, schedule_enabled, last_schedule_trigger_at, last_schedule_error, last_run_error, project_id, status, created_at, started_at, completed_at, current_index FROM batch_task_queues WHERE id = ?", "SELECT "+batchQueueSelectColumns+" FROM batch_task_queues WHERE id = ?",
queueID, queueID,
).Scan(&row.ID, &row.Title, &row.Role, &row.AgentMode, &row.ScheduleMode, &row.CronExpr, &row.NextRunAt, &row.ScheduleEnabled, &row.LastScheduleTriggerAt, &row.LastScheduleError, &row.LastRunError, &row.ProjectID, &row.Status, &createdAt, &row.StartedAt, &row.CompletedAt, &row.CurrentIndex) ).Scan(&row.ID, &row.Title, &row.Role, &row.AgentMode, &row.ScheduleMode, &row.CronExpr, &row.NextRunAt, &row.ScheduleEnabled, &row.LastScheduleTriggerAt, &row.LastScheduleError, &row.LastRunError, &row.ProjectID, &row.Concurrency, &row.Status, &createdAt, &row.StartedAt, &row.CompletedAt, &row.CurrentIndex)
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
return nil, nil return nil, nil
} }
@@ -133,7 +137,7 @@ func (db *DB) GetBatchQueue(queueID string) (*BatchTaskQueueRow, error) {
// GetAllBatchQueues 获取所有批量任务队列 // GetAllBatchQueues 获取所有批量任务队列
func (db *DB) GetAllBatchQueues() ([]*BatchTaskQueueRow, error) { func (db *DB) GetAllBatchQueues() ([]*BatchTaskQueueRow, error) {
rows, err := db.Query( rows, err := db.Query(
"SELECT id, title, role, agent_mode, schedule_mode, cron_expr, next_run_at, schedule_enabled, last_schedule_trigger_at, last_schedule_error, last_run_error, project_id, status, created_at, started_at, completed_at, current_index FROM batch_task_queues ORDER BY created_at DESC", "SELECT "+batchQueueSelectColumns+" FROM batch_task_queues ORDER BY created_at DESC",
) )
if err != nil { if err != nil {
return nil, fmt.Errorf("查询批量任务队列列表失败: %w", err) return nil, fmt.Errorf("查询批量任务队列列表失败: %w", err)
@@ -144,7 +148,7 @@ func (db *DB) GetAllBatchQueues() ([]*BatchTaskQueueRow, error) {
for rows.Next() { for rows.Next() {
var row BatchTaskQueueRow var row BatchTaskQueueRow
var createdAt string var createdAt string
if err := rows.Scan(&row.ID, &row.Title, &row.Role, &row.AgentMode, &row.ScheduleMode, &row.CronExpr, &row.NextRunAt, &row.ScheduleEnabled, &row.LastScheduleTriggerAt, &row.LastScheduleError, &row.LastRunError, &row.ProjectID, &row.Status, &createdAt, &row.StartedAt, &row.CompletedAt, &row.CurrentIndex); err != nil { if err := rows.Scan(&row.ID, &row.Title, &row.Role, &row.AgentMode, &row.ScheduleMode, &row.CronExpr, &row.NextRunAt, &row.ScheduleEnabled, &row.LastScheduleTriggerAt, &row.LastScheduleError, &row.LastRunError, &row.ProjectID, &row.Concurrency, &row.Status, &createdAt, &row.StartedAt, &row.CompletedAt, &row.CurrentIndex); err != nil {
return nil, fmt.Errorf("扫描批量任务队列失败: %w", err) return nil, fmt.Errorf("扫描批量任务队列失败: %w", err)
} }
parsedTime, parseErr := time.Parse("2006-01-02 15:04:05", createdAt) parsedTime, parseErr := time.Parse("2006-01-02 15:04:05", createdAt)
@@ -164,7 +168,7 @@ func (db *DB) GetAllBatchQueues() ([]*BatchTaskQueueRow, error) {
// ListBatchQueues 列出批量任务队列(支持筛选和分页) // ListBatchQueues 列出批量任务队列(支持筛选和分页)
func (db *DB) ListBatchQueues(limit, offset int, status, keyword string) ([]*BatchTaskQueueRow, error) { func (db *DB) ListBatchQueues(limit, offset int, status, keyword string) ([]*BatchTaskQueueRow, error) {
query := "SELECT id, title, role, agent_mode, schedule_mode, cron_expr, next_run_at, schedule_enabled, last_schedule_trigger_at, last_schedule_error, last_run_error, project_id, status, created_at, started_at, completed_at, current_index FROM batch_task_queues WHERE 1=1" query := "SELECT " + batchQueueSelectColumns + " FROM batch_task_queues WHERE 1=1"
args := []interface{}{} args := []interface{}{}
// 状态筛选 // 状态筛选
@@ -192,7 +196,7 @@ func (db *DB) ListBatchQueues(limit, offset int, status, keyword string) ([]*Bat
for rows.Next() { for rows.Next() {
var row BatchTaskQueueRow var row BatchTaskQueueRow
var createdAt string var createdAt string
if err := rows.Scan(&row.ID, &row.Title, &row.Role, &row.AgentMode, &row.ScheduleMode, &row.CronExpr, &row.NextRunAt, &row.ScheduleEnabled, &row.LastScheduleTriggerAt, &row.LastScheduleError, &row.LastRunError, &row.ProjectID, &row.Status, &createdAt, &row.StartedAt, &row.CompletedAt, &row.CurrentIndex); err != nil { if err := rows.Scan(&row.ID, &row.Title, &row.Role, &row.AgentMode, &row.ScheduleMode, &row.CronExpr, &row.NextRunAt, &row.ScheduleEnabled, &row.LastScheduleTriggerAt, &row.LastScheduleError, &row.LastRunError, &row.ProjectID, &row.Concurrency, &row.Status, &createdAt, &row.StartedAt, &row.CompletedAt, &row.CurrentIndex); err != nil {
return nil, fmt.Errorf("扫描批量任务队列失败: %w", err) return nil, fmt.Errorf("扫描批量任务队列失败: %w", err)
} }
parsedTime, parseErr := time.Parse("2006-01-02 15:04:05", createdAt) parsedTime, parseErr := time.Parse("2006-01-02 15:04:05", createdAt)
@@ -358,11 +362,11 @@ func (db *DB) UpdateBatchQueueCurrentIndex(queueID string, currentIndex int) err
return nil return nil
} }
// UpdateBatchQueueMetadata 更新批量任务队列标题、角色代理模式 // UpdateBatchQueueMetadata 更新批量任务队列标题、角色代理模式和并发数
func (db *DB) UpdateBatchQueueMetadata(queueID, title, role, agentMode string) error { func (db *DB) UpdateBatchQueueMetadata(queueID, title, role, agentMode string, concurrency int) error {
_, err := db.Exec( _, err := db.Exec(
"UPDATE batch_task_queues SET title = ?, role = ?, agent_mode = ? WHERE id = ?", "UPDATE batch_task_queues SET title = ?, role = ?, agent_mode = ?, concurrency = ? WHERE id = ?",
title, role, agentMode, queueID, title, role, agentMode, concurrency, queueID,
) )
if err != nil { if err != nil {
return fmt.Errorf("更新批量任务队列元数据失败: %w", err) return fmt.Errorf("更新批量任务队列元数据失败: %w", err)
+186 -5
View File
@@ -352,8 +352,8 @@ func (db *DB) GetConversationLite(id string) (*Conversation, error) {
conv.Pinned = pinned != 0 conv.Pinned = pinned != 0
// 加载消息(不加载 process_details // 加载消息(不加载 process_details / reasoning_content,减少历史会话切换 payload
messages, err := db.GetMessages(id) messages, err := db.GetMessagesLite(id)
if err != nil { if err != nil {
return nil, fmt.Errorf("加载消息失败: %w", err) return nil, fmt.Errorf("加载消息失败: %w", err)
} }
@@ -585,12 +585,14 @@ func (db *DB) DeleteConversation(id string) error {
// 不返回错误,继续删除对话 // 不返回错误,继续删除对话
} }
projectID, _ := db.GetConversationProjectID(id)
// 删除对话(外键CASCADE会自动删除其他相关数据) // 删除对话(外键CASCADE会自动删除其他相关数据)
_, err = db.Exec("DELETE FROM conversations WHERE id = ?", id) _, err = db.Exec("DELETE FROM conversations WHERE id = ?", id)
if err != nil { if err != nil {
return fmt.Errorf("删除对话失败: %w", err) return fmt.Errorf("删除对话失败: %w", err)
} }
db.removeConversationScopedDirs(id) db.removeConversationScopedDirs(id, projectID)
db.logger.Info("对话已删除(漏洞记录已保留)", zap.String("conversationId", id)) db.logger.Info("对话已删除(漏洞记录已保留)", zap.String("conversationId", id))
return nil return nil
@@ -628,13 +630,35 @@ func (db *DB) removeConversationScopedDir(base, conversationID, label string) {
} }
} }
func (db *DB) removeConversationScopedDirs(conversationID string) { func (db *DB) einoReductionBaseDir() string {
// summarization transcript, reduction files, etc. if db == nil {
return ""
}
if base := strings.TrimSpace(db.einoReductionRootDir); base != "" {
return base
}
return filepath.Join("tmp", "reduction")
}
func (db *DB) removeConversationScopedDirs(conversationID, projectID string) {
// summarization transcript, etc.
db.removeConversationScopedDir(db.conversationArtifactsDir, conversationID, "conversation_artifacts") db.removeConversationScopedDir(db.conversationArtifactsDir, conversationID, "conversation_artifacts")
// Eino plantask JSON boards (skills_dir/.eino/plantask/<id>/). // Eino plantask JSON boards (skills_dir/.eino/plantask/<id>/).
db.removeConversationScopedDir(db.einoPlantaskBaseDir, conversationID, "plantask") db.removeConversationScopedDir(db.einoPlantaskBaseDir, conversationID, "plantask")
// Eino ADK runner checkpoints (checkpoint_dir/<id>/). // Eino ADK runner checkpoints (checkpoint_dir/<id>/).
db.removeConversationScopedDir(db.einoCheckpointBaseDir, conversationID, "eino_checkpoint") db.removeConversationScopedDir(db.einoCheckpointBaseDir, conversationID, "eino_checkpoint")
// Eino reduction persisted tool outputs (tmp/reduction/conversations/<id>/).
// Project-bound sessions share projects/<id>/ — skip on single conversation delete.
if strings.TrimSpace(projectID) == "" {
reductionBase := filepath.Join(db.einoReductionBaseDir(), "conversations")
db.removeConversationScopedDir(reductionBase, conversationID, "reduction")
}
}
func (db *DB) removeProjectScopedDirs(projectID string) {
// Eino reduction persisted tool outputs (tmp/reduction/projects/<id>/).
reductionBase := filepath.Join(db.einoReductionBaseDir(), "projects")
db.removeConversationScopedDir(reductionBase, projectID, "reduction")
} }
// SaveAgentTrace 保存最后一轮代理消息轨迹与助手输出摘要。 // SaveAgentTrace 保存最后一轮代理消息轨迹与助手输出摘要。
@@ -811,6 +835,62 @@ func (db *DB) GetMessages(conversationID string) ([]Message, error) {
return messages, nil return messages, nil
} }
// GetMessagesLite 获取对话消息(不含 reasoning_content),用于历史会话快速切换。
func (db *DB) GetMessagesLite(conversationID string) ([]Message, error) {
rows, err := db.Query(
"SELECT id, conversation_id, role, content, mcp_execution_ids, created_at, updated_at FROM messages WHERE conversation_id = ? ORDER BY created_at ASC, rowid ASC",
conversationID,
)
if err != nil {
return nil, fmt.Errorf("查询消息失败: %w", err)
}
defer rows.Close()
var messages []Message
for rows.Next() {
var msg Message
var mcpIDsJSON sql.NullString
var createdAt string
var updatedAt sql.NullString
if err := rows.Scan(&msg.ID, &msg.ConversationID, &msg.Role, &msg.Content, &mcpIDsJSON, &createdAt, &updatedAt); err != nil {
return nil, fmt.Errorf("扫描消息失败: %w", err)
}
var err error
msg.CreatedAt, err = time.Parse("2006-01-02 15:04:05.999999999-07:00", createdAt)
if err != nil {
msg.CreatedAt, err = time.Parse("2006-01-02 15:04:05", createdAt)
}
if err != nil {
msg.CreatedAt, _ = time.Parse(time.RFC3339, createdAt)
}
if updatedAt.Valid && strings.TrimSpace(updatedAt.String) != "" {
msg.UpdatedAt, err = time.Parse("2006-01-02 15:04:05.999999999-07:00", updatedAt.String)
if err != nil {
msg.UpdatedAt, err = time.Parse("2006-01-02 15:04:05", updatedAt.String)
}
if err != nil {
msg.UpdatedAt, _ = time.Parse(time.RFC3339, updatedAt.String)
}
}
if msg.UpdatedAt.IsZero() {
msg.UpdatedAt = msg.CreatedAt
}
if mcpIDsJSON.Valid && mcpIDsJSON.String != "" {
if err := json.Unmarshal([]byte(mcpIDsJSON.String), &msg.MCPExecutionIDs); err != nil {
db.logger.Warn("解析MCP执行ID失败", zap.Error(err))
}
}
messages = append(messages, msg)
}
return messages, nil
}
// turnSliceRange 根据任意一条消息 ID 定位「一轮对话」在 msgs 中的 [start, end) 下标区间(msgs 须已按时间升序,与 GetMessages 一致)。 // turnSliceRange 根据任意一条消息 ID 定位「一轮对话」在 msgs 中的 [start, end) 下标区间(msgs 须已按时间升序,与 GetMessages 一致)。
// 一轮 = 从某条 user 消息起,至下一条 user 之前(含中间所有 assistant)。 // 一轮 = 从某条 user 消息起,至下一条 user 之前(含中间所有 assistant)。
func turnSliceRange(msgs []Message, anchorID string) (start, end int, err error) { func turnSliceRange(msgs []Message, anchorID string) (start, end int, err error) {
@@ -979,6 +1059,107 @@ func (db *DB) GetProcessDetails(messageID string) ([]ProcessDetail, error) {
return details, nil return details, nil
} }
// ProcessDetailsSummary 过程详情摘要(用于折叠态展示,避免全量加载)。
type ProcessDetailsSummary struct {
Total int `json:"total"`
IterationCount int `json:"iterationCount"`
MaxIteration int `json:"maxIteration"`
}
// GetProcessDetailsSummary 统计消息的过程详情数量与迭代轮次。
func (db *DB) GetProcessDetailsSummary(messageID string) (*ProcessDetailsSummary, error) {
var total int
if err := db.QueryRow(
"SELECT COUNT(*) FROM process_details WHERE message_id = ?",
messageID,
).Scan(&total); err != nil {
return nil, fmt.Errorf("统计过程详情失败: %w", err)
}
summary := &ProcessDetailsSummary{Total: total}
if total == 0 {
return summary, nil
}
rows, err := db.Query(
"SELECT data FROM process_details WHERE message_id = ? AND event_type = 'iteration' ORDER BY created_at ASC, rowid ASC",
messageID,
)
if err != nil {
return nil, fmt.Errorf("查询迭代详情失败: %w", err)
}
defer rows.Close()
maxIter := 0
iterCount := 0
for rows.Next() {
var dataJSON string
if err := rows.Scan(&dataJSON); err != nil {
return nil, fmt.Errorf("扫描迭代详情失败: %w", err)
}
iterCount++
if dataJSON == "" {
continue
}
var payload map[string]interface{}
if err := json.Unmarshal([]byte(dataJSON), &payload); err != nil {
continue
}
if n, ok := payload["iteration"].(float64); ok && int(n) > maxIter {
maxIter = int(n)
}
}
summary.IterationCount = iterCount
summary.MaxIteration = maxIter
return summary, nil
}
// GetProcessDetailsPage 分页获取消息的过程详情(按时间升序)。
func (db *DB) GetProcessDetailsPage(messageID string, limit, offset int) ([]ProcessDetail, int, error) {
var total int
if err := db.QueryRow(
"SELECT COUNT(*) FROM process_details WHERE message_id = ?",
messageID,
).Scan(&total); err != nil {
return nil, 0, fmt.Errorf("统计过程详情失败: %w", err)
}
if total == 0 || offset >= total {
return nil, total, nil
}
rows, err := db.Query(
"SELECT id, message_id, conversation_id, event_type, message, data, created_at FROM process_details WHERE message_id = ? ORDER BY created_at ASC, rowid ASC LIMIT ? OFFSET ?",
messageID, limit, offset,
)
if err != nil {
return nil, 0, fmt.Errorf("查询过程详情失败: %w", err)
}
defer rows.Close()
var details []ProcessDetail
for rows.Next() {
var detail ProcessDetail
var createdAt string
if err := rows.Scan(&detail.ID, &detail.MessageID, &detail.ConversationID, &detail.EventType, &detail.Message, &detail.Data, &createdAt); err != nil {
return nil, 0, fmt.Errorf("扫描过程详情失败: %w", err)
}
var parseErr error
detail.CreatedAt, parseErr = time.Parse("2006-01-02 15:04:05.999999999-07:00", createdAt)
if parseErr != nil {
detail.CreatedAt, parseErr = time.Parse("2006-01-02 15:04:05", createdAt)
}
if parseErr != nil {
detail.CreatedAt, _ = time.Parse(time.RFC3339, createdAt)
}
details = append(details, detail)
}
return details, total, nil
}
// GetProcessDetailsByConversation 获取对话的所有过程详情(按消息分组) // GetProcessDetailsByConversation 获取对话的所有过程详情(按消息分组)
func (db *DB) GetProcessDetailsByConversation(conversationID string) (map[string][]ProcessDetail, error) { func (db *DB) GetProcessDetailsByConversation(conversationID string) (map[string][]ProcessDetail, error) {
rows, err := db.Query( rows, err := db.Query(
+39 -2
View File
@@ -19,7 +19,8 @@ func TestDeleteConversationRemovesEinoScopedDirs(t *testing.T) {
plantaskBase := filepath.Join(tmp, "skills", ".eino", "plantask") plantaskBase := filepath.Join(tmp, "skills", ".eino", "plantask")
checkpointBase := filepath.Join(tmp, "eino-checkpoints") checkpointBase := filepath.Join(tmp, "eino-checkpoints")
db.SetEinoConversationDirs(plantaskBase, checkpointBase) reductionBase := filepath.Join(tmp, "reduction")
db.SetEinoConversationDirs(plantaskBase, checkpointBase, reductionBase)
conv, err := db.CreateConversation("cleanup test", ConversationCreateMeta{}) conv, err := db.CreateConversation("cleanup test", ConversationCreateMeta{})
if err != nil { if err != nil {
@@ -34,6 +35,7 @@ func TestDeleteConversationRemovesEinoScopedDirs(t *testing.T) {
{db.conversationArtifactsDir, "transcript.txt"}, {db.conversationArtifactsDir, "transcript.txt"},
{plantaskBase, "task-1.json"}, {plantaskBase, "task-1.json"},
{checkpointBase, "runner-deep.ckpt"}, {checkpointBase, "runner-deep.ckpt"},
{filepath.Join(reductionBase, "conversations"), "tool-output.txt"},
} { } {
dir := filepath.Join(base.root, seg) dir := filepath.Join(base.root, seg)
if err := os.MkdirAll(dir, 0o755); err != nil { if err := os.MkdirAll(dir, 0o755); err != nil {
@@ -48,10 +50,45 @@ func TestDeleteConversationRemovesEinoScopedDirs(t *testing.T) {
t.Fatalf("DeleteConversation: %v", err) t.Fatalf("DeleteConversation: %v", err)
} }
for _, base := range []string{db.conversationArtifactsDir, plantaskBase, checkpointBase} { for _, base := range []string{db.conversationArtifactsDir, plantaskBase, checkpointBase, filepath.Join(reductionBase, "conversations")} {
dir := filepath.Join(base, seg) dir := filepath.Join(base, seg)
if _, statErr := os.Stat(dir); !os.IsNotExist(statErr) { if _, statErr := os.Stat(dir); !os.IsNotExist(statErr) {
t.Fatalf("expected removed dir %s, stat err=%v", dir, statErr) t.Fatalf("expected removed dir %s, stat err=%v", dir, statErr)
} }
} }
} }
func TestDeleteProjectRemovesReductionDir(t *testing.T) {
tmp := t.TempDir()
dbPath := filepath.Join(tmp, "conversations.db")
db, err := NewDB(dbPath, zap.NewNop())
if err != nil {
t.Fatalf("NewDB: %v", err)
}
defer db.Close()
reductionBase := filepath.Join(tmp, "reduction")
db.SetEinoConversationDirs("", "", reductionBase)
project, err := db.CreateProject(&Project{Name: "cleanup test"})
if err != nil {
t.Fatalf("CreateProject: %v", err)
}
seg := sanitizeConversationPathSegment(project.ID)
reductionDir := filepath.Join(reductionBase, "projects", seg, "clear")
if err := os.MkdirAll(reductionDir, 0o755); err != nil {
t.Fatalf("mkdir %s: %v", reductionDir, err)
}
if err := os.WriteFile(filepath.Join(reductionDir, "call-1.txt"), []byte("x"), 0o644); err != nil {
t.Fatalf("write: %v", err)
}
if err := db.DeleteProject(project.ID); err != nil {
t.Fatalf("DeleteProject: %v", err)
}
projectReductionDir := filepath.Join(reductionBase, "projects", seg)
if _, statErr := os.Stat(projectReductionDir); !os.IsNotExist(statErr) {
t.Fatalf("expected removed dir %s, stat err=%v", projectReductionDir, statErr)
}
}
+21 -1
View File
@@ -51,6 +51,7 @@ type DB struct {
conversationArtifactsDir string conversationArtifactsDir string
einoPlantaskBaseDir string // skills_dir + plantask_rel_dir (per-conversation subdirs) einoPlantaskBaseDir string // skills_dir + plantask_rel_dir (per-conversation subdirs)
einoCheckpointBaseDir string // checkpoint_dir root (per-conversation subdirs) einoCheckpointBaseDir string // checkpoint_dir root (per-conversation subdirs)
einoReductionRootDir string // reduction_root_dir or default tmp/reduction (conversations/<id> subdirs)
checkpointLoopName string checkpointLoopName string
checkpointStop chan struct{} checkpointStop chan struct{}
checkpointDone chan struct{} checkpointDone chan struct{}
@@ -159,12 +160,14 @@ func NewDB(dbPath string, logger *zap.Logger) (*DB, error) {
// SetEinoConversationDirs configures best-effort filesystem cleanup on DeleteConversation. // SetEinoConversationDirs configures best-effort filesystem cleanup on DeleteConversation.
// plantaskBase is skills_root/plantask_rel (no conversation id); checkpointBase is checkpoint_dir root. // plantaskBase is skills_root/plantask_rel (no conversation id); checkpointBase is checkpoint_dir root.
func (db *DB) SetEinoConversationDirs(plantaskBase, checkpointBase string) { // reductionRoot is reduction_root_dir from config; empty uses tmp/reduction (conversation-scoped subdirs only).
func (db *DB) SetEinoConversationDirs(plantaskBase, checkpointBase, reductionRoot string) {
if db == nil { if db == nil {
return return
} }
db.einoPlantaskBaseDir = strings.TrimSpace(plantaskBase) db.einoPlantaskBaseDir = strings.TrimSpace(plantaskBase)
db.einoCheckpointBaseDir = strings.TrimSpace(checkpointBase) db.einoCheckpointBaseDir = strings.TrimSpace(checkpointBase)
db.einoReductionRootDir = strings.TrimSpace(reductionRoot)
} }
// initTables 初始化数据库表 // initTables 初始化数据库表
@@ -405,6 +408,8 @@ func (db *DB) initTables() error {
last_schedule_trigger_at DATETIME, last_schedule_trigger_at DATETIME,
last_schedule_error TEXT, last_schedule_error TEXT,
last_run_error TEXT, last_run_error TEXT,
project_id TEXT,
concurrency INTEGER NOT NULL DEFAULT 1,
status TEXT NOT NULL, status TEXT NOT NULL,
created_at DATETIME NOT NULL, created_at DATETIME NOT NULL,
started_at DATETIME, started_at DATETIME,
@@ -1134,6 +1139,21 @@ func (db *DB) migrateBatchTaskQueuesTable() error {
} }
} }
var concurrencyCount int
err = db.QueryRow("SELECT COUNT(*) FROM pragma_table_info('batch_task_queues') WHERE name='concurrency'").Scan(&concurrencyCount)
if err != nil {
if _, addErr := db.Exec("ALTER TABLE batch_task_queues ADD COLUMN concurrency INTEGER NOT NULL DEFAULT 1"); addErr != nil {
errMsg := strings.ToLower(addErr.Error())
if !strings.Contains(errMsg, "duplicate column") && !strings.Contains(errMsg, "already exists") {
db.logger.Warn("添加batch_task_queues.concurrency字段失败", zap.Error(addErr))
}
}
} else if concurrencyCount == 0 {
if _, err := db.Exec("ALTER TABLE batch_task_queues ADD COLUMN concurrency INTEGER NOT NULL DEFAULT 1"); err != nil {
db.logger.Warn("添加batch_task_queues.concurrency字段失败", zap.Error(err))
}
}
return nil return nil
} }
+70
View File
@@ -410,6 +410,76 @@ func (db *DB) GetToolExecutionsByIds(ids []string) ([]*mcp.ToolExecution, error)
return executions, nil return executions, nil
} }
type toolExecutionStatDelta struct {
totalCalls int
successCalls int
failedCalls int
}
// PurgeToolExecutionsBefore deletes executions older than cutoff and adjusts tool_stats.
func (db *DB) PurgeToolExecutionsBefore(cutoff time.Time) (int64, error) {
query := `
SELECT tool_name, status, COUNT(*) AS cnt
FROM tool_executions
WHERE ` + sqliteEpochGE("start_time", "<") + `
GROUP BY tool_name, status
`
rows, err := db.Query(query, formatSQLiteUTC(cutoff))
if err != nil {
return 0, err
}
defer rows.Close()
deltas := make(map[string]*toolExecutionStatDelta)
for rows.Next() {
var toolName, status string
var count int
if err := rows.Scan(&toolName, &status, &count); err != nil {
db.logger.Warn("读取待清理执行记录统计失败", zap.Error(err))
continue
}
toolName = strings.TrimSpace(toolName)
if toolName == "" || count <= 0 {
continue
}
delta := deltas[toolName]
if delta == nil {
delta = &toolExecutionStatDelta{}
deltas[toolName] = delta
}
delta.totalCalls += count
switch status {
case "failed", "cancelled":
delta.failedCalls += count
case "completed":
delta.successCalls += count
}
}
if err := rows.Err(); err != nil {
return 0, err
}
res, err := db.Exec(`DELETE FROM tool_executions WHERE `+sqliteEpochGE("start_time", "<"), formatSQLiteUTC(cutoff))
if err != nil {
return 0, err
}
deleted, err := res.RowsAffected()
if err != nil {
return 0, err
}
for toolName, delta := range deltas {
if err := db.DecreaseToolStats(toolName, delta.totalCalls, delta.successCalls, delta.failedCalls); err != nil {
db.logger.Warn("清理过期执行记录后更新统计失败",
zap.Error(err),
zap.String("toolName", toolName),
)
}
}
return deleted, nil
}
// SaveToolStats 保存工具统计信息 // SaveToolStats 保存工具统计信息
func (db *DB) SaveToolStats(toolName string, stats *mcp.ToolStats) error { func (db *DB) SaveToolStats(toolName string, stats *mcp.ToolStats) error {
var lastCallTime sql.NullTime var lastCallTime sql.NullTime
+122
View File
@@ -0,0 +1,122 @@
package database
import (
"path/filepath"
"testing"
"time"
"cyberstrike-ai/internal/mcp"
"go.uber.org/zap"
)
func TestPurgeToolExecutionsBefore(t *testing.T) {
dbPath := filepath.Join(t.TempDir(), "monitor.db")
db, err := NewDB(dbPath, zap.NewNop())
if err != nil {
t.Fatalf("NewDB: %v", err)
}
defer db.Close()
oldStart := time.Now().AddDate(0, 0, -100)
newStart := time.Now().AddDate(0, 0, -1)
oldExec := &mcp.ToolExecution{
ID: "old-completed",
ToolName: "nmap::scan",
Arguments: map[string]interface{}{"target": "127.0.0.1"},
Status: "completed",
StartTime: oldStart,
}
oldFailed := &mcp.ToolExecution{
ID: "old-failed",
ToolName: "nmap::scan",
Arguments: map[string]interface{}{"target": "127.0.0.1"},
Status: "failed",
Error: "timeout",
StartTime: oldStart,
}
newExec := &mcp.ToolExecution{
ID: "new-completed",
ToolName: "nmap::scan",
Arguments: map[string]interface{}{"target": "127.0.0.1"},
Status: "completed",
StartTime: newStart,
}
for _, exec := range []*mcp.ToolExecution{oldExec, oldFailed, newExec} {
if err := db.SaveToolExecution(exec); err != nil {
t.Fatalf("SaveToolExecution(%s): %v", exec.ID, err)
}
}
if err := db.UpdateToolStats("nmap::scan", 3, 2, 1, &newStart); err != nil {
t.Fatalf("UpdateToolStats: %v", err)
}
cutoff := time.Now().AddDate(0, 0, -90)
deleted, err := db.PurgeToolExecutionsBefore(cutoff)
if err != nil {
t.Fatalf("PurgeToolExecutionsBefore: %v", err)
}
if deleted != 2 {
t.Fatalf("deleted = %d, want 2", deleted)
}
if _, err := db.GetToolExecution("old-completed"); err == nil {
t.Fatal("old-completed should be deleted")
}
if _, err := db.GetToolExecution("old-failed"); err == nil {
t.Fatal("old-failed should be deleted")
}
if _, err := db.GetToolExecution("new-completed"); err != nil {
t.Fatalf("new-completed should remain: %v", err)
}
stats, err := db.LoadToolStats()
if err != nil {
t.Fatalf("LoadToolStats: %v", err)
}
stat := stats["nmap::scan"]
if stat == nil {
t.Fatal("expected stats for nmap::scan")
}
if stat.TotalCalls != 1 || stat.SuccessCalls != 1 || stat.FailedCalls != 0 {
t.Fatalf("stats after purge = %+v, want total=1 success=1 failed=0", stat)
}
total, err := db.CountToolExecutions("", "")
if err != nil {
t.Fatalf("CountToolExecutions: %v", err)
}
if total != 1 {
t.Fatalf("remaining executions = %d, want 1", total)
}
}
func TestPurgeToolExecutionsBefore_zeroRetentionSkipsViaService(t *testing.T) {
// RetentionDaysEffective: 0 means no purge at service layer; DB method still works when called directly.
dbPath := filepath.Join(t.TempDir(), "monitor.db")
db, err := NewDB(dbPath, zap.NewNop())
if err != nil {
t.Fatalf("NewDB: %v", err)
}
defer db.Close()
exec := &mcp.ToolExecution{
ID: "ancient",
ToolName: "curl::get",
Arguments: map[string]interface{}{},
Status: "completed",
StartTime: time.Now().AddDate(-1, 0, 0),
}
if err := db.SaveToolExecution(exec); err != nil {
t.Fatalf("SaveToolExecution: %v", err)
}
deleted, err := db.PurgeToolExecutionsBefore(time.Now())
if err != nil {
t.Fatalf("PurgeToolExecutionsBefore: %v", err)
}
if deleted != 1 {
t.Fatalf("deleted = %d, want 1", deleted)
}
}
+1
View File
@@ -195,6 +195,7 @@ func (db *DB) DeleteProject(id string) error {
if err != nil { if err != nil {
return fmt.Errorf("删除项目失败: %w", err) return fmt.Errorf("删除项目失败: %w", err)
} }
db.removeProjectScopedDirs(id)
return nil return nil
} }
+67 -427
View File
@@ -21,7 +21,6 @@ import (
"cyberstrike-ai/internal/config" "cyberstrike-ai/internal/config"
"cyberstrike-ai/internal/database" "cyberstrike-ai/internal/database"
"cyberstrike-ai/internal/reasoning" "cyberstrike-ai/internal/reasoning"
"cyberstrike-ai/internal/mcp"
"cyberstrike-ai/internal/mcp/builtin" "cyberstrike-ai/internal/mcp/builtin"
"cyberstrike-ai/internal/multiagent" "cyberstrike-ai/internal/multiagent"
"cyberstrike-ai/internal/openai" "cyberstrike-ai/internal/openai"
@@ -178,8 +177,6 @@ type AgentHandler struct {
} }
agentsMarkdownDir string // 多代理:Markdown 子 Agent 目录(绝对路径,空则不从磁盘合并) agentsMarkdownDir string // 多代理:Markdown 子 Agent 目录(绝对路径,空则不从磁盘合并)
batchCronParser cron.Parser batchCronParser cron.Parser
batchRunnerMu sync.Mutex
batchRunning map[string]struct{}
// hitlWhitelistSaver 侧栏「应用」HITL 时将会话增量白名单合并写入 config.yaml(可选) // hitlWhitelistSaver 侧栏「应用」HITL 时将会话增量白名单合并写入 config.yaml(可选)
hitlWhitelistSaver HitlToolWhitelistSaver hitlWhitelistSaver HitlToolWhitelistSaver
audit *audit.Service audit *audit.Service
@@ -190,6 +187,21 @@ func (h *AgentHandler) SetAudit(s *audit.Service) {
h.audit = s h.audit = s
} }
// CancelRunningTaskForConversation stops any in-flight agent work for the conversation (idempotent).
func (h *AgentHandler) CancelRunningTaskForConversation(conversationID string) {
if h == nil || conversationID == "" || h.tasks == nil {
return
}
if execID := h.tasks.ActiveMCPExecutionID(conversationID); execID != "" {
h.agent.CancelMCPToolExecutionWithNote(execID, "")
}
if ok, err := h.tasks.CancelTask(conversationID, ErrTaskCancelled); ok {
h.logger.Info("已取消会话运行中任务", zap.String("conversationId", conversationID))
} else if err != nil {
h.logger.Warn("取消会话运行中任务失败", zap.String("conversationId", conversationID), zap.Error(err))
}
}
// HitlToolWhitelistSaver 合并 HITL 免审批工具到全局配置并落盘 // HitlToolWhitelistSaver 合并 HITL 免审批工具到全局配置并落盘
type HitlToolWhitelistSaver interface { type HitlToolWhitelistSaver interface {
MergeHitlToolWhitelistIntoConfig(add []string) error MergeHitlToolWhitelistIntoConfig(add []string) error
@@ -218,7 +230,6 @@ func NewAgentHandler(agent *agent.Agent, db *database.DB, cfg *config.Config, lo
config: cfg, config: cfg,
hitlManager: NewHITLManager(db, logger), hitlManager: NewHITLManager(db, logger),
batchCronParser: cron.NewParser(cron.Minute | cron.Hour | cron.Dom | cron.Month | cron.Dow | cron.Descriptor), batchCronParser: cron.NewParser(cron.Minute | cron.Hour | cron.Dom | cron.Month | cron.Dow | cron.Descriptor),
batchRunning: make(map[string]struct{}),
} }
if err := handler.hitlManager.EnsureSchema(); err != nil { if err := handler.hitlManager.EnsureSchema(); err != nil {
logger.Warn("初始化 HITL 表失败", zap.Error(err)) logger.Warn("初始化 HITL 表失败", zap.Error(err))
@@ -631,40 +642,11 @@ func (h *AgentHandler) runRobotEinoSingleWithRetry(
assistantMessageID string, assistantMessageID string,
taskStatus *string, taskStatus *string,
) (string, string, error) { ) (string, string, error) {
curHist := history resultMA, errMA := multiagent.RunEinoSingleChatModelAgent(
curMsg := finalMessage taskCtx, h.config, &h.config.MultiAgent, h.agent, h.db, h.logger,
segmentUserMessage := finalMessage conversationID, h.conversationProjectID(conversationID), finalMessage, history, roleTools, progressCallback, nil, h.projectBlackboardBlock(conversationID),
var resultMA *multiagent.RunResult )
var errMA error if errMA != nil {
var transientRunAttempts int
var emptyResponseAttempts int
for {
resultMA, errMA = multiagent.RunEinoSingleChatModelAgent(
taskCtx, h.config, &h.config.MultiAgent, h.agent, h.db, h.logger,
conversationID, h.conversationProjectID(conversationID), curMsg, curHist, roleTools, progressCallback, nil, h.projectBlackboardBlock(conversationID),
)
handledEmpty, exhaustedEmpty := h.handleEinoEmptyResponseContinue(
taskCtx, conversationID, resultMA, errMA, &emptyResponseAttempts,
&curHist, &curMsg, segmentUserMessage, progressCallback, nil,
)
if exhaustedEmpty {
errMA = nil
break
}
if handledEmpty {
continue
}
if errMA == nil {
transientRunAttempts = 0
emptyResponseAttempts = 0
break
}
if handled, _ := h.handleEinoTransientRetryContinue(
taskCtx, conversationID, resultMA, errMA, &transientRunAttempts,
&curHist, &curMsg, segmentUserMessage, progressCallback, nil,
); handled {
continue
}
*taskStatus = "failed" *taskStatus = "failed"
return h.finalizeRobotAgentError(taskCtx, assistantMessageID, conversationID, resultMA, errMA) return h.finalizeRobotAgentError(taskCtx, assistantMessageID, conversationID, resultMA, errMA)
} }
@@ -680,41 +662,12 @@ func (h *AgentHandler) runRobotMultiAgentWithRetry(
assistantMessageID string, assistantMessageID string,
taskStatus *string, taskStatus *string,
) (string, string, error) { ) (string, string, error) {
curHist := history resultMA, errMA := multiagent.RunDeepAgent(
curMsg := finalMessage taskCtx, h.config, &h.config.MultiAgent, h.agent, h.db, h.logger,
segmentUserMessage := finalMessage conversationID, h.conversationProjectID(conversationID), finalMessage, history, roleTools, progressCallback,
var resultMA *multiagent.RunResult h.agentsMarkdownDir, orchestration, nil, h.projectBlackboardBlock(conversationID),
var errMA error )
var transientRunAttempts int if errMA != nil {
var emptyResponseAttempts int
for {
resultMA, errMA = multiagent.RunDeepAgent(
taskCtx, h.config, &h.config.MultiAgent, h.agent, h.db, h.logger,
conversationID, h.conversationProjectID(conversationID), curMsg, curHist, roleTools, progressCallback,
h.agentsMarkdownDir, orchestration, nil, h.projectBlackboardBlock(conversationID),
)
handledEmpty, exhaustedEmpty := h.handleEinoEmptyResponseContinue(
taskCtx, conversationID, resultMA, errMA, &emptyResponseAttempts,
&curHist, &curMsg, segmentUserMessage, progressCallback, nil,
)
if exhaustedEmpty {
errMA = nil
break
}
if handledEmpty {
continue
}
if errMA == nil {
transientRunAttempts = 0
emptyResponseAttempts = 0
break
}
if handled, _ := h.handleEinoTransientRetryContinue(
taskCtx, conversationID, resultMA, errMA, &transientRunAttempts,
&curHist, &curMsg, segmentUserMessage, progressCallback, nil,
); handled {
continue
}
*taskStatus = "failed" *taskStatus = "failed"
return h.finalizeRobotAgentError(taskCtx, assistantMessageID, conversationID, resultMA, errMA) return h.finalizeRobotAgentError(taskCtx, assistantMessageID, conversationID, resultMA, errMA)
} }
@@ -1379,6 +1332,21 @@ func (h *AgentHandler) CancelAgentLoop(c *gin.Context) {
}) })
return return
} }
if h.tasks.AbortActiveEinoExecute(req.ConversationID, note) {
h.logger.Info("对话页仅终止当前 Eino execute",
zap.String("conversationId", req.ConversationID),
zap.Bool("hasNote", note != ""),
)
c.JSON(http.StatusOK, gin.H{
"status": "tool_abort_requested",
"conversationId": req.ConversationID,
"message": "已请求终止当前 execute 命令;命令返回后本轮推理将继续。",
"continueAfter": true,
"interruptWithNote": note != "",
"continueWithoutTool": false,
})
return
}
// 无进行中的 MCP 工具(模型纯推理/流式输出阶段):取消当前上下文并由 Eino 流式处理器合并用户补充后自动续跑。 // 无进行中的 MCP 工具(模型纯推理/流式输出阶段):取消当前上下文并由 Eino 流式处理器合并用户补充后自动续跑。
h.tasks.SetInterruptContinueNote(req.ConversationID, note) h.tasks.SetInterruptContinueNote(req.ConversationID, note)
ok, err := h.tasks.CancelTask(req.ConversationID, multiagent.ErrInterruptContinue) ok, err := h.tasks.CancelTask(req.ConversationID, multiagent.ErrInterruptContinue)
@@ -1498,6 +1466,7 @@ type BatchTaskRequest struct {
CronExpr string `json:"cronExpr,omitempty"` // scheduleMode=cron 时必填 CronExpr string `json:"cronExpr,omitempty"` // scheduleMode=cron 时必填
ExecuteNow bool `json:"executeNow,omitempty"` // 创建后是否立即执行(默认 false) ExecuteNow bool `json:"executeNow,omitempty"` // 创建后是否立即执行(默认 false)
ProjectID string `json:"projectId,omitempty"` // 队列内子对话绑定的项目(可选) ProjectID string `json:"projectId,omitempty"` // 队列内子对话绑定的项目(可选)
Concurrency int `json:"concurrency,omitempty"` // 同时执行的子任务数,默认 1,最大 8
} }
// batchQueueWantsEino 队列是否配置为走 Eino 多代理。 // batchQueueWantsEino 队列是否配置为走 Eino 多代理。
@@ -1557,7 +1526,7 @@ func (h *AgentHandler) CreateBatchQueue(c *gin.Context) {
nextRunAt = &next nextRunAt = &next
} }
queue, createErr := h.batchTaskManager.CreateBatchQueue(req.Title, req.Role, agentMode, scheduleMode, cronExpr, req.ProjectID, nextRunAt, validTasks) queue, createErr := h.batchTaskManager.CreateBatchQueue(req.Title, req.Role, agentMode, scheduleMode, cronExpr, req.ProjectID, nextRunAt, req.Concurrency, validTasks)
if createErr != nil { if createErr != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": createErr.Error()}) c.JSON(http.StatusBadRequest, gin.H{"error": createErr.Error()})
return return
@@ -1747,15 +1716,16 @@ func (h *AgentHandler) PauseBatchQueue(c *gin.Context) {
func (h *AgentHandler) UpdateBatchQueueMetadata(c *gin.Context) { func (h *AgentHandler) UpdateBatchQueueMetadata(c *gin.Context) {
queueID := c.Param("queueId") queueID := c.Param("queueId")
var req struct { var req struct {
Title string `json:"title"` Title string `json:"title"`
Role string `json:"role"` Role string `json:"role"`
AgentMode string `json:"agentMode"` AgentMode string `json:"agentMode"`
Concurrency *int `json:"concurrency"`
} }
if err := c.ShouldBindJSON(&req); err != nil { if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return return
} }
if err := h.batchTaskManager.UpdateQueueMetadata(queueID, req.Title, req.Role, req.AgentMode); err != nil { if err := h.batchTaskManager.UpdateQueueMetadata(queueID, req.Title, req.Role, req.AgentMode, req.Concurrency); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return return
} }
@@ -1830,9 +1800,17 @@ func (h *AgentHandler) SetBatchQueueScheduleEnabled(c *gin.Context) {
// DeleteBatchQueue 删除批量任务队列 // DeleteBatchQueue 删除批量任务队列
func (h *AgentHandler) DeleteBatchQueue(c *gin.Context) { func (h *AgentHandler) DeleteBatchQueue(c *gin.Context) {
queueID := c.Param("queueId") queueID := c.Param("queueId")
success := h.batchTaskManager.DeleteQueue(queueID) if err := h.batchTaskManager.DeleteQueue(queueID); err != nil {
if !success { switch {
c.JSON(http.StatusNotFound, gin.H{"error": "队列不存在"}) case errors.Is(err, ErrBatchQueueNotFound):
c.JSON(http.StatusNotFound, gin.H{"error": "队列不存在"})
case errors.Is(err, ErrBatchQueueExecutorActive):
c.JSON(http.StatusConflict, gin.H{"error": "队列执行器仍在运行,请稍后再删除"})
case errors.Is(err, ErrBatchQueueStillRunning):
c.JSON(http.StatusConflict, gin.H{"error": "队列正在运行中,无法删除"})
default:
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
}
return return
} }
if h.audit != nil { if h.audit != nil {
@@ -1926,7 +1904,7 @@ func (h *AgentHandler) RunSingleBatchTask(c *gin.Context) {
// 暂停态单条执行:旧批量协程可能仍占用执行槽,先回收以便重新启动 // 暂停态单条执行:旧批量协程可能仍占用执行槽,先回收以便重新启动
if queue, ok := h.batchTaskManager.GetBatchQueue(queueID); ok && queue.Status == BatchQueueStatusPaused { if queue, ok := h.batchTaskManager.GetBatchQueue(queueID); ok && queue.Status == BatchQueueStatusPaused {
h.forceUnmarkBatchQueueRunning(queueID) h.batchTaskManager.ForceUnmarkQueueExecutor(queueID)
} }
autoStarted := true autoStarted := true
@@ -1985,26 +1963,6 @@ func (h *AgentHandler) DeleteBatchTask(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{"message": "任务已删除", "queue": queue}) c.JSON(http.StatusOK, gin.H{"message": "任务已删除", "queue": queue})
} }
func (h *AgentHandler) markBatchQueueRunning(queueID string) bool {
h.batchRunnerMu.Lock()
defer h.batchRunnerMu.Unlock()
if _, exists := h.batchRunning[queueID]; exists {
return false
}
h.batchRunning[queueID] = struct{}{}
return true
}
func (h *AgentHandler) unmarkBatchQueueRunning(queueID string) {
h.batchRunnerMu.Lock()
defer h.batchRunnerMu.Unlock()
delete(h.batchRunning, queueID)
}
func (h *AgentHandler) forceUnmarkBatchQueueRunning(queueID string) {
h.unmarkBatchQueueRunning(queueID)
}
func (h *AgentHandler) nextBatchQueueRunAt(cronExpr string, from time.Time) (*time.Time, error) { func (h *AgentHandler) nextBatchQueueRunAt(cronExpr string, from time.Time) (*time.Time, error) {
expr := strings.TrimSpace(cronExpr) expr := strings.TrimSpace(cronExpr)
if expr == "" { if expr == "" {
@@ -2020,43 +1978,43 @@ func (h *AgentHandler) nextBatchQueueRunAt(cronExpr string, from time.Time) (*ti
func (h *AgentHandler) startBatchQueueExecution(queueID string, scheduled bool) (bool, error) { func (h *AgentHandler) startBatchQueueExecution(queueID string, scheduled bool) (bool, error) {
// 先获取执行互斥门,再读取队列状态,避免基于过时快照做判断 // 先获取执行互斥门,再读取队列状态,避免基于过时快照做判断
if !h.markBatchQueueRunning(queueID) { if !h.batchTaskManager.TryMarkQueueExecutor(queueID) {
return true, nil return true, nil
} }
queue, exists := h.batchTaskManager.GetBatchQueue(queueID) queue, exists := h.batchTaskManager.GetBatchQueue(queueID)
if !exists { if !exists {
h.unmarkBatchQueueRunning(queueID) h.batchTaskManager.UnmarkQueueExecutor(queueID)
return false, nil return false, nil
} }
if scheduled { if scheduled {
if queue.ScheduleMode != "cron" { if queue.ScheduleMode != "cron" {
h.unmarkBatchQueueRunning(queueID) h.batchTaskManager.UnmarkQueueExecutor(queueID)
err := fmt.Errorf("队列未启用 cron 调度") err := fmt.Errorf("队列未启用 cron 调度")
h.batchTaskManager.SetLastScheduleError(queueID, err.Error()) h.batchTaskManager.SetLastScheduleError(queueID, err.Error())
return true, err return true, err
} }
if queue.Status == "running" || queue.Status == "paused" || queue.Status == "cancelled" { if queue.Status == "running" || queue.Status == "paused" || queue.Status == "cancelled" {
h.unmarkBatchQueueRunning(queueID) h.batchTaskManager.UnmarkQueueExecutor(queueID)
err := fmt.Errorf("当前队列状态不允许被调度执行") err := fmt.Errorf("当前队列状态不允许被调度执行")
h.batchTaskManager.SetLastScheduleError(queueID, err.Error()) h.batchTaskManager.SetLastScheduleError(queueID, err.Error())
return true, err return true, err
} }
if !h.batchTaskManager.ResetQueueForRerun(queueID) { if !h.batchTaskManager.ResetQueueForRerun(queueID) {
h.unmarkBatchQueueRunning(queueID) h.batchTaskManager.UnmarkQueueExecutor(queueID)
err := fmt.Errorf("重置队列失败") err := fmt.Errorf("重置队列失败")
h.batchTaskManager.SetLastScheduleError(queueID, err.Error()) h.batchTaskManager.SetLastScheduleError(queueID, err.Error())
return true, err return true, err
} }
queue, _ = h.batchTaskManager.GetBatchQueue(queueID) queue, _ = h.batchTaskManager.GetBatchQueue(queueID)
} else if queue.Status != "pending" && queue.Status != "paused" { } else if queue.Status != "pending" && queue.Status != "paused" {
h.unmarkBatchQueueRunning(queueID) h.batchTaskManager.UnmarkQueueExecutor(queueID)
return true, fmt.Errorf("队列状态不允许启动") return true, fmt.Errorf("队列状态不允许启动")
} }
if queue != nil && batchQueueWantsEino(queue.AgentMode) && (h.config == nil || !h.config.MultiAgent.Enabled) { if queue != nil && batchQueueWantsEino(queue.AgentMode) && (h.config == nil || !h.config.MultiAgent.Enabled) {
h.unmarkBatchQueueRunning(queueID) h.batchTaskManager.UnmarkQueueExecutor(queueID)
err := fmt.Errorf("当前队列配置为 Eino 多代理,但系统未启用多代理") err := fmt.Errorf("当前队列配置为 Eino 多代理,但系统未启用多代理")
if scheduled { if scheduled {
h.batchTaskManager.SetLastScheduleError(queueID, err.Error()) h.batchTaskManager.SetLastScheduleError(queueID, err.Error())
@@ -2108,324 +2066,6 @@ func (h *AgentHandler) batchQueueSchedulerLoop() {
} }
} }
// executeBatchQueue 执行批量任务队列
func (h *AgentHandler) executeBatchQueue(queueID string) {
defer h.unmarkBatchQueueRunning(queueID)
h.logger.Info("开始执行批量任务队列", zap.String("queueId", queueID))
for {
// 检查队列状态
queue, exists := h.batchTaskManager.GetBatchQueue(queueID)
if !exists || queue.Status == "cancelled" || queue.Status == "completed" || queue.Status == "paused" {
break
}
// 获取下一个任务
task, hasNext := h.batchTaskManager.GetNextTask(queueID)
if !hasNext {
// 所有任务完成:汇总子任务失败信息便于排障
q, ok := h.batchTaskManager.GetBatchQueue(queueID)
lastRunErr := ""
if ok {
for _, t := range q.Tasks {
if t.Status == "failed" && t.Error != "" {
lastRunErr = t.Error
}
}
}
h.batchTaskManager.SetLastRunError(queueID, lastRunErr)
h.batchTaskManager.UpdateQueueStatus(queueID, "completed")
h.logger.Info("批量任务队列执行完成", zap.String("queueId", queueID))
break
}
// 更新任务状态为运行中
h.batchTaskManager.UpdateTaskStatus(queueID, task.ID, "running", "", "")
// 创建新对话
title := safeTruncateString(task.Message, 50)
batchMeta := audit.ConversationCreateMeta("batch_task")
batchMeta.ProjectID = effectiveProjectID(h.config, queue.ProjectID)
conv, err := h.db.CreateConversation(title, batchMeta)
var conversationID string
if err != nil {
h.logger.Error("创建对话失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.Error(err))
h.batchTaskManager.UpdateTaskStatus(queueID, task.ID, "failed", "", "创建对话失败: "+err.Error())
h.batchTaskManager.MoveToNextTask(queueID)
if h.batchTaskManager.TakeSingleRunTaskIfMatch(queueID, task.ID) {
h.batchTaskManager.UpdateQueueStatus(queueID, "paused")
break
}
continue
}
conversationID = conv.ID
// 保存conversationId到任务中(即使是运行中状态也要保存,以便查看对话)
h.batchTaskManager.UpdateTaskStatusWithConversationID(queueID, task.ID, "running", "", "", conversationID)
// 应用角色用户提示词和工具配置
finalMessage := task.Message
var roleTools []string // 角色配置的工具列表
if queue.Role != "" && queue.Role != "默认" {
if h.config.Roles != nil {
if role, exists := h.config.Roles[queue.Role]; exists && role.Enabled {
// 应用用户提示词
if role.UserPrompt != "" {
finalMessage = role.UserPrompt + "\n\n" + task.Message
h.logger.Info("应用角色用户提示词", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.String("role", queue.Role))
}
// 获取角色配置的工具列表(优先使用tools字段,向后兼容mcps字段)
if len(role.Tools) > 0 {
roleTools = role.Tools
h.logger.Info("使用角色配置的工具列表", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.String("role", queue.Role), zap.Int("toolCount", len(roleTools)))
}
}
}
}
// 保存用户消息(保存原始消息,不包含角色提示词)
_, err = h.db.AddMessage(conversationID, "user", task.Message, nil)
if err != nil {
h.logger.Error("保存用户消息失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.String("conversationId", conversationID), zap.Error(err))
}
// 预先创建助手消息,以便关联过程详情
assistantMsg, err := h.db.AddMessage(conversationID, "assistant", "处理中...", nil)
if err != nil {
h.logger.Error("创建助手消息失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.String("conversationId", conversationID), zap.Error(err))
// 如果创建失败,继续执行但不保存过程详情
assistantMsg = nil
}
// 创建进度回调函数,复用统一逻辑(批量任务不需要流式事件,所以传入nil)
var assistantMessageID string
if assistantMsg != nil {
assistantMessageID = assistantMsg.ID
}
// 注意:批量任务没有前端直连的 POST /stream,因此若要支持「刷新后补流」,
// 需要把进度事件镜像到 TaskEventBusGET /api/agent-loop/task-events 会订阅这里)。
// progressCallback 将在子任务的 IIFE 内创建,以便拿到 taskCtx/cancelWithCause 与 sendEvent。
var progressCallback func(eventType, message string, data interface{})
// 执行任务(使用包含角色提示词的finalMessage和角色工具列表)
h.logger.Info("执行批量任务", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.String("message", task.Message), zap.String("role", queue.Role), zap.String("conversationId", conversationID))
func() {
// 与对话流式接口一致:同 conversationId 仅允许一个运行中任务,并支持 /api/agent-loop/cancel 与会话锁对齐。
baseCtx, cancelWithCause := context.WithCancelCause(context.Background())
// 单个子任务超时:6 小时(与原先 WithTimeout(Background) 一致)
taskCtx, timeoutCancel := context.WithTimeout(baseCtx, 6*time.Hour)
registered := false
finishStatus := "completed"
defer func() {
h.batchTaskManager.SetTaskCancel(queueID, nil)
timeoutCancel()
if registered {
// 与流式接口保持一致:结束前补一个 done,便于前端 task-events 侧及时收口 UI。
if h.taskEventBus != nil {
ev := StreamEvent{Type: "done", Message: "", Data: map[string]interface{}{"conversationId": conversationID}}
if b, err := json.Marshal(ev); err == nil {
h.taskEventBus.Publish(conversationID, append(append([]byte("data: "), b...), '\n', '\n'))
}
}
h.tasks.FinishTask(conversationID, finishStatus)
}
cancelWithCause(nil)
}()
// 事件镜像:只发布到 TaskEventBus,不直接写 HTTP Response(用于刷新后的补流)。
sendEvent := func(eventType, message string, data interface{}) {
if h.taskEventBus == nil {
return
}
ev := StreamEvent{Type: eventType, Message: message, Data: data}
b, err := json.Marshal(ev)
if err != nil {
b = []byte(`{"type":"error","message":"marshal failed"}`)
}
line := make([]byte, 0, len(b)+8)
line = append(line, []byte("data: ")...)
line = append(line, b...)
line = append(line, '\n', '\n')
h.taskEventBus.Publish(conversationID, line)
}
if _, err := h.tasks.StartTask(conversationID, task.Message, cancelWithCause); err != nil {
h.logger.Warn("批量队列子任务注册会话运行状态失败",
zap.String("queueId", queueID),
zap.String("taskId", task.ID),
zap.String("conversationId", conversationID),
zap.Error(err))
failMsg := err.Error()
if errors.Is(err, ErrTaskAlreadyRunning) {
failMsg = "会话已有任务正在执行,无法在该会话上并行启动批量子任务"
}
h.batchTaskManager.UpdateTaskStatus(queueID, task.ID, "failed", "", failMsg)
return
}
registered = true
// 存储取消函数:暂停队列时取消子任务 context(与原先语义一致)
h.batchTaskManager.SetTaskCancel(queueID, timeoutCancel)
// 创建进度回调函数:写 DB + 镜像到 task-events,支持刷新后继续流式展示。
progressCallback = h.createProgressCallback(taskCtx, cancelWithCause, conversationID, assistantMessageID, sendEvent)
taskCtx = mcp.WithMCPConversationID(taskCtx, conversationID)
taskCtx = mcp.WithToolRunRegistry(taskCtx, h.tasks)
// 使用队列配置的角色工具列表(如果为空,表示使用所有工具)
useBatchMulti := false
batchOrch := "deep"
am := strings.TrimSpace(strings.ToLower(queue.AgentMode))
if am == "multi" {
am = "deep"
}
if batchQueueWantsEino(queue.AgentMode) && h.config != nil && h.config.MultiAgent.Enabled {
useBatchMulti = true
batchOrch = config.NormalizeMultiAgentOrchestration(am)
} else if queue.AgentMode == "" && h.config != nil && h.config.MultiAgent.Enabled && h.config.MultiAgent.BatchUseMultiAgent {
// 兼容历史数据:未配置队列代理模式时,沿用旧的系统级开关
useBatchMulti = true
batchOrch = "deep"
}
var resultMA *multiagent.RunResult
var runErr error
switch {
case useBatchMulti:
resultMA, runErr = multiagent.RunDeepAgent(taskCtx, h.config, &h.config.MultiAgent, h.agent, h.db, h.logger, conversationID, h.conversationProjectID(conversationID), finalMessage, []agent.ChatMessage{}, roleTools, progressCallback, h.agentsMarkdownDir, batchOrch, nil, h.projectBlackboardBlock(conversationID))
default:
if h.config == nil {
runErr = fmt.Errorf("服务器配置未加载")
} else {
resultMA, runErr = multiagent.RunEinoSingleChatModelAgent(taskCtx, h.config, &h.config.MultiAgent, h.agent, h.db, h.logger, conversationID, h.conversationProjectID(conversationID), finalMessage, []agent.ChatMessage{}, roleTools, progressCallback, nil, h.projectBlackboardBlock(conversationID))
}
}
if runErr != nil {
if shouldPersistEinoAgentTraceAfterRunError(baseCtx) {
h.persistEinoAgentTraceForResume(conversationID, resultMA)
}
errStr := runErr.Error()
partialResp := ""
if resultMA != nil {
partialResp = resultMA.Response
}
isCancelled := errors.Is(context.Cause(baseCtx), ErrTaskCancelled) ||
errors.Is(runErr, context.Canceled) ||
strings.Contains(strings.ToLower(errStr), "context canceled") ||
strings.Contains(strings.ToLower(errStr), "context cancelled") ||
(partialResp != "" && (strings.Contains(partialResp, "任务已被取消") || strings.Contains(partialResp, "任务执行中断")))
isTimeout := errors.Is(runErr, context.DeadlineExceeded) || errors.Is(context.Cause(taskCtx), context.DeadlineExceeded)
if isTimeout {
finishStatus = "timeout"
} else if isCancelled {
finishStatus = "cancelled"
} else {
finishStatus = "failed"
}
if isCancelled {
h.logger.Info("批量任务被取消", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.String("conversationId", conversationID))
cancelMsg := "任务已被用户取消,后续操作已停止。"
// 如果执行结果中有更具体的取消消息,使用它
if partialResp != "" && (strings.Contains(partialResp, "任务已被取消") || strings.Contains(partialResp, "任务执行中断")) {
cancelMsg = partialResp
}
// 更新助手消息内容
if assistantMessageID != "" {
if updateErr := h.appendAssistantMessageNotice(assistantMessageID, cancelMsg); updateErr != nil {
h.logger.Warn("更新取消后的助手消息失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.Error(updateErr))
}
// 保存取消详情到数据库
if err := h.db.AddProcessDetail(assistantMessageID, conversationID, "cancelled", cancelMsg, nil); err != nil {
h.logger.Warn("保存取消详情失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.Error(err))
}
} else {
// 如果没有预先创建的助手消息,创建一个新的
_, errMsg := h.db.AddMessage(conversationID, "assistant", cancelMsg, nil)
if errMsg != nil {
h.logger.Warn("保存取消消息失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.Error(errMsg))
}
}
h.batchTaskManager.UpdateTaskStatusWithConversationID(queueID, task.ID, "cancelled", cancelMsg, "", conversationID)
} else {
h.logger.Error("批量任务执行失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.String("conversationId", conversationID), zap.Error(runErr))
errorMsg := "执行失败: " + runErr.Error()
// 更新助手消息内容
if assistantMessageID != "" {
if _, updateErr := h.db.Exec(
"UPDATE messages SET content = ?, updated_at = ? WHERE id = ?",
errorMsg,
time.Now(), assistantMessageID,
); updateErr != nil {
h.logger.Warn("更新失败后的助手消息失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.Error(updateErr))
}
// 保存错误详情到数据库
if err := h.db.AddProcessDetail(assistantMessageID, conversationID, "error", errorMsg, nil); err != nil {
h.logger.Warn("保存错误详情失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.Error(err))
}
}
h.batchTaskManager.UpdateTaskStatus(queueID, task.ID, "failed", "", runErr.Error())
}
} else {
h.logger.Info("批量任务执行成功", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.String("conversationId", conversationID))
resText := resultMA.Response
mcpIDs := resultMA.MCPExecutionIDs
lastIn := resultMA.LastAgentTraceInput
lastOut := resultMA.LastAgentTraceOutput
// 更新助手消息内容
if assistantMessageID != "" {
if updateErr := h.db.UpdateAssistantMessageFinalize(assistantMessageID, resText, mcpIDs, multiagent.AggregatedReasoningFromTraceJSON(lastIn)); updateErr != nil {
h.logger.Warn("更新助手消息失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.Error(updateErr))
// 如果更新失败,尝试创建新消息
_, err = h.db.AddMessage(conversationID, "assistant", resText, mcpIDs)
if err != nil {
h.logger.Error("保存助手消息失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.String("conversationId", conversationID), zap.Error(err))
}
}
} else {
// 如果没有预先创建的助手消息,创建一个新的
_, err = h.db.AddMessage(conversationID, "assistant", resText, mcpIDs)
if err != nil {
h.logger.Error("保存助手消息失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.String("conversationId", conversationID), zap.Error(err))
}
}
// 保存代理轨迹
if lastIn != "" || lastOut != "" {
if err := h.db.SaveAgentTrace(conversationID, lastIn, lastOut); err != nil {
h.logger.Warn("保存代理轨迹失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.Error(err))
} else {
h.logger.Info("已保存代理轨迹", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.String("conversationId", conversationID))
}
}
// 保存结果
h.batchTaskManager.UpdateTaskStatusWithConversationID(queueID, task.ID, "completed", resText, "", conversationID)
}
}()
// 移动到下一个任务
h.batchTaskManager.MoveToNextTask(queueID)
if h.batchTaskManager.TakeSingleRunTaskIfMatch(queueID, task.ID) {
h.batchTaskManager.UpdateQueueStatus(queueID, "paused")
h.logger.Info("单条执行完成,队列已暂停", zap.String("queueId", queueID), zap.String("taskId", task.ID))
break
}
// 检查是否被取消或暂停
queue, _ = h.batchTaskManager.GetBatchQueue(queueID)
if queue.Status == "cancelled" || queue.Status == "paused" {
break
}
}
}
// loadHistoryFromAgentTrace 从库中保存的代理消息轨迹恢复历史(列 last_react_*;含单代理与 Eino)。 // loadHistoryFromAgentTrace 从库中保存的代理消息轨迹恢复历史(列 last_react_*;含单代理与 Eino)。
// 逻辑与攻击链一致:优先用已保存的 JSON 消息带 + 最后一轮助手摘要,否则回退消息表。 // 逻辑与攻击链一致:优先用已保存的 JSON 消息带 + 最后一轮助手摘要,否则回退消息表。
func (h *AgentHandler) loadHistoryFromAgentTrace(conversationID string) ([]agent.ChatMessage, error) { func (h *AgentHandler) loadHistoryFromAgentTrace(conversationID string) ([]agent.ChatMessage, error) {
+352
View File
@@ -0,0 +1,352 @@
package handler
import (
"context"
"encoding/json"
"errors"
"fmt"
"strings"
"sync"
"time"
"cyberstrike-ai/internal/agent"
"cyberstrike-ai/internal/audit"
"cyberstrike-ai/internal/config"
"cyberstrike-ai/internal/mcp"
"cyberstrike-ai/internal/multiagent"
"go.uber.org/zap"
)
const batchQueueWorkerIdlePoll = 200 * time.Millisecond
// executeBatchQueue 使用并发 worker 池执行批量任务队列。
func (h *AgentHandler) executeBatchQueue(queueID string) {
defer h.batchTaskManager.UnmarkQueueExecutor(queueID)
queue, exists := h.batchTaskManager.GetBatchQueue(queueID)
if !exists {
return
}
concurrency := normalizeBatchQueueConcurrency(queue.Concurrency)
h.logger.Info("开始执行批量任务队列", zap.String("queueId", queueID), zap.Int("concurrency", concurrency))
var wg sync.WaitGroup
for i := 0; i < concurrency; i++ {
wg.Add(1)
go func() {
defer wg.Done()
h.runBatchQueueWorker(queueID)
}()
}
wg.Wait()
h.tryFinalizeBatchQueue(queueID)
}
func (h *AgentHandler) runBatchQueueWorker(queueID string) {
for {
queue, exists := h.batchTaskManager.GetBatchQueue(queueID)
if batchQueueExecutionShouldStop(queue, exists) {
return
}
task, ok := h.batchTaskManager.ClaimNextPendingTask(queueID)
if !ok {
if !h.batchTaskManager.HasRunningTasks(queueID) {
return
}
time.Sleep(batchQueueWorkerIdlePoll)
continue
}
queue, _ = h.batchTaskManager.GetBatchQueue(queueID)
if queue == nil {
return
}
h.batchTaskManager.UpdateTaskStatus(queueID, task.ID, BatchTaskStatusRunning, "", "")
h.executeOneBatchSubTask(queueID, queue, task)
if h.batchTaskManager.TakeSingleRunTaskIfMatch(queueID, task.ID) {
h.batchTaskManager.UpdateQueueStatus(queueID, BatchQueueStatusPaused)
h.logger.Info("单条执行完成,队列已暂停", zap.String("queueId", queueID), zap.String("taskId", task.ID))
return
}
queue, exists = h.batchTaskManager.GetBatchQueue(queueID)
if batchQueueExecutionShouldStop(queue, exists) {
if !exists {
h.logger.Warn("批量队列在执行收尾时已不存在,安全退出", zap.String("queueId", queueID))
}
return
}
}
}
func (h *AgentHandler) tryFinalizeBatchQueue(queueID string) {
queue, exists := h.batchTaskManager.GetBatchQueue(queueID)
if !exists || queue == nil {
return
}
if queue.Status != BatchQueueStatusRunning {
return
}
if h.batchTaskManager.HasPendingOrRunningTasks(queueID) {
return
}
lastRunErr := ""
for _, t := range queue.Tasks {
if t != nil && t.Status == BatchTaskStatusFailed && t.Error != "" {
lastRunErr = t.Error
}
}
h.batchTaskManager.SetLastRunError(queueID, lastRunErr)
h.batchTaskManager.UpdateQueueStatus(queueID, BatchQueueStatusCompleted)
h.logger.Info("批量任务队列执行完成", zap.String("queueId", queueID))
}
// executeOneBatchSubTask 执行单条批量子任务(各自独立会话)。
func (h *AgentHandler) executeOneBatchSubTask(queueID string, queue *BatchTaskQueue, task *BatchTask) {
title := safeTruncateString(task.Message, 50)
batchMeta := audit.ConversationCreateMeta("batch_task")
batchMeta.ProjectID = effectiveProjectID(h.config, queue.ProjectID)
conv, err := h.db.CreateConversation(title, batchMeta)
if err != nil {
h.logger.Error("创建对话失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.Error(err))
h.batchTaskManager.UpdateTaskStatus(queueID, task.ID, BatchTaskStatusFailed, "", "创建对话失败: "+err.Error())
return
}
conversationID := conv.ID
h.batchTaskManager.UpdateTaskStatusWithConversationID(queueID, task.ID, BatchTaskStatusRunning, "", "", conversationID)
finalMessage := task.Message
var roleTools []string
if queue.Role != "" && queue.Role != "默认" {
if h.config.Roles != nil {
if role, exists := h.config.Roles[queue.Role]; exists && role.Enabled {
if role.UserPrompt != "" {
finalMessage = role.UserPrompt + "\n\n" + task.Message
h.logger.Info("应用角色用户提示词", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.String("role", queue.Role))
}
if len(role.Tools) > 0 {
roleTools = role.Tools
h.logger.Info("使用角色配置的工具列表", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.String("role", queue.Role), zap.Int("toolCount", len(roleTools)))
}
}
}
}
if _, err = h.db.AddMessage(conversationID, "user", task.Message, nil); err != nil {
h.logger.Error("保存用户消息失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.String("conversationId", conversationID), zap.Error(err))
}
assistantMsg, err := h.db.AddMessage(conversationID, "assistant", "处理中...", nil)
if err != nil {
h.logger.Error("创建助手消息失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.String("conversationId", conversationID), zap.Error(err))
assistantMsg = nil
}
var assistantMessageID string
if assistantMsg != nil {
assistantMessageID = assistantMsg.ID
}
h.logger.Info("执行批量任务", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.String("message", task.Message), zap.String("role", queue.Role), zap.String("conversationId", conversationID))
baseCtx, cancelWithCause := context.WithCancelCause(context.Background())
taskCtx, timeoutCancel := context.WithTimeout(baseCtx, 6*time.Hour)
registered := false
finishStatus := "completed"
defer func() {
h.batchTaskManager.SetTaskCancel(queueID, task.ID, nil)
timeoutCancel()
if registered {
if h.taskEventBus != nil {
ev := StreamEvent{Type: "done", Message: "", Data: map[string]interface{}{"conversationId": conversationID}}
if b, err := json.Marshal(ev); err == nil {
h.taskEventBus.Publish(conversationID, append(append([]byte("data: "), b...), '\n', '\n'))
}
}
h.tasks.FinishTask(conversationID, finishStatus)
}
cancelWithCause(nil)
}()
sendEvent := func(eventType, message string, data interface{}) {
if h.taskEventBus == nil {
return
}
ev := StreamEvent{Type: eventType, Message: message, Data: data}
b, err := json.Marshal(ev)
if err != nil {
b = []byte(`{"type":"error","message":"marshal failed"}`)
}
line := make([]byte, 0, len(b)+8)
line = append(line, []byte("data: ")...)
line = append(line, b...)
line = append(line, '\n', '\n')
h.taskEventBus.Publish(conversationID, line)
}
if _, err := h.tasks.StartTask(conversationID, task.Message, cancelWithCause); err != nil {
h.logger.Warn("批量队列子任务注册会话运行状态失败",
zap.String("queueId", queueID),
zap.String("taskId", task.ID),
zap.String("conversationId", conversationID),
zap.Error(err))
failMsg := err.Error()
if errors.Is(err, ErrTaskAlreadyRunning) {
failMsg = "会话已有任务正在执行,无法在该会话上并行启动批量子任务"
}
h.batchTaskManager.UpdateTaskStatus(queueID, task.ID, BatchTaskStatusFailed, "", failMsg)
return
}
registered = true
h.batchTaskManager.SetTaskCancel(queueID, task.ID, timeoutCancel)
progressCallback := h.createProgressCallback(taskCtx, cancelWithCause, conversationID, assistantMessageID, sendEvent)
taskCtx = mcp.WithMCPConversationID(taskCtx, conversationID)
taskCtx = mcp.WithToolRunRegistry(taskCtx, h.tasks)
taskCtx = mcp.WithEinoExecuteRunRegistry(taskCtx, h.tasks)
useBatchMulti := false
batchOrch := "deep"
am := strings.TrimSpace(strings.ToLower(queue.AgentMode))
if am == "multi" {
am = "deep"
}
if batchQueueWantsEino(queue.AgentMode) && h.config != nil && h.config.MultiAgent.Enabled {
useBatchMulti = true
batchOrch = config.NormalizeMultiAgentOrchestration(am)
} else if queue.AgentMode == "" && h.config != nil && h.config.MultiAgent.Enabled && h.config.MultiAgent.BatchUseMultiAgent {
useBatchMulti = true
batchOrch = "deep"
}
var resultMA *multiagent.RunResult
var runErr error
switch {
case useBatchMulti:
resultMA, runErr = multiagent.RunDeepAgent(taskCtx, h.config, &h.config.MultiAgent, h.agent, h.db, h.logger, conversationID, h.conversationProjectID(conversationID), finalMessage, []agent.ChatMessage{}, roleTools, progressCallback, h.agentsMarkdownDir, batchOrch, nil, h.projectBlackboardBlock(conversationID))
default:
if h.config == nil {
runErr = fmt.Errorf("服务器配置未加载")
} else {
resultMA, runErr = multiagent.RunEinoSingleChatModelAgent(taskCtx, h.config, &h.config.MultiAgent, h.agent, h.db, h.logger, conversationID, h.conversationProjectID(conversationID), finalMessage, []agent.ChatMessage{}, roleTools, progressCallback, nil, h.projectBlackboardBlock(conversationID))
}
}
if runErr != nil {
h.handleBatchSubTaskRunError(queueID, task, conversationID, assistantMessageID, baseCtx, taskCtx, resultMA, runErr, &finishStatus)
return
}
if resultMA == nil {
h.logger.Error("批量任务执行成功但无结果对象",
zap.String("queueId", queueID),
zap.String("taskId", task.ID),
zap.String("conversationId", conversationID))
h.batchTaskManager.UpdateTaskStatus(queueID, task.ID, BatchTaskStatusFailed, "", "内部错误:无执行结果")
return
}
h.logger.Info("批量任务执行成功", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.String("conversationId", conversationID))
resText := resultMA.Response
mcpIDs := resultMA.MCPExecutionIDs
lastIn := resultMA.LastAgentTraceInput
lastOut := resultMA.LastAgentTraceOutput
if assistantMessageID != "" {
if updateErr := h.db.UpdateAssistantMessageFinalize(assistantMessageID, resText, mcpIDs, multiagent.AggregatedReasoningFromTraceJSON(lastIn)); updateErr != nil {
h.logger.Warn("更新助手消息失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.Error(updateErr))
if _, err = h.db.AddMessage(conversationID, "assistant", resText, mcpIDs); err != nil {
h.logger.Error("保存助手消息失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.String("conversationId", conversationID), zap.Error(err))
}
}
} else if _, err = h.db.AddMessage(conversationID, "assistant", resText, mcpIDs); err != nil {
h.logger.Error("保存助手消息失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.String("conversationId", conversationID), zap.Error(err))
}
if lastIn != "" || lastOut != "" {
if err := h.db.SaveAgentTrace(conversationID, lastIn, lastOut); err != nil {
h.logger.Warn("保存代理轨迹失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.Error(err))
}
}
h.batchTaskManager.UpdateTaskStatusWithConversationID(queueID, task.ID, BatchTaskStatusCompleted, resText, "", conversationID)
}
func (h *AgentHandler) handleBatchSubTaskRunError(
queueID string,
task *BatchTask,
conversationID, assistantMessageID string,
baseCtx, taskCtx context.Context,
resultMA *multiagent.RunResult,
runErr error,
finishStatus *string,
) {
if shouldPersistEinoAgentTraceAfterRunError(baseCtx) {
h.persistEinoAgentTraceForResume(conversationID, resultMA)
}
errStr := runErr.Error()
partialResp := ""
if resultMA != nil {
partialResp = resultMA.Response
}
isCancelled := errors.Is(context.Cause(baseCtx), ErrTaskCancelled) ||
errors.Is(runErr, context.Canceled) ||
strings.Contains(strings.ToLower(errStr), "context canceled") ||
strings.Contains(strings.ToLower(errStr), "context cancelled") ||
(partialResp != "" && (strings.Contains(partialResp, "任务已被取消") || strings.Contains(partialResp, "任务执行中断")))
isTimeout := errors.Is(runErr, context.DeadlineExceeded) || errors.Is(context.Cause(taskCtx), context.DeadlineExceeded)
if isTimeout {
*finishStatus = "timeout"
} else if isCancelled {
*finishStatus = "cancelled"
} else {
*finishStatus = "failed"
}
if isCancelled {
h.logger.Info("批量任务被取消", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.String("conversationId", conversationID))
cancelMsg := "任务已被用户取消,后续操作已停止。"
if partialResp != "" && (strings.Contains(partialResp, "任务已被取消") || strings.Contains(partialResp, "任务执行中断")) {
cancelMsg = partialResp
}
if assistantMessageID != "" {
if updateErr := h.appendAssistantMessageNotice(assistantMessageID, cancelMsg); updateErr != nil {
h.logger.Warn("更新取消后的助手消息失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.Error(updateErr))
}
if err := h.db.AddProcessDetail(assistantMessageID, conversationID, "cancelled", cancelMsg, nil); err != nil {
h.logger.Warn("保存取消详情失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.Error(err))
}
} else if _, errMsg := h.db.AddMessage(conversationID, "assistant", cancelMsg, nil); errMsg != nil {
h.logger.Warn("保存取消消息失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.Error(errMsg))
}
h.batchTaskManager.UpdateTaskStatusWithConversationID(queueID, task.ID, BatchTaskStatusCancelled, cancelMsg, "", conversationID)
return
}
h.logger.Error("批量任务执行失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.String("conversationId", conversationID), zap.Error(runErr))
errorMsg := "执行失败: " + runErr.Error()
if assistantMessageID != "" {
if _, updateErr := h.db.Exec(
"UPDATE messages SET content = ?, updated_at = ? WHERE id = ?",
errorMsg,
time.Now(), assistantMessageID,
); updateErr != nil {
h.logger.Warn("更新失败后的助手消息失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.Error(updateErr))
}
if err := h.db.AddProcessDetail(assistantMessageID, conversationID, "error", errorMsg, nil); err != nil {
h.logger.Warn("保存错误详情失败", zap.String("queueId", queueID), zap.String("taskId", task.ID), zap.Error(err))
}
}
h.batchTaskManager.UpdateTaskStatus(queueID, task.ID, BatchTaskStatusFailed, "", runErr.Error())
}
+216 -43
View File
@@ -4,6 +4,7 @@ import (
"context" "context"
"crypto/rand" "crypto/rand"
"encoding/hex" "encoding/hex"
"errors"
"fmt" "fmt"
"sort" "sort"
"strings" "strings"
@@ -17,6 +18,15 @@ import (
"go.uber.org/zap" "go.uber.org/zap"
) )
var (
// ErrBatchQueueNotFound 队列不存在或已从内存卸载。
ErrBatchQueueNotFound = errors.New("batch queue not found")
// ErrBatchQueueExecutorActive executeBatchQueue 协程仍在收尾,禁止删除。
ErrBatchQueueExecutorActive = errors.New("batch queue executor is still active")
// ErrBatchQueueStillRunning 队列状态仍为 running(无活跃执行器时的兜底保护)。
ErrBatchQueueStillRunning = errors.New("batch queue is still running")
)
// 批量任务状态常量 // 批量任务状态常量
const ( const (
BatchQueueStatusPending = "pending" BatchQueueStatusPending = "pending"
@@ -39,6 +49,12 @@ const (
// MaxBatchQueueRoleLen 角色名最大长度 // MaxBatchQueueRoleLen 角色名最大长度
MaxBatchQueueRoleLen = 100 MaxBatchQueueRoleLen = 100
// DefaultBatchQueueConcurrency 批量队列默认并发数(串行)
DefaultBatchQueueConcurrency = 1
// MaxBatchQueueConcurrency 批量队列最大并发数
MaxBatchQueueConcurrency = 8
) )
// BatchTask 批量任务项 // BatchTask 批量任务项
@@ -67,6 +83,7 @@ type BatchTaskQueue struct {
LastScheduleError string `json:"lastScheduleError,omitempty"` LastScheduleError string `json:"lastScheduleError,omitempty"`
LastRunError string `json:"lastRunError,omitempty"` LastRunError string `json:"lastRunError,omitempty"`
ProjectID string `json:"projectId,omitempty"` ProjectID string `json:"projectId,omitempty"`
Concurrency int `json:"concurrency"` // 同时执行的子任务数,默认 1
Tasks []*BatchTask `json:"tasks"` Tasks []*BatchTask `json:"tasks"`
Status string `json:"status"` // pending, running, paused, completed, cancelled Status string `json:"status"` // pending, running, paused, completed, cancelled
CreatedAt time.Time `json:"createdAt"` CreatedAt time.Time `json:"createdAt"`
@@ -80,8 +97,9 @@ type BatchTaskManager struct {
db *database.DB db *database.DB
logger *zap.Logger logger *zap.Logger
queues map[string]*BatchTaskQueue queues map[string]*BatchTaskQueue
taskCancels map[string]context.CancelFunc // 存储每个队列当前任务的取消函数 taskCancels map[string]map[string]context.CancelFunc // queueID -> taskID -> 取消函数
singleRunTasks map[string]string // queueID -> taskID,单条执行完成后暂停队列 singleRunTasks map[string]string // queueID -> taskID,单条执行完成后暂停队列
queueExecutors map[string]struct{} // executeBatchQueue 协程活跃标记(与队列 status 解耦)
mu sync.RWMutex mu sync.RWMutex
} }
@@ -93,11 +111,56 @@ func NewBatchTaskManager(logger *zap.Logger) *BatchTaskManager {
return &BatchTaskManager{ return &BatchTaskManager{
logger: logger, logger: logger,
queues: make(map[string]*BatchTaskQueue), queues: make(map[string]*BatchTaskQueue),
taskCancels: make(map[string]context.CancelFunc), taskCancels: make(map[string]map[string]context.CancelFunc),
singleRunTasks: make(map[string]string), singleRunTasks: make(map[string]string),
queueExecutors: make(map[string]struct{}),
} }
} }
// batchQueueExecutionShouldStop 判断 executeBatchQueue 主循环是否应退出。
func batchQueueExecutionShouldStop(queue *BatchTaskQueue, exists bool) bool {
if !exists || queue == nil {
return true
}
switch queue.Status {
case BatchQueueStatusCancelled, BatchQueueStatusCompleted, BatchQueueStatusPaused:
return true
default:
return false
}
}
// TryMarkQueueExecutor 标记队列执行协程已启动;若已有执行协程则返回 false。
func (m *BatchTaskManager) TryMarkQueueExecutor(queueID string) bool {
m.mu.Lock()
defer m.mu.Unlock()
if _, exists := m.queueExecutors[queueID]; exists {
return false
}
m.queueExecutors[queueID] = struct{}{}
return true
}
// UnmarkQueueExecutor 清除队列执行协程标记(executeBatchQueue defer 调用)。
func (m *BatchTaskManager) UnmarkQueueExecutor(queueID string) {
m.mu.Lock()
defer m.mu.Unlock()
delete(m.queueExecutors, queueID)
}
// ForceUnmarkQueueExecutor 强制清除执行协程标记(暂停态单条重跑等场景回收陈旧槽位)。
func (m *BatchTaskManager) ForceUnmarkQueueExecutor(queueID string) {
m.UnmarkQueueExecutor(queueID)
}
// IsQueueExecutorActive 队列 executeBatchQueue 协程是否仍在运行。
func (m *BatchTaskManager) IsQueueExecutorActive(queueID string) bool {
m.mu.RLock()
defer m.mu.RUnlock()
_, ok := m.queueExecutors[queueID]
return ok
}
// SetDB 设置数据库连接 // SetDB 设置数据库连接
func (m *BatchTaskManager) SetDB(db *database.DB) { func (m *BatchTaskManager) SetDB(db *database.DB) {
m.mu.Lock() m.mu.Lock()
@@ -105,10 +168,22 @@ func (m *BatchTaskManager) SetDB(db *database.DB) {
m.db = db m.db = db
} }
// normalizeBatchQueueConcurrency 规范化队列并发数。
func normalizeBatchQueueConcurrency(n int) int {
if n < 1 {
return DefaultBatchQueueConcurrency
}
if n > MaxBatchQueueConcurrency {
return MaxBatchQueueConcurrency
}
return n
}
// CreateBatchQueue 创建批量任务队列 // CreateBatchQueue 创建批量任务队列
func (m *BatchTaskManager) CreateBatchQueue( func (m *BatchTaskManager) CreateBatchQueue(
title, role, agentMode, scheduleMode, cronExpr, projectID string, title, role, agentMode, scheduleMode, cronExpr, projectID string,
nextRunAt *time.Time, nextRunAt *time.Time,
concurrency int,
tasks []string, tasks []string,
) (*BatchTaskQueue, error) { ) (*BatchTaskQueue, error) {
// 输入校验 // 输入校验
@@ -136,6 +211,7 @@ func (m *BatchTaskManager) CreateBatchQueue(
CronExpr: strings.TrimSpace(cronExpr), CronExpr: strings.TrimSpace(cronExpr),
NextRunAt: nextRunAt, NextRunAt: nextRunAt,
ScheduleEnabled: true, ScheduleEnabled: true,
Concurrency: normalizeBatchQueueConcurrency(concurrency),
Tasks: make([]*BatchTask, 0, len(tasks)), Tasks: make([]*BatchTask, 0, len(tasks)),
Status: BatchQueueStatusPending, Status: BatchQueueStatusPending,
CreatedAt: time.Now(), CreatedAt: time.Now(),
@@ -177,6 +253,7 @@ func (m *BatchTaskManager) CreateBatchQueue(
queue.CronExpr, queue.CronExpr,
queue.NextRunAt, queue.NextRunAt,
queue.ProjectID, queue.ProjectID,
queue.Concurrency,
dbTasks, dbTasks,
); err != nil { ); err != nil {
m.logger.Warn("batch queue DB create failed", zap.String("queueId", queueID), zap.Error(err)) m.logger.Warn("batch queue DB create failed", zap.String("queueId", queueID), zap.Error(err))
@@ -272,6 +349,7 @@ func (m *BatchTaskManager) loadQueueFromDB(queueID string) *BatchTaskQueue {
if queueRow.ProjectID.Valid { if queueRow.ProjectID.Valid {
queue.ProjectID = strings.TrimSpace(queueRow.ProjectID.String) queue.ProjectID = strings.TrimSpace(queueRow.ProjectID.String)
} }
queue.Concurrency = batchQueueConcurrencyFromRow(queueRow)
if queueRow.StartedAt.Valid { if queueRow.StartedAt.Valid {
queue.StartedAt = &queueRow.StartedAt.Time queue.StartedAt = &queueRow.StartedAt.Time
} }
@@ -511,6 +589,7 @@ func (m *BatchTaskManager) LoadFromDB() error {
if queueRow.ProjectID.Valid { if queueRow.ProjectID.Valid {
queue.ProjectID = strings.TrimSpace(queueRow.ProjectID.String) queue.ProjectID = strings.TrimSpace(queueRow.ProjectID.String)
} }
queue.Concurrency = batchQueueConcurrencyFromRow(queueRow)
if queueRow.StartedAt.Valid { if queueRow.StartedAt.Valid {
queue.StartedAt = &queueRow.StartedAt.Time queue.StartedAt = &queueRow.StartedAt.Time
} }
@@ -651,8 +730,16 @@ func (m *BatchTaskManager) UpdateQueueSchedule(queueID, scheduleMode, cronExpr s
} }
} }
// UpdateQueueMetadata 更新队列标题、角色和代理模式(非 running 时可用) // batchQueueConcurrencyFromRow 从数据库行读取并发数(缺省为 1)。
func (m *BatchTaskManager) UpdateQueueMetadata(queueID, title, role, agentMode string) error { func batchQueueConcurrencyFromRow(row *database.BatchTaskQueueRow) int {
if row == nil || !row.Concurrency.Valid {
return DefaultBatchQueueConcurrency
}
return normalizeBatchQueueConcurrency(int(row.Concurrency.Int64))
}
// UpdateQueueMetadata 更新队列标题、角色、代理模式和并发数(非 running 时可用)
func (m *BatchTaskManager) UpdateQueueMetadata(queueID, title, role, agentMode string, concurrency *int) error {
if utf8.RuneCountInString(title) > MaxBatchQueueTitleLen { if utf8.RuneCountInString(title) > MaxBatchQueueTitleLen {
return fmt.Errorf("标题不能超过 %d 个字符", MaxBatchQueueTitleLen) return fmt.Errorf("标题不能超过 %d 个字符", MaxBatchQueueTitleLen)
} }
@@ -680,9 +767,12 @@ func (m *BatchTaskManager) UpdateQueueMetadata(queueID, title, role, agentMode s
queue.Title = title queue.Title = title
queue.Role = role queue.Role = role
queue.AgentMode = agentMode queue.AgentMode = agentMode
if concurrency != nil {
queue.Concurrency = normalizeBatchQueueConcurrency(*concurrency)
}
if m.db != nil { if m.db != nil {
if err := m.db.UpdateBatchQueueMetadata(queueID, title, role, agentMode); err != nil { if err := m.db.UpdateBatchQueueMetadata(queueID, title, role, agentMode, queue.Concurrency); err != nil {
m.logger.Warn("batch queue DB metadata update failed", zap.String("queueId", queueID), zap.Error(err)) m.logger.Warn("batch queue DB metadata update failed", zap.String("queueId", queueID), zap.Error(err))
} }
} }
@@ -868,7 +958,6 @@ func (m *BatchTaskManager) AddTaskToQueue(queueID, message string) (*BatchTask,
// PrepareSingleTaskRun 准备单条执行:重置目标任务(若已有结果)并定位队列索引 // PrepareSingleTaskRun 准备单条执行:重置目标任务(若已有结果)并定位队列索引
func (m *BatchTaskManager) PrepareSingleTaskRun(queueID, taskID string) error { func (m *BatchTaskManager) PrepareSingleTaskRun(queueID, taskID string) error {
var cancelFunc context.CancelFunc
var siblingRunningIDs []string var siblingRunningIDs []string
m.mu.Lock() m.mu.Lock()
@@ -898,11 +987,9 @@ func (m *BatchTaskManager) PrepareSingleTaskRun(queueID, taskID string) error {
} }
// 暂停态:中止在途子任务并收口仍标记 running 的其它子任务,以便单条执行非冲突项 // 暂停态:中止在途子任务并收口仍标记 running 的其它子任务,以便单条执行非冲突项
var cancelFuncs []context.CancelFunc
if queue.Status == BatchQueueStatusPaused { if queue.Status == BatchQueueStatusPaused {
if c, ok := m.taskCancels[queueID]; ok { cancelFuncs = m.drainTaskCancelsLocked(queueID)
cancelFunc = c
delete(m.taskCancels, queueID)
}
for _, t := range queue.Tasks { for _, t := range queue.Tasks {
if t != nil && t.ID != taskID && t.Status == BatchTaskStatusRunning { if t != nil && t.ID != taskID && t.Status == BatchTaskStatusRunning {
siblingRunningIDs = append(siblingRunningIDs, t.ID) siblingRunningIDs = append(siblingRunningIDs, t.ID)
@@ -914,8 +1001,10 @@ func (m *BatchTaskManager) PrepareSingleTaskRun(queueID, taskID string) error {
resumeQueue := queue.Status == BatchQueueStatusCompleted || queue.Status == BatchQueueStatusCancelled resumeQueue := queue.Status == BatchQueueStatusCompleted || queue.Status == BatchQueueStatusCancelled
m.mu.Unlock() m.mu.Unlock()
if cancelFunc != nil { for _, c := range cancelFuncs {
cancelFunc() if c != nil {
c()
}
} }
const staleRunMsg = "为单条执行其它任务,已中止" const staleRunMsg = "为单条执行其它任务,已中止"
for _, sid := range siblingRunningIDs { for _, sid := range siblingRunningIDs {
@@ -1089,7 +1178,90 @@ func queueAllowsSingleTaskRunLocked(queue *BatchTaskQueue, task *BatchTask) bool
} }
} }
// GetNextTask 取下一个待执行任务 // ClaimNextPendingTask 原子领取下一个待执行任务(并发 worker 安全)。
func (m *BatchTaskManager) ClaimNextPendingTask(queueID string) (*BatchTask, bool) {
m.mu.Lock()
defer m.mu.Unlock()
queue, exists := m.queues[queueID]
if !exists || queue == nil {
return nil, false
}
if queue.Status == BatchQueueStatusCancelled || queue.Status == BatchQueueStatusCompleted || queue.Status == BatchQueueStatusPaused {
return nil, false
}
onlyTaskID := ""
if m.singleRunTasks != nil {
onlyTaskID = m.singleRunTasks[queueID]
}
for i, task := range queue.Tasks {
if task == nil || task.Status != BatchTaskStatusPending {
continue
}
if onlyTaskID != "" && task.ID != onlyTaskID {
continue
}
task.Status = BatchTaskStatusRunning
queue.CurrentIndex = i
return task, true
}
return nil, false
}
// HasRunningTasks 队列是否仍有 running 状态的子任务。
func (m *BatchTaskManager) HasRunningTasks(queueID string) bool {
m.mu.RLock()
defer m.mu.RUnlock()
queue, exists := m.queues[queueID]
if !exists || queue == nil {
return false
}
for _, task := range queue.Tasks {
if task != nil && task.Status == BatchTaskStatusRunning {
return true
}
}
return false
}
// HasPendingOrRunningTasks 队列是否仍有未完成的子任务。
func (m *BatchTaskManager) HasPendingOrRunningTasks(queueID string) bool {
m.mu.RLock()
defer m.mu.RUnlock()
queue, exists := m.queues[queueID]
if !exists || queue == nil {
return false
}
for _, task := range queue.Tasks {
if task == nil {
continue
}
if task.Status == BatchTaskStatusPending || task.Status == BatchTaskStatusRunning {
return true
}
}
return false
}
// drainTaskCancelsLocked 取出并清空队列下所有子任务取消函数(调用方须已持 m.mu)。
func (m *BatchTaskManager) drainTaskCancelsLocked(queueID string) []context.CancelFunc {
taskMap, ok := m.taskCancels[queueID]
if !ok || len(taskMap) == 0 {
return nil
}
cancels := make([]context.CancelFunc, 0, len(taskMap))
for _, c := range taskMap {
if c != nil {
cancels = append(cancels, c)
}
}
delete(m.taskCancels, queueID)
return cancels
}
// GetNextTask 获取下一个待执行的任务(串行兼容,优先使用 ClaimNextPendingTask
func (m *BatchTaskManager) GetNextTask(queueID string) (*BatchTask, bool) { func (m *BatchTaskManager) GetNextTask(queueID string) (*BatchTask, bool) {
m.mu.Lock() m.mu.Lock()
defer m.mu.Unlock() defer m.mu.Unlock()
@@ -1130,20 +1302,28 @@ func (m *BatchTaskManager) MoveToNextTask(queueID string) {
} }
} }
// SetTaskCancel 设置当前任务的取消函数 // SetTaskCancel 设置任务的取消函数
func (m *BatchTaskManager) SetTaskCancel(queueID string, cancel context.CancelFunc) { func (m *BatchTaskManager) SetTaskCancel(queueID, taskID string, cancel context.CancelFunc) {
m.mu.Lock() m.mu.Lock()
defer m.mu.Unlock() defer m.mu.Unlock()
if cancel != nil { if cancel == nil {
m.taskCancels[queueID] = cancel if taskMap, ok := m.taskCancels[queueID]; ok {
} else { delete(taskMap, taskID)
delete(m.taskCancels, queueID) if len(taskMap) == 0 {
delete(m.taskCancels, queueID)
}
}
return
} }
if m.taskCancels[queueID] == nil {
m.taskCancels[queueID] = make(map[string]context.CancelFunc)
}
m.taskCancels[queueID][taskID] = cancel
} }
// PauseQueue 暂停队列 // PauseQueue 暂停队列
func (m *BatchTaskManager) PauseQueue(queueID string) bool { func (m *BatchTaskManager) PauseQueue(queueID string) bool {
var cancelFunc context.CancelFunc var cancelFuncs []context.CancelFunc
m.mu.Lock() m.mu.Lock()
queue, exists := m.queues[queueID] queue, exists := m.queues[queueID]
@@ -1168,17 +1348,11 @@ func (m *BatchTaskManager) PauseQueue(queueID string) bool {
} }
queue.Status = BatchQueueStatusPaused queue.Status = BatchQueueStatusPaused
cancelFuncs = m.drainTaskCancelsLocked(queueID)
// 取消当前正在执行的任务(通过取消context)
if cancel, ok := m.taskCancels[queueID]; ok {
cancelFunc = cancel
delete(m.taskCancels, queueID)
}
m.mu.Unlock() m.mu.Unlock()
// 释放锁后执行取消回调(cancel 可能阻塞,不应持锁) for _, c := range cancelFuncs {
if cancelFunc != nil { c()
cancelFunc()
} }
return true return true
@@ -1187,7 +1361,7 @@ func (m *BatchTaskManager) PauseQueue(queueID string) bool {
// CancelQueue 取消队列(保留此方法以保持向后兼容,但建议使用PauseQueue) // CancelQueue 取消队列(保留此方法以保持向后兼容,但建议使用PauseQueue)
func (m *BatchTaskManager) CancelQueue(queueID string) bool { func (m *BatchTaskManager) CancelQueue(queueID string) bool {
now := time.Now() now := time.Now()
var cancelFunc context.CancelFunc var cancelFuncs []context.CancelFunc
m.mu.Lock() m.mu.Lock()
queue, exists := m.queues[queueID] queue, exists := m.queues[queueID]
@@ -1228,34 +1402,33 @@ func (m *BatchTaskManager) CancelQueue(queueID string) bool {
} }
} }
// 取消当前正在执行的任务 cancelFuncs = m.drainTaskCancelsLocked(queueID)
if cancel, ok := m.taskCancels[queueID]; ok {
cancelFunc = cancel
delete(m.taskCancels, queueID)
}
m.mu.Unlock() m.mu.Unlock()
// 释放锁后执行取消回调(cancel 可能阻塞,不应持锁) for _, c := range cancelFuncs {
if cancelFunc != nil { c()
cancelFunc()
} }
return true return true
} }
// DeleteQueue 删除队列(运行中的队列不允许删除) // DeleteQueue 删除队列。执行协程活跃或 status 为 running 时拒绝删除,避免 executeBatchQueue 空指针 panic。
func (m *BatchTaskManager) DeleteQueue(queueID string) bool { func (m *BatchTaskManager) DeleteQueue(queueID string) error {
m.mu.Lock() m.mu.Lock()
defer m.mu.Unlock() defer m.mu.Unlock()
queue, exists := m.queues[queueID] queue, exists := m.queues[queueID]
if !exists { if !exists {
return false return ErrBatchQueueNotFound
}
if _, exec := m.queueExecutors[queueID]; exec {
return ErrBatchQueueExecutorActive
} }
// 运行中的队列不允许删除,防止孤儿协程和数据丢失 // 运行中的队列不允许删除,防止孤儿协程和数据丢失
if queue.Status == BatchQueueStatusRunning { if queue.Status == BatchQueueStatusRunning {
return false return ErrBatchQueueStillRunning
} }
// 清理取消函数 // 清理取消函数
@@ -1269,7 +1442,7 @@ func (m *BatchTaskManager) DeleteQueue(queueID string) bool {
} }
delete(m.queues, queueID) delete(m.queues, queueID)
return true return nil
} }
// generateShortID 生成短ID // generateShortID 生成短ID
+121
View File
@@ -0,0 +1,121 @@
package handler
import (
"errors"
"testing"
"go.uber.org/zap"
)
func TestNormalizeBatchQueueConcurrency(t *testing.T) {
if got := normalizeBatchQueueConcurrency(0); got != DefaultBatchQueueConcurrency {
t.Fatalf("expected default %d, got %d", DefaultBatchQueueConcurrency, got)
}
if got := normalizeBatchQueueConcurrency(99); got != MaxBatchQueueConcurrency {
t.Fatalf("expected max %d, got %d", MaxBatchQueueConcurrency, got)
}
}
func TestClaimNextPendingTaskParallel(t *testing.T) {
m := NewBatchTaskManager(zap.NewNop())
queue, err := m.CreateBatchQueue("test", "", "eino_single", "manual", "", "", nil, 3, []string{"a", "b", "c"})
if err != nil {
t.Fatalf("CreateBatchQueue: %v", err)
}
m.UpdateQueueStatus(queue.ID, BatchQueueStatusRunning)
t1, ok1 := m.ClaimNextPendingTask(queue.ID)
t2, ok2 := m.ClaimNextPendingTask(queue.ID)
if !ok1 || !ok2 || t1.ID == t2.ID {
t.Fatalf("expected two distinct claims, got ok1=%v ok2=%v t1=%v t2=%v", ok1, ok2, t1, t2)
}
if t1.Status != BatchTaskStatusRunning || t2.Status != BatchTaskStatusRunning {
t.Fatalf("claimed tasks should be running")
}
t3, ok3 := m.ClaimNextPendingTask(queue.ID)
if !ok3 {
t.Fatal("expected third claim")
}
_, ok4 := m.ClaimNextPendingTask(queue.ID)
if ok4 {
t.Fatal("expected no fourth pending task")
}
_ = t3
}
func TestBatchQueueExecutionShouldStop(t *testing.T) {
t.Parallel()
if !batchQueueExecutionShouldStop(nil, false) {
t.Fatal("expected stop when queue missing")
}
if !batchQueueExecutionShouldStop(nil, true) {
t.Fatal("expected stop when queue is nil but exists=true")
}
q := &BatchTaskQueue{Status: BatchQueueStatusRunning}
if batchQueueExecutionShouldStop(q, true) {
t.Fatal("expected continue when running")
}
q.Status = BatchQueueStatusCancelled
if !batchQueueExecutionShouldStop(q, true) {
t.Fatal("expected stop when cancelled")
}
}
func TestDeleteQueueBlockedWhileExecutorActive(t *testing.T) {
t.Parallel()
m := NewBatchTaskManager(zap.NewNop())
queue, err := m.CreateBatchQueue("test", "", "eino_single", "manual", "", "", nil, 1, []string{"hello"})
if err != nil {
t.Fatalf("CreateBatchQueue: %v", err)
}
if !m.TryMarkQueueExecutor(queue.ID) {
t.Fatal("expected to mark executor")
}
m.UpdateQueueStatus(queue.ID, BatchQueueStatusCancelled)
err = m.DeleteQueue(queue.ID)
if !errors.Is(err, ErrBatchQueueExecutorActive) {
t.Fatalf("expected ErrBatchQueueExecutorActive, got %v", err)
}
if _, ok := m.GetBatchQueue(queue.ID); !ok {
t.Fatal("queue should still exist while executor active")
}
m.UnmarkQueueExecutor(queue.ID)
if err := m.DeleteQueue(queue.ID); err != nil {
t.Fatalf("expected delete after executor unmarked, got %v", err)
}
if _, ok := m.GetBatchQueue(queue.ID); ok {
t.Fatal("queue should be deleted")
}
}
func TestDeleteQueueBlockedWhileRunning(t *testing.T) {
t.Parallel()
m := NewBatchTaskManager(zap.NewNop())
queue, err := m.CreateBatchQueue("test", "", "eino_single", "manual", "", "", nil, 1, []string{"hello"})
if err != nil {
t.Fatalf("CreateBatchQueue: %v", err)
}
m.UpdateQueueStatus(queue.ID, BatchQueueStatusRunning)
err = m.DeleteQueue(queue.ID)
if !errors.Is(err, ErrBatchQueueStillRunning) {
t.Fatalf("expected ErrBatchQueueStillRunning, got %v", err)
}
}
func TestTryMarkQueueExecutorDedupes(t *testing.T) {
t.Parallel()
m := NewBatchTaskManager(zap.NewNop())
if !m.TryMarkQueueExecutor("q-1") {
t.Fatal("first mark should succeed")
}
if m.TryMarkQueueExecutor("q-1") {
t.Fatal("second mark should fail")
}
m.UnmarkQueueExecutor("q-1")
if !m.TryMarkQueueExecutor("q-1") {
t.Fatal("mark after unmark should succeed")
}
}
+30 -4
View File
@@ -3,6 +3,7 @@ package handler
import ( import (
"context" "context"
"encoding/json" "encoding/json"
"errors"
"fmt" "fmt"
"strconv" "strconv"
"strings" "strings"
@@ -181,6 +182,10 @@ func RegisterBatchTaskMCPTools(mcpServer *mcp.Server, h *AgentHandler, logger *z
"type": "string", "type": "string",
"description": "队列内子对话绑定的项目 ID(可选,未指定时使用 config.project.default_project_id", "description": "队列内子对话绑定的项目 ID(可选,未指定时使用 config.project.default_project_id",
}, },
"concurrency": map[string]interface{}{
"type": "integer",
"description": "同时执行的子任务数,默认 1(串行),最大 8。含扫描类工具时建议 1-2。",
},
}, },
}, },
}, func(ctx context.Context, args map[string]interface{}) (*mcp.ToolResult, error) { }, func(ctx context.Context, args map[string]interface{}) (*mcp.ToolResult, error) {
@@ -210,7 +215,8 @@ func RegisterBatchTaskMCPTools(mcpServer *mcp.Server, h *AgentHandler, logger *z
executeNow = false executeNow = false
} }
projectID := strings.TrimSpace(mcpArgString(args, "project_id")) projectID := strings.TrimSpace(mcpArgString(args, "project_id"))
queue, createErr := h.batchTaskManager.CreateBatchQueue(title, role, agentMode, scheduleMode, cronExpr, projectID, nextRunAt, tasks) concurrency := int(mcpArgFloat(args, "concurrency"))
queue, createErr := h.batchTaskManager.CreateBatchQueue(title, role, agentMode, scheduleMode, cronExpr, projectID, nextRunAt, concurrency, tasks)
if createErr != nil { if createErr != nil {
return batchMCPTextResult("创建队列失败: "+createErr.Error(), true), nil return batchMCPTextResult("创建队列失败: "+createErr.Error(), true), nil
} }
@@ -365,8 +371,17 @@ func RegisterBatchTaskMCPTools(mcpServer *mcp.Server, h *AgentHandler, logger *z
if qid == "" { if qid == "" {
return batchMCPTextResult("queue_id 不能为空", true), nil return batchMCPTextResult("queue_id 不能为空", true), nil
} }
if !h.batchTaskManager.DeleteQueue(qid) { if err := h.batchTaskManager.DeleteQueue(qid); err != nil {
return batchMCPTextResult("删除失败:队列不存在", true), nil switch {
case errors.Is(err, ErrBatchQueueNotFound):
return batchMCPTextResult("删除失败:队列不存在", true), nil
case errors.Is(err, ErrBatchQueueExecutorActive):
return batchMCPTextResult("删除失败:队列执行器仍在运行,请稍后再试", true), nil
case errors.Is(err, ErrBatchQueueStillRunning):
return batchMCPTextResult("删除失败:队列正在运行中", true), nil
default:
return batchMCPTextResult("删除失败:"+err.Error(), true), nil
}
} }
logger.Info("MCP batch_task_delete", zap.String("queueId", qid)) logger.Info("MCP batch_task_delete", zap.String("queueId", qid))
return batchMCPTextResult("队列已删除。", false), nil return batchMCPTextResult("队列已删除。", false), nil
@@ -397,6 +412,10 @@ func RegisterBatchTaskMCPTools(mcpServer *mcp.Server, h *AgentHandler, logger *z
"description": "代理模式:eino_single、deep、plan_execute、supervisor", "description": "代理模式:eino_single、deep、plan_execute、supervisor",
"enum": []string{"eino_single", "deep", "plan_execute", "supervisor"}, "enum": []string{"eino_single", "deep", "plan_execute", "supervisor"},
}, },
"concurrency": map[string]interface{}{
"type": "integer",
"description": "同时执行的子任务数,默认 1,最大 8",
},
}, },
"required": []string{"queue_id"}, "required": []string{"queue_id"},
}, },
@@ -408,7 +427,12 @@ func RegisterBatchTaskMCPTools(mcpServer *mcp.Server, h *AgentHandler, logger *z
title := mcpArgString(args, "title") title := mcpArgString(args, "title")
role := mcpArgString(args, "role") role := mcpArgString(args, "role")
agentMode := mcpArgString(args, "agent_mode") agentMode := mcpArgString(args, "agent_mode")
if err := h.batchTaskManager.UpdateQueueMetadata(qid, title, role, agentMode); err != nil { var concurrency *int
if raw, ok := args["concurrency"]; ok && raw != nil {
v := int(mcpArgFloat(args, "concurrency"))
concurrency = &v
}
if err := h.batchTaskManager.UpdateQueueMetadata(qid, title, role, agentMode, concurrency); err != nil {
return batchMCPTextResult(err.Error(), true), nil return batchMCPTextResult(err.Error(), true), nil
} }
updated, _ := h.batchTaskManager.GetBatchQueue(qid) updated, _ := h.batchTaskManager.GetBatchQueue(qid)
@@ -652,6 +676,7 @@ type batchTaskQueueMCPListItem struct {
StartedAt *time.Time `json:"startedAt,omitempty"` StartedAt *time.Time `json:"startedAt,omitempty"`
CompletedAt *time.Time `json:"completedAt,omitempty"` CompletedAt *time.Time `json:"completedAt,omitempty"`
CurrentIndex int `json:"currentIndex"` CurrentIndex int `json:"currentIndex"`
Concurrency int `json:"concurrency"`
TaskTotal int `json:"task_total"` TaskTotal int `json:"task_total"`
TaskCounts map[string]int `json:"task_counts"` TaskCounts map[string]int `json:"task_counts"`
Tasks []batchTaskMCPListSummary `json:"tasks"` Tasks []batchTaskMCPListSummary `json:"tasks"`
@@ -715,6 +740,7 @@ func toBatchTaskQueueMCPListItem(q *BatchTaskQueue) batchTaskQueueMCPListItem {
StartedAt: q.StartedAt, StartedAt: q.StartedAt,
CompletedAt: q.CompletedAt, CompletedAt: q.CompletedAt,
CurrentIndex: q.CurrentIndex, CurrentIndex: q.CurrentIndex,
Concurrency: q.Concurrency,
TaskTotal: len(tasks), TaskTotal: len(tasks),
TaskCounts: counts, TaskCounts: counts,
Tasks: tasks, Tasks: tasks,
+72 -7
View File
@@ -12,11 +12,17 @@ import (
"go.uber.org/zap" "go.uber.org/zap"
) )
// ConversationTaskStopper cancels in-flight agent work when a conversation is removed.
type ConversationTaskStopper interface {
CancelRunningTaskForConversation(conversationID string)
}
// ConversationHandler 对话处理器 // ConversationHandler 对话处理器
type ConversationHandler struct { type ConversationHandler struct {
db *database.DB db *database.DB
logger *zap.Logger logger *zap.Logger
audit *audit.Service audit *audit.Service
taskStopper ConversationTaskStopper
} }
// SetAudit wires platform audit logging. // SetAudit wires platform audit logging.
@@ -24,6 +30,11 @@ func (h *ConversationHandler) SetAudit(s *audit.Service) {
h.audit = s h.audit = s
} }
// SetTaskStopper wires cancellation of in-flight agent tasks on conversation delete.
func (h *ConversationHandler) SetTaskStopper(stopper ConversationTaskStopper) {
h.taskStopper = stopper
}
// NewConversationHandler 创建新的对话处理器 // NewConversationHandler 创建新的对话处理器
func NewConversationHandler(db *database.DB, logger *zap.Logger) *ConversationHandler { func NewConversationHandler(db *database.DB, logger *zap.Logger) *ConversationHandler {
return &ConversationHandler{ return &ConversationHandler{
@@ -165,6 +176,9 @@ func (h *ConversationHandler) GetConversation(c *gin.Context) {
} }
// GetMessageProcessDetails 获取指定消息的过程详情(按需加载) // GetMessageProcessDetails 获取指定消息的过程详情(按需加载)
// 查询参数:
// - summary=1:仅返回摘要(total / iterationCount / maxIteration
// - limit + offset:分页返回 processDetails(未指定 limit 时保持全量兼容)
func (h *ConversationHandler) GetMessageProcessDetails(c *gin.Context) { func (h *ConversationHandler) GetMessageProcessDetails(c *gin.Context) {
messageID := c.Param("id") messageID := c.Param("id")
if messageID == "" { if messageID == "" {
@@ -172,6 +186,51 @@ func (h *ConversationHandler) GetMessageProcessDetails(c *gin.Context) {
return return
} }
summaryStr := strings.TrimSpace(c.Query("summary"))
if summaryStr == "1" || strings.EqualFold(summaryStr, "true") || strings.EqualFold(summaryStr, "yes") {
summary, err := h.db.GetProcessDetailsSummary(messageID)
if err != nil {
h.logger.Error("获取过程详情摘要失败", zap.Error(err))
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, gin.H{"summary": summary})
return
}
limitStr := strings.TrimSpace(c.Query("limit"))
if limitStr != "" {
limit, err := strconv.Atoi(limitStr)
if err != nil || limit <= 0 {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid limit"})
return
}
if limit > 500 {
limit = 500
}
offset, _ := strconv.Atoi(strings.TrimSpace(c.Query("offset")))
if offset < 0 {
offset = 0
}
details, total, err := h.db.GetProcessDetailsPage(messageID, limit, offset)
if err != nil {
h.logger.Error("分页获取过程详情失败", zap.Error(err))
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
details = database.DedupeConsecutiveProcessDetails(details)
out := processDetailsToJSON(h.logger, details)
c.JSON(http.StatusOK, gin.H{
"processDetails": out,
"total": total,
"offset": offset,
"limit": limit,
"hasMore": offset+len(out) < total,
})
return
}
details, err := h.db.GetProcessDetails(messageID) details, err := h.db.GetProcessDetails(messageID)
if err != nil { if err != nil {
h.logger.Error("获取过程详情失败", zap.Error(err)) h.logger.Error("获取过程详情失败", zap.Error(err))
@@ -180,14 +239,17 @@ func (h *ConversationHandler) GetMessageProcessDetails(c *gin.Context) {
} }
details = database.DedupeConsecutiveProcessDetails(details) details = database.DedupeConsecutiveProcessDetails(details)
out := processDetailsToJSON(h.logger, details)
c.JSON(http.StatusOK, gin.H{"processDetails": out, "total": len(out)})
}
// 转换为前端期望的 JSON 结构(与 GetConversation 中 processDetails 结构一致) func processDetailsToJSON(logger *zap.Logger, details []database.ProcessDetail) []map[string]interface{} {
out := make([]map[string]interface{}, 0, len(details)) out := make([]map[string]interface{}, 0, len(details))
for _, d := range details { for _, d := range details {
var data interface{} var data interface{}
if d.Data != "" { if d.Data != "" {
if err := json.Unmarshal([]byte(d.Data), &data); err != nil { if err := json.Unmarshal([]byte(d.Data), &data); err != nil {
h.logger.Warn("解析过程详情数据失败", zap.Error(err)) logger.Warn("解析过程详情数据失败", zap.Error(err))
} }
} }
out = append(out, map[string]interface{}{ out = append(out, map[string]interface{}{
@@ -200,8 +262,7 @@ func (h *ConversationHandler) GetMessageProcessDetails(c *gin.Context) {
"createdAt": d.CreatedAt, "createdAt": d.CreatedAt,
}) })
} }
return out
c.JSON(http.StatusOK, gin.H{"processDetails": out})
} }
// UpdateConversationRequest 更新对话请求 // UpdateConversationRequest 更新对话请求
@@ -245,6 +306,10 @@ func (h *ConversationHandler) UpdateConversation(c *gin.Context) {
func (h *ConversationHandler) DeleteConversation(c *gin.Context) { func (h *ConversationHandler) DeleteConversation(c *gin.Context) {
id := c.Param("id") id := c.Param("id")
if h.taskStopper != nil {
h.taskStopper.CancelRunningTaskForConversation(id)
}
if err := h.db.DeleteConversation(id); err != nil { if err := h.db.DeleteConversation(id); err != nil {
h.logger.Error("删除对话失败", zap.Error(err)) h.logger.Error("删除对话失败", zap.Error(err))
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
@@ -0,0 +1,30 @@
package handler
import (
"context"
"testing"
"time"
"go.uber.org/zap"
)
func TestConversationHandlerDeleteConversationCancelsRunningTask(t *testing.T) {
tm := NewAgentTaskManager()
ctx, cancel := context.WithCancelCause(context.Background())
_, err := tm.StartTask("conv-1", "hello", cancel)
if err != nil {
t.Fatalf("StartTask: %v", err)
}
h := &AgentHandler{tasks: tm, logger: zap.NewNop()}
h.CancelRunningTaskForConversation("conv-1")
select {
case <-ctx.Done():
case <-time.After(2 * time.Second):
t.Fatal("task context was not cancelled")
}
if cause := context.Cause(ctx); cause != ErrTaskCancelled {
t.Fatalf("expected ErrTaskCancelled, got %v", cause)
}
}
-153
View File
@@ -2,31 +2,11 @@ package handler
import ( import (
"context" "context"
"errors"
"fmt"
"strings"
"time"
"cyberstrike-ai/internal/agent" "cyberstrike-ai/internal/agent"
"cyberstrike-ai/internal/multiagent" "cyberstrike-ai/internal/multiagent"
"go.uber.org/zap"
) )
func (h *AgentHandler) einoRunRetryMaxAttempts() int {
if h.config != nil {
return multiagent.RunRetryMaxAttemptsFromConfig(&h.config.MultiAgent.EinoMiddleware)
}
return multiagent.RunRetryMaxAttemptsFromConfig(nil)
}
func (h *AgentHandler) einoRunRetryMaxBackoffSec() int {
if h.config != nil && h.config.MultiAgent.EinoMiddleware.RunRetryMaxBackoffSec > 0 {
return h.config.MultiAgent.EinoMiddleware.RunRetryMaxBackoffSec
}
return 0
}
// applyEinoTraceResumeSegment 中断并继续:persist last_react_* → loadHistory,可选替换下一段 user 文案。 // applyEinoTraceResumeSegment 中断并继续:persist last_react_* → loadHistory,可选替换下一段 user 文案。
func (h *AgentHandler) applyEinoTraceResumeSegment( func (h *AgentHandler) applyEinoTraceResumeSegment(
conversationID string, conversationID string,
@@ -45,136 +25,3 @@ func (h *AgentHandler) applyEinoTraceResumeSegment(
*curFinalMessage = segmentUserMessage *curFinalMessage = segmentUserMessage
} }
} }
// applyEinoTransientRetrySegment 临时错误重试:恢复轨迹并保留本请求原始 user 文案(不注入续跑说明)。
// segmentUserMessage 为本轮 HTTP 请求开始时用户发送的内容,避免因清空 finalMessage 而丢失「你好」等短句。
func (h *AgentHandler) applyEinoTransientRetrySegment(
conversationID string,
result *multiagent.RunResult,
curHistory *[]agent.ChatMessage,
curFinalMessage *string,
segmentUserMessage string,
) {
if shouldPersistEinoAgentTraceAfterRunError(context.Background()) {
h.persistEinoAgentTraceForResume(conversationID, result)
}
if hist, err := h.loadHistoryFromAgentTrace(conversationID); err == nil && len(hist) > 0 {
*curHistory = hist
}
if s := strings.TrimSpace(segmentUserMessage); s != "" {
*curFinalMessage = segmentUserMessage
}
}
// handleEinoTransientRetryContinue 在 SSE 任务循环内处理临时错误重试;返回 true 表示外层 for 应 continue。
func (h *AgentHandler) handleEinoTransientRetryContinue(
baseCtx context.Context,
conversationID string,
result *multiagent.RunResult,
runErr error,
transientAttempts *int,
curHistory *[]agent.ChatMessage,
curFinalMessage *string,
segmentUserMessage string,
progressCallback func(eventType, message string, data interface{}),
sendProgress func(msg string, extra map[string]interface{}),
) (handled bool, fatal error) {
if !errors.Is(runErr, multiagent.ErrTransientRetryContinue) {
return false, nil
}
maxAttempts := h.einoRunRetryMaxAttempts()
*transientAttempts++
if *transientAttempts > maxAttempts {
if shouldPersistEinoAgentTraceAfterRunError(baseCtx) {
h.persistEinoAgentTraceForResume(conversationID, result)
}
return false, errors.New("transient retry exhausted: " + runErr.Error())
}
attemptNo := *transientAttempts
backoff := multiagent.TransientRetryBackoff(attemptNo-1, h.einoRunRetryMaxBackoffSec())
if progressCallback != nil {
progressCallback("eino_run_retry", fmt.Sprintf("遇到临时错误,%d 秒后第 %d/%d 次重试…", int(backoff.Seconds()), attemptNo, maxAttempts), map[string]interface{}{
"conversationId": conversationID,
"source": "eino",
"attempt": attemptNo,
"maxAttempts": maxAttempts,
"backoffSec": int(backoff.Seconds()),
})
}
select {
case <-baseCtx.Done():
return false, context.Cause(baseCtx)
case <-time.After(backoff):
}
h.applyEinoTransientRetrySegment(conversationID, result, curHistory, curFinalMessage, segmentUserMessage)
if progressCallback != nil {
progressCallback("eino_run_retry", "已恢复上下文,正在重试…", map[string]interface{}{
"conversationId": conversationID,
"source": "eino",
"attempt": attemptNo,
})
}
if sendProgress != nil {
sendProgress("正在重试…", map[string]interface{}{
"conversationId": conversationID,
"source": "transient_retry",
})
}
return true, nil
}
// handleEinoEmptyResponseContinue 在 SSE 任务循环内处理「正常结束但无助手正文」;返回 exhausted=true 时由外层按成功结束(保留占位文案)。
// 与临时错误重试一致:仅恢复轨迹并保留本请求原始 user 文案,不向模型注入续跑说明。
func (h *AgentHandler) handleEinoEmptyResponseContinue(
baseCtx context.Context,
conversationID string,
result *multiagent.RunResult,
runErr error,
emptyResponseAttempts *int,
curHistory *[]agent.ChatMessage,
curFinalMessage *string,
segmentUserMessage string,
progressCallback func(eventType, message string, data interface{}),
sendProgress func(msg string, extra map[string]interface{}),
) (handled bool, exhausted bool) {
if !errors.Is(runErr, multiagent.ErrEmptyResponseContinue) {
return false, false
}
maxAttempts := h.einoRunRetryMaxAttempts()
*emptyResponseAttempts++
if *emptyResponseAttempts > maxAttempts {
if h.logger != nil {
h.logger.Warn("eino empty response auto resume exhausted",
zap.String("conversationId", conversationID),
zap.Int("maxAttempts", maxAttempts))
}
if shouldPersistEinoAgentTraceAfterRunError(baseCtx) {
h.persistEinoAgentTraceForResume(conversationID, result)
}
return false, true
}
attemptNo := *emptyResponseAttempts
if h.logger != nil {
h.logger.Info("eino empty response, auto resume from trace",
zap.String("conversationId", conversationID),
zap.Int("attempt", attemptNo),
zap.Int("maxAttempts", maxAttempts))
}
if progressCallback != nil {
progressCallback("eino_empty_response_continue", fmt.Sprintf("未捕获到助手正文,正在基于轨迹自动续跑(%d/%d)…", attemptNo, maxAttempts), map[string]interface{}{
"conversationId": conversationID,
"source": "eino",
"attempt": attemptNo,
"maxAttempts": maxAttempts,
"resumeKind": "trace_segment",
})
}
h.applyEinoTransientRetrySegment(conversationID, result, curHistory, curFinalMessage, segmentUserMessage)
if sendProgress != nil {
sendProgress("已恢复上下文,正在继续推理…", map[string]interface{}{
"conversationId": conversationID,
"source": "empty_response_continue",
})
}
return true, false
}
+1 -69
View File
@@ -119,7 +119,6 @@ func (h *AgentHandler) EinoSingleAgentLoopStream(c *gin.Context) {
var cancelWithCause context.CancelCauseFunc var cancelWithCause context.CancelCauseFunc
curFinalMessage := prep.FinalMessage curFinalMessage := prep.FinalMessage
segmentUserMessage := prep.FinalMessage // 本请求原始用户句,临时重试时不得丢失
curHistory := prep.History curHistory := prep.History
roleTools := prep.RoleTools roleTools := prep.RoleTools
@@ -177,8 +176,6 @@ func (h *AgentHandler) EinoSingleAgentLoopStream(c *gin.Context) {
taskOwned = true taskOwned = true
var cumulativeMCPExecutionIDs []string var cumulativeMCPExecutionIDs []string
var transientRunAttempts int
var emptyResponseAttempts int
// 同一请求内分段续跑时,主代理 iteration 事件按偏移累计,避免 UI 出现「第3轮 → 第1轮」回跳。 // 同一请求内分段续跑时,主代理 iteration 事件按偏移累计,避免 UI 出现「第3轮 → 第1轮」回跳。
var mainIterationOffset int var mainIterationOffset int
@@ -215,6 +212,7 @@ func (h *AgentHandler) EinoSingleAgentLoopStream(c *gin.Context) {
} }
taskCtxLoop := mcp.WithMCPConversationID(taskCtx, conversationID) taskCtxLoop := mcp.WithMCPConversationID(taskCtx, conversationID)
taskCtxLoop = mcp.WithToolRunRegistry(taskCtxLoop, h.tasks) taskCtxLoop = mcp.WithToolRunRegistry(taskCtxLoop, h.tasks)
taskCtxLoop = mcp.WithEinoExecuteRunRegistry(taskCtxLoop, h.tasks)
taskCtxLoop = multiagent.WithHITLToolInterceptor(taskCtxLoop, func(ctx context.Context, toolName, arguments string) (string, error) { taskCtxLoop = multiagent.WithHITLToolInterceptor(taskCtxLoop, func(ctx context.Context, toolName, arguments string) (string, error) {
return h.interceptHITLForEinoTool(ctx, cancelWithCause, conversationID, assistantMessageID, sendEvent, toolName, arguments) return h.interceptHITLForEinoTool(ctx, cancelWithCause, conversationID, assistantMessageID, sendEvent, toolName, arguments)
}) })
@@ -240,54 +238,11 @@ func (h *AgentHandler) EinoSingleAgentLoopStream(c *gin.Context) {
cumulativeMCPExecutionIDs = mergeMCPExecutionIDLists(cumulativeMCPExecutionIDs, result.MCPExecutionIDs) cumulativeMCPExecutionIDs = mergeMCPExecutionIDLists(cumulativeMCPExecutionIDs, result.MCPExecutionIDs)
} }
handledEmpty, exhaustedEmpty := h.handleEinoEmptyResponseContinue(
baseCtx, conversationID, result, runErr, &emptyResponseAttempts,
&curHistory, &curFinalMessage, segmentUserMessage, progressCallback,
func(msg string, extra map[string]interface{}) { sendEvent("progress", msg, extra) },
)
if exhaustedEmpty {
runErr = nil
transientRunAttempts = 0
timeoutCancel()
break
}
if handledEmpty {
mainIterationOffset += segmentMainIterationMax
transientRunAttempts = 0
timeoutCancel()
baseCtx, cancelWithCause = context.WithCancelCause(context.Background())
h.tasks.BindTaskCancel(conversationID, cancelWithCause)
taskCtx, timeoutCancel = context.WithTimeout(baseCtx, 600*time.Minute)
h.tasks.UpdateTaskStatus(conversationID, "running")
continue
}
if runErr == nil { if runErr == nil {
// 任一段成功完成后,重置临时错误重试窗口(次数/退避从头开始)。
transientRunAttempts = 0
emptyResponseAttempts = 0
timeoutCancel() timeoutCancel()
break break
} }
handled, fatalErr := h.handleEinoTransientRetryContinue(
baseCtx, conversationID, result, runErr, &transientRunAttempts,
&curHistory, &curFinalMessage, segmentUserMessage, progressCallback,
func(msg string, extra map[string]interface{}) { sendEvent("progress", msg, extra) },
)
if handled {
mainIterationOffset += segmentMainIterationMax
timeoutCancel()
baseCtx, cancelWithCause = context.WithCancelCause(context.Background())
h.tasks.BindTaskCancel(conversationID, cancelWithCause)
taskCtx, timeoutCancel = context.WithTimeout(baseCtx, 600*time.Minute)
h.tasks.UpdateTaskStatus(conversationID, "running")
continue
}
if fatalErr != nil {
runErr = fatalErr
}
cause := context.Cause(baseCtx) cause := context.Cause(baseCtx)
if errors.Is(cause, multiagent.ErrInterruptContinue) { if errors.Is(cause, multiagent.ErrInterruptContinue) {
if shouldPersistEinoAgentTraceAfterRunError(baseCtx) { if shouldPersistEinoAgentTraceAfterRunError(baseCtx) {
@@ -312,8 +267,6 @@ func (h *AgentHandler) EinoSingleAgentLoopStream(c *gin.Context) {
"source": "interrupt_continue", "source": "interrupt_continue",
}) })
mainIterationOffset += segmentMainIterationMax mainIterationOffset += segmentMainIterationMax
// 非临时错误分段续跑(用户中断并继续)时,清空 transient 计数,避免跨分段累加。
transientRunAttempts = 0
timeoutCancel() timeoutCancel()
baseCtx, cancelWithCause = context.WithCancelCause(context.Background()) baseCtx, cancelWithCause = context.WithCancelCause(context.Background())
h.tasks.BindTaskCancel(conversationID, cancelWithCause) h.tasks.BindTaskCancel(conversationID, cancelWithCause)
@@ -448,8 +401,6 @@ func (h *AgentHandler) EinoSingleAgentLoop(c *gin.Context) {
curMsg := prep.FinalMessage curMsg := prep.FinalMessage
var result *multiagent.RunResult var result *multiagent.RunResult
var runErr error var runErr error
var transientRunAttempts int
var emptyResponseAttempts int
for { for {
result, runErr = multiagent.RunEinoSingleChatModelAgent( result, runErr = multiagent.RunEinoSingleChatModelAgent(
taskCtx, taskCtx,
@@ -467,28 +418,9 @@ func (h *AgentHandler) EinoSingleAgentLoop(c *gin.Context) {
chatReasoningToClientIntent(req.Reasoning), chatReasoningToClientIntent(req.Reasoning),
h.projectBlackboardBlock(prep.ConversationID), h.projectBlackboardBlock(prep.ConversationID),
) )
handledEmpty, exhaustedEmpty := h.handleEinoEmptyResponseContinue(
baseCtx, prep.ConversationID, result, runErr, &emptyResponseAttempts,
&curHist, &curMsg, prep.FinalMessage, progressCallback, nil,
)
if exhaustedEmpty {
runErr = nil
break
}
if handledEmpty {
continue
}
if runErr == nil { if runErr == nil {
break break
} }
if handled, fatalErr := h.handleEinoTransientRetryContinue(
baseCtx, prep.ConversationID, result, runErr, &transientRunAttempts,
&curHist, &curMsg, prep.FinalMessage, progressCallback, nil,
); handled {
continue
} else if fatalErr != nil {
runErr = fatalErr
}
if shouldPersistEinoAgentTraceAfterRunError(baseCtx) { if shouldPersistEinoAgentTraceAfterRunError(baseCtx) {
h.persistEinoAgentTraceForResume(prep.ConversationID, result) h.persistEinoAgentTraceForResume(prep.ConversationID, result)
} }
+37 -20
View File
@@ -10,8 +10,10 @@ import (
"time" "time"
"cyberstrike-ai/internal/audit" "cyberstrike-ai/internal/audit"
"cyberstrike-ai/internal/config"
"cyberstrike-ai/internal/database" "cyberstrike-ai/internal/database"
"cyberstrike-ai/internal/mcp" "cyberstrike-ai/internal/mcp"
"cyberstrike-ai/internal/monitor"
"cyberstrike-ai/internal/security" "cyberstrike-ai/internal/security"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"go.uber.org/zap" "go.uber.org/zap"
@@ -19,12 +21,18 @@ import (
// MonitorHandler 监控处理器 // MonitorHandler 监控处理器
type MonitorHandler struct { type MonitorHandler struct {
mcpServer *mcp.Server mcpServer *mcp.Server
externalMCPMgr *mcp.ExternalMCPManager externalMCPMgr *mcp.ExternalMCPManager
executor *security.Executor executor *security.Executor
db *database.DB db *database.DB
logger *zap.Logger logger *zap.Logger
audit *audit.Service audit *audit.Service
monitorRetention *monitor.Service
}
// SetMonitorRetention wires MCP execution retention settings.
func (h *MonitorHandler) SetMonitorRetention(s *monitor.Service) {
h.monitorRetention = s
} }
// SetAudit wires platform audit logging. // SetAudit wires platform audit logging.
@@ -50,13 +58,14 @@ func (h *MonitorHandler) SetExternalMCPManager(mgr *mcp.ExternalMCPManager) {
// MonitorResponse 监控响应 // MonitorResponse 监控响应
type MonitorResponse struct { type MonitorResponse struct {
Executions []*mcp.ToolExecution `json:"executions"` Executions []*mcp.ToolExecution `json:"executions"`
Stats map[string]*mcp.ToolStats `json:"stats"` Stats map[string]*mcp.ToolStats `json:"stats"`
Timestamp time.Time `json:"timestamp"` Timestamp time.Time `json:"timestamp"`
Total int `json:"total,omitempty"` Total int `json:"total,omitempty"`
Page int `json:"page,omitempty"` Page int `json:"page,omitempty"`
PageSize int `json:"page_size,omitempty"` PageSize int `json:"page_size,omitempty"`
TotalPages int `json:"total_pages,omitempty"` TotalPages int `json:"total_pages,omitempty"`
RetentionDays int `json:"retention_days,omitempty"`
} }
// Monitor 获取监控信息 // Monitor 获取监控信息
@@ -89,16 +98,24 @@ func (h *MonitorHandler) Monitor(c *gin.Context) {
} }
c.JSON(http.StatusOK, MonitorResponse{ c.JSON(http.StatusOK, MonitorResponse{
Executions: executions, Executions: executions,
Stats: stats, Stats: stats,
Timestamp: time.Now(), Timestamp: time.Now(),
Total: total, Total: total,
Page: page, Page: page,
PageSize: pageSize, PageSize: pageSize,
TotalPages: totalPages, TotalPages: totalPages,
RetentionDays: h.monitorRetentionDays(),
}) })
} }
func (h *MonitorHandler) monitorRetentionDays() int {
if h.monitorRetention != nil {
return h.monitorRetention.RetentionDays()
}
return config.MonitorConfig{}.RetentionDaysEffective()
}
func (h *MonitorHandler) loadExecutions() []*mcp.ToolExecution { func (h *MonitorHandler) loadExecutions() []*mcp.ToolExecution {
executions, _ := h.loadExecutionsWithPagination(1, 1000, "", "") executions, _ := h.loadExecutionsWithPagination(1, 1000, "", "")
return executions return executions
+1 -69
View File
@@ -136,7 +136,6 @@ func (h *AgentHandler) MultiAgentLoopStream(c *gin.Context) {
var cancelWithCause context.CancelCauseFunc var cancelWithCause context.CancelCauseFunc
curFinalMessage := prep.FinalMessage curFinalMessage := prep.FinalMessage
segmentUserMessage := prep.FinalMessage // 本请求原始用户句,临时重试时不得丢失
curHistory := prep.History curHistory := prep.History
roleTools := prep.RoleTools roleTools := prep.RoleTools
orch := strings.TrimSpace(req.Orchestration) orch := strings.TrimSpace(req.Orchestration)
@@ -187,8 +186,6 @@ func (h *AgentHandler) MultiAgentLoopStream(c *gin.Context) {
// 同一 HTTP 流内多段 Run(如中断并继续)合并 MCP execution id,供最终 response / 库表与工具芯片展示完整列表 // 同一 HTTP 流内多段 Run(如中断并继续)合并 MCP execution id,供最终 response / 库表与工具芯片展示完整列表
var cumulativeMCPExecutionIDs []string var cumulativeMCPExecutionIDs []string
var transientRunAttempts int
var emptyResponseAttempts int
// 同一请求内分段续跑时,主代理 iteration 事件按偏移累计,避免 UI 出现「第3轮 → 第1轮」回跳。 // 同一请求内分段续跑时,主代理 iteration 事件按偏移累计,避免 UI 出现「第3轮 → 第1轮」回跳。
var mainIterationOffset int var mainIterationOffset int
@@ -225,6 +222,7 @@ func (h *AgentHandler) MultiAgentLoopStream(c *gin.Context) {
} }
taskCtxLoop := mcp.WithMCPConversationID(taskCtx, conversationID) taskCtxLoop := mcp.WithMCPConversationID(taskCtx, conversationID)
taskCtxLoop = mcp.WithToolRunRegistry(taskCtxLoop, h.tasks) taskCtxLoop = mcp.WithToolRunRegistry(taskCtxLoop, h.tasks)
taskCtxLoop = mcp.WithEinoExecuteRunRegistry(taskCtxLoop, h.tasks)
taskCtxLoop = multiagent.WithHITLToolInterceptor(taskCtxLoop, func(ctx context.Context, toolName, arguments string) (string, error) { taskCtxLoop = multiagent.WithHITLToolInterceptor(taskCtxLoop, func(ctx context.Context, toolName, arguments string) (string, error) {
return h.interceptHITLForEinoTool(ctx, cancelWithCause, conversationID, assistantMessageID, sendEvent, toolName, arguments) return h.interceptHITLForEinoTool(ctx, cancelWithCause, conversationID, assistantMessageID, sendEvent, toolName, arguments)
}) })
@@ -252,54 +250,11 @@ func (h *AgentHandler) MultiAgentLoopStream(c *gin.Context) {
cumulativeMCPExecutionIDs = mergeMCPExecutionIDLists(cumulativeMCPExecutionIDs, result.MCPExecutionIDs) cumulativeMCPExecutionIDs = mergeMCPExecutionIDLists(cumulativeMCPExecutionIDs, result.MCPExecutionIDs)
} }
handledEmpty, exhaustedEmpty := h.handleEinoEmptyResponseContinue(
baseCtx, conversationID, result, runErr, &emptyResponseAttempts,
&curHistory, &curFinalMessage, segmentUserMessage, progressCallback,
func(msg string, extra map[string]interface{}) { sendEvent("progress", msg, extra) },
)
if exhaustedEmpty {
runErr = nil
transientRunAttempts = 0
timeoutCancel()
break
}
if handledEmpty {
mainIterationOffset += segmentMainIterationMax
transientRunAttempts = 0
timeoutCancel()
baseCtx, cancelWithCause = context.WithCancelCause(context.Background())
h.tasks.BindTaskCancel(conversationID, cancelWithCause)
taskCtx, timeoutCancel = context.WithTimeout(baseCtx, 600*time.Minute)
h.tasks.UpdateTaskStatus(conversationID, "running")
continue
}
if runErr == nil { if runErr == nil {
// 任一段成功完成后,重置临时错误重试窗口(次数/退避从头开始)。
transientRunAttempts = 0
emptyResponseAttempts = 0
timeoutCancel() timeoutCancel()
break break
} }
handled, fatalErr := h.handleEinoTransientRetryContinue(
baseCtx, conversationID, result, runErr, &transientRunAttempts,
&curHistory, &curFinalMessage, segmentUserMessage, progressCallback,
func(msg string, extra map[string]interface{}) { sendEvent("progress", msg, extra) },
)
if handled {
mainIterationOffset += segmentMainIterationMax
timeoutCancel()
baseCtx, cancelWithCause = context.WithCancelCause(context.Background())
h.tasks.BindTaskCancel(conversationID, cancelWithCause)
taskCtx, timeoutCancel = context.WithTimeout(baseCtx, 600*time.Minute)
h.tasks.UpdateTaskStatus(conversationID, "running")
continue
}
if fatalErr != nil {
runErr = fatalErr
}
cause := context.Cause(baseCtx) cause := context.Cause(baseCtx)
if errors.Is(cause, multiagent.ErrInterruptContinue) { if errors.Is(cause, multiagent.ErrInterruptContinue) {
if shouldPersistEinoAgentTraceAfterRunError(baseCtx) { if shouldPersistEinoAgentTraceAfterRunError(baseCtx) {
@@ -324,8 +279,6 @@ func (h *AgentHandler) MultiAgentLoopStream(c *gin.Context) {
"source": "interrupt_continue", "source": "interrupt_continue",
}) })
mainIterationOffset += segmentMainIterationMax mainIterationOffset += segmentMainIterationMax
// 非临时错误分段续跑(用户中断并继续)时,清空 transient 计数,避免跨分段累加。
transientRunAttempts = 0
timeoutCancel() timeoutCancel()
baseCtx, cancelWithCause = context.WithCancelCause(context.Background()) baseCtx, cancelWithCause = context.WithCancelCause(context.Background())
h.tasks.BindTaskCancel(conversationID, cancelWithCause) h.tasks.BindTaskCancel(conversationID, cancelWithCause)
@@ -460,8 +413,6 @@ func (h *AgentHandler) MultiAgentLoop(c *gin.Context) {
curMsg := prep.FinalMessage curMsg := prep.FinalMessage
var result *multiagent.RunResult var result *multiagent.RunResult
var runErr error var runErr error
var transientRunAttempts int
var emptyResponseAttempts int
for { for {
result, runErr = multiagent.RunDeepAgent( result, runErr = multiagent.RunDeepAgent(
taskCtx, taskCtx,
@@ -481,28 +432,9 @@ func (h *AgentHandler) MultiAgentLoop(c *gin.Context) {
chatReasoningToClientIntent(req.Reasoning), chatReasoningToClientIntent(req.Reasoning),
h.projectBlackboardBlock(prep.ConversationID), h.projectBlackboardBlock(prep.ConversationID),
) )
handledEmpty, exhaustedEmpty := h.handleEinoEmptyResponseContinue(
baseCtx, prep.ConversationID, result, runErr, &emptyResponseAttempts,
&curHist, &curMsg, prep.FinalMessage, progressCallback, nil,
)
if exhaustedEmpty {
runErr = nil
break
}
if handledEmpty {
continue
}
if runErr == nil { if runErr == nil {
break break
} }
if handled, fatalErr := h.handleEinoTransientRetryContinue(
baseCtx, prep.ConversationID, result, runErr, &transientRunAttempts,
&curHist, &curMsg, prep.FinalMessage, progressCallback, nil,
); handled {
continue
} else if fatalErr != nil {
runErr = fatalErr
}
if shouldPersistEinoAgentTraceAfterRunError(baseCtx) { if shouldPersistEinoAgentTraceAfterRunError(baseCtx) {
h.persistEinoAgentTraceForResume(prep.ConversationID, result) h.persistEinoAgentTraceForResume(prep.ConversationID, result)
} }
+3
View File
@@ -594,6 +594,9 @@ func (h *RobotHandler) cmdDelete(platform, userID, convID string) string {
h.mu.Unlock() h.mu.Unlock()
h.deleteSessionBinding(sk) h.deleteSessionBinding(sk)
} }
if h.agentHandler != nil {
h.agentHandler.CancelRunningTaskForConversation(convID)
}
if err := h.db.DeleteConversation(convID); err != nil { if err := h.db.DeleteConversation(convID); err != nil {
return "删除失败: " + err.Error() return "删除失败: " + err.Error()
} }
+68
View File
@@ -37,6 +37,11 @@ type AgentTask struct {
// InterruptContinueNote 无 MCP 时「中断并继续」由用户在弹窗中填写的补充说明(Cancel 前写入,续跑轮次读取后清空) // InterruptContinueNote 无 MCP 时「中断并继续」由用户在弹窗中填写的补充说明(Cancel 前写入,续跑轮次读取后清空)
InterruptContinueNote string `json:"-"` InterruptContinueNote string `json:"-"`
// activeEinoExecuteCancel 当前进行中的 Eino filesystem execute 取消函数(与 MCP 工具并行,供中断并继续)
activeEinoExecuteCancel context.CancelFunc
// activeEinoExecuteAbortNote AbortActiveEinoExecute 写入的用户说明,由 execute 收尾时合并进工具结果
activeEinoExecuteAbortNote string
cancel func(error) cancel func(error)
} }
@@ -70,6 +75,69 @@ func (m *AgentTaskManager) UnregisterRunningTool(conversationID, executionID str
} }
} }
// RegisterActiveEinoExecute 登记进行中的 Eino filesystem execute(每会话同时仅一条)。
func (m *AgentTaskManager) RegisterActiveEinoExecute(conversationID string, cancel context.CancelFunc) {
conversationID = strings.TrimSpace(conversationID)
if conversationID == "" || cancel == nil {
return
}
m.mu.Lock()
defer m.mu.Unlock()
if t, ok := m.tasks[conversationID]; ok && t != nil {
t.activeEinoExecuteCancel = cancel
t.activeEinoExecuteAbortNote = ""
}
}
// UnregisterActiveEinoExecute execute 正常结束或已取消后清除登记。
func (m *AgentTaskManager) UnregisterActiveEinoExecute(conversationID string) {
conversationID = strings.TrimSpace(conversationID)
if conversationID == "" {
return
}
m.mu.Lock()
defer m.mu.Unlock()
if t, ok := m.tasks[conversationID]; ok && t != nil {
t.activeEinoExecuteCancel = nil
t.activeEinoExecuteAbortNote = ""
}
}
// AbortActiveEinoExecute 终止当前 Eino execute 并暂存用户说明(与 MCP 工具终止一致)。
func (m *AgentTaskManager) AbortActiveEinoExecute(conversationID, note string) bool {
conversationID = strings.TrimSpace(conversationID)
if conversationID == "" {
return false
}
m.mu.Lock()
t, ok := m.tasks[conversationID]
if !ok || t == nil || t.activeEinoExecuteCancel == nil {
m.mu.Unlock()
return false
}
t.activeEinoExecuteAbortNote = strings.TrimSpace(note)
cancel := t.activeEinoExecuteCancel
m.mu.Unlock()
cancel()
return true
}
// TakeEinoExecuteAbortNote 读取并清空 execute 终止说明(execute 收尾时调用一次)。
func (m *AgentTaskManager) TakeEinoExecuteAbortNote(conversationID string) string {
conversationID = strings.TrimSpace(conversationID)
if conversationID == "" {
return ""
}
m.mu.Lock()
defer m.mu.Unlock()
if t, ok := m.tasks[conversationID]; ok && t != nil {
n := t.activeEinoExecuteAbortNote
t.activeEinoExecuteAbortNote = ""
return n
}
return ""
}
// SetInterruptContinueNote 在发起 ErrInterruptContinue 取消前写入用户补充说明(仅内存)。 // SetInterruptContinueNote 在发起 ErrInterruptContinue 取消前写入用户补充说明(仅内存)。
func (m *AgentTaskManager) SetInterruptContinueNote(conversationID, note string) { func (m *AgentTaskManager) SetInterruptContinueNote(conversationID, note string) {
conversationID = strings.TrimSpace(conversationID) conversationID = strings.TrimSpace(conversationID)
@@ -0,0 +1,40 @@
package handler
import (
"context"
"testing"
"time"
)
func TestAbortActiveEinoExecute(t *testing.T) {
m := NewAgentTaskManager()
conv := "conv-eino-exec-abort"
ctx, cancel := context.WithCancel(context.Background())
_, err := m.StartTask(conv, "test", func(error) {})
if err != nil {
t.Fatalf("StartTask: %v", err)
}
m.RegisterActiveEinoExecute(conv, cancel)
done := make(chan struct{})
go func() {
<-ctx.Done()
close(done)
}()
if !m.AbortActiveEinoExecute(conv, "跳过域名收集") {
t.Fatal("expected abort to succeed")
}
select {
case <-done:
case <-time.After(2 * time.Second):
t.Fatal("execute cancel did not propagate")
}
if got := m.TakeEinoExecuteAbortNote(conv); got != "跳过域名收集" {
t.Fatalf("abort note = %q, want 跳过域名收集", got)
}
m.UnregisterActiveEinoExecute(conv)
if m.AbortActiveEinoExecute(conv, "") {
t.Fatal("second abort should fail when no active execute")
}
}
+26
View File
@@ -11,7 +11,16 @@ type ToolRunRegistry interface {
UnregisterRunningTool(conversationID, executionID string) UnregisterRunningTool(conversationID, executionID string)
} }
// EinoExecuteRunRegistry 登记进行中的 Eino filesystem execute,供「中断并继续」终止 amass 等长命令。
type EinoExecuteRunRegistry interface {
RegisterActiveEinoExecute(conversationID string, cancel context.CancelFunc)
UnregisterActiveEinoExecute(conversationID string)
AbortActiveEinoExecute(conversationID, note string) bool
TakeEinoExecuteAbortNote(conversationID string) string
}
type toolRunRegistryCtxKey struct{} type toolRunRegistryCtxKey struct{}
type einoExecuteRunRegistryCtxKey struct{}
type mcpConversationIDCtxKey struct{} type mcpConversationIDCtxKey struct{}
// WithToolRunRegistry 将登记器注入 ctxEino / 原生 Agent 任务 ctx)。 // WithToolRunRegistry 将登记器注入 ctxEino / 原生 Agent 任务 ctx)。
@@ -31,6 +40,23 @@ func ToolRunRegistryFromContext(ctx context.Context) ToolRunRegistry {
return v return v
} }
// WithEinoExecuteRunRegistry 将 Eino execute 取消登记器注入 ctx。
func WithEinoExecuteRunRegistry(ctx context.Context, reg EinoExecuteRunRegistry) context.Context {
if ctx == nil || reg == nil {
return ctx
}
return context.WithValue(ctx, einoExecuteRunRegistryCtxKey{}, reg)
}
// EinoExecuteRunRegistryFromContext 取出 Eino execute 登记器(无则 nil)。
func EinoExecuteRunRegistryFromContext(ctx context.Context) EinoExecuteRunRegistry {
if ctx == nil {
return nil
}
v, _ := ctx.Value(einoExecuteRunRegistryCtxKey{}).(EinoExecuteRunRegistry)
return v
}
// WithMCPConversationID 将对话 ID 注入 ctx,供 CallTool 内与 executionId 关联。 // WithMCPConversationID 将对话 ID 注入 ctx,供 CallTool 内与 executionId 关联。
func WithMCPConversationID(ctx context.Context, conversationID string) context.Context { func WithMCPConversationID(ctx context.Context, conversationID string) context.Context {
if ctx == nil { if ctx == nil {
+71
View File
@@ -0,0 +1,71 @@
package monitor
import (
"time"
"cyberstrike-ai/internal/config"
"cyberstrike-ai/internal/database"
"go.uber.org/zap"
)
const retentionPurgeInterval = time.Hour
// Service manages MCP tool execution monitor retention.
type Service struct {
db *database.DB
cfg *config.Config
logger *zap.Logger
}
// NewService creates a monitor retention service.
func NewService(db *database.DB, cfg *config.Config, logger *zap.Logger) *Service {
return &Service{db: db, cfg: cfg, logger: logger}
}
// RetentionDays returns configured retention; 0 means keep forever.
func (s *Service) RetentionDays() int {
if s == nil || s.cfg == nil {
return config.MonitorConfig{}.RetentionDaysEffective()
}
return s.cfg.Monitor.RetentionDaysEffective()
}
// PurgeExpired deletes tool execution rows older than retention_days when configured.
func (s *Service) PurgeExpired() {
if s == nil || s.db == nil || s.cfg == nil {
return
}
days := s.cfg.Monitor.RetentionDaysEffective()
if days <= 0 {
return
}
cutoff := time.Now().AddDate(0, 0, -days)
n, err := s.db.PurgeToolExecutionsBefore(cutoff)
if err != nil {
if s.logger != nil {
s.logger.Warn("清理过期 MCP 执行记录失败", zap.Error(err))
}
return
}
if n > 0 && s.logger != nil {
s.logger.Info("已清理过期 MCP 执行记录", zap.Int64("deleted", n), zap.Int("retention_days", days))
}
}
// StartRetentionLoop periodically purges expired tool execution rows.
func StartRetentionLoop(s *Service, logger *zap.Logger) {
if s == nil {
return
}
go func() {
ticker := time.NewTicker(retentionPurgeInterval)
defer ticker.Stop()
for range ticker.C {
s.PurgeExpired()
if logger != nil {
logger.Debug("monitor retention tick completed")
}
}
}()
}
+94
View File
@@ -0,0 +1,94 @@
package monitor
import (
"path/filepath"
"testing"
"time"
"cyberstrike-ai/internal/config"
"cyberstrike-ai/internal/database"
"cyberstrike-ai/internal/mcp"
"go.uber.org/zap"
)
func TestServicePurgeExpired_respectsZeroRetention(t *testing.T) {
dbPath := filepath.Join(t.TempDir(), "monitor.db")
db, err := database.NewDB(dbPath, zap.NewNop())
if err != nil {
t.Fatalf("NewDB: %v", err)
}
defer db.Close()
exec := &mcp.ToolExecution{
ID: "ancient",
ToolName: "curl::get",
Arguments: map[string]interface{}{},
Status: "completed",
StartTime: mustParseTime(t, "2020-01-01T00:00:00Z"),
}
if err := db.SaveToolExecution(exec); err != nil {
t.Fatalf("SaveToolExecution: %v", err)
}
zero := 0
svc := NewService(db, &config.Config{
Monitor: config.MonitorConfig{RetentionDays: &zero},
}, zap.NewNop())
svc.PurgeExpired()
if _, err := db.GetToolExecution("ancient"); err != nil {
t.Fatalf("record should remain when retention_days=0: %v", err)
}
}
func TestServicePurgeExpired_deletesOldRows(t *testing.T) {
dbPath := filepath.Join(t.TempDir(), "monitor.db")
db, err := database.NewDB(dbPath, zap.NewNop())
if err != nil {
t.Fatalf("NewDB: %v", err)
}
defer db.Close()
exec := &mcp.ToolExecution{
ID: "ancient",
ToolName: "curl::get",
Arguments: map[string]interface{}{},
Status: "completed",
StartTime: mustParseTime(t, "2020-01-01T00:00:00Z"),
}
if err := db.SaveToolExecution(exec); err != nil {
t.Fatalf("SaveToolExecution: %v", err)
}
days := 90
svc := NewService(db, &config.Config{
Monitor: config.MonitorConfig{RetentionDays: &days},
}, zap.NewNop())
svc.PurgeExpired()
if _, err := db.GetToolExecution("ancient"); err == nil {
t.Fatal("record should be purged when older than retention_days")
}
}
func TestRetentionDaysEffective_defaults(t *testing.T) {
got := config.MonitorConfig{}.RetentionDaysEffective()
if got != 90 {
t.Fatalf("default = %d, want 90", got)
}
zero := 0
cfg := config.MonitorConfig{RetentionDays: &zero}
if cfg.RetentionDaysEffective() != 0 {
t.Fatalf("zero = %d, want 0", cfg.RetentionDaysEffective())
}
}
func mustParseTime(t *testing.T, value string) time.Time {
t.Helper()
parsed, err := time.Parse(time.RFC3339, value)
if err != nil {
t.Fatalf("parse time: %v", err)
}
return parsed
}
@@ -0,0 +1,104 @@
package multiagent
import (
"context"
"strings"
"github.com/cloudwego/eino/adk"
"github.com/cloudwego/eino/schema"
"go.uber.org/zap"
)
// continuationSessionMarker matches Cursor / IDE session-resume user injections.
const continuationSessionMarker = "This session is being continued from a previous conversation"
// continuationUserDedupMiddleware keeps only the latest session-resume user message when
// multiple continuation injections were stacked (e.g. after repeated out-of-context resumes).
type continuationUserDedupMiddleware struct {
adk.BaseChatModelAgentMiddleware
logger *zap.Logger
phase string
}
func newContinuationUserDedupMiddleware(logger *zap.Logger, phase string) adk.ChatModelAgentMiddleware {
return &continuationUserDedupMiddleware{logger: logger, phase: phase}
}
func (m *continuationUserDedupMiddleware) BeforeModelRewriteState(
ctx context.Context,
state *adk.ChatModelAgentState,
mc *adk.ModelContext,
) (context.Context, *adk.ChatModelAgentState, error) {
_ = mc
if m == nil || state == nil || len(state.Messages) == 0 {
return ctx, state, nil
}
deduped, dropped := dedupContinuationUserMessages(state.Messages)
if dropped == 0 {
return ctx, state, nil
}
if m.logger != nil {
m.logger.Info("eino continuation user messages deduplicated",
zap.String("phase", m.phase),
zap.Int("dropped", dropped),
zap.Int("messages_before", len(state.Messages)),
zap.Int("messages_after", len(deduped)),
)
}
out := *state
out.Messages = deduped
return ctx, &out, nil
}
func adkUserMessageText(msg adk.Message) string {
if msg == nil {
return ""
}
var b strings.Builder
if s := strings.TrimSpace(msg.Content); s != "" {
b.WriteString(s)
}
for _, part := range msg.UserInputMultiContent {
if part.Type == schema.ChatMessagePartTypeText {
if s := strings.TrimSpace(part.Text); s != "" {
if b.Len() > 0 {
b.WriteByte('\n')
}
b.WriteString(s)
}
}
}
return b.String()
}
func isContinuationUserMessage(msg adk.Message) bool {
if msg == nil || msg.Role != schema.User {
return false
}
return strings.Contains(adkUserMessageText(msg), continuationSessionMarker)
}
func dedupContinuationUserMessages(msgs []adk.Message) ([]adk.Message, int) {
lastIdx := -1
contCount := 0
for i, msg := range msgs {
if !isContinuationUserMessage(msg) {
continue
}
contCount++
lastIdx = i
}
if contCount <= 1 {
return msgs, 0
}
out := make([]adk.Message, 0, len(msgs)-(contCount-1))
dropped := 0
for i, msg := range msgs {
if isContinuationUserMessage(msg) && i != lastIdx {
dropped++
continue
}
out = append(out, msg)
}
return out, dropped
}
@@ -0,0 +1,65 @@
package multiagent
import (
"context"
"strings"
"testing"
"github.com/cloudwego/eino/adk"
"github.com/cloudwego/eino/schema"
)
func continuationUser(text string) adk.Message {
return &schema.Message{
Role: schema.User,
UserInputMultiContent: []schema.MessageInputPart{
{Type: schema.ChatMessagePartTypeText, Text: continuationSessionMarker + "\n" + text},
{Type: schema.ChatMessagePartTypeText, Text: "Please continue the conversation from where we left it off."},
},
}
}
func TestDedupContinuationUserMessages_KeepsLatest(t *testing.T) {
msgs := []adk.Message{
continuationUser("summary old"),
schema.UserMessage("real task"),
continuationUser("summary new"),
}
out, dropped := dedupContinuationUserMessages(msgs)
if dropped != 1 {
t.Fatalf("dropped=%d want 1", dropped)
}
if len(out) != 2 {
t.Fatalf("len=%d want 2", len(out))
}
if out[0].Role != schema.User || adkUserMessageText(out[0]) != "real task" {
t.Fatalf("first should remain real task, got %q", adkUserMessageText(out[0]))
}
if !strings.Contains(adkUserMessageText(out[1]), "summary new") {
t.Fatalf("latest continuation not kept: %q", adkUserMessageText(out[1]))
}
}
func TestDedupContinuationUserMessages_NoOpSingle(t *testing.T) {
msgs := []adk.Message{continuationUser("only"), schema.UserMessage("task")}
out, dropped := dedupContinuationUserMessages(msgs)
if dropped != 0 || len(out) != 2 {
t.Fatalf("unexpected change dropped=%d len=%d", dropped, len(out))
}
}
func TestContinuationUserDedupMiddleware(t *testing.T) {
mw := newContinuationUserDedupMiddleware(nil, "test")
state := &adk.ChatModelAgentState{Messages: []adk.Message{
continuationUser("old"),
continuationUser("new"),
schema.UserMessage("task"),
}}
_, out, err := mw.(*continuationUserDedupMiddleware).BeforeModelRewriteState(context.Background(), state, nil)
if err != nil {
t.Fatal(err)
}
if len(out.Messages) != 2 {
t.Fatalf("want 2 messages after dedup, got %d", len(out.Messages))
}
}
+158 -73
View File
@@ -90,7 +90,7 @@ type einoADKRunLoopArgs struct {
FilesystemMonitorRecord einomcp.ExecutionRecorder FilesystemMonitorRecord einomcp.ExecutionRecorder
MCPExecutionBinder *MCPExecutionBinder MCPExecutionBinder *MCPExecutionBinder
// ToolInvokeNotify 与 einomcp.ToolsFromDefinitions 共享:run loop 在迭代前 SetMCP 桥 Fire 以补全 tool_result。 // ToolInvokeNotify 与 einomcp.ToolsFromDefinitions 共享:run loop 在迭代前 Setexecute/MCP 桥 Fire 时立即推送 tool_resultADK 晚到经 toolResultSent 去重)
ToolInvokeNotify *einomcp.ToolInvokeNotifyHolder ToolInvokeNotify *einomcp.ToolInvokeNotifyHolder
DA adk.Agent DA adk.Agent
@@ -341,8 +341,22 @@ func runEinoADKAgentLoop(ctx context.Context, args *einoADKRunLoopArgs, baseMsgs
} }
if args.ToolInvokeNotify != nil { if args.ToolInvokeNotify != nil {
args.ToolInvokeNotify.Set(func(toolCallID, toolName, einoAgent string, success bool, content string, invokeErr error) { args.ToolInvokeNotify.Set(func(toolCallID, toolName, einoAgent string, success bool, content string, invokeErr error) {
removePendingByID(strings.TrimSpace(toolCallID)) // Eino execute / MCP 桥在工具返回时 Fire;若 ADK schema.Tool 事件迟迟不到,此处立即推送
// tool_result 仅由下方 ADK schema.Tool 事件推送,正文与送入模型的上下文一致(含 reduction 截断) // tool_result 解除 UI「执行中」。tryEmitToolResultProgress 经 toolResultSent 去重,ADK 晚到不重复
isErr := !success || invokeErr != nil
body := content
if strings.HasPrefix(body, einomcp.ToolErrorPrefix) {
isErr = true
body = strings.TrimPrefix(body, einomcp.ToolErrorPrefix)
}
if tail := friendlyEinoExecuteInvokeTail(invokeErr); tail != "" {
if body == "" {
body = tail
} else if !strings.Contains(body, tail) {
body = strings.TrimSpace(body) + "\n\n" + tail
}
}
tryEmitToolResultProgress(toolName, body, toolCallID, isErr, einoAgent)
}) })
} }
@@ -383,6 +397,12 @@ func runEinoADKAgentLoop(ctx context.Context, args *einoADKRunLoopArgs, baseMsgs
} }
} }
runner := adk.NewRunner(ctx, runnerCfg) runner := adk.NewRunner(ctx, runnerCfg)
startRunnerIter := func(runMsgs []adk.Message) *adk.AsyncIterator[*adk.AgentEvent] {
if checkPointID != "" {
return runner.Run(ctx, runMsgs, adk.WithCheckPointID(checkPointID))
}
return runner.Run(ctx, runMsgs)
}
var iter *adk.AsyncIterator[*adk.AgentEvent] var iter *adk.AsyncIterator[*adk.AgentEvent]
if cpStore != nil && checkPointID != "" { if cpStore != nil && checkPointID != "" {
if _, existed, getErr := cpStore.Get(ctx, checkPointID); getErr != nil { if _, existed, getErr := cpStore.Get(ctx, checkPointID); getErr != nil {
@@ -422,12 +442,9 @@ func runEinoADKAgentLoop(ctx context.Context, args *einoADKRunLoopArgs, baseMsgs
} }
} }
if iter == nil { if iter == nil {
if checkPointID != "" { iter = startRunnerIter(msgs)
iter = runner.Run(ctx, msgs, adk.WithCheckPointID(checkPointID))
} else {
iter = runner.Run(ctx, msgs)
}
} }
transientRetrier := newEinoTransientRunRetrier(einoTransientRunRetryPolicyFromArgs(args))
handleRunErr := func(runErr error) error { handleRunErr := func(runErr error) error {
if runErr == nil { if runErr == nil {
return nil return nil
@@ -480,26 +497,60 @@ func runEinoADKAgentLoop(ctx context.Context, args *einoADKRunLoopArgs, baseMsgs
return runErr return runErr
} }
// maybeRetryTransientRun:不在此层 runner.Run/Resume;由 handler 落库 + loadHistoryFromAgentTrace 分段续跑(同中断并继续)。 maybeRetryTransientRun := func(runErr error) (restarted bool, fatal error) {
maybeRetryTransientRun := func(runErr error) (retry bool, fatal error) { if runErr == nil {
if runErr == nil || !isEinoTransientRunError(runErr) { return false, nil
}
if !isEinoTransientRunError(runErr) {
return false, handleRunErr(runErr) return false, handleRunErr(runErr)
} }
restarted, restartMsgs, ctxSource, backoff, retErr := transientRetrier.tryRetry(
ctx, runErr, args, baseMsgs, runAccumulatedMsgs, baseAccumulatedCount,
)
if retErr != nil {
flushAllPendingAsFailed(runErr)
if logger != nil {
logger.Warn("eino transient retry exhausted",
zap.Error(retErr),
zap.String("orchestration", orchMode),
zap.Int("maxAttempts", transientRetrier.maxAttempts()))
}
return false, retErr
}
if !restarted {
return false, nil
}
attemptNo := transientRetrier.attempt()
maxAttempts := transientRetrier.maxAttempts()
if logger != nil { if logger != nil {
logger.Warn("eino transient error, ending run segment for handler resume", logger.Warn("eino transient error, retrying after backoff",
zap.Error(runErr), zap.Error(runErr),
zap.String("orchestration", orchMode)) zap.String("orchestration", orchMode),
zap.Int("attempt", attemptNo),
zap.Int("maxAttempts", maxAttempts),
zap.Duration("backoff", backoff))
} }
if progress != nil { if progress != nil {
progress("eino_run_retry", "遇到临时错误(限流或网络波动),将保存上下文并重试…", map[string]interface{}{ progress("eino_run_retry", fmt.Sprintf("遇到临时错误(限流或网络波动),%d 秒后第 %d/%d 次重试…", int(backoff.Seconds()), attemptNo, maxAttempts), map[string]interface{}{
"conversationId": conversationID, "conversationId": conversationID,
"source": "eino", "source": "eino",
"orchestration": orchMode, "orchestration": orchMode,
"error": runErr.Error(), "error": runErr.Error(),
"resumeKind": "trace_segment", "attempt": attemptNo,
"maxAttempts": maxAttempts,
"backoffSec": int(backoff.Seconds()),
})
progress("eino_run_retry", "已恢复上下文,正在重试…", map[string]interface{}{
"conversationId": conversationID,
"source": "eino",
"orchestration": orchMode,
"attempt": attemptNo,
"contextSource": string(ctxSource),
}) })
} }
return false, ErrTransientRetryContinue msgs = restartMsgs
iter = startRunnerIter(msgs)
return true, nil
} }
takePartial := func(runErr error) (*RunResult, error) { takePartial := func(runErr error) (*RunResult, error) {
@@ -514,10 +565,10 @@ func runEinoADKAgentLoop(ctx context.Context, args *einoADKRunLoopArgs, baseMsgs
} }
for { for {
// 检测 context 取消(用户关闭浏览器、请求超时等),flush pending 工具状态避免 UI 卡在 "执行中" // iter.Next 可能长时间阻塞(工具执行、模型推理);须与 ctx 联动,否则取消/超时无法及时 flush pending
select { ev, ok, iterCtxErr := nextAgentEventWithContext(ctx, iter)
case <-ctx.Done(): if iterCtxErr != nil {
flushAllPendingAsFailed(ctx.Err()) flushAllPendingAsFailed(iterCtxErr)
if progress != nil { if progress != nil {
if isInterruptContinue(ctx) { if isInterruptContinue(ctx) {
progress("progress", "已暂停当前输出,正在合并用户补充并继续…", map[string]interface{}{ progress("progress", "已暂停当前输出,正在合并用户补充并继续…", map[string]interface{}{
@@ -526,17 +577,14 @@ func runEinoADKAgentLoop(ctx context.Context, args *einoADKRunLoopArgs, baseMsgs
"kind": "interrupt_continue", "kind": "interrupt_continue",
}) })
} else { } else {
progress("error", "Request cancelled / 请求已取消", map[string]interface{}{ progress("error", iterCtxErr.Error(), map[string]interface{}{
"conversationId": conversationID, "conversationId": conversationID,
"source": "eino", "source": "eino",
}) })
} }
} }
return takePartial(ctx.Err()) return takePartial(iterCtxErr)
default:
} }
ev, ok := iter.Next()
if !ok { if !ok {
// iter 结束并不总是“正常完成”: // iter 结束并不总是“正常完成”:
// 当取消/超时发生在 iter.Next() 阻塞期间时,可能直接返回 !ok。 // 当取消/超时发生在 iter.Next() 阻塞期间时,可能直接返回 !ok。
@@ -583,9 +631,15 @@ func runEinoADKAgentLoop(ctx context.Context, args *einoADKRunLoopArgs, baseMsgs
continue continue
} }
if ev.Err != nil { if ev.Err != nil {
if _, retErr := maybeRetryTransientRun(ev.Err); retErr != nil { restarted, retErr := maybeRetryTransientRun(ev.Err)
if retErr != nil {
return takePartial(retErr) return takePartial(retErr)
} }
if restarted {
continue
}
} else {
transientRetrier.reset()
} }
if ev.AgentName != "" && progress != nil { if ev.AgentName != "" && progress != nil {
iterEinoAgent := orchestratorName iterEinoAgent := orchestratorName
@@ -648,29 +702,7 @@ func runEinoADKAgentLoop(ctx context.Context, args *einoADKRunLoopArgs, baseMsgs
if mv.IsStreaming && mv.MessageStream != nil && mv.Role == schema.Tool { if mv.IsStreaming && mv.MessageStream != nil && mv.Role == schema.Tool {
toolName := strings.TrimSpace(mv.ToolName) toolName := strings.TrimSpace(mv.ToolName)
var toolBuf strings.Builder content, streamToolCallID, toolStreamRecvErr := recvSchemaMessageStream(ctx, mv.MessageStream)
streamToolCallID := ""
var toolStreamRecvErr error
for {
chunk, rerr := mv.MessageStream.Recv()
if errors.Is(rerr, io.EOF) {
break
}
if rerr != nil {
toolStreamRecvErr = rerr
break
}
if chunk == nil {
continue
}
if chunk.Content != "" {
toolBuf.WriteString(chunk.Content)
}
if tid := strings.TrimSpace(chunk.ToolCallID); tid != "" {
streamToolCallID = tid
}
}
content := toolBuf.String()
isErr := false isErr := false
if strings.HasPrefix(content, einomcp.ToolErrorPrefix) { if strings.HasPrefix(content, einomcp.ToolErrorPrefix) {
isErr = true isErr = true
@@ -951,9 +983,13 @@ func runEinoADKAgentLoop(ctx context.Context, args *einoADKRunLoopArgs, baseMsgs
"einoRole": einoRoleTag(ev.AgentName), "einoRole": einoRoleTag(ev.AgentName),
}) })
} }
if _, retErr := maybeRetryTransientRun(streamRecvErr); retErr != nil { restarted, retErr := maybeRetryTransientRun(streamRecvErr)
if retErr != nil {
return takePartial(retErr) return takePartial(retErr)
} }
if restarted {
continue
}
} }
continue continue
} }
@@ -1057,32 +1093,9 @@ func runEinoADKAgentLoop(ctx context.Context, args *einoADKRunLoopArgs, baseMsgs
orchMode, runAccumulatedMsgs, persistTraceSource(args, runAccumulatedMsgs), orchMode, runAccumulatedMsgs, persistTraceSource(args, runAccumulatedMsgs),
lastAssistant, lastPlanExecuteExecutor, emptyHint, ids, false, lastAssistant, lastPlanExecuteExecutor, emptyHint, ids, false,
) )
if shouldEinoEmptyResponseContinue(out, emptyHint, len(runAccumulatedMsgs), baseAccumulatedCount) {
if logger != nil {
logger.Info("eino empty response, ending run segment for handler resume",
zap.String("conversationId", conversationID),
zap.String("orchestration", orchMode),
zap.Int("traceMessages", len(runAccumulatedMsgs)))
}
if progress != nil {
progress("eino_empty_response_continue", "会话已结束但未产生助手正文,正在基于轨迹自动续跑…", map[string]interface{}{
"conversationId": conversationID,
"source": "eino",
"resumeKind": "trace_segment",
})
}
return out, ErrEmptyResponseContinue
}
return out, nil return out, nil
} }
func shouldEinoEmptyResponseContinue(out *RunResult, emptyHint string, accumulatedLen, baseCount int) bool {
if out == nil || accumulatedLen <= baseCount {
return false
}
return strings.TrimSpace(out.Response) == strings.TrimSpace(emptyHint)
}
func persistTraceSource(args *einoADKRunLoopArgs, fallback []adk.Message) []adk.Message { func persistTraceSource(args *einoADKRunLoopArgs, fallback []adk.Message) []adk.Message {
if args != nil && args.ModelFacingTrace != nil { if args != nil && args.ModelFacingTrace != nil {
if snap := args.ModelFacingTrace.Snapshot(); len(snap) > 0 { if snap := args.ModelFacingTrace.Snapshot(); len(snap) > 0 {
@@ -1108,6 +1121,78 @@ func friendlyEinoExecuteInvokeTail(invokeErr error) string {
return "[执行未正常结束] " + invokeErr.Error() return "[执行未正常结束] " + invokeErr.Error()
} }
// nextAgentEventWithContext 在 ctx 取消时不再无限阻塞于 iter.Next()(工具执行/模型推理期间常见)。
func nextAgentEventWithContext(ctx context.Context, iter *adk.AsyncIterator[*adk.AgentEvent]) (ev *adk.AgentEvent, ok bool, ctxErr error) {
if iter == nil {
return nil, false, nil
}
type nextRes struct {
ev *adk.AgentEvent
ok bool
}
ch := make(chan nextRes, 1)
go func() {
e, o := iter.Next()
ch <- nextRes{e, o}
}()
select {
case <-ctx.Done():
return nil, false, ctx.Err()
case res := <-ch:
return res.ev, res.ok, nil
}
}
// recvSchemaMessageStream 消费 ADK Tool 流式结果;ctx 取消时立即返回,避免 amass 等无输出时永久阻塞。
func recvSchemaMessageStream(ctx context.Context, stream *schema.StreamReader[*schema.Message]) (content, toolCallID string, recvErr error) {
if stream == nil {
return "", "", nil
}
type streamMsg struct {
chunk *schema.Message
err error
}
recvCh := make(chan streamMsg, 8)
go func() {
defer close(recvCh)
for {
ch, rerr := stream.Recv()
recvCh <- streamMsg{chunk: ch, err: rerr}
if rerr != nil {
return
}
}
}()
var buf strings.Builder
for {
select {
case <-ctx.Done():
return buf.String(), toolCallID, ctx.Err()
case sm, open := <-recvCh:
if !open {
return buf.String(), toolCallID, nil
}
rerr := sm.err
if errors.Is(rerr, io.EOF) {
return buf.String(), toolCallID, nil
}
if rerr != nil {
return buf.String(), toolCallID, rerr
}
chunk := sm.chunk
if chunk == nil {
continue
}
if chunk.Content != "" {
buf.WriteString(chunk.Content)
}
if tid := strings.TrimSpace(chunk.ToolCallID); tid != "" {
toolCallID = tid
}
}
}
}
func buildEinoRunResultFromAccumulated( func buildEinoRunResultFromAccumulated(
orchMode string, orchMode string,
runAccumulatedMsgs []adk.Message, runAccumulatedMsgs []adk.Message,
@@ -0,0 +1,74 @@
package multiagent
import (
"context"
"errors"
"io"
"testing"
"time"
"github.com/cloudwego/eino/schema"
)
func TestRecvSchemaMessageStream_EOF(t *testing.T) {
sr, sw := schema.Pipe[*schema.Message](4)
_ = sw.Send(schema.ToolMessage("hello", "tc-1"), nil)
sw.Close()
content, tid, err := recvSchemaMessageStream(context.Background(), sr)
if err != nil {
t.Fatalf("unexpected err: %v", err)
}
if content != "hello" {
t.Fatalf("content=%q want hello", content)
}
if tid != "tc-1" {
t.Fatalf("toolCallID=%q want tc-1", tid)
}
}
func TestRecvSchemaMessageStream_ContextCancel(t *testing.T) {
sr, sw := schema.Pipe[*schema.Message](4)
t.Cleanup(func() { sw.Close() })
ctx, cancel := context.WithCancel(context.Background())
go func() {
time.Sleep(30 * time.Millisecond)
cancel()
}()
content, _, err := recvSchemaMessageStream(ctx, sr)
if !errors.Is(err, context.Canceled) {
t.Fatalf("want context.Canceled, got %v content=%q", err, content)
}
}
func TestRecvSchemaMessageStream_RecvError(t *testing.T) {
sr, sw := schema.Pipe[*schema.Message](4)
want := errors.New("stream broken")
_ = sw.Send(nil, want)
sw.Close()
_, _, err := recvSchemaMessageStream(context.Background(), sr)
if !errors.Is(err, want) {
t.Fatalf("want %v, got %v", want, err)
}
}
func TestRecvSchemaMessageStream_NilStream(t *testing.T) {
content, tid, err := recvSchemaMessageStream(context.Background(), nil)
if err != nil || content != "" || tid != "" {
t.Fatalf("nil stream: content=%q tid=%q err=%v", content, tid, err)
}
}
func TestRecvSchemaMessageStream_EOFViaEmptyRead(t *testing.T) {
sr, sw := schema.Pipe[*schema.Message](4)
_ = sw.Send(nil, io.EOF)
sw.Close()
_, _, err := recvSchemaMessageStream(context.Background(), sr)
if err != nil {
t.Fatalf("EOF should not surface as error, got %v", err)
}
}
@@ -0,0 +1,50 @@
package multiagent
import (
"github.com/cloudwego/eino/adk"
"go.uber.org/zap"
)
// einoChatModelTailConfig configures middleware appended after reduction/skill/plantask
// and immediately before each ChatModel invocation pipeline completes.
//
// Order (best practice):
// 1. system merge — accurate token count for summarization
// 2. continuation user dedup — drop stale session-resume injections
// 3. summarization
// 4. orphan tool prune
// 5. telemetry
// 6. model-facing trace snapshot
type einoChatModelTailConfig struct {
logger *zap.Logger
phase string
summarization adk.ChatModelAgentMiddleware
modelName string
conversationID string
trace *modelFacingTraceHolder
skipOrphanPruner bool
skipTelemetry bool
skipTrace bool
}
func appendEinoChatModelTailMiddlewares(handlers []adk.ChatModelAgentMiddleware, cfg einoChatModelTailConfig) []adk.ChatModelAgentMiddleware {
handlers = append(handlers, newSystemMessageNormalizerMiddleware(cfg.logger, cfg.phase))
handlers = append(handlers, newContinuationUserDedupMiddleware(cfg.logger, cfg.phase))
if cfg.summarization != nil {
handlers = append(handlers, cfg.summarization)
}
if !cfg.skipOrphanPruner {
handlers = append(handlers, newOrphanToolPrunerMiddleware(cfg.logger, cfg.phase))
}
if !cfg.skipTelemetry {
if teleMw := newEinoModelInputTelemetryMiddleware(cfg.logger, cfg.modelName, cfg.conversationID, cfg.phase); teleMw != nil {
handlers = append(handlers, teleMw)
}
}
if !cfg.skipTrace && cfg.trace != nil {
if capMw := newModelFacingTraceMiddleware(cfg.trace); capMw != nil {
handlers = append(handlers, capMw)
}
}
return handlers
}
@@ -1,21 +0,0 @@
package multiagent
import "testing"
func TestShouldEinoEmptyResponseContinue(t *testing.T) {
t.Parallel()
hint := "(empty hint)"
out := &RunResult{Response: hint}
if !shouldEinoEmptyResponseContinue(out, hint, 3, 1) {
t.Fatal("expected continue when response is empty hint and trace grew")
}
if shouldEinoEmptyResponseContinue(out, hint, 1, 1) {
t.Fatal("expected no continue when trace did not grow")
}
if shouldEinoEmptyResponseContinue(&RunResult{Response: "hello"}, hint, 3, 1) {
t.Fatal("expected no continue when response has content")
}
if shouldEinoEmptyResponseContinue(nil, hint, 3, 1) {
t.Fatal("expected no continue for nil result")
}
}
@@ -6,9 +6,11 @@ import (
"fmt" "fmt"
"io" "io"
"strings" "strings"
"sync"
"time" "time"
"cyberstrike-ai/internal/einomcp" "cyberstrike-ai/internal/einomcp"
"cyberstrike-ai/internal/mcp"
"cyberstrike-ai/internal/security" "cyberstrike-ai/internal/security"
"github.com/cloudwego/eino/adk/filesystem" "github.com/cloudwego/eino/adk/filesystem"
@@ -49,7 +51,7 @@ func einoExecuteRecvErrIsToolTimeout(rerr error, tctx context.Context) bool {
// 对「完全后台」命令自动开启 RunInBackendGround,与 local.runCmdInBackground 行为对齐。 // 对「完全后台」命令自动开启 RunInBackendGround,与 local.runCmdInBackground 行为对齐。
// //
// 使用 Pipe 将内层流转发给调用方:在 inner EOF 后、关闭 Pipe 前同步调用 ToolInvokeNotify.Fire // 使用 Pipe 将内层流转发给调用方:在 inner EOF 后、关闭 Pipe 前同步调用 ToolInvokeNotify.Fire
// 保证 run loop 在模型开始下一轮输出前已记录 execute 结果(用于 UI 与「重复助手复述」去重) // run loop 收到 Fire 后立即推送 tool_resulttoolResultSent 去重),避免 ADK Tool 事件迟到时 UI 卡在「执行中」
// //
// 若 inner 在校验阶段直接返回 error(未建立 reader),不会进入下方 goroutine,也必须 Fire // 若 inner 在校验阶段直接返回 error(未建立 reader),不会进入下方 goroutine,也必须 Fire
// 否则 pending tool_call 要等整轮 run 结束才被 force-close,与已展示的助手/工具软错误文案不同步。 // 否则 pending tool_call 要等整轮 run 结束才被 force-close,与已展示的助手/工具软错误文案不同步。
@@ -80,15 +82,23 @@ func (w *einoStreamingShellWrap) ExecuteStreaming(ctx context.Context, input *fi
req.Command = prependPythonUnbufferedEnv(req.Command) req.Command = prependPythonUnbufferedEnv(req.Command)
tid := strings.TrimSpace(compose.GetToolCallID(ctx)) tid := strings.TrimSpace(compose.GetToolCallID(ctx))
agentTag := strings.TrimSpace(w.einoAgentName) agentTag := strings.TrimSpace(w.einoAgentName)
convID := mcp.MCPConversationIDFromContext(ctx)
execReg := mcp.EinoExecuteRunRegistryFromContext(ctx)
execCtx := ctx execCtx, execCancel := context.WithCancel(ctx)
var execCancel context.CancelFunc var timeoutCancel context.CancelFunc
if w.toolTimeoutMinutes > 0 { if w.toolTimeoutMinutes > 0 {
execCtx, execCancel = context.WithTimeout(ctx, time.Duration(w.toolTimeoutMinutes)*time.Minute) execCtx, timeoutCancel = context.WithTimeout(execCtx, time.Duration(w.toolTimeoutMinutes)*time.Minute)
}
if execReg != nil && convID != "" {
execReg.RegisterActiveEinoExecute(convID, execCancel)
} }
sr, err := w.inner.ExecuteStreaming(execCtx, &req) sr, err := w.inner.ExecuteStreaming(execCtx, &req)
if err != nil { if err != nil {
if timeoutCancel != nil {
timeoutCancel()
}
if execCancel != nil { if execCancel != nil {
execCancel() execCancel()
} }
@@ -111,6 +121,9 @@ func (w *einoStreamingShellWrap) ExecuteStreaming(ctx context.Context, input *fi
return nil, err return nil, err
} }
if sr == nil || w.invokeNotify == nil { if sr == nil || w.invokeNotify == nil {
if timeoutCancel != nil {
timeoutCancel()
}
if execCancel != nil { if execCancel != nil {
execCancel() execCancel()
} }
@@ -119,11 +132,32 @@ func (w *einoStreamingShellWrap) ExecuteStreaming(ctx context.Context, input *fi
outR, outW := schema.Pipe[*filesystem.ExecuteResponse](32) outR, outW := schema.Pipe[*filesystem.ExecuteResponse](32)
go func(inner *schema.StreamReader[*filesystem.ExecuteResponse], command string, cancel context.CancelFunc, tctx context.Context) { go func(inner *schema.StreamReader[*filesystem.ExecuteResponse], command string, cancel context.CancelFunc, timeoutCleanup context.CancelFunc, tctx context.Context, conversationID string, reg mcp.EinoExecuteRunRegistry) {
defer inner.Close() var innerCloseOnce sync.Once
closeInner := func() {
innerCloseOnce.Do(func() { inner.Close() })
}
defer closeInner()
if timeoutCleanup != nil {
defer timeoutCleanup()
}
if cancel != nil { if cancel != nil {
defer cancel() defer cancel()
} }
if reg != nil && conversationID != "" {
defer reg.UnregisterActiveEinoExecute(conversationID)
}
// ctx 取消时关闭内层流,避免 amass 等长时间无换行输出时 Recv 永久阻塞。
stopWatch := make(chan struct{})
go func() {
select {
case <-tctx.Done():
closeInner()
case <-stopWatch:
}
}()
defer close(stopWatch)
var sb strings.Builder var sb strings.Builder
success := true success := true
@@ -144,6 +178,10 @@ func (w *einoStreamingShellWrap) ExecuteStreaming(ctx context.Context, input *fi
invokeErr = context.DeadlineExceeded invokeErr = context.DeadlineExceeded
break break
} }
if errors.Is(rerr, context.Canceled) || (tctx != nil && errors.Is(tctx.Err(), context.Canceled)) {
invokeErr = context.Canceled
break
}
_ = outW.Send(nil, rerr) _ = outW.Send(nil, rerr)
break break
} }
@@ -178,6 +216,21 @@ func (w *einoStreamingShellWrap) ExecuteStreaming(ctx context.Context, input *fi
success = false success = false
invokeErr = context.DeadlineExceeded invokeErr = context.DeadlineExceeded
} }
// 用户「中断并继续」终止 execute:合并说明进工具结果(与 MCP CancelToolExecutionWithNote 一致)。
partialStreamed := sb.String()
var abortNote string
if reg != nil && conversationID != "" && (invokeErr != nil || errors.Is(tctx.Err(), context.Canceled)) {
if note := reg.TakeEinoExecuteAbortNote(conversationID); note != "" {
abortNote = note
merged := mcp.MergePartialToolOutputAndAbortNote(partialStreamed, note)
sb.Reset()
sb.WriteString(merged)
if invokeErr == nil {
success = false
invokeErr = context.Canceled
}
}
}
// ADK 从本 Pipe 拼出 tool 消息正文;仅 Notify 尾标不会进入模型上下文。超时句写入流,与 UI 一致。 // ADK 从本 Pipe 拼出 tool 消息正文;仅 Notify 尾标不会进入模型上下文。超时句写入流,与 UI 一致。
if invokeErr != nil && errors.Is(invokeErr, context.DeadlineExceeded) { if invokeErr != nil && errors.Is(invokeErr, context.DeadlineExceeded) {
hint := "\n\n" + einoExecuteTimeoutUserHint() + "\n" hint := "\n\n" + einoExecuteTimeoutUserHint() + "\n"
@@ -187,12 +240,20 @@ func (w *einoStreamingShellWrap) ExecuteStreaming(ctx context.Context, input *fi
} }
sb.WriteString(hint) sb.WriteString(hint)
} }
// 中断时循环内已逐行写入 stdout;此处只追加 USER INTERRUPT NOTE,避免整段输出重复。
if invokeErr != nil && errors.Is(invokeErr, context.Canceled) && abortNote != "" {
if partialStreamed != "" {
_ = outW.Send(&filesystem.ExecuteResponse{Output: "\n\n" + mcp.AbortNoteBannerForModel + "\n" + abortNote}, nil)
} else if text := strings.TrimSpace(sb.String()); text != "" {
_ = outW.Send(&filesystem.ExecuteResponse{Output: text + "\n"}, nil)
}
}
if w.recordMonitor != nil { if w.recordMonitor != nil {
w.recordMonitor(tid, command, sb.String(), success, invokeErr) w.recordMonitor(tid, command, sb.String(), success, invokeErr)
} }
w.invokeNotify.Fire(tid, "execute", agentTag, success, sb.String(), invokeErr) w.invokeNotify.Fire(tid, "execute", agentTag, success, sb.String(), invokeErr)
outW.Close() outW.Close()
}(sr, userCmd, execCancel, execCtx) }(sr, userCmd, execCancel, timeoutCancel, execCtx, convID, execReg)
return outR, nil return outR, nil
} }
@@ -9,6 +9,7 @@ import (
"time" "time"
"cyberstrike-ai/internal/einomcp" "cyberstrike-ai/internal/einomcp"
"cyberstrike-ai/internal/mcp"
"github.com/cloudwego/eino/adk/filesystem" "github.com/cloudwego/eino/adk/filesystem"
"github.com/cloudwego/eino/schema" "github.com/cloudwego/eino/schema"
@@ -122,6 +123,94 @@ func TestEinoStreamingShellWrap_ToolTimeoutRecvErrIsSoft(t *testing.T) {
} }
} }
func TestEinoStreamingShellWrap_CapturesOutputWithToolTimeout(t *testing.T) {
inner := &mockStreamingShell{output: "100\n"}
notify := einomcp.NewToolInvokeNotifyHolder()
var firedContent string
notify.Set(func(toolCallID, toolName, einoAgent string, success bool, content string, invokeErr error) {
firedContent = content
})
wrap := &einoStreamingShellWrap{
inner: inner,
invokeNotify: notify,
toolTimeoutMinutes: 60,
}
sr, err := wrap.ExecuteStreaming(context.Background(), &filesystem.ExecuteRequest{Command: "echo 100"})
if err != nil {
t.Fatalf("ExecuteStreaming: %v", err)
}
defer sr.Close()
var got strings.Builder
for {
resp, rerr := sr.Recv()
if errors.Is(rerr, io.EOF) {
break
}
if rerr != nil {
t.Fatalf("unexpected stream error: %v", rerr)
}
if resp != nil && resp.Output != "" {
got.WriteString(resp.Output)
}
}
if !strings.Contains(got.String(), "100") {
t.Fatalf("stream output = %q, want contains 100", got.String())
}
if !strings.Contains(firedContent, "100") {
t.Fatalf("notify content = %q, want contains 100", firedContent)
}
}
func TestEinoStreamingShellWrap_AbortNoteDoesNotDuplicateStreamedOutput(t *testing.T) {
inner := &mockStreamingShell{output: "line1\nline2\n", recvErr: context.Canceled}
notify := einomcp.NewToolInvokeNotifyHolder()
wrap := &einoStreamingShellWrap{
inner: inner,
invokeNotify: notify,
}
reg := &abortNoteTestRegistry{note: "改成20次"}
ctx := mcp.WithEinoExecuteRunRegistry(
mcp.WithMCPConversationID(context.Background(), "conv-abort-dup"),
reg,
)
sr, err := wrap.ExecuteStreaming(ctx, &filesystem.ExecuteRequest{Command: "ping -c 10 baidu.com"})
if err != nil {
t.Fatalf("ExecuteStreaming: %v", err)
}
defer sr.Close()
var got strings.Builder
for {
resp, rerr := sr.Recv()
if errors.Is(rerr, io.EOF) {
break
}
if rerr != nil {
t.Fatalf("unexpected stream error: %v", rerr)
}
if resp != nil && resp.Output != "" {
got.WriteString(resp.Output)
}
}
out := got.String()
if strings.Count(out, "line1") != 1 || strings.Count(out, "line2") != 1 {
t.Fatalf("stream duplicated stdout: %q", out)
}
if !strings.Contains(out, "改成20次") {
t.Fatalf("stream missing abort note: %q", out)
}
}
type abortNoteTestRegistry struct {
note string
}
func (r *abortNoteTestRegistry) RegisterActiveEinoExecute(string, context.CancelFunc) {}
func (r *abortNoteTestRegistry) UnregisterActiveEinoExecute(string) {}
func (r *abortNoteTestRegistry) AbortActiveEinoExecute(string, string) bool { return false }
func (r *abortNoteTestRegistry) TakeEinoExecuteAbortNote(string) string { return r.note }
func TestEinoStreamingShellWrap_NonTimeoutRecvErrStillHard(t *testing.T) { func TestEinoStreamingShellWrap_NonTimeoutRecvErrStillHard(t *testing.T) {
inner := &mockStreamingShell{recvErr: errors.New("broken pipe")} inner := &mockStreamingShell{recvErr: errors.New("broken pipe")}
wrap := &einoStreamingShellWrap{inner: inner} wrap := &einoStreamingShellWrap{inner: inner}
+3 -6
View File
@@ -243,17 +243,14 @@ func prependEinoMiddlewares(
return outTools, extraHandlers, toolSearchActive, nil return outTools, extraHandlers, toolSearchActive, nil
} }
func deepExtrasFromConfig(ma *config.MultiAgentConfig) (outputKey string, retry *adk.ModelRetryConfig, taskDesc func(context.Context, []adk.Agent) (string, error)) { func deepExtrasFromConfig(ma *config.MultiAgentConfig) (outputKey string, taskDesc func(context.Context, []adk.Agent) (string, error)) {
if ma == nil { if ma == nil {
return "", nil, nil return "", nil
} }
mw := ma.EinoMiddleware mw := ma.EinoMiddleware
if k := strings.TrimSpace(mw.DeepOutputKey); k != "" { if k := strings.TrimSpace(mw.DeepOutputKey); k != "" {
outputKey = k outputKey = k
} }
if mw.DeepModelRetryMaxRetries > 0 {
retry = &adk.ModelRetryConfig{MaxRetries: mw.DeepModelRetryMaxRetries}
}
prefix := strings.TrimSpace(mw.TaskToolDescriptionPrefix) prefix := strings.TrimSpace(mw.TaskToolDescriptionPrefix)
if prefix != "" { if prefix != "" {
taskDesc = func(ctx context.Context, agents []adk.Agent) (string, error) { taskDesc = func(ctx context.Context, agents []adk.Agent) (string, error) {
@@ -274,5 +271,5 @@ func deepExtrasFromConfig(ma *config.MultiAgentConfig) (outputKey string, retry
return prefix + "\n可用子代理(按名称 transfer / task 调用):" + strings.Join(names, "、"), nil return prefix + "\n可用子代理(按名称 transfer / task 调用):" + strings.Join(names, "、"), nil
} }
} }
return outputKey, retry, taskDesc return outputKey, taskDesc
} }
+9 -13
View File
@@ -94,24 +94,20 @@ func NewPlanExecuteRoot(ctx context.Context, a *PlanExecuteRootArgs) (adk.Resuma
if a.SkillMiddleware != nil { if a.SkillMiddleware != nil {
execHandlers = append(execHandlers, a.SkillMiddleware) execHandlers = append(execHandlers, a.SkillMiddleware)
} }
// 4. summarization(最后,与 Deep/Supervisor 一致 // 4. pre-summarization normalize + continuation dedup, then summarization (与 Deep/Supervisor 一致)
if a.AppCfg != nil { if a.AppCfg != nil {
sumMw, sumErr := newEinoSummarizationMiddleware(ctx, a.ExecModel, a.AppCfg, a.MwCfg, a.ConversationID, a.DB, a.ProjectID, a.Logger) sumMw, sumErr := newEinoSummarizationMiddleware(ctx, a.ExecModel, a.AppCfg, a.MwCfg, a.ConversationID, a.DB, a.ProjectID, a.Logger)
if sumErr != nil { if sumErr != nil {
return nil, fmt.Errorf("plan_execute executor summarization: %w", sumErr) return nil, fmt.Errorf("plan_execute executor summarization: %w", sumErr)
} }
execHandlers = append(execHandlers, sumMw) execHandlers = appendEinoChatModelTailMiddlewares(execHandlers, einoChatModelTailConfig{
} logger: a.Logger,
// 5. 孤儿 tool 消息兜底:必须挂在所有改写历史中间件(summarization/reduction/skill)之后、 phase: "plan_execute_executor",
// telemetry 之前,保证送入 ChatModel 的消息序列 tool_call ↔ tool_result 配对完整。 summarization: sumMw,
execHandlers = append(execHandlers, newOrphanToolPrunerMiddleware(a.Logger, "plan_execute_executor")) modelName: a.ModelName,
if teleMw := newEinoModelInputTelemetryMiddleware(a.Logger, a.ModelName, a.ConversationID, "plan_execute_executor"); teleMw != nil { conversationID: a.ConversationID,
execHandlers = append(execHandlers, teleMw) trace: a.ModelFacingTrace,
} })
if a.ModelFacingTrace != nil {
if capMw := newModelFacingTraceMiddleware(a.ModelFacingTrace); capMw != nil {
execHandlers = append(execHandlers, capMw)
}
} }
executor, err := newPlanExecuteExecutor(ctx, &planexecute.ExecutorConfig{ executor, err := newPlanExecuteExecutor(ctx, &planexecute.ExecutorConfig{
Model: a.ExecModel, Model: a.ExecModel,
+9 -11
View File
@@ -144,13 +144,14 @@ func RunEinoSingleChatModelAgent(
} }
handlers = append(handlers, einoSkillMW) handlers = append(handlers, einoSkillMW)
} }
handlers = append(handlers, mainSumMw) handlers = appendEinoChatModelTailMiddlewares(handlers, einoChatModelTailConfig{
if teleMw := newEinoModelInputTelemetryMiddleware(logger, appCfg.OpenAI.Model, conversationID, "eino_single"); teleMw != nil { logger: logger,
handlers = append(handlers, teleMw) phase: "eino_single",
} summarization: mainSumMw,
if capMw := newModelFacingTraceMiddleware(modelFacingTrace); capMw != nil { modelName: appCfg.OpenAI.Model,
handlers = append(handlers, capMw) conversationID: conversationID,
} trace: modelFacingTrace,
})
maxIter := agentMaxIterations(appCfg) maxIter := agentMaxIterations(appCfg)
@@ -188,13 +189,10 @@ func RunEinoSingleChatModelAgent(
MaxIterations: maxIter, MaxIterations: maxIter,
Handlers: handlers, Handlers: handlers,
} }
outKey, modelRetry, _ := deepExtrasFromConfig(ma) outKey, _ := deepExtrasFromConfig(ma)
if outKey != "" { if outKey != "" {
chatCfg.OutputKey = outKey chatCfg.OutputKey = outKey
} }
if modelRetry != nil {
chatCfg.ModelRetryConfig = modelRetry
}
chatAgent, err := adk.NewChatModelAgent(ctx, chatCfg) chatAgent, err := adk.NewChatModelAgent(ctx, chatCfg)
if err != nil { if err != nil {
+14 -15
View File
@@ -22,8 +22,6 @@ import (
"go.uber.org/zap" "go.uber.org/zap"
) )
const defaultSummarizationRetryMax = 3
// einoSummarizeUserInstruction:压缩历史时保留渗透测试关键信息。 // einoSummarizeUserInstruction:压缩历史时保留渗透测试关键信息。
const einoSummarizeUserInstruction = `在保持所有关键安全测试信息完整的前提下压缩对话历史 const einoSummarizeUserInstruction = `在保持所有关键安全测试信息完整的前提下压缩对话历史
@@ -97,10 +95,8 @@ func newEinoSummarizationMiddleware(
} }
} }
retryMax := defaultSummarizationRetryMax retryPolicy := einoTransientRunRetryPolicyFromMW(mwCfg)
if mwCfg != nil && mwCfg.SummarizationRetryMaxAttempts > 0 { retryMax := retryPolicy.maxAttempts
retryMax = mwCfg.SummarizationRetryMaxAttempts
}
// ModelOptions apply only to summarization Generate (same ChatModel instance as the agent). // ModelOptions apply only to summarization Generate (same ChatModel instance as the agent).
// Strip thinking/reasoning on this call path; mark requests for empty-choices diagnostics. // Strip thinking/reasoning on this call path; mark requests for empty-choices diagnostics.
@@ -137,13 +133,14 @@ func newEinoSummarizationMiddleware(
Retry: &summarization.RetryConfig{ Retry: &summarization.RetryConfig{
MaxRetries: &retryMax, MaxRetries: &retryMax,
ShouldRetry: func(_ context.Context, _ adk.Message, err error) bool { ShouldRetry: func(_ context.Context, _ adk.Message, err error) bool {
if err != nil && logger != nil { retry := isEinoTransientRunError(err)
logger.Warn("eino summarization generate attempt failed, will retry if attempts remain", if retry && logger != nil {
logger.Warn("eino summarization generate transient error, will retry if attempts remain",
zap.Error(err), zap.Error(err),
zap.Int("max_retries", retryMax), zap.Int("max_retries", retryMax),
) )
} }
return err != nil return retry
}, },
}, },
Finalize: func(ctx context.Context, originalMessages []adk.Message, summary adk.Message) ([]adk.Message, error) { Finalize: func(ctx context.Context, originalMessages []adk.Message, summary adk.Message) ([]adk.Message, error) {
@@ -260,17 +257,19 @@ func summarizeFinalizeWithRecentAssistantToolTrail(
nonSystem = append(nonSystem, msg) nonSystem = append(nonSystem, msg)
} }
mergedSystem := mergeCollectedSystemMessages(systemMsgs)
if recentTrailTokenBudget <= 0 || len(nonSystem) == 0 { if recentTrailTokenBudget <= 0 || len(nonSystem) == 0 {
out := make([]adk.Message, 0, len(systemMsgs)+1) out := make([]adk.Message, 0, len(mergedSystem)+1)
out = append(out, systemMsgs...) out = append(out, mergedSystem...)
out = append(out, summary) out = append(out, summary)
return out, nil return out, nil
} }
rounds := splitMessagesIntoRounds(nonSystem) rounds := splitMessagesIntoRounds(nonSystem)
if len(rounds) == 0 { if len(rounds) == 0 {
out := make([]adk.Message, 0, len(systemMsgs)+1) out := make([]adk.Message, 0, len(mergedSystem)+1)
out = append(out, systemMsgs...) out = append(out, mergedSystem...)
out = append(out, summary) out = append(out, summary)
return out, nil return out, nil
} }
@@ -322,8 +321,8 @@ func summarizeFinalizeWithRecentAssistantToolTrail(
selectedMsgs = append(selectedMsgs, selectedRoundsReverse[i].messages...) selectedMsgs = append(selectedMsgs, selectedRoundsReverse[i].messages...)
} }
out := make([]adk.Message, 0, len(systemMsgs)+1+len(selectedMsgs)) out := make([]adk.Message, 0, len(mergedSystem)+1+len(selectedMsgs))
out = append(out, systemMsgs...) out = append(out, mergedSystem...)
out = append(out, summary) out = append(out, summary)
out = append(out, selectedMsgs...) out = append(out, selectedMsgs...)
return out, nil return out, nil
+15 -6
View File
@@ -192,8 +192,8 @@ func TestSummarizeFinalize_KeepsToolRoundIntact(t *testing.T) {
if len(out) < 2 { if len(out) < 2 {
t.Fatalf("output too short: %d", len(out)) t.Fatalf("output too short: %d", len(out))
} }
if out[0] != sys { if out[0].Role != schema.System || out[0].Content != "sys" {
t.Fatalf("first message must be system") t.Fatalf("first message must be system sys, got %s: %q", out[0].Role, out[0].Content)
} }
if out[1] != summary { if out[1] != summary {
t.Fatalf("second message must be summary") t.Fatalf("second message must be summary")
@@ -293,12 +293,12 @@ func TestSummarizeFinalize_BudgetZeroFallsBackToSummaryOnly(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("unexpected error: %v", err) t.Fatalf("unexpected error: %v", err)
} }
if len(out) != 2 || out[0] != sys || out[1] != summary { if len(out) != 2 || out[0].Role != schema.System || out[0].Content != "sys" || out[1] != summary {
t.Fatalf("budget=0 must yield [system, summary] only, got %+v", out) t.Fatalf("budget=0 must yield [system, summary] only, got %+v", out)
} }
} }
func TestSummarizeFinalize_PreservesAllSystemMessages(t *testing.T) { func TestSummarizeFinalize_MergesSystemMessages(t *testing.T) {
sys1 := schema.SystemMessage("sys1") sys1 := schema.SystemMessage("sys1")
sys2 := schema.SystemMessage("sys2") sys2 := schema.SystemMessage("sys2")
summary := schema.AssistantMessage("s", nil) summary := schema.AssistantMessage("s", nil)
@@ -321,10 +321,13 @@ func TestSummarizeFinalize_PreservesAllSystemMessages(t *testing.T) {
for _, m := range out { for _, m := range out {
if m != nil && m.Role == schema.System { if m != nil && m.Role == schema.System {
systemCount++ systemCount++
if got := m.Content; got != "sys1\n\nsys2" {
t.Fatalf("unexpected merged system content: %q", got)
}
} }
} }
if systemCount != 2 { if systemCount != 1 {
t.Fatalf("want 2 system messages retained, got %d", systemCount) t.Fatalf("want 1 merged system message, got %d", systemCount)
} }
} }
@@ -378,6 +381,12 @@ func TestWriteSummarizationTranscript(t *testing.T) {
if !strings.Contains(text, "tool_calls:") || !strings.Contains(text, "nmap output") { if !strings.Contains(text, "tool_calls:") || !strings.Contains(text, "nmap output") {
t.Fatalf("missing tool round: %q", text) t.Fatalf("missing tool round: %q", text)
} }
if !strings.Contains(text, `"name":"stub_tool"`) || !strings.Contains(text, `"arguments":"{}"`) {
t.Fatalf("missing tool name/arguments: %q", text)
}
if strings.Contains(text, "tool_call_id") || strings.Contains(text, `"id":"tc1"`) {
t.Fatalf("transcript should omit tool_call_id: %q", text)
}
} }
func TestSanitizeSystemContentForTranscript_BestPractice(t *testing.T) { func TestSanitizeSystemContentForTranscript_BestPractice(t *testing.T) {
@@ -23,6 +23,11 @@ const (
transcriptSkillsSystemMarker = "# Skills System" transcriptSkillsSystemMarker = "# Skills System"
) )
type transcriptToolCall struct {
Name string `json:"name"`
Arguments string `json:"arguments"`
}
// formatSummarizationTranscript renders pre-compaction messages for transcript.txt. // formatSummarizationTranscript renders pre-compaction messages for transcript.txt.
// Best practice: keep full user/assistant/tool turns; slim system to dynamic blocks only. // Best practice: keep full user/assistant/tool turns; slim system to dynamic blocks only.
func formatSummarizationTranscript(msgs []adk.Message) string { func formatSummarizationTranscript(msgs []adk.Message) string {
@@ -138,15 +143,21 @@ func appendTranscriptMessage(sb *strings.Builder, msg adk.Message) {
} }
} }
if len(msg.ToolCalls) > 0 { if len(msg.ToolCalls) > 0 {
if b, err := sonic.Marshal(msg.ToolCalls); err == nil { if b, err := sonic.Marshal(formatTranscriptToolCalls(msg.ToolCalls)); err == nil {
sb.WriteString("tool_calls: ") sb.WriteString("tool_calls: ")
sb.Write(b) sb.Write(b)
sb.WriteByte('\n') sb.WriteByte('\n')
} }
} }
if msg.ToolCallID != "" { }
sb.WriteString("tool_call_id: ")
sb.WriteString(msg.ToolCallID) func formatTranscriptToolCalls(calls []schema.ToolCall) []transcriptToolCall {
sb.WriteByte('\n') out := make([]transcriptToolCall, 0, len(calls))
} for _, tc := range calls {
out = append(out, transcriptToolCall{
Name: tc.Function.Name,
Arguments: tc.Function.Arguments,
})
}
return out
} }
+74 -14
View File
@@ -3,6 +3,7 @@ package multiagent
import ( import (
"context" "context"
"errors" "errors"
"fmt"
"strings" "strings"
"time" "time"
@@ -17,8 +18,9 @@ const (
defaultEinoRunRetryMaxBackoff = 30 * time.Second defaultEinoRunRetryMaxBackoff = 30 * time.Second
) )
// isEinoTransientRunError 判断 ADK 运行期错误是否适合指数退避续跑(429、5xx、网络抖动等) // isEinoTransientRunError 是 Eino 运行期「可退避重试 vs 直接失败」的唯一判据
// 用户取消、超时、迭代上限等由 run loop 单独处理,不在此列 // 429/5xx/网络抖动等返回 true;用户取消、超时、迭代上限、鉴权失败等返回 false
// 其它模块(run loop、summarization 等)只调用本函数,不在别处维护平行规则。
func isEinoTransientRunError(err error) bool { func isEinoTransientRunError(err error) bool {
if err == nil { if err == nil {
return false return false
@@ -60,6 +62,7 @@ func isEinoTransientRunError(err error) bool {
"dial tcp", "dial tcp",
"tls handshake timeout", "tls handshake timeout",
"stream error", "stream error",
"goaway", // http2: server sent GOAWAY and closed the connection
"unexpected eof", "unexpected eof",
`": eof`, // net/http: Post "url": EOF (often wraps io.EOF) `": eof`, // net/http: Post "url": EOF (often wraps io.EOF)
"unexpected end of json", "unexpected end of json",
@@ -78,6 +81,71 @@ func isEinoTransientRunError(err error) bool {
return false return false
} }
type einoTransientRunRetryPolicy struct {
maxAttempts int
maxBackoff time.Duration
}
func einoTransientRunRetryPolicyFromArgs(args *einoADKRunLoopArgs) einoTransientRunRetryPolicy {
return einoTransientRunRetryPolicy{
maxAttempts: einoRunRetryMaxAttempts(args),
maxBackoff: einoRunRetryMaxBackoff(args),
}
}
func einoTransientRunRetryPolicyFromMW(mw *config.MultiAgentEinoMiddlewareConfig) einoTransientRunRetryPolicy {
maxBackoff := defaultEinoRunRetryMaxBackoff
if mw != nil && mw.RunRetryMaxBackoffSec > 0 {
maxBackoff = time.Duration(mw.RunRetryMaxBackoffSec) * time.Second
}
return einoTransientRunRetryPolicy{
maxAttempts: RunRetryMaxAttemptsFromConfig(mw),
maxBackoff: maxBackoff,
}
}
// einoTransientRunRetrier 在 run loop 内对临时错误做指数退避并重启 Runner(唯一重试执行层)。
type einoTransientRunRetrier struct {
policy einoTransientRunRetryPolicy
attempts int
}
func newEinoTransientRunRetrier(policy einoTransientRunRetryPolicy) *einoTransientRunRetrier {
return &einoTransientRunRetrier{policy: policy}
}
// tryRetry 对临时错误退避后返回重启消息;次数用尽返回 exhausted 错误。
func (r *einoTransientRunRetrier) tryRetry(
ctx context.Context,
runErr error,
args *einoADKRunLoopArgs,
baseMsgs, accumulated []adk.Message,
baseCount int,
) (restarted bool, restartMsgs []adk.Message, ctxSource einoRunRestartContextSource, backoff time.Duration, fatal error) {
if runErr == nil || !isEinoTransientRunError(runErr) {
return false, nil, "", 0, runErr
}
r.attempts++
if r.attempts > r.policy.maxAttempts {
return false, nil, "", 0, fmt.Errorf("transient retry exhausted after %d attempts: %w", r.policy.maxAttempts, runErr)
}
backoff = einoTransientRetryBackoff(r.attempts-1, r.policy.maxBackoff)
select {
case <-ctx.Done():
return false, nil, "", 0, ctx.Err()
case <-time.After(backoff):
}
restartMsgs, ctxSource = einoMessagesForRunRestart(args, baseMsgs, accumulated, baseCount)
return true, restartMsgs, ctxSource, backoff, nil
}
func (r *einoTransientRunRetrier) attempt() int { return r.attempts }
func (r *einoTransientRunRetrier) maxAttempts() int { return r.policy.maxAttempts }
// reset 在一次成功推进后清零重试计数,使后续临时错误从第 1 次退避重新开始。
func (r *einoTransientRunRetrier) reset() { r.attempts = 0 }
func einoRunRetryMaxAttempts(args *einoADKRunLoopArgs) int { func einoRunRetryMaxAttempts(args *einoADKRunLoopArgs) int {
if args != nil && args.RunRetryMaxAttempts > 0 { if args != nil && args.RunRetryMaxAttempts > 0 {
return args.RunRetryMaxAttempts return args.RunRetryMaxAttempts
@@ -85,7 +153,7 @@ func einoRunRetryMaxAttempts(args *einoADKRunLoopArgs) int {
return defaultEinoRunRetryMaxAttempts return defaultEinoRunRetryMaxAttempts
} }
// RunRetryMaxAttemptsFromConfig 供 handler 分段续跑计数(与 eino_middleware.run_retry_max_attempts 一致 // RunRetryMaxAttemptsFromConfig 与 eino_middleware.run_retry_max_attempts 一致。
func RunRetryMaxAttemptsFromConfig(mw *config.MultiAgentEinoMiddlewareConfig) int { func RunRetryMaxAttemptsFromConfig(mw *config.MultiAgentEinoMiddlewareConfig) int {
if mw != nil && mw.RunRetryMaxAttempts > 0 { if mw != nil && mw.RunRetryMaxAttempts > 0 {
return mw.RunRetryMaxAttempts return mw.RunRetryMaxAttempts
@@ -93,15 +161,6 @@ func RunRetryMaxAttemptsFromConfig(mw *config.MultiAgentEinoMiddlewareConfig) in
return defaultEinoRunRetryMaxAttempts return defaultEinoRunRetryMaxAttempts
} }
// TransientRetryBackoff 供 handler 在分段续跑前退避。
func TransientRetryBackoff(attempt int, maxBackoffSec int) time.Duration {
max := defaultEinoRunRetryMaxBackoff
if maxBackoffSec > 0 {
max = time.Duration(maxBackoffSec) * time.Second
}
return einoTransientRetryBackoff(attempt, max)
}
func einoRunRetryMaxBackoff(args *einoADKRunLoopArgs) time.Duration { func einoRunRetryMaxBackoff(args *einoADKRunLoopArgs) time.Duration {
if args != nil && args.RunRetryMaxBackoffSec > 0 { if args != nil && args.RunRetryMaxBackoffSec > 0 {
return time.Duration(args.RunRetryMaxBackoffSec) * time.Second return time.Duration(args.RunRetryMaxBackoffSec) * time.Second
@@ -122,10 +181,11 @@ const (
// 1) ModelFacingTrace(与模型实际入参一致) 2) 事件流累积的 runAccumulatedMsgs 3) 初始 msgs。 // 1) ModelFacingTrace(与模型实际入参一致) 2) 事件流累积的 runAccumulatedMsgs 3) 初始 msgs。
func einoMessagesForRunRestart(args *einoADKRunLoopArgs, baseMsgs, accumulated []adk.Message, baseCount int) ([]adk.Message, einoRunRestartContextSource) { func einoMessagesForRunRestart(args *einoADKRunLoopArgs, baseMsgs, accumulated []adk.Message, baseCount int) ([]adk.Message, einoRunRestartContextSource) {
if trace := persistTraceSource(args, nil); len(trace) > 0 { if trace := persistTraceSource(args, nil); len(trace) > 0 {
return append([]adk.Message(nil), trace...), einoRestartContextModelTrace // modelFacingTrace includes prior Instruction system message(s); genModelInput will prepend again.
return stripADKSystemMessages(trace), einoRestartContextModelTrace
} }
if len(accumulated) > baseCount { if len(accumulated) > baseCount {
return append([]adk.Message(nil), accumulated...), einoRestartContextAccumulated return stripADKSystemMessages(accumulated), einoRestartContextAccumulated
} }
return append([]adk.Message(nil), baseMsgs...), einoRestartContextInitial return append([]adk.Message(nil), baseMsgs...), einoRestartContextInitial
} }
@@ -27,6 +27,7 @@ func TestIsEinoTransientRunError(t *testing.T) {
{"429", errors.New("HTTP 429 Too Many Requests"), true}, {"429", errors.New("HTTP 429 Too Many Requests"), true},
{"rate limit", errors.New(`{"error":"rate limit exceeded"}`), true}, {"rate limit", errors.New(`{"error":"rate limit exceeded"}`), true},
{"connection reset", errors.New("read tcp: connection reset by peer"), true}, {"connection reset", errors.New("read tcp: connection reset by peer"), true},
{"http2 goaway", errors.New("failed to receive stream chunk: error, http2: server sent GOAWAY and closed the connection; LastStreamID=791, ErrCode=NO_ERROR"), true},
{"unexpected eof", errors.New("unexpected EOF"), true}, {"unexpected eof", errors.New("unexpected EOF"), true},
{"503", errors.New("upstream returned 503"), true}, {"503", errors.New("upstream returned 503"), true},
{"iteration limit", errors.New("max iteration reached"), false}, {"iteration limit", errors.New("max iteration reached"), false},
@@ -90,6 +91,20 @@ func TestEinoRunRetryMaxAttemptsFromArgs(t *testing.T) {
} }
} }
func TestEinoTransientRunRetrierReset(t *testing.T) {
t.Parallel()
r := newEinoTransientRunRetrier(einoTransientRunRetryPolicy{maxAttempts: 10, maxBackoff: 30 * time.Second})
r.attempts = 3
r.reset()
if r.attempt() != 0 {
t.Fatalf("after reset: attempt=%d, want 0", r.attempt())
}
// 重置后下一次退避应从 2s 起算(attempt index 0)。
if got := einoTransientRetryBackoff(r.attempt(), r.policy.maxBackoff); got != 2*time.Second {
t.Fatalf("backoff after reset: got %v, want 2s", got)
}
}
func TestAppendUserMessageIfNeeded(t *testing.T) { func TestAppendUserMessageIfNeeded(t *testing.T) {
t.Parallel() t.Parallel()
msgs := []adk.Message{schema.UserMessage("old task")} msgs := []adk.Message{schema.UserMessage("old task")}
@@ -102,10 +117,3 @@ func TestAppendUserMessageIfNeeded(t *testing.T) {
t.Fatalf("should not duplicate user message: len=%d", len(dup)) t.Fatalf("should not duplicate user message: len=%d", len(dup))
} }
} }
func TestErrTransientRetryContinue(t *testing.T) {
t.Parallel()
if !errors.Is(ErrTransientRetryContinue, ErrTransientRetryContinue) {
t.Fatal("sentinel should match")
}
}
-8
View File
@@ -5,11 +5,3 @@ import "errors"
// ErrInterruptContinue 作为 context.CancelCause 使用:用户选择「中断并继续」且当前无进行中的 MCP 工具时, // ErrInterruptContinue 作为 context.CancelCause 使用:用户选择「中断并继续」且当前无进行中的 MCP 工具时,
// 取消当前推理/流式输出,并在同一会话任务内携带用户补充说明自动续跑下一轮(类似 Hermes 式人机回合)。 // 取消当前推理/流式输出,并在同一会话任务内携带用户补充说明自动续跑下一轮(类似 Hermes 式人机回合)。
var ErrInterruptContinue = errors.New("agent interrupt: continue with user-supplied context") var ErrInterruptContinue = errors.New("agent interrupt: continue with user-supplied context")
// ErrTransientRetryContinue 表示 Run 因 429/网络等临时错误结束,应由 handler 落库轨迹后
// loadHistoryFromAgentTrace 再开下一轮 Run(与 ErrInterruptContinue 同级的「分段续跑」语义)。
var ErrTransientRetryContinue = errors.New("agent transient: retry after persisting trace")
// ErrEmptyResponseContinue 表示 Eino ADK 会话正常结束但未捕获到助手正文,应由 handler 落库轨迹后
// loadHistoryFromAgentTrace 再开下一轮 Run(与 ErrInterruptContinue / ErrTransientRetryContinue 同级)。
var ErrEmptyResponseContinue = errors.New("agent empty response: continue after persisting trace")
@@ -27,7 +27,7 @@ import (
// 本中间件与之互补,专职兜底正向孤儿。 // 本中间件与之互补,专职兜底正向孤儿。
// - 仅剔除消息,不向历史里注入虚构 assistant(tc):虚构 tool_calls 反而会误导模型后续推理。 // - 仅剔除消息,不向历史里注入虚构 assistant(tc):虚构 tool_calls 反而会误导模型后续推理。
// 摘要已覆盖被裁剪段的语义,丢一条原始 tool 结果对对话连贯性影响最小。 // 摘要已覆盖被裁剪段的语义,丢一条原始 tool 结果对对话连贯性影响最小。
// - 位置建议:挂在所有可能改写历史的中间件(summarization / reduction / skill / plantask / // - 位置建议:挂在 summarization / reduction / skill / plantask / system 合并 / 续聊 dedup 之后,
// tool_search)之后,靠近 ChatModel 调用的那一端。 // tool_search)之后,靠近 ChatModel 调用的那一端。
type orphanToolPrunerMiddleware struct { type orphanToolPrunerMiddleware struct {
adk.BaseChatModelAgentMiddleware adk.BaseChatModelAgentMiddleware
+32 -36
View File
@@ -231,13 +231,13 @@ func RunDeepAgent(
} }
subHandlers = append(subHandlers, einoSkillMW) subHandlers = append(subHandlers, einoSkillMW)
} }
subHandlers = append(subHandlers, subSumMw) subHandlers = appendEinoChatModelTailMiddlewares(subHandlers, einoChatModelTailConfig{
// 孤儿 tool 消息兜底:放在 summarization 之后,telemetry 之前, logger: logger,
// 以便 telemetry 记录的 token 数与 LLM 实际入参一致。 phase: "sub_agent:" + id,
subHandlers = append(subHandlers, newOrphanToolPrunerMiddleware(logger, "sub_agent:"+id)) summarization: subSumMw,
if teleMw := newEinoModelInputTelemetryMiddleware(logger, appCfg.OpenAI.Model, conversationID, "sub_agent"); teleMw != nil { modelName: appCfg.OpenAI.Model,
subHandlers = append(subHandlers, teleMw) conversationID: conversationID,
} })
subInstrFinal := project.AppendVisionImageAnalysisIfReady(instr, appCfg.Vision.Ready()) subInstrFinal := project.AppendVisionImageAnalysisIfReady(instr, appCfg.Vision.Ready())
subInstrFinal = injectToolNamesOnlyInstruction(ctx, subInstrFinal, subTools, subToolSearchActive) subInstrFinal = injectToolNamesOnlyInstruction(ctx, subInstrFinal, subTools, subToolSearchActive)
@@ -379,14 +379,14 @@ func RunDeepAgent(
if einoSkillMW != nil { if einoSkillMW != nil {
deepHandlers = append(deepHandlers, einoSkillMW) deepHandlers = append(deepHandlers, einoSkillMW)
} }
deepHandlers = append(deepHandlers, mainSumMw) deepHandlers = appendEinoChatModelTailMiddlewares(deepHandlers, einoChatModelTailConfig{
deepHandlers = append(deepHandlers, newOrphanToolPrunerMiddleware(logger, "deep_orchestrator")) logger: logger,
if teleMw := newEinoModelInputTelemetryMiddleware(logger, appCfg.OpenAI.Model, conversationID, "deep_orchestrator"); teleMw != nil { phase: "deep_orchestrator",
deepHandlers = append(deepHandlers, teleMw) summarization: mainSumMw,
} modelName: appCfg.OpenAI.Model,
if capMw := newModelFacingTraceMiddleware(modelFacingTrace); capMw != nil { conversationID: conversationID,
deepHandlers = append(deepHandlers, capMw) trace: modelFacingTrace,
} })
supHandlers := []adk.ChatModelAgentMiddleware{} supHandlers := []adk.ChatModelAgentMiddleware{}
if len(mainOrchestratorPre) > 0 { if len(mainOrchestratorPre) > 0 {
@@ -395,14 +395,14 @@ func RunDeepAgent(
if einoSkillMW != nil { if einoSkillMW != nil {
supHandlers = append(supHandlers, einoSkillMW) supHandlers = append(supHandlers, einoSkillMW)
} }
supHandlers = append(supHandlers, mainSumMw) supHandlers = appendEinoChatModelTailMiddlewares(supHandlers, einoChatModelTailConfig{
supHandlers = append(supHandlers, newOrphanToolPrunerMiddleware(logger, "supervisor_orchestrator")) logger: logger,
if teleMw := newEinoModelInputTelemetryMiddleware(logger, appCfg.OpenAI.Model, conversationID, "supervisor_orchestrator"); teleMw != nil { phase: "supervisor_orchestrator",
supHandlers = append(supHandlers, teleMw) summarization: mainSumMw,
} modelName: appCfg.OpenAI.Model,
if capMw := newModelFacingTraceMiddleware(modelFacingTrace); capMw != nil { conversationID: conversationID,
supHandlers = append(supHandlers, capMw) trace: modelFacingTrace,
} })
mainToolsCfg := adk.ToolsConfig{ mainToolsCfg := adk.ToolsConfig{
ToolsNodeConfig: compose.ToolsNodeConfig{ ToolsNodeConfig: compose.ToolsNodeConfig{
@@ -416,7 +416,7 @@ func RunDeepAgent(
EmitInternalEvents: true, EmitInternalEvents: true,
} }
deepOutKey, modelRetry, taskGen := deepExtrasFromConfig(ma) deepOutKey, taskGen := deepExtrasFromConfig(ma)
var da adk.Agent var da adk.Agent
switch orchMode { switch orchMode {
@@ -451,12 +451,14 @@ func RunDeepAgent(
SkillMiddleware: einoSkillMW, SkillMiddleware: einoSkillMW,
FilesystemMiddleware: peFsMw, FilesystemMiddleware: peFsMw,
ModelFacingTrace: modelFacingTrace, ModelFacingTrace: modelFacingTrace,
PlannerReplannerRewriteHandlers: []adk.ChatModelAgentMiddleware{ PlannerReplannerRewriteHandlers: appendEinoChatModelTailMiddlewares(nil, einoChatModelTailConfig{
mainSumMw, logger: logger,
// 孤儿 tool 消息兜底:必须挂在 summarization 之后、telemetry 之前。 phase: "plan_execute_planner_replanner",
newOrphanToolPrunerMiddleware(logger, "plan_execute_planner_replanner"), summarization: mainSumMw,
newEinoModelInputTelemetryMiddleware(logger, appCfg.OpenAI.Model, conversationID, "plan_execute_planner_replanner_rewrite"), modelName: appCfg.OpenAI.Model,
}, conversationID: conversationID,
skipTrace: true,
}),
}) })
if perr != nil { if perr != nil {
return nil, perr return nil, perr
@@ -473,9 +475,6 @@ func RunDeepAgent(
Handlers: supHandlers, Handlers: supHandlers,
Exit: &adk.ExitTool{}, Exit: &adk.ExitTool{},
} }
if modelRetry != nil {
supCfg.ModelRetryConfig = modelRetry
}
if deepOutKey != "" { if deepOutKey != "" {
supCfg.OutputKey = deepOutKey supCfg.OutputKey = deepOutKey
} }
@@ -509,9 +508,6 @@ func RunDeepAgent(
if deepOutKey != "" { if deepOutKey != "" {
dcfg.OutputKey = deepOutKey dcfg.OutputKey = deepOutKey
} }
if modelRetry != nil {
dcfg.ModelRetryConfig = modelRetry
}
if taskGen != nil { if taskGen != nil {
dcfg.TaskToolDescriptionGenerator = taskGen dcfg.TaskToolDescriptionGenerator = taskGen
} }
@@ -0,0 +1,86 @@
package multiagent
import (
"context"
"github.com/cloudwego/eino/adk"
"github.com/cloudwego/eino/schema"
"go.uber.org/zap"
)
// systemMessageNormalizerMiddleware merges duplicate role=system messages into a single
// leading system message before summarization and each ChatModel call.
type systemMessageNormalizerMiddleware struct {
adk.BaseChatModelAgentMiddleware
logger *zap.Logger
phase string
}
func newSystemMessageNormalizerMiddleware(logger *zap.Logger, phase string) adk.ChatModelAgentMiddleware {
return &systemMessageNormalizerMiddleware{logger: logger, phase: phase}
}
func (m *systemMessageNormalizerMiddleware) BeforeModelRewriteState(
ctx context.Context,
state *adk.ChatModelAgentState,
mc *adk.ModelContext,
) (context.Context, *adk.ChatModelAgentState, error) {
_ = mc
if m == nil || state == nil || len(state.Messages) == 0 {
return ctx, state, nil
}
before := countADKSystemMessages(state.Messages)
if before <= 1 {
return ctx, state, nil
}
normalized := normalizeSingleLeadingSystemMessage(state.Messages, "")
if len(normalized) == len(state.Messages) && countADKSystemMessages(normalized) >= before {
return ctx, state, nil
}
if m.logger != nil {
m.logger.Info("eino system messages merged",
zap.String("phase", m.phase),
zap.Int("system_before", before),
zap.Int("system_after", countADKSystemMessages(normalized)),
zap.Int("messages_before", len(state.Messages)),
zap.Int("messages_after", len(normalized)),
)
}
out := *state
out.Messages = normalized
return ctx, &out, nil
}
func countADKSystemMessages(msgs []adk.Message) int {
n := 0
for _, msg := range msgs {
if msg != nil && msg.Role == schema.System {
n++
}
}
return n
}
// stripADKSystemMessages removes all system messages. Use before runner.Run restart when
// genModelInput will prepend a fresh Instruction.
func stripADKSystemMessages(msgs []adk.Message) []adk.Message {
if len(msgs) == 0 {
return msgs
}
out := make([]adk.Message, 0, len(msgs))
for _, msg := range msgs {
if msg == nil || msg.Role == schema.System {
continue
}
out = append(out, msg)
}
return out
}
// mergeCollectedSystemMessages collapses multiple system messages into one (or none).
func mergeCollectedSystemMessages(systemMsgs []adk.Message) []adk.Message {
if len(systemMsgs) == 0 {
return nil
}
return normalizeSingleLeadingSystemMessage(systemMsgs, "")
}
@@ -0,0 +1,75 @@
package multiagent
import (
"context"
"testing"
"github.com/cloudwego/eino/adk"
"github.com/cloudwego/eino/schema"
)
func TestStripADKSystemMessages(t *testing.T) {
in := []adk.Message{
schema.SystemMessage("a"),
schema.UserMessage("u"),
schema.SystemMessage("b"),
schema.AssistantMessage("x", nil),
}
out := stripADKSystemMessages(in)
if len(out) != 2 {
t.Fatalf("got %d messages, want 2", len(out))
}
if out[0].Role != schema.User || out[1].Role != schema.Assistant {
t.Fatalf("unexpected roles: %s, %s", out[0].Role, out[1].Role)
}
}
func TestEinoMessagesForRunRestart_StripsSystemFromTrace(t *testing.T) {
holder := newModelFacingTraceHolder()
holder.storeFromState(&adk.ChatModelAgentState{Messages: []adk.Message{
schema.SystemMessage("sys-1"),
schema.SystemMessage("sys-2"),
schema.UserMessage("task"),
}})
msgs, src := einoMessagesForRunRestart(&einoADKRunLoopArgs{ModelFacingTrace: holder}, nil, nil, 0)
if src != einoRestartContextModelTrace {
t.Fatalf("source: got %q want model_trace", src)
}
if len(msgs) != 1 || msgs[0].Role != schema.User {
t.Fatalf("expected user-only restart msgs, got %+v", msgs)
}
}
func TestSystemMessageNormalizerMiddleware_MergesDuplicates(t *testing.T) {
mw := newSystemMessageNormalizerMiddleware(nil, "test")
state := &adk.ChatModelAgentState{Messages: []adk.Message{
schema.SystemMessage("a"),
schema.SystemMessage("b"),
schema.UserMessage("u"),
}}
_, out, err := mw.(*systemMessageNormalizerMiddleware).BeforeModelRewriteState(context.Background(), state, nil)
if err != nil {
t.Fatal(err)
}
if countADKSystemMessages(out.Messages) != 1 {
t.Fatalf("want 1 system, got %d", countADKSystemMessages(out.Messages))
}
if out.Messages[0].Content != "a\n\nb" {
t.Fatalf("merged content: %q", out.Messages[0].Content)
}
}
func TestSystemMessageNormalizerMiddleware_NoOpSingleSystem(t *testing.T) {
mw := newSystemMessageNormalizerMiddleware(nil, "test")
state := &adk.ChatModelAgentState{Messages: []adk.Message{
schema.SystemMessage("only"),
schema.UserMessage("u"),
}}
_, out, err := mw.(*systemMessageNormalizerMiddleware).BeforeModelRewriteState(context.Background(), state, nil)
if err != nil {
t.Fatal(err)
}
if out != state {
t.Fatalf("expected same state pointer for no-op")
}
}
+71 -9
View File
@@ -2456,16 +2456,68 @@ header {
flex-shrink: 0; flex-shrink: 0;
} }
.mcp-call-buttons { .mcp-call-buttons,
.mcp-call-toolbar {
display: flex; display: flex;
flex-wrap: wrap; flex-wrap: wrap;
gap: 6px; gap: 6px;
align-items: center;
}
.mcp-tool-list {
display: none;
flex-wrap: wrap;
gap: 6px;
margin-top: 8px;
width: 100%;
}
.mcp-tool-list.expanded {
display: flex;
}
.mcp-tools-toggle-btn {
background: rgba(25, 118, 210, 0.1) !important;
border-color: rgba(25, 118, 210, 0.35) !important;
color: #1976d2 !important;
}
.mcp-call-toolbar .process-detail-btn,
.mcp-call-toolbar .mcp-tools-toggle-btn {
display: inline-flex;
align-items: center;
justify-content: center;
gap: 6px;
min-height: 32px;
padding: 6px 12px;
font-size: 0.8125rem;
font-weight: 500;
line-height: 1.25;
box-sizing: border-box;
white-space: nowrap;
vertical-align: middle;
}
.mcp-call-toolbar .process-detail-btn span,
.mcp-call-toolbar .mcp-tools-toggle-btn span {
display: inline-flex;
align-items: center;
line-height: 1.25;
}
.mcp-tools-toggle-btn:hover {
background: rgba(25, 118, 210, 0.18) !important;
border-color: #1976d2 !important;
color: #1565c0 !important;
} }
.process-detail-btn { .process-detail-btn {
background: rgba(156, 39, 176, 0.1) !important; background: rgba(156, 39, 176, 0.1) !important;
border-color: rgba(156, 39, 176, 0.3) !important; border-color: rgba(156, 39, 176, 0.3) !important;
color: #9c27b0 !important; color: #9c27b0 !important;
}
.mcp-call-toolbar .process-detail-btn {
display: inline-flex; display: inline-flex;
align-items: center; align-items: center;
gap: 6px; gap: 6px;
@@ -11797,34 +11849,44 @@ tr.mcp-stats-tool-row[data-tool-name]:focus-visible {
background: transparent; background: transparent;
color: var(--text-muted); color: var(--text-muted);
cursor: pointer; cursor: pointer;
border-radius: 4px; border-radius: 6px;
display: flex; display: flex;
align-items: center; align-items: center;
justify-content: center; justify-content: center;
transition: all 0.2s ease; transition: all 0.2s ease;
} }
.batch-delete-btn svg {
width: 16px;
height: 16px;
transition: transform 0.2s ease;
}
.batch-delete-btn:hover { .batch-delete-btn:hover {
background: rgba(220, 53, 69, 0.1); background: rgba(220, 53, 69, 0.1);
color: var(--error-color); color: var(--error-color);
} }
.batch-delete-btn:hover svg {
transform: scale(1.08);
}
.batch-delete-btn:active {
background: rgba(220, 53, 69, 0.2);
transform: scale(0.95);
}
.batch-manage-footer { .batch-manage-footer {
display: flex; display: flex;
align-items: center; align-items: center;
justify-content: space-between; justify-content: flex-end;
padding: 16px 24px; padding: 16px 24px;
border-top: 1px solid var(--border-color); border-top: 1px solid var(--border-color);
flex-shrink: 0; flex-shrink: 0;
} }
.select-all-checkbox { .batch-table-col-checkbox input[type="checkbox"] {
display: flex;
align-items: center;
gap: 8px;
cursor: pointer; cursor: pointer;
font-size: 0.875rem;
color: var(--text-primary);
} }
.batch-footer-actions { .batch-footer-actions {
+9 -2
View File
@@ -530,6 +530,8 @@
"noMatchTools": "No matching tools", "noMatchTools": "No matching tools",
"penetrationTestDetail": "Penetration test details", "penetrationTestDetail": "Penetration test details",
"expandDetail": "Expand details", "expandDetail": "Expand details",
"toolExecutionsCount": "{{n}} tool runs",
"collapseToolExecutions": "Collapse tool runs",
"noProcessDetail": "No process details (execution may be too fast or no detailed events)", "noProcessDetail": "No process details (execution may be too fast or no detailed events)",
"copyMessageTitle": "Copy message", "copyMessageTitle": "Copy message",
"deleteTurnTitle": "Delete this turn", "deleteTurnTitle": "Delete this turn",
@@ -1656,6 +1658,7 @@
"rateWarning": "Some failures detected", "rateWarning": "Some failures detected",
"rateCritical": "High failure rate", "rateCritical": "High failure rate",
"statsSubtitle": "Refreshed {{time}} · {{count}} tools", "statsSubtitle": "Refreshed {{time}} · {{count}} tools",
"retentionHint": "Execution records are kept for {{days}} days, then purged automatically.",
"timelineTitle": "Call trend", "timelineTitle": "Call trend",
"timelineHint": "All tools combined (not split by tool)", "timelineHint": "All tools combined (not split by tool)",
"timelineRange24h": "24h", "timelineRange24h": "24h",
@@ -2577,6 +2580,8 @@
"agentModeSingle": "Single-agent (Eino ADK)", "agentModeSingle": "Single-agent (Eino ADK)",
"agentModeMulti": "Multi-agent (Eino)", "agentModeMulti": "Multi-agent (Eino)",
"agentModeHint": "Same as chat: Eino single-agent (ADK), or Deep / Plan-Execute / Supervisor (last three require multi_agent.enabled).", "agentModeHint": "Same as chat: Eino single-agent (ADK), or Deep / Plan-Execute / Supervisor (last three require multi_agent.enabled).",
"concurrency": "Concurrency",
"concurrencyHint": "Number of subtasks to run in parallel (1-8). Default 1 is serial; use 1-2 for scan-heavy tasks.",
"scheduleMode": "Schedule mode", "scheduleMode": "Schedule mode",
"scheduleModeManual": "Manual", "scheduleModeManual": "Manual",
"scheduleModeCron": "Cron expression", "scheduleModeCron": "Cron expression",
@@ -2591,8 +2596,8 @@
"tasksList": "Task list (one task per line)", "tasksList": "Task list (one task per line)",
"tasksListPlaceholder": "Enter task list, one per line", "tasksListPlaceholder": "Enter task list, one per line",
"tasksListPlaceholderExample": "Enter task list, one per line, for example:\nScan open ports of 192.168.1.1\nCheck if https://example.com has SQL injection\nEnumerate subdomains of example.com", "tasksListPlaceholderExample": "Enter task list, one per line, for example:\nScan open ports of 192.168.1.1\nCheck if https://example.com has SQL injection\nEnumerate subdomains of example.com",
"tasksListHint": "Enter one task command per line; the system will execute them in order. Empty lines are ignored.", "tasksListHint": "Enter one task command per line; the system runs them via a concurrency pool. Empty lines are ignored.",
"tasksListHintFull": "Hint: Enter one task command per line; the system will execute these tasks in order. Empty lines are ignored.", "tasksListHintFull": "Hint: Enter one task command per line; the system runs them via a concurrency pool. Empty lines are ignored.",
"createQueue": "Create queue" "createQueue": "Create queue"
}, },
"batchQueueDetailModal": { "batchQueueDetailModal": {
@@ -2626,6 +2631,8 @@
"scheduleToggleFailed": "Failed to update schedule toggle", "scheduleToggleFailed": "Failed to update schedule toggle",
"completedAt": "Completed at", "completedAt": "Completed at",
"taskTotal": "Total tasks", "taskTotal": "Total tasks",
"concurrency": "Concurrency",
"concurrencyEditHint": "Click to edit. Cannot change while the queue is running.",
"taskList": "Task list", "taskList": "Task list",
"startLabel": "Start", "startLabel": "Start",
"completeLabel": "Complete", "completeLabel": "Complete",
+9 -2
View File
@@ -518,6 +518,8 @@
"noMatchTools": "没有匹配的工具", "noMatchTools": "没有匹配的工具",
"penetrationTestDetail": "渗透测试详情", "penetrationTestDetail": "渗透测试详情",
"expandDetail": "展开详情", "expandDetail": "展开详情",
"toolExecutionsCount": "{{n}}次工具执行",
"collapseToolExecutions": "收起工具执行",
"noProcessDetail": "暂无过程详情(可能执行过快或未触发详细事件)", "noProcessDetail": "暂无过程详情(可能执行过快或未触发详细事件)",
"copyMessageTitle": "复制消息内容", "copyMessageTitle": "复制消息内容",
"deleteTurnTitle": "删除本轮对话", "deleteTurnTitle": "删除本轮对话",
@@ -1644,6 +1646,7 @@
"rateWarning": "存在失败调用", "rateWarning": "存在失败调用",
"rateCritical": "失败率偏高", "rateCritical": "失败率偏高",
"statsSubtitle": "最后刷新 {{time}} · 共 {{count}} 个工具", "statsSubtitle": "最后刷新 {{time}} · 共 {{count}} 个工具",
"retentionHint": "执行记录保留 {{days}} 天,超期自动清理",
"timelineTitle": "调用趋势", "timelineTitle": "调用趋势",
"timelineHint": "全部工具合计,不按工具拆分", "timelineHint": "全部工具合计,不按工具拆分",
"timelineRange24h": "24 小时", "timelineRange24h": "24 小时",
@@ -2565,6 +2568,8 @@
"agentModeSingle": "单代理(Eino ADK", "agentModeSingle": "单代理(Eino ADK",
"agentModeMulti": "多代理(Eino", "agentModeMulti": "多代理(Eino",
"agentModeHint": "与对话页一致:Eino 单代理(ADK),或 Deep / Plan-Execute / Supervisor(后三种需已启用多代理)。", "agentModeHint": "与对话页一致:Eino 单代理(ADK),或 Deep / Plan-Execute / Supervisor(后三种需已启用多代理)。",
"concurrency": "并发数",
"concurrencyHint": "同时执行的子任务数量(1-8)。默认 1 为串行;含扫描类工具时建议 1-2。",
"scheduleMode": "调度方式", "scheduleMode": "调度方式",
"scheduleModeManual": "手工执行", "scheduleModeManual": "手工执行",
"scheduleModeCron": "调度表达式(Cron", "scheduleModeCron": "调度表达式(Cron",
@@ -2579,8 +2584,8 @@
"tasksList": "任务列表(每行一个任务)", "tasksList": "任务列表(每行一个任务)",
"tasksListPlaceholder": "请输入任务列表,每行一个任务", "tasksListPlaceholder": "请输入任务列表,每行一个任务",
"tasksListPlaceholderExample": "请输入任务列表,每行一个任务,例如:\n扫描 192.168.1.1 的开放端口\n检查 https://example.com 是否存在SQL注入\n枚举 example.com 的子域名", "tasksListPlaceholderExample": "请输入任务列表,每行一个任务,例如:\n扫描 192.168.1.1 的开放端口\n检查 https://example.com 是否存在SQL注入\n枚举 example.com 的子域名",
"tasksListHint": "每行输入一个任务指令,系统将依次执行这些任务。空行会被自动忽略。", "tasksListHint": "每行输入一个任务指令,系统将按并发池执行这些任务。空行会被自动忽略。",
"tasksListHintFull": "提示:每行输入一个任务指令,系统将依次执行这些任务。空行会被自动忽略。", "tasksListHintFull": "提示:每行输入一个任务指令,系统将按并发池执行这些任务。空行会被自动忽略。",
"createQueue": "创建队列" "createQueue": "创建队列"
}, },
"batchQueueDetailModal": { "batchQueueDetailModal": {
@@ -2614,6 +2619,8 @@
"scheduleToggleFailed": "更新调度开关失败", "scheduleToggleFailed": "更新调度开关失败",
"completedAt": "完成时间", "completedAt": "完成时间",
"taskTotal": "任务总数", "taskTotal": "任务总数",
"concurrency": "并发数",
"concurrencyEditHint": "点击可修改;队列运行中不可改。",
"taskList": "任务列表", "taskList": "任务列表",
"startLabel": "开始", "startLabel": "开始",
"completeLabel": "完成", "completeLabel": "完成",
+437 -108
View File
@@ -1,6 +1,39 @@
let currentConversationId = null; let currentConversationId = null;
let loadConversationRequestSeq = 0; let loadConversationRequestSeq = 0;
/** 轻量会话 LRU 缓存:来回切换已加载会话时避免重复网络 + 全量 DOM 重建 */
const CONVERSATION_LITE_CACHE_MAX = 12;
const conversationLiteCache = new Map();
function getConversationLiteFromCache(conversationId) {
if (!conversationId) return null;
const hit = conversationLiteCache.get(conversationId);
if (!hit) return null;
conversationLiteCache.delete(conversationId);
conversationLiteCache.set(conversationId, hit);
return hit;
}
function putConversationLiteCache(conversationId, data) {
if (!conversationId || !data) return;
conversationLiteCache.delete(conversationId);
conversationLiteCache.set(conversationId, data);
while (conversationLiteCache.size > CONVERSATION_LITE_CACHE_MAX) {
const oldest = conversationLiteCache.keys().next().value;
conversationLiteCache.delete(oldest);
}
}
function invalidateConversationLiteCache(conversationId) {
if (conversationId) {
conversationLiteCache.delete(conversationId);
} else {
conversationLiteCache.clear();
}
}
window.invalidateConversationLiteCache = invalidateConversationLiteCache;
// @ 提及相关状态 // @ 提及相关状态
let mentionTools = []; let mentionTools = [];
let mentionToolsLoaded = false; let mentionToolsLoaded = false;
@@ -886,6 +919,9 @@ async function sendMessage() {
window.CyberStrikeChatScroll.onUserSendMessage(); window.CyberStrikeChatScroll.onUserSendMessage();
} }
addMessage('user', displayMessage, null, null, null, { scroll: 'none' }); addMessage('user', displayMessage, null, null, null, { scroll: 'none' });
if (currentConversationId) {
invalidateConversationLiteCache(currentConversationId);
}
// 清除防抖定时器,防止在清空输入框后重新保存草稿 // 清除防抖定时器,防止在清空输入框后重新保存草稿
if (draftSaveTimer) { if (draftSaveTimer) {
@@ -2027,31 +2063,13 @@ function addMessage(role, content, mcpExecutionIds = null, progressId = null, cr
// 有 MCP 执行记录且非流式占位消息时展示调用按钮;带 progressId 的流式占位不挂此条(与进度卡片一致,结束时 integrate 再创建) // 有 MCP 执行记录且非流式占位消息时展示调用按钮;带 progressId 的流式占位不挂此条(与进度卡片一致,结束时 integrate 再创建)
if (role === 'assistant' && (mcpExecutionIds && Array.isArray(mcpExecutionIds) && mcpExecutionIds.length > 0) && !progressId) { if (role === 'assistant' && (mcpExecutionIds && Array.isArray(mcpExecutionIds) && mcpExecutionIds.length > 0) && !progressId) {
const mcpSection = document.createElement('div'); if (options && options.deferMcpButtons) {
mcpSection.className = 'mcp-call-section'; try {
messageDiv.dataset.pendingMcpExecutionIds = JSON.stringify(mcpExecutionIds);
const mcpLabel = document.createElement('div'); } catch (e) { /* ignore */ }
mcpLabel.className = 'mcp-call-label'; } else {
mcpLabel.textContent = '📋 ' + (typeof window.t === 'function' ? window.t('chat.penetrationTestDetail') : '渗透测试详情'); appendMcpCallButtons(messageDiv, mcpExecutionIds);
mcpSection.appendChild(mcpLabel); }
const buttonsContainer = document.createElement('div');
buttonsContainer.className = 'mcp-call-buttons';
mcpExecutionIds.forEach((execId, index) => {
const detailBtn = document.createElement('button');
detailBtn.className = 'mcp-detail-btn';
detailBtn.dataset.execId = execId;
detailBtn.dataset.execIndex = String(index + 1);
detailBtn.innerHTML = '<span>' + (typeof window.t === 'function' ? window.t('chat.callNumber', { n: index + 1 }) : '调用 #' + (index + 1)) + '</span>';
detailBtn.onclick = () => showMCPDetail(execId);
buttonsContainer.appendChild(detailBtn);
});
// 使用批量 API 一次性获取所有工具名称(消除 N 次单独请求)
batchUpdateButtonToolNames(buttonsContainer, mcpExecutionIds);
mcpSection.appendChild(buttonsContainer);
contentWrapper.appendChild(mcpSection);
} }
messageDiv.appendChild(contentWrapper); messageDiv.appendChild(contentWrapper);
@@ -2151,11 +2169,13 @@ function copyMessageToClipboard(messageDiv, button) {
function showCopySuccess(button) { function showCopySuccess(button) {
if (button) { if (button) {
const originalText = button.innerHTML; const originalText = button.innerHTML;
button.dataset.copySuccessActive = '1';
button.innerHTML = '<svg width="16" height="16" viewBox="0 0 24 24" fill="none" xmlns="http://www.w3.org/2000/svg"><path d="M20 6L9 17l-5-5" stroke="currentColor" stroke-width="2.5" stroke-linecap="round" stroke-linejoin="round" fill="none"/></svg><span>' + (typeof window.t === 'function' ? window.t('common.copied') : '已复制') + '</span>'; button.innerHTML = '<svg width="16" height="16" viewBox="0 0 24 24" fill="none" xmlns="http://www.w3.org/2000/svg"><path d="M20 6L9 17l-5-5" stroke="currentColor" stroke-width="2.5" stroke-linecap="round" stroke-linejoin="round" fill="none"/></svg><span>' + (typeof window.t === 'function' ? window.t('common.copied') : '已复制') + '</span>';
button.style.color = '#10b981'; button.style.color = '#10b981';
button.style.background = 'rgba(16, 185, 129, 0.1)'; button.style.background = 'rgba(16, 185, 129, 0.1)';
button.style.borderColor = 'rgba(16, 185, 129, 0.3)'; button.style.borderColor = 'rgba(16, 185, 129, 0.3)';
setTimeout(() => { setTimeout(() => {
delete button.dataset.copySuccessActive;
button.innerHTML = originalText; button.innerHTML = originalText;
button.style.color = ''; button.style.color = '';
button.style.background = ''; button.style.background = '';
@@ -2301,47 +2321,20 @@ function processDetailRowFingerprint(d) {
} }
// 渲染过程详情 // 渲染过程详情
function renderProcessDetails(messageId, processDetails) { // options.append=true 时分页追加;options.markLoaded=false 时保留 lazy 标记(分页加载中)
function renderProcessDetails(messageId, processDetails, options) {
const renderOpts = options || {};
const appendMode = !!renderOpts.append;
const markLoaded = renderOpts.markLoaded !== false;
const messageElement = document.getElementById(messageId); const messageElement = document.getElementById(messageId);
if (!messageElement) { if (!messageElement) {
return; return;
} }
// 查找或创建MCP调用区域 // 查找或创建 MCP 区域(工具栏 + 工具列表 + 迭代时间线 分区)
let mcpSection = messageElement.querySelector('.mcp-call-section'); const chrome = ensureMcpCallSectionChrome(messageElement, messageId);
if (!mcpSection) { if (!chrome) return;
mcpSection = document.createElement('div'); const { mcpSection, toolbar: buttonsContainer } = chrome;
mcpSection.className = 'mcp-call-section';
const contentWrapper = messageElement.querySelector('.message-content');
if (contentWrapper) {
contentWrapper.appendChild(mcpSection);
} else {
return;
}
}
// 确保有标签和按钮容器(统一结构)
let mcpLabel = mcpSection.querySelector('.mcp-call-label');
let buttonsContainer = mcpSection.querySelector('.mcp-call-buttons');
// 如果没有标签,创建一个(当没有工具调用时)
if (!mcpLabel && !buttonsContainer) {
mcpLabel = document.createElement('div');
mcpLabel.className = 'mcp-call-label';
mcpLabel.textContent = '📋 ' + (typeof window.t === 'function' ? window.t('chat.penetrationTestDetail') : '渗透测试详情');
mcpSection.appendChild(mcpLabel);
} else if (mcpLabel && mcpLabel.textContent !== ('📋 ' + (typeof window.t === 'function' ? window.t('chat.penetrationTestDetail') : '渗透测试详情'))) {
// 如果标签存在但不是统一格式,更新它
mcpLabel.textContent = '📋 ' + (typeof window.t === 'function' ? window.t('chat.penetrationTestDetail') : '渗透测试详情');
}
// 如果没有按钮容器,创建一个
if (!buttonsContainer) {
buttonsContainer = document.createElement('div');
buttonsContainer.className = 'mcp-call-buttons';
mcpSection.appendChild(buttonsContainer);
}
// 添加过程详情按钮(如果还没有) // 添加过程详情按钮(如果还没有)
let processDetailBtn = buttonsContainer.querySelector('.process-detail-btn'); let processDetailBtn = buttonsContainer.querySelector('.process-detail-btn');
@@ -2352,17 +2345,20 @@ function renderProcessDetails(messageId, processDetails) {
processDetailBtn.onclick = () => toggleProcessDetails(null, messageId); processDetailBtn.onclick = () => toggleProcessDetails(null, messageId);
buttonsContainer.appendChild(processDetailBtn); buttonsContainer.appendChild(processDetailBtn);
} }
syncMcpToolsToggleButton(messageElement);
// 创建过程详情容器(放在按钮容器之后) // 创建过程详情容器(放在工具列表之后)
const detailsId = 'process-details-' + messageId; const detailsId = 'process-details-' + messageId;
let detailsContainer = document.getElementById(detailsId); let detailsContainer = document.getElementById(detailsId);
const toolListEl = chrome.toolList;
if (!detailsContainer) { if (!detailsContainer) {
detailsContainer = document.createElement('div'); detailsContainer = document.createElement('div');
detailsContainer.id = detailsId; detailsContainer.id = detailsId;
detailsContainer.className = 'process-details-container'; detailsContainer.className = 'process-details-container';
// 确保容器在按钮容器之后 if (toolListEl) {
if (buttonsContainer.nextSibling) { toolListEl.after(detailsContainer);
} else if (buttonsContainer.nextSibling) {
mcpSection.insertBefore(detailsContainer, buttonsContainer.nextSibling); mcpSection.insertBefore(detailsContainer, buttonsContainer.nextSibling);
} else { } else {
mcpSection.appendChild(detailsContainer); mcpSection.appendChild(detailsContainer);
@@ -2391,17 +2387,21 @@ function renderProcessDetails(messageId, processDetails) {
if (isLazyNotLoaded && !reasoningFromMessage) { if (isLazyNotLoaded && !reasoningFromMessage) {
detailsContainer.dataset.lazyNotLoaded = '1'; detailsContainer.dataset.lazyNotLoaded = '1';
detailsContainer.dataset.loaded = '0'; detailsContainer.dataset.loaded = '0';
timeline.innerHTML = '<div class="progress-timeline-empty">' + const expandLabel = typeof window.t === 'function' ? window.t('chat.expandDetail') : '展开详情';
(typeof window.t === 'function' ? window.t('chat.expandDetail') : '展开详情') + let lazyHint = expandLabel + '(点击后加载迭代详情';
'(点击后加载)</div>'; timeline.innerHTML = '<div class="progress-timeline-empty">' + lazyHint + '</div>';
timeline.classList.remove('expanded'); timeline.classList.remove('expanded');
prefetchProcessDetailsSummaryHint(messageId, messageElement);
return; return;
} }
if (isLazyNotLoaded) { if (isLazyNotLoaded) {
detailsContainer.dataset.lazyNotLoaded = '1'; detailsContainer.dataset.lazyNotLoaded = '1';
detailsContainer.dataset.loaded = '0'; detailsContainer.dataset.loaded = '0';
processDetails = []; processDetails = [];
} else { if (!appendMode) {
prefetchProcessDetailsSummaryHint(messageId, messageElement);
}
} else if (markLoaded) {
detailsContainer.dataset.lazyNotLoaded = '0'; detailsContainer.dataset.lazyNotLoaded = '0';
detailsContainer.dataset.loaded = '1'; detailsContainer.dataset.loaded = '1';
} }
@@ -2413,15 +2413,16 @@ function renderProcessDetails(messageId, processDetails) {
} }
// 如果没有processDetails或为空,显示空状态 // 如果没有processDetails或为空,显示空状态
if (!processDetails || processDetails.length === 0) { if (!processDetails || processDetails.length === 0) {
// 显示空状态提示 if (!appendMode) {
timeline.innerHTML = '<div class="progress-timeline-empty">' + (typeof window.t === 'function' ? window.t('chat.noProcessDetail') : '暂无过程详情(可能执行过快或未触发详细事件)') + '</div>'; timeline.innerHTML = '<div class="progress-timeline-empty">' + (typeof window.t === 'function' ? window.t('chat.noProcessDetail') : '暂无过程详情(可能执行过快或未触发详细事件)') + '</div>';
// 默认折叠 timeline.classList.remove('expanded');
timeline.classList.remove('expanded'); }
return; return;
} }
// 清空时间线并重新渲染 if (!appendMode) {
timeline.innerHTML = ''; timeline.innerHTML = '';
}
function processDetailAgentPrefix(d) { function processDetailAgentPrefix(d) {
@@ -2430,14 +2431,12 @@ function renderProcessDetails(messageId, processDetails) {
return s ? ('[' + s + '] ') : ''; return s ? ('[' + s + '] ') : '';
} }
// 渲染每个过程详情事件 function renderOneProcessDetail(detail) {
processDetails.forEach(detail => {
const eventType = detail.eventType || ''; const eventType = detail.eventType || '';
const title = detail.message || ''; const title = detail.message || '';
const data = detail.data || {}; const data = detail.data || {};
const agPx = processDetailAgentPrefix(data); const agPx = processDetailAgentPrefix(data);
// 根据事件类型渲染不同的内容
let itemTitle = title; let itemTitle = title;
if (eventType === 'iteration') { if (eventType === 'iteration') {
const n = data.iteration || 1; const n = data.iteration || 1;
@@ -2530,15 +2529,38 @@ function renderProcessDetails(messageId, processDetails) {
title: itemTitle, title: itemTitle,
message: detail.message || '', message: detail.message || '',
data: data, data: data,
createdAt: detail.createdAt // 传递实际的事件创建时间 createdAt: detail.createdAt
}; };
if (eventType === 'tool_call' && data._mergedResult) { if (eventType === 'tool_call' && data._mergedResult) {
timelineOpts.mergedResult = data._mergedResult; timelineOpts.mergedResult = data._mergedResult;
} }
addTimelineItem(timeline, eventType, timelineOpts); addTimelineItem(timeline, eventType, timelineOpts);
}); }
if (isLazyNotLoaded && reasoningFromMessage) { const TIMELINE_RENDER_BATCH = 40;
const renderTimelineBatch = (startIdx) => {
const endIdx = Math.min(startIdx + TIMELINE_RENDER_BATCH, processDetails.length);
for (let i = startIdx; i < endIdx; i++) {
renderOneProcessDetail(processDetails[i]);
}
if (endIdx < processDetails.length) {
requestAnimationFrame(() => renderTimelineBatch(endIdx));
} else if (markLoaded) {
finishProcessDetailsRender(messageElement, processDetails, isLazyNotLoaded, timeline);
}
};
if (processDetails.length > TIMELINE_RENDER_BATCH) {
renderTimelineBatch(0);
} else {
processDetails.forEach(renderOneProcessDetail);
if (markLoaded) {
finishProcessDetailsRender(messageElement, processDetails, isLazyNotLoaded, timeline);
}
}
}
function finishProcessDetailsRender(messageElement, processDetails, isLazyNotLoaded, timeline) {
if (isLazyNotLoaded && getMessageReasoningContent(messageElement)) {
const lazyHint = document.createElement('div'); const lazyHint = document.createElement('div');
lazyHint.className = 'progress-timeline-empty progress-timeline-lazy-hint'; lazyHint.className = 'progress-timeline-empty progress-timeline-lazy-hint';
lazyHint.textContent = (typeof window.t === 'function' ? window.t('chat.expandDetail') : '展开详情') + lazyHint.textContent = (typeof window.t === 'function' ? window.t('chat.expandDetail') : '展开详情') +
@@ -2546,15 +2568,12 @@ function renderProcessDetails(messageId, processDetails) {
timeline.appendChild(lazyHint); timeline.appendChild(lazyHint);
} }
// 检查是否有错误或取消事件,如果有,确保详情默认折叠(但仍有待审批 HITL 时保持展开,由 restoreHitlInlineForConversation 处理)
const hasPendingHitlInDetails = processDetails.some(d => d && d.eventType === 'hitl_interrupt'); const hasPendingHitlInDetails = processDetails.some(d => d && d.eventType === 'hitl_interrupt');
const hasErrorOrCancelled = processDetails.some(d => const hasErrorOrCancelled = processDetails.some(d =>
d.eventType === 'error' || d.eventType === 'cancelled' d.eventType === 'error' || d.eventType === 'cancelled'
); );
if (hasErrorOrCancelled && !hasPendingHitlInDetails) { if (hasErrorOrCancelled && !hasPendingHitlInDetails) {
// 确保时间线是折叠的
timeline.classList.remove('expanded'); timeline.classList.remove('expanded');
// 更新按钮文本为"展开详情"
const processDetailBtn = messageElement.querySelector('.process-detail-btn'); const processDetailBtn = messageElement.querySelector('.process-detail-btn');
if (processDetailBtn) { if (processDetailBtn) {
processDetailBtn.innerHTML = '<span>' + (typeof window.t === 'function' ? window.t('chat.expandDetail') : '展开详情') + '</span>'; processDetailBtn.innerHTML = '<span>' + (typeof window.t === 'function' ? window.t('chat.expandDetail') : '展开详情') + '</span>';
@@ -2562,6 +2581,36 @@ function renderProcessDetails(messageId, processDetails) {
} }
} }
/** 懒加载折叠态:后台拉摘要,提示迭代规模而不加载全量详情 */
function prefetchProcessDetailsSummaryHint(messageId, messageElement) {
if (!messageElement || !messageElement.dataset || !messageElement.dataset.backendMessageId) return;
const backendId = String(messageElement.dataset.backendMessageId).trim();
if (!backendId || typeof apiFetch !== 'function') return;
const detailsContainer = document.getElementById('process-details-' + messageId);
if (!detailsContainer || detailsContainer.dataset.summaryFetched === '1') return;
detailsContainer.dataset.summaryFetched = '1';
apiFetch('/api/messages/' + encodeURIComponent(backendId) + '/process-details?summary=1')
.then(async (res) => {
const j = await res.json().catch(() => ({}));
if (!res.ok || !j.summary) return;
const s = j.summary;
const timeline = detailsContainer.querySelector('.progress-timeline');
if (!timeline || detailsContainer.dataset.loaded === '1') return;
const expandLabel = typeof window.t === 'function' ? window.t('chat.expandDetail') : '展开详情';
let hint = expandLabel + '(点击后加载迭代详情)';
if (s.maxIteration > 0) {
hint = expandLabel + '(共 ' + s.maxIteration + ' 轮迭代,' + (s.total || 0) + ' 条详情)';
} else if (s.total > 0) {
hint = expandLabel + '(共 ' + (s.total || 0) + ' 条详情)';
}
const empty = timeline.querySelector('.progress-timeline-empty');
if (empty) {
empty.textContent = hint;
}
})
.catch(() => {});
}
// 移除消息 // 移除消息
function removeMessage(id) { function removeMessage(id) {
const messageDiv = document.getElementById(id); const messageDiv = document.getElementById(id);
@@ -2623,6 +2672,201 @@ async function updateButtonWithToolName(button, executionId, index) {
} }
} }
function getPendingMcpExecutionCount(messageElement) {
if (!messageElement || !messageElement.dataset || !messageElement.dataset.pendingMcpExecutionIds) {
return 0;
}
try {
const ids = JSON.parse(messageElement.dataset.pendingMcpExecutionIds);
return Array.isArray(ids) ? ids.length : 0;
} catch (e) {
return 0;
}
}
function getMcpExecutionCount(messageElement) {
const pending = getPendingMcpExecutionCount(messageElement);
if (pending > 0) return pending;
const toolList = messageElement && messageElement.querySelector('.mcp-tool-list');
if (toolList) {
return toolList.querySelectorAll('.mcp-detail-btn[data-exec-id]').length;
}
return 0;
}
function formatMcpToolsToggleLabel(count, expanded) {
if (expanded) {
if (typeof window.t === 'function') {
const s = window.t('chat.collapseToolExecutions');
if (s && s !== 'chat.collapseToolExecutions') return s;
}
return '收起工具执行';
}
if (typeof window.t === 'function') {
const s = window.t('chat.toolExecutionsCount', { n: count });
if (s && s !== 'chat.toolExecutionsCount') return s;
}
return count + '次工具执行';
}
/** 渗透测试区:工具栏(展开详情 | N次工具执行)+ 独立工具列表 + 迭代时间线 */
function ensureMcpCallSectionChrome(messageElement, messageId) {
const contentWrapper = messageElement && messageElement.querySelector('.message-content');
if (!contentWrapper) return null;
let mcpSection = messageElement.querySelector('.mcp-call-section');
if (!mcpSection) {
mcpSection = document.createElement('div');
mcpSection.className = 'mcp-call-section';
const mcpLabel = document.createElement('div');
mcpLabel.className = 'mcp-call-label';
mcpLabel.textContent = '📋 ' + (typeof window.t === 'function' ? window.t('chat.penetrationTestDetail') : '渗透测试详情');
mcpSection.appendChild(mcpLabel);
contentWrapper.appendChild(mcpSection);
} else {
const mcpLabel = mcpSection.querySelector('.mcp-call-label');
const labelText = '📋 ' + (typeof window.t === 'function' ? window.t('chat.penetrationTestDetail') : '渗透测试详情');
if (mcpLabel && mcpLabel.textContent !== labelText) {
mcpLabel.textContent = labelText;
}
}
let toolbar = mcpSection.querySelector('.mcp-call-toolbar');
const legacyButtons = mcpSection.querySelector('.mcp-call-buttons');
if (!toolbar) {
toolbar = document.createElement('div');
toolbar.className = 'mcp-call-toolbar';
if (legacyButtons) {
const processBtn = legacyButtons.querySelector('.process-detail-btn');
if (processBtn) toolbar.appendChild(processBtn);
mcpSection.replaceChild(toolbar, legacyButtons);
} else {
mcpSection.appendChild(toolbar);
}
}
let toolList = mcpSection.querySelector('.mcp-tool-list');
if (!toolList) {
toolList = document.createElement('div');
toolList.className = 'mcp-tool-list';
const detailsContainer = mcpSection.querySelector('.process-details-container');
if (detailsContainer) {
mcpSection.insertBefore(toolList, detailsContainer);
} else {
toolbar.after(toolList);
}
}
if (legacyButtons && legacyButtons.parentNode === mcpSection) {
legacyButtons.querySelectorAll('.mcp-detail-btn[data-exec-id]').forEach((btn) => toolList.appendChild(btn));
legacyButtons.remove();
}
const clientId = messageId || messageElement.id;
if (clientId && !toolbar.querySelector('.process-detail-btn')) {
const processDetailBtn = document.createElement('button');
processDetailBtn.className = 'mcp-detail-btn process-detail-btn';
processDetailBtn.innerHTML = '<span>' + (typeof window.t === 'function' ? window.t('chat.expandDetail') : '展开详情') + '</span>';
processDetailBtn.onclick = () => toggleProcessDetails(null, clientId);
toolbar.appendChild(processDetailBtn);
}
return { mcpSection, toolbar, toolList };
}
function syncMcpToolsToggleButton(messageElement) {
if (!messageElement) return;
const chrome = ensureMcpCallSectionChrome(messageElement, messageElement.id);
if (!chrome) return;
const { toolbar, toolList } = chrome;
const count = getMcpExecutionCount(messageElement);
let toolsToggle = toolbar.querySelector('.mcp-tools-toggle-btn');
if (count <= 0) {
if (toolsToggle) toolsToggle.remove();
return;
}
if (!toolsToggle) {
toolsToggle = document.createElement('button');
toolsToggle.type = 'button';
toolsToggle.className = 'mcp-detail-btn mcp-tools-toggle-btn';
toolsToggle.onclick = function (e) {
e.stopPropagation();
toggleMcpToolList(messageElement.id);
};
toolbar.appendChild(toolsToggle);
}
const expanded = toolList.classList.contains('expanded');
toolsToggle.innerHTML = '<span>' + formatMcpToolsToggleLabel(count, expanded) + '</span>';
}
function toggleMcpToolList(assistantMessageId) {
const messageEl = document.getElementById(assistantMessageId);
if (!messageEl) return;
const chrome = ensureMcpCallSectionChrome(messageEl, assistantMessageId);
if (!chrome) return;
const { toolList } = chrome;
const willExpand = !toolList.classList.contains('expanded');
if (willExpand) {
ensureMcpCallButtons(messageEl);
toolList.classList.add('expanded');
} else {
toolList.classList.remove('expanded');
}
syncMcpToolsToggleButton(messageEl);
}
window.toggleMcpToolList = toggleMcpToolList;
window.syncMcpToolsToggleButton = syncMcpToolsToggleButton;
window.ensureMcpCallSectionChrome = ensureMcpCallSectionChrome;
/** 将 MCP 工具按钮挂到独立工具列表,并批量解析工具名 */
function appendMcpCallButtons(messageElement, executionIds) {
if (!messageElement || !Array.isArray(executionIds) || executionIds.length === 0) {
return;
}
const chrome = ensureMcpCallSectionChrome(messageElement, messageElement.id);
if (!chrome) return;
const toolList = chrome.toolList;
executionIds.forEach((execId, index) => {
if (toolList.querySelector('.mcp-detail-btn[data-exec-id="' + CSS.escape(String(execId)) + '"]')) {
return;
}
const detailBtn = document.createElement('button');
detailBtn.className = 'mcp-detail-btn';
detailBtn.dataset.execId = execId;
detailBtn.dataset.execIndex = String(index + 1);
detailBtn.innerHTML = '<span>' + (typeof window.t === 'function' ? window.t('chat.callNumber', { n: index + 1 }) : '调用 #' + (index + 1)) + '</span>';
detailBtn.onclick = () => showMCPDetail(execId);
toolList.appendChild(detailBtn);
});
batchUpdateButtonToolNames(toolList, executionIds);
syncMcpToolsToggleButton(messageElement);
}
/** 历史会话懒加载:用户展开工具列表时再渲染工具按钮 */
function ensureMcpCallButtons(messageElement) {
if (!messageElement || !messageElement.dataset || !messageElement.dataset.pendingMcpExecutionIds) {
return;
}
let executionIds;
try {
executionIds = JSON.parse(messageElement.dataset.pendingMcpExecutionIds);
} catch (e) {
delete messageElement.dataset.pendingMcpExecutionIds;
return;
}
if (!Array.isArray(executionIds) || executionIds.length === 0) {
delete messageElement.dataset.pendingMcpExecutionIds;
return;
}
appendMcpCallButtons(messageElement, executionIds);
delete messageElement.dataset.pendingMcpExecutionIds;
}
window.ensureMcpCallButtons = ensureMcpCallButtons;
window.appendMcpCallButtons = appendMcpCallButtons;
// 批量获取工具名称并更新按钮(消除 N 次单独 API 请求,合并为 1 次) // 批量获取工具名称并更新按钮(消除 N 次单独 API 请求,合并为 1 次)
async function batchUpdateButtonToolNames(buttonsContainer, executionIds) { async function batchUpdateButtonToolNames(buttonsContainer, executionIds) {
if (!executionIds || executionIds.length === 0) return; if (!executionIds || executionIds.length === 0) return;
@@ -3182,40 +3426,63 @@ function getConversationGroup(dateObj, todayStart, sevenDaysCutoff, yesterdaySta
} }
// 加载对话 // 加载对话
/** 轻量加载会话后,拉取最后一条助手消息的 process_details(机器人等 SSE 场景) */ /** 轻量加载会话后,仅对「处理中…」占位回复拉取过程详情(机器人等 SSE 场景);已完成会话不预取全量 */
async function prefetchLastAssistantProcessDetails() { async function prefetchLastAssistantProcessDetails() {
const nodes = document.querySelectorAll('#chat-messages .message.assistant'); const nodes = document.querySelectorAll('#chat-messages .message.assistant');
if (!nodes.length) return; if (!nodes.length) return;
const last = nodes[nodes.length - 1]; const last = nodes[nodes.length - 1];
if (!last || !last.id) return; if (!last || !last.id) return;
const bubble = last.querySelector('.message-bubble');
const visibleText = bubble ? String(bubble.textContent || '').trim() : '';
const isPlaceholder = visibleText === '处理中...' || visibleText === 'Processing...';
if (!isPlaceholder) return;
const container = document.getElementById('process-details-' + last.id); const container = document.getElementById('process-details-' + last.id);
if (!container || container.dataset.lazyNotLoaded !== '1') return; if (!container || container.dataset.lazyNotLoaded !== '1') return;
const backendId = last.dataset && last.dataset.backendMessageId; const backendId = last.dataset && last.dataset.backendMessageId;
if (!backendId || typeof apiFetch !== 'function') return; if (!backendId || typeof apiFetch !== 'function') return;
if (typeof window.loadProcessDetailsPaginated === 'function') {
await window.loadProcessDetailsPaginated(last.id, backendId);
return;
}
const res = await apiFetch('/api/messages/' + encodeURIComponent(String(backendId)) + '/process-details'); const res = await apiFetch('/api/messages/' + encodeURIComponent(String(backendId)) + '/process-details');
const j = await res.json().catch(() => ({})); const j = await res.json().catch(() => ({}));
if (!res.ok || !Array.isArray(j.processDetails) || j.processDetails.length === 0) return; if (!res.ok || !Array.isArray(j.processDetails) || j.processDetails.length === 0) return;
if (typeof renderProcessDetails === 'function') { if (typeof renderProcessDetails === 'function') {
renderProcessDetails(last.id, j.processDetails); renderProcessDetails(last.id, j.processDetails);
} }
if (typeof window.expandProcessDetailsTimeline === 'function') {
window.expandProcessDetailsTimeline(last.id);
}
} }
async function loadConversation(conversationId) { async function loadConversation(conversationId) {
const seq = ++loadConversationRequestSeq; const seq = ++loadConversationRequestSeq;
try { try {
// 轻量加载:不带 processDetails,避免历史会话切换卡顿;展开详情时再按需拉取 const cachedConversation = getConversationLiteFromCache(conversationId);
const response = await apiFetch(`/api/conversations/${conversationId}?include_process_details=0`); const fetchPromise = apiFetch(`/api/conversations/${conversationId}?include_process_details=0`)
if (seq !== loadConversationRequestSeq) { .then(async (response) => {
return; const data = await response.json();
} return { response, data };
const conversation = await response.json(); });
if (!response.ok) { let conversation;
showChatToast('加载对话失败: ' + (conversation.error || '未知错误'), 'error'); let response;
return; if (cachedConversation) {
conversation = cachedConversation;
fetchPromise.then(({ response: freshResp, data }) => {
if (freshResp.ok && data && seq === loadConversationRequestSeq && currentConversationId === conversationId) {
putConversationLiteCache(conversationId, data);
}
}).catch(() => {});
} else {
const fetched = await fetchPromise;
response = fetched.response;
conversation = fetched.data;
if (seq !== loadConversationRequestSeq) {
return;
}
if (!response.ok) {
showChatToast('加载对话失败: ' + (conversation.error || '未知错误'), 'error');
return;
}
putConversationLiteCache(conversationId, conversation);
} }
if (seq !== loadConversationRequestSeq) { if (seq !== loadConversationRequestSeq) {
return; return;
@@ -3265,11 +3532,15 @@ async function loadConversation(conversationId) {
if (typeof refreshChatProjectSelector === 'function') { if (typeof refreshChatProjectSelector === 'function') {
refreshChatProjectSelector(); refreshChatProjectSelector();
} }
if (typeof window.syncHitlConfigFromServer === 'function') { refreshHitlConfigByCurrentConversation();
await window.syncHitlConfigFromServer(conversationId); const hitlSyncPromise = (typeof window.syncHitlConfigFromServer === 'function')
} else { ? window.syncHitlConfigFromServer(conversationId).then(() => {
refreshHitlConfigByCurrentConversation(); if (seq === loadConversationRequestSeq && currentConversationId === conversationId) {
} refreshHitlConfigByCurrentConversation();
}
}).catch(() => {})
: Promise.resolve();
void hitlSyncPromise;
updateActiveConversation(); updateActiveConversation();
// 如果攻击链模态框打开且显示的不是当前对话,关闭它 // 如果攻击链模态框打开且显示的不是当前对话,关闭它
@@ -3336,7 +3607,9 @@ async function loadConversation(conversationId) {
// - user: createdAt 即可(发送后不会再更新) // - user: createdAt 即可(发送后不会再更新)
// - assistant: 如果后端提供 updatedAt(任务完成时写回),优先用它,避免占位消息“任务开始时间”误导 // - assistant: 如果后端提供 updatedAt(任务完成时写回),优先用它,避免占位消息“任务开始时间”误导
const msgTime = (msg && msg.role === 'assistant' && msg.updatedAt) ? msg.updatedAt : (msg ? msg.createdAt : null); const msgTime = (msg && msg.role === 'assistant' && msg.updatedAt) ? msg.updatedAt : (msg ? msg.createdAt : null);
const messageId = addMessage(msg.role, displayContent, msg.mcpExecutionIds || [], null, msgTime); const mcpIds = (msg.mcpExecutionIds && Array.isArray(msg.mcpExecutionIds)) ? msg.mcpExecutionIds : [];
const addOpts = (msg.role === 'assistant' && mcpIds.length > 0) ? { deferMcpButtons: true } : null;
const messageId = addMessage(msg.role, displayContent, mcpIds, null, msgTime, addOpts);
const messageEl = document.getElementById(messageId); const messageEl = document.getElementById(messageId);
if (messageEl && msg && msg.id) { if (messageEl && msg && msg.id) {
messageEl.dataset.backendMessageId = String(msg.id); messageEl.dataset.backendMessageId = String(msg.id);
@@ -3504,6 +3777,7 @@ async function deleteConversationTurnFromUI(anchorBackendMessageId) {
if (!response.ok) { if (!response.ok) {
throw new Error(data.error || data.message || 'delete failed'); throw new Error(data.error || data.message || 'delete failed');
} }
invalidateConversationLiteCache(currentConversationId);
await loadConversation(currentConversationId); await loadConversation(currentConversationId);
if (typeof loadConversationsWithGroups === 'function') { if (typeof loadConversationsWithGroups === 'function') {
loadConversationsWithGroups(); loadConversationsWithGroups();
@@ -3550,6 +3824,7 @@ async function deleteConversation(conversationId, skipConfirm = false) {
// 更新缓存 - 立即删除,确保后续加载时能正确识别 // 更新缓存 - 立即删除,确保后续加载时能正确识别
delete conversationGroupMappingCache[conversationId]; delete conversationGroupMappingCache[conversationId];
invalidateConversationLiteCache(conversationId);
// 同时从待保留映射中移除 // 同时从待保留映射中移除
delete pendingGroupMappings[conversationId]; delete pendingGroupMappings[conversationId];
@@ -7450,14 +7725,14 @@ async function showBatchManageModal() {
updateBatchManageTitle(allConversationsForBatch.length); updateBatchManageTitle(allConversationsForBatch.length);
renderBatchConversations(); renderBatchConversations();
openAppModal('batch-manage-modal'); openAppModal('batch-manage-modal', { focus: false });
} catch (error) { } catch (error) {
console.error('加载对话列表失败:', error); console.error('加载对话列表失败:', error);
// 错误时使用空数组,不显示错误提示(更友好的用户体验) // 错误时使用空数组,不显示错误提示(更友好的用户体验)
allConversationsForBatch = []; allConversationsForBatch = [];
updateBatchManageTitle(0); updateBatchManageTitle(0);
renderBatchConversations(); renderBatchConversations();
openAppModal('batch-manage-modal'); openAppModal('batch-manage-modal', { focus: false });
} }
} }
@@ -7517,6 +7792,7 @@ function renderBatchConversations(filtered = null) {
checkbox.type = 'checkbox'; checkbox.type = 'checkbox';
checkbox.className = 'batch-conversation-checkbox'; checkbox.className = 'batch-conversation-checkbox';
checkbox.dataset.conversationId = conv.id; checkbox.dataset.conversationId = conv.id;
checkbox.addEventListener('change', syncSelectAllBatchCheckbox);
const name = document.createElement('div'); const name = document.createElement('div');
name.className = 'batch-table-col-name'; name.className = 'batch-table-col-name';
@@ -7542,9 +7818,21 @@ function renderBatchConversations(filtered = null) {
const action = document.createElement('div'); const action = document.createElement('div');
action.className = 'batch-table-col-action'; action.className = 'batch-table-col-action';
const deleteBtn = document.createElement('button'); const deleteBtn = document.createElement('button');
deleteBtn.type = 'button';
deleteBtn.className = 'batch-delete-btn'; deleteBtn.className = 'batch-delete-btn';
deleteBtn.innerHTML = '🗑️'; deleteBtn.innerHTML = `
deleteBtn.onclick = () => deleteConversation(conv.id); <svg width="16" height="16" viewBox="0 0 24 24" fill="none" xmlns="http://www.w3.org/2000/svg" aria-hidden="true">
<path d="M3 6h18M8 6V4a2 2 0 0 1 2-2h4a2 2 0 0 1 2 2v2m3 0v14a2 2 0 0 1-2 2H7a2 2 0 0 1-2-2V6h14zM10 11v6M14 11v6"
stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"/>
</svg>
`;
const deleteLabel = typeof window.t === 'function' ? window.t('contextMenu.deleteConversation') : '删除此对话';
deleteBtn.title = deleteLabel;
deleteBtn.setAttribute('aria-label', deleteLabel);
deleteBtn.onclick = (e) => {
e.stopPropagation();
deleteConversation(conv.id);
};
action.appendChild(deleteBtn); action.appendChild(deleteBtn);
row.appendChild(checkbox); row.appendChild(checkbox);
@@ -7554,6 +7842,8 @@ function renderBatchConversations(filtered = null) {
list.appendChild(row); list.appendChild(row);
}); });
syncSelectAllBatchCheckbox();
} }
// 筛选批量管理对话 // 筛选批量管理对话
@@ -7575,12 +7865,35 @@ function filterBatchConversations(query) {
function toggleSelectAllBatch() { function toggleSelectAllBatch() {
const selectAll = document.getElementById('batch-select-all'); const selectAll = document.getElementById('batch-select-all');
const checkboxes = document.querySelectorAll('.batch-conversation-checkbox'); const checkboxes = document.querySelectorAll('.batch-conversation-checkbox');
if (selectAll) {
selectAll.indeterminate = false;
}
checkboxes.forEach(cb => { checkboxes.forEach(cb => {
cb.checked = selectAll.checked; cb.checked = selectAll.checked;
}); });
} }
function syncSelectAllBatchCheckbox() {
const selectAll = document.getElementById('batch-select-all');
if (!selectAll) return;
const checkboxes = document.querySelectorAll('.batch-conversation-checkbox');
const total = checkboxes.length;
const checked = document.querySelectorAll('.batch-conversation-checkbox:checked').length;
if (total === 0 || checked === 0) {
selectAll.checked = false;
selectAll.indeterminate = false;
} else if (checked === total) {
selectAll.checked = true;
selectAll.indeterminate = false;
} else {
selectAll.checked = false;
selectAll.indeterminate = true;
}
}
// 删除选中的对话 // 删除选中的对话
async function deleteSelectedConversations() { async function deleteSelectedConversations() {
const checkboxes = document.querySelectorAll('.batch-conversation-checkbox:checked'); const checkboxes = document.querySelectorAll('.batch-conversation-checkbox:checked');
@@ -7604,6 +7917,7 @@ async function deleteSelectedConversations() {
const selectAll = document.getElementById('batch-select-all'); const selectAll = document.getElementById('batch-select-all');
if (selectAll) { if (selectAll) {
selectAll.checked = false; selectAll.checked = false;
selectAll.indeterminate = false;
} }
} catch (error) { } catch (error) {
console.error('删除失败:', error); console.error('删除失败:', error);
@@ -7619,6 +7933,7 @@ function closeBatchManageModal() {
const selectAll = document.getElementById('batch-select-all'); const selectAll = document.getElementById('batch-select-all');
if (selectAll) { if (selectAll) {
selectAll.checked = false; selectAll.checked = false;
selectAll.indeterminate = false;
} }
allConversationsForBatch = []; allConversationsForBatch = [];
} }
@@ -7653,6 +7968,20 @@ function refreshChatPanelI18n() {
const expanded = timeline && timeline.classList.contains('expanded'); const expanded = timeline && timeline.classList.contains('expanded');
span.textContent = expanded ? t('tasks.collapseDetail') : t('chat.expandDetail'); span.textContent = expanded ? t('tasks.collapseDetail') : t('chat.expandDetail');
}); });
const copyLabel = t('common.copy');
const copyTitle = t('chat.copyMessageTitle');
messagesEl.querySelectorAll('.message-copy-btn').forEach(function (btn) {
if (btn.dataset.copySuccessActive === '1') return;
const span = btn.querySelector('span');
if (span) span.textContent = copyLabel;
btn.title = copyTitle;
btn.setAttribute('aria-label', copyTitle);
});
messagesEl.querySelectorAll('.message.assistant').forEach(function (msgEl) {
if (typeof window.syncMcpToolsToggleButton === 'function') {
window.syncMcpToolsToggleButton(msgEl);
}
});
} }
if (isAppModalOpen('mcp-detail-modal')) { if (isAppModalOpen('mcp-detail-modal')) {
+101 -85
View File
@@ -1259,101 +1259,60 @@ function integrateProgressToMCPSection(progressId, assistantMessageId, mcpExecut
return; return;
} }
// 查找或创建 MCP 区域 // 查找或创建 MCP 区域(工具栏 + 工具列表 + 迭代时间线)
let mcpSection = assistantElement.querySelector('.mcp-call-section'); if (typeof window.ensureMcpCallSectionChrome === 'function') {
if (!mcpSection) { window.ensureMcpCallSectionChrome(assistantElement, assistantMessageId);
mcpSection = document.createElement('div');
mcpSection.className = 'mcp-call-section';
const mcpLabel = document.createElement('div');
mcpLabel.className = 'mcp-call-label';
mcpLabel.textContent = '📋 ' + (typeof window.t === 'function' ? window.t('chat.penetrationTestDetail') : '渗透测试详情');
mcpSection.appendChild(mcpLabel);
const buttonsContainerInit = document.createElement('div');
buttonsContainerInit.className = 'mcp-call-buttons';
mcpSection.appendChild(buttonsContainerInit);
contentWrapper.appendChild(mcpSection);
} }
const mcpSection = assistantElement.querySelector('.mcp-call-section');
// 获取时间线内容 if (!mcpSection) {
const hasContent = timelineHTML.trim().length > 0; removeMessage(progressId);
return;
// 检查时间线中是否有错误项
const hasError = timeline && timeline.querySelector('.timeline-item-error');
// 确保按钮容器存在
let buttonsContainer = mcpSection.querySelector('.mcp-call-buttons');
if (!buttonsContainer) {
buttonsContainer = document.createElement('div');
buttonsContainer.className = 'mcp-call-buttons';
mcpSection.appendChild(buttonsContainer);
} }
let maxExecIndex = 0; const hasContent = timelineHTML.trim().length > 0;
const existingExecBtns = buttonsContainer.querySelectorAll('.mcp-detail-btn:not(.process-detail-btn)');
existingExecBtns.forEach(function (btn) { if (mcpIds.length > 0 && typeof window.appendMcpCallButtons === 'function') {
const n = parseInt(btn.dataset.execIndex, 10); window.appendMcpCallButtons(assistantElement, mcpIds);
if (!isNaN(n) && n > maxExecIndex) maxExecIndex = n; const toolList = mcpSection.querySelector('.mcp-tool-list');
}); if (toolList) toolList.classList.remove('expanded');
const seenExec = new Set();
existingExecBtns.forEach(function (btn) {
if (btn.dataset.execId) seenExec.add(String(btn.dataset.execId).trim());
});
let appendedAny = false;
if (mcpIds.length > 0) {
mcpIds.forEach(function (execId) {
const id = execId != null ? String(execId).trim() : '';
if (!id || seenExec.has(id)) return;
seenExec.add(id);
maxExecIndex += 1;
appendedAny = true;
const detailBtn = document.createElement('button');
detailBtn.className = 'mcp-detail-btn';
detailBtn.dataset.execId = id;
detailBtn.dataset.execIndex = String(maxExecIndex);
detailBtn.innerHTML = '<span>' + (typeof window.t === 'function' ? window.t('chat.callNumber', { n: maxExecIndex }) : '调用 #' + maxExecIndex) + '</span>';
detailBtn.onclick = function () { showMCPDetail(id); };
buttonsContainer.appendChild(detailBtn);
});
if (appendedAny && typeof batchUpdateButtonToolNames === 'function') {
batchUpdateButtonToolNames(buttonsContainer, mcpIds);
}
} }
if (!buttonsContainer.querySelector('.process-detail-btn')) { if (typeof window.syncMcpToolsToggleButton === 'function') {
window.syncMcpToolsToggleButton(assistantElement);
}
const toolbar = mcpSection.querySelector('.mcp-call-toolbar');
if (toolbar && !toolbar.querySelector('.process-detail-btn')) {
const progressDetailBtn = document.createElement('button'); const progressDetailBtn = document.createElement('button');
progressDetailBtn.className = 'mcp-detail-btn process-detail-btn'; progressDetailBtn.className = 'mcp-detail-btn process-detail-btn';
progressDetailBtn.innerHTML = '<span>' + (typeof window.t === 'function' ? window.t('chat.expandDetail') : '展开详情') + '</span>'; progressDetailBtn.innerHTML = '<span>' + (typeof window.t === 'function' ? window.t('chat.expandDetail') : '展开详情') + '</span>';
progressDetailBtn.onclick = () => toggleProcessDetails(null, assistantMessageId); progressDetailBtn.onclick = () => toggleProcessDetails(null, assistantMessageId);
buttonsContainer.appendChild(progressDetailBtn); toolbar.appendChild(progressDetailBtn);
} }
// 创建详情容器,放在MCP按钮区域下方(统一结构)
const detailsId = 'process-details-' + assistantMessageId; const detailsId = 'process-details-' + assistantMessageId;
let detailsContainer = document.getElementById(detailsId); let detailsContainer = document.getElementById(detailsId);
const toolListEl = mcpSection.querySelector('.mcp-tool-list');
if (!detailsContainer) { if (!detailsContainer) {
detailsContainer = document.createElement('div'); detailsContainer = document.createElement('div');
detailsContainer.id = detailsId; detailsContainer.id = detailsId;
detailsContainer.className = 'process-details-container'; detailsContainer.className = 'process-details-container';
// 确保容器在按钮容器之后 if (toolListEl) {
if (buttonsContainer.nextSibling) { toolListEl.after(detailsContainer);
mcpSection.insertBefore(detailsContainer, buttonsContainer.nextSibling);
} else { } else {
mcpSection.appendChild(detailsContainer); mcpSection.appendChild(detailsContainer);
} }
} }
// 设置详情内容(如果有错误,默认折叠;否则默认折叠)
detailsContainer.innerHTML = ` detailsContainer.innerHTML = `
<div class="process-details-content"> <div class="process-details-content">
${hasContent ? `<div class="progress-timeline" id="${detailsId}-timeline">${timelineHTML}</div>` : '<div class="progress-timeline-empty">' + (typeof window.t === 'function' ? window.t('chat.noProcessDetail') : '暂无过程详情(可能执行过快或未触发详细事件)') + '</div>'} ${hasContent ? `<div class="progress-timeline" id="${detailsId}-timeline">${timelineHTML}</div>` : '<div class="progress-timeline-empty">' + (typeof window.t === 'function' ? window.t('chat.noProcessDetail') : '暂无过程详情(可能执行过快或未触发详细事件)') + '</div>'}
</div> </div>
`; `;
// 确保初始状态是折叠的(默认折叠,特别是错误时)
if (hasContent) { if (hasContent) {
const timeline = document.getElementById(detailsId + '-timeline'); const timeline = document.getElementById(detailsId + '-timeline');
if (timeline) { if (timeline) {
// 如果有错误,确保折叠;否则也默认折叠
timeline.classList.remove('expanded'); timeline.classList.remove('expanded');
} }
@@ -1363,10 +1322,47 @@ function integrateProgressToMCPSection(progressId, assistantMessageId, mcpExecut
}); });
} }
// 移除原来的进度消息(详情已快照到助手消息下的 process-details
removeMessage(progressId); removeMessage(progressId);
} }
const PROCESS_DETAILS_PAGE_SIZE = 100;
/**
* 分页加载过程详情并增量渲染避免数百轮迭代一次性阻塞主线程
*/
async function loadProcessDetailsPaginated(assistantMessageId, backendMessageId) {
if (!assistantMessageId || !backendMessageId || typeof apiFetch !== 'function' || typeof renderProcessDetails !== 'function') {
return;
}
const PAGE = PROCESS_DETAILS_PAGE_SIZE;
let offset = 0;
let isFirst = true;
while (true) {
const res = await apiFetch(
'/api/messages/' + encodeURIComponent(String(backendMessageId)) +
'/process-details?limit=' + PAGE + '&offset=' + offset
);
const j = await res.json().catch(() => ({}));
if (!res.ok) {
throw new Error((j && j.error) ? j.error : String(res.status));
}
const details = (j && Array.isArray(j.processDetails)) ? j.processDetails : [];
const hasMore = !!(j && j.hasMore);
renderProcessDetails(assistantMessageId, details, {
append: !isFirst,
markLoaded: !hasMore
});
if (!hasMore || details.length === 0) {
break;
}
offset += details.length;
isFirst = false;
await new Promise((resolve) => requestAnimationFrame(resolve));
}
}
window.loadProcessDetailsPaginated = loadProcessDetailsPaginated;
// 切换过程详情显示 // 切换过程详情显示
function toggleProcessDetails(progressId, assistantMessageId) { function toggleProcessDetails(progressId, assistantMessageId) {
const detailsId = 'process-details-' + assistantMessageId; const detailsId = 'process-details-' + assistantMessageId;
@@ -1383,26 +1379,17 @@ function toggleProcessDetails(progressId, assistantMessageId) {
// 正在加载中,避免重复请求 // 正在加载中,避免重复请求
} else { } else {
detailsContainer.dataset.loading = '1'; detailsContainer.dataset.loading = '1';
// 先展开容器,显示加载态
const timeline = detailsContainer.querySelector('.progress-timeline'); const timeline = detailsContainer.querySelector('.progress-timeline');
if (timeline) { if (timeline) {
timeline.innerHTML = '<div class="progress-timeline-empty">' + ((typeof window.t === 'function') ? window.t('common.loading') : '加载中…') + '</div>'; timeline.innerHTML = '<div class="progress-timeline-empty">' + ((typeof window.t === 'function') ? window.t('common.loading') : '加载中…') + '</div>';
} }
apiFetch(`/api/messages/${encodeURIComponent(String(backendMessageId))}/process-details`) loadProcessDetailsPaginated(assistantMessageId, backendMessageId)
.then(async (res) => {
const j = await res.json().catch(() => ({}));
if (!res.ok) throw new Error((j && j.error) ? j.error : res.status);
const details = (j && Array.isArray(j.processDetails)) ? j.processDetails : [];
// 重新渲染详情(renderProcessDetails 会清掉 lazy 标记并写入 loaded
renderProcessDetails(assistantMessageId, details);
})
.catch((e) => { .catch((e) => {
console.error('加载过程详情失败:', e); console.error('加载过程详情失败:', e);
const tl = detailsContainer.querySelector('.progress-timeline'); const tl = detailsContainer.querySelector('.progress-timeline');
if (tl) { if (tl) {
tl.innerHTML = '<div class="progress-timeline-empty">' + ((typeof window.t === 'function') ? window.t('chat.noProcessDetail') : '暂无过程详情(加载失败)') + '</div>'; tl.innerHTML = '<div class="progress-timeline-empty">' + ((typeof window.t === 'function') ? window.t('chat.noProcessDetail') : '暂无过程详情(加载失败)') + '</div>';
} }
// 失败时保留 lazy 状态,允许用户重试
detailsContainer.dataset.lazyNotLoaded = '1'; detailsContainer.dataset.lazyNotLoaded = '1';
detailsContainer.dataset.loaded = '0'; detailsContainer.dataset.loaded = '0';
}) })
@@ -1944,6 +1931,7 @@ function handleStreamEvent(event, progressElement, progressId,
message: event.message || '', message: event.message || '',
data: d data: d
}); });
finalizeOutstandingToolCallsForProgress(progressId, 'failed');
break; break;
} }
@@ -2755,12 +2743,16 @@ async function restoreHitlInlineForConversation(conversationId) {
if (detailsContainer.dataset.lazyNotLoaded === '1' && detailsContainer.dataset.loaded !== '1') { if (detailsContainer.dataset.lazyNotLoaded === '1' && detailsContainer.dataset.loaded !== '1') {
try { try {
detailsContainer.dataset.loading = '1'; detailsContainer.dataset.loading = '1';
const res = await apiFetch('/api/messages/' + encodeURIComponent(backendMsgId) + '/process-details'); if (typeof loadProcessDetailsPaginated === 'function') {
const j = await res.json().catch(function () { return {}; }); await loadProcessDetailsPaginated(clientMsgId, backendMsgId);
if (!res.ok) throw new Error((j && j.error) ? j.error : String(res.status)); } else {
const details = (j && Array.isArray(j.processDetails)) ? j.processDetails : []; const res = await apiFetch('/api/messages/' + encodeURIComponent(backendMsgId) + '/process-details');
if (typeof renderProcessDetails === 'function') { const j = await res.json().catch(function () { return {}; });
renderProcessDetails(clientMsgId, details); if (!res.ok) throw new Error((j && j.error) ? j.error : String(res.status));
const details = (j && Array.isArray(j.processDetails)) ? j.processDetails : [];
if (typeof renderProcessDetails === 'function') {
renderProcessDetails(clientMsgId, details);
}
} }
} catch (e) { } catch (e) {
console.error('加载过程详情失败(HITL 恢复):', e); console.error('加载过程详情失败(HITL 恢复):', e);
@@ -3531,6 +3523,7 @@ const monitorState = {
timelineRange: null, timelineRange: null,
timelineError: null, timelineError: null,
lastFetchedAt: null, lastFetchedAt: null,
retentionDays: 0,
pagination: { pagination: {
page: 1, page: 1,
pageSize: (() => { pageSize: (() => {
@@ -3626,6 +3619,7 @@ async function refreshMonitorPanel(page = null) {
monitorState.timeline = timeline; monitorState.timeline = timeline;
monitorState.timelineError = timelineError; monitorState.timelineError = timelineError;
monitorState.lastFetchedAt = new Date(); monitorState.lastFetchedAt = new Date();
monitorState.retentionDays = typeof result.retention_days === 'number' ? result.retention_days : 0;
// 更新分页信息 // 更新分页信息
if (result.total !== undefined) { if (result.total !== undefined) {
@@ -3709,6 +3703,7 @@ async function refreshMonitorPanelWithFilter(statusFilter = 'all', toolFilter =
monitorState.timeline = timeline; monitorState.timeline = timeline;
monitorState.timelineError = timelineError; monitorState.timelineError = timelineError;
monitorState.lastFetchedAt = new Date(); monitorState.lastFetchedAt = new Date();
monitorState.retentionDays = typeof result.retention_days === 'number' ? result.retention_days : 0;
// 更新分页信息 // 更新分页信息
if (result.total !== undefined) { if (result.total !== undefined) {
@@ -4526,15 +4521,20 @@ function renderMcpStatsStackedBar(success, failed) {
</div>`; </div>`;
} }
function updateMonitorStatsSubtitle(lastFetchedAt, toolCount) { function updateMonitorStatsSubtitle(lastFetchedAt, toolCount, retentionDays) {
const subtitle = document.getElementById('monitor-stats-subtitle'); const subtitle = document.getElementById('monitor-stats-subtitle');
if (!subtitle) return; if (!subtitle) return;
const locale = (typeof window.__locale === 'string' && window.__locale.startsWith('zh')) ? 'zh-CN' : 'en-US'; const locale = (typeof window.__locale === 'string' && window.__locale.startsWith('zh')) ? 'zh-CN' : 'en-US';
const timeText = lastFetchedAt const timeText = lastFetchedAt
? (lastFetchedAt.toLocaleString ? lastFetchedAt.toLocaleString(locale) : String(lastFetchedAt)) ? (lastFetchedAt.toLocaleString ? lastFetchedAt.toLocaleString(locale) : String(lastFetchedAt))
: '—'; : '—';
const text = mcpMonitorT('statsSubtitle', { time: timeText, count: toolCount }) let text = mcpMonitorT('statsSubtitle', { time: timeText, count: toolCount })
|| monitorFallback(`最后刷新 ${timeText} · 共 ${toolCount} 个工具`, `Refreshed ${timeText} · ${toolCount} tools`); || monitorFallback(`最后刷新 ${timeText} · 共 ${toolCount} 个工具`, `Refreshed ${timeText} · ${toolCount} tools`);
if (typeof retentionDays === 'number' && retentionDays > 0) {
const hint = mcpMonitorT('retentionHint', { days: retentionDays })
|| monitorFallback(`执行记录保留 ${retentionDays} 天,超期自动清理`, `Execution records are kept for ${retentionDays} days, then purged automatically.`);
text += ' · ' + hint;
}
subtitle.textContent = text; subtitle.textContent = text;
subtitle.hidden = false; subtitle.hidden = false;
} }
@@ -4959,7 +4959,7 @@ function renderMonitorStats(statsMap = {}, lastFetchedAt = null) {
} else if (toolFilterEl) { } else if (toolFilterEl) {
toolFilterEl.classList.remove('is-filter-active'); toolFilterEl.classList.remove('is-filter-active');
} }
updateMonitorStatsSubtitle(lastFetchedAt, entries.length); updateMonitorStatsSubtitle(lastFetchedAt, entries.length, monitorState.retentionDays);
} }
function renderMonitorExecutions(executions = [], statusFilter = 'all') { function renderMonitorExecutions(executions = [], statusFilter = 'all') {
@@ -5459,6 +5459,22 @@ function refreshProgressAndTimelineI18n() {
const expanded = timeline && timeline.classList.contains('expanded'); const expanded = timeline && timeline.classList.contains('expanded');
span.textContent = expanded ? _t('tasks.collapseDetail') : _t('chat.expandDetail'); span.textContent = expanded ? _t('tasks.collapseDetail') : _t('chat.expandDetail');
}); });
document.querySelectorAll('#chat-messages .message.assistant').forEach(function (msgEl) {
if (typeof window.syncMcpToolsToggleButton === 'function') {
window.syncMcpToolsToggleButton(msgEl);
}
});
const copyLabel = _t('common.copy');
const copyTitle = _t('chat.copyMessageTitle');
document.querySelectorAll('#chat-messages .message-copy-btn').forEach(function (btn) {
if (btn.dataset.copySuccessActive === '1') return;
const span = btn.querySelector('span');
if (span) span.textContent = copyLabel;
btn.title = copyTitle;
btn.setAttribute('aria-label', copyTitle);
});
} }
document.addEventListener('languagechange', function () { document.addEventListener('languagechange', function () {
+77
View File
@@ -990,6 +990,7 @@ async function createBatchQueue() {
const roleSelect = document.getElementById('batch-queue-role'); const roleSelect = document.getElementById('batch-queue-role');
const projectSelect = document.getElementById('batch-queue-project-id'); const projectSelect = document.getElementById('batch-queue-project-id');
const agentModeSelect = document.getElementById('batch-queue-agent-mode'); const agentModeSelect = document.getElementById('batch-queue-agent-mode');
const concurrencyInput = document.getElementById('batch-queue-concurrency');
const scheduleModeSelect = document.getElementById('batch-queue-schedule-mode'); const scheduleModeSelect = document.getElementById('batch-queue-schedule-mode');
const cronExprInput = document.getElementById('batch-queue-cron-expr'); const cronExprInput = document.getElementById('batch-queue-cron-expr');
const executeNowCheckbox = document.getElementById('batch-queue-execute-now'); const executeNowCheckbox = document.getElementById('batch-queue-execute-now');
@@ -1019,6 +1020,9 @@ async function createBatchQueue() {
const scheduleMode = scheduleModeSelect ? (scheduleModeSelect.value === 'cron' ? 'cron' : 'manual') : 'manual'; const scheduleMode = scheduleModeSelect ? (scheduleModeSelect.value === 'cron' ? 'cron' : 'manual') : 'manual';
const cronExpr = cronExprInput ? cronExprInput.value.trim() : ''; const cronExpr = cronExprInput ? cronExprInput.value.trim() : '';
const executeNow = executeNowCheckbox ? !!executeNowCheckbox.checked : false; const executeNow = executeNowCheckbox ? !!executeNowCheckbox.checked : false;
let concurrency = concurrencyInput ? parseInt(concurrencyInput.value, 10) : 1;
if (!Number.isFinite(concurrency) || concurrency < 1) concurrency = 1;
if (concurrency > 8) concurrency = 8;
if (scheduleMode === 'cron' && !cronExpr) { if (scheduleMode === 'cron' && !cronExpr) {
alert(_t('batchImportModal.cronExprRequired')); alert(_t('batchImportModal.cronExprRequired'));
return; return;
@@ -1043,6 +1047,7 @@ async function createBatchQueue() {
cronExpr, cronExpr,
executeNow, executeNow,
projectId, projectId,
concurrency,
}), }),
}); });
@@ -1489,6 +1494,7 @@ async function showBatchQueueDetail(queueId) {
<div class="bq-kv"><span class="bq-kv__k">${escapeHtml(_t('batchQueueDetailModal.role'))}</span><span class="bq-kv__v" id="bq-role-val">${allowSubtaskMutation ? `<span class="bq-inline-editable" onclick="startInlineEditRole()" title="${escapeHtml(_t('common.edit'))}">${roleLineVal}</span>` : roleLineVal}</span></div> <div class="bq-kv"><span class="bq-kv__k">${escapeHtml(_t('batchQueueDetailModal.role'))}</span><span class="bq-kv__v" id="bq-role-val">${allowSubtaskMutation ? `<span class="bq-inline-editable" onclick="startInlineEditRole()" title="${escapeHtml(_t('common.edit'))}">${roleLineVal}</span>` : roleLineVal}</span></div>
<div class="bq-kv"><span class="bq-kv__k">${escapeHtml(_t('batchImportModal.agentMode'))}</span><span class="bq-kv__v" id="bq-agentmode-val">${allowSubtaskMutation ? `<span class="bq-inline-editable" onclick="startInlineEditAgentMode()" title="${escapeHtml(_t('common.edit'))}">${escapeHtml(agentModeText)}</span>` : escapeHtml(agentModeText)}</span></div> <div class="bq-kv"><span class="bq-kv__k">${escapeHtml(_t('batchImportModal.agentMode'))}</span><span class="bq-kv__v" id="bq-agentmode-val">${allowSubtaskMutation ? `<span class="bq-inline-editable" onclick="startInlineEditAgentMode()" title="${escapeHtml(_t('common.edit'))}">${escapeHtml(agentModeText)}</span>` : escapeHtml(agentModeText)}</span></div>
<div class="bq-kv"><span class="bq-kv__k">${escapeHtml(_t('batchImportModal.scheduleMode'))}</span><span class="bq-kv__v" id="bq-schedule-val">${allowSubtaskMutation ? `<span class="bq-inline-editable" onclick="startInlineEditSchedule()" title="${escapeHtml(_t('common.edit'))}">${scheduleDetail}</span>` : scheduleDetail}</span></div> <div class="bq-kv"><span class="bq-kv__k">${escapeHtml(_t('batchImportModal.scheduleMode'))}</span><span class="bq-kv__v" id="bq-schedule-val">${allowSubtaskMutation ? `<span class="bq-inline-editable" onclick="startInlineEditSchedule()" title="${escapeHtml(_t('common.edit'))}">${scheduleDetail}</span>` : scheduleDetail}</span></div>
<div class="bq-kv"><span class="bq-kv__k">${escapeHtml(_t('batchQueueDetailModal.concurrency'))}</span><span class="bq-kv__v" id="bq-concurrency-val">${allowSubtaskMutation ? `<span class="bq-inline-editable" onclick="startInlineEditConcurrency()" title="${escapeHtml(_t('common.edit'))}">${escapeHtml(String(queue.concurrency && queue.concurrency > 0 ? queue.concurrency : 1))}</span>` : escapeHtml(String(queue.concurrency && queue.concurrency > 0 ? queue.concurrency : 1))}</span></div>
<div class="bq-kv"><span class="bq-kv__k">${escapeHtml(_t('batchQueueDetailModal.taskTotal'))}</span><span class="bq-kv__v">${queue.tasks.length}</span></div> <div class="bq-kv"><span class="bq-kv__k">${escapeHtml(_t('batchQueueDetailModal.taskTotal'))}</span><span class="bq-kv__v">${queue.tasks.length}</span></div>
${queue.scheduleMode === 'cron' ? `<div class="bq-kv bq-kv--block"><span class="bq-kv__k">${escapeHtml(_t('batchQueueDetailModal.scheduleCronAuto'))}</span><span class="bq-kv__v bq-kv__v--control"><label class="bq-cron-toggle"><input type="checkbox" ${queue.scheduleEnabled !== false ? 'checked' : ''} onchange="updateBatchQueueScheduleEnabled(this.checked)" /><span class="bq-cron-toggle__hint">${escapeHtml(_t('batchQueueDetailModal.scheduleCronAutoHint'))}</span></label></span></div>` : ''} ${queue.scheduleMode === 'cron' ? `<div class="bq-kv bq-kv--block"><span class="bq-kv__k">${escapeHtml(_t('batchQueueDetailModal.scheduleCronAuto'))}</span><span class="bq-kv__v bq-kv__v--control"><label class="bq-cron-toggle"><input type="checkbox" ${queue.scheduleEnabled !== false ? 'checked' : ''} onchange="updateBatchQueueScheduleEnabled(this.checked)" /><span class="bq-cron-toggle__hint">${escapeHtml(_t('batchQueueDetailModal.scheduleCronAutoHint'))}</span></label></span></div>` : ''}
</section> </section>
@@ -2287,6 +2293,75 @@ async function saveInlineAgentMode() {
} }
} }
function normalizeBatchQueueConcurrencyInput(raw) {
let n = parseInt(raw, 10);
if (!Number.isFinite(n) || n < 1) n = 1;
if (n > 8) n = 8;
return n;
}
// --- 内联编辑:并发数 ---
function startInlineEditConcurrency() {
const container = document.getElementById('bq-concurrency-val');
if (!container) return;
const queueId = batchQueuesState.currentQueueId;
if (!queueId) return;
apiFetch(`/api/batch-tasks/${queueId}`).then(r => r.json()).then(detail => {
const queue = detail.queue || {};
const current = normalizeBatchQueueConcurrencyInput(queue.concurrency || 1);
container.innerHTML = `<span class="bq-inline-edit-controls">
<input type="number" id="bq-edit-concurrency" min="1" max="8" value="${current}" style="width:72px;" />
</span>`;
const inp = document.getElementById('bq-edit-concurrency');
if (!inp) return;
inp.focus();
inp.select();
let cancelled = false;
inp.addEventListener('keydown', (e) => {
if (e.key === 'Enter') { e.preventDefault(); inp.blur(); }
if (e.key === 'Escape') { cancelled = true; cancelAllInlineEdits(); }
});
inp.addEventListener('blur', () => {
if (!cancelled) saveInlineConcurrency();
});
});
}
async function saveInlineConcurrency() {
if (_bqInlineSaving) return;
_bqInlineSaving = true;
const queueId = batchQueuesState.currentQueueId;
if (!queueId) { _bqInlineSaving = false; return; }
const inp = document.getElementById('bq-edit-concurrency');
const concurrency = normalizeBatchQueueConcurrencyInput(inp ? inp.value : 1);
try {
const detailResp = await apiFetch(`/api/batch-tasks/${queueId}`);
const detail = await detailResp.json();
const q = detail.queue || {};
const response = await apiFetch(`/api/batch-tasks/${queueId}/metadata`, {
method: 'PUT',
headers: { 'Content-Type': 'application/json' },
body: JSON.stringify({
title: q.title || '',
role: q.role || '',
agentMode: q.agentMode || 'eino_single',
concurrency,
}),
});
if (!response.ok) {
const result = await response.json().catch(() => ({}));
throw new Error(result.error || _t('tasks.updateTaskFailed'));
}
_bqInlineSaving = false;
showBatchQueueDetail(queueId);
refreshBatchQueues();
} catch (e) {
_bqInlineSaving = false;
console.error(e);
alert(e.message);
}
}
// --- 单条执行 --- // --- 单条执行 ---
async function runSingleBatchTask(queueId, taskId) { async function runSingleBatchTask(queueId, taskId) {
if (!queueId || !taskId) return; if (!queueId || !taskId) return;
@@ -2441,6 +2516,8 @@ window.startInlineEditRole = startInlineEditRole;
window.saveInlineRole = saveInlineRole; window.saveInlineRole = saveInlineRole;
window.startInlineEditAgentMode = startInlineEditAgentMode; window.startInlineEditAgentMode = startInlineEditAgentMode;
window.saveInlineAgentMode = saveInlineAgentMode; window.saveInlineAgentMode = saveInlineAgentMode;
window.startInlineEditConcurrency = startInlineEditConcurrency;
window.saveInlineConcurrency = saveInlineConcurrency;
window.runSingleBatchTask = runSingleBatchTask; window.runSingleBatchTask = runSingleBatchTask;
window.startInlineEditSchedule = startInlineEditSchedule; window.startInlineEditSchedule = startInlineEditSchedule;
window.toggleInlineScheduleCron = toggleInlineScheduleCron; window.toggleInlineScheduleCron = toggleInlineScheduleCron;
+8 -5
View File
@@ -3777,7 +3777,9 @@
<div class="modal-body batch-manage-body"> <div class="modal-body batch-manage-body">
<div class="batch-conversations-table"> <div class="batch-conversations-table">
<div class="batch-table-header"> <div class="batch-table-header">
<div class="batch-table-col-checkbox"></div> <div class="batch-table-col-checkbox">
<input type="checkbox" id="batch-select-all" onchange="toggleSelectAllBatch()" data-i18n="batchManageModal.selectAll" data-i18n-attr="title" title="全选" />
</div>
<div class="batch-table-col-name" data-i18n="batchManageModal.conversationName">对话名称</div> <div class="batch-table-col-name" data-i18n="batchManageModal.conversationName">对话名称</div>
<div class="batch-table-col-time" data-i18n="batchManageModal.lastTime">最近一次对话时间</div> <div class="batch-table-col-time" data-i18n="batchManageModal.lastTime">最近一次对话时间</div>
<div class="batch-table-col-action" data-i18n="batchManageModal.action">操作</div> <div class="batch-table-col-action" data-i18n="batchManageModal.action">操作</div>
@@ -3786,10 +3788,6 @@
</div> </div>
</div> </div>
<div class="modal-footer batch-manage-footer"> <div class="modal-footer batch-manage-footer">
<label class="select-all-checkbox">
<input type="checkbox" id="batch-select-all" onchange="toggleSelectAllBatch()" />
<span data-i18n="batchManageModal.selectAll">全选</span>
</label>
<div class="batch-footer-actions"> <div class="batch-footer-actions">
<button class="btn-secondary" onclick="closeBatchManageModal()" data-i18n="common.cancel">取消</button> <button class="btn-secondary" onclick="closeBatchManageModal()" data-i18n="common.cancel">取消</button>
<button class="btn-primary" onclick="deleteSelectedConversations()" data-i18n="batchManageModal.deleteSelected">删除所选</button> <button class="btn-primary" onclick="deleteSelectedConversations()" data-i18n="batchManageModal.deleteSelected">删除所选</button>
@@ -4012,6 +4010,11 @@
</select> </select>
<div class="form-hint" style="margin-top: 4px;" data-i18n="batchImportModal.agentModeHint">与对话页一致:Eino 单代理(ADK),或 Deep / Plan-Execute / Supervisor(后三种需已启用多代理)。</div> <div class="form-hint" style="margin-top: 4px;" data-i18n="batchImportModal.agentModeHint">与对话页一致:Eino 单代理(ADK),或 Deep / Plan-Execute / Supervisor(后三种需已启用多代理)。</div>
</div> </div>
<div class="form-group">
<label for="batch-queue-concurrency" data-i18n="batchImportModal.concurrency">并发数</label>
<input type="number" id="batch-queue-concurrency" min="1" max="8" value="1" style="width: 100%; padding: 8px; border: 1px solid #ddd; border-radius: 4px; font-size: 0.875rem;" />
<div class="form-hint" style="margin-top: 4px;" data-i18n="batchImportModal.concurrencyHint">同时执行的子任务数量(1-8)。默认 1 为串行;含扫描类工具时建议 1-2。</div>
</div>
<div class="form-group"> <div class="form-group">
<label for="batch-queue-schedule-mode" data-i18n="batchImportModal.scheduleMode">调度方式</label> <label for="batch-queue-schedule-mode" data-i18n="batchImportModal.scheduleMode">调度方式</label>
<select id="batch-queue-schedule-mode" onchange="handleBatchScheduleModeChange()" style="width: 100%; padding: 8px; border: 1px solid #ddd; border-radius: 4px; font-size: 0.875rem;"> <select id="batch-queue-schedule-mode" onchange="handleBatchScheduleModeChange()" style="width: 100%; padding: 8px; border: 1px solid #ddd; border-radius: 4px; font-size: 0.875rem;">