Skip to content

Instantly share code, notes, and snippets.

@Shengaero
Created November 9, 2018 23:12
Show Gist options
  • Save Shengaero/ddd073bfa8c29f57a9bb4feaf29b8b08 to your computer and use it in GitHub Desktop.
Save Shengaero/ddd073bfa8c29f57a9bb4feaf29b8b08 to your computer and use it in GitHub Desktop.
Simple Proof of Concept for a kotlin-coroutine based rate-limiter API
import kotlinx.coroutines.*
import kotlinx.coroutines.channels.Channel
import org.junit.jupiter.api.*
import java.util.concurrent.TimeUnit
import kotlin.coroutines.CoroutineContext
import kotlin.math.min
import kotlin.system.measureTimeMillis
import kotlin.test.assertEquals
import kotlin.test.assertFalse
import kotlin.test.assertTrue
interface RateLimiter {
val maxUses: Int
val currentUses: Int
val refreshDelay: Long
val isClosed: Boolean
suspend fun request()
fun tryRequest(): Boolean
fun close(cause: Throwable? = null)
}
private class RateLimiterImpl(
override val refreshDelay: Long,
override val maxUses: Int,
val context: CoroutineContext
): RateLimiter {
// The sync, used to manage concurrent access
//to this RateLimiter.
// This is an inline class wrapper for a Channel
//which can be viewed below this class.
private val sync = Sync(Channel(maxUses))
// Close cause stored so we can rethrow it when appropriate
private var closeCause: Throwable? = null
// This is the "coroutine" of the RateLimiter.
// It runs in the background when calls to request and
//tryRequest are made for the first time in a reset
//period, and it's functionality consists of waiting
//the specified refreshDelay, resetting the sync, and
//restarting underneath some special conditions.
private var refresh: Job? = null
// This is just a safety completion for modifying the
//state of the currentUses property. This is set after
//a delay period in the refresh, and completed/set to
//null after the refresh has finished.
private var reset: CompletableDeferred<Unit>? = null
@Volatile override var currentUses = 0
set(value) { field = min(value, maxUses) }
init { sync.init() }
override val isClosed get() = sync.isClosed
override suspend fun request() {
closeCause?.let { throw it }
initRefresh()
sync.acquire()
reset?.await()
currentUses++
}
override fun tryRequest(): Boolean {
initRefresh()
if(sync.tryAcquire()) {
currentUses++
return true
}
return false
}
override fun close(cause: Throwable?) {
if(!sync.isClosed) {
closeCause = cause
refresh?.cancel()
refresh = null
sync.close(cause)
}
}
private fun initRefresh(forceReset: Boolean = false) {
if(closeCause != null) return
var refresh = this.refresh
if(forceReset || refresh == null || !refresh.isActive) {
refresh = GlobalScope.launch(context) {
delay(refreshDelay)
reset = CompletableDeferred()
currentUses = 0
for(i in 1..(maxUses - currentUses)) {
if(!sync.release() || !sync.isFull)
break // we've hit max capacity early, break
}
reset?.complete(Unit)
reset = null
}
refresh.invokeOnCompletion { e ->
if(e != null && closeCause == e) return@invokeOnCompletion
when(e) {
null -> when {
currentUses > 0 -> initRefresh(forceReset = true)
else -> this.refresh = null
}
is CancellationException -> this.refresh = null
else -> close(e)
}
}
this.refresh = refresh
}
}
}
@Suppress("EXPERIMENTAL_FEATURE_WARNING", "unused", "NOTHING_TO_INLINE")
private inline class Sync(private val channel: Channel<Unit>) {
inline val isFull: Boolean
inline get() = channel.isFull
inline val isClosed: Boolean
inline get() = channel.isClosedForSend || channel.isClosedForReceive
inline fun init() {
// populate the channel
while(!isFull) if(!release()) break
}
suspend inline fun acquire() = channel.send(Unit)
inline fun tryAcquire(): Boolean = channel.offer(Unit)
inline fun release(): Boolean = channel.poll() != null
inline fun close(cause: Throwable?): Boolean = channel.close(cause)
}
@Suppress("TestFunctionName")
fun RateLimiter(
context: CoroutineContext,
maxUses: Int,
refreshDelay: Long,
unit: TimeUnit = TimeUnit.MILLISECONDS
): RateLimiter {
require(maxUses > 0) { "maxUses must be > 0" }
require(refreshDelay > 0) { "refreshDelay must be > 0" }
val trueDelay = TimeUnit.MILLISECONDS.convert(refreshDelay, unit)
return RateLimiterImpl(trueDelay, maxUses, context)
}
class RateLimiterTests: CoroutineTestBase() {
private lateinit var dispatcher: ExecutorCoroutineDispatcher
@BeforeEach fun init(info: TestInfo) {
dispatcher = newPool(info.displayName)
}
@Test fun `Test RateLimiter Usage Count`() = runTest {
val rateLimiter = RateLimiter(dispatcher, 5, 3, TimeUnit.SECONDS)
assertEquals(0, rateLimiter.currentUses)
rateLimiter.request()
assertEquals(1, rateLimiter.currentUses)
rateLimiter.request()
rateLimiter.request()
assertEquals(3, rateLimiter.currentUses)
delay(3010)
assertEquals(0, rateLimiter.currentUses)
rateLimiter.close()
}
@RepeatedTest(5) fun `Test RateLimiter Suspension On Limit Hit`(info: RepetitionInfo) = runTest {
val rateLimiter = RateLimiter(dispatcher, info.currentRepetition, 1, TimeUnit.SECONDS)
repeat(info.currentRepetition) {
assertTrue(rateLimiter.currentUses < info.currentRepetition)
rateLimiter.request()
}
assertTrue(1000 <= measureTimeMillis { rateLimiter.request() })
rateLimiter.close()
}
@Test fun `Test RateLimiter Usage In Parallel`() = runTest {
val rateLimiter = RateLimiter(dispatcher, 3, 4, TimeUnit.SECONDS)
val block = launch {
delay(3000)
rateLimiter.request()
assertEquals(3, rateLimiter.currentUses)
}
launch {
delay(1000)
rateLimiter.request()
assertEquals(1, rateLimiter.currentUses)
}
launch {
delay(2000)
rateLimiter.request()
assertEquals(2, rateLimiter.currentUses)
}
withContext(coroutineContext) {
block.join()
assertFalse(rateLimiter.tryRequest())
assertTrue(1000 <= measureTimeMillis { rateLimiter.request() })
}
rateLimiter.close()
}
@Test fun `Test RateLimiter Request Queue Delays Reasonably`() = runTest {
val rateLimiter = RateLimiter(dispatcher, 3, 2, TimeUnit.SECONDS)
assertTrue(4000 <= measureTimeMillis { repeat(7) { rateLimiter.request() } })
rateLimiter.close()
}
@AfterEach fun destroy() {
dispatcher.close()
}
private fun newPool(testName: String): ExecutorCoroutineDispatcher {
return newSingleThreadContext("RateLimiterTests ThreadPool - $testName")
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment