Skip to content

Instantly share code, notes, and snippets.

@oraoto
Last active July 26, 2018 05:21
Show Gist options
  • Save oraoto/8cfcbb92c3482f4f00aaf5947b3b6868 to your computer and use it in GitHub Desktop.
Save oraoto/8cfcbb92c3482f4f00aaf5947b3b6868 to your computer and use it in GitHub Desktop.
Paddle Fluid pre-trained model fine-tuning
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"ExecuteTime": {
"end_time": "2018-07-20T10:03:22.529185Z",
"start_time": "2018-07-20T10:03:21.747660Z"
}
},
"outputs": [],
"source": [
"import paddle.fluid as fluid\n",
"import paddle\n",
"#from se_resnext import SE_ResNeXt50_32x4d\n",
"import numpy as np\n",
"import os\n",
"import math\n",
"from paddle.fluid.debugger import draw_block_graphviz"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"ExecuteTime": {
"end_time": "2018-07-20T10:03:22.537785Z",
"start_time": "2018-07-20T10:03:22.533273Z"
}
},
"outputs": [],
"source": [
"pretrained_model_path = \"models/se_resnext_50/129\""
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"ExecuteTime": {
"end_time": "2018-07-20T10:03:22.616251Z",
"start_time": "2018-07-20T10:03:22.541650Z"
}
},
"outputs": [],
"source": [
"# 修改自 https://github.com/PaddlePaddle/models/blob/develop/fluid/image_classification/models/se_resnext.py\n",
"# 1. 增加了self.variables记录中间变量\n",
"# 2. 去掉 101, 152层的支持\n",
"# 3. 修改最后的fc层名字,避免加载参数\n",
"\n",
"class SE_ResNeXt50():\n",
" \n",
" def __init__(self):\n",
" # 记录中间变量\n",
" self.variables = []\n",
" \n",
" def net(self, input, class_dim=1000):\n",
" cardinality = 32\n",
" reduction_ratio = 16\n",
" depth = [3, 4, 6, 3]\n",
" num_filters = [128, 256, 512, 1024]\n",
" \n",
" conv = self.conv_bn_layer(\n",
" input=input,\n",
" num_filters=64,\n",
" filter_size=7,\n",
" stride=2,\n",
" act='relu')\n",
" conv = fluid.layers.pool2d(\n",
" input=conv,\n",
" pool_size=3,\n",
" pool_stride=2,\n",
" pool_padding=1,\n",
" pool_type='max')\n",
" \n",
" for block in range(len(depth)):\n",
" for i in range(depth[block]):\n",
" conv = self.bottleneck_block(\n",
" input=conv,\n",
" num_filters=num_filters[block],\n",
" stride=2 if i == 0 and block != 0 else 1,\n",
" cardinality=cardinality,\n",
" reduction_ratio=reduction_ratio)\n",
" \n",
" pool = fluid.layers.pool2d(\n",
" input=conv, pool_size=7, pool_type='avg', global_pooling=True)\n",
" drop = fluid.layers.dropout(x=pool, dropout_prob=0.5)\n",
" stdv = 1.0 / math.sqrt(drop.shape[1] * 1.0)\n",
" out = fluid.layers.fc(input=drop,\n",
" size=class_dim,\n",
" name='se_res_out',\n",
" act='softmax',\n",
" param_attr=fluid.param_attr.ParamAttr(\n",
" initializer=fluid.initializer.Uniform(-stdv,\n",
" stdv)))\n",
" return out\n",
"\n",
" def shortcut(self, input, ch_out, stride):\n",
" ch_in = input.shape[1]\n",
" if ch_in != ch_out or stride != 1:\n",
" filter_size = 1\n",
" return self.conv_bn_layer(input, ch_out, filter_size, stride)\n",
" else:\n",
" return input\n",
"\n",
" def bottleneck_block(self, input, num_filters, stride, cardinality,\n",
" reduction_ratio):\n",
" conv0 = self.conv_bn_layer(\n",
" input=input, num_filters=num_filters, filter_size=1, act='relu')\n",
" conv1 = self.conv_bn_layer(\n",
" input=conv0,\n",
" num_filters=num_filters,\n",
" filter_size=3,\n",
" stride=stride,\n",
" groups=cardinality,\n",
" act='relu')\n",
" conv2 = self.conv_bn_layer(\n",
" input=conv1, num_filters=num_filters * 2, filter_size=1, act=None)\n",
" scale = self.squeeze_excitation(\n",
" input=conv2,\n",
" num_channels=num_filters * 2,\n",
" reduction_ratio=reduction_ratio)\n",
"\n",
" short = self.shortcut(input, num_filters * 2, stride)\n",
"\n",
" return fluid.layers.elementwise_add(x=short, y=scale, act='relu')\n",
"\n",
" def conv_bn_layer(self,\n",
" input,\n",
" num_filters,\n",
" filter_size,\n",
" stride=1,\n",
" groups=1,\n",
" act=None):\n",
" conv = fluid.layers.conv2d(\n",
" input=input,\n",
" num_filters=num_filters,\n",
" filter_size=filter_size,\n",
" stride=stride,\n",
" padding=(filter_size - 1) / 2,\n",
" groups=groups,\n",
" act=None,\n",
" bias_attr=False)\n",
" self.variables.append(conv)\n",
" bn = fluid.layers.batch_norm(input=conv, act=act)\n",
" self.variables.append(bn)\n",
" return bn\n",
"\n",
" def squeeze_excitation(self, input, num_channels, reduction_ratio):\n",
" pool = fluid.layers.pool2d(\n",
" input=input, pool_size=0, pool_type='avg', global_pooling=True)\n",
" stdv = 1.0 / math.sqrt(pool.shape[1] * 1.0)\n",
" squeeze = fluid.layers.fc(input=pool,\n",
" size=num_channels / reduction_ratio,\n",
" act='relu',\n",
" param_attr=fluid.param_attr.ParamAttr(\n",
" initializer=fluid.initializer.Uniform(\n",
" -stdv, stdv)))\n",
" self.variables.append(squeeze)\n",
" stdv = 1.0 / math.sqrt(squeeze.shape[1] * 1.0)\n",
" excitation = fluid.layers.fc(input=squeeze,\n",
" size=num_channels,\n",
" act='sigmoid',\n",
" param_attr=fluid.param_attr.ParamAttr(\n",
" initializer=fluid.initializer.Uniform(\n",
" -stdv, stdv)))\n",
" self.variables.append(excitation)\n",
" scale = fluid.layers.elementwise_mul(x=input, y=excitation, axis=0)\n",
" return scale"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"ExecuteTime": {
"end_time": "2018-07-20T10:03:22.860628Z",
"start_time": "2018-07-20T10:03:22.619990Z"
}
},
"outputs": [],
"source": [
"image = fluid.layers.data(name='image', shape=[3, 32, 32], dtype='float32')\n",
"label = fluid.layers.data(name='label', shape=[-1, 1], dtype='int64')\n",
"\n",
"base_model = SE_ResNeXt50()\n",
"#base_model = SE_ResNeXt50_32x4d()\n",
"predict = base_model.net(image, class_dim=10)\n",
"\n",
"inference_program = fluid.default_main_program().clone(for_test=True)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"ExecuteTime": {
"end_time": "2018-07-20T10:03:22.932634Z",
"start_time": "2018-07-20T10:03:22.864467Z"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"(0, 'conv2d_0.tmp_0')\n",
"(1, 'batch_norm_0.tmp_2')\n",
"(2, 'conv2d_1.tmp_0')\n",
"(3, 'batch_norm_1.tmp_2')\n",
"(4, 'conv2d_2.tmp_0')\n",
"(5, 'batch_norm_2.tmp_2')\n",
"(6, 'conv2d_3.tmp_0')\n",
"(7, 'batch_norm_3.tmp_2')\n",
"(8, 'fc_0.tmp_1')\n",
"(9, 'fc_1.tmp_1')\n",
"(10, 'conv2d_4.tmp_0')\n",
"(11, 'batch_norm_4.tmp_2')\n",
"(12, 'conv2d_5.tmp_0')\n",
"(13, 'batch_norm_5.tmp_2')\n",
"(14, 'conv2d_6.tmp_0')\n",
"(15, 'batch_norm_6.tmp_2')\n",
"(16, 'conv2d_7.tmp_0')\n",
"(17, 'batch_norm_7.tmp_2')\n",
"(18, 'fc_2.tmp_1')\n",
"(19, 'fc_3.tmp_1')\n",
"(20, 'conv2d_8.tmp_0')\n",
"(21, 'batch_norm_8.tmp_2')\n",
"(22, 'conv2d_9.tmp_0')\n",
"(23, 'batch_norm_9.tmp_2')\n",
"(24, 'conv2d_10.tmp_0')\n",
"(25, 'batch_norm_10.tmp_2')\n",
"(26, 'fc_4.tmp_1')\n",
"(27, 'fc_5.tmp_1')\n",
"(28, 'conv2d_11.tmp_0')\n",
"(29, 'batch_norm_11.tmp_2')\n",
"(30, 'conv2d_12.tmp_0')\n",
"(31, 'batch_norm_12.tmp_2')\n",
"(32, 'conv2d_13.tmp_0')\n",
"(33, 'batch_norm_13.tmp_2')\n",
"(34, 'fc_6.tmp_1')\n",
"(35, 'fc_7.tmp_1')\n",
"(36, 'conv2d_14.tmp_0')\n",
"(37, 'batch_norm_14.tmp_2')\n",
"(38, 'conv2d_15.tmp_0')\n",
"(39, 'batch_norm_15.tmp_2')\n",
"(40, 'conv2d_16.tmp_0')\n",
"(41, 'batch_norm_16.tmp_2')\n",
"(42, 'conv2d_17.tmp_0')\n",
"(43, 'batch_norm_17.tmp_2')\n",
"(44, 'fc_8.tmp_1')\n",
"(45, 'fc_9.tmp_1')\n",
"(46, 'conv2d_18.tmp_0')\n",
"(47, 'batch_norm_18.tmp_2')\n",
"(48, 'conv2d_19.tmp_0')\n",
"(49, 'batch_norm_19.tmp_2')\n",
"(50, 'conv2d_20.tmp_0')\n",
"(51, 'batch_norm_20.tmp_2')\n",
"(52, 'fc_10.tmp_1')\n",
"(53, 'fc_11.tmp_1')\n",
"(54, 'conv2d_21.tmp_0')\n",
"(55, 'batch_norm_21.tmp_2')\n",
"(56, 'conv2d_22.tmp_0')\n",
"(57, 'batch_norm_22.tmp_2')\n",
"(58, 'conv2d_23.tmp_0')\n",
"(59, 'batch_norm_23.tmp_2')\n",
"(60, 'fc_12.tmp_1')\n",
"(61, 'fc_13.tmp_1')\n",
"(62, 'conv2d_24.tmp_0')\n",
"(63, 'batch_norm_24.tmp_2')\n",
"(64, 'conv2d_25.tmp_0')\n",
"(65, 'batch_norm_25.tmp_2')\n",
"(66, 'conv2d_26.tmp_0')\n",
"(67, 'batch_norm_26.tmp_2')\n",
"(68, 'fc_14.tmp_1')\n",
"(69, 'fc_15.tmp_1')\n",
"(70, 'conv2d_27.tmp_0')\n",
"(71, 'batch_norm_27.tmp_2')\n",
"(72, 'conv2d_28.tmp_0')\n",
"(73, 'batch_norm_28.tmp_2')\n",
"(74, 'conv2d_29.tmp_0')\n",
"(75, 'batch_norm_29.tmp_2')\n",
"(76, 'conv2d_30.tmp_0')\n",
"(77, 'batch_norm_30.tmp_2')\n",
"(78, 'fc_16.tmp_1')\n",
"(79, 'fc_17.tmp_1')\n",
"(80, 'conv2d_31.tmp_0')\n",
"(81, 'batch_norm_31.tmp_2')\n",
"(82, 'conv2d_32.tmp_0')\n",
"(83, 'batch_norm_32.tmp_2')\n",
"(84, 'conv2d_33.tmp_0')\n",
"(85, 'batch_norm_33.tmp_2')\n",
"(86, 'fc_18.tmp_1')\n",
"(87, 'fc_19.tmp_1')\n",
"(88, 'conv2d_34.tmp_0')\n",
"(89, 'batch_norm_34.tmp_2')\n",
"(90, 'conv2d_35.tmp_0')\n",
"(91, 'batch_norm_35.tmp_2')\n",
"(92, 'conv2d_36.tmp_0')\n",
"(93, 'batch_norm_36.tmp_2')\n",
"(94, 'fc_20.tmp_1')\n",
"(95, 'fc_21.tmp_1')\n",
"(96, 'conv2d_37.tmp_0')\n",
"(97, 'batch_norm_37.tmp_2')\n",
"(98, 'conv2d_38.tmp_0')\n",
"(99, 'batch_norm_38.tmp_2')\n",
"(100, 'conv2d_39.tmp_0')\n",
"(101, 'batch_norm_39.tmp_2')\n",
"(102, 'fc_22.tmp_1')\n",
"(103, 'fc_23.tmp_1')\n",
"(104, 'conv2d_40.tmp_0')\n",
"(105, 'batch_norm_40.tmp_2')\n",
"(106, 'conv2d_41.tmp_0')\n",
"(107, 'batch_norm_41.tmp_2')\n",
"(108, 'conv2d_42.tmp_0')\n",
"(109, 'batch_norm_42.tmp_2')\n",
"(110, 'fc_24.tmp_1')\n",
"(111, 'fc_25.tmp_1')\n",
"(112, 'conv2d_43.tmp_0')\n",
"(113, 'batch_norm_43.tmp_2')\n",
"(114, 'conv2d_44.tmp_0')\n",
"(115, 'batch_norm_44.tmp_2')\n",
"(116, 'conv2d_45.tmp_0')\n",
"(117, 'batch_norm_45.tmp_2')\n",
"(118, 'fc_26.tmp_1')\n",
"(119, 'fc_27.tmp_1')\n",
"(120, 'conv2d_46.tmp_0')\n",
"(121, 'batch_norm_46.tmp_2')\n",
"(122, 'conv2d_47.tmp_0')\n",
"(123, 'batch_norm_47.tmp_2')\n",
"(124, 'conv2d_48.tmp_0')\n",
"(125, 'batch_norm_48.tmp_2')\n",
"(126, 'conv2d_49.tmp_0')\n",
"(127, 'batch_norm_49.tmp_2')\n",
"(128, 'fc_28.tmp_1')\n",
"(129, 'fc_29.tmp_1')\n",
"(130, 'conv2d_50.tmp_0')\n",
"(131, 'batch_norm_50.tmp_2')\n",
"(132, 'conv2d_51.tmp_0')\n",
"(133, 'batch_norm_51.tmp_2')\n",
"(134, 'conv2d_52.tmp_0')\n",
"(135, 'batch_norm_52.tmp_2')\n",
"(136, 'fc_30.tmp_1')\n",
"(137, 'fc_31.tmp_1')\n"
]
}
],
"source": [
"for i, v in enumerate(base_model.variables):\n",
" print(i, v.name)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {
"ExecuteTime": {
"end_time": "2018-07-20T10:03:22.971445Z",
"start_time": "2018-07-20T10:03:22.938370Z"
}
},
"outputs": [],
"source": [
"#只训练最后n层\n",
"for v in base_model.variables[:-10]:\n",
" v.stop_gradient = True\n",
"for v in base_model.variables[-10:]:\n",
" v.stop_gradient = False"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {
"ExecuteTime": {
"end_time": "2018-07-20T10:03:23.039162Z",
"start_time": "2018-07-20T10:03:22.975037Z"
}
},
"outputs": [],
"source": [
"loss = fluid.layers.cross_entropy(input=predict, label=label)\n",
"loss = fluid.layers.mean(loss)"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {
"ExecuteTime": {
"end_time": "2018-07-20T10:03:23.143387Z",
"start_time": "2018-07-20T10:03:23.045525Z"
}
},
"outputs": [],
"source": [
"opt = fluid.optimizer.Adam(learning_rate=0.001)\n",
"_ = opt.minimize(loss)"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {
"ExecuteTime": {
"end_time": "2018-07-20T10:03:23.667592Z",
"start_time": "2018-07-20T10:03:23.147225Z"
}
},
"outputs": [
{
"data": {
"text/plain": [
"[]"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"exe = fluid.executor.Executor(fluid.CPUPlace())\n",
"exe.run(fluid.default_startup_program())"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {
"ExecuteTime": {
"end_time": "2018-07-20T10:03:23.759366Z",
"start_time": "2018-07-20T10:03:23.670510Z"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"models/se_resnext_50/129/conv2d_0.w_0\n",
"models/se_resnext_50/129/batch_norm_0.w_0\n",
"models/se_resnext_50/129/batch_norm_0.b_0\n",
"models/se_resnext_50/129/batch_norm_0.w_1\n",
"models/se_resnext_50/129/batch_norm_0.w_2\n",
"models/se_resnext_50/129/conv2d_1.w_0\n",
"models/se_resnext_50/129/batch_norm_1.w_0\n",
"models/se_resnext_50/129/batch_norm_1.b_0\n",
"models/se_resnext_50/129/batch_norm_1.w_1\n",
"models/se_resnext_50/129/batch_norm_1.w_2\n",
"models/se_resnext_50/129/conv2d_2.w_0\n",
"models/se_resnext_50/129/batch_norm_2.w_0\n",
"models/se_resnext_50/129/batch_norm_2.b_0\n",
"models/se_resnext_50/129/batch_norm_2.w_1\n",
"models/se_resnext_50/129/batch_norm_2.w_2\n",
"models/se_resnext_50/129/conv2d_3.w_0\n",
"models/se_resnext_50/129/batch_norm_3.w_0\n",
"models/se_resnext_50/129/batch_norm_3.b_0\n",
"models/se_resnext_50/129/batch_norm_3.w_1\n",
"models/se_resnext_50/129/batch_norm_3.w_2\n",
"models/se_resnext_50/129/fc_0.w_0\n",
"models/se_resnext_50/129/fc_0.b_0\n",
"models/se_resnext_50/129/fc_1.w_0\n",
"models/se_resnext_50/129/fc_1.b_0\n",
"models/se_resnext_50/129/conv2d_4.w_0\n",
"models/se_resnext_50/129/batch_norm_4.w_0\n",
"models/se_resnext_50/129/batch_norm_4.b_0\n",
"models/se_resnext_50/129/batch_norm_4.w_1\n",
"models/se_resnext_50/129/batch_norm_4.w_2\n",
"models/se_resnext_50/129/conv2d_5.w_0\n",
"models/se_resnext_50/129/batch_norm_5.w_0\n",
"models/se_resnext_50/129/batch_norm_5.b_0\n",
"models/se_resnext_50/129/batch_norm_5.w_1\n",
"models/se_resnext_50/129/batch_norm_5.w_2\n",
"models/se_resnext_50/129/conv2d_6.w_0\n",
"models/se_resnext_50/129/batch_norm_6.w_0\n",
"models/se_resnext_50/129/batch_norm_6.b_0\n",
"models/se_resnext_50/129/batch_norm_6.w_1\n",
"models/se_resnext_50/129/batch_norm_6.w_2\n",
"models/se_resnext_50/129/conv2d_7.w_0\n",
"models/se_resnext_50/129/batch_norm_7.w_0\n",
"models/se_resnext_50/129/batch_norm_7.b_0\n",
"models/se_resnext_50/129/batch_norm_7.w_1\n",
"models/se_resnext_50/129/batch_norm_7.w_2\n",
"models/se_resnext_50/129/fc_2.w_0\n",
"models/se_resnext_50/129/fc_2.b_0\n",
"models/se_resnext_50/129/fc_3.w_0\n",
"models/se_resnext_50/129/fc_3.b_0\n",
"models/se_resnext_50/129/conv2d_8.w_0\n",
"models/se_resnext_50/129/batch_norm_8.w_0\n",
"models/se_resnext_50/129/batch_norm_8.b_0\n",
"models/se_resnext_50/129/batch_norm_8.w_1\n",
"models/se_resnext_50/129/batch_norm_8.w_2\n",
"models/se_resnext_50/129/conv2d_9.w_0\n",
"models/se_resnext_50/129/batch_norm_9.w_0\n",
"models/se_resnext_50/129/batch_norm_9.b_0\n",
"models/se_resnext_50/129/batch_norm_9.w_1\n",
"models/se_resnext_50/129/batch_norm_9.w_2\n",
"models/se_resnext_50/129/conv2d_10.w_0\n",
"models/se_resnext_50/129/batch_norm_10.w_0\n",
"models/se_resnext_50/129/batch_norm_10.b_0\n",
"models/se_resnext_50/129/batch_norm_10.w_1\n",
"models/se_resnext_50/129/batch_norm_10.w_2\n",
"models/se_resnext_50/129/fc_4.w_0\n",
"models/se_resnext_50/129/fc_4.b_0\n",
"models/se_resnext_50/129/fc_5.w_0\n",
"models/se_resnext_50/129/fc_5.b_0\n",
"models/se_resnext_50/129/conv2d_11.w_0\n",
"models/se_resnext_50/129/batch_norm_11.w_0\n",
"models/se_resnext_50/129/batch_norm_11.b_0\n",
"models/se_resnext_50/129/batch_norm_11.w_1\n",
"models/se_resnext_50/129/batch_norm_11.w_2\n",
"models/se_resnext_50/129/conv2d_12.w_0\n",
"models/se_resnext_50/129/batch_norm_12.w_0\n",
"models/se_resnext_50/129/batch_norm_12.b_0\n",
"models/se_resnext_50/129/batch_norm_12.w_1\n",
"models/se_resnext_50/129/batch_norm_12.w_2\n",
"models/se_resnext_50/129/conv2d_13.w_0\n",
"models/se_resnext_50/129/batch_norm_13.w_0\n",
"models/se_resnext_50/129/batch_norm_13.b_0\n",
"models/se_resnext_50/129/batch_norm_13.w_1\n",
"models/se_resnext_50/129/batch_norm_13.w_2\n",
"models/se_resnext_50/129/fc_6.w_0\n",
"models/se_resnext_50/129/fc_6.b_0\n",
"models/se_resnext_50/129/fc_7.w_0\n",
"models/se_resnext_50/129/fc_7.b_0\n",
"models/se_resnext_50/129/conv2d_14.w_0\n",
"models/se_resnext_50/129/batch_norm_14.w_0\n",
"models/se_resnext_50/129/batch_norm_14.b_0\n",
"models/se_resnext_50/129/batch_norm_14.w_1\n",
"models/se_resnext_50/129/batch_norm_14.w_2\n",
"models/se_resnext_50/129/conv2d_15.w_0\n",
"models/se_resnext_50/129/batch_norm_15.w_0\n",
"models/se_resnext_50/129/batch_norm_15.b_0\n",
"models/se_resnext_50/129/batch_norm_15.w_1\n",
"models/se_resnext_50/129/batch_norm_15.w_2\n",
"models/se_resnext_50/129/conv2d_16.w_0\n",
"models/se_resnext_50/129/batch_norm_16.w_0\n",
"models/se_resnext_50/129/batch_norm_16.b_0\n",
"models/se_resnext_50/129/batch_norm_16.w_1\n",
"models/se_resnext_50/129/batch_norm_16.w_2\n",
"models/se_resnext_50/129/conv2d_17.w_0\n",
"models/se_resnext_50/129/batch_norm_17.w_0\n",
"models/se_resnext_50/129/batch_norm_17.b_0\n",
"models/se_resnext_50/129/batch_norm_17.w_1\n",
"models/se_resnext_50/129/batch_norm_17.w_2\n",
"models/se_resnext_50/129/fc_8.w_0\n",
"models/se_resnext_50/129/fc_8.b_0\n",
"models/se_resnext_50/129/fc_9.w_0\n",
"models/se_resnext_50/129/fc_9.b_0\n",
"models/se_resnext_50/129/conv2d_18.w_0\n",
"models/se_resnext_50/129/batch_norm_18.w_0\n",
"models/se_resnext_50/129/batch_norm_18.b_0\n",
"models/se_resnext_50/129/batch_norm_18.w_1\n",
"models/se_resnext_50/129/batch_norm_18.w_2\n",
"models/se_resnext_50/129/conv2d_19.w_0\n",
"models/se_resnext_50/129/batch_norm_19.w_0\n",
"models/se_resnext_50/129/batch_norm_19.b_0\n",
"models/se_resnext_50/129/batch_norm_19.w_1\n",
"models/se_resnext_50/129/batch_norm_19.w_2\n",
"models/se_resnext_50/129/conv2d_20.w_0\n",
"models/se_resnext_50/129/batch_norm_20.w_0\n",
"models/se_resnext_50/129/batch_norm_20.b_0\n",
"models/se_resnext_50/129/batch_norm_20.w_1\n",
"models/se_resnext_50/129/batch_norm_20.w_2\n",
"models/se_resnext_50/129/fc_10.w_0\n",
"models/se_resnext_50/129/fc_10.b_0\n",
"models/se_resnext_50/129/fc_11.w_0\n",
"models/se_resnext_50/129/fc_11.b_0\n",
"models/se_resnext_50/129/conv2d_21.w_0\n",
"models/se_resnext_50/129/batch_norm_21.w_0\n",
"models/se_resnext_50/129/batch_norm_21.b_0\n",
"models/se_resnext_50/129/batch_norm_21.w_1\n",
"models/se_resnext_50/129/batch_norm_21.w_2\n",
"models/se_resnext_50/129/conv2d_22.w_0\n",
"models/se_resnext_50/129/batch_norm_22.w_0\n",
"models/se_resnext_50/129/batch_norm_22.b_0\n",
"models/se_resnext_50/129/batch_norm_22.w_1\n",
"models/se_resnext_50/129/batch_norm_22.w_2\n",
"models/se_resnext_50/129/conv2d_23.w_0\n",
"models/se_resnext_50/129/batch_norm_23.w_0\n",
"models/se_resnext_50/129/batch_norm_23.b_0\n",
"models/se_resnext_50/129/batch_norm_23.w_1\n",
"models/se_resnext_50/129/batch_norm_23.w_2\n",
"models/se_resnext_50/129/fc_12.w_0\n",
"models/se_resnext_50/129/fc_12.b_0\n",
"models/se_resnext_50/129/fc_13.w_0\n",
"models/se_resnext_50/129/fc_13.b_0\n",
"models/se_resnext_50/129/conv2d_24.w_0\n",
"models/se_resnext_50/129/batch_norm_24.w_0\n",
"models/se_resnext_50/129/batch_norm_24.b_0\n",
"models/se_resnext_50/129/batch_norm_24.w_1\n",
"models/se_resnext_50/129/batch_norm_24.w_2\n",
"models/se_resnext_50/129/conv2d_25.w_0\n",
"models/se_resnext_50/129/batch_norm_25.w_0\n",
"models/se_resnext_50/129/batch_norm_25.b_0\n",
"models/se_resnext_50/129/batch_norm_25.w_1\n",
"models/se_resnext_50/129/batch_norm_25.w_2\n",
"models/se_resnext_50/129/conv2d_26.w_0\n",
"models/se_resnext_50/129/batch_norm_26.w_0\n",
"models/se_resnext_50/129/batch_norm_26.b_0\n",
"models/se_resnext_50/129/batch_norm_26.w_1\n",
"models/se_resnext_50/129/batch_norm_26.w_2\n",
"models/se_resnext_50/129/fc_14.w_0\n",
"models/se_resnext_50/129/fc_14.b_0\n",
"models/se_resnext_50/129/fc_15.w_0\n",
"models/se_resnext_50/129/fc_15.b_0\n",
"models/se_resnext_50/129/conv2d_27.w_0\n",
"models/se_resnext_50/129/batch_norm_27.w_0\n",
"models/se_resnext_50/129/batch_norm_27.b_0\n",
"models/se_resnext_50/129/batch_norm_27.w_1\n",
"models/se_resnext_50/129/batch_norm_27.w_2\n",
"models/se_resnext_50/129/conv2d_28.w_0\n",
"models/se_resnext_50/129/batch_norm_28.w_0\n",
"models/se_resnext_50/129/batch_norm_28.b_0\n",
"models/se_resnext_50/129/batch_norm_28.w_1\n",
"models/se_resnext_50/129/batch_norm_28.w_2\n",
"models/se_resnext_50/129/conv2d_29.w_0\n",
"models/se_resnext_50/129/batch_norm_29.w_0\n",
"models/se_resnext_50/129/batch_norm_29.b_0\n",
"models/se_resnext_50/129/batch_norm_29.w_1\n",
"models/se_resnext_50/129/batch_norm_29.w_2\n",
"models/se_resnext_50/129/conv2d_30.w_0\n",
"models/se_resnext_50/129/batch_norm_30.w_0\n",
"models/se_resnext_50/129/batch_norm_30.b_0\n",
"models/se_resnext_50/129/batch_norm_30.w_1\n",
"models/se_resnext_50/129/batch_norm_30.w_2\n",
"models/se_resnext_50/129/fc_16.w_0\n",
"models/se_resnext_50/129/fc_16.b_0\n",
"models/se_resnext_50/129/fc_17.w_0\n",
"models/se_resnext_50/129/fc_17.b_0\n",
"models/se_resnext_50/129/conv2d_31.w_0\n",
"models/se_resnext_50/129/batch_norm_31.w_0\n",
"models/se_resnext_50/129/batch_norm_31.b_0\n",
"models/se_resnext_50/129/batch_norm_31.w_1\n",
"models/se_resnext_50/129/batch_norm_31.w_2\n",
"models/se_resnext_50/129/conv2d_32.w_0\n",
"models/se_resnext_50/129/batch_norm_32.w_0\n",
"models/se_resnext_50/129/batch_norm_32.b_0\n",
"models/se_resnext_50/129/batch_norm_32.w_1\n",
"models/se_resnext_50/129/batch_norm_32.w_2\n",
"models/se_resnext_50/129/conv2d_33.w_0\n",
"models/se_resnext_50/129/batch_norm_33.w_0\n",
"models/se_resnext_50/129/batch_norm_33.b_0\n",
"models/se_resnext_50/129/batch_norm_33.w_1\n",
"models/se_resnext_50/129/batch_norm_33.w_2\n",
"models/se_resnext_50/129/fc_18.w_0\n",
"models/se_resnext_50/129/fc_18.b_0\n",
"models/se_resnext_50/129/fc_19.w_0\n",
"models/se_resnext_50/129/fc_19.b_0\n",
"models/se_resnext_50/129/conv2d_34.w_0\n",
"models/se_resnext_50/129/batch_norm_34.w_0\n",
"models/se_resnext_50/129/batch_norm_34.b_0\n",
"models/se_resnext_50/129/batch_norm_34.w_1\n",
"models/se_resnext_50/129/batch_norm_34.w_2\n",
"models/se_resnext_50/129/conv2d_35.w_0\n",
"models/se_resnext_50/129/batch_norm_35.w_0\n",
"models/se_resnext_50/129/batch_norm_35.b_0\n",
"models/se_resnext_50/129/batch_norm_35.w_1\n",
"models/se_resnext_50/129/batch_norm_35.w_2\n",
"models/se_resnext_50/129/conv2d_36.w_0\n",
"models/se_resnext_50/129/batch_norm_36.w_0\n",
"models/se_resnext_50/129/batch_norm_36.b_0\n",
"models/se_resnext_50/129/batch_norm_36.w_1\n",
"models/se_resnext_50/129/batch_norm_36.w_2\n",
"models/se_resnext_50/129/fc_20.w_0\n",
"models/se_resnext_50/129/fc_20.b_0\n",
"models/se_resnext_50/129/fc_21.w_0\n",
"models/se_resnext_50/129/fc_21.b_0\n",
"models/se_resnext_50/129/conv2d_37.w_0\n",
"models/se_resnext_50/129/batch_norm_37.w_0\n",
"models/se_resnext_50/129/batch_norm_37.b_0\n",
"models/se_resnext_50/129/batch_norm_37.w_1\n",
"models/se_resnext_50/129/batch_norm_37.w_2\n",
"models/se_resnext_50/129/conv2d_38.w_0\n",
"models/se_resnext_50/129/batch_norm_38.w_0\n",
"models/se_resnext_50/129/batch_norm_38.b_0\n",
"models/se_resnext_50/129/batch_norm_38.w_1\n",
"models/se_resnext_50/129/batch_norm_38.w_2\n",
"models/se_resnext_50/129/conv2d_39.w_0\n",
"models/se_resnext_50/129/batch_norm_39.w_0\n",
"models/se_resnext_50/129/batch_norm_39.b_0\n",
"models/se_resnext_50/129/batch_norm_39.w_1\n",
"models/se_resnext_50/129/batch_norm_39.w_2\n",
"models/se_resnext_50/129/fc_22.w_0\n",
"models/se_resnext_50/129/fc_22.b_0\n",
"models/se_resnext_50/129/fc_23.w_0\n",
"models/se_resnext_50/129/fc_23.b_0\n",
"models/se_resnext_50/129/conv2d_40.w_0\n",
"models/se_resnext_50/129/batch_norm_40.w_0\n",
"models/se_resnext_50/129/batch_norm_40.b_0\n",
"models/se_resnext_50/129/batch_norm_40.w_1\n",
"models/se_resnext_50/129/batch_norm_40.w_2\n",
"models/se_resnext_50/129/conv2d_41.w_0\n",
"models/se_resnext_50/129/batch_norm_41.w_0\n",
"models/se_resnext_50/129/batch_norm_41.b_0\n",
"models/se_resnext_50/129/batch_norm_41.w_1\n",
"models/se_resnext_50/129/batch_norm_41.w_2\n",
"models/se_resnext_50/129/conv2d_42.w_0\n",
"models/se_resnext_50/129/batch_norm_42.w_0\n",
"models/se_resnext_50/129/batch_norm_42.b_0\n",
"models/se_resnext_50/129/batch_norm_42.w_1\n",
"models/se_resnext_50/129/batch_norm_42.w_2\n",
"models/se_resnext_50/129/fc_24.w_0\n",
"models/se_resnext_50/129/fc_24.b_0\n",
"models/se_resnext_50/129/fc_25.w_0\n",
"models/se_resnext_50/129/fc_25.b_0\n",
"models/se_resnext_50/129/conv2d_43.w_0\n",
"models/se_resnext_50/129/batch_norm_43.w_0\n",
"models/se_resnext_50/129/batch_norm_43.b_0\n",
"models/se_resnext_50/129/batch_norm_43.w_1\n",
"models/se_resnext_50/129/batch_norm_43.w_2\n",
"models/se_resnext_50/129/conv2d_44.w_0\n",
"models/se_resnext_50/129/batch_norm_44.w_0\n",
"models/se_resnext_50/129/batch_norm_44.b_0\n",
"models/se_resnext_50/129/batch_norm_44.w_1\n",
"models/se_resnext_50/129/batch_norm_44.w_2\n",
"models/se_resnext_50/129/conv2d_45.w_0\n",
"models/se_resnext_50/129/batch_norm_45.w_0\n",
"models/se_resnext_50/129/batch_norm_45.b_0\n",
"models/se_resnext_50/129/batch_norm_45.w_1\n",
"models/se_resnext_50/129/batch_norm_45.w_2\n",
"models/se_resnext_50/129/fc_26.w_0\n",
"models/se_resnext_50/129/fc_26.b_0\n",
"models/se_resnext_50/129/fc_27.w_0\n",
"models/se_resnext_50/129/fc_27.b_0\n",
"models/se_resnext_50/129/conv2d_46.w_0\n",
"models/se_resnext_50/129/batch_norm_46.w_0\n",
"models/se_resnext_50/129/batch_norm_46.b_0\n",
"models/se_resnext_50/129/batch_norm_46.w_1\n",
"models/se_resnext_50/129/batch_norm_46.w_2\n",
"models/se_resnext_50/129/conv2d_47.w_0\n",
"models/se_resnext_50/129/batch_norm_47.w_0\n",
"models/se_resnext_50/129/batch_norm_47.b_0\n",
"models/se_resnext_50/129/batch_norm_47.w_1\n",
"models/se_resnext_50/129/batch_norm_47.w_2\n",
"models/se_resnext_50/129/conv2d_48.w_0\n",
"models/se_resnext_50/129/batch_norm_48.w_0\n",
"models/se_resnext_50/129/batch_norm_48.b_0\n",
"models/se_resnext_50/129/batch_norm_48.w_1\n",
"models/se_resnext_50/129/batch_norm_48.w_2\n",
"models/se_resnext_50/129/conv2d_49.w_0\n",
"models/se_resnext_50/129/batch_norm_49.w_0\n",
"models/se_resnext_50/129/batch_norm_49.b_0\n",
"models/se_resnext_50/129/batch_norm_49.w_1\n",
"models/se_resnext_50/129/batch_norm_49.w_2\n",
"models/se_resnext_50/129/fc_28.w_0\n",
"models/se_resnext_50/129/fc_28.b_0\n",
"models/se_resnext_50/129/fc_29.w_0\n",
"models/se_resnext_50/129/fc_29.b_0\n",
"models/se_resnext_50/129/conv2d_50.w_0\n",
"models/se_resnext_50/129/batch_norm_50.w_0\n",
"models/se_resnext_50/129/batch_norm_50.b_0\n",
"models/se_resnext_50/129/batch_norm_50.w_1\n",
"models/se_resnext_50/129/batch_norm_50.w_2\n",
"models/se_resnext_50/129/conv2d_51.w_0\n",
"models/se_resnext_50/129/batch_norm_51.w_0\n",
"models/se_resnext_50/129/batch_norm_51.b_0\n",
"models/se_resnext_50/129/batch_norm_51.w_1\n",
"models/se_resnext_50/129/batch_norm_51.w_2\n",
"models/se_resnext_50/129/conv2d_52.w_0\n",
"models/se_resnext_50/129/batch_norm_52.w_0\n",
"models/se_resnext_50/129/batch_norm_52.b_0\n",
"models/se_resnext_50/129/batch_norm_52.w_1\n",
"models/se_resnext_50/129/batch_norm_52.w_2\n",
"models/se_resnext_50/129/fc_30.w_0\n",
"models/se_resnext_50/129/fc_30.b_0\n",
"models/se_resnext_50/129/fc_31.w_0\n",
"models/se_resnext_50/129/fc_31.b_0\n"
]
}
],
"source": [
"# 加载参数\n",
"def if_exist(var):\n",
" path = os.path.join(pretrained_model_path, var.name)\n",
" exist = os.path.exists(path)\n",
" if exist:\n",
" print(path)\n",
" return exist\n",
"\n",
"fluid.io.load_vars(exe, pretrained_model_path, predicate=if_exist)"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {
"ExecuteTime": {
"end_time": "2018-07-20T10:03:24.157400Z",
"start_time": "2018-07-20T10:03:23.763261Z"
}
},
"outputs": [],
"source": [
"feeder = fluid.data_feeder.DataFeeder([image, label], fluid.CPUPlace())\n",
"reader = feeder.decorate_reader(paddle.batch(paddle.dataset.cifar.train10(), 64), None)"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {
"ExecuteTime": {
"end_time": "2018-07-20T10:03:37.037340Z",
"start_time": "2018-07-20T10:03:24.161237Z"
},
"scrolled": true
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"(0, [array([2.3287988], dtype=float32)])\n",
"(1, [array([2.3302152], dtype=float32)])\n",
"(2, [array([2.5670676], dtype=float32)])\n",
"(3, [array([2.2559876], dtype=float32)])\n",
"(4, [array([2.3306024], dtype=float32)])\n",
"(5, [array([2.3891745], dtype=float32)])\n",
"(6, [array([2.3429651], dtype=float32)])\n",
"(7, [array([2.2461793], dtype=float32)])\n",
"(8, [array([2.2587867], dtype=float32)])\n",
"(9, [array([2.3550396], dtype=float32)])\n",
"(10, [array([2.3980129], dtype=float32)])\n",
"(11, [array([2.3221207], dtype=float32)])\n"
]
}
],
"source": [
"for batch_id, data in enumerate(reader()):\n",
" result = exe.run(fetch_list=[loss], feed=data)\n",
" print(batch_id, result)\n",
" if batch_id > 10:\n",
" break"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"!ls models/fine.model | wc -l\n",
"!ls models/se_resnext_50/129 | wc -l"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {
"ExecuteTime": {
"end_time": "2018-07-20T10:03:37.169253Z",
"start_time": "2018-07-20T10:03:37.041280Z"
}
},
"outputs": [],
"source": [
"fluid.io.save_params(exe, \"models/fine.model\")"
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {
"ExecuteTime": {
"end_time": "2018-07-20T10:06:20.966974Z",
"start_time": "2018-07-20T10:06:20.458865Z"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"333\n",
"331\n",
"125\n"
]
}
],
"source": [
"# 只要部分参数改变了\n",
"!ls models/fine.model | wc -l\n",
"!ls models/se_resnext_50/129 | wc -l\n",
"!diff models/se_resnext_50/129 models/fine.model | grep differ | wc -l"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 2",
"language": "python",
"name": "python2"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 2
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython2",
"version": "2.7.15"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
@oraoto
Copy link
Author

oraoto commented Jul 26, 2018

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment