mirror of
https://github.com/Ed1s0nZ/CyberStrikeAI.git
synced 2026-06-24 14:59:59 +02:00
Add files via upload
This commit is contained in:
+17
-4
@@ -779,13 +779,26 @@ func (a *Agent) ExecuteMCPToolForConversation(ctx context.Context, conversationI
|
||||
return a.executeToolViaMCP(ctx, toolName, args)
|
||||
}
|
||||
|
||||
// RecordLocalToolExecution 将非 CallTool 路径完成的工具调用写入 MCP 监控库(与 CallTool 落库一致),返回 executionId。
|
||||
// 用于 Eino filesystem execute 等场景,使助手气泡「渗透测试详情」与常规 MCP 一致可点进监控。
|
||||
func (a *Agent) RecordLocalToolExecution(toolName string, args map[string]interface{}, resultText string, invokeErr error) string {
|
||||
// BeginLocalToolExecution 在非 CallTool 路径工具开始时写入 running 状态,供 MCP 监控页展示「执行中」。
|
||||
func (a *Agent) BeginLocalToolExecution(toolName string, args map[string]interface{}) string {
|
||||
if a == nil || a.mcpServer == nil {
|
||||
return ""
|
||||
}
|
||||
return a.mcpServer.RecordCompletedToolInvocation(toolName, args, resultText, invokeErr)
|
||||
return a.mcpServer.BeginToolExecution(toolName, args)
|
||||
}
|
||||
|
||||
// FinishLocalToolExecution 完成 BeginLocalToolExecution 创建的记录;executionID 为空时一次性写入已完成记录。
|
||||
func (a *Agent) FinishLocalToolExecution(executionID, toolName string, args map[string]interface{}, resultText string, invokeErr error) string {
|
||||
if a == nil || a.mcpServer == nil {
|
||||
return ""
|
||||
}
|
||||
return a.mcpServer.FinishToolExecution(executionID, toolName, args, resultText, invokeErr)
|
||||
}
|
||||
|
||||
// RecordLocalToolExecution 将非 CallTool 路径完成的工具调用写入 MCP 监控库(与 CallTool 落库一致),返回 executionId。
|
||||
// 用于 Eino filesystem execute 等场景,使助手气泡「渗透测试详情」与常规 MCP 一致可点进监控。
|
||||
func (a *Agent) RecordLocalToolExecution(toolName string, args map[string]interface{}, resultText string, invokeErr error) string {
|
||||
return a.FinishLocalToolExecution("", toolName, args, resultText, invokeErr)
|
||||
}
|
||||
|
||||
// UpdateMCPExecutionDisplayResult 将监控库中的工具结果更新为送入模型的展示正文(reduction 后)。
|
||||
|
||||
@@ -605,6 +605,8 @@ type DatabaseConfig struct {
|
||||
type AgentConfig struct {
|
||||
MaxIterations int `yaml:"max_iterations" json:"max_iterations"`
|
||||
ToolTimeoutMinutes int `yaml:"tool_timeout_minutes" json:"tool_timeout_minutes"` // 单次工具执行最大时长(分钟),超时自动终止,防止长时间挂起;0 表示不限制(不推荐)
|
||||
// ShellNoOutputTimeoutSeconds execute/exec 无任何 stdout/stderr 时的空闲终止秒数(通用防挂死,不维护命令黑名单);0=默认 300(5 分钟);-1=关闭。
|
||||
ShellNoOutputTimeoutSeconds int `yaml:"shell_no_output_timeout_seconds" json:"shell_no_output_timeout_seconds"`
|
||||
// SystemPromptPath 单代理系统提示 Markdown/文本文件路径(相对 config.yaml 所在目录,或可写绝对路径)。非空且可读时替换内置单代理提示;留空用内置。
|
||||
SystemPromptPath string `yaml:"system_prompt_path,omitempty" json:"system_prompt_path,omitempty"`
|
||||
}
|
||||
@@ -1270,8 +1272,9 @@ func Default() *Config {
|
||||
MaxTotalTokens: 120000,
|
||||
},
|
||||
Agent: AgentConfig{
|
||||
MaxIterations: 30, // 默认最大迭代次数
|
||||
ToolTimeoutMinutes: 10, // 单次工具执行默认最多 10 分钟,避免异常长时间占用
|
||||
MaxIterations: 30, // 默认最大迭代次数
|
||||
ToolTimeoutMinutes: 10, // 单次工具执行默认最多 10 分钟,避免异常长时间占用
|
||||
ShellNoOutputTimeoutSeconds: 300, // execute/exec 无新输出空闲终止(秒);-1 关闭
|
||||
},
|
||||
Security: SecurityConfig{
|
||||
Tools: []ToolConfig{}, // 工具配置应该从 config.yaml 或 tools/ 目录加载
|
||||
|
||||
+83
-16
@@ -921,9 +921,8 @@ func (s *Server) CallTool(ctx context.Context, toolName string, args map[string]
|
||||
return finalResult, executionID, nil
|
||||
}
|
||||
|
||||
// RecordCompletedToolInvocation 将已在其它路径完成的工具调用写入监控存储(格式与 CallTool 结束后一致),
|
||||
// 用于 Eino ADK filesystem execute 等未经过 CallTool 的场景;返回 executionId 供助手消息 mcpExecutionIds 关联。
|
||||
func (s *Server) RecordCompletedToolInvocation(toolName string, args map[string]interface{}, resultText string, invokeErr error) string {
|
||||
// BeginToolExecution 创建 running 状态的执行记录,供 Eino 等非 CallTool 路径在工具开始时落库。
|
||||
func (s *Server) BeginToolExecution(toolName string, args map[string]interface{}) string {
|
||||
if s == nil {
|
||||
return ""
|
||||
}
|
||||
@@ -931,21 +930,73 @@ func (s *Server) RecordCompletedToolInvocation(toolName string, args map[string]
|
||||
args = map[string]interface{}{}
|
||||
}
|
||||
executionID := uuid.New().String()
|
||||
now := time.Now()
|
||||
failed := invokeErr != nil
|
||||
exec := &ToolExecution{
|
||||
execution := &ToolExecution{
|
||||
ID: executionID,
|
||||
ToolName: toolName,
|
||||
Arguments: args,
|
||||
StartTime: now,
|
||||
EndTime: &now,
|
||||
Duration: 0,
|
||||
Status: "running",
|
||||
StartTime: time.Now(),
|
||||
}
|
||||
|
||||
s.mu.Lock()
|
||||
s.executions[executionID] = execution
|
||||
s.cleanupOldExecutions()
|
||||
s.mu.Unlock()
|
||||
|
||||
if s.storage != nil {
|
||||
if err := s.storage.SaveToolExecution(execution); err != nil {
|
||||
s.logger.Warn("保存执行记录到数据库失败", zap.Error(err))
|
||||
}
|
||||
}
|
||||
return executionID
|
||||
}
|
||||
|
||||
// FinishToolExecution 完成先前 BeginToolExecution 创建的记录;executionID 为空时等同 RecordCompletedToolInvocation。
|
||||
func (s *Server) FinishToolExecution(executionID, toolName string, args map[string]interface{}, resultText string, invokeErr error) string {
|
||||
if s == nil {
|
||||
return ""
|
||||
}
|
||||
if args == nil {
|
||||
args = map[string]interface{}{}
|
||||
}
|
||||
id := strings.TrimSpace(executionID)
|
||||
if id == "" {
|
||||
return s.RecordCompletedToolInvocation(toolName, args, resultText, invokeErr)
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
failed := invokeErr != nil
|
||||
var finalResult *ToolResult
|
||||
|
||||
s.mu.Lock()
|
||||
exec, inMem := s.executions[id]
|
||||
if !inMem || exec == nil {
|
||||
exec = &ToolExecution{
|
||||
ID: id,
|
||||
ToolName: toolName,
|
||||
Arguments: args,
|
||||
StartTime: now,
|
||||
}
|
||||
s.executions[id] = exec
|
||||
} else if toolName != "" {
|
||||
exec.ToolName = toolName
|
||||
}
|
||||
if len(args) > 0 {
|
||||
exec.Arguments = args
|
||||
}
|
||||
exec.EndTime = &now
|
||||
if exec.StartTime.IsZero() {
|
||||
exec.StartTime = now
|
||||
}
|
||||
exec.Duration = now.Sub(exec.StartTime)
|
||||
|
||||
if failed {
|
||||
exec.Status = "failed"
|
||||
exec.Error = invokeErr.Error()
|
||||
st, msg := executionStatusAndMessage(invokeErr)
|
||||
exec.Status = st
|
||||
exec.Error = msg
|
||||
if strings.TrimSpace(resultText) != "" {
|
||||
exec.Result = &ToolResult{Content: []Content{{Type: "text", Text: resultText}}}
|
||||
finalResult = &ToolResult{Content: []Content{{Type: "text", Text: resultText}}}
|
||||
exec.Result = finalResult
|
||||
}
|
||||
} else {
|
||||
exec.Status = "completed"
|
||||
@@ -953,15 +1004,31 @@ func (s *Server) RecordCompletedToolInvocation(toolName string, args map[string]
|
||||
if strings.TrimSpace(text) == "" {
|
||||
text = "(无输出)"
|
||||
}
|
||||
exec.Result = &ToolResult{Content: []Content{{Type: "text", Text: text}}}
|
||||
finalResult = &ToolResult{Content: []Content{{Type: "text", Text: text}}}
|
||||
exec.Result = finalResult
|
||||
}
|
||||
s.mu.Unlock()
|
||||
|
||||
if s.storage != nil {
|
||||
if err := s.storage.SaveToolExecution(exec); err != nil {
|
||||
s.logger.Warn("RecordCompletedToolInvocation 保存失败", zap.Error(err))
|
||||
s.logger.Warn("保存执行记录到数据库失败", zap.Error(err))
|
||||
}
|
||||
}
|
||||
s.updateStats(toolName, failed)
|
||||
return executionID
|
||||
|
||||
s.updateStats(exec.ToolName, failed)
|
||||
|
||||
if s.storage != nil {
|
||||
s.mu.Lock()
|
||||
delete(s.executions, id)
|
||||
s.mu.Unlock()
|
||||
}
|
||||
return id
|
||||
}
|
||||
|
||||
// RecordCompletedToolInvocation 将已在其它路径完成的工具调用写入监控存储(格式与 CallTool 结束后一致),
|
||||
// 用于 Eino ADK filesystem execute 等未经过 CallTool 的场景;返回 executionId 供助手消息 mcpExecutionIds 关联。
|
||||
func (s *Server) RecordCompletedToolInvocation(toolName string, args map[string]interface{}, resultText string, invokeErr error) string {
|
||||
return s.FinishToolExecution("", toolName, args, resultText, invokeErr)
|
||||
}
|
||||
|
||||
// UpdateToolExecutionResult 将监控库中的工具结果更新为送入模型的展示正文(如 reduction 后的 persisted-output)。
|
||||
|
||||
@@ -199,6 +199,8 @@ type ToolExecution struct {
|
||||
StartTime time.Time `json:"startTime"`
|
||||
EndTime *time.Time `json:"endTime,omitempty"`
|
||||
Duration time.Duration `json:"duration,omitempty"`
|
||||
// ConversationID 仅 API 展示用(进行中的 Agent 任务),不写入 tool_executions 表。
|
||||
ConversationID string `json:"conversationId,omitempty"`
|
||||
}
|
||||
|
||||
// ToolStats 工具统计信息
|
||||
|
||||
@@ -32,10 +32,11 @@ var ToolOutputCallbackCtxKey = toolOutputCallbackCtxKey{}
|
||||
|
||||
// Executor 安全工具执行器
|
||||
type Executor struct {
|
||||
config *config.SecurityConfig
|
||||
toolIndex map[string]*config.ToolConfig // 工具索引,用于 O(1) 查找
|
||||
mcpServer *mcp.Server
|
||||
logger *zap.Logger
|
||||
config *config.SecurityConfig
|
||||
toolIndex map[string]*config.ToolConfig // 工具索引,用于 O(1) 查找
|
||||
mcpServer *mcp.Server
|
||||
logger *zap.Logger
|
||||
shellNoOutputTimeoutSec int // execute/exec 无新输出空闲秒数;0=默认 300;-1=关闭(见 SetShellNoOutputTimeoutSeconds)
|
||||
}
|
||||
|
||||
// NewExecutor 创建新的执行器
|
||||
@@ -51,6 +52,11 @@ func NewExecutor(cfg *config.SecurityConfig, mcpServer *mcp.Server, logger *zap.
|
||||
return executor
|
||||
}
|
||||
|
||||
// SetShellNoOutputTimeoutSeconds 配置 exec 工具无输出空闲终止(与 agent.shell_no_output_timeout_seconds 一致)。
|
||||
func (e *Executor) SetShellNoOutputTimeoutSeconds(sec int) {
|
||||
e.shellNoOutputTimeoutSec = sec
|
||||
}
|
||||
|
||||
// buildToolIndex 构建工具索引,将 O(n) 查找优化为 O(1)
|
||||
func (e *Executor) buildToolIndex() {
|
||||
e.toolIndex = make(map[string]*config.ToolConfig)
|
||||
@@ -133,6 +139,7 @@ func (e *Executor) ExecuteTool(ctx context.Context, toolName string, args map[st
|
||||
// 执行命令
|
||||
cmd := exec.CommandContext(ctx, toolConfig.Command, cmdArgs...)
|
||||
applyDefaultTerminalEnv(cmd)
|
||||
attachNonInteractiveStdin(cmd)
|
||||
_ = prepareShellCmdSession(cmd)
|
||||
|
||||
e.logger.Info("执行安全工具",
|
||||
@@ -144,7 +151,7 @@ func (e *Executor) ExecuteTool(ctx context.Context, toolName string, args map[st
|
||||
var err error
|
||||
// 如果上层提供了 stdout/stderr 增量回调,则边执行边读取并回调。
|
||||
if cb, ok := ctx.Value(ToolOutputCallbackCtxKey).(ToolOutputCallback); ok && cb != nil {
|
||||
output, err = streamCommandOutput(ctx, cmd, cb)
|
||||
output, err = streamCommandOutput(ctx, cmd, cb, ResolveShellNoOutputTimeoutSeconds(e.shellNoOutputTimeoutSec))
|
||||
if err != nil && shouldRetryWithPTY(output) {
|
||||
e.logger.Info("检测到工具需要 TTY,使用 PTY 重试",
|
||||
zap.String("tool", toolName),
|
||||
@@ -797,6 +804,8 @@ func (e *Executor) executeSystemCommand(ctx context.Context, args map[string]int
|
||||
zap.String("command", command),
|
||||
)
|
||||
|
||||
command = PrepareNonInteractiveShellCommand(command)
|
||||
|
||||
// 获取shell类型(可选,默认为sh)
|
||||
shell := "sh"
|
||||
if s, ok := args["shell"].(string); ok && s != "" {
|
||||
@@ -821,6 +830,7 @@ func (e *Executor) executeSystemCommand(ctx context.Context, args map[string]int
|
||||
cmd = exec.CommandContext(ctx, shell, "-c", command)
|
||||
}
|
||||
applyDefaultTerminalEnv(cmd)
|
||||
attachNonInteractiveStdin(cmd)
|
||||
_ = prepareShellCmdSession(cmd)
|
||||
|
||||
// 执行命令
|
||||
@@ -963,7 +973,7 @@ func (e *Executor) executeSystemCommand(ctx context.Context, args map[string]int
|
||||
var err error
|
||||
// 若上层提供工具输出增量回调,则边执行边流式读取。
|
||||
if cb, ok := ctx.Value(ToolOutputCallbackCtxKey).(ToolOutputCallback); ok && cb != nil {
|
||||
output, err = streamCommandOutput(ctx, cmd, cb)
|
||||
output, err = streamCommandOutput(ctx, cmd, cb, ResolveShellNoOutputTimeoutSeconds(e.shellNoOutputTimeoutSec))
|
||||
if err != nil && shouldRetryWithPTY(output) {
|
||||
e.logger.Info("检测到系统命令需要 TTY,使用 PTY 重试")
|
||||
cmd2 := exec.CommandContext(ctx, shell, "-c", command)
|
||||
@@ -1024,7 +1034,7 @@ func (e *Executor) executeSystemCommand(ctx context.Context, args map[string]int
|
||||
|
||||
// streamCommandOutput 以“边读边回调”的方式读取命令 stdout/stderr。
|
||||
// 使用定长块读取,避免按行读取在无换行输出时永久阻塞;ctx 取消时终止进程树。
|
||||
func streamCommandOutput(ctx context.Context, cmd *exec.Cmd, cb ToolOutputCallback) (string, error) {
|
||||
func streamCommandOutput(ctx context.Context, cmd *exec.Cmd, cb ToolOutputCallback, noOutputSec int) (string, error) {
|
||||
if err := prepareShellCmdSession(cmd); err != nil {
|
||||
return "", err
|
||||
}
|
||||
@@ -1091,12 +1101,43 @@ func streamCommandOutput(ctx context.Context, cmd *exec.Cmd, cb ToolOutputCallba
|
||||
lastFlush = time.Now()
|
||||
}
|
||||
|
||||
for chunk := range chunks {
|
||||
outBuilder.WriteString(chunk)
|
||||
deltaBuilder.WriteString(chunk)
|
||||
// 简单节流:buffer 大于 2KB 或 200ms 就刷新一次
|
||||
if deltaBuilder.Len() >= 2048 || time.Since(lastFlush) >= 200*time.Millisecond {
|
||||
flush()
|
||||
idleWatch := NewShellInactivityWatch(noOutputSec)
|
||||
if idleWatch != nil {
|
||||
defer idleWatch.Stop()
|
||||
}
|
||||
|
||||
fireInactivity := func() {
|
||||
terminateCmdTree(cmd)
|
||||
msg := ShellNoOutputTimeoutMessage(idleWatch.Sec)
|
||||
outBuilder.WriteString(msg)
|
||||
if cb != nil {
|
||||
cb(msg)
|
||||
}
|
||||
_ = cmd.Wait()
|
||||
}
|
||||
|
||||
chunksLoop:
|
||||
for {
|
||||
var idleCh <-chan struct{}
|
||||
if idleWatch != nil {
|
||||
idleCh = idleWatch.Expired
|
||||
}
|
||||
select {
|
||||
case <-idleCh:
|
||||
fireInactivity()
|
||||
return outBuilder.String(), fmt.Errorf("shell inactivity timeout (%ds)", idleWatch.Sec)
|
||||
case chunk, ok := <-chunks:
|
||||
if !ok {
|
||||
break chunksLoop
|
||||
}
|
||||
if chunk != "" && idleWatch != nil {
|
||||
idleWatch.Bump()
|
||||
}
|
||||
outBuilder.WriteString(chunk)
|
||||
deltaBuilder.WriteString(chunk)
|
||||
if deltaBuilder.Len() >= 2048 || time.Since(lastFlush) >= 200*time.Millisecond {
|
||||
flush()
|
||||
}
|
||||
}
|
||||
}
|
||||
flush()
|
||||
@@ -1116,6 +1157,7 @@ func applyDefaultTerminalEnv(cmd *exec.Cmd) {
|
||||
if cmd.Env == nil {
|
||||
cmd.Env = os.Environ()
|
||||
}
|
||||
cmd.Env = ApplyNonInteractivePagerEnv(cmd.Env)
|
||||
// 如果用户已设置 TERM/COLUMNS/LINES,则不覆盖
|
||||
has := func(k string) bool {
|
||||
prefix := k + "="
|
||||
@@ -1159,7 +1201,7 @@ func runCommandWithPTY(ctx context.Context, cmd *exec.Cmd, cb ToolOutputCallback
|
||||
if runtime.GOOS == "windows" {
|
||||
// PTY 方案为类 Unix;Windows 走原逻辑
|
||||
if cb != nil {
|
||||
return streamCommandOutput(ctx, cmd, cb)
|
||||
return streamCommandOutput(ctx, cmd, cb, 0)
|
||||
}
|
||||
_ = prepareShellCmdSession(cmd)
|
||||
out, err := cmd.CombinedOutput()
|
||||
|
||||
@@ -0,0 +1,163 @@
|
||||
package security
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"os/exec"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// ShellNoOutputTimeoutMessage 长时间无新 stdout/stderr 时的提示(软失败,模型可见)。
|
||||
func ShellNoOutputTimeoutMessage(idleSec int) string {
|
||||
return fmt.Sprintf(`命令已终止:超过 %d 秒没有新的输出,疑似在等待交互输入或已挂起。
|
||||
|
||||
长时静默任务请使用末尾 & 后台运行,或增大 agent.shell_no_output_timeout_seconds(-1=关闭此检测)。
|
||||
|
||||
Command terminated: no new output for %d seconds (possible interactive wait or hung process).`, idleSec, idleSec)
|
||||
}
|
||||
|
||||
// ShellInactivityWatch 在 noOutputSec 内无任何新输出时向 expired 发送信号;每次 Bump 重置计时。
|
||||
// 与「仅有首包输出就永久取消计时」不同,可兜住 sudo 打印 Password 提示后继续挂起等情况。
|
||||
type ShellInactivityWatch struct {
|
||||
Sec int
|
||||
mu sync.Mutex
|
||||
timer *time.Timer
|
||||
Expired chan struct{}
|
||||
}
|
||||
|
||||
func NewShellInactivityWatch(noOutputSec int) *ShellInactivityWatch {
|
||||
sec := ResolveShellNoOutputTimeoutSeconds(noOutputSec)
|
||||
if sec <= 0 {
|
||||
return nil
|
||||
}
|
||||
w := &ShellInactivityWatch{
|
||||
Sec: sec,
|
||||
Expired: make(chan struct{}, 1),
|
||||
}
|
||||
w.Bump()
|
||||
return w
|
||||
}
|
||||
|
||||
func (w *ShellInactivityWatch) Bump() {
|
||||
if w == nil || w.Sec <= 0 {
|
||||
return
|
||||
}
|
||||
w.mu.Lock()
|
||||
defer w.mu.Unlock()
|
||||
if w.timer != nil {
|
||||
w.timer.Stop()
|
||||
}
|
||||
w.timer = time.AfterFunc(time.Duration(w.Sec)*time.Second, func() {
|
||||
select {
|
||||
case w.Expired <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func (w *ShellInactivityWatch) Stop() {
|
||||
if w == nil {
|
||||
return
|
||||
}
|
||||
w.mu.Lock()
|
||||
defer w.mu.Unlock()
|
||||
if w.timer != nil {
|
||||
w.timer.Stop()
|
||||
w.timer = nil
|
||||
}
|
||||
}
|
||||
|
||||
// ResolveShellNoOutputTimeoutSeconds:0=默认 300(5 分钟);-1=关闭;>0=自定义。
|
||||
func ResolveShellNoOutputTimeoutSeconds(sec int) int {
|
||||
if sec < 0 {
|
||||
return 0
|
||||
}
|
||||
if sec == 0 {
|
||||
return 300
|
||||
}
|
||||
return sec
|
||||
}
|
||||
|
||||
// PrependNonInteractiveShellExports 为 sh -c 注入通用非交互环境(pager 等),不维护命令黑名单。
|
||||
func PrependNonInteractiveShellExports(shellCommand string) string {
|
||||
if strings.TrimSpace(shellCommand) == "" {
|
||||
return shellCommand
|
||||
}
|
||||
upper := strings.ToUpper(shellCommand)
|
||||
var pairs []string
|
||||
add := func(key, val string) {
|
||||
if strings.Contains(upper, strings.ToUpper(key)) {
|
||||
return
|
||||
}
|
||||
pairs = append(pairs, key+"="+val)
|
||||
}
|
||||
add("GIT_PAGER", "cat")
|
||||
add("PAGER", "cat")
|
||||
add("SYSTEMD_PAGER", "cat")
|
||||
add("DEBIAN_FRONTEND", "noninteractive")
|
||||
if len(pairs) == 0 {
|
||||
return shellCommand
|
||||
}
|
||||
return "export " + strings.Join(pairs, " ") + "\n" + shellCommand
|
||||
}
|
||||
|
||||
// PrependNonInteractiveStdinRedirect 为 sh -c 关闭 stdin(与 attachNonInteractiveStdin 等价),
|
||||
// 使 read/input()/sudo -S 等从 stdin 读取的程序快速失败而非挂起。已含 </dev/null 时不重复注入。
|
||||
func PrependNonInteractiveStdinRedirect(shellCommand string) string {
|
||||
if strings.TrimSpace(shellCommand) == "" {
|
||||
return shellCommand
|
||||
}
|
||||
lower := strings.ToLower(shellCommand)
|
||||
if strings.Contains(lower, "</dev/null") || strings.Contains(lower, "0</dev/null") {
|
||||
return shellCommand
|
||||
}
|
||||
return "exec </dev/null\n" + shellCommand
|
||||
}
|
||||
|
||||
// PrepareNonInteractiveShellCommand 组合非交互包装:stdin 关闭 + pager 等环境变量(零名单)。
|
||||
func PrepareNonInteractiveShellCommand(shellCommand string) string {
|
||||
return PrependNonInteractiveStdinRedirect(PrependNonInteractiveShellExports(shellCommand))
|
||||
}
|
||||
|
||||
// ApplyNonInteractivePagerEnv 为 exec.Cmd 补齐与 PrependNonInteractiveShellExports 一致的环境变量。
|
||||
func ApplyNonInteractivePagerEnv(cmdEnv []string) []string {
|
||||
if cmdEnv == nil {
|
||||
cmdEnv = []string{}
|
||||
}
|
||||
has := func(k string) bool {
|
||||
prefix := k + "="
|
||||
for _, e := range cmdEnv {
|
||||
if strings.HasPrefix(e, prefix) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
if !has("GIT_PAGER") {
|
||||
cmdEnv = append(cmdEnv, "GIT_PAGER=cat")
|
||||
}
|
||||
if !has("PAGER") {
|
||||
cmdEnv = append(cmdEnv, "PAGER=cat")
|
||||
}
|
||||
if !has("SYSTEMD_PAGER") {
|
||||
cmdEnv = append(cmdEnv, "SYSTEMD_PAGER=cat")
|
||||
}
|
||||
if !has("DEBIAN_FRONTEND") {
|
||||
cmdEnv = append(cmdEnv, "DEBIAN_FRONTEND=noninteractive")
|
||||
}
|
||||
return cmdEnv
|
||||
}
|
||||
|
||||
// attachNonInteractiveStdin 关闭交互式 stdin,使部分命令快速失败而非等待输入。
|
||||
func attachNonInteractiveStdin(cmd *exec.Cmd) {
|
||||
if cmd == nil || cmd.Stdin != nil {
|
||||
return
|
||||
}
|
||||
f, err := os.Open(os.DevNull)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
cmd.Stdin = f
|
||||
}
|
||||
@@ -0,0 +1,128 @@
|
||||
package security
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"os/exec"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestPrependNonInteractiveShellExports(t *testing.T) {
|
||||
out := PrependNonInteractiveShellExports("echo hi")
|
||||
if !strings.Contains(out, "GIT_PAGER=cat") || !strings.Contains(out, "PAGER=cat") {
|
||||
t.Fatalf("missing pager exports: %q", out)
|
||||
}
|
||||
if !strings.HasSuffix(strings.TrimSpace(out), "echo hi") {
|
||||
t.Fatalf("command suffix lost: %q", out)
|
||||
}
|
||||
skip := PrependNonInteractiveShellExports("GIT_PAGER=less echo hi")
|
||||
if strings.Contains(skip, "export GIT_PAGER=cat") {
|
||||
t.Fatalf("should not override existing GIT_PAGER: %q", skip)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPrependNonInteractiveStdinRedirect(t *testing.T) {
|
||||
out := PrependNonInteractiveStdinRedirect("echo hi")
|
||||
if !strings.HasPrefix(out, "exec </dev/null") {
|
||||
t.Fatalf("missing stdin redirect: %q", out)
|
||||
}
|
||||
if !strings.HasSuffix(strings.TrimSpace(out), "echo hi") {
|
||||
t.Fatalf("command suffix lost: %q", out)
|
||||
}
|
||||
skip := PrependNonInteractiveStdinRedirect("cmd </dev/null")
|
||||
if strings.HasPrefix(skip, "exec </dev/null") {
|
||||
t.Fatalf("should not double redirect: %q", skip)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPrepareNonInteractiveShellCommand(t *testing.T) {
|
||||
out := PrepareNonInteractiveShellCommand("echo hi")
|
||||
if !strings.Contains(out, "exec </dev/null") {
|
||||
t.Fatalf("missing stdin redirect: %q", out)
|
||||
}
|
||||
if !strings.Contains(out, "GIT_PAGER=cat") {
|
||||
t.Fatalf("missing pager export: %q", out)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewShellInactivityWatch(t *testing.T) {
|
||||
w := NewShellInactivityWatch(1)
|
||||
if w == nil {
|
||||
t.Fatal("expected watch")
|
||||
}
|
||||
w.Bump()
|
||||
select {
|
||||
case <-w.Expired:
|
||||
case <-time.After(3 * time.Second):
|
||||
t.Fatal("expected inactivity fire within 3s")
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveShellNoOutputTimeoutSeconds(t *testing.T) {
|
||||
if ResolveShellNoOutputTimeoutSeconds(0) != 300 {
|
||||
t.Fatal("zero should default to 300")
|
||||
}
|
||||
if ResolveShellNoOutputTimeoutSeconds(-1) != 0 {
|
||||
t.Fatal("-1 should disable")
|
||||
}
|
||||
if ResolveShellNoOutputTimeoutSeconds(30) != 30 {
|
||||
t.Fatal("explicit value")
|
||||
}
|
||||
}
|
||||
|
||||
// TestNonInteractiveStdinReadExitsQuickly 验证 exec </dev/null + attachNonInteractiveStdin 时 read 立即 EOF,不挂起。
|
||||
func TestNonInteractiveStdinReadExitsQuickly(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("skipping shell integration in -short")
|
||||
}
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
cmd := exec.CommandContext(ctx, "sh", "-c", PrepareNonInteractiveShellCommand(`read x; echo "x=<$x>"`))
|
||||
attachNonInteractiveStdin(cmd)
|
||||
|
||||
start := time.Now()
|
||||
out, err := cmd.CombinedOutput()
|
||||
elapsed := time.Since(start)
|
||||
if elapsed > 2*time.Second {
|
||||
t.Fatalf("read with closed stdin took %v, want <2s", elapsed)
|
||||
}
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v output=%q", err, out)
|
||||
}
|
||||
if !strings.Contains(string(out), "x=<>") {
|
||||
t.Fatalf("unexpected output: %q", out)
|
||||
}
|
||||
}
|
||||
|
||||
// TestNonInteractiveStdinReadBlocksWithoutRedirect 对照:stdin 为永不写入的管道时 read 会挂起。
|
||||
func TestNonInteractiveStdinReadBlocksWithoutRedirect(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("skipping shell integration in -short")
|
||||
}
|
||||
r, w, err := os.Pipe()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer r.Close()
|
||||
// 保持 w 打开且不写数据,模拟「等待用户输入」
|
||||
|
||||
cmd := exec.Command("sh", "-c", `read x; echo done`)
|
||||
cmd.Stdin = r
|
||||
|
||||
done := make(chan error, 1)
|
||||
go func() { done <- cmd.Run() }()
|
||||
|
||||
select {
|
||||
case err := <-done:
|
||||
t.Fatalf("expected hang, but command finished: %v", err)
|
||||
case <-time.After(500 * time.Millisecond):
|
||||
if cmd.Process != nil {
|
||||
_ = cmd.Process.Kill()
|
||||
}
|
||||
_ = w.Close()
|
||||
<-done // 等待 goroutine 退出
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user