weights_conv1weights_ = get_weights_variable(layer_name='layer_conv1')
plot_conv_weightsplot_con(weights_conv1weights_)
Last active
January 19, 2019 07:44
-
-
Save wut0n9/0347f5985fcc879820c449b0a28e68c1 to your computer and use it in GitHub Desktop.
对特定卷积层权重可视化
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# https://colab.research.google.com/github/Hvass-Labs/TensorFlow-Tutorials/blob/master/04_Save_Restore.ipynb#scrollTo=WTQRVlJU_1NN | |
# https://github.com/Hvass-Labs/TensorFlow-Tutorials/blob/master/04_Save_Restore.ipynb | |
%%matplotlibmatplot inline | |
import matplotlib.pyplot as plt | |
import tensorflow as tf | |
import numpy as np | |
from sklearn.metrics import confusion_matrix | |
import time | |
from datetime import timedelta | |
import math | |
import os | |
# Use PrettyTensor to simplify Neural Network construction. | |
import prettytensor as pt | |
def plot_conv_weightsplot_con (weights, input_channel=0): | |
# Assume weights are TensorFlow ops for 4-dim variables | |
# e.g. weights_conv1 or weights_conv2. | |
# Retrieve the values of the weight-variables from TensorFlow. | |
# A feed-dict is not necessary because nothing is calculated. | |
w = session.run(weights) | |
# Print mean and standard deviation. | |
print("Mean: {0:.5f}, Stdev: {1:.5f}".format(w.mean(), w.std())) | |
# Get the lowest and highest values for the weights. | |
# This is used to correct the colour intensity across | |
# the images so they can be compared with each other. | |
w_min = np.min(w) | |
w_max = np.max(w) | |
# Number of filters used in the conv. layer. | |
num_filters = w.shape[3] | |
# Number of grids to plot. | |
# Rounded-up, square-root of the number of filters. | |
num_grids = math.ceil(math.sqrt(num_filters)) | |
# Create figure with a grid of sub-plots. | |
fig, axes = plt.subplots(num_grids, num_grids) | |
# Plot all the filter-weights. | |
for i, ax in enumerate(axes.flat): | |
# Only plot the valid filter-weights. | |
if i<num_filters: | |
# Get the weights for the i'th filter of the input channel. | |
# The format of this 4-dim tensor is determined by the | |
# TensorFlow API. See Tutorial #02 for more details. | |
img = w[:, :, input_channel, i] | |
# Plot image. | |
ax.imshow(img, vmin=w_min, vmax=w_max, | |
interpolation='nearest', cmap='seismic') | |
# Remove ticks from the plot. | |
ax.set_xticks([]) | |
ax.set_yticks([]) | |
# Ensure the plot is shown correctly with multiple plots | |
# in a single Notebook cell. | |
plt.show() | |
def get_weights_variable(layer_name): | |
# Retrieve an existing variable named 'weights' in the scope | |
# with the given layer_name. | |
# This is awkward because the TensorFlow function was | |
# really intended for another purpose. | |
with tf.variable_scope(layer_name, reuse=True): | |
variable = tf.get_variable('weights') | |
return variable |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment