diff --git a/internal/database/conversation.go b/internal/database/conversation.go index cf2d59d8..7a769656 100644 --- a/internal/database/conversation.go +++ b/internal/database/conversation.go @@ -13,6 +13,9 @@ import ( "go.uber.org/zap" ) +// ProjectFilterUnbound 列表 API 中 project_id=__none__ 表示仅未绑定项目的对话。 +const ProjectFilterUnbound = "__none__" + // Conversation 对话 type Conversation struct { ID string `json:"id"` @@ -361,20 +364,44 @@ func (db *DB) GetConversationLite(id string) (*Conversation, error) { return &conv, nil } +func conversationProjectIDColumn(alias string) string { + if alias != "" { + return alias + ".project_id" + } + return "project_id" +} + +func appendConversationProjectFilter(where string, args []interface{}, projectID, alias string) (string, []interface{}) { + pid := strings.TrimSpace(projectID) + if pid == "" { + return where, args + } + col := conversationProjectIDColumn(alias) + if pid == ProjectFilterUnbound { + return where + fmt.Sprintf(" AND (%s IS NULL OR TRIM(COALESCE(%s, '')) = '')", col, col), args + } + return where + fmt.Sprintf(" AND %s = ?", col), append(args, pid) +} + // CountConversations 统计对话数量。 -func (db *DB) CountConversations(search string) (int, error) { +func (db *DB) CountConversations(search, projectID string) (int, error) { var count int var err error if search != "" { searchPattern := "%" + search + "%" - err = db.QueryRow( - `SELECT COUNT(*) FROM conversations c - WHERE c.title LIKE ? - OR EXISTS (SELECT 1 FROM messages m WHERE m.conversation_id = c.id AND m.content LIKE ?)`, - searchPattern, searchPattern, - ).Scan(&count) + where := ` WHERE (c.title LIKE ? + OR EXISTS (SELECT 1 FROM messages m WHERE m.conversation_id = c.id AND m.content LIKE ?))` + args := []interface{}{searchPattern, searchPattern} + where, args = appendConversationProjectFilter(where, args, projectID, "c") + err = db.QueryRow(`SELECT COUNT(*) FROM conversations c`+where, args...).Scan(&count) } else { - err = db.QueryRow(`SELECT COUNT(*) FROM conversations`).Scan(&count) + where := "" + args := []interface{}{} + where, args = appendConversationProjectFilter(where, args, projectID, "") + if where != "" { + where = " WHERE" + strings.TrimPrefix(where, " AND") + } + err = db.QueryRow(`SELECT COUNT(*) FROM conversations`+where, args...).Scan(&count) } if err != nil { return 0, fmt.Errorf("统计对话失败: %w", err) @@ -395,7 +422,7 @@ func conversationOrderClause(sortBy, tableAlias string) string { } // ListConversations 列出所有对话 -func (db *DB) ListConversations(limit, offset int, search, sortBy string) ([]*Conversation, error) { +func (db *DB) ListConversations(limit, offset int, search, sortBy, projectID string) ([]*Conversation, error) { var rows *sql.Rows var err error @@ -403,20 +430,30 @@ func (db *DB) ListConversations(limit, offset int, search, sortBy string) ([]*Co // 使用 EXISTS 子查询代替 LEFT JOIN + DISTINCT,避免大表笛卡尔积 searchPattern := "%" + search + "%" orderClause := conversationOrderClause(sortBy, "c") + where := ` WHERE (c.title LIKE ? + OR EXISTS (SELECT 1 FROM messages m WHERE m.conversation_id = c.id AND m.content LIKE ?))` + args := []interface{}{searchPattern, searchPattern} + where, args = appendConversationProjectFilter(where, args, projectID, "c") + args = append(args, limit, offset) rows, err = db.Query( `SELECT c.id, c.title, COALESCE(c.pinned, 0), c.created_at, c.updated_at, c.project_id - FROM conversations c - WHERE c.title LIKE ? - OR EXISTS (SELECT 1 FROM messages m WHERE m.conversation_id = c.id AND m.content LIKE ?) + FROM conversations c`+where+` `+orderClause+` LIMIT ? OFFSET ?`, - searchPattern, searchPattern, limit, offset, + args..., ) } else { orderClause := conversationOrderClause(sortBy, "") + where := "" + args := []interface{}{} + where, args = appendConversationProjectFilter(where, args, projectID, "") + if where != "" { + where = " WHERE" + strings.TrimPrefix(where, " AND") + } + args = append(args, limit, offset) rows, err = db.Query( - "SELECT id, title, COALESCE(pinned, 0), created_at, updated_at, project_id FROM conversations "+orderClause+" LIMIT ? OFFSET ?", - limit, offset, + "SELECT id, title, COALESCE(pinned, 0), created_at, updated_at, project_id FROM conversations"+where+" "+orderClause+" LIMIT ? OFFSET ?", + args..., ) } @@ -472,23 +509,30 @@ const ungroupedConversationsSQL = ` )` // CountUngroupedConversations 统计不在任何分组中的对话数量。 -func (db *DB) CountUngroupedConversations() (int, error) { +func (db *DB) CountUngroupedConversations(projectID string) (int, error) { + where := ungroupedConversationsSQL + args := []interface{}{} + where, args = appendConversationProjectFilter(where, args, projectID, "c") var count int - if err := db.QueryRow(`SELECT COUNT(*) ` + ungroupedConversationsSQL).Scan(&count); err != nil { + if err := db.QueryRow(`SELECT COUNT(*) `+where, args...).Scan(&count); err != nil { return 0, fmt.Errorf("统计未分组对话失败: %w", err) } return count, nil } // ListUngroupedConversations 列出不在任何分组中的对话(最近对话侧栏)。 -func (db *DB) ListUngroupedConversations(limit, offset int, sortBy string) ([]*Conversation, error) { +func (db *DB) ListUngroupedConversations(limit, offset int, sortBy, projectID string) ([]*Conversation, error) { orderClause := conversationOrderClause(sortBy, "c") + where := ungroupedConversationsSQL + args := []interface{}{} + where, args = appendConversationProjectFilter(where, args, projectID, "c") + args = append(args, limit, offset) rows, err := db.Query( `SELECT c.id, c.title, COALESCE(c.pinned, 0), c.created_at, c.updated_at, c.project_id `+ - ungroupedConversationsSQL+` + where+` `+orderClause+` LIMIT ? OFFSET ?`, - limit, offset, + args..., ) if err != nil { return nil, fmt.Errorf("查询未分组对话失败: %w", err) diff --git a/internal/database/conversation_project_filter_test.go b/internal/database/conversation_project_filter_test.go new file mode 100644 index 00000000..457542b7 --- /dev/null +++ b/internal/database/conversation_project_filter_test.go @@ -0,0 +1,60 @@ +package database + +import ( + "path/filepath" + "testing" + + "go.uber.org/zap" +) + +func TestConversationProjectFilter(t *testing.T) { + tmp := t.TempDir() + dbPath := filepath.Join(tmp, "conversations.db") + db, err := NewDB(dbPath, zap.NewNop()) + if err != nil { + t.Fatalf("NewDB: %v", err) + } + defer db.Close() + + p, err := db.CreateProject(&Project{Name: "target-a", Status: "active"}) + if err != nil { + t.Fatalf("CreateProject: %v", err) + } + + convNone, err := db.CreateConversation("unbound", ConversationCreateMeta{}) + if err != nil { + t.Fatalf("CreateConversation unbound: %v", err) + } + convBound, err := db.CreateConversation("bound", ConversationCreateMeta{ProjectID: p.ID}) + if err != nil { + t.Fatalf("CreateConversation bound: %v", err) + } + + totalAll, err := db.CountConversations("", "") + if err != nil || totalAll < 2 { + t.Fatalf("CountConversations all: total=%d err=%v", totalAll, err) + } + + totalBound, err := db.CountConversations("", p.ID) + if err != nil || totalBound != 1 { + t.Fatalf("CountConversations project: total=%d err=%v", totalBound, err) + } + + totalUnbound, err := db.CountConversations("", ProjectFilterUnbound) + if err != nil || totalUnbound != 1 { + t.Fatalf("CountConversations unbound: total=%d err=%v", totalUnbound, err) + } + + listBound, err := db.ListConversations(10, 0, "", "", p.ID) + if err != nil || len(listBound) != 1 || listBound[0].ID != convBound.ID { + t.Fatalf("ListConversations project: %+v err=%v", listBound, err) + } + + listUnbound, err := db.ListConversations(10, 0, "", "", ProjectFilterUnbound) + if err != nil || len(listUnbound) != 1 || listUnbound[0].ID != convNone.ID { + t.Fatalf("ListConversations unbound: %+v err=%v", listUnbound, err) + } + + _ = convNone + _ = convBound +}