mirror of
https://github.com/Ed1s0nZ/CyberStrikeAI.git
synced 2026-06-27 00:10:00 +02:00
Add files via upload
This commit is contained in:
@@ -691,83 +691,21 @@ func (e *Executor) formatParamValue(param config.ParameterConfig, value interfac
|
||||
// IsBackgroundShellCommand 检测命令是否为完全后台命令(末尾有独立 &,且不在引号内)。
|
||||
// command1 & command2 不算完全后台(command2 仍在前台执行)。
|
||||
func IsBackgroundShellCommand(command string) bool {
|
||||
// 移除首尾空格
|
||||
command = strings.TrimSpace(command)
|
||||
if command == "" {
|
||||
return false
|
||||
}
|
||||
|
||||
// 检查命令中所有不在引号内的 & 符号
|
||||
// 找到最后一个 & 符号,检查它是否在命令末尾
|
||||
inSingleQuote := false
|
||||
inDoubleQuote := false
|
||||
escaped := false
|
||||
lastAmpersandPos := -1
|
||||
|
||||
for i, r := range command {
|
||||
if escaped {
|
||||
escaped = false
|
||||
continue
|
||||
}
|
||||
if r == '\\' {
|
||||
escaped = true
|
||||
continue
|
||||
}
|
||||
if r == '\'' && !inDoubleQuote {
|
||||
inSingleQuote = !inSingleQuote
|
||||
continue
|
||||
}
|
||||
if r == '"' && !inSingleQuote {
|
||||
inDoubleQuote = !inDoubleQuote
|
||||
continue
|
||||
}
|
||||
if r == '&' && !inSingleQuote && !inDoubleQuote {
|
||||
// 检查 & 前后是否有空格或换行(确保是独立的 &,而不是变量名的一部分)
|
||||
isStandalone := false
|
||||
|
||||
// 检查前面:空格、制表符、换行符,或者是命令开头
|
||||
if i == 0 {
|
||||
isStandalone = true
|
||||
} else {
|
||||
prev := command[i-1]
|
||||
if prev == ' ' || prev == '\t' || prev == '\n' || prev == '\r' {
|
||||
isStandalone = true
|
||||
}
|
||||
}
|
||||
|
||||
// 检查后面:空格、制表符、换行符,或者是命令末尾
|
||||
if isStandalone {
|
||||
if i == len(command)-1 {
|
||||
// 在末尾,肯定是独立的 &
|
||||
lastAmpersandPos = i
|
||||
} else {
|
||||
next := command[i+1]
|
||||
if next == ' ' || next == '\t' || next == '\n' || next == '\r' {
|
||||
// 后面有空格,是独立的 &
|
||||
lastAmpersandPos = i
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 如果没有找到 & 符号,不是后台命令
|
||||
if lastAmpersandPos == -1 {
|
||||
positions := findStandaloneAmpersandPositions(command)
|
||||
if len(positions) == 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
// 检查最后一个 & 后面是否还有非空内容
|
||||
afterAmpersand := strings.TrimSpace(command[lastAmpersandPos+1:])
|
||||
if afterAmpersand == "" {
|
||||
// & 在末尾或后面只有空白字符,这是完全后台命令
|
||||
// 检查 & 前面是否有内容
|
||||
beforeAmpersand := strings.TrimSpace(command[:lastAmpersandPos])
|
||||
return beforeAmpersand != ""
|
||||
last := positions[len(positions)-1]
|
||||
afterAmpersand := strings.TrimSpace(command[last+1:])
|
||||
if afterAmpersand != "" {
|
||||
return false
|
||||
}
|
||||
|
||||
// 如果 & 后面还有非空内容,说明是 command1 & command2 的情况
|
||||
// 这种情况下,command2会在前台执行,所以不算完全后台命令
|
||||
return false
|
||||
beforeAmpersand := strings.TrimSpace(command[:last])
|
||||
return beforeAmpersand != ""
|
||||
}
|
||||
|
||||
// executeSystemCommand 执行系统命令
|
||||
@@ -803,7 +741,7 @@ func (e *Executor) executeSystemCommand(ctx context.Context, args map[string]int
|
||||
zap.String("command", command),
|
||||
)
|
||||
|
||||
command = PrepareNonInteractiveShellCommand(command)
|
||||
command = PrepareShellCommandForExecute(command)
|
||||
|
||||
// 获取shell类型(可选,默认为sh)
|
||||
shell := "sh"
|
||||
@@ -844,10 +782,8 @@ func (e *Executor) executeSystemCommand(ctx context.Context, args map[string]int
|
||||
commandWithoutAmpersand := strings.TrimSuffix(strings.TrimSpace(command), "&")
|
||||
commandWithoutAmpersand = strings.TrimSpace(commandWithoutAmpersand)
|
||||
|
||||
// 构建新命令:将用户命令置于独立重定向的后台作业,再 echo $pid。
|
||||
// 若子进程与 echo 共享同一 stdout 管道,且长时间不向 stdout 写入换行,
|
||||
// bufio.ReadString('\n') 会永久阻塞(例如 beacon 持续写二进制/单行日志)。
|
||||
pidCommand := fmt.Sprintf("%s </dev/null >/dev/null 2>&1 & pid=$!; echo $pid", commandWithoutAmpersand)
|
||||
// 构建新命令:后台作业重定向标准流后 echo $pid(与 RedirectBackgroundJobStdio 一致)。
|
||||
pidCommand := RedirectBackgroundJobStdio(commandWithoutAmpersand+" &") + " pid=$!; echo $pid"
|
||||
|
||||
// 创建新命令来获取PID
|
||||
var pidCmd *exec.Cmd
|
||||
@@ -1029,27 +965,25 @@ func (e *Executor) executeSystemCommand(ctx context.Context, args map[string]int
|
||||
// 非流式路径不使用双流管道 fan-in,避免 stderr 撑满管道缓冲区时与 stdout 互相阻塞导致死锁。
|
||||
// 无输出空闲检测由上层 agent.tool_timeout_minutes 兜底,不改变原 CombinedOutput 语义。
|
||||
func combinedOutputCancellable(ctx context.Context, cmd *exec.Cmd) (string, error) {
|
||||
if err := prepareShellCmdSession(cmd); err != nil {
|
||||
return "", err
|
||||
}
|
||||
var stdoutBuf, stderrBuf strings.Builder
|
||||
cmd.Stdout = &stdoutBuf
|
||||
cmd.Stderr = &stderrBuf
|
||||
|
||||
if err := cmd.Start(); err != nil {
|
||||
session, err := StartShellSession(cmd)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
done := make(chan error, 1)
|
||||
go func() {
|
||||
done <- cmd.Wait()
|
||||
done <- session.Wait()
|
||||
}()
|
||||
|
||||
stopWatch := make(chan struct{})
|
||||
go func() {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
terminateCmdTree(cmd)
|
||||
TerminateShellCmdSession(session)
|
||||
case <-stopWatch:
|
||||
}
|
||||
}()
|
||||
@@ -1078,9 +1012,6 @@ func joinCommandOutput(stdout, stderr string) string {
|
||||
// streamCommandOutput 以“边读边回调”的方式读取命令 stdout/stderr。
|
||||
// 使用定长块读取,避免按行读取在无换行输出时永久阻塞;ctx 取消时终止进程树。
|
||||
func streamCommandOutput(ctx context.Context, cmd *exec.Cmd, cb ToolOutputCallback, noOutputSec int) (string, error) {
|
||||
if err := prepareShellCmdSession(cmd); err != nil {
|
||||
return "", err
|
||||
}
|
||||
stdoutPipe, err := cmd.StdoutPipe()
|
||||
if err != nil {
|
||||
return "", err
|
||||
@@ -1090,7 +1021,8 @@ func streamCommandOutput(ctx context.Context, cmd *exec.Cmd, cb ToolOutputCallba
|
||||
_ = stdoutPipe.Close()
|
||||
return "", err
|
||||
}
|
||||
if err := cmd.Start(); err != nil {
|
||||
session, err := StartShellSession(cmd)
|
||||
if err != nil {
|
||||
_ = stdoutPipe.Close()
|
||||
_ = stderrPipe.Close()
|
||||
return "", err
|
||||
@@ -1100,7 +1032,7 @@ func streamCommandOutput(ctx context.Context, cmd *exec.Cmd, cb ToolOutputCallba
|
||||
go func() {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
terminateCmdTree(cmd)
|
||||
TerminateShellCmdSession(session)
|
||||
case <-stopWatch:
|
||||
}
|
||||
}()
|
||||
@@ -1152,13 +1084,13 @@ func streamCommandOutput(ctx context.Context, cmd *exec.Cmd, cb ToolOutputCallba
|
||||
}
|
||||
|
||||
fireInactivity := func() {
|
||||
terminateCmdTree(cmd)
|
||||
TerminateShellCmdSession(session)
|
||||
msg := ShellNoOutputTimeoutMessage(idleWatch.Sec)
|
||||
outBuilder.WriteString(msg)
|
||||
if cb != nil {
|
||||
cb(msg)
|
||||
}
|
||||
_ = cmd.Wait()
|
||||
_ = session.Wait()
|
||||
}
|
||||
|
||||
chunksLoop:
|
||||
@@ -1169,9 +1101,9 @@ chunksLoop:
|
||||
}
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
terminateCmdTree(cmd)
|
||||
TerminateShellCmdSession(session)
|
||||
flush()
|
||||
_ = cmd.Wait()
|
||||
_ = session.Wait()
|
||||
return outBuilder.String(), ctx.Err()
|
||||
case <-idleCh:
|
||||
fireInactivity()
|
||||
@@ -1193,7 +1125,7 @@ chunksLoop:
|
||||
flush()
|
||||
|
||||
// 等待命令结束,返回最终退出状态
|
||||
waitErr := cmd.Wait()
|
||||
waitErr := session.Wait()
|
||||
return outBuilder.String(), waitErr
|
||||
}
|
||||
|
||||
@@ -1265,13 +1197,18 @@ func runCommandWithPTY(ctx context.Context, cmd *exec.Cmd, cb ToolOutputCallback
|
||||
}
|
||||
defer func() { _ = ptmx.Close() }()
|
||||
|
||||
rootPID := 0
|
||||
if cmd.Process != nil {
|
||||
rootPID = cmd.Process.Pid
|
||||
}
|
||||
|
||||
// ctx 取消时尽快终止子进程
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
_ = ptmx.Close() // 触发读退出
|
||||
terminateCmdTree(cmd)
|
||||
terminateProcessGroup(rootPID, cmd)
|
||||
case <-done:
|
||||
}
|
||||
}()
|
||||
|
||||
@@ -19,13 +19,23 @@ func prepareShellCmdSession(cmd *exec.Cmd) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// terminateCmdTree 尽力终止 cmd 及其进程组(Unix 下 Setsid 后 PGID == 首进程 PID)。
|
||||
func terminateCmdTree(cmd *exec.Cmd) {
|
||||
if cmd == nil || cmd.Process == nil {
|
||||
// terminateProcessGroup 对 rootPID 对应进程组发 SIGKILL;rootPID 为 0 时回退到 cmd.Process.Pid。
|
||||
func terminateProcessGroup(rootPID int, cmd *exec.Cmd) {
|
||||
pid := rootPID
|
||||
if pid <= 0 && cmd != nil && cmd.Process != nil {
|
||||
pid = cmd.Process.Pid
|
||||
}
|
||||
if pid <= 0 {
|
||||
return
|
||||
}
|
||||
pid := cmd.Process.Pid
|
||||
if err := syscall.Kill(-pid, syscall.SIGKILL); err != nil {
|
||||
_ = cmd.Process.Kill()
|
||||
if cmd != nil && cmd.Process != nil {
|
||||
_ = cmd.Process.Kill()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// terminateCmdTree 尽力终止 cmd 及其进程组(Unix 下 Setsid 后 PGID == 首进程 PID)。
|
||||
func terminateCmdTree(cmd *exec.Cmd) {
|
||||
terminateProcessGroup(0, cmd)
|
||||
}
|
||||
|
||||
@@ -20,14 +20,24 @@ func prepareShellCmdSession(cmd *exec.Cmd) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// terminateCmdTree 使用 taskkill /F /T 终止进程及其子进程(Windows 上 Process.Kill 无法保证杀掉 python 等孙进程)。
|
||||
func terminateCmdTree(cmd *exec.Cmd) {
|
||||
if cmd == nil || cmd.Process == nil {
|
||||
// terminateProcessGroup 使用 taskkill /F /T 终止进程及其子进程;rootPID 为 0 时回退到 cmd.Process.Pid。
|
||||
func terminateProcessGroup(rootPID int, cmd *exec.Cmd) {
|
||||
pid := rootPID
|
||||
if pid <= 0 && cmd != nil && cmd.Process != nil {
|
||||
pid = cmd.Process.Pid
|
||||
}
|
||||
if pid <= 0 {
|
||||
return
|
||||
}
|
||||
pid := cmd.Process.Pid
|
||||
tk := exec.Command("taskkill", "/F", "/T", "/PID", strconv.Itoa(pid))
|
||||
if err := tk.Run(); err != nil {
|
||||
_ = cmd.Process.Kill()
|
||||
if cmd != nil && cmd.Process != nil {
|
||||
_ = cmd.Process.Kill()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// terminateCmdTree 使用 taskkill /F /T 终止进程及其子进程(Windows 上 Process.Kill 无法保证杀掉 python 等孙进程)。
|
||||
func terminateCmdTree(cmd *exec.Cmd) {
|
||||
terminateProcessGroup(0, cmd)
|
||||
}
|
||||
|
||||
@@ -0,0 +1,111 @@
|
||||
package security
|
||||
|
||||
import "strings"
|
||||
|
||||
const backgroundJobStdioRedirect = " </dev/null >/dev/null 2>&1"
|
||||
|
||||
// findStandaloneAmpersandPositions 返回不在引号内的独立 & 下标(排除 &&)。
|
||||
func findStandaloneAmpersandPositions(command string) []int {
|
||||
command = strings.TrimSpace(command)
|
||||
if command == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
var positions []int
|
||||
inSingleQuote := false
|
||||
inDoubleQuote := false
|
||||
escaped := false
|
||||
|
||||
for i := 0; i < len(command); i++ {
|
||||
r := command[i]
|
||||
if escaped {
|
||||
escaped = false
|
||||
continue
|
||||
}
|
||||
if r == '\\' {
|
||||
escaped = true
|
||||
continue
|
||||
}
|
||||
if r == '\'' && !inDoubleQuote {
|
||||
inSingleQuote = !inSingleQuote
|
||||
continue
|
||||
}
|
||||
if r == '"' && !inSingleQuote {
|
||||
inDoubleQuote = !inDoubleQuote
|
||||
continue
|
||||
}
|
||||
if r != '&' || inSingleQuote || inDoubleQuote {
|
||||
continue
|
||||
}
|
||||
if i+1 < len(command) && command[i+1] == '&' {
|
||||
continue
|
||||
}
|
||||
if i > 0 && command[i-1] == '&' {
|
||||
continue
|
||||
}
|
||||
|
||||
isStandalone := i == 0
|
||||
if !isStandalone {
|
||||
prev := command[i-1]
|
||||
isStandalone = prev == ' ' || prev == '\t' || prev == '\n' || prev == '\r'
|
||||
}
|
||||
if !isStandalone {
|
||||
continue
|
||||
}
|
||||
if i == len(command)-1 {
|
||||
positions = append(positions, i)
|
||||
continue
|
||||
}
|
||||
next := command[i+1]
|
||||
if next == ' ' || next == '\t' || next == '\n' || next == '\r' {
|
||||
positions = append(positions, i)
|
||||
}
|
||||
}
|
||||
return positions
|
||||
}
|
||||
|
||||
func segmentHasStdioRedirect(segment string) bool {
|
||||
lower := strings.ToLower(strings.TrimSpace(segment))
|
||||
if lower == "" {
|
||||
return false
|
||||
}
|
||||
if strings.Contains(lower, ">/dev/null") || strings.Contains(lower, "2>/dev/null") {
|
||||
return true
|
||||
}
|
||||
if strings.Contains(lower, "&>") || strings.Contains(lower, "&>>") {
|
||||
return true
|
||||
}
|
||||
if strings.Contains(lower, "2>&1") && strings.Contains(lower, "/dev/null") {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// RedirectBackgroundJobStdio 为每个独立 & 前的后台段注入 </dev/null >/dev/null 2>&1,
|
||||
// 避免后台子进程占用 execute/exec 管道导致挂死。
|
||||
func RedirectBackgroundJobStdio(command string) string {
|
||||
positions := findStandaloneAmpersandPositions(command)
|
||||
if len(positions) == 0 {
|
||||
return command
|
||||
}
|
||||
|
||||
out := command
|
||||
for j := len(positions) - 1; j >= 0; j-- {
|
||||
i := positions[j]
|
||||
before := out[:i]
|
||||
after := out[i:]
|
||||
trimmed := strings.TrimRight(before, " \t\r\n")
|
||||
if segmentHasStdioRedirect(trimmed) {
|
||||
continue
|
||||
}
|
||||
trailing := before[len(trimmed):]
|
||||
out = trimmed + backgroundJobStdioRedirect + trailing + after
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// PrepareShellCommandForExecute 组合 execute/exec 用的非交互包装与后台 IO 重定向。
|
||||
// 须先注入 exec </dev/null,再改写 & 后台段,否则段内 </dev/null 会使 stdin 重定向被误判为已存在。
|
||||
func PrepareShellCommandForExecute(shellCommand string) string {
|
||||
return RedirectBackgroundJobStdio(PrepareNonInteractiveShellCommand(shellCommand))
|
||||
}
|
||||
@@ -0,0 +1,64 @@
|
||||
package security
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestRedirectBackgroundJobStdio_mixedCommand(t *testing.T) {
|
||||
in := "java -jar app.jar & JRMP_PID=$!; echo started"
|
||||
out := RedirectBackgroundJobStdio(in)
|
||||
if !strings.Contains(out, "java -jar app.jar </dev/null >/dev/null 2>&1 &") {
|
||||
t.Fatalf("expected redirect before &: %q", out)
|
||||
}
|
||||
if !strings.Contains(out, "echo started") {
|
||||
t.Fatalf("foreground tail preserved: %q", out)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRedirectBackgroundJobStdio_trailingOnly(t *testing.T) {
|
||||
in := "sleep 120 &"
|
||||
out := RedirectBackgroundJobStdio(in)
|
||||
want := "sleep 120 </dev/null >/dev/null 2>&1 &"
|
||||
if strings.TrimSpace(out) != want {
|
||||
t.Fatalf("got %q want %q", out, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRedirectBackgroundJobStdio_skipsAlreadyRedirected(t *testing.T) {
|
||||
in := "sleep 1 >/dev/null 2>&1 & echo ok"
|
||||
out := RedirectBackgroundJobStdio(in)
|
||||
if out != in {
|
||||
t.Fatalf("should not double-redirect: %q", out)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRedirectBackgroundJobStdio_skipsAndAnd(t *testing.T) {
|
||||
in := "test -f /etc/passwd && echo ok"
|
||||
out := RedirectBackgroundJobStdio(in)
|
||||
if out != in {
|
||||
t.Fatalf("&& must not be treated as background &: %q", out)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPrepareShellCommandForExecute(t *testing.T) {
|
||||
out := PrepareShellCommandForExecute("java -jar x & echo hi")
|
||||
if !strings.Contains(out, "exec </dev/null") {
|
||||
t.Fatalf("missing stdin redirect: %q", out)
|
||||
}
|
||||
if !strings.Contains(out, "GIT_PAGER=cat") {
|
||||
t.Fatalf("missing pager export: %q", out)
|
||||
}
|
||||
if !strings.Contains(out, "java -jar x </dev/null >/dev/null 2>&1 &") {
|
||||
t.Fatalf("missing background redirect: %q", out)
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsBackgroundShellCommand_usesSharedParser(t *testing.T) {
|
||||
if !IsBackgroundShellCommand("sleep 1 &") {
|
||||
t.Fatal("trailing & should be background")
|
||||
}
|
||||
if IsBackgroundShellCommand("sleep 1 & echo hi") {
|
||||
t.Fatal("mixed should not be fully background")
|
||||
}
|
||||
}
|
||||
@@ -27,6 +27,11 @@ func TerminateShellCmdTree(cmd *exec.Cmd) {
|
||||
terminateCmdTree(cmd)
|
||||
}
|
||||
|
||||
// TerminateShellCmdSession 使用 Start 时缓存的进程组 ID 终止(shell 已退出时仍有效)。
|
||||
func TerminateShellCmdSession(session *ShellSession) {
|
||||
TerminateShellSession(session)
|
||||
}
|
||||
|
||||
// EinoStreamingShell 为 Eino ADK execute 工具提供流式 shell,行为与 exec 对齐:
|
||||
// 并发读取 stdout/stderr(定长块,非按行),避免官方 local.ExecuteStreaming 先排空 stdout
|
||||
// 导致 stderr 错误(如 sudo 密码提示)长时间不可见、UI 一直显示「执行中」。
|
||||
@@ -55,8 +60,10 @@ func (s *EinoStreamingShell) ExecuteStreaming(ctx context.Context, input *filesy
|
||||
func runShellInBackground(ctx context.Context, command string, w *schema.StreamWriter[*filesystem.ExecuteResponse]) {
|
||||
defer w.Close()
|
||||
|
||||
command = PrepareShellCommandForExecute(command)
|
||||
cmd := exec.CommandContext(ctx, "/bin/sh", "-c", command)
|
||||
ConfigureShellCmdForAgentExecute(cmd)
|
||||
applyDefaultTerminalEnv(cmd)
|
||||
attachNonInteractiveStdin(cmd)
|
||||
stdout, err := cmd.StdoutPipe()
|
||||
if err != nil {
|
||||
_ = w.Send(nil, fmt.Errorf("failed to create stdout pipe: %w", err))
|
||||
@@ -68,7 +75,8 @@ func runShellInBackground(ctx context.Context, command string, w *schema.StreamW
|
||||
_ = w.Send(nil, fmt.Errorf("failed to create stderr pipe: %w", err))
|
||||
return
|
||||
}
|
||||
if err := cmd.Start(); err != nil {
|
||||
session, err := StartShellSession(cmd)
|
||||
if err != nil {
|
||||
_ = stdout.Close()
|
||||
_ = stderr.Close()
|
||||
_ = w.Send(nil, fmt.Errorf("failed to start command: %w", err))
|
||||
@@ -78,14 +86,14 @@ func runShellInBackground(ctx context.Context, command string, w *schema.StreamW
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
drainShellPipes(stdout, stderr)
|
||||
_ = cmd.Wait()
|
||||
_ = session.Wait()
|
||||
close(done)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-ctx.Done():
|
||||
TerminateShellCmdTree(cmd)
|
||||
TerminateShellCmdSession(session)
|
||||
}
|
||||
|
||||
exitCode := 0
|
||||
@@ -112,8 +120,10 @@ func drainShellPipes(stdout, stderr io.Reader) {
|
||||
func streamShellForeground(ctx context.Context, command string, w *schema.StreamWriter[*filesystem.ExecuteResponse]) {
|
||||
defer w.Close()
|
||||
|
||||
command = PrepareShellCommandForExecute(command)
|
||||
cmd := exec.CommandContext(ctx, "/bin/sh", "-c", command)
|
||||
ConfigureShellCmdForAgentExecute(cmd)
|
||||
applyDefaultTerminalEnv(cmd)
|
||||
attachNonInteractiveStdin(cmd)
|
||||
|
||||
stdoutPipe, err := cmd.StdoutPipe()
|
||||
if err != nil {
|
||||
@@ -126,7 +136,8 @@ func streamShellForeground(ctx context.Context, command string, w *schema.Stream
|
||||
_ = w.Send(nil, fmt.Errorf("failed to create stderr pipe: %w", err))
|
||||
return
|
||||
}
|
||||
if err := cmd.Start(); err != nil {
|
||||
session, err := StartShellSession(cmd)
|
||||
if err != nil {
|
||||
_ = stdoutPipe.Close()
|
||||
_ = stderrPipe.Close()
|
||||
_ = w.Send(nil, fmt.Errorf("failed to start command: %w", err))
|
||||
@@ -137,7 +148,7 @@ func streamShellForeground(ctx context.Context, command string, w *schema.Stream
|
||||
go func() {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
TerminateShellCmdTree(cmd)
|
||||
TerminateShellCmdSession(session)
|
||||
case <-stopWatch:
|
||||
}
|
||||
}()
|
||||
@@ -174,12 +185,12 @@ func streamShellForeground(ctx context.Context, command string, w *schema.Stream
|
||||
}
|
||||
hadOutput = true
|
||||
if w.Send(&filesystem.ExecuteResponse{Output: chunk}, nil) {
|
||||
TerminateShellCmdTree(cmd)
|
||||
TerminateShellCmdSession(session)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
waitErr := cmd.Wait()
|
||||
waitErr := session.Wait()
|
||||
if waitErr == nil {
|
||||
exitCode := 0
|
||||
_ = w.Send(&filesystem.ExecuteResponse{ExitCode: &exitCode}, nil)
|
||||
|
||||
@@ -115,3 +115,38 @@ func TestEinoStreamingShell_StderrWhileStdoutBlocks(t *testing.T) {
|
||||
t.Fatalf("expected early stderr, got: %q", got.String())
|
||||
}
|
||||
}
|
||||
|
||||
// TestEinoStreamingShell_BackgroundJobDoesNotHoldPipe 模拟 cmd & 后继续前台逻辑:重定向后应快速结束。
|
||||
func TestEinoStreamingShell_BackgroundJobDoesNotHoldPipe(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("skipping shell integration in -short")
|
||||
}
|
||||
shell := NewEinoStreamingShell()
|
||||
cmd := `(sh -c 'printf x; sleep 120') & echo started; sleep 0`
|
||||
sr, err := shell.ExecuteStreaming(context.Background(), &filesystem.ExecuteRequest{Command: cmd})
|
||||
if err != nil {
|
||||
t.Fatalf("ExecuteStreaming: %v", err)
|
||||
}
|
||||
defer sr.Close()
|
||||
|
||||
start := time.Now()
|
||||
var got strings.Builder
|
||||
for {
|
||||
resp, rerr := sr.Recv()
|
||||
if errors.Is(rerr, io.EOF) {
|
||||
break
|
||||
}
|
||||
if rerr != nil {
|
||||
t.Fatalf("recv: %v", rerr)
|
||||
}
|
||||
if resp != nil && resp.Output != "" {
|
||||
got.WriteString(resp.Output)
|
||||
}
|
||||
}
|
||||
if time.Since(start) > 3*time.Second {
|
||||
t.Fatalf("expected fast completion, took %v output=%q", time.Since(start), got.String())
|
||||
}
|
||||
if !strings.Contains(got.String(), "started") {
|
||||
t.Fatalf("expected foreground echo, got: %q", got.String())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,47 @@
|
||||
package security
|
||||
|
||||
import "os/exec"
|
||||
|
||||
// ShellSession 在 Start 时记录根 shell 的进程组 ID,取消/超时时可杀整组(即使 cmd.Process 已失效)。
|
||||
type ShellSession struct {
|
||||
Cmd *exec.Cmd
|
||||
rootPID int
|
||||
}
|
||||
|
||||
// StartShellSession 配置独立进程组并启动 shell,缓存 rootPID(Unix 下即 PGID)。
|
||||
func StartShellSession(cmd *exec.Cmd) (*ShellSession, error) {
|
||||
if err := prepareShellCmdSession(cmd); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := cmd.Start(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
pid := 0
|
||||
if cmd.Process != nil {
|
||||
pid = cmd.Process.Pid
|
||||
}
|
||||
return &ShellSession{Cmd: cmd, rootPID: pid}, nil
|
||||
}
|
||||
|
||||
// Wait 等待 shell 退出。
|
||||
func (s *ShellSession) Wait() error {
|
||||
if s == nil || s.Cmd == nil {
|
||||
return nil
|
||||
}
|
||||
return s.Cmd.Wait()
|
||||
}
|
||||
|
||||
// Terminate 终止 shell 及其进程组。
|
||||
func (s *ShellSession) Terminate() {
|
||||
if s == nil {
|
||||
return
|
||||
}
|
||||
terminateProcessGroup(s.rootPID, s.Cmd)
|
||||
}
|
||||
|
||||
// TerminateShellSession 终止由 StartShellSession 启动的会话。
|
||||
func TerminateShellSession(session *ShellSession) {
|
||||
if session != nil {
|
||||
session.Terminate()
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,65 @@
|
||||
package security
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os/exec"
|
||||
"runtime"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestShellSession_TerminateUsesCachedRootPID(t *testing.T) {
|
||||
if runtime.GOOS == "windows" {
|
||||
t.Skip("unix process group kill")
|
||||
}
|
||||
|
||||
cmd := exec.Command("sh", "-c", "sleep 300")
|
||||
ConfigureShellCmdForAgentExecute(cmd)
|
||||
|
||||
session, err := StartShellSession(cmd)
|
||||
if err != nil {
|
||||
t.Fatalf("StartShellSession: %v", err)
|
||||
}
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
session.Terminate()
|
||||
|
||||
done := make(chan error, 1)
|
||||
go func() { done <- session.Wait() }()
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(5 * time.Second):
|
||||
t.Fatal("session did not finish within 5s after Terminate")
|
||||
}
|
||||
}
|
||||
|
||||
func TestShellSession_TerminateAfterContextCancel(t *testing.T) {
|
||||
if runtime.GOOS == "windows" {
|
||||
t.Skip("unix process group kill")
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
cmd := exec.CommandContext(ctx, "sh", "-c", "sleep 300")
|
||||
ConfigureShellCmdForAgentExecute(cmd)
|
||||
|
||||
session, err := StartShellSession(cmd)
|
||||
if err != nil {
|
||||
t.Fatalf("StartShellSession: %v", err)
|
||||
}
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
cancel()
|
||||
TerminateShellCmdSession(session)
|
||||
|
||||
done := make(chan error, 1)
|
||||
go func() { done <- session.Wait() }()
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(5 * time.Second):
|
||||
t.Fatal("session did not finish within 5s after cancel+terminate")
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user