Files
phishingclub/backend/controller/utils.go
2025-08-21 16:14:09 +02:00

359 lines
8.7 KiB
Go

package controller
import (
"bytes"
"encoding/csv"
"fmt"
"io"
"mime"
"path/filepath"
"strings"
"time"
"github.com/go-errors/errors"
"github.com/gin-gonic/gin"
"github.com/google/uuid"
"github.com/phishingclub/phishingclub/api"
"github.com/phishingclub/phishingclub/errs"
"github.com/phishingclub/phishingclub/model"
"github.com/phishingclub/phishingclub/service"
"github.com/phishingclub/phishingclub/utils"
"github.com/phishingclub/phishingclub/vo"
"go.uber.org/zap"
"gorm.io/gorm"
)
// Common is a common controller base struct it holds common operations on the
// common dependencies
type Common struct {
Response api.JSONResponseHandler
Logger *zap.SugaredLogger
SessionService *service.Session
}
// handleSession handles the session and returns the session and user
// if the session is not valid, a 401 response is sent
func (c *Common) handleSession(
g *gin.Context,
) (*model.Session, *model.User, bool) {
s, ok := g.Get("session")
if !ok {
c.Logger.Debug("session not found in context")
c.Response.Unauthorized(g)
return nil, nil, false
}
session, ok := s.(*model.Session)
if !ok {
c.Logger.Error("session in context is not of type model.Session")
c.Response.Unauthorized(g)
return nil, nil, false
}
user := session.User
if user == nil {
c.Logger.Error("user not found in session")
c.Response.Unauthorized(g)
return nil, nil, false
}
return session, user, true
}
// HandleParseRequest parses the request and returns true if successful
// if the request is not parsable, a 400 response is sent
func (c *Common) handleParseRequest(
g *gin.Context,
req any,
) bool {
body, err := io.ReadAll(g.Request.Body)
if err != nil {
c.Logger.Debugw("failed to read request body",
"error", err,
)
c.Response.BadRequest(g)
return false
}
if err := utils.Unmarshal(body, &req); err != nil {
c.Logger.Debugw("failed to parse request",
"error", err,
)
c.Response.BadRequestMessage(g, err.Error())
return false
}
return true
}
// handleParseIDParam parses the id parameter from the request
// and returns it if successful
// if the id is not parsable, a 400 response is sent
func (c *Common) handleParseIDParam(
g *gin.Context,
) (*uuid.UUID, bool) {
id, err := uuid.Parse(g.Param("id"))
if err != nil {
c.Logger.Debugw("failed to parse id",
"error", err,
)
c.Response.BadRequestMessage(g, errs.MsgFailedToParseUUID)
return nil, false
}
return &id, true
}
// handlePagination parses the pagination from the request and returns it
// if the pagination is not valid, a 400 response is sent
func (c *Common) handlePagination(
g *gin.Context,
) (*vo.Pagination, bool) {
pagination, err := vo.NewPaginationFromRequest(g)
if err != nil {
c.Logger.Debugw("invalid offset or limit",
"error", err,
)
c.Response.ValidationFailed(g, "pagination", err)
return nil, false
}
return pagination, true
}
// handleQueryArgs parses the query from the request and returns it
func (c *Common) handleQueryArgs(g *gin.Context) (*vo.QueryArgs, bool) {
q, err := vo.QueryFromRequest(g)
if err != nil {
c.Logger.Debugw("failed to parse query",
"error", err,
)
c.Response.ValidationFailed(g, "query args", err)
return nil, false
}
return q, true
}
// handleErrors is a helper function to handle common handleErrors
// it most often checks for more than what is needed, but is
// useful to avoid missing any error handling and saving time
// it returns true if no errors are found
// it returns false if an error is found and a response is sent
func (c *Common) handleErrors(
g *gin.Context,
err error,
) bool {
if err != nil {
if ok := handleAuthorizationError(g, c.Response, err); !ok {
c.Logger.Debugw("authorization error",
"auth_error", err,
)
return false
}
if ok := handleValidationError(g, c.Response, err); !ok {
c.Logger.Debugw("validation error",
"validation_error", err,
)
return false
}
if ok := handleCustomError(g, c.Response, err); !ok {
c.Logger.Debugw("custom error",
"custom_error", err,
)
return false
}
if ok := handleDBRowNotFound(g, c.Response, err); !ok {
c.Logger.Debugw("DB row not found error",
"error", err,
)
return false
}
c.Logger.Errorw("API unknown error type", "error", err)
_ = handleServerError(g, c.Response, err)
return false
}
return true
}
// responseWithCSV
func (c *Common) responseWithCSV(
g *gin.Context,
buffer *bytes.Buffer,
writer *csv.Writer,
name string,
) {
writer.Flush()
if err := writer.Error(); err != nil {
c.handleErrors(g, err)
return
}
// Set CSV response headers
setSecureContentDisposition(g, name)
g.Header("Content-Type", "text/csv")
g.Header("Content-Length", fmt.Sprint(buffer.Len()))
// Write the CSV buffer to the response
_, err := g.Writer.Write(buffer.Bytes())
if err != nil {
c.handleErrors(g, err)
}
}
// responseWithZIP
func (c *Common) responseWithZIP(
g *gin.Context,
buffer *bytes.Buffer,
name string,
) {
g.Header("Content-Type", "application/zip")
setSecureContentDisposition(g, name)
g.Header("Content-Transfer-Encoding", "binary")
g.Header("Expires", "0")
g.Header("Cache-Control", "must-revalidate")
g.Header("Pragma", "public")
g.Header("Content-Length", fmt.Sprintf("%d", buffer.Len()))
_, err := g.Writer.Write(buffer.Bytes())
if err != nil {
c.handleErrors(g, err)
}
}
// companyIDFromRequestQuery returns the companyID as a UUID from the query
// or nil if not found
func companyIDFromRequestQuery(g *gin.Context) *uuid.UUID {
companyID := g.Query("companyID")
if companyID != "" {
cid, err := uuid.Parse(companyID)
if err != nil {
return nil
}
return &cid
}
return nil
}
// SetSessionInGinContext sets the session in the gin context
func SetSessionInGinContext(c *gin.Context, s *model.Session) {
c.Set("session", s)
}
// handleDBRowNotFound checks if the error is a not found error
// if it is, a 404 response is sent
// if it is not, true is returned
func handleDBRowNotFound(
g *gin.Context,
responseHandler api.JSONResponseHandler,
err error,
) bool {
if errors.Is(err, gorm.ErrRecordNotFound) {
// error is logged in service
_ = err
responseHandler.NotFound(g)
return false
}
return true
}
// handleAuthorizationError checks if the error is an authorization error
// if it is, a 403 response is sent
// if it is not, true is returned
func handleAuthorizationError(
g *gin.Context,
responseHandler api.JSONResponseHandler,
err error,
) bool {
if errors.Is(err, errs.ErrAuthorizationFailed) {
// error is logged in service
_ = err
responseHandler.Forbidden(g)
return false
}
return true
}
// handleValidationError checks if the error is a validation error
// if it is, a 400 response is sent
// if it is not, true is returned
func handleValidationError(
g *gin.Context,
responseHandler api.JSONResponseHandler,
err error,
) bool {
if errors.As(err, &errs.ValidationError{}) {
// error is logged in service
_ = err
responseHandler.BadRequestMessage(g, err.Error())
return false
}
return true
}
// handleCustomError checks if the error is a custom error
// if it is a 400 response is sent
// if it is not, true is returned
func handleCustomError(
g *gin.Context,
responseHandler api.JSONResponseHandler,
err error,
) bool {
if errors.As(err, &errs.CustomError{}) {
// error is logged in service
_ = err
responseHandler.BadRequestMessage(g, err.Error())
return false
}
return true
}
// handleServerError checks if the error is a server error
// if it is, a 500 response is sent
// if it is not, true is returned
func handleServerError(
g *gin.Context,
responseHandler api.JSONResponseHandler,
err error,
) bool {
if err != nil {
// error is logged in service
_ = err
responseHandler.ServerError(g)
return false
}
return true
}
func setSecureContentDisposition(c *gin.Context, filename string) {
// Strip any directory components
filename = filepath.Base(filename)
// Remove any potentially problematic characters
filename = strings.Map(func(r rune) rune {
// Keep only alphanumeric, space, dash, underscore and dot
if (r >= 'a' && r <= 'z') ||
(r >= 'A' && r <= 'Z') ||
(r >= '0' && r <= '9') ||
(r == ' ' || r == '-' || r == '_' || r == '.') {
return r
}
return -1
}, filename)
// Ensure we still have a valid filename
if filename == "" || filename == "." || filename == ".." {
filename = time.Now().UTC().Format("20060102150405")
}
// Properly encode the filename for Content-Disposition
encodedFilename := mime.QEncoding.Encode("utf-8", filename)
c.Header("Content-Disposition",
fmt.Sprintf(`attachment; filename="%s";`,
encodedFilename,
),
)
}
func (c *Common) requiresFlag(g *gin.Context, featureFlag string) {
// handle session
_, _, ok := c.handleSession(g)
if !ok {
return
}
c.Response.ServerErrorMessage(g, "requires "+featureFlag+" edition")
}