From cb0ad7ac9a5d7dc73d2ff6c19858bc54a6bcaadd Mon Sep 17 00:00:00 2001 From: Abdullah Atta Date: Fri, 7 Jun 2024 10:49:57 +0500 Subject: [PATCH] api: improve pro authorization handling --- .../Authorization/ProUserRequirement.cs | 34 ++++++++++++++++--- .../Authorization/SyncRequirement.cs | 12 +++---- .../AuthorizationResultTransformer.cs | 2 +- 3 files changed, 35 insertions(+), 13 deletions(-) diff --git a/Notesnook.API/Authorization/ProUserRequirement.cs b/Notesnook.API/Authorization/ProUserRequirement.cs index d0b0233..b8a61c8 100644 --- a/Notesnook.API/Authorization/ProUserRequirement.cs +++ b/Notesnook.API/Authorization/ProUserRequirement.cs @@ -17,21 +17,47 @@ You should have received a copy of the Affero GNU General Public License along with this program. If not, see . */ +using System.Collections.Generic; +using System.Diagnostics; using System.Linq; +using System.Security.Claims; +using System.Text.Json; using System.Threading.Tasks; using Microsoft.AspNetCore.Authorization; +using Microsoft.AspNetCore.Http; namespace Notesnook.API.Authorization { public class ProUserRequirement : AuthorizationHandler, IAuthorizationRequirement { - private readonly string[] allowedClaims = { "trial", "premium", "premium_canceled" }; + private readonly Dictionary pathErrorPhraseMap = new() + { + ["/s3"] = "upload attachments", + ["/s3/multipart"] = "upload attachments", + }; + private readonly string[] allowedClaims = ["trial", "premium", "premium_canceled"]; protected override Task HandleRequirementAsync(AuthorizationHandlerContext context, ProUserRequirement requirement) { - var isProOrTrial = context.User.HasClaim((c) => c.Type == "notesnook:status" && allowedClaims.Contains(c.Value)); - if (isProOrTrial) - context.Succeed(requirement); + PathString path = context.Resource is DefaultHttpContext httpContext ? httpContext.Request.Path : null; + var isProOrTrial = context.User.Claims.Any((c) => c.Type == "notesnook:status" && allowedClaims.Contains(c.Value)); + if (isProOrTrial) context.Succeed(requirement); + else + { + var phrase = "continue"; + foreach (var item in pathErrorPhraseMap) + { + if (path != null && path.StartsWithSegments(item.Key)) + phrase = item.Value; + } + var error = $"Please upgrade to Pro to {phrase}."; + context.Fail(new AuthorizationFailureReason(this, error)); + } return Task.CompletedTask; } + + public override Task HandleAsync(AuthorizationHandlerContext context) + { + return this.HandleRequirementAsync(context, this); + } } } \ No newline at end of file diff --git a/Notesnook.API/Authorization/SyncRequirement.cs b/Notesnook.API/Authorization/SyncRequirement.cs index eb9ccd2..6fd17d4 100644 --- a/Notesnook.API/Authorization/SyncRequirement.cs +++ b/Notesnook.API/Authorization/SyncRequirement.cs @@ -29,7 +29,7 @@ namespace Notesnook.API.Authorization { public class SyncRequirement : AuthorizationHandler, IAuthorizationRequirement { - private readonly Dictionary pathErrorPhraseMap = new Dictionary + private readonly Dictionary pathErrorPhraseMap = new() { ["/sync/attachments"] = "use attachments", ["/sync"] = "sync your notes", @@ -43,13 +43,9 @@ namespace Notesnook.API.Authorization PathString path = context.Resource is DefaultHttpContext httpContext ? httpContext.Request.Path : null; var result = this.IsAuthorized(context.User, path); if (result.Succeeded) context.Succeed(requirement); - else - { - var hasReason = result.AuthorizationFailure.FailureReasons.Any(); - if (hasReason) - context.Fail(result.AuthorizationFailure.FailureReasons.First()); - else context.Fail(); - } + else if (result.AuthorizationFailure.FailureReasons.Any()) + context.Fail(result.AuthorizationFailure.FailureReasons.First()); + else context.Fail(); return Task.CompletedTask; } diff --git a/Notesnook.API/Extensions/AuthorizationResultTransformer.cs b/Notesnook.API/Extensions/AuthorizationResultTransformer.cs index 8e1b87d..90f65bd 100644 --- a/Notesnook.API/Extensions/AuthorizationResultTransformer.cs +++ b/Notesnook.API/Extensions/AuthorizationResultTransformer.cs @@ -48,7 +48,7 @@ namespace Notesnook.API.Extensions { var error = string.Join("\n", policyAuthorizationResult.AuthorizationFailure.FailureReasons.Select((r) => r.Message)); - if (!string.IsNullOrEmpty(error) && !isWebsocket) + if (!string.IsNullOrEmpty(error)) { httpContext.Response.StatusCode = (int)HttpStatusCode.Unauthorized; httpContext.Response.ContentType = "application/json";