global: add null safety checks

This commit is contained in:
Abdullah Atta
2025-10-14 21:15:51 +05:00
parent be432dfd24
commit 6e35edb715
109 changed files with 452 additions and 590 deletions

View File

@@ -17,6 +17,7 @@ You should have received a copy of the Affero GNU General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
using System;
using System.Collections.Generic;
using System.ComponentModel;
using System.Linq;
@@ -24,9 +25,6 @@ using System.Security.Claims;
using System.Text.Json;
using System.Threading.Tasks;
using AspNetCore.Identity.Mongo.Model;
using IdentityServer4;
using IdentityServer4.Configuration;
using IdentityServer4.Models;
using IdentityServer4.Stores;
using Microsoft.AspNetCore.Authorization;
using Microsoft.AspNetCore.Identity;
@@ -74,7 +72,7 @@ namespace Streetwriters.Identity.Controllers
var client = Clients.FindClientById(clientId);
if (client == null) return BadRequest("Invalid client_id.");
var user = await UserManager.FindByIdAsync(userId);
var user = await UserManager.FindByIdAsync(userId) ?? throw new Exception("User not found.");
if (!await UserService.IsUserValidAsync(UserManager, user, clientId)) return BadRequest($"Unable to find user with ID '{userId}'.");
switch (type)
@@ -86,16 +84,13 @@ namespace Streetwriters.Identity.Controllers
var result = await UserManager.ConfirmEmailAsync(user, code);
if (!result.Succeeded) return BadRequest(result.Errors.ToErrors());
if (await UserManager.IsInRoleAsync(user, client.Id))
if (await UserManager.IsInRoleAsync(user, client.Id) && client.OnEmailConfirmed != null)
{
await client.OnEmailConfirmed(userId);
}
if (!await UserManager.GetTwoFactorEnabledAsync(user))
{
await MFAService.EnableMFAAsync(user, MFAMethods.Email);
user = await UserManager.GetUserAsync(User);
}
var redirectUrl = $"{client.EmailConfirmedRedirectURL}?userId={userId}";
return RedirectPermanent(redirectUrl);
@@ -122,11 +117,12 @@ namespace Streetwriters.Identity.Controllers
var client = Clients.FindClientById(User.FindFirstValue("client_id"));
if (client == null) return BadRequest("Invalid client_id.");
var user = await UserManager.GetUserAsync(User);
var user = await UserManager.GetUserAsync(User) ?? throw new Exception("User not found.");
if (!await UserService.IsUserValidAsync(UserManager, user, client.Id)) return BadRequest($"Unable to find user with ID '{UserManager.GetUserId(User)}'.");
if (string.IsNullOrEmpty(newEmail))
{
ArgumentNullException.ThrowIfNull(user.Email);
var code = await UserManager.GenerateEmailConfirmationTokenAsync(user);
var callbackUrl = Url.TokenLink(user.Id.ToString(), code, client.Id, TokenType.CONFRIM_EMAIL);
await EmailSender.SendConfirmationEmailAsync(user.Email, callbackUrl, client);
@@ -144,7 +140,7 @@ namespace Streetwriters.Identity.Controllers
{
var client = Clients.FindClientById(User.FindFirstValue("client_id"));
if (client == null) return BadRequest("Invalid client_id.");
var user = await UserManager.GetUserAsync(User);
var user = await UserManager.GetUserAsync(User) ?? throw new Exception("User not found.");
return Ok(UserAccountService.GetUserAsync(client.Id, user.Id.ToString()));
}
@@ -156,7 +152,7 @@ namespace Streetwriters.Identity.Controllers
var client = Clients.FindClientById(form.ClientId);
if (client == null) return BadRequest("Invalid client_id.");
var user = await UserManager.FindByEmailAsync(form.Email);
var user = await UserManager.FindByEmailAsync(form.Email) ?? throw new Exception("User not found.");
if (!await UserService.IsUserValidAsync(UserManager, user, form.ClientId)) return Ok();
var code = await UserManager.GenerateUserTokenAsync(user, TokenOptions.DefaultProvider, "ResetPassword");
@@ -176,7 +172,7 @@ namespace Streetwriters.Identity.Controllers
var client = Clients.FindClientById(User.FindFirstValue("client_id"));
if (client == null) return BadRequest("Invalid client_id.");
var user = await UserManager.GetUserAsync(User);
var user = await UserManager.GetUserAsync(User) ?? throw new Exception("User not found.");
if (!await UserService.IsUserValidAsync(UserManager, user, client.Id)) return BadRequest($"Unable to find user with ID '{UserManager.GetUserId(User)}'.");
var subjectId = User.FindFirstValue("sub");
@@ -187,7 +183,7 @@ namespace Streetwriters.Identity.Controllers
ClientId = client.Id,
SubjectId = subjectId
});
grants = grants.Where((grant) => grant.Data.Contains(jti));
grants = jti == null ? [] : grants.Where((grant) => grant.Data.Contains(jti));
if (grants.Any())
{
foreach (var grant in grants)
@@ -203,7 +199,7 @@ namespace Streetwriters.Identity.Controllers
public async Task<IActionResult> GetAccessTokenFromCode([FromForm] GetAccessTokenForm form)
{
if (!Clients.IsValidClient(form.ClientId)) return BadRequest("Invalid clientId.");
var user = await UserManager.FindByIdAsync(form.UserId);
var user = await UserManager.FindByIdAsync(form.UserId) ?? throw new Exception("User not found.");
if (!await UserService.IsUserValidAsync(UserManager, user, form.ClientId))
return BadRequest($"Unable to find user with ID '{form.UserId}'.");
@@ -224,7 +220,7 @@ namespace Streetwriters.Identity.Controllers
var client = Clients.FindClientById(User.FindFirstValue("client_id"));
if (client == null) return BadRequest("Invalid client_id.");
var user = await UserManager.GetUserAsync(User);
var user = await UserManager.GetUserAsync(User) ?? throw new Exception("User not found.");
if (!await UserService.IsUserValidAsync(UserManager, user, client.Id))
return BadRequest($"Unable to find user with ID '{UserManager.GetUserId(User)}'.");
@@ -232,6 +228,9 @@ namespace Streetwriters.Identity.Controllers
{
case "change_email":
{
ArgumentNullException.ThrowIfNull(form.NewEmail);
ArgumentNullException.ThrowIfNull(form.Password);
ArgumentNullException.ThrowIfNull(form.VerificationCode);
var result = await UserManager.ChangeEmailAsync(user, form.NewEmail, form.VerificationCode);
if (result.Succeeded)
{
@@ -251,6 +250,8 @@ namespace Streetwriters.Identity.Controllers
}
case "change_password":
{
ArgumentNullException.ThrowIfNull(form.OldPassword);
ArgumentNullException.ThrowIfNull(form.NewPassword);
var result = await UserManager.ChangePasswordAsync(user, form.OldPassword, form.NewPassword);
if (result.Succeeded)
{
@@ -261,6 +262,7 @@ namespace Streetwriters.Identity.Controllers
}
case "reset_password":
{
ArgumentNullException.ThrowIfNull(form.NewPassword);
var result = await UserManager.RemovePasswordAsync(user);
if (result.Succeeded)
{
@@ -295,7 +297,7 @@ namespace Streetwriters.Identity.Controllers
var client = Clients.FindClientById(User.FindFirstValue("client_id"));
if (client == null) return BadRequest("Invalid client_id.");
var user = await UserManager.GetUserAsync(User);
var user = await UserManager.GetUserAsync(User) ?? throw new Exception("User not found.");
if (!await UserService.IsUserValidAsync(UserManager, user, client.Id)) return BadRequest($"Unable to find user with ID '{user.Id}'.");
var jti = User.FindFirstValue("jti");

View File

@@ -51,7 +51,7 @@ namespace Streetwriters.Identity.Controllers
var client = Clients.FindClientById(User.FindFirstValue("client_id"));
if (client == null) return BadRequest("Invalid client_id.");
var user = await UserManager.GetUserAsync(User);
var user = await UserManager.GetUserAsync(User) ?? throw new Exception("User not found.");
try
{
@@ -83,7 +83,7 @@ namespace Streetwriters.Identity.Controllers
[HttpGet("codes")]
public async Task<IActionResult> GetRecoveryCodes()
{
var user = await UserManager.GetUserAsync(User);
var user = await UserManager.GetUserAsync(User) ?? throw new Exception("User not found.");
if (!await UserManager.GetTwoFactorEnabledAsync(user)) return BadRequest("Please enable 2FA.");
return Ok(await UserManager.GenerateNewTwoFactorRecoveryCodesAsync(user, 16));
}
@@ -97,7 +97,7 @@ namespace Streetwriters.Identity.Controllers
var client = Clients.FindClientById(User.FindFirstValue("client_id"));
if (client == null) return BadRequest("Invalid client_id.");
var user = await UserManager.FindByIdAsync(User.FindFirstValue("sub"));
var user = await UserManager.GetUserAsync(User);
if (user == null) return Ok(); // We cannot expose that the user doesn't exist.
await MFAService.SendOTPAsync(user, client, new MultiFactorSetupForm
@@ -111,7 +111,7 @@ namespace Streetwriters.Identity.Controllers
[HttpPatch]
public async Task<IActionResult> EnableAuthenticator([FromForm] MultiFactorEnableForm form)
{
var user = await UserManager.GetUserAsync(User);
var user = await UserManager.GetUserAsync(User) ?? throw new Exception("User not found.");
if (!await MFAService.VerifyOTPAsync(user, form.VerificationCode, form.Type))
return BadRequest("Invalid verification code.");

View File

@@ -88,6 +88,7 @@ namespace Streetwriters.Identity.Controllers
if (result.Errors.Any((e) => e.Code == "DuplicateEmail"))
{
var user = await UserManager.FindByEmailAsync(form.Email);
if (user == null) return BadRequest(new string[] { "User not found." });
if (!await UserManager.IsInRoleAsync(user, client.Id))
{
@@ -114,6 +115,8 @@ namespace Streetwriters.Identity.Controllers
if (result.Succeeded)
{
var user = await UserManager.FindByEmailAsync(form.Email);
if (user == null) return BadRequest(new string[] { "User not found after creation." });
await UserManager.AddToRoleAsync(user, client.Id);
if (Constants.IS_SELF_HOSTED)
{
@@ -124,7 +127,10 @@ namespace Streetwriters.Identity.Controllers
await UserManager.AddClaimAsync(user, new Claim("platform", PlatformFromUserAgent(base.HttpContext.Request.Headers.UserAgent)));
var code = await UserManager.GenerateEmailConfirmationTokenAsync(user);
var callbackUrl = Url.TokenLink(user.Id.ToString(), code, client.Id, TokenType.CONFRIM_EMAIL);
await EmailSender.SendConfirmationEmailAsync(user.Email, callbackUrl, client);
if (!string.IsNullOrEmpty(user.Email) && callbackUrl != null)
{
await EmailSender.SendConfirmationEmailAsync(user.Email, callbackUrl, client);
}
}
return Ok(new
{
@@ -141,8 +147,9 @@ namespace Streetwriters.Identity.Controllers
}
}
string PlatformFromUserAgent(string userAgent)
static string PlatformFromUserAgent(string? userAgent)
{
if (string.IsNullOrEmpty(userAgent)) return "unknown";
return userAgent.Contains("okhttp/") ? "android" : userAgent.Contains("Darwin/") || userAgent.Contains("CFNetwork/") ? "ios" : "web";
}
}

View File

@@ -33,14 +33,14 @@ namespace Microsoft.AspNetCore.Http
/// <param name="context">Http context</param>
/// <param name="allowForwarded">Whether to allow x-forwarded-for header check</param>
/// <returns>IPAddress</returns>
public static IPAddress GetRemoteIPAddress(this HttpContext context, bool allowForwarded = true)
public static IPAddress? GetRemoteIPAddress(this HttpContext context, bool allowForwarded = true)
{
if (allowForwarded)
{
// if you are allowing these forward headers, please ensure you are restricting context.Connection.RemoteIpAddress
// to cloud flare ips: https://www.cloudflare.com/ips/
string header = (context.Request.Headers["CF-Connecting-IP"].FirstOrDefault() ?? context.Request.Headers["X-Forwarded-For"].FirstOrDefault());
if (IPAddress.TryParse(header, out IPAddress ip))
string? header = context.Request.Headers["CF-Connecting-IP"].FirstOrDefault() ?? context.Request.Headers["X-Forwarded-For"].FirstOrDefault();
if (IPAddress.TryParse(header, out IPAddress? ip))
{
return ip;
}
@@ -48,12 +48,12 @@ namespace Microsoft.AspNetCore.Http
return context.Connection.RemoteIpAddress;
}
static UserAgentService userAgentService = new UserAgentService();
static readonly UserAgentService userAgentService = new();
public static string GetClientInfo(this HttpContext httpContext)
{
var clientIp = httpContext.GetRemoteIPAddress().ToString();
var clientIp = httpContext.GetRemoteIPAddress()?.ToString();
var country = httpContext.Request.Headers["CF-IPCountry"];
var userAgent = httpContext.Request.Headers["User-Agent"];
var userAgent = httpContext.Request.Headers.UserAgent;
var builder = new StringBuilder();
builder.AppendLine($"Date: {DateTime.UtcNow.ToString("yyyy-MM-dd HH:mm:ss")}");

View File

@@ -55,10 +55,9 @@ namespace Microsoft.AspNetCore.Authentication
return Task.FromResult(true);
}
Task<AuthenticationTicket> ITicketStore.RetrieveAsync(string key)
Task<AuthenticationTicket?> ITicketStore.RetrieveAsync(string key)
{
AuthenticationTicket ticket;
_cache.TryGetValue(key, out ticket);
_cache.TryGetValue(key, out AuthenticationTicket? ticket);
return Task.FromResult(ticket);
}

View File

@@ -29,7 +29,7 @@ namespace Microsoft.AspNetCore.Mvc
{
public static class UrlHelperExtensions
{
public static string TokenLink(this IUrlHelper urlHelper, string userId, string code, string clientId, TokenType type)
public static string? TokenLink(this IUrlHelper urlHelper, string userId, string code, string clientId, TokenType type)
{
return urlHelper.ActionLink(

View File

@@ -24,8 +24,10 @@ namespace Microsoft.AspNetCore.Identity
{
public static class UserManagerExtensions
{
public static async Task<User> FindRegisteredUserAsync(this UserManager<User> userManager, string email, string clientId)
public static async Task<User?> FindRegisteredUserAsync(this UserManager<User> userManager, string? email, string clientId)
{
if (email == null) return null;
var user = await userManager.FindByEmailAsync(email);
return user != null && await userManager.IsInRoleAsync(user, clientId) ? user : null;
}

View File

@@ -31,7 +31,7 @@ namespace Streetwriters.Identity.Interfaces
Task<bool> ResetMFAAsync(User user);
Task SetSecondaryMethodAsync(User user, string secondaryMethod);
string GetPrimaryMethod(User user);
string GetSecondaryMethod(User user);
string? GetSecondaryMethod(User user);
Task<int> GetRemainingValidCodesAsync(User user);
bool IsValidMFAMethod(string method);
bool IsValidMFAMethod(string method, User user);

View File

@@ -36,7 +36,7 @@ namespace Streetwriters.Identity.MessageHandlers
var client = Clients.FindClientByAppId(subscription.AppId);
if (client == null || user == null) return;
IdentityUserClaim<string> statusClaim = user.Claims.FirstOrDefault((c) => c.ClaimType == UserService.GetClaimKey(client.Id));
IdentityUserClaim<string>? statusClaim = user.Claims.FirstOrDefault((c) => c.ClaimType == UserService.GetClaimKey(client.Id));
Claim subscriptionClaim = UserService.SubscriptionPlanToClaim(client.Id, subscription);
if (statusClaim?.ClaimValue == subscriptionClaim.Value) return;
if (statusClaim != null)

View File

@@ -34,7 +34,7 @@ namespace Streetwriters.Identity.MessageHandlers
var client = Clients.FindClientByAppId(message.AppId);
if (client == null || user == null) return;
IdentityUserClaim<string> statusClaim = user.Claims.FirstOrDefault((c) => c.ClaimType == $"{client.Id}:status");
IdentityUserClaim<string>? statusClaim = user.Claims.FirstOrDefault((c) => c.ClaimType == $"{client.Id}:status");
if (statusClaim != null)
{
await userManager.RemoveClaimAsync(user, statusClaim.ToClaim());

View File

@@ -21,12 +21,12 @@ namespace Streetwriters.Identity.Models
{
public class AuthenticatorDetails
{
public string SharedKey
public required string SharedKey
{
get; set;
}
public string AuthenticatorUri
public required string AuthenticatorUri
{
get; set;
}

View File

@@ -31,7 +31,7 @@ namespace Streetwriters.Identity.Models
[Required]
[BindProperty(Name = "email")]
[EmailAddress]
public string NewEmail
public required string NewEmail
{
get; set;
}

View File

@@ -27,21 +27,21 @@ namespace Streetwriters.Identity.Models
{
[Required]
[BindProperty(Name = "authorization_code")]
public string Code
public required string Code
{
get; set;
}
[Required]
[BindProperty(Name = "user_id")]
public string UserId
public required string UserId
{
get; set;
}
[Required]
[BindProperty(Name = "client_id")]
public string ClientId
public required string ClientId
{
get; set;
}

View File

@@ -24,6 +24,6 @@ namespace Streetwriters.Identity.Models
public class MFAPasswordRequiredResponse
{
[JsonPropertyName("token")]
public string Token { get; set; }
public required string Token { get; set; }
}
}

View File

@@ -24,12 +24,12 @@ namespace Streetwriters.Identity.Models
public class MFARequiredResponse
{
[JsonPropertyName("primaryMethod")]
public string PrimaryMethod { get; set; }
public required string PrimaryMethod { get; set; }
[JsonPropertyName("secondaryMethod")]
public string SecondaryMethod { get; set; }
public string? SecondaryMethod { get; set; }
[JsonPropertyName("token")]
public string Token { get; set; }
public string? Token { get; set; }
[JsonPropertyName("phoneNumber")]
public string PhoneNumber { get; set; }
public string? PhoneNumber { get; set; }
}
}

View File

@@ -28,14 +28,14 @@ namespace Streetwriters.Identity.Models
[DataType(DataType.Text)]
[Display(Name = "Authenticator type")]
[BindProperty(Name = "type")]
public string Type { get; set; }
public required string Type { get; set; }
[Required]
[StringLength(6, ErrorMessage = "The {0} must be at least {2} and at max {1} characters long.", MinimumLength = 6)]
[DataType(DataType.Text)]
[Display(Name = "Verification Code")]
[BindProperty(Name = "code")]
public string VerificationCode { get; set; }
public required string VerificationCode { get; set; }
[BindProperty(Name = "isFallback")]
public bool IsFallback { get; set; }

View File

@@ -28,10 +28,10 @@ namespace Streetwriters.Identity.Models
[Required]
[Display(Name = "Authenticator type")]
[BindProperty(Name = "type")]
public string Type { get; set; }
public required string Type { get; set; }
[Display(Name = "Phone number")]
[BindProperty(Name = "phoneNumber")]
public string PhoneNumber { get; set; }
public string? PhoneNumber { get; set; }
}
}

View File

@@ -27,14 +27,14 @@ namespace Streetwriters.Identity.Models
{
[Required]
[BindProperty(Name = "email")]
public string Email
public required string Email
{
get; set;
}
[Required]
[BindProperty(Name = "client_id")]
public string ClientId
public required string ClientId
{
get; set;
}

View File

@@ -28,7 +28,7 @@ namespace Streetwriters.Identity.Models
[Required]
[StringLength(120, ErrorMessage = "Password must be longer than or equal to 8 characters.", MinimumLength = 8)]
[BindProperty(Name = "password")]
public string Password
public required string Password
{
get; set;
}
@@ -36,20 +36,20 @@ namespace Streetwriters.Identity.Models
[Required]
[BindProperty(Name = "email")]
[EmailAddress]
public string Email
public required string Email
{
get; set;
}
[BindProperty(Name = "username")]
public string Username
public string? Username
{
get; set;
}
[Required]
[BindProperty(Name = "client_id")]
public string ClientId
public required string ClientId
{
get; set;
}

View File

@@ -30,7 +30,7 @@ namespace Streetwriters.Identity.Models
[DataType(DataType.Text)]
[Display(Name = "Authenticator code")]
[BindProperty(Name = "code")]
public string Code { get; set; }
public required string Code { get; set; }
[BindProperty(Name = "rememberMachine")]
public bool RememberMachine { get; set; }

View File

@@ -27,7 +27,7 @@ namespace Streetwriters.Identity.Models
{
[Required]
[BindProperty(Name = "type")]
public string Type
public required string Type
{
get; set;
}
@@ -39,33 +39,33 @@ namespace Streetwriters.Identity.Models
}
[BindProperty(Name = "old_password")]
public string OldPassword
public string? OldPassword
{
get; set;
}
[BindProperty(Name = "new_password")]
public string NewPassword
public string? NewPassword
{
get; set;
}
[BindProperty(Name = "password")]
public string Password
public string? Password
{
get; set;
}
[BindProperty(Name = "new_email")]
public string NewEmail
public string? NewEmail
{
get; set;
}
[BindProperty(Name = "verification_code")]
public string VerificationCode
public string? VerificationCode
{
get; set;
}

View File

@@ -57,7 +57,7 @@ namespace Streetwriters.Identity
{
options.Limits.MaxRequestBodySize = long.MaxValue;
options.ListenAnyIP(Servers.IdentityServer.Port);
if (Servers.IdentityServer.IsSecure)
if (Servers.IdentityServer.IsSecure && Servers.IdentityServer.SSLCertificate != null)
{
options.ListenAnyIP(443, listenerOptions =>
{

View File

@@ -43,9 +43,9 @@ namespace Streetwriters.Identity.Services
{
var result = await base.ProcessAsync(validationResult);
if (result.TryGetValue("sub", out object userId))
if (result.TryGetValue("sub", out object? userId) && userId != null)
{
var user = await UserManager.FindByIdAsync(userId.ToString());
var user = await UserManager.FindByIdAsync(userId.ToString() ?? "");
if (user == null || user.Claims == null) return result;
var verifiedClaim = user.Claims.Find((c) => c.ClaimType == "verified");
@@ -57,7 +57,7 @@ namespace Streetwriters.Identity.Services
user.Claims.ForEach((claim) =>
{
if (claim.ClaimType == "verified" || claim.ClaimType == "hcli") return;
if (claim.ClaimType == null || claim.ClaimType == "verified" || claim.ClaimType == "hcli") return;
result.TryAdd(claim.ClaimType, claim.ClaimValue);
});
result.TryAdd("verified", user.EmailConfirmed.ToString().ToLowerInvariant());

View File

@@ -101,18 +101,18 @@ namespace Streetwriters.Identity.Services
public string GetPrimaryMethod(User user)
{
return this.GetClaimValue(user, MFAService.PRIMARY_METHOD_CLAIM, MFAMethods.Email);
return GetClaimValue(user, MFAService.PRIMARY_METHOD_CLAIM) ?? MFAMethods.Email;
}
public string GetSecondaryMethod(User user)
public string? GetSecondaryMethod(User user)
{
return this.GetClaimValue(user, MFAService.SECONDARY_METHOD_CLAIM);
return GetClaimValue(user, MFAService.SECONDARY_METHOD_CLAIM);
}
public string GetClaimValue(User user, string claimType, string defaultValue = null)
public static string? GetClaimValue(User user, string claimType)
{
var claim = user.Claims.FirstOrDefault((c) => c.ClaimType == claimType);
return claim != null ? claim.ClaimValue : defaultValue;
return claim?.ClaimValue;
}
public Task<int> GetRemainingValidCodesAsync(User user)
@@ -158,6 +158,8 @@ namespace Streetwriters.Identity.Services
await UserManager.ResetAuthenticatorKeyAsync(user);
unformattedKey = await UserManager.GetAuthenticatorKeyAsync(user);
}
ArgumentNullException.ThrowIfNull(unformattedKey);
ArgumentNullException.ThrowIfNull(user.Email);
return new AuthenticatorDetails
{
@@ -183,10 +185,12 @@ namespace Streetwriters.Identity.Services
switch (method)
{
case "email":
ArgumentNullException.ThrowIfNull(user.Email);
string emailOTP = await UserManager.GenerateTwoFactorTokenAsync(user, TokenOptions.DefaultPhoneProvider);
await EmailSender.Send2FACodeEmailAsync(user.Email, emailOTP, client);
break;
case "sms":
ArgumentNullException.ThrowIfNull(form.PhoneNumber);
await UserManager.SetPhoneNumberAsync(user, form.PhoneNumber);
var id = await SMSSender.SendOTPAsync(form.PhoneNumber, client);
logger.LogInformation("SMS OTP sent for user: {UserId}, SMS ID: {SmsId}", user.Id, id);
@@ -200,13 +204,14 @@ namespace Streetwriters.Identity.Services
{
if (method == MFAMethods.SMS)
{
var id = this.GetClaimValue(user, MFAService.SMS_ID_CLAIM);
var id = GetClaimValue(user, MFAService.SMS_ID_CLAIM);
if (string.IsNullOrEmpty(id)) throw new Exception("Could not find associated SMS verify id. Please try sending the code again.");
if (await SMSSender.VerifyOTPAsync(id, code))
{
// Auto confirm user phone number if not confirmed
if (!await UserManager.IsPhoneNumberConfirmedAsync(user))
{
ArgumentNullException.ThrowIfNull(user.PhoneNumber);
var token = await UserManager.GenerateChangePhoneNumberTokenAsync(user, user.PhoneNumber);
await UserManager.VerifyChangePhoneNumberTokenAsync(user, token, user.PhoneNumber);
}
@@ -238,7 +243,7 @@ namespace Streetwriters.Identity.Services
return method == MFAMethods.Email || method == MFAMethods.SMS ? TokenOptions.DefaultPhoneProvider : UserManager.Options.Tokens.AuthenticatorTokenProvider;
}
private string FormatKey(string unformattedKey)
private static string FormatKey(string unformattedKey)
{
var result = new StringBuilder();
int currentPosition = 0;
@@ -255,7 +260,7 @@ namespace Streetwriters.Identity.Services
return result.ToString().ToLowerInvariant();
}
private string GenerateQrCodeUri(string email, string unformattedKey, string issuer)
private static string GenerateQrCodeUri(string email, string unformattedKey, string issuer)
{
const string AuthenticatorUriFormat = "otpauth://totp/{0}:{1}?secret={2}&issuer={0}&digits=6";

View File

@@ -43,7 +43,7 @@ namespace Streetwriters.Identity.Services
public async Task GetProfileDataAsync(ProfileDataRequestContext context)
{
User user = await UserManager.GetUserAsync(context.Subject);
User? user = await UserManager.GetUserAsync(context.Subject);
if (user == null) return;
IList<string> roles = await UserManager.GetRolesAsync(user);

View File

@@ -60,35 +60,35 @@ namespace Streetwriters.Identity.Services
EmailSender = emailSender;
}
EmailTemplate Email2FATemplate = new EmailTemplate
readonly EmailTemplate Email2FATemplate = new()
{
Html = HtmlHelper.ReadMinifiedHtmlFile("Templates/Email2FACode.html"),
Text = File.ReadAllText("Templates/Email2FACode.txt"),
Subject = "Your {{app_name}} account 2FA code",
};
EmailTemplate ConfirmEmailTemplate = new EmailTemplate
readonly EmailTemplate ConfirmEmailTemplate = new()
{
Html = HtmlHelper.ReadMinifiedHtmlFile("Templates/ConfirmEmail.html"),
Text = File.ReadAllText("Templates/ConfirmEmail.txt"),
Subject = "Confirm your {{app_name}} account",
};
EmailTemplate ConfirmChangeEmailTemplate = new EmailTemplate
readonly EmailTemplate ConfirmChangeEmailTemplate = new()
{
Html = HtmlHelper.ReadMinifiedHtmlFile("Templates/EmailChangeConfirmation.html"),
Text = File.ReadAllText("Templates/EmailChangeConfirmation.txt"),
Subject = "Change {{app_name}} account email address",
};
EmailTemplate PasswordResetEmailTemplate = new EmailTemplate
readonly EmailTemplate PasswordResetEmailTemplate = new()
{
Html = HtmlHelper.ReadMinifiedHtmlFile("Templates/ResetAccountPassword.html"),
Text = File.ReadAllText("Templates/ResetAccountPassword.txt"),
Subject = "Reset {{app_name}} account password",
};
EmailTemplate FailedLoginAlertTemplate = new EmailTemplate
readonly EmailTemplate FailedLoginAlertTemplate = new()
{
Html = HtmlHelper.ReadMinifiedHtmlFile("Templates/FailedLoginAlert.html"),
Text = File.ReadAllText("Templates/FailedLoginAlert.txt"),
@@ -97,12 +97,12 @@ namespace Streetwriters.Identity.Services
public async Task Send2FACodeEmailAsync(string email, string code, IClient client)
{
var template = new EmailTemplate
var template = new EmailTemplate()
{
Html = Email2FATemplate.Html,
Text = Email2FATemplate.Text,
Subject = Email2FATemplate.Subject,
Data = new { app_name = client.Name, code = code },
Data = new { app_name = client.Name, code },
};
await EmailSender.SendEmailAsync(email, template, client, NNGnuPGContext);
}
@@ -113,7 +113,7 @@ namespace Streetwriters.Identity.Services
IClient client
)
{
var template = new EmailTemplate
var template = new EmailTemplate()
{
Html = ConfirmEmailTemplate.Html,
Text = ConfirmEmailTemplate.Text,
@@ -129,12 +129,12 @@ namespace Streetwriters.Identity.Services
IClient client
)
{
var template = new EmailTemplate
var template = new EmailTemplate()
{
Html = ConfirmChangeEmailTemplate.Html,
Text = ConfirmChangeEmailTemplate.Text,
Subject = ConfirmChangeEmailTemplate.Subject,
Data = new { app_name = client.Name, code = code },
Data = new { app_name = client.Name, code },
};
await EmailSender.SendEmailAsync(email, template, client, NNGnuPGContext);
}
@@ -145,7 +145,7 @@ namespace Streetwriters.Identity.Services
IClient client
)
{
var template = new EmailTemplate
var template = new EmailTemplate()
{
Html = PasswordResetEmailTemplate.Html,
Text = PasswordResetEmailTemplate.Text,
@@ -157,7 +157,7 @@ namespace Streetwriters.Identity.Services
public async Task SendFailedLoginAlertAsync(string email, string deviceInfo, IClient client)
{
var template = new EmailTemplate
var template = new EmailTemplate()
{
Html = FailedLoginAlertTemplate.Html,
Text = FailedLoginAlertTemplate.Text,
@@ -176,7 +176,7 @@ namespace Streetwriters.Identity.Services
{
IConfiguration PgpKeySettings { get; set; } = pgpKeySettings;
protected override string GetPasswordForKey(PgpSecretKey key)
protected override string? GetPasswordForKey(PgpSecretKey key)
{
return PgpKeySettings[key.KeyId.ToString("X")];
}

View File

@@ -15,7 +15,7 @@ namespace Streetwriters.Identity.Services
public async Task<UserModel?> GetUserAsync(string clientId, string userId)
{
var user = await userManager.FindByIdAsync(userId);
if (!await UserService.IsUserValidAsync(userManager, user, clientId))
if (user == null || !await UserService.IsUserValidAsync(userManager, user, clientId))
return null;
var claims = await userManager.GetClaimsAsync(user);
@@ -25,7 +25,9 @@ namespace Streetwriters.Identity.Services
{
await mfaService.EnableMFAAsync(user, MFAMethods.Email);
user = await userManager.FindByIdAsync(userId);
ArgumentNullException.ThrowIfNull(user);
}
ArgumentNullException.ThrowIfNull(user.Email);
return new UserModel
{
@@ -46,7 +48,7 @@ namespace Streetwriters.Identity.Services
public async Task DeleteUserAsync(string clientId, string userId, string password)
{
var user = await userManager.FindByIdAsync(userId);
if (!await UserService.IsUserValidAsync(userManager, user, clientId)) return;
if (user == null || !await UserService.IsUserValidAsync(userManager, user, clientId)) return;
if (!await userManager.CheckPasswordAsync(user, password)) throw new Exception("Wrong password.");

View File

@@ -32,7 +32,7 @@ namespace Streetwriters.Identity.Services
private static SubscriptionPlan? GetUserSubscriptionPlan(string clientId, User user)
{
var claimKey = GetClaimKey(clientId);
var status = user.Claims.FirstOrDefault((c) => c.ClaimType == claimKey).ClaimValue;
var status = user.Claims.FirstOrDefault((c) => c.ClaimType == claimKey)?.ClaimValue;
switch (status)
{
case "free":

View File

@@ -263,7 +263,7 @@ namespace Streetwriters.Identity
});
}
private void AddOperationalStore(IServiceCollection services, TokenCleanupOptions tokenCleanUpOptions = null)
private static void AddOperationalStore(IServiceCollection services, TokenCleanupOptions? tokenCleanUpOptions = null)
{
BsonClassMap.RegisterClassMap<PersistedGrant>(cm =>
{
@@ -279,7 +279,7 @@ namespace Streetwriters.Identity
{
q.UseMicrosoftDependencyInjectionJobFactory();
if (tokenCleanUpOptions.Enable)
if (tokenCleanUpOptions?.Enable == true)
{
var jobKey = new JobKey("TokenCleanupJob");
q.AddJob<TokenCleanupJob>(opts => opts.WithIdentity(jobKey));

View File

@@ -3,6 +3,7 @@
<PropertyGroup>
<TargetFramework>net8.0</TargetFramework>
<StartupObject>Streetwriters.Identity.Program</StartupObject>
<Nullable>enable</Nullable>
</PropertyGroup>
<ItemGroup>

View File

@@ -33,7 +33,7 @@ namespace Streetwriters.Identity.Validation
/// <returns></returns>
public static BearerTokenUsageValidationResult ValidateAuthorizationHeader(HttpContext context)
{
var authorizationHeader = context.Request.Headers["Authorization"].FirstOrDefault();
var authorizationHeader = context.Request.Headers.Authorization.FirstOrDefault();
if (!string.IsNullOrEmpty(authorizationHeader))
{
var header = authorizationHeader.Trim();

View File

@@ -84,7 +84,7 @@ namespace Streetwriters.Identity.Validation
var mfaCode = context.Request.Raw["mfa:code"];
var mfaMethod = context.Request.Raw["mfa:method"];
if (string.IsNullOrEmpty(mfaCode) || !MFAService.IsValidMFAMethod(mfaMethod, user))
if (string.IsNullOrEmpty(mfaCode) || string.IsNullOrEmpty(mfaMethod) || !MFAService.IsValidMFAMethod(mfaMethod, user))
{
var sendPhoneNumber = primaryMethod == MFAMethods.SMS || secondaryMethod == MFAMethods.SMS;
@@ -95,7 +95,7 @@ namespace Streetwriters.Identity.Validation
["error_description"] = "Multifactor authentication required.",
["data"] = JsonSerializer.Serialize(new MFARequiredResponse
{
PhoneNumber = sendPhoneNumber ? Regex.Replace(user.PhoneNumber, @"\d(?!\d{0,3}$)", "*") : null,
PhoneNumber = sendPhoneNumber && user.PhoneNumber != null ? Regex.Replace(user.PhoneNumber, @"\d(?!\d{0,3}$)", "*") : null,
PrimaryMethod = primaryMethod,
SecondaryMethod = secondaryMethod,
Token = token,
@@ -117,7 +117,6 @@ namespace Streetwriters.Identity.Validation
}
else
{
var provider = mfaMethod == MFAMethods.Email || mfaMethod == MFAMethods.SMS ? TokenOptions.DefaultPhoneProvider : UserManager.Options.Tokens.AuthenticatorTokenProvider;
var isMFACodeValid = await MFAService.VerifyOTPAsync(user, mfaCode, mfaMethod);
if (!isMFACodeValid)
{

View File

@@ -64,20 +64,17 @@ namespace Streetwriters.Identity.Validation
{
var email = context.Request.Raw["email"];
var clientId = context.Request.ClientId;
var user = await UserManager.FindRegisteredUserAsync(email, clientId);
if (user == null)
var user = await UserManager.FindRegisteredUserAsync(email, clientId) ?? new User
{
user = new User
{
Id = MongoDB.Bson.ObjectId.GenerateNewId(),
Email = email,
UserName = email,
NormalizedEmail = email,
NormalizedUserName = email,
EmailConfirmed = false,
SecurityStamp = ""
};
}
Id = MongoDB.Bson.ObjectId.GenerateNewId(),
Email = email,
UserName = email,
NormalizedEmail = email,
NormalizedUserName = email,
EmailConfirmed = false,
SecurityStamp = ""
};
var isMultiFactor = await UserManager.GetTwoFactorEnabledAsync(user);
var primaryMethod = isMultiFactor ? MFAService.GetPrimaryMethod(user) : MFAMethods.Email;
@@ -88,7 +85,7 @@ namespace Streetwriters.Identity.Validation
{
["additional_data"] = new MFARequiredResponse
{
PhoneNumber = sendPhoneNumber ? Regex.Replace(user.PhoneNumber, @"\d(?!\d{0,3}$)", "*") : null,
PhoneNumber = sendPhoneNumber && user.PhoneNumber != null ? Regex.Replace(user.PhoneNumber, @"\d(?!\d{0,3}$)", "*") : null,
PrimaryMethod = primaryMethod,
SecondaryMethod = secondaryMethod,
}
@@ -96,12 +93,5 @@ namespace Streetwriters.Identity.Validation
context.Result.IsError = false;
context.Result.Subject = await TokenGenerationService.TransformTokenRequestAsync(context.Request, user, GrantType, new string[] { Config.MFA_GRANT_TYPE_SCOPE });
}
string Pluralize(int? value, string singular, string plural)
{
if (value == null) return $"0 {plural}";
return value == 1 ? $"{value} {singular}" : $"{value} {plural}";
}
}
}

View File

@@ -67,6 +67,8 @@ namespace Streetwriters.Identity.Validation
context.Result = new GrantValidationResult(TokenRequestErrors.InvalidGrant);
var httpContext = HttpContextAccessor.HttpContext;
if (httpContext == null) return;
var tokenResult = BearerTokenValidator.ValidateAuthorizationHeader(httpContext);
if (!tokenResult.TokenFound) return;

View File

@@ -58,6 +58,8 @@ namespace Streetwriters.Identity.Validation
context.Result = new GrantValidationResult(TokenRequestErrors.InvalidGrant);
var httpContext = HttpContextAccessor.HttpContext;
if (httpContext == null) return;
var tokenResult = BearerTokenValidator.ValidateAuthorizationHeader(httpContext);
if (!tokenResult.TokenFound) return;