Last active
January 25, 2018 21:12
-
-
Save hushell/de5b694e968cfcf11e7e9ae92f119b3d to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"import os\n", | |
"from torch.autograd import Variable\n", | |
"import torch\n", | |
"from tqdm import tqdm_notebook as tqdm\n", | |
"from torchvision.models import resnet50\n", | |
"from torchvision import transforms\n", | |
"from torchvision import datasets\n", | |
"from torchnet.meter import ClassErrorMeter" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"Downloading: \"https://download.pytorch.org/models/resnet50-19c8e357.pth\" to /home/hus/.torch/models/resnet50-19c8e357.pth\n", | |
"100%|██████████| 102502400/102502400 [00:13<00:00, 7460954.44it/s]\n" | |
] | |
} | |
], | |
"source": [ | |
"model = resnet50(pretrained=True)\n", | |
"#checkpoint = torch.load('/home/zagoruys/hdd2/symmetry_imagenet_logs/official/resnet_50/model_best.pth.tar')\n", | |
"#model.load_state_dict({k[7:]: v for k, v in checkpoint['state_dict'].items()})" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"metadata": { | |
"collapsed": true, | |
"scrolled": false | |
}, | |
"outputs": [], | |
"source": [ | |
"weights = {k: v for k, v in model.state_dict().items() if k.endswith('.conv2.weight')}" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 5, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"[('layer1.0.conv2.weight', 64, 64, 3, 3),\n", | |
" ('layer1.1.conv2.weight', 64, 64, 3, 3),\n", | |
" ('layer1.2.conv2.weight', 64, 64, 3, 3),\n", | |
" ('layer2.0.conv2.weight', 128, 128, 3, 3),\n", | |
" ('layer2.1.conv2.weight', 128, 128, 3, 3),\n", | |
" ('layer2.2.conv2.weight', 128, 128, 3, 3),\n", | |
" ('layer2.3.conv2.weight', 128, 128, 3, 3),\n", | |
" ('layer3.0.conv2.weight', 256, 256, 3, 3),\n", | |
" ('layer3.1.conv2.weight', 256, 256, 3, 3),\n", | |
" ('layer3.2.conv2.weight', 256, 256, 3, 3),\n", | |
" ('layer3.3.conv2.weight', 256, 256, 3, 3),\n", | |
" ('layer3.4.conv2.weight', 256, 256, 3, 3),\n", | |
" ('layer3.5.conv2.weight', 256, 256, 3, 3),\n", | |
" ('layer4.0.conv2.weight', 512, 512, 3, 3),\n", | |
" ('layer4.1.conv2.weight', 512, 512, 3, 3),\n", | |
" ('layer4.2.conv2.weight', 512, 512, 3, 3)]" | |
] | |
}, | |
"execution_count": 5, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"[tuple((k,) + w.shape) for k, w in weights.items()]" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 32, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import tensorly as tl\n", | |
"import numpy as np\n", | |
"from tensorly.decomposition import parafac\n", | |
"\n", | |
"def cp_decomp_approx(X):\n", | |
" N,C,H,W = X.shape\n", | |
" rank = (N*C*H*W) // (2*(N+C+H+W))\n", | |
" print('rank(W) = %d' % rank)\n", | |
" factors = parafac(X.cpu().numpy(), rank=rank)\n", | |
" full_tensor = tl.kruskal_to_tensor(factors)\n", | |
" return torch.FloatTensor(full_tensor)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 33, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"(64, 64, 3, 3)" | |
] | |
}, | |
"execution_count": 33, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"X = weights['layer1.0.conv2.weight']\n", | |
"Xnumpy = X.cpu().numpy()\n", | |
"factors = parafac(Xnumpy, rank = 137)\n", | |
"full_tensor = tl.kruskal_to_tensor(factors)\n", | |
"full_tensor.shape" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 34, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"rank(W) = 137\n" | |
] | |
} | |
], | |
"source": [ | |
"X = weights['layer1.0.conv2.weight']\n", | |
"X.shape\n", | |
"XX = cp_decomp_approx(X)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 35, | |
"metadata": { | |
"collapsed": true, | |
"scrolled": false | |
}, | |
"outputs": [], | |
"source": [ | |
"def low_rank(W):\n", | |
" U, S, V = torch.svd(W)\n", | |
" n = S.shape[0] // 2\n", | |
" return (U[:,:n] @ S[:n].diag()) @ V[:,:n].t()\n", | |
"\n", | |
"def low_rank3x3(W):\n", | |
" W = W.clone()\n", | |
" for i in range(3):\n", | |
" for j in range(3):\n", | |
" W[:,:,i,j] = low_rank(W[:,:,i,j])\n", | |
" return W" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 36, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"rank(W) = 137\n", | |
"rank(W) = 137\n", | |
"rank(W) = 137\n", | |
"rank(W) = 281\n", | |
"rank(W) = 281\n", | |
"rank(W) = 281\n", | |
"rank(W) = 281\n", | |
"rank(W) = 569\n" | |
] | |
}, | |
{ | |
"ename": "MemoryError", | |
"evalue": "", | |
"output_type": "error", | |
"traceback": [ | |
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", | |
"\u001b[0;31mMemoryError\u001b[0m Traceback (most recent call last)", | |
"\u001b[0;32m<ipython-input-36-7a0ad0cc284b>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m()\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0;31m#weights_star = {k: low_rank3x3(w) for k, w in weights.items()}\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 2\u001b[0;31m \u001b[0mweights_star\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m{\u001b[0m\u001b[0mk\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mcp_decomp_approx\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mw\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mk\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mw\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mweights\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mitems\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m}\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", | |
"\u001b[0;32m<ipython-input-36-7a0ad0cc284b>\u001b[0m in \u001b[0;36m<dictcomp>\u001b[0;34m(.0)\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0;31m#weights_star = {k: low_rank3x3(w) for k, w in weights.items()}\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 2\u001b[0;31m \u001b[0mweights_star\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m{\u001b[0m\u001b[0mk\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mcp_decomp_approx\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mw\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mk\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mw\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mweights\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mitems\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m}\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", | |
"\u001b[0;32m<ipython-input-32-6522eac9dc60>\u001b[0m in \u001b[0;36mcp_decomp_approx\u001b[0;34m(X)\u001b[0m\n\u001b[1;32m 7\u001b[0m \u001b[0mrank\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mN\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0mC\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0mH\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0mW\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m//\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0;36m2\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mN\u001b[0m\u001b[0;34m+\u001b[0m\u001b[0mC\u001b[0m\u001b[0;34m+\u001b[0m\u001b[0mH\u001b[0m\u001b[0;34m+\u001b[0m\u001b[0mW\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 8\u001b[0m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'rank(W) = %d'\u001b[0m \u001b[0;34m%\u001b[0m \u001b[0mrank\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 9\u001b[0;31m \u001b[0mfactors\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mparafac\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mX\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcpu\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnumpy\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mrank\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mrank\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 10\u001b[0m \u001b[0mfull_tensor\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtl\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mkruskal_to_tensor\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfactors\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 11\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mFloatTensor\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfull_tensor\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", | |
"\u001b[0;32m~/anaconda3/lib/python3.6/site-packages/tensorly/decomposition/candecomp_parafac.py\u001b[0m in \u001b[0;36mparafac\u001b[0;34m(tensor, rank, n_iter_max, init, tol, random_state, verbose)\u001b[0m\n\u001b[1;32m 52\u001b[0m \u001b[0mfactors\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 53\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mmode\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mrange\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mT\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mndim\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtensor\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 54\u001b[0;31m \u001b[0mU\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0m_\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0m_\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mT\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpartial_svd\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0munfold\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtensor\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmode\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mn_eigenvecs\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mrank\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 55\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 56\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mtensor\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mshape\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mmode\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m<\u001b[0m \u001b[0mrank\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", | |
"\u001b[0;32m~/anaconda3/lib/python3.6/site-packages/tensorly/backend/numpy_backend.py\u001b[0m in \u001b[0;36mpartial_svd\u001b[0;34m(matrix, n_eigenvecs)\u001b[0m\n\u001b[1;32m 162\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mn_eigenvecs\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mNone\u001b[0m \u001b[0;32mor\u001b[0m \u001b[0mn_eigenvecs\u001b[0m \u001b[0;34m>=\u001b[0m \u001b[0mmin_dim\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 163\u001b[0m \u001b[0;31m# Default on standard SVD\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 164\u001b[0;31m \u001b[0mU\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mS\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mV\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mscipy\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlinalg\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msvd\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmatrix\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 165\u001b[0m \u001b[0mU\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mS\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mV\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mU\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m:\u001b[0m\u001b[0mn_eigenvecs\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mS\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0mn_eigenvecs\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mV\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0mn_eigenvecs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m:\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 166\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mU\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mS\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mV\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", | |
"\u001b[0;32m~/anaconda3/lib/python3.6/site-packages/scipy/linalg/decomp_svd.py\u001b[0m in \u001b[0;36msvd\u001b[0;34m(a, full_matrices, compute_uv, overwrite_a, check_finite, lapack_driver)\u001b[0m\n\u001b[1;32m 114\u001b[0m \u001b[0;31m# perform decomposition\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 115\u001b[0m u, s, v, info = gesXd(a1, compute_uv=compute_uv, lwork=lwork,\n\u001b[0;32m--> 116\u001b[0;31m full_matrices=full_matrices, overwrite_a=overwrite_a)\n\u001b[0m\u001b[1;32m 117\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 118\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0minfo\u001b[0m \u001b[0;34m>\u001b[0m \u001b[0;36m0\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", | |
"\u001b[0;31mMemoryError\u001b[0m: " | |
] | |
} | |
], | |
"source": [ | |
"#weights_star = {k: low_rank3x3(w) for k, w in weights.items()}\n", | |
"weights_star = {k: cp_decomp_approx(w) for k, w in weights.items()}" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 36, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"| setting up data loader...\n" | |
] | |
} | |
], | |
"source": [ | |
"imagenetpath = '/home/zagoruys/ILSVRC2012/'\n", | |
"numthreads = 8\n", | |
"\n", | |
"tr_center_crop = transforms.Compose([\n", | |
"# transforms.ToPILImage(),\n", | |
" transforms.Resize(256),\n", | |
" transforms.CenterCrop(224),\n", | |
" transforms.ToTensor(),\n", | |
" transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])\n", | |
" ])\n", | |
"\n", | |
"print(\"| setting up data loader...\")\n", | |
"valdir = os.path.join(imagenetpath, 'val')\n", | |
"ds = datasets.ImageFolder(valdir, tr_center_crop)\n", | |
"\n", | |
"train_loader = torch.utils.data.DataLoader(ds,\n", | |
" batch_size=256, shuffle=False,\n", | |
" num_workers=numthreads, pin_memory=False)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 37, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "eb140d1fc4aa4e43a5bcbc510a8dfa98", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/html": [ | |
"<p>Failed to display Jupyter Widget of type <code>HBox</code>.</p>\n", | |
"<p>\n", | |
" If you're reading this message in the Jupyter Notebook or JupyterLab Notebook, it may mean\n", | |
" that the widgets JavaScript is still loading. If this message persists, it\n", | |
" likely means that the widgets JavaScript library is either not installed or\n", | |
" not enabled. See the <a href=\"https://ipywidgets.readthedocs.io/en/stable/user_install.html\">Jupyter\n", | |
" Widgets Documentation</a> for setup instructions.\n", | |
"</p>\n", | |
"<p>\n", | |
" If you're reading this message in another frontend (for example, a static\n", | |
" rendering on GitHub or <a href=\"https://nbviewer.jupyter.org/\">NBViewer</a>),\n", | |
" it may mean that your frontend doesn't currently support widgets.\n", | |
"</p>\n" | |
], | |
"text/plain": [ | |
"HBox(children=(IntProgress(value=0, max=196), HTML(value='')))" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"\n", | |
"Validation top1/top5 accuracy:\n", | |
"[76.484, 93.138]\n" | |
] | |
} | |
], | |
"source": [ | |
"class_err = ClassErrorMeter(topk=[1,5], accuracy=True)\n", | |
"\n", | |
"model.cuda().eval()\n", | |
"\n", | |
"with torch.no_grad():\n", | |
" for x, t in tqdm(train_loader):\n", | |
" class_err.add(model(Variable(x.cuda())).data, t)\n", | |
"\n", | |
"print('Validation top1/top5 accuracy:')\n", | |
"print(class_err.value())" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 38, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "935b0dc23e524391a24f3251043ddaaf", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/html": [ | |
"<p>Failed to display Jupyter Widget of type <code>HBox</code>.</p>\n", | |
"<p>\n", | |
" If you're reading this message in the Jupyter Notebook or JupyterLab Notebook, it may mean\n", | |
" that the widgets JavaScript is still loading. If this message persists, it\n", | |
" likely means that the widgets JavaScript library is either not installed or\n", | |
" not enabled. See the <a href=\"https://ipywidgets.readthedocs.io/en/stable/user_install.html\">Jupyter\n", | |
" Widgets Documentation</a> for setup instructions.\n", | |
"</p>\n", | |
"<p>\n", | |
" If you're reading this message in another frontend (for example, a static\n", | |
" rendering on GitHub or <a href=\"https://nbviewer.jupyter.org/\">NBViewer</a>),\n", | |
" it may mean that your frontend doesn't currently support widgets.\n", | |
"</p>\n" | |
], | |
"text/plain": [ | |
"HBox(children=(IntProgress(value=0, max=196), HTML(value='')))" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"\n", | |
"Validation top1/top5 accuracy:\n", | |
"[76.194, 92.976]\n" | |
] | |
} | |
], | |
"source": [ | |
"class_err = ClassErrorMeter(topk=[1,5], accuracy=True)\n", | |
"\n", | |
"model.cuda().eval()\n", | |
"for k, w in weights_star.items():\n", | |
" model.state_dict()[k].copy_(w)\n", | |
"\n", | |
"with torch.no_grad():\n", | |
" for x, t in tqdm(train_loader):\n", | |
" class_err.add(model(Variable(x.cuda())).data, t)\n", | |
"\n", | |
"print('Validation top1/top5 accuracy:')\n", | |
"print(class_err.value())" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.6.3" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 2 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment