Skip to content

Instantly share code, notes, and snippets.

@ZakTaccardi
Created February 2, 2019 04:45
Show Gist options
  • Star 3 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save ZakTaccardi/aa951827aa8c02c0754c9f4b649edeb3 to your computer and use it in GitHub Desktop.
Save ZakTaccardi/aa951827aa8c02c0754c9f4b649edeb3 to your computer and use it in GitHub Desktop.
Some Rx-style operators on ReceiveChannel<T>
@file:JvmName("RxChannelExtensions")
import kotlinx.coroutines.experimental.Dispatchers
import kotlinx.coroutines.experimental.GlobalScope
import kotlinx.coroutines.experimental.channels.ReceiveChannel
import kotlinx.coroutines.experimental.channels.consumeEach
import kotlinx.coroutines.experimental.channels.consumes
import kotlinx.coroutines.experimental.channels.consumesAll
import kotlinx.coroutines.experimental.channels.produce
import kotlinx.coroutines.experimental.launch
import kotlinx.coroutines.experimental.sync.Mutex
import kotlinx.coroutines.experimental.sync.withLock
import java.util.concurrent.atomic.AtomicInteger
import java.util.concurrent.atomic.AtomicReference
import kotlin.coroutines.experimental.CoroutineContext
import kotlin.coroutines.experimental.EmptyCoroutineContext
/**
* Execute a [sideEffect] when [E] emits.
*
* Equivalent to RxJava's `.doOnNext()` operator.
*/
fun <E> ReceiveChannel<E>.doOnNext(
context: CoroutineContext = Dispatchers.Unconfined,
sideEffect: (E) -> Unit
): ReceiveChannel<E> = GlobalScope.produce(context, onCompletion = consumes()) {
consumeEach {
sideEffect(it)
send(it)
}
}
/**
* Merges multiple [ReceiveChannel]s of the same type [T] into a single [ReceiveChannel]
*/
fun <T> merge(
vararg sources: ReceiveChannel<T>,
context: CoroutineContext = Dispatchers.Unconfined
): ReceiveChannel<T> =
GlobalScope.produce(context, onCompletion = consumesAll(*sources)) {
sources.forEach { source ->
launch { source.consumeEach { send(it) } }
}
}
/**
* Emit [E] only when it is not equal to the previous emission. The first emission will always be
* emitted. Use this to not emit the same value twice in a raw.
*
* `.equals()` equality comparison will be used.
*
* Equivalent to RxJava's `.distinctUntilChanged()` operator.
*/
fun <E> ReceiveChannel<E>.distinctUntilChanged(
context: CoroutineContext = Dispatchers.Unconfined
): ReceiveChannel<E> = GlobalScope.produce(context, onCompletion = consumes()) {
val last = AtomicReference<E>()
var wasInitialized = false
consumeEach { emission ->
if (!wasInitialized) {
// first emission
last.set(emission)
wasInitialized = true
send(emission)
} else {
// we have a previous emission to compare to
if (emission != last.get()) {
// a distinct value has appeared
last.set(emission)
send(emission)
}
}
}
}
/**
* Suppress the first [skipCount] items emitted by [this] [ReceiveChannel].
*/
fun <E> ReceiveChannel<E>.skip(
skipCount: Int,
context: CoroutineContext = Dispatchers.Unconfined
): ReceiveChannel<E> = GlobalScope.produce(context, onCompletion = consumes()) {
// TODO replace with actor to guarentee thread safety
val skipped = AtomicInteger(0)
val mutex = Mutex()
consumeEach { emission: E ->
mutex.withLock {
if (skipped.get() >= skipCount) {
send(emission)
} else {
// emission was skipped
skipped.incrementAndGet()
}
}
}
}
/**
* A combine latest that takes in 3 sources and runs a [combineFunction] over their latest emissions
* to emit [R]
*/
fun <A : Any?, B : Any?, C : Any?, R> combineLatest(
sourceA: ReceiveChannel<A>,
sourceB: ReceiveChannel<B>,
sourceC: ReceiveChannel<C>,
context: CoroutineContext = Dispatchers.Unconfined,
combineFunction: suspend (A, B, C) -> R
): ReceiveChannel<R> = GlobalScope.produce(context, onCompletion = consumesAll(sourceA, sourceB, sourceC)) {
val latestA = AtomicReference<A>()
val latestB = AtomicReference<B>()
val latestC = AtomicReference<C>()
var aInitialized = false
var bInitialized = false
var cInitialized = false
val mutex = Mutex()
suspend fun combineAndSendIfInitialized() {
if (aInitialized && bInitialized && cInitialized) {
send(combineFunction(latestA.get(), latestB.get(), latestC.get()))
}
}
launch(coroutineContext) {
sourceA.consumeEach { a ->
mutex.withLock {
latestA.set(a)
aInitialized = true
combineAndSendIfInitialized()
}
}
}
launch(coroutineContext) {
sourceB.consumeEach { b ->
mutex.withLock {
latestB.set(b)
bInitialized = true
combineAndSendIfInitialized()
}
}
}
launch(coroutineContext) {
sourceC.consumeEach { C ->
mutex.withLock {
latestC.set(C)
cInitialized = true
combineAndSendIfInitialized()
}
}
}
}
/**
* A combine latest that takes in 3 sources and runs a [combineFunction] over their latest emissions
* to emit [R]
*
* // TODO add tests
*/
fun <A : Any?, B : Any?, C : Any?, D : Any?, R> combineLatest(
sourceA: ReceiveChannel<A>,
sourceB: ReceiveChannel<B>,
sourceC: ReceiveChannel<C>,
sourceD: ReceiveChannel<D>,
context: CoroutineContext = Dispatchers.Unconfined,
combineFunction: suspend (A, B, C, D) -> R
): ReceiveChannel<R> = GlobalScope.produce(
context, onCompletion = consumesAll(sourceA, sourceB, sourceC, sourceD)
) {
val latestA = AtomicReference<A>()
val latestB = AtomicReference<B>()
val latestC = AtomicReference<C>()
val latestD = AtomicReference<D>()
var aInitialized = false
var bInitialized = false
var cInitialized = false
var dInitialized = false
val mutex = Mutex()
suspend fun combineAndSendIfInitialized() {
if (aInitialized && bInitialized && cInitialized && dInitialized) {
send(combineFunction(latestA.get(), latestB.get(), latestC.get(), latestD.get()))
}
}
launch(coroutineContext) {
sourceA.consumeEach { a ->
mutex.withLock {
latestA.set(a)
aInitialized = true
combineAndSendIfInitialized()
}
}
}
launch(coroutineContext) {
sourceB.consumeEach { b ->
mutex.withLock {
latestB.set(b)
bInitialized = true
combineAndSendIfInitialized()
}
}
}
launch(coroutineContext) {
sourceC.consumeEach { C ->
mutex.withLock {
latestC.set(C)
cInitialized = true
combineAndSendIfInitialized()
}
}
}
launch(coroutineContext) {
sourceD.consumeEach { D ->
mutex.withLock {
latestD.set(D)
dInitialized = true
combineAndSendIfInitialized()
}
}
}
}
/**
* Execute a [combineFunction] and emit its result [R] over the latest value received by [A] and [B].
*
* A combine latest that takes in 2 sources and runs a [combineFunction] over their latest emissions
* to emit [R]
*
* This will not emit until [A] and [B] have each emitted at least once.
*
* Equivalent to RxJava's `.combineLatest()` operator
*/
fun <A : Any?, B : Any?, R> combineLatest(
sourceA: ReceiveChannel<A>,
sourceB: ReceiveChannel<B>,
context: CoroutineContext = Dispatchers.Unconfined,
combineFunction: suspend (A, B) -> R
): ReceiveChannel<R> = GlobalScope.produce(context, onCompletion = consumesAll(sourceA, sourceB)) {
val latestA = AtomicReference<A>()
val latestB = AtomicReference<B>()
var aInitialized = false
var bInitialized = false
val mutex = Mutex()
suspend fun combineAndSendIfInitialized() {
if (aInitialized && bInitialized) {
send(combineFunction(latestA.get(), latestB.get()))
}
}
launch(coroutineContext) {
sourceA.consumeEach { A ->
mutex.withLock {
latestA.set(A)
aInitialized = true
combineAndSendIfInitialized()
}
}
}
launch(coroutineContext) {
sourceB.consumeEach { b ->
mutex.withLock {
latestB.set(b)
bInitialized = true
combineAndSendIfInitialized()
}
}
}
}
import com.nhaarman.mockitokotlin2.InOrderOnType
import com.nhaarman.mockitokotlin2.mock
import com.nhaarman.mockitokotlin2.times
import com.nhaarman.mockitokotlin2.verify
import com.nhaarman.mockitokotlin2.verifyNoMoreInteractions
import com.nhaarman.mockitokotlin2.verifyZeroInteractions
import kotlinx.coroutines.experimental.CoroutineScope
import kotlinx.coroutines.experimental.Dispatchers
import kotlinx.coroutines.experimental.channels.ArrayBroadcastChannel
import kotlinx.coroutines.experimental.channels.BroadcastChannel
import kotlinx.coroutines.experimental.channels.consumeEach
import kotlinx.coroutines.experimental.channels.sendBlocking
import kotlinx.coroutines.experimental.delay
import kotlinx.coroutines.experimental.launch
import kotlinx.coroutines.experimental.runBlocking
import org.junit.Before
import org.junit.Test
import kotlin.coroutines.experimental.CoroutineContext
class RxChannelExtensionsTest {
private lateinit var scope: CoroutineScope
@Before
fun setUp() {
scope = CoroutineScope(Dispatchers.Unconfined)
}
/**
* Tests [doOnNext]
*/
@Test
fun operator_doOnNext() {
val mockObserver = mock<(Int) -> Unit> { }
val logger: (Int) -> Unit = { println(it) }
val source = BroadcastChannel<Int>(10)
fun verify(emission: Int, times: Int = 1) {
verify(mockObserver, times(times)).invoke(emission)
}
scope.launch {
source.openSubscription()
.doOnNext {
logger(it)
mockObserver(it)
}
.consumeEach {}
}
source.sendBlocking(1)
source.sendBlocking(2)
source.sendBlocking(3)
verify(1)
verify(2)
verify(3)
verifyNoMoreInteractions(mockObserver)
}
/**
* Tests [merge]
*/
@Test
fun operator_merge() {
val mockObserver = mock<(Int) -> Unit> { }
val logger: (Int) -> Unit = { println(it) }
val source1 = BroadcastChannel<Int>(10)
val source2 = BroadcastChannel<Int>(10)
fun verify(emission: Int, times: Int = 1) {
verify(mockObserver, times(times)).invoke(emission)
}
scope.launch {
merge(
source1.openSubscription(),
source2.openSubscription()
)
.consumeEach {
mockObserver(it)
logger(it)
}
}
verifyZeroInteractions(mockObserver)
source1.sendBlocking(1)
verify(1)
source2.sendBlocking(2)
verify(2)
source1.sendBlocking(3)
verify(3)
source2.sendBlocking(4)
verify(4)
verifyNoMoreInteractions(mockObserver)
}
/**
* Test for [distinctUntilChanged]
*/
@Test
fun operator_distinctUntilChanged() {
val mockObserver = mock<(Int) -> Unit> { }
val logger: (Int) -> Unit = { println(it) }
val source = BroadcastChannel<Int>(1)
fun verify(emission: Int, times: Int = 1) {
verify(mockObserver, times(times)).invoke(emission)
}
scope.launch {
source.openSubscription()
.distinctUntilChanged()
.consumeEach {
logger(it)
mockObserver(it)
}
}
source.sendBlocking(0)
source.sendBlocking(1)
source.sendBlocking(1)
source.sendBlocking(2)
verify(0)
verify(1)
verify(2)
verifyNoMoreInteractions(mockObserver)
source.sendBlocking(1)
verify(1, times = 2)
verifyZeroInteractions(mockObserver)
}
@Test
fun operator_skip() {
val mockObserver = mock<(Int) -> Unit> { }
val source = BroadcastChannel<Int>(5)
scope.launch {
source.openSubscription()
.skip(1)
.doOnNext { println(it) }
.consumeEach(mockObserver)
}
runBlocking {
source.send(1)
source.send(2)
source.send(3)
}
val inOrder = InOrderOnType(mockObserver)
runBlocking {
// first invocation is skipped
inOrder.verify(mockObserver, times(1)).invoke(2)
inOrder.verify(mockObserver, times(1)).invoke(3)
}
verifyNoMoreInteractions(mockObserver)
}
@Test
fun operator_combineLatest() {
fun runTest(context: CoroutineContext) {
println("Running test for $context")
val mockObserver = mock<(Pair<String, Int>) -> Unit> { }
val logger: (Pair<String, Int>) -> Unit = { println(it) }
val sourceNames = ArrayBroadcastChannel<String>(1000)
val sourceAges = ArrayBroadcastChannel<Int>(1000)
val sourceAChannel = sourceNames.openSubscription()
val sourceBChannel = sourceAges.openSubscription()
scope.launch {
combineLatest(
sourceAChannel,
sourceBChannel
) { name, age -> Pair(name, age) }
.consumeEach {
logger(it)
mockObserver(it)
}
}
val job = scope.launch {
sourceNames.send("Zak")
delay(10)
sourceNames.send("Grace")
delay(10)
sourceAges.send(24)
delay(10)
sourceNames.send("Kelly")
delay(10)
sourceAges.send(25)
delay(10)
sourceNames.send("Jack")
delay(10)
sourceAges.send(27)
delay(10)
sourceAges.send(28) // happy birthday
delay(10)
}
runBlocking {
job.join()
}
val inOrder = InOrderOnType(mockObserver)
inOrder.verify(mockObserver, times(1)).invoke(Pair("Grace", 24))
inOrder.verify(mockObserver, times(1)).invoke(Pair("Kelly", 24))
inOrder.verify(mockObserver, times(1)).invoke(Pair("Kelly", 25))
inOrder.verify(mockObserver, times(1)).invoke(Pair("Jack", 25))
inOrder.verify(mockObserver, times(1)).invoke(Pair("Jack", 27))
inOrder.verify(mockObserver, times(1)).invoke(Pair("Jack", 28))
verifyNoMoreInteractions(mockObserver)
println("Test passed for $context")
}
runTest(Dispatchers.Unconfined)
}
@Test
fun operator_combineLatest3() {
fun runTest(context: CoroutineContext) {
println("Running test for $context")
val mockObserver = mock<(Triple<Int, String, Boolean>) -> Unit> { }
val logger: (Triple<Int, String, Boolean>) -> Unit = { println(it) }
val sourceInts = ArrayBroadcastChannel<Int>(1000)
val sourceStrings = ArrayBroadcastChannel<String>(1000)
val sourceBooleans = ArrayBroadcastChannel<Boolean>(1000)
val sourceAChannel = sourceInts.openSubscription()
val sourceBChannel = sourceStrings.openSubscription()
val sourceCChannel = sourceBooleans.openSubscription()
scope.launch {
combineLatest(
sourceAChannel,
sourceBChannel,
sourceCChannel
) { int, string, boolean -> Triple(int, string, boolean) }
.consumeEach {
logger(it)
mockObserver(it)
}
}
val job = scope.launch {
sourceInts.sendBlocking(0)
delay(10)
sourceStrings.sendBlocking("0")
delay(10)
sourceBooleans.sendBlocking(false)
delay(10)
sourceInts.sendBlocking(1)
delay(10)
sourceStrings.sendBlocking("1")
delay(10)
sourceBooleans.sendBlocking(true)
delay(10)
}
runBlocking {
job.join()
}
val inOrder = InOrderOnType(mockObserver)
fun verify(int: Int, string: String, boolean: Boolean) {
inOrder.verify(mockObserver, times(1)).invoke(Triple(int, string, boolean))
}
verify(0, "0", false)
verify(1, "0", false)
verify(1, "1", false)
verify(1, "1", true)
verifyNoMoreInteractions(mockObserver)
println("Test passed for $context")
}
runTest(Dispatchers.Unconfined)
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment