Skip to content

Instantly share code, notes, and snippets.

@jhjin
Created August 29, 2016 23:31
Show Gist options
  • Save jhjin/8e67b5af4b31553e94c4e14f2ee81736 to your computer and use it in GitHub Desktop.
Save jhjin/8e67b5af4b31553e94c4e14f2ee81736 to your computer and use it in GitHub Desktop.
simple script to spatialize caffe model
#!/usr/bin/env python
import numpy as np
import matplotlib.pyplot as plt
import caffe
caffe.set_mode_cpu()
path_src = 'src.caffemodel'
path_dst = 'dst.caffemodel'
net_src = caffe.Net('src.prototxt', path_src, caffe.TEST)
net_dst = caffe.Net('dst.prototxt', path_dst, caffe.TEST)
# list.txt is generated with the following plus manual modification
# cat src.prototxt | grep "name:" | grep -v "relu" | sed '1d' | sed 's/"//g' | sed 's/name:\ //g' > list.txt
layer_names = [line.rstrip('\n') for line in open('list.txt', 'r')]
for layer in layer_names:
print("==> overwriting " + layer + "...")
# get pretrained weight/bias
w_src = net_src.params[layer][0].data[...]
b_src = net_src.params[layer][1].data[...]
# get weight/bias to be overwritten
w_dst = net_dst.params[layer][0].data[...]
b_dst = net_dst.params[layer][1].data[...]
# reshape and overwrite
w_dst = np.reshape(w_src, w_dst.shape)
b_dst = np.reshape(b_src, b_dst.shape)
print("==> overwriting " + path_dst)
net_dst.save('dst.caffemodel')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment