Skip to content

Instantly share code, notes, and snippets.

@gmazzo
Last active March 6, 2020 08:47
Show Gist options
  • Save gmazzo/6caf8e83f15c060db61b362178603dfa to your computer and use it in GitHub Desktop.
Save gmazzo/6caf8e83f15c060db61b362178603dfa to your computer and use it in GitHub Desktop.
A `Completable` class implementing a `Semaphore` behavior
import io.reactivex.Completable
import io.reactivex.Completable.defer
import io.reactivex.CompletableObserver
import java.util.concurrent.atomic.AtomicReference
class CompletableSemaphore(
private val hasPermission: () -> Boolean,
private val acquirePermission: Completable
) : Completable() {
private val complete = complete()
private val status = AtomicReference(complete)
private val handler = defer {
if (!hasPermission()) {
status.compareAndSet(complete, acquirePermission
.doFinally { status.set(complete) }
.cache())
}
return@defer status.get()
}
override fun subscribeActual(observer: CompletableObserver) {
handler.subscribe(observer)
}
}
import io.reactivex.Completable
import io.reactivex.schedulers.Schedulers
import junit.framework.Assert.assertEquals
import org.junit.After
import org.junit.Assert.assertTrue
import org.junit.Test
import test_utils.BaseTest
import java.util.concurrent.ConcurrentSkipListSet
import java.util.concurrent.CountDownLatch
import java.util.concurrent.Executors
import java.util.concurrent.ThreadFactory
import java.util.concurrent.atomic.AtomicInteger
import kotlin.random.Random
class CompletableSemaphoreTest {
private val counter = AtomicInteger(0)
private val acquireCount = AtomicInteger(0)
private val semaphore by lazy {
CompletableSemaphore(
hasPermission = { !counter.get().shouldAcquire },
acquirePermission = Completable.fromAction {
val value = counter.get()
val count = acquireCount.incrementAndGet()
println("acquirePermission: ${Thread.currentThread().name}, counter=$value, acquireCount=$count")
assertTrue("value=$value, shouldAcquire=${value.shouldAcquire}", value.shouldAcquire)
})
}
private val scheduler = Schedulers.from(Executors.newFixedThreadPool(4, TF()))
@Test(timeout = 10000)
fun multithreadingTest() {
val lock = CountDownLatch(200)
val threads = ConcurrentSkipListSet<String>()
(1..lock.count)
.map {
Completable.fromAction {
threads.add(threadName)
val value = counter.incrementAndGet()
sleep()
println("run#$it: $threadName, counter=$value, acquireCount=${acquireCount.get()}")
}
}
.map { semaphore.andThen(it) }
.map { it.doOnComplete { lock.countDown() } }
.map { it.doOnError { ex -> lock.unlock(); throw ex } }
.forEach {
it.subscribeOn(scheduler).subscribe()
sleep()
}
lock.await()
assertEquals((1..4).map { "thread$it" }.toSet(), threads)
assertEquals(20, acquireCount.get())
}
@After
fun tearDown() {
scheduler.shutdown()
}
private val Int.shouldAcquire
get() = this % 10 == 0
private val threadName
get() = Thread.currentThread().name
private fun CountDownLatch.unlock() =
(0..count).forEach { _ -> countDown() }
private fun sleep() =
Thread.sleep(Random.nextLong(10) + 10)
private class TF : ThreadFactory {
private val number = AtomicInteger(0)
override fun newThread(r: Runnable) = Thread(r, "thread${number.incrementAndGet()}")
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment