Skip to content

Instantly share code, notes, and snippets.

@amrynsky
Last active October 1, 2023 19:52
Show Gist options
  • Save amrynsky/bd379b8e107e787611daf34f8ca5309b to your computer and use it in GitHub Desktop.
Save amrynsky/bd379b8e107e787611daf34f8ca5309b to your computer and use it in GitHub Desktop.
Rate limiter in Java (Token bucket & sliding window)
package com.example.demo;
import lombok.extern.slf4j.Slf4j;
import org.junit.jupiter.api.Test;
import java.time.Duration;
import java.time.Instant;
import java.time.temporal.ChronoUnit;
import java.time.temporal.TemporalUnit;
import java.util.ArrayDeque;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.function.Supplier;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertTrue;
@Slf4j
public class RateLimiterTest {
public static class SlidingWindowRateLimiter extends RateLimiter {
private final ArrayDeque<Instant> requests;
public SlidingWindowRateLimiter(int requestsPerUnit, long refreshTime, TemporalUnit timeUnit) {
super(requestsPerUnit, refreshTime, timeUnit);
this.requests = new ArrayDeque<>(requestsPerUnit);
}
@Override
public synchronized boolean allowRequest() {
while (requests.size() > 0 &&
Duration.between(requests.peekFirst(), getCurrentTime()).compareTo(refreshPeriod) > 0) {
requests.removeFirst();
}
requests.add(getCurrentTime());
log.info("[{}] request, size: {}", getCurrentTime(), requests.size());
return requests.size() <= requestsPerUnit;
}
}
public static class TokenBucketRateLimiter extends RateLimiter {
private int tokens;
private Instant lastRefresh;
public TokenBucketRateLimiter(int requestsPerUnit, long refreshTime, TemporalUnit timeUnit) {
super(requestsPerUnit, refreshTime, timeUnit);
this.tokens = requestsPerUnit;
this.lastRefresh = Instant.now();
}
@Override
public synchronized boolean allowRequest() {
var timeSinceLastRefresh = Duration.between(this.lastRefresh, getCurrentTime());
if (timeSinceLastRefresh.compareTo(refreshPeriod) > 0) {
log.info("[{}] refresh tokens: {}", getCurrentTime(), requestsPerUnit);
tokens = requestsPerUnit;
lastRefresh = getCurrentTime();
}
log.info("[{}] request, tokens: {}", getCurrentTime(), tokens);
if (tokens < 1) return false;
tokens--;
return true;
}
}
public static abstract class RateLimiter {
protected final int requestsPerUnit;
protected final Duration refreshPeriod;
private Instant currentTime;
protected RateLimiter(int requestsPerUnit, long refreshTime, TemporalUnit timeUnit) {
this.requestsPerUnit = requestsPerUnit;
this.refreshPeriod = Duration.of(refreshTime, timeUnit);
}
public abstract boolean allowRequest();
protected Instant getCurrentTime() {
return (this.currentTime != null) ? this.currentTime : Instant.now();
}
void setCurrentTime(Instant time) {
this.currentTime = time;
}
}
public static class RateLimiterFactory {
private final Supplier<RateLimiter> factory;
private final Map<String, RateLimiter> rateLimiters = new ConcurrentHashMap<>();
public RateLimiterFactory(Supplier<RateLimiter> factory) {
this.factory = factory;
}
public RateLimiter get(String tenant) {
return rateLimiters.computeIfAbsent(tenant, (t) -> factory.get());
}
}
@Test
void allowedRequestsOverTime() {
var rateLimiter = new SlidingWindowRateLimiter(3, 1, ChronoUnit.SECONDS);
var time = Instant.now();
for (int i = 0; i < 10; i++) {
time = time.plus(Duration.ofMillis(334));
rateLimiter.setCurrentTime(time);
assertTrue(rateLimiter.allowRequest());
}
}
@Test
void rateLimitedRequests() {
var rateLimiter = new SlidingWindowRateLimiter(2, 1, ChronoUnit.SECONDS);
var time = Instant.now();
rateLimiter.setCurrentTime(time);
assertTrue(rateLimiter.allowRequest());
assertTrue(rateLimiter.allowRequest());
assertFalse(rateLimiter.allowRequest());
}
@Test
void rateLimitedRequestsMultitenant() {
var rateLimiters = new RateLimiterFactory(() ->
new SlidingWindowRateLimiter(2, 1, ChronoUnit.SECONDS));
var time = Instant.now();
rateLimiters.get("c1").setCurrentTime(time);
assertTrue(rateLimiters.get("c1").allowRequest());
assertTrue(rateLimiters.get("c1").allowRequest());
assertFalse(rateLimiters.get("c1").allowRequest());
}
@Test
void allowedRequestsOverTimeMultitenant() {
var rateLimiters = new RateLimiterFactory(() ->
new SlidingWindowRateLimiter(2, 1, ChronoUnit.SECONDS));
var time = Instant.now();
for (int i = 0; i < 10; i++) {
time = time.plus(Duration.ofMillis(335));
rateLimiters.get("c1").setCurrentTime(time);
rateLimiters.get("c2").setCurrentTime(time);
assertTrue(rateLimiters.get("c1").allowRequest());
assertTrue(rateLimiters.get("c2").allowRequest());
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment