Last active
July 26, 2018 05:21
-
-
Save oraoto/8cfcbb92c3482f4f00aaf5947b3b6868 to your computer and use it in GitHub Desktop.
Paddle Fluid pre-trained model fine-tuning
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"metadata": { | |
"ExecuteTime": { | |
"end_time": "2018-07-20T10:03:22.529185Z", | |
"start_time": "2018-07-20T10:03:21.747660Z" | |
} | |
}, | |
"outputs": [], | |
"source": [ | |
"import paddle.fluid as fluid\n", | |
"import paddle\n", | |
"#from se_resnext import SE_ResNeXt50_32x4d\n", | |
"import numpy as np\n", | |
"import os\n", | |
"import math\n", | |
"from paddle.fluid.debugger import draw_block_graphviz" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"metadata": { | |
"ExecuteTime": { | |
"end_time": "2018-07-20T10:03:22.537785Z", | |
"start_time": "2018-07-20T10:03:22.533273Z" | |
} | |
}, | |
"outputs": [], | |
"source": [ | |
"pretrained_model_path = \"models/se_resnext_50/129\"" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"metadata": { | |
"ExecuteTime": { | |
"end_time": "2018-07-20T10:03:22.616251Z", | |
"start_time": "2018-07-20T10:03:22.541650Z" | |
} | |
}, | |
"outputs": [], | |
"source": [ | |
"# 修改自 https://github.com/PaddlePaddle/models/blob/develop/fluid/image_classification/models/se_resnext.py\n", | |
"# 1. 增加了self.variables记录中间变量\n", | |
"# 2. 去掉 101, 152层的支持\n", | |
"# 3. 修改最后的fc层名字,避免加载参数\n", | |
"\n", | |
"class SE_ResNeXt50():\n", | |
" \n", | |
" def __init__(self):\n", | |
" # 记录中间变量\n", | |
" self.variables = []\n", | |
" \n", | |
" def net(self, input, class_dim=1000):\n", | |
" cardinality = 32\n", | |
" reduction_ratio = 16\n", | |
" depth = [3, 4, 6, 3]\n", | |
" num_filters = [128, 256, 512, 1024]\n", | |
" \n", | |
" conv = self.conv_bn_layer(\n", | |
" input=input,\n", | |
" num_filters=64,\n", | |
" filter_size=7,\n", | |
" stride=2,\n", | |
" act='relu')\n", | |
" conv = fluid.layers.pool2d(\n", | |
" input=conv,\n", | |
" pool_size=3,\n", | |
" pool_stride=2,\n", | |
" pool_padding=1,\n", | |
" pool_type='max')\n", | |
" \n", | |
" for block in range(len(depth)):\n", | |
" for i in range(depth[block]):\n", | |
" conv = self.bottleneck_block(\n", | |
" input=conv,\n", | |
" num_filters=num_filters[block],\n", | |
" stride=2 if i == 0 and block != 0 else 1,\n", | |
" cardinality=cardinality,\n", | |
" reduction_ratio=reduction_ratio)\n", | |
" \n", | |
" pool = fluid.layers.pool2d(\n", | |
" input=conv, pool_size=7, pool_type='avg', global_pooling=True)\n", | |
" drop = fluid.layers.dropout(x=pool, dropout_prob=0.5)\n", | |
" stdv = 1.0 / math.sqrt(drop.shape[1] * 1.0)\n", | |
" out = fluid.layers.fc(input=drop,\n", | |
" size=class_dim,\n", | |
" name='se_res_out',\n", | |
" act='softmax',\n", | |
" param_attr=fluid.param_attr.ParamAttr(\n", | |
" initializer=fluid.initializer.Uniform(-stdv,\n", | |
" stdv)))\n", | |
" return out\n", | |
"\n", | |
" def shortcut(self, input, ch_out, stride):\n", | |
" ch_in = input.shape[1]\n", | |
" if ch_in != ch_out or stride != 1:\n", | |
" filter_size = 1\n", | |
" return self.conv_bn_layer(input, ch_out, filter_size, stride)\n", | |
" else:\n", | |
" return input\n", | |
"\n", | |
" def bottleneck_block(self, input, num_filters, stride, cardinality,\n", | |
" reduction_ratio):\n", | |
" conv0 = self.conv_bn_layer(\n", | |
" input=input, num_filters=num_filters, filter_size=1, act='relu')\n", | |
" conv1 = self.conv_bn_layer(\n", | |
" input=conv0,\n", | |
" num_filters=num_filters,\n", | |
" filter_size=3,\n", | |
" stride=stride,\n", | |
" groups=cardinality,\n", | |
" act='relu')\n", | |
" conv2 = self.conv_bn_layer(\n", | |
" input=conv1, num_filters=num_filters * 2, filter_size=1, act=None)\n", | |
" scale = self.squeeze_excitation(\n", | |
" input=conv2,\n", | |
" num_channels=num_filters * 2,\n", | |
" reduction_ratio=reduction_ratio)\n", | |
"\n", | |
" short = self.shortcut(input, num_filters * 2, stride)\n", | |
"\n", | |
" return fluid.layers.elementwise_add(x=short, y=scale, act='relu')\n", | |
"\n", | |
" def conv_bn_layer(self,\n", | |
" input,\n", | |
" num_filters,\n", | |
" filter_size,\n", | |
" stride=1,\n", | |
" groups=1,\n", | |
" act=None):\n", | |
" conv = fluid.layers.conv2d(\n", | |
" input=input,\n", | |
" num_filters=num_filters,\n", | |
" filter_size=filter_size,\n", | |
" stride=stride,\n", | |
" padding=(filter_size - 1) / 2,\n", | |
" groups=groups,\n", | |
" act=None,\n", | |
" bias_attr=False)\n", | |
" self.variables.append(conv)\n", | |
" bn = fluid.layers.batch_norm(input=conv, act=act)\n", | |
" self.variables.append(bn)\n", | |
" return bn\n", | |
"\n", | |
" def squeeze_excitation(self, input, num_channels, reduction_ratio):\n", | |
" pool = fluid.layers.pool2d(\n", | |
" input=input, pool_size=0, pool_type='avg', global_pooling=True)\n", | |
" stdv = 1.0 / math.sqrt(pool.shape[1] * 1.0)\n", | |
" squeeze = fluid.layers.fc(input=pool,\n", | |
" size=num_channels / reduction_ratio,\n", | |
" act='relu',\n", | |
" param_attr=fluid.param_attr.ParamAttr(\n", | |
" initializer=fluid.initializer.Uniform(\n", | |
" -stdv, stdv)))\n", | |
" self.variables.append(squeeze)\n", | |
" stdv = 1.0 / math.sqrt(squeeze.shape[1] * 1.0)\n", | |
" excitation = fluid.layers.fc(input=squeeze,\n", | |
" size=num_channels,\n", | |
" act='sigmoid',\n", | |
" param_attr=fluid.param_attr.ParamAttr(\n", | |
" initializer=fluid.initializer.Uniform(\n", | |
" -stdv, stdv)))\n", | |
" self.variables.append(excitation)\n", | |
" scale = fluid.layers.elementwise_mul(x=input, y=excitation, axis=0)\n", | |
" return scale" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"metadata": { | |
"ExecuteTime": { | |
"end_time": "2018-07-20T10:03:22.860628Z", | |
"start_time": "2018-07-20T10:03:22.619990Z" | |
} | |
}, | |
"outputs": [], | |
"source": [ | |
"image = fluid.layers.data(name='image', shape=[3, 32, 32], dtype='float32')\n", | |
"label = fluid.layers.data(name='label', shape=[-1, 1], dtype='int64')\n", | |
"\n", | |
"base_model = SE_ResNeXt50()\n", | |
"#base_model = SE_ResNeXt50_32x4d()\n", | |
"predict = base_model.net(image, class_dim=10)\n", | |
"\n", | |
"inference_program = fluid.default_main_program().clone(for_test=True)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 5, | |
"metadata": { | |
"ExecuteTime": { | |
"end_time": "2018-07-20T10:03:22.932634Z", | |
"start_time": "2018-07-20T10:03:22.864467Z" | |
} | |
}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"(0, 'conv2d_0.tmp_0')\n", | |
"(1, 'batch_norm_0.tmp_2')\n", | |
"(2, 'conv2d_1.tmp_0')\n", | |
"(3, 'batch_norm_1.tmp_2')\n", | |
"(4, 'conv2d_2.tmp_0')\n", | |
"(5, 'batch_norm_2.tmp_2')\n", | |
"(6, 'conv2d_3.tmp_0')\n", | |
"(7, 'batch_norm_3.tmp_2')\n", | |
"(8, 'fc_0.tmp_1')\n", | |
"(9, 'fc_1.tmp_1')\n", | |
"(10, 'conv2d_4.tmp_0')\n", | |
"(11, 'batch_norm_4.tmp_2')\n", | |
"(12, 'conv2d_5.tmp_0')\n", | |
"(13, 'batch_norm_5.tmp_2')\n", | |
"(14, 'conv2d_6.tmp_0')\n", | |
"(15, 'batch_norm_6.tmp_2')\n", | |
"(16, 'conv2d_7.tmp_0')\n", | |
"(17, 'batch_norm_7.tmp_2')\n", | |
"(18, 'fc_2.tmp_1')\n", | |
"(19, 'fc_3.tmp_1')\n", | |
"(20, 'conv2d_8.tmp_0')\n", | |
"(21, 'batch_norm_8.tmp_2')\n", | |
"(22, 'conv2d_9.tmp_0')\n", | |
"(23, 'batch_norm_9.tmp_2')\n", | |
"(24, 'conv2d_10.tmp_0')\n", | |
"(25, 'batch_norm_10.tmp_2')\n", | |
"(26, 'fc_4.tmp_1')\n", | |
"(27, 'fc_5.tmp_1')\n", | |
"(28, 'conv2d_11.tmp_0')\n", | |
"(29, 'batch_norm_11.tmp_2')\n", | |
"(30, 'conv2d_12.tmp_0')\n", | |
"(31, 'batch_norm_12.tmp_2')\n", | |
"(32, 'conv2d_13.tmp_0')\n", | |
"(33, 'batch_norm_13.tmp_2')\n", | |
"(34, 'fc_6.tmp_1')\n", | |
"(35, 'fc_7.tmp_1')\n", | |
"(36, 'conv2d_14.tmp_0')\n", | |
"(37, 'batch_norm_14.tmp_2')\n", | |
"(38, 'conv2d_15.tmp_0')\n", | |
"(39, 'batch_norm_15.tmp_2')\n", | |
"(40, 'conv2d_16.tmp_0')\n", | |
"(41, 'batch_norm_16.tmp_2')\n", | |
"(42, 'conv2d_17.tmp_0')\n", | |
"(43, 'batch_norm_17.tmp_2')\n", | |
"(44, 'fc_8.tmp_1')\n", | |
"(45, 'fc_9.tmp_1')\n", | |
"(46, 'conv2d_18.tmp_0')\n", | |
"(47, 'batch_norm_18.tmp_2')\n", | |
"(48, 'conv2d_19.tmp_0')\n", | |
"(49, 'batch_norm_19.tmp_2')\n", | |
"(50, 'conv2d_20.tmp_0')\n", | |
"(51, 'batch_norm_20.tmp_2')\n", | |
"(52, 'fc_10.tmp_1')\n", | |
"(53, 'fc_11.tmp_1')\n", | |
"(54, 'conv2d_21.tmp_0')\n", | |
"(55, 'batch_norm_21.tmp_2')\n", | |
"(56, 'conv2d_22.tmp_0')\n", | |
"(57, 'batch_norm_22.tmp_2')\n", | |
"(58, 'conv2d_23.tmp_0')\n", | |
"(59, 'batch_norm_23.tmp_2')\n", | |
"(60, 'fc_12.tmp_1')\n", | |
"(61, 'fc_13.tmp_1')\n", | |
"(62, 'conv2d_24.tmp_0')\n", | |
"(63, 'batch_norm_24.tmp_2')\n", | |
"(64, 'conv2d_25.tmp_0')\n", | |
"(65, 'batch_norm_25.tmp_2')\n", | |
"(66, 'conv2d_26.tmp_0')\n", | |
"(67, 'batch_norm_26.tmp_2')\n", | |
"(68, 'fc_14.tmp_1')\n", | |
"(69, 'fc_15.tmp_1')\n", | |
"(70, 'conv2d_27.tmp_0')\n", | |
"(71, 'batch_norm_27.tmp_2')\n", | |
"(72, 'conv2d_28.tmp_0')\n", | |
"(73, 'batch_norm_28.tmp_2')\n", | |
"(74, 'conv2d_29.tmp_0')\n", | |
"(75, 'batch_norm_29.tmp_2')\n", | |
"(76, 'conv2d_30.tmp_0')\n", | |
"(77, 'batch_norm_30.tmp_2')\n", | |
"(78, 'fc_16.tmp_1')\n", | |
"(79, 'fc_17.tmp_1')\n", | |
"(80, 'conv2d_31.tmp_0')\n", | |
"(81, 'batch_norm_31.tmp_2')\n", | |
"(82, 'conv2d_32.tmp_0')\n", | |
"(83, 'batch_norm_32.tmp_2')\n", | |
"(84, 'conv2d_33.tmp_0')\n", | |
"(85, 'batch_norm_33.tmp_2')\n", | |
"(86, 'fc_18.tmp_1')\n", | |
"(87, 'fc_19.tmp_1')\n", | |
"(88, 'conv2d_34.tmp_0')\n", | |
"(89, 'batch_norm_34.tmp_2')\n", | |
"(90, 'conv2d_35.tmp_0')\n", | |
"(91, 'batch_norm_35.tmp_2')\n", | |
"(92, 'conv2d_36.tmp_0')\n", | |
"(93, 'batch_norm_36.tmp_2')\n", | |
"(94, 'fc_20.tmp_1')\n", | |
"(95, 'fc_21.tmp_1')\n", | |
"(96, 'conv2d_37.tmp_0')\n", | |
"(97, 'batch_norm_37.tmp_2')\n", | |
"(98, 'conv2d_38.tmp_0')\n", | |
"(99, 'batch_norm_38.tmp_2')\n", | |
"(100, 'conv2d_39.tmp_0')\n", | |
"(101, 'batch_norm_39.tmp_2')\n", | |
"(102, 'fc_22.tmp_1')\n", | |
"(103, 'fc_23.tmp_1')\n", | |
"(104, 'conv2d_40.tmp_0')\n", | |
"(105, 'batch_norm_40.tmp_2')\n", | |
"(106, 'conv2d_41.tmp_0')\n", | |
"(107, 'batch_norm_41.tmp_2')\n", | |
"(108, 'conv2d_42.tmp_0')\n", | |
"(109, 'batch_norm_42.tmp_2')\n", | |
"(110, 'fc_24.tmp_1')\n", | |
"(111, 'fc_25.tmp_1')\n", | |
"(112, 'conv2d_43.tmp_0')\n", | |
"(113, 'batch_norm_43.tmp_2')\n", | |
"(114, 'conv2d_44.tmp_0')\n", | |
"(115, 'batch_norm_44.tmp_2')\n", | |
"(116, 'conv2d_45.tmp_0')\n", | |
"(117, 'batch_norm_45.tmp_2')\n", | |
"(118, 'fc_26.tmp_1')\n", | |
"(119, 'fc_27.tmp_1')\n", | |
"(120, 'conv2d_46.tmp_0')\n", | |
"(121, 'batch_norm_46.tmp_2')\n", | |
"(122, 'conv2d_47.tmp_0')\n", | |
"(123, 'batch_norm_47.tmp_2')\n", | |
"(124, 'conv2d_48.tmp_0')\n", | |
"(125, 'batch_norm_48.tmp_2')\n", | |
"(126, 'conv2d_49.tmp_0')\n", | |
"(127, 'batch_norm_49.tmp_2')\n", | |
"(128, 'fc_28.tmp_1')\n", | |
"(129, 'fc_29.tmp_1')\n", | |
"(130, 'conv2d_50.tmp_0')\n", | |
"(131, 'batch_norm_50.tmp_2')\n", | |
"(132, 'conv2d_51.tmp_0')\n", | |
"(133, 'batch_norm_51.tmp_2')\n", | |
"(134, 'conv2d_52.tmp_0')\n", | |
"(135, 'batch_norm_52.tmp_2')\n", | |
"(136, 'fc_30.tmp_1')\n", | |
"(137, 'fc_31.tmp_1')\n" | |
] | |
} | |
], | |
"source": [ | |
"for i, v in enumerate(base_model.variables):\n", | |
" print(i, v.name)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 6, | |
"metadata": { | |
"ExecuteTime": { | |
"end_time": "2018-07-20T10:03:22.971445Z", | |
"start_time": "2018-07-20T10:03:22.938370Z" | |
} | |
}, | |
"outputs": [], | |
"source": [ | |
"#只训练最后n层\n", | |
"for v in base_model.variables[:-10]:\n", | |
" v.stop_gradient = True\n", | |
"for v in base_model.variables[-10:]:\n", | |
" v.stop_gradient = False" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 7, | |
"metadata": { | |
"ExecuteTime": { | |
"end_time": "2018-07-20T10:03:23.039162Z", | |
"start_time": "2018-07-20T10:03:22.975037Z" | |
} | |
}, | |
"outputs": [], | |
"source": [ | |
"loss = fluid.layers.cross_entropy(input=predict, label=label)\n", | |
"loss = fluid.layers.mean(loss)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 8, | |
"metadata": { | |
"ExecuteTime": { | |
"end_time": "2018-07-20T10:03:23.143387Z", | |
"start_time": "2018-07-20T10:03:23.045525Z" | |
} | |
}, | |
"outputs": [], | |
"source": [ | |
"opt = fluid.optimizer.Adam(learning_rate=0.001)\n", | |
"_ = opt.minimize(loss)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 9, | |
"metadata": { | |
"ExecuteTime": { | |
"end_time": "2018-07-20T10:03:23.667592Z", | |
"start_time": "2018-07-20T10:03:23.147225Z" | |
} | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"[]" | |
] | |
}, | |
"execution_count": 9, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"exe = fluid.executor.Executor(fluid.CPUPlace())\n", | |
"exe.run(fluid.default_startup_program())" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 10, | |
"metadata": { | |
"ExecuteTime": { | |
"end_time": "2018-07-20T10:03:23.759366Z", | |
"start_time": "2018-07-20T10:03:23.670510Z" | |
} | |
}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"models/se_resnext_50/129/conv2d_0.w_0\n", | |
"models/se_resnext_50/129/batch_norm_0.w_0\n", | |
"models/se_resnext_50/129/batch_norm_0.b_0\n", | |
"models/se_resnext_50/129/batch_norm_0.w_1\n", | |
"models/se_resnext_50/129/batch_norm_0.w_2\n", | |
"models/se_resnext_50/129/conv2d_1.w_0\n", | |
"models/se_resnext_50/129/batch_norm_1.w_0\n", | |
"models/se_resnext_50/129/batch_norm_1.b_0\n", | |
"models/se_resnext_50/129/batch_norm_1.w_1\n", | |
"models/se_resnext_50/129/batch_norm_1.w_2\n", | |
"models/se_resnext_50/129/conv2d_2.w_0\n", | |
"models/se_resnext_50/129/batch_norm_2.w_0\n", | |
"models/se_resnext_50/129/batch_norm_2.b_0\n", | |
"models/se_resnext_50/129/batch_norm_2.w_1\n", | |
"models/se_resnext_50/129/batch_norm_2.w_2\n", | |
"models/se_resnext_50/129/conv2d_3.w_0\n", | |
"models/se_resnext_50/129/batch_norm_3.w_0\n", | |
"models/se_resnext_50/129/batch_norm_3.b_0\n", | |
"models/se_resnext_50/129/batch_norm_3.w_1\n", | |
"models/se_resnext_50/129/batch_norm_3.w_2\n", | |
"models/se_resnext_50/129/fc_0.w_0\n", | |
"models/se_resnext_50/129/fc_0.b_0\n", | |
"models/se_resnext_50/129/fc_1.w_0\n", | |
"models/se_resnext_50/129/fc_1.b_0\n", | |
"models/se_resnext_50/129/conv2d_4.w_0\n", | |
"models/se_resnext_50/129/batch_norm_4.w_0\n", | |
"models/se_resnext_50/129/batch_norm_4.b_0\n", | |
"models/se_resnext_50/129/batch_norm_4.w_1\n", | |
"models/se_resnext_50/129/batch_norm_4.w_2\n", | |
"models/se_resnext_50/129/conv2d_5.w_0\n", | |
"models/se_resnext_50/129/batch_norm_5.w_0\n", | |
"models/se_resnext_50/129/batch_norm_5.b_0\n", | |
"models/se_resnext_50/129/batch_norm_5.w_1\n", | |
"models/se_resnext_50/129/batch_norm_5.w_2\n", | |
"models/se_resnext_50/129/conv2d_6.w_0\n", | |
"models/se_resnext_50/129/batch_norm_6.w_0\n", | |
"models/se_resnext_50/129/batch_norm_6.b_0\n", | |
"models/se_resnext_50/129/batch_norm_6.w_1\n", | |
"models/se_resnext_50/129/batch_norm_6.w_2\n", | |
"models/se_resnext_50/129/conv2d_7.w_0\n", | |
"models/se_resnext_50/129/batch_norm_7.w_0\n", | |
"models/se_resnext_50/129/batch_norm_7.b_0\n", | |
"models/se_resnext_50/129/batch_norm_7.w_1\n", | |
"models/se_resnext_50/129/batch_norm_7.w_2\n", | |
"models/se_resnext_50/129/fc_2.w_0\n", | |
"models/se_resnext_50/129/fc_2.b_0\n", | |
"models/se_resnext_50/129/fc_3.w_0\n", | |
"models/se_resnext_50/129/fc_3.b_0\n", | |
"models/se_resnext_50/129/conv2d_8.w_0\n", | |
"models/se_resnext_50/129/batch_norm_8.w_0\n", | |
"models/se_resnext_50/129/batch_norm_8.b_0\n", | |
"models/se_resnext_50/129/batch_norm_8.w_1\n", | |
"models/se_resnext_50/129/batch_norm_8.w_2\n", | |
"models/se_resnext_50/129/conv2d_9.w_0\n", | |
"models/se_resnext_50/129/batch_norm_9.w_0\n", | |
"models/se_resnext_50/129/batch_norm_9.b_0\n", | |
"models/se_resnext_50/129/batch_norm_9.w_1\n", | |
"models/se_resnext_50/129/batch_norm_9.w_2\n", | |
"models/se_resnext_50/129/conv2d_10.w_0\n", | |
"models/se_resnext_50/129/batch_norm_10.w_0\n", | |
"models/se_resnext_50/129/batch_norm_10.b_0\n", | |
"models/se_resnext_50/129/batch_norm_10.w_1\n", | |
"models/se_resnext_50/129/batch_norm_10.w_2\n", | |
"models/se_resnext_50/129/fc_4.w_0\n", | |
"models/se_resnext_50/129/fc_4.b_0\n", | |
"models/se_resnext_50/129/fc_5.w_0\n", | |
"models/se_resnext_50/129/fc_5.b_0\n", | |
"models/se_resnext_50/129/conv2d_11.w_0\n", | |
"models/se_resnext_50/129/batch_norm_11.w_0\n", | |
"models/se_resnext_50/129/batch_norm_11.b_0\n", | |
"models/se_resnext_50/129/batch_norm_11.w_1\n", | |
"models/se_resnext_50/129/batch_norm_11.w_2\n", | |
"models/se_resnext_50/129/conv2d_12.w_0\n", | |
"models/se_resnext_50/129/batch_norm_12.w_0\n", | |
"models/se_resnext_50/129/batch_norm_12.b_0\n", | |
"models/se_resnext_50/129/batch_norm_12.w_1\n", | |
"models/se_resnext_50/129/batch_norm_12.w_2\n", | |
"models/se_resnext_50/129/conv2d_13.w_0\n", | |
"models/se_resnext_50/129/batch_norm_13.w_0\n", | |
"models/se_resnext_50/129/batch_norm_13.b_0\n", | |
"models/se_resnext_50/129/batch_norm_13.w_1\n", | |
"models/se_resnext_50/129/batch_norm_13.w_2\n", | |
"models/se_resnext_50/129/fc_6.w_0\n", | |
"models/se_resnext_50/129/fc_6.b_0\n", | |
"models/se_resnext_50/129/fc_7.w_0\n", | |
"models/se_resnext_50/129/fc_7.b_0\n", | |
"models/se_resnext_50/129/conv2d_14.w_0\n", | |
"models/se_resnext_50/129/batch_norm_14.w_0\n", | |
"models/se_resnext_50/129/batch_norm_14.b_0\n", | |
"models/se_resnext_50/129/batch_norm_14.w_1\n", | |
"models/se_resnext_50/129/batch_norm_14.w_2\n", | |
"models/se_resnext_50/129/conv2d_15.w_0\n", | |
"models/se_resnext_50/129/batch_norm_15.w_0\n", | |
"models/se_resnext_50/129/batch_norm_15.b_0\n", | |
"models/se_resnext_50/129/batch_norm_15.w_1\n", | |
"models/se_resnext_50/129/batch_norm_15.w_2\n", | |
"models/se_resnext_50/129/conv2d_16.w_0\n", | |
"models/se_resnext_50/129/batch_norm_16.w_0\n", | |
"models/se_resnext_50/129/batch_norm_16.b_0\n", | |
"models/se_resnext_50/129/batch_norm_16.w_1\n", | |
"models/se_resnext_50/129/batch_norm_16.w_2\n", | |
"models/se_resnext_50/129/conv2d_17.w_0\n", | |
"models/se_resnext_50/129/batch_norm_17.w_0\n", | |
"models/se_resnext_50/129/batch_norm_17.b_0\n", | |
"models/se_resnext_50/129/batch_norm_17.w_1\n", | |
"models/se_resnext_50/129/batch_norm_17.w_2\n", | |
"models/se_resnext_50/129/fc_8.w_0\n", | |
"models/se_resnext_50/129/fc_8.b_0\n", | |
"models/se_resnext_50/129/fc_9.w_0\n", | |
"models/se_resnext_50/129/fc_9.b_0\n", | |
"models/se_resnext_50/129/conv2d_18.w_0\n", | |
"models/se_resnext_50/129/batch_norm_18.w_0\n", | |
"models/se_resnext_50/129/batch_norm_18.b_0\n", | |
"models/se_resnext_50/129/batch_norm_18.w_1\n", | |
"models/se_resnext_50/129/batch_norm_18.w_2\n", | |
"models/se_resnext_50/129/conv2d_19.w_0\n", | |
"models/se_resnext_50/129/batch_norm_19.w_0\n", | |
"models/se_resnext_50/129/batch_norm_19.b_0\n", | |
"models/se_resnext_50/129/batch_norm_19.w_1\n", | |
"models/se_resnext_50/129/batch_norm_19.w_2\n", | |
"models/se_resnext_50/129/conv2d_20.w_0\n", | |
"models/se_resnext_50/129/batch_norm_20.w_0\n", | |
"models/se_resnext_50/129/batch_norm_20.b_0\n", | |
"models/se_resnext_50/129/batch_norm_20.w_1\n", | |
"models/se_resnext_50/129/batch_norm_20.w_2\n", | |
"models/se_resnext_50/129/fc_10.w_0\n", | |
"models/se_resnext_50/129/fc_10.b_0\n", | |
"models/se_resnext_50/129/fc_11.w_0\n", | |
"models/se_resnext_50/129/fc_11.b_0\n", | |
"models/se_resnext_50/129/conv2d_21.w_0\n", | |
"models/se_resnext_50/129/batch_norm_21.w_0\n", | |
"models/se_resnext_50/129/batch_norm_21.b_0\n", | |
"models/se_resnext_50/129/batch_norm_21.w_1\n", | |
"models/se_resnext_50/129/batch_norm_21.w_2\n", | |
"models/se_resnext_50/129/conv2d_22.w_0\n", | |
"models/se_resnext_50/129/batch_norm_22.w_0\n", | |
"models/se_resnext_50/129/batch_norm_22.b_0\n", | |
"models/se_resnext_50/129/batch_norm_22.w_1\n", | |
"models/se_resnext_50/129/batch_norm_22.w_2\n", | |
"models/se_resnext_50/129/conv2d_23.w_0\n", | |
"models/se_resnext_50/129/batch_norm_23.w_0\n", | |
"models/se_resnext_50/129/batch_norm_23.b_0\n", | |
"models/se_resnext_50/129/batch_norm_23.w_1\n", | |
"models/se_resnext_50/129/batch_norm_23.w_2\n", | |
"models/se_resnext_50/129/fc_12.w_0\n", | |
"models/se_resnext_50/129/fc_12.b_0\n", | |
"models/se_resnext_50/129/fc_13.w_0\n", | |
"models/se_resnext_50/129/fc_13.b_0\n", | |
"models/se_resnext_50/129/conv2d_24.w_0\n", | |
"models/se_resnext_50/129/batch_norm_24.w_0\n", | |
"models/se_resnext_50/129/batch_norm_24.b_0\n", | |
"models/se_resnext_50/129/batch_norm_24.w_1\n", | |
"models/se_resnext_50/129/batch_norm_24.w_2\n", | |
"models/se_resnext_50/129/conv2d_25.w_0\n", | |
"models/se_resnext_50/129/batch_norm_25.w_0\n", | |
"models/se_resnext_50/129/batch_norm_25.b_0\n", | |
"models/se_resnext_50/129/batch_norm_25.w_1\n", | |
"models/se_resnext_50/129/batch_norm_25.w_2\n", | |
"models/se_resnext_50/129/conv2d_26.w_0\n", | |
"models/se_resnext_50/129/batch_norm_26.w_0\n", | |
"models/se_resnext_50/129/batch_norm_26.b_0\n", | |
"models/se_resnext_50/129/batch_norm_26.w_1\n", | |
"models/se_resnext_50/129/batch_norm_26.w_2\n", | |
"models/se_resnext_50/129/fc_14.w_0\n", | |
"models/se_resnext_50/129/fc_14.b_0\n", | |
"models/se_resnext_50/129/fc_15.w_0\n", | |
"models/se_resnext_50/129/fc_15.b_0\n", | |
"models/se_resnext_50/129/conv2d_27.w_0\n", | |
"models/se_resnext_50/129/batch_norm_27.w_0\n", | |
"models/se_resnext_50/129/batch_norm_27.b_0\n", | |
"models/se_resnext_50/129/batch_norm_27.w_1\n", | |
"models/se_resnext_50/129/batch_norm_27.w_2\n", | |
"models/se_resnext_50/129/conv2d_28.w_0\n", | |
"models/se_resnext_50/129/batch_norm_28.w_0\n", | |
"models/se_resnext_50/129/batch_norm_28.b_0\n", | |
"models/se_resnext_50/129/batch_norm_28.w_1\n", | |
"models/se_resnext_50/129/batch_norm_28.w_2\n", | |
"models/se_resnext_50/129/conv2d_29.w_0\n", | |
"models/se_resnext_50/129/batch_norm_29.w_0\n", | |
"models/se_resnext_50/129/batch_norm_29.b_0\n", | |
"models/se_resnext_50/129/batch_norm_29.w_1\n", | |
"models/se_resnext_50/129/batch_norm_29.w_2\n", | |
"models/se_resnext_50/129/conv2d_30.w_0\n", | |
"models/se_resnext_50/129/batch_norm_30.w_0\n", | |
"models/se_resnext_50/129/batch_norm_30.b_0\n", | |
"models/se_resnext_50/129/batch_norm_30.w_1\n", | |
"models/se_resnext_50/129/batch_norm_30.w_2\n", | |
"models/se_resnext_50/129/fc_16.w_0\n", | |
"models/se_resnext_50/129/fc_16.b_0\n", | |
"models/se_resnext_50/129/fc_17.w_0\n", | |
"models/se_resnext_50/129/fc_17.b_0\n", | |
"models/se_resnext_50/129/conv2d_31.w_0\n", | |
"models/se_resnext_50/129/batch_norm_31.w_0\n", | |
"models/se_resnext_50/129/batch_norm_31.b_0\n", | |
"models/se_resnext_50/129/batch_norm_31.w_1\n", | |
"models/se_resnext_50/129/batch_norm_31.w_2\n", | |
"models/se_resnext_50/129/conv2d_32.w_0\n", | |
"models/se_resnext_50/129/batch_norm_32.w_0\n", | |
"models/se_resnext_50/129/batch_norm_32.b_0\n", | |
"models/se_resnext_50/129/batch_norm_32.w_1\n", | |
"models/se_resnext_50/129/batch_norm_32.w_2\n", | |
"models/se_resnext_50/129/conv2d_33.w_0\n", | |
"models/se_resnext_50/129/batch_norm_33.w_0\n", | |
"models/se_resnext_50/129/batch_norm_33.b_0\n", | |
"models/se_resnext_50/129/batch_norm_33.w_1\n", | |
"models/se_resnext_50/129/batch_norm_33.w_2\n", | |
"models/se_resnext_50/129/fc_18.w_0\n", | |
"models/se_resnext_50/129/fc_18.b_0\n", | |
"models/se_resnext_50/129/fc_19.w_0\n", | |
"models/se_resnext_50/129/fc_19.b_0\n", | |
"models/se_resnext_50/129/conv2d_34.w_0\n", | |
"models/se_resnext_50/129/batch_norm_34.w_0\n", | |
"models/se_resnext_50/129/batch_norm_34.b_0\n", | |
"models/se_resnext_50/129/batch_norm_34.w_1\n", | |
"models/se_resnext_50/129/batch_norm_34.w_2\n", | |
"models/se_resnext_50/129/conv2d_35.w_0\n", | |
"models/se_resnext_50/129/batch_norm_35.w_0\n", | |
"models/se_resnext_50/129/batch_norm_35.b_0\n", | |
"models/se_resnext_50/129/batch_norm_35.w_1\n", | |
"models/se_resnext_50/129/batch_norm_35.w_2\n", | |
"models/se_resnext_50/129/conv2d_36.w_0\n", | |
"models/se_resnext_50/129/batch_norm_36.w_0\n", | |
"models/se_resnext_50/129/batch_norm_36.b_0\n", | |
"models/se_resnext_50/129/batch_norm_36.w_1\n", | |
"models/se_resnext_50/129/batch_norm_36.w_2\n", | |
"models/se_resnext_50/129/fc_20.w_0\n", | |
"models/se_resnext_50/129/fc_20.b_0\n", | |
"models/se_resnext_50/129/fc_21.w_0\n", | |
"models/se_resnext_50/129/fc_21.b_0\n", | |
"models/se_resnext_50/129/conv2d_37.w_0\n", | |
"models/se_resnext_50/129/batch_norm_37.w_0\n", | |
"models/se_resnext_50/129/batch_norm_37.b_0\n", | |
"models/se_resnext_50/129/batch_norm_37.w_1\n", | |
"models/se_resnext_50/129/batch_norm_37.w_2\n", | |
"models/se_resnext_50/129/conv2d_38.w_0\n", | |
"models/se_resnext_50/129/batch_norm_38.w_0\n", | |
"models/se_resnext_50/129/batch_norm_38.b_0\n", | |
"models/se_resnext_50/129/batch_norm_38.w_1\n", | |
"models/se_resnext_50/129/batch_norm_38.w_2\n", | |
"models/se_resnext_50/129/conv2d_39.w_0\n", | |
"models/se_resnext_50/129/batch_norm_39.w_0\n", | |
"models/se_resnext_50/129/batch_norm_39.b_0\n", | |
"models/se_resnext_50/129/batch_norm_39.w_1\n", | |
"models/se_resnext_50/129/batch_norm_39.w_2\n", | |
"models/se_resnext_50/129/fc_22.w_0\n", | |
"models/se_resnext_50/129/fc_22.b_0\n", | |
"models/se_resnext_50/129/fc_23.w_0\n", | |
"models/se_resnext_50/129/fc_23.b_0\n", | |
"models/se_resnext_50/129/conv2d_40.w_0\n", | |
"models/se_resnext_50/129/batch_norm_40.w_0\n", | |
"models/se_resnext_50/129/batch_norm_40.b_0\n", | |
"models/se_resnext_50/129/batch_norm_40.w_1\n", | |
"models/se_resnext_50/129/batch_norm_40.w_2\n", | |
"models/se_resnext_50/129/conv2d_41.w_0\n", | |
"models/se_resnext_50/129/batch_norm_41.w_0\n", | |
"models/se_resnext_50/129/batch_norm_41.b_0\n", | |
"models/se_resnext_50/129/batch_norm_41.w_1\n", | |
"models/se_resnext_50/129/batch_norm_41.w_2\n", | |
"models/se_resnext_50/129/conv2d_42.w_0\n", | |
"models/se_resnext_50/129/batch_norm_42.w_0\n", | |
"models/se_resnext_50/129/batch_norm_42.b_0\n", | |
"models/se_resnext_50/129/batch_norm_42.w_1\n", | |
"models/se_resnext_50/129/batch_norm_42.w_2\n", | |
"models/se_resnext_50/129/fc_24.w_0\n", | |
"models/se_resnext_50/129/fc_24.b_0\n", | |
"models/se_resnext_50/129/fc_25.w_0\n", | |
"models/se_resnext_50/129/fc_25.b_0\n", | |
"models/se_resnext_50/129/conv2d_43.w_0\n", | |
"models/se_resnext_50/129/batch_norm_43.w_0\n", | |
"models/se_resnext_50/129/batch_norm_43.b_0\n", | |
"models/se_resnext_50/129/batch_norm_43.w_1\n", | |
"models/se_resnext_50/129/batch_norm_43.w_2\n", | |
"models/se_resnext_50/129/conv2d_44.w_0\n", | |
"models/se_resnext_50/129/batch_norm_44.w_0\n", | |
"models/se_resnext_50/129/batch_norm_44.b_0\n", | |
"models/se_resnext_50/129/batch_norm_44.w_1\n", | |
"models/se_resnext_50/129/batch_norm_44.w_2\n", | |
"models/se_resnext_50/129/conv2d_45.w_0\n", | |
"models/se_resnext_50/129/batch_norm_45.w_0\n", | |
"models/se_resnext_50/129/batch_norm_45.b_0\n", | |
"models/se_resnext_50/129/batch_norm_45.w_1\n", | |
"models/se_resnext_50/129/batch_norm_45.w_2\n", | |
"models/se_resnext_50/129/fc_26.w_0\n", | |
"models/se_resnext_50/129/fc_26.b_0\n", | |
"models/se_resnext_50/129/fc_27.w_0\n", | |
"models/se_resnext_50/129/fc_27.b_0\n", | |
"models/se_resnext_50/129/conv2d_46.w_0\n", | |
"models/se_resnext_50/129/batch_norm_46.w_0\n", | |
"models/se_resnext_50/129/batch_norm_46.b_0\n", | |
"models/se_resnext_50/129/batch_norm_46.w_1\n", | |
"models/se_resnext_50/129/batch_norm_46.w_2\n", | |
"models/se_resnext_50/129/conv2d_47.w_0\n", | |
"models/se_resnext_50/129/batch_norm_47.w_0\n", | |
"models/se_resnext_50/129/batch_norm_47.b_0\n", | |
"models/se_resnext_50/129/batch_norm_47.w_1\n", | |
"models/se_resnext_50/129/batch_norm_47.w_2\n", | |
"models/se_resnext_50/129/conv2d_48.w_0\n", | |
"models/se_resnext_50/129/batch_norm_48.w_0\n", | |
"models/se_resnext_50/129/batch_norm_48.b_0\n", | |
"models/se_resnext_50/129/batch_norm_48.w_1\n", | |
"models/se_resnext_50/129/batch_norm_48.w_2\n", | |
"models/se_resnext_50/129/conv2d_49.w_0\n", | |
"models/se_resnext_50/129/batch_norm_49.w_0\n", | |
"models/se_resnext_50/129/batch_norm_49.b_0\n", | |
"models/se_resnext_50/129/batch_norm_49.w_1\n", | |
"models/se_resnext_50/129/batch_norm_49.w_2\n", | |
"models/se_resnext_50/129/fc_28.w_0\n", | |
"models/se_resnext_50/129/fc_28.b_0\n", | |
"models/se_resnext_50/129/fc_29.w_0\n", | |
"models/se_resnext_50/129/fc_29.b_0\n", | |
"models/se_resnext_50/129/conv2d_50.w_0\n", | |
"models/se_resnext_50/129/batch_norm_50.w_0\n", | |
"models/se_resnext_50/129/batch_norm_50.b_0\n", | |
"models/se_resnext_50/129/batch_norm_50.w_1\n", | |
"models/se_resnext_50/129/batch_norm_50.w_2\n", | |
"models/se_resnext_50/129/conv2d_51.w_0\n", | |
"models/se_resnext_50/129/batch_norm_51.w_0\n", | |
"models/se_resnext_50/129/batch_norm_51.b_0\n", | |
"models/se_resnext_50/129/batch_norm_51.w_1\n", | |
"models/se_resnext_50/129/batch_norm_51.w_2\n", | |
"models/se_resnext_50/129/conv2d_52.w_0\n", | |
"models/se_resnext_50/129/batch_norm_52.w_0\n", | |
"models/se_resnext_50/129/batch_norm_52.b_0\n", | |
"models/se_resnext_50/129/batch_norm_52.w_1\n", | |
"models/se_resnext_50/129/batch_norm_52.w_2\n", | |
"models/se_resnext_50/129/fc_30.w_0\n", | |
"models/se_resnext_50/129/fc_30.b_0\n", | |
"models/se_resnext_50/129/fc_31.w_0\n", | |
"models/se_resnext_50/129/fc_31.b_0\n" | |
] | |
} | |
], | |
"source": [ | |
"# 加载参数\n", | |
"def if_exist(var):\n", | |
" path = os.path.join(pretrained_model_path, var.name)\n", | |
" exist = os.path.exists(path)\n", | |
" if exist:\n", | |
" print(path)\n", | |
" return exist\n", | |
"\n", | |
"fluid.io.load_vars(exe, pretrained_model_path, predicate=if_exist)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 11, | |
"metadata": { | |
"ExecuteTime": { | |
"end_time": "2018-07-20T10:03:24.157400Z", | |
"start_time": "2018-07-20T10:03:23.763261Z" | |
} | |
}, | |
"outputs": [], | |
"source": [ | |
"feeder = fluid.data_feeder.DataFeeder([image, label], fluid.CPUPlace())\n", | |
"reader = feeder.decorate_reader(paddle.batch(paddle.dataset.cifar.train10(), 64), None)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 12, | |
"metadata": { | |
"ExecuteTime": { | |
"end_time": "2018-07-20T10:03:37.037340Z", | |
"start_time": "2018-07-20T10:03:24.161237Z" | |
}, | |
"scrolled": true | |
}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"(0, [array([2.3287988], dtype=float32)])\n", | |
"(1, [array([2.3302152], dtype=float32)])\n", | |
"(2, [array([2.5670676], dtype=float32)])\n", | |
"(3, [array([2.2559876], dtype=float32)])\n", | |
"(4, [array([2.3306024], dtype=float32)])\n", | |
"(5, [array([2.3891745], dtype=float32)])\n", | |
"(6, [array([2.3429651], dtype=float32)])\n", | |
"(7, [array([2.2461793], dtype=float32)])\n", | |
"(8, [array([2.2587867], dtype=float32)])\n", | |
"(9, [array([2.3550396], dtype=float32)])\n", | |
"(10, [array([2.3980129], dtype=float32)])\n", | |
"(11, [array([2.3221207], dtype=float32)])\n" | |
] | |
} | |
], | |
"source": [ | |
"for batch_id, data in enumerate(reader()):\n", | |
" result = exe.run(fetch_list=[loss], feed=data)\n", | |
" print(batch_id, result)\n", | |
" if batch_id > 10:\n", | |
" break" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"!ls models/fine.model | wc -l\n", | |
"!ls models/se_resnext_50/129 | wc -l" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 13, | |
"metadata": { | |
"ExecuteTime": { | |
"end_time": "2018-07-20T10:03:37.169253Z", | |
"start_time": "2018-07-20T10:03:37.041280Z" | |
} | |
}, | |
"outputs": [], | |
"source": [ | |
"fluid.io.save_params(exe, \"models/fine.model\")" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 19, | |
"metadata": { | |
"ExecuteTime": { | |
"end_time": "2018-07-20T10:06:20.966974Z", | |
"start_time": "2018-07-20T10:06:20.458865Z" | |
} | |
}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"333\n", | |
"331\n", | |
"125\n" | |
] | |
} | |
], | |
"source": [ | |
"# 只要部分参数改变了\n", | |
"!ls models/fine.model | wc -l\n", | |
"!ls models/se_resnext_50/129 | wc -l\n", | |
"!diff models/se_resnext_50/129 models/fine.model | grep differ | wc -l" | |
] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 2", | |
"language": "python", | |
"name": "python2" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 2 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython2", | |
"version": "2.7.15" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 2 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
更完整版本在 https://github.com/oraoto/learn_ml/blob/master/paddle/pretrained.ipynb