From 0ad00c97470178ca29e36db9e8108ac67a65a50d Mon Sep 17 00:00:00 2001 From: Abdullah Atta Date: Thu, 8 Jun 2023 12:55:27 +0500 Subject: [PATCH] identity: make 2fa truly mandatory --- .../Controllers/AccountController.cs | 7 +++++++ .../Controllers/MFAController.cs | 16 ++-------------- .../Controllers/SignupController.cs | 3 +++ Streetwriters.Identity/Services/MFAService.cs | 6 +++--- 4 files changed, 15 insertions(+), 17 deletions(-) diff --git a/Streetwriters.Identity/Controllers/AccountController.cs b/Streetwriters.Identity/Controllers/AccountController.cs index dba6fce..b5dc51c 100644 --- a/Streetwriters.Identity/Controllers/AccountController.cs +++ b/Streetwriters.Identity/Controllers/AccountController.cs @@ -29,6 +29,7 @@ using Microsoft.AspNetCore.Authorization; using Microsoft.AspNetCore.Identity; using Microsoft.AspNetCore.Mvc; using Streetwriters.Common; +using Streetwriters.Common.Enums; using Streetwriters.Common.Messages; using Streetwriters.Common.Models; using Streetwriters.Identity.Enums; @@ -172,6 +173,12 @@ namespace Streetwriters.Identity.Controllers var claims = await UserManager.GetClaimsAsync(user); var marketingConsentClaim = claims.FirstOrDefault((claim) => claim.Type == $"{client.Id}:marketing_consent"); + if (!await UserManager.GetTwoFactorEnabledAsync(user)) + { + await MFAService.EnableMFAAsync(user, MFAMethods.Email); + user = await UserManager.GetUserAsync(User); + } + return Ok(new UserModel { UserId = user.Id.ToString(), diff --git a/Streetwriters.Identity/Controllers/MFAController.cs b/Streetwriters.Identity/Controllers/MFAController.cs index 899d6f1..1dee0e9 100644 --- a/Streetwriters.Identity/Controllers/MFAController.cs +++ b/Streetwriters.Identity/Controllers/MFAController.cs @@ -74,21 +74,9 @@ namespace Streetwriters.Identity.Controllers } [HttpDelete] - public async Task Disable2FA() + public IActionResult Disable2FA() { - var user = await UserManager.GetUserAsync(User); - - if (!await UserManager.GetTwoFactorEnabledAsync(user)) - { - return BadRequest("Cannot disable 2FA as it's not currently enabled"); - } - - if (await MFAService.DisableMFAAsync(user)) - { - return Ok(); - } - - return BadRequest("Failed to disable 2FA."); + return BadRequest("2FA is mandatory and cannot be disabled."); } [HttpGet("codes")] diff --git a/Streetwriters.Identity/Controllers/SignupController.cs b/Streetwriters.Identity/Controllers/SignupController.cs index 18a606b..f1f0d1c 100644 --- a/Streetwriters.Identity/Controllers/SignupController.cs +++ b/Streetwriters.Identity/Controllers/SignupController.cs @@ -27,6 +27,7 @@ using Microsoft.AspNetCore.Authorization; using Microsoft.AspNetCore.Identity; using Microsoft.AspNetCore.Mvc; using Streetwriters.Common; +using Streetwriters.Common.Enums; using Streetwriters.Common.Models; using Streetwriters.Identity.Enums; using Streetwriters.Identity.Interfaces; @@ -109,6 +110,8 @@ namespace Streetwriters.Identity.Controllers var callbackUrl = Url.TokenLink(user.Id.ToString(), code, client.Id, TokenType.CONFRIM_EMAIL, Request.Scheme); await EmailSender.SendConfirmationEmailAsync(user.Email, callbackUrl, client); + await MFAService.EnableMFAAsync(user, MFAMethods.Email); + return Ok(new { userId = user.Id.ToString() diff --git a/Streetwriters.Identity/Services/MFAService.cs b/Streetwriters.Identity/Services/MFAService.cs index e72ece2..af61907 100644 --- a/Streetwriters.Identity/Services/MFAService.cs +++ b/Streetwriters.Identity/Services/MFAService.cs @@ -82,7 +82,7 @@ namespace Streetwriters.Identity.Services public string GetPrimaryMethod(User user) { - return this.GetClaimValue(user, MFAService.PRIMARY_METHOD_CLAIM); + return this.GetClaimValue(user, MFAService.PRIMARY_METHOD_CLAIM, MFAMethods.Email); } public string GetSecondaryMethod(User user) @@ -90,10 +90,10 @@ namespace Streetwriters.Identity.Services return this.GetClaimValue(user, MFAService.SECONDARY_METHOD_CLAIM); } - public string GetClaimValue(User user, string claimType) + public string GetClaimValue(User user, string claimType, string defaultValue = null) { var claim = user.Claims.FirstOrDefault((c) => c.ClaimType == claimType); - return claim != null ? claim.ClaimValue : null; + return claim != null ? claim.ClaimValue : defaultValue; } public Task GetRemainingValidCodesAsync(User user)