Skip to content

Instantly share code, notes, and snippets.

@paulo-raca
Last active October 1, 2023 22:16
Show Gist options
  • Save paulo-raca/ef6a827046a5faec95024ff406d3a692 to your computer and use it in GitHub Desktop.
Save paulo-raca/ef6a827046a5faec95024ff406d3a692 to your computer and use it in GitHub Desktop.
Condition Variables for Kotlin Coroutines
package kotlinx.coroutines.sync
import kotlinx.coroutines.TimeoutCancellationException
import kotlinx.coroutines.sync.Mutex
import kotlinx.coroutines.withTimeout
import java.util.function.Predicate
import kotlin.time.Duration
import kotlin.time.ExperimentalTime
import kotlin.time.nanoseconds
/**
* This should be part of kotlin-coroutines: https://github.com/Kotlin/kotlinx.coroutines/issues/2531
*/
class Condition(val mutex: Mutex) {
val waiting = LinkedHashSet<Mutex>()
/**
* Blocks this coroutine until the predicate is true or the specified timeout has elapsed
*
* The associated mutex is unlocked while this coroutine is awaiting
*
* @return true If this coroutine was waked by signal() or signalAll(), false if the timeout has elapsed
*/
@ExperimentalTime
suspend fun awaitUntil(timeout: Duration = Duration.INFINITE, owner: Any? = null, predicate: () -> Boolean): Boolean {
val start = System.nanoTime()
while (!predicate()) {
val elapsed = (System.nanoTime() - start).nanoseconds
val remainingTimeout = timeout - elapsed
if (remainingTimeout < Duration.ZERO) {
return false // Timeout elapsed without success
}
await(remainingTimeout, owner)
}
return true
}
/**
* Blocks this coroutine until unblocked by signal() or signalAll(), or the specified timeout has elapsed
*
* The associated mutex is unlocked while this coroutine is awaiting
*
* @return true If this coroutine was waked by signal() or signalAll(), false if the timeout has elapsed
*/
@ExperimentalTime
suspend fun await(timeout: Duration = Duration.INFINITE, owner: Any? = null): Boolean {
ensureLocked(owner, "wait")
val waiter = Mutex(true)
waiting.add(waiter)
mutex.unlock(owner)
try {
withTimeout(timeout) {
waiter.lock()
}
return true
} catch (e: TimeoutCancellationException) {
return false
} finally {
mutex.lock(owner)
waiting.remove(waiter)
}
}
/**
* Wakes up one coroutine blocked in await()
*/
suspend fun signal(owner: Any? = null) {
ensureLocked(owner, "notify")
val it = waiting.iterator()
if (it.hasNext()) {
val waiter = it.next()
it.remove()
waiter.unlock()
}
}
/**
* Wakes up all coroutines blocked in await()
*/
suspend fun signalAll(owner: Any? = null) {
ensureLocked(owner, "notifyAll")
val it = waiting.iterator()
while (it.hasNext()) {
val waiter = it.next()
it.remove()
waiter.unlock()
}
}
internal fun ensureLocked(owner: Any?, funcName: String) {
val isLocked = if (owner == null) mutex.isLocked else mutex.holdsLock(owner)
if (!isLocked) {
throw IllegalStateException("${funcName} requires a locked mutex")
}
}
}
fun Mutex.newCondition(): Condition {
return Condition(this)
}
package kotlinx.coroutines.sync
import kotlinx.coroutines.sync.newCondition
import kotlinx.coroutines.*
import kotlinx.coroutines.sync.Mutex
import kotlinx.coroutines.sync.withLock
import org.junit.jupiter.api.Assertions
import org.junit.jupiter.api.Test
import java.util.stream.Collectors
import java.util.stream.IntStream
import kotlin.time.ExperimentalTime
import kotlin.time.milliseconds
import kotlin.time.seconds
@ExperimentalTime
class ConditionTest {
val lock = Mutex()
val cond = lock.newCondition()
@Test
fun testAwaitWithoutSignal() {
runBlocking {
lock.withLock {
Assertions.assertFalse(cond.await(1.seconds))
}
}
}
@Test
fun testAwaitSignal() {
runBlocking {
launch {
delay(500)
lock.withLock {
cond.signal()
}
}
lock.withLock {
Assertions.assertTrue(cond.await(1.seconds))
Assertions.assertFalse(cond.await(1.seconds))
}
}
}
@Test
fun testSignalAwait() {
runBlocking {
lock.withLock {
cond.signal()
}
lock.withLock {
delay(500)
Assertions.assertFalse(cond.await(1.seconds))
}
}
}
@Test
fun testNotifyOnce() {
runBlocking {
val waiters = IntStream.range(0, 5)
.mapToObj { i ->
async<Boolean> {
lock.withLock {
val ret = cond.await(1.seconds)
ret
}
}
}
.collect(Collectors.toList())
.toTypedArray()
delay(100.milliseconds)
lock.withLock {
cond.signal()
}
val results = awaitAll(*waiters)
val successCount = results.stream()
.map { ret -> if (ret) 1 else 0 }
.reduce { a, b -> a + b }
.get()
Assertions.assertEquals(1, successCount)
}
}
@Test
fun testNotifyAll() {
runBlocking {
val waiters = IntStream.range(0, 5)
.mapToObj { i ->
async<Boolean> {
lock.withLock {
val ret = cond.await(1.seconds)
ret
}
}
}
.collect(Collectors.toList())
.toTypedArray()
delay(100.milliseconds)
lock.withLock {
cond.signalAll()
}
val results = awaitAll(*waiters)
val successCount = results.stream()
.map { ret -> if (ret) 1 else 0 }
.reduce { a, b -> a + b }
.get()
Assertions.assertEquals(results.size, successCount)
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment