Created
April 30, 2023 04:24
-
-
Save ashvardanian/301b0614252941ac8a3137ac72a18892 to your computer and use it in GitHub Desktop.
Reads a binary matrix from disk, inferring the type of scalars from filename.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def read_matrix(filename: str, start_row: int = 0, count_rows: Optional[int] = None): | |
""" | |
Read *.ibin, *.hbin, *.fbin, *.dbin files with matrixes. | |
Args: | |
:param filename (str): path to the matrix file | |
:param start_row (int): start reading vectors from this index | |
:param count_rows (int): number of vectors to read. If None, read all vectors | |
Returns: | |
Parsed matrix (numpy.ndarray) | |
""" | |
dtype = np.float32 | |
scalar_size = 4 | |
if filename.endswith('.fbin'): | |
dtype = np.float32 | |
scalar_size = 4 | |
elif filename.endswith('.dbin'): | |
dtype = np.float64 | |
scalar_size = 8 | |
elif filename.endswith('.hbin'): | |
dtype = np.float16 | |
scalar_size = 2 | |
elif filename.endswith('.ibin'): | |
dtype = np.int32 | |
scalar_size = 4 | |
else: | |
raise Exception('Unknown file type') | |
with open(filename, 'rb') as f: | |
rows, cols = np.fromfile(f, count=2, dtype=np.int32) | |
rows = (rows - start_row) if count_rows is None else count_rows | |
arr = np.fromfile( | |
f, count=rows * cols, dtype=dtype, | |
offset=start_row * scalar_size * cols) | |
return arr.reshape(rows, cols) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment