Last active
August 15, 2024 01:43
-
-
Save nmwsharp/54d04af87872a4988809f128e1a1d233 to your computer and use it in GitHub Desktop.
Pretty print tables summarizing properties of tensor arrays in numpy, pytorch, jax, etc. --- now on pip: `pip install arrgh`
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Pretty print tables summarizing properties of tensor arrays in numpy, pytorch, jax, etc. | |
Now on pip! `pip install arrgh` https://github.com/nmwsharp/arrgh |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
name | dtype | shape | type | device | min | max | mean | |
-------------------------------------------------------------------------------------------------------------- | |
[None] | None | N/A | NoneType | | N/A | N/A | N/A | |
intval1 | int | scalar | int | | 7 | 7 | 7 | |
intval2 | int | scalar | int | | -3 | -3 | -3 | |
floatval0 | float | scalar | float | | 42 | 42 | 42 | |
floatval1 | float | scalar | float | | 5.5e-12 | 5.5e-12 | 5.5e-12 | |
floatval2 | float | scalar | float | | 7.72324e+44 | 7.72324e+44 | 7.72324e+44 | |
npval1 | int64 | [100] | numpy.ndarray | | 0 | 99 | 49.5 | |
npval2 | int64 | [10000] | numpy.ndarray | | 0 | 9999 | 4999.5 | |
npval3 | uint64 | [10000] | numpy.ndarray | | 0 | 9999 | 4999.5 | |
npval4 | float32 | [100, 10, 10] | numpy.ndarray | | 0 | 9999 | 4999.5 | |
[temporary] | float32 | [10, 8] | numpy.ndarray | | 2 | 99 | 50.5 | |
npval5 | int64 | [] | numpy.int64 | | 9999 | 9999 | 9999 | |
torchval1 | torch.float32 | [1000, 12, 3] | torch.Tensor | cpu | -4.08445 | 3.90982 | 0.00404567 | |
torchval2 | torch.float32 | [1000, 12, 3] | torch.Tensor | cuda:0 | -3.87309 | 3.90342 | 0.00339224 | |
torchval3 | torch.int64 | [1000] | torch.Tensor | cpu | 0 | 999 | N/A | |
torchval4 | torch.int64 | [] | torch.Tensor | cpu | 0 | 0 | N/A |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def printarr(*arrs, float_width=6): | |
""" | |
Print a pretty table giving name, shape, dtype, type, and content information for input tensors or scalars. | |
Call like: printarr(my_arr, some_other_arr, maybe_a_scalar). Accepts a variable number of arguments. | |
Inputs can be: | |
- Numpy tensor arrays | |
- Pytorch tensor arrays | |
- Jax tensor arrays | |
- Python ints / floats | |
- None | |
It may also work with other array-like types, but they have not been tested. | |
Use the `float_width` option specify the precision to which floating point types are printed. | |
Author: Nicholas Sharp (nmwsharp.com) | |
Canonical source: https://gist.github.com/nmwsharp/54d04af87872a4988809f128e1a1d233 | |
License: This snippet may be used under an MIT license, and it is also released into the public domain. | |
Please retain this docstring as a reference. | |
""" | |
frame = inspect.currentframe().f_back | |
default_name = "[temporary]" | |
## helpers to gather data about each array | |
def name_from_outer_scope(a): | |
if a is None: | |
return '[None]' | |
name = default_name | |
for k, v in frame.f_locals.items(): | |
if v is a: | |
name = k | |
break | |
return name | |
def dtype_str(a): | |
if a is None: | |
return 'None' | |
if isinstance(a, int): | |
return 'int' | |
if isinstance(a, float): | |
return 'float' | |
return str(a.dtype) | |
def shape_str(a): | |
if a is None: | |
return 'N/A' | |
if isinstance(a, int): | |
return 'scalar' | |
if isinstance(a, float): | |
return 'scalar' | |
return str(list(a.shape)) | |
def type_str(a): | |
return str(type(a))[8:-2] # TODO this is is weird... what's the better way? | |
def device_str(a): | |
if hasattr(a, 'device'): | |
device_str = str(a.device) | |
if len(device_str) < 10: | |
# heuristic: jax returns some goofy long string we don't want, ignore it | |
return device_str | |
return "" | |
def format_float(x): | |
return f"{x:{float_width}g}" | |
def minmaxmean_str(a): | |
if a is None: | |
return ('N/A', 'N/A', 'N/A') | |
if isinstance(a, int) or isinstance(a, float): | |
return (format_float(a), format_float(a), format_float(a)) | |
# compute min/max/mean. if anything goes wrong, just print 'N/A' | |
min_str = "N/A" | |
try: min_str = format_float(a.min()) | |
except: pass | |
max_str = "N/A" | |
try: max_str = format_float(a.max()) | |
except: pass | |
mean_str = "N/A" | |
try: mean_str = format_float(a.mean()) | |
except: pass | |
return (min_str, max_str, mean_str) | |
try: | |
props = ['name', 'dtype', 'shape', 'type', 'device', 'min', 'max', 'mean'] | |
# precompute all of the properties for each input | |
str_props = [] | |
for a in arrs: | |
minmaxmean = minmaxmean_str(a) | |
str_props.append({ | |
'name' : name_from_outer_scope(a), | |
'dtype' : dtype_str(a), | |
'shape' : shape_str(a), | |
'type' : type_str(a), | |
'device' : device_str(a), | |
'min' : minmaxmean[0], | |
'max' : minmaxmean[1], | |
'mean' : minmaxmean[2], | |
}) | |
# for each property, compute its length | |
maxlen = {} | |
for p in props: maxlen[p] = 0 | |
for sp in str_props: | |
for p in props: | |
maxlen[p] = max(maxlen[p], len(sp[p])) | |
# if any property got all empty strings, don't bother printing it, remove if from the list | |
props = [p for p in props if maxlen[p] > 0] | |
# print a header | |
header_str = "" | |
for p in props: | |
prefix = "" if p == 'name' else " | " | |
fmt_key = ">" if p == 'name' else "<" | |
header_str += f"{prefix}{p:{fmt_key}{maxlen[p]}}" | |
print(header_str) | |
print("-"*len(header_str)) | |
# now print the acual arrays | |
for strp in str_props: | |
for p in props: | |
prefix = "" if p == 'name' else " | " | |
fmt_key = ">" if p == 'name' else "<" | |
print(f"{prefix}{strp[p]:{fmt_key}{maxlen[p]}}", end='') | |
print("") | |
finally: | |
del frame |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
if __name__ == "__main__": | |
## test it! | |
# plain python vlaues | |
noneval = None | |
intval1 = 7 | |
intval2 = -3 | |
floatval0 = 42.0 | |
floatval1 = 5.5 * 1e-12 | |
floatval2 = 7.7232412351231231234 * 1e44 | |
# numpy values | |
import numpy as np | |
npval1 = np.arange(100) | |
npval2 = np.arange(10000) | |
npval3 = np.arange(10000).astype(np.uint64) | |
npval4 = np.arange(10000).astype(np.float32).reshape(100,10,10) | |
npval5 = np.arange(10000)[-1] | |
# torch values | |
torchval1 = None | |
torchval2 = None | |
torchval3 = None | |
torchval4 = None | |
try: | |
import torch | |
torchval1 = torch.randn((1000,12,3)) | |
torchval2 = torch.randn((1000,12,3)).cuda() | |
torchval3 = torch.arange(1000) | |
torchval4 = torch.arange(1000)[0] | |
except ModuleNotFoundError: | |
pass | |
# jax values | |
jaxval1 = None | |
jaxval2 = None | |
jaxval3 = None | |
jaxval4 = None | |
try: | |
import jax | |
import jax.numpy as jnp | |
jaxval1 = jnp.linspace(0,1,10000) | |
jaxval2 = jnp.linspace(0,1,10000).reshape(100,10,10) | |
jaxval3 = jnp.arange(1000) | |
jaxval4 = jnp.arange(1000)[0] | |
except ModuleNotFoundError: | |
pass | |
printarr(noneval, | |
intval1, intval2, \ | |
floatval0, floatval1, floatval2, \ | |
npval1, npval2, npval3, npval4, npval4[0,:,2:], npval5, \ | |
torchval1, torchval2, torchval3, torchval4, \ | |
jaxval1, jaxval2, jaxval3, jaxval4, \ | |
) | |
Thanks @afspies! I incorporated that fix into the version now on pip.
This gist is now a small package on pip! pip install arrgh
https://github.com/nmwsharp/arrgh
Awesome! Thanks for your work on this :)
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Hi,
Thanks for this! I ran into an issue when all arrays were on CPU - in this case, the header for "device" was actually longer than the longest value in the rows, causing columns to misalign.
Adding
to line 111 fixed this, though there is probably a nicer way.