diff --git a/internal/security/command_failure_format.go b/internal/security/command_failure_format.go new file mode 100644 index 00000000..dc5af2c5 --- /dev/null +++ b/internal/security/command_failure_format.go @@ -0,0 +1,56 @@ +package security + +import ( + "errors" + "fmt" + "os/exec" + "strings" +) + +// FormatCommandFailureResult 与 exec 工具 ToolResult 文案一致(不含 ToolErrorPrefix)。 +func FormatCommandFailureResult(exitCode int, output string) string { + output = strings.TrimSpace(output) + errMsg := fmt.Sprintf("exit status %d", exitCode) + if output == "" { + return fmt.Sprintf("命令执行失败: %s", errMsg) + } + if strings.HasPrefix(output, "命令执行失败:") { + return output + } + return fmt.Sprintf("命令执行失败: %s\n输出: %s", errMsg, output) +} + +// FormatCommandFailureFromErr 根据 exec/execute 返回的 error 生成统一失败文案(IsError 正文)。 +func FormatCommandFailureFromErr(err error, output string) string { + if err == nil { + return strings.TrimSpace(output) + } + var exitError *exec.ExitError + if errors.As(err, &exitError) { + return FormatCommandFailureResult(exitError.ExitCode(), output) + } + output = strings.TrimSpace(output) + if output == "" { + return fmt.Sprintf("命令执行失败: %v", err) + } + if strings.HasPrefix(output, "命令执行失败:") { + return output + } + return fmt.Sprintf("命令执行失败: %v\n输出: %s", err, output) +} + +// ExecuteFailureStatusLine 流式 execute 结束时追加的单行状态(输出正文已在流中推送过)。 +func ExecuteFailureStatusLine(exitCode int) string { + return fmt.Sprintf("\n命令执行失败: exit status %d", exitCode) +} + +// IsCommandFailureResult 判断工具结果正文是否表示命令非零退出(用于 execute / exec 对齐 isError)。 +func IsCommandFailureResult(content string) bool { + return strings.Contains(content, "命令执行失败:") +} + +// IsLegacyShellExitNoise 过滤旧版 shell 流中冗余的 exit code 行。 +func IsLegacyShellExitNoise(s string) bool { + trimmed := strings.TrimSpace(s) + return strings.HasPrefix(trimmed, "command exited with non-zero code ") +} diff --git a/internal/security/command_failure_format_test.go b/internal/security/command_failure_format_test.go new file mode 100644 index 00000000..d7ca53a2 --- /dev/null +++ b/internal/security/command_failure_format_test.go @@ -0,0 +1,54 @@ +package security + +import ( + "errors" + "os/exec" + "strings" + "testing" +) + +func TestFormatCommandFailureResult(t *testing.T) { + got := FormatCommandFailureResult(1, "sudo: password required") + want := "命令执行失败: exit status 1\n输出: sudo: password required" + if got != want { + t.Fatalf("got %q want %q", got, want) + } + if FormatCommandFailureResult(2, "") != "命令执行失败: exit status 2" { + t.Fatal("empty output format") + } + if FormatCommandFailureResult(1, "命令执行失败: exit status 1") != "命令执行失败: exit status 1" { + t.Fatal("should not double-wrap") + } +} + +func TestIsCommandFailureResult(t *testing.T) { + if !IsCommandFailureResult("sudo: err\n命令执行失败: exit status 1") { + t.Fatal("expected true") + } + if IsCommandFailureResult("sudo: err only") { + t.Fatal("expected false") + } +} + +func TestFormatCommandFailureFromErr(t *testing.T) { + cmd := exec.Command("sh", "-c", "exit 42") + err := cmd.Run() + got := FormatCommandFailureFromErr(err, "oops") + if got != "命令执行失败: exit status 42\n输出: oops" { + t.Fatalf("got %q", got) + } + timeoutErr := errors.New("shell inactivity timeout (300s)") + got2 := FormatCommandFailureFromErr(timeoutErr, "already timed out") + if !strings.Contains(got2, "shell inactivity timeout") || !strings.Contains(got2, "already timed out") { + t.Fatalf("got %q", got2) + } +} + +func TestIsLegacyShellExitNoise(t *testing.T) { + if !IsLegacyShellExitNoise("command exited with non-zero code 1\n") { + t.Fatal("expected legacy noise") + } + if IsLegacyShellExitNoise("sudo: failed") { + t.Fatal("unexpected noise") + } +} diff --git a/internal/security/executor.go b/internal/security/executor.go index 260f9427..2092bf8b 100644 --- a/internal/security/executor.go +++ b/internal/security/executor.go @@ -829,9 +829,7 @@ func (e *Executor) executeSystemCommand(ctx context.Context, args map[string]int } else { cmd = exec.CommandContext(ctx, shell, "-c", command) } - applyDefaultTerminalEnv(cmd) - attachNonInteractiveStdin(cmd) - _ = prepareShellCmdSession(cmd) + ConfigureShellCmdForAgentExecute(cmd) // 执行命令 e.logger.Info("执行系统命令", @@ -860,8 +858,7 @@ func (e *Executor) executeSystemCommand(ctx context.Context, args map[string]int } else { pidCmd = exec.CommandContext(ctx, shell, "-c", pidCommand) } - applyDefaultTerminalEnv(pidCmd) - _ = prepareShellCmdSession(pidCmd) + ConfigureShellCmdForAgentExecute(pidCmd) // 获取stdout管道 stdout, err := pidCmd.StdoutPipe() @@ -980,8 +977,7 @@ func (e *Executor) executeSystemCommand(ctx context.Context, args map[string]int if workDir != "" { cmd2.Dir = workDir } - applyDefaultTerminalEnv(cmd2) - _ = prepareShellCmdSession(cmd2) + ConfigureShellCmdForAgentExecute(cmd2) output, err = runCommandWithPTY(ctx, cmd2, cb) } } else { @@ -994,8 +990,7 @@ func (e *Executor) executeSystemCommand(ctx context.Context, args map[string]int if workDir != "" { cmd2.Dir = workDir } - applyDefaultTerminalEnv(cmd2) - _ = prepareShellCmdSession(cmd2) + ConfigureShellCmdForAgentExecute(cmd2) output, err = runCommandWithPTY(ctx, cmd2, nil) } } @@ -1009,7 +1004,7 @@ func (e *Executor) executeSystemCommand(ctx context.Context, args map[string]int Content: []mcp.Content{ { Type: "text", - Text: fmt.Sprintf("命令执行失败: %v\n输出: %s", err, string(output)), + Text: FormatCommandFailureFromErr(err, output), }, }, IsError: true, diff --git a/internal/security/executor_test.go b/internal/security/executor_test.go index 5bb08678..fa24d6d0 100644 --- a/internal/security/executor_test.go +++ b/internal/security/executor_test.go @@ -71,6 +71,27 @@ func TestExecuteSystemCommand_BackgroundDoesNotBlockOnChildStdout(t *testing.T) } } +func TestExecuteSystemCommand_FailureFormat(t *testing.T) { + executor, _ := setupTestExecutor(t) + res, err := executor.executeSystemCommand(context.Background(), map[string]interface{}{ + "command": "echo fail-msg >&2; exit 7", + "shell": "sh", + }) + if err != nil { + t.Fatalf("executeSystemCommand: %v", err) + } + if res == nil || !res.IsError { + t.Fatalf("expected IsError, got %+v", res) + } + text := res.Content[0].Text + if text != FormatCommandFailureResult(7, "fail-msg\n") && text != FormatCommandFailureResult(7, "fail-msg") { + t.Fatalf("unexpected failure text: %q", text) + } + if !strings.Contains(text, "exit status 7") || !strings.Contains(text, "fail-msg") { + t.Fatalf("unexpected failure text: %q", text) + } +} + func TestBuildCommandArgs_NmapSkipsEmptyOptionalFlags(t *testing.T) { pos1 := 1 executor, _ := setupTestExecutor(t) diff --git a/internal/security/shell_execute_stream.go b/internal/security/shell_execute_stream.go new file mode 100644 index 00000000..ea3c7d39 --- /dev/null +++ b/internal/security/shell_execute_stream.go @@ -0,0 +1,200 @@ +package security + +import ( + "context" + "errors" + "fmt" + "io" + "os/exec" + "sync" + + "github.com/cloudwego/eino/adk/filesystem" + "github.com/cloudwego/eino/schema" +) + +// ConfigureShellCmdForAgentExecute 与 exec 工具一致:非交互 stdin、pager/TERM 环境、独立进程组。 +func ConfigureShellCmdForAgentExecute(cmd *exec.Cmd) { + if cmd == nil { + return + } + applyDefaultTerminalEnv(cmd) + attachNonInteractiveStdin(cmd) + _ = prepareShellCmdSession(cmd) +} + +// TerminateShellCmdTree 尽力终止 shell 及其子进程组(与 exec/execute 超时取消一致)。 +func TerminateShellCmdTree(cmd *exec.Cmd) { + terminateCmdTree(cmd) +} + +// EinoStreamingShell 为 Eino ADK execute 工具提供流式 shell,行为与 exec 对齐: +// 并发读取 stdout/stderr(定长块,非按行),避免官方 local.ExecuteStreaming 先排空 stdout +// 导致 stderr 错误(如 sudo 密码提示)长时间不可见、UI 一直显示「执行中」。 +type EinoStreamingShell struct{} + +// NewEinoStreamingShell 创建 execute 流式 shell 实现。 +func NewEinoStreamingShell() *EinoStreamingShell { + return &EinoStreamingShell{} +} + +// ExecuteStreaming 实现 filesystem.StreamingShell。 +func (s *EinoStreamingShell) ExecuteStreaming(ctx context.Context, input *filesystem.ExecuteRequest) (*schema.StreamReader[*filesystem.ExecuteResponse], error) { + if input == nil || input.Command == "" { + return nil, fmt.Errorf("command is required") + } + + sr, w := schema.Pipe[*filesystem.ExecuteResponse](100) + if input.RunInBackendGround { + go runShellInBackground(ctx, input.Command, w) + return sr, nil + } + go streamShellForeground(ctx, input.Command, w) + return sr, nil +} + +func runShellInBackground(ctx context.Context, command string, w *schema.StreamWriter[*filesystem.ExecuteResponse]) { + defer w.Close() + + cmd := exec.CommandContext(ctx, "/bin/sh", "-c", command) + ConfigureShellCmdForAgentExecute(cmd) + stdout, err := cmd.StdoutPipe() + if err != nil { + _ = w.Send(nil, fmt.Errorf("failed to create stdout pipe: %w", err)) + return + } + stderr, err := cmd.StderrPipe() + if err != nil { + _ = stdout.Close() + _ = w.Send(nil, fmt.Errorf("failed to create stderr pipe: %w", err)) + return + } + if err := cmd.Start(); err != nil { + _ = stdout.Close() + _ = stderr.Close() + _ = w.Send(nil, fmt.Errorf("failed to start command: %w", err)) + return + } + + done := make(chan struct{}) + go func() { + drainShellPipes(stdout, stderr) + _ = cmd.Wait() + close(done) + }() + + select { + case <-done: + case <-ctx.Done(): + TerminateShellCmdTree(cmd) + } + + exitCode := 0 + _ = w.Send(&filesystem.ExecuteResponse{ + Output: "command started in background\n", + ExitCode: &exitCode, + }, nil) +} + +func drainShellPipes(stdout, stderr io.Reader) { + var wg sync.WaitGroup + wg.Add(2) + go func() { + defer wg.Done() + _, _ = io.Copy(io.Discard, stdout) + }() + go func() { + defer wg.Done() + _, _ = io.Copy(io.Discard, stderr) + }() + wg.Wait() +} + +func streamShellForeground(ctx context.Context, command string, w *schema.StreamWriter[*filesystem.ExecuteResponse]) { + defer w.Close() + + cmd := exec.CommandContext(ctx, "/bin/sh", "-c", command) + ConfigureShellCmdForAgentExecute(cmd) + + stdoutPipe, err := cmd.StdoutPipe() + if err != nil { + _ = w.Send(nil, fmt.Errorf("failed to create stdout pipe: %w", err)) + return + } + stderrPipe, err := cmd.StderrPipe() + if err != nil { + _ = stdoutPipe.Close() + _ = w.Send(nil, fmt.Errorf("failed to create stderr pipe: %w", err)) + return + } + if err := cmd.Start(); err != nil { + _ = stdoutPipe.Close() + _ = stderrPipe.Close() + _ = w.Send(nil, fmt.Errorf("failed to start command: %w", err)) + return + } + + stopWatch := make(chan struct{}) + go func() { + select { + case <-ctx.Done(): + TerminateShellCmdTree(cmd) + case <-stopWatch: + } + }() + defer close(stopWatch) + + chunks := make(chan string, 64) + var wg sync.WaitGroup + readFn := func(r io.Reader) { + defer wg.Done() + buf := make([]byte, 8192) + for { + n, readErr := r.Read(buf) + if n > 0 { + chunks <- string(buf[:n]) + } + if readErr != nil { + return + } + } + } + + wg.Add(2) + go readFn(stdoutPipe) + go readFn(stderrPipe) + go func() { + wg.Wait() + close(chunks) + }() + + hadOutput := false + for chunk := range chunks { + if chunk == "" { + continue + } + hadOutput = true + if w.Send(&filesystem.ExecuteResponse{Output: chunk}, nil) { + TerminateShellCmdTree(cmd) + return + } + } + + waitErr := cmd.Wait() + if waitErr == nil { + exitCode := 0 + _ = w.Send(&filesystem.ExecuteResponse{ExitCode: &exitCode}, nil) + return + } + + var exitError *exec.ExitError + if errors.As(waitErr, &exitError) { + exitCode := exitError.ExitCode() + resp := &filesystem.ExecuteResponse{ExitCode: &exitCode} + if !hadOutput { + resp.Output = FormatCommandFailureResult(exitCode, "") + } + _ = w.Send(resp, nil) + return + } + _ = w.Send(nil, fmt.Errorf("command failed: %w", waitErr)) +} diff --git a/internal/security/shell_execute_stream_test.go b/internal/security/shell_execute_stream_test.go new file mode 100644 index 00000000..feeaecb5 --- /dev/null +++ b/internal/security/shell_execute_stream_test.go @@ -0,0 +1,117 @@ +package security + +import ( + "context" + "errors" + "io" + "strings" + "testing" + "time" + + "github.com/cloudwego/eino/adk/filesystem" +) + +func TestEinoStreamingShell_StreamsStderrBeforeStdoutEOF(t *testing.T) { + shell := NewEinoStreamingShell() + cmd := PrepareNonInteractiveShellCommand("echo err-only >&2; exit 1") + sr, err := shell.ExecuteStreaming(context.Background(), &filesystem.ExecuteRequest{Command: cmd}) + if err != nil { + t.Fatalf("ExecuteStreaming: %v", err) + } + defer sr.Close() + + start := time.Now() + var got strings.Builder + for { + resp, rerr := sr.Recv() + if errors.Is(rerr, io.EOF) { + break + } + if rerr != nil { + t.Fatalf("recv: %v", rerr) + } + if resp != nil && resp.Output != "" { + got.WriteString(resp.Output) + } + } + if time.Since(start) > 3*time.Second { + t.Fatalf("expected fast completion, took %v", time.Since(start)) + } + if !strings.Contains(got.String(), "err-only") { + t.Fatalf("expected stderr in output, got: %q", got.String()) + } +} + +func TestEinoStreamingShell_SudoFailsFast(t *testing.T) { + shell := NewEinoStreamingShell() + cmd := PrepareNonInteractiveShellCommand("sudo whoami && sudo cat /etc/os-release") + sr, err := shell.ExecuteStreaming(context.Background(), &filesystem.ExecuteRequest{Command: cmd}) + if err != nil { + t.Fatalf("ExecuteStreaming: %v", err) + } + defer sr.Close() + + start := time.Now() + var got strings.Builder + for { + resp, rerr := sr.Recv() + if errors.Is(rerr, io.EOF) { + break + } + if rerr != nil { + t.Fatalf("recv: %v", rerr) + } + if resp == nil { + continue + } + got.WriteString(resp.Output) + } + if time.Since(start) > 5*time.Second { + t.Fatalf("sudo should fail quickly, took %v output=%q", time.Since(start), got.String()) + } + out := got.String() + if strings.Contains(out, "command exited with non-zero code") { + t.Fatalf("legacy exit line present: %q", out) + } + if !strings.Contains(out, "sudo") && !strings.Contains(out, "password") && !strings.Contains(out, "terminal") { + t.Fatalf("expected sudo error text, got: %q", out) + } +} + +func TestEinoStreamingShell_StderrWhileStdoutBlocks(t *testing.T) { + shell := NewEinoStreamingShell() + // 模拟 sudo:stderr 先有输出,stdout 侧进程仍挂起;旧 eino local 在首包 stderr 前不会向流写任何内容。 + cmd := PrepareNonInteractiveShellCommand(`echo "password prompt" >&2; sleep 30`) + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + + sr, err := shell.ExecuteStreaming(ctx, &filesystem.ExecuteRequest{Command: cmd}) + if err != nil { + t.Fatalf("ExecuteStreaming: %v", err) + } + defer sr.Close() + + start := time.Now() + var got strings.Builder + for { + resp, rerr := sr.Recv() + if errors.Is(rerr, io.EOF) { + break + } + if rerr != nil { + break + } + if resp != nil && resp.Output != "" { + got.WriteString(resp.Output) + if strings.Contains(got.String(), "password prompt") { + break + } + } + } + if time.Since(start) > 1500*time.Millisecond { + t.Fatalf("expected stderr promptly, took %v output=%q", time.Since(start), got.String()) + } + if !strings.Contains(got.String(), "password prompt") { + t.Fatalf("expected early stderr, got: %q", got.String()) + } +}