Last active
September 24, 2023 07:06
-
-
Save fathonyfath/6b791677f3732c13f456e1661796b9a5 to your computer and use it in GitHub Desktop.
An implementation of saved state registry that automatically register the value so it is easier to use on Activity and Fragment
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import android.os.Bundle | |
import android.view.View | |
import android.widget.Button | |
import android.widget.TextView | |
import androidx.fragment.app.Fragment | |
class SampleFragment : Fragment(R.layout.fragment_sample) { | |
private lateinit var countLabel: TextView | |
private lateinit var decrement: Button | |
private lateinit var increment: Button | |
private var count by savedStateHolder("count") { 0 } | |
override fun onViewCreated(view: View, savedInstanceState: Bundle?) { | |
super.onViewCreated(view, savedInstanceState) | |
countLabel = view.findViewById(R.id.count_label) | |
decrement = view.findViewById(R.id.decrement) | |
increment = view.findViewById(R.id.increment) | |
countLabel.text = count.toString() | |
increment.setOnClickListener { | |
countLabel.text = (++count).toString() | |
} | |
decrement.setOnClickListener { | |
countLabel.text = (--count).toString() | |
} | |
} | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import android.os.Build | |
import android.os.Bundle | |
import android.os.IBinder | |
import android.os.Parcelable | |
import android.util.Size | |
import android.util.SizeF | |
import androidx.activity.ComponentActivity | |
import androidx.core.os.bundleOf | |
import androidx.fragment.app.Fragment | |
import androidx.lifecycle.DefaultLifecycleObserver | |
import androidx.lifecycle.LifecycleOwner | |
import androidx.savedstate.SavedStateRegistry | |
import androidx.savedstate.SavedStateRegistryOwner | |
import java.io.Serializable | |
import kotlin.properties.ReadWriteProperty | |
import kotlin.reflect.KProperty | |
@Suppress("UNCHECKED_CAST", "DEPRECATION") | |
class SavedStateHolder<T> internal constructor( | |
private val name: String, | |
savedStateRegistryOwner: SavedStateRegistryOwner, | |
initialValueCreator: () -> T, | |
) : SavedStateRegistry.SavedStateProvider, ReadWriteProperty<Any, T> { | |
private val lifecycle = savedStateRegistryOwner.lifecycle | |
private val savedStateRegistry = savedStateRegistryOwner.savedStateRegistry | |
private var value: T = initialValueCreator() | |
init { | |
lifecycle.addObserver(object : DefaultLifecycleObserver { | |
override fun onCreate(owner: LifecycleOwner) { | |
val savedBundle = | |
savedStateRegistryOwner.savedStateRegistry.consumeRestoredStateForKey(name) | |
if (savedBundle != null) { | |
value = if (EnableTypeSafeUnwrap) { | |
val type = SupportedTypes.entries[savedBundle.getInt(TypeKey)] | |
savedBundle.get(type, ValueKey) as T | |
} else { | |
savedBundle.get(ValueKey) as T | |
} | |
} | |
tryToRegister() | |
} | |
}) | |
} | |
override fun saveState(): Bundle { | |
return bundleOf( | |
TypeKey to getSupportedTypeForValue(value).ordinal, | |
ValueKey to value | |
) | |
} | |
override fun getValue(thisRef: Any, property: KProperty<*>): T { | |
return value | |
} | |
override fun setValue(thisRef: Any, property: KProperty<*>, value: T) { | |
this.value = value | |
} | |
private fun tryToRegister() { | |
if (savedStateRegistry.getSavedStateProvider(name) == null) { | |
savedStateRegistry.registerSavedStateProvider(name, this) | |
} | |
} | |
private fun Bundle.get(type: SupportedTypes, key: String): Any? { | |
return when (type) { | |
SupportedTypes.Null -> getString(key) | |
SupportedTypes.Boolean -> getBoolean(key) | |
SupportedTypes.Byte -> getByte(key) | |
SupportedTypes.Char -> getChar(key) | |
SupportedTypes.Double -> getDouble(key) | |
SupportedTypes.Float -> getFloat(key) | |
SupportedTypes.Int -> getInt(key) | |
SupportedTypes.Long -> getLong(key) | |
SupportedTypes.Short -> getShort(key) | |
SupportedTypes.Bundle -> getBundle(key) | |
SupportedTypes.CharSequence -> getCharSequence(key) | |
SupportedTypes.Parcelable -> getParcelable(key) | |
SupportedTypes.BooleanArray -> getBooleanArray(key) | |
SupportedTypes.ByteArray -> getByteArray(key) | |
SupportedTypes.CharArray -> getCharArray(key) | |
SupportedTypes.DoubleArray -> getDoubleArray(key) | |
SupportedTypes.FloatArray -> getFloatArray(key) | |
SupportedTypes.IntArray -> getIntArray(key) | |
SupportedTypes.LongArray -> getLongArray(key) | |
SupportedTypes.ShortArray -> getShortArray(key) | |
SupportedTypes.ParcelableArray -> getParcelableArray(key) | |
SupportedTypes.StringArray -> getStringArray(key) | |
SupportedTypes.CharSequenceArray -> getCharSequenceArray(key) | |
SupportedTypes.Serializable -> getSerializable(key) | |
SupportedTypes.IBinder -> if (Build.VERSION.SDK_INT >= 18) getBinder(key) else null | |
SupportedTypes.Size -> if (Build.VERSION.SDK_INT >= 21) getSize(key) else null | |
SupportedTypes.SizeF -> if (Build.VERSION.SDK_INT >= 21) getSizeF(key) else null | |
} | |
} | |
private fun getSupportedTypeForValue(value: Any?): SupportedTypes { | |
return when (value) { | |
null -> SupportedTypes.Null | |
is Boolean -> SupportedTypes.Boolean | |
is Byte -> SupportedTypes.Byte | |
is Char -> SupportedTypes.Char | |
is Double -> SupportedTypes.Double | |
is Float -> SupportedTypes.Float | |
is Int -> SupportedTypes.Int | |
is Long -> SupportedTypes.Long | |
is Short -> SupportedTypes.Short | |
is Bundle -> SupportedTypes.Bundle | |
is CharSequence -> SupportedTypes.CharSequence | |
is Parcelable -> SupportedTypes.Parcelable | |
is BooleanArray -> SupportedTypes.BooleanArray | |
is ByteArray -> SupportedTypes.ByteArray | |
is CharArray -> SupportedTypes.CharArray | |
is DoubleArray -> SupportedTypes.DoubleArray | |
is FloatArray -> SupportedTypes.FloatArray | |
is IntArray -> SupportedTypes.IntArray | |
is LongArray -> SupportedTypes.LongArray | |
is ShortArray -> SupportedTypes.ShortArray | |
is Array<*> -> { | |
val componentType = value::class.java.componentType!! | |
@Suppress("UNCHECKED_CAST") // Checked by reflection. | |
when { | |
Parcelable::class.java.isAssignableFrom(componentType) -> { | |
SupportedTypes.ParcelableArray | |
} | |
String::class.java.isAssignableFrom(componentType) -> { | |
SupportedTypes.StringArray | |
} | |
CharSequence::class.java.isAssignableFrom(componentType) -> { | |
SupportedTypes.CharSequenceArray | |
} | |
Serializable::class.java.isAssignableFrom(componentType) -> { | |
SupportedTypes.Serializable | |
} | |
else -> { | |
val valueType = componentType.canonicalName | |
throw IllegalArgumentException( | |
"Illegal value array type $valueType" | |
) | |
} | |
} | |
} | |
is Serializable -> SupportedTypes.Serializable | |
else -> { | |
if (Build.VERSION.SDK_INT >= 18 && value is IBinder) { | |
SupportedTypes.IBinder | |
} else if (Build.VERSION.SDK_INT >= 21 && value is Size) { | |
SupportedTypes.Size | |
} else if (Build.VERSION.SDK_INT >= 21 && value is SizeF) { | |
SupportedTypes.SizeF | |
} else { | |
val valueType = value.javaClass.canonicalName | |
throw IllegalArgumentException("Illegal value type $valueType") | |
} | |
} | |
} | |
} | |
companion object { | |
private const val TypeKey = "type" | |
private const val ValueKey = "value" | |
private const val EnableTypeSafeUnwrap = false | |
} | |
enum class SupportedTypes { | |
Null, | |
Boolean, Byte, Char, Double, Float, Int, Long, Short, | |
Bundle, CharSequence, Parcelable, | |
BooleanArray, ByteArray, CharArray, DoubleArray, FloatArray, IntArray, LongArray, ShortArray, | |
ParcelableArray, StringArray, CharSequenceArray, | |
Serializable, | |
IBinder, Size, SizeF | |
} | |
} | |
fun <T> ComponentActivity.savedStateHolder( | |
name: String, | |
creator: () -> T | |
): SavedStateHolder<T> { | |
return SavedStateHolder(name, this, creator) | |
} | |
fun <T> Fragment.savedStateHolder( | |
name: String, | |
creator: () -> T | |
): SavedStateHolder<T> { | |
return SavedStateHolder(name, this, creator) | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment