Created
May 21, 2024 04:37
-
-
Save msaroufim/4ba4b9ea09b07912111f33d16382eb9e to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
# >>> import sys | |
# >>> size_of_bool = sys.getsizeof(True) # or sys.getsizeof(False) | |
# >>> print(size_of_bool) | |
# 28 | |
# In C its 1 byte | |
# Nice tutorial about how bool is actually 1 byte and how to bit pack 8 bools | |
# https://www.youtube.com/watch?v=LRJclCVrvQI | |
# Main reason why we don't do this is because of memory alignment issues makes performance worst | |
import torch | |
def pack(uint4_1, uint4_2): | |
""" | |
Pack two uint4 values into a single uint8. | |
Args: | |
uint4_1 (torch.Tensor): A tensor of uint4 values. | |
uint4_2 (torch.Tensor): A tensor of uint4 values. | |
Returns: | |
torch.Tensor: A tensor of uint8 values. | |
""" | |
assert (uint4_1 >= 0).all() and (uint4_1 < 16).all(), "Values in uint4_1 must be in the range 0-15" | |
assert (uint4_2 >= 0).all() and (uint4_2 < 16).all(), "Values in uint4_2 must be in the range 0-15" | |
return (uint4_1 & 0x0F) | ((uint4_2 & 0x0F) << 4) | |
def unpack(uint8): | |
""" | |
Unpack a uint8 value into two uint4 values. | |
Args: | |
uint8 (torch.Tensor): A tensor of uint8 values. | |
Returns: | |
Tuple[torch.Tensor, torch.Tensor]: Two tensors of uint4 values. | |
""" | |
uint4_1 = uint8 & 0x0F | |
uint4_2 = (uint8 >> 4) & 0x0F | |
return uint4_1, uint4_2 | |
def uint4_vector_addition(vec1, vec2): | |
""" | |
Perform vector addition on two vectors of uint4 values packed in uint8. | |
Args: | |
vec1 (torch.Tensor): A tensor of uint8 values containing packed uint4 values. | |
vec2 (torch.Tensor): A tensor of uint8 values containing packed uint4 values. | |
Returns: | |
torch.Tensor: A tensor of uint8 values containing the result of the addition. | |
""" | |
assert vec1.dtype == torch.uint8 and vec2.dtype == torch.uint8, "Input tensors must be of type uint8" | |
uint4_1_1, uint4_1_2 = unpack(vec1) | |
uint4_2_1, uint4_2_2 = unpack(vec2) | |
sum_1 = (uint4_1_1 + uint4_2_1) % 16 | |
sum_2 = (uint4_1_2 + uint4_2_2) % 16 | |
return pack(sum_1, sum_2) | |
# Create example vectors | |
vec1 = torch.tensor([pack(torch.tensor(3), torch.tensor(7)), pack(torch.tensor(12), torch.tensor(1))], dtype=torch.uint8) | |
vec2 = torch.tensor([pack(torch.tensor(4), torch.tensor(2)), pack(torch.tensor(6), torch.tensor(9))], dtype=torch.uint8) | |
# Perform uint4 vector addition | |
result = uint4_vector_addition(vec1, vec2) | |
# Unpack the result to see the uint4 values | |
result_unpacked = [unpack(r) for r in result] | |
# Assertions to check correctness | |
expected_result = [(7, 9), (2, 10)] # (3+4)%16, (7+2)%16, (12+6)%16, (1+9)%16 | |
assert result_unpacked == expected_result, f"Expected {expected_result}, but got {result_unpacked}" | |
print("All assertions passed!") | |
print(result_unpacked) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment