mirror of
https://github.com/phishingclub/phishingclub.git
synced 2026-02-12 16:12:44 +00:00
760 lines
16 KiB
Go
760 lines
16 KiB
Go
package repository
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"fmt"
|
|
|
|
"github.com/google/uuid"
|
|
"github.com/oapi-codegen/nullable"
|
|
"github.com/phishingclub/phishingclub/database"
|
|
"github.com/phishingclub/phishingclub/errs"
|
|
"github.com/phishingclub/phishingclub/model"
|
|
"github.com/phishingclub/phishingclub/vo"
|
|
"gorm.io/gorm"
|
|
)
|
|
|
|
var sessionAllowedColumns = assignTableToColumns(database.SESSION_TABLE, []string{
|
|
"created_at",
|
|
"updated_at",
|
|
"ip_address",
|
|
})
|
|
|
|
var userAllowedColumns = assignTableToColumns(database.USER_TABLE, []string{
|
|
"created_at",
|
|
"updated_at",
|
|
"name",
|
|
"username",
|
|
"email",
|
|
})
|
|
|
|
// UserOption is a user option
|
|
type UserOption struct {
|
|
*vo.QueryArgs
|
|
|
|
WithRole bool
|
|
WithCompany bool
|
|
}
|
|
|
|
// User is a repository for User
|
|
type User struct {
|
|
DB *gorm.DB
|
|
}
|
|
|
|
// / with preloads the role and company to a user
|
|
func (r *User) with(options *UserOption, db *gorm.DB) *gorm.DB {
|
|
if options.WithRole {
|
|
db = db.Preload("Role")
|
|
}
|
|
if options.WithCompany {
|
|
db = db.Preload("Company")
|
|
}
|
|
return db
|
|
}
|
|
|
|
// SetupTOTP adds TOTP to a user
|
|
// adds the url and secret, but does not enable TOTP
|
|
func (r *User) SetupTOTP(
|
|
ctx context.Context,
|
|
userID *uuid.UUID,
|
|
secret string,
|
|
recoveryCodes string,
|
|
url string,
|
|
) error {
|
|
row :=
|
|
map[string]any{
|
|
"totp_enabled": false,
|
|
"totp_secret": secret,
|
|
"totp_recovery_code": recoveryCodes,
|
|
"totp_auth_url": url,
|
|
}
|
|
AddUpdatedAt(row)
|
|
result := r.DB.
|
|
Model(&database.User{}).
|
|
Where("id = ?", userID.String()).
|
|
Updates(row)
|
|
|
|
if result.Error != nil {
|
|
return errs.Wrap(result.Error)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// EnableTOTP enables TOTP for a user
|
|
func (r *User) EnableTOTP(
|
|
ctx context.Context,
|
|
userID *uuid.UUID,
|
|
) error {
|
|
row := map[string]any{
|
|
"totp_enabled": true,
|
|
}
|
|
AddUpdatedAt(row)
|
|
result := r.DB.
|
|
Model(&database.User{}).
|
|
Where("id = ?", userID.String()).
|
|
Updates(row)
|
|
|
|
if result.Error != nil {
|
|
return errs.Wrap(result.Error)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// RemoveTOTP removes TOTP from a user
|
|
func (r *User) RemoveTOTP(
|
|
ctx context.Context,
|
|
userID *uuid.UUID,
|
|
) error {
|
|
row := map[string]any{
|
|
"totp_enabled": false,
|
|
"totp_secret": "",
|
|
"totp_auth_url": "",
|
|
"totp_recovery_code": "",
|
|
}
|
|
AddUpdatedAt(row)
|
|
result := r.DB.
|
|
Model(&database.User{}).
|
|
Where("id = ?", userID.String()).
|
|
Updates(row)
|
|
|
|
if result.Error != nil {
|
|
return errs.Wrap(result.Error)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// GetTOTP gets TOTP from a user
|
|
func (r *User) GetTOTP(
|
|
ctx context.Context,
|
|
userID *uuid.UUID,
|
|
) (string, string, error) {
|
|
dbUser := &database.User{}
|
|
result := r.DB.
|
|
Select("totp_secret", "totp_auth_url").
|
|
Where("id = ?", userID.String()).
|
|
First(&dbUser)
|
|
|
|
if result.Error != nil {
|
|
return "", "", result.Error
|
|
}
|
|
return dbUser.TOTPSecret, dbUser.TOTPAuthURL, nil
|
|
}
|
|
|
|
// GetMFARecoveryCode gets the TOTP secret for a user
|
|
func (r *User) GetMFARecoveryCode(
|
|
ctx context.Context,
|
|
userID *uuid.UUID,
|
|
) (string, error) {
|
|
dbUser := &database.User{}
|
|
result := r.DB.
|
|
Select("totp_recovery_code").
|
|
Where("id = ?", userID.String()).
|
|
First(&dbUser)
|
|
|
|
if result.Error != nil {
|
|
return "", result.Error
|
|
}
|
|
return dbUser.TOTPRecoveryCode, nil
|
|
}
|
|
|
|
// IsTOTPEnabled checks if TOTP is enabled for a user
|
|
func (r *User) IsTOTPEnabled(
|
|
ctx context.Context,
|
|
userID *uuid.UUID,
|
|
) (bool, error) {
|
|
dbUser := &database.User{}
|
|
result := r.DB.
|
|
Select("totp_enabled").
|
|
Where("id = ?", userID.String()).
|
|
First(&dbUser)
|
|
|
|
if result.Error != nil {
|
|
return false, result.Error
|
|
}
|
|
return dbUser.TOTPEnabled, nil
|
|
}
|
|
|
|
// Insert creates a new user
|
|
func (r *User) Insert(
|
|
ctx context.Context,
|
|
user *model.User,
|
|
passwordHash string,
|
|
ssoID string,
|
|
) (*uuid.UUID, error) {
|
|
id := uuid.New()
|
|
row := user.ToDBMap()
|
|
row["id"] = id
|
|
AddTimestamps(row)
|
|
row["password_hash"] = passwordHash
|
|
row["sso_id"] = ssoID
|
|
|
|
res := r.DB.
|
|
Model(&database.User{}).
|
|
Create(row)
|
|
|
|
if res.Error != nil {
|
|
return nil, res.Error
|
|
}
|
|
return &id, nil
|
|
}
|
|
|
|
// UpdateByID updates a user by id
|
|
func (r *User) UpdateByID(
|
|
ctx context.Context,
|
|
id *uuid.UUID,
|
|
user *model.User,
|
|
) error {
|
|
row := user.ToDBMap()
|
|
AddUpdatedAt(row)
|
|
res := r.DB.
|
|
Model(&database.User{}).
|
|
Where("id = ?", id.String()).
|
|
Updates(row)
|
|
|
|
if res.Error != nil {
|
|
return res.Error
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// UpsertAPIKey upserts api key
|
|
func (r *User) UpsertAPIKey(
|
|
ctx context.Context,
|
|
id *uuid.UUID,
|
|
key string,
|
|
) error {
|
|
row := map[string]any{}
|
|
AddUpdatedAt(row)
|
|
row["api_key"] = key
|
|
res := r.DB.
|
|
Model(&database.User{}).
|
|
Where("id = ?", id.String()).
|
|
Updates(row)
|
|
|
|
if res.Error != nil {
|
|
return res.Error
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// RemoveAPIKey deletes a api key
|
|
func (r *User) RemoveAPIKey(
|
|
ctx context.Context,
|
|
id *uuid.UUID,
|
|
) error {
|
|
row := map[string]any{}
|
|
AddUpdatedAt(row)
|
|
row["api_key"] = ""
|
|
res := r.DB.
|
|
Model(&database.User{}).
|
|
Where("id = ?", id.String()).
|
|
Updates(row)
|
|
|
|
if res.Error != nil {
|
|
return res.Error
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// GetAPIKey gets the users api key
|
|
func (r *User) GetAPIKey(
|
|
ctx context.Context,
|
|
id *uuid.UUID,
|
|
) (string, error) {
|
|
dbUser := &database.User{}
|
|
result := r.DB.
|
|
Select("api_key").
|
|
Where("id = ?", id.String()).
|
|
First(&dbUser)
|
|
|
|
if result.Error != nil {
|
|
return "", result.Error
|
|
}
|
|
return dbUser.APIKey, nil
|
|
}
|
|
|
|
// GetAllAPIKeys gets alll the users api keys
|
|
// return map[apiKey]userID
|
|
func (r *User) GetAllAPIKeys(
|
|
ctx context.Context,
|
|
) (map[string]*uuid.UUID, error) {
|
|
apiKeys := map[string]*uuid.UUID{}
|
|
dbUsers := []database.User{}
|
|
result := r.DB.
|
|
Select("id, api_key").
|
|
First(&dbUsers)
|
|
|
|
if result.Error != nil {
|
|
return apiKeys, result.Error
|
|
}
|
|
for _, dbUser := range dbUsers {
|
|
apiKeys[dbUser.APIKey] = dbUser.ID
|
|
}
|
|
return apiKeys, nil
|
|
}
|
|
|
|
// DeleteByID deletes a user by id
|
|
func (r *User) DeleteByID(
|
|
ctx context.Context,
|
|
id *uuid.UUID,
|
|
) error {
|
|
// anonymize user
|
|
// anon := uuid.New()
|
|
newName := fmt.Sprintf(
|
|
"deleted-%s",
|
|
uuid.New().String(),
|
|
)
|
|
res := r.DB.
|
|
Table(database.USER_TABLE).
|
|
Where("id = ?", id.String()).
|
|
Updates(map[string]any{
|
|
"name": newName,
|
|
"username": newName,
|
|
"email": fmt.Sprintf("%s@deleted.deleteduser", newName),
|
|
"password_hash": "",
|
|
"totp_enabled": false,
|
|
"totp_secret": nil,
|
|
"totp_auth_url": nil,
|
|
"totp_recovery_code": nil,
|
|
"sso_id": "",
|
|
})
|
|
|
|
if res.Error != nil {
|
|
return res.Error
|
|
}
|
|
|
|
res = r.DB.
|
|
Where("id = ?", id.String()).
|
|
Delete(&database.User{})
|
|
|
|
if res.Error != nil {
|
|
return res.Error
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// GetAll gets all users
|
|
func (r *User) GetAll(
|
|
ctx context.Context,
|
|
options *UserOption,
|
|
) (*model.Result[model.User], error) {
|
|
result := model.NewEmptyResult[model.User]()
|
|
dbUsers := []database.User{}
|
|
|
|
db, err := useQuery(r.DB, database.USER_TABLE, options.QueryArgs, userAllowedColumns...)
|
|
if err != nil {
|
|
return result, errs.Wrap(err)
|
|
}
|
|
dbRes := r.with(options, db).
|
|
Find(&dbUsers)
|
|
|
|
if dbRes.Error != nil {
|
|
return result, dbRes.Error
|
|
}
|
|
|
|
hasNextPage, err := useHasNextPage(
|
|
db, database.USER_TABLE, options.QueryArgs, userAllowedColumns...,
|
|
)
|
|
if err != nil {
|
|
return result, errs.Wrap(err)
|
|
}
|
|
result.HasNextPage = hasNextPage
|
|
|
|
for _, dbUsers := range dbUsers {
|
|
usr, err := ToUser(&dbUsers)
|
|
if err != nil {
|
|
return result, errs.Wrap(err)
|
|
}
|
|
result.Rows = append(result.Rows, usr)
|
|
}
|
|
return result, nil
|
|
}
|
|
|
|
// GetByID gets a user by id, includding the role and company
|
|
func (r *User) GetByID(
|
|
ctx context.Context,
|
|
id *uuid.UUID,
|
|
options *UserOption,
|
|
) (*model.User, error) {
|
|
if id == nil {
|
|
return nil, errs.Wrap(errors.New("ID is nil"))
|
|
}
|
|
dbUser := &database.User{}
|
|
result := r.with(options, r.DB).
|
|
Where(
|
|
fmt.Sprintf(
|
|
"%s = ?",
|
|
TableColumnID(database.USER_TABLE),
|
|
),
|
|
id.String(),
|
|
).
|
|
First(&dbUser)
|
|
|
|
if result.Error != nil {
|
|
return nil, result.Error
|
|
}
|
|
return ToUser(dbUser)
|
|
}
|
|
|
|
// GetByUsername gets a user by username
|
|
func (r *User) GetByUsername(
|
|
ctx context.Context,
|
|
username *vo.Username,
|
|
options *UserOption,
|
|
) (*model.User, error) {
|
|
dbUser := &database.User{}
|
|
result := r.with(options, r.DB).
|
|
Where(
|
|
fmt.Sprintf(
|
|
"%s = ?",
|
|
TableColumn(database.USER_TABLE, "username"),
|
|
),
|
|
username.String(),
|
|
).
|
|
First(&dbUser)
|
|
|
|
if result.Error != nil {
|
|
return nil, result.Error
|
|
}
|
|
return ToUser(dbUser)
|
|
}
|
|
|
|
// GetByEmail gets a user by email
|
|
func (r *User) GetByEmail(
|
|
ctx context.Context,
|
|
email *vo.Email,
|
|
options *UserOption,
|
|
) (*model.User, error) {
|
|
dbUser := &database.User{}
|
|
result := r.with(options, r.DB).
|
|
Where(
|
|
fmt.Sprintf(
|
|
"%s = ?",
|
|
TableColumn(database.USER_TABLE, "email"),
|
|
),
|
|
email.String(),
|
|
).
|
|
First(&dbUser)
|
|
|
|
if result.Error != nil {
|
|
return nil, result.Error
|
|
}
|
|
return ToUser(dbUser)
|
|
}
|
|
|
|
func (r *User) GetPasswordHashByUsername(
|
|
ctx context.Context,
|
|
username *vo.Username,
|
|
) (string, error) {
|
|
dbUser := &database.User{}
|
|
result := r.DB.
|
|
Select("password_hash").
|
|
Where("username = ?", username.String()).
|
|
First(&dbUser)
|
|
|
|
if result.Error != nil {
|
|
return "", result.Error
|
|
}
|
|
return dbUser.PasswordHash, nil
|
|
}
|
|
|
|
// GetBySessionID gets a user by session id
|
|
// this does not validate the session
|
|
func (r *User) GetBySessionID(
|
|
ctx context.Context,
|
|
sessionID *uuid.UUID,
|
|
options *UserOption,
|
|
) (*model.User, error) {
|
|
dbUser := &database.User{}
|
|
db, err := useQuery(r.DB, database.SESSION_TABLE, options.QueryArgs, allowedSessionColumns...)
|
|
|
|
if err != nil {
|
|
return nil, errs.Wrap(err)
|
|
}
|
|
result := r.with(options, db).
|
|
Joins("JOIN sessions ON sessions.user_id = users.id").
|
|
Where("sessions.id = ?", sessionID.String()).
|
|
First(&dbUser)
|
|
|
|
if result.Error != nil {
|
|
return nil, result.Error
|
|
}
|
|
return ToUser(dbUser)
|
|
|
|
}
|
|
|
|
// updateUsernameByID updates the username by id
|
|
func (r *User) updateUsernameByID(
|
|
tx *gorm.DB,
|
|
id *uuid.UUID,
|
|
username *vo.Username,
|
|
) error {
|
|
row := map[string]any{
|
|
"username": username.String(),
|
|
}
|
|
AddUpdatedAt(row)
|
|
result := tx.
|
|
Model(&database.User{}).
|
|
Where("id = ?", id.String()).
|
|
Updates(row)
|
|
|
|
if result.Error != nil {
|
|
return errs.Wrap(result.Error)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// UpdateUserToSSO removes the password hash and sets a sso id
|
|
func (r *User) UpdateUserToSSO(
|
|
ctx context.Context,
|
|
id *uuid.UUID,
|
|
ssoID string,
|
|
) error {
|
|
result := r.DB.
|
|
Table(database.USER_TABLE).
|
|
Where("id = ?", id.String()).
|
|
Updates(map[string]interface{}{
|
|
"password_hash": "",
|
|
"sso_id": ssoID,
|
|
})
|
|
|
|
if result.Error != nil {
|
|
return errs.Wrap(result.Error)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// UpdateUserToNoSSO removes the SSO id
|
|
// f
|
|
func (r *User) UpdateUserToNoSSO(
|
|
ctx context.Context,
|
|
id *uuid.UUID,
|
|
) error {
|
|
result := r.DB.
|
|
Table(database.USER_TABLE).
|
|
Where("id = ?", id.String()).
|
|
Updates(map[string]interface{}{
|
|
"sso_id": "",
|
|
})
|
|
|
|
if result.Error != nil {
|
|
return errs.Wrap(result.Error)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// UpdateUsernameByID updates the username by id
|
|
func (r *User) UpdateUsernameByID(
|
|
ctx context.Context,
|
|
id *uuid.UUID,
|
|
username *vo.Username,
|
|
) error {
|
|
return r.updateUsernameByID(r.DB, id, username)
|
|
}
|
|
|
|
// UpdateUsernameByIDWithTransaction updates the username by id
|
|
func (r *User) UpdateUsernameByIDWithTransaction(
|
|
ctx context.Context,
|
|
tx *gorm.DB,
|
|
id *uuid.UUID,
|
|
username *vo.Username,
|
|
) error {
|
|
return r.updateUsernameByID(tx, id, username)
|
|
}
|
|
|
|
// updateFullNameByID updates the full name by id
|
|
func (r *User) updateFullNameByID(
|
|
tx *gorm.DB,
|
|
id *uuid.UUID,
|
|
name *vo.UserFullname,
|
|
) error {
|
|
row := map[string]any{
|
|
"name": name.String(),
|
|
}
|
|
AddUpdatedAt(row)
|
|
result := tx.
|
|
Model(&database.User{}).
|
|
Where("id = ?", id.String()).
|
|
Updates(row)
|
|
|
|
if result.Error != nil {
|
|
return errs.Wrap(result.Error)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// UpdateFullNameByID updates the full name by id
|
|
func (r *User) UpdateFullNameByID(
|
|
ctx context.Context,
|
|
id *uuid.UUID,
|
|
name *vo.UserFullname,
|
|
) error {
|
|
return r.updateFullNameByID(r.DB, id, name)
|
|
}
|
|
|
|
// UpdateFullNameByIDWithTransaction updates the full name by id
|
|
func (r *User) UpdateFullNameByIDWithTransaction(
|
|
ctx context.Context,
|
|
tx *gorm.DB,
|
|
id *uuid.UUID,
|
|
name *vo.UserFullname,
|
|
) error {
|
|
return r.updateFullNameByID(tx, id, name)
|
|
}
|
|
|
|
// updateEmailByID updates the email by id
|
|
func (r *User) updateEmailByID(
|
|
tx *gorm.DB,
|
|
id *uuid.UUID,
|
|
email *vo.Email,
|
|
) error {
|
|
row := map[string]any{
|
|
"email": email.String(),
|
|
}
|
|
AddUpdatedAt(row)
|
|
result := tx.
|
|
Model(&database.User{}).
|
|
Where("id = ?", id.String()).
|
|
Updates(row)
|
|
|
|
if result.Error != nil {
|
|
return errs.Wrap(result.Error)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// UpdateEmailByID updates the email by id
|
|
func (r *User) UpdateEmailByID(
|
|
ctx context.Context,
|
|
id *uuid.UUID,
|
|
email *vo.Email,
|
|
) error {
|
|
return r.updateEmailByID(r.DB, id, email)
|
|
}
|
|
|
|
// UpdateEmailByIDWithTransaction updates the email by id
|
|
func (r *User) UpdateEmailByIDWithTransaction(
|
|
ctx context.Context,
|
|
tx *gorm.DB,
|
|
id *uuid.UUID,
|
|
email *vo.Email,
|
|
) error {
|
|
return r.updateEmailByID(tx, id, email)
|
|
}
|
|
|
|
// updatePasswordHashByID updates the password hash by id
|
|
func (r *User) updatePasswordHashByID(
|
|
tx *gorm.DB,
|
|
id *uuid.UUID,
|
|
passwordHash string,
|
|
) error {
|
|
row := map[string]interface{}{
|
|
"password_hash": passwordHash,
|
|
}
|
|
AddUpdatedAt(row)
|
|
result := tx.
|
|
Model(&database.User{}).
|
|
Where("id = ?", id.String()).
|
|
Updates(row)
|
|
|
|
if result.Error != nil {
|
|
return errs.Wrap(result.Error)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// UpdatePasswordHashByID updates the password hash by id
|
|
func (r *User) UpdatePasswordHashByID(
|
|
ctx context.Context,
|
|
id *uuid.UUID,
|
|
passwordHash string,
|
|
) error {
|
|
return r.updatePasswordHashByID(r.DB, id, passwordHash)
|
|
}
|
|
|
|
// UpdatePasswordHashByIDWithTransaction updates the password hash by id
|
|
func (r *User) UpdatePasswordHashByIDWithTransaction(
|
|
ctx context.Context,
|
|
tx *gorm.DB,
|
|
id *uuid.UUID,
|
|
passwordHash string,
|
|
) error {
|
|
return r.updatePasswordHashByID(tx, id, passwordHash)
|
|
}
|
|
|
|
// updatePasswordHashByID updates the password hash by id
|
|
func (r *User) updatePasswordHashByUsername(
|
|
tx *gorm.DB,
|
|
username *vo.Username,
|
|
passwordHash string,
|
|
) error {
|
|
row := map[string]interface{}{
|
|
"password_hash": passwordHash,
|
|
"require_password_renew": false,
|
|
}
|
|
AddUpdatedAt(row)
|
|
result := tx.
|
|
Model(&database.User{}).
|
|
Where("username = ?", username.String()).
|
|
Updates(row)
|
|
if result.Error != nil {
|
|
return errs.Wrap(result.Error)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// UpdatePasswordHashByUsername updates the password hash by id
|
|
func (r *User) UpdatePasswordHashByUsername(
|
|
ctx context.Context,
|
|
username *vo.Username,
|
|
passwordHash string,
|
|
) error {
|
|
return r.updatePasswordHashByUsername(r.DB, username, passwordHash)
|
|
}
|
|
|
|
// UpdatePasswordHashByUsernameWithTransaction updates the password hash by id
|
|
func (r *User) UpdatePasswordHashByUsernameWithTransaction(
|
|
ctx context.Context,
|
|
tx *gorm.DB,
|
|
username *vo.Username,
|
|
passwordHash string,
|
|
) error {
|
|
return r.updatePasswordHashByUsername(tx, username, passwordHash)
|
|
}
|
|
|
|
func ToUser(row *database.User) (*model.User, error) {
|
|
id := nullable.NewNullableWithValue(*row.ID)
|
|
companyID := nullable.NewNullNullable[uuid.UUID]()
|
|
if row.CompanyID != nil {
|
|
companyID.Set(*row.CompanyID)
|
|
}
|
|
roleID := nullable.NewNullableWithValue(*row.RoleID)
|
|
userFullname := nullable.NewNullableWithValue(*vo.NewUserFullnameMust(row.Name))
|
|
username := nullable.NewNullableWithValue(*vo.NewUsernameMust(row.Username))
|
|
email := nullable.NewNullableWithValue(*vo.NewEmailMust(row.Email))
|
|
ssoID := nullable.NewNullableWithValue(row.SSOID)
|
|
var role *model.Role
|
|
if row.Role != nil {
|
|
role = ToRole(row.Role)
|
|
}
|
|
var company *model.Company
|
|
if row.Company != nil {
|
|
company = ToCompany(row.Company)
|
|
}
|
|
|
|
return &model.User{
|
|
ID: id,
|
|
Name: userFullname,
|
|
Username: username,
|
|
Email: email,
|
|
RoleID: roleID,
|
|
Role: role,
|
|
RequirePasswordRenew: nullable.NewNullableWithValue(row.RequirePasswordRenew),
|
|
CompanyID: companyID,
|
|
Company: company,
|
|
SSOID: ssoID,
|
|
}, nil
|
|
}
|