Skip to content

Instantly share code, notes, and snippets.

@nmwsharp
Last active August 15, 2024 01:43
Show Gist options
  • Save nmwsharp/54d04af87872a4988809f128e1a1d233 to your computer and use it in GitHub Desktop.
Save nmwsharp/54d04af87872a4988809f128e1a1d233 to your computer and use it in GitHub Desktop.
Pretty print tables summarizing properties of tensor arrays in numpy, pytorch, jax, etc. --- now on pip: `pip install arrgh`
Pretty print tables summarizing properties of tensor arrays in numpy, pytorch, jax, etc.
Now on pip! `pip install arrgh` https://github.com/nmwsharp/arrgh
name | dtype | shape | type | device | min | max | mean
--------------------------------------------------------------------------------------------------------------
[None] | None | N/A | NoneType | | N/A | N/A | N/A
intval1 | int | scalar | int | | 7 | 7 | 7
intval2 | int | scalar | int | | -3 | -3 | -3
floatval0 | float | scalar | float | | 42 | 42 | 42
floatval1 | float | scalar | float | | 5.5e-12 | 5.5e-12 | 5.5e-12
floatval2 | float | scalar | float | | 7.72324e+44 | 7.72324e+44 | 7.72324e+44
npval1 | int64 | [100] | numpy.ndarray | | 0 | 99 | 49.5
npval2 | int64 | [10000] | numpy.ndarray | | 0 | 9999 | 4999.5
npval3 | uint64 | [10000] | numpy.ndarray | | 0 | 9999 | 4999.5
npval4 | float32 | [100, 10, 10] | numpy.ndarray | | 0 | 9999 | 4999.5
[temporary] | float32 | [10, 8] | numpy.ndarray | | 2 | 99 | 50.5
npval5 | int64 | [] | numpy.int64 | | 9999 | 9999 | 9999
torchval1 | torch.float32 | [1000, 12, 3] | torch.Tensor | cpu | -4.08445 | 3.90982 | 0.00404567
torchval2 | torch.float32 | [1000, 12, 3] | torch.Tensor | cuda:0 | -3.87309 | 3.90342 | 0.00339224
torchval3 | torch.int64 | [1000] | torch.Tensor | cpu | 0 | 999 | N/A
torchval4 | torch.int64 | [] | torch.Tensor | cpu | 0 | 0 | N/A
def printarr(*arrs, float_width=6):
"""
Print a pretty table giving name, shape, dtype, type, and content information for input tensors or scalars.
Call like: printarr(my_arr, some_other_arr, maybe_a_scalar). Accepts a variable number of arguments.
Inputs can be:
- Numpy tensor arrays
- Pytorch tensor arrays
- Jax tensor arrays
- Python ints / floats
- None
It may also work with other array-like types, but they have not been tested.
Use the `float_width` option specify the precision to which floating point types are printed.
Author: Nicholas Sharp (nmwsharp.com)
Canonical source: https://gist.github.com/nmwsharp/54d04af87872a4988809f128e1a1d233
License: This snippet may be used under an MIT license, and it is also released into the public domain.
Please retain this docstring as a reference.
"""
frame = inspect.currentframe().f_back
default_name = "[temporary]"
## helpers to gather data about each array
def name_from_outer_scope(a):
if a is None:
return '[None]'
name = default_name
for k, v in frame.f_locals.items():
if v is a:
name = k
break
return name
def dtype_str(a):
if a is None:
return 'None'
if isinstance(a, int):
return 'int'
if isinstance(a, float):
return 'float'
return str(a.dtype)
def shape_str(a):
if a is None:
return 'N/A'
if isinstance(a, int):
return 'scalar'
if isinstance(a, float):
return 'scalar'
return str(list(a.shape))
def type_str(a):
return str(type(a))[8:-2] # TODO this is is weird... what's the better way?
def device_str(a):
if hasattr(a, 'device'):
device_str = str(a.device)
if len(device_str) < 10:
# heuristic: jax returns some goofy long string we don't want, ignore it
return device_str
return ""
def format_float(x):
return f"{x:{float_width}g}"
def minmaxmean_str(a):
if a is None:
return ('N/A', 'N/A', 'N/A')
if isinstance(a, int) or isinstance(a, float):
return (format_float(a), format_float(a), format_float(a))
# compute min/max/mean. if anything goes wrong, just print 'N/A'
min_str = "N/A"
try: min_str = format_float(a.min())
except: pass
max_str = "N/A"
try: max_str = format_float(a.max())
except: pass
mean_str = "N/A"
try: mean_str = format_float(a.mean())
except: pass
return (min_str, max_str, mean_str)
try:
props = ['name', 'dtype', 'shape', 'type', 'device', 'min', 'max', 'mean']
# precompute all of the properties for each input
str_props = []
for a in arrs:
minmaxmean = minmaxmean_str(a)
str_props.append({
'name' : name_from_outer_scope(a),
'dtype' : dtype_str(a),
'shape' : shape_str(a),
'type' : type_str(a),
'device' : device_str(a),
'min' : minmaxmean[0],
'max' : minmaxmean[1],
'mean' : minmaxmean[2],
})
# for each property, compute its length
maxlen = {}
for p in props: maxlen[p] = 0
for sp in str_props:
for p in props:
maxlen[p] = max(maxlen[p], len(sp[p]))
# if any property got all empty strings, don't bother printing it, remove if from the list
props = [p for p in props if maxlen[p] > 0]
# print a header
header_str = ""
for p in props:
prefix = "" if p == 'name' else " | "
fmt_key = ">" if p == 'name' else "<"
header_str += f"{prefix}{p:{fmt_key}{maxlen[p]}}"
print(header_str)
print("-"*len(header_str))
# now print the acual arrays
for strp in str_props:
for p in props:
prefix = "" if p == 'name' else " | "
fmt_key = ">" if p == 'name' else "<"
print(f"{prefix}{strp[p]:{fmt_key}{maxlen[p]}}", end='')
print("")
finally:
del frame
if __name__ == "__main__":
## test it!
# plain python vlaues
noneval = None
intval1 = 7
intval2 = -3
floatval0 = 42.0
floatval1 = 5.5 * 1e-12
floatval2 = 7.7232412351231231234 * 1e44
# numpy values
import numpy as np
npval1 = np.arange(100)
npval2 = np.arange(10000)
npval3 = np.arange(10000).astype(np.uint64)
npval4 = np.arange(10000).astype(np.float32).reshape(100,10,10)
npval5 = np.arange(10000)[-1]
# torch values
torchval1 = None
torchval2 = None
torchval3 = None
torchval4 = None
try:
import torch
torchval1 = torch.randn((1000,12,3))
torchval2 = torch.randn((1000,12,3)).cuda()
torchval3 = torch.arange(1000)
torchval4 = torch.arange(1000)[0]
except ModuleNotFoundError:
pass
# jax values
jaxval1 = None
jaxval2 = None
jaxval3 = None
jaxval4 = None
try:
import jax
import jax.numpy as jnp
jaxval1 = jnp.linspace(0,1,10000)
jaxval2 = jnp.linspace(0,1,10000).reshape(100,10,10)
jaxval3 = jnp.arange(1000)
jaxval4 = jnp.arange(1000)[0]
except ModuleNotFoundError:
pass
printarr(noneval,
intval1, intval2, \
floatval0, floatval1, floatval2, \
npval1, npval2, npval3, npval4, npval4[0,:,2:], npval5, \
torchval1, torchval2, torchval3, torchval4, \
jaxval1, jaxval2, jaxval3, jaxval4, \
)
@nmwsharp
Copy link
Author

Thanks @afspies! I incorporated that fix into the version now on pip.

@nmwsharp
Copy link
Author

This gist is now a small package on pip! pip install arrgh https://github.com/nmwsharp/arrgh

@afspies
Copy link

afspies commented Jun 16, 2023

Awesome! Thanks for your work on this :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment