Skip to content

Instantly share code, notes, and snippets.

@coryt
Created January 26, 2015 00:24
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 1 You must be signed in to fork a gist
  • Save coryt/f1d040047cc45b6650ff to your computer and use it in GitHub Desktop.
Save coryt/f1d040047cc45b6650ff to your computer and use it in GitHub Desktop.
ServiceStack plugin to throttle api requests
public class ThrottlePlugin : IPlugin
{
/// <param name="redisHost">host name</param>
/// <param name="redisPort">port</param>
/// <param name="redisPassword">password</param>
public ThrottlePlugin(string redisHost, int redisPort, string redisPassword = null)
{
_redisClient = new RedisClient(redisHost, redisPort, redisPassword);
}
private string _scriptSha;
private RedisClient _redisClient;
private Dictionary<Type, ThrottleInfoAttribute> _throttleInfoMap = new Dictionary<Type, ThrottleInfoAttribute>();
public void Register(IAppHost appHost)
{
RegisterThrottleInfoForAllRoutes(appHost);
//Store the lua script in redis for quick ref during request lookups
_redisClient.RemoveAllLuaScripts();
_scriptSha = _redisClient.LoadLuaScript(ReadLuaScriptResource("rate_limit.lua"));
appHost.GlobalRequestFilters.Add((request, response, requestDto) =>
{
//this request isn't setup to be throttled
if (!_throttleInfoMap.ContainsKey(requestDto.GetType()))
return;
var throttleInfo = _throttleInfoMap[requestDto.GetType()];
var key = string.Format("{0}:{1}", request.RemoteIp, request.OperationName);
try
{
var result = _redisClient.ExecLuaShaAsString(_scriptSha,
new[] {key},
new[]
{
throttleInfo.PerMinute.ToString(),
throttleInfo.PerHour.ToString(),
throttleInfo.PerDay.ToString(),
SecondsFromUnixTime().ToString()
}
);
if (result != null)
{
response.StatusCode = 429;
response.StatusDescription = "Too many Requests. Back-off and try again later.";
response.Close();
}
}
catch (Exception ex)
{
//got an error calling redis so log something and let the redis through
}
});
}
private void RegisterThrottleInfoForAllRoutes(IAppHost appHost)
{
//pre calculate all request throttling attributes to reduce lookup time during a request
foreach (var operation in appHost.Metadata.Operations)
{
var throttleAttribute = operation.RequestType.GetCustomAttributes(typeof(ThrottleInfoAttribute)).First() as ThrottleInfoAttribute;
if (throttleAttribute == null)
continue;
_throttleInfoMap.Add(operation.RequestType, throttleAttribute);
}
}
private int SecondsFromUnixTime()
{
TimeSpan t = (DateTime.UtcNow - new DateTime(1970, 1, 1));
return (int)t.TotalSeconds;
}
private string ReadLuaScriptResource(string resourceName)
{
var assembly = Assembly.GetExecutingAssembly();
var script = string.Empty;
using (Stream stream = assembly.GetManifestResourceStream("Throttling." + resourceName))
using (StreamReader reader = new StreamReader(stream))
{
script = reader.ReadToEnd();
}
return script;
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment