Files
phishingclub/backend/repository/utils.go
2025-08-21 16:14:09 +02:00

408 lines
11 KiB
Go

package repository
import (
"context"
"fmt"
"slices"
"strings"
"github.com/go-errors/errors"
"github.com/google/uuid"
"github.com/phishingclub/phishingclub/errs"
"github.com/phishingclub/phishingclub/utils"
"github.com/phishingclub/phishingclub/vo"
"gorm.io/gorm"
)
var defaultAllowdSearchColumns = []string{
"name",
}
var defaultAllowdColumns = map[string]struct{}{
"name": {},
"created_at": {},
"updated_at": {},
}
func withOffsetLimit(db *gorm.DB, offset, limit int) *gorm.DB {
if offset == 0 && limit == 0 {
return db
}
return db.Offset(offset).Limit(limit)
}
func WithOrderBy(db *gorm.DB, orderBy string, desc bool, allowed ...string) (*gorm.DB, error) {
if orderBy == "" {
return db, nil
}
// if no allowed columns are provided, use the default
// else check if the column is allowed
if len(allowed) == 0 {
if _, ok := defaultAllowdColumns[orderBy]; !ok {
return db, fmt.Errorf(
"not known or allowed column: %s - allowd: %s",
orderBy,
defaultAllowdColumns,
)
}
} else {
if !slices.Contains(allowed, orderBy) {
return db, fmt.Errorf(
"not known or allowed column: %s - allowd: %s",
orderBy,
allowed,
)
}
}
// TODO this ruins all indexes performance but is a quick fix to work for all databases
// to ensure that the order by is case insensitive
// the real solution is to use a case insensitive collation
// but these differ per database, another option would be LOWER indexes for some columns however this is also not ideal
//orderBy = fmt.Sprintf("LOWER(%s)", orderBy)
if !desc {
return db.Order(orderBy + " COLLATE NOCASE ASC"), nil
}
return db.Order(orderBy + " COLLATE NOCASE DESC"), nil
}
func WithOrderByOnTable(db *gorm.DB, table string, orderBy string, desc bool, allowed ...string) (*gorm.DB, error) {
if orderBy == "" {
return db, nil
}
// only check default columns if no allowed columns are provided
if _, ok := defaultAllowdColumns[orderBy]; !ok && len(allowed) == 0 {
return db, fmt.Errorf("invalid column: %s", orderBy)
}
for _, allowedOrderBy := range allowed {
if !slices.Contains(allowed, orderBy) {
return db, fmt.Errorf("invalid column: %s", allowedOrderBy)
}
}
// TODO this ruins all indexes performance but is a quick fix to work for all databases
// to ensure that the order by is case insensitive
// the real solution is to use a case insensitive collation
// but these differ per database, another option would be LOWER indexes for some columns however this is also not ideal
if !desc {
return db.Order(
//fmt.Sprintf("LOWER(`%s`.`%s`) ASC", table, orderBy),
fmt.Sprintf("LOWER(`%s`.`%s`) COLLATE NOCAS ASC", table, orderBy),
), nil
}
return db.Order(
//fmt.Sprintf("LOWER(`%s`.`%s`) DESC", table, orderBy),
fmt.Sprintf("LOWER(`%s`.`%s`) NO COLLATE DESC", table, orderBy),
), nil
}
func assignTableToColumn(table, column string) string {
// if the column already contains a dot, it is already formatted
if strings.Contains(column, ".") {
return column
}
return fmt.Sprintf("`%s`.`%s`", table, column)
}
func assignTableToColumns(table string, columns []string) []string {
for i, column := range columns {
columns[i] = assignTableToColumn(table, column)
}
return columns
}
func useQuery(db *gorm.DB, tableName string, q *vo.QueryArgs, allowdColumns ...string) (*gorm.DB, error) {
if q == nil {
return db, nil
}
db = withOffsetLimit(db, q.Offset, q.Limit)
db, err := WithOrderBy(db, assignTableToColumn(tableName, q.OrderBy), q.Desc, allowdColumns...)
if err != nil {
return db, errs.Wrap(err)
}
// handle search
if q.Search != "" {
searchColumns := []string{}
// add columns that are allowed to be searched in
for _, column := range allowdColumns {
searchColumns = append(
searchColumns,
assignTableToColumn(tableName, column),
)
}
// if no columns has been added, use the default
if len(searchColumns) == 0 {
searchColumns = assignTableToColumns(tableName, defaultAllowdSearchColumns)
}
// remove search symbols
search := strings.ReplaceAll(q.Search, "%", " ")
search = strings.ReplaceAll(search, "_", " ")
// build the LIKE query
var searches []interface{}
q := ""
for i, column := range searchColumns {
if i == 0 {
q += fmt.Sprintf("%s LIKE ?", column)
} else {
q += fmt.Sprintf(" OR %s LIKE ?", column)
}
searches = append(searches, "%"+search+"%")
}
db.Where(q, searches...)
}
return db, errs.Wrap(err)
}
func useHasNextPage(
db *gorm.DB,
tableName string,
q *vo.QueryArgs,
allowdColumns ...string,
) (bool, error) {
if q == nil {
return false, nil
}
if q.Limit == 0 && q.Offset == 0 {
return false, nil
}
db = withOffsetLimit(db, q.Offset+q.Limit, 1)
db, err := WithOrderBy(db, assignTableToColumn(tableName, q.OrderBy), q.Desc, allowdColumns...)
if err != nil {
return false, errs.Wrap(err)
}
// handle search
if q.Search != "" {
searchColumns := []string{}
// add columns that are allowed to be searched in
for _, column := range allowdColumns {
searchColumns = append(
searchColumns,
assignTableToColumn(tableName, column),
)
}
// if no columns has been added, use the default
if len(searchColumns) == 0 {
searchColumns = assignTableToColumns(tableName, defaultAllowdSearchColumns)
}
// remove search symbols
search := strings.ReplaceAll(q.Search, "%", " ")
search = strings.ReplaceAll(search, "_", " ")
// build the LIKE query
var searches []interface{}
q := ""
for i, column := range searchColumns {
if i == 0 {
q += fmt.Sprintf("%s LIKE ?", column)
} else {
q += fmt.Sprintf(" OR %s LIKE ?", column)
}
searches = append(searches, "%"+search+"%")
}
db.Where(q, searches...)
}
// Check if there's at least one record
var exists bool
err = db.Select("1").Find(&exists).Error
if errors.Is(err, gorm.ErrRecordNotFound) {
return false, nil
}
if err != nil {
return false, errs.Wrap(err)
}
return exists, nil
}
/*
func useQueryWithTable(db *gorm.DB, table string, q *vo.QueryArgs, allowdColumns ...string) (*gorm.DB, error) {
if q == nil {
return db, nil
}
db = withOffsetLimit(db, q.Offset, q.Limit)
db, err := WithOrderByOnTable(db, table, q.OrderBy, q.Desc, allowdColumns...)
// handle search
if q.Search != "" {
searchColumns := []string{}
// add columns that are allowed to be searched in
for _, column := range allowdColumns {
if column == "created_at" || column == "updated_at" {
continue
}
searchColumns = append(searchColumns, column)
}
// if no columns has been added, use the default
if len(searchColumns) == 0 {
searchColumns = defaultAllowdSearchColumns
}
// remove search symbols
search := strings.ReplaceAll(q.Search, "%", " ")
search = strings.ReplaceAll(search, "_", " ")
// build the LIKE query
// todo perhaps this needs table prefix also
var searches []interface{}
q := ""
for i, column := range searchColumns {
if i == 0 {
q += column + " LIKE ?"
} else {
q += " OR " + column + " LIKE ?"
}
searches = append(searches, "%"+search+"%")
}
db.Where(q, searches...)
}
return db,errs.Wrap(err)
}
*/
func SelectTable(tableName string) []string {
return []string{fmt.Sprintf("`%s`.*", tableName)}
}
// SelectColumnAs creates a list of columns with aliases column is map[column]alias
func SelectColumnAs(tableName string, columns map[string]string) []string {
var cols []string
for key, value := range columns {
cols = append(cols, fmt.Sprintf("`%s`.`%s` AS %s", tableName, value, key))
}
return cols
}
func useSelect(db *gorm.DB, fields []string) *gorm.DB {
if len(fields) == 0 {
return db
}
return db.Select(fields)
}
func LeftJoinOn(tableLeft, columnLeft, tableRight, columnRight string) string {
return fmt.Sprintf("LEFT JOIN `%s` on `%s`.`%s` = `%s`.`%s`", tableRight, tableLeft, columnLeft, tableRight, columnRight)
}
func LeftJoinOnWithAlias(tableLeft, columnLeft, tableRight, columnRight, alias string) string {
return fmt.Sprintf("LEFT JOIN `%s` '%s' on `%s`.`%s` = `%s`.`%s`", tableRight, alias, tableLeft, columnLeft, alias, columnRight)
}
// withCompanyTableContext adds a company context to the query
func withCompanyIncludingNullContext(db *gorm.DB, companyID *uuid.UUID, tableName string) *gorm.DB {
column := fmt.Sprintf("`%s`.company_id", tableName)
if companyID != nil {
return db.Where(
fmt.Sprintf("(%s = ? OR %s IS NULL)", column, column), companyID)
}
return db.Where(
fmt.Sprintf("(%s IS NULL)", column),
)
}
// withCompany adds a where company id
// if companyID is NIL it will add a companyID IS NULL
func whereCompany(db *gorm.DB, tableName string, companyID *uuid.UUID) *gorm.DB {
column := fmt.Sprintf("`%s`.company_id", tableName)
if companyID == nil {
return db.Where(fmt.Sprintf("%s IS NULL", column))
} else {
return db.Where(
fmt.Sprintf("%s = ?", column), companyID)
}
}
// withCompany adds a where company id is null
func whereCompanyIsNull(db *gorm.DB, tableName string) *gorm.DB {
column := fmt.Sprintf("`%s`.company_id", tableName)
return db.Where(
fmt.Sprintf("%s IS NULL", column))
}
// AddTimestamps adds created_at and updated_at to a map
func AddTimestamps(row map[string]interface{}) {
now := utils.NowRFC3339UTC()
row["created_at"] = now
row["updated_at"] = now
}
// AddCreatedAt adds updated_at to a map
func AddUpdatedAt(row map[string]interface{}) {
row["updated_at"] = utils.NowRFC3339UTC()
}
// CheckColumnIsUnique checks if a column is unique within a company and globally
// if companyID is nil, it is global no other row should use the name, period.
// if companyID is set, then the same column id should not use it, and no null company ID
// columns is not sqli safe
func CheckColumnIsUnique(
ctx context.Context,
db *gorm.DB,
table string,
column string,
value string,
companyID *uuid.UUID,
currentID *uuid.UUID, // if currentID is set, it is allowed to use the same value
) (bool, error) {
var count int64
tx := db.Table(table)
if companyID != nil {
tx = tx.Where(column+" = ? AND (company_id = ? OR company_id IS NULL)", value, companyID)
} else {
tx = tx.Where(column+" = ?", value)
}
if currentID != nil {
tx = tx.Where("id != ?", currentID)
}
result := tx.Count(&count)
if result.Error != nil {
return false, result.Error
}
return count == 0, nil
}
func UUIDsToStrings(ids []*uuid.UUID) []string {
args := []string{}
for _, s := range ids {
args = append(args, s.String())
}
return args
}
// CheckNameIsUnique checks if a name is unique within a company and globally
// if companyID is nil, it is global no other row should use the name, period.
// if companyID is set, then the same company id should not use it, and no null company ID
func CheckNameIsUnique(
ctx context.Context,
db *gorm.DB,
table string,
name string,
companyID *uuid.UUID,
currentID *uuid.UUID,
) (bool, error) {
return CheckColumnIsUnique(ctx, db, table, "name", name, companyID, currentID)
}
func TableSelect(selects ...string) string {
return strings.Join(
selects,
",",
)
}
func TableColumn(tableName, columnName string) string {
return fmt.Sprintf("`%s`.`%s`", tableName, columnName)
}
func TableColumnAlias(tableName, columnName, alias string) string {
return fmt.Sprintf("`%s`.`%s` AS `%s`", tableName, columnName, alias)
}
func TableColumnAll(tableName string) string {
return fmt.Sprintf("`%s`.*", tableName)
}
func TableColumnID(tableName string) string {
return TableColumn(tableName, "id")
}
func TableColumnName(tableName string) string {
return TableColumn(tableName, "name")
}