Skip to content

Instantly share code, notes, and snippets.

@Niranjankumar-c
Created October 11, 2019 06:24
Show Gist options
  • Save Niranjankumar-c/1beb9f8260f1b209719c99d0258a17d7 to your computer and use it in GitHub Desktop.
Save Niranjankumar-c/1beb9f8260f1b209719c99d0258a17d7 to your computer and use it in GitHub Desktop.
plotting weights
def plot_weights(model, layer_num, single_channel = True, collated = False):
#extracting the model features at the particular layer number
layer = model.features[layer_num]
#checking whether the layer is convolution layer or not
if isinstance(layer, nn.Conv2d):
#getting the weight tensor data
weight_tensor = model.features[layer_num].weight.data
if single_channel:
if collated:
plot_filters_single_channel_big(weight_tensor)
else:
plot_filters_single_channel(weight_tensor)
else:
if weight_tensor.shape[1] == 3:
plot_filters_multi_channel(weight_tensor)
else:
print("Can only plot weights with three channels with single channel = False")
else:
print("Can only visualize layers which are convolutional")
#visualize weights for alexnet - first conv layer
plot_weights(alexnet, 0, single_channel = False)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment