Skip to content

Instantly share code, notes, and snippets.

@r4nc1d
Last active December 15, 2021 22:49
Show Gist options
  • Save r4nc1d/f01e4594917843189299eed98f5043e2 to your computer and use it in GitHub Desktop.
Save r4nc1d/f01e4594917843189299eed98f5043e2 to your computer and use it in GitHub Desktop.
public class RateLimitMiddleware
{
private readonly RequestDelegate next;
private readonly IDistributedCache distributedCache;
private readonly IHostingEnvironment hostingEnvironment;
public RateLimitMiddleware(RequestDelegate next, IDistributedCache distributedCache, IHostingEnvironment hostingEnvironment)
{
this.next = next;
this.distributedCache = distributedCache;
this.hostingEnvironment = hostingEnvironment;
}
public async Task InvokeAsync(HttpContext context)
{
var requestLimit = this.hostingEnvironment.IsDevelopment() ? 10000 : 240;
var throttler = new Throttler(this.distributedCache, clientId: context.User.Identity.Name, requestLimit: requestLimit, timeoutInSeconds: 60);
if (throttler.RequestShouldBeThrottled())
{
context.Response.StatusCode = (int)HttpStatusCode.TooManyRequests;
ApplyRateLimitHeaders(context, throttler);
return;
}
await throttler.IncrementRequestCount();
ApplyRateLimitHeaders(context, throttler);
await this.next.Invoke(context);
}
private static void ApplyRateLimitHeaders(HttpContext context, Throttler throttler)
{
// Add X-Rate- headers
foreach (var header in throttler.GetRateLimitHeaders())
{
context.Response.Headers.Add(header.Key, header.Value);
}
}
}
public static class RateLimitMiddlewareExtensions
{
public static IApplicationBuilder UseRateLimiting(this IApplicationBuilder builder)
{
return builder.UseMiddleware<RateLimitMiddleware>();
}
}
public class Throttler
{
private readonly IDistributedCache distributedCache;
private readonly string clientId;
private readonly int requestLimit;
private readonly int timeoutInSeconds;
public Throttler(IDistributedCache distributedCache, string clientId, int requestLimit, int timeoutInSeconds)
{
this.distributedCache = distributedCache;
this.requestLimit = requestLimit;
this.timeoutInSeconds = timeoutInSeconds;
this.clientId = clientId;
}
public bool RequestShouldBeThrottled()
{
var throttleInfo = this.GetThrottleInfoFromCache();
return (throttleInfo.RequestCount >= requestLimit);
}
public Task IncrementRequestCount()
{
// potential threading issue!!
var throttleInfo = this.GetThrottleInfoFromCache();
throttleInfo.RequestCount++;
return this.distributedCache.SetAsync(this.GetKey(this.clientId), throttleInfo,
new DistributedCacheEntryOptions
{ AbsoluteExpirationRelativeToNow = TimeSpan.FromSeconds(timeoutInSeconds) });
}
public Dictionary<string, string> GetRateLimitHeaders()
{
var throttleInfo = this.GetThrottleInfoFromCache();
var requestsRemaining = Math.Max(requestLimit - throttleInfo.RequestCount, 0);
var headers = new Dictionary<string, string>
{
{"X-RateLimit-Limit", this.requestLimit.ToString()},
{"X-RateLimit-Remaining", requestsRemaining.ToString()},
{"X-RateLimit-Reset", throttleInfo.ExpiresAt.ToString()}
};
return headers;
}
private ThrottleInfo GetThrottleInfoFromCache()
{
var throttleInfo = distributedCache.GetOrCreate(this.GetKey(this.clientId), () =>
{
return new ThrottleInfo()
{
ExpiresAt = DateTimeOffset.Now.AddSeconds(timeoutInSeconds),
RequestCount = 0
};
});
return throttleInfo;
}
private string GetKey(string key)
{
return "Throttle:" + key;
}
private class ThrottleInfo
{
public DateTimeOffset ExpiresAt { get; set; }
public int RequestCount { get; set; }
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment