Skip to content

Instantly share code, notes, and snippets.

@okwrtdsh
Created May 23, 2018 07:05
Show Gist options
  • Save okwrtdsh/52db34902c4892e2503872677d8a842d to your computer and use it in GitHub Desktop.
Save okwrtdsh/52db34902c4892e2503872677d8a842d to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"import caffe.proto.caffe_pb2 as caffe\n",
"# import caffe_pb2 as caffe\n",
"import numpy as np\n",
"\n",
"p = caffe.NetParameter()\n",
"with open('/src/conv3d_deepnetA_sport1m_iter_1900000', 'rb') as f:\n",
" p.ParseFromString(\n",
" f.read()\n",
" )\n",
"def rot90(W):\n",
" for i in range(W.shape[0]):\n",
" for j in range(W.shape[1]):\n",
" for k in range(W.shape[2]):\n",
" W[i, j, k] = np.rot90(W[i, j, k], 2)\n",
" return W"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"params = []\n",
"conv_layers_indx = [1, 4, 7, 9, 12, 14, 17, 19]\n",
"for i in conv_layers_indx:\n",
" layer = p.layers[i]\n",
" weights_b = np.array(layer.blobs[1].diff, dtype=np.float32)\n",
" weights_p = np.array(layer.blobs[0].diff, dtype=np.float32).reshape(\n",
" layer.blobs[0].num, layer.blobs[0].channels, -1,\n",
" layer.blobs[0].height, layer.blobs[0].width\n",
" )\n",
" weights_p = rot90(weights_p)\n",
" params.append([weights_p, weights_b])"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"fc_layers_indx = [22, 25, 28]\n",
"for i in fc_layers_indx:\n",
" layer = p.layers[i]\n",
" weights_b = np.array(layer.blobs[1].diff, dtype=np.float32)\n",
" weights_p = np.array(layer.blobs[0].diff, dtype=np.float32).reshape(\n",
" layer.blobs[0].num, layer.blobs[0].channels, \n",
" layer.blobs[0].height,\n",
" {22:8192, 25:4096, 28:4096}[i], layer.blobs[0].width)[0,0,0,:,:].T\n",
" params.append([weights_p, weights_b])"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"import pickle\n",
"with open('/src/params.pkl', mode='wb') as f:\n",
" pickle.dump(params,f)\n",
"# with open('/src/params.pickle', 'rb') as f:\n",
"# params = pickle.load(f)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/usr/local/lib/python2.7/dist-packages/h5py/__init__.py:36: FutureWarning: Conversion of the second argument of issubdtype from `float` to `np.floating` is deprecated. In future, it will be treated as `np.float64 == np.dtype(float).type`.\n",
" from ._conv import register_converters as _register_converters\n",
"Using TensorFlow backend.\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"_________________________________________________________________\n",
"Layer (type) Output Shape Param # \n",
"=================================================================\n",
"conv1 (Conv3D) (None, 16, 112, 112, 64) 5248 \n",
"_________________________________________________________________\n",
"pool1 (MaxPooling3D) (None, 16, 56, 56, 64) 0 \n",
"_________________________________________________________________\n",
"conv2 (Conv3D) (None, 16, 56, 56, 128) 221312 \n",
"_________________________________________________________________\n",
"pool2 (MaxPooling3D) (None, 8, 28, 28, 128) 0 \n",
"_________________________________________________________________\n",
"conv3a (Conv3D) (None, 8, 28, 28, 256) 884992 \n",
"_________________________________________________________________\n",
"conv3b (Conv3D) (None, 8, 28, 28, 256) 1769728 \n",
"_________________________________________________________________\n",
"pool3 (MaxPooling3D) (None, 4, 14, 14, 256) 0 \n",
"_________________________________________________________________\n",
"conv4a (Conv3D) (None, 4, 14, 14, 512) 3539456 \n",
"_________________________________________________________________\n",
"conv4b (Conv3D) (None, 4, 14, 14, 512) 7078400 \n",
"_________________________________________________________________\n",
"pool4 (MaxPooling3D) (None, 2, 7, 7, 512) 0 \n",
"_________________________________________________________________\n",
"conv5a (Conv3D) (None, 2, 7, 7, 512) 7078400 \n",
"_________________________________________________________________\n",
"conv5b (Conv3D) (None, 2, 7, 7, 512) 7078400 \n",
"_________________________________________________________________\n",
"zeropad5 (ZeroPadding3D) (None, 2, 8, 8, 512) 0 \n",
"_________________________________________________________________\n",
"pool5 (MaxPooling3D) (None, 1, 4, 4, 512) 0 \n",
"_________________________________________________________________\n",
"flatten_1 (Flatten) (None, 8192) 0 \n",
"_________________________________________________________________\n",
"fc6 (Dense) (None, 4096) 33558528 \n",
"_________________________________________________________________\n",
"dropout_1 (Dropout) (None, 4096) 0 \n",
"_________________________________________________________________\n",
"fc7 (Dense) (None, 4096) 16781312 \n",
"_________________________________________________________________\n",
"dropout_2 (Dropout) (None, 4096) 0 \n",
"_________________________________________________________________\n",
"fc8 (Dense) (None, 487) 1995239 \n",
"=================================================================\n",
"Total params: 79,991,015\n",
"Trainable params: 79,991,015\n",
"Non-trainable params: 0\n",
"_________________________________________________________________\n",
"None\n"
]
}
],
"source": [
"from keras.models import Sequential\n",
"from keras.layers import Conv3D, MaxPool3D, ZeroPadding3D, Dense, Dropout, Flatten\n",
"def get_model(summary=False, backend='tf'):\n",
" \"\"\" Return the Keras model of the network\n",
" \"\"\"\n",
" model = Sequential()\n",
" if backend == 'tf':\n",
" input_shape = (16, 112, 112, 3) # l, h, w, c\n",
" else:\n",
" input_shape = (3, 16, 112, 112) # c, l, h, w\n",
" model.add(Conv3D(64, (3, 3, 3), activation='relu',\n",
" padding='same', name='conv1',\n",
" input_shape=input_shape))\n",
" model.add(MaxPool3D(pool_size=(1, 2, 2), strides=(1, 2, 2),\n",
" padding='valid', name='pool1'))\n",
" # 2nd layer group\n",
" model.add(Conv3D(128, (3, 3, 3), activation='relu',\n",
" padding='same', name='conv2'))\n",
" model.add(MaxPool3D(pool_size=(2, 2, 2), strides=(2, 2, 2),\n",
" padding='valid', name='pool2'))\n",
" # 3rd layer group\n",
" model.add(Conv3D(256, (3, 3, 3), activation='relu',\n",
" padding='same', name='conv3a'))\n",
" model.add(Conv3D(256, (3, 3, 3), activation='relu',\n",
" padding='same', name='conv3b'))\n",
" model.add(MaxPool3D(pool_size=(2, 2, 2), strides=(2, 2, 2),\n",
" padding='valid', name='pool3'))\n",
" # 4th layer group\n",
" model.add(Conv3D(512, (3, 3, 3), activation='relu',\n",
" padding='same', name='conv4a'))\n",
" model.add(Conv3D(512, (3, 3, 3), activation='relu',\n",
" padding='same', name='conv4b'))\n",
" model.add(MaxPool3D(pool_size=(2, 2, 2), strides=(2, 2, 2),\n",
" padding='valid', name='pool4'))\n",
" # 5th layer group\n",
" model.add(Conv3D(512, (3, 3, 3), activation='relu',\n",
" padding='same', name='conv5a'))\n",
" model.add(Conv3D(512, (3, 3, 3), activation='relu',\n",
" padding='same', name='conv5b'))\n",
" model.add(ZeroPadding3D(padding=((0, 0), (0, 1), (0, 1)),\n",
" name='zeropad5'))\n",
" model.add(MaxPool3D(pool_size=(2, 2, 2), strides=(2, 2, 2),\n",
" padding='valid', name='pool5'))\n",
" model.add(Flatten())\n",
" # FC layers group\n",
" model.add(Dense(4096, activation='relu', name='fc6'))\n",
" model.add(Dropout(.5))\n",
" model.add(Dense(4096, activation='relu', name='fc7'))\n",
" model.add(Dropout(.5))\n",
" model.add(Dense(487, activation='softmax', name='fc8'))\n",
" model.compile(optimizer='sgd', loss='mean_squared_error')\n",
"\n",
" if summary:\n",
" print(model.summary())\n",
" return model\n",
"model = get_model(summary=True)"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"(0, 0, 'conv1')\n",
"((3, 3, 3, 3, 64), (64,), (3, 3, 3, 3, 64), (64,))\n",
"(2, 1, 'conv2')\n",
"((3, 3, 3, 64, 128), (128,), (3, 3, 3, 64, 128), (128,))\n",
"(4, 2, 'conv3a')\n",
"((3, 3, 3, 128, 256), (256,), (3, 3, 3, 128, 256), (256,))\n",
"(5, 3, 'conv3b')\n",
"((3, 3, 3, 256, 256), (256,), (3, 3, 3, 256, 256), (256,))\n",
"(7, 4, 'conv4a')\n",
"((3, 3, 3, 256, 512), (512,), (3, 3, 3, 256, 512), (512,))\n",
"(8, 5, 'conv4b')\n",
"((3, 3, 3, 512, 512), (512,), (3, 3, 3, 512, 512), (512,))\n",
"(10, 6, 'conv5a')\n",
"((3, 3, 3, 512, 512), (512,), (3, 3, 3, 512, 512), (512,))\n",
"(11, 7, 'conv5b')\n",
"((3, 3, 3, 512, 512), (512,), (3, 3, 3, 512, 512), (512,))\n",
"(15, 8, 'fc6')\n",
"((8192, 4096), (4096,), (8192, 4096), (4096,))\n",
"(17, 9, 'fc7')\n",
"((4096, 4096), (4096,), (4096, 4096), (4096,))\n",
"(19, 10, 'fc8')\n",
"((4096, 487), (487,), (4096, 487), (487,))\n"
]
}
],
"source": [
"model_layers_indx = [0, 2, 4, 5, 7, 8, 10, 11] + [15, 17, 19]\n",
"for j,i in enumerate(model_layers_indx):\n",
" print(i,j, model.layers[i].name)\n",
" wp,wb = model.layers[i].get_weights()\n",
" pp, pb = params[j]\n",
" print(wp.shape, wb.shape, pp.transpose().shape, pb.shape)\n",
" model.layers[i].set_weights([pp.transpose(), pb])"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"import h5py\n",
"\n",
"model.save_weights('/src/sports1M_weights.h5', overwrite=True)\n",
"json_string = model.to_json()\n",
"with open('/src/sports1M_model.json', 'w') as f:\n",
" f.write(json_string)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 2",
"language": "python",
"name": "python2"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 2
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython2",
"version": "2.7.6"
}
},
"nbformat": 4,
"nbformat_minor": 0
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment