Created
November 9, 2018 23:12
-
-
Save Shengaero/ddd073bfa8c29f57a9bb4feaf29b8b08 to your computer and use it in GitHub Desktop.
Simple Proof of Concept for a kotlin-coroutine based rate-limiter API
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import kotlinx.coroutines.* | |
import kotlinx.coroutines.channels.Channel | |
import org.junit.jupiter.api.* | |
import java.util.concurrent.TimeUnit | |
import kotlin.coroutines.CoroutineContext | |
import kotlin.math.min | |
import kotlin.system.measureTimeMillis | |
import kotlin.test.assertEquals | |
import kotlin.test.assertFalse | |
import kotlin.test.assertTrue | |
interface RateLimiter { | |
val maxUses: Int | |
val currentUses: Int | |
val refreshDelay: Long | |
val isClosed: Boolean | |
suspend fun request() | |
fun tryRequest(): Boolean | |
fun close(cause: Throwable? = null) | |
} | |
private class RateLimiterImpl( | |
override val refreshDelay: Long, | |
override val maxUses: Int, | |
val context: CoroutineContext | |
): RateLimiter { | |
// The sync, used to manage concurrent access | |
//to this RateLimiter. | |
// This is an inline class wrapper for a Channel | |
//which can be viewed below this class. | |
private val sync = Sync(Channel(maxUses)) | |
// Close cause stored so we can rethrow it when appropriate | |
private var closeCause: Throwable? = null | |
// This is the "coroutine" of the RateLimiter. | |
// It runs in the background when calls to request and | |
//tryRequest are made for the first time in a reset | |
//period, and it's functionality consists of waiting | |
//the specified refreshDelay, resetting the sync, and | |
//restarting underneath some special conditions. | |
private var refresh: Job? = null | |
// This is just a safety completion for modifying the | |
//state of the currentUses property. This is set after | |
//a delay period in the refresh, and completed/set to | |
//null after the refresh has finished. | |
private var reset: CompletableDeferred<Unit>? = null | |
@Volatile override var currentUses = 0 | |
set(value) { field = min(value, maxUses) } | |
init { sync.init() } | |
override val isClosed get() = sync.isClosed | |
override suspend fun request() { | |
closeCause?.let { throw it } | |
initRefresh() | |
sync.acquire() | |
reset?.await() | |
currentUses++ | |
} | |
override fun tryRequest(): Boolean { | |
initRefresh() | |
if(sync.tryAcquire()) { | |
currentUses++ | |
return true | |
} | |
return false | |
} | |
override fun close(cause: Throwable?) { | |
if(!sync.isClosed) { | |
closeCause = cause | |
refresh?.cancel() | |
refresh = null | |
sync.close(cause) | |
} | |
} | |
private fun initRefresh(forceReset: Boolean = false) { | |
if(closeCause != null) return | |
var refresh = this.refresh | |
if(forceReset || refresh == null || !refresh.isActive) { | |
refresh = GlobalScope.launch(context) { | |
delay(refreshDelay) | |
reset = CompletableDeferred() | |
currentUses = 0 | |
for(i in 1..(maxUses - currentUses)) { | |
if(!sync.release() || !sync.isFull) | |
break // we've hit max capacity early, break | |
} | |
reset?.complete(Unit) | |
reset = null | |
} | |
refresh.invokeOnCompletion { e -> | |
if(e != null && closeCause == e) return@invokeOnCompletion | |
when(e) { | |
null -> when { | |
currentUses > 0 -> initRefresh(forceReset = true) | |
else -> this.refresh = null | |
} | |
is CancellationException -> this.refresh = null | |
else -> close(e) | |
} | |
} | |
this.refresh = refresh | |
} | |
} | |
} | |
@Suppress("EXPERIMENTAL_FEATURE_WARNING", "unused", "NOTHING_TO_INLINE") | |
private inline class Sync(private val channel: Channel<Unit>) { | |
inline val isFull: Boolean | |
inline get() = channel.isFull | |
inline val isClosed: Boolean | |
inline get() = channel.isClosedForSend || channel.isClosedForReceive | |
inline fun init() { | |
// populate the channel | |
while(!isFull) if(!release()) break | |
} | |
suspend inline fun acquire() = channel.send(Unit) | |
inline fun tryAcquire(): Boolean = channel.offer(Unit) | |
inline fun release(): Boolean = channel.poll() != null | |
inline fun close(cause: Throwable?): Boolean = channel.close(cause) | |
} | |
@Suppress("TestFunctionName") | |
fun RateLimiter( | |
context: CoroutineContext, | |
maxUses: Int, | |
refreshDelay: Long, | |
unit: TimeUnit = TimeUnit.MILLISECONDS | |
): RateLimiter { | |
require(maxUses > 0) { "maxUses must be > 0" } | |
require(refreshDelay > 0) { "refreshDelay must be > 0" } | |
val trueDelay = TimeUnit.MILLISECONDS.convert(refreshDelay, unit) | |
return RateLimiterImpl(trueDelay, maxUses, context) | |
} | |
class RateLimiterTests: CoroutineTestBase() { | |
private lateinit var dispatcher: ExecutorCoroutineDispatcher | |
@BeforeEach fun init(info: TestInfo) { | |
dispatcher = newPool(info.displayName) | |
} | |
@Test fun `Test RateLimiter Usage Count`() = runTest { | |
val rateLimiter = RateLimiter(dispatcher, 5, 3, TimeUnit.SECONDS) | |
assertEquals(0, rateLimiter.currentUses) | |
rateLimiter.request() | |
assertEquals(1, rateLimiter.currentUses) | |
rateLimiter.request() | |
rateLimiter.request() | |
assertEquals(3, rateLimiter.currentUses) | |
delay(3010) | |
assertEquals(0, rateLimiter.currentUses) | |
rateLimiter.close() | |
} | |
@RepeatedTest(5) fun `Test RateLimiter Suspension On Limit Hit`(info: RepetitionInfo) = runTest { | |
val rateLimiter = RateLimiter(dispatcher, info.currentRepetition, 1, TimeUnit.SECONDS) | |
repeat(info.currentRepetition) { | |
assertTrue(rateLimiter.currentUses < info.currentRepetition) | |
rateLimiter.request() | |
} | |
assertTrue(1000 <= measureTimeMillis { rateLimiter.request() }) | |
rateLimiter.close() | |
} | |
@Test fun `Test RateLimiter Usage In Parallel`() = runTest { | |
val rateLimiter = RateLimiter(dispatcher, 3, 4, TimeUnit.SECONDS) | |
val block = launch { | |
delay(3000) | |
rateLimiter.request() | |
assertEquals(3, rateLimiter.currentUses) | |
} | |
launch { | |
delay(1000) | |
rateLimiter.request() | |
assertEquals(1, rateLimiter.currentUses) | |
} | |
launch { | |
delay(2000) | |
rateLimiter.request() | |
assertEquals(2, rateLimiter.currentUses) | |
} | |
withContext(coroutineContext) { | |
block.join() | |
assertFalse(rateLimiter.tryRequest()) | |
assertTrue(1000 <= measureTimeMillis { rateLimiter.request() }) | |
} | |
rateLimiter.close() | |
} | |
@Test fun `Test RateLimiter Request Queue Delays Reasonably`() = runTest { | |
val rateLimiter = RateLimiter(dispatcher, 3, 2, TimeUnit.SECONDS) | |
assertTrue(4000 <= measureTimeMillis { repeat(7) { rateLimiter.request() } }) | |
rateLimiter.close() | |
} | |
@AfterEach fun destroy() { | |
dispatcher.close() | |
} | |
private fun newPool(testName: String): ExecutorCoroutineDispatcher { | |
return newSingleThreadContext("RateLimiterTests ThreadPool - $testName") | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment