mirror of
https://github.com/Ed1s0nZ/CyberStrikeAI.git
synced 2026-06-24 06:49:59 +02:00
260 lines
9.6 KiB
Go
260 lines
9.6 KiB
Go
package multiagent
|
||
|
||
import (
|
||
"context"
|
||
"errors"
|
||
"fmt"
|
||
"io"
|
||
"strings"
|
||
"sync"
|
||
"time"
|
||
|
||
"cyberstrike-ai/internal/einomcp"
|
||
"cyberstrike-ai/internal/mcp"
|
||
"cyberstrike-ai/internal/security"
|
||
|
||
"github.com/cloudwego/eino/adk/filesystem"
|
||
"github.com/cloudwego/eino/compose"
|
||
"github.com/cloudwego/eino/schema"
|
||
)
|
||
|
||
// prependPythonUnbufferedEnv 为 /bin/sh -c 注入 PYTHONUNBUFFERED=1。
|
||
// eino-ext local 对流式 stdout 使用 bufio 按「行」推送;python3 写管道时默认块缓冲,print 长期留在用户态缓冲,
|
||
// 管道里收不到换行,表现为长时间无输出直至超时或退出。若命令里已出现 PYTHONUNBUFFERED 则不再覆盖。
|
||
func prependPythonUnbufferedEnv(shellCommand string) string {
|
||
if strings.TrimSpace(shellCommand) == "" {
|
||
return shellCommand
|
||
}
|
||
if strings.Contains(strings.ToUpper(shellCommand), "PYTHONUNBUFFERED") {
|
||
return shellCommand
|
||
}
|
||
return "export PYTHONUNBUFFERED=1\n" + shellCommand
|
||
}
|
||
|
||
// einoExecuteTimeoutUserHint 与写入 ADK 工具消息(模型可见)及 SSE tool_result 尾标一致。
|
||
func einoExecuteTimeoutUserHint() string {
|
||
return "已超时终止 · Timed out"
|
||
}
|
||
|
||
// einoExecuteRecvErrIsToolTimeout 判断 Recv 错误是否由 agent.tool_timeout_minutes 触发。
|
||
// WithTimeout 到期后 local 侧常报 canceled / exit -1,但 execCtx.Err() 仍为 DeadlineExceeded。
|
||
func einoExecuteRecvErrIsToolTimeout(rerr error, tctx context.Context) bool {
|
||
if tctx != nil && errors.Is(tctx.Err(), context.DeadlineExceeded) {
|
||
return true
|
||
}
|
||
return errors.Is(rerr, context.DeadlineExceeded)
|
||
}
|
||
|
||
// einoStreamingShellWrap 包装 Eino filesystem 使用的 StreamingShell(cloudwego eino-ext local.Local)。
|
||
// 官方 execute 工具默认走 ExecuteStreaming 且不设 RunInBackendGround;末尾带 & 时子进程仍与管道相连,
|
||
// streamStdout 按行读取会在无换行输出时长时间阻塞(与 MCP 工具 exec 的独立实现不同)。
|
||
// 对「完全后台」命令自动开启 RunInBackendGround,与 local.runCmdInBackground 行为对齐。
|
||
//
|
||
// 使用 Pipe 将内层流转发给调用方:在 inner EOF 后、关闭 Pipe 前同步调用 ToolInvokeNotify.Fire,
|
||
// 保证 run loop 在模型开始下一轮输出前已记录 execute 结果(用于 UI 与「重复助手复述」去重)。
|
||
//
|
||
// 若 inner 在校验阶段直接返回 error(未建立 reader),不会进入下方 goroutine,也必须 Fire;
|
||
// 否则 pending tool_call 要等整轮 run 结束才被 force-close,与已展示的助手/工具软错误文案不同步。
|
||
type einoStreamingShellWrap struct {
|
||
inner filesystem.StreamingShell
|
||
invokeNotify *einomcp.ToolInvokeNotifyHolder
|
||
einoAgentName string
|
||
// outputChunk 可选;非 nil 时在收到内层 ExecuteResponse 片段时推送,与 MCP 工具的 tool_result_delta 一致(需有效 toolCallId)。
|
||
outputChunk func(toolName, toolCallID, chunk string)
|
||
// toolTimeoutMinutes 与 agent.tool_timeout_minutes 对齐;>0 时对单次 execute 套用 context 超时(与 MCP 工具经 executeToolViaMCP 行为一致)。0 表示仅依赖上层 ctx(如整任务 10h 上限)。
|
||
toolTimeoutMinutes int
|
||
// recordMonitor 在 execute 流结束后写入 tool_executions 并 recorder(executionId),使「渗透测试详情」与常规 MCP 一致。
|
||
recordMonitor func(toolCallID, command, stdout string, success bool, invokeErr error)
|
||
}
|
||
|
||
func (w *einoStreamingShellWrap) ExecuteStreaming(ctx context.Context, input *filesystem.ExecuteRequest) (*schema.StreamReader[*filesystem.ExecuteResponse], error) {
|
||
if w.inner == nil {
|
||
return nil, fmt.Errorf("einoStreamingShellWrap: inner shell is nil")
|
||
}
|
||
if input == nil {
|
||
return w.inner.ExecuteStreaming(ctx, nil)
|
||
}
|
||
req := *input
|
||
userCmd := strings.TrimSpace(req.Command)
|
||
if security.IsBackgroundShellCommand(req.Command) && !req.RunInBackendGround {
|
||
req.RunInBackendGround = true
|
||
}
|
||
req.Command = prependPythonUnbufferedEnv(req.Command)
|
||
tid := strings.TrimSpace(compose.GetToolCallID(ctx))
|
||
agentTag := strings.TrimSpace(w.einoAgentName)
|
||
convID := mcp.MCPConversationIDFromContext(ctx)
|
||
execReg := mcp.EinoExecuteRunRegistryFromContext(ctx)
|
||
|
||
execCtx, execCancel := context.WithCancel(ctx)
|
||
var timeoutCancel context.CancelFunc
|
||
if w.toolTimeoutMinutes > 0 {
|
||
execCtx, timeoutCancel = context.WithTimeout(execCtx, time.Duration(w.toolTimeoutMinutes)*time.Minute)
|
||
}
|
||
if execReg != nil && convID != "" {
|
||
execReg.RegisterActiveEinoExecute(convID, execCancel)
|
||
}
|
||
|
||
sr, err := w.inner.ExecuteStreaming(execCtx, &req)
|
||
if err != nil {
|
||
if timeoutCancel != nil {
|
||
timeoutCancel()
|
||
}
|
||
if execCancel != nil {
|
||
execCancel()
|
||
}
|
||
if einoExecuteRecvErrIsToolTimeout(err, execCtx) {
|
||
hint := "\n\n" + einoExecuteTimeoutUserHint() + "\n"
|
||
if w.recordMonitor != nil {
|
||
w.recordMonitor(tid, userCmd, hint, false, context.DeadlineExceeded)
|
||
}
|
||
if w.invokeNotify != nil && tid != "" {
|
||
w.invokeNotify.Fire(tid, "execute", agentTag, false, hint, context.DeadlineExceeded)
|
||
}
|
||
return schema.StreamReaderFromArray([]*filesystem.ExecuteResponse{{Output: hint}}), nil
|
||
}
|
||
if w.recordMonitor != nil {
|
||
w.recordMonitor(tid, userCmd, "", false, err)
|
||
}
|
||
if w.invokeNotify != nil && tid != "" {
|
||
w.invokeNotify.Fire(tid, "execute", agentTag, false, "", err)
|
||
}
|
||
return nil, err
|
||
}
|
||
if sr == nil || w.invokeNotify == nil {
|
||
if timeoutCancel != nil {
|
||
timeoutCancel()
|
||
}
|
||
if execCancel != nil {
|
||
execCancel()
|
||
}
|
||
return sr, nil
|
||
}
|
||
|
||
outR, outW := schema.Pipe[*filesystem.ExecuteResponse](32)
|
||
|
||
go func(inner *schema.StreamReader[*filesystem.ExecuteResponse], command string, cancel context.CancelFunc, timeoutCleanup context.CancelFunc, tctx context.Context, conversationID string, reg mcp.EinoExecuteRunRegistry) {
|
||
var innerCloseOnce sync.Once
|
||
closeInner := func() {
|
||
innerCloseOnce.Do(func() { inner.Close() })
|
||
}
|
||
defer closeInner()
|
||
if timeoutCleanup != nil {
|
||
defer timeoutCleanup()
|
||
}
|
||
if cancel != nil {
|
||
defer cancel()
|
||
}
|
||
if reg != nil && conversationID != "" {
|
||
defer reg.UnregisterActiveEinoExecute(conversationID)
|
||
}
|
||
|
||
// ctx 取消时关闭内层流,避免 amass 等长时间无换行输出时 Recv 永久阻塞。
|
||
stopWatch := make(chan struct{})
|
||
go func() {
|
||
select {
|
||
case <-tctx.Done():
|
||
closeInner()
|
||
case <-stopWatch:
|
||
}
|
||
}()
|
||
defer close(stopWatch)
|
||
|
||
var sb strings.Builder
|
||
success := true
|
||
var invokeErr error
|
||
exitCode := 0
|
||
hasExitCode := false
|
||
|
||
for {
|
||
resp, rerr := inner.Recv()
|
||
if errors.Is(rerr, io.EOF) {
|
||
break
|
||
}
|
||
if rerr != nil {
|
||
success = false
|
||
invokeErr = rerr
|
||
// 单次 execute 超时须与 MCP 工具一致:写入工具结果尾标、继续迭代,不得向 ADK 流注入硬错误。
|
||
if einoExecuteRecvErrIsToolTimeout(rerr, tctx) {
|
||
invokeErr = context.DeadlineExceeded
|
||
break
|
||
}
|
||
if errors.Is(rerr, context.Canceled) || (tctx != nil && errors.Is(tctx.Err(), context.Canceled)) {
|
||
invokeErr = context.Canceled
|
||
break
|
||
}
|
||
_ = outW.Send(nil, rerr)
|
||
break
|
||
}
|
||
if resp != nil {
|
||
if resp.ExitCode != nil {
|
||
hasExitCode = true
|
||
exitCode = *resp.ExitCode
|
||
}
|
||
var appended string
|
||
if resp.Output != "" {
|
||
sb.WriteString(resp.Output)
|
||
appended = resp.Output
|
||
}
|
||
if w.outputChunk != nil && strings.TrimSpace(appended) != "" {
|
||
w.outputChunk("execute", tid, appended)
|
||
}
|
||
if outW.Send(resp, nil) {
|
||
success = false
|
||
invokeErr = fmt.Errorf("execute stream closed by consumer")
|
||
break
|
||
}
|
||
}
|
||
}
|
||
|
||
if success && hasExitCode && exitCode != 0 {
|
||
success = false
|
||
invokeErr = fmt.Errorf("execute exited with code %d", exitCode)
|
||
}
|
||
// WithTimeout 触发后,子进程常被信号结束,local 侧多报 exit -1 / canceled,错误链里不一定带 DeadlineExceeded。
|
||
// 用执行所用 ctx 归一化,便于 UI 展示「超时」而非含糊的 -1。
|
||
if tctx != nil && errors.Is(tctx.Err(), context.DeadlineExceeded) {
|
||
success = false
|
||
invokeErr = context.DeadlineExceeded
|
||
}
|
||
// 用户「中断并继续」终止 execute:合并说明进工具结果(与 MCP CancelToolExecutionWithNote 一致)。
|
||
partialStreamed := sb.String()
|
||
var abortNote string
|
||
if reg != nil && conversationID != "" && (invokeErr != nil || errors.Is(tctx.Err(), context.Canceled)) {
|
||
if note := reg.TakeEinoExecuteAbortNote(conversationID); note != "" {
|
||
abortNote = note
|
||
merged := mcp.MergePartialToolOutputAndAbortNote(partialStreamed, note)
|
||
sb.Reset()
|
||
sb.WriteString(merged)
|
||
if invokeErr == nil {
|
||
success = false
|
||
invokeErr = context.Canceled
|
||
}
|
||
}
|
||
}
|
||
// ADK 从本 Pipe 拼出 tool 消息正文;仅 Notify 尾标不会进入模型上下文。超时句写入流,与 UI 一致。
|
||
if invokeErr != nil && errors.Is(invokeErr, context.DeadlineExceeded) {
|
||
hint := "\n\n" + einoExecuteTimeoutUserHint() + "\n"
|
||
_ = outW.Send(&filesystem.ExecuteResponse{Output: hint}, nil)
|
||
if w.outputChunk != nil && tid != "" {
|
||
w.outputChunk("execute", tid, hint)
|
||
}
|
||
sb.WriteString(hint)
|
||
}
|
||
// 中断时循环内已逐行写入 stdout;此处只追加 USER INTERRUPT NOTE,避免整段输出重复。
|
||
if invokeErr != nil && errors.Is(invokeErr, context.Canceled) && abortNote != "" {
|
||
if partialStreamed != "" {
|
||
_ = outW.Send(&filesystem.ExecuteResponse{Output: "\n\n" + mcp.AbortNoteBannerForModel + "\n" + abortNote}, nil)
|
||
} else if text := strings.TrimSpace(sb.String()); text != "" {
|
||
_ = outW.Send(&filesystem.ExecuteResponse{Output: text + "\n"}, nil)
|
||
}
|
||
}
|
||
if w.recordMonitor != nil {
|
||
w.recordMonitor(tid, command, sb.String(), success, invokeErr)
|
||
}
|
||
w.invokeNotify.Fire(tid, "execute", agentTag, success, sb.String(), invokeErr)
|
||
outW.Close()
|
||
}(sr, userCmd, execCancel, timeoutCancel, execCtx, convID, execReg)
|
||
|
||
return outR, nil
|
||
}
|