Skip to content

Instantly share code, notes, and snippets.

@hrules6872
Last active April 2, 2023 17:03
Show Gist options
  • Star 2 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save hrules6872/db04c302088275a8e1c972a67b5d13f2 to your computer and use it in GitHub Desktop.
Save hrules6872/db04c302088275a8e1c972a67b5d13f2 to your computer and use it in GitHub Desktop.
Zustand 🐻 implementation for Kotlin :) https://github.com/pmndrs/zustand
/*
* Copyright (c) 2022. Héctor de Isidro - hrules6872
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
private val LocalStore: ProvidableCompositionLocal<Store<*>> = compositionLocalOf { error("Store not provided") }
@Composable
fun <STATE : State> StoreProvider(store: Store<STATE>, content: @Composable Store<STATE>.() -> Unit) {
CompositionLocalProvider(LocalStore provides store) {
store.content()
}
}
@Composable
@Suppress("UNCHECKED_CAST")
fun <STATE : State> store(): Store<STATE> = LocalStore.current as Store<STATE>
@Composable
inline fun <reified STATE : State> dispatcher(): (Effect<Store<STATE>, STATE>) -> Unit = store<STATE>().dispatcher()
@Composable
fun <STATE : State> Store<STATE>.dispatcher(): (Effect<Store<STATE>, STATE>) -> Unit = ::dispatch
@Composable
inline fun <reified STATE : State> multiDispatcher(): (Array<out Effect<Store<STATE>, STATE>>) -> Unit = store<STATE>().multiDispatcher()
@Composable
fun <STATE : State> Store<STATE>.multiDispatcher(): (Array<out Effect<Store<STATE>, STATE>>) -> Unit = ::dispatch
@Composable
inline fun <reified STATE : State, VALUE> subscribeTo(
crossinline selector: @DisallowComposableCalls STATE.() -> VALUE
): MutableState<VALUE> = store<STATE>().subscribeTo(selector)
@Composable
inline fun <STATE : State, VALUE> Store<STATE>.subscribeTo(
crossinline selector: @DisallowComposableCalls (STATE.() -> VALUE)
): MutableState<VALUE> {
val result: MutableState<VALUE> = remember { mutableStateOf(state.selector()) }
DisposableEffect(result) {
val unsubscribe: Unsubscribe = subscribe { state -> result.value = state.selector() }
onDispose(unsubscribe)
}
return result
}
@Composable
inline fun <reified STATE : State> subscribe(): MutableState<STATE> = store<STATE>().subscribe()
@Composable
fun <STATE : State> Store<STATE>.subscribe(): MutableState<STATE> {
val result: MutableState<STATE> = remember { mutableStateOf(state) }
DisposableEffect(result) {
val unsubscribe: Unsubscribe = subscribe { state -> result.value = state }
onDispose(unsubscribe)
}
return result
}
/*
* Copyright (c) 2022. Héctor de Isidro - hrules6872
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
@OptIn(ExperimentalCoroutinesApi::class)
@RunWith(AndroidJUnit4::class)
class ComposeStoreTest {
@get:Rule
val compose = createComposeRule()
@Test
fun test_initial_subscribeTo() = runTest {
val store = createTestStore(initialState = TestState(counter = 100, other = 200))
compose.setContent {
StoreProvider(store) {
val counter by subscribeTo { counter }
val other by subscribeTo { other }
assertThat(counter).isEqualTo(store.state.counter)
assertThat(other).isEqualTo(store.state.other)
}
}
compose.awaitIdle()
}
@Test
fun test_dispatch_and_subscribeTo() = runTest {
val store = createTestStore(initialState = TestState(counter = 100, other = 200))
compose.setContent {
StoreProvider(store) {
val counter by subscribeTo { counter }
val dispatch = dispatcher()
Button(
onClick = { dispatch(::incrementCounterByOneEffect) },
modifier = Modifier.testTag("button"),
) {
Text("$counter")
}
}
}
compose.awaitIdle()
assert(store)
}
@Test
fun test_dispatch_and_subscribe_to_all() = runTest {
val store = createTestStore(initialState = TestState(counter = 100, other = 200))
compose.setContent {
StoreProvider(store) {
val state by subscribe()
val dispatch = dispatcher()
Button(
onClick = { dispatch(::incrementCounterByOneEffect) },
modifier = Modifier.testTag("button"),
) {
Text("${state.counter}")
}
}
}
compose.awaitIdle()
assert(store)
}
private suspend fun assert(store: Store<TestState>) {
compose.onNodeWithTag("button").run {
val counter = store.state.counter
assertTextEquals("$counter")
performClick()
compose.awaitIdle()
assertTextEquals("${counter?.plus(1)}")
}
}
}
/*
* Copyright (c) 2022. Héctor de Isidro - hrules6872
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
@OptIn(ExperimentalCoroutinesApi::class)
class SameThreadEnforcedStoreTest {
@Test(expected = IllegalStateException::class)
fun `dispatch an Effect from a thread other than thread from which the SameThreadEnforcedStore was created will throw an IllegalStateException`() {
val sameThreadEnforcedStore = createSameThreadEnforcedStore(createTestStore())
val storeThread = currentThreadName()
runTest {
withContext(Dispatchers.Default) {
assertThat(storeThread).isNotEqualTo(currentThreadName())
sameThreadEnforcedStore.dispatch(::incrementCounterByOneEffect)
}
}
}
@Test
fun `dispatch an Effect from a thread other than thread from which the Store was created will not throw an exception`() {
val store = createTestStore()
val storeThread = currentThreadName()
runTest {
withContext(Dispatchers.Default) {
assertThat(storeThread).isNotEqualTo(currentThreadName())
store.dispatch(::incrementCounterByOneEffect)
}
}
}
}
private fun currentThreadName(): String = Thread.currentThread().name
/*
* Copyright (c) 2022. Héctor de Isidro - hrules6872
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
interface State
typealias STORE<STATE> = Store<STATE>
typealias Effect<STORE, STATE> = (store: STORE, mutate: Mutator<STATE>) -> Unit
typealias EffectWith<STORE, STATE, PARAMS> = (params: PARAMS, store: STORE, mutate: Mutator<STATE>) -> Unit
typealias Mutator<STATE> = (STATE) -> Unit
typealias Subscription<STATE> = (state: STATE) -> Unit
typealias Unsubscribe = () -> Unit
private class SubscriptionWithBinder<STATE>(val binder: Binder<STATE>, val subscription: Subscription<STATE>?) {
operator fun component1() = binder
operator fun component2() = subscription
}
typealias Binder<STATE> = STATE.() -> Any?
@Suppress("UnusedPrivateMember")
private fun <STATE> defaultBinder(bind: STATE): Int = bind.hashCode()
typealias SideEffect<STATE> = Effect<STORE<STATE>, STATE>
interface Store<STATE : State> {
val state: STATE
fun subscribe(subscription: Subscription<STATE>): Unsubscribe
fun subscribe(binder: Binder<STATE>, subscription: Subscription<STATE>): Unsubscribe
fun subscribe(vararg binders: Binder<STATE>, subscription: Subscription<STATE>): Unsubscribe
fun dispatch(effect: Effect<Store<STATE>, STATE>)
fun dispatch(vararg effects: Effect<Store<STATE>, STATE>)
fun <PARAMS> dispatch(effect: EffectWith<Store<STATE>, STATE, PARAMS>, params: PARAMS)
fun <PARAMS> dispatch(effects: Map<EffectWith<Store<STATE>, STATE, PARAMS>, PARAMS>)
fun rehydrate(newState: STATE)
}
fun <STATE : State> createStore(
initialState: STATE,
vararg sideEffects: SideEffect<STATE>
): Store<STATE> = object : Store<STATE> {
private val _state: AtomicReference<STATE> = AtomicReference(initialState)
override val state: STATE get() = _state.get()
private val subscriptions = CopyOnWriteArrayList<SubscriptionWithBinder<STATE>>()
override fun subscribe(subscription: Subscription<STATE>): Unsubscribe = subscribe(::defaultBinder, subscription)
override fun subscribe(binder: Binder<STATE>, subscription: Subscription<STATE>): Unsubscribe {
val subscriptionWithBinder = SubscriptionWithBinder(binder, subscription)
subscriptions.add(subscriptionWithBinder)
subscription(state)
return { subscriptions.remove(subscriptionWithBinder) }
}
override fun subscribe(vararg binders: Binder<STATE>, subscription: Subscription<STATE>): Unsubscribe {
fun consumer(subscriptions: List<Unsubscribe>): Unsubscribe = { subscriptions.forEach { it() } }
return consumer(binders.map { subscribe(it, subscription) })
}
override fun dispatch(effect: Effect<Store<STATE>, STATE>) = effect(this, ::mutate)
override fun dispatch(vararg effects: Effect<Store<STATE>, STATE>) = listOf(*effects).forEach { effect -> effect(this, ::mutate) }
override fun <PARAMS> dispatch(effect: EffectWith<Store<STATE>, STATE, PARAMS>, params: PARAMS) = effect(params, this, ::mutate)
override fun <PARAMS> dispatch(effects: Map<EffectWith<Store<STATE>, STATE, PARAMS>, PARAMS>) = effects.forEach { (effect, params) ->
effect(params, this, ::mutate)
}
private fun mutate(newState: STATE) {
fun emit(newState: STATE) = subscriptions
.reversed()
.filter { (binder, _) -> state.binder() !== newState.binder() }
.forEach { (_, subscription) -> subscription?.invoke(newState) }
.also { _state.set(newState) }
emit(newState)
sideEffects.forEach { sideEffect -> sideEffect(this, ::emit) }
}
override fun rehydrate(newState: STATE) = mutate(newState)
}
/*
* Copyright (c) 2022. Héctor de Isidro - hrules6872
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
class StoreTest {
private val store: Store<TestState> = createTestStore()
private var notified = false
@Test
fun `Effect dispatching`() {
fun effect(store: Store<TestState>, mutate: Mutator<TestState>) {
mutate(store.state.copy(counter = 100))
assertThat(store.state.counter).isEqualTo(100)
mutate(store.state.copy(counter = 200))
assertThat(store.state.counter).isEqualTo(200)
}
store.dispatch(::effect)
}
@Test
fun `Effect dispatching with params`() {
class Params(val value: Int)
fun effect(params: Params, store: Store<TestState>, mutate: Mutator<TestState>) {
mutate(store.state.copy(counter = 100 + params.value))
assertThat(store.state.counter).isEqualTo(100 + params.value)
mutate(store.state.copy(counter = 200 + params.value))
assertThat(store.state.counter).isEqualTo(200 + params.value)
}
store.dispatch(::effect, Params(1))
}
@Test
@Suppress("UnusedPrivateMember")
fun `Effect dispatching with SideEffects`() {
fun sideEffect(store: Store<TestState>, mutate: Mutator<TestState>) {
assertThat(store.state.counter).isEqualTo(100)
}
fun effect(store: Store<TestState>, mutate: Mutator<TestState>) {
mutate(store.state.copy(counter = 100))
}
createTestStore(TestState(), ::sideEffect).dispatch(::effect)
}
@Test
fun `get initial status upon subscription`() {
store.subscribe { state ->
notified = true
assertThat(state.counter).isNull()
}
assertThat(notified).isTrue
}
@Test
fun `Subscribe with no binding`() {
store.run {
subscribe { state ->
notified = true
// discard initial status
state.counter?.let { assertThat(state.counter).isEqualTo(1) }
}
dispatch(::incrementCounterByOneEffect)
}
assertThat(notified).isTrue
}
@Test
fun `Subscribe with binding`() {
store.run {
subscribe({ other }) { state ->
notified = true
// discard initial status
state.other?.let { assertThat(state.other).isEqualTo(1) }
assertThat(state.counter).isNull()
}
dispatch(::incrementOtherByOneEffect)
}
assertThat(notified).isTrue
}
@Test
fun `Subscribe with a list of bindings`() {
store.run {
subscribe({ other }, { counter }) { state ->
notified = true
state.other?.let { assertThat(state.other).isEqualTo(1) }
assertThat(state.counter).isNull()
}
dispatch(::incrementOtherByOneEffect)
}
assertThat(notified).isTrue
}
@Test
fun `rehydrate state`() {
store.run {
subscribe { state ->
notified = true
state.counter?.let { assertThat(state.counter).isEqualTo(100) }
}
rehydrate(TestState(counter = 100))
assertThat(state.counter).isEqualTo(100)
}
assertThat(notified).isTrue
}
}
/*
* Copyright (c) 2022. Héctor de Isidro - hrules6872
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
data class TestState(
val counter: Int? = null,
val other: Int? = null
) : State
fun createTestStore(initialState: TestState = TestState(), vararg sideEffects: SideEffect<TestState>) = createStore(
initialState = initialState,
sideEffects = sideEffects
)
fun incrementCounterByOneEffect(store: Store<TestState>, mutate: Mutator<TestState>) {
mutate(store.state.copy(counter = (store.state.counter ?: 0) + 1))
}
/*
* Copyright (c) 2022. Héctor de Isidro - hrules6872
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
data class TestState(
val counter: Int? = null,
val other: Int? = null
) : State
fun createTestStore(initialState: TestState = TestState(), vararg sideEffects: SideEffect<TestState>) = createStore(
initialState = initialState,
sideEffects = sideEffects
)
fun incrementCounterByOneEffect(store: Store<TestState>, mutate: Mutator<TestState>) {
mutate(store.state.copy(counter = (store.state.counter ?: 0) + 1))
}
fun incrementOtherByOneEffect(store: Store<TestState>, mutate: Mutator<TestState>) {
mutate(store.state.copy(other = (store.state.other ?: 0) + 1))
}
/*
* Copyright (c) 2022. Héctor de Isidro - hrules6872
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/**
* Threadsafe decorator that synchronizes access to each function
* This may have performance impact on JVM/Native.
*/
fun <STATE : State> createThreadSafeStore(store: Store<STATE>): ThreadSafeStore<STATE> = ThreadSafeStore(store)
fun <STATE : State> createThreadSafeStore(
initialState: STATE,
vararg sideEffects: SideEffect<STATE>
): ThreadSafeStore<STATE> = ThreadSafeStore(createStore(initialState, *sideEffects))
class ThreadSafeStore<STATE : State>(private val store: Store<STATE>) : Store<STATE> by store {
@get:Synchronized
override val state get() = synchronized(this) { store.state }
@Synchronized
override fun subscribe(subscription: Subscription<STATE>): Unsubscribe = synchronized(this) { store.subscribe(subscription) }
@Synchronized
override fun subscribe(binder: Binder<STATE>, subscription: Subscription<STATE>): Unsubscribe =
synchronized(this) { store.subscribe(binder, subscription) }
@Synchronized
override fun subscribe(vararg binders: Binder<STATE>, subscription: Subscription<STATE>): Unsubscribe =
synchronized(this) { store.subscribe(binders = binders, subscription) }
@Synchronized
override fun dispatch(effect: Effect<Store<STATE>, STATE>) = synchronized(this) { store.dispatch(effect) }
@Synchronized
override fun dispatch(vararg effects: Effect<Store<STATE>, STATE>) = synchronized(this) { store.dispatch(effects = effects) }
@Synchronized
override fun <PARAMS> dispatch(effect: EffectWith<Store<STATE>, STATE, PARAMS>, params: PARAMS) = synchronized(this) { store.dispatch(effect, params) }
@Synchronized
override fun <PARAMS> dispatch(effects: Map<EffectWith<Store<STATE>, STATE, PARAMS>, PARAMS>) = synchronized(this) { store.dispatch(effects) }
@Synchronized
override fun rehydrate(newState: STATE) = synchronized(this) { store.rehydrate(newState) }
}
/**
* Decorator for store whose functions can only be accessed from the same thread in which the Store was created
* Functions called from a thread other than thread from which the Store was created will throw an IllegalStateException
*/
fun <STATE : State> createSameThreadEnforcedStore(store: Store<STATE>): SameThreadEnforcedStore<STATE> = SameThreadEnforcedStore(store)
fun <STATE : State> createSameThreadEnforcedStore(
initialState: STATE,
vararg sideEffects: SideEffect<STATE>
): SameThreadEnforcedStore<STATE> = SameThreadEnforcedStore(createStore(initialState, *sideEffects))
class SameThreadEnforcedStore<STATE : State>(private val store: Store<STATE>) : Store<STATE> by store {
private val storeThreadName = currentThreadName()
override fun subscribe(subscription: Subscription<STATE>): Unsubscribe {
checkIsSameThread()
return store.subscribe(subscription)
}
override fun subscribe(binder: Binder<STATE>, subscription: Subscription<STATE>): Unsubscribe {
checkIsSameThread()
return store.subscribe(binder, subscription)
}
override fun subscribe(vararg binders: Binder<STATE>, subscription: Subscription<STATE>): Unsubscribe {
checkIsSameThread()
return store.subscribe(binders = binders, subscription)
}
override fun dispatch(effect: Effect<Store<STATE>, STATE>) {
checkIsSameThread()
return store.dispatch(effect)
}
override fun dispatch(vararg effects: Effect<Store<STATE>, STATE>) {
checkIsSameThread()
return store.dispatch(effects = effects)
}
override fun <PARAMS> dispatch(effect: EffectWith<Store<STATE>, STATE, PARAMS>, params: PARAMS) {
checkIsSameThread()
return store.dispatch(effect, params)
}
override fun <PARAMS> dispatch(effects: Map<EffectWith<Store<STATE>, STATE, PARAMS>, PARAMS>) {
checkIsSameThread()
return store.dispatch(effects)
}
override fun rehydrate(newState: STATE) {
checkIsSameThread()
return store.rehydrate(newState)
}
private fun currentThreadName(): String = Thread.currentThread().name.stripCoroutineName()
private fun isSameThread() = storeThreadName.equals(currentThreadName(), ignoreCase = true)
private fun checkIsSameThread() = check(isSameThread()) {
"""You may not call the store from a thread other than the thread on which it was created.
This store was created on: '$storeThreadName' and current thread is '${currentThreadName()}'
""".trimMargin()
}
/**
* Thread name may have '@coroutine#n' appended to it
* https://kotlinlang.org/docs/coroutine-context-and-dispatchers.html#debugging-using-logging
*/
private fun String.stripCoroutineName(): String {
val lastIndex = this.lastIndexOf('@')
return (if (lastIndex < 0) this else this.substring(0, lastIndex)).trim()
}
}
/*
* Copyright (c) 2022. Héctor de Isidro - hrules6872
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/**
* test results (performance impact example):
* thread safe: completed 100000 actions in 48 ms
* non-thread safe: completed 100000 actions in 24 ms
*/
@OptIn(ExperimentalCoroutinesApi::class)
class ThreadSafeStoreTest {
@Test
fun `massive Effect dispatching using a non-thread safe Store should return an unexpected state`() {
val store = createTestStore()
runTest {
withContext(Dispatchers.Default) {
massiveRunner(
"non-thread safe",
SYNC_NUM_COROUTINES,
SYNC_NUM_REPEATS
) { store.dispatch(::incrementCounterByOneEffect) }
assertThat(store.state.counter).isNotEqualTo(SYNC_NUM_COROUTINES * SYNC_NUM_REPEATS)
}
}
}
@Test
fun `massive Effect dispatching using a thread safe Store should return an expected state`() {
val threadSafeStore = createThreadSafeStore(createTestStore())
runTest {
withContext(Dispatchers.Default) {
massiveRunner(
"thread safe",
SYNC_NUM_COROUTINES,
SYNC_NUM_REPEATS
) { threadSafeStore.dispatch(::incrementCounterByOneEffect) }
assertThat(threadSafeStore.state.counter).isEqualTo(SYNC_NUM_COROUTINES * SYNC_NUM_REPEATS)
}
}
}
}
private suspend fun massiveRunner(tag: String, numCoroutines: Int, numRepeats: Int, block: suspend () -> Unit) {
val time = measureTimeMillis {
coroutineScope {
repeat(numCoroutines) {
launch {
repeat(numRepeats) { block() }
}
}
}
}
println("$tag: completed ${numCoroutines * numRepeats} actions in $time ms")
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment