Skip to content

Instantly share code, notes, and snippets.

@ayushidalmia
Created November 6, 2017 11:06
Show Gist options
  • Save ayushidalmia/6127bdd649f2527a9e533aa7b523035b to your computer and use it in GitHub Desktop.
Save ayushidalmia/6127bdd649f2527a9e533aa7b523035b to your computer and use it in GitHub Desktop.
def load_initial_weights(self, session):
"""
As the weights from http://www.cs.toronto.edu/~guerzhoy/tf_alexnet/ come
as a dict of lists (e.g. weights['conv1'] is a list) and not as dict of
dicts (e.g. weights['conv1'] is a dict with keys 'weights' & 'biases') we
need a special load function
"""
# Load the weights into memory
weights_dict = np.load(self.WEIGHTS_PATH, encoding = 'bytes').item()
# Loop over all layer names stored in the weights dict
for op_name in weights_dict:
# Check if the layer is one of the layers that should be reinitialized
train_bool = True
if op_name not in self.SKIP_LAYER:
train_bool = False
with tf.variable_scope(op_name, reuse = True):
# Loop over list of weights/biases and assign them to their corresponding tf variable
for data in weights_dict[op_name]:
# Biases
if len(data.shape) == 1:
var = tf.get_variable('biases', trainable = train_bool)
session.run(var.assign(data))
# Weights
else:
var = tf.get_variable('weights', trainable = train_bool)
session.run(var.assign(data))
@kratzert
Copy link

kratzert commented Nov 6, 2017

I think there is an indentation error starting from

        # Biases
        if len(data.shape) == 1:
              
          var = tf.get_variable('biases', trainable = train_bool)
          session.run(var.assign(data))
              
            # Weights
        else:
              
          var = tf.get_variable('weights', trainable = train_bool)
          session.run(var.assign(data))

this whole block needs to be indented.
For the rest, I think it looks good to me.

@ayushidalmia
Copy link
Author

Yeah. My bad.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment