mirror of
https://github.com/Ed1s0nZ/CyberStrikeAI.git
synced 2026-06-25 15:30:15 +02:00
Add files via upload
This commit is contained in:
@@ -32,10 +32,11 @@ var ToolOutputCallbackCtxKey = toolOutputCallbackCtxKey{}
|
||||
|
||||
// Executor 安全工具执行器
|
||||
type Executor struct {
|
||||
config *config.SecurityConfig
|
||||
toolIndex map[string]*config.ToolConfig // 工具索引,用于 O(1) 查找
|
||||
mcpServer *mcp.Server
|
||||
logger *zap.Logger
|
||||
config *config.SecurityConfig
|
||||
toolIndex map[string]*config.ToolConfig // 工具索引,用于 O(1) 查找
|
||||
mcpServer *mcp.Server
|
||||
logger *zap.Logger
|
||||
shellNoOutputTimeoutSec int // execute/exec 无新输出空闲秒数;0=默认 300;-1=关闭(见 SetShellNoOutputTimeoutSeconds)
|
||||
}
|
||||
|
||||
// NewExecutor 创建新的执行器
|
||||
@@ -51,6 +52,11 @@ func NewExecutor(cfg *config.SecurityConfig, mcpServer *mcp.Server, logger *zap.
|
||||
return executor
|
||||
}
|
||||
|
||||
// SetShellNoOutputTimeoutSeconds 配置 exec 工具无输出空闲终止(与 agent.shell_no_output_timeout_seconds 一致)。
|
||||
func (e *Executor) SetShellNoOutputTimeoutSeconds(sec int) {
|
||||
e.shellNoOutputTimeoutSec = sec
|
||||
}
|
||||
|
||||
// buildToolIndex 构建工具索引,将 O(n) 查找优化为 O(1)
|
||||
func (e *Executor) buildToolIndex() {
|
||||
e.toolIndex = make(map[string]*config.ToolConfig)
|
||||
@@ -133,6 +139,7 @@ func (e *Executor) ExecuteTool(ctx context.Context, toolName string, args map[st
|
||||
// 执行命令
|
||||
cmd := exec.CommandContext(ctx, toolConfig.Command, cmdArgs...)
|
||||
applyDefaultTerminalEnv(cmd)
|
||||
attachNonInteractiveStdin(cmd)
|
||||
_ = prepareShellCmdSession(cmd)
|
||||
|
||||
e.logger.Info("执行安全工具",
|
||||
@@ -144,7 +151,7 @@ func (e *Executor) ExecuteTool(ctx context.Context, toolName string, args map[st
|
||||
var err error
|
||||
// 如果上层提供了 stdout/stderr 增量回调,则边执行边读取并回调。
|
||||
if cb, ok := ctx.Value(ToolOutputCallbackCtxKey).(ToolOutputCallback); ok && cb != nil {
|
||||
output, err = streamCommandOutput(ctx, cmd, cb)
|
||||
output, err = streamCommandOutput(ctx, cmd, cb, ResolveShellNoOutputTimeoutSeconds(e.shellNoOutputTimeoutSec))
|
||||
if err != nil && shouldRetryWithPTY(output) {
|
||||
e.logger.Info("检测到工具需要 TTY,使用 PTY 重试",
|
||||
zap.String("tool", toolName),
|
||||
@@ -797,6 +804,8 @@ func (e *Executor) executeSystemCommand(ctx context.Context, args map[string]int
|
||||
zap.String("command", command),
|
||||
)
|
||||
|
||||
command = PrepareNonInteractiveShellCommand(command)
|
||||
|
||||
// 获取shell类型(可选,默认为sh)
|
||||
shell := "sh"
|
||||
if s, ok := args["shell"].(string); ok && s != "" {
|
||||
@@ -821,6 +830,7 @@ func (e *Executor) executeSystemCommand(ctx context.Context, args map[string]int
|
||||
cmd = exec.CommandContext(ctx, shell, "-c", command)
|
||||
}
|
||||
applyDefaultTerminalEnv(cmd)
|
||||
attachNonInteractiveStdin(cmd)
|
||||
_ = prepareShellCmdSession(cmd)
|
||||
|
||||
// 执行命令
|
||||
@@ -963,7 +973,7 @@ func (e *Executor) executeSystemCommand(ctx context.Context, args map[string]int
|
||||
var err error
|
||||
// 若上层提供工具输出增量回调,则边执行边流式读取。
|
||||
if cb, ok := ctx.Value(ToolOutputCallbackCtxKey).(ToolOutputCallback); ok && cb != nil {
|
||||
output, err = streamCommandOutput(ctx, cmd, cb)
|
||||
output, err = streamCommandOutput(ctx, cmd, cb, ResolveShellNoOutputTimeoutSeconds(e.shellNoOutputTimeoutSec))
|
||||
if err != nil && shouldRetryWithPTY(output) {
|
||||
e.logger.Info("检测到系统命令需要 TTY,使用 PTY 重试")
|
||||
cmd2 := exec.CommandContext(ctx, shell, "-c", command)
|
||||
@@ -1024,7 +1034,7 @@ func (e *Executor) executeSystemCommand(ctx context.Context, args map[string]int
|
||||
|
||||
// streamCommandOutput 以“边读边回调”的方式读取命令 stdout/stderr。
|
||||
// 使用定长块读取,避免按行读取在无换行输出时永久阻塞;ctx 取消时终止进程树。
|
||||
func streamCommandOutput(ctx context.Context, cmd *exec.Cmd, cb ToolOutputCallback) (string, error) {
|
||||
func streamCommandOutput(ctx context.Context, cmd *exec.Cmd, cb ToolOutputCallback, noOutputSec int) (string, error) {
|
||||
if err := prepareShellCmdSession(cmd); err != nil {
|
||||
return "", err
|
||||
}
|
||||
@@ -1091,12 +1101,43 @@ func streamCommandOutput(ctx context.Context, cmd *exec.Cmd, cb ToolOutputCallba
|
||||
lastFlush = time.Now()
|
||||
}
|
||||
|
||||
for chunk := range chunks {
|
||||
outBuilder.WriteString(chunk)
|
||||
deltaBuilder.WriteString(chunk)
|
||||
// 简单节流:buffer 大于 2KB 或 200ms 就刷新一次
|
||||
if deltaBuilder.Len() >= 2048 || time.Since(lastFlush) >= 200*time.Millisecond {
|
||||
flush()
|
||||
idleWatch := NewShellInactivityWatch(noOutputSec)
|
||||
if idleWatch != nil {
|
||||
defer idleWatch.Stop()
|
||||
}
|
||||
|
||||
fireInactivity := func() {
|
||||
terminateCmdTree(cmd)
|
||||
msg := ShellNoOutputTimeoutMessage(idleWatch.Sec)
|
||||
outBuilder.WriteString(msg)
|
||||
if cb != nil {
|
||||
cb(msg)
|
||||
}
|
||||
_ = cmd.Wait()
|
||||
}
|
||||
|
||||
chunksLoop:
|
||||
for {
|
||||
var idleCh <-chan struct{}
|
||||
if idleWatch != nil {
|
||||
idleCh = idleWatch.Expired
|
||||
}
|
||||
select {
|
||||
case <-idleCh:
|
||||
fireInactivity()
|
||||
return outBuilder.String(), fmt.Errorf("shell inactivity timeout (%ds)", idleWatch.Sec)
|
||||
case chunk, ok := <-chunks:
|
||||
if !ok {
|
||||
break chunksLoop
|
||||
}
|
||||
if chunk != "" && idleWatch != nil {
|
||||
idleWatch.Bump()
|
||||
}
|
||||
outBuilder.WriteString(chunk)
|
||||
deltaBuilder.WriteString(chunk)
|
||||
if deltaBuilder.Len() >= 2048 || time.Since(lastFlush) >= 200*time.Millisecond {
|
||||
flush()
|
||||
}
|
||||
}
|
||||
}
|
||||
flush()
|
||||
@@ -1116,6 +1157,7 @@ func applyDefaultTerminalEnv(cmd *exec.Cmd) {
|
||||
if cmd.Env == nil {
|
||||
cmd.Env = os.Environ()
|
||||
}
|
||||
cmd.Env = ApplyNonInteractivePagerEnv(cmd.Env)
|
||||
// 如果用户已设置 TERM/COLUMNS/LINES,则不覆盖
|
||||
has := func(k string) bool {
|
||||
prefix := k + "="
|
||||
@@ -1159,7 +1201,7 @@ func runCommandWithPTY(ctx context.Context, cmd *exec.Cmd, cb ToolOutputCallback
|
||||
if runtime.GOOS == "windows" {
|
||||
// PTY 方案为类 Unix;Windows 走原逻辑
|
||||
if cb != nil {
|
||||
return streamCommandOutput(ctx, cmd, cb)
|
||||
return streamCommandOutput(ctx, cmd, cb, 0)
|
||||
}
|
||||
_ = prepareShellCmdSession(cmd)
|
||||
out, err := cmd.CombinedOutput()
|
||||
|
||||
@@ -0,0 +1,163 @@
|
||||
package security
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"os/exec"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// ShellNoOutputTimeoutMessage 长时间无新 stdout/stderr 时的提示(软失败,模型可见)。
|
||||
func ShellNoOutputTimeoutMessage(idleSec int) string {
|
||||
return fmt.Sprintf(`命令已终止:超过 %d 秒没有新的输出,疑似在等待交互输入或已挂起。
|
||||
|
||||
长时静默任务请使用末尾 & 后台运行,或增大 agent.shell_no_output_timeout_seconds(-1=关闭此检测)。
|
||||
|
||||
Command terminated: no new output for %d seconds (possible interactive wait or hung process).`, idleSec, idleSec)
|
||||
}
|
||||
|
||||
// ShellInactivityWatch 在 noOutputSec 内无任何新输出时向 expired 发送信号;每次 Bump 重置计时。
|
||||
// 与「仅有首包输出就永久取消计时」不同,可兜住 sudo 打印 Password 提示后继续挂起等情况。
|
||||
type ShellInactivityWatch struct {
|
||||
Sec int
|
||||
mu sync.Mutex
|
||||
timer *time.Timer
|
||||
Expired chan struct{}
|
||||
}
|
||||
|
||||
func NewShellInactivityWatch(noOutputSec int) *ShellInactivityWatch {
|
||||
sec := ResolveShellNoOutputTimeoutSeconds(noOutputSec)
|
||||
if sec <= 0 {
|
||||
return nil
|
||||
}
|
||||
w := &ShellInactivityWatch{
|
||||
Sec: sec,
|
||||
Expired: make(chan struct{}, 1),
|
||||
}
|
||||
w.Bump()
|
||||
return w
|
||||
}
|
||||
|
||||
func (w *ShellInactivityWatch) Bump() {
|
||||
if w == nil || w.Sec <= 0 {
|
||||
return
|
||||
}
|
||||
w.mu.Lock()
|
||||
defer w.mu.Unlock()
|
||||
if w.timer != nil {
|
||||
w.timer.Stop()
|
||||
}
|
||||
w.timer = time.AfterFunc(time.Duration(w.Sec)*time.Second, func() {
|
||||
select {
|
||||
case w.Expired <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func (w *ShellInactivityWatch) Stop() {
|
||||
if w == nil {
|
||||
return
|
||||
}
|
||||
w.mu.Lock()
|
||||
defer w.mu.Unlock()
|
||||
if w.timer != nil {
|
||||
w.timer.Stop()
|
||||
w.timer = nil
|
||||
}
|
||||
}
|
||||
|
||||
// ResolveShellNoOutputTimeoutSeconds:0=默认 300(5 分钟);-1=关闭;>0=自定义。
|
||||
func ResolveShellNoOutputTimeoutSeconds(sec int) int {
|
||||
if sec < 0 {
|
||||
return 0
|
||||
}
|
||||
if sec == 0 {
|
||||
return 300
|
||||
}
|
||||
return sec
|
||||
}
|
||||
|
||||
// PrependNonInteractiveShellExports 为 sh -c 注入通用非交互环境(pager 等),不维护命令黑名单。
|
||||
func PrependNonInteractiveShellExports(shellCommand string) string {
|
||||
if strings.TrimSpace(shellCommand) == "" {
|
||||
return shellCommand
|
||||
}
|
||||
upper := strings.ToUpper(shellCommand)
|
||||
var pairs []string
|
||||
add := func(key, val string) {
|
||||
if strings.Contains(upper, strings.ToUpper(key)) {
|
||||
return
|
||||
}
|
||||
pairs = append(pairs, key+"="+val)
|
||||
}
|
||||
add("GIT_PAGER", "cat")
|
||||
add("PAGER", "cat")
|
||||
add("SYSTEMD_PAGER", "cat")
|
||||
add("DEBIAN_FRONTEND", "noninteractive")
|
||||
if len(pairs) == 0 {
|
||||
return shellCommand
|
||||
}
|
||||
return "export " + strings.Join(pairs, " ") + "\n" + shellCommand
|
||||
}
|
||||
|
||||
// PrependNonInteractiveStdinRedirect 为 sh -c 关闭 stdin(与 attachNonInteractiveStdin 等价),
|
||||
// 使 read/input()/sudo -S 等从 stdin 读取的程序快速失败而非挂起。已含 </dev/null 时不重复注入。
|
||||
func PrependNonInteractiveStdinRedirect(shellCommand string) string {
|
||||
if strings.TrimSpace(shellCommand) == "" {
|
||||
return shellCommand
|
||||
}
|
||||
lower := strings.ToLower(shellCommand)
|
||||
if strings.Contains(lower, "</dev/null") || strings.Contains(lower, "0</dev/null") {
|
||||
return shellCommand
|
||||
}
|
||||
return "exec </dev/null\n" + shellCommand
|
||||
}
|
||||
|
||||
// PrepareNonInteractiveShellCommand 组合非交互包装:stdin 关闭 + pager 等环境变量(零名单)。
|
||||
func PrepareNonInteractiveShellCommand(shellCommand string) string {
|
||||
return PrependNonInteractiveStdinRedirect(PrependNonInteractiveShellExports(shellCommand))
|
||||
}
|
||||
|
||||
// ApplyNonInteractivePagerEnv 为 exec.Cmd 补齐与 PrependNonInteractiveShellExports 一致的环境变量。
|
||||
func ApplyNonInteractivePagerEnv(cmdEnv []string) []string {
|
||||
if cmdEnv == nil {
|
||||
cmdEnv = []string{}
|
||||
}
|
||||
has := func(k string) bool {
|
||||
prefix := k + "="
|
||||
for _, e := range cmdEnv {
|
||||
if strings.HasPrefix(e, prefix) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
if !has("GIT_PAGER") {
|
||||
cmdEnv = append(cmdEnv, "GIT_PAGER=cat")
|
||||
}
|
||||
if !has("PAGER") {
|
||||
cmdEnv = append(cmdEnv, "PAGER=cat")
|
||||
}
|
||||
if !has("SYSTEMD_PAGER") {
|
||||
cmdEnv = append(cmdEnv, "SYSTEMD_PAGER=cat")
|
||||
}
|
||||
if !has("DEBIAN_FRONTEND") {
|
||||
cmdEnv = append(cmdEnv, "DEBIAN_FRONTEND=noninteractive")
|
||||
}
|
||||
return cmdEnv
|
||||
}
|
||||
|
||||
// attachNonInteractiveStdin 关闭交互式 stdin,使部分命令快速失败而非等待输入。
|
||||
func attachNonInteractiveStdin(cmd *exec.Cmd) {
|
||||
if cmd == nil || cmd.Stdin != nil {
|
||||
return
|
||||
}
|
||||
f, err := os.Open(os.DevNull)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
cmd.Stdin = f
|
||||
}
|
||||
@@ -0,0 +1,128 @@
|
||||
package security
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"os/exec"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestPrependNonInteractiveShellExports(t *testing.T) {
|
||||
out := PrependNonInteractiveShellExports("echo hi")
|
||||
if !strings.Contains(out, "GIT_PAGER=cat") || !strings.Contains(out, "PAGER=cat") {
|
||||
t.Fatalf("missing pager exports: %q", out)
|
||||
}
|
||||
if !strings.HasSuffix(strings.TrimSpace(out), "echo hi") {
|
||||
t.Fatalf("command suffix lost: %q", out)
|
||||
}
|
||||
skip := PrependNonInteractiveShellExports("GIT_PAGER=less echo hi")
|
||||
if strings.Contains(skip, "export GIT_PAGER=cat") {
|
||||
t.Fatalf("should not override existing GIT_PAGER: %q", skip)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPrependNonInteractiveStdinRedirect(t *testing.T) {
|
||||
out := PrependNonInteractiveStdinRedirect("echo hi")
|
||||
if !strings.HasPrefix(out, "exec </dev/null") {
|
||||
t.Fatalf("missing stdin redirect: %q", out)
|
||||
}
|
||||
if !strings.HasSuffix(strings.TrimSpace(out), "echo hi") {
|
||||
t.Fatalf("command suffix lost: %q", out)
|
||||
}
|
||||
skip := PrependNonInteractiveStdinRedirect("cmd </dev/null")
|
||||
if strings.HasPrefix(skip, "exec </dev/null") {
|
||||
t.Fatalf("should not double redirect: %q", skip)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPrepareNonInteractiveShellCommand(t *testing.T) {
|
||||
out := PrepareNonInteractiveShellCommand("echo hi")
|
||||
if !strings.Contains(out, "exec </dev/null") {
|
||||
t.Fatalf("missing stdin redirect: %q", out)
|
||||
}
|
||||
if !strings.Contains(out, "GIT_PAGER=cat") {
|
||||
t.Fatalf("missing pager export: %q", out)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewShellInactivityWatch(t *testing.T) {
|
||||
w := NewShellInactivityWatch(1)
|
||||
if w == nil {
|
||||
t.Fatal("expected watch")
|
||||
}
|
||||
w.Bump()
|
||||
select {
|
||||
case <-w.Expired:
|
||||
case <-time.After(3 * time.Second):
|
||||
t.Fatal("expected inactivity fire within 3s")
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveShellNoOutputTimeoutSeconds(t *testing.T) {
|
||||
if ResolveShellNoOutputTimeoutSeconds(0) != 300 {
|
||||
t.Fatal("zero should default to 300")
|
||||
}
|
||||
if ResolveShellNoOutputTimeoutSeconds(-1) != 0 {
|
||||
t.Fatal("-1 should disable")
|
||||
}
|
||||
if ResolveShellNoOutputTimeoutSeconds(30) != 30 {
|
||||
t.Fatal("explicit value")
|
||||
}
|
||||
}
|
||||
|
||||
// TestNonInteractiveStdinReadExitsQuickly 验证 exec </dev/null + attachNonInteractiveStdin 时 read 立即 EOF,不挂起。
|
||||
func TestNonInteractiveStdinReadExitsQuickly(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("skipping shell integration in -short")
|
||||
}
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
cmd := exec.CommandContext(ctx, "sh", "-c", PrepareNonInteractiveShellCommand(`read x; echo "x=<$x>"`))
|
||||
attachNonInteractiveStdin(cmd)
|
||||
|
||||
start := time.Now()
|
||||
out, err := cmd.CombinedOutput()
|
||||
elapsed := time.Since(start)
|
||||
if elapsed > 2*time.Second {
|
||||
t.Fatalf("read with closed stdin took %v, want <2s", elapsed)
|
||||
}
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v output=%q", err, out)
|
||||
}
|
||||
if !strings.Contains(string(out), "x=<>") {
|
||||
t.Fatalf("unexpected output: %q", out)
|
||||
}
|
||||
}
|
||||
|
||||
// TestNonInteractiveStdinReadBlocksWithoutRedirect 对照:stdin 为永不写入的管道时 read 会挂起。
|
||||
func TestNonInteractiveStdinReadBlocksWithoutRedirect(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("skipping shell integration in -short")
|
||||
}
|
||||
r, w, err := os.Pipe()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer r.Close()
|
||||
// 保持 w 打开且不写数据,模拟「等待用户输入」
|
||||
|
||||
cmd := exec.Command("sh", "-c", `read x; echo done`)
|
||||
cmd.Stdin = r
|
||||
|
||||
done := make(chan error, 1)
|
||||
go func() { done <- cmd.Run() }()
|
||||
|
||||
select {
|
||||
case err := <-done:
|
||||
t.Fatalf("expected hang, but command finished: %v", err)
|
||||
case <-time.After(500 * time.Millisecond):
|
||||
if cmd.Process != nil {
|
||||
_ = cmd.Process.Kill()
|
||||
}
|
||||
_ = w.Close()
|
||||
<-done // 等待 goroutine 退出
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user