Files
phishingclub/backend/repository/user.go
2025-08-21 16:14:09 +02:00

760 lines
16 KiB
Go

package repository
import (
"context"
"errors"
"fmt"
"github.com/google/uuid"
"github.com/oapi-codegen/nullable"
"github.com/phishingclub/phishingclub/database"
"github.com/phishingclub/phishingclub/errs"
"github.com/phishingclub/phishingclub/model"
"github.com/phishingclub/phishingclub/vo"
"gorm.io/gorm"
)
var sessionAllowedColumns = assignTableToColumns(database.SESSION_TABLE, []string{
"created_at",
"updated_at",
"ip_address",
})
var userAllowedColumns = assignTableToColumns(database.USER_TABLE, []string{
"created_at",
"updated_at",
"name",
"username",
"email",
})
// UserOption is a user option
type UserOption struct {
*vo.QueryArgs
WithRole bool
WithCompany bool
}
// User is a repository for User
type User struct {
DB *gorm.DB
}
// / with preloads the role and company to a user
func (r *User) with(options *UserOption, db *gorm.DB) *gorm.DB {
if options.WithRole {
db = db.Preload("Role")
}
if options.WithCompany {
db = db.Preload("Company")
}
return db
}
// SetupTOTP adds TOTP to a user
// adds the url and secret, but does not enable TOTP
func (r *User) SetupTOTP(
ctx context.Context,
userID *uuid.UUID,
secret string,
recoveryCodes string,
url string,
) error {
row :=
map[string]any{
"totp_enabled": false,
"totp_secret": secret,
"totp_recovery_code": recoveryCodes,
"totp_auth_url": url,
}
AddUpdatedAt(row)
result := r.DB.
Model(&database.User{}).
Where("id = ?", userID.String()).
Updates(row)
if result.Error != nil {
return errs.Wrap(result.Error)
}
return nil
}
// EnableTOTP enables TOTP for a user
func (r *User) EnableTOTP(
ctx context.Context,
userID *uuid.UUID,
) error {
row := map[string]any{
"totp_enabled": true,
}
AddUpdatedAt(row)
result := r.DB.
Model(&database.User{}).
Where("id = ?", userID.String()).
Updates(row)
if result.Error != nil {
return errs.Wrap(result.Error)
}
return nil
}
// RemoveTOTP removes TOTP from a user
func (r *User) RemoveTOTP(
ctx context.Context,
userID *uuid.UUID,
) error {
row := map[string]any{
"totp_enabled": false,
"totp_secret": "",
"totp_auth_url": "",
"totp_recovery_code": "",
}
AddUpdatedAt(row)
result := r.DB.
Model(&database.User{}).
Where("id = ?", userID.String()).
Updates(row)
if result.Error != nil {
return errs.Wrap(result.Error)
}
return nil
}
// GetTOTP gets TOTP from a user
func (r *User) GetTOTP(
ctx context.Context,
userID *uuid.UUID,
) (string, string, error) {
dbUser := &database.User{}
result := r.DB.
Select("totp_secret", "totp_auth_url").
Where("id = ?", userID.String()).
First(&dbUser)
if result.Error != nil {
return "", "", result.Error
}
return dbUser.TOTPSecret, dbUser.TOTPAuthURL, nil
}
// GetMFARecoveryCode gets the TOTP secret for a user
func (r *User) GetMFARecoveryCode(
ctx context.Context,
userID *uuid.UUID,
) (string, error) {
dbUser := &database.User{}
result := r.DB.
Select("totp_recovery_code").
Where("id = ?", userID.String()).
First(&dbUser)
if result.Error != nil {
return "", result.Error
}
return dbUser.TOTPRecoveryCode, nil
}
// IsTOTPEnabled checks if TOTP is enabled for a user
func (r *User) IsTOTPEnabled(
ctx context.Context,
userID *uuid.UUID,
) (bool, error) {
dbUser := &database.User{}
result := r.DB.
Select("totp_enabled").
Where("id = ?", userID.String()).
First(&dbUser)
if result.Error != nil {
return false, result.Error
}
return dbUser.TOTPEnabled, nil
}
// Insert creates a new user
func (r *User) Insert(
ctx context.Context,
user *model.User,
passwordHash string,
ssoID string,
) (*uuid.UUID, error) {
id := uuid.New()
row := user.ToDBMap()
row["id"] = id
AddTimestamps(row)
row["password_hash"] = passwordHash
row["sso_id"] = ssoID
res := r.DB.
Model(&database.User{}).
Create(row)
if res.Error != nil {
return nil, res.Error
}
return &id, nil
}
// UpdateByID updates a user by id
func (r *User) UpdateByID(
ctx context.Context,
id *uuid.UUID,
user *model.User,
) error {
row := user.ToDBMap()
AddUpdatedAt(row)
res := r.DB.
Model(&database.User{}).
Where("id = ?", id.String()).
Updates(row)
if res.Error != nil {
return res.Error
}
return nil
}
// UpsertAPIKey upserts api key
func (r *User) UpsertAPIKey(
ctx context.Context,
id *uuid.UUID,
key string,
) error {
row := map[string]any{}
AddUpdatedAt(row)
row["api_key"] = key
res := r.DB.
Model(&database.User{}).
Where("id = ?", id.String()).
Updates(row)
if res.Error != nil {
return res.Error
}
return nil
}
// RemoveAPIKey deletes a api key
func (r *User) RemoveAPIKey(
ctx context.Context,
id *uuid.UUID,
) error {
row := map[string]any{}
AddUpdatedAt(row)
row["api_key"] = ""
res := r.DB.
Model(&database.User{}).
Where("id = ?", id.String()).
Updates(row)
if res.Error != nil {
return res.Error
}
return nil
}
// GetAPIKey gets the users api key
func (r *User) GetAPIKey(
ctx context.Context,
id *uuid.UUID,
) (string, error) {
dbUser := &database.User{}
result := r.DB.
Select("api_key").
Where("id = ?", id.String()).
First(&dbUser)
if result.Error != nil {
return "", result.Error
}
return dbUser.APIKey, nil
}
// GetAllAPIKeys gets alll the users api keys
// return map[apiKey]userID
func (r *User) GetAllAPIKeys(
ctx context.Context,
) (map[string]*uuid.UUID, error) {
apiKeys := map[string]*uuid.UUID{}
dbUsers := []database.User{}
result := r.DB.
Select("id, api_key").
First(&dbUsers)
if result.Error != nil {
return apiKeys, result.Error
}
for _, dbUser := range dbUsers {
apiKeys[dbUser.APIKey] = dbUser.ID
}
return apiKeys, nil
}
// DeleteByID deletes a user by id
func (r *User) DeleteByID(
ctx context.Context,
id *uuid.UUID,
) error {
// anonymize user
// anon := uuid.New()
newName := fmt.Sprintf(
"deleted-%s",
uuid.New().String(),
)
res := r.DB.
Table(database.USER_TABLE).
Where("id = ?", id.String()).
Updates(map[string]any{
"name": newName,
"username": newName,
"email": fmt.Sprintf("%s@deleted.deleteduser", newName),
"password_hash": "",
"totp_enabled": false,
"totp_secret": nil,
"totp_auth_url": nil,
"totp_recovery_code": nil,
"sso_id": "",
})
if res.Error != nil {
return res.Error
}
res = r.DB.
Where("id = ?", id.String()).
Delete(&database.User{})
if res.Error != nil {
return res.Error
}
return nil
}
// GetAll gets all users
func (r *User) GetAll(
ctx context.Context,
options *UserOption,
) (*model.Result[model.User], error) {
result := model.NewEmptyResult[model.User]()
dbUsers := []database.User{}
db, err := useQuery(r.DB, database.USER_TABLE, options.QueryArgs, userAllowedColumns...)
if err != nil {
return result, errs.Wrap(err)
}
dbRes := r.with(options, db).
Find(&dbUsers)
if dbRes.Error != nil {
return result, dbRes.Error
}
hasNextPage, err := useHasNextPage(
db, database.USER_TABLE, options.QueryArgs, userAllowedColumns...,
)
if err != nil {
return result, errs.Wrap(err)
}
result.HasNextPage = hasNextPage
for _, dbUsers := range dbUsers {
usr, err := ToUser(&dbUsers)
if err != nil {
return result, errs.Wrap(err)
}
result.Rows = append(result.Rows, usr)
}
return result, nil
}
// GetByID gets a user by id, includding the role and company
func (r *User) GetByID(
ctx context.Context,
id *uuid.UUID,
options *UserOption,
) (*model.User, error) {
if id == nil {
return nil, errs.Wrap(errors.New("ID is nil"))
}
dbUser := &database.User{}
result := r.with(options, r.DB).
Where(
fmt.Sprintf(
"%s = ?",
TableColumnID(database.USER_TABLE),
),
id.String(),
).
First(&dbUser)
if result.Error != nil {
return nil, result.Error
}
return ToUser(dbUser)
}
// GetByUsername gets a user by username
func (r *User) GetByUsername(
ctx context.Context,
username *vo.Username,
options *UserOption,
) (*model.User, error) {
dbUser := &database.User{}
result := r.with(options, r.DB).
Where(
fmt.Sprintf(
"%s = ?",
TableColumn(database.USER_TABLE, "username"),
),
username.String(),
).
First(&dbUser)
if result.Error != nil {
return nil, result.Error
}
return ToUser(dbUser)
}
// GetByEmail gets a user by email
func (r *User) GetByEmail(
ctx context.Context,
email *vo.Email,
options *UserOption,
) (*model.User, error) {
dbUser := &database.User{}
result := r.with(options, r.DB).
Where(
fmt.Sprintf(
"%s = ?",
TableColumn(database.USER_TABLE, "email"),
),
email.String(),
).
First(&dbUser)
if result.Error != nil {
return nil, result.Error
}
return ToUser(dbUser)
}
func (r *User) GetPasswordHashByUsername(
ctx context.Context,
username *vo.Username,
) (string, error) {
dbUser := &database.User{}
result := r.DB.
Select("password_hash").
Where("username = ?", username.String()).
First(&dbUser)
if result.Error != nil {
return "", result.Error
}
return dbUser.PasswordHash, nil
}
// GetBySessionID gets a user by session id
// this does not validate the session
func (r *User) GetBySessionID(
ctx context.Context,
sessionID *uuid.UUID,
options *UserOption,
) (*model.User, error) {
dbUser := &database.User{}
db, err := useQuery(r.DB, database.SESSION_TABLE, options.QueryArgs, allowedSessionColumns...)
if err != nil {
return nil, errs.Wrap(err)
}
result := r.with(options, db).
Joins("JOIN sessions ON sessions.user_id = users.id").
Where("sessions.id = ?", sessionID.String()).
First(&dbUser)
if result.Error != nil {
return nil, result.Error
}
return ToUser(dbUser)
}
// updateUsernameByID updates the username by id
func (r *User) updateUsernameByID(
tx *gorm.DB,
id *uuid.UUID,
username *vo.Username,
) error {
row := map[string]any{
"username": username.String(),
}
AddUpdatedAt(row)
result := tx.
Model(&database.User{}).
Where("id = ?", id.String()).
Updates(row)
if result.Error != nil {
return errs.Wrap(result.Error)
}
return nil
}
// UpdateUserToSSO removes the password hash and sets a sso id
func (r *User) UpdateUserToSSO(
ctx context.Context,
id *uuid.UUID,
ssoID string,
) error {
result := r.DB.
Table(database.USER_TABLE).
Where("id = ?", id.String()).
Updates(map[string]interface{}{
"password_hash": "",
"sso_id": ssoID,
})
if result.Error != nil {
return errs.Wrap(result.Error)
}
return nil
}
// UpdateUserToNoSSO removes the SSO id
// f
func (r *User) UpdateUserToNoSSO(
ctx context.Context,
id *uuid.UUID,
) error {
result := r.DB.
Table(database.USER_TABLE).
Where("id = ?", id.String()).
Updates(map[string]interface{}{
"sso_id": "",
})
if result.Error != nil {
return errs.Wrap(result.Error)
}
return nil
}
// UpdateUsernameByID updates the username by id
func (r *User) UpdateUsernameByID(
ctx context.Context,
id *uuid.UUID,
username *vo.Username,
) error {
return r.updateUsernameByID(r.DB, id, username)
}
// UpdateUsernameByIDWithTransaction updates the username by id
func (r *User) UpdateUsernameByIDWithTransaction(
ctx context.Context,
tx *gorm.DB,
id *uuid.UUID,
username *vo.Username,
) error {
return r.updateUsernameByID(tx, id, username)
}
// updateFullNameByID updates the full name by id
func (r *User) updateFullNameByID(
tx *gorm.DB,
id *uuid.UUID,
name *vo.UserFullname,
) error {
row := map[string]any{
"name": name.String(),
}
AddUpdatedAt(row)
result := tx.
Model(&database.User{}).
Where("id = ?", id.String()).
Updates(row)
if result.Error != nil {
return errs.Wrap(result.Error)
}
return nil
}
// UpdateFullNameByID updates the full name by id
func (r *User) UpdateFullNameByID(
ctx context.Context,
id *uuid.UUID,
name *vo.UserFullname,
) error {
return r.updateFullNameByID(r.DB, id, name)
}
// UpdateFullNameByIDWithTransaction updates the full name by id
func (r *User) UpdateFullNameByIDWithTransaction(
ctx context.Context,
tx *gorm.DB,
id *uuid.UUID,
name *vo.UserFullname,
) error {
return r.updateFullNameByID(tx, id, name)
}
// updateEmailByID updates the email by id
func (r *User) updateEmailByID(
tx *gorm.DB,
id *uuid.UUID,
email *vo.Email,
) error {
row := map[string]any{
"email": email.String(),
}
AddUpdatedAt(row)
result := tx.
Model(&database.User{}).
Where("id = ?", id.String()).
Updates(row)
if result.Error != nil {
return errs.Wrap(result.Error)
}
return nil
}
// UpdateEmailByID updates the email by id
func (r *User) UpdateEmailByID(
ctx context.Context,
id *uuid.UUID,
email *vo.Email,
) error {
return r.updateEmailByID(r.DB, id, email)
}
// UpdateEmailByIDWithTransaction updates the email by id
func (r *User) UpdateEmailByIDWithTransaction(
ctx context.Context,
tx *gorm.DB,
id *uuid.UUID,
email *vo.Email,
) error {
return r.updateEmailByID(tx, id, email)
}
// updatePasswordHashByID updates the password hash by id
func (r *User) updatePasswordHashByID(
tx *gorm.DB,
id *uuid.UUID,
passwordHash string,
) error {
row := map[string]interface{}{
"password_hash": passwordHash,
}
AddUpdatedAt(row)
result := tx.
Model(&database.User{}).
Where("id = ?", id.String()).
Updates(row)
if result.Error != nil {
return errs.Wrap(result.Error)
}
return nil
}
// UpdatePasswordHashByID updates the password hash by id
func (r *User) UpdatePasswordHashByID(
ctx context.Context,
id *uuid.UUID,
passwordHash string,
) error {
return r.updatePasswordHashByID(r.DB, id, passwordHash)
}
// UpdatePasswordHashByIDWithTransaction updates the password hash by id
func (r *User) UpdatePasswordHashByIDWithTransaction(
ctx context.Context,
tx *gorm.DB,
id *uuid.UUID,
passwordHash string,
) error {
return r.updatePasswordHashByID(tx, id, passwordHash)
}
// updatePasswordHashByID updates the password hash by id
func (r *User) updatePasswordHashByUsername(
tx *gorm.DB,
username *vo.Username,
passwordHash string,
) error {
row := map[string]interface{}{
"password_hash": passwordHash,
"require_password_renew": false,
}
AddUpdatedAt(row)
result := tx.
Model(&database.User{}).
Where("username = ?", username.String()).
Updates(row)
if result.Error != nil {
return errs.Wrap(result.Error)
}
return nil
}
// UpdatePasswordHashByUsername updates the password hash by id
func (r *User) UpdatePasswordHashByUsername(
ctx context.Context,
username *vo.Username,
passwordHash string,
) error {
return r.updatePasswordHashByUsername(r.DB, username, passwordHash)
}
// UpdatePasswordHashByUsernameWithTransaction updates the password hash by id
func (r *User) UpdatePasswordHashByUsernameWithTransaction(
ctx context.Context,
tx *gorm.DB,
username *vo.Username,
passwordHash string,
) error {
return r.updatePasswordHashByUsername(tx, username, passwordHash)
}
func ToUser(row *database.User) (*model.User, error) {
id := nullable.NewNullableWithValue(*row.ID)
companyID := nullable.NewNullNullable[uuid.UUID]()
if row.CompanyID != nil {
companyID.Set(*row.CompanyID)
}
roleID := nullable.NewNullableWithValue(*row.RoleID)
userFullname := nullable.NewNullableWithValue(*vo.NewUserFullnameMust(row.Name))
username := nullable.NewNullableWithValue(*vo.NewUsernameMust(row.Username))
email := nullable.NewNullableWithValue(*vo.NewEmailMust(row.Email))
ssoID := nullable.NewNullableWithValue(row.SSOID)
var role *model.Role
if row.Role != nil {
role = ToRole(row.Role)
}
var company *model.Company
if row.Company != nil {
company = ToCompany(row.Company)
}
return &model.User{
ID: id,
Name: userFullname,
Username: username,
Email: email,
RoleID: roleID,
Role: role,
RequirePasswordRenew: nullable.NewNullableWithValue(row.RequirePasswordRenew),
CompanyID: companyID,
Company: company,
SSOID: ssoID,
}, nil
}