diff --git a/internal/handler/agent.go b/internal/handler/agent.go index a1e6d500..ad26d919 100644 --- a/internal/handler/agent.go +++ b/internal/handler/agent.go @@ -1185,6 +1185,8 @@ func (h *AgentHandler) createProgressCallback(runCtx context.Context, cancelRun } } flushResponsePlan() + // 助手正文开始前,推理流通常已结束;落库以便刷新后「渗透测试详情」可回放 + flushThinkingStreams() respPlan.meta = nil if dataMap, ok := data.(map[string]interface{}); ok { respPlan.meta = make(map[string]interface{}, len(dataMap)) @@ -1220,6 +1222,19 @@ func (h *AgentHandler) createProgressCallback(runCtx context.Context, cancelRun } if eventType == "response" { flushResponsePlan() + flushThinkingStreams() + return + } + if eventType == "done" { + flushResponsePlan() + flushThinkingStreams() + return + } + + // 流式思考/推理结束:聚合落库(与 eino_agent_reply_stream_end 同理) + if eventType == "thinking_stream_end" || eventType == "reasoning_chain_stream_end" { + flushResponsePlan() + flushThinkingStreams() return } diff --git a/internal/handler/agent_progress_callback_test.go b/internal/handler/agent_progress_callback_test.go index 0b64f47e..6eb13e31 100644 --- a/internal/handler/agent_progress_callback_test.go +++ b/internal/handler/agent_progress_callback_test.go @@ -3,10 +3,14 @@ package handler import ( "context" "fmt" + "os" + "path/filepath" "sync" "testing" "cyberstrike-ai/internal/config" + "cyberstrike-ai/internal/database" + "cyberstrike-ai/internal/openai" "go.uber.org/zap" ) @@ -46,3 +50,50 @@ func TestCreateProgressCallback_ConcurrentToolEvents(t *testing.T) { } wg.Wait() } + +// TestCreateProgressCallback_FlushesReasoningOnDone 流式推理聚合须在 done/response 时落库,刷新后可回放。 +func TestCreateProgressCallback_FlushesReasoningOnDone(t *testing.T) { + tmp := t.TempDir() + db, err := database.NewDB(filepath.Join(tmp, "test.sqlite"), zap.NewNop()) + if err != nil { + t.Fatalf("NewDB: %v", err) + } + defer os.RemoveAll(tmp) + + conv, err := db.CreateConversation("test", database.ConversationCreateMeta{}) + if err != nil { + t.Fatalf("CreateConversation: %v", err) + } + asst, err := db.AddMessage(conv.ID, "assistant", "处理中...", nil) + if err != nil { + t.Fatalf("AddMessage: %v", err) + } + + h := &AgentHandler{logger: zap.NewNop(), db: db} + cb := h.createProgressCallback(context.Background(), nil, conv.ID, asst.ID, nil) + + streamID := "eino-reasoning-test-1" + cb("reasoning_chain_stream_start", " ", map[string]interface{}{ + "streamId": streamID, + "source": "eino", + }) + cb("reasoning_chain_stream_delta", "step one", openai.WithSSEAccumulated(map[string]interface{}{ + "streamId": streamID, + }, "step one")) + cb("done", "", map[string]interface{}{"conversationId": conv.ID}) + + details, err := db.GetProcessDetails(asst.ID) + if err != nil { + t.Fatalf("GetProcessDetails: %v", err) + } + found := false + for _, d := range details { + if d.EventType == "reasoning_chain" && d.Message == "step one" { + found = true + break + } + } + if !found { + t.Fatalf("expected reasoning_chain persisted on done, got %+v", details) + } +}