Skip to content

Commit

Permalink
Prevent TrnMatchPolicy and TrnRequirementType from being overriden (#741
Browse files Browse the repository at this point in the history
)
  • Loading branch information
gunndabad authored Oct 17, 2023
1 parent 5d044d6 commit ce4382b
Show file tree
Hide file tree
Showing 4 changed files with 53 additions and 61 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ public class AuthorizationController : Controller
private readonly TeacherIdentityServerDbContext _dbContext;
private readonly IClock _clock;
private readonly TrnTokenHelper _trnTokenHelper;
private readonly IConfiguration _configuration;

public AuthorizationController(
TeacherIdentityApplicationManager applicationManager,
Expand All @@ -36,7 +37,8 @@ public AuthorizationController(
UserClaimHelper userClaimHelper,
TeacherIdentityServerDbContext dbContext,
IClock clock,
TrnTokenHelper trnTokenHelper)
TrnTokenHelper trnTokenHelper,
IConfiguration configuration)
{
_applicationManager = applicationManager;
_authorizationManager = authorizationManager;
Expand All @@ -46,6 +48,7 @@ public AuthorizationController(
_dbContext = dbContext;
_clock = clock;
_trnTokenHelper = trnTokenHelper;
_configuration = configuration;
}

[HttpGet("~/connect/authorize")]
Expand Down Expand Up @@ -94,32 +97,54 @@ public async Task<IActionResult> Authorize()

if (userRequirements.HasFlag(UserRequirements.TrnHolder))
{
trnRequirementType = await GetTrnRequirementType(request);
trnMatchPolicy = await GetTrnMatchPolicy(request);
var client = (await _applicationManager.FindByClientIdAsync(request.ClientId!))!;
var allowTrnConfigurationOverrides = client.ClientId == "testclient" || _configuration.GetValue<bool>("AllowTrnConfigurationOverrides", false);

if (trnRequirementType is null)
if (allowTrnConfigurationOverrides)
{
return Forbid(
authenticationSchemes: OpenIddictServerAspNetCoreDefaults.AuthenticationScheme,
properties: new AuthenticationProperties(new Dictionary<string, string?>()
var requestedTrnRequirement = request["trn_requirement"];
if (requestedTrnRequirement.HasValue)
{
if (Enum.TryParse<TrnRequirementType>(requestedTrnRequirement?.Value as string, out var parsedTrnRequirementType))
{
[OpenIddictServerAspNetCoreConstants.Properties.Error] = Errors.InvalidRequest,
[OpenIddictServerAspNetCoreConstants.Properties.ErrorDescription] =
"Invalid trn_requirement specified."
}));
}

if (trnMatchPolicy is null)
{
return Forbid(
authenticationSchemes: OpenIddictServerAspNetCoreDefaults.AuthenticationScheme,
properties: new AuthenticationProperties(new Dictionary<string, string?>()
trnRequirementType = parsedTrnRequirementType;
}
else
{
return Forbid(
authenticationSchemes: OpenIddictServerAspNetCoreDefaults.AuthenticationScheme,
properties: new AuthenticationProperties(new Dictionary<string, string?>()
{
[OpenIddictServerAspNetCoreConstants.Properties.Error] = Errors.InvalidRequest,
[OpenIddictServerAspNetCoreConstants.Properties.ErrorDescription] =
"Invalid trn_requirement specified."
}));
}
}

var requestedTrnMatchPolicy = request["trn_match_policy"];
if (requestedTrnMatchPolicy.HasValue)
{
if (Enum.TryParse<TrnMatchPolicy>(requestedTrnMatchPolicy?.Value as string, out var parsedTrnMatchPolicy))
{
trnMatchPolicy = parsedTrnMatchPolicy;
}
else
{
[OpenIddictServerAspNetCoreConstants.Properties.Error] = Errors.InvalidRequest,
[OpenIddictServerAspNetCoreConstants.Properties.ErrorDescription] =
"Invalid trn_match_policy specified."
}));
return Forbid(
authenticationSchemes: OpenIddictServerAspNetCoreDefaults.AuthenticationScheme,
properties: new AuthenticationProperties(new Dictionary<string, string?>()
{
[OpenIddictServerAspNetCoreConstants.Properties.Error] = Errors.InvalidRequest,
[OpenIddictServerAspNetCoreConstants.Properties.ErrorDescription] =
"Invalid trn_match_policy specified."
}));
}
}
}

trnRequirementType ??= client.TrnRequirementType;
trnMatchPolicy ??= client.TrnMatchPolicy;
}

var sessionId = request["session_id"]?.Value as string;
Expand Down Expand Up @@ -490,42 +515,6 @@ private static IEnumerable<string> GetDestinations(Claim claim, ClaimsPrincipal
return signedInUser;
}

private async Task<TrnRequirementType?> GetTrnRequirementType(OpenIddictRequest request)
{
var requestedTrnRequirement = request["trn_requirement"];

if (requestedTrnRequirement.HasValue)
{
if (!Enum.TryParse<TrnRequirementType>(requestedTrnRequirement?.Value as string, out var parsedTrnRequirementType))
{
return null;
}

return parsedTrnRequirementType;
}

var client = (await _applicationManager.FindByClientIdAsync(request.ClientId!))!;
return client.TrnRequirementType;
}

private async Task<TrnMatchPolicy?> GetTrnMatchPolicy(OpenIddictRequest request)
{
var trnMatchPolicy = request["trn_match_policy"];

if (trnMatchPolicy.HasValue)
{
if (!Enum.TryParse<TrnMatchPolicy>(trnMatchPolicy?.Value as string, out var parsedTrnMatchPolicy))
{
return null;
}

return parsedTrnMatchPolicy;
}

var client = (await _applicationManager.FindByClientIdAsync(request.ClientId!))!;
return client.TrnMatchPolicy;
}

private async Task<ClaimsPrincipal?> InitializeAuthenticationState(User? signedInUser, EnhancedTrnToken? trnToken,
AuthenticationState authenticationState)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,5 +68,6 @@
"UserVerification": {
"UseFixedPin": true,
"Pin": "00000"
}
},
"AllowTrnConfigurationOverrides": true
}
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@
"UserVerification": {
"UseFixedPin": true
},
"BlockEstablishmentEmailDomains": true
"BlockEstablishmentEmailDomains": true,
"AllowTrnConfigurationOverrides": true
},
"TestClient": {
"ClientId": "testclient",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,5 +10,6 @@
"BlockEstablishmentEmailDomains": true,
"WebHooks": {
"WebHooksCacheDurationSeconds": 30
}
},
"AllowTrnConfigurationOverrides": true
}

0 comments on commit ce4382b

Please sign in to comment.