mirror of
https://github.com/Ed1s0nZ/CyberStrikeAI.git
synced 2026-06-23 14:30:11 +02:00
Compare commits
30 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 7a3c67458c | |||
| 6e9e43eec8 | |||
| bca86e48ae | |||
| 3f3b8b4db4 | |||
| b366dc0287 | |||
| a52452ceea | |||
| 5b87667782 | |||
| 4f0e812d37 | |||
| 79691c021f | |||
| 5a8309a015 | |||
| 6244197339 | |||
| eb14aca05a | |||
| 091e8a4da8 | |||
| 48ce0c519e | |||
| afc37051c0 | |||
| 2964247361 | |||
| 02919df476 | |||
| c3294d96a2 | |||
| c8b8b41bda | |||
| 9a4c333b90 | |||
| 8e21ae290a | |||
| b9d102d046 | |||
| 8c85494a05 | |||
| c3d2a41301 | |||
| 1a2e282d46 | |||
| 8129f2147f | |||
| 4a9889f0af | |||
| 732d47a965 | |||
| e22382aab0 | |||
| b6ff80adf2 |
@@ -40,6 +40,9 @@ audit:
|
||||
retention_days: 15 # 0 表示不自动清理
|
||||
max_detail_bytes: 8192
|
||||
auth_failure_cooldown_seconds: 60 # 同一 IP 登录/改密失败审计最短间隔(秒);未配置时默认 60;-1 关闭节流
|
||||
# MCP 状态监控执行记录保留(tool_executions 表)
|
||||
monitor:
|
||||
retention_days: 90 # 省略时默认 90;0 表示不自动清理
|
||||
# ============================================
|
||||
# 对话相关配置
|
||||
# ============================================
|
||||
|
||||
+9
-1
@@ -25,6 +25,7 @@ import (
|
||||
"cyberstrike-ai/internal/logger"
|
||||
"cyberstrike-ai/internal/mcp"
|
||||
"cyberstrike-ai/internal/mcp/builtin"
|
||||
"cyberstrike-ai/internal/monitor"
|
||||
"cyberstrike-ai/internal/robot"
|
||||
"cyberstrike-ai/internal/security"
|
||||
"cyberstrike-ai/internal/skillpackage"
|
||||
@@ -99,6 +100,10 @@ func New(cfg *config.Config, log *logger.Logger, configPath string) (*App, error
|
||||
auditSvc.PurgeExpired()
|
||||
audit.StartRetentionLoop(auditSvc, log.Logger)
|
||||
|
||||
monitorRetention := monitor.NewService(db, cfg, log.Logger)
|
||||
monitorRetention.PurgeExpired()
|
||||
monitor.StartRetentionLoop(monitorRetention, log.Logger)
|
||||
|
||||
// 创建MCP服务器(带数据库持久化)
|
||||
mcpServer := mcp.NewServerWithStorage(log.Logger, db)
|
||||
mcpServer.ConfigureHTTPToolCallTimeoutFromAgentMinutes(cfg.Agent.ToolTimeoutMinutes)
|
||||
@@ -298,7 +303,8 @@ func New(cfg *config.Config, log *logger.Logger, configPath string) (*App, error
|
||||
plantaskBase := filepath.Join(skillsDir, plantaskRel)
|
||||
// Match eino_adk_run_loop: checkpoint_dir is used as configured (relative to process CWD when not absolute).
|
||||
checkpointBase := strings.TrimSpace(cfg.MultiAgent.EinoMiddleware.CheckpointDir)
|
||||
db.SetEinoConversationDirs(plantaskBase, checkpointBase)
|
||||
reductionRoot := strings.TrimSpace(cfg.MultiAgent.EinoMiddleware.ReductionRootDir)
|
||||
db.SetEinoConversationDirs(plantaskBase, checkpointBase, reductionRoot)
|
||||
agent.SetPromptBaseDir(configDir)
|
||||
|
||||
agentsDir := cfg.AgentsDir
|
||||
@@ -325,6 +331,7 @@ func New(cfg *config.Config, log *logger.Logger, configPath string) (*App, error
|
||||
}
|
||||
monitorHandler := handler.NewMonitorHandler(mcpServer, executor, db, log.Logger)
|
||||
monitorHandler.SetAudit(auditSvc)
|
||||
monitorHandler.SetMonitorRetention(monitorRetention)
|
||||
monitorHandler.SetExternalMCPManager(externalMCPMgr) // 设置外部MCP管理器,以便获取外部MCP执行记录
|
||||
notificationHandler := handler.NewNotificationHandler(db, agentHandler, log.Logger)
|
||||
groupHandler := handler.NewGroupHandler(db, log.Logger)
|
||||
@@ -368,6 +375,7 @@ func New(cfg *config.Config, log *logger.Logger, configPath string) (*App, error
|
||||
// 创建OpenAPI处理器
|
||||
conversationHandler := handler.NewConversationHandler(db, log.Logger)
|
||||
conversationHandler.SetAudit(auditSvc)
|
||||
conversationHandler.SetTaskStopper(agentHandler)
|
||||
auditHandler := handler.NewAuditHandler(db, auditSvc, log.Logger)
|
||||
robotHandler := handler.NewRobotHandler(cfg, db, agentHandler, log.Logger)
|
||||
openAPIHandler := handler.NewOpenAPIHandler(db, log.Logger, conversationHandler, agentHandler)
|
||||
|
||||
@@ -27,6 +27,7 @@ type Config struct {
|
||||
Database DatabaseConfig `yaml:"database"`
|
||||
Auth AuthConfig `yaml:"auth"`
|
||||
Audit AuditConfig `yaml:"audit,omitempty" json:"audit,omitempty"`
|
||||
Monitor MonitorConfig `yaml:"monitor,omitempty" json:"monitor,omitempty"`
|
||||
ExternalMCP ExternalMCPConfig `yaml:"external_mcp,omitempty"`
|
||||
Knowledge KnowledgeConfig `yaml:"knowledge,omitempty"`
|
||||
C2 C2Config `yaml:"c2,omitempty" json:"c2,omitempty"` // 内置 C2 总开关;未配置时默认启用
|
||||
@@ -249,7 +250,7 @@ type MultiAgentEinoMiddlewareConfig struct {
|
||||
SummarizationTriggerRatio float64 `yaml:"summarization_trigger_ratio,omitempty" json:"summarization_trigger_ratio,omitempty"`
|
||||
// SummarizationEmitInternalEvents controls middleware internal event emission (default true).
|
||||
SummarizationEmitInternalEvents *bool `yaml:"summarization_emit_internal_events,omitempty" json:"summarization_emit_internal_events,omitempty"`
|
||||
// SummarizationRetryMaxAttempts is extra retries after the first summarization Generate attempt; 0 = default 3.
|
||||
// SummarizationRetryMaxAttempts 已废弃:summarization 与 run loop 共用 run_retry_max_attempts 及 isEinoTransientRunError。
|
||||
SummarizationRetryMaxAttempts int `yaml:"summarization_retry_max_attempts,omitempty" json:"summarization_retry_max_attempts,omitempty"`
|
||||
// PlanExecuteUserInputBudgetRatio caps planner/replanner/executor userInput prompt budget ratio (default 0.35).
|
||||
PlanExecuteUserInputBudgetRatio float64 `yaml:"plan_execute_user_input_budget_ratio,omitempty" json:"plan_execute_user_input_budget_ratio,omitempty"`
|
||||
@@ -263,9 +264,9 @@ type MultiAgentEinoMiddlewareConfig struct {
|
||||
CheckpointDir string `yaml:"checkpoint_dir,omitempty" json:"checkpoint_dir,omitempty"`
|
||||
// DeepOutputKey passed to deep.Config OutputKey (session final text); empty = off.
|
||||
DeepOutputKey string `yaml:"deep_output_key,omitempty" json:"deep_output_key,omitempty"`
|
||||
// DeepModelRetryMaxRetries > 0 enables deep.Config ModelRetryConfig (framework-level chat model retries).
|
||||
// DeepModelRetryMaxRetries 已废弃:临时错误统一由 run loop 内 isEinoTransientRunError + run_retry_max_attempts 处理。
|
||||
DeepModelRetryMaxRetries int `yaml:"deep_model_retry_max_retries,omitempty" json:"deep_model_retry_max_retries,omitempty"`
|
||||
// RunRetryMaxAttempts > 0:429/5xx/网络抖动时 handler 分段续跑次数;0=默认 10。
|
||||
// RunRetryMaxAttempts > 0:429/5xx/网络抖动时可退避重试次数(run loop 与 summarization 共用);0=默认 10。
|
||||
RunRetryMaxAttempts int `yaml:"run_retry_max_attempts,omitempty" json:"run_retry_max_attempts,omitempty"`
|
||||
// RunRetryMaxBackoffSec 单次退避上限秒数;0=默认 30。
|
||||
RunRetryMaxBackoffSec int `yaml:"run_retry_max_backoff_sec,omitempty" json:"run_retry_max_backoff_sec,omitempty"`
|
||||
@@ -623,6 +624,23 @@ type AuthConfig struct {
|
||||
GeneratedPasswordPersistErr string `yaml:"-" json:"-"`
|
||||
}
|
||||
|
||||
// MonitorConfig MCP 状态监控(tool_executions)保留策略。
|
||||
type MonitorConfig struct {
|
||||
// RetentionDays 执行记录保留天数;省略时默认 90;0 表示不自动清理。
|
||||
RetentionDays *int `yaml:"retention_days,omitempty" json:"retention_days,omitempty"`
|
||||
}
|
||||
|
||||
// RetentionDaysEffective returns retention; 0 means keep forever; omitted defaults to 90.
|
||||
func (m MonitorConfig) RetentionDaysEffective() int {
|
||||
if m.RetentionDays == nil {
|
||||
return 90
|
||||
}
|
||||
if *m.RetentionDays < 0 {
|
||||
return 0
|
||||
}
|
||||
return *m.RetentionDays
|
||||
}
|
||||
|
||||
// AuditConfig platform operation audit log settings (not chat/tool execution bodies).
|
||||
type AuditConfig struct {
|
||||
// Enabled nil or true enables persistence; explicit false disables.
|
||||
@@ -1274,6 +1292,10 @@ func Default() *Config {
|
||||
Enabled: &on,
|
||||
}
|
||||
}(),
|
||||
Monitor: func() MonitorConfig {
|
||||
days := 90
|
||||
return MonitorConfig{RetentionDays: &days}
|
||||
}(),
|
||||
Robots: RobotsConfig{
|
||||
Session: RobotSessionConfig{
|
||||
StrictUserIdentity: &strictRobotIdentity,
|
||||
|
||||
@@ -585,12 +585,14 @@ func (db *DB) DeleteConversation(id string) error {
|
||||
// 不返回错误,继续删除对话
|
||||
}
|
||||
|
||||
projectID, _ := db.GetConversationProjectID(id)
|
||||
|
||||
// 删除对话(外键CASCADE会自动删除其他相关数据)
|
||||
_, err = db.Exec("DELETE FROM conversations WHERE id = ?", id)
|
||||
if err != nil {
|
||||
return fmt.Errorf("删除对话失败: %w", err)
|
||||
}
|
||||
db.removeConversationScopedDirs(id)
|
||||
db.removeConversationScopedDirs(id, projectID)
|
||||
|
||||
db.logger.Info("对话已删除(漏洞记录已保留)", zap.String("conversationId", id))
|
||||
return nil
|
||||
@@ -628,13 +630,35 @@ func (db *DB) removeConversationScopedDir(base, conversationID, label string) {
|
||||
}
|
||||
}
|
||||
|
||||
func (db *DB) removeConversationScopedDirs(conversationID string) {
|
||||
// summarization transcript, reduction files, etc.
|
||||
func (db *DB) einoReductionBaseDir() string {
|
||||
if db == nil {
|
||||
return ""
|
||||
}
|
||||
if base := strings.TrimSpace(db.einoReductionRootDir); base != "" {
|
||||
return base
|
||||
}
|
||||
return filepath.Join("tmp", "reduction")
|
||||
}
|
||||
|
||||
func (db *DB) removeConversationScopedDirs(conversationID, projectID string) {
|
||||
// summarization transcript, etc.
|
||||
db.removeConversationScopedDir(db.conversationArtifactsDir, conversationID, "conversation_artifacts")
|
||||
// Eino plantask JSON boards (skills_dir/.eino/plantask/<id>/).
|
||||
db.removeConversationScopedDir(db.einoPlantaskBaseDir, conversationID, "plantask")
|
||||
// Eino ADK runner checkpoints (checkpoint_dir/<id>/).
|
||||
db.removeConversationScopedDir(db.einoCheckpointBaseDir, conversationID, "eino_checkpoint")
|
||||
// Eino reduction persisted tool outputs (tmp/reduction/conversations/<id>/).
|
||||
// Project-bound sessions share projects/<id>/ — skip on single conversation delete.
|
||||
if strings.TrimSpace(projectID) == "" {
|
||||
reductionBase := filepath.Join(db.einoReductionBaseDir(), "conversations")
|
||||
db.removeConversationScopedDir(reductionBase, conversationID, "reduction")
|
||||
}
|
||||
}
|
||||
|
||||
func (db *DB) removeProjectScopedDirs(projectID string) {
|
||||
// Eino reduction persisted tool outputs (tmp/reduction/projects/<id>/).
|
||||
reductionBase := filepath.Join(db.einoReductionBaseDir(), "projects")
|
||||
db.removeConversationScopedDir(reductionBase, projectID, "reduction")
|
||||
}
|
||||
|
||||
// SaveAgentTrace 保存最后一轮代理消息轨迹与助手输出摘要。
|
||||
|
||||
@@ -19,7 +19,8 @@ func TestDeleteConversationRemovesEinoScopedDirs(t *testing.T) {
|
||||
|
||||
plantaskBase := filepath.Join(tmp, "skills", ".eino", "plantask")
|
||||
checkpointBase := filepath.Join(tmp, "eino-checkpoints")
|
||||
db.SetEinoConversationDirs(plantaskBase, checkpointBase)
|
||||
reductionBase := filepath.Join(tmp, "reduction")
|
||||
db.SetEinoConversationDirs(plantaskBase, checkpointBase, reductionBase)
|
||||
|
||||
conv, err := db.CreateConversation("cleanup test", ConversationCreateMeta{})
|
||||
if err != nil {
|
||||
@@ -34,6 +35,7 @@ func TestDeleteConversationRemovesEinoScopedDirs(t *testing.T) {
|
||||
{db.conversationArtifactsDir, "transcript.txt"},
|
||||
{plantaskBase, "task-1.json"},
|
||||
{checkpointBase, "runner-deep.ckpt"},
|
||||
{filepath.Join(reductionBase, "conversations"), "tool-output.txt"},
|
||||
} {
|
||||
dir := filepath.Join(base.root, seg)
|
||||
if err := os.MkdirAll(dir, 0o755); err != nil {
|
||||
@@ -48,10 +50,45 @@ func TestDeleteConversationRemovesEinoScopedDirs(t *testing.T) {
|
||||
t.Fatalf("DeleteConversation: %v", err)
|
||||
}
|
||||
|
||||
for _, base := range []string{db.conversationArtifactsDir, plantaskBase, checkpointBase} {
|
||||
for _, base := range []string{db.conversationArtifactsDir, plantaskBase, checkpointBase, filepath.Join(reductionBase, "conversations")} {
|
||||
dir := filepath.Join(base, seg)
|
||||
if _, statErr := os.Stat(dir); !os.IsNotExist(statErr) {
|
||||
t.Fatalf("expected removed dir %s, stat err=%v", dir, statErr)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestDeleteProjectRemovesReductionDir(t *testing.T) {
|
||||
tmp := t.TempDir()
|
||||
dbPath := filepath.Join(tmp, "conversations.db")
|
||||
db, err := NewDB(dbPath, zap.NewNop())
|
||||
if err != nil {
|
||||
t.Fatalf("NewDB: %v", err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
reductionBase := filepath.Join(tmp, "reduction")
|
||||
db.SetEinoConversationDirs("", "", reductionBase)
|
||||
|
||||
project, err := db.CreateProject(&Project{Name: "cleanup test"})
|
||||
if err != nil {
|
||||
t.Fatalf("CreateProject: %v", err)
|
||||
}
|
||||
seg := sanitizeConversationPathSegment(project.ID)
|
||||
reductionDir := filepath.Join(reductionBase, "projects", seg, "clear")
|
||||
if err := os.MkdirAll(reductionDir, 0o755); err != nil {
|
||||
t.Fatalf("mkdir %s: %v", reductionDir, err)
|
||||
}
|
||||
if err := os.WriteFile(filepath.Join(reductionDir, "call-1.txt"), []byte("x"), 0o644); err != nil {
|
||||
t.Fatalf("write: %v", err)
|
||||
}
|
||||
|
||||
if err := db.DeleteProject(project.ID); err != nil {
|
||||
t.Fatalf("DeleteProject: %v", err)
|
||||
}
|
||||
|
||||
projectReductionDir := filepath.Join(reductionBase, "projects", seg)
|
||||
if _, statErr := os.Stat(projectReductionDir); !os.IsNotExist(statErr) {
|
||||
t.Fatalf("expected removed dir %s, stat err=%v", projectReductionDir, statErr)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -51,6 +51,7 @@ type DB struct {
|
||||
conversationArtifactsDir string
|
||||
einoPlantaskBaseDir string // skills_dir + plantask_rel_dir (per-conversation subdirs)
|
||||
einoCheckpointBaseDir string // checkpoint_dir root (per-conversation subdirs)
|
||||
einoReductionRootDir string // reduction_root_dir or default tmp/reduction (conversations/<id> subdirs)
|
||||
checkpointLoopName string
|
||||
checkpointStop chan struct{}
|
||||
checkpointDone chan struct{}
|
||||
@@ -159,12 +160,14 @@ func NewDB(dbPath string, logger *zap.Logger) (*DB, error) {
|
||||
|
||||
// SetEinoConversationDirs configures best-effort filesystem cleanup on DeleteConversation.
|
||||
// plantaskBase is skills_root/plantask_rel (no conversation id); checkpointBase is checkpoint_dir root.
|
||||
func (db *DB) SetEinoConversationDirs(plantaskBase, checkpointBase string) {
|
||||
// reductionRoot is reduction_root_dir from config; empty uses tmp/reduction (conversation-scoped subdirs only).
|
||||
func (db *DB) SetEinoConversationDirs(plantaskBase, checkpointBase, reductionRoot string) {
|
||||
if db == nil {
|
||||
return
|
||||
}
|
||||
db.einoPlantaskBaseDir = strings.TrimSpace(plantaskBase)
|
||||
db.einoCheckpointBaseDir = strings.TrimSpace(checkpointBase)
|
||||
db.einoReductionRootDir = strings.TrimSpace(reductionRoot)
|
||||
}
|
||||
|
||||
// initTables 初始化数据库表
|
||||
|
||||
@@ -410,6 +410,76 @@ func (db *DB) GetToolExecutionsByIds(ids []string) ([]*mcp.ToolExecution, error)
|
||||
return executions, nil
|
||||
}
|
||||
|
||||
type toolExecutionStatDelta struct {
|
||||
totalCalls int
|
||||
successCalls int
|
||||
failedCalls int
|
||||
}
|
||||
|
||||
// PurgeToolExecutionsBefore deletes executions older than cutoff and adjusts tool_stats.
|
||||
func (db *DB) PurgeToolExecutionsBefore(cutoff time.Time) (int64, error) {
|
||||
query := `
|
||||
SELECT tool_name, status, COUNT(*) AS cnt
|
||||
FROM tool_executions
|
||||
WHERE ` + sqliteEpochGE("start_time", "<") + `
|
||||
GROUP BY tool_name, status
|
||||
`
|
||||
rows, err := db.Query(query, formatSQLiteUTC(cutoff))
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
deltas := make(map[string]*toolExecutionStatDelta)
|
||||
for rows.Next() {
|
||||
var toolName, status string
|
||||
var count int
|
||||
if err := rows.Scan(&toolName, &status, &count); err != nil {
|
||||
db.logger.Warn("读取待清理执行记录统计失败", zap.Error(err))
|
||||
continue
|
||||
}
|
||||
toolName = strings.TrimSpace(toolName)
|
||||
if toolName == "" || count <= 0 {
|
||||
continue
|
||||
}
|
||||
delta := deltas[toolName]
|
||||
if delta == nil {
|
||||
delta = &toolExecutionStatDelta{}
|
||||
deltas[toolName] = delta
|
||||
}
|
||||
delta.totalCalls += count
|
||||
switch status {
|
||||
case "failed", "cancelled":
|
||||
delta.failedCalls += count
|
||||
case "completed":
|
||||
delta.successCalls += count
|
||||
}
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
res, err := db.Exec(`DELETE FROM tool_executions WHERE `+sqliteEpochGE("start_time", "<"), formatSQLiteUTC(cutoff))
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
deleted, err := res.RowsAffected()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
for toolName, delta := range deltas {
|
||||
if err := db.DecreaseToolStats(toolName, delta.totalCalls, delta.successCalls, delta.failedCalls); err != nil {
|
||||
db.logger.Warn("清理过期执行记录后更新统计失败",
|
||||
zap.Error(err),
|
||||
zap.String("toolName", toolName),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
return deleted, nil
|
||||
}
|
||||
|
||||
// SaveToolStats 保存工具统计信息
|
||||
func (db *DB) SaveToolStats(toolName string, stats *mcp.ToolStats) error {
|
||||
var lastCallTime sql.NullTime
|
||||
|
||||
@@ -0,0 +1,122 @@
|
||||
package database
|
||||
|
||||
import (
|
||||
"path/filepath"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"cyberstrike-ai/internal/mcp"
|
||||
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
func TestPurgeToolExecutionsBefore(t *testing.T) {
|
||||
dbPath := filepath.Join(t.TempDir(), "monitor.db")
|
||||
db, err := NewDB(dbPath, zap.NewNop())
|
||||
if err != nil {
|
||||
t.Fatalf("NewDB: %v", err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
oldStart := time.Now().AddDate(0, 0, -100)
|
||||
newStart := time.Now().AddDate(0, 0, -1)
|
||||
|
||||
oldExec := &mcp.ToolExecution{
|
||||
ID: "old-completed",
|
||||
ToolName: "nmap::scan",
|
||||
Arguments: map[string]interface{}{"target": "127.0.0.1"},
|
||||
Status: "completed",
|
||||
StartTime: oldStart,
|
||||
}
|
||||
oldFailed := &mcp.ToolExecution{
|
||||
ID: "old-failed",
|
||||
ToolName: "nmap::scan",
|
||||
Arguments: map[string]interface{}{"target": "127.0.0.1"},
|
||||
Status: "failed",
|
||||
Error: "timeout",
|
||||
StartTime: oldStart,
|
||||
}
|
||||
newExec := &mcp.ToolExecution{
|
||||
ID: "new-completed",
|
||||
ToolName: "nmap::scan",
|
||||
Arguments: map[string]interface{}{"target": "127.0.0.1"},
|
||||
Status: "completed",
|
||||
StartTime: newStart,
|
||||
}
|
||||
for _, exec := range []*mcp.ToolExecution{oldExec, oldFailed, newExec} {
|
||||
if err := db.SaveToolExecution(exec); err != nil {
|
||||
t.Fatalf("SaveToolExecution(%s): %v", exec.ID, err)
|
||||
}
|
||||
}
|
||||
if err := db.UpdateToolStats("nmap::scan", 3, 2, 1, &newStart); err != nil {
|
||||
t.Fatalf("UpdateToolStats: %v", err)
|
||||
}
|
||||
|
||||
cutoff := time.Now().AddDate(0, 0, -90)
|
||||
deleted, err := db.PurgeToolExecutionsBefore(cutoff)
|
||||
if err != nil {
|
||||
t.Fatalf("PurgeToolExecutionsBefore: %v", err)
|
||||
}
|
||||
if deleted != 2 {
|
||||
t.Fatalf("deleted = %d, want 2", deleted)
|
||||
}
|
||||
|
||||
if _, err := db.GetToolExecution("old-completed"); err == nil {
|
||||
t.Fatal("old-completed should be deleted")
|
||||
}
|
||||
if _, err := db.GetToolExecution("old-failed"); err == nil {
|
||||
t.Fatal("old-failed should be deleted")
|
||||
}
|
||||
if _, err := db.GetToolExecution("new-completed"); err != nil {
|
||||
t.Fatalf("new-completed should remain: %v", err)
|
||||
}
|
||||
|
||||
stats, err := db.LoadToolStats()
|
||||
if err != nil {
|
||||
t.Fatalf("LoadToolStats: %v", err)
|
||||
}
|
||||
stat := stats["nmap::scan"]
|
||||
if stat == nil {
|
||||
t.Fatal("expected stats for nmap::scan")
|
||||
}
|
||||
if stat.TotalCalls != 1 || stat.SuccessCalls != 1 || stat.FailedCalls != 0 {
|
||||
t.Fatalf("stats after purge = %+v, want total=1 success=1 failed=0", stat)
|
||||
}
|
||||
|
||||
total, err := db.CountToolExecutions("", "")
|
||||
if err != nil {
|
||||
t.Fatalf("CountToolExecutions: %v", err)
|
||||
}
|
||||
if total != 1 {
|
||||
t.Fatalf("remaining executions = %d, want 1", total)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPurgeToolExecutionsBefore_zeroRetentionSkipsViaService(t *testing.T) {
|
||||
// RetentionDaysEffective: 0 means no purge at service layer; DB method still works when called directly.
|
||||
dbPath := filepath.Join(t.TempDir(), "monitor.db")
|
||||
db, err := NewDB(dbPath, zap.NewNop())
|
||||
if err != nil {
|
||||
t.Fatalf("NewDB: %v", err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
exec := &mcp.ToolExecution{
|
||||
ID: "ancient",
|
||||
ToolName: "curl::get",
|
||||
Arguments: map[string]interface{}{},
|
||||
Status: "completed",
|
||||
StartTime: time.Now().AddDate(-1, 0, 0),
|
||||
}
|
||||
if err := db.SaveToolExecution(exec); err != nil {
|
||||
t.Fatalf("SaveToolExecution: %v", err)
|
||||
}
|
||||
|
||||
deleted, err := db.PurgeToolExecutionsBefore(time.Now())
|
||||
if err != nil {
|
||||
t.Fatalf("PurgeToolExecutionsBefore: %v", err)
|
||||
}
|
||||
if deleted != 1 {
|
||||
t.Fatalf("deleted = %d, want 1", deleted)
|
||||
}
|
||||
}
|
||||
@@ -195,6 +195,7 @@ func (db *DB) DeleteProject(id string) error {
|
||||
if err != nil {
|
||||
return fmt.Errorf("删除项目失败: %w", err)
|
||||
}
|
||||
db.removeProjectScopedDirs(id)
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
+42
-69
@@ -190,6 +190,21 @@ func (h *AgentHandler) SetAudit(s *audit.Service) {
|
||||
h.audit = s
|
||||
}
|
||||
|
||||
// CancelRunningTaskForConversation stops any in-flight agent work for the conversation (idempotent).
|
||||
func (h *AgentHandler) CancelRunningTaskForConversation(conversationID string) {
|
||||
if h == nil || conversationID == "" || h.tasks == nil {
|
||||
return
|
||||
}
|
||||
if execID := h.tasks.ActiveMCPExecutionID(conversationID); execID != "" {
|
||||
h.agent.CancelMCPToolExecutionWithNote(execID, "")
|
||||
}
|
||||
if ok, err := h.tasks.CancelTask(conversationID, ErrTaskCancelled); ok {
|
||||
h.logger.Info("已取消会话运行中任务", zap.String("conversationId", conversationID))
|
||||
} else if err != nil {
|
||||
h.logger.Warn("取消会话运行中任务失败", zap.String("conversationId", conversationID), zap.Error(err))
|
||||
}
|
||||
}
|
||||
|
||||
// HitlToolWhitelistSaver 合并 HITL 免审批工具到全局配置并落盘
|
||||
type HitlToolWhitelistSaver interface {
|
||||
MergeHitlToolWhitelistIntoConfig(add []string) error
|
||||
@@ -631,40 +646,11 @@ func (h *AgentHandler) runRobotEinoSingleWithRetry(
|
||||
assistantMessageID string,
|
||||
taskStatus *string,
|
||||
) (string, string, error) {
|
||||
curHist := history
|
||||
curMsg := finalMessage
|
||||
segmentUserMessage := finalMessage
|
||||
var resultMA *multiagent.RunResult
|
||||
var errMA error
|
||||
var transientRunAttempts int
|
||||
var emptyResponseAttempts int
|
||||
for {
|
||||
resultMA, errMA = multiagent.RunEinoSingleChatModelAgent(
|
||||
taskCtx, h.config, &h.config.MultiAgent, h.agent, h.db, h.logger,
|
||||
conversationID, h.conversationProjectID(conversationID), curMsg, curHist, roleTools, progressCallback, nil, h.projectBlackboardBlock(conversationID),
|
||||
)
|
||||
handledEmpty, exhaustedEmpty := h.handleEinoEmptyResponseContinue(
|
||||
taskCtx, conversationID, resultMA, errMA, &emptyResponseAttempts,
|
||||
&curHist, &curMsg, segmentUserMessage, progressCallback, nil,
|
||||
)
|
||||
if exhaustedEmpty {
|
||||
errMA = nil
|
||||
break
|
||||
}
|
||||
if handledEmpty {
|
||||
continue
|
||||
}
|
||||
if errMA == nil {
|
||||
transientRunAttempts = 0
|
||||
emptyResponseAttempts = 0
|
||||
break
|
||||
}
|
||||
if handled, _ := h.handleEinoTransientRetryContinue(
|
||||
taskCtx, conversationID, resultMA, errMA, &transientRunAttempts,
|
||||
&curHist, &curMsg, segmentUserMessage, progressCallback, nil,
|
||||
); handled {
|
||||
continue
|
||||
}
|
||||
resultMA, errMA := multiagent.RunEinoSingleChatModelAgent(
|
||||
taskCtx, h.config, &h.config.MultiAgent, h.agent, h.db, h.logger,
|
||||
conversationID, h.conversationProjectID(conversationID), finalMessage, history, roleTools, progressCallback, nil, h.projectBlackboardBlock(conversationID),
|
||||
)
|
||||
if errMA != nil {
|
||||
*taskStatus = "failed"
|
||||
return h.finalizeRobotAgentError(taskCtx, assistantMessageID, conversationID, resultMA, errMA)
|
||||
}
|
||||
@@ -680,41 +666,12 @@ func (h *AgentHandler) runRobotMultiAgentWithRetry(
|
||||
assistantMessageID string,
|
||||
taskStatus *string,
|
||||
) (string, string, error) {
|
||||
curHist := history
|
||||
curMsg := finalMessage
|
||||
segmentUserMessage := finalMessage
|
||||
var resultMA *multiagent.RunResult
|
||||
var errMA error
|
||||
var transientRunAttempts int
|
||||
var emptyResponseAttempts int
|
||||
for {
|
||||
resultMA, errMA = multiagent.RunDeepAgent(
|
||||
taskCtx, h.config, &h.config.MultiAgent, h.agent, h.db, h.logger,
|
||||
conversationID, h.conversationProjectID(conversationID), curMsg, curHist, roleTools, progressCallback,
|
||||
h.agentsMarkdownDir, orchestration, nil, h.projectBlackboardBlock(conversationID),
|
||||
)
|
||||
handledEmpty, exhaustedEmpty := h.handleEinoEmptyResponseContinue(
|
||||
taskCtx, conversationID, resultMA, errMA, &emptyResponseAttempts,
|
||||
&curHist, &curMsg, segmentUserMessage, progressCallback, nil,
|
||||
)
|
||||
if exhaustedEmpty {
|
||||
errMA = nil
|
||||
break
|
||||
}
|
||||
if handledEmpty {
|
||||
continue
|
||||
}
|
||||
if errMA == nil {
|
||||
transientRunAttempts = 0
|
||||
emptyResponseAttempts = 0
|
||||
break
|
||||
}
|
||||
if handled, _ := h.handleEinoTransientRetryContinue(
|
||||
taskCtx, conversationID, resultMA, errMA, &transientRunAttempts,
|
||||
&curHist, &curMsg, segmentUserMessage, progressCallback, nil,
|
||||
); handled {
|
||||
continue
|
||||
}
|
||||
resultMA, errMA := multiagent.RunDeepAgent(
|
||||
taskCtx, h.config, &h.config.MultiAgent, h.agent, h.db, h.logger,
|
||||
conversationID, h.conversationProjectID(conversationID), finalMessage, history, roleTools, progressCallback,
|
||||
h.agentsMarkdownDir, orchestration, nil, h.projectBlackboardBlock(conversationID),
|
||||
)
|
||||
if errMA != nil {
|
||||
*taskStatus = "failed"
|
||||
return h.finalizeRobotAgentError(taskCtx, assistantMessageID, conversationID, resultMA, errMA)
|
||||
}
|
||||
@@ -1379,6 +1336,21 @@ func (h *AgentHandler) CancelAgentLoop(c *gin.Context) {
|
||||
})
|
||||
return
|
||||
}
|
||||
if h.tasks.AbortActiveEinoExecute(req.ConversationID, note) {
|
||||
h.logger.Info("对话页仅终止当前 Eino execute",
|
||||
zap.String("conversationId", req.ConversationID),
|
||||
zap.Bool("hasNote", note != ""),
|
||||
)
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"status": "tool_abort_requested",
|
||||
"conversationId": req.ConversationID,
|
||||
"message": "已请求终止当前 execute 命令;命令返回后本轮推理将继续。",
|
||||
"continueAfter": true,
|
||||
"interruptWithNote": note != "",
|
||||
"continueWithoutTool": false,
|
||||
})
|
||||
return
|
||||
}
|
||||
// 无进行中的 MCP 工具(模型纯推理/流式输出阶段):取消当前上下文并由 Eino 流式处理器合并用户补充后自动续跑。
|
||||
h.tasks.SetInterruptContinueNote(req.ConversationID, note)
|
||||
ok, err := h.tasks.CancelTask(req.ConversationID, multiagent.ErrInterruptContinue)
|
||||
@@ -2273,6 +2245,7 @@ func (h *AgentHandler) executeBatchQueue(queueID string) {
|
||||
progressCallback = h.createProgressCallback(taskCtx, cancelWithCause, conversationID, assistantMessageID, sendEvent)
|
||||
taskCtx = mcp.WithMCPConversationID(taskCtx, conversationID)
|
||||
taskCtx = mcp.WithToolRunRegistry(taskCtx, h.tasks)
|
||||
taskCtx = mcp.WithEinoExecuteRunRegistry(taskCtx, h.tasks)
|
||||
|
||||
// 使用队列配置的角色工具列表(如果为空,表示使用所有工具)
|
||||
useBatchMulti := false
|
||||
|
||||
@@ -12,11 +12,17 @@ import (
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// ConversationTaskStopper cancels in-flight agent work when a conversation is removed.
|
||||
type ConversationTaskStopper interface {
|
||||
CancelRunningTaskForConversation(conversationID string)
|
||||
}
|
||||
|
||||
// ConversationHandler 对话处理器
|
||||
type ConversationHandler struct {
|
||||
db *database.DB
|
||||
logger *zap.Logger
|
||||
audit *audit.Service
|
||||
db *database.DB
|
||||
logger *zap.Logger
|
||||
audit *audit.Service
|
||||
taskStopper ConversationTaskStopper
|
||||
}
|
||||
|
||||
// SetAudit wires platform audit logging.
|
||||
@@ -24,6 +30,11 @@ func (h *ConversationHandler) SetAudit(s *audit.Service) {
|
||||
h.audit = s
|
||||
}
|
||||
|
||||
// SetTaskStopper wires cancellation of in-flight agent tasks on conversation delete.
|
||||
func (h *ConversationHandler) SetTaskStopper(stopper ConversationTaskStopper) {
|
||||
h.taskStopper = stopper
|
||||
}
|
||||
|
||||
// NewConversationHandler 创建新的对话处理器
|
||||
func NewConversationHandler(db *database.DB, logger *zap.Logger) *ConversationHandler {
|
||||
return &ConversationHandler{
|
||||
@@ -245,6 +256,10 @@ func (h *ConversationHandler) UpdateConversation(c *gin.Context) {
|
||||
func (h *ConversationHandler) DeleteConversation(c *gin.Context) {
|
||||
id := c.Param("id")
|
||||
|
||||
if h.taskStopper != nil {
|
||||
h.taskStopper.CancelRunningTaskForConversation(id)
|
||||
}
|
||||
|
||||
if err := h.db.DeleteConversation(id); err != nil {
|
||||
h.logger.Error("删除对话失败", zap.Error(err))
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
|
||||
@@ -0,0 +1,30 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
func TestConversationHandlerDeleteConversationCancelsRunningTask(t *testing.T) {
|
||||
tm := NewAgentTaskManager()
|
||||
ctx, cancel := context.WithCancelCause(context.Background())
|
||||
_, err := tm.StartTask("conv-1", "hello", cancel)
|
||||
if err != nil {
|
||||
t.Fatalf("StartTask: %v", err)
|
||||
}
|
||||
|
||||
h := &AgentHandler{tasks: tm, logger: zap.NewNop()}
|
||||
h.CancelRunningTaskForConversation("conv-1")
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("task context was not cancelled")
|
||||
}
|
||||
if cause := context.Cause(ctx); cause != ErrTaskCancelled {
|
||||
t.Fatalf("expected ErrTaskCancelled, got %v", cause)
|
||||
}
|
||||
}
|
||||
@@ -2,31 +2,11 @@ package handler
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"cyberstrike-ai/internal/agent"
|
||||
"cyberstrike-ai/internal/multiagent"
|
||||
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
func (h *AgentHandler) einoRunRetryMaxAttempts() int {
|
||||
if h.config != nil {
|
||||
return multiagent.RunRetryMaxAttemptsFromConfig(&h.config.MultiAgent.EinoMiddleware)
|
||||
}
|
||||
return multiagent.RunRetryMaxAttemptsFromConfig(nil)
|
||||
}
|
||||
|
||||
func (h *AgentHandler) einoRunRetryMaxBackoffSec() int {
|
||||
if h.config != nil && h.config.MultiAgent.EinoMiddleware.RunRetryMaxBackoffSec > 0 {
|
||||
return h.config.MultiAgent.EinoMiddleware.RunRetryMaxBackoffSec
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
// applyEinoTraceResumeSegment 中断并继续:persist last_react_* → loadHistory,可选替换下一段 user 文案。
|
||||
func (h *AgentHandler) applyEinoTraceResumeSegment(
|
||||
conversationID string,
|
||||
@@ -45,136 +25,3 @@ func (h *AgentHandler) applyEinoTraceResumeSegment(
|
||||
*curFinalMessage = segmentUserMessage
|
||||
}
|
||||
}
|
||||
|
||||
// applyEinoTransientRetrySegment 临时错误重试:恢复轨迹并保留本请求原始 user 文案(不注入续跑说明)。
|
||||
// segmentUserMessage 为本轮 HTTP 请求开始时用户发送的内容,避免因清空 finalMessage 而丢失「你好」等短句。
|
||||
func (h *AgentHandler) applyEinoTransientRetrySegment(
|
||||
conversationID string,
|
||||
result *multiagent.RunResult,
|
||||
curHistory *[]agent.ChatMessage,
|
||||
curFinalMessage *string,
|
||||
segmentUserMessage string,
|
||||
) {
|
||||
if shouldPersistEinoAgentTraceAfterRunError(context.Background()) {
|
||||
h.persistEinoAgentTraceForResume(conversationID, result)
|
||||
}
|
||||
if hist, err := h.loadHistoryFromAgentTrace(conversationID); err == nil && len(hist) > 0 {
|
||||
*curHistory = hist
|
||||
}
|
||||
if s := strings.TrimSpace(segmentUserMessage); s != "" {
|
||||
*curFinalMessage = segmentUserMessage
|
||||
}
|
||||
}
|
||||
|
||||
// handleEinoTransientRetryContinue 在 SSE 任务循环内处理临时错误重试;返回 true 表示外层 for 应 continue。
|
||||
func (h *AgentHandler) handleEinoTransientRetryContinue(
|
||||
baseCtx context.Context,
|
||||
conversationID string,
|
||||
result *multiagent.RunResult,
|
||||
runErr error,
|
||||
transientAttempts *int,
|
||||
curHistory *[]agent.ChatMessage,
|
||||
curFinalMessage *string,
|
||||
segmentUserMessage string,
|
||||
progressCallback func(eventType, message string, data interface{}),
|
||||
sendProgress func(msg string, extra map[string]interface{}),
|
||||
) (handled bool, fatal error) {
|
||||
if !errors.Is(runErr, multiagent.ErrTransientRetryContinue) {
|
||||
return false, nil
|
||||
}
|
||||
maxAttempts := h.einoRunRetryMaxAttempts()
|
||||
*transientAttempts++
|
||||
if *transientAttempts > maxAttempts {
|
||||
if shouldPersistEinoAgentTraceAfterRunError(baseCtx) {
|
||||
h.persistEinoAgentTraceForResume(conversationID, result)
|
||||
}
|
||||
return false, errors.New("transient retry exhausted: " + runErr.Error())
|
||||
}
|
||||
attemptNo := *transientAttempts
|
||||
backoff := multiagent.TransientRetryBackoff(attemptNo-1, h.einoRunRetryMaxBackoffSec())
|
||||
if progressCallback != nil {
|
||||
progressCallback("eino_run_retry", fmt.Sprintf("遇到临时错误,%d 秒后第 %d/%d 次重试…", int(backoff.Seconds()), attemptNo, maxAttempts), map[string]interface{}{
|
||||
"conversationId": conversationID,
|
||||
"source": "eino",
|
||||
"attempt": attemptNo,
|
||||
"maxAttempts": maxAttempts,
|
||||
"backoffSec": int(backoff.Seconds()),
|
||||
})
|
||||
}
|
||||
select {
|
||||
case <-baseCtx.Done():
|
||||
return false, context.Cause(baseCtx)
|
||||
case <-time.After(backoff):
|
||||
}
|
||||
h.applyEinoTransientRetrySegment(conversationID, result, curHistory, curFinalMessage, segmentUserMessage)
|
||||
if progressCallback != nil {
|
||||
progressCallback("eino_run_retry", "已恢复上下文,正在重试…", map[string]interface{}{
|
||||
"conversationId": conversationID,
|
||||
"source": "eino",
|
||||
"attempt": attemptNo,
|
||||
})
|
||||
}
|
||||
if sendProgress != nil {
|
||||
sendProgress("正在重试…", map[string]interface{}{
|
||||
"conversationId": conversationID,
|
||||
"source": "transient_retry",
|
||||
})
|
||||
}
|
||||
return true, nil
|
||||
}
|
||||
|
||||
// handleEinoEmptyResponseContinue 在 SSE 任务循环内处理「正常结束但无助手正文」;返回 exhausted=true 时由外层按成功结束(保留占位文案)。
|
||||
// 与临时错误重试一致:仅恢复轨迹并保留本请求原始 user 文案,不向模型注入续跑说明。
|
||||
func (h *AgentHandler) handleEinoEmptyResponseContinue(
|
||||
baseCtx context.Context,
|
||||
conversationID string,
|
||||
result *multiagent.RunResult,
|
||||
runErr error,
|
||||
emptyResponseAttempts *int,
|
||||
curHistory *[]agent.ChatMessage,
|
||||
curFinalMessage *string,
|
||||
segmentUserMessage string,
|
||||
progressCallback func(eventType, message string, data interface{}),
|
||||
sendProgress func(msg string, extra map[string]interface{}),
|
||||
) (handled bool, exhausted bool) {
|
||||
if !errors.Is(runErr, multiagent.ErrEmptyResponseContinue) {
|
||||
return false, false
|
||||
}
|
||||
maxAttempts := h.einoRunRetryMaxAttempts()
|
||||
*emptyResponseAttempts++
|
||||
if *emptyResponseAttempts > maxAttempts {
|
||||
if h.logger != nil {
|
||||
h.logger.Warn("eino empty response auto resume exhausted",
|
||||
zap.String("conversationId", conversationID),
|
||||
zap.Int("maxAttempts", maxAttempts))
|
||||
}
|
||||
if shouldPersistEinoAgentTraceAfterRunError(baseCtx) {
|
||||
h.persistEinoAgentTraceForResume(conversationID, result)
|
||||
}
|
||||
return false, true
|
||||
}
|
||||
attemptNo := *emptyResponseAttempts
|
||||
if h.logger != nil {
|
||||
h.logger.Info("eino empty response, auto resume from trace",
|
||||
zap.String("conversationId", conversationID),
|
||||
zap.Int("attempt", attemptNo),
|
||||
zap.Int("maxAttempts", maxAttempts))
|
||||
}
|
||||
if progressCallback != nil {
|
||||
progressCallback("eino_empty_response_continue", fmt.Sprintf("未捕获到助手正文,正在基于轨迹自动续跑(%d/%d)…", attemptNo, maxAttempts), map[string]interface{}{
|
||||
"conversationId": conversationID,
|
||||
"source": "eino",
|
||||
"attempt": attemptNo,
|
||||
"maxAttempts": maxAttempts,
|
||||
"resumeKind": "trace_segment",
|
||||
})
|
||||
}
|
||||
h.applyEinoTransientRetrySegment(conversationID, result, curHistory, curFinalMessage, segmentUserMessage)
|
||||
if sendProgress != nil {
|
||||
sendProgress("已恢复上下文,正在继续推理…", map[string]interface{}{
|
||||
"conversationId": conversationID,
|
||||
"source": "empty_response_continue",
|
||||
})
|
||||
}
|
||||
return true, false
|
||||
}
|
||||
|
||||
@@ -119,7 +119,6 @@ func (h *AgentHandler) EinoSingleAgentLoopStream(c *gin.Context) {
|
||||
|
||||
var cancelWithCause context.CancelCauseFunc
|
||||
curFinalMessage := prep.FinalMessage
|
||||
segmentUserMessage := prep.FinalMessage // 本请求原始用户句,临时重试时不得丢失
|
||||
curHistory := prep.History
|
||||
roleTools := prep.RoleTools
|
||||
|
||||
@@ -177,8 +176,6 @@ func (h *AgentHandler) EinoSingleAgentLoopStream(c *gin.Context) {
|
||||
taskOwned = true
|
||||
|
||||
var cumulativeMCPExecutionIDs []string
|
||||
var transientRunAttempts int
|
||||
var emptyResponseAttempts int
|
||||
// 同一请求内分段续跑时,主代理 iteration 事件按偏移累计,避免 UI 出现「第3轮 → 第1轮」回跳。
|
||||
var mainIterationOffset int
|
||||
|
||||
@@ -215,6 +212,7 @@ func (h *AgentHandler) EinoSingleAgentLoopStream(c *gin.Context) {
|
||||
}
|
||||
taskCtxLoop := mcp.WithMCPConversationID(taskCtx, conversationID)
|
||||
taskCtxLoop = mcp.WithToolRunRegistry(taskCtxLoop, h.tasks)
|
||||
taskCtxLoop = mcp.WithEinoExecuteRunRegistry(taskCtxLoop, h.tasks)
|
||||
taskCtxLoop = multiagent.WithHITLToolInterceptor(taskCtxLoop, func(ctx context.Context, toolName, arguments string) (string, error) {
|
||||
return h.interceptHITLForEinoTool(ctx, cancelWithCause, conversationID, assistantMessageID, sendEvent, toolName, arguments)
|
||||
})
|
||||
@@ -240,54 +238,11 @@ func (h *AgentHandler) EinoSingleAgentLoopStream(c *gin.Context) {
|
||||
cumulativeMCPExecutionIDs = mergeMCPExecutionIDLists(cumulativeMCPExecutionIDs, result.MCPExecutionIDs)
|
||||
}
|
||||
|
||||
handledEmpty, exhaustedEmpty := h.handleEinoEmptyResponseContinue(
|
||||
baseCtx, conversationID, result, runErr, &emptyResponseAttempts,
|
||||
&curHistory, &curFinalMessage, segmentUserMessage, progressCallback,
|
||||
func(msg string, extra map[string]interface{}) { sendEvent("progress", msg, extra) },
|
||||
)
|
||||
if exhaustedEmpty {
|
||||
runErr = nil
|
||||
transientRunAttempts = 0
|
||||
timeoutCancel()
|
||||
break
|
||||
}
|
||||
if handledEmpty {
|
||||
mainIterationOffset += segmentMainIterationMax
|
||||
transientRunAttempts = 0
|
||||
timeoutCancel()
|
||||
baseCtx, cancelWithCause = context.WithCancelCause(context.Background())
|
||||
h.tasks.BindTaskCancel(conversationID, cancelWithCause)
|
||||
taskCtx, timeoutCancel = context.WithTimeout(baseCtx, 600*time.Minute)
|
||||
h.tasks.UpdateTaskStatus(conversationID, "running")
|
||||
continue
|
||||
}
|
||||
|
||||
if runErr == nil {
|
||||
// 任一段成功完成后,重置临时错误重试窗口(次数/退避从头开始)。
|
||||
transientRunAttempts = 0
|
||||
emptyResponseAttempts = 0
|
||||
timeoutCancel()
|
||||
break
|
||||
}
|
||||
|
||||
handled, fatalErr := h.handleEinoTransientRetryContinue(
|
||||
baseCtx, conversationID, result, runErr, &transientRunAttempts,
|
||||
&curHistory, &curFinalMessage, segmentUserMessage, progressCallback,
|
||||
func(msg string, extra map[string]interface{}) { sendEvent("progress", msg, extra) },
|
||||
)
|
||||
if handled {
|
||||
mainIterationOffset += segmentMainIterationMax
|
||||
timeoutCancel()
|
||||
baseCtx, cancelWithCause = context.WithCancelCause(context.Background())
|
||||
h.tasks.BindTaskCancel(conversationID, cancelWithCause)
|
||||
taskCtx, timeoutCancel = context.WithTimeout(baseCtx, 600*time.Minute)
|
||||
h.tasks.UpdateTaskStatus(conversationID, "running")
|
||||
continue
|
||||
}
|
||||
if fatalErr != nil {
|
||||
runErr = fatalErr
|
||||
}
|
||||
|
||||
cause := context.Cause(baseCtx)
|
||||
if errors.Is(cause, multiagent.ErrInterruptContinue) {
|
||||
if shouldPersistEinoAgentTraceAfterRunError(baseCtx) {
|
||||
@@ -312,8 +267,6 @@ func (h *AgentHandler) EinoSingleAgentLoopStream(c *gin.Context) {
|
||||
"source": "interrupt_continue",
|
||||
})
|
||||
mainIterationOffset += segmentMainIterationMax
|
||||
// 非临时错误分段续跑(用户中断并继续)时,清空 transient 计数,避免跨分段累加。
|
||||
transientRunAttempts = 0
|
||||
timeoutCancel()
|
||||
baseCtx, cancelWithCause = context.WithCancelCause(context.Background())
|
||||
h.tasks.BindTaskCancel(conversationID, cancelWithCause)
|
||||
@@ -448,8 +401,6 @@ func (h *AgentHandler) EinoSingleAgentLoop(c *gin.Context) {
|
||||
curMsg := prep.FinalMessage
|
||||
var result *multiagent.RunResult
|
||||
var runErr error
|
||||
var transientRunAttempts int
|
||||
var emptyResponseAttempts int
|
||||
for {
|
||||
result, runErr = multiagent.RunEinoSingleChatModelAgent(
|
||||
taskCtx,
|
||||
@@ -467,28 +418,9 @@ func (h *AgentHandler) EinoSingleAgentLoop(c *gin.Context) {
|
||||
chatReasoningToClientIntent(req.Reasoning),
|
||||
h.projectBlackboardBlock(prep.ConversationID),
|
||||
)
|
||||
handledEmpty, exhaustedEmpty := h.handleEinoEmptyResponseContinue(
|
||||
baseCtx, prep.ConversationID, result, runErr, &emptyResponseAttempts,
|
||||
&curHist, &curMsg, prep.FinalMessage, progressCallback, nil,
|
||||
)
|
||||
if exhaustedEmpty {
|
||||
runErr = nil
|
||||
break
|
||||
}
|
||||
if handledEmpty {
|
||||
continue
|
||||
}
|
||||
if runErr == nil {
|
||||
break
|
||||
}
|
||||
if handled, fatalErr := h.handleEinoTransientRetryContinue(
|
||||
baseCtx, prep.ConversationID, result, runErr, &transientRunAttempts,
|
||||
&curHist, &curMsg, prep.FinalMessage, progressCallback, nil,
|
||||
); handled {
|
||||
continue
|
||||
} else if fatalErr != nil {
|
||||
runErr = fatalErr
|
||||
}
|
||||
if shouldPersistEinoAgentTraceAfterRunError(baseCtx) {
|
||||
h.persistEinoAgentTraceForResume(prep.ConversationID, result)
|
||||
}
|
||||
|
||||
+37
-20
@@ -10,8 +10,10 @@ import (
|
||||
"time"
|
||||
|
||||
"cyberstrike-ai/internal/audit"
|
||||
"cyberstrike-ai/internal/config"
|
||||
"cyberstrike-ai/internal/database"
|
||||
"cyberstrike-ai/internal/mcp"
|
||||
"cyberstrike-ai/internal/monitor"
|
||||
"cyberstrike-ai/internal/security"
|
||||
"github.com/gin-gonic/gin"
|
||||
"go.uber.org/zap"
|
||||
@@ -19,12 +21,18 @@ import (
|
||||
|
||||
// MonitorHandler 监控处理器
|
||||
type MonitorHandler struct {
|
||||
mcpServer *mcp.Server
|
||||
externalMCPMgr *mcp.ExternalMCPManager
|
||||
executor *security.Executor
|
||||
db *database.DB
|
||||
logger *zap.Logger
|
||||
audit *audit.Service
|
||||
mcpServer *mcp.Server
|
||||
externalMCPMgr *mcp.ExternalMCPManager
|
||||
executor *security.Executor
|
||||
db *database.DB
|
||||
logger *zap.Logger
|
||||
audit *audit.Service
|
||||
monitorRetention *monitor.Service
|
||||
}
|
||||
|
||||
// SetMonitorRetention wires MCP execution retention settings.
|
||||
func (h *MonitorHandler) SetMonitorRetention(s *monitor.Service) {
|
||||
h.monitorRetention = s
|
||||
}
|
||||
|
||||
// SetAudit wires platform audit logging.
|
||||
@@ -50,13 +58,14 @@ func (h *MonitorHandler) SetExternalMCPManager(mgr *mcp.ExternalMCPManager) {
|
||||
|
||||
// MonitorResponse 监控响应
|
||||
type MonitorResponse struct {
|
||||
Executions []*mcp.ToolExecution `json:"executions"`
|
||||
Stats map[string]*mcp.ToolStats `json:"stats"`
|
||||
Timestamp time.Time `json:"timestamp"`
|
||||
Total int `json:"total,omitempty"`
|
||||
Page int `json:"page,omitempty"`
|
||||
PageSize int `json:"page_size,omitempty"`
|
||||
TotalPages int `json:"total_pages,omitempty"`
|
||||
Executions []*mcp.ToolExecution `json:"executions"`
|
||||
Stats map[string]*mcp.ToolStats `json:"stats"`
|
||||
Timestamp time.Time `json:"timestamp"`
|
||||
Total int `json:"total,omitempty"`
|
||||
Page int `json:"page,omitempty"`
|
||||
PageSize int `json:"page_size,omitempty"`
|
||||
TotalPages int `json:"total_pages,omitempty"`
|
||||
RetentionDays int `json:"retention_days,omitempty"`
|
||||
}
|
||||
|
||||
// Monitor 获取监控信息
|
||||
@@ -89,16 +98,24 @@ func (h *MonitorHandler) Monitor(c *gin.Context) {
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, MonitorResponse{
|
||||
Executions: executions,
|
||||
Stats: stats,
|
||||
Timestamp: time.Now(),
|
||||
Total: total,
|
||||
Page: page,
|
||||
PageSize: pageSize,
|
||||
TotalPages: totalPages,
|
||||
Executions: executions,
|
||||
Stats: stats,
|
||||
Timestamp: time.Now(),
|
||||
Total: total,
|
||||
Page: page,
|
||||
PageSize: pageSize,
|
||||
TotalPages: totalPages,
|
||||
RetentionDays: h.monitorRetentionDays(),
|
||||
})
|
||||
}
|
||||
|
||||
func (h *MonitorHandler) monitorRetentionDays() int {
|
||||
if h.monitorRetention != nil {
|
||||
return h.monitorRetention.RetentionDays()
|
||||
}
|
||||
return config.MonitorConfig{}.RetentionDaysEffective()
|
||||
}
|
||||
|
||||
func (h *MonitorHandler) loadExecutions() []*mcp.ToolExecution {
|
||||
executions, _ := h.loadExecutionsWithPagination(1, 1000, "", "")
|
||||
return executions
|
||||
|
||||
@@ -136,7 +136,6 @@ func (h *AgentHandler) MultiAgentLoopStream(c *gin.Context) {
|
||||
|
||||
var cancelWithCause context.CancelCauseFunc
|
||||
curFinalMessage := prep.FinalMessage
|
||||
segmentUserMessage := prep.FinalMessage // 本请求原始用户句,临时重试时不得丢失
|
||||
curHistory := prep.History
|
||||
roleTools := prep.RoleTools
|
||||
orch := strings.TrimSpace(req.Orchestration)
|
||||
@@ -187,8 +186,6 @@ func (h *AgentHandler) MultiAgentLoopStream(c *gin.Context) {
|
||||
|
||||
// 同一 HTTP 流内多段 Run(如中断并继续)合并 MCP execution id,供最终 response / 库表与工具芯片展示完整列表
|
||||
var cumulativeMCPExecutionIDs []string
|
||||
var transientRunAttempts int
|
||||
var emptyResponseAttempts int
|
||||
// 同一请求内分段续跑时,主代理 iteration 事件按偏移累计,避免 UI 出现「第3轮 → 第1轮」回跳。
|
||||
var mainIterationOffset int
|
||||
|
||||
@@ -225,6 +222,7 @@ func (h *AgentHandler) MultiAgentLoopStream(c *gin.Context) {
|
||||
}
|
||||
taskCtxLoop := mcp.WithMCPConversationID(taskCtx, conversationID)
|
||||
taskCtxLoop = mcp.WithToolRunRegistry(taskCtxLoop, h.tasks)
|
||||
taskCtxLoop = mcp.WithEinoExecuteRunRegistry(taskCtxLoop, h.tasks)
|
||||
taskCtxLoop = multiagent.WithHITLToolInterceptor(taskCtxLoop, func(ctx context.Context, toolName, arguments string) (string, error) {
|
||||
return h.interceptHITLForEinoTool(ctx, cancelWithCause, conversationID, assistantMessageID, sendEvent, toolName, arguments)
|
||||
})
|
||||
@@ -252,54 +250,11 @@ func (h *AgentHandler) MultiAgentLoopStream(c *gin.Context) {
|
||||
cumulativeMCPExecutionIDs = mergeMCPExecutionIDLists(cumulativeMCPExecutionIDs, result.MCPExecutionIDs)
|
||||
}
|
||||
|
||||
handledEmpty, exhaustedEmpty := h.handleEinoEmptyResponseContinue(
|
||||
baseCtx, conversationID, result, runErr, &emptyResponseAttempts,
|
||||
&curHistory, &curFinalMessage, segmentUserMessage, progressCallback,
|
||||
func(msg string, extra map[string]interface{}) { sendEvent("progress", msg, extra) },
|
||||
)
|
||||
if exhaustedEmpty {
|
||||
runErr = nil
|
||||
transientRunAttempts = 0
|
||||
timeoutCancel()
|
||||
break
|
||||
}
|
||||
if handledEmpty {
|
||||
mainIterationOffset += segmentMainIterationMax
|
||||
transientRunAttempts = 0
|
||||
timeoutCancel()
|
||||
baseCtx, cancelWithCause = context.WithCancelCause(context.Background())
|
||||
h.tasks.BindTaskCancel(conversationID, cancelWithCause)
|
||||
taskCtx, timeoutCancel = context.WithTimeout(baseCtx, 600*time.Minute)
|
||||
h.tasks.UpdateTaskStatus(conversationID, "running")
|
||||
continue
|
||||
}
|
||||
|
||||
if runErr == nil {
|
||||
// 任一段成功完成后,重置临时错误重试窗口(次数/退避从头开始)。
|
||||
transientRunAttempts = 0
|
||||
emptyResponseAttempts = 0
|
||||
timeoutCancel()
|
||||
break
|
||||
}
|
||||
|
||||
handled, fatalErr := h.handleEinoTransientRetryContinue(
|
||||
baseCtx, conversationID, result, runErr, &transientRunAttempts,
|
||||
&curHistory, &curFinalMessage, segmentUserMessage, progressCallback,
|
||||
func(msg string, extra map[string]interface{}) { sendEvent("progress", msg, extra) },
|
||||
)
|
||||
if handled {
|
||||
mainIterationOffset += segmentMainIterationMax
|
||||
timeoutCancel()
|
||||
baseCtx, cancelWithCause = context.WithCancelCause(context.Background())
|
||||
h.tasks.BindTaskCancel(conversationID, cancelWithCause)
|
||||
taskCtx, timeoutCancel = context.WithTimeout(baseCtx, 600*time.Minute)
|
||||
h.tasks.UpdateTaskStatus(conversationID, "running")
|
||||
continue
|
||||
}
|
||||
if fatalErr != nil {
|
||||
runErr = fatalErr
|
||||
}
|
||||
|
||||
cause := context.Cause(baseCtx)
|
||||
if errors.Is(cause, multiagent.ErrInterruptContinue) {
|
||||
if shouldPersistEinoAgentTraceAfterRunError(baseCtx) {
|
||||
@@ -324,8 +279,6 @@ func (h *AgentHandler) MultiAgentLoopStream(c *gin.Context) {
|
||||
"source": "interrupt_continue",
|
||||
})
|
||||
mainIterationOffset += segmentMainIterationMax
|
||||
// 非临时错误分段续跑(用户中断并继续)时,清空 transient 计数,避免跨分段累加。
|
||||
transientRunAttempts = 0
|
||||
timeoutCancel()
|
||||
baseCtx, cancelWithCause = context.WithCancelCause(context.Background())
|
||||
h.tasks.BindTaskCancel(conversationID, cancelWithCause)
|
||||
@@ -460,8 +413,6 @@ func (h *AgentHandler) MultiAgentLoop(c *gin.Context) {
|
||||
curMsg := prep.FinalMessage
|
||||
var result *multiagent.RunResult
|
||||
var runErr error
|
||||
var transientRunAttempts int
|
||||
var emptyResponseAttempts int
|
||||
for {
|
||||
result, runErr = multiagent.RunDeepAgent(
|
||||
taskCtx,
|
||||
@@ -481,28 +432,9 @@ func (h *AgentHandler) MultiAgentLoop(c *gin.Context) {
|
||||
chatReasoningToClientIntent(req.Reasoning),
|
||||
h.projectBlackboardBlock(prep.ConversationID),
|
||||
)
|
||||
handledEmpty, exhaustedEmpty := h.handleEinoEmptyResponseContinue(
|
||||
baseCtx, prep.ConversationID, result, runErr, &emptyResponseAttempts,
|
||||
&curHist, &curMsg, prep.FinalMessage, progressCallback, nil,
|
||||
)
|
||||
if exhaustedEmpty {
|
||||
runErr = nil
|
||||
break
|
||||
}
|
||||
if handledEmpty {
|
||||
continue
|
||||
}
|
||||
if runErr == nil {
|
||||
break
|
||||
}
|
||||
if handled, fatalErr := h.handleEinoTransientRetryContinue(
|
||||
baseCtx, prep.ConversationID, result, runErr, &transientRunAttempts,
|
||||
&curHist, &curMsg, prep.FinalMessage, progressCallback, nil,
|
||||
); handled {
|
||||
continue
|
||||
} else if fatalErr != nil {
|
||||
runErr = fatalErr
|
||||
}
|
||||
if shouldPersistEinoAgentTraceAfterRunError(baseCtx) {
|
||||
h.persistEinoAgentTraceForResume(prep.ConversationID, result)
|
||||
}
|
||||
|
||||
@@ -594,6 +594,9 @@ func (h *RobotHandler) cmdDelete(platform, userID, convID string) string {
|
||||
h.mu.Unlock()
|
||||
h.deleteSessionBinding(sk)
|
||||
}
|
||||
if h.agentHandler != nil {
|
||||
h.agentHandler.CancelRunningTaskForConversation(convID)
|
||||
}
|
||||
if err := h.db.DeleteConversation(convID); err != nil {
|
||||
return "删除失败: " + err.Error()
|
||||
}
|
||||
|
||||
@@ -37,6 +37,11 @@ type AgentTask struct {
|
||||
// InterruptContinueNote 无 MCP 时「中断并继续」由用户在弹窗中填写的补充说明(Cancel 前写入,续跑轮次读取后清空)
|
||||
InterruptContinueNote string `json:"-"`
|
||||
|
||||
// activeEinoExecuteCancel 当前进行中的 Eino filesystem execute 取消函数(与 MCP 工具并行,供中断并继续)
|
||||
activeEinoExecuteCancel context.CancelFunc
|
||||
// activeEinoExecuteAbortNote AbortActiveEinoExecute 写入的用户说明,由 execute 收尾时合并进工具结果
|
||||
activeEinoExecuteAbortNote string
|
||||
|
||||
cancel func(error)
|
||||
}
|
||||
|
||||
@@ -70,6 +75,69 @@ func (m *AgentTaskManager) UnregisterRunningTool(conversationID, executionID str
|
||||
}
|
||||
}
|
||||
|
||||
// RegisterActiveEinoExecute 登记进行中的 Eino filesystem execute(每会话同时仅一条)。
|
||||
func (m *AgentTaskManager) RegisterActiveEinoExecute(conversationID string, cancel context.CancelFunc) {
|
||||
conversationID = strings.TrimSpace(conversationID)
|
||||
if conversationID == "" || cancel == nil {
|
||||
return
|
||||
}
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
if t, ok := m.tasks[conversationID]; ok && t != nil {
|
||||
t.activeEinoExecuteCancel = cancel
|
||||
t.activeEinoExecuteAbortNote = ""
|
||||
}
|
||||
}
|
||||
|
||||
// UnregisterActiveEinoExecute execute 正常结束或已取消后清除登记。
|
||||
func (m *AgentTaskManager) UnregisterActiveEinoExecute(conversationID string) {
|
||||
conversationID = strings.TrimSpace(conversationID)
|
||||
if conversationID == "" {
|
||||
return
|
||||
}
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
if t, ok := m.tasks[conversationID]; ok && t != nil {
|
||||
t.activeEinoExecuteCancel = nil
|
||||
t.activeEinoExecuteAbortNote = ""
|
||||
}
|
||||
}
|
||||
|
||||
// AbortActiveEinoExecute 终止当前 Eino execute 并暂存用户说明(与 MCP 工具终止一致)。
|
||||
func (m *AgentTaskManager) AbortActiveEinoExecute(conversationID, note string) bool {
|
||||
conversationID = strings.TrimSpace(conversationID)
|
||||
if conversationID == "" {
|
||||
return false
|
||||
}
|
||||
m.mu.Lock()
|
||||
t, ok := m.tasks[conversationID]
|
||||
if !ok || t == nil || t.activeEinoExecuteCancel == nil {
|
||||
m.mu.Unlock()
|
||||
return false
|
||||
}
|
||||
t.activeEinoExecuteAbortNote = strings.TrimSpace(note)
|
||||
cancel := t.activeEinoExecuteCancel
|
||||
m.mu.Unlock()
|
||||
cancel()
|
||||
return true
|
||||
}
|
||||
|
||||
// TakeEinoExecuteAbortNote 读取并清空 execute 终止说明(execute 收尾时调用一次)。
|
||||
func (m *AgentTaskManager) TakeEinoExecuteAbortNote(conversationID string) string {
|
||||
conversationID = strings.TrimSpace(conversationID)
|
||||
if conversationID == "" {
|
||||
return ""
|
||||
}
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
if t, ok := m.tasks[conversationID]; ok && t != nil {
|
||||
n := t.activeEinoExecuteAbortNote
|
||||
t.activeEinoExecuteAbortNote = ""
|
||||
return n
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// SetInterruptContinueNote 在发起 ErrInterruptContinue 取消前写入用户补充说明(仅内存)。
|
||||
func (m *AgentTaskManager) SetInterruptContinueNote(conversationID, note string) {
|
||||
conversationID = strings.TrimSpace(conversationID)
|
||||
|
||||
@@ -0,0 +1,40 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestAbortActiveEinoExecute(t *testing.T) {
|
||||
m := NewAgentTaskManager()
|
||||
conv := "conv-eino-exec-abort"
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
_, err := m.StartTask(conv, "test", func(error) {})
|
||||
if err != nil {
|
||||
t.Fatalf("StartTask: %v", err)
|
||||
}
|
||||
m.RegisterActiveEinoExecute(conv, cancel)
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
<-ctx.Done()
|
||||
close(done)
|
||||
}()
|
||||
|
||||
if !m.AbortActiveEinoExecute(conv, "跳过域名收集") {
|
||||
t.Fatal("expected abort to succeed")
|
||||
}
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("execute cancel did not propagate")
|
||||
}
|
||||
if got := m.TakeEinoExecuteAbortNote(conv); got != "跳过域名收集" {
|
||||
t.Fatalf("abort note = %q, want 跳过域名收集", got)
|
||||
}
|
||||
m.UnregisterActiveEinoExecute(conv)
|
||||
if m.AbortActiveEinoExecute(conv, "") {
|
||||
t.Fatal("second abort should fail when no active execute")
|
||||
}
|
||||
}
|
||||
@@ -11,7 +11,16 @@ type ToolRunRegistry interface {
|
||||
UnregisterRunningTool(conversationID, executionID string)
|
||||
}
|
||||
|
||||
// EinoExecuteRunRegistry 登记进行中的 Eino filesystem execute,供「中断并继续」终止 amass 等长命令。
|
||||
type EinoExecuteRunRegistry interface {
|
||||
RegisterActiveEinoExecute(conversationID string, cancel context.CancelFunc)
|
||||
UnregisterActiveEinoExecute(conversationID string)
|
||||
AbortActiveEinoExecute(conversationID, note string) bool
|
||||
TakeEinoExecuteAbortNote(conversationID string) string
|
||||
}
|
||||
|
||||
type toolRunRegistryCtxKey struct{}
|
||||
type einoExecuteRunRegistryCtxKey struct{}
|
||||
type mcpConversationIDCtxKey struct{}
|
||||
|
||||
// WithToolRunRegistry 将登记器注入 ctx(Eino / 原生 Agent 任务 ctx)。
|
||||
@@ -31,6 +40,23 @@ func ToolRunRegistryFromContext(ctx context.Context) ToolRunRegistry {
|
||||
return v
|
||||
}
|
||||
|
||||
// WithEinoExecuteRunRegistry 将 Eino execute 取消登记器注入 ctx。
|
||||
func WithEinoExecuteRunRegistry(ctx context.Context, reg EinoExecuteRunRegistry) context.Context {
|
||||
if ctx == nil || reg == nil {
|
||||
return ctx
|
||||
}
|
||||
return context.WithValue(ctx, einoExecuteRunRegistryCtxKey{}, reg)
|
||||
}
|
||||
|
||||
// EinoExecuteRunRegistryFromContext 取出 Eino execute 登记器(无则 nil)。
|
||||
func EinoExecuteRunRegistryFromContext(ctx context.Context) EinoExecuteRunRegistry {
|
||||
if ctx == nil {
|
||||
return nil
|
||||
}
|
||||
v, _ := ctx.Value(einoExecuteRunRegistryCtxKey{}).(EinoExecuteRunRegistry)
|
||||
return v
|
||||
}
|
||||
|
||||
// WithMCPConversationID 将对话 ID 注入 ctx,供 CallTool 内与 executionId 关联。
|
||||
func WithMCPConversationID(ctx context.Context, conversationID string) context.Context {
|
||||
if ctx == nil {
|
||||
|
||||
@@ -0,0 +1,71 @@
|
||||
package monitor
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"cyberstrike-ai/internal/config"
|
||||
"cyberstrike-ai/internal/database"
|
||||
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
const retentionPurgeInterval = time.Hour
|
||||
|
||||
// Service manages MCP tool execution monitor retention.
|
||||
type Service struct {
|
||||
db *database.DB
|
||||
cfg *config.Config
|
||||
logger *zap.Logger
|
||||
}
|
||||
|
||||
// NewService creates a monitor retention service.
|
||||
func NewService(db *database.DB, cfg *config.Config, logger *zap.Logger) *Service {
|
||||
return &Service{db: db, cfg: cfg, logger: logger}
|
||||
}
|
||||
|
||||
// RetentionDays returns configured retention; 0 means keep forever.
|
||||
func (s *Service) RetentionDays() int {
|
||||
if s == nil || s.cfg == nil {
|
||||
return config.MonitorConfig{}.RetentionDaysEffective()
|
||||
}
|
||||
return s.cfg.Monitor.RetentionDaysEffective()
|
||||
}
|
||||
|
||||
// PurgeExpired deletes tool execution rows older than retention_days when configured.
|
||||
func (s *Service) PurgeExpired() {
|
||||
if s == nil || s.db == nil || s.cfg == nil {
|
||||
return
|
||||
}
|
||||
days := s.cfg.Monitor.RetentionDaysEffective()
|
||||
if days <= 0 {
|
||||
return
|
||||
}
|
||||
cutoff := time.Now().AddDate(0, 0, -days)
|
||||
n, err := s.db.PurgeToolExecutionsBefore(cutoff)
|
||||
if err != nil {
|
||||
if s.logger != nil {
|
||||
s.logger.Warn("清理过期 MCP 执行记录失败", zap.Error(err))
|
||||
}
|
||||
return
|
||||
}
|
||||
if n > 0 && s.logger != nil {
|
||||
s.logger.Info("已清理过期 MCP 执行记录", zap.Int64("deleted", n), zap.Int("retention_days", days))
|
||||
}
|
||||
}
|
||||
|
||||
// StartRetentionLoop periodically purges expired tool execution rows.
|
||||
func StartRetentionLoop(s *Service, logger *zap.Logger) {
|
||||
if s == nil {
|
||||
return
|
||||
}
|
||||
go func() {
|
||||
ticker := time.NewTicker(retentionPurgeInterval)
|
||||
defer ticker.Stop()
|
||||
for range ticker.C {
|
||||
s.PurgeExpired()
|
||||
if logger != nil {
|
||||
logger.Debug("monitor retention tick completed")
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
@@ -0,0 +1,94 @@
|
||||
package monitor
|
||||
|
||||
import (
|
||||
"path/filepath"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"cyberstrike-ai/internal/config"
|
||||
"cyberstrike-ai/internal/database"
|
||||
"cyberstrike-ai/internal/mcp"
|
||||
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
func TestServicePurgeExpired_respectsZeroRetention(t *testing.T) {
|
||||
dbPath := filepath.Join(t.TempDir(), "monitor.db")
|
||||
db, err := database.NewDB(dbPath, zap.NewNop())
|
||||
if err != nil {
|
||||
t.Fatalf("NewDB: %v", err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
exec := &mcp.ToolExecution{
|
||||
ID: "ancient",
|
||||
ToolName: "curl::get",
|
||||
Arguments: map[string]interface{}{},
|
||||
Status: "completed",
|
||||
StartTime: mustParseTime(t, "2020-01-01T00:00:00Z"),
|
||||
}
|
||||
if err := db.SaveToolExecution(exec); err != nil {
|
||||
t.Fatalf("SaveToolExecution: %v", err)
|
||||
}
|
||||
|
||||
zero := 0
|
||||
svc := NewService(db, &config.Config{
|
||||
Monitor: config.MonitorConfig{RetentionDays: &zero},
|
||||
}, zap.NewNop())
|
||||
svc.PurgeExpired()
|
||||
|
||||
if _, err := db.GetToolExecution("ancient"); err != nil {
|
||||
t.Fatalf("record should remain when retention_days=0: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestServicePurgeExpired_deletesOldRows(t *testing.T) {
|
||||
dbPath := filepath.Join(t.TempDir(), "monitor.db")
|
||||
db, err := database.NewDB(dbPath, zap.NewNop())
|
||||
if err != nil {
|
||||
t.Fatalf("NewDB: %v", err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
exec := &mcp.ToolExecution{
|
||||
ID: "ancient",
|
||||
ToolName: "curl::get",
|
||||
Arguments: map[string]interface{}{},
|
||||
Status: "completed",
|
||||
StartTime: mustParseTime(t, "2020-01-01T00:00:00Z"),
|
||||
}
|
||||
if err := db.SaveToolExecution(exec); err != nil {
|
||||
t.Fatalf("SaveToolExecution: %v", err)
|
||||
}
|
||||
|
||||
days := 90
|
||||
svc := NewService(db, &config.Config{
|
||||
Monitor: config.MonitorConfig{RetentionDays: &days},
|
||||
}, zap.NewNop())
|
||||
svc.PurgeExpired()
|
||||
|
||||
if _, err := db.GetToolExecution("ancient"); err == nil {
|
||||
t.Fatal("record should be purged when older than retention_days")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRetentionDaysEffective_defaults(t *testing.T) {
|
||||
got := config.MonitorConfig{}.RetentionDaysEffective()
|
||||
if got != 90 {
|
||||
t.Fatalf("default = %d, want 90", got)
|
||||
}
|
||||
zero := 0
|
||||
cfg := config.MonitorConfig{RetentionDays: &zero}
|
||||
if cfg.RetentionDaysEffective() != 0 {
|
||||
t.Fatalf("zero = %d, want 0", cfg.RetentionDaysEffective())
|
||||
}
|
||||
}
|
||||
|
||||
func mustParseTime(t *testing.T, value string) time.Time {
|
||||
t.Helper()
|
||||
parsed, err := time.Parse(time.RFC3339, value)
|
||||
if err != nil {
|
||||
t.Fatalf("parse time: %v", err)
|
||||
}
|
||||
return parsed
|
||||
}
|
||||
@@ -0,0 +1,104 @@
|
||||
package multiagent
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
|
||||
"github.com/cloudwego/eino/adk"
|
||||
"github.com/cloudwego/eino/schema"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// continuationSessionMarker matches Cursor / IDE session-resume user injections.
|
||||
const continuationSessionMarker = "This session is being continued from a previous conversation"
|
||||
|
||||
// continuationUserDedupMiddleware keeps only the latest session-resume user message when
|
||||
// multiple continuation injections were stacked (e.g. after repeated out-of-context resumes).
|
||||
type continuationUserDedupMiddleware struct {
|
||||
adk.BaseChatModelAgentMiddleware
|
||||
logger *zap.Logger
|
||||
phase string
|
||||
}
|
||||
|
||||
func newContinuationUserDedupMiddleware(logger *zap.Logger, phase string) adk.ChatModelAgentMiddleware {
|
||||
return &continuationUserDedupMiddleware{logger: logger, phase: phase}
|
||||
}
|
||||
|
||||
func (m *continuationUserDedupMiddleware) BeforeModelRewriteState(
|
||||
ctx context.Context,
|
||||
state *adk.ChatModelAgentState,
|
||||
mc *adk.ModelContext,
|
||||
) (context.Context, *adk.ChatModelAgentState, error) {
|
||||
_ = mc
|
||||
if m == nil || state == nil || len(state.Messages) == 0 {
|
||||
return ctx, state, nil
|
||||
}
|
||||
deduped, dropped := dedupContinuationUserMessages(state.Messages)
|
||||
if dropped == 0 {
|
||||
return ctx, state, nil
|
||||
}
|
||||
if m.logger != nil {
|
||||
m.logger.Info("eino continuation user messages deduplicated",
|
||||
zap.String("phase", m.phase),
|
||||
zap.Int("dropped", dropped),
|
||||
zap.Int("messages_before", len(state.Messages)),
|
||||
zap.Int("messages_after", len(deduped)),
|
||||
)
|
||||
}
|
||||
out := *state
|
||||
out.Messages = deduped
|
||||
return ctx, &out, nil
|
||||
}
|
||||
|
||||
func adkUserMessageText(msg adk.Message) string {
|
||||
if msg == nil {
|
||||
return ""
|
||||
}
|
||||
var b strings.Builder
|
||||
if s := strings.TrimSpace(msg.Content); s != "" {
|
||||
b.WriteString(s)
|
||||
}
|
||||
for _, part := range msg.UserInputMultiContent {
|
||||
if part.Type == schema.ChatMessagePartTypeText {
|
||||
if s := strings.TrimSpace(part.Text); s != "" {
|
||||
if b.Len() > 0 {
|
||||
b.WriteByte('\n')
|
||||
}
|
||||
b.WriteString(s)
|
||||
}
|
||||
}
|
||||
}
|
||||
return b.String()
|
||||
}
|
||||
|
||||
func isContinuationUserMessage(msg adk.Message) bool {
|
||||
if msg == nil || msg.Role != schema.User {
|
||||
return false
|
||||
}
|
||||
return strings.Contains(adkUserMessageText(msg), continuationSessionMarker)
|
||||
}
|
||||
|
||||
func dedupContinuationUserMessages(msgs []adk.Message) ([]adk.Message, int) {
|
||||
lastIdx := -1
|
||||
contCount := 0
|
||||
for i, msg := range msgs {
|
||||
if !isContinuationUserMessage(msg) {
|
||||
continue
|
||||
}
|
||||
contCount++
|
||||
lastIdx = i
|
||||
}
|
||||
if contCount <= 1 {
|
||||
return msgs, 0
|
||||
}
|
||||
out := make([]adk.Message, 0, len(msgs)-(contCount-1))
|
||||
dropped := 0
|
||||
for i, msg := range msgs {
|
||||
if isContinuationUserMessage(msg) && i != lastIdx {
|
||||
dropped++
|
||||
continue
|
||||
}
|
||||
out = append(out, msg)
|
||||
}
|
||||
return out, dropped
|
||||
}
|
||||
@@ -0,0 +1,65 @@
|
||||
package multiagent
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/cloudwego/eino/adk"
|
||||
"github.com/cloudwego/eino/schema"
|
||||
)
|
||||
|
||||
func continuationUser(text string) adk.Message {
|
||||
return &schema.Message{
|
||||
Role: schema.User,
|
||||
UserInputMultiContent: []schema.MessageInputPart{
|
||||
{Type: schema.ChatMessagePartTypeText, Text: continuationSessionMarker + "\n" + text},
|
||||
{Type: schema.ChatMessagePartTypeText, Text: "Please continue the conversation from where we left it off."},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func TestDedupContinuationUserMessages_KeepsLatest(t *testing.T) {
|
||||
msgs := []adk.Message{
|
||||
continuationUser("summary old"),
|
||||
schema.UserMessage("real task"),
|
||||
continuationUser("summary new"),
|
||||
}
|
||||
out, dropped := dedupContinuationUserMessages(msgs)
|
||||
if dropped != 1 {
|
||||
t.Fatalf("dropped=%d want 1", dropped)
|
||||
}
|
||||
if len(out) != 2 {
|
||||
t.Fatalf("len=%d want 2", len(out))
|
||||
}
|
||||
if out[0].Role != schema.User || adkUserMessageText(out[0]) != "real task" {
|
||||
t.Fatalf("first should remain real task, got %q", adkUserMessageText(out[0]))
|
||||
}
|
||||
if !strings.Contains(adkUserMessageText(out[1]), "summary new") {
|
||||
t.Fatalf("latest continuation not kept: %q", adkUserMessageText(out[1]))
|
||||
}
|
||||
}
|
||||
|
||||
func TestDedupContinuationUserMessages_NoOpSingle(t *testing.T) {
|
||||
msgs := []adk.Message{continuationUser("only"), schema.UserMessage("task")}
|
||||
out, dropped := dedupContinuationUserMessages(msgs)
|
||||
if dropped != 0 || len(out) != 2 {
|
||||
t.Fatalf("unexpected change dropped=%d len=%d", dropped, len(out))
|
||||
}
|
||||
}
|
||||
|
||||
func TestContinuationUserDedupMiddleware(t *testing.T) {
|
||||
mw := newContinuationUserDedupMiddleware(nil, "test")
|
||||
state := &adk.ChatModelAgentState{Messages: []adk.Message{
|
||||
continuationUser("old"),
|
||||
continuationUser("new"),
|
||||
schema.UserMessage("task"),
|
||||
}}
|
||||
_, out, err := mw.(*continuationUserDedupMiddleware).BeforeModelRewriteState(context.Background(), state, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(out.Messages) != 2 {
|
||||
t.Fatalf("want 2 messages after dedup, got %d", len(out.Messages))
|
||||
}
|
||||
}
|
||||
@@ -383,6 +383,12 @@ func runEinoADKAgentLoop(ctx context.Context, args *einoADKRunLoopArgs, baseMsgs
|
||||
}
|
||||
}
|
||||
runner := adk.NewRunner(ctx, runnerCfg)
|
||||
startRunnerIter := func(runMsgs []adk.Message) *adk.AsyncIterator[*adk.AgentEvent] {
|
||||
if checkPointID != "" {
|
||||
return runner.Run(ctx, runMsgs, adk.WithCheckPointID(checkPointID))
|
||||
}
|
||||
return runner.Run(ctx, runMsgs)
|
||||
}
|
||||
var iter *adk.AsyncIterator[*adk.AgentEvent]
|
||||
if cpStore != nil && checkPointID != "" {
|
||||
if _, existed, getErr := cpStore.Get(ctx, checkPointID); getErr != nil {
|
||||
@@ -422,12 +428,9 @@ func runEinoADKAgentLoop(ctx context.Context, args *einoADKRunLoopArgs, baseMsgs
|
||||
}
|
||||
}
|
||||
if iter == nil {
|
||||
if checkPointID != "" {
|
||||
iter = runner.Run(ctx, msgs, adk.WithCheckPointID(checkPointID))
|
||||
} else {
|
||||
iter = runner.Run(ctx, msgs)
|
||||
}
|
||||
iter = startRunnerIter(msgs)
|
||||
}
|
||||
transientRetrier := newEinoTransientRunRetrier(einoTransientRunRetryPolicyFromArgs(args))
|
||||
handleRunErr := func(runErr error) error {
|
||||
if runErr == nil {
|
||||
return nil
|
||||
@@ -480,26 +483,60 @@ func runEinoADKAgentLoop(ctx context.Context, args *einoADKRunLoopArgs, baseMsgs
|
||||
return runErr
|
||||
}
|
||||
|
||||
// maybeRetryTransientRun:不在此层 runner.Run/Resume;由 handler 落库 + loadHistoryFromAgentTrace 分段续跑(同中断并继续)。
|
||||
maybeRetryTransientRun := func(runErr error) (retry bool, fatal error) {
|
||||
if runErr == nil || !isEinoTransientRunError(runErr) {
|
||||
maybeRetryTransientRun := func(runErr error) (restarted bool, fatal error) {
|
||||
if runErr == nil {
|
||||
return false, nil
|
||||
}
|
||||
if !isEinoTransientRunError(runErr) {
|
||||
return false, handleRunErr(runErr)
|
||||
}
|
||||
restarted, restartMsgs, ctxSource, backoff, retErr := transientRetrier.tryRetry(
|
||||
ctx, runErr, args, baseMsgs, runAccumulatedMsgs, baseAccumulatedCount,
|
||||
)
|
||||
if retErr != nil {
|
||||
flushAllPendingAsFailed(runErr)
|
||||
if logger != nil {
|
||||
logger.Warn("eino transient retry exhausted",
|
||||
zap.Error(retErr),
|
||||
zap.String("orchestration", orchMode),
|
||||
zap.Int("maxAttempts", transientRetrier.maxAttempts()))
|
||||
}
|
||||
return false, retErr
|
||||
}
|
||||
if !restarted {
|
||||
return false, nil
|
||||
}
|
||||
attemptNo := transientRetrier.attempt()
|
||||
maxAttempts := transientRetrier.maxAttempts()
|
||||
if logger != nil {
|
||||
logger.Warn("eino transient error, ending run segment for handler resume",
|
||||
logger.Warn("eino transient error, retrying after backoff",
|
||||
zap.Error(runErr),
|
||||
zap.String("orchestration", orchMode))
|
||||
zap.String("orchestration", orchMode),
|
||||
zap.Int("attempt", attemptNo),
|
||||
zap.Int("maxAttempts", maxAttempts),
|
||||
zap.Duration("backoff", backoff))
|
||||
}
|
||||
if progress != nil {
|
||||
progress("eino_run_retry", "遇到临时错误(限流或网络波动),将保存上下文并重试…", map[string]interface{}{
|
||||
progress("eino_run_retry", fmt.Sprintf("遇到临时错误(限流或网络波动),%d 秒后第 %d/%d 次重试…", int(backoff.Seconds()), attemptNo, maxAttempts), map[string]interface{}{
|
||||
"conversationId": conversationID,
|
||||
"source": "eino",
|
||||
"orchestration": orchMode,
|
||||
"error": runErr.Error(),
|
||||
"resumeKind": "trace_segment",
|
||||
"attempt": attemptNo,
|
||||
"maxAttempts": maxAttempts,
|
||||
"backoffSec": int(backoff.Seconds()),
|
||||
})
|
||||
progress("eino_run_retry", "已恢复上下文,正在重试…", map[string]interface{}{
|
||||
"conversationId": conversationID,
|
||||
"source": "eino",
|
||||
"orchestration": orchMode,
|
||||
"attempt": attemptNo,
|
||||
"contextSource": string(ctxSource),
|
||||
})
|
||||
}
|
||||
return false, ErrTransientRetryContinue
|
||||
msgs = restartMsgs
|
||||
iter = startRunnerIter(msgs)
|
||||
return true, nil
|
||||
}
|
||||
|
||||
takePartial := func(runErr error) (*RunResult, error) {
|
||||
@@ -583,9 +620,15 @@ func runEinoADKAgentLoop(ctx context.Context, args *einoADKRunLoopArgs, baseMsgs
|
||||
continue
|
||||
}
|
||||
if ev.Err != nil {
|
||||
if _, retErr := maybeRetryTransientRun(ev.Err); retErr != nil {
|
||||
restarted, retErr := maybeRetryTransientRun(ev.Err)
|
||||
if retErr != nil {
|
||||
return takePartial(retErr)
|
||||
}
|
||||
if restarted {
|
||||
continue
|
||||
}
|
||||
} else {
|
||||
transientRetrier.reset()
|
||||
}
|
||||
if ev.AgentName != "" && progress != nil {
|
||||
iterEinoAgent := orchestratorName
|
||||
@@ -951,9 +994,13 @@ func runEinoADKAgentLoop(ctx context.Context, args *einoADKRunLoopArgs, baseMsgs
|
||||
"einoRole": einoRoleTag(ev.AgentName),
|
||||
})
|
||||
}
|
||||
if _, retErr := maybeRetryTransientRun(streamRecvErr); retErr != nil {
|
||||
restarted, retErr := maybeRetryTransientRun(streamRecvErr)
|
||||
if retErr != nil {
|
||||
return takePartial(retErr)
|
||||
}
|
||||
if restarted {
|
||||
continue
|
||||
}
|
||||
}
|
||||
continue
|
||||
}
|
||||
@@ -1057,32 +1104,9 @@ func runEinoADKAgentLoop(ctx context.Context, args *einoADKRunLoopArgs, baseMsgs
|
||||
orchMode, runAccumulatedMsgs, persistTraceSource(args, runAccumulatedMsgs),
|
||||
lastAssistant, lastPlanExecuteExecutor, emptyHint, ids, false,
|
||||
)
|
||||
if shouldEinoEmptyResponseContinue(out, emptyHint, len(runAccumulatedMsgs), baseAccumulatedCount) {
|
||||
if logger != nil {
|
||||
logger.Info("eino empty response, ending run segment for handler resume",
|
||||
zap.String("conversationId", conversationID),
|
||||
zap.String("orchestration", orchMode),
|
||||
zap.Int("traceMessages", len(runAccumulatedMsgs)))
|
||||
}
|
||||
if progress != nil {
|
||||
progress("eino_empty_response_continue", "会话已结束但未产生助手正文,正在基于轨迹自动续跑…", map[string]interface{}{
|
||||
"conversationId": conversationID,
|
||||
"source": "eino",
|
||||
"resumeKind": "trace_segment",
|
||||
})
|
||||
}
|
||||
return out, ErrEmptyResponseContinue
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func shouldEinoEmptyResponseContinue(out *RunResult, emptyHint string, accumulatedLen, baseCount int) bool {
|
||||
if out == nil || accumulatedLen <= baseCount {
|
||||
return false
|
||||
}
|
||||
return strings.TrimSpace(out.Response) == strings.TrimSpace(emptyHint)
|
||||
}
|
||||
|
||||
func persistTraceSource(args *einoADKRunLoopArgs, fallback []adk.Message) []adk.Message {
|
||||
if args != nil && args.ModelFacingTrace != nil {
|
||||
if snap := args.ModelFacingTrace.Snapshot(); len(snap) > 0 {
|
||||
|
||||
@@ -0,0 +1,50 @@
|
||||
package multiagent
|
||||
|
||||
import (
|
||||
"github.com/cloudwego/eino/adk"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// einoChatModelTailConfig configures middleware appended after reduction/skill/plantask
|
||||
// and immediately before each ChatModel invocation pipeline completes.
|
||||
//
|
||||
// Order (best practice):
|
||||
// 1. system merge — accurate token count for summarization
|
||||
// 2. continuation user dedup — drop stale session-resume injections
|
||||
// 3. summarization
|
||||
// 4. orphan tool prune
|
||||
// 5. telemetry
|
||||
// 6. model-facing trace snapshot
|
||||
type einoChatModelTailConfig struct {
|
||||
logger *zap.Logger
|
||||
phase string
|
||||
summarization adk.ChatModelAgentMiddleware
|
||||
modelName string
|
||||
conversationID string
|
||||
trace *modelFacingTraceHolder
|
||||
skipOrphanPruner bool
|
||||
skipTelemetry bool
|
||||
skipTrace bool
|
||||
}
|
||||
|
||||
func appendEinoChatModelTailMiddlewares(handlers []adk.ChatModelAgentMiddleware, cfg einoChatModelTailConfig) []adk.ChatModelAgentMiddleware {
|
||||
handlers = append(handlers, newSystemMessageNormalizerMiddleware(cfg.logger, cfg.phase))
|
||||
handlers = append(handlers, newContinuationUserDedupMiddleware(cfg.logger, cfg.phase))
|
||||
if cfg.summarization != nil {
|
||||
handlers = append(handlers, cfg.summarization)
|
||||
}
|
||||
if !cfg.skipOrphanPruner {
|
||||
handlers = append(handlers, newOrphanToolPrunerMiddleware(cfg.logger, cfg.phase))
|
||||
}
|
||||
if !cfg.skipTelemetry {
|
||||
if teleMw := newEinoModelInputTelemetryMiddleware(cfg.logger, cfg.modelName, cfg.conversationID, cfg.phase); teleMw != nil {
|
||||
handlers = append(handlers, teleMw)
|
||||
}
|
||||
}
|
||||
if !cfg.skipTrace && cfg.trace != nil {
|
||||
if capMw := newModelFacingTraceMiddleware(cfg.trace); capMw != nil {
|
||||
handlers = append(handlers, capMw)
|
||||
}
|
||||
}
|
||||
return handlers
|
||||
}
|
||||
@@ -1,21 +0,0 @@
|
||||
package multiagent
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestShouldEinoEmptyResponseContinue(t *testing.T) {
|
||||
t.Parallel()
|
||||
hint := "(empty hint)"
|
||||
out := &RunResult{Response: hint}
|
||||
if !shouldEinoEmptyResponseContinue(out, hint, 3, 1) {
|
||||
t.Fatal("expected continue when response is empty hint and trace grew")
|
||||
}
|
||||
if shouldEinoEmptyResponseContinue(out, hint, 1, 1) {
|
||||
t.Fatal("expected no continue when trace did not grow")
|
||||
}
|
||||
if shouldEinoEmptyResponseContinue(&RunResult{Response: "hello"}, hint, 3, 1) {
|
||||
t.Fatal("expected no continue when response has content")
|
||||
}
|
||||
if shouldEinoEmptyResponseContinue(nil, hint, 3, 1) {
|
||||
t.Fatal("expected no continue for nil result")
|
||||
}
|
||||
}
|
||||
@@ -6,9 +6,11 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"cyberstrike-ai/internal/einomcp"
|
||||
"cyberstrike-ai/internal/mcp"
|
||||
"cyberstrike-ai/internal/security"
|
||||
|
||||
"github.com/cloudwego/eino/adk/filesystem"
|
||||
@@ -80,15 +82,23 @@ func (w *einoStreamingShellWrap) ExecuteStreaming(ctx context.Context, input *fi
|
||||
req.Command = prependPythonUnbufferedEnv(req.Command)
|
||||
tid := strings.TrimSpace(compose.GetToolCallID(ctx))
|
||||
agentTag := strings.TrimSpace(w.einoAgentName)
|
||||
convID := mcp.MCPConversationIDFromContext(ctx)
|
||||
execReg := mcp.EinoExecuteRunRegistryFromContext(ctx)
|
||||
|
||||
execCtx := ctx
|
||||
var execCancel context.CancelFunc
|
||||
execCtx, execCancel := context.WithCancel(ctx)
|
||||
var timeoutCancel context.CancelFunc
|
||||
if w.toolTimeoutMinutes > 0 {
|
||||
execCtx, execCancel = context.WithTimeout(ctx, time.Duration(w.toolTimeoutMinutes)*time.Minute)
|
||||
execCtx, timeoutCancel = context.WithTimeout(execCtx, time.Duration(w.toolTimeoutMinutes)*time.Minute)
|
||||
}
|
||||
if execReg != nil && convID != "" {
|
||||
execReg.RegisterActiveEinoExecute(convID, execCancel)
|
||||
}
|
||||
|
||||
sr, err := w.inner.ExecuteStreaming(execCtx, &req)
|
||||
if err != nil {
|
||||
if timeoutCancel != nil {
|
||||
timeoutCancel()
|
||||
}
|
||||
if execCancel != nil {
|
||||
execCancel()
|
||||
}
|
||||
@@ -111,6 +121,9 @@ func (w *einoStreamingShellWrap) ExecuteStreaming(ctx context.Context, input *fi
|
||||
return nil, err
|
||||
}
|
||||
if sr == nil || w.invokeNotify == nil {
|
||||
if timeoutCancel != nil {
|
||||
timeoutCancel()
|
||||
}
|
||||
if execCancel != nil {
|
||||
execCancel()
|
||||
}
|
||||
@@ -119,11 +132,32 @@ func (w *einoStreamingShellWrap) ExecuteStreaming(ctx context.Context, input *fi
|
||||
|
||||
outR, outW := schema.Pipe[*filesystem.ExecuteResponse](32)
|
||||
|
||||
go func(inner *schema.StreamReader[*filesystem.ExecuteResponse], command string, cancel context.CancelFunc, tctx context.Context) {
|
||||
defer inner.Close()
|
||||
go func(inner *schema.StreamReader[*filesystem.ExecuteResponse], command string, cancel context.CancelFunc, timeoutCleanup context.CancelFunc, tctx context.Context, conversationID string, reg mcp.EinoExecuteRunRegistry) {
|
||||
var innerCloseOnce sync.Once
|
||||
closeInner := func() {
|
||||
innerCloseOnce.Do(func() { inner.Close() })
|
||||
}
|
||||
defer closeInner()
|
||||
if timeoutCleanup != nil {
|
||||
defer timeoutCleanup()
|
||||
}
|
||||
if cancel != nil {
|
||||
defer cancel()
|
||||
}
|
||||
if reg != nil && conversationID != "" {
|
||||
defer reg.UnregisterActiveEinoExecute(conversationID)
|
||||
}
|
||||
|
||||
// ctx 取消时关闭内层流,避免 amass 等长时间无换行输出时 Recv 永久阻塞。
|
||||
stopWatch := make(chan struct{})
|
||||
go func() {
|
||||
select {
|
||||
case <-tctx.Done():
|
||||
closeInner()
|
||||
case <-stopWatch:
|
||||
}
|
||||
}()
|
||||
defer close(stopWatch)
|
||||
|
||||
var sb strings.Builder
|
||||
success := true
|
||||
@@ -144,6 +178,10 @@ func (w *einoStreamingShellWrap) ExecuteStreaming(ctx context.Context, input *fi
|
||||
invokeErr = context.DeadlineExceeded
|
||||
break
|
||||
}
|
||||
if errors.Is(rerr, context.Canceled) || (tctx != nil && errors.Is(tctx.Err(), context.Canceled)) {
|
||||
invokeErr = context.Canceled
|
||||
break
|
||||
}
|
||||
_ = outW.Send(nil, rerr)
|
||||
break
|
||||
}
|
||||
@@ -178,6 +216,21 @@ func (w *einoStreamingShellWrap) ExecuteStreaming(ctx context.Context, input *fi
|
||||
success = false
|
||||
invokeErr = context.DeadlineExceeded
|
||||
}
|
||||
// 用户「中断并继续」终止 execute:合并说明进工具结果(与 MCP CancelToolExecutionWithNote 一致)。
|
||||
partialStreamed := sb.String()
|
||||
var abortNote string
|
||||
if reg != nil && conversationID != "" && (invokeErr != nil || errors.Is(tctx.Err(), context.Canceled)) {
|
||||
if note := reg.TakeEinoExecuteAbortNote(conversationID); note != "" {
|
||||
abortNote = note
|
||||
merged := mcp.MergePartialToolOutputAndAbortNote(partialStreamed, note)
|
||||
sb.Reset()
|
||||
sb.WriteString(merged)
|
||||
if invokeErr == nil {
|
||||
success = false
|
||||
invokeErr = context.Canceled
|
||||
}
|
||||
}
|
||||
}
|
||||
// ADK 从本 Pipe 拼出 tool 消息正文;仅 Notify 尾标不会进入模型上下文。超时句写入流,与 UI 一致。
|
||||
if invokeErr != nil && errors.Is(invokeErr, context.DeadlineExceeded) {
|
||||
hint := "\n\n" + einoExecuteTimeoutUserHint() + "\n"
|
||||
@@ -187,12 +240,20 @@ func (w *einoStreamingShellWrap) ExecuteStreaming(ctx context.Context, input *fi
|
||||
}
|
||||
sb.WriteString(hint)
|
||||
}
|
||||
// 中断时循环内已逐行写入 stdout;此处只追加 USER INTERRUPT NOTE,避免整段输出重复。
|
||||
if invokeErr != nil && errors.Is(invokeErr, context.Canceled) && abortNote != "" {
|
||||
if partialStreamed != "" {
|
||||
_ = outW.Send(&filesystem.ExecuteResponse{Output: "\n\n" + mcp.AbortNoteBannerForModel + "\n" + abortNote}, nil)
|
||||
} else if text := strings.TrimSpace(sb.String()); text != "" {
|
||||
_ = outW.Send(&filesystem.ExecuteResponse{Output: text + "\n"}, nil)
|
||||
}
|
||||
}
|
||||
if w.recordMonitor != nil {
|
||||
w.recordMonitor(tid, command, sb.String(), success, invokeErr)
|
||||
}
|
||||
w.invokeNotify.Fire(tid, "execute", agentTag, success, sb.String(), invokeErr)
|
||||
outW.Close()
|
||||
}(sr, userCmd, execCancel, execCtx)
|
||||
}(sr, userCmd, execCancel, timeoutCancel, execCtx, convID, execReg)
|
||||
|
||||
return outR, nil
|
||||
}
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
"time"
|
||||
|
||||
"cyberstrike-ai/internal/einomcp"
|
||||
"cyberstrike-ai/internal/mcp"
|
||||
|
||||
"github.com/cloudwego/eino/adk/filesystem"
|
||||
"github.com/cloudwego/eino/schema"
|
||||
@@ -122,6 +123,94 @@ func TestEinoStreamingShellWrap_ToolTimeoutRecvErrIsSoft(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestEinoStreamingShellWrap_CapturesOutputWithToolTimeout(t *testing.T) {
|
||||
inner := &mockStreamingShell{output: "100\n"}
|
||||
notify := einomcp.NewToolInvokeNotifyHolder()
|
||||
var firedContent string
|
||||
notify.Set(func(toolCallID, toolName, einoAgent string, success bool, content string, invokeErr error) {
|
||||
firedContent = content
|
||||
})
|
||||
wrap := &einoStreamingShellWrap{
|
||||
inner: inner,
|
||||
invokeNotify: notify,
|
||||
toolTimeoutMinutes: 60,
|
||||
}
|
||||
sr, err := wrap.ExecuteStreaming(context.Background(), &filesystem.ExecuteRequest{Command: "echo 100"})
|
||||
if err != nil {
|
||||
t.Fatalf("ExecuteStreaming: %v", err)
|
||||
}
|
||||
defer sr.Close()
|
||||
|
||||
var got strings.Builder
|
||||
for {
|
||||
resp, rerr := sr.Recv()
|
||||
if errors.Is(rerr, io.EOF) {
|
||||
break
|
||||
}
|
||||
if rerr != nil {
|
||||
t.Fatalf("unexpected stream error: %v", rerr)
|
||||
}
|
||||
if resp != nil && resp.Output != "" {
|
||||
got.WriteString(resp.Output)
|
||||
}
|
||||
}
|
||||
if !strings.Contains(got.String(), "100") {
|
||||
t.Fatalf("stream output = %q, want contains 100", got.String())
|
||||
}
|
||||
if !strings.Contains(firedContent, "100") {
|
||||
t.Fatalf("notify content = %q, want contains 100", firedContent)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEinoStreamingShellWrap_AbortNoteDoesNotDuplicateStreamedOutput(t *testing.T) {
|
||||
inner := &mockStreamingShell{output: "line1\nline2\n", recvErr: context.Canceled}
|
||||
notify := einomcp.NewToolInvokeNotifyHolder()
|
||||
wrap := &einoStreamingShellWrap{
|
||||
inner: inner,
|
||||
invokeNotify: notify,
|
||||
}
|
||||
reg := &abortNoteTestRegistry{note: "改成20次"}
|
||||
ctx := mcp.WithEinoExecuteRunRegistry(
|
||||
mcp.WithMCPConversationID(context.Background(), "conv-abort-dup"),
|
||||
reg,
|
||||
)
|
||||
sr, err := wrap.ExecuteStreaming(ctx, &filesystem.ExecuteRequest{Command: "ping -c 10 baidu.com"})
|
||||
if err != nil {
|
||||
t.Fatalf("ExecuteStreaming: %v", err)
|
||||
}
|
||||
defer sr.Close()
|
||||
|
||||
var got strings.Builder
|
||||
for {
|
||||
resp, rerr := sr.Recv()
|
||||
if errors.Is(rerr, io.EOF) {
|
||||
break
|
||||
}
|
||||
if rerr != nil {
|
||||
t.Fatalf("unexpected stream error: %v", rerr)
|
||||
}
|
||||
if resp != nil && resp.Output != "" {
|
||||
got.WriteString(resp.Output)
|
||||
}
|
||||
}
|
||||
out := got.String()
|
||||
if strings.Count(out, "line1") != 1 || strings.Count(out, "line2") != 1 {
|
||||
t.Fatalf("stream duplicated stdout: %q", out)
|
||||
}
|
||||
if !strings.Contains(out, "改成20次") {
|
||||
t.Fatalf("stream missing abort note: %q", out)
|
||||
}
|
||||
}
|
||||
|
||||
type abortNoteTestRegistry struct {
|
||||
note string
|
||||
}
|
||||
|
||||
func (r *abortNoteTestRegistry) RegisterActiveEinoExecute(string, context.CancelFunc) {}
|
||||
func (r *abortNoteTestRegistry) UnregisterActiveEinoExecute(string) {}
|
||||
func (r *abortNoteTestRegistry) AbortActiveEinoExecute(string, string) bool { return false }
|
||||
func (r *abortNoteTestRegistry) TakeEinoExecuteAbortNote(string) string { return r.note }
|
||||
|
||||
func TestEinoStreamingShellWrap_NonTimeoutRecvErrStillHard(t *testing.T) {
|
||||
inner := &mockStreamingShell{recvErr: errors.New("broken pipe")}
|
||||
wrap := &einoStreamingShellWrap{inner: inner}
|
||||
|
||||
@@ -243,17 +243,14 @@ func prependEinoMiddlewares(
|
||||
return outTools, extraHandlers, toolSearchActive, nil
|
||||
}
|
||||
|
||||
func deepExtrasFromConfig(ma *config.MultiAgentConfig) (outputKey string, retry *adk.ModelRetryConfig, taskDesc func(context.Context, []adk.Agent) (string, error)) {
|
||||
func deepExtrasFromConfig(ma *config.MultiAgentConfig) (outputKey string, taskDesc func(context.Context, []adk.Agent) (string, error)) {
|
||||
if ma == nil {
|
||||
return "", nil, nil
|
||||
return "", nil
|
||||
}
|
||||
mw := ma.EinoMiddleware
|
||||
if k := strings.TrimSpace(mw.DeepOutputKey); k != "" {
|
||||
outputKey = k
|
||||
}
|
||||
if mw.DeepModelRetryMaxRetries > 0 {
|
||||
retry = &adk.ModelRetryConfig{MaxRetries: mw.DeepModelRetryMaxRetries}
|
||||
}
|
||||
prefix := strings.TrimSpace(mw.TaskToolDescriptionPrefix)
|
||||
if prefix != "" {
|
||||
taskDesc = func(ctx context.Context, agents []adk.Agent) (string, error) {
|
||||
@@ -274,5 +271,5 @@ func deepExtrasFromConfig(ma *config.MultiAgentConfig) (outputKey string, retry
|
||||
return prefix + "\n可用子代理(按名称 transfer / task 调用):" + strings.Join(names, "、"), nil
|
||||
}
|
||||
}
|
||||
return outputKey, retry, taskDesc
|
||||
return outputKey, taskDesc
|
||||
}
|
||||
|
||||
@@ -94,24 +94,20 @@ func NewPlanExecuteRoot(ctx context.Context, a *PlanExecuteRootArgs) (adk.Resuma
|
||||
if a.SkillMiddleware != nil {
|
||||
execHandlers = append(execHandlers, a.SkillMiddleware)
|
||||
}
|
||||
// 4. summarization(最后,与 Deep/Supervisor 一致)
|
||||
// 4. pre-summarization normalize + continuation dedup, then summarization (与 Deep/Supervisor 一致)
|
||||
if a.AppCfg != nil {
|
||||
sumMw, sumErr := newEinoSummarizationMiddleware(ctx, a.ExecModel, a.AppCfg, a.MwCfg, a.ConversationID, a.DB, a.ProjectID, a.Logger)
|
||||
if sumErr != nil {
|
||||
return nil, fmt.Errorf("plan_execute executor summarization: %w", sumErr)
|
||||
}
|
||||
execHandlers = append(execHandlers, sumMw)
|
||||
}
|
||||
// 5. 孤儿 tool 消息兜底:必须挂在所有改写历史中间件(summarization/reduction/skill)之后、
|
||||
// telemetry 之前,保证送入 ChatModel 的消息序列 tool_call ↔ tool_result 配对完整。
|
||||
execHandlers = append(execHandlers, newOrphanToolPrunerMiddleware(a.Logger, "plan_execute_executor"))
|
||||
if teleMw := newEinoModelInputTelemetryMiddleware(a.Logger, a.ModelName, a.ConversationID, "plan_execute_executor"); teleMw != nil {
|
||||
execHandlers = append(execHandlers, teleMw)
|
||||
}
|
||||
if a.ModelFacingTrace != nil {
|
||||
if capMw := newModelFacingTraceMiddleware(a.ModelFacingTrace); capMw != nil {
|
||||
execHandlers = append(execHandlers, capMw)
|
||||
}
|
||||
execHandlers = appendEinoChatModelTailMiddlewares(execHandlers, einoChatModelTailConfig{
|
||||
logger: a.Logger,
|
||||
phase: "plan_execute_executor",
|
||||
summarization: sumMw,
|
||||
modelName: a.ModelName,
|
||||
conversationID: a.ConversationID,
|
||||
trace: a.ModelFacingTrace,
|
||||
})
|
||||
}
|
||||
executor, err := newPlanExecuteExecutor(ctx, &planexecute.ExecutorConfig{
|
||||
Model: a.ExecModel,
|
||||
|
||||
@@ -144,13 +144,14 @@ func RunEinoSingleChatModelAgent(
|
||||
}
|
||||
handlers = append(handlers, einoSkillMW)
|
||||
}
|
||||
handlers = append(handlers, mainSumMw)
|
||||
if teleMw := newEinoModelInputTelemetryMiddleware(logger, appCfg.OpenAI.Model, conversationID, "eino_single"); teleMw != nil {
|
||||
handlers = append(handlers, teleMw)
|
||||
}
|
||||
if capMw := newModelFacingTraceMiddleware(modelFacingTrace); capMw != nil {
|
||||
handlers = append(handlers, capMw)
|
||||
}
|
||||
handlers = appendEinoChatModelTailMiddlewares(handlers, einoChatModelTailConfig{
|
||||
logger: logger,
|
||||
phase: "eino_single",
|
||||
summarization: mainSumMw,
|
||||
modelName: appCfg.OpenAI.Model,
|
||||
conversationID: conversationID,
|
||||
trace: modelFacingTrace,
|
||||
})
|
||||
|
||||
maxIter := agentMaxIterations(appCfg)
|
||||
|
||||
@@ -188,13 +189,10 @@ func RunEinoSingleChatModelAgent(
|
||||
MaxIterations: maxIter,
|
||||
Handlers: handlers,
|
||||
}
|
||||
outKey, modelRetry, _ := deepExtrasFromConfig(ma)
|
||||
outKey, _ := deepExtrasFromConfig(ma)
|
||||
if outKey != "" {
|
||||
chatCfg.OutputKey = outKey
|
||||
}
|
||||
if modelRetry != nil {
|
||||
chatCfg.ModelRetryConfig = modelRetry
|
||||
}
|
||||
|
||||
chatAgent, err := adk.NewChatModelAgent(ctx, chatCfg)
|
||||
if err != nil {
|
||||
|
||||
@@ -22,8 +22,6 @@ import (
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
const defaultSummarizationRetryMax = 3
|
||||
|
||||
// einoSummarizeUserInstruction:压缩历史时保留渗透测试关键信息。
|
||||
const einoSummarizeUserInstruction = `在保持所有关键安全测试信息完整的前提下压缩对话历史。
|
||||
|
||||
@@ -97,10 +95,8 @@ func newEinoSummarizationMiddleware(
|
||||
}
|
||||
}
|
||||
|
||||
retryMax := defaultSummarizationRetryMax
|
||||
if mwCfg != nil && mwCfg.SummarizationRetryMaxAttempts > 0 {
|
||||
retryMax = mwCfg.SummarizationRetryMaxAttempts
|
||||
}
|
||||
retryPolicy := einoTransientRunRetryPolicyFromMW(mwCfg)
|
||||
retryMax := retryPolicy.maxAttempts
|
||||
|
||||
// ModelOptions apply only to summarization Generate (same ChatModel instance as the agent).
|
||||
// Strip thinking/reasoning on this call path; mark requests for empty-choices diagnostics.
|
||||
@@ -137,13 +133,14 @@ func newEinoSummarizationMiddleware(
|
||||
Retry: &summarization.RetryConfig{
|
||||
MaxRetries: &retryMax,
|
||||
ShouldRetry: func(_ context.Context, _ adk.Message, err error) bool {
|
||||
if err != nil && logger != nil {
|
||||
logger.Warn("eino summarization generate attempt failed, will retry if attempts remain",
|
||||
retry := isEinoTransientRunError(err)
|
||||
if retry && logger != nil {
|
||||
logger.Warn("eino summarization generate transient error, will retry if attempts remain",
|
||||
zap.Error(err),
|
||||
zap.Int("max_retries", retryMax),
|
||||
)
|
||||
}
|
||||
return err != nil
|
||||
return retry
|
||||
},
|
||||
},
|
||||
Finalize: func(ctx context.Context, originalMessages []adk.Message, summary adk.Message) ([]adk.Message, error) {
|
||||
@@ -260,17 +257,19 @@ func summarizeFinalizeWithRecentAssistantToolTrail(
|
||||
nonSystem = append(nonSystem, msg)
|
||||
}
|
||||
|
||||
mergedSystem := mergeCollectedSystemMessages(systemMsgs)
|
||||
|
||||
if recentTrailTokenBudget <= 0 || len(nonSystem) == 0 {
|
||||
out := make([]adk.Message, 0, len(systemMsgs)+1)
|
||||
out = append(out, systemMsgs...)
|
||||
out := make([]adk.Message, 0, len(mergedSystem)+1)
|
||||
out = append(out, mergedSystem...)
|
||||
out = append(out, summary)
|
||||
return out, nil
|
||||
}
|
||||
|
||||
rounds := splitMessagesIntoRounds(nonSystem)
|
||||
if len(rounds) == 0 {
|
||||
out := make([]adk.Message, 0, len(systemMsgs)+1)
|
||||
out = append(out, systemMsgs...)
|
||||
out := make([]adk.Message, 0, len(mergedSystem)+1)
|
||||
out = append(out, mergedSystem...)
|
||||
out = append(out, summary)
|
||||
return out, nil
|
||||
}
|
||||
@@ -322,8 +321,8 @@ func summarizeFinalizeWithRecentAssistantToolTrail(
|
||||
selectedMsgs = append(selectedMsgs, selectedRoundsReverse[i].messages...)
|
||||
}
|
||||
|
||||
out := make([]adk.Message, 0, len(systemMsgs)+1+len(selectedMsgs))
|
||||
out = append(out, systemMsgs...)
|
||||
out := make([]adk.Message, 0, len(mergedSystem)+1+len(selectedMsgs))
|
||||
out = append(out, mergedSystem...)
|
||||
out = append(out, summary)
|
||||
out = append(out, selectedMsgs...)
|
||||
return out, nil
|
||||
|
||||
@@ -192,8 +192,8 @@ func TestSummarizeFinalize_KeepsToolRoundIntact(t *testing.T) {
|
||||
if len(out) < 2 {
|
||||
t.Fatalf("output too short: %d", len(out))
|
||||
}
|
||||
if out[0] != sys {
|
||||
t.Fatalf("first message must be system")
|
||||
if out[0].Role != schema.System || out[0].Content != "sys" {
|
||||
t.Fatalf("first message must be system sys, got %s: %q", out[0].Role, out[0].Content)
|
||||
}
|
||||
if out[1] != summary {
|
||||
t.Fatalf("second message must be summary")
|
||||
@@ -293,12 +293,12 @@ func TestSummarizeFinalize_BudgetZeroFallsBackToSummaryOnly(t *testing.T) {
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if len(out) != 2 || out[0] != sys || out[1] != summary {
|
||||
if len(out) != 2 || out[0].Role != schema.System || out[0].Content != "sys" || out[1] != summary {
|
||||
t.Fatalf("budget=0 must yield [system, summary] only, got %+v", out)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSummarizeFinalize_PreservesAllSystemMessages(t *testing.T) {
|
||||
func TestSummarizeFinalize_MergesSystemMessages(t *testing.T) {
|
||||
sys1 := schema.SystemMessage("sys1")
|
||||
sys2 := schema.SystemMessage("sys2")
|
||||
summary := schema.AssistantMessage("s", nil)
|
||||
@@ -321,10 +321,13 @@ func TestSummarizeFinalize_PreservesAllSystemMessages(t *testing.T) {
|
||||
for _, m := range out {
|
||||
if m != nil && m.Role == schema.System {
|
||||
systemCount++
|
||||
if got := m.Content; got != "sys1\n\nsys2" {
|
||||
t.Fatalf("unexpected merged system content: %q", got)
|
||||
}
|
||||
}
|
||||
}
|
||||
if systemCount != 2 {
|
||||
t.Fatalf("want 2 system messages retained, got %d", systemCount)
|
||||
if systemCount != 1 {
|
||||
t.Fatalf("want 1 merged system message, got %d", systemCount)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -378,6 +381,12 @@ func TestWriteSummarizationTranscript(t *testing.T) {
|
||||
if !strings.Contains(text, "tool_calls:") || !strings.Contains(text, "nmap output") {
|
||||
t.Fatalf("missing tool round: %q", text)
|
||||
}
|
||||
if !strings.Contains(text, `"name":"stub_tool"`) || !strings.Contains(text, `"arguments":"{}"`) {
|
||||
t.Fatalf("missing tool name/arguments: %q", text)
|
||||
}
|
||||
if strings.Contains(text, "tool_call_id") || strings.Contains(text, `"id":"tc1"`) {
|
||||
t.Fatalf("transcript should omit tool_call_id: %q", text)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSanitizeSystemContentForTranscript_BestPractice(t *testing.T) {
|
||||
|
||||
@@ -23,6 +23,11 @@ const (
|
||||
transcriptSkillsSystemMarker = "# Skills System"
|
||||
)
|
||||
|
||||
type transcriptToolCall struct {
|
||||
Name string `json:"name"`
|
||||
Arguments string `json:"arguments"`
|
||||
}
|
||||
|
||||
// formatSummarizationTranscript renders pre-compaction messages for transcript.txt.
|
||||
// Best practice: keep full user/assistant/tool turns; slim system to dynamic blocks only.
|
||||
func formatSummarizationTranscript(msgs []adk.Message) string {
|
||||
@@ -138,15 +143,21 @@ func appendTranscriptMessage(sb *strings.Builder, msg adk.Message) {
|
||||
}
|
||||
}
|
||||
if len(msg.ToolCalls) > 0 {
|
||||
if b, err := sonic.Marshal(msg.ToolCalls); err == nil {
|
||||
if b, err := sonic.Marshal(formatTranscriptToolCalls(msg.ToolCalls)); err == nil {
|
||||
sb.WriteString("tool_calls: ")
|
||||
sb.Write(b)
|
||||
sb.WriteByte('\n')
|
||||
}
|
||||
}
|
||||
if msg.ToolCallID != "" {
|
||||
sb.WriteString("tool_call_id: ")
|
||||
sb.WriteString(msg.ToolCallID)
|
||||
sb.WriteByte('\n')
|
||||
}
|
||||
}
|
||||
|
||||
func formatTranscriptToolCalls(calls []schema.ToolCall) []transcriptToolCall {
|
||||
out := make([]transcriptToolCall, 0, len(calls))
|
||||
for _, tc := range calls {
|
||||
out = append(out, transcriptToolCall{
|
||||
Name: tc.Function.Name,
|
||||
Arguments: tc.Function.Arguments,
|
||||
})
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
@@ -3,6 +3,7 @@ package multiagent
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
@@ -17,8 +18,9 @@ const (
|
||||
defaultEinoRunRetryMaxBackoff = 30 * time.Second
|
||||
)
|
||||
|
||||
// isEinoTransientRunError 判断 ADK 运行期错误是否适合指数退避续跑(429、5xx、网络抖动等)。
|
||||
// 用户取消、超时、迭代上限等由 run loop 单独处理,不在此列。
|
||||
// isEinoTransientRunError 是 Eino 运行期「可退避重试 vs 直接失败」的唯一判据。
|
||||
// 429/5xx/网络抖动等返回 true;用户取消、超时、迭代上限、鉴权失败等返回 false。
|
||||
// 其它模块(run loop、summarization 等)只调用本函数,不在别处维护平行规则。
|
||||
func isEinoTransientRunError(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
@@ -60,6 +62,7 @@ func isEinoTransientRunError(err error) bool {
|
||||
"dial tcp",
|
||||
"tls handshake timeout",
|
||||
"stream error",
|
||||
"goaway", // http2: server sent GOAWAY and closed the connection
|
||||
"unexpected eof",
|
||||
`": eof`, // net/http: Post "url": EOF (often wraps io.EOF)
|
||||
"unexpected end of json",
|
||||
@@ -78,6 +81,71 @@ func isEinoTransientRunError(err error) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
type einoTransientRunRetryPolicy struct {
|
||||
maxAttempts int
|
||||
maxBackoff time.Duration
|
||||
}
|
||||
|
||||
func einoTransientRunRetryPolicyFromArgs(args *einoADKRunLoopArgs) einoTransientRunRetryPolicy {
|
||||
return einoTransientRunRetryPolicy{
|
||||
maxAttempts: einoRunRetryMaxAttempts(args),
|
||||
maxBackoff: einoRunRetryMaxBackoff(args),
|
||||
}
|
||||
}
|
||||
|
||||
func einoTransientRunRetryPolicyFromMW(mw *config.MultiAgentEinoMiddlewareConfig) einoTransientRunRetryPolicy {
|
||||
maxBackoff := defaultEinoRunRetryMaxBackoff
|
||||
if mw != nil && mw.RunRetryMaxBackoffSec > 0 {
|
||||
maxBackoff = time.Duration(mw.RunRetryMaxBackoffSec) * time.Second
|
||||
}
|
||||
return einoTransientRunRetryPolicy{
|
||||
maxAttempts: RunRetryMaxAttemptsFromConfig(mw),
|
||||
maxBackoff: maxBackoff,
|
||||
}
|
||||
}
|
||||
|
||||
// einoTransientRunRetrier 在 run loop 内对临时错误做指数退避并重启 Runner(唯一重试执行层)。
|
||||
type einoTransientRunRetrier struct {
|
||||
policy einoTransientRunRetryPolicy
|
||||
attempts int
|
||||
}
|
||||
|
||||
func newEinoTransientRunRetrier(policy einoTransientRunRetryPolicy) *einoTransientRunRetrier {
|
||||
return &einoTransientRunRetrier{policy: policy}
|
||||
}
|
||||
|
||||
// tryRetry 对临时错误退避后返回重启消息;次数用尽返回 exhausted 错误。
|
||||
func (r *einoTransientRunRetrier) tryRetry(
|
||||
ctx context.Context,
|
||||
runErr error,
|
||||
args *einoADKRunLoopArgs,
|
||||
baseMsgs, accumulated []adk.Message,
|
||||
baseCount int,
|
||||
) (restarted bool, restartMsgs []adk.Message, ctxSource einoRunRestartContextSource, backoff time.Duration, fatal error) {
|
||||
if runErr == nil || !isEinoTransientRunError(runErr) {
|
||||
return false, nil, "", 0, runErr
|
||||
}
|
||||
r.attempts++
|
||||
if r.attempts > r.policy.maxAttempts {
|
||||
return false, nil, "", 0, fmt.Errorf("transient retry exhausted after %d attempts: %w", r.policy.maxAttempts, runErr)
|
||||
}
|
||||
backoff = einoTransientRetryBackoff(r.attempts-1, r.policy.maxBackoff)
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return false, nil, "", 0, ctx.Err()
|
||||
case <-time.After(backoff):
|
||||
}
|
||||
restartMsgs, ctxSource = einoMessagesForRunRestart(args, baseMsgs, accumulated, baseCount)
|
||||
return true, restartMsgs, ctxSource, backoff, nil
|
||||
}
|
||||
|
||||
func (r *einoTransientRunRetrier) attempt() int { return r.attempts }
|
||||
|
||||
func (r *einoTransientRunRetrier) maxAttempts() int { return r.policy.maxAttempts }
|
||||
|
||||
// reset 在一次成功推进后清零重试计数,使后续临时错误从第 1 次退避重新开始。
|
||||
func (r *einoTransientRunRetrier) reset() { r.attempts = 0 }
|
||||
|
||||
func einoRunRetryMaxAttempts(args *einoADKRunLoopArgs) int {
|
||||
if args != nil && args.RunRetryMaxAttempts > 0 {
|
||||
return args.RunRetryMaxAttempts
|
||||
@@ -85,7 +153,7 @@ func einoRunRetryMaxAttempts(args *einoADKRunLoopArgs) int {
|
||||
return defaultEinoRunRetryMaxAttempts
|
||||
}
|
||||
|
||||
// RunRetryMaxAttemptsFromConfig 供 handler 分段续跑计数(与 eino_middleware.run_retry_max_attempts 一致)。
|
||||
// RunRetryMaxAttemptsFromConfig 与 eino_middleware.run_retry_max_attempts 一致。
|
||||
func RunRetryMaxAttemptsFromConfig(mw *config.MultiAgentEinoMiddlewareConfig) int {
|
||||
if mw != nil && mw.RunRetryMaxAttempts > 0 {
|
||||
return mw.RunRetryMaxAttempts
|
||||
@@ -93,15 +161,6 @@ func RunRetryMaxAttemptsFromConfig(mw *config.MultiAgentEinoMiddlewareConfig) in
|
||||
return defaultEinoRunRetryMaxAttempts
|
||||
}
|
||||
|
||||
// TransientRetryBackoff 供 handler 在分段续跑前退避。
|
||||
func TransientRetryBackoff(attempt int, maxBackoffSec int) time.Duration {
|
||||
max := defaultEinoRunRetryMaxBackoff
|
||||
if maxBackoffSec > 0 {
|
||||
max = time.Duration(maxBackoffSec) * time.Second
|
||||
}
|
||||
return einoTransientRetryBackoff(attempt, max)
|
||||
}
|
||||
|
||||
func einoRunRetryMaxBackoff(args *einoADKRunLoopArgs) time.Duration {
|
||||
if args != nil && args.RunRetryMaxBackoffSec > 0 {
|
||||
return time.Duration(args.RunRetryMaxBackoffSec) * time.Second
|
||||
@@ -122,10 +181,11 @@ const (
|
||||
// 1) ModelFacingTrace(与模型实际入参一致) 2) 事件流累积的 runAccumulatedMsgs 3) 初始 msgs。
|
||||
func einoMessagesForRunRestart(args *einoADKRunLoopArgs, baseMsgs, accumulated []adk.Message, baseCount int) ([]adk.Message, einoRunRestartContextSource) {
|
||||
if trace := persistTraceSource(args, nil); len(trace) > 0 {
|
||||
return append([]adk.Message(nil), trace...), einoRestartContextModelTrace
|
||||
// modelFacingTrace includes prior Instruction system message(s); genModelInput will prepend again.
|
||||
return stripADKSystemMessages(trace), einoRestartContextModelTrace
|
||||
}
|
||||
if len(accumulated) > baseCount {
|
||||
return append([]adk.Message(nil), accumulated...), einoRestartContextAccumulated
|
||||
return stripADKSystemMessages(accumulated), einoRestartContextAccumulated
|
||||
}
|
||||
return append([]adk.Message(nil), baseMsgs...), einoRestartContextInitial
|
||||
}
|
||||
|
||||
@@ -27,6 +27,7 @@ func TestIsEinoTransientRunError(t *testing.T) {
|
||||
{"429", errors.New("HTTP 429 Too Many Requests"), true},
|
||||
{"rate limit", errors.New(`{"error":"rate limit exceeded"}`), true},
|
||||
{"connection reset", errors.New("read tcp: connection reset by peer"), true},
|
||||
{"http2 goaway", errors.New("failed to receive stream chunk: error, http2: server sent GOAWAY and closed the connection; LastStreamID=791, ErrCode=NO_ERROR"), true},
|
||||
{"unexpected eof", errors.New("unexpected EOF"), true},
|
||||
{"503", errors.New("upstream returned 503"), true},
|
||||
{"iteration limit", errors.New("max iteration reached"), false},
|
||||
@@ -90,6 +91,20 @@ func TestEinoRunRetryMaxAttemptsFromArgs(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestEinoTransientRunRetrierReset(t *testing.T) {
|
||||
t.Parallel()
|
||||
r := newEinoTransientRunRetrier(einoTransientRunRetryPolicy{maxAttempts: 10, maxBackoff: 30 * time.Second})
|
||||
r.attempts = 3
|
||||
r.reset()
|
||||
if r.attempt() != 0 {
|
||||
t.Fatalf("after reset: attempt=%d, want 0", r.attempt())
|
||||
}
|
||||
// 重置后下一次退避应从 2s 起算(attempt index 0)。
|
||||
if got := einoTransientRetryBackoff(r.attempt(), r.policy.maxBackoff); got != 2*time.Second {
|
||||
t.Fatalf("backoff after reset: got %v, want 2s", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAppendUserMessageIfNeeded(t *testing.T) {
|
||||
t.Parallel()
|
||||
msgs := []adk.Message{schema.UserMessage("old task")}
|
||||
@@ -102,10 +117,3 @@ func TestAppendUserMessageIfNeeded(t *testing.T) {
|
||||
t.Fatalf("should not duplicate user message: len=%d", len(dup))
|
||||
}
|
||||
}
|
||||
|
||||
func TestErrTransientRetryContinue(t *testing.T) {
|
||||
t.Parallel()
|
||||
if !errors.Is(ErrTransientRetryContinue, ErrTransientRetryContinue) {
|
||||
t.Fatal("sentinel should match")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -5,11 +5,3 @@ import "errors"
|
||||
// ErrInterruptContinue 作为 context.CancelCause 使用:用户选择「中断并继续」且当前无进行中的 MCP 工具时,
|
||||
// 取消当前推理/流式输出,并在同一会话任务内携带用户补充说明自动续跑下一轮(类似 Hermes 式人机回合)。
|
||||
var ErrInterruptContinue = errors.New("agent interrupt: continue with user-supplied context")
|
||||
|
||||
// ErrTransientRetryContinue 表示 Run 因 429/网络等临时错误结束,应由 handler 落库轨迹后
|
||||
// loadHistoryFromAgentTrace 再开下一轮 Run(与 ErrInterruptContinue 同级的「分段续跑」语义)。
|
||||
var ErrTransientRetryContinue = errors.New("agent transient: retry after persisting trace")
|
||||
|
||||
// ErrEmptyResponseContinue 表示 Eino ADK 会话正常结束但未捕获到助手正文,应由 handler 落库轨迹后
|
||||
// loadHistoryFromAgentTrace 再开下一轮 Run(与 ErrInterruptContinue / ErrTransientRetryContinue 同级)。
|
||||
var ErrEmptyResponseContinue = errors.New("agent empty response: continue after persisting trace")
|
||||
|
||||
@@ -27,7 +27,7 @@ import (
|
||||
// 本中间件与之互补,专职兜底正向孤儿。
|
||||
// - 仅剔除消息,不向历史里注入虚构 assistant(tc):虚构 tool_calls 反而会误导模型后续推理。
|
||||
// 摘要已覆盖被裁剪段的语义,丢一条原始 tool 结果对对话连贯性影响最小。
|
||||
// - 位置建议:挂在所有可能改写历史的中间件(summarization / reduction / skill / plantask /
|
||||
// - 位置建议:挂在 summarization / reduction / skill / plantask / system 合并 / 续聊 dedup 之后,
|
||||
// tool_search)之后,靠近 ChatModel 调用的那一端。
|
||||
type orphanToolPrunerMiddleware struct {
|
||||
adk.BaseChatModelAgentMiddleware
|
||||
|
||||
@@ -231,13 +231,13 @@ func RunDeepAgent(
|
||||
}
|
||||
subHandlers = append(subHandlers, einoSkillMW)
|
||||
}
|
||||
subHandlers = append(subHandlers, subSumMw)
|
||||
// 孤儿 tool 消息兜底:放在 summarization 之后,telemetry 之前,
|
||||
// 以便 telemetry 记录的 token 数与 LLM 实际入参一致。
|
||||
subHandlers = append(subHandlers, newOrphanToolPrunerMiddleware(logger, "sub_agent:"+id))
|
||||
if teleMw := newEinoModelInputTelemetryMiddleware(logger, appCfg.OpenAI.Model, conversationID, "sub_agent"); teleMw != nil {
|
||||
subHandlers = append(subHandlers, teleMw)
|
||||
}
|
||||
subHandlers = appendEinoChatModelTailMiddlewares(subHandlers, einoChatModelTailConfig{
|
||||
logger: logger,
|
||||
phase: "sub_agent:" + id,
|
||||
summarization: subSumMw,
|
||||
modelName: appCfg.OpenAI.Model,
|
||||
conversationID: conversationID,
|
||||
})
|
||||
|
||||
subInstrFinal := project.AppendVisionImageAnalysisIfReady(instr, appCfg.Vision.Ready())
|
||||
subInstrFinal = injectToolNamesOnlyInstruction(ctx, subInstrFinal, subTools, subToolSearchActive)
|
||||
@@ -379,14 +379,14 @@ func RunDeepAgent(
|
||||
if einoSkillMW != nil {
|
||||
deepHandlers = append(deepHandlers, einoSkillMW)
|
||||
}
|
||||
deepHandlers = append(deepHandlers, mainSumMw)
|
||||
deepHandlers = append(deepHandlers, newOrphanToolPrunerMiddleware(logger, "deep_orchestrator"))
|
||||
if teleMw := newEinoModelInputTelemetryMiddleware(logger, appCfg.OpenAI.Model, conversationID, "deep_orchestrator"); teleMw != nil {
|
||||
deepHandlers = append(deepHandlers, teleMw)
|
||||
}
|
||||
if capMw := newModelFacingTraceMiddleware(modelFacingTrace); capMw != nil {
|
||||
deepHandlers = append(deepHandlers, capMw)
|
||||
}
|
||||
deepHandlers = appendEinoChatModelTailMiddlewares(deepHandlers, einoChatModelTailConfig{
|
||||
logger: logger,
|
||||
phase: "deep_orchestrator",
|
||||
summarization: mainSumMw,
|
||||
modelName: appCfg.OpenAI.Model,
|
||||
conversationID: conversationID,
|
||||
trace: modelFacingTrace,
|
||||
})
|
||||
|
||||
supHandlers := []adk.ChatModelAgentMiddleware{}
|
||||
if len(mainOrchestratorPre) > 0 {
|
||||
@@ -395,14 +395,14 @@ func RunDeepAgent(
|
||||
if einoSkillMW != nil {
|
||||
supHandlers = append(supHandlers, einoSkillMW)
|
||||
}
|
||||
supHandlers = append(supHandlers, mainSumMw)
|
||||
supHandlers = append(supHandlers, newOrphanToolPrunerMiddleware(logger, "supervisor_orchestrator"))
|
||||
if teleMw := newEinoModelInputTelemetryMiddleware(logger, appCfg.OpenAI.Model, conversationID, "supervisor_orchestrator"); teleMw != nil {
|
||||
supHandlers = append(supHandlers, teleMw)
|
||||
}
|
||||
if capMw := newModelFacingTraceMiddleware(modelFacingTrace); capMw != nil {
|
||||
supHandlers = append(supHandlers, capMw)
|
||||
}
|
||||
supHandlers = appendEinoChatModelTailMiddlewares(supHandlers, einoChatModelTailConfig{
|
||||
logger: logger,
|
||||
phase: "supervisor_orchestrator",
|
||||
summarization: mainSumMw,
|
||||
modelName: appCfg.OpenAI.Model,
|
||||
conversationID: conversationID,
|
||||
trace: modelFacingTrace,
|
||||
})
|
||||
|
||||
mainToolsCfg := adk.ToolsConfig{
|
||||
ToolsNodeConfig: compose.ToolsNodeConfig{
|
||||
@@ -416,7 +416,7 @@ func RunDeepAgent(
|
||||
EmitInternalEvents: true,
|
||||
}
|
||||
|
||||
deepOutKey, modelRetry, taskGen := deepExtrasFromConfig(ma)
|
||||
deepOutKey, taskGen := deepExtrasFromConfig(ma)
|
||||
|
||||
var da adk.Agent
|
||||
switch orchMode {
|
||||
@@ -451,12 +451,14 @@ func RunDeepAgent(
|
||||
SkillMiddleware: einoSkillMW,
|
||||
FilesystemMiddleware: peFsMw,
|
||||
ModelFacingTrace: modelFacingTrace,
|
||||
PlannerReplannerRewriteHandlers: []adk.ChatModelAgentMiddleware{
|
||||
mainSumMw,
|
||||
// 孤儿 tool 消息兜底:必须挂在 summarization 之后、telemetry 之前。
|
||||
newOrphanToolPrunerMiddleware(logger, "plan_execute_planner_replanner"),
|
||||
newEinoModelInputTelemetryMiddleware(logger, appCfg.OpenAI.Model, conversationID, "plan_execute_planner_replanner_rewrite"),
|
||||
},
|
||||
PlannerReplannerRewriteHandlers: appendEinoChatModelTailMiddlewares(nil, einoChatModelTailConfig{
|
||||
logger: logger,
|
||||
phase: "plan_execute_planner_replanner",
|
||||
summarization: mainSumMw,
|
||||
modelName: appCfg.OpenAI.Model,
|
||||
conversationID: conversationID,
|
||||
skipTrace: true,
|
||||
}),
|
||||
})
|
||||
if perr != nil {
|
||||
return nil, perr
|
||||
@@ -473,9 +475,6 @@ func RunDeepAgent(
|
||||
Handlers: supHandlers,
|
||||
Exit: &adk.ExitTool{},
|
||||
}
|
||||
if modelRetry != nil {
|
||||
supCfg.ModelRetryConfig = modelRetry
|
||||
}
|
||||
if deepOutKey != "" {
|
||||
supCfg.OutputKey = deepOutKey
|
||||
}
|
||||
@@ -509,9 +508,6 @@ func RunDeepAgent(
|
||||
if deepOutKey != "" {
|
||||
dcfg.OutputKey = deepOutKey
|
||||
}
|
||||
if modelRetry != nil {
|
||||
dcfg.ModelRetryConfig = modelRetry
|
||||
}
|
||||
if taskGen != nil {
|
||||
dcfg.TaskToolDescriptionGenerator = taskGen
|
||||
}
|
||||
|
||||
@@ -0,0 +1,86 @@
|
||||
package multiagent
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/cloudwego/eino/adk"
|
||||
"github.com/cloudwego/eino/schema"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// systemMessageNormalizerMiddleware merges duplicate role=system messages into a single
|
||||
// leading system message before summarization and each ChatModel call.
|
||||
type systemMessageNormalizerMiddleware struct {
|
||||
adk.BaseChatModelAgentMiddleware
|
||||
logger *zap.Logger
|
||||
phase string
|
||||
}
|
||||
|
||||
func newSystemMessageNormalizerMiddleware(logger *zap.Logger, phase string) adk.ChatModelAgentMiddleware {
|
||||
return &systemMessageNormalizerMiddleware{logger: logger, phase: phase}
|
||||
}
|
||||
|
||||
func (m *systemMessageNormalizerMiddleware) BeforeModelRewriteState(
|
||||
ctx context.Context,
|
||||
state *adk.ChatModelAgentState,
|
||||
mc *adk.ModelContext,
|
||||
) (context.Context, *adk.ChatModelAgentState, error) {
|
||||
_ = mc
|
||||
if m == nil || state == nil || len(state.Messages) == 0 {
|
||||
return ctx, state, nil
|
||||
}
|
||||
before := countADKSystemMessages(state.Messages)
|
||||
if before <= 1 {
|
||||
return ctx, state, nil
|
||||
}
|
||||
normalized := normalizeSingleLeadingSystemMessage(state.Messages, "")
|
||||
if len(normalized) == len(state.Messages) && countADKSystemMessages(normalized) >= before {
|
||||
return ctx, state, nil
|
||||
}
|
||||
if m.logger != nil {
|
||||
m.logger.Info("eino system messages merged",
|
||||
zap.String("phase", m.phase),
|
||||
zap.Int("system_before", before),
|
||||
zap.Int("system_after", countADKSystemMessages(normalized)),
|
||||
zap.Int("messages_before", len(state.Messages)),
|
||||
zap.Int("messages_after", len(normalized)),
|
||||
)
|
||||
}
|
||||
out := *state
|
||||
out.Messages = normalized
|
||||
return ctx, &out, nil
|
||||
}
|
||||
|
||||
func countADKSystemMessages(msgs []adk.Message) int {
|
||||
n := 0
|
||||
for _, msg := range msgs {
|
||||
if msg != nil && msg.Role == schema.System {
|
||||
n++
|
||||
}
|
||||
}
|
||||
return n
|
||||
}
|
||||
|
||||
// stripADKSystemMessages removes all system messages. Use before runner.Run restart when
|
||||
// genModelInput will prepend a fresh Instruction.
|
||||
func stripADKSystemMessages(msgs []adk.Message) []adk.Message {
|
||||
if len(msgs) == 0 {
|
||||
return msgs
|
||||
}
|
||||
out := make([]adk.Message, 0, len(msgs))
|
||||
for _, msg := range msgs {
|
||||
if msg == nil || msg.Role == schema.System {
|
||||
continue
|
||||
}
|
||||
out = append(out, msg)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// mergeCollectedSystemMessages collapses multiple system messages into one (or none).
|
||||
func mergeCollectedSystemMessages(systemMsgs []adk.Message) []adk.Message {
|
||||
if len(systemMsgs) == 0 {
|
||||
return nil
|
||||
}
|
||||
return normalizeSingleLeadingSystemMessage(systemMsgs, "")
|
||||
}
|
||||
@@ -0,0 +1,75 @@
|
||||
package multiagent
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/cloudwego/eino/adk"
|
||||
"github.com/cloudwego/eino/schema"
|
||||
)
|
||||
|
||||
func TestStripADKSystemMessages(t *testing.T) {
|
||||
in := []adk.Message{
|
||||
schema.SystemMessage("a"),
|
||||
schema.UserMessage("u"),
|
||||
schema.SystemMessage("b"),
|
||||
schema.AssistantMessage("x", nil),
|
||||
}
|
||||
out := stripADKSystemMessages(in)
|
||||
if len(out) != 2 {
|
||||
t.Fatalf("got %d messages, want 2", len(out))
|
||||
}
|
||||
if out[0].Role != schema.User || out[1].Role != schema.Assistant {
|
||||
t.Fatalf("unexpected roles: %s, %s", out[0].Role, out[1].Role)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEinoMessagesForRunRestart_StripsSystemFromTrace(t *testing.T) {
|
||||
holder := newModelFacingTraceHolder()
|
||||
holder.storeFromState(&adk.ChatModelAgentState{Messages: []adk.Message{
|
||||
schema.SystemMessage("sys-1"),
|
||||
schema.SystemMessage("sys-2"),
|
||||
schema.UserMessage("task"),
|
||||
}})
|
||||
msgs, src := einoMessagesForRunRestart(&einoADKRunLoopArgs{ModelFacingTrace: holder}, nil, nil, 0)
|
||||
if src != einoRestartContextModelTrace {
|
||||
t.Fatalf("source: got %q want model_trace", src)
|
||||
}
|
||||
if len(msgs) != 1 || msgs[0].Role != schema.User {
|
||||
t.Fatalf("expected user-only restart msgs, got %+v", msgs)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSystemMessageNormalizerMiddleware_MergesDuplicates(t *testing.T) {
|
||||
mw := newSystemMessageNormalizerMiddleware(nil, "test")
|
||||
state := &adk.ChatModelAgentState{Messages: []adk.Message{
|
||||
schema.SystemMessage("a"),
|
||||
schema.SystemMessage("b"),
|
||||
schema.UserMessage("u"),
|
||||
}}
|
||||
_, out, err := mw.(*systemMessageNormalizerMiddleware).BeforeModelRewriteState(context.Background(), state, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if countADKSystemMessages(out.Messages) != 1 {
|
||||
t.Fatalf("want 1 system, got %d", countADKSystemMessages(out.Messages))
|
||||
}
|
||||
if out.Messages[0].Content != "a\n\nb" {
|
||||
t.Fatalf("merged content: %q", out.Messages[0].Content)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSystemMessageNormalizerMiddleware_NoOpSingleSystem(t *testing.T) {
|
||||
mw := newSystemMessageNormalizerMiddleware(nil, "test")
|
||||
state := &adk.ChatModelAgentState{Messages: []adk.Message{
|
||||
schema.SystemMessage("only"),
|
||||
schema.UserMessage("u"),
|
||||
}}
|
||||
_, out, err := mw.(*systemMessageNormalizerMiddleware).BeforeModelRewriteState(context.Background(), state, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if out != state {
|
||||
t.Fatalf("expected same state pointer for no-op")
|
||||
}
|
||||
}
|
||||
@@ -11797,34 +11797,44 @@ tr.mcp-stats-tool-row[data-tool-name]:focus-visible {
|
||||
background: transparent;
|
||||
color: var(--text-muted);
|
||||
cursor: pointer;
|
||||
border-radius: 4px;
|
||||
border-radius: 6px;
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
transition: all 0.2s ease;
|
||||
}
|
||||
|
||||
.batch-delete-btn svg {
|
||||
width: 16px;
|
||||
height: 16px;
|
||||
transition: transform 0.2s ease;
|
||||
}
|
||||
|
||||
.batch-delete-btn:hover {
|
||||
background: rgba(220, 53, 69, 0.1);
|
||||
color: var(--error-color);
|
||||
}
|
||||
|
||||
.batch-delete-btn:hover svg {
|
||||
transform: scale(1.08);
|
||||
}
|
||||
|
||||
.batch-delete-btn:active {
|
||||
background: rgba(220, 53, 69, 0.2);
|
||||
transform: scale(0.95);
|
||||
}
|
||||
|
||||
.batch-manage-footer {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: space-between;
|
||||
justify-content: flex-end;
|
||||
padding: 16px 24px;
|
||||
border-top: 1px solid var(--border-color);
|
||||
flex-shrink: 0;
|
||||
}
|
||||
|
||||
.select-all-checkbox {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 8px;
|
||||
.batch-table-col-checkbox input[type="checkbox"] {
|
||||
cursor: pointer;
|
||||
font-size: 0.875rem;
|
||||
color: var(--text-primary);
|
||||
}
|
||||
|
||||
.batch-footer-actions {
|
||||
|
||||
@@ -1656,6 +1656,7 @@
|
||||
"rateWarning": "Some failures detected",
|
||||
"rateCritical": "High failure rate",
|
||||
"statsSubtitle": "Refreshed {{time}} · {{count}} tools",
|
||||
"retentionHint": "Execution records are kept for {{days}} days, then purged automatically.",
|
||||
"timelineTitle": "Call trend",
|
||||
"timelineHint": "All tools combined (not split by tool)",
|
||||
"timelineRange24h": "24h",
|
||||
|
||||
@@ -1644,6 +1644,7 @@
|
||||
"rateWarning": "存在失败调用",
|
||||
"rateCritical": "失败率偏高",
|
||||
"statsSubtitle": "最后刷新 {{time}} · 共 {{count}} 个工具",
|
||||
"retentionHint": "执行记录保留 {{days}} 天,超期自动清理",
|
||||
"timelineTitle": "调用趋势",
|
||||
"timelineHint": "全部工具合计,不按工具拆分",
|
||||
"timelineRange24h": "24 小时",
|
||||
|
||||
+45
-5
@@ -7450,14 +7450,14 @@ async function showBatchManageModal() {
|
||||
updateBatchManageTitle(allConversationsForBatch.length);
|
||||
|
||||
renderBatchConversations();
|
||||
openAppModal('batch-manage-modal');
|
||||
openAppModal('batch-manage-modal', { focus: false });
|
||||
} catch (error) {
|
||||
console.error('加载对话列表失败:', error);
|
||||
// 错误时使用空数组,不显示错误提示(更友好的用户体验)
|
||||
allConversationsForBatch = [];
|
||||
updateBatchManageTitle(0);
|
||||
renderBatchConversations();
|
||||
openAppModal('batch-manage-modal');
|
||||
openAppModal('batch-manage-modal', { focus: false });
|
||||
}
|
||||
}
|
||||
|
||||
@@ -7517,6 +7517,7 @@ function renderBatchConversations(filtered = null) {
|
||||
checkbox.type = 'checkbox';
|
||||
checkbox.className = 'batch-conversation-checkbox';
|
||||
checkbox.dataset.conversationId = conv.id;
|
||||
checkbox.addEventListener('change', syncSelectAllBatchCheckbox);
|
||||
|
||||
const name = document.createElement('div');
|
||||
name.className = 'batch-table-col-name';
|
||||
@@ -7542,9 +7543,21 @@ function renderBatchConversations(filtered = null) {
|
||||
const action = document.createElement('div');
|
||||
action.className = 'batch-table-col-action';
|
||||
const deleteBtn = document.createElement('button');
|
||||
deleteBtn.type = 'button';
|
||||
deleteBtn.className = 'batch-delete-btn';
|
||||
deleteBtn.innerHTML = '🗑️';
|
||||
deleteBtn.onclick = () => deleteConversation(conv.id);
|
||||
deleteBtn.innerHTML = `
|
||||
<svg width="16" height="16" viewBox="0 0 24 24" fill="none" xmlns="http://www.w3.org/2000/svg" aria-hidden="true">
|
||||
<path d="M3 6h18M8 6V4a2 2 0 0 1 2-2h4a2 2 0 0 1 2 2v2m3 0v14a2 2 0 0 1-2 2H7a2 2 0 0 1-2-2V6h14zM10 11v6M14 11v6"
|
||||
stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"/>
|
||||
</svg>
|
||||
`;
|
||||
const deleteLabel = typeof window.t === 'function' ? window.t('contextMenu.deleteConversation') : '删除此对话';
|
||||
deleteBtn.title = deleteLabel;
|
||||
deleteBtn.setAttribute('aria-label', deleteLabel);
|
||||
deleteBtn.onclick = (e) => {
|
||||
e.stopPropagation();
|
||||
deleteConversation(conv.id);
|
||||
};
|
||||
action.appendChild(deleteBtn);
|
||||
|
||||
row.appendChild(checkbox);
|
||||
@@ -7554,6 +7567,8 @@ function renderBatchConversations(filtered = null) {
|
||||
|
||||
list.appendChild(row);
|
||||
});
|
||||
|
||||
syncSelectAllBatchCheckbox();
|
||||
}
|
||||
|
||||
// 筛选批量管理对话
|
||||
@@ -7575,12 +7590,35 @@ function filterBatchConversations(query) {
|
||||
function toggleSelectAllBatch() {
|
||||
const selectAll = document.getElementById('batch-select-all');
|
||||
const checkboxes = document.querySelectorAll('.batch-conversation-checkbox');
|
||||
|
||||
|
||||
if (selectAll) {
|
||||
selectAll.indeterminate = false;
|
||||
}
|
||||
checkboxes.forEach(cb => {
|
||||
cb.checked = selectAll.checked;
|
||||
});
|
||||
}
|
||||
|
||||
function syncSelectAllBatchCheckbox() {
|
||||
const selectAll = document.getElementById('batch-select-all');
|
||||
if (!selectAll) return;
|
||||
|
||||
const checkboxes = document.querySelectorAll('.batch-conversation-checkbox');
|
||||
const total = checkboxes.length;
|
||||
const checked = document.querySelectorAll('.batch-conversation-checkbox:checked').length;
|
||||
|
||||
if (total === 0 || checked === 0) {
|
||||
selectAll.checked = false;
|
||||
selectAll.indeterminate = false;
|
||||
} else if (checked === total) {
|
||||
selectAll.checked = true;
|
||||
selectAll.indeterminate = false;
|
||||
} else {
|
||||
selectAll.checked = false;
|
||||
selectAll.indeterminate = true;
|
||||
}
|
||||
}
|
||||
|
||||
// 删除选中的对话
|
||||
async function deleteSelectedConversations() {
|
||||
const checkboxes = document.querySelectorAll('.batch-conversation-checkbox:checked');
|
||||
@@ -7604,6 +7642,7 @@ async function deleteSelectedConversations() {
|
||||
const selectAll = document.getElementById('batch-select-all');
|
||||
if (selectAll) {
|
||||
selectAll.checked = false;
|
||||
selectAll.indeterminate = false;
|
||||
}
|
||||
} catch (error) {
|
||||
console.error('删除失败:', error);
|
||||
@@ -7619,6 +7658,7 @@ function closeBatchManageModal() {
|
||||
const selectAll = document.getElementById('batch-select-all');
|
||||
if (selectAll) {
|
||||
selectAll.checked = false;
|
||||
selectAll.indeterminate = false;
|
||||
}
|
||||
allConversationsForBatch = [];
|
||||
}
|
||||
|
||||
@@ -1944,6 +1944,7 @@ function handleStreamEvent(event, progressElement, progressId,
|
||||
message: event.message || '',
|
||||
data: d
|
||||
});
|
||||
finalizeOutstandingToolCallsForProgress(progressId, 'failed');
|
||||
break;
|
||||
}
|
||||
|
||||
@@ -3531,6 +3532,7 @@ const monitorState = {
|
||||
timelineRange: null,
|
||||
timelineError: null,
|
||||
lastFetchedAt: null,
|
||||
retentionDays: 0,
|
||||
pagination: {
|
||||
page: 1,
|
||||
pageSize: (() => {
|
||||
@@ -3626,6 +3628,7 @@ async function refreshMonitorPanel(page = null) {
|
||||
monitorState.timeline = timeline;
|
||||
monitorState.timelineError = timelineError;
|
||||
monitorState.lastFetchedAt = new Date();
|
||||
monitorState.retentionDays = typeof result.retention_days === 'number' ? result.retention_days : 0;
|
||||
|
||||
// 更新分页信息
|
||||
if (result.total !== undefined) {
|
||||
@@ -3709,6 +3712,7 @@ async function refreshMonitorPanelWithFilter(statusFilter = 'all', toolFilter =
|
||||
monitorState.timeline = timeline;
|
||||
monitorState.timelineError = timelineError;
|
||||
monitorState.lastFetchedAt = new Date();
|
||||
monitorState.retentionDays = typeof result.retention_days === 'number' ? result.retention_days : 0;
|
||||
|
||||
// 更新分页信息
|
||||
if (result.total !== undefined) {
|
||||
@@ -4526,15 +4530,20 @@ function renderMcpStatsStackedBar(success, failed) {
|
||||
</div>`;
|
||||
}
|
||||
|
||||
function updateMonitorStatsSubtitle(lastFetchedAt, toolCount) {
|
||||
function updateMonitorStatsSubtitle(lastFetchedAt, toolCount, retentionDays) {
|
||||
const subtitle = document.getElementById('monitor-stats-subtitle');
|
||||
if (!subtitle) return;
|
||||
const locale = (typeof window.__locale === 'string' && window.__locale.startsWith('zh')) ? 'zh-CN' : 'en-US';
|
||||
const timeText = lastFetchedAt
|
||||
? (lastFetchedAt.toLocaleString ? lastFetchedAt.toLocaleString(locale) : String(lastFetchedAt))
|
||||
: '—';
|
||||
const text = mcpMonitorT('statsSubtitle', { time: timeText, count: toolCount })
|
||||
let text = mcpMonitorT('statsSubtitle', { time: timeText, count: toolCount })
|
||||
|| monitorFallback(`最后刷新 ${timeText} · 共 ${toolCount} 个工具`, `Refreshed ${timeText} · ${toolCount} tools`);
|
||||
if (typeof retentionDays === 'number' && retentionDays > 0) {
|
||||
const hint = mcpMonitorT('retentionHint', { days: retentionDays })
|
||||
|| monitorFallback(`执行记录保留 ${retentionDays} 天,超期自动清理`, `Execution records are kept for ${retentionDays} days, then purged automatically.`);
|
||||
text += ' · ' + hint;
|
||||
}
|
||||
subtitle.textContent = text;
|
||||
subtitle.hidden = false;
|
||||
}
|
||||
@@ -4959,7 +4968,7 @@ function renderMonitorStats(statsMap = {}, lastFetchedAt = null) {
|
||||
} else if (toolFilterEl) {
|
||||
toolFilterEl.classList.remove('is-filter-active');
|
||||
}
|
||||
updateMonitorStatsSubtitle(lastFetchedAt, entries.length);
|
||||
updateMonitorStatsSubtitle(lastFetchedAt, entries.length, monitorState.retentionDays);
|
||||
}
|
||||
|
||||
function renderMonitorExecutions(executions = [], statusFilter = 'all') {
|
||||
|
||||
@@ -3777,7 +3777,9 @@
|
||||
<div class="modal-body batch-manage-body">
|
||||
<div class="batch-conversations-table">
|
||||
<div class="batch-table-header">
|
||||
<div class="batch-table-col-checkbox"></div>
|
||||
<div class="batch-table-col-checkbox">
|
||||
<input type="checkbox" id="batch-select-all" onchange="toggleSelectAllBatch()" data-i18n="batchManageModal.selectAll" data-i18n-attr="title" title="全选" />
|
||||
</div>
|
||||
<div class="batch-table-col-name" data-i18n="batchManageModal.conversationName">对话名称</div>
|
||||
<div class="batch-table-col-time" data-i18n="batchManageModal.lastTime">最近一次对话时间</div>
|
||||
<div class="batch-table-col-action" data-i18n="batchManageModal.action">操作</div>
|
||||
@@ -3786,10 +3788,6 @@
|
||||
</div>
|
||||
</div>
|
||||
<div class="modal-footer batch-manage-footer">
|
||||
<label class="select-all-checkbox">
|
||||
<input type="checkbox" id="batch-select-all" onchange="toggleSelectAllBatch()" />
|
||||
<span data-i18n="batchManageModal.selectAll">全选</span>
|
||||
</label>
|
||||
<div class="batch-footer-actions">
|
||||
<button class="btn-secondary" onclick="closeBatchManageModal()" data-i18n="common.cancel">取消</button>
|
||||
<button class="btn-primary" onclick="deleteSelectedConversations()" data-i18n="batchManageModal.deleteSelected">删除所选</button>
|
||||
|
||||
Reference in New Issue
Block a user