diff --git a/internal/app/app.go b/internal/app/app.go index 1e967b2e..0f6a2f10 100644 --- a/internal/app/app.go +++ b/internal/app/app.go @@ -664,6 +664,7 @@ func setupRoutes( protected.GET("/messages/:id/process-details", conversationHandler.GetMessageProcessDetails) protected.PUT("/conversations/:id", conversationHandler.UpdateConversation) protected.DELETE("/conversations/:id", conversationHandler.DeleteConversation) + protected.POST("/conversations/:id/delete-turn", conversationHandler.DeleteConversationTurn) protected.PUT("/conversations/:id/pinned", groupHandler.UpdateConversationPinned) // 对话分组 diff --git a/internal/database/conversation.go b/internal/database/conversation.go index bd7bd97f..db180a94 100644 --- a/internal/database/conversation.go +++ b/internal/database/conversation.go @@ -4,6 +4,7 @@ import ( "database/sql" "encoding/json" "fmt" + "strings" "time" "github.com/google/uuid" @@ -553,6 +554,102 @@ func (db *DB) GetMessages(conversationID string) ([]Message, error) { return messages, nil } +// turnSliceRange 根据任意一条消息 ID 定位「一轮对话」在 msgs 中的 [start, end) 下标区间(msgs 须已按时间升序,与 GetMessages 一致)。 +// 一轮 = 从某条 user 消息起,至下一条 user 之前(含中间所有 assistant)。 +func turnSliceRange(msgs []Message, anchorID string) (start, end int, err error) { + idx := -1 + for i := range msgs { + if msgs[i].ID == anchorID { + idx = i + break + } + } + if idx < 0 { + return 0, 0, fmt.Errorf("message not found") + } + start = idx + for start > 0 && msgs[start].Role != "user" { + start-- + } + if start < len(msgs) && msgs[start].Role != "user" { + start = 0 + } + end = len(msgs) + for i := start + 1; i < len(msgs); i++ { + if msgs[i].Role == "user" { + end = i + break + } + } + return start, end, nil +} + +// DeleteConversationTurn 删除锚点所在轮次的全部消息(用户提问 + 该轮助手回复等),并清空 last_react_*,避免与消息表不一致。 +func (db *DB) DeleteConversationTurn(conversationID, anchorMessageID string) (deletedIDs []string, err error) { + msgs, err := db.GetMessages(conversationID) + if err != nil { + return nil, err + } + start, end, err := turnSliceRange(msgs, anchorMessageID) + if err != nil { + return nil, err + } + if start >= end { + return nil, fmt.Errorf("empty turn range") + } + deletedIDs = make([]string, 0, end-start) + for i := start; i < end; i++ { + deletedIDs = append(deletedIDs, msgs[i].ID) + } + + tx, err := db.Begin() + if err != nil { + return nil, fmt.Errorf("begin tx: %w", err) + } + defer func() { _ = tx.Rollback() }() + + ph := strings.Repeat("?,", len(deletedIDs)) + ph = ph[:len(ph)-1] + args := make([]interface{}, 0, 1+len(deletedIDs)) + args = append(args, conversationID) + for _, id := range deletedIDs { + args = append(args, id) + } + res, err := tx.Exec( + "DELETE FROM messages WHERE conversation_id = ? AND id IN ("+ph+")", + args..., + ) + if err != nil { + return nil, fmt.Errorf("delete messages: %w", err) + } + n, err := res.RowsAffected() + if err != nil { + return nil, err + } + if int(n) != len(deletedIDs) { + return nil, fmt.Errorf("deleted count mismatch") + } + + _, err = tx.Exec( + `UPDATE conversations SET last_react_input = NULL, last_react_output = NULL, updated_at = ? WHERE id = ?`, + time.Now(), conversationID, + ) + if err != nil { + return nil, fmt.Errorf("clear react data: %w", err) + } + + if err := tx.Commit(); err != nil { + return nil, fmt.Errorf("commit: %w", err) + } + + db.logger.Info("conversation turn deleted", + zap.String("conversationId", conversationID), + zap.Strings("deletedMessageIds", deletedIDs), + zap.Int("count", len(deletedIDs)), + ) + return deletedIDs, nil +} + // ProcessDetail 过程详情事件 type ProcessDetail struct { ID string `json:"id"` diff --git a/internal/database/conversation_turn_test.go b/internal/database/conversation_turn_test.go new file mode 100644 index 00000000..68743468 --- /dev/null +++ b/internal/database/conversation_turn_test.go @@ -0,0 +1,39 @@ +package database + +import ( + "testing" +) + +func TestTurnSliceRange(t *testing.T) { + mk := func(id, role string) Message { + return Message{ID: id, Role: role} + } + msgs := []Message{ + mk("u1", "user"), + mk("a1", "assistant"), + mk("u2", "user"), + mk("a2", "assistant"), + } + cases := []struct { + anchor string + start int + end int + }{ + {"u1", 0, 2}, + {"a1", 0, 2}, + {"u2", 2, 4}, + {"a2", 2, 4}, + } + for _, tc := range cases { + s, e, err := turnSliceRange(msgs, tc.anchor) + if err != nil { + t.Fatalf("anchor %s: %v", tc.anchor, err) + } + if s != tc.start || e != tc.end { + t.Fatalf("anchor %s: got [%d,%d) want [%d,%d)", tc.anchor, s, e, tc.start, tc.end) + } + } + if _, _, err := turnSliceRange(msgs, "nope"); err == nil { + t.Fatal("expected error for missing id") + } +} diff --git a/internal/handler/agent.go b/internal/handler/agent.go index e0ecafb4..871df57c 100644 --- a/internal/handler/agent.go +++ b/internal/handler/agent.go @@ -1256,7 +1256,7 @@ func (h *AgentHandler) AgentLoopStream(c *gin.Context) { // 保存用户消息:有附件时一并保存附件名与路径,刷新后显示、继续对话时大模型也能从历史中拿到路径 userContent := userMessageContentForStorage(req.Message, req.Attachments, savedPaths) - _, err = h.db.AddMessage(conversationID, "user", userContent, nil) + userMsgRow, err := h.db.AddMessage(conversationID, "user", userContent, nil) if err != nil { h.logger.Error("保存用户消息失败", zap.Error(err)) } @@ -1275,6 +1275,14 @@ func (h *AgentHandler) AgentLoopStream(c *gin.Context) { assistantMessageID = assistantMsg.ID } + // 尽早下发消息 ID,便于前端在流式结束前挂上「删除本轮」等(无需等整段结束再刷新) + if userMsgRow != nil { + sendEvent("message_saved", "", map[string]interface{}{ + "conversationId": conversationID, + "userMessageId": userMsgRow.ID, + }) + } + // 创建进度回调函数,复用统一逻辑 progressCallback := h.createProgressCallback(conversationID, assistantMessageID, sendEvent) diff --git a/internal/handler/conversation.go b/internal/handler/conversation.go index f0c074f0..4bb72bbe 100644 --- a/internal/handler/conversation.go +++ b/internal/handler/conversation.go @@ -190,3 +190,44 @@ func (h *ConversationHandler) DeleteConversation(c *gin.Context) { c.JSON(http.StatusOK, gin.H{"message": "删除成功"}) } +// DeleteTurnRequest 删除一轮对话(POST /api/conversations/:id/delete-turn) +type DeleteTurnRequest struct { + MessageID string `json:"messageId"` +} + +// DeleteConversationTurn 删除锚点消息所在轮次(从该轮 user 到下一轮 user 之前),并清空 last_react_*。 +func (h *ConversationHandler) DeleteConversationTurn(c *gin.Context) { + conversationID := c.Param("id") + if conversationID == "" { + c.JSON(http.StatusBadRequest, gin.H{"error": "conversation id required"}) + return + } + + var req DeleteTurnRequest + if err := c.ShouldBindJSON(&req); err != nil || req.MessageID == "" { + c.JSON(http.StatusBadRequest, gin.H{"error": "messageId required"}) + return + } + + if _, err := h.db.GetConversation(conversationID); err != nil { + c.JSON(http.StatusNotFound, gin.H{"error": "对话不存在"}) + return + } + + deletedIDs, err := h.db.DeleteConversationTurn(conversationID, req.MessageID) + if err != nil { + h.logger.Warn("删除对话轮次失败", + zap.String("conversationId", conversationID), + zap.String("messageId", req.MessageID), + zap.Error(err), + ) + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + + c.JSON(http.StatusOK, gin.H{ + "deletedMessageIds": deletedIDs, + "message": "ok", + }) +} + diff --git a/internal/handler/multi_agent.go b/internal/handler/multi_agent.go index 4ff6b7f6..d8a54625 100644 --- a/internal/handler/multi_agent.go +++ b/internal/handler/multi_agent.go @@ -103,6 +103,13 @@ func (h *AgentHandler) MultiAgentLoopStream(c *gin.Context) { conversationID := prep.ConversationID assistantMessageID := prep.AssistantMessageID + if prep.UserMessageID != "" { + sendEvent("message_saved", "", map[string]interface{}{ + "conversationId": conversationID, + "userMessageId": prep.UserMessageID, + }) + } + progressCallback := h.createProgressCallback(conversationID, assistantMessageID, sendEvent) baseCtx, cancelWithCause := context.WithCancelCause(context.Background()) diff --git a/internal/handler/multi_agent_prepare.go b/internal/handler/multi_agent_prepare.go index 60244117..4e2ea4fe 100644 --- a/internal/handler/multi_agent_prepare.go +++ b/internal/handler/multi_agent_prepare.go @@ -19,6 +19,7 @@ type multiAgentPrepared struct { FinalMessage string RoleTools []string AssistantMessageID string + UserMessageID string } func (h *AgentHandler) prepareMultiAgentSession(req *ChatRequest) (*multiAgentPrepared, error) { @@ -109,9 +110,14 @@ func (h *AgentHandler) prepareMultiAgentSession(req *ChatRequest) (*multiAgentPr finalMessage = appendAttachmentsToMessage(finalMessage, req.Attachments, savedPaths) userContent := userMessageContentForStorage(req.Message, req.Attachments, savedPaths) - if _, err = h.db.AddMessage(conversationID, "user", userContent, nil); err != nil { - h.logger.Error("保存用户消息失败", zap.Error(err)) - return nil, fmt.Errorf("保存用户消息失败: %w", err) + userMsgRow, uerr := h.db.AddMessage(conversationID, "user", userContent, nil) + if uerr != nil { + h.logger.Error("保存用户消息失败", zap.Error(uerr)) + return nil, fmt.Errorf("保存用户消息失败: %w", uerr) + } + userMessageID := "" + if userMsgRow != nil { + userMessageID = userMsgRow.ID } assistantMsg, aerr := h.db.AddMessage(conversationID, "assistant", "处理中...", nil) @@ -129,5 +135,6 @@ func (h *AgentHandler) prepareMultiAgentSession(req *ChatRequest) (*multiAgentPr FinalMessage: finalMessage, RoleTools: roleTools, AssistantMessageID: assistantMessageID, + UserMessageID: userMessageID, }, nil }