fix bug and trottle

Signed-off-by: Ronni Skansing <rskansing@gmail.com>
This commit is contained in:
Ronni Skansing
2026-06-07 17:27:07 +02:00
parent 1f0f138652
commit 1586e3eaec
6 changed files with 170 additions and 61 deletions
+5 -2
View File
@@ -280,8 +280,11 @@ func setupRoutes(
controllers *Controllers,
middleware *Middlewares,
) *gin.Engine {
// scim v2 provisioning endpoints — no ip allowlist, bearer-token auth only
r.
// scim v2 provisioning endpoints — no ip allowlist, bearer-token auth only.
// rate limited per company so cloud IdPs sharing source IPs across tenants
// do not throttle each other.
scim := r.Group("/", middleware.ScimRateLimiter)
scim.
GET(ROUTE_SCIM_V2_SERVICE_PROVIDER_CONFIG, controllers.Scim.ServiceProviderConfig).
GET(ROUTE_SCIM_V2_RESOURCE_TYPES, controllers.Scim.ResourceTypes).
GET(ROUTE_SCIM_V2_SCHEMAS, controllers.Scim.Schemas).
+8
View File
@@ -13,6 +13,7 @@ import (
type Middlewares struct {
IPLimiter gin.HandlerFunc
LoginRateLimiter gin.HandlerFunc
ScimRateLimiter gin.HandlerFunc
SessionHandler gin.HandlerFunc
SoftSessionHandler gin.HandlerFunc
}
@@ -31,6 +32,12 @@ func NewMiddlewares(
requestPerSecond, // requests per second
requestBurst, // burst
)
// per-company SCIM limiter: each company gets its own bucket so cloud IdPs
// sharing source IPs across tenants do not throttle each other
scimThrottle := middleware.NewScimRateLimiterMiddleware(
20, // requests per second per company
40, // burst
)
sessionHandler := middleware.NewSessionHandler(
services.Session,
services.User,
@@ -46,6 +53,7 @@ func NewMiddlewares(
return &Middlewares{
IPLimiter: ipLimiter,
LoginRateLimiter: loginThrottle,
ScimRateLimiter: scimThrottle,
SessionHandler: sessionHandler,
SoftSessionHandler: softSessionHandler,
}
+2 -2
View File
@@ -240,7 +240,7 @@ func (c *Scim) ListGroups(g *gin.Context) {
}
base := scimBaseURL(g, companyID)
startIndex := parseIntQuery(g, "startIndex", 1)
count := parseIntQuery(g, "count", 0)
count := parseIntQuery(g, "count", -1)
filter := g.Query("filter")
excludedAttributes := g.Query("excludedAttributes")
@@ -447,7 +447,7 @@ func (c *Scim) ListUsers(g *gin.Context) {
base := scimBaseURL(g, companyID)
filter := g.Query("filter")
startIndex := parseIntQuery(g, "startIndex", 1)
count := parseIntQuery(g, "count", 0)
count := parseIntQuery(g, "count", -1)
sortBy := g.Query("sortBy")
sortOrder := g.Query("sortOrder")
+22
View File
@@ -33,6 +33,28 @@ func NewIPRateLimiterMiddleware(limit float64, burst int) gin.HandlerFunc {
}
}
// NewScimRateLimiterMiddleware limits SCIM requests per company rather than per
// IP. cloud identity providers (Microsoft Entra, Okta) send every tenant's SCIM
// traffic from a shared pool of source IPs, so an IP based limit would throttle
// all companies together. keying on the companyID path param gives each company
// its own bucket. requests without a companyID fall back to the client IP.
// limit is requests per second, burst is the maximum burst size.
func NewScimRateLimiterMiddleware(limit float64, burst int) gin.HandlerFunc {
companyLimiter := NewKeyRateLimiter(rate.Limit(limit), burst, 10*time.Minute)
return func(c *gin.Context) {
key := c.Param("companyID")
if key == "" {
key = c.ClientIP()
}
limiter := companyLimiter.GetLimiter(key)
if !limiter.Allow() {
c.AbortWithStatus(http.StatusTooManyRequests)
return
}
c.Next()
}
}
// KeyRateLimiter is a rate limiter for key such as username, email or IP
type KeyRateLimiter struct {
// key is a map of key to limiterEntry
+27
View File
@@ -746,6 +746,33 @@ func (r *Recipient) GetByEmail(
return ToRecipient(&dbRecipient)
}
// GetByEmailLowerAndCompanyID looks up a recipient by a case-insensitive email
// match within a company. the supplied email is compared with LOWER() on both
// sides so that differently cased addresses (e.g. John@X.com vs john@x.com) are
// treated as the same recipient. used by SCIM provisioning to dedupe.
func (r *Recipient) GetByEmailLowerAndCompanyID(
ctx context.Context,
email *vo.Email,
companyID *uuid.UUID,
fields ...string,
) (*model.Recipient, error) {
var dbRecipient database.Recipient
emailCol := TableColumn(database.RECIPIENT_TABLE, "email")
companyCol := TableColumn(database.RECIPIENT_TABLE, "company_id")
q := r.DB.Where(
fmt.Sprintf("LOWER(%s) = LOWER(?) AND %s = ?", emailCol, companyCol),
email.String(),
companyID,
)
fields = assignTableToColumns(database.RECIPIENT_TABLE, fields)
q = useSelect(q, fields)
res := q.First(&dbRecipient)
if res.Error != nil {
return nil, res.Error
}
return ToRecipient(&dbRecipient)
}
func (r *Recipient) GetByEmailAndCompanyID(
ctx context.Context,
email *vo.Email,
+106 -57
View File
@@ -598,8 +598,11 @@ func (s *Scim) ListGroupsRaw(
}
all = all[offset:]
// apply count
if count > 0 && count < len(all) {
// apply count — 0 returns zero resources (RFC 7644 §3.4.2.4); a negative or
// absent value means no limit
if count == 0 {
all = []ScimGroup{}
} else if count > 0 && count < len(all) {
all = all[:count]
}
@@ -684,6 +687,7 @@ func (s *Scim) CreateGroup(
s.Logger.Errorw("scim create group: failed to reload group", "error", err)
return nil, errs.Wrap(err)
}
s.auditScim("Scim.CreateGroup", config, map[string]any{"groupID": groupID.String()})
g := recipientGroupToScimGroup(created, baseURL)
return &g, nil
}
@@ -736,6 +740,7 @@ func (s *Scim) ReplaceGroup(
if err != nil {
return nil, errs.Wrap(err)
}
s.auditScim("Scim.ReplaceGroup", config, map[string]any{"groupID": groupID.String()})
g := recipientGroupToScimGroup(updated, baseURL)
return &g, nil
}
@@ -793,6 +798,7 @@ func (s *Scim) PatchGroup(
if err != nil {
return nil, errs.Wrap(err)
}
s.auditScim("Scim.PatchGroup", config, map[string]any{"groupID": groupID.String()})
g := recipientGroupToScimGroup(updated, baseURL)
return &g, nil
}
@@ -827,6 +833,7 @@ func (s *Scim) DeleteGroup(
s.Logger.Errorw("scim delete group: failed to delete group", "error", err)
return errs.Wrap(err)
}
s.auditScim("Scim.DeleteGroup", config, map[string]any{"groupID": groupID.String()})
return nil
}
@@ -908,10 +915,21 @@ func (s *Scim) ListUsers(
return nil, errs.Wrap(err)
}
// load group memberships once so each user can report its groups
groupsByRecipient := map[uuid.UUID][]ScimUserGroup{}
if groupList, gErr := s.RecipientGroupRepository.GetAllByCompanyID(ctx, companyID, &repository.RecipientGroupOption{WithRecipients: true}); gErr != nil {
s.Logger.Warnw("scim list users: failed to load group memberships", "error", gErr)
} else {
groupsByRecipient = buildGroupsByRecipient(groupList)
}
// build the full filtered list first so totalResults is accurate
all := make([]ScimUser, 0, len(recipientResult.Rows))
for _, r := range recipientResult.Rows {
u := recipientToScimUser(r, baseURL)
if rid, idErr := r.ID.Get(); idErr == nil {
u.Groups = groupsByRecipient[rid]
}
if filter != "" && !scimFilterMatchesUser(filter, u) {
continue
}
@@ -935,8 +953,11 @@ func (s *Scim) ListUsers(
}
all = all[offset:]
// apply count
if count > 0 && count < len(all) {
// apply count — 0 returns zero resources (RFC 7644 §3.4.2.4); a negative or
// absent value means no limit
if count == 0 {
all = []ScimUser{}
} else if count > 0 && count < len(all) {
all = all[:count]
}
@@ -971,6 +992,7 @@ func (s *Scim) GetUser(
return nil, errs.Wrap(gorm.ErrRecordNotFound)
}
u := recipientToScimUser(recipient, baseURL)
u.Groups = s.groupsForRecipient(ctx, companyID, recipientID, baseURL)
return &u, nil
}
@@ -993,15 +1015,9 @@ func (s *Scim) CreateUser(
return nil, errs.NewValidationError(fmt.Errorf("invalid email %q: %w", email, err))
}
// dedup lookup uses lowercase so matching is case-insensitive
emailLower, _ := canonicalEmailLower(scimUser)
emailLowerVO, lookupErr := vo.NewEmail(emailLower)
if lookupErr != nil {
emailLowerVO = emailVO
}
// reject duplicate userName — rfc 7644 requires 409 for uniqueness conflicts
existingByEmail, err := s.RecipientRepository.GetByEmailAndCompanyID(ctx, emailLowerVO, companyID)
// reject duplicate userName — rfc 7644 requires 409 for uniqueness conflicts.
// the lookup is case-insensitive so John@X.com and john@x.com collide.
existingByEmail, err := s.RecipientRepository.GetByEmailLowerAndCompanyID(ctx, emailVO, companyID)
if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) {
s.Logger.Errorw("scim create user: lookup by email failed", "error", err, "email", email)
return nil, errs.Wrap(err)
@@ -1033,6 +1049,10 @@ func (s *Scim) CreateUser(
s.Logger.Errorw("scim create user: failed to reload recipient", "error", err)
return nil, errs.Wrap(err)
}
// note: active=false on create is not separately representable — a recipient
// either exists (active) or is deprovisioned (deleted). the resource is still
// created so the IdP receives a retrievable 201 response.
s.auditScim("Scim.CreateUser", config, map[string]any{"recipientID": recipientID.String()})
u := recipientToScimUser(created, baseURL)
return &u, nil
}
@@ -1059,6 +1079,14 @@ func (s *Scim) ReplaceUser(
if compErr != nil || rCompanyID != *companyID {
return nil, errs.Wrap(gorm.ErrRecordNotFound)
}
// a PUT with active=false is a deprovision request — hard-delete the recipient
if !scimUser.Active {
if err := s.deprovisionRecipient(ctx, recipientID); err != nil {
return nil, errs.Wrap(err)
}
s.auditScim("Scim.DeprovisionUser", config, map[string]any{"recipientID": recipientID.String(), "via": "replace"})
return nil, errs.Wrap(gorm.ErrRecordNotFound)
}
if err := s.applyScimUserToRecipient(ctx, existing, scimUser); err != nil {
return nil, err
}
@@ -1071,6 +1099,7 @@ func (s *Scim) ReplaceUser(
if err != nil {
return nil, errs.Wrap(err)
}
s.auditScim("Scim.ReplaceUser", config, map[string]any{"recipientID": recipientID.String()})
u := recipientToScimUser(updated, baseURL)
return &u, nil
}
@@ -1108,17 +1137,16 @@ func (s *Scim) PatchUser(
}
// active=false triggers a hard-delete; nothing more to do
if deactivated {
s.auditScim("Scim.DeprovisionUser", config, map[string]any{"recipientID": recipientID.String(), "via": "patch"})
return nil, errs.Wrap(gorm.ErrRecordNotFound)
}
case "remove":
// remove op on "active" means deactivate — hard-delete the recipient
if strings.EqualFold(op.Path, "active") {
if err := s.RecipientGroupRepository.RemoveRecipientByIDFromAllGroups(ctx, recipientID); err != nil {
s.Logger.Warnw("scim patch remove active: failed to remove from groups", "error", err)
}
if err := s.RecipientRepository.DeleteByID(ctx, recipientID); err != nil {
if err := s.deprovisionRecipient(ctx, recipientID); err != nil {
return nil, errs.Wrap(err)
}
s.auditScim("Scim.DeprovisionUser", config, map[string]any{"recipientID": recipientID.String(), "via": "patch"})
return nil, errs.Wrap(gorm.ErrRecordNotFound)
}
}
@@ -1128,6 +1156,7 @@ func (s *Scim) PatchUser(
if err != nil {
return nil, errs.Wrap(err)
}
s.auditScim("Scim.PatchUser", config, map[string]any{"recipientID": recipientID.String()})
u := recipientToScimUser(updated, baseURL)
return &u, nil
}
@@ -1153,11 +1182,11 @@ func (s *Scim) DeprovisionUser(
if compErr != nil || rCompanyID != *companyID {
return errs.Wrap(gorm.ErrRecordNotFound)
}
// remove from all groups first to avoid orphan join rows
if err := s.RecipientGroupRepository.RemoveRecipientByIDFromAllGroups(ctx, recipientID); err != nil {
s.Logger.Warnw("scim deprovision user: failed to remove from groups", "error", err)
if err := s.deprovisionRecipient(ctx, recipientID); err != nil {
return errs.Wrap(err)
}
return s.RecipientRepository.DeleteByID(ctx, recipientID)
s.auditScim("Scim.DeprovisionUser", config, map[string]any{"recipientID": recipientID.String(), "via": "delete"})
return nil
}
// VerifyAndLoadConfig authenticates the bearer token against the stored hash
@@ -1194,6 +1223,35 @@ func (s *Scim) UpdateLastSync(ctx context.Context, config *model.CompanyScimConf
// ── helpers ───────────────────────────────────────────────────────────────────
// deprovisionRecipient removes a recipient from all groups and hard-deletes it.
// shared by DELETE, PUT active=false and PATCH active=false.
func (s *Scim) deprovisionRecipient(ctx context.Context, recipientID *uuid.UUID) error {
if err := s.RecipientGroupRepository.RemoveRecipientByIDFromAllGroups(ctx, recipientID); err != nil {
s.Logger.Warnw("scim deprovision: failed to remove recipient from groups", "error", err)
}
return s.RecipientRepository.DeleteByID(ctx, recipientID)
}
// auditScim emits an audit event for an externally driven SCIM mutation.
// SCIM has no admin session, so the actor is identified by the company and the
// token prefix instead of a user id.
func (s *Scim) auditScim(name string, config *model.CompanyScimConfig, details map[string]any) {
ae := NewAuditEvent(name, nil)
ae.Details["actor"] = "scim"
if config != nil {
if cid, err := config.CompanyID.Get(); err == nil {
ae.Details["companyID"] = cid.String()
}
if tp, err := config.TokenPrefix.Get(); err == nil {
ae.Details["scimTokenPrefix"] = tp
}
}
for k, v := range details {
ae.Details[k] = v
}
s.AuditLogAuthorized(ae)
}
// groupsForRecipient returns the ScimUserGroup list for a recipient by scanning
// all company groups for membership.
func (s *Scim) groupsForRecipient(
@@ -1499,53 +1557,42 @@ func (s *Scim) applyScimUserToRecipient(
existing *model.Recipient,
scimUser *ScimUser,
) error {
// always update the stored scim userName so it round-trips
// PUT is a full replace (RFC 7644 §3.5.1): attributes absent from the
// request are cleared. email is the one exception — it is required, so an
// absent or invalid email leaves the existing address untouched.
existing.ScimUserName.Set(*vo.NewOptionalString127Must(truncate(scimUserNameFrom(scimUser), 127)))
// email
if email, err := canonicalEmail(scimUser); err == nil && email != "" {
// email — stored lowercased for case-insensitive matching
if email, err := canonicalEmailLower(scimUser); err == nil && email != "" {
if ev, err := vo.NewEmail(email); err == nil {
existing.Email.Set(*ev)
}
}
// first name
if fn := firstNameFrom(scimUser); fn != "" {
existing.FirstName.Set(*vo.NewOptionalString127Must(truncate(fn, 127)))
}
// last name
if ln := lastNameFrom(scimUser); ln != "" {
existing.LastName.Set(*vo.NewOptionalString127Must(truncate(ln, 127)))
}
// first and last name
existing.FirstName.Set(*vo.NewOptionalString127Must(truncate(firstNameFrom(scimUser), 127)))
existing.LastName.Set(*vo.NewOptionalString127Must(truncate(lastNameFrom(scimUser), 127)))
// phone
if phone := primaryPhoneFrom(scimUser); phone != "" {
existing.Phone.Set(*vo.NewOptionalString127Must(truncate(phone, 127)))
}
existing.Phone.Set(*vo.NewOptionalString127Must(truncate(primaryPhoneFrom(scimUser), 127)))
// department and title from enterprise extension
department := ""
title := ""
if scimUser.EnterpriseUser != nil {
if scimUser.EnterpriseUser.Department != "" {
existing.Department.Set(*vo.NewOptionalString127Must(truncate(scimUser.EnterpriseUser.Department, 127)))
}
if scimUser.EnterpriseUser.Title != "" {
existing.Position.Set(*vo.NewOptionalString127Must(truncate(scimUser.EnterpriseUser.Title, 127)))
}
department = scimUser.EnterpriseUser.Department
title = scimUser.EnterpriseUser.Title
}
existing.Department.Set(*vo.NewOptionalString127Must(truncate(department, 127)))
existing.Position.Set(*vo.NewOptionalString127Must(truncate(title, 127)))
// addresses — city and country from primary/work address
if len(scimUser.Addresses) > 0 {
city, country := primaryAddressFrom(scimUser)
if city != "" {
existing.City.Set(*vo.NewOptionalString127Must(truncate(city, 127)))
}
if country != "" {
existing.Country.Set(*vo.NewOptionalString127Must(truncate(country, 127)))
}
}
city, country := primaryAddressFrom(scimUser)
existing.City.Set(*vo.NewOptionalString127Must(truncate(city, 127)))
existing.Country.Set(*vo.NewOptionalString127Must(truncate(country, 127)))
// externalId -> extra_identifier
if scimUser.ExternalID != "" {
existing.ExtraIdentifier.Set(*vo.NewOptionalString127Must(truncate(scimUser.ExternalID, 127)))
}
existing.ExtraIdentifier.Set(*vo.NewOptionalString127Must(truncate(scimUser.ExternalID, 127)))
// misc from custom extension
if scimUser.CustomExtension != nil && scimUser.CustomExtension.Misc != "" {
existing.Misc.Set(*vo.NewOptionalString127Must(truncate(scimUser.CustomExtension.Misc, 127)))
misc := ""
if scimUser.CustomExtension != nil {
misc = scimUser.CustomExtension.Misc
}
existing.Misc.Set(*vo.NewOptionalString127Must(truncate(misc, 127)))
id := existing.ID.MustGet()
if err := s.RecipientRepository.UpdateByID(ctx, &id, existing); err != nil {
@@ -1636,7 +1683,7 @@ func (s *Scim) applyPatchOperation(
case "username":
existing.ScimUserName.Set(*vo.NewOptionalString127Must(truncate(strVal, 127)))
case "emails[type eq \"work\"].value", "emails":
if ev, err := vo.NewEmail(strVal); err == nil {
if ev, err := vo.NewEmail(strings.ToLower(strings.TrimSpace(strVal))); err == nil {
existing.Email.Set(*ev)
}
case "name.givenname":
@@ -1880,7 +1927,9 @@ func scimUserToRecipient(scimUser *ScimUser, companyID *uuid.UUID) *model.Recipi
// store the original userName so it round-trips exactly
r.ScimUserName = nullable.NewNullableWithValue(*vo.NewOptionalString127Must(truncate(scimUserNameFrom(scimUser), 127)))
emailStr, _ := canonicalEmail(scimUser)
// email is stored lowercased so dedup and the unique index are case
// insensitive; the original userName case is preserved in scim_user_name
emailStr, _ := canonicalEmailLower(scimUser)
if emailStr != "" {
if ev, err := vo.NewEmail(emailStr); err == nil {
r.Email = nullable.NewNullableWithValue(*ev)