Instantly share code, notes, and snippets.

Embed
What would you like to do?
Better type-safety with units in Kotlin
// Here's how unit type-safety can prevent you from performing invalid unit arithmetic
fun main(args: Array<String>) {
val oranges = Oranges(4)
val apples = Apples(4)
// Will result in compile time error:
println("Can't compare apples to oranges: ${apples == oranges}")
println("Can't add two unknown things to my apples: ${apples + 2}")
println("Can't divide apples by apples: ${apples / apples}")
// Valid syntax
println("Comparing apples to apples: ${apples == apples }") // Comparing apples to apples: true
println("Adding apples to apples: ${apples + apples}") // "Adding apples to apples: 8(Apples)"
println("Dividing apples by 4: ${apples / 4.0}") // "Dividing apples by 4: 1(Apples)"
}
// Here's how you declare your unit:
class Oranges(value: Int): IntUnit<Oranges>(value, { Oranges(it) })
class Apples(value: Int): IntUnit<Apples>(value, { Apples(it) })
// Here's the classes that make it happen:
abstract class Unit<TYPE: Number, THIS: Unit<TYPE, THIS>>(
protected val value: TYPE
): Number(), Comparable<THIS> {
@Suppress("UNCHECKED_CAST")
override fun equals(other: Any?) = javaClass.isInstance(other) && compareTo(other as THIS) == 0
override fun hashCode() = value.hashCode() xor javaClass.hashCode()
override fun toString() = "$value ${javaClass.simpleName}"
override fun toInt() = value.toInt()
override fun toLong() = value.toLong()
override fun toDouble() = value.toDouble()
override fun toFloat() = value.toFloat()
override fun toByte() = value.toByte()
override fun toChar() = value.toChar()
override fun toShort() = value.toShort()
}
abstract class IntUnit<T: IntUnit<T>>(
value: Int,
private val constructor: (Int) -> T
): Unit<Int, T>(value) {
override fun compareTo(other: T) = value.compareTo(other.value)
operator fun inc() = constructor(value.inc())
operator fun dec() = constructor(value.dec())
operator fun unaryMinus() = constructor(value.unaryMinus())
operator fun unaryPlus() = constructor(value.unaryPlus())
operator fun times(other: Byte) = constructor(value.times(other))
operator fun times(other: Double) = constructor(value.times(other).toInt())
operator fun times(other: Float) = constructor(value.times(other).toInt())
operator fun times(other: Int) = constructor(value.times(other))
operator fun times(other: Long) = constructor(value.times(other).toInt())
operator fun times(other: Short) = constructor(value.times(other))
operator fun div(other: Byte) = constructor(value.div(other))
operator fun div(other: Double) = constructor(value.div(other).toInt())
operator fun div(other: Float) = constructor(value.div(other).toInt())
operator fun div(other: Int) = constructor(value.div(other))
operator fun div(other: Long) = constructor(value.div(other).toInt())
operator fun div(other: Short) = constructor(value.div(other))
operator fun rem(other: Byte) = constructor(value.rem(other))
operator fun rem(other: Double) = constructor(value.rem(other).toInt())
operator fun rem(other: Float) = constructor(value.rem(other).toInt())
operator fun rem(other: Int) = constructor(value.rem(other))
operator fun rem(other: Long) = constructor(value.rem(other).toInt())
operator fun rem(other: Short) = constructor(value.rem(other))
operator fun plus(other: T) = constructor(value.plus(other.value))
operator fun minus(other: T) = constructor(value.minus(other.value))
}
abstract class LongUnit<T: LongUnit<T>>(
value: Long,
private val constructor: (Long) -> T
): Unit<Long, T>(value) {
override fun compareTo(other: T) = value.compareTo(other.value)
operator fun inc() = constructor(value.inc())
operator fun dec() = constructor(value.dec())
operator fun unaryMinus() = constructor(value.unaryMinus())
operator fun unaryPlus() = constructor(value.unaryPlus())
operator fun times(other: Byte) = constructor(value.times(other))
operator fun times(other: Double) = constructor(value.times(other).toLong())
operator fun times(other: Float) = constructor(value.times(other).toLong())
operator fun times(other: Int) = constructor(value.times(other))
operator fun times(other: Long) = constructor(value.times(other))
operator fun times(other: Short) = constructor(value.times(other))
operator fun div(other: Byte) = constructor(value.div(other))
operator fun div(other: Double) = constructor(value.div(other).toLong())
operator fun div(other: Float) = constructor(value.div(other).toLong())
operator fun div(other: Int) = constructor(value.div(other))
operator fun div(other: Long) = constructor(value.div(other))
operator fun div(other: Short) = constructor(value.div(other))
operator fun rem(other: Byte) = constructor(value.rem(other))
operator fun rem(other: Double) = constructor(value.rem(other).toLong())
operator fun rem(other: Float) = constructor(value.rem(other).toLong())
operator fun rem(other: Int) = constructor(value.rem(other))
operator fun rem(other: Long) = constructor(value.rem(other))
operator fun rem(other: Short) = constructor(value.rem(other))
operator fun plus(other: T) = constructor(value.plus(other.value))
operator fun minus(other: T) = constructor(value.minus(other.value))
}
abstract class DoubleUnit<T: DoubleUnit<T>>(
value: Double,
private val constructor: (Double) -> T
): Unit<Double, T>(value) {
override fun compareTo(other: T) = value.compareTo(other.value)
operator fun inc() = constructor(value.inc())
operator fun dec() = constructor(value.dec())
operator fun unaryMinus() = constructor(value.unaryMinus())
operator fun unaryPlus() = constructor(value.unaryPlus())
operator fun times(other: Byte) = constructor(value.times(other))
operator fun times(other: Double) = constructor(value.times(other))
operator fun times(other: Float) = constructor(value.times(other))
operator fun times(other: Int) = constructor(value.times(other))
operator fun times(other: Long) = constructor(value.times(other))
operator fun times(other: Short) = constructor(value.times(other))
operator fun div(other: Byte) = constructor(value.div(other))
operator fun div(other: Double) = constructor(value.div(other))
operator fun div(other: Float) = constructor(value.div(other))
operator fun div(other: Int) = constructor(value.div(other))
operator fun div(other: Long) = constructor(value.div(other))
operator fun div(other: Short) = constructor(value.div(other))
operator fun rem(other: Byte) = constructor(value.rem(other))
operator fun rem(other: Double) = constructor(value.rem(other))
operator fun rem(other: Float) = constructor(value.rem(other))
operator fun rem(other: Int) = constructor(value.rem(other))
operator fun rem(other: Long) = constructor(value.rem(other))
operator fun rem(other: Short) = constructor(value.rem(other))
operator fun plus(other: T) = constructor(value.plus(other.value))
operator fun minus(other: T) = constructor(value.minus(other.value))
}
abstract class FloatUnit<T: FloatUnit<T>>(
value: Float,
private val constructor: (Float) -> T
): Unit<Float, T>(value) {
override fun compareTo(other: T) = value.compareTo(other.value)
operator fun inc() = constructor(value.inc())
operator fun dec() = constructor(value.dec())
operator fun unaryMinus() = constructor(value.unaryMinus())
operator fun unaryPlus() = constructor(value.unaryPlus())
operator fun times(other: Byte) = constructor(value.times(other))
operator fun times(other: Double) = constructor(value.times(other).toFloat())
operator fun times(other: Float) = constructor(value.times(other))
operator fun times(other: Int) = constructor(value.times(other))
operator fun times(other: Long) = constructor(value.times(other))
operator fun times(other: Short) = constructor(value.times(other))
operator fun div(other: Byte) = constructor(value.div(other))
operator fun div(other: Double) = constructor(value.div(other).toFloat())
operator fun div(other: Float) = constructor(value.div(other))
operator fun div(other: Int) = constructor(value.div(other))
operator fun div(other: Long) = constructor(value.div(other))
operator fun div(other: Short) = constructor(value.div(other))
operator fun rem(other: Byte) = constructor(value.rem(other))
operator fun rem(other: Double) = constructor(value.rem(other).toFloat())
operator fun rem(other: Float) = constructor(value.rem(other))
operator fun rem(other: Int) = constructor(value.rem(other))
operator fun rem(other: Long) = constructor(value.rem(other))
operator fun rem(other: Short) = constructor(value.rem(other))
operator fun plus(other: T) = constructor(value.plus(other.value))
operator fun minus(other: T) = constructor(value.minus(other.value))
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment