mirror of
https://github.com/phishingclub/phishingclub.git
synced 2026-02-12 16:12:44 +00:00
408 lines
11 KiB
Go
408 lines
11 KiB
Go
package repository
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"slices"
|
|
"strings"
|
|
|
|
"github.com/go-errors/errors"
|
|
"github.com/google/uuid"
|
|
"github.com/phishingclub/phishingclub/errs"
|
|
"github.com/phishingclub/phishingclub/utils"
|
|
"github.com/phishingclub/phishingclub/vo"
|
|
"gorm.io/gorm"
|
|
)
|
|
|
|
var defaultAllowdSearchColumns = []string{
|
|
"name",
|
|
}
|
|
|
|
var defaultAllowdColumns = map[string]struct{}{
|
|
"name": {},
|
|
"created_at": {},
|
|
"updated_at": {},
|
|
}
|
|
|
|
func withOffsetLimit(db *gorm.DB, offset, limit int) *gorm.DB {
|
|
if offset == 0 && limit == 0 {
|
|
return db
|
|
}
|
|
return db.Offset(offset).Limit(limit)
|
|
}
|
|
|
|
func WithOrderBy(db *gorm.DB, orderBy string, desc bool, allowed ...string) (*gorm.DB, error) {
|
|
if orderBy == "" {
|
|
return db, nil
|
|
}
|
|
// if no allowed columns are provided, use the default
|
|
// else check if the column is allowed
|
|
if len(allowed) == 0 {
|
|
if _, ok := defaultAllowdColumns[orderBy]; !ok {
|
|
return db, fmt.Errorf(
|
|
"not known or allowed column: %s - allowd: %s",
|
|
orderBy,
|
|
defaultAllowdColumns,
|
|
)
|
|
}
|
|
} else {
|
|
if !slices.Contains(allowed, orderBy) {
|
|
return db, fmt.Errorf(
|
|
"not known or allowed column: %s - allowd: %s",
|
|
orderBy,
|
|
allowed,
|
|
)
|
|
}
|
|
}
|
|
// TODO this ruins all indexes performance but is a quick fix to work for all databases
|
|
// to ensure that the order by is case insensitive
|
|
// the real solution is to use a case insensitive collation
|
|
// but these differ per database, another option would be LOWER indexes for some columns however this is also not ideal
|
|
//orderBy = fmt.Sprintf("LOWER(%s)", orderBy)
|
|
|
|
if !desc {
|
|
return db.Order(orderBy + " COLLATE NOCASE ASC"), nil
|
|
}
|
|
return db.Order(orderBy + " COLLATE NOCASE DESC"), nil
|
|
}
|
|
|
|
func WithOrderByOnTable(db *gorm.DB, table string, orderBy string, desc bool, allowed ...string) (*gorm.DB, error) {
|
|
if orderBy == "" {
|
|
return db, nil
|
|
}
|
|
// only check default columns if no allowed columns are provided
|
|
if _, ok := defaultAllowdColumns[orderBy]; !ok && len(allowed) == 0 {
|
|
return db, fmt.Errorf("invalid column: %s", orderBy)
|
|
}
|
|
for _, allowedOrderBy := range allowed {
|
|
if !slices.Contains(allowed, orderBy) {
|
|
return db, fmt.Errorf("invalid column: %s", allowedOrderBy)
|
|
}
|
|
}
|
|
|
|
// TODO this ruins all indexes performance but is a quick fix to work for all databases
|
|
// to ensure that the order by is case insensitive
|
|
// the real solution is to use a case insensitive collation
|
|
// but these differ per database, another option would be LOWER indexes for some columns however this is also not ideal
|
|
if !desc {
|
|
return db.Order(
|
|
//fmt.Sprintf("LOWER(`%s`.`%s`) ASC", table, orderBy),
|
|
fmt.Sprintf("LOWER(`%s`.`%s`) COLLATE NOCAS ASC", table, orderBy),
|
|
), nil
|
|
}
|
|
return db.Order(
|
|
//fmt.Sprintf("LOWER(`%s`.`%s`) DESC", table, orderBy),
|
|
fmt.Sprintf("LOWER(`%s`.`%s`) NO COLLATE DESC", table, orderBy),
|
|
), nil
|
|
}
|
|
|
|
func assignTableToColumn(table, column string) string {
|
|
// if the column already contains a dot, it is already formatted
|
|
if strings.Contains(column, ".") {
|
|
return column
|
|
}
|
|
return fmt.Sprintf("`%s`.`%s`", table, column)
|
|
}
|
|
|
|
func assignTableToColumns(table string, columns []string) []string {
|
|
for i, column := range columns {
|
|
columns[i] = assignTableToColumn(table, column)
|
|
}
|
|
return columns
|
|
}
|
|
|
|
func useQuery(db *gorm.DB, tableName string, q *vo.QueryArgs, allowdColumns ...string) (*gorm.DB, error) {
|
|
if q == nil {
|
|
return db, nil
|
|
}
|
|
db = withOffsetLimit(db, q.Offset, q.Limit)
|
|
db, err := WithOrderBy(db, assignTableToColumn(tableName, q.OrderBy), q.Desc, allowdColumns...)
|
|
if err != nil {
|
|
return db, errs.Wrap(err)
|
|
}
|
|
// handle search
|
|
if q.Search != "" {
|
|
searchColumns := []string{}
|
|
// add columns that are allowed to be searched in
|
|
for _, column := range allowdColumns {
|
|
searchColumns = append(
|
|
searchColumns,
|
|
assignTableToColumn(tableName, column),
|
|
)
|
|
}
|
|
// if no columns has been added, use the default
|
|
if len(searchColumns) == 0 {
|
|
searchColumns = assignTableToColumns(tableName, defaultAllowdSearchColumns)
|
|
}
|
|
// remove search symbols
|
|
search := strings.ReplaceAll(q.Search, "%", " ")
|
|
search = strings.ReplaceAll(search, "_", " ")
|
|
// build the LIKE query
|
|
var searches []interface{}
|
|
q := ""
|
|
for i, column := range searchColumns {
|
|
if i == 0 {
|
|
q += fmt.Sprintf("%s LIKE ?", column)
|
|
} else {
|
|
q += fmt.Sprintf(" OR %s LIKE ?", column)
|
|
}
|
|
searches = append(searches, "%"+search+"%")
|
|
}
|
|
db.Where(q, searches...)
|
|
}
|
|
return db, errs.Wrap(err)
|
|
}
|
|
|
|
func useHasNextPage(
|
|
db *gorm.DB,
|
|
tableName string,
|
|
q *vo.QueryArgs,
|
|
allowdColumns ...string,
|
|
) (bool, error) {
|
|
if q == nil {
|
|
return false, nil
|
|
}
|
|
if q.Limit == 0 && q.Offset == 0 {
|
|
return false, nil
|
|
}
|
|
db = withOffsetLimit(db, q.Offset+q.Limit, 1)
|
|
db, err := WithOrderBy(db, assignTableToColumn(tableName, q.OrderBy), q.Desc, allowdColumns...)
|
|
if err != nil {
|
|
return false, errs.Wrap(err)
|
|
}
|
|
// handle search
|
|
if q.Search != "" {
|
|
searchColumns := []string{}
|
|
// add columns that are allowed to be searched in
|
|
for _, column := range allowdColumns {
|
|
searchColumns = append(
|
|
searchColumns,
|
|
assignTableToColumn(tableName, column),
|
|
)
|
|
}
|
|
// if no columns has been added, use the default
|
|
if len(searchColumns) == 0 {
|
|
searchColumns = assignTableToColumns(tableName, defaultAllowdSearchColumns)
|
|
}
|
|
// remove search symbols
|
|
search := strings.ReplaceAll(q.Search, "%", " ")
|
|
search = strings.ReplaceAll(search, "_", " ")
|
|
// build the LIKE query
|
|
var searches []interface{}
|
|
q := ""
|
|
for i, column := range searchColumns {
|
|
if i == 0 {
|
|
q += fmt.Sprintf("%s LIKE ?", column)
|
|
} else {
|
|
q += fmt.Sprintf(" OR %s LIKE ?", column)
|
|
}
|
|
searches = append(searches, "%"+search+"%")
|
|
}
|
|
db.Where(q, searches...)
|
|
}
|
|
// Check if there's at least one record
|
|
var exists bool
|
|
err = db.Select("1").Find(&exists).Error
|
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
|
return false, nil
|
|
}
|
|
if err != nil {
|
|
return false, errs.Wrap(err)
|
|
}
|
|
return exists, nil
|
|
}
|
|
|
|
/*
|
|
func useQueryWithTable(db *gorm.DB, table string, q *vo.QueryArgs, allowdColumns ...string) (*gorm.DB, error) {
|
|
if q == nil {
|
|
return db, nil
|
|
}
|
|
db = withOffsetLimit(db, q.Offset, q.Limit)
|
|
db, err := WithOrderByOnTable(db, table, q.OrderBy, q.Desc, allowdColumns...)
|
|
// handle search
|
|
if q.Search != "" {
|
|
searchColumns := []string{}
|
|
// add columns that are allowed to be searched in
|
|
for _, column := range allowdColumns {
|
|
if column == "created_at" || column == "updated_at" {
|
|
continue
|
|
}
|
|
searchColumns = append(searchColumns, column)
|
|
}
|
|
// if no columns has been added, use the default
|
|
if len(searchColumns) == 0 {
|
|
searchColumns = defaultAllowdSearchColumns
|
|
}
|
|
// remove search symbols
|
|
search := strings.ReplaceAll(q.Search, "%", " ")
|
|
search = strings.ReplaceAll(search, "_", " ")
|
|
// build the LIKE query
|
|
// todo perhaps this needs table prefix also
|
|
var searches []interface{}
|
|
q := ""
|
|
for i, column := range searchColumns {
|
|
if i == 0 {
|
|
q += column + " LIKE ?"
|
|
} else {
|
|
q += " OR " + column + " LIKE ?"
|
|
}
|
|
searches = append(searches, "%"+search+"%")
|
|
}
|
|
db.Where(q, searches...)
|
|
}
|
|
return db,errs.Wrap(err)
|
|
}
|
|
*/
|
|
|
|
func SelectTable(tableName string) []string {
|
|
return []string{fmt.Sprintf("`%s`.*", tableName)}
|
|
}
|
|
|
|
// SelectColumnAs creates a list of columns with aliases column is map[column]alias
|
|
func SelectColumnAs(tableName string, columns map[string]string) []string {
|
|
var cols []string
|
|
for key, value := range columns {
|
|
cols = append(cols, fmt.Sprintf("`%s`.`%s` AS %s", tableName, value, key))
|
|
}
|
|
return cols
|
|
}
|
|
|
|
func useSelect(db *gorm.DB, fields []string) *gorm.DB {
|
|
if len(fields) == 0 {
|
|
return db
|
|
}
|
|
return db.Select(fields)
|
|
}
|
|
|
|
func LeftJoinOn(tableLeft, columnLeft, tableRight, columnRight string) string {
|
|
return fmt.Sprintf("LEFT JOIN `%s` on `%s`.`%s` = `%s`.`%s`", tableRight, tableLeft, columnLeft, tableRight, columnRight)
|
|
}
|
|
|
|
func LeftJoinOnWithAlias(tableLeft, columnLeft, tableRight, columnRight, alias string) string {
|
|
return fmt.Sprintf("LEFT JOIN `%s` '%s' on `%s`.`%s` = `%s`.`%s`", tableRight, alias, tableLeft, columnLeft, alias, columnRight)
|
|
}
|
|
|
|
// withCompanyTableContext adds a company context to the query
|
|
func withCompanyIncludingNullContext(db *gorm.DB, companyID *uuid.UUID, tableName string) *gorm.DB {
|
|
column := fmt.Sprintf("`%s`.company_id", tableName)
|
|
if companyID != nil {
|
|
return db.Where(
|
|
fmt.Sprintf("(%s = ? OR %s IS NULL)", column, column), companyID)
|
|
}
|
|
return db.Where(
|
|
fmt.Sprintf("(%s IS NULL)", column),
|
|
)
|
|
}
|
|
|
|
// withCompany adds a where company id
|
|
// if companyID is NIL it will add a companyID IS NULL
|
|
func whereCompany(db *gorm.DB, tableName string, companyID *uuid.UUID) *gorm.DB {
|
|
column := fmt.Sprintf("`%s`.company_id", tableName)
|
|
if companyID == nil {
|
|
return db.Where(fmt.Sprintf("%s IS NULL", column))
|
|
} else {
|
|
return db.Where(
|
|
fmt.Sprintf("%s = ?", column), companyID)
|
|
}
|
|
}
|
|
|
|
// withCompany adds a where company id is null
|
|
func whereCompanyIsNull(db *gorm.DB, tableName string) *gorm.DB {
|
|
column := fmt.Sprintf("`%s`.company_id", tableName)
|
|
return db.Where(
|
|
fmt.Sprintf("%s IS NULL", column))
|
|
}
|
|
|
|
// AddTimestamps adds created_at and updated_at to a map
|
|
func AddTimestamps(row map[string]interface{}) {
|
|
now := utils.NowRFC3339UTC()
|
|
row["created_at"] = now
|
|
row["updated_at"] = now
|
|
}
|
|
|
|
// AddCreatedAt adds updated_at to a map
|
|
func AddUpdatedAt(row map[string]interface{}) {
|
|
row["updated_at"] = utils.NowRFC3339UTC()
|
|
}
|
|
|
|
// CheckColumnIsUnique checks if a column is unique within a company and globally
|
|
// if companyID is nil, it is global no other row should use the name, period.
|
|
// if companyID is set, then the same column id should not use it, and no null company ID
|
|
// columns is not sqli safe
|
|
func CheckColumnIsUnique(
|
|
ctx context.Context,
|
|
db *gorm.DB,
|
|
table string,
|
|
column string,
|
|
value string,
|
|
companyID *uuid.UUID,
|
|
currentID *uuid.UUID, // if currentID is set, it is allowed to use the same value
|
|
) (bool, error) {
|
|
var count int64
|
|
tx := db.Table(table)
|
|
|
|
if companyID != nil {
|
|
tx = tx.Where(column+" = ? AND (company_id = ? OR company_id IS NULL)", value, companyID)
|
|
} else {
|
|
tx = tx.Where(column+" = ?", value)
|
|
}
|
|
if currentID != nil {
|
|
tx = tx.Where("id != ?", currentID)
|
|
}
|
|
|
|
result := tx.Count(&count)
|
|
if result.Error != nil {
|
|
return false, result.Error
|
|
}
|
|
return count == 0, nil
|
|
}
|
|
|
|
func UUIDsToStrings(ids []*uuid.UUID) []string {
|
|
args := []string{}
|
|
for _, s := range ids {
|
|
args = append(args, s.String())
|
|
}
|
|
return args
|
|
}
|
|
|
|
// CheckNameIsUnique checks if a name is unique within a company and globally
|
|
// if companyID is nil, it is global no other row should use the name, period.
|
|
// if companyID is set, then the same company id should not use it, and no null company ID
|
|
func CheckNameIsUnique(
|
|
ctx context.Context,
|
|
db *gorm.DB,
|
|
table string,
|
|
name string,
|
|
companyID *uuid.UUID,
|
|
currentID *uuid.UUID,
|
|
) (bool, error) {
|
|
return CheckColumnIsUnique(ctx, db, table, "name", name, companyID, currentID)
|
|
}
|
|
|
|
func TableSelect(selects ...string) string {
|
|
return strings.Join(
|
|
selects,
|
|
",",
|
|
)
|
|
}
|
|
|
|
func TableColumn(tableName, columnName string) string {
|
|
return fmt.Sprintf("`%s`.`%s`", tableName, columnName)
|
|
}
|
|
|
|
func TableColumnAlias(tableName, columnName, alias string) string {
|
|
return fmt.Sprintf("`%s`.`%s` AS `%s`", tableName, columnName, alias)
|
|
}
|
|
|
|
func TableColumnAll(tableName string) string {
|
|
return fmt.Sprintf("`%s`.*", tableName)
|
|
}
|
|
|
|
func TableColumnID(tableName string) string {
|
|
return TableColumn(tableName, "id")
|
|
}
|
|
|
|
func TableColumnName(tableName string) string {
|
|
return TableColumn(tableName, "name")
|
|
}
|