Skip to content

Instantly share code, notes, and snippets.

@wcho21
Last active February 16, 2024 08:02
Show Gist options
  • Save wcho21/e216323133a8de85adc6a25c3644a0af to your computer and use it in GitHub Desktop.
Save wcho21/e216323133a8de85adc6a25c3644a0af to your computer and use it in GitHub Desktop.
Floating Point Implementation
def readNumberLiteral(literal: str) -> tuple[bool, float]:
i = 0
# read sign
positive = literal[i] != "-"
if literal[i] in "+-":
i += 1
# read whole numbers
num = 0
while i < len(literal) and literal[i].isdigit():
num = num*10 + (ord(literal[i]) - ord('0'))
i += 1
# return number if end
if i == len(literal) or literal[i] != ".":
return (positive, num)
# read decimal part
i += 1
factor = 1
while i < len(literal) and literal[i].isdigit():
num = num*10 + (ord(literal[i]) - ord('0'))
i += 1
factor *= 10
divided = num / factor
return (positive, divided if positive else -divided)
def roundToNearestEven(binary: int) -> int:
"""
round using the last three bits
"""
least = (binary & 0b1000) >> 3
last = binary & 0b111 # guard, round, sticky bits
if last < 0b100:
return binary >> 3
if last > 0b100:
return (binary >> 3) + 1
# round to nearest even
if least == 0:
return binary >> 3
else:
return (binary >> 3) + 1
def truncate(binary: int) -> int:
"""
truncate the last three bits
"""
return (binary >> 3)
class Float:
roundMode = "NEAREST_EVEN"
def __init__(self, sign: int, exp: int, frac: int):
self.sign = sign
self.exp = exp
self.frac = frac
@staticmethod
def fromLiteral(literal: str) -> "Float":
positive, num = readNumberLiteral(literal)
if num == 0:
return Float(0 if positive else 1, 0, 0)
return Float.fromNumber(num)
@staticmethod
def fromNumber(number: float) -> "Float":
sign = 1 if number < 0 else 0
if sign == 1:
number *= -1
exp = 0
digit = 1.0
if number > 1:
while digit < number/2:
exp += 1
digit *= 2
if number < 1:
while digit > number:
exp -= 1
digit /= 2
acc = 0.0
bits = []
for _ in range(1+23+3):
if acc+digit <= number:
acc += digit
bits.append(1)
else:
bits.append(0)
digit /= 2
binaryStr = "".join(map(lambda n: str(n), bits[1:]))
binary = int(binaryStr, base=2)
rounded = roundToNearestEven(binary)
if rounded >= (1 << 23):
rounded -= (1 << 23)
exp += 1
return Float(sign, exp+127, rounded)
def isZero(self) -> bool:
return self.exp == 0 and self.frac == 0
def __eq__(self, other) -> bool:
if not isinstance(other, Float):
raise Exception("not comparable")
if self.isZero() and other.isZero():
return True
if self.sign != other.sign:
return False
if self.exp != other.exp:
return False
if self.frac != other.frac:
return False
return True
def __add__(self, other) -> "Float":
if not isinstance(other, Float):
raise Exception("not comparable")
if self.sign != other.sign:
raise Exception("not implemented for different signs")
sign = self.sign
greater, smaller = (self, other) if self.exp >= other.exp else (other, self)
if smaller.isZero():
return Float(greater.sign, greater.exp, greater.frac)
# prepend omitted bit
greaterFrac = greater.frac | (1 << 23)
smallerFrac = smaller.frac | (1 << 23)
# append three bits
greaterFrac <<= 3
smallerFrac <<= 3
# align by exp
smallerFrac >>= greater.exp - smaller.exp
# add and round
added = greaterFrac + smallerFrac
rounded = 0
if Float.roundMode == "NEAREST_EVEN":
rounded = roundToNearestEven(added)
else:
rounded = truncate(added)
# normalize
exp = greater.exp
large = 1 << 24
while rounded >= large:
exp += 1
rounded >>= 1
# get last 23 bits to drop the first bit
rounded &= ((1 << 23) - 1)
return Float(sign, exp, rounded)
def __repr__(self):
expStr = bin(self.exp)[2:].rjust(8, "0")
fracStr = bin(self.frac)[2:].rjust(23, "0")
return f"[Float: sign={self.sign}, exp={expStr}, frac={fracStr}]"
# test codes
if __name__ == "__main__":
def testFloat():
positive_zero = Float(0, 0, 0) # +0
negative_zero = Float(1, 0, 0) # -0
assert positive_zero == negative_zero
two_point_seven_five = Float(0, 1 << 7, 0b011 << 20) # -2.75
assert two_point_seven_five != positive_zero
testFloat()
def testReadNumberLiteral():
positive1, num1 = readNumberLiteral("12.25")
assert positive1 == True
assert num1 == 12.25
positive2, num2 = readNumberLiteral("-0.125")
assert positive2 == False
assert num2 == -0.125
testReadNumberLiteral()
def testRoundToNearestEven():
assert roundToNearestEven(0b0101) == 1
assert roundToNearestEven(0b10_0000_0000_0000_0000_0000_0101) == 0b100_0000_0000_0000_0000_0001
testRoundToNearestEven()
def testTruncate():
assert truncate(0b0101) == 0
assert truncate(0b10_0000_0000_0000_0000_0000_0101) == 0b100_0000_0000_0000_0000_0000
testTruncate()
def testFromLiteral():
f1 = Float.fromLiteral("0")
assert f1.sign == 0
assert f1.exp == 0
assert f1.frac == 0
f2 = Float.fromLiteral("-0")
assert f2.sign == 1
assert f2.exp == 0
assert f2.frac == 0
f3 = Float.fromLiteral("0.1")
assert f3.sign == 0
assert f3.exp == 0b01111011
assert f3.frac == 0b10011001100110011001101
f4 = Float.fromLiteral("0.2")
assert f4.sign == 0
assert f4.exp == 0b01111100
assert f4.frac == 0b10011001100110011001101
f5 = Float.fromLiteral("0.01")
assert f5.sign == 0
assert f5.exp == 0b01111000
assert f5.frac == 0b01000111101011100001010
f6 = Float.fromLiteral("-0.01")
assert f6.sign == 1
assert f6.exp == 0b01111000
assert f6.frac == 0b01000111101011100001010
f7 = Float.fromLiteral("123.45678")
assert f7.sign == 0
assert f7.exp == 0b10000101
assert f7.frac == 0b11101101110100111011111
f8 = Float.fromLiteral("0.123456789")
assert f8.sign == 0
assert f8.exp == 0b01111011
assert f8.frac == 0b11111001101011011101010
f9 = Float.fromLiteral("0.99999998509883880615234375") # should be rounded to 1
assert f9.sign == 0
assert f9.exp == 0b01111111
assert f9.frac == 0
testFromLiteral()
def testAdd():
f1 = Float.fromLiteral("0.1") + Float.fromLiteral("0.2")
assert f1.sign == 0
assert f1.exp == 0b01111101
assert f1.frac == 0b00110011001100110011010
assert f1 == Float.fromLiteral("0.3")
f2 = Float.fromLiteral("0.2") + Float.fromLiteral("0.1")
assert f2.sign == 0
assert f2.exp == 0b01111101
assert f2.frac == 0b00110011001100110011010
f3 = Float.fromLiteral("-0.1") + Float.fromLiteral("-0.2")
assert f3.sign == 1
assert f3.exp == 0b01111101
assert f3.frac == 0b00110011001100110011010
f4 = Float.fromLiteral("0.1") + Float.fromLiteral("0")
assert f4.sign == 0
assert f4.exp == 0b01111011
assert f4.frac == 0b10011001100110011001101
f5 = Float.fromLiteral("0.1") + Float.fromLiteral("0.01")
assert f5.sign == 0
assert f5.exp == 0b01111011
assert f5.frac == 0b11000010100011110101110
f6 = Float.fromLiteral("123.45") + Float.fromLiteral("1234.567")
assert f6.sign == 0
assert f6.exp == 0b10001001
assert f6.frac == 0b01010011100000010001011
Float.roundMode = "TRUNCATE"
f1 = Float.fromLiteral("0.1") + Float.fromLiteral("0.2")
assert f1.sign == 0
assert f1.exp == 0b01111101
assert f1.frac == 0b00110011001100110011001
assert f1 != Float.fromLiteral("0.3")
testAdd()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment