diff --git a/internal/multiagent/eino_execute_streaming_wrap.go b/internal/multiagent/eino_execute_streaming_wrap.go index a69586f7..016eecc6 100644 --- a/internal/multiagent/eino_execute_streaming_wrap.go +++ b/internal/multiagent/eino_execute_streaming_wrap.go @@ -6,9 +6,11 @@ import ( "fmt" "io" "strings" + "sync" "time" "cyberstrike-ai/internal/einomcp" + "cyberstrike-ai/internal/mcp" "cyberstrike-ai/internal/security" "github.com/cloudwego/eino/adk/filesystem" @@ -80,15 +82,23 @@ func (w *einoStreamingShellWrap) ExecuteStreaming(ctx context.Context, input *fi req.Command = prependPythonUnbufferedEnv(req.Command) tid := strings.TrimSpace(compose.GetToolCallID(ctx)) agentTag := strings.TrimSpace(w.einoAgentName) + convID := mcp.MCPConversationIDFromContext(ctx) + execReg := mcp.EinoExecuteRunRegistryFromContext(ctx) - execCtx := ctx - var execCancel context.CancelFunc + execCtx, execCancel := context.WithCancel(ctx) + var timeoutCancel context.CancelFunc if w.toolTimeoutMinutes > 0 { - execCtx, execCancel = context.WithTimeout(ctx, time.Duration(w.toolTimeoutMinutes)*time.Minute) + execCtx, timeoutCancel = context.WithTimeout(execCtx, time.Duration(w.toolTimeoutMinutes)*time.Minute) + } + if execReg != nil && convID != "" { + execReg.RegisterActiveEinoExecute(convID, execCancel) } sr, err := w.inner.ExecuteStreaming(execCtx, &req) if err != nil { + if timeoutCancel != nil { + timeoutCancel() + } if execCancel != nil { execCancel() } @@ -111,6 +121,9 @@ func (w *einoStreamingShellWrap) ExecuteStreaming(ctx context.Context, input *fi return nil, err } if sr == nil || w.invokeNotify == nil { + if timeoutCancel != nil { + timeoutCancel() + } if execCancel != nil { execCancel() } @@ -119,11 +132,32 @@ func (w *einoStreamingShellWrap) ExecuteStreaming(ctx context.Context, input *fi outR, outW := schema.Pipe[*filesystem.ExecuteResponse](32) - go func(inner *schema.StreamReader[*filesystem.ExecuteResponse], command string, cancel context.CancelFunc, tctx context.Context) { - defer inner.Close() + go func(inner *schema.StreamReader[*filesystem.ExecuteResponse], command string, cancel context.CancelFunc, timeoutCleanup context.CancelFunc, tctx context.Context, conversationID string, reg mcp.EinoExecuteRunRegistry) { + var innerCloseOnce sync.Once + closeInner := func() { + innerCloseOnce.Do(func() { inner.Close() }) + } + defer closeInner() + if timeoutCleanup != nil { + defer timeoutCleanup() + } if cancel != nil { defer cancel() } + if reg != nil && conversationID != "" { + defer reg.UnregisterActiveEinoExecute(conversationID) + } + + // ctx 取消时关闭内层流,避免 amass 等长时间无换行输出时 Recv 永久阻塞。 + stopWatch := make(chan struct{}) + go func() { + select { + case <-tctx.Done(): + closeInner() + case <-stopWatch: + } + }() + defer close(stopWatch) var sb strings.Builder success := true @@ -144,6 +178,10 @@ func (w *einoStreamingShellWrap) ExecuteStreaming(ctx context.Context, input *fi invokeErr = context.DeadlineExceeded break } + if errors.Is(rerr, context.Canceled) || (tctx != nil && errors.Is(tctx.Err(), context.Canceled)) { + invokeErr = context.Canceled + break + } _ = outW.Send(nil, rerr) break } @@ -178,6 +216,21 @@ func (w *einoStreamingShellWrap) ExecuteStreaming(ctx context.Context, input *fi success = false invokeErr = context.DeadlineExceeded } + // 用户「中断并继续」终止 execute:合并说明进工具结果(与 MCP CancelToolExecutionWithNote 一致)。 + partialStreamed := sb.String() + var abortNote string + if reg != nil && conversationID != "" && (invokeErr != nil || errors.Is(tctx.Err(), context.Canceled)) { + if note := reg.TakeEinoExecuteAbortNote(conversationID); note != "" { + abortNote = note + merged := mcp.MergePartialToolOutputAndAbortNote(partialStreamed, note) + sb.Reset() + sb.WriteString(merged) + if invokeErr == nil { + success = false + invokeErr = context.Canceled + } + } + } // ADK 从本 Pipe 拼出 tool 消息正文;仅 Notify 尾标不会进入模型上下文。超时句写入流,与 UI 一致。 if invokeErr != nil && errors.Is(invokeErr, context.DeadlineExceeded) { hint := "\n\n" + einoExecuteTimeoutUserHint() + "\n" @@ -187,12 +240,20 @@ func (w *einoStreamingShellWrap) ExecuteStreaming(ctx context.Context, input *fi } sb.WriteString(hint) } + // 中断时循环内已逐行写入 stdout;此处只追加 USER INTERRUPT NOTE,避免整段输出重复。 + if invokeErr != nil && errors.Is(invokeErr, context.Canceled) && abortNote != "" { + if partialStreamed != "" { + _ = outW.Send(&filesystem.ExecuteResponse{Output: "\n\n" + mcp.AbortNoteBannerForModel + "\n" + abortNote}, nil) + } else if text := strings.TrimSpace(sb.String()); text != "" { + _ = outW.Send(&filesystem.ExecuteResponse{Output: text + "\n"}, nil) + } + } if w.recordMonitor != nil { w.recordMonitor(tid, command, sb.String(), success, invokeErr) } w.invokeNotify.Fire(tid, "execute", agentTag, success, sb.String(), invokeErr) outW.Close() - }(sr, userCmd, execCancel, execCtx) + }(sr, userCmd, execCancel, timeoutCancel, execCtx, convID, execReg) return outR, nil } diff --git a/internal/multiagent/eino_execute_streaming_wrap_test.go b/internal/multiagent/eino_execute_streaming_wrap_test.go index 3cadcfa5..5e8d0751 100644 --- a/internal/multiagent/eino_execute_streaming_wrap_test.go +++ b/internal/multiagent/eino_execute_streaming_wrap_test.go @@ -9,6 +9,7 @@ import ( "time" "cyberstrike-ai/internal/einomcp" + "cyberstrike-ai/internal/mcp" "github.com/cloudwego/eino/adk/filesystem" "github.com/cloudwego/eino/schema" @@ -122,6 +123,94 @@ func TestEinoStreamingShellWrap_ToolTimeoutRecvErrIsSoft(t *testing.T) { } } +func TestEinoStreamingShellWrap_CapturesOutputWithToolTimeout(t *testing.T) { + inner := &mockStreamingShell{output: "100\n"} + notify := einomcp.NewToolInvokeNotifyHolder() + var firedContent string + notify.Set(func(toolCallID, toolName, einoAgent string, success bool, content string, invokeErr error) { + firedContent = content + }) + wrap := &einoStreamingShellWrap{ + inner: inner, + invokeNotify: notify, + toolTimeoutMinutes: 60, + } + sr, err := wrap.ExecuteStreaming(context.Background(), &filesystem.ExecuteRequest{Command: "echo 100"}) + if err != nil { + t.Fatalf("ExecuteStreaming: %v", err) + } + defer sr.Close() + + var got strings.Builder + for { + resp, rerr := sr.Recv() + if errors.Is(rerr, io.EOF) { + break + } + if rerr != nil { + t.Fatalf("unexpected stream error: %v", rerr) + } + if resp != nil && resp.Output != "" { + got.WriteString(resp.Output) + } + } + if !strings.Contains(got.String(), "100") { + t.Fatalf("stream output = %q, want contains 100", got.String()) + } + if !strings.Contains(firedContent, "100") { + t.Fatalf("notify content = %q, want contains 100", firedContent) + } +} + +func TestEinoStreamingShellWrap_AbortNoteDoesNotDuplicateStreamedOutput(t *testing.T) { + inner := &mockStreamingShell{output: "line1\nline2\n", recvErr: context.Canceled} + notify := einomcp.NewToolInvokeNotifyHolder() + wrap := &einoStreamingShellWrap{ + inner: inner, + invokeNotify: notify, + } + reg := &abortNoteTestRegistry{note: "改成20次"} + ctx := mcp.WithEinoExecuteRunRegistry( + mcp.WithMCPConversationID(context.Background(), "conv-abort-dup"), + reg, + ) + sr, err := wrap.ExecuteStreaming(ctx, &filesystem.ExecuteRequest{Command: "ping -c 10 baidu.com"}) + if err != nil { + t.Fatalf("ExecuteStreaming: %v", err) + } + defer sr.Close() + + var got strings.Builder + for { + resp, rerr := sr.Recv() + if errors.Is(rerr, io.EOF) { + break + } + if rerr != nil { + t.Fatalf("unexpected stream error: %v", rerr) + } + if resp != nil && resp.Output != "" { + got.WriteString(resp.Output) + } + } + out := got.String() + if strings.Count(out, "line1") != 1 || strings.Count(out, "line2") != 1 { + t.Fatalf("stream duplicated stdout: %q", out) + } + if !strings.Contains(out, "改成20次") { + t.Fatalf("stream missing abort note: %q", out) + } +} + +type abortNoteTestRegistry struct { + note string +} + +func (r *abortNoteTestRegistry) RegisterActiveEinoExecute(string, context.CancelFunc) {} +func (r *abortNoteTestRegistry) UnregisterActiveEinoExecute(string) {} +func (r *abortNoteTestRegistry) AbortActiveEinoExecute(string, string) bool { return false } +func (r *abortNoteTestRegistry) TakeEinoExecuteAbortNote(string) string { return r.note } + func TestEinoStreamingShellWrap_NonTimeoutRecvErrStillHard(t *testing.T) { inner := &mockStreamingShell{recvErr: errors.New("broken pipe")} wrap := &einoStreamingShellWrap{inner: inner}