Skip to content

Instantly share code, notes, and snippets.

@anand374
Created April 12, 2021 18:41
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save anand374/4150380166f53fb09d92a80c7da55d84 to your computer and use it in GitHub Desktop.
Save anand374/4150380166f53fb09d92a80c7da55d84 to your computer and use it in GitHub Desktop.
Message Inspector for WCF SOAP AAD Token Validation
using Microsoft.IdentityModel.Tokens;
using System;
using System.Collections.Generic;
using System.Collections.Specialized;
using System.Configuration;
using System.IdentityModel.Tokens.Jwt;
using System.Linq;
using System.Net;
using System.Net.Http;
using System.Security.Claims;
using System.ServiceModel;
using System.ServiceModel.Channels;
using System.ServiceModel.Dispatcher;
namespace TokenValidator
{
public class BearerTokenMessageInspector : IDispatchMessageInspector
{
private static readonly string _audience;
private static readonly string _authority;
private static readonly OpenIdConnectCachingSecurityTokenProvider _securityTokenProviderV1;
private static readonly OpenIdConnectCachingSecurityTokenProvider _securityTokenProviderV2;
private static readonly bool _validateIssuer;
private static readonly bool _validateAudience;
private static readonly bool _validateLifetime;
private static readonly bool _validateIssuerSigningKey;
private static readonly bool _useV2;
private static readonly int _maxRetries;
private static readonly List<string> _allowedTenantIDs;
private static List<Microsoft.IdentityModel.Tokens.SecurityKey> _signingKeys;
private static string _issuer = string.Empty;
/// <summary>
/// Static Constructor to initialize all the config driven variables
/// </summary>
static BearerTokenMessageInspector()
{
NameValueCollection appSettings = ConfigurationManager.AppSettings;
_audience = appSettings["AADAudience"];
_authority = appSettings["AADAuthority"];
_allowedTenantIDs = appSettings["AllowedTenantIDs"].Replace(" ", "").Split(',').ToList(); //Removed any spaces that might pop up while adding tenantIDs in config
_validateIssuer = bool.Parse(appSettings["ValidateIssuer"]);
_validateAudience = bool.Parse(appSettings["ValidateAudience"]);
_validateIssuerSigningKey = bool.Parse(appSettings["ValidateIssuerSigningKey"]);
_validateLifetime = bool.Parse(appSettings["ValidateLifetime"]);
_useV2 = bool.Parse(appSettings["useV2"]);
_securityTokenProviderV1 = GetConfig();
_securityTokenProviderV2 = GetConfig(true);
if (!Int32.TryParse(appSettings["MaxRetries"], out _maxRetries))
{
_maxRetries = 2;
}
}
/// <summary>
/// Method called just after request is received. Implemented by default as defined in IDispatchMessageInspector
/// </summary>
/// Refer to the method details in IDispatchMessageInspector for more info
public object AfterReceiveRequest(ref Message request, IClientChannel channel, InstanceContext instanceContext)
{
WcfErrorResponseData error = null;
string clientID = string.Empty;
var requestMessage = request.Properties["httpRequest"] as HttpRequestMessageProperty;
if (request == null)
{
error = new WcfErrorResponseData(HttpStatusCode.BadRequest, string.Empty, new KeyValuePair<string, string>("InvalidOperation", "Request Body Empty."));
return error;
}
var authHeader = requestMessage.Headers["Authorization"];
try
{
if (string.IsNullOrEmpty(authHeader))
{
error = new WcfErrorResponseData(HttpStatusCode.Unauthorized, string.Empty, new KeyValuePair<string, string>("WWW-Authenticate", "Error: Authorization Header empty! Please pass a Token using Bearer scheme."));
}
else if (this.Authenticate(authHeader))
{
return null;
}
else
{
clientID = GetClientID(authHeader); // authHeader is not null and Token authentication has failed
}
}
catch (Exception e)
{
if (string.IsNullOrEmpty(clientID))
clientID = GetClientID(authHeader);
error = new WcfErrorResponseData(HttpStatusCode.Unauthorized, string.Empty, new KeyValuePair<string, string>("WWW-Authenticate", "Token with Client ID \"" + clientID + "\" failed validation with Error Messsage - " + e.Message));
}
if (error == null) //Means the token is valid but request must be unauthorized due to not-allowed client id
{
if (string.IsNullOrEmpty(clientID))
clientID = GetClientID(authHeader);
error = new WcfErrorResponseData(HttpStatusCode.Unauthorized, string.Empty, new KeyValuePair<string, string>("WWW-Authenticate", "Token with Client ID \"" + clientID + "\" failed validation with Error Messsage - " + "The client ID: " + clientID + " might not be in the allowed list."));
}
//This will be checked before the custom invoker invokes the method, if unauthorized, nothing is invoked
OperationContext.Current.IncomingMessageProperties.Add("Authorized", false);
return error;
}
/// <summary>
/// Method responsible for validating the token and tenantID Claim.
/// </summary>
/// <param name="authHeader"> The JWT token as a Base64 encoded string. </param>
/// <returns> True if validation of Token and claims is successful, false otherwise. </returns>
private bool Authenticate(string authHeader)
{
const string bearer = "Bearer ";
if (!authHeader.StartsWith(bearer, StringComparison.InvariantCultureIgnoreCase)) { return false; }
var jwtToken = authHeader.Substring(bearer.Length);
PopulateIssuerAndKeys();
var validationParameters = GenerateTokenValidationParameters(_signingKeys, _issuer);
return ValidateToken(jwtToken, validationParameters);
}
/// <summary>
/// Method responsible for validating the token against the validation parameters. Key Rollover is
/// handled by refreshing the keys if SecurityTokenSignatureKeyNotFoundException is thrown.
/// MaxRetries can be set in web.config.
/// </summary>
/// <param name="jwtToken">The base64 encoded jwt Token starting with eY... </param>
/// <param name="validationParameters">The collection of parameters that will be validated in the Token. </param>
/// <returns>True if Token validation successful, false otherwise.</returns>
private bool ValidateToken(string jwtToken, TokenValidationParameters validationParameters)
{
int count = 0;
bool result = false;
var tokenHandler = new JwtSecurityTokenHandler();
SecurityToken validatedToken;
while (count < _maxRetries && !result)
{ //Retry in case the keys rolled over
try
{
var claimsPrincipal = tokenHandler.ValidateToken(jwtToken, validationParameters, out validatedToken);
result = (CheckTenantID(validatedToken));
}
catch (SecurityTokenSignatureKeyNotFoundException e)
{
//Means a key roll over has occured!
RefreshKeys();
//Try to validate the token again. If fails again then throw exception
//LogManager.WriteError(String.Format("AAD Signing Keys Rolled Over! : {0}", e.Message));
if (count > 0) //means the refresh keys did not work! Need to inform the user about it.
{
throw e;
}
}
catch (Exception e)
{
if (count > 0) { throw e; } // means it already retried once still exception coming, so just throw it, will be catched in base caller!
}
count++;
}
return result;
}
/// <summary>
/// Method responsible for generating the validation parameters for validating the token.
/// </summary>
/// <param name="signingKeys"> The signing keys (public keys) for the AAD (V1/V2) Issuer. </param>
/// <param name="issuer"> The URI of the AAD (V1/V2) issuer. </param>
/// <returns> TokenValidationParameters. </returns>
private TokenValidationParameters GenerateTokenValidationParameters(IList<Microsoft.IdentityModel.Tokens.SecurityKey> signingKeys, string issuer)
{
TokenValidationParameters validationParameters = new TokenValidationParameters
{
ValidAudience = _audience,
ValidIssuer = issuer,
ValidateIssuer = _validateIssuer,
ValidateAudience = _validateAudience,
ValidateIssuerSigningKey = _validateIssuerSigningKey,
ValidateLifetime = _validateLifetime,
IssuerSigningKeys = signingKeys
};
return validationParameters;
}
/// <summary>
/// Method responsible for pulling up the OpenID config for AAD v1.0 and v2.0 tokens.
/// </summary>
/// <param name="useV2">Flag indicating if v2.0 config should be retrieved.</param>
/// <returns>The configuration corresponding to AAD v1.0 or v2.0 endpoint.</returns>
private static OpenIdConnectCachingSecurityTokenProvider GetConfig(bool useV2 = false)
{
// Ensure that the proper authority is used for V2.0 token by updating the openid-configuration
// URL. Perform this action only when useV2 flag is passed as true.
string authority = _authority;
if (useV2) authority = string.Format("{0}/{1}", _authority, "v2.0");
string stsDiscoveryEndpoint = $"{authority}/.well-known/openid-configuration";
// Client for getting Open ID Connect metadata
HttpClient metadataClient = new HttpClient();
return new OpenIdConnectCachingSecurityTokenProvider(stsDiscoveryEndpoint, metadataClient);
}
/// <summary>
/// Method responsible for checking if the token sender is allowed to use our application as a tenant.
/// </summary>
/// <param name="validatedSecurityToken">Validated token containing the claims.</param>
/// <returns>If tenant ID is found in "AllowedtenantIDs" list fetched from config, then true else false.</returns>
private bool CheckTenantID(SecurityToken validatedSecurityToken)
{
var clientID = string.Empty;
var token = validatedSecurityToken as JwtSecurityToken;
if (token != null)
{
IEnumerable<Claim> claimsToCheck = token.Claims;
foreach (var claim in claimsToCheck)
{
// In AAD V1, the tenant ID is passed as appid claim and In AAD V2, the tenant ID is passed as azp claim
if ((!_useV2 && claim.Type == "appid") || (_useV2 && claim.Type == "azp"))
{
clientID = claim.Value;
break;
}
}
}
if (_allowedTenantIDs.Contains(clientID))
return true;
return false;
}
/// <summary>
/// Method responsible for extracting the clientID from the Token
/// </summary>
/// <param name="authHeader">The Authorization HTTP Header in Bearer format.</param>
/// <returns>The client ID extracted from the appid/azp claim of AAD token.</returns>
private string GetClientID(string authHeader)
{
string clientID = string.Empty;
const string bearer = "Bearer ";
try
{
if (!authHeader.StartsWith(bearer, StringComparison.InvariantCultureIgnoreCase)) { return clientID; }
var jwtToken = authHeader.Substring(bearer.Length);
var tokenHandler = new JwtSecurityTokenHandler();
var token = tokenHandler.ReadJwtToken(jwtToken);
if (token != null)
{
IEnumerable<Claim> claimsToCheck = token.Claims;
foreach (var claim in claimsToCheck)
{
// In AAD V1, the tenant ID is passed as appid claim and In AAD V2, the tenant ID is passed as azp claim
if ((!_useV2 && claim.Type == "appid") || (_useV2 && claim.Type == "azp"))
{
clientID = claim.Value;
break;
}
}
}
}
catch (Exception e)
{
LogManager.WriteError(String.Format("Unable to extract clientID from Token! : {0}", e.Message));
}
return clientID;
}
/// <summary>
/// Method responsible for doing a forced refresh of issuer and signing keys if key rollover occured.
/// Note that OpenIDConnect library by default refreshes the keys after 24 hours (AutomaticRefreshInterval period).
/// </summary>
private void RefreshKeys()
{
if (_useV2)
{
_securityTokenProviderV2.RequestRefresh();
_issuer = _securityTokenProviderV2.Issuer;
_signingKeys = _securityTokenProviderV2.SecurityTokens.ToList();
}
else
{
_securityTokenProviderV1.RequestRefresh();
_issuer = _securityTokenProviderV1.Issuer;
_signingKeys = _securityTokenProviderV1.SecurityTokens.ToList();
}
}
/// <summary>
/// Method responsible for setting the issuer and signing keys. Calls getMetadata method of OpenIDConnect.
/// </summary>
private void PopulateIssuerAndKeys()
{
if (_useV2)
{
_issuer = _securityTokenProviderV2.Issuer;
_signingKeys = _securityTokenProviderV2.SecurityTokens.ToList();
}
else
{
_issuer = _securityTokenProviderV1.Issuer;
_signingKeys = _securityTokenProviderV1.SecurityTokens.ToList();
}
}
/// <summary>
/// Method responsible for sending proper Unauthorized reply if the token validation failed.
/// </summary>
/// <param name="reply">Reply message to be sent to the caller.</param>
/// <param name="correlationState">If there is any error, then it's saved in this variable as the output of AfterReceiveRequest method.</param>
/// <returns>Void.</returns>
public void BeforeSendReply(ref Message reply, object correlationState)
{
var error = correlationState as WcfErrorResponseData;
if (error == null) return;
var responseProperty = new HttpResponseMessageProperty();
reply.Properties["httpResponse"] = responseProperty;
responseProperty.StatusCode = error.StatusCode;
var headers = error.Headers;
if (headers == null) return;
foreach (var t in headers)
{
responseProperty.Headers.Add(t.Key, t.Value);
}
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment