Navigation Menu

Skip to content

Instantly share code, notes, and snippets.

@Dico200
Last active October 16, 2018 14:41
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save Dico200/0de362aabcc0bfd6e6f6bf4f591a8055 to your computer and use it in GitHub Desktop.
Save Dico200/0de362aabcc0bfd6e6f6bf4f591a8055 to your computer and use it in GitHub Desktop.
Code to access CoroutineContext from non-suspend functions
package io.dico.parcels2.util
import kotlinx.coroutines.*
import kotlin.coroutines.CoroutineContext
/**
* Return the context of the coroutine that (directly or indirectly) called this function.
* Does not work when the thread is changed.
*
* Requires that the coroutine context contains [CoroutineContextStoringElement] to work.
*/
val localCoroutineContext: CoroutineContext?
get() = threadLocalWithContext.get()
// Use a type alias to avoid confusion inside [CoroutineContextStoringElement]
private typealias ElementType = CoroutineContext?
// Encapsulate the ThreadLocal by putting it on the top level.
private val threadLocalWithContext = ThreadLocal<ElementType>()
object CoroutineContextStoringElement : ThreadContextElement<ElementType> {
override val key: CoroutineContext.Key<CoroutineContextStoringElement>
get() = Key
private object Key : CoroutineContext.Key<CoroutineContextStoringElement>
override fun restoreThreadContext(context: CoroutineContext, oldState: ElementType) {
if (oldState != null) {
threadLocalWithContext.set(oldState)
} else {
threadLocalWithContext.remove()
}
}
override fun updateThreadContext(context: CoroutineContext): ElementType {
val oldState = threadLocalWithContext.get()
val newElement = context
threadLocalWithContext.set(newElement)
return oldState
}
}
/*
* Usage
*
* Output:
*
* 1 - Context element count: 3
*/
class App : CoroutineScope {
override val coroutineContext: CoroutineContext =
CoroutineContextStoringElement /* + Job() + Dispatchers.Main */
fun launchMyCoroutine() {
synchronousFunction(0)
val job = launch {
// It should print some information about the context only for this call
synchronousFunction(1)
}
synchronousFunction(2)
runBlocking {
synchronousFunction(3)
job.join()
synchronousFunction(4)
}
synchronousFunction(5)
}
fun synchronousFunction(num: Int) {
val context = localCoroutineContext
if (context != null) {
val elementCount = context.fold(0) { cur, elem -> cur + 1 }
println("$num - Context element count: $elementCount")
}
}
}
fun main() {
App().launchMyCoroutine()
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment