Skip to content

Instantly share code, notes, and snippets.

@arienkock
Created October 8, 2015 08:33
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
Star You must be signed in to star a gist
Embed
What would you like to do?
import java.io.IOException;
import java.util.Collection;
import java.util.concurrent.TimeUnit;
import javax.servlet.Filter;
import javax.servlet.FilterChain;
import javax.servlet.FilterConfig;
import javax.servlet.ServletException;
import javax.servlet.ServletRequest;
import javax.servlet.ServletResponse;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import com.google.common.cache.Cache;
import com.google.common.cache.CacheBuilder;
import com.google.common.collect.HashMultimap;
import com.google.common.collect.Multimap;
import com.google.common.collect.Multimaps;
public class SimilarRequestLimitingFilter implements Filter {
private static Logger logger = LoggerFactory.getLogger(SimilarRequestLimitingFilter.class);
private Multimap<String, Thread> multimap;
private int similarRequestLimit = 5;
private Cache<String, String> loggedKeysSet;
private String newline = System.getProperty("line.separator");
@Override
public void init(FilterConfig filterConfig) throws ServletException {
multimap = Multimaps.synchronizedSetMultimap(HashMultimap.<String, Thread>create());
String similarRequestLimitParam = filterConfig.getInitParameter("similarRequestLimit");
if (similarRequestLimitParam != null && similarRequestLimitParam.length() > 0) {
similarRequestLimit = Integer.parseInt(similarRequestLimitParam);
}
loggedKeysSet = CacheBuilder.newBuilder().maximumSize(100).expireAfterWrite(5, TimeUnit.SECONDS).build();
}
@Override
public void doFilter(ServletRequest request, ServletResponse response, FilterChain chain)
throws IOException, ServletException {
String key = null;
Thread currentThread = Thread.currentThread();
try {
if (request instanceof HttpServletRequest) {
HttpServletRequest req = (HttpServletRequest)request;
key = createUniqueRequestKey(req);
if (limitReached(key)) {
// to shorten work to be done in "finally"
key = null;
// status code 429 Too Many Requests (RFC 6585)
((HttpServletResponse)response).setStatus(429);
((HttpServletResponse)response).getWriter().println("429 Too Many Requests");
return;
}
multimap.put(key, currentThread);
}
chain.doFilter(request, response);
} finally {
if (key != null) {
try {
multimap.remove(key, currentThread);
} catch (Throwable t) {
multimap.clear();
}
}
}
}
/**
* Determines what identifies a request as "similar". This implementation
* concatenates the {@link HttpServletRequest#getRemoteAddr()} and the
* {@link HttpServletRequest#getRequestURI()}.
*
* Subclasses can override this
* method to customize the key on which requests are grouped.
*
* @param req
* @return
*/
protected String createUniqueRequestKey(HttpServletRequest req) {
return req.getRemoteAddr() + req.getRequestURI();
}
private boolean limitReached(String key) {
Collection<Thread> collection = multimap.get(key);
if (collection != null && collection.size() >= similarRequestLimit) {
logTraces(key, collection);
return true;
}
return false;
}
private void logTraces(String key, Collection<Thread> collection) {
// only acquire lock if necessary
if (loggedKeysSet.getIfPresent(key) == null) {
// lock on cache to do atomic putIfNotPresent
synchronized (loggedKeysSet) {
// must re-check if key exists once we have the lock
// as some other thread may have gone before is
if (loggedKeysSet.getIfPresent(key) == null) {
loggedKeysSet.put(key, key);
// and continue
} else {
return;
}
}
} else {
return;
}
// only do IO/logging when we're the first thread to detect a reached limit
// and do it outside of the synchronized block to avoid blocking other threads
StringBuilder sb = new StringBuilder();
for(Thread t : collection) {
sb.append(t.getName()).append(newline);
StackTraceElement[] stackTrace = t.getStackTrace();
for (int i = 0; i < stackTrace.length; i++) {
sb.append(stackTrace[i]).append(newline);
}
sb.append(newline);
}
logger.error("Limit per IP reached for {} \n{}", key, sb);
}
@Override
public void destroy() {
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment