Add files via upload

This commit is contained in:
公明
2026-04-03 22:55:30 +08:00
committed by GitHub
parent acef729800
commit 4fd083ff37
7 changed files with 204 additions and 4 deletions

View File

@@ -664,6 +664,7 @@ func setupRoutes(
protected.GET("/messages/:id/process-details", conversationHandler.GetMessageProcessDetails)
protected.PUT("/conversations/:id", conversationHandler.UpdateConversation)
protected.DELETE("/conversations/:id", conversationHandler.DeleteConversation)
protected.POST("/conversations/:id/delete-turn", conversationHandler.DeleteConversationTurn)
protected.PUT("/conversations/:id/pinned", groupHandler.UpdateConversationPinned)
// 对话分组

View File

@@ -4,6 +4,7 @@ import (
"database/sql"
"encoding/json"
"fmt"
"strings"
"time"
"github.com/google/uuid"
@@ -553,6 +554,102 @@ func (db *DB) GetMessages(conversationID string) ([]Message, error) {
return messages, nil
}
// turnSliceRange 根据任意一条消息 ID 定位「一轮对话」在 msgs 中的 [start, end) 下标区间msgs 须已按时间升序,与 GetMessages 一致)。
// 一轮 = 从某条 user 消息起,至下一条 user 之前(含中间所有 assistant
func turnSliceRange(msgs []Message, anchorID string) (start, end int, err error) {
idx := -1
for i := range msgs {
if msgs[i].ID == anchorID {
idx = i
break
}
}
if idx < 0 {
return 0, 0, fmt.Errorf("message not found")
}
start = idx
for start > 0 && msgs[start].Role != "user" {
start--
}
if start < len(msgs) && msgs[start].Role != "user" {
start = 0
}
end = len(msgs)
for i := start + 1; i < len(msgs); i++ {
if msgs[i].Role == "user" {
end = i
break
}
}
return start, end, nil
}
// DeleteConversationTurn 删除锚点所在轮次的全部消息(用户提问 + 该轮助手回复等),并清空 last_react_*,避免与消息表不一致。
func (db *DB) DeleteConversationTurn(conversationID, anchorMessageID string) (deletedIDs []string, err error) {
msgs, err := db.GetMessages(conversationID)
if err != nil {
return nil, err
}
start, end, err := turnSliceRange(msgs, anchorMessageID)
if err != nil {
return nil, err
}
if start >= end {
return nil, fmt.Errorf("empty turn range")
}
deletedIDs = make([]string, 0, end-start)
for i := start; i < end; i++ {
deletedIDs = append(deletedIDs, msgs[i].ID)
}
tx, err := db.Begin()
if err != nil {
return nil, fmt.Errorf("begin tx: %w", err)
}
defer func() { _ = tx.Rollback() }()
ph := strings.Repeat("?,", len(deletedIDs))
ph = ph[:len(ph)-1]
args := make([]interface{}, 0, 1+len(deletedIDs))
args = append(args, conversationID)
for _, id := range deletedIDs {
args = append(args, id)
}
res, err := tx.Exec(
"DELETE FROM messages WHERE conversation_id = ? AND id IN ("+ph+")",
args...,
)
if err != nil {
return nil, fmt.Errorf("delete messages: %w", err)
}
n, err := res.RowsAffected()
if err != nil {
return nil, err
}
if int(n) != len(deletedIDs) {
return nil, fmt.Errorf("deleted count mismatch")
}
_, err = tx.Exec(
`UPDATE conversations SET last_react_input = NULL, last_react_output = NULL, updated_at = ? WHERE id = ?`,
time.Now(), conversationID,
)
if err != nil {
return nil, fmt.Errorf("clear react data: %w", err)
}
if err := tx.Commit(); err != nil {
return nil, fmt.Errorf("commit: %w", err)
}
db.logger.Info("conversation turn deleted",
zap.String("conversationId", conversationID),
zap.Strings("deletedMessageIds", deletedIDs),
zap.Int("count", len(deletedIDs)),
)
return deletedIDs, nil
}
// ProcessDetail 过程详情事件
type ProcessDetail struct {
ID string `json:"id"`

View File

@@ -0,0 +1,39 @@
package database
import (
"testing"
)
func TestTurnSliceRange(t *testing.T) {
mk := func(id, role string) Message {
return Message{ID: id, Role: role}
}
msgs := []Message{
mk("u1", "user"),
mk("a1", "assistant"),
mk("u2", "user"),
mk("a2", "assistant"),
}
cases := []struct {
anchor string
start int
end int
}{
{"u1", 0, 2},
{"a1", 0, 2},
{"u2", 2, 4},
{"a2", 2, 4},
}
for _, tc := range cases {
s, e, err := turnSliceRange(msgs, tc.anchor)
if err != nil {
t.Fatalf("anchor %s: %v", tc.anchor, err)
}
if s != tc.start || e != tc.end {
t.Fatalf("anchor %s: got [%d,%d) want [%d,%d)", tc.anchor, s, e, tc.start, tc.end)
}
}
if _, _, err := turnSliceRange(msgs, "nope"); err == nil {
t.Fatal("expected error for missing id")
}
}

View File

@@ -1256,7 +1256,7 @@ func (h *AgentHandler) AgentLoopStream(c *gin.Context) {
// 保存用户消息:有附件时一并保存附件名与路径,刷新后显示、继续对话时大模型也能从历史中拿到路径
userContent := userMessageContentForStorage(req.Message, req.Attachments, savedPaths)
_, err = h.db.AddMessage(conversationID, "user", userContent, nil)
userMsgRow, err := h.db.AddMessage(conversationID, "user", userContent, nil)
if err != nil {
h.logger.Error("保存用户消息失败", zap.Error(err))
}
@@ -1275,6 +1275,14 @@ func (h *AgentHandler) AgentLoopStream(c *gin.Context) {
assistantMessageID = assistantMsg.ID
}
// 尽早下发消息 ID便于前端在流式结束前挂上「删除本轮」等无需等整段结束再刷新
if userMsgRow != nil {
sendEvent("message_saved", "", map[string]interface{}{
"conversationId": conversationID,
"userMessageId": userMsgRow.ID,
})
}
// 创建进度回调函数,复用统一逻辑
progressCallback := h.createProgressCallback(conversationID, assistantMessageID, sendEvent)

View File

@@ -190,3 +190,44 @@ func (h *ConversationHandler) DeleteConversation(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{"message": "删除成功"})
}
// DeleteTurnRequest 删除一轮对话POST /api/conversations/:id/delete-turn
type DeleteTurnRequest struct {
MessageID string `json:"messageId"`
}
// DeleteConversationTurn 删除锚点消息所在轮次(从该轮 user 到下一轮 user 之前),并清空 last_react_*。
func (h *ConversationHandler) DeleteConversationTurn(c *gin.Context) {
conversationID := c.Param("id")
if conversationID == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "conversation id required"})
return
}
var req DeleteTurnRequest
if err := c.ShouldBindJSON(&req); err != nil || req.MessageID == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "messageId required"})
return
}
if _, err := h.db.GetConversation(conversationID); err != nil {
c.JSON(http.StatusNotFound, gin.H{"error": "对话不存在"})
return
}
deletedIDs, err := h.db.DeleteConversationTurn(conversationID, req.MessageID)
if err != nil {
h.logger.Warn("删除对话轮次失败",
zap.String("conversationId", conversationID),
zap.String("messageId", req.MessageID),
zap.Error(err),
)
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, gin.H{
"deletedMessageIds": deletedIDs,
"message": "ok",
})
}

View File

@@ -103,6 +103,13 @@ func (h *AgentHandler) MultiAgentLoopStream(c *gin.Context) {
conversationID := prep.ConversationID
assistantMessageID := prep.AssistantMessageID
if prep.UserMessageID != "" {
sendEvent("message_saved", "", map[string]interface{}{
"conversationId": conversationID,
"userMessageId": prep.UserMessageID,
})
}
progressCallback := h.createProgressCallback(conversationID, assistantMessageID, sendEvent)
baseCtx, cancelWithCause := context.WithCancelCause(context.Background())

View File

@@ -19,6 +19,7 @@ type multiAgentPrepared struct {
FinalMessage string
RoleTools []string
AssistantMessageID string
UserMessageID string
}
func (h *AgentHandler) prepareMultiAgentSession(req *ChatRequest) (*multiAgentPrepared, error) {
@@ -109,9 +110,14 @@ func (h *AgentHandler) prepareMultiAgentSession(req *ChatRequest) (*multiAgentPr
finalMessage = appendAttachmentsToMessage(finalMessage, req.Attachments, savedPaths)
userContent := userMessageContentForStorage(req.Message, req.Attachments, savedPaths)
if _, err = h.db.AddMessage(conversationID, "user", userContent, nil); err != nil {
h.logger.Error("保存用户消息失败", zap.Error(err))
return nil, fmt.Errorf("保存用户消息失败: %w", err)
userMsgRow, uerr := h.db.AddMessage(conversationID, "user", userContent, nil)
if uerr != nil {
h.logger.Error("保存用户消息失败", zap.Error(uerr))
return nil, fmt.Errorf("保存用户消息失败: %w", uerr)
}
userMessageID := ""
if userMsgRow != nil {
userMessageID = userMsgRow.ID
}
assistantMsg, aerr := h.db.AddMessage(conversationID, "assistant", "处理中...", nil)
@@ -129,5 +135,6 @@ func (h *AgentHandler) prepareMultiAgentSession(req *ChatRequest) (*multiAgentPr
FinalMessage: finalMessage,
RoleTools: roleTools,
AssistantMessageID: assistantMessageID,
UserMessageID: userMessageID,
}, nil
}