Skip to content

Instantly share code, notes, and snippets.

@krishvishal
Last active October 22, 2022 18:27
Show Gist options
  • Star 11 You must be signed in to star a gist
  • Fork 2 You must be signed in to fork a gist
  • Save krishvishal/e6bebc0d809a31f56cbccf5e15f24016 to your computer and use it in GitHub Desktop.
Save krishvishal/e6bebc0d809a31f56cbccf5e15f24016 to your computer and use it in GitHub Desktop.
Visualize weights in pytorch
from model import Net
from trainer import Trainer
import torch
from torch import nn
from matplotlib import pyplot as plt
model = Net()
ckpt = torch.load('path_to_checkpoint')
model.load_state_dict(ckpt['state_dict'])
filter = model.conv1.weight.data.numpy()
#(1/(2*(maximum negative value)))*filter+0.5 === you need to normalize the filter before plotting.
filter = (1/(2*3.69201088))*filter + 0.5 #Normalizing the values to [0,1]
#num_cols= choose the grid size you want
def plot_kernels(tensor, num_cols=8):
if not tensor.ndim==4:
raise Exception("assumes a 4D tensor")
if not tensor.shape[-1]==3:
raise Exception("last dim needs to be 3 to plot")
num_kernels = tensor.shape[0]
num_rows = 1+ num_kernels // num_cols
fig = plt.figure(figsize=(num_cols,num_rows))
for i in range(tensor.shape[0]):
ax1 = fig.add_subplot(num_rows,num_cols,i+1)
ax1.imshow(tensor[i])
ax1.axis('off')
ax1.set_xticklabels([])
ax1.set_yticklabels([])
plt.subplots_adjust(wspace=0.1, hspace=0.1)
plt.show()
plot_kernels(filter)
@sayloren
Copy link

i think you might need a [0] after the [i] in your ax1.imshow(tensor[i]) so that it is plotting the right dimensions to vis your weights?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment