mirror of
https://github.com/phishingclub/phishingclub.git
synced 2026-02-12 16:12:44 +00:00
926 lines
20 KiB
Go
926 lines
20 KiB
Go
package controller
|
|
|
|
import (
|
|
"net/http"
|
|
|
|
"github.com/go-errors/errors"
|
|
|
|
"github.com/gin-gonic/gin"
|
|
"github.com/google/uuid"
|
|
"github.com/phishingclub/phishingclub/data"
|
|
"github.com/phishingclub/phishingclub/database"
|
|
"github.com/phishingclub/phishingclub/errs"
|
|
"github.com/phishingclub/phishingclub/model"
|
|
"github.com/phishingclub/phishingclub/repository"
|
|
"github.com/phishingclub/phishingclub/service"
|
|
"github.com/phishingclub/phishingclub/vo"
|
|
"gorm.io/gorm"
|
|
)
|
|
|
|
var SessionColumnsMap = map[string]string{
|
|
"created_at": repository.TableColumn(database.SESSION_TABLE, "created_at"),
|
|
"updated_at": repository.TableColumn(database.SESSION_TABLE, "updated_at"),
|
|
"ip_address": repository.TableColumn(database.SESSION_TABLE, "ip_address"),
|
|
}
|
|
|
|
var UserColumnsMap = map[string]string{
|
|
"created_at": repository.TableColumn(database.USER_TABLE, "created_at"),
|
|
"updated_at": repository.TableColumn(database.USER_TABLE, "updated_at"),
|
|
"name": repository.TableColumn(database.USER_TABLE, "name"),
|
|
"username": repository.TableColumn(database.USER_TABLE, "username"),
|
|
"email": repository.TableColumn(database.USER_TABLE, "email"),
|
|
}
|
|
|
|
// UserLoginRequest is a request for login with username and password
|
|
type UserLoginRequest struct {
|
|
Username string `json:"username"`
|
|
Password string `json:"password"`
|
|
TOTP string `json:"totp"`
|
|
MFARecoveryCode string `json:"recoveryCode"`
|
|
}
|
|
|
|
// UserSetupTOTPRequest is a request for setting up TOTP
|
|
type UserSetupTOTPRequest struct {
|
|
Password string `json:"password"`
|
|
}
|
|
|
|
// UserSetupDisableTOTPRequest is a request for disabling TOTP
|
|
type UserDisableTOTPRequest struct {
|
|
Token string `json:"token"`
|
|
}
|
|
|
|
// UserVerifyTOTPRequest is a request for verifying TOTP
|
|
type UserVerifyTOTPRequest struct {
|
|
TOTP string `json:"token"`
|
|
}
|
|
|
|
// UserLoginWithMFARecoveryCodeRequest is a request for login with MFA recovery code
|
|
type UserLoginWithMFARecoveryCodeRequest struct {
|
|
RecoveryCode string `json:"recoveryCode"`
|
|
Username string `json:"username"`
|
|
Password string `json:"password"`
|
|
}
|
|
|
|
// User is the change email controller
|
|
type User struct {
|
|
Common
|
|
UserService *service.User
|
|
}
|
|
|
|
// Create creates a new user
|
|
func (c *User) Create(g *gin.Context) {
|
|
session, _, ok := c.handleSession(g)
|
|
if !ok {
|
|
return
|
|
}
|
|
// parse req
|
|
var req model.UserUpsertRequest
|
|
if ok := c.handleParseRequest(g, &req); !ok {
|
|
return
|
|
}
|
|
// create user
|
|
newUserID, err := c.UserService.Create(
|
|
g,
|
|
session,
|
|
&req,
|
|
)
|
|
if ok := c.handleErrors(g, err); !ok {
|
|
return
|
|
}
|
|
c.Response.OK(
|
|
g,
|
|
gin.H{
|
|
"id": newUserID.String(),
|
|
},
|
|
)
|
|
}
|
|
|
|
// GetMaskedAPIKey gets logged-in users masked API key
|
|
func (c *User) GetMaskedAPIKey(g *gin.Context) {
|
|
session, user, ok := c.handleSession(g)
|
|
if !ok {
|
|
return
|
|
}
|
|
if user == nil {
|
|
c.handleErrors(g, errors.New("no user in session"))
|
|
}
|
|
// get
|
|
cid := user.ID.MustGet()
|
|
apiKey, err := c.UserService.GetMaskedAPIKey(
|
|
g,
|
|
session,
|
|
&cid,
|
|
)
|
|
if ok := c.handleErrors(g, err); !ok {
|
|
return
|
|
}
|
|
c.Response.OK(
|
|
g,
|
|
gin.H{
|
|
"apiKey": apiKey,
|
|
},
|
|
)
|
|
}
|
|
|
|
// UpsertAPIKey create/update API key
|
|
func (c *User) UpsertAPIKey(g *gin.Context) {
|
|
session, user, ok := c.handleSession(g)
|
|
if !ok {
|
|
return
|
|
}
|
|
if user == nil {
|
|
c.handleErrors(g, errors.New("no user in session"))
|
|
}
|
|
// create user
|
|
uid := user.ID.MustGet()
|
|
apiKey, err := c.UserService.UpsertAPIKey(
|
|
g,
|
|
session,
|
|
&uid,
|
|
)
|
|
if ok := c.handleErrors(g, err); !ok {
|
|
return
|
|
}
|
|
c.Response.OK(
|
|
g,
|
|
gin.H{
|
|
"apiKey": apiKey,
|
|
},
|
|
)
|
|
}
|
|
|
|
// RemoveAPIKey removes a api key
|
|
func (c *User) RemoveAPIKey(g *gin.Context) {
|
|
session, user, ok := c.handleSession(g)
|
|
if !ok {
|
|
return
|
|
}
|
|
if user == nil {
|
|
c.handleErrors(g, errors.New("no user in session"))
|
|
}
|
|
// create user
|
|
uid := user.ID.MustGet()
|
|
err := c.UserService.RemoveAPIKey(
|
|
g,
|
|
session,
|
|
&uid,
|
|
)
|
|
if ok := c.handleErrors(g, err); !ok {
|
|
return
|
|
}
|
|
c.Response.OK(
|
|
g,
|
|
gin.H{},
|
|
)
|
|
}
|
|
|
|
// UpdateByID updates a user by ID
|
|
func (c *User) UpdateByID(g *gin.Context) {
|
|
session, _, ok := c.handleSession(g)
|
|
if !ok {
|
|
return
|
|
}
|
|
// parse request
|
|
id, ok := c.handleParseIDParam(g)
|
|
if !ok {
|
|
return
|
|
}
|
|
var req model.User
|
|
if ok := c.handleParseRequest(g, &req); !ok {
|
|
return
|
|
}
|
|
// update user
|
|
err := c.UserService.Update(
|
|
g,
|
|
session,
|
|
id,
|
|
&req,
|
|
)
|
|
// handle response
|
|
if ok := c.handleErrors(g, err); !ok {
|
|
return
|
|
}
|
|
c.Response.OK(g, gin.H{})
|
|
}
|
|
|
|
// Delete deletes a user
|
|
func (c *User) Delete(g *gin.Context) {
|
|
session, _, ok := c.handleSession(g)
|
|
if !ok {
|
|
return
|
|
}
|
|
// parse request
|
|
id, ok := c.handleParseIDParam(g)
|
|
if !ok {
|
|
return
|
|
}
|
|
// delete user
|
|
err := c.UserService.Delete(g, session, id)
|
|
// handle response
|
|
if ok := c.handleErrors(g, err); !ok {
|
|
return
|
|
}
|
|
c.Response.OK(g, gin.H{})
|
|
}
|
|
|
|
// GetAll gets all users using pagination
|
|
func (c *User) GetAll(g *gin.Context) {
|
|
session, _, ok := c.handleSession(g)
|
|
if !ok {
|
|
return
|
|
}
|
|
// parse request
|
|
queryArgs, ok := c.handleQueryArgs(g)
|
|
if !ok {
|
|
return
|
|
}
|
|
queryArgs.DefaultSortByUpdatedAt()
|
|
queryArgs.RemapOrderBy(UserColumnsMap)
|
|
// get user
|
|
users, err := c.UserService.GetAll(g, session, &repository.UserOption{
|
|
QueryArgs: queryArgs,
|
|
WithRole: true,
|
|
WithCompany: true,
|
|
})
|
|
// handle response
|
|
if ok := c.handleErrors(g, err); !ok {
|
|
return
|
|
}
|
|
c.Response.OK(g, users)
|
|
}
|
|
|
|
// GetByID gets a user by ID
|
|
func (c *User) GetByID(g *gin.Context) {
|
|
session, _, ok := c.handleSession(g)
|
|
if !ok {
|
|
return
|
|
}
|
|
// parse request
|
|
id, ok := c.handleParseIDParam(g)
|
|
if !ok {
|
|
return
|
|
}
|
|
// get user
|
|
user, err := c.UserService.GetByID(g, session, id)
|
|
// handle response
|
|
if ok := c.handleErrors(g, err); !ok {
|
|
return
|
|
}
|
|
c.Response.OK(g, user)
|
|
}
|
|
|
|
// ChangeEmailOnLoggedInUser changes email on logged in user
|
|
// this is an administrator action
|
|
func (c *User) ChangeEmailOnLoggedInUser(g *gin.Context) {
|
|
session, sessionUser, ok := c.handleSession(g)
|
|
if !ok {
|
|
return
|
|
}
|
|
// parse and validate request
|
|
var request model.UserChangeEmailRequest
|
|
if ok := c.handleParseRequest(g, &request); !ok {
|
|
return
|
|
}
|
|
// change email
|
|
userID := sessionUser.ID.MustGet()
|
|
changedEmail, err := c.UserService.ChangeEmailAsAdministrator(
|
|
g,
|
|
session,
|
|
&userID,
|
|
&request.Email,
|
|
)
|
|
// handle response
|
|
if ok := c.handleErrors(g, err); !ok {
|
|
return
|
|
}
|
|
c.Response.OK(
|
|
g,
|
|
gin.H{"email": changedEmail.String()},
|
|
)
|
|
}
|
|
|
|
// ChangeFullnameOnLoggedInUser is the handler for change fullname
|
|
func (c *User) ChangeFullnameOnLoggedInUser(g *gin.Context) {
|
|
session, sessionUser, ok := c.handleSession(g)
|
|
if !ok {
|
|
return
|
|
}
|
|
// parse req
|
|
var req model.UserChangeFullnameRequest
|
|
if ok := c.handleParseRequest(g, &req); !ok {
|
|
return
|
|
}
|
|
// change fullname
|
|
userID := sessionUser.ID.MustGet()
|
|
_, err := c.UserService.ChangeFullname(
|
|
g,
|
|
session,
|
|
&userID,
|
|
&req.NewFullname,
|
|
)
|
|
// handle response
|
|
if ok := c.handleErrors(g, err); !ok {
|
|
return
|
|
}
|
|
c.Response.OK(g, gin.H{})
|
|
}
|
|
|
|
// ChangePasswordOnLoggedInUser changes the password on the logged in user
|
|
func (c *User) ChangePasswordOnLoggedInUser(g *gin.Context) {
|
|
session, sessionUser, ok := c.handleSession(g)
|
|
if !ok {
|
|
return
|
|
}
|
|
// parse req
|
|
var req model.UserChangePasswordRequest
|
|
if ok := c.handleParseRequest(g, &req); !ok {
|
|
return
|
|
}
|
|
// change password
|
|
err := c.UserService.ChangePassword(
|
|
g,
|
|
session,
|
|
&req.CurrentPassword,
|
|
&req.NewPassword,
|
|
)
|
|
// handle response
|
|
if errors.Is(err, errs.ErrUserWrongPasword) {
|
|
c.Response.BadRequestMessage(g, "Invalid current password")
|
|
return
|
|
}
|
|
if ok := c.handleErrors(g, err); !ok {
|
|
return
|
|
}
|
|
// invalidate all currently running sessions
|
|
userID := sessionUser.ID.MustGet()
|
|
err = c.SessionService.ExpireAllByUserID(g, session, &userID)
|
|
// partial error, the password is changed but the sessions are not invalidated
|
|
if ok := c.handleErrors(g, err); !ok {
|
|
return
|
|
}
|
|
c.Response.OK(
|
|
g,
|
|
"password changed - all sessions have been invalidated",
|
|
)
|
|
}
|
|
|
|
// ChangeUsernameOnLoggedInUser changes the username
|
|
func (c *User) ChangeUsernameOnLoggedInUser(g *gin.Context) {
|
|
session, sessionUser, ok := c.handleSession(g)
|
|
if !ok {
|
|
return
|
|
}
|
|
// parse req
|
|
var req model.UserChangeUsernameOnLoggedInRequest
|
|
if ok := c.handleParseRequest(g, &req); !ok {
|
|
return
|
|
}
|
|
userID := sessionUser.ID.MustGet()
|
|
// change username
|
|
err := c.UserService.ChangeUsername(
|
|
g.Request.Context(),
|
|
session,
|
|
&userID,
|
|
&req.NewUsername,
|
|
)
|
|
// handle error
|
|
if ok := c.handleErrors(g, err); !ok {
|
|
return
|
|
}
|
|
c.Response.OK(g, gin.H{})
|
|
}
|
|
|
|
// ExpireSessionByID expires a session by ID
|
|
// a administrator can expire any session
|
|
// a user can expire their own sessions
|
|
func (c *User) ExpireSessionByID(g *gin.Context) {
|
|
session, _, ok := c.handleSession(g)
|
|
if !ok {
|
|
return
|
|
}
|
|
id, ok := c.handleParseIDParam(g)
|
|
if !ok {
|
|
return
|
|
}
|
|
isAuthorized, err := service.IsAuthorized(session, data.PERMISSION_ALLOW_GLOBAL)
|
|
if ok := c.handleErrors(g, err); !ok {
|
|
return
|
|
}
|
|
if !isAuthorized {
|
|
c.Response.Forbidden(g)
|
|
return
|
|
}
|
|
err = c.SessionService.Expire(g, id)
|
|
// handle response
|
|
if ok := c.handleErrors(g, err); !ok {
|
|
return
|
|
}
|
|
c.Response.OK(
|
|
g,
|
|
"session expired",
|
|
)
|
|
}
|
|
|
|
// GetSessionsByUserID gets all sessions by user ID
|
|
func (c *User) GetSessionsOnLoggedInUser(g *gin.Context) {
|
|
session, sessionUser, ok := c.handleSession(g)
|
|
if !ok {
|
|
return
|
|
}
|
|
// parse request
|
|
queryArgs, ok := c.handleQueryArgs(g)
|
|
if !ok {
|
|
return
|
|
}
|
|
queryArgs.DefaultSortByUpdatedAt()
|
|
queryArgs.RemapOrderBy(SessionColumnsMap)
|
|
userID := sessionUser.ID.MustGet()
|
|
sessions, err := c.SessionService.GetSessionsByUserID(
|
|
g,
|
|
session,
|
|
&userID,
|
|
&repository.SessionOption{
|
|
QueryArgs: queryArgs,
|
|
},
|
|
)
|
|
// handle response
|
|
if ok := c.handleErrors(g, err); !ok {
|
|
return
|
|
}
|
|
data := []map[string]interface{}{}
|
|
for _, sess := range sessions.Rows {
|
|
idStr := sess.ID.String()
|
|
data = append(data, map[string]interface{}{
|
|
"id": idStr,
|
|
"current": idStr == session.ID.String(),
|
|
"ip": sess.IP,
|
|
"createdAt": sess.CreatedAt,
|
|
"updatedAt": sess.UpdatedAt,
|
|
})
|
|
}
|
|
c.Response.OK(
|
|
g,
|
|
gin.H{
|
|
"sessions": data,
|
|
"hasNextPage": sessions.HasNextPage,
|
|
},
|
|
)
|
|
}
|
|
|
|
// Login logs in a user
|
|
func (c *User) Login(g *gin.Context) {
|
|
// parse req
|
|
var req UserLoginRequest
|
|
if ok := c.handleParseRequest(g, &req); !ok {
|
|
return
|
|
}
|
|
user, err := c.UserService.AuthenticateUsernameWithPassword(
|
|
g,
|
|
req.Username,
|
|
req.Password,
|
|
g.ClientIP(),
|
|
)
|
|
if errors.Is(err, errs.ErrUserWrongPasword) {
|
|
c.Response.BadRequestMessage(g, "Invalid password")
|
|
return
|
|
}
|
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
|
c.Response.BadRequestMessage(g, "Invalid credentials")
|
|
return
|
|
}
|
|
if ok := c.handleErrors(g, err); !ok {
|
|
return
|
|
}
|
|
// if the user has MFA enabled then we check the MFA flow
|
|
// if the user has MFA enabled, we must check if there is a
|
|
// valid MFA or a valid recovery code
|
|
userID := user.ID.MustGet()
|
|
MFATokenSupplied := len(req.TOTP) > 0
|
|
MFARecoveryCodeSupplied := len(req.MFARecoveryCode) > 0
|
|
mfaEnabled, err := c.UserService.IsTOTPEnabledByUserID(
|
|
g,
|
|
&userID,
|
|
)
|
|
if errors.Is(err, errs.ErrUserWrongTOTP) {
|
|
c.Response.BadRequestMessage(g, "Invalid TOTP")
|
|
return
|
|
}
|
|
if ok := c.handleErrors(g, err); !ok {
|
|
return
|
|
}
|
|
if mfaEnabled {
|
|
// if tokens or recovery codes are supplied
|
|
// return mfa is required
|
|
if !MFATokenSupplied && !MFARecoveryCodeSupplied {
|
|
c.Response.OK(
|
|
g,
|
|
gin.H{
|
|
"mfa": true,
|
|
},
|
|
)
|
|
return
|
|
}
|
|
// if the client has given both a TOTP and a recovery code
|
|
// we return a bad request
|
|
if MFATokenSupplied && MFARecoveryCodeSupplied {
|
|
c.Response.BadRequestMessage(g, "Cannot supply both MFA token and MFA recovery code")
|
|
return
|
|
}
|
|
// verify the TOTP MFA token
|
|
userID := user.ID.MustGet()
|
|
if MFATokenSupplied && !MFARecoveryCodeSupplied {
|
|
// if MFA is enabled, verify the TOTP
|
|
totpToken, err := vo.NewString64(req.TOTP)
|
|
if err != nil {
|
|
c.Logger.Debugw("failed to create TOTP",
|
|
"error", err,
|
|
)
|
|
c.Response.ValidationFailed(g, "TOTP", err)
|
|
return
|
|
}
|
|
err = c.UserService.CheckTOTP(
|
|
g,
|
|
&userID,
|
|
totpToken,
|
|
)
|
|
if err != nil {
|
|
if errors.Is(err, errs.ErrUserWrongTOTP) {
|
|
c.Response.BadRequestMessage(g, "Invalid TOTP")
|
|
return
|
|
}
|
|
if ok := c.handleErrors(g, err); !ok {
|
|
return
|
|
}
|
|
}
|
|
}
|
|
// if the user has MFA enabled and the client has supplied a recovery code
|
|
// we verify the recovery code
|
|
if !MFATokenSupplied && MFARecoveryCodeSupplied {
|
|
recoveryCode, err := vo.NewString64(req.MFARecoveryCode)
|
|
if err != nil {
|
|
c.Logger.Debugw("failed to create recovery code",
|
|
"error", err,
|
|
)
|
|
c.Response.ValidationFailed(g, "RecoveryCode", err)
|
|
return
|
|
}
|
|
verifiedMFA, err := c.UserService.CheckMFARecoveryCode(
|
|
g,
|
|
&userID,
|
|
recoveryCode,
|
|
)
|
|
if err != nil {
|
|
if errors.Is(err, errs.ErrUserWrongRecoveryCode) {
|
|
c.Response.BadRequestMessage(g, "Invalid recovery code")
|
|
return
|
|
}
|
|
if ok := c.handleErrors(g, err); !ok {
|
|
return
|
|
}
|
|
}
|
|
if !verifiedMFA {
|
|
c.Response.BadRequestMessage(g, "Invalid recovery code")
|
|
return
|
|
}
|
|
// as the recovery code is valid, we can now disable MFA
|
|
err = c.UserService.DisableTOTP(g, &userID)
|
|
if ok := c.handleErrors(g, err); !ok {
|
|
return
|
|
}
|
|
}
|
|
}
|
|
// create a new session
|
|
session, err := c.SessionService.Create(
|
|
g,
|
|
user,
|
|
g.ClientIP(),
|
|
)
|
|
// handle response
|
|
if ok := c.handleErrors(g, err); !ok {
|
|
return
|
|
}
|
|
// Set the session in the cookie
|
|
cookie := &http.Cookie{
|
|
Name: data.SessionCookieKey,
|
|
Value: session.ID.String(),
|
|
Path: "/",
|
|
SameSite: http.SameSiteStrictMode,
|
|
HttpOnly: true,
|
|
Secure: true,
|
|
Expires: *session.MaxAgeAt,
|
|
}
|
|
http.SetCookie(g.Writer, cookie)
|
|
c.Response.OK(g, session)
|
|
}
|
|
|
|
// expireCookieAndStatusOK expires the cookie and returns a 200 OK
|
|
func (c *User) expireCookieAndStatusOK(g *gin.Context) {
|
|
g.SetCookie(
|
|
data.SessionCookieKey,
|
|
"",
|
|
-1,
|
|
"/",
|
|
"",
|
|
false,
|
|
true,
|
|
)
|
|
c.logoutOK(g)
|
|
}
|
|
|
|
// logoutOK returns a 200 OK
|
|
func (c *User) logoutOK(g *gin.Context) {
|
|
c.Response.OK(
|
|
g,
|
|
gin.H{"message": "logged out"},
|
|
)
|
|
}
|
|
|
|
// Logout logs out the user
|
|
// only invalidates the session if the session cookie is
|
|
// in the request, this should reduce the risk of CSRF logout
|
|
func (c *User) Logout(g *gin.Context) {
|
|
sessionCookie, err := g.Cookie(data.SessionCookieKey)
|
|
if err != nil {
|
|
c.logoutOK(g)
|
|
return
|
|
}
|
|
sessionID, err := uuid.Parse(sessionCookie)
|
|
if err != nil {
|
|
c.logoutOK(g)
|
|
return
|
|
}
|
|
ctx := g.Request.Context()
|
|
err = c.SessionService.Expire(ctx, &sessionID)
|
|
if err != nil {
|
|
c.expireCookieAndStatusOK(g)
|
|
return
|
|
}
|
|
c.expireCookieAndStatusOK(g)
|
|
}
|
|
|
|
// SessionPing pings the session
|
|
func (c *User) SessionPing(g *gin.Context) {
|
|
// handle session
|
|
session, sessionUser, ok := c.handleSession(g)
|
|
if !ok {
|
|
return
|
|
}
|
|
c.Logger.Debugw("pinged session for user",
|
|
"userID", sessionUser.ID.MustGet().String(),
|
|
)
|
|
sessionRole := sessionUser.Role
|
|
if sessionRole == nil {
|
|
c.Logger.Error("failed to load role from session user")
|
|
c.Response.ServerError(g)
|
|
return
|
|
}
|
|
sessionCompany := sessionUser.Company
|
|
companyName := ""
|
|
if sessionCompany != nil {
|
|
companyName = sessionCompany.Name.MustGet().String()
|
|
}
|
|
c.Response.OK(
|
|
g,
|
|
gin.H{
|
|
"userID": sessionUser.ID,
|
|
"username": sessionUser.Username.MustGet().String(),
|
|
"name": sessionUser.Name.MustGet().String(),
|
|
"role": sessionRole.Name,
|
|
"company": companyName,
|
|
"ip": session.IP,
|
|
},
|
|
)
|
|
}
|
|
|
|
// InvalidateAllSessionByUserID is the nuclear session button for a user
|
|
func (c *User) InvalidateAllSessionByUserID(g *gin.Context) {
|
|
session, user, ok := c.handleSession(g)
|
|
if !ok {
|
|
return
|
|
}
|
|
var userID *uuid.UUID
|
|
// parse req
|
|
var req model.InvalidateAllSessionRequest
|
|
err := g.ShouldBindJSON(&req)
|
|
if err != nil {
|
|
if user == nil || !user.ID.IsSpecified() {
|
|
c.Response.BadRequest(g)
|
|
return
|
|
}
|
|
uid := user.ID.MustGet()
|
|
userID = &uid
|
|
} else {
|
|
if req.UserID == nil {
|
|
c.Response.BadRequest(g)
|
|
return
|
|
}
|
|
userID = req.UserID
|
|
}
|
|
// invalidate
|
|
err = c.SessionService.ExpireAllByUserID(g, session, userID)
|
|
if ok := c.handleErrors(g, err); !ok {
|
|
return
|
|
}
|
|
c.Response.OK(g, gin.H{})
|
|
}
|
|
|
|
// SetupTOTP generates a new TOTP MFA secrets
|
|
func (c *User) SetupTOTP(g *gin.Context) {
|
|
session, _, ok := c.handleSession(g)
|
|
if !ok {
|
|
return
|
|
}
|
|
// parse request
|
|
var request UserSetupTOTPRequest
|
|
if ok := c.handleParseRequest(g, &request); !ok {
|
|
return
|
|
}
|
|
passwd, err := vo.NewReasonableLengthPassword(request.Password)
|
|
if err != nil {
|
|
c.Logger.Debugw("failed to create password",
|
|
"error", err,
|
|
)
|
|
c.Response.ValidationFailed(g, "Password", err)
|
|
return
|
|
}
|
|
// get and save TOTP for user
|
|
totpValues, err := c.UserService.SetupTOTP(
|
|
g.Request.Context(),
|
|
session,
|
|
passwd,
|
|
)
|
|
// handle response
|
|
if errors.Is(err, errs.ErrAuthenticationFailed) {
|
|
c.Response.BadRequestMessage(g, "Incorrect password")
|
|
return
|
|
}
|
|
if ok := handleServerError(g, c.Response, err); !ok {
|
|
return
|
|
}
|
|
c.Response.OK(
|
|
g,
|
|
gin.H{
|
|
"base32": totpValues.Secret,
|
|
"url": totpValues.URL,
|
|
"recoveryCode": totpValues.RecoveryCode,
|
|
},
|
|
)
|
|
}
|
|
|
|
// SetupVerifyTOTP verifies a TOTP
|
|
func (c *User) SetupVerifyTOTP(g *gin.Context) {
|
|
session, _, ok := c.handleSession(g)
|
|
if !ok {
|
|
return
|
|
}
|
|
// parse req
|
|
var req UserVerifyTOTPRequest
|
|
if ok := c.handleParseRequest(g, &req); !ok {
|
|
return
|
|
}
|
|
totp, err := vo.NewString64(req.TOTP)
|
|
if err != nil {
|
|
c.Logger.Debugw("failed to create TOTP",
|
|
"error", err,
|
|
)
|
|
c.Response.ValidationFailed(g, "TOTP", err)
|
|
return
|
|
}
|
|
// verify TOTP
|
|
err = c.UserService.SetupCheckTOTP(
|
|
g.Request.Context(),
|
|
session,
|
|
totp,
|
|
)
|
|
if errors.Is(err, errs.ErrUserWrongTOTP) {
|
|
c.Response.BadRequestMessage(g, "Invalid token")
|
|
return
|
|
}
|
|
|
|
// handle response
|
|
if ok := c.handleErrors(g, err); !ok {
|
|
return
|
|
}
|
|
c.Response.OK(
|
|
g,
|
|
"TOTP verified",
|
|
)
|
|
}
|
|
|
|
// IsTOTPEnabled checks if TOTP is enabled
|
|
func (c *User) IsTOTPEnabled(g *gin.Context) {
|
|
session, _, ok := c.handleSession(g)
|
|
if !ok {
|
|
return
|
|
}
|
|
// check if TOTP is enabled
|
|
isEnabled, err := c.UserService.IsTOTPEnabled(
|
|
g.Request.Context(),
|
|
session,
|
|
)
|
|
// handle response
|
|
if ok := handleServerError(g, c.Response, err); !ok {
|
|
return
|
|
}
|
|
c.Response.OK(
|
|
g,
|
|
gin.H{"enabled": isEnabled},
|
|
)
|
|
}
|
|
|
|
// DisableTOTP disables TOTP
|
|
func (c *User) DisableTOTP(g *gin.Context) {
|
|
_, user, ok := c.handleSession(g)
|
|
if !ok {
|
|
return
|
|
}
|
|
// parse request
|
|
var request UserDisableTOTPRequest
|
|
if ok := c.handleParseRequest(g, &request); !ok {
|
|
return
|
|
}
|
|
token, err := vo.NewString64(request.Token)
|
|
if err != nil {
|
|
c.Logger.Debugw("failed to create token",
|
|
"error", err,
|
|
)
|
|
c.Response.ValidationFailed(g, "Token", err)
|
|
return
|
|
}
|
|
// check TOTP
|
|
userID := user.ID.MustGet()
|
|
err = c.UserService.CheckTOTP(
|
|
g.Request.Context(),
|
|
&userID,
|
|
token,
|
|
)
|
|
if err != nil {
|
|
if errors.Is(err, errs.ErrUserWrongTOTP) {
|
|
c.Response.BadRequestMessage(g, "Invalid token")
|
|
return
|
|
}
|
|
if ok := c.handleErrors(g, err); !ok {
|
|
return
|
|
}
|
|
}
|
|
// disable TOTP
|
|
err = c.UserService.DisableTOTP(
|
|
g.Request.Context(),
|
|
&userID,
|
|
)
|
|
// handle response
|
|
if err != nil {
|
|
if errors.Is(err, errs.ErrUserWrongTOTP) {
|
|
c.Response.BadRequestMessage(g, "Invalid token")
|
|
return
|
|
}
|
|
if ok := c.handleErrors(g, err); !ok {
|
|
return
|
|
}
|
|
}
|
|
c.Response.OK(
|
|
g,
|
|
"TOTP disabled",
|
|
)
|
|
}
|
|
|
|
// VerifyTOTP verifies a TOTP
|
|
func (c *User) VerifyTOTP(g *gin.Context) {
|
|
_, user, ok := c.handleSession(g)
|
|
if !ok {
|
|
return
|
|
}
|
|
// parse req
|
|
var req UserVerifyTOTPRequest
|
|
if ok := c.handleParseRequest(g, &req); !ok {
|
|
return
|
|
}
|
|
totp, err := vo.NewString64(req.TOTP)
|
|
if err != nil {
|
|
c.Logger.Debugw("failed to create TOTP",
|
|
"error", err,
|
|
)
|
|
c.Response.ValidationFailed(g, "TOTP", err)
|
|
return
|
|
}
|
|
// verify TOTP
|
|
userID := user.ID.MustGet()
|
|
err = c.UserService.CheckTOTP(
|
|
g.Request.Context(),
|
|
&userID,
|
|
totp,
|
|
)
|
|
if errors.Is(err, errs.ErrUserWrongTOTP) {
|
|
c.Response.BadRequestMessage(g, "Invalid token")
|
|
return
|
|
}
|
|
// handle response
|
|
if ok := c.handleErrors(g, err); !ok {
|
|
return
|
|
}
|
|
c.Response.OK(
|
|
g,
|
|
"TOTP verified",
|
|
)
|
|
}
|