mirror of
https://github.com/Ed1s0nZ/CyberStrikeAI.git
synced 2026-06-24 06:49:59 +02:00
228 lines
6.6 KiB
Go
228 lines
6.6 KiB
Go
package multiagent
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"io"
|
|
"strings"
|
|
"testing"
|
|
"time"
|
|
|
|
"cyberstrike-ai/internal/einomcp"
|
|
"cyberstrike-ai/internal/mcp"
|
|
|
|
"github.com/cloudwego/eino/adk/filesystem"
|
|
"github.com/cloudwego/eino/schema"
|
|
)
|
|
|
|
type mockStreamingShell struct {
|
|
immediateErr error
|
|
recvErr error
|
|
output string
|
|
}
|
|
|
|
func (m *mockStreamingShell) ExecuteStreaming(ctx context.Context, input *filesystem.ExecuteRequest) (*schema.StreamReader[*filesystem.ExecuteResponse], error) {
|
|
if m.immediateErr != nil {
|
|
return nil, m.immediateErr
|
|
}
|
|
outR, outW := schema.Pipe[*filesystem.ExecuteResponse](4)
|
|
go func() {
|
|
defer outW.Close()
|
|
if strings.TrimSpace(m.output) != "" {
|
|
_ = outW.Send(&filesystem.ExecuteResponse{Output: m.output}, nil)
|
|
}
|
|
if m.recvErr != nil {
|
|
_ = outW.Send(nil, m.recvErr)
|
|
}
|
|
}()
|
|
return outR, nil
|
|
}
|
|
|
|
func TestEinoExecuteRecvErrIsToolTimeout(t *testing.T) {
|
|
tctx, cancel := context.WithTimeout(context.Background(), time.Millisecond)
|
|
defer cancel()
|
|
time.Sleep(2 * time.Millisecond)
|
|
<-tctx.Done()
|
|
|
|
if !einoExecuteRecvErrIsToolTimeout(context.Canceled, tctx) {
|
|
t.Fatal("expected canceled recv with deadline exec ctx to count as tool timeout")
|
|
}
|
|
if !einoExecuteRecvErrIsToolTimeout(context.DeadlineExceeded, nil) {
|
|
t.Fatal("expected DeadlineExceeded recv without tctx")
|
|
}
|
|
if einoExecuteRecvErrIsToolTimeout(errors.New("exit status 1"), context.Background()) {
|
|
t.Fatal("unexpected timeout for generic error")
|
|
}
|
|
}
|
|
|
|
func TestEinoStreamingShellWrap_ToolTimeoutImmediateErrIsSoft(t *testing.T) {
|
|
inner := &mockStreamingShell{immediateErr: context.DeadlineExceeded}
|
|
wrap := &einoStreamingShellWrap{
|
|
inner: inner,
|
|
toolTimeoutMinutes: 60,
|
|
}
|
|
sr, err := wrap.ExecuteStreaming(context.Background(), &filesystem.ExecuteRequest{Command: "true"})
|
|
if err != nil {
|
|
t.Fatalf("immediate tool timeout must return soft stream, got err: %v", err)
|
|
}
|
|
defer sr.Close()
|
|
|
|
var got strings.Builder
|
|
for {
|
|
resp, rerr := sr.Recv()
|
|
if errors.Is(rerr, io.EOF) {
|
|
break
|
|
}
|
|
if rerr != nil {
|
|
t.Fatalf("outer stream must not hard-fail, got: %v", rerr)
|
|
}
|
|
if resp != nil && resp.Output != "" {
|
|
got.WriteString(resp.Output)
|
|
}
|
|
}
|
|
if !strings.Contains(got.String(), einoExecuteTimeoutUserHint()) {
|
|
t.Fatalf("expected timeout hint, got: %q", got.String())
|
|
}
|
|
}
|
|
|
|
func TestEinoStreamingShellWrap_ToolTimeoutRecvErrIsSoft(t *testing.T) {
|
|
inner := &mockStreamingShell{recvErr: context.DeadlineExceeded}
|
|
notify := einomcp.NewToolInvokeNotifyHolder()
|
|
wrap := &einoStreamingShellWrap{
|
|
inner: inner,
|
|
invokeNotify: notify,
|
|
toolTimeoutMinutes: 60,
|
|
}
|
|
// 生产路径由 Eino compose 注入 toolCallID;单测通过已过期 execCtx 识别 tool_timeout 软错误。
|
|
tctx, cancel := context.WithTimeout(context.Background(), time.Millisecond)
|
|
defer cancel()
|
|
time.Sleep(2 * time.Millisecond)
|
|
<-tctx.Done()
|
|
|
|
sr, err := wrap.ExecuteStreaming(tctx, &filesystem.ExecuteRequest{Command: "sleep 999"})
|
|
if err != nil {
|
|
t.Fatalf("ExecuteStreaming: %v", err)
|
|
}
|
|
defer sr.Close()
|
|
|
|
var got strings.Builder
|
|
for {
|
|
resp, rerr := sr.Recv()
|
|
if errors.Is(rerr, io.EOF) {
|
|
break
|
|
}
|
|
if rerr != nil {
|
|
t.Fatalf("outer stream must not hard-fail on tool timeout, got: %v", rerr)
|
|
}
|
|
if resp != nil && resp.Output != "" {
|
|
got.WriteString(resp.Output)
|
|
}
|
|
}
|
|
if !strings.Contains(got.String(), einoExecuteTimeoutUserHint()) {
|
|
t.Fatalf("expected timeout hint in stream, got: %q", got.String())
|
|
}
|
|
}
|
|
|
|
func TestEinoStreamingShellWrap_CapturesOutputWithToolTimeout(t *testing.T) {
|
|
inner := &mockStreamingShell{output: "100\n"}
|
|
notify := einomcp.NewToolInvokeNotifyHolder()
|
|
var firedContent string
|
|
notify.Set(func(toolCallID, toolName, einoAgent string, success bool, content string, invokeErr error) {
|
|
firedContent = content
|
|
})
|
|
wrap := &einoStreamingShellWrap{
|
|
inner: inner,
|
|
invokeNotify: notify,
|
|
toolTimeoutMinutes: 60,
|
|
}
|
|
sr, err := wrap.ExecuteStreaming(context.Background(), &filesystem.ExecuteRequest{Command: "echo 100"})
|
|
if err != nil {
|
|
t.Fatalf("ExecuteStreaming: %v", err)
|
|
}
|
|
defer sr.Close()
|
|
|
|
var got strings.Builder
|
|
for {
|
|
resp, rerr := sr.Recv()
|
|
if errors.Is(rerr, io.EOF) {
|
|
break
|
|
}
|
|
if rerr != nil {
|
|
t.Fatalf("unexpected stream error: %v", rerr)
|
|
}
|
|
if resp != nil && resp.Output != "" {
|
|
got.WriteString(resp.Output)
|
|
}
|
|
}
|
|
if !strings.Contains(got.String(), "100") {
|
|
t.Fatalf("stream output = %q, want contains 100", got.String())
|
|
}
|
|
if !strings.Contains(firedContent, "100") {
|
|
t.Fatalf("notify content = %q, want contains 100", firedContent)
|
|
}
|
|
}
|
|
|
|
func TestEinoStreamingShellWrap_AbortNoteDoesNotDuplicateStreamedOutput(t *testing.T) {
|
|
inner := &mockStreamingShell{output: "line1\nline2\n", recvErr: context.Canceled}
|
|
notify := einomcp.NewToolInvokeNotifyHolder()
|
|
wrap := &einoStreamingShellWrap{
|
|
inner: inner,
|
|
invokeNotify: notify,
|
|
}
|
|
reg := &abortNoteTestRegistry{note: "改成20次"}
|
|
ctx := mcp.WithEinoExecuteRunRegistry(
|
|
mcp.WithMCPConversationID(context.Background(), "conv-abort-dup"),
|
|
reg,
|
|
)
|
|
sr, err := wrap.ExecuteStreaming(ctx, &filesystem.ExecuteRequest{Command: "ping -c 10 baidu.com"})
|
|
if err != nil {
|
|
t.Fatalf("ExecuteStreaming: %v", err)
|
|
}
|
|
defer sr.Close()
|
|
|
|
var got strings.Builder
|
|
for {
|
|
resp, rerr := sr.Recv()
|
|
if errors.Is(rerr, io.EOF) {
|
|
break
|
|
}
|
|
if rerr != nil {
|
|
t.Fatalf("unexpected stream error: %v", rerr)
|
|
}
|
|
if resp != nil && resp.Output != "" {
|
|
got.WriteString(resp.Output)
|
|
}
|
|
}
|
|
out := got.String()
|
|
if strings.Count(out, "line1") != 1 || strings.Count(out, "line2") != 1 {
|
|
t.Fatalf("stream duplicated stdout: %q", out)
|
|
}
|
|
if !strings.Contains(out, "改成20次") {
|
|
t.Fatalf("stream missing abort note: %q", out)
|
|
}
|
|
}
|
|
|
|
type abortNoteTestRegistry struct {
|
|
note string
|
|
}
|
|
|
|
func (r *abortNoteTestRegistry) RegisterActiveEinoExecute(string, context.CancelFunc) {}
|
|
func (r *abortNoteTestRegistry) UnregisterActiveEinoExecute(string) {}
|
|
func (r *abortNoteTestRegistry) AbortActiveEinoExecute(string, string) bool { return false }
|
|
func (r *abortNoteTestRegistry) TakeEinoExecuteAbortNote(string) string { return r.note }
|
|
|
|
func TestEinoStreamingShellWrap_NonTimeoutRecvErrStillHard(t *testing.T) {
|
|
inner := &mockStreamingShell{recvErr: errors.New("broken pipe")}
|
|
wrap := &einoStreamingShellWrap{inner: inner}
|
|
sr, err := wrap.ExecuteStreaming(context.Background(), &filesystem.ExecuteRequest{Command: "true"})
|
|
if err != nil {
|
|
t.Fatalf("ExecuteStreaming: %v", err)
|
|
}
|
|
defer sr.Close()
|
|
|
|
_, rerr := sr.Recv()
|
|
if rerr == nil || errors.Is(rerr, io.EOF) {
|
|
t.Fatal("expected hard stream error for non-timeout failure")
|
|
}
|
|
}
|