Skip to content

Instantly share code, notes, and snippets.

@fathonyfath
Last active September 24, 2023 07:06
Show Gist options
  • Save fathonyfath/6b791677f3732c13f456e1661796b9a5 to your computer and use it in GitHub Desktop.
Save fathonyfath/6b791677f3732c13f456e1661796b9a5 to your computer and use it in GitHub Desktop.
An implementation of saved state registry that automatically register the value so it is easier to use on Activity and Fragment
import android.os.Bundle
import android.view.View
import android.widget.Button
import android.widget.TextView
import androidx.fragment.app.Fragment
class SampleFragment : Fragment(R.layout.fragment_sample) {
private lateinit var countLabel: TextView
private lateinit var decrement: Button
private lateinit var increment: Button
private var count by savedStateHolder("count") { 0 }
override fun onViewCreated(view: View, savedInstanceState: Bundle?) {
super.onViewCreated(view, savedInstanceState)
countLabel = view.findViewById(R.id.count_label)
decrement = view.findViewById(R.id.decrement)
increment = view.findViewById(R.id.increment)
countLabel.text = count.toString()
increment.setOnClickListener {
countLabel.text = (++count).toString()
}
decrement.setOnClickListener {
countLabel.text = (--count).toString()
}
}
}
import android.os.Build
import android.os.Bundle
import android.os.IBinder
import android.os.Parcelable
import android.util.Size
import android.util.SizeF
import androidx.activity.ComponentActivity
import androidx.core.os.bundleOf
import androidx.fragment.app.Fragment
import androidx.lifecycle.DefaultLifecycleObserver
import androidx.lifecycle.LifecycleOwner
import androidx.savedstate.SavedStateRegistry
import androidx.savedstate.SavedStateRegistryOwner
import java.io.Serializable
import kotlin.properties.ReadWriteProperty
import kotlin.reflect.KProperty
@Suppress("UNCHECKED_CAST", "DEPRECATION")
class SavedStateHolder<T> internal constructor(
private val name: String,
savedStateRegistryOwner: SavedStateRegistryOwner,
initialValueCreator: () -> T,
) : SavedStateRegistry.SavedStateProvider, ReadWriteProperty<Any, T> {
private val lifecycle = savedStateRegistryOwner.lifecycle
private val savedStateRegistry = savedStateRegistryOwner.savedStateRegistry
private var value: T = initialValueCreator()
init {
lifecycle.addObserver(object : DefaultLifecycleObserver {
override fun onCreate(owner: LifecycleOwner) {
val savedBundle =
savedStateRegistryOwner.savedStateRegistry.consumeRestoredStateForKey(name)
if (savedBundle != null) {
value = if (EnableTypeSafeUnwrap) {
val type = SupportedTypes.entries[savedBundle.getInt(TypeKey)]
savedBundle.get(type, ValueKey) as T
} else {
savedBundle.get(ValueKey) as T
}
}
tryToRegister()
}
})
}
override fun saveState(): Bundle {
return bundleOf(
TypeKey to getSupportedTypeForValue(value).ordinal,
ValueKey to value
)
}
override fun getValue(thisRef: Any, property: KProperty<*>): T {
return value
}
override fun setValue(thisRef: Any, property: KProperty<*>, value: T) {
this.value = value
}
private fun tryToRegister() {
if (savedStateRegistry.getSavedStateProvider(name) == null) {
savedStateRegistry.registerSavedStateProvider(name, this)
}
}
private fun Bundle.get(type: SupportedTypes, key: String): Any? {
return when (type) {
SupportedTypes.Null -> getString(key)
SupportedTypes.Boolean -> getBoolean(key)
SupportedTypes.Byte -> getByte(key)
SupportedTypes.Char -> getChar(key)
SupportedTypes.Double -> getDouble(key)
SupportedTypes.Float -> getFloat(key)
SupportedTypes.Int -> getInt(key)
SupportedTypes.Long -> getLong(key)
SupportedTypes.Short -> getShort(key)
SupportedTypes.Bundle -> getBundle(key)
SupportedTypes.CharSequence -> getCharSequence(key)
SupportedTypes.Parcelable -> getParcelable(key)
SupportedTypes.BooleanArray -> getBooleanArray(key)
SupportedTypes.ByteArray -> getByteArray(key)
SupportedTypes.CharArray -> getCharArray(key)
SupportedTypes.DoubleArray -> getDoubleArray(key)
SupportedTypes.FloatArray -> getFloatArray(key)
SupportedTypes.IntArray -> getIntArray(key)
SupportedTypes.LongArray -> getLongArray(key)
SupportedTypes.ShortArray -> getShortArray(key)
SupportedTypes.ParcelableArray -> getParcelableArray(key)
SupportedTypes.StringArray -> getStringArray(key)
SupportedTypes.CharSequenceArray -> getCharSequenceArray(key)
SupportedTypes.Serializable -> getSerializable(key)
SupportedTypes.IBinder -> if (Build.VERSION.SDK_INT >= 18) getBinder(key) else null
SupportedTypes.Size -> if (Build.VERSION.SDK_INT >= 21) getSize(key) else null
SupportedTypes.SizeF -> if (Build.VERSION.SDK_INT >= 21) getSizeF(key) else null
}
}
private fun getSupportedTypeForValue(value: Any?): SupportedTypes {
return when (value) {
null -> SupportedTypes.Null
is Boolean -> SupportedTypes.Boolean
is Byte -> SupportedTypes.Byte
is Char -> SupportedTypes.Char
is Double -> SupportedTypes.Double
is Float -> SupportedTypes.Float
is Int -> SupportedTypes.Int
is Long -> SupportedTypes.Long
is Short -> SupportedTypes.Short
is Bundle -> SupportedTypes.Bundle
is CharSequence -> SupportedTypes.CharSequence
is Parcelable -> SupportedTypes.Parcelable
is BooleanArray -> SupportedTypes.BooleanArray
is ByteArray -> SupportedTypes.ByteArray
is CharArray -> SupportedTypes.CharArray
is DoubleArray -> SupportedTypes.DoubleArray
is FloatArray -> SupportedTypes.FloatArray
is IntArray -> SupportedTypes.IntArray
is LongArray -> SupportedTypes.LongArray
is ShortArray -> SupportedTypes.ShortArray
is Array<*> -> {
val componentType = value::class.java.componentType!!
@Suppress("UNCHECKED_CAST") // Checked by reflection.
when {
Parcelable::class.java.isAssignableFrom(componentType) -> {
SupportedTypes.ParcelableArray
}
String::class.java.isAssignableFrom(componentType) -> {
SupportedTypes.StringArray
}
CharSequence::class.java.isAssignableFrom(componentType) -> {
SupportedTypes.CharSequenceArray
}
Serializable::class.java.isAssignableFrom(componentType) -> {
SupportedTypes.Serializable
}
else -> {
val valueType = componentType.canonicalName
throw IllegalArgumentException(
"Illegal value array type $valueType"
)
}
}
}
is Serializable -> SupportedTypes.Serializable
else -> {
if (Build.VERSION.SDK_INT >= 18 && value is IBinder) {
SupportedTypes.IBinder
} else if (Build.VERSION.SDK_INT >= 21 && value is Size) {
SupportedTypes.Size
} else if (Build.VERSION.SDK_INT >= 21 && value is SizeF) {
SupportedTypes.SizeF
} else {
val valueType = value.javaClass.canonicalName
throw IllegalArgumentException("Illegal value type $valueType")
}
}
}
}
companion object {
private const val TypeKey = "type"
private const val ValueKey = "value"
private const val EnableTypeSafeUnwrap = false
}
enum class SupportedTypes {
Null,
Boolean, Byte, Char, Double, Float, Int, Long, Short,
Bundle, CharSequence, Parcelable,
BooleanArray, ByteArray, CharArray, DoubleArray, FloatArray, IntArray, LongArray, ShortArray,
ParcelableArray, StringArray, CharSequenceArray,
Serializable,
IBinder, Size, SizeF
}
}
fun <T> ComponentActivity.savedStateHolder(
name: String,
creator: () -> T
): SavedStateHolder<T> {
return SavedStateHolder(name, this, creator)
}
fun <T> Fragment.savedStateHolder(
name: String,
creator: () -> T
): SavedStateHolder<T> {
return SavedStateHolder(name, this, creator)
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment