diff --git a/internal/handler/agent.go b/internal/handler/agent.go index 7a92b34b..bcdde35a 100644 --- a/internal/handler/agent.go +++ b/internal/handler/agent.go @@ -214,7 +214,7 @@ type ChatAttachment struct { type ChatReasoningRequest struct { // Mode: default(跟随系统)| off | on | auto Mode string `json:"mode,omitempty"` - // Effort: low | medium | high | max;空表示不指定(由系统默认与各 profile 决定)。 + // Effort: low | medium | high | max | xhigh(原样下发;不同网关最高档命名不同)。空表示不指定。 Effort string `json:"effort,omitempty"` } @@ -830,11 +830,26 @@ func (h *AgentHandler) ProcessMessageForRobot(ctx context.Context, platform, con } switch robotMode { case "eino_single": - resultMA, errMA := multiagent.RunEinoSingleChatModelAgent( - taskCtx, h.config, &h.config.MultiAgent, h.agent, h.logger, - conversationID, finalMessage, agentHistoryMessages, roleTools, progressCallback, nil, - ) - if errMA != nil { + curHist := agentHistoryMessages + curMsg := finalMessage + segmentUserMessage := finalMessage + var resultMA *multiagent.RunResult + var errMA error + var transientRunAttempts int + for { + resultMA, errMA = multiagent.RunEinoSingleChatModelAgent( + taskCtx, h.config, &h.config.MultiAgent, h.agent, h.logger, + conversationID, curMsg, curHist, roleTools, progressCallback, nil, + ) + if errMA == nil { + break + } + if handled, _ := h.handleEinoTransientRetryContinue( + taskCtx, conversationID, resultMA, errMA, &transientRunAttempts, + &curHist, &curMsg, segmentUserMessage, progressCallback, nil, + ); handled { + continue + } taskStatus = "failed" return h.finalizeRobotAgentError(taskCtx, assistantMessageID, conversationID, resultMA, errMA) } @@ -845,12 +860,27 @@ func (h *AgentHandler) ProcessMessageForRobot(ctx context.Context, platform, con zap.String("robot_mode", robotMode)) break } - resultMA, errMA := multiagent.RunDeepAgent( - taskCtx, h.config, &h.config.MultiAgent, h.agent, h.logger, - conversationID, finalMessage, agentHistoryMessages, roleTools, progressCallback, - h.agentsMarkdownDir, robotMode, nil, - ) - if errMA != nil { + curHist := agentHistoryMessages + curMsg := finalMessage + segmentUserMessage := finalMessage + var resultMA *multiagent.RunResult + var errMA error + var transientRunAttempts int + for { + resultMA, errMA = multiagent.RunDeepAgent( + taskCtx, h.config, &h.config.MultiAgent, h.agent, h.logger, + conversationID, curMsg, curHist, roleTools, progressCallback, + h.agentsMarkdownDir, robotMode, nil, + ) + if errMA == nil { + break + } + if handled, _ := h.handleEinoTransientRetryContinue( + taskCtx, conversationID, resultMA, errMA, &transientRunAttempts, + &curHist, &curMsg, segmentUserMessage, progressCallback, nil, + ); handled { + continue + } taskStatus = "failed" return h.finalizeRobotAgentError(taskCtx, assistantMessageID, conversationID, resultMA, errMA) } diff --git a/internal/handler/eino_resume_segment.go b/internal/handler/eino_resume_segment.go new file mode 100644 index 00000000..a72a1d61 --- /dev/null +++ b/internal/handler/eino_resume_segment.go @@ -0,0 +1,122 @@ +package handler + +import ( + "context" + "errors" + "fmt" + "strings" + "time" + + "cyberstrike-ai/internal/agent" + "cyberstrike-ai/internal/multiagent" +) + +func (h *AgentHandler) einoRunRetryMaxAttempts() int { + if h.config != nil { + return multiagent.RunRetryMaxAttemptsFromConfig(&h.config.MultiAgent.EinoMiddleware) + } + return multiagent.RunRetryMaxAttemptsFromConfig(nil) +} + +func (h *AgentHandler) einoRunRetryMaxBackoffSec() int { + if h.config != nil && h.config.MultiAgent.EinoMiddleware.RunRetryMaxBackoffSec > 0 { + return h.config.MultiAgent.EinoMiddleware.RunRetryMaxBackoffSec + } + return 0 +} + +// applyEinoTraceResumeSegment 中断并继续:persist last_react_* → loadHistory,可选替换下一段 user 文案。 +func (h *AgentHandler) applyEinoTraceResumeSegment( + conversationID string, + result *multiagent.RunResult, + curHistory *[]agent.ChatMessage, + curFinalMessage *string, + segmentUserMessage string, +) { + if shouldPersistEinoAgentTraceAfterRunError(context.Background()) { + h.persistEinoAgentTraceForResume(conversationID, result) + } + if hist, err := h.loadHistoryFromAgentTrace(conversationID); err == nil && len(hist) > 0 { + *curHistory = hist + } + if segmentUserMessage != "" { + *curFinalMessage = segmentUserMessage + } +} + +// applyEinoTransientRetrySegment 临时错误重试:恢复轨迹并保留本请求原始 user 文案(不注入续跑说明)。 +// segmentUserMessage 为本轮 HTTP 请求开始时用户发送的内容,避免因清空 finalMessage 而丢失「你好」等短句。 +func (h *AgentHandler) applyEinoTransientRetrySegment( + conversationID string, + result *multiagent.RunResult, + curHistory *[]agent.ChatMessage, + curFinalMessage *string, + segmentUserMessage string, +) { + if shouldPersistEinoAgentTraceAfterRunError(context.Background()) { + h.persistEinoAgentTraceForResume(conversationID, result) + } + if hist, err := h.loadHistoryFromAgentTrace(conversationID); err == nil && len(hist) > 0 { + *curHistory = hist + } + if s := strings.TrimSpace(segmentUserMessage); s != "" { + *curFinalMessage = segmentUserMessage + } +} + +// handleEinoTransientRetryContinue 在 SSE 任务循环内处理临时错误重试;返回 true 表示外层 for 应 continue。 +func (h *AgentHandler) handleEinoTransientRetryContinue( + baseCtx context.Context, + conversationID string, + result *multiagent.RunResult, + runErr error, + transientAttempts *int, + curHistory *[]agent.ChatMessage, + curFinalMessage *string, + segmentUserMessage string, + progressCallback func(eventType, message string, data interface{}), + sendProgress func(msg string, extra map[string]interface{}), +) (handled bool, fatal error) { + if !errors.Is(runErr, multiagent.ErrTransientRetryContinue) { + return false, nil + } + maxAttempts := h.einoRunRetryMaxAttempts() + *transientAttempts++ + if *transientAttempts > maxAttempts { + if shouldPersistEinoAgentTraceAfterRunError(baseCtx) { + h.persistEinoAgentTraceForResume(conversationID, result) + } + return false, errors.New("transient retry exhausted: " + runErr.Error()) + } + attemptNo := *transientAttempts + backoff := multiagent.TransientRetryBackoff(attemptNo-1, h.einoRunRetryMaxBackoffSec()) + if progressCallback != nil { + progressCallback("eino_run_retry", fmt.Sprintf("遇到临时错误,%d 秒后第 %d/%d 次重试…", int(backoff.Seconds()), attemptNo, maxAttempts), map[string]interface{}{ + "conversationId": conversationID, + "source": "eino", + "attempt": attemptNo, + "maxAttempts": maxAttempts, + "backoffSec": int(backoff.Seconds()), + }) + } + select { + case <-baseCtx.Done(): + return false, context.Cause(baseCtx) + case <-time.After(backoff): + } + h.applyEinoTransientRetrySegment(conversationID, result, curHistory, curFinalMessage, segmentUserMessage) + if progressCallback != nil { + progressCallback("eino_run_retry", "已恢复上下文,正在重试…", map[string]interface{}{ + "conversationId": conversationID, + "source": "eino", + "attempt": attemptNo, + }) + } + if sendProgress != nil { + sendProgress("正在重试…", map[string]interface{}{ + "conversationId": conversationID, + "source": "transient_retry", + }) + } + return true, nil +} diff --git a/internal/handler/eino_single_agent.go b/internal/handler/eino_single_agent.go index d51a9cfe..6fcf2366 100644 --- a/internal/handler/eino_single_agent.go +++ b/internal/handler/eino_single_agent.go @@ -119,6 +119,7 @@ func (h *AgentHandler) EinoSingleAgentLoopStream(c *gin.Context) { var cancelWithCause context.CancelCauseFunc curFinalMessage := prep.FinalMessage + segmentUserMessage := prep.FinalMessage // 本请求原始用户句,临时重试时不得丢失 curHistory := prep.History roleTools := prep.RoleTools @@ -176,6 +177,7 @@ func (h *AgentHandler) EinoSingleAgentLoopStream(c *gin.Context) { taskOwned = true var cumulativeMCPExecutionIDs []string + var transientRunAttempts int for { progressCallback := h.createProgressCallback(taskCtx, cancelWithCause, conversationID, assistantMessageID, sendEvent) @@ -198,16 +200,33 @@ func (h *AgentHandler) EinoSingleAgentLoopStream(c *gin.Context) { progressCallback, chatReasoningToClientIntent(req.Reasoning), ) - timeoutCancel() if result != nil && len(result.MCPExecutionIDs) > 0 { cumulativeMCPExecutionIDs = mergeMCPExecutionIDLists(cumulativeMCPExecutionIDs, result.MCPExecutionIDs) } if runErr == nil { + timeoutCancel() break } + handled, fatalErr := h.handleEinoTransientRetryContinue( + baseCtx, conversationID, result, runErr, &transientRunAttempts, + &curHistory, &curFinalMessage, segmentUserMessage, progressCallback, + func(msg string, extra map[string]interface{}) { sendEvent("progress", msg, extra) }, + ) + if handled { + timeoutCancel() + baseCtx, cancelWithCause = context.WithCancelCause(context.Background()) + h.tasks.BindTaskCancel(conversationID, cancelWithCause) + taskCtx, timeoutCancel = context.WithTimeout(baseCtx, 600*time.Minute) + h.tasks.UpdateTaskStatus(conversationID, "running") + continue + } + if fatalErr != nil { + runErr = fatalErr + } + cause := context.Cause(baseCtx) if errors.Is(cause, multiagent.ErrInterruptContinue) { if shouldPersistEinoAgentTraceAfterRunError(baseCtx) { @@ -231,10 +250,11 @@ func (h *AgentHandler) EinoSingleAgentLoopStream(c *gin.Context) { "conversationId": conversationID, "source": "interrupt_continue", }) - h.tasks.UpdateTaskStatus(conversationID, "running") + timeoutCancel() baseCtx, cancelWithCause = context.WithCancelCause(context.Background()) h.tasks.BindTaskCancel(conversationID, cancelWithCause) taskCtx, timeoutCancel = context.WithTimeout(baseCtx, 600*time.Minute) + h.tasks.UpdateTaskStatus(conversationID, "running") continue } @@ -261,6 +281,7 @@ func (h *AgentHandler) EinoSingleAgentLoopStream(c *gin.Context) { "messageId": assistantMessageID, }) sendEvent("done", "", map[string]interface{}{"conversationId": conversationID}) + timeoutCancel() return } @@ -278,6 +299,7 @@ func (h *AgentHandler) EinoSingleAgentLoopStream(c *gin.Context) { "errorType": "timeout", }) sendEvent("done", "", map[string]interface{}{"conversationId": conversationID}) + timeoutCancel() return } @@ -294,9 +316,12 @@ func (h *AgentHandler) EinoSingleAgentLoopStream(c *gin.Context) { "messageId": assistantMessageID, }) sendEvent("done", "", map[string]interface{}{"conversationId": conversationID}) + timeoutCancel() return } + timeoutCancel() + if assistantMessageID != "" { _ = h.db.UpdateAssistantMessageFinalize(assistantMessageID, result.Response, cumulativeMCPExecutionIDs, multiagent.AggregatedReasoningFromTraceJSON(result.LastAgentTraceInput)) } diff --git a/internal/handler/multi_agent.go b/internal/handler/multi_agent.go index 8a707186..e1b0ebd4 100644 --- a/internal/handler/multi_agent.go +++ b/internal/handler/multi_agent.go @@ -136,6 +136,7 @@ func (h *AgentHandler) MultiAgentLoopStream(c *gin.Context) { var cancelWithCause context.CancelCauseFunc curFinalMessage := prep.FinalMessage + segmentUserMessage := prep.FinalMessage // 本请求原始用户句,临时重试时不得丢失 curHistory := prep.History roleTools := prep.RoleTools orch := strings.TrimSpace(req.Orchestration) @@ -186,6 +187,7 @@ func (h *AgentHandler) MultiAgentLoopStream(c *gin.Context) { // 同一 HTTP 流内多段 Run(如中断并继续)合并 MCP execution id,供最终 response / 库表与工具芯片展示完整列表 var cumulativeMCPExecutionIDs []string + var transientRunAttempts int for { progressCallback := h.createProgressCallback(taskCtx, cancelWithCause, conversationID, assistantMessageID, sendEvent) @@ -210,16 +212,33 @@ func (h *AgentHandler) MultiAgentLoopStream(c *gin.Context) { orch, chatReasoningToClientIntent(req.Reasoning), ) - timeoutCancel() if result != nil && len(result.MCPExecutionIDs) > 0 { cumulativeMCPExecutionIDs = mergeMCPExecutionIDLists(cumulativeMCPExecutionIDs, result.MCPExecutionIDs) } if runErr == nil { + timeoutCancel() break } + handled, fatalErr := h.handleEinoTransientRetryContinue( + baseCtx, conversationID, result, runErr, &transientRunAttempts, + &curHistory, &curFinalMessage, segmentUserMessage, progressCallback, + func(msg string, extra map[string]interface{}) { sendEvent("progress", msg, extra) }, + ) + if handled { + timeoutCancel() + baseCtx, cancelWithCause = context.WithCancelCause(context.Background()) + h.tasks.BindTaskCancel(conversationID, cancelWithCause) + taskCtx, timeoutCancel = context.WithTimeout(baseCtx, 600*time.Minute) + h.tasks.UpdateTaskStatus(conversationID, "running") + continue + } + if fatalErr != nil { + runErr = fatalErr + } + cause := context.Cause(baseCtx) if errors.Is(cause, multiagent.ErrInterruptContinue) { if shouldPersistEinoAgentTraceAfterRunError(baseCtx) { @@ -243,10 +262,11 @@ func (h *AgentHandler) MultiAgentLoopStream(c *gin.Context) { "conversationId": conversationID, "source": "interrupt_continue", }) - h.tasks.UpdateTaskStatus(conversationID, "running") + timeoutCancel() baseCtx, cancelWithCause = context.WithCancelCause(context.Background()) h.tasks.BindTaskCancel(conversationID, cancelWithCause) taskCtx, timeoutCancel = context.WithTimeout(baseCtx, 600*time.Minute) + h.tasks.UpdateTaskStatus(conversationID, "running") continue } @@ -273,6 +293,7 @@ func (h *AgentHandler) MultiAgentLoopStream(c *gin.Context) { "messageId": assistantMessageID, }) sendEvent("done", "", map[string]interface{}{"conversationId": conversationID}) + timeoutCancel() return } @@ -290,6 +311,7 @@ func (h *AgentHandler) MultiAgentLoopStream(c *gin.Context) { "errorType": "timeout", }) sendEvent("done", "", map[string]interface{}{"conversationId": conversationID}) + timeoutCancel() return } @@ -306,9 +328,12 @@ func (h *AgentHandler) MultiAgentLoopStream(c *gin.Context) { "messageId": assistantMessageID, }) sendEvent("done", "", map[string]interface{}{"conversationId": conversationID}) + timeoutCancel() return } + timeoutCancel() + if assistantMessageID != "" { _ = h.db.UpdateAssistantMessageFinalize(assistantMessageID, result.Response, cumulativeMCPExecutionIDs, multiagent.AggregatedReasoningFromTraceJSON(result.LastAgentTraceInput)) } diff --git a/internal/reasoning/eino.go b/internal/reasoning/eino.go index 397ac526..0d7a0a30 100644 --- a/internal/reasoning/eino.go +++ b/internal/reasoning/eino.go @@ -149,13 +149,18 @@ func effectiveEffort(sr *config.OpenAIReasoningConfig, client *ClientIntent, all func normalizeEffort(s string) string { e := strings.ToLower(strings.TrimSpace(s)) switch e { - case "low", "medium", "high", "max": + case "low", "medium", "high", "max", "xhigh": return e default: return "" } } +// usesExtraFieldsReasoningEffort 为 Eino 无枚举的最高档 effort,经 ExtraFields 原样下发(max / xhigh 由网关自行识别,不做互转)。 +func usesExtraFieldsReasoningEffort(e string) bool { + return e == "max" || e == "xhigh" +} + func resolveWireProfile(oa *config.OpenAIConfig, sr *config.OpenAIReasoningConfig) wireProfile { if strings.EqualFold(strings.TrimSpace(oa.Provider), "claude") { return wireClaude @@ -210,11 +215,11 @@ func applyOpenAICompat(cfg *einoopenai.ChatModelConfig, mode, effort string) { if e == "" { return } - if e == "max" { + if usesExtraFieldsReasoningEffort(e) { if cfg.ExtraFields == nil { cfg.ExtraFields = make(map[string]any) } - cfg.ExtraFields["reasoning_effort"] = "max" + cfg.ExtraFields["reasoning_effort"] = effortStringForAPI(e) return } switch e { @@ -245,6 +250,6 @@ func applyOutputConfigEffort(cfg *einoopenai.ChatModelConfig, mode, effort strin } func effortStringForAPI(e string) string { - // Gateways expect lowercase strings; "max" kept as max. + // 原样透传:OpenAI 官方多为 xhigh,部分兼容网关为 max,由配置/对话 effort 选择。 return strings.ToLower(strings.TrimSpace(e)) } diff --git a/internal/reasoning/eino_test.go b/internal/reasoning/eino_test.go new file mode 100644 index 00000000..1aae209e --- /dev/null +++ b/internal/reasoning/eino_test.go @@ -0,0 +1,66 @@ +package reasoning + +import ( + "testing" + + "cyberstrike-ai/internal/config" + + einoopenai "github.com/cloudwego/eino-ext/components/model/openai" +) + +func TestEffortStringForAPI_passthrough(t *testing.T) { + cases := map[string]string{ + "max": "max", + "xhigh": "xhigh", + "HIGH": "high", + "Medium": "medium", + } + for in, want := range cases { + if got := effortStringForAPI(in); got != want { + t.Fatalf("%q -> %q, want %q", in, got, want) + } + } +} + +func TestNormalizeEffort_maxAndXhigh(t *testing.T) { + if normalizeEffort("xhigh") != "xhigh" { + t.Fatal("xhigh not accepted") + } + if normalizeEffort("max") != "max" { + t.Fatal("max not accepted") + } +} + +func TestApplyOpenAICompat_xhighExtraField(t *testing.T) { + cfg := &einoopenai.ChatModelConfig{} + oa := &config.OpenAIConfig{ + Reasoning: config.OpenAIReasoningConfig{ + Profile: "openai_compat", + Mode: "on", + Effort: "xhigh", + }, + } + ApplyToEinoChatModelConfig(cfg, oa, nil) + if cfg.ExtraFields == nil { + t.Fatal("expected ExtraFields") + } + if got, _ := cfg.ExtraFields["reasoning_effort"].(string); got != "xhigh" { + t.Fatalf("reasoning_effort=%q", got) + } +} + +func TestApplyOpenAICompat_maxPassthrough(t *testing.T) { + cfg := &einoopenai.ChatModelConfig{} + oa := &config.OpenAIConfig{ + Reasoning: config.OpenAIReasoningConfig{ + Profile: "openai_compat", + Mode: "on", + Effort: "max", + }, + } + ApplyToEinoChatModelConfig(cfg, oa, nil) + got, _ := cfg.ExtraFields["reasoning_effort"].(string) + if got != "max" { + t.Fatalf("max effort wire=%q, want max", got) + } +}