Skip to content

Instantly share code, notes, and snippets.

@msaroufim
Created May 21, 2024 04:37
Show Gist options
  • Save msaroufim/4ba4b9ea09b07912111f33d16382eb9e to your computer and use it in GitHub Desktop.
Save msaroufim/4ba4b9ea09b07912111f33d16382eb9e to your computer and use it in GitHub Desktop.
import torch
# >>> import sys
# >>> size_of_bool = sys.getsizeof(True) # or sys.getsizeof(False)
# >>> print(size_of_bool)
# 28
# In C its 1 byte
# Nice tutorial about how bool is actually 1 byte and how to bit pack 8 bools
# https://www.youtube.com/watch?v=LRJclCVrvQI
# Main reason why we don't do this is because of memory alignment issues makes performance worst
import torch
def pack(uint4_1, uint4_2):
"""
Pack two uint4 values into a single uint8.
Args:
uint4_1 (torch.Tensor): A tensor of uint4 values.
uint4_2 (torch.Tensor): A tensor of uint4 values.
Returns:
torch.Tensor: A tensor of uint8 values.
"""
assert (uint4_1 >= 0).all() and (uint4_1 < 16).all(), "Values in uint4_1 must be in the range 0-15"
assert (uint4_2 >= 0).all() and (uint4_2 < 16).all(), "Values in uint4_2 must be in the range 0-15"
return (uint4_1 & 0x0F) | ((uint4_2 & 0x0F) << 4)
def unpack(uint8):
"""
Unpack a uint8 value into two uint4 values.
Args:
uint8 (torch.Tensor): A tensor of uint8 values.
Returns:
Tuple[torch.Tensor, torch.Tensor]: Two tensors of uint4 values.
"""
uint4_1 = uint8 & 0x0F
uint4_2 = (uint8 >> 4) & 0x0F
return uint4_1, uint4_2
def uint4_vector_addition(vec1, vec2):
"""
Perform vector addition on two vectors of uint4 values packed in uint8.
Args:
vec1 (torch.Tensor): A tensor of uint8 values containing packed uint4 values.
vec2 (torch.Tensor): A tensor of uint8 values containing packed uint4 values.
Returns:
torch.Tensor: A tensor of uint8 values containing the result of the addition.
"""
assert vec1.dtype == torch.uint8 and vec2.dtype == torch.uint8, "Input tensors must be of type uint8"
uint4_1_1, uint4_1_2 = unpack(vec1)
uint4_2_1, uint4_2_2 = unpack(vec2)
sum_1 = (uint4_1_1 + uint4_2_1) % 16
sum_2 = (uint4_1_2 + uint4_2_2) % 16
return pack(sum_1, sum_2)
# Create example vectors
vec1 = torch.tensor([pack(torch.tensor(3), torch.tensor(7)), pack(torch.tensor(12), torch.tensor(1))], dtype=torch.uint8)
vec2 = torch.tensor([pack(torch.tensor(4), torch.tensor(2)), pack(torch.tensor(6), torch.tensor(9))], dtype=torch.uint8)
# Perform uint4 vector addition
result = uint4_vector_addition(vec1, vec2)
# Unpack the result to see the uint4 values
result_unpacked = [unpack(r) for r in result]
# Assertions to check correctness
expected_result = [(7, 9), (2, 10)] # (3+4)%16, (7+2)%16, (12+6)%16, (1+9)%16
assert result_unpacked == expected_result, f"Expected {expected_result}, but got {result_unpacked}"
print("All assertions passed!")
print(result_unpacked)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment