diff --git a/backend/app/administration.go b/backend/app/administration.go index eccabe9..5f89e90 100644 --- a/backend/app/administration.go +++ b/backend/app/administration.go @@ -280,8 +280,11 @@ func setupRoutes( controllers *Controllers, middleware *Middlewares, ) *gin.Engine { - // scim v2 provisioning endpoints — no ip allowlist, bearer-token auth only - r. + // scim v2 provisioning endpoints — no ip allowlist, bearer-token auth only. + // rate limited per company so cloud IdPs sharing source IPs across tenants + // do not throttle each other. + scim := r.Group("/", middleware.ScimRateLimiter) + scim. GET(ROUTE_SCIM_V2_SERVICE_PROVIDER_CONFIG, controllers.Scim.ServiceProviderConfig). GET(ROUTE_SCIM_V2_RESOURCE_TYPES, controllers.Scim.ResourceTypes). GET(ROUTE_SCIM_V2_SCHEMAS, controllers.Scim.Schemas). diff --git a/backend/app/middleware.go b/backend/app/middleware.go index 6b83886..fa6732a 100644 --- a/backend/app/middleware.go +++ b/backend/app/middleware.go @@ -13,6 +13,7 @@ import ( type Middlewares struct { IPLimiter gin.HandlerFunc LoginRateLimiter gin.HandlerFunc + ScimRateLimiter gin.HandlerFunc SessionHandler gin.HandlerFunc SoftSessionHandler gin.HandlerFunc } @@ -31,6 +32,12 @@ func NewMiddlewares( requestPerSecond, // requests per second requestBurst, // burst ) + // per-company SCIM limiter: each company gets its own bucket so cloud IdPs + // sharing source IPs across tenants do not throttle each other + scimThrottle := middleware.NewScimRateLimiterMiddleware( + 20, // requests per second per company + 40, // burst + ) sessionHandler := middleware.NewSessionHandler( services.Session, services.User, @@ -46,6 +53,7 @@ func NewMiddlewares( return &Middlewares{ IPLimiter: ipLimiter, LoginRateLimiter: loginThrottle, + ScimRateLimiter: scimThrottle, SessionHandler: sessionHandler, SoftSessionHandler: softSessionHandler, } diff --git a/backend/controller/scim.go b/backend/controller/scim.go index b8f03ea..1971df3 100644 --- a/backend/controller/scim.go +++ b/backend/controller/scim.go @@ -240,7 +240,7 @@ func (c *Scim) ListGroups(g *gin.Context) { } base := scimBaseURL(g, companyID) startIndex := parseIntQuery(g, "startIndex", 1) - count := parseIntQuery(g, "count", 0) + count := parseIntQuery(g, "count", -1) filter := g.Query("filter") excludedAttributes := g.Query("excludedAttributes") @@ -447,7 +447,7 @@ func (c *Scim) ListUsers(g *gin.Context) { base := scimBaseURL(g, companyID) filter := g.Query("filter") startIndex := parseIntQuery(g, "startIndex", 1) - count := parseIntQuery(g, "count", 0) + count := parseIntQuery(g, "count", -1) sortBy := g.Query("sortBy") sortOrder := g.Query("sortOrder") diff --git a/backend/middleware/ratelimiter.go b/backend/middleware/ratelimiter.go index f786eea..69ee52e 100644 --- a/backend/middleware/ratelimiter.go +++ b/backend/middleware/ratelimiter.go @@ -33,6 +33,28 @@ func NewIPRateLimiterMiddleware(limit float64, burst int) gin.HandlerFunc { } } +// NewScimRateLimiterMiddleware limits SCIM requests per company rather than per +// IP. cloud identity providers (Microsoft Entra, Okta) send every tenant's SCIM +// traffic from a shared pool of source IPs, so an IP based limit would throttle +// all companies together. keying on the companyID path param gives each company +// its own bucket. requests without a companyID fall back to the client IP. +// limit is requests per second, burst is the maximum burst size. +func NewScimRateLimiterMiddleware(limit float64, burst int) gin.HandlerFunc { + companyLimiter := NewKeyRateLimiter(rate.Limit(limit), burst, 10*time.Minute) + return func(c *gin.Context) { + key := c.Param("companyID") + if key == "" { + key = c.ClientIP() + } + limiter := companyLimiter.GetLimiter(key) + if !limiter.Allow() { + c.AbortWithStatus(http.StatusTooManyRequests) + return + } + c.Next() + } +} + // KeyRateLimiter is a rate limiter for key such as username, email or IP type KeyRateLimiter struct { // key is a map of key to limiterEntry diff --git a/backend/repository/recipient.go b/backend/repository/recipient.go index 02e494f..c86d16f 100644 --- a/backend/repository/recipient.go +++ b/backend/repository/recipient.go @@ -746,6 +746,33 @@ func (r *Recipient) GetByEmail( return ToRecipient(&dbRecipient) } +// GetByEmailLowerAndCompanyID looks up a recipient by a case-insensitive email +// match within a company. the supplied email is compared with LOWER() on both +// sides so that differently cased addresses (e.g. John@X.com vs john@x.com) are +// treated as the same recipient. used by SCIM provisioning to dedupe. +func (r *Recipient) GetByEmailLowerAndCompanyID( + ctx context.Context, + email *vo.Email, + companyID *uuid.UUID, + fields ...string, +) (*model.Recipient, error) { + var dbRecipient database.Recipient + emailCol := TableColumn(database.RECIPIENT_TABLE, "email") + companyCol := TableColumn(database.RECIPIENT_TABLE, "company_id") + q := r.DB.Where( + fmt.Sprintf("LOWER(%s) = LOWER(?) AND %s = ?", emailCol, companyCol), + email.String(), + companyID, + ) + fields = assignTableToColumns(database.RECIPIENT_TABLE, fields) + q = useSelect(q, fields) + res := q.First(&dbRecipient) + if res.Error != nil { + return nil, res.Error + } + return ToRecipient(&dbRecipient) +} + func (r *Recipient) GetByEmailAndCompanyID( ctx context.Context, email *vo.Email, diff --git a/backend/service/scim.go b/backend/service/scim.go index 03407ec..a2fa176 100644 --- a/backend/service/scim.go +++ b/backend/service/scim.go @@ -598,8 +598,11 @@ func (s *Scim) ListGroupsRaw( } all = all[offset:] - // apply count - if count > 0 && count < len(all) { + // apply count — 0 returns zero resources (RFC 7644 §3.4.2.4); a negative or + // absent value means no limit + if count == 0 { + all = []ScimGroup{} + } else if count > 0 && count < len(all) { all = all[:count] } @@ -684,6 +687,7 @@ func (s *Scim) CreateGroup( s.Logger.Errorw("scim create group: failed to reload group", "error", err) return nil, errs.Wrap(err) } + s.auditScim("Scim.CreateGroup", config, map[string]any{"groupID": groupID.String()}) g := recipientGroupToScimGroup(created, baseURL) return &g, nil } @@ -736,6 +740,7 @@ func (s *Scim) ReplaceGroup( if err != nil { return nil, errs.Wrap(err) } + s.auditScim("Scim.ReplaceGroup", config, map[string]any{"groupID": groupID.String()}) g := recipientGroupToScimGroup(updated, baseURL) return &g, nil } @@ -793,6 +798,7 @@ func (s *Scim) PatchGroup( if err != nil { return nil, errs.Wrap(err) } + s.auditScim("Scim.PatchGroup", config, map[string]any{"groupID": groupID.String()}) g := recipientGroupToScimGroup(updated, baseURL) return &g, nil } @@ -827,6 +833,7 @@ func (s *Scim) DeleteGroup( s.Logger.Errorw("scim delete group: failed to delete group", "error", err) return errs.Wrap(err) } + s.auditScim("Scim.DeleteGroup", config, map[string]any{"groupID": groupID.String()}) return nil } @@ -908,10 +915,21 @@ func (s *Scim) ListUsers( return nil, errs.Wrap(err) } + // load group memberships once so each user can report its groups + groupsByRecipient := map[uuid.UUID][]ScimUserGroup{} + if groupList, gErr := s.RecipientGroupRepository.GetAllByCompanyID(ctx, companyID, &repository.RecipientGroupOption{WithRecipients: true}); gErr != nil { + s.Logger.Warnw("scim list users: failed to load group memberships", "error", gErr) + } else { + groupsByRecipient = buildGroupsByRecipient(groupList) + } + // build the full filtered list first so totalResults is accurate all := make([]ScimUser, 0, len(recipientResult.Rows)) for _, r := range recipientResult.Rows { u := recipientToScimUser(r, baseURL) + if rid, idErr := r.ID.Get(); idErr == nil { + u.Groups = groupsByRecipient[rid] + } if filter != "" && !scimFilterMatchesUser(filter, u) { continue } @@ -935,8 +953,11 @@ func (s *Scim) ListUsers( } all = all[offset:] - // apply count - if count > 0 && count < len(all) { + // apply count — 0 returns zero resources (RFC 7644 §3.4.2.4); a negative or + // absent value means no limit + if count == 0 { + all = []ScimUser{} + } else if count > 0 && count < len(all) { all = all[:count] } @@ -971,6 +992,7 @@ func (s *Scim) GetUser( return nil, errs.Wrap(gorm.ErrRecordNotFound) } u := recipientToScimUser(recipient, baseURL) + u.Groups = s.groupsForRecipient(ctx, companyID, recipientID, baseURL) return &u, nil } @@ -993,15 +1015,9 @@ func (s *Scim) CreateUser( return nil, errs.NewValidationError(fmt.Errorf("invalid email %q: %w", email, err)) } - // dedup lookup uses lowercase so matching is case-insensitive - emailLower, _ := canonicalEmailLower(scimUser) - emailLowerVO, lookupErr := vo.NewEmail(emailLower) - if lookupErr != nil { - emailLowerVO = emailVO - } - - // reject duplicate userName — rfc 7644 requires 409 for uniqueness conflicts - existingByEmail, err := s.RecipientRepository.GetByEmailAndCompanyID(ctx, emailLowerVO, companyID) + // reject duplicate userName — rfc 7644 requires 409 for uniqueness conflicts. + // the lookup is case-insensitive so John@X.com and john@x.com collide. + existingByEmail, err := s.RecipientRepository.GetByEmailLowerAndCompanyID(ctx, emailVO, companyID) if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) { s.Logger.Errorw("scim create user: lookup by email failed", "error", err, "email", email) return nil, errs.Wrap(err) @@ -1033,6 +1049,10 @@ func (s *Scim) CreateUser( s.Logger.Errorw("scim create user: failed to reload recipient", "error", err) return nil, errs.Wrap(err) } + // note: active=false on create is not separately representable — a recipient + // either exists (active) or is deprovisioned (deleted). the resource is still + // created so the IdP receives a retrievable 201 response. + s.auditScim("Scim.CreateUser", config, map[string]any{"recipientID": recipientID.String()}) u := recipientToScimUser(created, baseURL) return &u, nil } @@ -1059,6 +1079,14 @@ func (s *Scim) ReplaceUser( if compErr != nil || rCompanyID != *companyID { return nil, errs.Wrap(gorm.ErrRecordNotFound) } + // a PUT with active=false is a deprovision request — hard-delete the recipient + if !scimUser.Active { + if err := s.deprovisionRecipient(ctx, recipientID); err != nil { + return nil, errs.Wrap(err) + } + s.auditScim("Scim.DeprovisionUser", config, map[string]any{"recipientID": recipientID.String(), "via": "replace"}) + return nil, errs.Wrap(gorm.ErrRecordNotFound) + } if err := s.applyScimUserToRecipient(ctx, existing, scimUser); err != nil { return nil, err } @@ -1071,6 +1099,7 @@ func (s *Scim) ReplaceUser( if err != nil { return nil, errs.Wrap(err) } + s.auditScim("Scim.ReplaceUser", config, map[string]any{"recipientID": recipientID.String()}) u := recipientToScimUser(updated, baseURL) return &u, nil } @@ -1108,17 +1137,16 @@ func (s *Scim) PatchUser( } // active=false triggers a hard-delete; nothing more to do if deactivated { + s.auditScim("Scim.DeprovisionUser", config, map[string]any{"recipientID": recipientID.String(), "via": "patch"}) return nil, errs.Wrap(gorm.ErrRecordNotFound) } case "remove": // remove op on "active" means deactivate — hard-delete the recipient if strings.EqualFold(op.Path, "active") { - if err := s.RecipientGroupRepository.RemoveRecipientByIDFromAllGroups(ctx, recipientID); err != nil { - s.Logger.Warnw("scim patch remove active: failed to remove from groups", "error", err) - } - if err := s.RecipientRepository.DeleteByID(ctx, recipientID); err != nil { + if err := s.deprovisionRecipient(ctx, recipientID); err != nil { return nil, errs.Wrap(err) } + s.auditScim("Scim.DeprovisionUser", config, map[string]any{"recipientID": recipientID.String(), "via": "patch"}) return nil, errs.Wrap(gorm.ErrRecordNotFound) } } @@ -1128,6 +1156,7 @@ func (s *Scim) PatchUser( if err != nil { return nil, errs.Wrap(err) } + s.auditScim("Scim.PatchUser", config, map[string]any{"recipientID": recipientID.String()}) u := recipientToScimUser(updated, baseURL) return &u, nil } @@ -1153,11 +1182,11 @@ func (s *Scim) DeprovisionUser( if compErr != nil || rCompanyID != *companyID { return errs.Wrap(gorm.ErrRecordNotFound) } - // remove from all groups first to avoid orphan join rows - if err := s.RecipientGroupRepository.RemoveRecipientByIDFromAllGroups(ctx, recipientID); err != nil { - s.Logger.Warnw("scim deprovision user: failed to remove from groups", "error", err) + if err := s.deprovisionRecipient(ctx, recipientID); err != nil { + return errs.Wrap(err) } - return s.RecipientRepository.DeleteByID(ctx, recipientID) + s.auditScim("Scim.DeprovisionUser", config, map[string]any{"recipientID": recipientID.String(), "via": "delete"}) + return nil } // VerifyAndLoadConfig authenticates the bearer token against the stored hash @@ -1194,6 +1223,35 @@ func (s *Scim) UpdateLastSync(ctx context.Context, config *model.CompanyScimConf // ── helpers ─────────────────────────────────────────────────────────────────── +// deprovisionRecipient removes a recipient from all groups and hard-deletes it. +// shared by DELETE, PUT active=false and PATCH active=false. +func (s *Scim) deprovisionRecipient(ctx context.Context, recipientID *uuid.UUID) error { + if err := s.RecipientGroupRepository.RemoveRecipientByIDFromAllGroups(ctx, recipientID); err != nil { + s.Logger.Warnw("scim deprovision: failed to remove recipient from groups", "error", err) + } + return s.RecipientRepository.DeleteByID(ctx, recipientID) +} + +// auditScim emits an audit event for an externally driven SCIM mutation. +// SCIM has no admin session, so the actor is identified by the company and the +// token prefix instead of a user id. +func (s *Scim) auditScim(name string, config *model.CompanyScimConfig, details map[string]any) { + ae := NewAuditEvent(name, nil) + ae.Details["actor"] = "scim" + if config != nil { + if cid, err := config.CompanyID.Get(); err == nil { + ae.Details["companyID"] = cid.String() + } + if tp, err := config.TokenPrefix.Get(); err == nil { + ae.Details["scimTokenPrefix"] = tp + } + } + for k, v := range details { + ae.Details[k] = v + } + s.AuditLogAuthorized(ae) +} + // groupsForRecipient returns the ScimUserGroup list for a recipient by scanning // all company groups for membership. func (s *Scim) groupsForRecipient( @@ -1499,53 +1557,42 @@ func (s *Scim) applyScimUserToRecipient( existing *model.Recipient, scimUser *ScimUser, ) error { - // always update the stored scim userName so it round-trips + // PUT is a full replace (RFC 7644 §3.5.1): attributes absent from the + // request are cleared. email is the one exception — it is required, so an + // absent or invalid email leaves the existing address untouched. existing.ScimUserName.Set(*vo.NewOptionalString127Must(truncate(scimUserNameFrom(scimUser), 127))) - // email - if email, err := canonicalEmail(scimUser); err == nil && email != "" { + // email — stored lowercased for case-insensitive matching + if email, err := canonicalEmailLower(scimUser); err == nil && email != "" { if ev, err := vo.NewEmail(email); err == nil { existing.Email.Set(*ev) } } - // first name - if fn := firstNameFrom(scimUser); fn != "" { - existing.FirstName.Set(*vo.NewOptionalString127Must(truncate(fn, 127))) - } - // last name - if ln := lastNameFrom(scimUser); ln != "" { - existing.LastName.Set(*vo.NewOptionalString127Must(truncate(ln, 127))) - } + // first and last name + existing.FirstName.Set(*vo.NewOptionalString127Must(truncate(firstNameFrom(scimUser), 127))) + existing.LastName.Set(*vo.NewOptionalString127Must(truncate(lastNameFrom(scimUser), 127))) // phone - if phone := primaryPhoneFrom(scimUser); phone != "" { - existing.Phone.Set(*vo.NewOptionalString127Must(truncate(phone, 127))) - } + existing.Phone.Set(*vo.NewOptionalString127Must(truncate(primaryPhoneFrom(scimUser), 127))) // department and title from enterprise extension + department := "" + title := "" if scimUser.EnterpriseUser != nil { - if scimUser.EnterpriseUser.Department != "" { - existing.Department.Set(*vo.NewOptionalString127Must(truncate(scimUser.EnterpriseUser.Department, 127))) - } - if scimUser.EnterpriseUser.Title != "" { - existing.Position.Set(*vo.NewOptionalString127Must(truncate(scimUser.EnterpriseUser.Title, 127))) - } + department = scimUser.EnterpriseUser.Department + title = scimUser.EnterpriseUser.Title } + existing.Department.Set(*vo.NewOptionalString127Must(truncate(department, 127))) + existing.Position.Set(*vo.NewOptionalString127Must(truncate(title, 127))) // addresses — city and country from primary/work address - if len(scimUser.Addresses) > 0 { - city, country := primaryAddressFrom(scimUser) - if city != "" { - existing.City.Set(*vo.NewOptionalString127Must(truncate(city, 127))) - } - if country != "" { - existing.Country.Set(*vo.NewOptionalString127Must(truncate(country, 127))) - } - } + city, country := primaryAddressFrom(scimUser) + existing.City.Set(*vo.NewOptionalString127Must(truncate(city, 127))) + existing.Country.Set(*vo.NewOptionalString127Must(truncate(country, 127))) // externalId -> extra_identifier - if scimUser.ExternalID != "" { - existing.ExtraIdentifier.Set(*vo.NewOptionalString127Must(truncate(scimUser.ExternalID, 127))) - } + existing.ExtraIdentifier.Set(*vo.NewOptionalString127Must(truncate(scimUser.ExternalID, 127))) // misc from custom extension - if scimUser.CustomExtension != nil && scimUser.CustomExtension.Misc != "" { - existing.Misc.Set(*vo.NewOptionalString127Must(truncate(scimUser.CustomExtension.Misc, 127))) + misc := "" + if scimUser.CustomExtension != nil { + misc = scimUser.CustomExtension.Misc } + existing.Misc.Set(*vo.NewOptionalString127Must(truncate(misc, 127))) id := existing.ID.MustGet() if err := s.RecipientRepository.UpdateByID(ctx, &id, existing); err != nil { @@ -1636,7 +1683,7 @@ func (s *Scim) applyPatchOperation( case "username": existing.ScimUserName.Set(*vo.NewOptionalString127Must(truncate(strVal, 127))) case "emails[type eq \"work\"].value", "emails": - if ev, err := vo.NewEmail(strVal); err == nil { + if ev, err := vo.NewEmail(strings.ToLower(strings.TrimSpace(strVal))); err == nil { existing.Email.Set(*ev) } case "name.givenname": @@ -1880,7 +1927,9 @@ func scimUserToRecipient(scimUser *ScimUser, companyID *uuid.UUID) *model.Recipi // store the original userName so it round-trips exactly r.ScimUserName = nullable.NewNullableWithValue(*vo.NewOptionalString127Must(truncate(scimUserNameFrom(scimUser), 127))) - emailStr, _ := canonicalEmail(scimUser) + // email is stored lowercased so dedup and the unique index are case + // insensitive; the original userName case is preserved in scim_user_name + emailStr, _ := canonicalEmailLower(scimUser) if emailStr != "" { if ev, err := vo.NewEmail(emailStr); err == nil { r.Email = nullable.NewNullableWithValue(*ev)