Skip to content

Instantly share code, notes, and snippets.

@konrad-kaminski
Last active July 14, 2022 23:50
Show Gist options
  • Star 5 You must be signed in to star a gist
  • Fork 1 You must be signed in to fork a gist
  • Save konrad-kaminski/d7808070f4218349674589e1dc97264a to your computer and use it in GitHub Desktop.
Save konrad-kaminski/d7808070f4218349674589e1dc97264a to your computer and use it in GitHub Desktop.
CountDownLatch naive implementation
/*
* Copyright 2016-2017 JetBrains s.r.o.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package kotlinx.coroutines.experimental.sync
import kotlinx.coroutines.experimental.CancellationException
import kotlinx.coroutines.experimental.suspendCancellableCoroutine
import kotlinx.coroutines.experimental.withTimeoutOrNull
import java.util.concurrent.TimeUnit
import java.util.concurrent.locks.ReentrantLock
import kotlin.concurrent.withLock
import kotlin.coroutines.experimental.Continuation
/**
* Equivalent of [java.util.concurrent.CountDownLatch] for coroutines.
*/
interface CountDownLatch {
/**
* Decrements the count of the latch, resuming all suspended coroutines if
* the count reaches zero.
*
* If the current count is greater than zero then it is decremented.
* If the new count is zero then all suspended coroutines are resumed.
*
* If the current count equals zero then nothing happens.
*/
fun countDown()
/**
* Returns the current count.
*
* This method is typically used for debugging and testing purposes.
*
* @return the current count
*/
fun getCount(): Long
/**
* Causes the current coroutine to suspend until the latch has counted down to
* zero, unless the couroutine is cancelled.
*
* If the current count is zero then this method returns immediately.
*
* If the current count is greater than zero then the current
* coroutine is suspended and awaits until one of two things happen:
*
* * The count reaches zero due to invocations of the [countDown] method; or
* * the coroutine is cancelled.
*
* If the current coroutine:
*
* * is already cancelled; or
* * is cancelled while waiting,
*
* then [CancellationException] is thrown.
*
* @throws CancellationException if the current coroutine is cancelled
* while waiting or is already cancelled
*/
@Throws(CancellationException::class)
suspend fun await()
/**
* Causes the current coroutine to suspend until the latch has counted down to
* zero, unless the couroutine is cancelled.
*
* If the current count is zero then this method returns immediately.
*
* If the current count is greater than zero then the current
* coroutine is suspended and awaits until one of two things happen:
*
* * The count reaches zero due to invocations of the [countDown] method; or
* * the coroutine is cancelled; or
* * the specified waiting time elapses.
*
* If the current coroutine:
*
* * is already cancelled; or
* * is cancelled while waiting,
*
* then [CancellationException] is thrown.
*
* If the specified waiting time elapses then the value `false`
* is returned. If the time is less than or equal to zero, the method
* will not wait at all.
*
* @param timeout the maximum time to wait
* @param unit the time unit of the `timeout` argument
* @return `true` if the count reached zero and `false`
* if the waiting time elapsed before the count reached zero
* @throws CancellationException if the current coroutine is cancelled
* while waiting or is already cancelled
*/
@Throws(CancellationException::class)
suspend fun await(time: Long, unit: TimeUnit = TimeUnit.MILLISECONDS): Boolean
/**
* Factory for [CountDownLatch] instances.
*/
companion object {
/**
* Creates new [CountDownLatch] instance.
*
* @param initialCount initial count of the latch.
*/
operator fun invoke(initialCount: Int): CountDownLatch = CountDownLatchImpl(initialCount)
}
}
internal class CountDownLatchImpl(initialCount: Int) : CountDownLatch {
private var count = initialCount
private val continuations = mutableListOf<Continuation<Unit>>() // should probably use LockFreeLinkedListNode instead
private val lock = ReentrantLock()
init {
if (initialCount < 0) {
throw IllegalArgumentException("initialCount < 0")
}
}
override fun countDown() {
val doResume = lock.withLock {
count != 0 && (--count == 0)
}
if (doResume) {
continuations.forEach {
it.resume(Unit)
}
continuations.clear()
}
}
override fun getCount() = lock.withLock { count.toLong() }
suspend override fun await(time: Long, unit: TimeUnit): Boolean =
withTimeoutOrNull(time, unit) { await() } != null
suspend override fun await() {
var locked = true
lock.lock()
try {
if (count > 0) {
suspendCancellableCoroutine<Unit>(true) { cont ->
continuations += cont
cont.initCancellability()
cont.invokeOnCompletion {
if (cont.isCancelled) {
lock.withLock {
continuations -= cont
}
}
}
lock.unlock()
locked = false
}
}
}
finally {
if (locked) {
lock.unlock()
}
}
}
}
/*
* Copyright 2016-2017 JetBrains s.r.o.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package kotlinx.coroutines.experimental.sync
import kotlinx.coroutines.experimental.TestBase
import kotlinx.coroutines.experimental.launch
import kotlinx.coroutines.experimental.runBlocking
import kotlinx.coroutines.experimental.yield
import org.junit.Assert.assertEquals
import org.junit.Test
class CountDownLatchTest : TestBase() {
@Test
fun testSimple() = runBlocking {
val latch = CountDownLatch(2)
expect(1)
launch(context) {
expect(4)
latch.await() // suspends
expect(7) // now latch is down
}
expect(2)
latch.countDown()
expect(3)
yield()
expect(5)
latch.countDown()
expect(6)
yield()
finish(8)
}
@Test
fun countDownTest() {
val latch = CountDownLatch(3)
assertEquals(3, latch.getCount())
latch.countDown()
assertEquals(2, latch.getCount())
latch.countDown()
assertEquals(1, latch.getCount())
latch.countDown()
assertEquals(0, latch.getCount())
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment