Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save izmailovpavel/438ed2fcf46b2ea5f8a8e7fac3daffc3 to your computer and use it in GitHub Desktop.
Save izmailovpavel/438ed2fcf46b2ea5f8a8e7fac3daffc3 to your computer and use it in GitHub Desktop.
AlexNet CIFAR-10 test run
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 76,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"env: CUDA_VISIBLE_DEVICES=0\n"
]
}
],
"source": [
"%env CUDA_VISIBLE_DEVICES=0"
]
},
{
"cell_type": "code",
"execution_count": 45,
"metadata": {},
"outputs": [],
"source": [
"# uncomment order to re-install the starter kit\n",
"# !rm -rf neurips_bdl_starter_kit\n",
"\n",
"# !git clone https://github.com/izmailovpavel/neurips_bdl_starter_kit\n",
"\n",
"import sys\n",
"import math\n",
"import matplotlib\n",
"import numpy as onp\n",
"import einops\n",
"import torch \n",
"import torch.nn as nn\n",
"import torch.nn.functional as F\n",
"from torch.utils.data import DataLoader, TensorDataset\n",
"import torch.optim as optim\n",
"from torchvision.datasets import CIFAR10\n",
"from torchvision import transforms\n",
"\n",
"from matplotlib import pyplot as plt\n",
"import tqdm\n",
"\n",
"sys.path.append(\"../neurips_bdl_starter_kit\")\n",
"import pytorch_models as models"
]
},
{
"cell_type": "code",
"execution_count": 46,
"metadata": {},
"outputs": [],
"source": [
"# train_dataset = TensorDataset(torch.from_numpy(x_train_).float(),\n",
"# torch.from_numpy(y_train).long())\n",
"# test_dataset = TensorDataset(torch.from_numpy(x_test_).float(),\n",
"# torch.from_numpy(y_test).long())\n",
"transform = transforms.Compose([\n",
" transforms.ToTensor(),\n",
" transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),\n",
" ])\n",
"train_dataset = CIFAR10(root=\"/datasets/cifar10/\", train=True, transform=transform, download=False)\n",
"test_dataset = CIFAR10(root=\"/datasets/cifar10/\", train=False, transform=transform, download=False)"
]
},
{
"cell_type": "code",
"execution_count": 49,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<matplotlib.colorbar.Colorbar at 0x7f601bc6cb20>"
]
},
"execution_count": 49,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAATEAAAD5CAYAAABPqQIFAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/YYfK9AAAACXBIWXMAAAsTAAALEwEAmpwYAAAjWklEQVR4nO2df4xc1ZXnP6d/d/v3b4xt1gnjTEIyEwd5SHbIZkmiZAhi5WR3hchqs5kRiqMVaDdS9g+GlTastJEyq4SI1a4YOQEFRkkIO8CGjdgkhJ2EMJkAhhhjMAQDNnbT/tG22227f1bV2T/qedXtfud2dXe5ul77+7Geuuqeuu/duu/V8b33nHuOuTtCCFFUWua7AUIIMRekxIQQhUZKTAhRaKTEhBCFRkpMCFFopMSEEIWmbS6Vzex64G6gFfiuu38j9fnVq1f75s2b53JJ0WAqlUooK5VKoaytrTW33CuxS09LS/x/qrVYKINYFl0tdbYic+DAAfr7++f09f7s44v8xMlyTZ99fs/oz9z9+rlcb67MWomZWSvwP4BPAYeB58zsMXd/JaqzefNmdu3alStL/VhEHUi4A5rFz/zwuaFQduJkfyhbuXJFbnl5bCSs093TE8paOzpDmVus/CqBuspXscXnmmuumfM5+k+WeeZnG2v6bPv6N1an5Ga2CXgAWEf1Kdzp7neb2Z3Al4Dj2UfvcPfHszp/CdwClIF/5+4/S11jLiOxa4D97v5mduEHge1AqMSEEEXAKXvdBhUl4Kvu/oKZLQGeN7MnMtm33f2bEz9sZlcBNwPvBy4HfmFm73H3cGg4lzWxDcChCe8PZ2VCiALjQAWv6Zj2XO597v5C9voMsI+0ntgOPOjuo+7+FrCf6oAp5KIv7JvZDjPbZWa7jh8/Pn0FIcS8U6nx30wws83Ah4BnsqLbzGyPmd1nZufXH2Y8OJqLEusFNk14vzErm4S773T3be6+bc2aNXO4nBCiETjOuFdqOoDV5wcp2bEj75xmthh4GPiKuw8C9wBXAluBPuBbs23vXNbEngO2mNm7qCqvm4F/NYfzCSGaAAfKNUwVM/rdfVvqA2bWTlWBfd/dHwFw96MT5N8BfpK9rWlwNJFZKzF3L5nZbcDPqBp77nP3l2d7vpR5Xcwfo0OnQ9nJw2+GskP78uudHjwX1rn2E58MZUu7u0JZakJhgXVST1uaWta7asGqpu97gX3ufteE8vXu3pe9/RywN3v9GPADM7uL6sL+FuDZ1DXm5CeWmUQfn8s5hBDNhQPl+oXouhb4AvCSme3Oyu4APm9mW7PLHQC+DODuL5vZQ1S9HErArSnLJMxRiQkhFib1crBw96fJ9y0OBz/u/nXg67VeQ0pMCDEJx2eyJjbvSIkJISbhDuPF0WFSYkKICzHKBdpdKiUmhJiEA4l9+k1H0ygxJSy5uKT6t8Vi2ZFDb4WyPf/wVCgbH87fON6+OH9jOMDwYOzOsXTlylAWbfKGeHO4nrY0GokJIQpL1dlVSkwIUVAcGPfiuANLiQkhJuEY5QLtaZASE0JMoeKaTgohCorWxGZJKkSymDue2EgyPhqHoH7n0MFQtrSnO5T1LF+SW37s1Jmwzom+OFjBuk1XhDJa4mDTYYz9ZMz+Sx2jrDUxIURRqUZ2lRITQhQUd2PMi5NKRUpMCDGFlANxsyElJoSYRHVhX9NJIURh0cK+EKLAaGFfzCvRRu/UJu/jJ0+EsgMH3g5lo4l6S7o6csuHzg6GdV598Xeh7LLNV4ay5ZclMnoF/ZGKNyB3HyjL2VUIUVQcY9yLoxqK01IhREPQwr4QotA4pumkEKLYaGFfCFFY3JGLhRCiuFQX9i+RbUdmdgA4A5SBkrtvq0ejxFyIXAriJMq9hw+HsrfejmWH9r8ZylYvWZxbvnH1orBO39txxIyXdj0XyrZdtzyU9Sxdli8ozpLPvHCpLex/3N3763AeIUQT4JiCIgohik2RRmJzbakDPzez581sRz0aJISYX6p5J1tqOpqBuY7EPuruvWa2FnjCzF5190nJCDPltgPgiisS0TmFEE1CsTKAz0mVuntv9vcY8ChwTc5ndrr7NnfftmbNmrlcTgjRAKop21prOpqBWSsxM1tkZkvOvwY+DeytV8OEEPODu10y08l1wKPZjv824Afu/tPZny5OZDE7e/hFGA4HkQ88TEcBeOJ7JaIl2Kz/f8k/Z6VSCmuMl8ZD2ZmhkVB2+OjJUHY0kJXLa8M6G9fG3/nV554NZWsvWx/K3vMnUyYHGfGj35KwzCWCgSSHBCljn6WekXniknB2dfc3gQ/WsS1CiCagGk+sOGticrEQQlyAIrsKIQpM1cVCIzEhREG5pPZOCiEWJgrFI4QoLNVQPJpOzoKU7Xo2Z5vlTUg1I0w6EVdyYteGpBtF0v0iJZu55IrNm0NZz5KloWzw3HAow/K/295Dx8Iq3W2doaxtZCyUvfybX4WyVRvW5Zav2PjusI6V4vtpiR936pmrtMTnTIjmjXqtiZnZJuABqi5ZDux097vNbCXwI2AzcAC4yd1PWdVn627gBmAI+HN3fyF1jeKMGYUQDaEaxaJuzq4l4KvufhXwEeBWM7sKuB140t23AE9m7wE+A2zJjh3APdNdQEpMCDGJ6rajlpqOac/l3nd+JOXuZ4B9wAZgO3B/9rH7gc9mr7cDD3iV3wLLzSz2ZqapppNCiObAZrKlaLWZ7Zrwfqe778w9q9lm4EPAM8A6d+/LREeoTjehquAOTah2OCvrI0BKTAgxhRl47PfXEtHZzBYDDwNfcffBiQmK3d3Nkhu6kkiJCSEmUW/rpJm1U1Vg33f3R7Lio2a23t37sunieYtPL7BpQvWNWVlIEymx+i7PzVavpyyNVPJllUT8+vFSbFXr6OgIZen/mFIWsqhK7Ly4YsXqUPbRj10Xyl7a/WooO/BWfrz8cinuq/2tR0JZ1+bLQ1n5tddD2Uu/+vvc8g//szgsVHdPfn4AgHJqI3dKFosozcIyH1mo62XorFeEiszaeC+wz93vmiB6DPgi8I3s748nlN9mZg8CHwZOT5h25tJESkwI0QzUOcb+tcAXgJfMbHdWdgdV5fWQmd0CHARuymSPU3Wv2E/VxeIvpruAlJgQYhIOlOo0EnP3p4kHop/M+bwDt87kGlJiQogpNEvAw1qQEhNCTMaVsk0IUWAUFFEIUXg0EpsNySDkszlfalN2YoNv4pQlz9/M/fr+2MQ/PHwulL33fe8LZZ2dsUtES8qWH1BJxIeqJB6DP732n4Syt9+K3Xe++9ffzS0vDccuJ28fHwhlnT3x5vAtK+P1m9d+vSu3fE1iA/h7r43i8sNQYkN/eyVuR0finp0cOp1bPjo2GtaJXFXGxuM6taKgiEKIQuMYpYRCbjakxIQQU9CamBCiuLimk0KIAqM1MSFE4ZESE0IUFscoL6SFfTO7D7gROObuH8jKcuNjz6UhlYRLRBTQIRnbvpyIbZ+6PwlT+KHet3PL//fjPwnrDA7mm88B/rQ/jjf/8X/6iVDW2Rm7G0T9WAlrQKkcSxcvWRLKbtx+Yyjb/9rvc8t/8X+eCOsMjsf37NXeOMLFCusOZV0j+Tf7tz/9eVinbVUcxaJl3fJQdm4gvtftlTh6R9/g4dzy02fi842MjOSWnx0aDOvMhCIt7Neibr8HXH9BWRQfWwhRcDxb2K/laAamVWLu/hRw8oLiKD62EGIB4G41Hc3AbNfEovjYQojC0zyjrFqY88L+dPGxzWwH1dRLXHHFFXO9nBCiATTLKKsWZmuCOHo+jdIF8bGn4O473X2bu29bsyYOCSyEaA7coVyxmo5mYLZK7Hx8bJgcH1sIsQCoYDUdzUAtLhY/BK6jml/uMPA14vjYcyA2QUc+EadOnQirnD51oS1iwula484/cjx2e/iHXc/mlj//8othncGTA6FsdDyO6PD+P/pAKFu7Jk7s0dqaf0sHzwyFdQYGBkLZ5o0bQ9nlG9eGsj//0r/OLT/U+0ZY55kX94Sy0XNxFI7XD8fuFz2X5dc7sXdvWGfokVDElddeHcpOnT0TnzPh+jBqA7nlqYgUlSBpTSoxTa04xZpOTqvE3P3zgWhKfGwhxELgElvYF0IsPFKZC5sNKTEhxBQW1HRSCHFpUbVOLqC9k0KISw9NJ4UQhUbTyRAH8s3GlcQu/8gd5fRgf1jl1795OpQdfCc/agBA/+BAKDt1Lt+E3rKoI6zTNboolB07kWr/r0PZ5s2bQlkU4aL38PGwzvhYbJYfHhoIZWfPxLL24Ml635/ECTp2738plI2diYcGhwdi94Wejvz+2LisK6zz1q4XQllrZzzNarl8ZSg7XYpdXELnEY+fq9HR/N+Rp8KV1IjTPPsia0EjMSHEFAo0m5QSE0JcgIM3yZaiWpASE0JMQdNJIUShkXVSCFFYFtzeSSHEJYYDUmL5DI8M8fK+/IgPbW3tYb3IBeBUIvrCwNk4ycLbfb2hbNnaVaFs5bL8hBSrVsdx0o6/0RfK9u2NXQqe+EWcUGPZ0jgxRmtbvsF+dCyeH4yN5iedAPjpz2JZe8KpO4pw0bM6vs8f3PreUPa7p18LZUOJNCi/P3E0t7y7HLu+rCjFyVH2//b5UDawJnbbONkSt7F9LL9eKZE4ZWgo32XjzOBwWGcmaDophCgwJuukEKLgaCQmhCgsroV9IUTRKdBIrDjxNoQQDcRqPKY5i9l9ZnbMzPZOKLvTzHrNbHd23DBB9pdmtt/MXjOzP6ulpQ0diZ07d5bfPPubXNnw4Lmw3qKufEvSjTduD+uUPH/jL8DzL70aypYtWRHKhiv5lrrL18ZpN8ePxtai0+fiTcFDr8fWuBWJTciLluX31eIVsQW1a1FsOVu2PI5tv2zp0lC2dOni3PLuxT1hnes+8eFQdro/tjbv3ftmKCuP5//Q3h5IWF3bYwtq25HYYnjmVCwrLYktyi3d+TkTeg/Flu3B4PcyNjL3GPsACYPvTPke8N+BBy4o/7a7f3NigZldBdwMvB+4HPiFmb3H3RPRITQSE0JcyHk/sVqO6U7l/hQQZ+2ZzHbgQXcfdfe3gP3ANdNVkhITQkzBvbZjDtxmZnuy6eb56c8G4NCEzxzOypJIiQkhpuI1HtVUjrsmHDtqOPs9wJXAVqAP+NZcmirrpBBiKrW7WPS7+7YZndr9/2+jMLPvAD/J3vYCEyN+bszKkmgkJoSYgnltx6zObbZ+wtvPAectl48BN5tZp5m9C9gC5GesnoBGYkKIybhBnbYdmdkPgeuoTjsPA18DrjOzrVQnpAeALwO4+8tm9hDwClACbp3OMgk1KDEzuw+4ETjm7h/Iyu4EvgScD9x+h7s/Pt25RkfHePNAvjn89LFTYb0t79qSW97dHW/ifeedY6Hs4Ftvh7LFi2JT+Oh4vkuEJTbdDg/EZnda4gflD66MY9FfuWZZKFuyIt/t4dix2EVhxcp4QL5+U9zHZwZjF5GOwETfVYldNpYmvtenrv94KDt5Ko6xf/Rw/nPQPxr7EPScjs+3NuFW0pYYmmxYEsffX7Tustzy3gMHwjpjQ/n5HjyVq2Im1MnZ1d0/n1N8b+LzXwe+PpNr1DKd/B5wfU75t919a3ZMq8CEEAWi9oX9eWdaJTZDPw8hxEJgISmxBHl+HkKIolNHZ9dGMFslVrOfh5ntOO9DMjRUn4BtQoiLy8W0TtabWSkxdz/q7mV3rwDfIbE1wN13uvs2d9/W0xMvmgshmoiFPp1M+HkIIRYARRqJ1eJiUbOfx3RUymXOnc439Q+NxFPNzp78GOSnz8RuAwcPHQhly5fFZvLyuTi6gY3kp47vO7I/rNP3Tn98vpb88wHc9C/+eSirnI3tLP/36V/mlh/cEzs+r1rWEcqOvB6ve2y4/IpQdno8P7Y97bHry8pVcTSQP/rDD4Sysc/Gj/F99/5Nbvnwmfg+vzNwNpTRFvfV6FjstnG2/0Qouzx4Hju642gaq9cuzy3vPxb0+0xpkvWuWphWic3Uz0MIUXCaaKpYC/LYF0JMRUpMCFFkrH5BES86UmJCiKloJCaEKCrNZHmsBSkxIcRUFpJ1sp5UvMLYaL4rxdBonChk/1v5LgyP/q+HwzpP/+pXocwSN+joYGxeP37wUG55e2L9YDwRVaDjsjhqw98/9etQNjoYu2288vrvc8vPHY2jaQwcj9u4fFW+ewvA8UTSjMHT+fdzxfLY4XmsnN92gF/+8oVQ1r10VShbsXptbnn/eOzyMDQaf6/ehGuGd8bPVU/QHwCtx/PdTpavip+P1tb8n+4br8dJU2aERmJCiCKj6aQQori4rJNCiKKjkZgQotBIiQkhikyR1sSU7UgIUWgaOhJrbWtl2cp8s/F4Qp0Ons1P3PDK7t1hnaNvvRXKWhJfu6ctjhzQ0ZIfwcDHxhLXis3uG9fHyY1XLomD5Z5KBJd89+Y/zC0/WI4TsQycjN0Nyp3LQ9nRRMSPoaF8t42Bk3GUBWuNk4iMWKL9Q2+EspaOfJeOSmscjcI74nYMEa94l0uxbFHQDoDFy/LvdWtr/KOoBEmAWhN9OCMKNBLTdFIIMRlZJ4UQhUcjMSFEUTGKtbAvJSaEmIqUmBCisCiKRUxrayuLA+tk25JFYb2xE/mbZ/t/n78hG2DT4njzrAVWRoAzw7HFbaQlf2OwdcebpDstthYdPxrHyn/+mRdD2bolS0LZiVMDueWnh2OL5tnEIu5wf75luEpseW0LrH/d7fGvYyRh5T0+MBDKyi1xH/e05VsFrSW2/LV0pSx8ic7y8VB07lzc/4OD+bIVq5YnmhH1fZ2iT2hhXwhRZDQSE0IUGykxIURhUbYjIUTR0XRSCFFspMSEEEVmQW07MrNNwAPAOqr6eae7321mK4EfAZuBA8BN7h7v0qWae6DSkW/a9nJsGu4INsK2j8ex4a9YujKUlRIm+TMJV4TWpYtzy1s6YheL4aOnQ9nowFDcjhNnQll/JXYPGBjNP+fmq/84rHPkeLwBfOBU3P7Fi2O3mJGhfLeY8fa4r0YSse2Hx+NfVUtL/Ox0BffGLXaHKCf8C1rb4p9MSykevlQq8TmPHR/ILS/FjzdtHfnfuVSug/Yp2JpYLaF4SsBX3f0q4CPArWZ2FXA78KS7bwGezN4LIQqOzeBoBqZVYu7e5+4vZK/PAPuADcB24P7sY/cDn71IbRRCNBqv8WgCZrQmZmabgQ8BzwDr3L0vEx2hOt0UQiwAimSdrDmyq5ktBh4GvuLuk/aiuHuol81sh5ntMrNdQ2fj9SYhRBNRp5GYmd1nZsfMbO+EspVm9oSZvZ79XZGVm5n9NzPbb2Z7zOzqWppakxIzs3aqCuz77v5IVnzUzNZn8vVAbgZQd9/p7tvcfVvP4ji6pRCiSciCItZy1MD3gOsvKIvW0z8DbMmOHcA9tVxgWiVmZgbcC+xz97smiB4Dvpi9/iLw41ouKIQoAHUaibn7U8CFkQ6i9fTtwANe5bfA8vMDpRS1rIldC3wBeMnMdmdldwDfAB4ys1uAg8BN052oXK4wMJDvOjA6FEcwWDSW7xKx5rLLwzonDuanhgfYf+BgKDs+HkexWLky322jpSseYZ6rxF4n5fHYvlMaGg1lI6Ox7b0ULGYcP9If1jl3Nnb18PH4Se3p7AllY0E0EOvsDOuURuLv3LEodufwhFvByGj+c1Vpib/XWCl+Fjvb4wgoHV3xd1vck++eA9AdyMYTfd8SReGo01rWRV4Ti9bTNwATQ9Mczsr6SDCtEnP3p4mtqZ+crr4QooDUrsRWm9muCe93uvvOmi/j7mZzU5ny2BdCTGEGaqXf3bfN8PRHzWy9u/ddsJ7eC2ya8LmNWVkS5Z0UQkzGqQZFrOWYHdF6+mPAv8mslB8BTk+YdoZoJCaEmEQ9E4WY2Q+B66hOOw8DXyNeT38cuAHYDwwBf1HLNaTEhBBTqZMSc/fPB6Ip6+mZv+mtM72GlJgQYgrmxXHZb6wSqxgMt+fLYus6Jcs3a59L5HPoSyTo6Eukmz87lpjon8iP6NDaHrsoDCWiF3iY7AGGS3FEBw9S2AN0BC4AvcdjF4tU5ANLbPM9fioRtMTy63k5bnt7d+yqsrQjdm0oJ8I9ePBjbG2Ll4O7CZ5RoCWIqALQnnC/sET7PXhGLHGtFgt+ukG/z4gm2hdZCxqJCSGmUKS9k1JiQogpLKigiEKISxCNxIQQhUUZwIUQhUdKTAhRVOrp7NoIGqrEzIw2yzdfjyf8Us4O5/tfnBwczC0HODkW+2yU2uOv7aXYNWMkiswQREoAGPdUgov4WouWLQ1lra1xvSiRhSc2mEVuCNNeKyGLkndEwRcAKglhS/I7x31cruS7X3giuUjqWmH0CKrPdyyM61WCNia8bChFwjr5d1mlOFpMIzEhxGTkJyaEKDpysRBCFBuNxIQQRUYL+0KI4uLUzUDQCBqqxCrlMmfPnM2VDQ7mp70HOBekejt3Lo6HnzIULV0eW/46u+M46eG1Ehar7rZ44297R3ytlOWvPWFdjayT5dRG9OQDG8tS1VqjPkn8F19ObA4PrXGk2z8e1CsnvldrW9z3bUH/TteOrq6uUNYZ3E8PrJYAnUGugqSFdAZoTUwIUVjkJyaEKDbumk4KIYqNRmJCiGIjJSaEKDIaiQkhiosD5eJosWmVmJltAh6gmmrcqWb4vdvM7gS+BBzPPnqHuz+eOlepVKL/xIlc2fhYbE4eGcnfYD02Fm+8bu+K46S3d8VuD8PD+e4cEMdXT23kJiFzj83hpXLsUtCSig/fE5jeUzuvE4u4KdeMFJGpPxWzP8XQUJzHIOWa0Ra5LyQ2gKf6KuXCkHZVSXzvoFpXV5xzIHKxSG1QnwkLbSRWAr7q7i+Y2RLgeTN7IpN9292/efGaJ4SYFxaSdTLLwNuXvT5jZvuADRe7YUKI+aNII7EZjT3NbDPwIeCZrOg2M9tjZveZ2Yp6N04IMQ/4DI4moGYlZmaLgYeBr7j7IHAPcCWwlepI7VtBvR1mtsvMdo2OJpJLCiGaAgOs7DUdzUBNSszM2qkqsO+7+yMA7n7U3cvuXgG+A1yTV9fdd7r7NnffFi1GCiGaC3Ov6WgGplViVjXH3Avsc/e7JpSvn/CxzwF76988IUTDKdh0shbr5LXAF4CXzGx3VnYH8Hkz20r1qxwAvjzdiSrujI8HbhGJIPBtbfnuEqmBXWd3bJ5OWbuj7PAQR5ZIhSMvJ9woUq4BrQnXjNaORAz49vx+7Aj6ENKuAak2pl0K8kkEZki6ByxfvjyUjY+Ph7LRwA2nnFi5nq0bRSrSRqkUt5FyJJv5fSmX6xF+YoHtnXT3p8n/2Sd9woQQxaVI1kl57AshprKQRmJCiEsMp2ksj7UgJSaEmEpxdJiUmBBiKs3iPlELUmJCiKlIiQUXa2tj1apVubIWYheAcjA/Hy8l0tcnzCsjI3GkCmtNRDcIUtFXEpEexhIm79ZKIvpFglQSkYrnm95TfTXbyBKpnBSVwO+kVIp9LCqJdZhU8o6Ua0OUKGS8kogSkujf2bpfpO5ZSzB3S7m3RM9c1fd8jjigRCFCiKJiNI83fi1IiQkhpjLLOHJ5mNkB4AxQBkruvs3MVgI/AjZTdZa/yd1Pzeb89YmgJoRYOJyfTtZy1M7H3X2ru2/L3t8OPOnuW4Ans/ezQkpMCDGFBmwA3w7cn72+H/jsbE8kJSaEmMr53JPTHTWeDfi5mT1vZjuysnVZwFWAI1TD388KrYkJIS5gRgpqtZntmvB+p7vvvOAzH3X3XjNbCzxhZq9Oupq7m81+t2ZDlVhraytLly7NlVXKqUQK+QPG0bE4MsDg0NlQ1taeiBCRkIUm70RkhvZEZIZSYvG0kjKvB24UAARuIJaIppEMw5GgknjQK4FriScG/5WEe8DYcJwUJhXFohK5nicShaR6I+VO44maPV1doawjcB9pSbhztLXl/3TrkihkZtmO+iesc+Wfzr03+3vMzB6lGnvwqJmtd/e+LKzXsdk2V9NJIcQU6rUmZmaLsgRDmNki4NNUYw8+Bnwx+9gXgR/Ptq2aTgohplI/P7F1wKOZk3Ab8AN3/6mZPQc8ZGa3AAeBm2Z7ASkxIcRknFkvMUw5lfubwAdzyk8An6zHNaTEhBAXsMAiuwohLkGkxIQQhcWBusTqbwwNV2IWGERTbiJj4/n5KkdG42gUYUIS0lEK2hImag9u7FgiisJoImqDJcz8lmhHyvQemdgrpbh/U//npuJbpB5zD9pYTrkoWCxraYtb0t4aR0CJr5WQJROnJNxKUh2ZcB9pCdxiUnVK4/nPVV2iWODJazcbGokJIaai6aQQorDU0TrZCKTEhBBT0UhMCFFopMSEEIXFHRJ7d5uNaZWYmXUBTwGd2ef/1t2/ZmbvAh4EVgHPA19w99gkCODxBtrR0dQG33zZ2NhIWGcscb6x8diamNqEHMWiT8VP7+rsDGUtibjx5YTFM2U9i/rXWhJx4xM2yNSG4o7E944YGYnvWSpWfmuiHan+j/pqdDTf4g0wNJTIwZCwDHclNnmn2l8ay29LaLUEurryn6tU+2ZEgUZitWwAHwU+4e4fBLYC15vZR4C/Ar7t7n8AnAJuuWitFEI0lvrGE7uoTKvEvMr5uDbt2eHAJ4C/zcrnFJlRCNFMeNU6WcvRBNQUisfMWs1sN9WYP08AbwAD7n5+/H8Y2HBRWiiEaCxedZqt5WgGalrYd/cysNXMlgOPAu+t9QJZONodAEuWLplFE4UQDadA245mFBTR3QeAvwP+MbDczM4rwY1Ab1Bnp7tvc/dt3d3dc2mrEKIRuFdTttVyNAHTKjEzW5ONwDCzbuBTwD6qyuxfZh+bU2RGIUSTUaCF/Vqmk+uB+82slarSe8jdf2JmrwAPmtl/AX4H3Dvdidw9jIee2rAdmt4TnRjFIAcg6W4QE5nyU24IntjkPZ5wKUi1P5Xe3oLt3K2JTdItqf5ImOxTrh4e/C/d0dGRaEfcj7N1zWhvz//eKbeMVDtSfZ9qR0fgEgHQ09mTW556FqP7knKXmQnR/WtGplVi7r4H+FBO+ZtUA/4LIRYUzTPKqgV57AshJqMN4EKIIuOAL6RtR0KISwxXUEQhRMFxTSeFEIWmQCMxS5nJ634xs+NUE2UCrAb6G3bxGLVjMmrHZIrWjn/k7mvmciEz+2l2vVrod/fr53K9udJQJTbpwma73H3bvFxc7VA71I4Fw4y2HQkhRLMhJSaEKDTzqcR2zuO1J6J2TEbtmIza0eTM25qYEELUA00nhRCFZl6UmJldb2avmdl+M7t9PtqQteOAmb1kZrvNbFcDr3ufmR0zs70Tylaa2RNm9nr2d8U8teNOM+vN+mS3md3QgHZsMrO/M7NXzOxlM/v3WXlD+yTRjob2iZl1mdmzZvZi1o7/nJW/y8yeyX43PzKzOCTIpYS7N/QAWqmGt3430AG8CFzV6HZkbTkArJ6H634MuBrYO6HsvwK3Z69vB/5qntpxJ/AfGtwf64Grs9dLgN8DVzW6TxLtaGifUI3Cszh73Q48A3wEeAi4OSv/a+DfNvI+NesxHyOxa4D97v6mV1O8PQhsn4d2zBvu/hRw8oLi7VQTrkCDEq8E7Wg47t7n7i9kr89QDbq5gQb3SaIdDcWrKDlPjcyHEtsAHJrwfj6TjDjwczN7PssFMJ+sc/e+7PURYN08tuU2M9uTTTcv+rR2Ima2mWr8umeYxz65oB3Q4D5Rcp7audQX9j/q7lcDnwFuNbOPzXeDoPo/MQQhWi8+9wBXUs0x2gd8q1EXNrPFwMPAV9x9cKKskX2S046G94m7l919K9X8Fdcwg+Q8lxrzocR6gU0T3odJRi427t6b/T1GNYvTfEaqPWpm6wGyv8fmoxHufjT7AVWA79CgPjGzdqqK4/vu/khW3PA+yWvHfPVJdu0BZpic51JjPpTYc8CWzNLSAdwMPNboRpjZIjNbcv418Glgb7rWReUxqglXYB4Tr5xXGhmfowF9YtWA8fcC+9z9rgmihvZJ1I5G94mS88yQ+bAmADdQtfy8AfzHeWrDu6laRl8EXm5kO4AfUp2WjFNd27gFWAU8CbwO/AJYOU/t+BvgJWAPVSWyvgHt+CjVqeIeYHd23NDoPkm0o6F9Avwx1eQ7e6gqzP804Zl9FtgP/E+gs1HPbDMf8tgXQhSaS31hXwhRcKTEhBCFRkpMCFFopMSEEIVGSkwIUWikxIQQhUZKTAhRaKTEhBCF5v8BAUzAj0XIlRoAAAAASUVORK5CYII=\n",
"text/plain": [
"<Figure size 432x288 with 2 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"plt.imshow(train_dataset.data[2])\n",
"plt.colorbar()"
]
},
{
"cell_type": "code",
"execution_count": 72,
"metadata": {},
"outputs": [],
"source": [
"model = models.get_model(\"cifar_alexnet\", data_info={\"num_classes\": 10})\n",
"model.cuda();\n",
"# prior_variance = 0.05\n",
"\n",
"# def log_likelihood_fn(model, batch):\n",
"# \"\"\"Computes the log-likelihood.\"\"\"\n",
"# x, y = batch\n",
"# if torch.cuda.is_available():\n",
"# x = x.cuda()\n",
"# y = y.cuda().long()\n",
"# model.zero_grad()\n",
"# logits = model(x)\n",
"# softmax_xent = F.cross_entropy(logits, y)\n",
"# # num_classes = 2 #logits.shape[-1]\n",
"# # labels = F.one_hot(y, num_classes=num_classes)\n",
"# # print(logits, labels)\n",
"# # return 0.\n",
"# # softmax_xent = torch.sum(labels * F.log_softmax(logits, dim=-1))\n",
"# return softmax_xent\n",
"\n",
"\n",
"# def log_prior_fn(model):\n",
"# \"\"\"Computes the Gaussian prior log-density.\"\"\"\n",
"# n_params = sum(p.numel() for p in model.parameters())\n",
"# exp_term = sum((-p**2 / (2 * prior_variance)).sum()\n",
"# for p in model.parameters())\n",
"# norm_constant = -0.5 * n_params * math.log((2 * math.pi * prior_variance))\n",
"# return exp_term + norm_constant\n",
"\n",
"\n",
"# def log_posterior_fn(model, batch):\n",
"# log_lik = log_likelihood_fn(model, batch)\n",
"# # log_prior = log_prior_fn(model)\n",
"# return log_lik #+ log_prior\n",
"\n",
"\n",
"def get_accuracy_fn(model, batch):\n",
" x, y = batch\n",
" x = x.cuda()\n",
" y = y.cuda()\n",
"\n",
" logits = model(x)\n",
" probs = F.softmax(logits, dim=1)\n",
" preds = torch.argmax(logits, dim=1)\n",
"# print(preds, y)\n",
"# print(y)\n",
" accuracy = (preds == y).float().mean()\n",
" return accuracy, probs\n",
"\n",
"\n",
"def evaluate_fn(model, data_loader):\n",
" model.eval()\n",
" sum_accuracy = 0\n",
" all_probs = []\n",
" with torch.no_grad():\n",
" for bacth in data_loader: \n",
" batch_accuracy, batch_probs = get_accuracy_fn(model, bacth)\n",
" sum_accuracy += batch_accuracy.item()\n",
" all_probs.append(batch_probs.detach())\n",
" all_probs = torch.cat(all_probs, dim=0)\n",
" model.train()\n",
" return sum_accuracy / len(data_loader), all_probs"
]
},
{
"cell_type": "code",
"execution_count": 73,
"metadata": {},
"outputs": [],
"source": [
"batch_size = 100\n",
"test_batch_size = 100\n",
"num_epochs = 20\n",
"momentum_decay = 0.9\n",
"init_lr = 0.03\n",
"\n",
"train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)\n",
"test_loader = DataLoader(test_dataset, batch_size=test_batch_size, shuffle=False)\n",
"\n",
"optimizer = optim.SGD(model.parameters(), lr=init_lr, momentum=momentum_decay)\n",
"scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(\n",
" optimizer, T_max=len(train_loader)*num_epochs)"
]
},
{
"cell_type": "code",
"execution_count": 74,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"500it [00:13, 37.93it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\t0 loss: 2.25559139251709\n",
"\t0 test_acc: 0.271499991863966\n",
"Epoch 0\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"500it [00:12, 38.56it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\t1 loss: 1.611702561378479\n",
"\t1 test_acc: 0.5068999880552292\n",
"Epoch 1\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"500it [00:12, 38.52it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\t2 loss: 1.1714857816696167\n",
"\t2 test_acc: 0.6321999886631966\n",
"Epoch 2\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"500it [00:12, 38.64it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\t3 loss: 0.9047054648399353\n",
"\t3 test_acc: 0.680299985408783\n",
"Epoch 3\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"500it [00:12, 38.71it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\t4 loss: 0.7031428813934326\n",
"\t4 test_acc: 0.7346999812126159\n",
"Epoch 4\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"500it [00:12, 38.73it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\t5 loss: 0.5433331727981567\n",
"\t5 test_acc: 0.7538999825716018\n",
"Epoch 5\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"500it [00:12, 38.73it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\t6 loss: 0.40039512515068054\n",
"\t6 test_acc: 0.7570999819040298\n",
"Epoch 6\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"500it [00:12, 38.68it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\t7 loss: 0.2642303705215454\n",
"\t7 test_acc: 0.7646999776363372\n",
"Epoch 7\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"500it [00:12, 38.71it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\t8 loss: 0.14901365339756012\n",
"\t8 test_acc: 0.7622999817132949\n",
"Epoch 8\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"500it [00:12, 38.70it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\t9 loss: 0.07879100739955902\n",
"\t9 test_acc: 0.7709999805688859\n",
"Epoch 9\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"500it [00:12, 38.68it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\t10 loss: 0.036166418343782425\n",
"\t10 test_acc: 0.7703999787569046\n",
"Epoch 10\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"500it [00:12, 38.68it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\t11 loss: 0.009763226844370365\n",
"\t11 test_acc: 0.7810999780893326\n",
"Epoch 11\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"500it [00:12, 38.71it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\t12 loss: 0.0017603568267077208\n",
"\t12 test_acc: 0.7866999781131745\n",
"Epoch 12\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"500it [00:12, 38.58it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\t13 loss: 0.0006528312806040049\n",
"\t13 test_acc: 0.7871999788284302\n",
"Epoch 13\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"500it [00:12, 38.61it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\t14 loss: 0.00046996952733024955\n",
"\t14 test_acc: 0.7871999776363373\n",
"Epoch 14\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"500it [00:12, 38.55it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\t15 loss: 0.00040141792851500213\n",
"\t15 test_acc: 0.7865999788045883\n",
"Epoch 15\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"500it [00:12, 38.56it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\t16 loss: 0.0003653575840871781\n",
"\t16 test_acc: 0.7864999788999557\n",
"Epoch 16\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"500it [00:12, 38.55it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\t17 loss: 0.0003474488912615925\n",
"\t17 test_acc: 0.7864999788999557\n",
"Epoch 17\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"500it [00:12, 38.60it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\t18 loss: 0.00033885688753798604\n",
"\t18 test_acc: 0.7865999794006348\n",
"Epoch 18\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"500it [00:12, 38.70it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\t19 loss: 0.0003360138216521591\n",
"\t19 test_acc: 0.7867999792098999\n",
"Epoch 19\n",
"Train accuracy: 1.0\n"
]
}
],
"source": [
"for epoch in range(num_epochs):\n",
" sum_loss = 0.\n",
" for i, batch in tqdm.tqdm(enumerate(train_loader)):\n",
" x, y = batch\n",
" x = x.cuda()\n",
" y = y.cuda()\n",
" optimizer.zero_grad()\n",
" logits = model(x)\n",
" loss = F.cross_entropy(logits, y)\n",
" loss.backward()\n",
" optimizer.step()\n",
" scheduler.step()\n",
" sum_loss += loss\n",
" # if i % 50 == 0:\n",
" print(f\"\\t{epoch} loss: {sum_loss / len(train_loader)}\")\n",
" test_acc, _ = evaluate_fn(model, test_loader)\n",
" print(f\"\\t{epoch} test_acc: {test_acc}\")\n",
" print(\"Epoch {}\".format(epoch))\n",
" \n",
"\n",
"_, all_test_probs = evaluate_fn(model, test_loader)\n",
"train_acc, _ = evaluate_fn(model, train_loader)\n",
"print(f\"Train accuracy: {train_acc}\")\n",
"all_test_probs = all_test_probs.cpu().numpy()"
]
},
{
"cell_type": "code",
"execution_count": 75,
"metadata": {},
"outputs": [],
"source": [
"# torch.save(model.state_dict(), \"model3\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Ensemble"
]
},
{
"cell_type": "code",
"execution_count": 94,
"metadata": {},
"outputs": [],
"source": [
"y_test = onp.array(test_dataset.targets)"
]
},
{
"cell_type": "code",
"execution_count": 81,
"metadata": {},
"outputs": [],
"source": [
"all_test_probs = []\n",
"for i in range(1, 4):\n",
" model.load_state_dict(torch.load(f\"model{i}\"))\n",
" _, test_probs = evaluate_fn(model, test_loader)\n",
" test_probs = test_probs.cpu().numpy()\n",
" all_test_probs.append(test_probs)\n",
"all_test_probs = onp.stack(all_test_probs)"
]
},
{
"cell_type": "code",
"execution_count": 96,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([0.7948, 0.7781, 0.7868])"
]
},
"execution_count": 96,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"(onp.argmax(all_test_probs, axis=-1) == y_test[None, :]).mean(-1)\n",
"# y_test"
]
},
{
"cell_type": "code",
"execution_count": 98,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([0.8138])"
]
},
"execution_count": 98,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"(all_test_probs.mean(0).argmax(axis=-1) == y_test[None, :]).mean(-1)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "py38",
"language": "python",
"name": "py38"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.3"
}
},
"nbformat": 4,
"nbformat_minor": 4
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment