Add files via upload

This commit is contained in:
公明
2026-06-24 17:17:33 +08:00
committed by GitHub
parent b6a6009629
commit 85d58eeeb3
7 changed files with 454 additions and 36 deletions
+56 -14
View File
@@ -32,10 +32,11 @@ var ToolOutputCallbackCtxKey = toolOutputCallbackCtxKey{}
// Executor 安全工具执行器
type Executor struct {
config *config.SecurityConfig
toolIndex map[string]*config.ToolConfig // 工具索引,用于 O(1) 查找
mcpServer *mcp.Server
logger *zap.Logger
config *config.SecurityConfig
toolIndex map[string]*config.ToolConfig // 工具索引,用于 O(1) 查找
mcpServer *mcp.Server
logger *zap.Logger
shellNoOutputTimeoutSec int // execute/exec 无新输出空闲秒数;0=默认 300-1=关闭(见 SetShellNoOutputTimeoutSeconds
}
// NewExecutor 创建新的执行器
@@ -51,6 +52,11 @@ func NewExecutor(cfg *config.SecurityConfig, mcpServer *mcp.Server, logger *zap.
return executor
}
// SetShellNoOutputTimeoutSeconds 配置 exec 工具无输出空闲终止(与 agent.shell_no_output_timeout_seconds 一致)。
func (e *Executor) SetShellNoOutputTimeoutSeconds(sec int) {
e.shellNoOutputTimeoutSec = sec
}
// buildToolIndex 构建工具索引,将 O(n) 查找优化为 O(1)
func (e *Executor) buildToolIndex() {
e.toolIndex = make(map[string]*config.ToolConfig)
@@ -133,6 +139,7 @@ func (e *Executor) ExecuteTool(ctx context.Context, toolName string, args map[st
// 执行命令
cmd := exec.CommandContext(ctx, toolConfig.Command, cmdArgs...)
applyDefaultTerminalEnv(cmd)
attachNonInteractiveStdin(cmd)
_ = prepareShellCmdSession(cmd)
e.logger.Info("执行安全工具",
@@ -144,7 +151,7 @@ func (e *Executor) ExecuteTool(ctx context.Context, toolName string, args map[st
var err error
// 如果上层提供了 stdout/stderr 增量回调,则边执行边读取并回调。
if cb, ok := ctx.Value(ToolOutputCallbackCtxKey).(ToolOutputCallback); ok && cb != nil {
output, err = streamCommandOutput(ctx, cmd, cb)
output, err = streamCommandOutput(ctx, cmd, cb, ResolveShellNoOutputTimeoutSeconds(e.shellNoOutputTimeoutSec))
if err != nil && shouldRetryWithPTY(output) {
e.logger.Info("检测到工具需要 TTY,使用 PTY 重试",
zap.String("tool", toolName),
@@ -797,6 +804,8 @@ func (e *Executor) executeSystemCommand(ctx context.Context, args map[string]int
zap.String("command", command),
)
command = PrepareNonInteractiveShellCommand(command)
// 获取shell类型(可选,默认为sh)
shell := "sh"
if s, ok := args["shell"].(string); ok && s != "" {
@@ -821,6 +830,7 @@ func (e *Executor) executeSystemCommand(ctx context.Context, args map[string]int
cmd = exec.CommandContext(ctx, shell, "-c", command)
}
applyDefaultTerminalEnv(cmd)
attachNonInteractiveStdin(cmd)
_ = prepareShellCmdSession(cmd)
// 执行命令
@@ -963,7 +973,7 @@ func (e *Executor) executeSystemCommand(ctx context.Context, args map[string]int
var err error
// 若上层提供工具输出增量回调,则边执行边流式读取。
if cb, ok := ctx.Value(ToolOutputCallbackCtxKey).(ToolOutputCallback); ok && cb != nil {
output, err = streamCommandOutput(ctx, cmd, cb)
output, err = streamCommandOutput(ctx, cmd, cb, ResolveShellNoOutputTimeoutSeconds(e.shellNoOutputTimeoutSec))
if err != nil && shouldRetryWithPTY(output) {
e.logger.Info("检测到系统命令需要 TTY,使用 PTY 重试")
cmd2 := exec.CommandContext(ctx, shell, "-c", command)
@@ -1024,7 +1034,7 @@ func (e *Executor) executeSystemCommand(ctx context.Context, args map[string]int
// streamCommandOutput 以“边读边回调”的方式读取命令 stdout/stderr。
// 使用定长块读取,避免按行读取在无换行输出时永久阻塞;ctx 取消时终止进程树。
func streamCommandOutput(ctx context.Context, cmd *exec.Cmd, cb ToolOutputCallback) (string, error) {
func streamCommandOutput(ctx context.Context, cmd *exec.Cmd, cb ToolOutputCallback, noOutputSec int) (string, error) {
if err := prepareShellCmdSession(cmd); err != nil {
return "", err
}
@@ -1091,12 +1101,43 @@ func streamCommandOutput(ctx context.Context, cmd *exec.Cmd, cb ToolOutputCallba
lastFlush = time.Now()
}
for chunk := range chunks {
outBuilder.WriteString(chunk)
deltaBuilder.WriteString(chunk)
// 简单节流:buffer 大于 2KB 或 200ms 就刷新一次
if deltaBuilder.Len() >= 2048 || time.Since(lastFlush) >= 200*time.Millisecond {
flush()
idleWatch := NewShellInactivityWatch(noOutputSec)
if idleWatch != nil {
defer idleWatch.Stop()
}
fireInactivity := func() {
terminateCmdTree(cmd)
msg := ShellNoOutputTimeoutMessage(idleWatch.Sec)
outBuilder.WriteString(msg)
if cb != nil {
cb(msg)
}
_ = cmd.Wait()
}
chunksLoop:
for {
var idleCh <-chan struct{}
if idleWatch != nil {
idleCh = idleWatch.Expired
}
select {
case <-idleCh:
fireInactivity()
return outBuilder.String(), fmt.Errorf("shell inactivity timeout (%ds)", idleWatch.Sec)
case chunk, ok := <-chunks:
if !ok {
break chunksLoop
}
if chunk != "" && idleWatch != nil {
idleWatch.Bump()
}
outBuilder.WriteString(chunk)
deltaBuilder.WriteString(chunk)
if deltaBuilder.Len() >= 2048 || time.Since(lastFlush) >= 200*time.Millisecond {
flush()
}
}
}
flush()
@@ -1116,6 +1157,7 @@ func applyDefaultTerminalEnv(cmd *exec.Cmd) {
if cmd.Env == nil {
cmd.Env = os.Environ()
}
cmd.Env = ApplyNonInteractivePagerEnv(cmd.Env)
// 如果用户已设置 TERM/COLUMNS/LINES,则不覆盖
has := func(k string) bool {
prefix := k + "="
@@ -1159,7 +1201,7 @@ func runCommandWithPTY(ctx context.Context, cmd *exec.Cmd, cb ToolOutputCallback
if runtime.GOOS == "windows" {
// PTY 方案为类 UnixWindows 走原逻辑
if cb != nil {
return streamCommandOutput(ctx, cmd, cb)
return streamCommandOutput(ctx, cmd, cb, 0)
}
_ = prepareShellCmdSession(cmd)
out, err := cmd.CombinedOutput()
+163
View File
@@ -0,0 +1,163 @@
package security
import (
"fmt"
"os"
"os/exec"
"strings"
"sync"
"time"
)
// ShellNoOutputTimeoutMessage 长时间无新 stdout/stderr 时的提示(软失败,模型可见)。
func ShellNoOutputTimeoutMessage(idleSec int) string {
return fmt.Sprintf(`命令已终止:超过 %d 秒没有新的输出,疑似在等待交互输入或已挂起。
长时静默任务请使用末尾 & 后台运行,或增大 agent.shell_no_output_timeout_seconds-1=关闭此检测)。
Command terminated: no new output for %d seconds (possible interactive wait or hung process).`, idleSec, idleSec)
}
// ShellInactivityWatch 在 noOutputSec 内无任何新输出时向 expired 发送信号;每次 Bump 重置计时。
// 与「仅有首包输出就永久取消计时」不同,可兜住 sudo 打印 Password 提示后继续挂起等情况。
type ShellInactivityWatch struct {
Sec int
mu sync.Mutex
timer *time.Timer
Expired chan struct{}
}
func NewShellInactivityWatch(noOutputSec int) *ShellInactivityWatch {
sec := ResolveShellNoOutputTimeoutSeconds(noOutputSec)
if sec <= 0 {
return nil
}
w := &ShellInactivityWatch{
Sec: sec,
Expired: make(chan struct{}, 1),
}
w.Bump()
return w
}
func (w *ShellInactivityWatch) Bump() {
if w == nil || w.Sec <= 0 {
return
}
w.mu.Lock()
defer w.mu.Unlock()
if w.timer != nil {
w.timer.Stop()
}
w.timer = time.AfterFunc(time.Duration(w.Sec)*time.Second, func() {
select {
case w.Expired <- struct{}{}:
default:
}
})
}
func (w *ShellInactivityWatch) Stop() {
if w == nil {
return
}
w.mu.Lock()
defer w.mu.Unlock()
if w.timer != nil {
w.timer.Stop()
w.timer = nil
}
}
// ResolveShellNoOutputTimeoutSeconds0=默认 3005 分钟);-1=关闭;>0=自定义。
func ResolveShellNoOutputTimeoutSeconds(sec int) int {
if sec < 0 {
return 0
}
if sec == 0 {
return 300
}
return sec
}
// PrependNonInteractiveShellExports 为 sh -c 注入通用非交互环境(pager 等),不维护命令黑名单。
func PrependNonInteractiveShellExports(shellCommand string) string {
if strings.TrimSpace(shellCommand) == "" {
return shellCommand
}
upper := strings.ToUpper(shellCommand)
var pairs []string
add := func(key, val string) {
if strings.Contains(upper, strings.ToUpper(key)) {
return
}
pairs = append(pairs, key+"="+val)
}
add("GIT_PAGER", "cat")
add("PAGER", "cat")
add("SYSTEMD_PAGER", "cat")
add("DEBIAN_FRONTEND", "noninteractive")
if len(pairs) == 0 {
return shellCommand
}
return "export " + strings.Join(pairs, " ") + "\n" + shellCommand
}
// PrependNonInteractiveStdinRedirect 为 sh -c 关闭 stdin(与 attachNonInteractiveStdin 等价),
// 使 read/input()/sudo -S 等从 stdin 读取的程序快速失败而非挂起。已含 </dev/null 时不重复注入。
func PrependNonInteractiveStdinRedirect(shellCommand string) string {
if strings.TrimSpace(shellCommand) == "" {
return shellCommand
}
lower := strings.ToLower(shellCommand)
if strings.Contains(lower, "</dev/null") || strings.Contains(lower, "0</dev/null") {
return shellCommand
}
return "exec </dev/null\n" + shellCommand
}
// PrepareNonInteractiveShellCommand 组合非交互包装:stdin 关闭 + pager 等环境变量(零名单)。
func PrepareNonInteractiveShellCommand(shellCommand string) string {
return PrependNonInteractiveStdinRedirect(PrependNonInteractiveShellExports(shellCommand))
}
// ApplyNonInteractivePagerEnv 为 exec.Cmd 补齐与 PrependNonInteractiveShellExports 一致的环境变量。
func ApplyNonInteractivePagerEnv(cmdEnv []string) []string {
if cmdEnv == nil {
cmdEnv = []string{}
}
has := func(k string) bool {
prefix := k + "="
for _, e := range cmdEnv {
if strings.HasPrefix(e, prefix) {
return true
}
}
return false
}
if !has("GIT_PAGER") {
cmdEnv = append(cmdEnv, "GIT_PAGER=cat")
}
if !has("PAGER") {
cmdEnv = append(cmdEnv, "PAGER=cat")
}
if !has("SYSTEMD_PAGER") {
cmdEnv = append(cmdEnv, "SYSTEMD_PAGER=cat")
}
if !has("DEBIAN_FRONTEND") {
cmdEnv = append(cmdEnv, "DEBIAN_FRONTEND=noninteractive")
}
return cmdEnv
}
// attachNonInteractiveStdin 关闭交互式 stdin,使部分命令快速失败而非等待输入。
func attachNonInteractiveStdin(cmd *exec.Cmd) {
if cmd == nil || cmd.Stdin != nil {
return
}
f, err := os.Open(os.DevNull)
if err != nil {
return
}
cmd.Stdin = f
}
@@ -0,0 +1,128 @@
package security
import (
"context"
"os"
"os/exec"
"strings"
"testing"
"time"
)
func TestPrependNonInteractiveShellExports(t *testing.T) {
out := PrependNonInteractiveShellExports("echo hi")
if !strings.Contains(out, "GIT_PAGER=cat") || !strings.Contains(out, "PAGER=cat") {
t.Fatalf("missing pager exports: %q", out)
}
if !strings.HasSuffix(strings.TrimSpace(out), "echo hi") {
t.Fatalf("command suffix lost: %q", out)
}
skip := PrependNonInteractiveShellExports("GIT_PAGER=less echo hi")
if strings.Contains(skip, "export GIT_PAGER=cat") {
t.Fatalf("should not override existing GIT_PAGER: %q", skip)
}
}
func TestPrependNonInteractiveStdinRedirect(t *testing.T) {
out := PrependNonInteractiveStdinRedirect("echo hi")
if !strings.HasPrefix(out, "exec </dev/null") {
t.Fatalf("missing stdin redirect: %q", out)
}
if !strings.HasSuffix(strings.TrimSpace(out), "echo hi") {
t.Fatalf("command suffix lost: %q", out)
}
skip := PrependNonInteractiveStdinRedirect("cmd </dev/null")
if strings.HasPrefix(skip, "exec </dev/null") {
t.Fatalf("should not double redirect: %q", skip)
}
}
func TestPrepareNonInteractiveShellCommand(t *testing.T) {
out := PrepareNonInteractiveShellCommand("echo hi")
if !strings.Contains(out, "exec </dev/null") {
t.Fatalf("missing stdin redirect: %q", out)
}
if !strings.Contains(out, "GIT_PAGER=cat") {
t.Fatalf("missing pager export: %q", out)
}
}
func TestNewShellInactivityWatch(t *testing.T) {
w := NewShellInactivityWatch(1)
if w == nil {
t.Fatal("expected watch")
}
w.Bump()
select {
case <-w.Expired:
case <-time.After(3 * time.Second):
t.Fatal("expected inactivity fire within 3s")
}
}
func TestResolveShellNoOutputTimeoutSeconds(t *testing.T) {
if ResolveShellNoOutputTimeoutSeconds(0) != 300 {
t.Fatal("zero should default to 300")
}
if ResolveShellNoOutputTimeoutSeconds(-1) != 0 {
t.Fatal("-1 should disable")
}
if ResolveShellNoOutputTimeoutSeconds(30) != 30 {
t.Fatal("explicit value")
}
}
// TestNonInteractiveStdinReadExitsQuickly 验证 exec </dev/null + attachNonInteractiveStdin 时 read 立即 EOF,不挂起。
func TestNonInteractiveStdinReadExitsQuickly(t *testing.T) {
if testing.Short() {
t.Skip("skipping shell integration in -short")
}
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
cmd := exec.CommandContext(ctx, "sh", "-c", PrepareNonInteractiveShellCommand(`read x; echo "x=<$x>"`))
attachNonInteractiveStdin(cmd)
start := time.Now()
out, err := cmd.CombinedOutput()
elapsed := time.Since(start)
if elapsed > 2*time.Second {
t.Fatalf("read with closed stdin took %v, want <2s", elapsed)
}
if err != nil {
t.Fatalf("unexpected error: %v output=%q", err, out)
}
if !strings.Contains(string(out), "x=<>") {
t.Fatalf("unexpected output: %q", out)
}
}
// TestNonInteractiveStdinReadBlocksWithoutRedirect 对照:stdin 为永不写入的管道时 read 会挂起。
func TestNonInteractiveStdinReadBlocksWithoutRedirect(t *testing.T) {
if testing.Short() {
t.Skip("skipping shell integration in -short")
}
r, w, err := os.Pipe()
if err != nil {
t.Fatal(err)
}
defer r.Close()
// 保持 w 打开且不写数据,模拟「等待用户输入」
cmd := exec.Command("sh", "-c", `read x; echo done`)
cmd.Stdin = r
done := make(chan error, 1)
go func() { done <- cmd.Run() }()
select {
case err := <-done:
t.Fatalf("expected hang, but command finished: %v", err)
case <-time.After(500 * time.Millisecond):
if cmd.Process != nil {
_ = cmd.Process.Kill()
}
_ = w.Close()
<-done // 等待 goroutine 退出
}
}