Skip to content

Instantly share code, notes, and snippets.

@mayhewsw
Created May 2, 2019 17:15
Show Gist options
  • Save mayhewsw/6d5843eaabe87b6efb84098db16246ed to your computer and use it in GitHub Desktop.
Save mayhewsw/6d5843eaabe87b6efb84098db16246ed to your computer and use it in GitHub Desktop.
Display pytorch matrices using matplotlib.
import torch
import matplotlib.pyplot as plt
import numpy as np
# I learned this one in allennlp, hence the name.
p = "path/to/model/best.th"
w = torch.load(p)
for k in w.keys():
print(k)
t = w[k]
try:
# Display matrix
plt.matshow(t.cpu().numpy())
plt.show()
except Exception:
# sometimes it fails.
print("oops")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment