This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import java.io.IOException; | |
import java.util.Collection; | |
import java.util.concurrent.TimeUnit; | |
import javax.servlet.Filter; | |
import javax.servlet.FilterChain; | |
import javax.servlet.FilterConfig; | |
import javax.servlet.ServletException; | |
import javax.servlet.ServletRequest; | |
import javax.servlet.ServletResponse; | |
import javax.servlet.http.HttpServletRequest; | |
import javax.servlet.http.HttpServletResponse; | |
import org.slf4j.Logger; | |
import org.slf4j.LoggerFactory; | |
import com.google.common.cache.Cache; | |
import com.google.common.cache.CacheBuilder; | |
import com.google.common.collect.HashMultimap; | |
import com.google.common.collect.Multimap; | |
import com.google.common.collect.Multimaps; | |
public class SimilarRequestLimitingFilter implements Filter { | |
private static Logger logger = LoggerFactory.getLogger(SimilarRequestLimitingFilter.class); | |
private Multimap<String, Thread> multimap; | |
private int similarRequestLimit = 5; | |
private Cache<String, String> loggedKeysSet; | |
private String newline = System.getProperty("line.separator"); | |
@Override | |
public void init(FilterConfig filterConfig) throws ServletException { | |
multimap = Multimaps.synchronizedSetMultimap(HashMultimap.<String, Thread>create()); | |
String similarRequestLimitParam = filterConfig.getInitParameter("similarRequestLimit"); | |
if (similarRequestLimitParam != null && similarRequestLimitParam.length() > 0) { | |
similarRequestLimit = Integer.parseInt(similarRequestLimitParam); | |
} | |
loggedKeysSet = CacheBuilder.newBuilder().maximumSize(100).expireAfterWrite(5, TimeUnit.SECONDS).build(); | |
} | |
@Override | |
public void doFilter(ServletRequest request, ServletResponse response, FilterChain chain) | |
throws IOException, ServletException { | |
String key = null; | |
Thread currentThread = Thread.currentThread(); | |
try { | |
if (request instanceof HttpServletRequest) { | |
HttpServletRequest req = (HttpServletRequest)request; | |
key = createUniqueRequestKey(req); | |
if (limitReached(key)) { | |
// to shorten work to be done in "finally" | |
key = null; | |
// status code 429 Too Many Requests (RFC 6585) | |
((HttpServletResponse)response).setStatus(429); | |
((HttpServletResponse)response).getWriter().println("429 Too Many Requests"); | |
return; | |
} | |
multimap.put(key, currentThread); | |
} | |
chain.doFilter(request, response); | |
} finally { | |
if (key != null) { | |
try { | |
multimap.remove(key, currentThread); | |
} catch (Throwable t) { | |
multimap.clear(); | |
} | |
} | |
} | |
} | |
/** | |
* Determines what identifies a request as "similar". This implementation | |
* concatenates the {@link HttpServletRequest#getRemoteAddr()} and the | |
* {@link HttpServletRequest#getRequestURI()}. | |
* | |
* Subclasses can override this | |
* method to customize the key on which requests are grouped. | |
* | |
* @param req | |
* @return | |
*/ | |
protected String createUniqueRequestKey(HttpServletRequest req) { | |
return req.getRemoteAddr() + req.getRequestURI(); | |
} | |
private boolean limitReached(String key) { | |
Collection<Thread> collection = multimap.get(key); | |
if (collection != null && collection.size() >= similarRequestLimit) { | |
logTraces(key, collection); | |
return true; | |
} | |
return false; | |
} | |
private void logTraces(String key, Collection<Thread> collection) { | |
// only acquire lock if necessary | |
if (loggedKeysSet.getIfPresent(key) == null) { | |
// lock on cache to do atomic putIfNotPresent | |
synchronized (loggedKeysSet) { | |
// must re-check if key exists once we have the lock | |
// as some other thread may have gone before is | |
if (loggedKeysSet.getIfPresent(key) == null) { | |
loggedKeysSet.put(key, key); | |
// and continue | |
} else { | |
return; | |
} | |
} | |
} else { | |
return; | |
} | |
// only do IO/logging when we're the first thread to detect a reached limit | |
// and do it outside of the synchronized block to avoid blocking other threads | |
StringBuilder sb = new StringBuilder(); | |
for(Thread t : collection) { | |
sb.append(t.getName()).append(newline); | |
StackTraceElement[] stackTrace = t.getStackTrace(); | |
for (int i = 0; i < stackTrace.length; i++) { | |
sb.append(stackTrace[i]).append(newline); | |
} | |
sb.append(newline); | |
} | |
logger.error("Limit per IP reached for {} \n{}", key, sb); | |
} | |
@Override | |
public void destroy() { | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment