diff --git a/internal/database/batch_task.go b/internal/database/batch_task.go index 8e72ed73..b59a9acb 100644 --- a/internal/database/batch_task.go +++ b/internal/database/batch_task.go @@ -23,6 +23,7 @@ type BatchTaskQueueRow struct { LastScheduleError sql.NullString LastRunError sql.NullString ProjectID sql.NullString + Concurrency sql.NullInt64 Status string CreatedAt time.Time StartedAt sql.NullTime @@ -53,6 +54,7 @@ func (db *DB) CreateBatchQueue( cronExpr string, nextRunAt *time.Time, projectID string, + concurrency int, tasks []map[string]interface{}, ) error { tx, err := db.Begin() @@ -72,8 +74,8 @@ func (db *DB) CreateBatchQueue( projectIDVal = strings.TrimSpace(projectID) } _, err = tx.Exec( - "INSERT INTO batch_task_queues (id, title, role, agent_mode, schedule_mode, cron_expr, next_run_at, schedule_enabled, project_id, status, created_at, current_index) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)", - queueID, title, role, agentMode, scheduleMode, cronExpr, nextRunAtValue, 1, projectIDVal, "pending", now, 0, + "INSERT INTO batch_task_queues (id, title, role, agent_mode, schedule_mode, cron_expr, next_run_at, schedule_enabled, project_id, concurrency, status, created_at, current_index) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)", + queueID, title, role, agentMode, scheduleMode, cronExpr, nextRunAtValue, 1, projectIDVal, concurrency, "pending", now, 0, ) if err != nil { return fmt.Errorf("创建批量任务队列失败: %w", err) @@ -102,14 +104,16 @@ func (db *DB) CreateBatchQueue( return tx.Commit() } +const batchQueueSelectColumns = `id, title, role, agent_mode, schedule_mode, cron_expr, next_run_at, schedule_enabled, last_schedule_trigger_at, last_schedule_error, last_run_error, project_id, concurrency, status, created_at, started_at, completed_at, current_index` + // GetBatchQueue 获取批量任务队列 func (db *DB) GetBatchQueue(queueID string) (*BatchTaskQueueRow, error) { var row BatchTaskQueueRow var createdAt string err := db.QueryRow( - "SELECT id, title, role, agent_mode, schedule_mode, cron_expr, next_run_at, schedule_enabled, last_schedule_trigger_at, last_schedule_error, last_run_error, project_id, status, created_at, started_at, completed_at, current_index FROM batch_task_queues WHERE id = ?", + "SELECT "+batchQueueSelectColumns+" FROM batch_task_queues WHERE id = ?", queueID, - ).Scan(&row.ID, &row.Title, &row.Role, &row.AgentMode, &row.ScheduleMode, &row.CronExpr, &row.NextRunAt, &row.ScheduleEnabled, &row.LastScheduleTriggerAt, &row.LastScheduleError, &row.LastRunError, &row.ProjectID, &row.Status, &createdAt, &row.StartedAt, &row.CompletedAt, &row.CurrentIndex) + ).Scan(&row.ID, &row.Title, &row.Role, &row.AgentMode, &row.ScheduleMode, &row.CronExpr, &row.NextRunAt, &row.ScheduleEnabled, &row.LastScheduleTriggerAt, &row.LastScheduleError, &row.LastRunError, &row.ProjectID, &row.Concurrency, &row.Status, &createdAt, &row.StartedAt, &row.CompletedAt, &row.CurrentIndex) if err == sql.ErrNoRows { return nil, nil } @@ -133,7 +137,7 @@ func (db *DB) GetBatchQueue(queueID string) (*BatchTaskQueueRow, error) { // GetAllBatchQueues 获取所有批量任务队列 func (db *DB) GetAllBatchQueues() ([]*BatchTaskQueueRow, error) { rows, err := db.Query( - "SELECT id, title, role, agent_mode, schedule_mode, cron_expr, next_run_at, schedule_enabled, last_schedule_trigger_at, last_schedule_error, last_run_error, project_id, status, created_at, started_at, completed_at, current_index FROM batch_task_queues ORDER BY created_at DESC", + "SELECT "+batchQueueSelectColumns+" FROM batch_task_queues ORDER BY created_at DESC", ) if err != nil { return nil, fmt.Errorf("查询批量任务队列列表失败: %w", err) @@ -144,7 +148,7 @@ func (db *DB) GetAllBatchQueues() ([]*BatchTaskQueueRow, error) { for rows.Next() { var row BatchTaskQueueRow var createdAt string - if err := rows.Scan(&row.ID, &row.Title, &row.Role, &row.AgentMode, &row.ScheduleMode, &row.CronExpr, &row.NextRunAt, &row.ScheduleEnabled, &row.LastScheduleTriggerAt, &row.LastScheduleError, &row.LastRunError, &row.ProjectID, &row.Status, &createdAt, &row.StartedAt, &row.CompletedAt, &row.CurrentIndex); err != nil { + if err := rows.Scan(&row.ID, &row.Title, &row.Role, &row.AgentMode, &row.ScheduleMode, &row.CronExpr, &row.NextRunAt, &row.ScheduleEnabled, &row.LastScheduleTriggerAt, &row.LastScheduleError, &row.LastRunError, &row.ProjectID, &row.Concurrency, &row.Status, &createdAt, &row.StartedAt, &row.CompletedAt, &row.CurrentIndex); err != nil { return nil, fmt.Errorf("扫描批量任务队列失败: %w", err) } parsedTime, parseErr := time.Parse("2006-01-02 15:04:05", createdAt) @@ -164,7 +168,7 @@ func (db *DB) GetAllBatchQueues() ([]*BatchTaskQueueRow, error) { // ListBatchQueues 列出批量任务队列(支持筛选和分页) func (db *DB) ListBatchQueues(limit, offset int, status, keyword string) ([]*BatchTaskQueueRow, error) { - query := "SELECT id, title, role, agent_mode, schedule_mode, cron_expr, next_run_at, schedule_enabled, last_schedule_trigger_at, last_schedule_error, last_run_error, project_id, status, created_at, started_at, completed_at, current_index FROM batch_task_queues WHERE 1=1" + query := "SELECT " + batchQueueSelectColumns + " FROM batch_task_queues WHERE 1=1" args := []interface{}{} // 状态筛选 @@ -192,7 +196,7 @@ func (db *DB) ListBatchQueues(limit, offset int, status, keyword string) ([]*Bat for rows.Next() { var row BatchTaskQueueRow var createdAt string - if err := rows.Scan(&row.ID, &row.Title, &row.Role, &row.AgentMode, &row.ScheduleMode, &row.CronExpr, &row.NextRunAt, &row.ScheduleEnabled, &row.LastScheduleTriggerAt, &row.LastScheduleError, &row.LastRunError, &row.ProjectID, &row.Status, &createdAt, &row.StartedAt, &row.CompletedAt, &row.CurrentIndex); err != nil { + if err := rows.Scan(&row.ID, &row.Title, &row.Role, &row.AgentMode, &row.ScheduleMode, &row.CronExpr, &row.NextRunAt, &row.ScheduleEnabled, &row.LastScheduleTriggerAt, &row.LastScheduleError, &row.LastRunError, &row.ProjectID, &row.Concurrency, &row.Status, &createdAt, &row.StartedAt, &row.CompletedAt, &row.CurrentIndex); err != nil { return nil, fmt.Errorf("扫描批量任务队列失败: %w", err) } parsedTime, parseErr := time.Parse("2006-01-02 15:04:05", createdAt) @@ -358,11 +362,11 @@ func (db *DB) UpdateBatchQueueCurrentIndex(queueID string, currentIndex int) err return nil } -// UpdateBatchQueueMetadata 更新批量任务队列标题、角色和代理模式 -func (db *DB) UpdateBatchQueueMetadata(queueID, title, role, agentMode string) error { +// UpdateBatchQueueMetadata 更新批量任务队列标题、角色、代理模式和并发数 +func (db *DB) UpdateBatchQueueMetadata(queueID, title, role, agentMode string, concurrency int) error { _, err := db.Exec( - "UPDATE batch_task_queues SET title = ?, role = ?, agent_mode = ? WHERE id = ?", - title, role, agentMode, queueID, + "UPDATE batch_task_queues SET title = ?, role = ?, agent_mode = ?, concurrency = ? WHERE id = ?", + title, role, agentMode, concurrency, queueID, ) if err != nil { return fmt.Errorf("更新批量任务队列元数据失败: %w", err) diff --git a/internal/database/database.go b/internal/database/database.go index 0ffbd2b8..55661e56 100644 --- a/internal/database/database.go +++ b/internal/database/database.go @@ -408,6 +408,8 @@ func (db *DB) initTables() error { last_schedule_trigger_at DATETIME, last_schedule_error TEXT, last_run_error TEXT, + project_id TEXT, + concurrency INTEGER NOT NULL DEFAULT 1, status TEXT NOT NULL, created_at DATETIME NOT NULL, started_at DATETIME, @@ -1137,6 +1139,21 @@ func (db *DB) migrateBatchTaskQueuesTable() error { } } + var concurrencyCount int + err = db.QueryRow("SELECT COUNT(*) FROM pragma_table_info('batch_task_queues') WHERE name='concurrency'").Scan(&concurrencyCount) + if err != nil { + if _, addErr := db.Exec("ALTER TABLE batch_task_queues ADD COLUMN concurrency INTEGER NOT NULL DEFAULT 1"); addErr != nil { + errMsg := strings.ToLower(addErr.Error()) + if !strings.Contains(errMsg, "duplicate column") && !strings.Contains(errMsg, "already exists") { + db.logger.Warn("添加batch_task_queues.concurrency字段失败", zap.Error(addErr)) + } + } + } else if concurrencyCount == 0 { + if _, err := db.Exec("ALTER TABLE batch_task_queues ADD COLUMN concurrency INTEGER NOT NULL DEFAULT 1"); err != nil { + db.logger.Warn("添加batch_task_queues.concurrency字段失败", zap.Error(err)) + } + } + return nil } diff --git a/internal/multiagent/eino_adk_run_loop.go b/internal/multiagent/eino_adk_run_loop.go index 358a933d..59176c8e 100644 --- a/internal/multiagent/eino_adk_run_loop.go +++ b/internal/multiagent/eino_adk_run_loop.go @@ -90,7 +90,7 @@ type einoADKRunLoopArgs struct { FilesystemMonitorRecord einomcp.ExecutionRecorder MCPExecutionBinder *MCPExecutionBinder - // ToolInvokeNotify 与 einomcp.ToolsFromDefinitions 共享:run loop 在迭代前 Set,MCP 桥 Fire 以补全 tool_result。 + // ToolInvokeNotify 与 einomcp.ToolsFromDefinitions 共享:run loop 在迭代前 Set,execute/MCP 桥 Fire 时立即推送 tool_result(ADK 晚到经 toolResultSent 去重)。 ToolInvokeNotify *einomcp.ToolInvokeNotifyHolder DA adk.Agent @@ -341,8 +341,22 @@ func runEinoADKAgentLoop(ctx context.Context, args *einoADKRunLoopArgs, baseMsgs } if args.ToolInvokeNotify != nil { args.ToolInvokeNotify.Set(func(toolCallID, toolName, einoAgent string, success bool, content string, invokeErr error) { - removePendingByID(strings.TrimSpace(toolCallID)) - // tool_result 仅由下方 ADK schema.Tool 事件推送,正文与送入模型的上下文一致(含 reduction 截断)。 + // Eino execute / MCP 桥在工具返回时 Fire;若 ADK schema.Tool 事件迟迟不到,此处立即推送 + // tool_result 解除 UI「执行中」。tryEmitToolResultProgress 经 toolResultSent 去重,ADK 晚到不重复。 + isErr := !success || invokeErr != nil + body := content + if strings.HasPrefix(body, einomcp.ToolErrorPrefix) { + isErr = true + body = strings.TrimPrefix(body, einomcp.ToolErrorPrefix) + } + if tail := friendlyEinoExecuteInvokeTail(invokeErr); tail != "" { + if body == "" { + body = tail + } else if !strings.Contains(body, tail) { + body = strings.TrimSpace(body) + "\n\n" + tail + } + } + tryEmitToolResultProgress(toolName, body, toolCallID, isErr, einoAgent) }) } @@ -551,10 +565,10 @@ func runEinoADKAgentLoop(ctx context.Context, args *einoADKRunLoopArgs, baseMsgs } for { - // 检测 context 取消(用户关闭浏览器、请求超时等),flush pending 工具状态避免 UI 卡在 "执行中"。 - select { - case <-ctx.Done(): - flushAllPendingAsFailed(ctx.Err()) + // iter.Next 可能长时间阻塞(工具执行、模型推理);须与 ctx 联动,否则取消/超时无法及时 flush pending。 + ev, ok, iterCtxErr := nextAgentEventWithContext(ctx, iter) + if iterCtxErr != nil { + flushAllPendingAsFailed(iterCtxErr) if progress != nil { if isInterruptContinue(ctx) { progress("progress", "已暂停当前输出,正在合并用户补充并继续…", map[string]interface{}{ @@ -563,17 +577,14 @@ func runEinoADKAgentLoop(ctx context.Context, args *einoADKRunLoopArgs, baseMsgs "kind": "interrupt_continue", }) } else { - progress("error", "Request cancelled / 请求已取消", map[string]interface{}{ + progress("error", iterCtxErr.Error(), map[string]interface{}{ "conversationId": conversationID, "source": "eino", }) } } - return takePartial(ctx.Err()) - default: + return takePartial(iterCtxErr) } - - ev, ok := iter.Next() if !ok { // iter 结束并不总是“正常完成”: // 当取消/超时发生在 iter.Next() 阻塞期间时,可能直接返回 !ok。 @@ -691,29 +702,7 @@ func runEinoADKAgentLoop(ctx context.Context, args *einoADKRunLoopArgs, baseMsgs if mv.IsStreaming && mv.MessageStream != nil && mv.Role == schema.Tool { toolName := strings.TrimSpace(mv.ToolName) - var toolBuf strings.Builder - streamToolCallID := "" - var toolStreamRecvErr error - for { - chunk, rerr := mv.MessageStream.Recv() - if errors.Is(rerr, io.EOF) { - break - } - if rerr != nil { - toolStreamRecvErr = rerr - break - } - if chunk == nil { - continue - } - if chunk.Content != "" { - toolBuf.WriteString(chunk.Content) - } - if tid := strings.TrimSpace(chunk.ToolCallID); tid != "" { - streamToolCallID = tid - } - } - content := toolBuf.String() + content, streamToolCallID, toolStreamRecvErr := recvSchemaMessageStream(ctx, mv.MessageStream) isErr := false if strings.HasPrefix(content, einomcp.ToolErrorPrefix) { isErr = true @@ -1132,6 +1121,78 @@ func friendlyEinoExecuteInvokeTail(invokeErr error) string { return "[执行未正常结束] " + invokeErr.Error() } +// nextAgentEventWithContext 在 ctx 取消时不再无限阻塞于 iter.Next()(工具执行/模型推理期间常见)。 +func nextAgentEventWithContext(ctx context.Context, iter *adk.AsyncIterator[*adk.AgentEvent]) (ev *adk.AgentEvent, ok bool, ctxErr error) { + if iter == nil { + return nil, false, nil + } + type nextRes struct { + ev *adk.AgentEvent + ok bool + } + ch := make(chan nextRes, 1) + go func() { + e, o := iter.Next() + ch <- nextRes{e, o} + }() + select { + case <-ctx.Done(): + return nil, false, ctx.Err() + case res := <-ch: + return res.ev, res.ok, nil + } +} + +// recvSchemaMessageStream 消费 ADK Tool 流式结果;ctx 取消时立即返回,避免 amass 等无输出时永久阻塞。 +func recvSchemaMessageStream(ctx context.Context, stream *schema.StreamReader[*schema.Message]) (content, toolCallID string, recvErr error) { + if stream == nil { + return "", "", nil + } + type streamMsg struct { + chunk *schema.Message + err error + } + recvCh := make(chan streamMsg, 8) + go func() { + defer close(recvCh) + for { + ch, rerr := stream.Recv() + recvCh <- streamMsg{chunk: ch, err: rerr} + if rerr != nil { + return + } + } + }() + var buf strings.Builder + for { + select { + case <-ctx.Done(): + return buf.String(), toolCallID, ctx.Err() + case sm, open := <-recvCh: + if !open { + return buf.String(), toolCallID, nil + } + rerr := sm.err + if errors.Is(rerr, io.EOF) { + return buf.String(), toolCallID, nil + } + if rerr != nil { + return buf.String(), toolCallID, rerr + } + chunk := sm.chunk + if chunk == nil { + continue + } + if chunk.Content != "" { + buf.WriteString(chunk.Content) + } + if tid := strings.TrimSpace(chunk.ToolCallID); tid != "" { + toolCallID = tid + } + } + } +} + func buildEinoRunResultFromAccumulated( orchMode string, runAccumulatedMsgs []adk.Message, diff --git a/internal/multiagent/eino_adk_run_loop_stream_test.go b/internal/multiagent/eino_adk_run_loop_stream_test.go new file mode 100644 index 00000000..4c216938 --- /dev/null +++ b/internal/multiagent/eino_adk_run_loop_stream_test.go @@ -0,0 +1,74 @@ +package multiagent + +import ( + "context" + "errors" + "io" + "testing" + "time" + + "github.com/cloudwego/eino/schema" +) + +func TestRecvSchemaMessageStream_EOF(t *testing.T) { + sr, sw := schema.Pipe[*schema.Message](4) + _ = sw.Send(schema.ToolMessage("hello", "tc-1"), nil) + sw.Close() + + content, tid, err := recvSchemaMessageStream(context.Background(), sr) + if err != nil { + t.Fatalf("unexpected err: %v", err) + } + if content != "hello" { + t.Fatalf("content=%q want hello", content) + } + if tid != "tc-1" { + t.Fatalf("toolCallID=%q want tc-1", tid) + } +} + +func TestRecvSchemaMessageStream_ContextCancel(t *testing.T) { + sr, sw := schema.Pipe[*schema.Message](4) + t.Cleanup(func() { sw.Close() }) + + ctx, cancel := context.WithCancel(context.Background()) + go func() { + time.Sleep(30 * time.Millisecond) + cancel() + }() + + content, _, err := recvSchemaMessageStream(ctx, sr) + if !errors.Is(err, context.Canceled) { + t.Fatalf("want context.Canceled, got %v content=%q", err, content) + } +} + +func TestRecvSchemaMessageStream_RecvError(t *testing.T) { + sr, sw := schema.Pipe[*schema.Message](4) + want := errors.New("stream broken") + _ = sw.Send(nil, want) + sw.Close() + + _, _, err := recvSchemaMessageStream(context.Background(), sr) + if !errors.Is(err, want) { + t.Fatalf("want %v, got %v", want, err) + } +} + +func TestRecvSchemaMessageStream_NilStream(t *testing.T) { + content, tid, err := recvSchemaMessageStream(context.Background(), nil) + if err != nil || content != "" || tid != "" { + t.Fatalf("nil stream: content=%q tid=%q err=%v", content, tid, err) + } +} + +func TestRecvSchemaMessageStream_EOFViaEmptyRead(t *testing.T) { + sr, sw := schema.Pipe[*schema.Message](4) + _ = sw.Send(nil, io.EOF) + sw.Close() + + _, _, err := recvSchemaMessageStream(context.Background(), sr) + if err != nil { + t.Fatalf("EOF should not surface as error, got %v", err) + } +} diff --git a/internal/multiagent/eino_execute_streaming_wrap.go b/internal/multiagent/eino_execute_streaming_wrap.go index 016eecc6..1af004e3 100644 --- a/internal/multiagent/eino_execute_streaming_wrap.go +++ b/internal/multiagent/eino_execute_streaming_wrap.go @@ -51,7 +51,7 @@ func einoExecuteRecvErrIsToolTimeout(rerr error, tctx context.Context) bool { // 对「完全后台」命令自动开启 RunInBackendGround,与 local.runCmdInBackground 行为对齐。 // // 使用 Pipe 将内层流转发给调用方:在 inner EOF 后、关闭 Pipe 前同步调用 ToolInvokeNotify.Fire, -// 保证 run loop 在模型开始下一轮输出前已记录 execute 结果(用于 UI 与「重复助手复述」去重)。 +// run loop 收到 Fire 后立即推送 tool_result(toolResultSent 去重),避免 ADK Tool 事件迟到时 UI 卡在「执行中」。 // // 若 inner 在校验阶段直接返回 error(未建立 reader),不会进入下方 goroutine,也必须 Fire; // 否则 pending tool_call 要等整轮 run 结束才被 force-close,与已展示的助手/工具软错误文案不同步。