Skip to content

Instantly share code, notes, and snippets.

@non
Last active June 10, 2022 03:11
Show Gist options
  • Star 5 You must be signed in to star a gist
  • Fork 1 You must be signed in to fork a gist
  • Save non/29f8d66036afca402f96 to your computer and use it in GitHub Desktop.
Save non/29f8d66036afca402f96 to your computer and use it in GitHub Desktop.
Scala implementation of 16-bit floating point numbers. Also, a test prgoram.
package half
import scala.math.{pow, round, signum}
import scala.util.Random.nextInt
import java.lang.{Float => JFloat}
/**
* Float16 represents 16-bit floating-point values.
*
* This type does not actually support arithmetic directly. The
* expected use case is to convert to Float to perform any actual
* arithmetic, then convert back to a Float16 if needed.
*
* Binary representation:
*
* sign (1 bit)
* |
* | exponent (5 bits)
* | |
* | | mantissa (10 bits)
* | | |
* x xxxxx xxxxxxxxxx
*
* Value interpretation (in order of precedence, with _ wild):
*
* 0 00000 0000000000 (positive) zero
* 1 00000 0000000000 negative zero
* _ 00000 __________ subnormal number
* _ 11111 0000000000 +/- infinity
* _ 11111 __________ not-a-number
* _ _____ __________ normal number
*
* For non-zero exponents, the mantissa has an implied leading 1 bit,
* so 10 bits of data provide 11 bits of precision for normal numbers.
*/
class Float16(val raw: Short) extends AnyVal { lhs =>
def isNaN: Boolean = (raw & 0x7fff) > 0x7c00
def nonNaN: Boolean = (raw & 0x7fff) <= 0x7c00
/**
* Returns if this is a zero value (positive or negative).
*/
def isZero: Boolean = (raw & 0x7fff) == 0
def nonZero: Boolean = (raw & 0x7fff) != 0
def isPositiveZero: Boolean = raw == -0x8000
def isNegativeZero: Boolean = raw == 0
def isInfinite: Boolean = (raw & 0x7fff) == 0x7c00
def isPositiveInfinity: Boolean = raw == 0x7c00
def isNegativeInfinity: Boolean = raw == 0xfc00
/**
* Whether this Float16 value is finite or not.
*
* For the purposes of this method, infinities and NaNs are
* considered non-finite. For those values it returns false and for
* all other values it returns true.
*/
def isFinite: Boolean = (raw & 0x7c00) != 0x7c00
/**
* Return the sign of a Float16 value as a Float.
*
* There are five possible return values:
*
* * NaN: the value is Float16.NaN (and has no sign)
* * -1F: the value is a non-zero negative number
* * -0F: the value is Float16.NegativeZero
* * 0F: the value is Float16.Zero
* * 1F: the value is a non-zero positive number
*
* PositiveInfinity and NegativeInfinity return their expected
* signs.
*/
def signum: Float =
if (raw == -0x8000) 0F
else if (raw == 0) -0F
else if ((raw & 0x7fff) > 0x7c00) Float.NaN
else ((raw >>> 14) & 2) - 1F
/**
* Reverse the sign of this Float16 value.
*
* This just involves toggling the sign bit with XOR.
*
* -Float16.NaN has no meaningful effect.
* -Float16.Zero returns Float16.NegativeZero.
*/
def unary_- : Float16 =
new Float16((raw ^ 0x8000).toShort)
def +(rhs: Float16): Float16 =
Float16.fromFloat(lhs.toFloat + rhs.toFloat)
def -(rhs: Float16): Float16 =
Float16.fromFloat(lhs.toFloat - rhs.toFloat)
def *(rhs: Float16): Float16 =
Float16.fromFloat(lhs.toFloat * rhs.toFloat)
def /(rhs: Float16): Float16 =
Float16.fromFloat(lhs.toFloat / rhs.toFloat)
def **(rhs: Int): Float16 =
Float16.fromFloat(pow(lhs.toFloat, rhs).toFloat)
def <(rhs: Float16): Boolean = {
if (lhs.raw == rhs.raw || lhs.isNaN || rhs.isNaN) return false
if (lhs.isZero && rhs.isZero) return false
val ls = (lhs.raw >>> 15) & 1
val rs = (rhs.raw >>> 15) & 1
if (ls < rs) return true
if (ls > rs) return false
val le = (lhs.raw >>> 10) & 31
val re = (rhs.raw >>> 10) & 31
if (le < re) return ls == 1
if (le > re) return ls == 0
val lm = lhs.raw & 1023
val rm = rhs.raw & 1023
if (ls == 1) lm < rm else rm < lm
}
def <=(rhs: Float16): Boolean = {
if (lhs.isNaN || rhs.isNaN) return false
if (lhs.isZero && rhs.isZero) return true
val ls = (lhs.raw >>> 15) & 1
val rs = (rhs.raw >>> 15) & 1
if (ls < rs) return true
if (ls > rs) return false
val le = (lhs.raw >>> 10) & 31
val re = (rhs.raw >>> 10) & 31
if (le < re) return ls == 1
if (le > re) return ls == 0
val lm = lhs.raw & 1023
val rm = rhs.raw & 1023
if (ls == 1) lm <= rm else rm <= lm
}
def >(rhs: Float16): Boolean =
!(lhs.isNaN || rhs.isNaN || lhs <= rhs)
def >=(rhs: Float16): Boolean =
!(lhs.isNaN || rhs.isNaN || lhs < rhs)
def ==(rhs: Float16): Boolean =
if (lhs.isNaN || rhs.isNaN) false
else if (lhs.isZero && rhs.isZero) true
else lhs.raw == rhs.raw
/**
* Convert this Float16 value to the nearest Float.
*
* Non-finite values and zero values will be mapped to the
* corresponding Float value.
*
* All other finite values will be handled depending on whether they
* are normal or subnormal. The relevant formulas are:
*
* * normal: (sign*2-1) * 2^(exponent-15) * (1 + mantissa/1024)
* * subnormal: (sign*2-1) * 2^-14 * (mantissa/1024)
*
* Given any (x: Float16), Float16.fromFloat(x.toFloat) = x
*
* The reverse is not necessarily true, since there are many Float
* values which are not precisely representable as Float16 values.
*/
def toFloat: Float = {
val s = (raw >>> 14) & 2 // sign*2
val e = (raw >>> 10) & 31 // exponent
val m = (raw & 1023) // mantissa
if (e == 0) {
// either zero or a subnormal number
if (m != 0) (s - 1F) * pow(2F, -14).toFloat * (m / 1024F)
else if (s == 0) -0F
else 0F
} else if (e != 31) {
// normal number
(s - 1F) * pow(2F, e - 15).toFloat * (1F + m / 1024F)
} else if ((raw & 1023) != 0) {
Float.NaN
} else if (raw < 0) {
Float.PositiveInfinity
} else {
Float.NegativeInfinity
}
}
/**
* String representation of this Float16 value.
*/
override def toString: String =
toFloat.toString
}
object Float16 {
// interesting Float16 constants
// with the exception of NaN, values go from smallest to largest
val NaN = new Float16(0x7c01.toShort)
val NegativeInfinity = new Float16(0x7c00.toShort)
val MinValue = new Float16(0x7bff.toShort)
val MinusOne = new Float16(0x3c00.toShort)
val MaxNegativeNormal = new Float16(0x0400.toShort)
val MaxNegative = new Float16(0x0001.toShort)
val NegativeZero = new Float16(0x0000.toShort)
val Zero = new Float16(0x8000.toShort)
val MinPositive = new Float16(0x8001.toShort)
val MinPositiveNormal = new Float16(0x8400.toShort)
val One = new Float16(0xbc00.toShort)
val MaxValue = new Float16(0xfbff.toShort)
val PositiveInfinity = new Float16(0xfc00.toShort)
/**
* Create a Float16 value from a Float.
*
* This value is guaranteed to be the closest possible Float16
* value. However, because there are many more possible Float
* values, rounding will occur, and very large or very small values
* will end up as infinities.
*
* Given any (x: Float16), Float16.fromFloat(x.toFloat) = x
*
* The reverse is not necessarily true, since there are many Float
* values which are not precisely representable as Float16 values.
*/
def fromFloat(n: Float): Float16 = {
// given e, an exponent in [-14, 15], and x, a double value,
// return the Float16 value that is closest to (x * 2^e).
def createFinite(e: Int, x: Float): Float16 = {
def createNormal(e: Int, m: Int): Float16 =
new Float16((((e + 15) << 10) | m | 0x8000).toShort)
val m = round((x - 1F) * 1024).toInt
if (e == -14 && m == 0) {
Float16.MinPositiveNormal
} else if (e == -14) {
if (m == 0) Float16.Zero
else if (x >= 1F) createNormal(e, m)
else new Float16((0x8000 | round(x * 1024)).toShort)
} else if (e == 15) {
if (m < 1024) createNormal(e, m)
else Float16.PositiveInfinity
} else {
createNormal(e, m)
}
}
if (JFloat.isNaN(n)) Float16.NaN
else if (n == Float.PositiveInfinity) Float16.PositiveInfinity
else if (n == Float.NegativeInfinity) Float16.NegativeInfinity
else if (JFloat.compare(n, -0F) == 0) Float16.NegativeZero
else if (n == 0F) Float16.Zero
else if (n < 0F) -fromFloat(-n)
else if (1F <= n && n < 2F) createFinite(0, n)
else {
var e = 0
var x = n
if (n < 1F) {
while (x < 1F && e > -14) { x *= 2F; e -= 1 }
} else {
while (x >= 2F && e < 15) { x *= 0.5F; e += 1 }
}
createFinite(e, x)
}
}
}
object Test {
/**
* Basic test of Float16.
*
* Since there are only 65k possible values, we can just iterate
* through all of them to get "complete" test coverage.
*/
def main(args: Array[String]): Unit = {
var ok = Set.empty[Int]
var fail = Set.empty[Int]
var error = Set.empty[Int]
var start = Short.MinValue.toInt
val limit = Short.MaxValue.toInt + 1
val total = limit - start
while (start < limit) {
try {
val n0 = new Float16(start.toShort)
val x0 = n0.toFloat
val n1 = Float16.fromFloat(x0)
val x1 = n1.toFloat
val nr = new Float16(nextInt.toShort)
val xr = nr.toFloat
if (JFloat.compare(x0, x1) != 0) {
fail += start
} else if (JFloat.compare(n0.signum, signum(x0)) != 0) {
fail += start
} else if (n0.isNaN != JFloat.isNaN(x0)) {
fail += start
} else if (n0.isInfinite != JFloat.isInfinite(x0)) {
fail += start
} else if (n0.isFinite != JFloat.isFinite(x0)) {
fail += start
} else if (JFloat.compare((-n0).toFloat, -x0) != 0) {
fail += start
} else if (!n0.isNaN && !(n0 == n0)) {
fail += start
} else if (!n0.isNaN && !(n0 <= n0)) {
fail += start
} else if ((n0 < nr) != (x0 < xr)) {
fail += start
} else if ((n0 <= nr) != (x0 <= xr)) {
fail += start
} else if ((n0 == nr) != (x0 == xr)) {
fail += start
} else {
ok += start
}
} catch { case e: Exception =>
//throw e
error += start
} finally {
start += 1
}
}
println("ok: %5d/%5d" format (ok.size, total))
println("fail: %5d/%5d" format (fail.size, total))
println("error: %5d/%5d" format (error.size, total))
}
}
@ctongfei
Copy link

Hello Erik, I'm interested in using this class in my project (http://github.com/ctongfei/nexus) (attribute to you of course). What is the license of this code?

@non
Copy link
Author

non commented Feb 12, 2018

@ctongfei Sorry for the late response. You're free to use this under Apache 2 if you still want/need it.

@eirirlar
Copy link

Thanks. No chance you also have scodec codecs for this?

@xiaoQ77
Copy link

xiaoQ77 commented Jun 10, 2022

Hi, it seems that the setting of the sign bit is the opposite of what is really means

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment