-
-
Save anand374/4150380166f53fb09d92a80c7da55d84 to your computer and use it in GitHub Desktop.
Message Inspector for WCF SOAP AAD Token Validation
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
using Microsoft.IdentityModel.Tokens; | |
using System; | |
using System.Collections.Generic; | |
using System.Collections.Specialized; | |
using System.Configuration; | |
using System.IdentityModel.Tokens.Jwt; | |
using System.Linq; | |
using System.Net; | |
using System.Net.Http; | |
using System.Security.Claims; | |
using System.ServiceModel; | |
using System.ServiceModel.Channels; | |
using System.ServiceModel.Dispatcher; | |
namespace TokenValidator | |
{ | |
public class BearerTokenMessageInspector : IDispatchMessageInspector | |
{ | |
private static readonly string _audience; | |
private static readonly string _authority; | |
private static readonly OpenIdConnectCachingSecurityTokenProvider _securityTokenProviderV1; | |
private static readonly OpenIdConnectCachingSecurityTokenProvider _securityTokenProviderV2; | |
private static readonly bool _validateIssuer; | |
private static readonly bool _validateAudience; | |
private static readonly bool _validateLifetime; | |
private static readonly bool _validateIssuerSigningKey; | |
private static readonly bool _useV2; | |
private static readonly int _maxRetries; | |
private static readonly List<string> _allowedTenantIDs; | |
private static List<Microsoft.IdentityModel.Tokens.SecurityKey> _signingKeys; | |
private static string _issuer = string.Empty; | |
/// <summary> | |
/// Static Constructor to initialize all the config driven variables | |
/// </summary> | |
static BearerTokenMessageInspector() | |
{ | |
NameValueCollection appSettings = ConfigurationManager.AppSettings; | |
_audience = appSettings["AADAudience"]; | |
_authority = appSettings["AADAuthority"]; | |
_allowedTenantIDs = appSettings["AllowedTenantIDs"].Replace(" ", "").Split(',').ToList(); //Removed any spaces that might pop up while adding tenantIDs in config | |
_validateIssuer = bool.Parse(appSettings["ValidateIssuer"]); | |
_validateAudience = bool.Parse(appSettings["ValidateAudience"]); | |
_validateIssuerSigningKey = bool.Parse(appSettings["ValidateIssuerSigningKey"]); | |
_validateLifetime = bool.Parse(appSettings["ValidateLifetime"]); | |
_useV2 = bool.Parse(appSettings["useV2"]); | |
_securityTokenProviderV1 = GetConfig(); | |
_securityTokenProviderV2 = GetConfig(true); | |
if (!Int32.TryParse(appSettings["MaxRetries"], out _maxRetries)) | |
{ | |
_maxRetries = 2; | |
} | |
} | |
/// <summary> | |
/// Method called just after request is received. Implemented by default as defined in IDispatchMessageInspector | |
/// </summary> | |
/// Refer to the method details in IDispatchMessageInspector for more info | |
public object AfterReceiveRequest(ref Message request, IClientChannel channel, InstanceContext instanceContext) | |
{ | |
WcfErrorResponseData error = null; | |
string clientID = string.Empty; | |
var requestMessage = request.Properties["httpRequest"] as HttpRequestMessageProperty; | |
if (request == null) | |
{ | |
error = new WcfErrorResponseData(HttpStatusCode.BadRequest, string.Empty, new KeyValuePair<string, string>("InvalidOperation", "Request Body Empty.")); | |
return error; | |
} | |
var authHeader = requestMessage.Headers["Authorization"]; | |
try | |
{ | |
if (string.IsNullOrEmpty(authHeader)) | |
{ | |
error = new WcfErrorResponseData(HttpStatusCode.Unauthorized, string.Empty, new KeyValuePair<string, string>("WWW-Authenticate", "Error: Authorization Header empty! Please pass a Token using Bearer scheme.")); | |
} | |
else if (this.Authenticate(authHeader)) | |
{ | |
return null; | |
} | |
else | |
{ | |
clientID = GetClientID(authHeader); // authHeader is not null and Token authentication has failed | |
} | |
} | |
catch (Exception e) | |
{ | |
if (string.IsNullOrEmpty(clientID)) | |
clientID = GetClientID(authHeader); | |
error = new WcfErrorResponseData(HttpStatusCode.Unauthorized, string.Empty, new KeyValuePair<string, string>("WWW-Authenticate", "Token with Client ID \"" + clientID + "\" failed validation with Error Messsage - " + e.Message)); | |
} | |
if (error == null) //Means the token is valid but request must be unauthorized due to not-allowed client id | |
{ | |
if (string.IsNullOrEmpty(clientID)) | |
clientID = GetClientID(authHeader); | |
error = new WcfErrorResponseData(HttpStatusCode.Unauthorized, string.Empty, new KeyValuePair<string, string>("WWW-Authenticate", "Token with Client ID \"" + clientID + "\" failed validation with Error Messsage - " + "The client ID: " + clientID + " might not be in the allowed list.")); | |
} | |
//This will be checked before the custom invoker invokes the method, if unauthorized, nothing is invoked | |
OperationContext.Current.IncomingMessageProperties.Add("Authorized", false); | |
return error; | |
} | |
/// <summary> | |
/// Method responsible for validating the token and tenantID Claim. | |
/// </summary> | |
/// <param name="authHeader"> The JWT token as a Base64 encoded string. </param> | |
/// <returns> True if validation of Token and claims is successful, false otherwise. </returns> | |
private bool Authenticate(string authHeader) | |
{ | |
const string bearer = "Bearer "; | |
if (!authHeader.StartsWith(bearer, StringComparison.InvariantCultureIgnoreCase)) { return false; } | |
var jwtToken = authHeader.Substring(bearer.Length); | |
PopulateIssuerAndKeys(); | |
var validationParameters = GenerateTokenValidationParameters(_signingKeys, _issuer); | |
return ValidateToken(jwtToken, validationParameters); | |
} | |
/// <summary> | |
/// Method responsible for validating the token against the validation parameters. Key Rollover is | |
/// handled by refreshing the keys if SecurityTokenSignatureKeyNotFoundException is thrown. | |
/// MaxRetries can be set in web.config. | |
/// </summary> | |
/// <param name="jwtToken">The base64 encoded jwt Token starting with eY... </param> | |
/// <param name="validationParameters">The collection of parameters that will be validated in the Token. </param> | |
/// <returns>True if Token validation successful, false otherwise.</returns> | |
private bool ValidateToken(string jwtToken, TokenValidationParameters validationParameters) | |
{ | |
int count = 0; | |
bool result = false; | |
var tokenHandler = new JwtSecurityTokenHandler(); | |
SecurityToken validatedToken; | |
while (count < _maxRetries && !result) | |
{ //Retry in case the keys rolled over | |
try | |
{ | |
var claimsPrincipal = tokenHandler.ValidateToken(jwtToken, validationParameters, out validatedToken); | |
result = (CheckTenantID(validatedToken)); | |
} | |
catch (SecurityTokenSignatureKeyNotFoundException e) | |
{ | |
//Means a key roll over has occured! | |
RefreshKeys(); | |
//Try to validate the token again. If fails again then throw exception | |
//LogManager.WriteError(String.Format("AAD Signing Keys Rolled Over! : {0}", e.Message)); | |
if (count > 0) //means the refresh keys did not work! Need to inform the user about it. | |
{ | |
throw e; | |
} | |
} | |
catch (Exception e) | |
{ | |
if (count > 0) { throw e; } // means it already retried once still exception coming, so just throw it, will be catched in base caller! | |
} | |
count++; | |
} | |
return result; | |
} | |
/// <summary> | |
/// Method responsible for generating the validation parameters for validating the token. | |
/// </summary> | |
/// <param name="signingKeys"> The signing keys (public keys) for the AAD (V1/V2) Issuer. </param> | |
/// <param name="issuer"> The URI of the AAD (V1/V2) issuer. </param> | |
/// <returns> TokenValidationParameters. </returns> | |
private TokenValidationParameters GenerateTokenValidationParameters(IList<Microsoft.IdentityModel.Tokens.SecurityKey> signingKeys, string issuer) | |
{ | |
TokenValidationParameters validationParameters = new TokenValidationParameters | |
{ | |
ValidAudience = _audience, | |
ValidIssuer = issuer, | |
ValidateIssuer = _validateIssuer, | |
ValidateAudience = _validateAudience, | |
ValidateIssuerSigningKey = _validateIssuerSigningKey, | |
ValidateLifetime = _validateLifetime, | |
IssuerSigningKeys = signingKeys | |
}; | |
return validationParameters; | |
} | |
/// <summary> | |
/// Method responsible for pulling up the OpenID config for AAD v1.0 and v2.0 tokens. | |
/// </summary> | |
/// <param name="useV2">Flag indicating if v2.0 config should be retrieved.</param> | |
/// <returns>The configuration corresponding to AAD v1.0 or v2.0 endpoint.</returns> | |
private static OpenIdConnectCachingSecurityTokenProvider GetConfig(bool useV2 = false) | |
{ | |
// Ensure that the proper authority is used for V2.0 token by updating the openid-configuration | |
// URL. Perform this action only when useV2 flag is passed as true. | |
string authority = _authority; | |
if (useV2) authority = string.Format("{0}/{1}", _authority, "v2.0"); | |
string stsDiscoveryEndpoint = $"{authority}/.well-known/openid-configuration"; | |
// Client for getting Open ID Connect metadata | |
HttpClient metadataClient = new HttpClient(); | |
return new OpenIdConnectCachingSecurityTokenProvider(stsDiscoveryEndpoint, metadataClient); | |
} | |
/// <summary> | |
/// Method responsible for checking if the token sender is allowed to use our application as a tenant. | |
/// </summary> | |
/// <param name="validatedSecurityToken">Validated token containing the claims.</param> | |
/// <returns>If tenant ID is found in "AllowedtenantIDs" list fetched from config, then true else false.</returns> | |
private bool CheckTenantID(SecurityToken validatedSecurityToken) | |
{ | |
var clientID = string.Empty; | |
var token = validatedSecurityToken as JwtSecurityToken; | |
if (token != null) | |
{ | |
IEnumerable<Claim> claimsToCheck = token.Claims; | |
foreach (var claim in claimsToCheck) | |
{ | |
// In AAD V1, the tenant ID is passed as appid claim and In AAD V2, the tenant ID is passed as azp claim | |
if ((!_useV2 && claim.Type == "appid") || (_useV2 && claim.Type == "azp")) | |
{ | |
clientID = claim.Value; | |
break; | |
} | |
} | |
} | |
if (_allowedTenantIDs.Contains(clientID)) | |
return true; | |
return false; | |
} | |
/// <summary> | |
/// Method responsible for extracting the clientID from the Token | |
/// </summary> | |
/// <param name="authHeader">The Authorization HTTP Header in Bearer format.</param> | |
/// <returns>The client ID extracted from the appid/azp claim of AAD token.</returns> | |
private string GetClientID(string authHeader) | |
{ | |
string clientID = string.Empty; | |
const string bearer = "Bearer "; | |
try | |
{ | |
if (!authHeader.StartsWith(bearer, StringComparison.InvariantCultureIgnoreCase)) { return clientID; } | |
var jwtToken = authHeader.Substring(bearer.Length); | |
var tokenHandler = new JwtSecurityTokenHandler(); | |
var token = tokenHandler.ReadJwtToken(jwtToken); | |
if (token != null) | |
{ | |
IEnumerable<Claim> claimsToCheck = token.Claims; | |
foreach (var claim in claimsToCheck) | |
{ | |
// In AAD V1, the tenant ID is passed as appid claim and In AAD V2, the tenant ID is passed as azp claim | |
if ((!_useV2 && claim.Type == "appid") || (_useV2 && claim.Type == "azp")) | |
{ | |
clientID = claim.Value; | |
break; | |
} | |
} | |
} | |
} | |
catch (Exception e) | |
{ | |
LogManager.WriteError(String.Format("Unable to extract clientID from Token! : {0}", e.Message)); | |
} | |
return clientID; | |
} | |
/// <summary> | |
/// Method responsible for doing a forced refresh of issuer and signing keys if key rollover occured. | |
/// Note that OpenIDConnect library by default refreshes the keys after 24 hours (AutomaticRefreshInterval period). | |
/// </summary> | |
private void RefreshKeys() | |
{ | |
if (_useV2) | |
{ | |
_securityTokenProviderV2.RequestRefresh(); | |
_issuer = _securityTokenProviderV2.Issuer; | |
_signingKeys = _securityTokenProviderV2.SecurityTokens.ToList(); | |
} | |
else | |
{ | |
_securityTokenProviderV1.RequestRefresh(); | |
_issuer = _securityTokenProviderV1.Issuer; | |
_signingKeys = _securityTokenProviderV1.SecurityTokens.ToList(); | |
} | |
} | |
/// <summary> | |
/// Method responsible for setting the issuer and signing keys. Calls getMetadata method of OpenIDConnect. | |
/// </summary> | |
private void PopulateIssuerAndKeys() | |
{ | |
if (_useV2) | |
{ | |
_issuer = _securityTokenProviderV2.Issuer; | |
_signingKeys = _securityTokenProviderV2.SecurityTokens.ToList(); | |
} | |
else | |
{ | |
_issuer = _securityTokenProviderV1.Issuer; | |
_signingKeys = _securityTokenProviderV1.SecurityTokens.ToList(); | |
} | |
} | |
/// <summary> | |
/// Method responsible for sending proper Unauthorized reply if the token validation failed. | |
/// </summary> | |
/// <param name="reply">Reply message to be sent to the caller.</param> | |
/// <param name="correlationState">If there is any error, then it's saved in this variable as the output of AfterReceiveRequest method.</param> | |
/// <returns>Void.</returns> | |
public void BeforeSendReply(ref Message reply, object correlationState) | |
{ | |
var error = correlationState as WcfErrorResponseData; | |
if (error == null) return; | |
var responseProperty = new HttpResponseMessageProperty(); | |
reply.Properties["httpResponse"] = responseProperty; | |
responseProperty.StatusCode = error.StatusCode; | |
var headers = error.Headers; | |
if (headers == null) return; | |
foreach (var t in headers) | |
{ | |
responseProperty.Headers.Add(t.Key, t.Value); | |
} | |
} | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment