Last active
February 16, 2024 08:02
-
-
Save wcho21/e216323133a8de85adc6a25c3644a0af to your computer and use it in GitHub Desktop.
Floating Point Implementation
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def readNumberLiteral(literal: str) -> tuple[bool, float]: | |
i = 0 | |
# read sign | |
positive = literal[i] != "-" | |
if literal[i] in "+-": | |
i += 1 | |
# read whole numbers | |
num = 0 | |
while i < len(literal) and literal[i].isdigit(): | |
num = num*10 + (ord(literal[i]) - ord('0')) | |
i += 1 | |
# return number if end | |
if i == len(literal) or literal[i] != ".": | |
return (positive, num) | |
# read decimal part | |
i += 1 | |
factor = 1 | |
while i < len(literal) and literal[i].isdigit(): | |
num = num*10 + (ord(literal[i]) - ord('0')) | |
i += 1 | |
factor *= 10 | |
divided = num / factor | |
return (positive, divided if positive else -divided) | |
def roundToNearestEven(binary: int) -> int: | |
""" | |
round using the last three bits | |
""" | |
least = (binary & 0b1000) >> 3 | |
last = binary & 0b111 # guard, round, sticky bits | |
if last < 0b100: | |
return binary >> 3 | |
if last > 0b100: | |
return (binary >> 3) + 1 | |
# round to nearest even | |
if least == 0: | |
return binary >> 3 | |
else: | |
return (binary >> 3) + 1 | |
def truncate(binary: int) -> int: | |
""" | |
truncate the last three bits | |
""" | |
return (binary >> 3) | |
class Float: | |
roundMode = "NEAREST_EVEN" | |
def __init__(self, sign: int, exp: int, frac: int): | |
self.sign = sign | |
self.exp = exp | |
self.frac = frac | |
@staticmethod | |
def fromLiteral(literal: str) -> "Float": | |
positive, num = readNumberLiteral(literal) | |
if num == 0: | |
return Float(0 if positive else 1, 0, 0) | |
return Float.fromNumber(num) | |
@staticmethod | |
def fromNumber(number: float) -> "Float": | |
sign = 1 if number < 0 else 0 | |
if sign == 1: | |
number *= -1 | |
exp = 0 | |
digit = 1.0 | |
if number > 1: | |
while digit < number/2: | |
exp += 1 | |
digit *= 2 | |
if number < 1: | |
while digit > number: | |
exp -= 1 | |
digit /= 2 | |
acc = 0.0 | |
bits = [] | |
for _ in range(1+23+3): | |
if acc+digit <= number: | |
acc += digit | |
bits.append(1) | |
else: | |
bits.append(0) | |
digit /= 2 | |
binaryStr = "".join(map(lambda n: str(n), bits[1:])) | |
binary = int(binaryStr, base=2) | |
rounded = roundToNearestEven(binary) | |
if rounded >= (1 << 23): | |
rounded -= (1 << 23) | |
exp += 1 | |
return Float(sign, exp+127, rounded) | |
def isZero(self) -> bool: | |
return self.exp == 0 and self.frac == 0 | |
def __eq__(self, other) -> bool: | |
if not isinstance(other, Float): | |
raise Exception("not comparable") | |
if self.isZero() and other.isZero(): | |
return True | |
if self.sign != other.sign: | |
return False | |
if self.exp != other.exp: | |
return False | |
if self.frac != other.frac: | |
return False | |
return True | |
def __add__(self, other) -> "Float": | |
if not isinstance(other, Float): | |
raise Exception("not comparable") | |
if self.sign != other.sign: | |
raise Exception("not implemented for different signs") | |
sign = self.sign | |
greater, smaller = (self, other) if self.exp >= other.exp else (other, self) | |
if smaller.isZero(): | |
return Float(greater.sign, greater.exp, greater.frac) | |
# prepend omitted bit | |
greaterFrac = greater.frac | (1 << 23) | |
smallerFrac = smaller.frac | (1 << 23) | |
# append three bits | |
greaterFrac <<= 3 | |
smallerFrac <<= 3 | |
# align by exp | |
smallerFrac >>= greater.exp - smaller.exp | |
# add and round | |
added = greaterFrac + smallerFrac | |
rounded = 0 | |
if Float.roundMode == "NEAREST_EVEN": | |
rounded = roundToNearestEven(added) | |
else: | |
rounded = truncate(added) | |
# normalize | |
exp = greater.exp | |
large = 1 << 24 | |
while rounded >= large: | |
exp += 1 | |
rounded >>= 1 | |
# get last 23 bits to drop the first bit | |
rounded &= ((1 << 23) - 1) | |
return Float(sign, exp, rounded) | |
def __repr__(self): | |
expStr = bin(self.exp)[2:].rjust(8, "0") | |
fracStr = bin(self.frac)[2:].rjust(23, "0") | |
return f"[Float: sign={self.sign}, exp={expStr}, frac={fracStr}]" | |
# test codes | |
if __name__ == "__main__": | |
def testFloat(): | |
positive_zero = Float(0, 0, 0) # +0 | |
negative_zero = Float(1, 0, 0) # -0 | |
assert positive_zero == negative_zero | |
two_point_seven_five = Float(0, 1 << 7, 0b011 << 20) # -2.75 | |
assert two_point_seven_five != positive_zero | |
testFloat() | |
def testReadNumberLiteral(): | |
positive1, num1 = readNumberLiteral("12.25") | |
assert positive1 == True | |
assert num1 == 12.25 | |
positive2, num2 = readNumberLiteral("-0.125") | |
assert positive2 == False | |
assert num2 == -0.125 | |
testReadNumberLiteral() | |
def testRoundToNearestEven(): | |
assert roundToNearestEven(0b0101) == 1 | |
assert roundToNearestEven(0b10_0000_0000_0000_0000_0000_0101) == 0b100_0000_0000_0000_0000_0001 | |
testRoundToNearestEven() | |
def testTruncate(): | |
assert truncate(0b0101) == 0 | |
assert truncate(0b10_0000_0000_0000_0000_0000_0101) == 0b100_0000_0000_0000_0000_0000 | |
testTruncate() | |
def testFromLiteral(): | |
f1 = Float.fromLiteral("0") | |
assert f1.sign == 0 | |
assert f1.exp == 0 | |
assert f1.frac == 0 | |
f2 = Float.fromLiteral("-0") | |
assert f2.sign == 1 | |
assert f2.exp == 0 | |
assert f2.frac == 0 | |
f3 = Float.fromLiteral("0.1") | |
assert f3.sign == 0 | |
assert f3.exp == 0b01111011 | |
assert f3.frac == 0b10011001100110011001101 | |
f4 = Float.fromLiteral("0.2") | |
assert f4.sign == 0 | |
assert f4.exp == 0b01111100 | |
assert f4.frac == 0b10011001100110011001101 | |
f5 = Float.fromLiteral("0.01") | |
assert f5.sign == 0 | |
assert f5.exp == 0b01111000 | |
assert f5.frac == 0b01000111101011100001010 | |
f6 = Float.fromLiteral("-0.01") | |
assert f6.sign == 1 | |
assert f6.exp == 0b01111000 | |
assert f6.frac == 0b01000111101011100001010 | |
f7 = Float.fromLiteral("123.45678") | |
assert f7.sign == 0 | |
assert f7.exp == 0b10000101 | |
assert f7.frac == 0b11101101110100111011111 | |
f8 = Float.fromLiteral("0.123456789") | |
assert f8.sign == 0 | |
assert f8.exp == 0b01111011 | |
assert f8.frac == 0b11111001101011011101010 | |
f9 = Float.fromLiteral("0.99999998509883880615234375") # should be rounded to 1 | |
assert f9.sign == 0 | |
assert f9.exp == 0b01111111 | |
assert f9.frac == 0 | |
testFromLiteral() | |
def testAdd(): | |
f1 = Float.fromLiteral("0.1") + Float.fromLiteral("0.2") | |
assert f1.sign == 0 | |
assert f1.exp == 0b01111101 | |
assert f1.frac == 0b00110011001100110011010 | |
assert f1 == Float.fromLiteral("0.3") | |
f2 = Float.fromLiteral("0.2") + Float.fromLiteral("0.1") | |
assert f2.sign == 0 | |
assert f2.exp == 0b01111101 | |
assert f2.frac == 0b00110011001100110011010 | |
f3 = Float.fromLiteral("-0.1") + Float.fromLiteral("-0.2") | |
assert f3.sign == 1 | |
assert f3.exp == 0b01111101 | |
assert f3.frac == 0b00110011001100110011010 | |
f4 = Float.fromLiteral("0.1") + Float.fromLiteral("0") | |
assert f4.sign == 0 | |
assert f4.exp == 0b01111011 | |
assert f4.frac == 0b10011001100110011001101 | |
f5 = Float.fromLiteral("0.1") + Float.fromLiteral("0.01") | |
assert f5.sign == 0 | |
assert f5.exp == 0b01111011 | |
assert f5.frac == 0b11000010100011110101110 | |
f6 = Float.fromLiteral("123.45") + Float.fromLiteral("1234.567") | |
assert f6.sign == 0 | |
assert f6.exp == 0b10001001 | |
assert f6.frac == 0b01010011100000010001011 | |
Float.roundMode = "TRUNCATE" | |
f1 = Float.fromLiteral("0.1") + Float.fromLiteral("0.2") | |
assert f1.sign == 0 | |
assert f1.exp == 0b01111101 | |
assert f1.frac == 0b00110011001100110011001 | |
assert f1 != Float.fromLiteral("0.3") | |
testAdd() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment