Skip to content

Instantly share code, notes, and snippets.

@jaderberg
Created December 17, 2014 11:33
Show Gist options
  • Save jaderberg/2d4c8d30582a88da97d0 to your computer and use it in GitHub Desktop.
Save jaderberg/2d4c8d30582a88da97d0 to your computer and use it in GitHub Desktop.
#!/usr/bin/env python
# Max Jaderberg 21/5/14
# this replaces weights from a matlab file
from proto import caffe_pb2
import numpy as np
import scipy.io
import os
# path to the net
net_path = './caffe_imagenet_train_iter_755000'
new_net_path = './caffe_imagenet_train_iter_755000_approxconv2_2x'
replace_name = 'conv2'
replace_weights = 'conv2_scheme12x.mat'
# load net
net = caffe_pb2.NetParameter()
fid = open(net_path, 'rb')
net.ParseFromString(fid.read())
fid.close()
# load weights
w = scipy.io.loadmat(replace_weights)['w']
net_layers = [conn.layer.name for conn in net.layers]
#print net_layers
for layer_name in net_layers:
if layer_name != replace_name:
continue
layer_idx = net_layers.index(layer_name)
# list of blobs
origblobs = list(net.layers[layer_idx].layer.blobs)
origblob = net.layers[layer_idx].layer.blobs[0]
del origblob.data[:]
origblob.data.extend(w.astype(float).flat)
# Write the new address book back to disk.
f = open(new_net_path, "wb")
f.write(net.SerializeToString())
f.close()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment