mirror of
https://github.com/Ed1s0nZ/CyberStrikeAI.git
synced 2026-06-24 06:49:59 +02:00
Add files via upload
This commit is contained in:
@@ -23,6 +23,7 @@ type BatchTaskQueueRow struct {
|
||||
LastScheduleError sql.NullString
|
||||
LastRunError sql.NullString
|
||||
ProjectID sql.NullString
|
||||
Concurrency sql.NullInt64
|
||||
Status string
|
||||
CreatedAt time.Time
|
||||
StartedAt sql.NullTime
|
||||
@@ -53,6 +54,7 @@ func (db *DB) CreateBatchQueue(
|
||||
cronExpr string,
|
||||
nextRunAt *time.Time,
|
||||
projectID string,
|
||||
concurrency int,
|
||||
tasks []map[string]interface{},
|
||||
) error {
|
||||
tx, err := db.Begin()
|
||||
@@ -72,8 +74,8 @@ func (db *DB) CreateBatchQueue(
|
||||
projectIDVal = strings.TrimSpace(projectID)
|
||||
}
|
||||
_, err = tx.Exec(
|
||||
"INSERT INTO batch_task_queues (id, title, role, agent_mode, schedule_mode, cron_expr, next_run_at, schedule_enabled, project_id, status, created_at, current_index) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
|
||||
queueID, title, role, agentMode, scheduleMode, cronExpr, nextRunAtValue, 1, projectIDVal, "pending", now, 0,
|
||||
"INSERT INTO batch_task_queues (id, title, role, agent_mode, schedule_mode, cron_expr, next_run_at, schedule_enabled, project_id, concurrency, status, created_at, current_index) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
|
||||
queueID, title, role, agentMode, scheduleMode, cronExpr, nextRunAtValue, 1, projectIDVal, concurrency, "pending", now, 0,
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("创建批量任务队列失败: %w", err)
|
||||
@@ -102,14 +104,16 @@ func (db *DB) CreateBatchQueue(
|
||||
return tx.Commit()
|
||||
}
|
||||
|
||||
const batchQueueSelectColumns = `id, title, role, agent_mode, schedule_mode, cron_expr, next_run_at, schedule_enabled, last_schedule_trigger_at, last_schedule_error, last_run_error, project_id, concurrency, status, created_at, started_at, completed_at, current_index`
|
||||
|
||||
// GetBatchQueue 获取批量任务队列
|
||||
func (db *DB) GetBatchQueue(queueID string) (*BatchTaskQueueRow, error) {
|
||||
var row BatchTaskQueueRow
|
||||
var createdAt string
|
||||
err := db.QueryRow(
|
||||
"SELECT id, title, role, agent_mode, schedule_mode, cron_expr, next_run_at, schedule_enabled, last_schedule_trigger_at, last_schedule_error, last_run_error, project_id, status, created_at, started_at, completed_at, current_index FROM batch_task_queues WHERE id = ?",
|
||||
"SELECT "+batchQueueSelectColumns+" FROM batch_task_queues WHERE id = ?",
|
||||
queueID,
|
||||
).Scan(&row.ID, &row.Title, &row.Role, &row.AgentMode, &row.ScheduleMode, &row.CronExpr, &row.NextRunAt, &row.ScheduleEnabled, &row.LastScheduleTriggerAt, &row.LastScheduleError, &row.LastRunError, &row.ProjectID, &row.Status, &createdAt, &row.StartedAt, &row.CompletedAt, &row.CurrentIndex)
|
||||
).Scan(&row.ID, &row.Title, &row.Role, &row.AgentMode, &row.ScheduleMode, &row.CronExpr, &row.NextRunAt, &row.ScheduleEnabled, &row.LastScheduleTriggerAt, &row.LastScheduleError, &row.LastRunError, &row.ProjectID, &row.Concurrency, &row.Status, &createdAt, &row.StartedAt, &row.CompletedAt, &row.CurrentIndex)
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, nil
|
||||
}
|
||||
@@ -133,7 +137,7 @@ func (db *DB) GetBatchQueue(queueID string) (*BatchTaskQueueRow, error) {
|
||||
// GetAllBatchQueues 获取所有批量任务队列
|
||||
func (db *DB) GetAllBatchQueues() ([]*BatchTaskQueueRow, error) {
|
||||
rows, err := db.Query(
|
||||
"SELECT id, title, role, agent_mode, schedule_mode, cron_expr, next_run_at, schedule_enabled, last_schedule_trigger_at, last_schedule_error, last_run_error, project_id, status, created_at, started_at, completed_at, current_index FROM batch_task_queues ORDER BY created_at DESC",
|
||||
"SELECT "+batchQueueSelectColumns+" FROM batch_task_queues ORDER BY created_at DESC",
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("查询批量任务队列列表失败: %w", err)
|
||||
@@ -144,7 +148,7 @@ func (db *DB) GetAllBatchQueues() ([]*BatchTaskQueueRow, error) {
|
||||
for rows.Next() {
|
||||
var row BatchTaskQueueRow
|
||||
var createdAt string
|
||||
if err := rows.Scan(&row.ID, &row.Title, &row.Role, &row.AgentMode, &row.ScheduleMode, &row.CronExpr, &row.NextRunAt, &row.ScheduleEnabled, &row.LastScheduleTriggerAt, &row.LastScheduleError, &row.LastRunError, &row.ProjectID, &row.Status, &createdAt, &row.StartedAt, &row.CompletedAt, &row.CurrentIndex); err != nil {
|
||||
if err := rows.Scan(&row.ID, &row.Title, &row.Role, &row.AgentMode, &row.ScheduleMode, &row.CronExpr, &row.NextRunAt, &row.ScheduleEnabled, &row.LastScheduleTriggerAt, &row.LastScheduleError, &row.LastRunError, &row.ProjectID, &row.Concurrency, &row.Status, &createdAt, &row.StartedAt, &row.CompletedAt, &row.CurrentIndex); err != nil {
|
||||
return nil, fmt.Errorf("扫描批量任务队列失败: %w", err)
|
||||
}
|
||||
parsedTime, parseErr := time.Parse("2006-01-02 15:04:05", createdAt)
|
||||
@@ -164,7 +168,7 @@ func (db *DB) GetAllBatchQueues() ([]*BatchTaskQueueRow, error) {
|
||||
|
||||
// ListBatchQueues 列出批量任务队列(支持筛选和分页)
|
||||
func (db *DB) ListBatchQueues(limit, offset int, status, keyword string) ([]*BatchTaskQueueRow, error) {
|
||||
query := "SELECT id, title, role, agent_mode, schedule_mode, cron_expr, next_run_at, schedule_enabled, last_schedule_trigger_at, last_schedule_error, last_run_error, project_id, status, created_at, started_at, completed_at, current_index FROM batch_task_queues WHERE 1=1"
|
||||
query := "SELECT " + batchQueueSelectColumns + " FROM batch_task_queues WHERE 1=1"
|
||||
args := []interface{}{}
|
||||
|
||||
// 状态筛选
|
||||
@@ -192,7 +196,7 @@ func (db *DB) ListBatchQueues(limit, offset int, status, keyword string) ([]*Bat
|
||||
for rows.Next() {
|
||||
var row BatchTaskQueueRow
|
||||
var createdAt string
|
||||
if err := rows.Scan(&row.ID, &row.Title, &row.Role, &row.AgentMode, &row.ScheduleMode, &row.CronExpr, &row.NextRunAt, &row.ScheduleEnabled, &row.LastScheduleTriggerAt, &row.LastScheduleError, &row.LastRunError, &row.ProjectID, &row.Status, &createdAt, &row.StartedAt, &row.CompletedAt, &row.CurrentIndex); err != nil {
|
||||
if err := rows.Scan(&row.ID, &row.Title, &row.Role, &row.AgentMode, &row.ScheduleMode, &row.CronExpr, &row.NextRunAt, &row.ScheduleEnabled, &row.LastScheduleTriggerAt, &row.LastScheduleError, &row.LastRunError, &row.ProjectID, &row.Concurrency, &row.Status, &createdAt, &row.StartedAt, &row.CompletedAt, &row.CurrentIndex); err != nil {
|
||||
return nil, fmt.Errorf("扫描批量任务队列失败: %w", err)
|
||||
}
|
||||
parsedTime, parseErr := time.Parse("2006-01-02 15:04:05", createdAt)
|
||||
@@ -358,11 +362,11 @@ func (db *DB) UpdateBatchQueueCurrentIndex(queueID string, currentIndex int) err
|
||||
return nil
|
||||
}
|
||||
|
||||
// UpdateBatchQueueMetadata 更新批量任务队列标题、角色和代理模式
|
||||
func (db *DB) UpdateBatchQueueMetadata(queueID, title, role, agentMode string) error {
|
||||
// UpdateBatchQueueMetadata 更新批量任务队列标题、角色、代理模式和并发数
|
||||
func (db *DB) UpdateBatchQueueMetadata(queueID, title, role, agentMode string, concurrency int) error {
|
||||
_, err := db.Exec(
|
||||
"UPDATE batch_task_queues SET title = ?, role = ?, agent_mode = ? WHERE id = ?",
|
||||
title, role, agentMode, queueID,
|
||||
"UPDATE batch_task_queues SET title = ?, role = ?, agent_mode = ?, concurrency = ? WHERE id = ?",
|
||||
title, role, agentMode, concurrency, queueID,
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("更新批量任务队列元数据失败: %w", err)
|
||||
|
||||
@@ -408,6 +408,8 @@ func (db *DB) initTables() error {
|
||||
last_schedule_trigger_at DATETIME,
|
||||
last_schedule_error TEXT,
|
||||
last_run_error TEXT,
|
||||
project_id TEXT,
|
||||
concurrency INTEGER NOT NULL DEFAULT 1,
|
||||
status TEXT NOT NULL,
|
||||
created_at DATETIME NOT NULL,
|
||||
started_at DATETIME,
|
||||
@@ -1137,6 +1139,21 @@ func (db *DB) migrateBatchTaskQueuesTable() error {
|
||||
}
|
||||
}
|
||||
|
||||
var concurrencyCount int
|
||||
err = db.QueryRow("SELECT COUNT(*) FROM pragma_table_info('batch_task_queues') WHERE name='concurrency'").Scan(&concurrencyCount)
|
||||
if err != nil {
|
||||
if _, addErr := db.Exec("ALTER TABLE batch_task_queues ADD COLUMN concurrency INTEGER NOT NULL DEFAULT 1"); addErr != nil {
|
||||
errMsg := strings.ToLower(addErr.Error())
|
||||
if !strings.Contains(errMsg, "duplicate column") && !strings.Contains(errMsg, "already exists") {
|
||||
db.logger.Warn("添加batch_task_queues.concurrency字段失败", zap.Error(addErr))
|
||||
}
|
||||
}
|
||||
} else if concurrencyCount == 0 {
|
||||
if _, err := db.Exec("ALTER TABLE batch_task_queues ADD COLUMN concurrency INTEGER NOT NULL DEFAULT 1"); err != nil {
|
||||
db.logger.Warn("添加batch_task_queues.concurrency字段失败", zap.Error(err))
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
@@ -90,7 +90,7 @@ type einoADKRunLoopArgs struct {
|
||||
FilesystemMonitorRecord einomcp.ExecutionRecorder
|
||||
MCPExecutionBinder *MCPExecutionBinder
|
||||
|
||||
// ToolInvokeNotify 与 einomcp.ToolsFromDefinitions 共享:run loop 在迭代前 Set,MCP 桥 Fire 以补全 tool_result。
|
||||
// ToolInvokeNotify 与 einomcp.ToolsFromDefinitions 共享:run loop 在迭代前 Set,execute/MCP 桥 Fire 时立即推送 tool_result(ADK 晚到经 toolResultSent 去重)。
|
||||
ToolInvokeNotify *einomcp.ToolInvokeNotifyHolder
|
||||
|
||||
DA adk.Agent
|
||||
@@ -341,8 +341,22 @@ func runEinoADKAgentLoop(ctx context.Context, args *einoADKRunLoopArgs, baseMsgs
|
||||
}
|
||||
if args.ToolInvokeNotify != nil {
|
||||
args.ToolInvokeNotify.Set(func(toolCallID, toolName, einoAgent string, success bool, content string, invokeErr error) {
|
||||
removePendingByID(strings.TrimSpace(toolCallID))
|
||||
// tool_result 仅由下方 ADK schema.Tool 事件推送,正文与送入模型的上下文一致(含 reduction 截断)。
|
||||
// Eino execute / MCP 桥在工具返回时 Fire;若 ADK schema.Tool 事件迟迟不到,此处立即推送
|
||||
// tool_result 解除 UI「执行中」。tryEmitToolResultProgress 经 toolResultSent 去重,ADK 晚到不重复。
|
||||
isErr := !success || invokeErr != nil
|
||||
body := content
|
||||
if strings.HasPrefix(body, einomcp.ToolErrorPrefix) {
|
||||
isErr = true
|
||||
body = strings.TrimPrefix(body, einomcp.ToolErrorPrefix)
|
||||
}
|
||||
if tail := friendlyEinoExecuteInvokeTail(invokeErr); tail != "" {
|
||||
if body == "" {
|
||||
body = tail
|
||||
} else if !strings.Contains(body, tail) {
|
||||
body = strings.TrimSpace(body) + "\n\n" + tail
|
||||
}
|
||||
}
|
||||
tryEmitToolResultProgress(toolName, body, toolCallID, isErr, einoAgent)
|
||||
})
|
||||
}
|
||||
|
||||
@@ -551,10 +565,10 @@ func runEinoADKAgentLoop(ctx context.Context, args *einoADKRunLoopArgs, baseMsgs
|
||||
}
|
||||
|
||||
for {
|
||||
// 检测 context 取消(用户关闭浏览器、请求超时等),flush pending 工具状态避免 UI 卡在 "执行中"。
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
flushAllPendingAsFailed(ctx.Err())
|
||||
// iter.Next 可能长时间阻塞(工具执行、模型推理);须与 ctx 联动,否则取消/超时无法及时 flush pending。
|
||||
ev, ok, iterCtxErr := nextAgentEventWithContext(ctx, iter)
|
||||
if iterCtxErr != nil {
|
||||
flushAllPendingAsFailed(iterCtxErr)
|
||||
if progress != nil {
|
||||
if isInterruptContinue(ctx) {
|
||||
progress("progress", "已暂停当前输出,正在合并用户补充并继续…", map[string]interface{}{
|
||||
@@ -563,17 +577,14 @@ func runEinoADKAgentLoop(ctx context.Context, args *einoADKRunLoopArgs, baseMsgs
|
||||
"kind": "interrupt_continue",
|
||||
})
|
||||
} else {
|
||||
progress("error", "Request cancelled / 请求已取消", map[string]interface{}{
|
||||
progress("error", iterCtxErr.Error(), map[string]interface{}{
|
||||
"conversationId": conversationID,
|
||||
"source": "eino",
|
||||
})
|
||||
}
|
||||
}
|
||||
return takePartial(ctx.Err())
|
||||
default:
|
||||
return takePartial(iterCtxErr)
|
||||
}
|
||||
|
||||
ev, ok := iter.Next()
|
||||
if !ok {
|
||||
// iter 结束并不总是“正常完成”:
|
||||
// 当取消/超时发生在 iter.Next() 阻塞期间时,可能直接返回 !ok。
|
||||
@@ -691,29 +702,7 @@ func runEinoADKAgentLoop(ctx context.Context, args *einoADKRunLoopArgs, baseMsgs
|
||||
|
||||
if mv.IsStreaming && mv.MessageStream != nil && mv.Role == schema.Tool {
|
||||
toolName := strings.TrimSpace(mv.ToolName)
|
||||
var toolBuf strings.Builder
|
||||
streamToolCallID := ""
|
||||
var toolStreamRecvErr error
|
||||
for {
|
||||
chunk, rerr := mv.MessageStream.Recv()
|
||||
if errors.Is(rerr, io.EOF) {
|
||||
break
|
||||
}
|
||||
if rerr != nil {
|
||||
toolStreamRecvErr = rerr
|
||||
break
|
||||
}
|
||||
if chunk == nil {
|
||||
continue
|
||||
}
|
||||
if chunk.Content != "" {
|
||||
toolBuf.WriteString(chunk.Content)
|
||||
}
|
||||
if tid := strings.TrimSpace(chunk.ToolCallID); tid != "" {
|
||||
streamToolCallID = tid
|
||||
}
|
||||
}
|
||||
content := toolBuf.String()
|
||||
content, streamToolCallID, toolStreamRecvErr := recvSchemaMessageStream(ctx, mv.MessageStream)
|
||||
isErr := false
|
||||
if strings.HasPrefix(content, einomcp.ToolErrorPrefix) {
|
||||
isErr = true
|
||||
@@ -1132,6 +1121,78 @@ func friendlyEinoExecuteInvokeTail(invokeErr error) string {
|
||||
return "[执行未正常结束] " + invokeErr.Error()
|
||||
}
|
||||
|
||||
// nextAgentEventWithContext 在 ctx 取消时不再无限阻塞于 iter.Next()(工具执行/模型推理期间常见)。
|
||||
func nextAgentEventWithContext(ctx context.Context, iter *adk.AsyncIterator[*adk.AgentEvent]) (ev *adk.AgentEvent, ok bool, ctxErr error) {
|
||||
if iter == nil {
|
||||
return nil, false, nil
|
||||
}
|
||||
type nextRes struct {
|
||||
ev *adk.AgentEvent
|
||||
ok bool
|
||||
}
|
||||
ch := make(chan nextRes, 1)
|
||||
go func() {
|
||||
e, o := iter.Next()
|
||||
ch <- nextRes{e, o}
|
||||
}()
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, false, ctx.Err()
|
||||
case res := <-ch:
|
||||
return res.ev, res.ok, nil
|
||||
}
|
||||
}
|
||||
|
||||
// recvSchemaMessageStream 消费 ADK Tool 流式结果;ctx 取消时立即返回,避免 amass 等无输出时永久阻塞。
|
||||
func recvSchemaMessageStream(ctx context.Context, stream *schema.StreamReader[*schema.Message]) (content, toolCallID string, recvErr error) {
|
||||
if stream == nil {
|
||||
return "", "", nil
|
||||
}
|
||||
type streamMsg struct {
|
||||
chunk *schema.Message
|
||||
err error
|
||||
}
|
||||
recvCh := make(chan streamMsg, 8)
|
||||
go func() {
|
||||
defer close(recvCh)
|
||||
for {
|
||||
ch, rerr := stream.Recv()
|
||||
recvCh <- streamMsg{chunk: ch, err: rerr}
|
||||
if rerr != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
var buf strings.Builder
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return buf.String(), toolCallID, ctx.Err()
|
||||
case sm, open := <-recvCh:
|
||||
if !open {
|
||||
return buf.String(), toolCallID, nil
|
||||
}
|
||||
rerr := sm.err
|
||||
if errors.Is(rerr, io.EOF) {
|
||||
return buf.String(), toolCallID, nil
|
||||
}
|
||||
if rerr != nil {
|
||||
return buf.String(), toolCallID, rerr
|
||||
}
|
||||
chunk := sm.chunk
|
||||
if chunk == nil {
|
||||
continue
|
||||
}
|
||||
if chunk.Content != "" {
|
||||
buf.WriteString(chunk.Content)
|
||||
}
|
||||
if tid := strings.TrimSpace(chunk.ToolCallID); tid != "" {
|
||||
toolCallID = tid
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func buildEinoRunResultFromAccumulated(
|
||||
orchMode string,
|
||||
runAccumulatedMsgs []adk.Message,
|
||||
|
||||
@@ -0,0 +1,74 @@
|
||||
package multiagent
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"io"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/cloudwego/eino/schema"
|
||||
)
|
||||
|
||||
func TestRecvSchemaMessageStream_EOF(t *testing.T) {
|
||||
sr, sw := schema.Pipe[*schema.Message](4)
|
||||
_ = sw.Send(schema.ToolMessage("hello", "tc-1"), nil)
|
||||
sw.Close()
|
||||
|
||||
content, tid, err := recvSchemaMessageStream(context.Background(), sr)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected err: %v", err)
|
||||
}
|
||||
if content != "hello" {
|
||||
t.Fatalf("content=%q want hello", content)
|
||||
}
|
||||
if tid != "tc-1" {
|
||||
t.Fatalf("toolCallID=%q want tc-1", tid)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRecvSchemaMessageStream_ContextCancel(t *testing.T) {
|
||||
sr, sw := schema.Pipe[*schema.Message](4)
|
||||
t.Cleanup(func() { sw.Close() })
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
go func() {
|
||||
time.Sleep(30 * time.Millisecond)
|
||||
cancel()
|
||||
}()
|
||||
|
||||
content, _, err := recvSchemaMessageStream(ctx, sr)
|
||||
if !errors.Is(err, context.Canceled) {
|
||||
t.Fatalf("want context.Canceled, got %v content=%q", err, content)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRecvSchemaMessageStream_RecvError(t *testing.T) {
|
||||
sr, sw := schema.Pipe[*schema.Message](4)
|
||||
want := errors.New("stream broken")
|
||||
_ = sw.Send(nil, want)
|
||||
sw.Close()
|
||||
|
||||
_, _, err := recvSchemaMessageStream(context.Background(), sr)
|
||||
if !errors.Is(err, want) {
|
||||
t.Fatalf("want %v, got %v", want, err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRecvSchemaMessageStream_NilStream(t *testing.T) {
|
||||
content, tid, err := recvSchemaMessageStream(context.Background(), nil)
|
||||
if err != nil || content != "" || tid != "" {
|
||||
t.Fatalf("nil stream: content=%q tid=%q err=%v", content, tid, err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRecvSchemaMessageStream_EOFViaEmptyRead(t *testing.T) {
|
||||
sr, sw := schema.Pipe[*schema.Message](4)
|
||||
_ = sw.Send(nil, io.EOF)
|
||||
sw.Close()
|
||||
|
||||
_, _, err := recvSchemaMessageStream(context.Background(), sr)
|
||||
if err != nil {
|
||||
t.Fatalf("EOF should not surface as error, got %v", err)
|
||||
}
|
||||
}
|
||||
@@ -51,7 +51,7 @@ func einoExecuteRecvErrIsToolTimeout(rerr error, tctx context.Context) bool {
|
||||
// 对「完全后台」命令自动开启 RunInBackendGround,与 local.runCmdInBackground 行为对齐。
|
||||
//
|
||||
// 使用 Pipe 将内层流转发给调用方:在 inner EOF 后、关闭 Pipe 前同步调用 ToolInvokeNotify.Fire,
|
||||
// 保证 run loop 在模型开始下一轮输出前已记录 execute 结果(用于 UI 与「重复助手复述」去重)。
|
||||
// run loop 收到 Fire 后立即推送 tool_result(toolResultSent 去重),避免 ADK Tool 事件迟到时 UI 卡在「执行中」。
|
||||
//
|
||||
// 若 inner 在校验阶段直接返回 error(未建立 reader),不会进入下方 goroutine,也必须 Fire;
|
||||
// 否则 pending tool_call 要等整轮 run 结束才被 force-close,与已展示的助手/工具软错误文案不同步。
|
||||
|
||||
Reference in New Issue
Block a user