Skip to content

Instantly share code, notes, and snippets.

@thomasbrandon
Created September 27, 2019 14:30
Show Gist options
  • Save thomasbrandon/2e8e365086ed20cef655dac11cd8c365 to your computer and use it in GitHub Desktop.
Save thomasbrandon/2e8e365086ed20cef655dac11cd8c365 to your computer and use it in GitHub Desktop.
NB for MNIST Stats update in Fastai
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"Collapsed": "false"
},
"source": [
"# MNIST Stats"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"Collapsed": "false"
},
"outputs": [],
"source": [
"from fastai.vision import *\n",
"from itertools import islice"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"Collapsed": "false"
},
"outputs": [],
"source": [
"DATA = untar_data(URLs.MNIST)"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"Collapsed": "false"
},
"outputs": [],
"source": [
"src = (ImageList.from_folder(DATA)\n",
" .split_by_folder(train='training', valid='testing')\n",
" .label_from_folder())"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"Collapsed": "false"
},
"outputs": [
{
"data": {
"text/plain": [
"([0.15, 0.15, 0.15], [0.15, 0.15, 0.15])"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"mnist_stats"
]
},
{
"cell_type": "markdown",
"metadata": {
"Collapsed": "false"
},
"source": [
"So the stats look odd, looks like copy-paste error."
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"Collapsed": "false"
},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAiYAAAI4CAYAAABeEiKtAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAaR0lEQVR4nO3de6zdZb3n8e9TNqUU7PSUGqRcrIZLSXGkpIEgEMCp9HBtjqgFqpzKJG1KiDJGTAwgiLcWabQCDlGJo+IkQIpF8AKWpsNIpVBBsBJoCsitYI/tMD21LUzpb/6oJzMiez+/w157r+/e+/VK+OesT37riZd13v4sj6VpmgAAyGBUtw8AAPBvhAkAkIYwAQDSECYAQBrCBABIQ5gAAGkIEwAgDWHC3ymlHFlKWVFK+d+llPWllH/q9pmAoa+Uslcp5eZSynOllH8tpTxaSjm92+ciF2HC3yil9ETEnRFxd0RMiIh5EXFLKeXwrh4MGA56IuKFiDg5Iv5DRFwZEbeVUiZ38UwkU9z8yv+vlHJURDwYEe9o/vovjlLKvRGxummaK7t6OGDYKaU8HhFfbJpmabfPQg7emPBmpZf/21GDfRBgeCul7B8Rh0fEH7p9FvIQJrzZkxGxMSIuK6XsWUo5LXa/dh3b3WMBw0kpZc+I+HFE/KBpmie7fR7y8F/l8HdKKf8xIq6P3W9J1kTEv0TEa03T/OeuHgwYFkopoyLiv0fEuIiY1TTN/+nykUikp9sHIJ+maR6P3W9JIiKilLIqIn7QvRMBw0UppUTEzRGxf0ScIUp4M2HC3/nrG5N1sfu/6rs4Ig6IiP/WzTMBw8Z/jYgjI2JG0zTbu30Y8vFnTHgrn4iIl2P3nzX5TxHxoaZpXuvukYChrpTy7oiYHxFHR8QrpZStf/1rTpePRiL+jAkAkIY3JgBAGsIEAEhDmAAAaQgTACCNPv924VKKPxkLI1TTNG/1P0/QMX5fYOTq6/fFGxMAIA1hAgCkIUwAgDSECQCQhjABANIQJgBAGsIEAEhDmAAAaQgTACANYQIApCFMAIA0hAkAkIYwAQDSECYAQBrCBABIQ5gAAGkIEwAgDWECAKQhTACANIQJAJCGMAEA0hAmAEAawgQASEOYAABpCBMAIA1hAgCkIUwAgDSECQCQhjABANIQJgBAGsIEAEhDmAAAaQgTACANYQIApCFMAIA0hAkAkEZPtw8AMNxMnDixulmyZEl1c8EFF1Q3a9asqW6+8pWvVDcREb/85S+rmx07drR6Frxd3pgAAGkIEwAgDWECAKQhTACANIQJAJCGMAEA0hAmAEAawgQASKM0TdP7h6X0/iEwrDVNUwby+cP59+Uzn/lMdfP1r399EE7y7/PQQw9VNwsWLKhufve733XiOAxjff2+eGMCAKQhTACANIQJAJCGMAEA0hAmAEAawgQASEOYAABpCBMAIA0XrAFvyQVrb99+++1X3bzyyivVzfLly6ubT33qU9XN5ZdfXt1ERMyZM6e62bJlS3Vzww03VDdXX311ddPX/39iaHPBGgAwJAgTACANYQIApCFMAIA0hAkAkIYwAQDSECYAQBrCBABIwwVrSeyzzz6tdocddlh1c8opp1Q3RxxxRHUzb9686uanP/1pdXPTTTdVN/fcc091w+BywdrA+sIXvlDdzJ8/v7qZNm1adbNx48ZWZ5o1a1Z1881vfrO6OeSQQ6qbZcuWVTcf//jHq5vt27dXN+TjgjUAYEgQJgBAGsIEAEhDmAAAaQgTACANYQIApCFMAIA0hAkAkIYL1gbB+973vurmmmuuafWsc845p7/HGXTPP/98dXPzzTe3etbixYurGxcudYYL1rrvoIMOqm5efPHFQTjJ//PRj360urn22murmzaXsH31q1+tbq688srqhnxcsAYADAnCBABIQ5gAAGkIEwAgDWECAKQhTACANIQJAJCGMAEA0nDBWj9dffXV1c3nP//56qanp6cDp9lt/fr1HXnOE088Ud2cdtpp1c2YMWM6cZyIiDj44IOrmw0bNnTs+0YyF6zxdk2ZMqW6+cUvflHdtLlgbvbs2dXNHXfcUd0wuFywBgAMCcIEAEhDmAAAaQgTACANYQIApCFMAIA0hAkAkIYwAQDScMFaH4499tjq5q677qpuJk6cWN089NBDrc503XXXVTdLly5t9axOuOyyy6qbhQsXduz7XLA2eFywxkA69NBDq5tHH320ulm3bl11c+KJJ7Y60/bt21vt6D8XrAEAQ4IwAQDSECYAQBrCBABIQ5gAAGkIEwAgDWECAKQhTACANIQJAJCGm1/78MILL1Q3kyZNqm5WrFhR3Zx55pmtzvT666+32g2W0aNHVze33357dXPWWWe1+r4jjjiiulm/fn2rZ9E3N7/SbatXr65upk+fXt2cd955rb6vzW8VneHmVwBgSBAmAEAawgQASEOYAABpCBMAIA1hAgCkIUwAgDSECQCQRk+3D9AtJ5xwQnUzceLE6mbUqHrbffnLX65usl2c1tbHPvax6mbs2LEd+75PfOIT1c1VV13Vse8DuqfNb+eyZcuqm5kzZ7b6Phes5eCNCQCQhjABANIQJgBAGsIEAEhDmAAAaQgTACANYQIApCFMAIA0RuwFawcccEB1M3r06Opm165d1c0HP/jB6uaxxx6rbjrpIx/5SHUza9as6mbGjBnVTZt/HAHebPXq1R15zjvf+c5WuzYXZrb5zad/vDEBANIQJgBAGsIEAEhDmAAAaQgTACANYQIApCFMAIA0hAkAkMaIvWCtzcU9W7durW723Xff6uaKK67oyGawbdu2rbp55plnqpspU6Z04jjACLN58+bqZvny5dXNWWed1er79tprr+pm+/btrZ7F2+eNCQCQhjABANIQJgBAGsIEAEhDmAAAaQgTACANYQIApCFMAIA0RuwFay+88EJ1c9lll1U3119/fXXT09O5f5gfeeSR6mbVqlUd+a4VK1ZUN2PHjq1ubrnllk4cJyIi1q5d27FnAbnt3LmzumlzCRtDizcmAEAawgQASEOYAABpCBMAIA1hAgCkIUwAgDSECQCQhjABANIYsRestfGd73ynI5vZs2dXN1u3bm11pp/97GetdoPl/PPPH9Tve+CBBwb1+4Du2XPPPaubiRMnDsJJGEzemAAAaQgTACANYQIApCFMAIA0hAkAkIYwAQDSECYAQBrCBABIwwVrg+DWW2/t9hEAhpwxY8ZUN5MnTx74gzCovDEBANIQJgBAGsIEAEhDmAAAaQgTACANYQIApCFMAIA0hAkAkIYwAQDScPMrABER8YEPfKC6Ofvss6ub5cuXVzf3339/dXPOOedUN4ccckh1s23btuomIqJpmlY7BpY3JgBAGsIEAEhDmAAAaQgTACANYQIApCFMAIA0hAkAkIYwAQDScMEawDB3ySWXtNpdddVV1c2ECROqm8997nPVzcqVK6ubE088sbrZunVrdbNw4cLqJiJix44drXYMLG9MAIA0hAkAkIYwAQDSECYAQBrCBABIQ5gAAGkIEwAgDWECAKThgjX65R3veEe3jwBUTJs2rdWuzeVpnXLKKad05Dl//vOfq5vvf//7HfkuBoc3JgBAGsIEAEhDmAAAaQgTACANYQIApCFMAIA0hAkAkIYwAQDScMEa/TJnzpyOPOfBBx9stdu0aVNHvg9GkhUrVrTanXfeedXNmDFj+nuciIjYsGFDdbNz587q5tBDD61ufvWrX7U607nnnlvdPPPMM62exdvnjQkAkIYwAQDSECYAQBrCBABIQ5gAAGkIEwAgDWECAKQhTACANFywRr9s27atI8+ZMGFCq92ee+5Z3bz22mv9PQ4MKz/+8Y9b7Z599tnq5pRTTqluHnvssermgQceqG7a/L7cfPPN1c0FF1xQ3URE/P73v69uPvvZz1Y3Y8eOrW6efvrp6mbZsmXVzXDkjQkAkIYwAQDSECYAQBrCBABIQ5gAAGkIEwAgDWECAKQhTACANErTNL1/WErvH0JEnH/++dXNLbfc0rHvO/jgg6ubDRs2dOz7RrKmacpAPt/vC50wZsyY6mbmzJmtnnXFFVdUNz099XtJ165dW91ceuml1c2mTZuqm6Gqr98Xb0wAgDSECQCQhjABANIQJgBAGsIEAEhDmAAAaQgTACANYQIApFG/KQYAktqxY0d1c+edd7Z6VtsdA8sbEwAgDWECAKQhTACANIQJAJCGMAEA0hAmAEAawgQASEOYAABpCBMAIA03v9IvDz74YHXT5mbGMWPGdOI4AAxx3pgAAGkIEwAgDWECAKQhTACANIQJAJCGMAEA0hAmAEAawgQASKM0TdP7h6X0/iG0dOmll1Y3ixcvbvWs733ve9XN/PnzWz2LvjVNUwby+X5fYOTq6/fFGxMAIA1hAgCkIUwAgDSECQCQhjABANIQJgBAGsIEAEhDmAAAabhgDXhLLlgDBooL1gCAIUGYAABpCBMAIA1hAgCkIUwAgDSECQCQhjABANIQJgBAGn1esAYAMJi8MQEA0hAmAEAawgQASEOYAABpCBMAIA1hAgCkIUwAgDSECQCQhjABANIQJgBAGsIEAEhDmAAAaQgT/kYpZeub/nqjlHJ9t88FDA+llMmllJ+XUv5XKeWVUsoNpZSebp+LPIQJf6Npmn3/7a+I2D8itkfE7V0+FjB8fDsiNkbEARFxdEScHBEXd/VEpCJM6MtHYvcPyP/s9kGAYeM9EXFb0zQ7mqZ5JSJ+GRFTu3wmEhEm9OWfI+KHTdM03T4IMGwsiYjzSiljSykHRsTpsTtOICKECb0opRwSu1+x/qDbZwGGlf8Ru9+QbImIFyNiTUQs6+qJSEWY0JsLI+LXTdM82+2DAMNDKWVURNwTEXdExD4RMTEi/iEiFnXzXOQiTOjNheFtCdBZEyLi4Ii4oWma15qm2RQR34+IM7p7LDIRJvydUsoHIuLA8HfjAB3UNM2fI+LZiFhQSukppYyP3X+W7bHunoxMhAlv5Z8j4o6maf612wcBhp0PR8Q/RsS/RMT6iNgZEf+lqycileJvuAAAsvDGBABIQ5gAAGkIEwAgDWECAKTR5/+iYynFn4yFEappmjKQz/f7AiNXX78v3pgAAGkIEwAgDWECAKQhTACANIQJAJCGMAEA0hAmAEAawgQASEOYAABpCBMAIA1hAgCkIUwAgDSECQCQhjABANIQJgBAGsIEAEhDmAAAaQgTACANYQIApCFMAIA0hAkAkIYwAQDSECYAQBrCBABIQ5gAAGkIEwAgDWECAKQhTACANIQJAJCGMAEA0hAmAEAawgQASEOYAABpCBMAIA1hAgCkIUwAgDSECQCQhjABANIQJgBAGsIEAEhDmAAAaQgTACANYQIApCFMAIA0hAkAkEZPtw8AwNAxefLk6mbBggXVzfTp0zvyXW0269atq24iItasWdNqV/OlL32puml7ppHIGxMAIA1hAgCkIUwAgDSECQCQhjABANIQJgBAGsIEAEhDmAAAabhgDWCY22effVrt5s2bV91cc8011c3YsWOrm1dffbW6efLJJ6ubV155pbpp69hjj61uDj300OrmjTfeqG7mzp3b5kgjkjcmAEAawgQASEOYAABpCBMAIA1hAgCkIUwAgDSECQCQhjABANJwwVo/lVKqmz322KO6OfXUU1t933777VfdTJkypbqZOnVqq+/rhKOOOqq6aXPmiIi77767upkzZ051s2XLllbfB9m1ucxswYIFrZ61aNGi6mbjxo3VzXXXXVfdfPe7361uNmzYUN100r777lvd3HnnndXNYYcd1onjjFjemAAAaQgTACANYQIApCFMAIA0hAkAkIYwAQDSECYAQBrCBABIozRN0/uHpfT+4QjQ5lKiuXPnVjfTp0/vwGna27VrV3Xz+uuvD8JJdtu8eXN1M378+FbPanOZ1Nlnn13d/PznP2/1fSNZ0zT12wP7YaT/vowbN6662b59e3UzY8aM6qbNxYQR7S5PO/PMM6ubRx55pNX3ZTN69Ojqps0/Jz/84Q+rm09+8pOtzjRc9fX74o0JAJCGMAEA0hAmAEAawgQASEOYAABpCBMAIA1hAgCkIUwAgDR6un2AzM4444zqps3laY8//nh1c9ttt7U6U5tn/eUvf6luVq5c2er7Bstzzz3XatfmgrWHH364v8eBfjnwwAOrmzb/Xl6+fHl188c//rHNkVqZPXt2dTNUL09r48Ybb+zIc5YuXdqR54xU3pgAAGkIEwAgDWECAKQhTACANIQJAJCGMAEA0hAmAEAawgQASMMFa3349re/Xd20uYTt05/+dHVz//33tzoTkN9LL71U3Xzta1+rbhYtWlTdjBpV/8+Xu3btqm4iIt7znvdUN0Pxt2r+/PmtdhdddFF1c9VVV1U3d999d6vv4615YwIApCFMAIA0hAkAkIYwAQDSECYAQBrCBABIQ5gAAGkIEwAgDWECAKTh5tc+nHTSSR15znHHHVfdDMXbFIG3b/HixdXNs88+W920uaF64sSJrc60cOHC6ubll1+ubu69995W39cJF198cXXzjW98o9Wznnzyyerm2muvbfUs3j5vTACANIQJAJCGMAEA0hAmAEAawgQASEOYAABpCBMAIA1hAgCk4YK1PkybNq0jz1m9enVHnkPE66+/Xt3s2rVrEE4C/dM0TXWzdOnS6uaZZ56pbubOndvmSDFv3rzq5q677qpuVq5cWd1cdNFF1c0JJ5xQ3Xzxi1+sbjZv3lzdRETMmjWrumnzG0T/eGMCAKQhTACANIQJAJCGMAEA0hAmAEAawgQASEOYAABpCBMAIA0XrPXTtm3bqpuXXnppEE6S1+jRo6ubnp52/1JcsWJFdbNp06ZWz4Lh4NFHH+3IJiLi9ttvr25+9KMfVTczZsyobh5//PHqZt99961u2lyeduaZZ1Y3ERHr169vtWNgeWMCAKQhTACANIQJAJCGMAEA0hAmAEAawgQASEOYAABpCBMAIA0XrPVh48aN1c2vf/3r6ubpp5/uxHGGrNmzZ1c373rXu1o9a8mSJf09DtCLNr9nM2fOrG7uu+++6mbSpEmtzlTT5vflkUce6ch3MTi8MQEA0hAmAEAawgQASEOYAABpCBMAIA1hAgCkIUwAgDSECQCQRmmapvcPS+n9wxFgjz32qG723nvv6mbr1q2dOM6Q9dvf/ra6Ofroo1s966CDDqpuXn755VbPom9N05SBfP5I/30Zqg4//PDqps0Fa23+vbxr167qps3v6zHHHFPdRLgMczD19fvijQkAkIYwAQDSECYAQBrCBABIQ5gAAGkIEwAgDWECAKQhTACANHq6fYDM3njjjepmpF+e1sb48eOrm23btrV61s6dO/t7HKAX733ve6ubNhcmjh07trpZvHhxdXPyySdXN20uT1u1alV1ExFx6qmnVjdPPPFEq2fx9nljAgCkIUwAgDSECQCQhjABANIQJgBAGsIEAEhDmAAAaQgTACANF6wl8e53v7vV7qKLLqpupkyZUt0cddRR1c3atWurm8mTJ1c3kyZNqm7aXpx25JFHVjfz5s2rbtpc7rRjx45WZ4Ls9tprr1a7a665proZPXp0dXP55ZdXNwsXLqxu2pz7oYceqm7a/N5FREydOrW6ccHawPPGBABIQ5gAAGkIEwAgDWECAKQhTACANIQJAJCGMAEA0hAmAEAawgQASKM0TdP7h6X0/iEdtWTJkla7D3/4w9XNqFH13hw/fnx1M2bMmFZnGore//73Vzdtbr4dzpqmKQP5fL8vg+ekk05qtVu5cmV1s2rVqo59Xye0uen6vvvua/Wsnp76Zej7779/q2fRt75+X7wxAQDSECYAQBrCBABIQ5gAAGkIEwAgDWECAKQhTACANIQJAJCGC9aGmDYXo7W5YG3vvfeubkaPHl3dzJw5s7q58cYbq5vnn3++uomIWLRoUXVzzz33VDfPPfdcdbNr165WZxquXLA2fLS9wPGSSy6pbs4999zqZtmyZa2+b7D85je/abU79thjq5sPfehD1c2KFStafd9I5oI1AGBIECYAQBrCBABIQ5gAAGkIEwAgDWECAKQhTACANIQJAJBGT7cPwL/Pq6++2u0j/I0jjjiiI89pc1FbRMS6des68n3A31u/fn11k+3ytME2adKkbh9h2PPGBABIQ5gAAGkIEwAgDWECAKQhTACANIQJAJCGMAEA0hAmAEAaLlijX84444zqZsuWLdVNtovjYDjZtGlTq93kyZOrm9NOO626uffee1t9Xyecfvrp1c0xxxzTse+77bbbOvYs3po3JgBAGsIEAEhDmAAAaQgTACANYQIApCFMAIA0hAkAkIYwAQDSKE3T9P5hKb1/CBHx1FNPVTebN2+ubo4//vhOHIcOapqmDOTz/b4MnuOOO67VbtWqVdXN888/X91MnTq11ffVzJo1q7r51re+Vd2MGzeu1fddeOGF1c2tt97a6ln0ra/fF29MAIA0hAkAkIYwAQDSECYAQBrCBABIQ5gAAGkIEwAgDWECAKTR0+0DADCw1q5d22q3fPny6mbGjBnVzcMPP1zdjBpV/8/Fhx9+eHXz6quvVjdz586tbiJcnpaFNyYAQBrCBABIQ5gAAGkIEwAgDWECAKQhTACANIQJAJCGMAEA0ihN0/T+YSm9fwgR8dRTT1U3EyZMqG6mTJnS6vs2bdrUakf/NU1TBvL5fl/ymTZtWnVz0003VTfTp0/vxHFizZo11c0555xT3fzpT3/qxHHooL5+X7wxAQDSECYAQBrCBABIQ5gAAGkIEwAgDWECAKQhTACANIQJAJCGMAEA0nDzK70aN25cdfOHP/yhunnxxRerm+OPP77VmRg8bn4FBoqbXwGAIUGYAABpCBMAIA1hAgCkIUwAgDSECQCQhjABANIQJgBAGj3dPgB5jRpV79Y2m5/85CedOA4AI4A3JgBAGsIEAEhDmAAAaQgTACANYQIApCFMAIA0hAkAkIYwAQDSKE3T9P5hKb1/CAxrTdOUgXy+3xcYufr6ffHGBABIQ5gAAGkIEwAgDWECAKQhTACANIQJAJCGMAEA0hAmAEAafV6wBgAwmLwxAQDSECYAQBrCBABIQ5gAAGkIEwAgDWECAKTxfwFP0ayXeJygGAAAAABJRU5ErkJggg==\n",
"text/plain": [
"<Figure size 576x576 with 4 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"data = src.databunch(bs=4, num_workers=0).normalize(mnist_stats)\n",
"data.show_batch()"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {
"Collapsed": "false"
},
"outputs": [
{
"data": {
"text/plain": [
"[(tensor(0.1296), tensor(0.3091)),\n",
" (tensor(0.1280), tensor(0.3092)),\n",
" (tensor(0.2026), tensor(0.3676)),\n",
" (tensor(0.1413), tensor(0.3152)),\n",
" (tensor(0.1778), tensor(0.3480))]"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"[(im.px.mean(), im.px.std()) for im in src.train.x[:5]]"
]
},
{
"cell_type": "markdown",
"metadata": {
"Collapsed": "false"
},
"source": [
"Not the same as the fastai stats."
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {
"Collapsed": "false"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"mean= 0.024, std=2.227\n",
"mean= 0.007, std=2.202\n",
"mean=-0.153, std=2.062\n",
"mean=-0.178, std=2.000\n",
"mean=-0.050, std=2.137\n"
]
}
],
"source": [
"for xb,_ in islice(data.train_dl, 5):\n",
" print(f\"mean={xb.mean().item(): 0.3f}, std={xb.std().item():0.3f}\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"Collapsed": "false"
},
"source": [
"Resulting std isn't what we're going for."
]
},
{
"cell_type": "markdown",
"metadata": {
"Collapsed": "false"
},
"source": [
"## Calculate stats"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {
"Collapsed": "false"
},
"outputs": [],
"source": [
"class RunningStatistics:\n",
" '''Records mean and variance of the final `n_dims` dimension over other dimensions across items. So collecting across `(l,m,n,o)` sized\n",
" items with `n_dims=1` will collect `(l,m,n)` sized statistics while with `n_dims=2` the collected statistics will be of size `(l,m)`.\n",
"\n",
" Uses the algorithm from Chan, Golub, and LeVeque in \"Algorithms for computing the sample variance: analysis and recommendations\":\n",
"\n",
" `variance = variance1 + variance2 + n/(m*(m+n)) * pow(((m/n)*t1 - t2), 2)`\n",
"\n",
" This combines the variance for 2 blocks: block 1 having `n` elements with `variance1` and a sum of `t1` and block 2 having `m` elements\n",
" with `variance2` and a sum of `t2`. The algorithm is proven to be numerically stable but there is a reasonable loss of accuracy (~0.1% error).\n",
"\n",
" Note that collecting minimum and maximum values is reasonably innefficient, adding about 80% to the running time, and hence is disabled by default.\n",
" '''\n",
" def __init__(self, n_dims:int=2, record_range=False):\n",
" self._n_dims,self._range = n_dims,record_range\n",
" self.n,self.sum,self.min,self.max = 0,None,None,None\n",
" \n",
" def update(self, data:Tensor):\n",
" data = data.view(*list(data.shape[:-self._n_dims]) + [-1])\n",
" with torch.no_grad():\n",
" new_n,new_var,new_sum = data.shape[-1],data.var(-1),data.sum(-1)\n",
" if self.n == 0:\n",
" self.n = new_n\n",
" self._shape = data.shape[:-1]\n",
" self.sum = new_sum\n",
" self._nvar = new_var.mul_(new_n)\n",
" if self._range:\n",
" self.min = data.min(-1)[0]\n",
" self.max = data.max(-1)[0]\n",
" else:\n",
" assert data.shape[:-1] == self._shape, f\"Mismatched shapes, expected {self._shape} but got {data.shape[:-1]}.\"\n",
" ratio = self.n / new_n\n",
" t = (self.sum / ratio).sub_(new_sum).pow_(2)\n",
" self._nvar.add_(new_n, new_var).add_(ratio / (self.n + new_n), t)\n",
" self.sum.add_(new_sum)\n",
" self.n += new_n\n",
" if self._range:\n",
" self.min = torch.min(self.min, data.min(-1)[0])\n",
" self.max = torch.max(self.max, data.max(-1)[0])\n",
"\n",
" @property\n",
" def mean(self): return self.sum / self.n if self.n > 0 else None\n",
" @property\n",
" def var(self): return self._nvar / self.n if self.n > 0 else None\n",
" @property\n",
" def std(self): return self.var.sqrt() if self.n > 0 else None\n",
"\n",
" def __repr__(self):\n",
" def _fmt_t(t:Tensor):\n",
" if t.numel() > 5: return f\"tensor of ({','.join(map(str,t.shape))})\"\n",
" def __fmt_t(t:Tensor):\n",
" return '[' + ','.join([f\"{v:.3g}\" if v.ndim==0 else __fmt_t(v) for v in t]) + ']'\n",
" return __fmt_t(t)\n",
" rng_str = f\", min={_fmt_t(self.min)}, max={_fmt_t(self.max)}\" if self._range else \"\"\n",
" return f\"RunningStatistics(n={self.n}, mean={_fmt_t(self.mean)}, std={_fmt_t(self.std)}{rng_str})\"\n",
"\n",
"def collect_stats(items:Iterable, n_dims:int=2, record_range:bool=False):\n",
" stats = RunningStatistics(n_dims, record_range)\n",
" for it in progress_bar(items):\n",
" it = getattr(it, 'data', it) # Use data from fastai Image\n",
" stats.update(it)\n",
" return stats"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {
"Collapsed": "false"
},
"outputs": [
{
"data": {
"text/html": [
"\n",
" <div>\n",
" <style>\n",
" /* Turns off some styling */\n",
" progress {\n",
" /* gets rid of default border in Firefox and Opera. */\n",
" border: none;\n",
" /* Needs to be in here for Safari polyfill so background images work as expected. */\n",
" background-size: auto;\n",
" }\n",
" .progress-bar-interrupted, .progress-bar-interrupted::-webkit-progress-bar {\n",
" background: #F44336;\n",
" }\n",
" </style>\n",
" <progress value='60000' class='' max='60000', style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
" 100.00% [60000/60000 00:15<00:00]\n",
" </div>\n",
" "
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"RunningStatistics(n=47040000, mean=[0.131,0.131,0.131], std=[0.308,0.308,0.308])"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"stats = collect_stats(src.train.x)\n",
"stats"
]
},
{
"cell_type": "markdown",
"metadata": {
"Collapsed": "false"
},
"source": [
"Algorithm seems to work, slight innaccuracy but seems numerically stable (in my local version I `assert_allclose(..., rtol=0.001, atol=0.01)` in unit tests on `torch.randn` data)"
]
},
{
"cell_type": "markdown",
"metadata": {
"Collapsed": "false"
},
"source": [
"## Apply stats"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {
"Collapsed": "false"
},
"outputs": [
{
"data": {
"text/plain": [
"([0.131, 0.131, 0.131], [0.308, 0.308, 0.308])"
]
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"true_mnist_stats = ([0.131]*3, [0.308]*3)\n",
"true_mnist_stats"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {
"Collapsed": "false"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"mean= 0.021, std=1.016\n",
"mean= 0.032, std=1.022\n",
"mean= 0.046, std=1.050\n",
"mean= 0.060, std=1.058\n",
"mean= 0.047, std=1.045\n"
]
}
],
"source": [
"data = src.databunch(bs=4, num_workers=0).normalize(true_mnist_stats)\n",
"for xb,_ in islice(data.train_dl, 5):\n",
" print(f\"mean={xb.mean().item(): 0.3f}, std={xb.std().item():0.3f}\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"Collapsed": "false"
},
"source": [
"Looks right."
]
},
{
"cell_type": "markdown",
"metadata": {
"Collapsed": "false"
},
"source": [
"However, for grayscale:"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {
"Collapsed": "false"
},
"outputs": [
{
"data": {
"text/plain": [
"torch.Size([8, 3, 28, 28])"
]
},
"execution_count": 12,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"data = (ImageList.from_folder(DATA, convert_mode='L')\n",
" .split_by_folder(train='training', valid='testing')\n",
" .label_from_folder()\n",
" .databunch(bs=8, num_workers=0)\n",
" .normalize(true_mnist_stats))\n",
"xb,_ = next(iter(data.train_dl))\n",
"xb.shape"
]
},
{
"cell_type": "markdown",
"metadata": {
"Collapsed": "false"
},
"source": [
"It's broadcast the image up to 3 channels. So:"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {
"Collapsed": "false"
},
"outputs": [],
"source": [
"good_mnist_stats = ([0.131], [0.308])"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {
"Collapsed": "false"
},
"outputs": [
{
"data": {
"text/plain": [
"torch.Size([8, 1, 28, 28])"
]
},
"execution_count": 14,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"data = (ImageList.from_folder(DATA, convert_mode='L')\n",
" .split_by_folder(train='training', valid='testing')\n",
" .label_from_folder()\n",
" .databunch(bs=8, num_workers=0)\n",
" .normalize(good_mnist_stats))\n",
"xb,_ = next(iter(data.train_dl))\n",
"xb.shape"
]
},
{
"cell_type": "markdown",
"metadata": {
"Collapsed": "false"
},
"source": [
"These work for RGB as well where it broadcasts the stats:"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {
"Collapsed": "false"
},
"outputs": [
{
"data": {
"text/plain": [
"torch.Size([8, 3, 28, 28])"
]
},
"execution_count": 15,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"data = (ImageList.from_folder(DATA)\n",
" .split_by_folder(train='training', valid='testing')\n",
" .label_from_folder()\n",
" .databunch(bs=8, num_workers=0)\n",
" .normalize(good_mnist_stats))\n",
"xb,_ = next(iter(data.train_dl))\n",
"xb.shape"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python [conda env:.conda-fastai-dev]",
"language": "python",
"name": "conda-env-.conda-fastai-dev-py"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.4"
}
},
"nbformat": 4,
"nbformat_minor": 4
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment