Skip to content

Instantly share code, notes, and snippets.

@barchito
Created July 11, 2017 21:20
Show Gist options
  • Save barchito/89bffc5d1b8ec01b785e77ebac295d81 to your computer and use it in GitHub Desktop.
Save barchito/89bffc5d1b8ec01b785e77ebac295d81 to your computer and use it in GitHub Desktop.
public class CustomThrottlingHandler : ThrottlingHandler
{
protected override RequestIdentity SetIdentity(HttpRequestMessage request)
{
return new RequestIdentity()
{
ClientKey = request.Headers.Contains("Authorization") ? request.Headers.GetValues("Authorization").First() : "anon",
ClientIp = base.GetClientIp(request).ToString(),
Endpoint = request.RequestUri.AbsolutePath.ToLowerInvariant()
};
}
}
public static class WebApiConfig
{
public static void Register(HttpConfiguration config)
{
var cors = new EnableCorsAttribute(origins: ConfigurationManager.AppSettings["AccessControlAllowOrigin"], headers: "*", methods: "*");
config.EnableCors(cors);
// Web API configuration and services
config.MessageHandlers.Add(new JsonWebTokenValidationHandler()
{
Audience = ConfigurationManager.AppSettings["Aud"], // client id
SymmetricKey = ConfigurationManager.AppSettings["Secret"] // client secret
});
config.MessageHandlers.Add(new TraceMessageHandler());
var maxCallPerSecondFromIp = int.Parse(ConfigurationManager.AppSettings["MaxCallPerSecondFromIp"]);
config.MessageHandlers.Add(new CustomThrottlingHandler()
{
Policy = new ThrottlePolicy(perSecond: maxCallPerSecondFromIp)
{
IpThrottling = true
},
Repository = new CacheRepository()
});
// Web API routes
config.MapHttpAttributeRoutes();
config.Routes.MapHttpRoute(
name: "DefaultApi",
routeTemplate: "api/{controller}/{id}",
defaults: new { id = RouteParameter.Optional }
);
config.Services.Add(typeof(IExceptionLogger), new AiExceptionLogger());
}
}
public class TraceMessageHandler : DelegatingHandler
{
protected override async Task<HttpResponseMessage> SendAsync(HttpRequestMessage request, CancellationToken cancellationToken)
{
try
{
if (request.RequestUri.AbsolutePath.StartsWith("/api/"))
{
if (request.Method == HttpMethod.Get)
{
var ai = new TelemetryClient();
var data = request.GetQueryNameValuePairs();
ai.TrackTrace("Query Request", data.ToDictionary(x => x.Key, x => x.Value));
}
else
{
var ai = new TelemetryClient();
var data = await request.Content.ReadAsStringAsync();
var dict = new Dictionary<string, string>();
dict.Add("props", data);
ai.TrackTrace("Data Request", dict);
}
}
}
catch (Exception)
{
}
return await base.SendAsync(request, cancellationToken);
}
}
public class JsonWebTokenValidationHandler : DelegatingHandler
{
public string SymmetricKey { get; set; }
public string Audience { get; set; }
public string Issuer { get; set; }
private static bool TryRetrieveToken(HttpRequestMessage request, out string token)
{
token = null;
IEnumerable<string> authzHeaders;
if (!request.Headers.TryGetValues("Authorization", out authzHeaders) || authzHeaders.Count() > 1)
{
// Fail if no Authorization header or more than one Authorization headers
// are found in the HTTP request
return false;
}
// Remove the bearer token scheme prefix and return the rest as ACS token
var bearerToken = authzHeaders.ElementAt(0);
token = bearerToken.StartsWith("Bearer ") ? bearerToken.Substring(7) : bearerToken;
return true;
}
protected override Task<HttpResponseMessage> SendAsync(HttpRequestMessage request, CancellationToken cancellationToken)
{
string token;
HttpResponseMessage errorResponse = null;
if (TryRetrieveToken(request, out token))
{
try
{
//var secret = this.SymmetricKey.Replace('-', '+').Replace('_', '/');
Thread.CurrentPrincipal = JsonWebToken.ValidateToken(
token,
this.SymmetricKey,
audience: this.Audience,
checkExpiration: true,
issuer: this.Issuer);
if (HttpContext.Current != null)
{
HttpContext.Current.User = new UserPrincipal(Thread.CurrentPrincipal.Identity);
}
}
catch (Jose.JoseException ex)
{
errorResponse = request.CreateErrorResponse(HttpStatusCode.Unauthorized, "Wrong token, or you are not authorized.");
errorResponse.Headers.Add("Access-Control-Allow-Origin", ConfigurationManager.AppSettings["AccessControlAllowOrigin"]);
}
catch (JsonWebToken.TokenValidationException ex)
{
errorResponse = request.CreateErrorResponse(HttpStatusCode.Unauthorized, "Wrong token, or you are not authorized.");
errorResponse.Headers.Add("Access-Control-Allow-Origin", ConfigurationManager.AppSettings["AccessControlAllowOrigin"]);
}
catch (Exception ex)
{
errorResponse = request.CreateErrorResponse(HttpStatusCode.InternalServerError, "Wrong token, or you are not authorized.");
errorResponse.Headers.Add("Access-Control-Allow-Origin", ConfigurationManager.AppSettings["AccessControlAllowOrigin"]);
}
}
return errorResponse != null ?
Task.FromResult(errorResponse) :
base.SendAsync(request, cancellationToken);
}
}
public static class JsonWebToken
{
private const string NameClaimType = "http://schemas.xmlsoap.org/ws/2005/05/identity/claims/name";
private const string RoleClaimType = "http://schemas.microsoft.com/ws/2008/06/identity/claims/role";
private const string ActorClaimType = "http://schemas.xmlsoap.org/ws/2009/09/identity/claims/actor";
private const string DefaultIssuer = "LOCAL AUTHORITY";
private const string StringClaimValueType = "http://www.w3.org/2001/XMLSchema#string";
// sort claim types by relevance
private static string[] claimTypesForUserName = new string[] { "name", "email", "user_id", "sub", "user_metadata" };
private static string[] claimsToExclude = new string[] { "iss", "sub", "aud", "exp", "iat", "identities" };
private static JavaScriptSerializer jsonSerializer = new JavaScriptSerializer();
public static ClaimsPrincipal ValidateToken(string token, string secretKey, string audience = null, bool checkExpiration = false, string issuer = null)
{
var payloadJson = JWT.Decode(token, Base64UrlDecode(secretKey));
var payloadData = jsonSerializer.Deserialize<Dictionary<string, object>>(payloadJson);
// audience check
object aud;
if (!string.IsNullOrEmpty(audience) && payloadData.TryGetValue("aud", out aud))
{
if (!aud.ToString().Equals(audience, StringComparison.Ordinal))
{
throw new TokenValidationException(string.Format("Audience mismatch. Expected: '{0}' and got: '{1}'", audience, aud));
}
}
// expiration check
object exp;
if (checkExpiration && payloadData.TryGetValue("exp", out exp))
{
DateTime validTo = FromUnixTime(long.Parse(exp.ToString()));
if (DateTime.Compare(validTo, DateTime.UtcNow) <= 0)
{
throw new TokenValidationException(
string.Format("Token is expired. Expiration: '{0}'. Current: '{1}'", validTo, DateTime.UtcNow));
}
}
// issuer check
object iss;
if (payloadData.TryGetValue("iss", out iss))
{
if (!string.IsNullOrEmpty(issuer))
{
if (!iss.ToString().Equals(issuer, StringComparison.Ordinal))
{
throw new TokenValidationException(string.Format("Token issuer mismatch. Expected: '{0}' and got: '{1}'", issuer, iss));
}
}
else
{
// if issuer is not specified, set issuer with jwt[iss]
issuer = iss.ToString();
}
}
return new ClaimsPrincipal(ClaimsIdentityFromJwt(payloadData, issuer));
}
private static List<Claim> ClaimsFromJwt(IDictionary<string, object> jwtData, string issuer)
{
var list = new List<Claim>();
issuer = issuer ?? DefaultIssuer;
foreach (KeyValuePair<string, object> pair in jwtData)
{
var claimType = pair.Key;
var source = pair.Value as ArrayList;
if (source != null)
{
foreach (var item in source)
{
list.Add(new Claim(claimType, item.ToString(), StringClaimValueType, issuer, issuer));
}
continue;
}
if (claimType.Equals("user_metadata", StringComparison.InvariantCultureIgnoreCase))
{
var metadata = pair.Value as Dictionary<string, object>;
foreach (var item in metadata)
{
list.Add(new Claim(claimType, $"{{\"{item.Key}\":\"{item.Value}\"}}", StringClaimValueType, issuer, issuer));
}
list.Add(new Claim(claimType, JsonConvert.SerializeObject(pair.Value), StringClaimValueType, issuer, issuer));
}
else
{
list.Add(new Claim(claimType, pair.Value.ToString(), StringClaimValueType, issuer, issuer));
}
}
if (list.Any(c => c.Type == "sub"))
{
list.Add(new Claim(ClaimTypes.NameIdentifier, list.First(c=>c.Type=="sub").Value, StringClaimValueType, issuer,issuer));
}
// set claim for user name
for (int i = 0; i < claimTypesForUserName.Length; i++)
{
if (list.Any(c => c.Type == claimTypesForUserName[i]))
{
var nameClaim = new Claim(NameClaimType, list.First(c => c.Type == claimTypesForUserName[i]).Value, StringClaimValueType, issuer, issuer);
list.Add(nameClaim);
break;
}
}
// dont include specific jwt claims
return list.Where(c => !claimsToExclude.Any(t => t == c.Type)).ToList();
}
private static ClaimsIdentity ClaimsIdentityFromJwt(IDictionary<string, object> jwtData, string issuer)
{
var subject = new ClaimsIdentity("Federation", NameClaimType, RoleClaimType);
var claims = ClaimsFromJwt(jwtData, issuer);
foreach (Claim claim in claims)
{
var type = claim.Type;
if (type == ActorClaimType)
{
if (subject.Actor != null)
{
throw new InvalidOperationException(string.Format(
"Jwt10401: Only a single 'Actor' is supported. Found second claim of type: '{0}', value: '{1}'", new object[] { "actor", claim.Value }));
}
var claim2 = new Claim(type, claim.Value, claim.ValueType, issuer, issuer, subject);
subject.AddClaim(claim2);
continue;
}
if (type == "user_id")
{
var claim4 = new Claim(ClaimTypes.NameIdentifier, claim.Value, claim.ValueType, issuer, issuer, subject);
subject.AddClaim(claim4);
}
var claim3 = new Claim(type, claim.Value, claim.ValueType, issuer, issuer, subject);
subject.AddClaim(claim3);
}
//subject.Name = subject.Claims.FirstOrDefault(x => x.Type)
return subject;
}
private static DateTime FromUnixTime(long unixTime)
{
var epoch = new DateTime(1970, 1, 1, 0, 0, 0, DateTimeKind.Utc);
return epoch.AddSeconds(unixTime);
}
public class TokenValidationException : Exception
{
public TokenValidationException(string message)
: base(message)
{
}
}
private static byte[] Base64UrlDecode(string arg)
{
string s = arg;
s = s.Replace('-', '+'); // 62nd char of encoding
s = s.Replace('_', '/'); // 63rd char of encoding
switch (s.Length % 4) // Pad with trailing '='s
{
case 0: break; // No pad chars in this case
case 2: s += "=="; break; // Two pad chars
case 3: s += "="; break; // One pad char
default:
throw new System.Exception(
"Illegal base64url string!");
}
return Convert.FromBase64String(s); // Standard base64 decoder
}
}
public class AiExceptionLogger : ExceptionLogger
{
public override void Log(ExceptionLoggerContext context)
{
if (context != null && context.Exception != null)
{//or reuse instance (recommended!). see note above
var ai = new TelemetryClient();
ai.TrackException(context.Exception);
}
base.Log(context);
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment