Skip to content

Instantly share code, notes, and snippets.

@syed-ahmed
Created June 1, 2022 20:23
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save syed-ahmed/f1449e060c03cd5f7cd3dfb1a877087d to your computer and use it in GitHub Desktop.
Save syed-ahmed/f1449e060c03cd5f7cd3dfb1a877087d to your computer and use it in GitHub Desktop.
Floating-point intrinsic errors demonstrated using PyTorch
import torch
import ctypes
def print_binary_val(tensor):
print(bin(ctypes.c_uint.from_buffer(ctypes.c_float(tensor.item())).value))
# Error 1: Large round-off incurred when
# representing the same number in a lower precision.
A = torch.tensor(1.0001, dtype=torch.float32).cuda()
B = A.half()
print_binary_val(A)
# 0b111111100000000000001101000111
print_binary_val(B)
# 0b111111100000000000000000000000
# Error 2: Large difference in Addend Exponent
# discards the small number.
A = torch.tensor(2.25e30, dtype=torch.float32).cuda()
B = torch.tensor(4.25e2, dtype=torch.float32).cuda()
C = A + B
print_binary_val(A)
# 0b1110001111000110011000100100011
print_binary_val(C)
# 0b1110001111000110011000100100011
# Error 3: Severe Cancellation
# since A and B are close, subtraction of the squares
# gets rid of the significant digits whereas the factored
# version retains them.
A = A = torch.tensor(7.65432e4, dtype=torch.float32).cuda()
B = torch.tensor(7.6543e4, dtype=torch.float32).cuda()
C = torch.square(A) - torch.square(B) # gives tensor(31232.)
D = (A+B)*(A-B) # gives tensor(31095.6348)
print_binary_val(C)
# 0b1000110111101000000000000000000
print_binary_val(D)
# 0b1000110111100101110111101000101
# Error 4: Overflow/Underflow occurs when an operation
# results in floats that are near the limits of the datatype.
# If we don't upcast the operation in higher precision
# we get infinity.
A = torch.tensor([1e20, 1e20], dtype=torch.float32).cuda()
B = A.norm() # produces tensor(inf)
C = A.double().norm() # produces tensor(1.4142e+20, dtype=torch.float64), representable in fp32
print_binary_val(B)
# 0b1111111100000000000000000000000
print_binary_val(C)
# 0b1100000111101010101001110110011
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment