diff --git a/internal/app/app.go b/internal/app/app.go index f044dd6f..f1671ffa 100644 --- a/internal/app/app.go +++ b/internal/app/app.go @@ -675,6 +675,7 @@ func setupRoutes( protected.DELETE("/groups/:id", groupHandler.DeleteGroup) protected.PUT("/groups/:id/pinned", groupHandler.UpdateGroupPinned) protected.GET("/groups/:id/conversations", groupHandler.GetGroupConversations) + protected.GET("/groups/mappings", groupHandler.GetAllMappings) protected.POST("/groups/conversations", groupHandler.AddConversationToGroup) protected.DELETE("/groups/:id/conversations/:conversationId", groupHandler.RemoveConversationFromGroup) protected.PUT("/groups/:id/conversations/:conversationId/pinned", groupHandler.UpdateConversationPinnedInGroup) @@ -682,6 +683,7 @@ func setupRoutes( // 监控 protected.GET("/monitor", monitorHandler.Monitor) protected.GET("/monitor/execution/:id", monitorHandler.GetExecution) + protected.POST("/monitor/executions/names", monitorHandler.BatchGetToolNames) protected.DELETE("/monitor/execution/:id", monitorHandler.DeleteExecution) protected.DELETE("/monitor/executions", monitorHandler.DeleteExecutions) protected.GET("/monitor/stats", monitorHandler.GetStats) diff --git a/internal/database/conversation.go b/internal/database/conversation.go index db180a94..ca2b1f5a 100644 --- a/internal/database/conversation.go +++ b/internal/database/conversation.go @@ -310,15 +310,14 @@ func (db *DB) ListConversations(limit, offset int, search string) ([]*Conversati var err error if search != "" { - // 使用LIKE进行模糊搜索,搜索标题和消息内容 + // 使用 EXISTS 子查询代替 LEFT JOIN + DISTINCT,避免大表笛卡尔积 searchPattern := "%" + search + "%" - // 使用DISTINCT避免重复,因为一个对话可能有多条消息匹配 rows, err = db.Query( - `SELECT DISTINCT c.id, c.title, COALESCE(c.pinned, 0), c.created_at, c.updated_at + `SELECT c.id, c.title, COALESCE(c.pinned, 0), c.created_at, c.updated_at FROM conversations c - LEFT JOIN messages m ON c.id = m.conversation_id - WHERE c.title LIKE ? OR m.content LIKE ? - ORDER BY c.updated_at DESC + WHERE c.title LIKE ? + OR EXISTS (SELECT 1 FROM messages m WHERE m.conversation_id = c.id AND m.content LIKE ?) + ORDER BY c.updated_at DESC LIMIT ? OFFSET ?`, searchPattern, searchPattern, limit, offset, ) diff --git a/internal/database/group.go b/internal/database/group.go index 35e249f6..a3d32106 100644 --- a/internal/database/group.go +++ b/internal/database/group.go @@ -403,6 +403,35 @@ func (db *DB) UpdateGroupPinned(id string, pinned bool) error { return nil } +// GroupMapping 分组映射关系 +type GroupMapping struct { + ConversationID string `json:"conversationId"` + GroupID string `json:"groupId"` +} + +// GetAllGroupMappings 批量获取所有分组映射(消除 N+1 查询) +func (db *DB) GetAllGroupMappings() ([]GroupMapping, error) { + rows, err := db.Query("SELECT conversation_id, group_id FROM conversation_group_mappings") + if err != nil { + return nil, fmt.Errorf("查询分组映射失败: %w", err) + } + defer rows.Close() + + var mappings []GroupMapping + for rows.Next() { + var m GroupMapping + if err := rows.Scan(&m.ConversationID, &m.GroupID); err != nil { + return nil, fmt.Errorf("扫描分组映射失败: %w", err) + } + mappings = append(mappings, m) + } + + if mappings == nil { + mappings = []GroupMapping{} + } + return mappings, nil +} + // UpdateConversationPinnedInGroup 更新对话在分组中的置顶状态 func (db *DB) UpdateConversationPinnedInGroup(conversationID, groupID string, pinned bool) error { pinnedValue := 0 diff --git a/internal/handler/group.go b/internal/handler/group.go index d3bfc9a8..495e7695 100644 --- a/internal/handler/group.go +++ b/internal/handler/group.go @@ -234,6 +234,18 @@ func (h *GroupHandler) GetGroupConversations(c *gin.Context) { c.JSON(http.StatusOK, groupConvs) } +// GetAllMappings 批量获取所有分组映射(消除前端 N+1 请求) +func (h *GroupHandler) GetAllMappings(c *gin.Context) { + mappings, err := h.db.GetAllGroupMappings() + if err != nil { + h.logger.Error("获取分组映射失败", zap.Error(err)) + c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) + return + } + + c.JSON(http.StatusOK, mappings) +} + // UpdateConversationPinnedRequest 更新对话置顶状态请求 type UpdateConversationPinnedRequest struct { Pinned bool `json:"pinned"` diff --git a/internal/handler/monitor.go b/internal/handler/monitor.go index e2ebc456..c337c374 100644 --- a/internal/handler/monitor.go +++ b/internal/handler/monitor.go @@ -246,6 +246,41 @@ func (h *MonitorHandler) GetExecution(c *gin.Context) { c.JSON(http.StatusNotFound, gin.H{"error": "执行记录未找到"}) } +// BatchGetToolNames 批量获取工具执行的工具名称(消除前端 N+1 请求) +func (h *MonitorHandler) BatchGetToolNames(c *gin.Context) { + var req struct { + IDs []string `json:"ids"` + } + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + + result := make(map[string]string, len(req.IDs)) + for _, id := range req.IDs { + // 先从内部MCP服务器查找 + if exec, exists := h.mcpServer.GetExecution(id); exists { + result[id] = exec.ToolName + continue + } + // 再从外部MCP管理器查找 + if h.externalMCPMgr != nil { + if exec, exists := h.externalMCPMgr.GetExecution(id); exists { + result[id] = exec.ToolName + continue + } + } + // 最后从数据库查找 + if h.db != nil { + if exec, err := h.db.GetToolExecution(id); err == nil && exec != nil { + result[id] = exec.ToolName + } + } + } + + c.JSON(http.StatusOK, result) +} + // GetStats 获取统计信息 func (h *MonitorHandler) GetStats(c *gin.Context) { stats := h.loadStats() diff --git a/internal/multiagent/runner.go b/internal/multiagent/runner.go index a04590c5..cda89530 100644 --- a/internal/multiagent/runner.go +++ b/internal/multiagent/runner.go @@ -233,6 +233,9 @@ func RunDeepAgent( ToolsNodeConfig: compose.ToolsNodeConfig{ Tools: subTools, UnknownToolsHandler: einomcp.UnknownToolReminderHandler(), + ToolCallMiddlewares: []compose.ToolMiddleware{ + {Invokable: softRecoveryToolCallMiddleware()}, + }, }, EmitInternalEvents: true, }, @@ -288,6 +291,9 @@ func RunDeepAgent( ToolsNodeConfig: compose.ToolsNodeConfig{ Tools: mainTools, UnknownToolsHandler: einomcp.UnknownToolReminderHandler(), + ToolCallMiddlewares: []compose.ToolMiddleware{ + {Invokable: softRecoveryToolCallMiddleware()}, + }, }, EmitInternalEvents: true, }, diff --git a/internal/multiagent/tool_args_json_retry.go b/internal/multiagent/tool_args_json_retry.go index 9f97a0f0..d6d79971 100644 --- a/internal/multiagent/tool_args_json_retry.go +++ b/internal/multiagent/tool_args_json_retry.go @@ -10,7 +10,7 @@ import ( // maxToolCallRecoveryAttempts 含首次运行:首次 + 自动重试次数。 // 例如为 3 表示最多共 3 次完整 DeepAgent 运行(2 次失败后各追加一条纠错提示)。 // 该常量同时用于 JSON 参数错误和工具执行错误(如子代理名称不存在)的恢复重试。 -const maxToolCallRecoveryAttempts = 3 +const maxToolCallRecoveryAttempts = 5 // toolCallArgumentsJSONRetryHint 追加在用户消息后,提示模型输出合法 JSON 工具参数(部分云厂商会在流式阶段校验 arguments)。 func toolCallArgumentsJSONRetryHint() *schema.Message { diff --git a/internal/multiagent/tool_error_middleware.go b/internal/multiagent/tool_error_middleware.go new file mode 100644 index 00000000..10158fc2 --- /dev/null +++ b/internal/multiagent/tool_error_middleware.go @@ -0,0 +1,131 @@ +package multiagent + +import ( + "context" + "encoding/json" + "fmt" + "strings" + + "github.com/cloudwego/eino/compose" +) + +// softRecoveryToolCallMiddleware returns an InvokableToolMiddleware that catches +// specific recoverable errors from tool execution (JSON parse errors, tool-not-found, +// etc.) and converts them into soft errors: nil error + descriptive error content +// returned to the LLM. This allows the model to self-correct within the same +// iteration rather than crashing the entire graph and requiring a full replay. +// +// Without this middleware, a JSON parse failure in any tool's InvokableRun propagates +// as a hard error through the Eino ToolsNode → [NodeRunError] → ev.Err, which +// either triggers the full-replay retry loop (expensive) or terminates the run +// entirely once retries are exhausted. With it, the LLM simply sees an error message +// in the tool result and can adjust its next tool call accordingly. +func softRecoveryToolCallMiddleware() compose.InvokableToolMiddleware { + return func(next compose.InvokableToolEndpoint) compose.InvokableToolEndpoint { + return func(ctx context.Context, input *compose.ToolInput) (*compose.ToolOutput, error) { + output, err := next(ctx, input) + if err == nil { + return output, nil + } + if !isSoftRecoverableToolError(err) { + return output, err + } + // Convert the hard error into a soft error: the LLM will see this + // message as the tool's output and can self-correct. + msg := buildSoftRecoveryMessage(input.Name, input.Arguments, err) + return &compose.ToolOutput{Result: msg}, nil + } + } +} + +// isSoftRecoverableToolError determines whether a tool execution error should be +// silently converted to a tool-result message rather than crashing the graph. +func isSoftRecoverableToolError(err error) bool { + if err == nil { + return false + } + s := strings.ToLower(err.Error()) + + // JSON unmarshal/parse failures — the model generated truncated or malformed arguments. + if isJSONRelatedError(s) { + return true + } + + // Sub-agent type not found (from deep/task_tool.go) + if strings.Contains(s, "subagent type") && strings.Contains(s, "not found") { + return true + } + + // Tool not found in ToolsNode indexes + if strings.Contains(s, "tool") && strings.Contains(s, "not found") { + return true + } + + return false +} + +// isJSONRelatedError checks whether an error string indicates a JSON parsing problem. +func isJSONRelatedError(lower string) bool { + if !strings.Contains(lower, "json") { + return false + } + jsonIndicators := []string{ + "unexpected end of json", + "unmarshal", + "invalid character", + "cannot unmarshal", + "invalid tool arguments", + "failed to unmarshal", + "must be in json format", + "unexpected eof", + } + for _, ind := range jsonIndicators { + if strings.Contains(lower, ind) { + return true + } + } + return false +} + +// buildSoftRecoveryMessage creates a bilingual error message that the LLM can act on. +func buildSoftRecoveryMessage(toolName, arguments string, err error) string { + // Truncate arguments preview to avoid flooding the context. + argPreview := arguments + if len(argPreview) > 300 { + argPreview = argPreview[:300] + "... (truncated)" + } + + // Try to determine if it's specifically a JSON parse error for a friendlier message. + errStr := err.Error() + var jsonErr *json.SyntaxError + isJSONErr := strings.Contains(strings.ToLower(errStr), "json") || + strings.Contains(strings.ToLower(errStr), "unmarshal") + _ = jsonErr // suppress unused + + if isJSONErr { + return fmt.Sprintf( + "[Tool Error] The arguments for tool '%s' are not valid JSON and could not be parsed.\n"+ + "Error: %s\n"+ + "Arguments received: %s\n\n"+ + "Please fix the JSON (ensure double-quoted keys, matched braces/brackets, no trailing commas, "+ + "no truncation) and call the tool again.\n\n"+ + "[工具错误] 工具 '%s' 的参数不是合法 JSON,无法解析。\n"+ + "错误:%s\n"+ + "收到的参数:%s\n\n"+ + "请修正 JSON(确保双引号键名、括号配对、无尾部逗号、无截断),然后重新调用工具。", + toolName, errStr, argPreview, + toolName, errStr, argPreview, + ) + } + + return fmt.Sprintf( + "[Tool Error] Tool '%s' execution failed: %s\n"+ + "Arguments: %s\n\n"+ + "Please review the available tools and their expected arguments, then retry.\n\n"+ + "[工具错误] 工具 '%s' 执行失败:%s\n"+ + "参数:%s\n\n"+ + "请检查可用工具及其参数要求,然后重试。", + toolName, errStr, argPreview, + toolName, errStr, argPreview, + ) +} diff --git a/internal/multiagent/tool_error_middleware_test.go b/internal/multiagent/tool_error_middleware_test.go new file mode 100644 index 00000000..d87e417b --- /dev/null +++ b/internal/multiagent/tool_error_middleware_test.go @@ -0,0 +1,166 @@ +package multiagent + +import ( + "context" + "encoding/json" + "errors" + "testing" + + "github.com/cloudwego/eino/compose" +) + +func TestIsSoftRecoverableToolError(t *testing.T) { + tests := []struct { + name string + err error + expected bool + }{ + { + name: "nil error", + err: nil, + expected: false, + }, + { + name: "unexpected end of JSON input", + err: errors.New("unexpected end of JSON input"), + expected: true, + }, + { + name: "failed to unmarshal task tool input json", + err: errors.New("failed to unmarshal task tool input json: unexpected end of JSON input"), + expected: true, + }, + { + name: "invalid tool arguments JSON", + err: errors.New("invalid tool arguments JSON: unexpected end of JSON input"), + expected: true, + }, + { + name: "json invalid character", + err: errors.New(`invalid character '}' looking for beginning of value in JSON`), + expected: true, + }, + { + name: "subagent type not found", + err: errors.New("subagent type recon_agent not found"), + expected: true, + }, + { + name: "tool not found", + err: errors.New("tool nmap_scan not found in toolsNode indexes"), + expected: true, + }, + { + name: "unrelated network error", + err: errors.New("connection refused"), + expected: false, + }, + { + name: "context cancelled", + err: context.Canceled, + expected: false, + }, + { + name: "real json unmarshal error", + err: func() error { + var v map[string]interface{} + return json.Unmarshal([]byte(`{"key": `), &v) + }(), + expected: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := isSoftRecoverableToolError(tt.err) + if got != tt.expected { + t.Errorf("isSoftRecoverableToolError(%v) = %v, want %v", tt.err, got, tt.expected) + } + }) + } +} + +func TestSoftRecoveryToolCallMiddleware_PassesThrough(t *testing.T) { + mw := softRecoveryToolCallMiddleware() + called := false + next := func(ctx context.Context, input *compose.ToolInput) (*compose.ToolOutput, error) { + called = true + return &compose.ToolOutput{Result: "success"}, nil + } + wrapped := mw(next) + out, err := wrapped(context.Background(), &compose.ToolInput{ + Name: "test_tool", + Arguments: `{"key": "value"}`, + }) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if !called { + t.Fatal("next endpoint was not called") + } + if out.Result != "success" { + t.Fatalf("expected 'success', got %q", out.Result) + } +} + +func TestSoftRecoveryToolCallMiddleware_ConvertsJSONError(t *testing.T) { + mw := softRecoveryToolCallMiddleware() + next := func(ctx context.Context, input *compose.ToolInput) (*compose.ToolOutput, error) { + return nil, errors.New("failed to unmarshal task tool input json: unexpected end of JSON input") + } + wrapped := mw(next) + out, err := wrapped(context.Background(), &compose.ToolInput{ + Name: "task", + Arguments: `{"subagent_type": "recon`, + }) + if err != nil { + t.Fatalf("expected nil error (soft recovery), got: %v", err) + } + if out == nil || out.Result == "" { + t.Fatal("expected non-empty recovery message") + } + if !containsAll(out.Result, "[Tool Error]", "task", "JSON") { + t.Fatalf("recovery message missing expected content: %s", out.Result) + } +} + +func TestSoftRecoveryToolCallMiddleware_PropagatesNonRecoverable(t *testing.T) { + mw := softRecoveryToolCallMiddleware() + origErr := errors.New("connection timeout to remote server") + next := func(ctx context.Context, input *compose.ToolInput) (*compose.ToolOutput, error) { + return nil, origErr + } + wrapped := mw(next) + _, err := wrapped(context.Background(), &compose.ToolInput{ + Name: "test_tool", + Arguments: `{}`, + }) + if err == nil { + t.Fatal("expected error to propagate for non-recoverable errors") + } + if err != origErr { + t.Fatalf("expected original error, got: %v", err) + } +} + +func containsAll(s string, subs ...string) bool { + for _, sub := range subs { + if !contains(s, sub) { + return false + } + } + return true +} + +func contains(s, sub string) bool { + return len(s) >= len(sub) && searchString(s, sub) +} + +func searchString(s, sub string) bool { + for i := 0; i <= len(s)-len(sub); i++ { + if s[i:i+len(sub)] == sub { + return true + } + } + return false +}