Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save shadiakiki1986/1ee234b4e18924dbf4a49f9c4895e345 to your computer and use it in GitHub Desktop.
Save shadiakiki1986/1ee234b4e18924dbf4a49f9c4895e345 to your computer and use it in GitHub Desktop.
sklearn preprocessor - image centering - improves SVM-RBF against translation jitter.ipynb
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"name": "sklearn preprocessor - image centering - improves SVM-RBF against translation jitter.ipynb",
"provenance": [],
"collapsed_sections": [],
"authorship_tag": "ABX9TyPKcZri36zDOzJcE2a56Lx2",
"include_colab_link": true
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"language_info": {
"name": "python"
}
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/gist/shadiakiki1986/1ee234b4e18924dbf4a49f9c4895e345/sklearn-preprocessor-image-centering-improves-svm-rbf-against-translation-jitter.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "8GaAZFCMWaw4"
},
"source": [
"Based on sklearn's [digit classification example](https://scikit-learn.org/stable/auto_examples/classification/plot_digits_classification.html#sphx-glr-auto-examples-classification-plot-digits-classification-py)\n",
"\n",
"Uses my digits dataset with translation jitter: https://github.com/shadiakiki1986/mnist-digits-jitter\n",
"\n",
"Published as gist: [sklearn preprocessor - image centering - improves SVM-RBF against translation jitter.ipynb](https://gist.github.com/shadiakiki1986/1ee234b4e18924dbf4a49f9c4895e345)\n",
"\n",
"Related gist but without centering: [SVM-RBF sensitivity to translation.ipynb](https://gist.github.com/shadiakiki1986/689980135fe9dde1d892127bde40a5a1)\n",
"\n",
"Proposed the image centering preprocessor as an sklearn feature: https://github.com/scikit-learn/scikit-learn/issues/20888"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "6iuapqgYnf9w"
},
"source": [
"# dependencies"
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "RrAkpWbdfPMu",
"outputId": "d905aa91-8c74-423b-b296-cbd5f025c444"
},
"source": [
"# then the padded and jittered data\n",
"!git clone https://github.com/shadiakiki1986/mnist-digits-jitter\n",
"\n",
"# Update: no need to gunzip since np.loadtxt can automatically do it\n",
"#!gunzip mnist-digits-jitter/digits_padded.csv.gz\n",
"#!gunzip mnist-digits-jitter/digits_jitter.csv.gz"
],
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"text": [
"Cloning into 'mnist-digits-jitter'...\n",
"remote: Enumerating objects: 47, done.\u001b[K\n",
"remote: Counting objects: 100% (47/47), done.\u001b[K\n",
"remote: Compressing objects: 100% (44/44), done.\u001b[K\n",
"remote: Total 47 (delta 20), reused 8 (delta 1), pack-reused 0\u001b[K\n",
"Unpacking objects: 100% (47/47), done.\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "qln7PdsNnbGq"
},
"source": [
"# if uploading instead of github\n",
"#!mkdir -p mnist-digits-jitter\n",
"#!mv *gz mnist-digits-jitter/"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "HGkCzTmwXS4o"
},
"source": [
"# get data"
]
},
{
"cell_type": "code",
"metadata": {
"id": "wxWb-jzbXTdv"
},
"source": [
"# first the original data\n",
"from sklearn.datasets import load_digits\n",
"digits = load_digits()"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "66oFvhVEfZMk",
"outputId": "6d81d0f1-9d7c-4bc6-df76-854744bda83d"
},
"source": [
"import numpy as np\n",
"# np.loadtxt can decompress the files on read\n",
"digpad = {\"data\": np.loadtxt(\"mnist-digits-jitter/digits_padded.csv.gz\", delimiter=\",\", dtype=int)}\n",
"digjit = {\"data\": np.loadtxt(\"mnist-digits-jitter/digits_jitter.csv.gz\", delimiter=\",\", dtype=int)}\n",
"digpad[\"data\"].shape, digjit[\"data\"].shape"
],
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"((1797, 225), (1797, 225))"
]
},
"metadata": {},
"execution_count": 4
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "nBjvLW_9f3YM"
},
"source": [
"# convert data to image (not flat)\n",
"#def im2data(digxxx_img):\n",
"# return np.vstack([img.reshape((-1,1)).squeeze() for img in digxxx_img])\n",
"\n",
"def data2im(digxxx_data):\n",
" \"\"\"\n",
" Inverse of img_to_array from https://keras.io/api/preprocessing/image/#img_to_array-function\n",
" \"\"\"\n",
" s = int(digxxx_data.shape[1]**.5)\n",
" l = digxxx_data.reshape((-1,s,s))\n",
" return l\n",
"\n",
"digjit[\"images\"] = data2im(digjit[\"data\"])\n",
"digpad[\"images\"] = data2im(digpad[\"data\"])\n",
"\n",
"assert digjit[\"data\"].shape == (1797, 225)\n",
"assert digpad[\"data\"].shape == (1797, 225)\n",
"assert digjit[\"images\"].shape == (1797, 15, 15)\n",
"assert digpad[\"images\"].shape == (1797, 15, 15)"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "-lWgeNt6oWVl"
},
"source": [
"# utility plotting code\n",
"# From https://www.codespeedy.com/image-augmentation-using-skimage-in-python/\n",
"#basic Function to display image side by side\n",
"from matplotlib import pyplot as plt\n",
"def plot_side(img1, img2, title1, title2, cmap = None):\n",
" fig = plt.figure(tight_layout='auto', figsize=(5,5))\n",
" fig.add_subplot(221)\n",
" plt.title(title1)\n",
" plt.imshow(img1) #, origin=\"lower\")\n",
" plt.colorbar()\n",
"\n",
" fig.add_subplot(222)\n",
" plt.title(title2)\n",
" plt.imshow(img2) #, cmap = None, origin=\"lower\")\n",
" plt.colorbar()\n",
" return fig"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 199
},
"id": "oNH6xGZnoXnJ",
"outputId": "1269d88b-022f-4064-98bc-664882e117a0"
},
"source": [
"plot_side(digpad[\"images\"][0], digjit[\"images\"][0], \"no jitter\", \"with jitter\")\n",
"plt.show()"
],
"execution_count": null,
"outputs": [
{
"output_type": "display_data",
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 360x360 with 4 Axes>"
]
},
"metadata": {
"needs_background": "light"
}
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "srQUt0XE1Hc3"
},
"source": [
"# util: translate function"
]
},
{
"cell_type": "code",
"metadata": {
"id": "mCoaqbHz1Iw4"
},
"source": [
"from skimage.transform import warp\n",
"#from skimage.transform import SimilarityTransform\n",
"from skimage.transform import EuclideanTransform\n",
"\n",
"# use sklearn transform\n",
"# https://scikit-image.org/docs/dev/api/skimage.transform.html\n",
"my_translate = lambda img, jx, jy: warp(img, EuclideanTransform(translation=(jx, jy)), preserve_range=True).astype(int)"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "5iv7_IKT1LJg"
},
"source": [
"# test\n",
"#print(digpad[\"images\"][0])\n",
"#print(my_translate(digpad[\"images\"][0], 0, 0))\n",
"\n",
"assert (digpad[\"images\"][0] == my_translate(digpad[\"images\"][0], 0, 0)).all()"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "BvDdpxZIUVXe"
},
"source": [
"# preprocessing: center images\n",
"\n",
"Use take 2 with take 3"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "5z30q5YbXn6e"
},
"source": [
"## take 2: function"
]
},
{
"cell_type": "code",
"metadata": {
"id": "tXW80H44UXqg"
},
"source": [
"# Custom sklearn transformer that centers the images\n",
"# http://scipy-lectures.org/packages/scikit-image/#measuring-regions-properties\n",
"# https://scikit-image.org/docs/dev/auto_examples/segmentation/plot_regionprops.html\n",
"# https://kapernikov.com/tutorial-image-classification-with-scikit-learn/\n",
"from skimage import measure\n",
"from matplotlib import pyplot as plt\n",
"\n",
"def my_flatarray_to_square(x):\n",
" assert len(x.shape) == 1\n",
" s = x.shape[0]**.5\n",
" assert s == int(s)\n",
" s = int(s)\n",
" x = x.reshape((s,s))\n",
" return x\n",
"\n",
"\n",
"def my_center(X, flat_in, flat_out, verbose):\n",
" \"\"\"\n",
" X : if flat_in=False, one of the following:\n",
" - list of length n_samples where each entry is an array-like of shape (image dim 1, image dim 2)\n",
" - array-like of shape (n_samples, image dim 1, image dim 2)\n",
" - array-like of shape (image dim 1, image dim 2)\n",
" If flat_in=True, one of:\n",
" - list of length n_samples where each entry is an array-like of shape (image dim 1 * image dim 2, )\n",
" - array-like of shape (n_samples, image dim 1 * image dim 2)\n",
" - array-like of shape (image dim 1 * image dim 2,)\n",
"\n",
" flat_in: boolean. True => input is flat image or list of flat images or array of flat images\n",
" flat_out: boolean. True => return array of shape, depending on the input:\n",
" - (dim 1 * dim 2, )\n",
" - (n_samples, dim 1 * dim 2)\n",
" \"\"\"\n",
" if type(X) not in [list, np.ndarray]:\n",
" raise ValueError(f\"Unsupported type: {type(X)}\")\n",
"\n",
" do_multi = False\n",
" if type(X)==list:\n",
" do_multi = True\n",
" for i,x in enumerate(X):\n",
" if flat_in:\n",
" X[i] = my_flatarray_to_square(x)\n",
" else:\n",
" assert len(x.shape) == 2\n",
" elif type(X) == np.ndarray:\n",
" if flat_in:\n",
" if len(X.shape) == 1:\n",
" do_multi = False\n",
" X = my_flatarray_to_square(X)\n",
" elif len(X.shape) == 2:\n",
" do_multi = True\n",
" X = np.array([my_flatarray_to_square(x) for x in X])\n",
" else:\n",
" raise ValueError(f\"When flat_in=True, require X.shape in [1,2]. Found {X.shape}.\")\n",
" else:\n",
" if len(X.shape) == 2:\n",
" do_multi = False\n",
" elif len(X.shape) == 3:\n",
" do_multi = True\n",
" else:\n",
" raise ValueError(f\"Unsupported X.shape={X.shape}. Need len(X.shape) in [2,3]\")\n",
"\n",
" if do_multi:\n",
" o = [my_center(img, False, flat_out, verbose) for img in X]\n",
" o = np.array(o)\n",
" return o\n",
"\n",
" \n",
" img_jit = X\n",
" img_lab = measure.label(img_jit>0)\n",
" regions = measure.regionprops(img_lab)\n",
"\n",
" if verbose:\n",
" print(\"labels\")\n",
" print(img_lab)\n",
" print(\"regions\")\n",
" print([r.centroid for r in regions])\n",
" plt.imshow(img_jit)\n",
" for r in regions:\n",
" c = r.centroid\n",
" plt.scatter(c[1], c[0], color=\"red\")\n",
" \n",
" plt.title(\"input and centroid\")\n",
" plt.show()\n",
"\n",
" assert len(regions)==1\n",
"\n",
" # flip the x and y since I want the 1st dimension to be the \"y\"\n",
" c_dig = (regions[0].centroid[1], regions[0].centroid[0])\n",
"\n",
" # center of image\n",
" c_img = tuple([x//2 for x in img_jit.shape])\n",
"\n",
" # delta\n",
" d = tuple([round(x1-x2) for x1, x2 in zip(c_dig, c_img)])\n",
"\n",
" img_cent = my_translate(img_jit, d[0], d[1])\n",
"\n",
" if verbose:\n",
" plot_side(img_jit, img_cent, \"with jitter\", \"centered\")\n",
" plt.scatter(c_dig[0], c_dig[1], color=\"red\")\n",
" plt.scatter(c_img[0], c_img[1], color=\"green\")\n",
" plt.show()\n",
"\n",
" if flat_out:\n",
" assert len(img_cent.shape)==2\n",
" img_cent = img_cent.reshape((img_cent.shape[0]*img_cent.shape[1], 1)).squeeze()\n",
"\n",
" return img_cent"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 1000
},
"id": "D_l9jVZ4boNq",
"outputId": "0e1c35e0-ef05-441f-a342-09cd1643d410"
},
"source": [
"# test on single image\n",
"# Update: Digit 7 uncovered a major bug in flipping the x,y of the centroid\n",
"i = 7\n",
"print(digjit[\"images\"][i])\n",
"img_cent = my_center(digjit[\"images\"][i], flat_in=False, flat_out=False, verbose=True)\n",
"assert img_cent.shape == (15, 15)"
],
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"text": [
"[[ 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]\n",
" [ 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]\n",
" [ 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]\n",
" [ 0 0 0 0 0 0 0 0 7 8 13 16 15 1 0]\n",
" [ 0 0 0 0 0 0 0 0 7 7 4 11 12 0 0]\n",
" [ 0 0 0 0 0 0 0 0 0 0 8 13 1 0 0]\n",
" [ 0 0 0 0 0 0 0 4 8 8 15 15 6 0 0]\n",
" [ 0 0 0 0 0 0 0 2 11 15 15 4 0 0 0]\n",
" [ 0 0 0 0 0 0 0 0 0 16 5 0 0 0 0]\n",
" [ 0 0 0 0 0 0 0 0 9 15 1 0 0 0 0]\n",
" [ 0 0 0 0 0 0 0 0 13 5 0 0 0 0 0]\n",
" [ 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]\n",
" [ 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]\n",
" [ 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]\n",
" [ 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]]\n",
"labels\n",
"[[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]\n",
" [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]\n",
" [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]\n",
" [0 0 0 0 0 0 0 0 1 1 1 1 1 1 0]\n",
" [0 0 0 0 0 0 0 0 1 1 1 1 1 0 0]\n",
" [0 0 0 0 0 0 0 0 0 0 1 1 1 0 0]\n",
" [0 0 0 0 0 0 0 1 1 1 1 1 1 0 0]\n",
" [0 0 0 0 0 0 0 1 1 1 1 1 0 0 0]\n",
" [0 0 0 0 0 0 0 0 0 1 1 0 0 0 0]\n",
" [0 0 0 0 0 0 0 0 1 1 1 0 0 0 0]\n",
" [0 0 0 0 0 0 0 0 1 1 0 0 0 0 0]\n",
" [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]\n",
" [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]\n",
" [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]\n",
" [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]]\n",
"regions\n",
"[(5.84375, 9.71875)]\n"
],
"name": "stdout"
},
{
"output_type": "display_data",
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAPsAAAEICAYAAACZA4KlAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAARx0lEQVR4nO3dfbBcdX3H8feH3IRAiCERBfJEeJIOOlboBcFSyjSISKkRx07xoQVJZSxVwMrQIA5Qh7aordZOndoMoKiIVCWKDigBH1GIQgpCCIQAIRBCggVDJCAJfPvH+V3YbHfvvTkPezf5fV4zd+7Zc35nz/ee3c89D3v2/BQRmNmOb6exLsDMesNhN8uEw26WCYfdLBMOu1kmHHazTDjsNZC0TNIxY11H3SR9UdLFY13HtpA0W9JvJY3rMv0iSV/pdV39wGGvQUS8NiJ+1PRytsfwbStJqyQdW3b+iFgdEbtFxAt11rUjcNhtuyJpYKxr2F457DVo3Rql3cT/lvQlSRvTLv5gW9vzJN0j6SlJX5A0MU07VdLNbc8dkg6QdDrwHuDctJv6nS61fFbSI5KelnS7pD9qmTZSbYdIWpqmXQ1MHOHvfr+k5an9PZIOTeOnS/qmpCckPSTpzNHUIOnLwGzgO+lvPFfSnLQO5ktaDfxA0k6SPibpYUnr03NNSc8x1H4gPd5X0o/TshYDe4z4gu6oIsI/FX+AVcCxafgi4DngBGAc8M/ArW1t7wZmAdOAnwEXp2mnAje3PXcAB6ThLw61HaaW9wKvBAaAjwCPAxNHqg2YADwMfBgYD7wT2NxtecCfA2uAwwABBwD7UGxAbgcuSM+5H/Ag8JZtWD/Htjyek9bBl4BJwC7AacDK9Ny7AdcAX25rP5Ae3wJ8GtgZOBrYCHxlrN8zY/I+HesCdoSfDmG/sWXawcCzbW0/0PL4BOCBNFw57B1qewr4/ZFqS0F4DFDL9J8PE/bvA2d1GP9GYHXbuPOAL2zD+ukU9v1axt0EnNHy+CCKf0wDrWGn2EvYAkxqafvVXMPu459mPN4yvAmYKGkgIrakcY+0TH8YmF7XgiWdA8xPzxnAK9h617Vjban9mkiJaKmtm1nAAx3G7wNMl/SblnHjgJ+OVEPL+umkdZ1Nb6vtYYpw79k2z3TgqYh4pq3trGGWs8Ny2MdG65ttNsUWFeAZYNehCZL2aptv2K8opuPzc4G5wLKIeFHSUxS72SNZC8yQpJbAz6ZzoKEI3/5dxj8UEQeOYpmddPsbW8c/RvFPZcjQFnwdMLNl/FpgqqRJLYGfPcwydmg+QTc2/lbSTEnTgPOBq9P4O4HXSnpDOml3Udt86yiOU7uZTPGmfwIYkHQBxZZ9NG5J854pabykdwCHD9P+UuAcSX+gwgGS9gF+AWyU9PeSdpE0TtLrJB02yjpG+hsBrgI+nE6+7Qb8E3B1+55BRDwM3Ab8g6QJko4C/myUdexwHPax8VXgBooTVw8AFwNExArg48CNwP3AzW3zXQYcLOk3kr7V4Xm/D3wPWEGxu/ocW+/+dhURzwPvoDhv8CTwFxQnvrq1/zrwj+lv2Qh8C5gWxefbJwJvAB4Cfk3xj2HKaOqgOGH3sfQ3ntOlzeXAl4GfpGU8B3yoS9t3U5xHeBK4kOJEX5a09SGaNU3SKuCvI+LGsa7F8uItu1kmHHazTHg33iwT3rKbZaKnn7NP0M4xkUm9XKRZVp7jGZ6P33W8rqKnYZ/IJN6oub1cpFlWlsRNXad5N94sEw67WSYqhV3S8ZLuk7RS0oK6ijKz+pUOe7rH1+eAt1J8TfFdkg6uqzAzq1eVLfvhwMqIeDBdV/01YF49ZZlZ3aqEfQZbf8ni0TTOzPpQ4x+9pXunnQ4w8eWvaptZj1XZsq9h65swzEzjthIRCyNiMCIGx7NzhcWZWRVVwv5L4MB0A4EJwMnAtfWUZWZ1K70bHxFbJH2Q4oYJ44DLI2JZbZWZWa0qHbNHxHXAdTXVYmYN8hV0Zplw2M0y4bCbZcJhN8uEw26WCYfdLBMOu1kmHHazTDjsZplw2M0y4bCbZcJhN8uEw26WCYfdLBMOu1kmHHazTDjsZplw2M0y4bCbZcJhN8tElb7eZkn6oaR7JC2TdFadhZlZvarcXXYL8JGIWCppMnC7pMURcU9NtZlZjUpv2SNibUQsTcMbgeW4rzezvlXLMbukOcAhwJI6ns/M6le5Y0dJuwHfBM6OiKc7THfHjmZ9oNKWXdJ4iqBfGRHXdGrjjh3N+kOVs/ECLgOWR8Sn6yvJzJpQZcv+h8BfAn8i6Y70c0JNdZlZzar04nozoBprMbMG+Qo6s0w47GaZqPzRm+Vr9UVvKj3v81NeLDXf/Lk/LL3Mj+5xX6n5Vmx+pvQyzz78pFLzvbBufellduMtu1kmHHazTDjsZplw2M0y4bCbZcJhN8uEw26WCYfdLBMOu1kmHHazTDjsZplw2M0y4bCbZcLferMxMWFDue3M9RceU3qZi8/4vVLzzZn8ZOllNvHttbK8ZTfLhMNulgmH3SwTlcMuaZyk/5H03ToKMrNm1LFlP4uinzcz62NVe4SZCfwpcGk95ZhZU6pu2f8NOBcod/dAM+uZKt0/nQisj4jbR2h3uqTbJN22md+VXZyZVVS1+6e3SVoFfI2iG6ivtDdyx45m/aF02CPivIiYGRFzgJOBH0TEe2urzMxq5c/ZzTJRy7XxEfEj4Ed1PJeZNcNbdrNMOOxmmfBXXK202Rf9vOfLXPmZI0rPO3/Pe0vNd/Ob9ym9TNhYYd56ectulgmH3SwTDrtZJhx2s0w47GaZcNjNMuGwm2XCYTfLhMNulgmH3SwTDrtZJhx2s0w47GaZ8LfejE0nvbHUfI8drZorGdn17/jXni/z6nfPLT3vXp9xx45m1mMOu213Ji/axP5HruOgfday/5HrmLxo01iXtF3wbrxtVyYv2sTeCzaw07PF4/FrXmTvBRsA2HjSrmNYWf+r2v3T7pK+IeleScslHVlXYWadvPqTG18K+pCdni3G2/Cqbtk/C3wvIt4paQLgf63WqIHHOvc01m28vaxK909TgKOBywAi4vmI+E1dhZl1smV657dst/H2sipraF/gCeALqX/2SyVNqqkus47WnzuZF3fZetyLuxTjbXhVwj4AHAr8Z0QcAjwDLGhv5I4drU4bT9qVtZdMYfOMnQjB5hk7sfaSKT45NwpVjtkfBR6NiCXp8TfoEPaIWAgsBHiFpkWF5ZkBReAd7m1XpWPHx4FHJB2URs0F7qmlKjOrXdWz8R8Crkxn4h8E3le9JDNrQqWwR8QdwGBNtZhZg/x5hVkmHHazTPjaeGPyinLXQs0+47nSy/yv13y19LxlzT/770rNt9ei3ndg2QRv2c0y4bCbZcJhN8uEw26WCYfdLBMOu1kmHHazTDjsZplw2M0y4bCbZcJhN8uEw26WCYfdLBP+1pvxwrL7Ss034c3ll/max8rdiPiw8/+m9DKnLbql9Lw7Am/ZzTLhsJtlwmE3y0TVjh0/LGmZpLslXSVpYl2FmVm9qvT1NgM4ExiMiNcB44CT6yrMzOpVdTd+ANhF0gBFD66PVS/JzJpQpUeYNcC/AKuBtcCGiLihrsLMrF5VduOnAvMoenOdDkyS9N4O7dyxo1kfqLIbfyzwUEQ8ERGbgWuAN7U3ioiFETEYEYPj2bnC4sysiiphXw0cIWlXSaLo2HF5PWWZWd2qHLMvoeimeSlwV3quhTXVZWY1q9qx44XAhTXVYmYN8hV0Zplw2M0y4a+4WmkrLh8sP+/mn5Wa71XXPVB6mS+UnnPH4C27WSYcdrNMOOxmmXDYzTLhsJtlwmE3y4TDbpYJh90sEw67WSYcdrNMOOxmmXDYzTLhsJtlwt96s9LeP/jT0vO+56JzSs03bV3enTNW4S27WSYcdrNMOOxmmRgx7JIul7Re0t0t46ZJWizp/vR7arNlmllVo9myfxE4vm3cAuCmiDgQuCk9NrM+NmLYI+InwJNto+cBV6ThK4C311yXmdWs7DH7nhGxNg0/DuxZUz1m1pDKJ+giIoDoNt0dO5r1h7JhXydpb4D0e323hu7Y0aw/lA37tcApafgU4Nv1lGNmTRnNR29XAbcAB0l6VNJ84BLgzZLup+i6+ZJmyzSzqka8Nj4i3tVl0tyaazGzBvkKOrNMOOxmmfBXXK20H79+l9LzTsNfVe01b9nNMuGwm2XCYTfLhMNulgmH3SwTDrtZJhx2s0w47GaZcNjNMuGwm2XCYTfLhMNulgmH3SwTDrtZJhx2s0w47GaZcNjNMlG2Y8dPSbpX0q8kLZK0e7NlmllVZTt2XAy8LiJeD6wAzqu5LjOrWamOHSPihojYkh7eCsxsoDYzq1Edx+ynAdfX8Dxm1qBKd5eVdD6wBbhymDanA6cDTGTXKoszswpKh13SqcCJwNzUk2tHEbEQWAjwCk3r2s7MmlUq7JKOB84F/jgiNtVbkpk1oWzHjv8BTAYWS7pD0ucbrtPMKirbseNlDdRiZg3yFXRmmXDYzTLhsJtlwmE3y4TDbpYJh90sEw67WSYcdrNMOOxmmXDYzTLhsJtlwmE3y4TDbpYJh90sEw67WSYcdrNMOOxmmXDYzTLhsJtlwmE3y0Spjh1bpn1EUkjao5nyzKwuZTt2RNIs4Dhgdc01mVkDSnXsmHyGoqMI9/Jith0odcwuaR6wJiLurLkeM2vINnf/JGlX4KMUu/Cjae+OHc36QJkt+/7AvsCdklZR9M2+VNJenRpHxMKIGIyIwfHsXL5SM6tkm7fsEXEX8OqhxynwgxHx6xrrMrOale3Y0cy2M2U7dmydPqe2asysMb6CziwTDrtZJhTRu2tiJD0BPNxl8h5AP53k67d6oP9qcj3DG4t69omIV3Wa0NOwD0fSbRExONZ1DOm3eqD/anI9w+u3erwbb5YJh90sE/0U9oVjXUCbfqsH+q8m1zO8vqqnb47ZzaxZ/bRlN7MGOexmmeh52CUdL+k+SSslLegwfWdJV6fpSyTNabCWWZJ+KOkeScskndWhzTGSNki6I/1c0FQ9LctcJemutLzbOkyXpH9P6+hXkg5tsJaDWv72OyQ9LenstjaNrqNOt0aTNE3SYkn3p99Tu8x7Smpzv6RTGqznU5LuTa/HIkm7d5l32Ne2URHRsx9gHPAAsB8wAbgTOLitzRnA59PwycDVDdazN3BoGp4MrOhQzzHAd3u8nlYBewwz/QTgekDAEcCSHr5+j1NcuNGzdQQcDRwK3N0y7pPAgjS8APhEh/mmAQ+m31PT8NSG6jkOGEjDn+hUz2he2yZ/er1lPxxYGREPRsTzwNeAeW1t5gFXpOFvAHMlqYliImJtRCxNwxuB5cCMJpZVs3nAl6JwK7C7pL17sNy5wAMR0e0qyEZE51ujtb5PrgDe3mHWtwCLI+LJiHgKWEyH+ynWUU9E3BARW9LDWynu89BXeh32GcAjLY8f5f+H66U2aeVtAF7ZdGHpcOEQYEmHyUdKulPS9ZJe23QtFPf1u0HS7elOP+1Gsx6bcDJwVZdpvV5He0bE2jT8OLBnhzZjtZ5Oo9jz6mSk17Yx23zzih2RpN2AbwJnR8TTbZOXUuy2/lbSCcC3gAMbLumoiFgj6dXAYkn3pq3JmJE0AXgbcF6HyWOxjl4SESGpLz5DlnQ+sAW4skuTMXtte71lXwPMank8M43r2EbSADAF+N+mCpI0niLoV0bENe3TI+LpiPhtGr4OGN/0ffIjYk36vR5YRHH402o067FubwWWRsS69gljsY6AdUOHLun3+g5terqeJJ0KnAi8J9IBertRvLaN6XXYfwkcKGnftKU4Gbi2rc21wNBZ03cCP+i24qpK5wIuA5ZHxKe7tNlr6JyBpMMp1lmT/3wmSZo8NExx4qe9g45rgb9KZ+WPADa07NI25V102YXv9TpKWt8npwDf7tDm+8Bxkqams/XHpXG1k3Q8xa3V3xYRm7q0Gc1r25xenxGkOJO8guKs/Plp3McpVhLARODrwErgF8B+DdZyFMUx1K+AO9LPCcAHgA+kNh8EllF8cnAr8KaG189+aVl3puUOraPWmgR8Lq3DuyjuAdhkTZMowjulZVzP1hHFP5m1wGaK4+75FOdxbgLuB24EpqW2g8ClLfOelt5LK4H3NVjPSorzA0Pvo6FPlKYD1w332vbqx5fLmmXCV9CZZcJhN8uEw26WCYfdLBMOu1kmHHazTDjsZpn4P/6RcSZqfqNsAAAAAElFTkSuQmCC\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
}
},
{
"output_type": "display_data",
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 360x360 with 4 Axes>"
]
},
"metadata": {
"needs_background": "light"
}
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "AGjOrHTopgNF"
},
"source": [
"# test single, flat_out\n",
"assert my_center(digjit[\"images\"][0], flat_in=False, flat_out=True, verbose=False).shape == (225,)"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "RjqjBDb1slwu"
},
"source": [
"# test single, flat_in, flat_out\n",
"assert my_center(digjit[\"data\"][0], flat_in=True, flat_out=True, verbose=False).shape == (225,)"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "kXSQk_iEs1q_"
},
"source": [
"# test list, flat_out\n",
"assert my_center(digjit[\"images\"][:3], flat_in=False, flat_out=True, verbose=False).shape == (3, 225)"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "2I09sdTytYR3"
},
"source": [
"# test list, flat_in, flat_out\n",
"assert digjit[\"data\"][:3].shape == (3, 225)\n",
"assert my_center(digjit[\"data\"][:3], flat_in=True, flat_out=True, verbose=False).shape == (3, 225)"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "3nzOrrDCgxtq"
},
"source": [
"## take 4: sklearn function transformer"
]
},
{
"cell_type": "code",
"metadata": {
"id": "aG4T2b6SgztF"
},
"source": [
"# https://scikit-learn.org/stable/modules/generated/sklearn.preprocessing.FunctionTransformer.html\n",
"from sklearn.preprocessing import FunctionTransformer\n",
"image_center_transformer = FunctionTransformer(lambda X: my_center(X, True, True, False))"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 1000
},
"id": "t-A5wYvThMzp",
"outputId": "1ada57b6-71c2-490f-fd0f-5842246d6077"
},
"source": [
"# test\n",
"n_test = 10\n",
"img_cent = image_center_transformer.fit_transform(digjit[\"data\"][:n_test])\n",
"\n",
"for i in range(n_test):\n",
" print(f\"Label #{i}: {digits.target[i]}\")\n",
" plot_side(digjit[\"images\"][i], img_cent[i, :].reshape((15,15)), \"with jitter\", \"centered\")\n",
" plt.show()"
],
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"text": [
"Label #0: 0\n"
],
"name": "stdout"
},
{
"output_type": "display_data",
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAWIAAAC2CAYAAADjhIf3AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAYjUlEQVR4nO3df7RcZX3v8feHAFIDohiIGDDQa7TFH9BbimWV24ZSIbDgQu+iGJZFFHpTXfWue29Ri9YKy9YW29raLqxZqaZQiyDXGkWNQqStiEUl4SKCoAQKJhESQwIEFcI553v/2M/kToZ59syZmTMzzzmf11p7ZWbvPXs/J/mcb57ZP56tiMDMzEZnn1E3wMxsrnMhNjMbMRdiM7MRcyE2MxsxF2IzsxFzITYzGzEXYpsxklZL2ibp7qZ5l0vaIunONJ2R+ewySd+TtFHSpcNrtdnwsytfR2wzRdKvAk8B/xgRr07zLgeeioi/rPncPOD7wOuBzcDtwPkR8d0Zb7QZw8+ue8Q2YyLiFmBHDx89AdgYEQ9GxG7gOuDsgTbOrMaws7tvDzuyOeK0k+fHYzsms8s33PXMPcDTTbNWRcSqLjb9dklvAtYDl0TEzpbli4BNTe83A6/rrtVm5WXXhdiyHtsxybdufFl2+bzD7386Io6f5mY/CvwxEOnPDwEX9dxIszZKy64LsWUFwbMxMdhtRmxtvJb098AX2qy2BTiy6f0RaZ5ZV0rLro8RW1YAE0xmp15IOrzp7W8Cd7dZ7XZgiaSjJe0PLAdu6GmHNieVll33iC0rCCb7uKpG0rXAUmCBpM3AZcBSScdR/a48BPxuWvelwMci4oyImJD0duBGYB6wOiLu6ednsbmltOz68jXLOu7Y/ePmLx2WXb5g0ZYNPRxnM5txpWXXPWKrNYX/o7YylZRdF2LLCuBZf2OyApWWXRdiy4oIdhcUZrOG0rLrQmxZAUyNuhFmPSgtuy7ElhWIZ0OjbobZtJWWXRdiywpgty81twKVll0XYqs1VVCvwqxZSdkt57+MAZH0lKSfrVn+kKTf6HJbb5R0U7fbLs0UYjfzspNZq+n8/syk0rI75wpxRBwYEQ8CSLpK0p/0sa1rIuLUbrc9LiGdjqlQdrLxJunNkm4ddTtGpaTs+tBEISSJ6k7IoZ0MDsTuGL/egw2HpH0jBjxyzpCUlt1Z0SOW9BZJn296f7+k/9P0flO6RxxJIenlklYAbwTelQ4pfL5pk8dJukvSE5I+JemAzH736nHUbVvSJ4CXAZ9P896VPvPLkv5d0uOSvi1padP2/k3SByR9HfgJMNTDHgE8y7zsZIMl6UhJn5H0I0mPSboyzb9I0r2Sdkq6UdLips+EpLemzD8u6SOq/DywEjgx5e3xtP7zJP2lpB9I2ipppaSfScuWStos6Q8kPQr8g6R9JF0q6YHUpuslHdK0/wskPZyW/eFQ/8JqlJbdWVGIga8C/yWF5qXA/sCJAOmY7YHAXc0fSINAXwP8eTqkcFbT4vOAZcDRwGuBN0+nMe22HREXAD8Azkrz/lzSIuCLwJ8AhwDvAP5Z0qFNm7sAWAEcBDw8nXb0K0JMxj7ZyQYnPWLnC1T/xkdRDTB+naSzgfcA/w04FPgacG3Lx88Efokqq+cBp0XEvcBbgdtS3l6Y1r0CeAVwHPDytJ/3NW3rJVRZXEyVu/8BnAP8GvBSYCfwkdTmY6jG6L0gLXsx1bCPI1dadsevRT1Ix2V3UYXrV6lGPvqhpJ+jCtDXpvmV/m8j4ocRsQP4fNruTPhtYG1ErI2IqYhYRzXyf/NDCa+KiHsiYiIinp2hdrRVWq+icCdQFbN3RsSPI+LpiLiVqpj+WUTcmw4T/CnVN7bFTZ+9IiIej4gfAP9KJq/p8NYK4H9HxI6I2JW2t7xptSngsoh4JiJ+mvb/hxGxOSKeAS4HzpW0L3Au8IWIuCUt+yPG5D6K0rI7m44Rf5Vq2LqXp9ePUxXhE9P76Xi06fVPqH5BZsJi4LckNffG96P6ZWrYxMhoLHsPs9SRwMNtjskuBv5G0oea5omqJ9v4htSa1wMz+zgUeD6woarJe7bVXJl+FBHNjxBaDKyR1FxgJ4GFVL8Xe/IZET+W9Fhm30NWVnZnWyE+i+pwwp9SFeI3UhXiKzOfmcmb0dttu3XeJuATEfHfp7mdoagGThm/3sMstQl4WZsTZJuAD0TENT1sszU724GfAq+KiNxTI9pl9KKI+HrripIeAX6+6f3zqQ5PjFxp2S3nv4zOvgqcDPxMRGymOpa2jCoY/zfzma3M3AmwdttunfdPwFmSTpM0T9IB6YTJeBxnQzwb+2YnG6hvAY8AV0ian7LwK1Qn3N4t6VUAkg6W9FtdbnMrcISqJ0WQDs/9PfDXkg5L21sk6bSabawEPtA4FCLp0HTcGuDTwJmSTkr7eD9jUlNKy+5Y/KUNQkR8H3iKqgATEU8CDwJfj4jcs1E+DhyTzjZ/dsBNarftPwPem+a9IyI2UT1q+z3Aj6h6H+9kTP5dAphkn+xkg5MyehbVobUfUD399w0RsQb4INWJuyepHs9zepeb/RfgHuBRSdvTvD8ANgLfSNv7CvDKmm38DdWjfm6StAv4BumpxOnJE78HfJLqP5Gdqd0jV1p2/YSOPki6CPjtiPj1UbdlJhz9mgPj8s+8Nrv8za+4bayecmDWUFp2x6+PXpZXAf8x6kbMlAj1dZxN0mqqS6u2RcSr07y/oOr57QYeAN4SEY+3+exDVFfCTAIT4/RLY+OvtOyOXx+9EOlwwzLgQ53WLVVAv9diXkX1d9RsHfDqiHgt8H3g3TWfPzkijnMRtukqLbvuEfcoIs4ZdRtmWnXCo/deRUTcIumolnk3Nb39BtW1qGYDVVp23SO2Wh1OeCyQtL5pWjHNzV8EfCmzLKhOEG3oYbtmRWXXPWLL6qJXsb3XwwZpXIIJqlvB2zkpIraky6zWSbovIm7pZV8295SW3b4KsaRlVJe3zAM+FhFX1K2/v54XBzC/n13aDHiaH7M7nnnO2IAzdVG8pDdTnQg5JTKX7TRuOIiIbZLWUN0CPLBCPJ3sOrfjaxc7t0fEoa3zS8tuz4U4DVLyEeD1VNcO3i7phoj4bu4zBzCf1+mUXndpM+SbcXPb+RFiasC3iaYC+C7g1yLiJ5l15gP7RMSu9PpUqpsFBtWGaWXXuR1fX4lPtx0Iq7Ts9tPSE4CNEfFgROwGrqO6OcFmiUavIjd1Iula4DbglWl4xYupbjc/iOor252SVqZ1XyppbfroQuBWSd+muuPsixHx5QH+aM7uLFdadvs5NLGIvQek2Uy646blB1pBNeITB/D8PnZnw9ffwCkRcX6b2R/PrPtD0qhzaTS9Y3vecWcds+vclq6s7M74ybo0Nu8qgBfoEN/GV5DSBk4ZJOe2bKVlt59CvIVq6L6GI9I8myUCMVFQmKfB2Z3lSstuP8eIbweWSDo6jby0nGpwEJslImAylJ0K5uzOcqVlt+cecURMSHo71dMw5gGr02hMNksEYmKqnF5Ft5zd2a+07PZ1jDgi1gJrO65oRaqOs83Omy+d3dmttOz6zjqrMfhrMc2Go6zsuhBbVkRZvQqzhtKy60JstUrqVZg1Kym7LsSWVV0CVE6YzRpKy64LsWUFFHXm2ayhtOy6EFteiKkxvObSrKPCsutCbFkBRX29M2soLbsuxJYVUFSvwqyhtOy6EFtWdXdSOb0Ks4bSsutCbHlR1tc7sz0Ky64LsWWV9vXOrKG07LoQW1ZpX+/MGkrLbjkttZGYjH2yUyeSVkvaJunupnmHSFon6f7054syn70wrXO/pAsH+CPZHFFSdl2ILSui+nqXm7pwFbCsZd6lwM0RsQS4Ob3fi6RDgMuoHl90AnBZLvRm7ZSWXRdiqyEmp/bJTp1ExC3AjpbZZwNXp9dXA+e0+ehpwLqI2BERO4F1PPeXwqxGWdn1MWKrFfW9hwWS1je9X5We9VZnYUQ8kl4/SvXU21btHu65qFNbzZqVlF0XYsuKgMmp2jBvj4jje99+hCQ/mNMGrrTsuhBbVkBfjyTP2Crp8Ih4RNLhwLY262wBlja9PwL4t0E3ZDbZvuLE7LJ3XnJddtkfbTg7u+wVv/9IdtnEo1u7a9iIlJZdHyO2GvmTHX1co3kD0DiTfCHwuTbr3AicKulF6UTHqWmeWZfKyq4LsdWamlJ26kTStcBtwCslbZZ0MXAF8HpJ9wO/kd4j6XhJHwOIiB3AH1M9bfl24P1pnlnXSsquD01YVnWcrff/qyPi/MyiU9qsux74nab3q4HVPe/c5rTSsutCbLXCp9KsUCVl14XYsgIxVdBtomYNpWXXhdhqFdSpMNtLSdl1Iba8gOjixIaNXt0lassP2pld9uEXPpVd9sU78if7f/Hyt2WXLVh1W3bZ0BSW3b4KsaSHgF3AJDDRzwXSNp66OcNcImd39ispu4PoEZ8cEdsHsB0bM0HH20RL5+zOUqVl14cmLK+wr3dmexSW3X5PKwZwk6QNkla0W0HSCknrJa1/lmf63J0NXdRMZavNrnM7CxSU3X57xCdFxBZJhwHrJN2Xho/bI41otArgBTpkDP8KLE9F9SqmqTa7zm3pyspuXz3iiNiS/twGrKEaCNlmi/T1LjeVzNmd5QrLbs89YknzgX0iYld6fSrw/oG1rDDDHP1quCNfjV9o+1Vqdid+/Rezy5YfdGd22enLlmeXHXzXfdll5936nLt599jxC5PZZQuyS4atnOz2c2hiIbBGUmM7n4yILw+kVTY+pkbdgBnh7M4FBWW350IcEQ8Cxw6wLTZuCjvz3C1ndw4oLLu+fM3q+TSVlaqg7LoQWy0V1Kswa1ZSdl2ILW9Mr7k066iw7LoQWw1BQb0Ks/+vrOy6EA/IMEe/GurIV32ceZb0SuBTTbN+FnhfRHy4aZ2lVM/++o806zMRMfaXko3C0y/O/7q+d9trssumai5Rq3P7d/5TT58bGwVl14XY8gLoY+CUiPgecByApHlUT7hd02bVr0XEmT3vyKxVYdl1IbZaGty1mKcAD0TEwwPbolmNkrJbzrNErHTLgWszy06U9G1JX5L0qmE2yqwLM55d94itVodLgBZIWt/0flUaLGfvbUj7A/8VeHebbdwBLI6IpySdAXwWWNJHk82AsrLrQmx5QacTHtu7fLLF6cAdEfGcQTIi4smm12sl/Z2kBR6w3fpSWHZ9aMJqKfLTNJxP5qudpJcoDfog6QSqTD7Wb7vNSsque8TTMC6jXw115Ks+T3ik0c1eD/xu07y3AkTESuBc4G2SJoCfAssjoqBL8Yfn6Rfl+03X3JYf/e8VfKun/e178O7ssokn9u9pm0NVUHZdiC1L0f9tohHxY+DFLfNWNr2+Eriyr52YtSgtuy7EVs99UytVQdl1IbZaA7wW02yoSsquC7HlTf/Ehtl4KCy7LsRWr6BehdleCsquC7HVKqlXYdaspOy6EE/DnBz9qqAwz3YH7Mx38X7pNQ9klz1Rs819X7Iwu+wNx2zILrv+SyfVbHVMFJRdF2LLi7JOeJjtUVh2XYitXkG9CrO9FJRdF2LLEmX1KswaSsuuC7HlFfb1zmyPwrLrQmz1Cvp6Z7aXgrLrQmy1SupVmDUrKbsdC7Gk1cCZwLaIeHWadwjVg/WOAh4CzouI/BMyZ4lxGf1qaCNfFfb1rtVsy+4Lvpe/EO2yI76QXfamFb+fXbbfOT/qqS1Hv3vAD6kdtMKy2814xFcBy1rmXQrcHBFLgJvTe5uNomYaf1fh7M5dBWW3YyGOiFuAHS2zzwauTq+vBs4ZcLtsTGgqP407Z3duKym7vR4jXhgRj6TXjwLZ23MkrQBWABzA83vcnY3EmPYe+tRVdp3bwhWW3b4flZRGpM/+yBGxKiKOj4jj9+N5/e7OhkgM7HEzY6kuu85t2UrLbq+FeKukwwHSn9sG1yQbJyV9veuSsztHlJTdXgvxDcCF6fWFwOcG0xwbO32e8JD0kKTvSLqz5fHljeWS9LeSNkq6S9J/Hlzj23J254qCstvN5WvXAkuBBZI2A5cBVwDXS7oYeBg4r9cGlGRcRr8a2shXg7sE6OSaR4yfDixJ0+uAj6Y/+zbbsls3it8bPnpJdtl7L2n7EGIAPvxA+wfUAtx+3LzuGjaOCstux0IcEednFuX/BW3WGMLXuLOBf0zHa78h6YWSDm86odYzZ3duKym7fZ+ss9mtwwmPBZLWN00r2mwigJskbcgsXwRsanq/Oc0z60tJ2fUtzpYXdHrczPaIOL7DVk6KiC2SDgPWSbovXd9rNnMKy657xJY1iEuAImJL+nMbsAY4oWWVLcCRTe+PSPPMelZadl2IrZamIjt1/Kw0X9JBjdfAqcDdLavdALwpnYH+ZeCJQRwfNispuz40MQ3jMujK0AZc6f/M80JgjSSosvbJiPiypLcCRMRKYC1wBrAR+Anwlr72OEct+uC/Z5f9wwcXZ5cdzMaZaM7oFZZdF2Kr18ddSBHxIHBsm/krm14H8Hu978Uso6DsuhBbrXG8C8msGyVl14XY8sb0vnyzjgrLrguxZZX2AEazhtKy60Jstbo5w2w2jkrKrgux5RU2pqvZHoVl14V4GubioCuaHHULzHpTUnZdiC0vyvp6Z7ZHYdl1IbZaJZ15NmtWUnZdiC2rtDPPZg2lZdeF2PIiqsmsNIVl14XYapXUqzBrVlJ2XYgtL0CT5fQqzPYoLLsuxAMya0e/KifLZnsrKLsuxFarpEuAzJqVlF0XYqtV0iVAZs1Kyq4LsWWpsIvizRpKy64LsdUq6YSHWbOSsutn1lledJg6kHSkpH+V9F1J90j6n23WWSrpCUl3pul9A/0ZbG4qLLvuEVuN7h60WGMCuCQi7kgPYtwgaV1EfLdlva9FxJn97Mhsb2Vld6iFeBc7t38lPv1wersA2D7M/We4HdD++ro+r8VMT7R9JL3eJeleYBHQGuax1pJbcGZaObt9GmohjohDG68lrY+I44e5/3bcjg7qbxNdIGl90/tVEbGq3YqSjgJ+Afhmm8UnSvo28EPgHRFxT2+NnRnNuYXx+bdyOzooKLs+NGG1Ony9297NL6CkA4F/Bv5XRDzZsvgOYHFEPCXpDOCzwJJe22vWUFJ2fbLO6jUGT2k3dUHSflRBviYiPvPczceTEfFUer0W2E/SgkH+CDZHFZTdUfaI234NGAG3I0MRfR1nkyTg48C9EfFXmXVeAmyNiJB0AlXn4LGedzoc4/Jv5XZklJbdkRXi3PGYYXM7OpjqawirXwEuAL4j6c407z3AywAiYiVwLvA2SRPAT4HlEeM9fuG4/Fu5HR0UlF0fI7a8APrIckTcSjVGd906VwJX9r4XszYKy+5IjhFLWibpe5I2Srp0FG1I7XhI0nfSxdjrO39iYPtdLWmbpLub5h0iaZ2k+9OfLxpWe+poaio7zTVzPbdp387uDBh6IZY0D/gIcDpwDHC+pGOG3Y4mJ0fEcUO+/OYqYFnLvEuBmyNiCXBzej9iNSc7xvvowcA5t3tchbM7cKPoEZ8AbIyIByNiN3AdcPYI2jEyEXELsKNl9tnA1en11cA5Q21UOwFMRn6aW+Z8bsHZnSmjKMSLgE1N7zeneaMQwE2SNkhaMaI2NCxMd/MAPAosHGVjGkr6ejfDnNs8Z7dPc/1k3UkRsUXSYcA6Sfel//FHKl0OM/r/tgMoaCjBOWQscwvObq9G0SPeAhzZ9P6ING/oImJL+nMbsIbq6+eobJV0OED6c9sI25JEdQlQbppbnNs8Z7dPoyjEtwNLJB0taX9gOXDDsBshaX4aVQlJ84FTgbvrPzWjbgAuTK8vBD43wrZUApicyk9zi3Ob5+z2aeiHJiJiQtLbgRuBecDqEQ3yshBYU91Aw77AJyPiy8PYsaRrgaVUA49sBi4DrgCul3Qx8DBw3jDaUi8gxi+0o+DcVpzdmTGSY8Tpvuy1o9h3UxseBI4d0b7Pzyw6ZagN6aTRqzDAuU37d3ZnwFw/WWedjOE1l2ZdKSi7LsRWI8byxIZZZ2Vl14XY8gKYnBx1K8ymr7DsuhBbvYK+3pntpaDsuhBbXgRRUK/CbI/CsutCbPUKOvNstpeCsutCbHlR1gkPsz0Ky64LsdUq6eudWbOSsutCbDXGc+xWs87Kyq6f4mx5jUuAclMXOj3VQtLzJH0qLf+mpKMG+0PYnNRndoedWxdiy4p05jk3ddLlUy0uBnZGxMuBvwY+OOAfw+agfrI7ity6EFutmIrs1IVunmrR/HSHTwOnpEeZm/Wlj+wOPbc+RmxZu9h541emrl9Qs8oBLQ+vXNXyaPV2T7V4Xcs29qyTRjh7AngxsL33lttc12d2h55bF2LLiojWh0SaFaG07PrQhM2kbp5qsWcdSfsCBwOPDaV1Zu0NPbcuxDaTunmqRfPTHc4F/iWioOuObDYaem59aMJmTO6pFpLeD6yPiBuAjwOfkLSR6jHty0fXYrPR5FbufJiZjZYPTZiZjZgLsZnZiLkQm5mNmAuxmdmIuRCbmY2YC7GZ2Yi5EJuZjdj/A/64/1/6Vd35AAAAAElFTkSuQmCC\n",
"text/plain": [
"<Figure size 360x360 with 4 Axes>"
]
},
"metadata": {
"needs_background": "light"
}
},
{
"output_type": "stream",
"text": [
"Label #1: 1\n"
],
"name": "stdout"
},
{
"output_type": "display_data",
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 360x360 with 4 Axes>"
]
},
"metadata": {
"needs_background": "light"
}
},
{
"output_type": "stream",
"text": [
"Label #2: 2\n"
],
"name": "stdout"
},
{
"output_type": "display_data",
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 360x360 with 4 Axes>"
]
},
"metadata": {
"needs_background": "light"
}
},
{
"output_type": "stream",
"text": [
"Label #3: 3\n"
],
"name": "stdout"
},
{
"output_type": "display_data",
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 360x360 with 4 Axes>"
]
},
"metadata": {
"needs_background": "light"
}
},
{
"output_type": "stream",
"text": [
"Label #4: 4\n"
],
"name": "stdout"
},
{
"output_type": "display_data",
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 360x360 with 4 Axes>"
]
},
"metadata": {
"needs_background": "light"
}
},
{
"output_type": "stream",
"text": [
"Label #5: 5\n"
],
"name": "stdout"
},
{
"output_type": "display_data",
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 360x360 with 4 Axes>"
]
},
"metadata": {
"needs_background": "light"
}
},
{
"output_type": "stream",
"text": [
"Label #6: 6\n"
],
"name": "stdout"
},
{
"output_type": "display_data",
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 360x360 with 4 Axes>"
]
},
"metadata": {
"needs_background": "light"
}
},
{
"output_type": "stream",
"text": [
"Label #7: 7\n"
],
"name": "stdout"
},
{
"output_type": "display_data",
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 360x360 with 4 Axes>"
]
},
"metadata": {
"needs_background": "light"
}
},
{
"output_type": "stream",
"text": [
"Label #8: 8\n"
],
"name": "stdout"
},
{
"output_type": "display_data",
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 360x360 with 4 Axes>"
]
},
"metadata": {
"needs_background": "light"
}
},
{
"output_type": "stream",
"text": [
"Label #9: 9\n"
],
"name": "stdout"
},
{
"output_type": "display_data",
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 360x360 with 4 Axes>"
]
},
"metadata": {
"needs_background": "light"
}
}
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 282
},
"id": "c3hlcfUPyfu3",
"outputId": "cfe62c90-9b26-4bc0-f0d1-a67b597ed9e7"
},
"source": [
"plt.imshow(digjit[\"images\"][9])\n",
"plt.show()\n",
"digits.target[9]"
],
"execution_count": null,
"outputs": [
{
"output_type": "display_data",
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAPsAAAD4CAYAAAAq5pAIAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAANZ0lEQVR4nO3df6xkZX3H8ffHhQUXUaBURKCCLSUBYwLZIFqrplspUuP6h39gakWxIaa1ldbEQE1q0qRGa2NrU1ND1BZToqaIhRhUtqi0TcoqriDye6WILD8tDfgjFdBv/5iz5no7d3c9P+bO8rxfyeSemfOc+3z3zP3sM+fMmXlSVUh66nvaehcgaTEMu9QIwy41wrBLjTDsUiMOWGRnG3NQHcwhi+xSasr/8gMerx9l3rqFhv1gDuFF2bLILqWmbK9r1lzny3ipEYZdasSgsCc5K8ntSXYmuXCsoiSNr3fYk2wAPgS8CjgZeH2Sk8cqTNK4hozspwM7q+quqnoc+CSwdZyyJI1tSNiPAb6z4v693WOSltDkb70lOR84H+BgNk3dnaQ1DBnZdwHHrbh/bPfYz6iqi6tqc1VtPpCDBnQnaYghYf8qcGKSE5JsBM4BrhynLElj6/0yvqqeTPI24AvABuBjVXXzaJVJGtWgY/aqugq4aqRaJE3IK+ikRhh2qREL/dTbenh82/N6b3v8oY/02u6+M77Xu09pKo7sUiMMu9QIwy41wrBLjTDsUiMMu9QIwy41wrBLjTDsUiMMu9QIwy41wrBLjTDsUiP2i0+9bTjlpN7bfumUT41YyT66r/+m7/luv3/rtS98ev9O1QRHdqkRhl1qhGGXGjFkrrfjknwpyS1Jbk7y9jELkzSuISfongTeUVU7khwKfC3Jtqq6ZaTaJI2o98heVfdX1Y5u+XvArTjXm7S0RjlmT3I8cCqwfYzfJ2l8g99nT/IM4NPABVX12Jz1TuwoLYFBI3uSA5kF/dKqunxeGyd2lJbDkLPxAT4K3FpVHxivJElTGDKy/xrwu8BvJLmhu509Ul2SRjZkFtf/ADJiLZIm5BV0UiMMu9SI/eIjrk8cuT5v2b35nl/vtd1Xdv1S7z7/4oVX9NruWn6ld59qgyO71AjDLjXCsEuNMOxSIwy71AjDLjXCsEuNMOxSIwy71AjDLjXCsEuNMOxSIwy71Ij94lNvB962a136fXBrv8kST7/int59nrzxwZ5b+qk37Zkju9QIwy41wrBLjRgc9iQbknw9yWfHKEjSNMYY2d/ObJ43SUts6IwwxwK/DXxknHIkTWXoyP43wDuBn4xQi6QJDZn+6dXAQ1X1tb20Oz/J9Umuf4If9e1O0kBDp396TZK7gU8ymwbqn1Y3cmJHaTn0DntVXVRVx1bV8cA5wBer6g2jVSZpVL7PLjVilGvjq+rLwJfH+F2SpuHILjXCsEuN2C8+4vrjBx/qve17vntS722v+vrVvbY74fO/17vPi47+fK/tNpzS/9/545tv772t9h+O7FIjDLvUCMMuNcKwS40w7FIjDLvUCMMuNcKwS40w7FIjDLvUCMMuNcKwS40w7FIjUlUL6+yZOaJelC0L62+on7z81F7bPe3ar/fu846Pbe613fHHPdy7z42v/HbvbbVcttc1PFaPZN46R3apEYZdaoRhlxoxdPqnw5JcluS2JLcmefFYhUka19Cvpfog8Pmqel2SjcCmEWqSNIHeYU/yLOBlwJsAqupx4PFxypI0tiEv408AHgb+oZuf/SNJDhmpLkkjGxL2A4DTgL+vqlOBHwAXrm7kxI7SchgS9nuBe6tqe3f/Mmbh/xlO7CgthyETOz4AfCfJ7i8s3wLcMkpVkkY39Gz8HwKXdmfi7wLePLwkSVMYFPaqugHodzG3pIXyCjqpEYZdasR+MbHjeun7UdW+H1MF+MKWD/ba7i0X/EnvPjfiR1xb4MguNcKwS40w7FIjDLvUCMMuNcKwS40w7FIjDLvUCMMuNcKwS40w7FIjDLvUCMMuNeIp/6m3IZ9Ae8XJt/fa7uWb/r13n3/wxrf12m7Ttdv33khNc2SXGmHYpUYYdqkRQyd2/OMkNyf5ZpJPJDl4rMIkjat32JMcA/wRsLmqXgBsAM4ZqzBJ4xr6Mv4A4OlJDmA2g+t9w0uSNIUhM8LsAv4KuAe4H3i0qq4eqzBJ4xryMv5wYCuz2VyfCxyS5A1z2jmxo7QEhryM/03gv6rq4ap6ArgceMnqRk7sKC2HIWG/BzgjyaYkYTax463jlCVpbEOO2bczm6Z5B3BT97suHqkuSSMbOrHju4F3j1SLpAl5BZ3UCMMuNeIp/xHXXz3v+t7b9r1C6D6e3rvPp9FvMklpbxzZpUYYdqkRhl1qhGGXGmHYpUYYdqkRhl1qhGGXGmHYpUYYdqkRhl1qhGGXGmHYpUYYdqkRhl1qhGGXGmHYpUbsNexJPpbkoSTfXPHYEUm2Jbmz+3n4tGVKGmpfRvZ/BM5a9diFwDVVdSJwTXdf0hLba9ir6t+AR1Y9vBW4pFu+BHjtyHVJGlnfY/ajqur+bvkB4KiR6pE0kcEn6KqqgFprvRM7Ssuhb9gfTHI0QPfzobUaOrGjtBz6hv1K4Nxu+VzginHKkTSVfXnr7RPAfwInJbk3yVuA9wKvTHIns6mb3zttmZKG2uuMMFX1+jVWbRm5FkkT8go6qRGGXWqEYZcaYdilRhh2qRGGXWqEYZcaYdilRhh2qRGGXWqEYZcaYdilRhh2qRGGXWqEYZcaYdilRhh2qRGGXWqEYZcaYdilRvSd2PH9SW5L8o0kn0ly2LRlShqq78SO24AXVNULgTuAi0auS9LIek3sWFVXV9WT3d3rgGMnqE3SiMY4Zj8P+NwIv0fShPY6ScSeJHkX8CRw6R7anA+cD3Awm4Z0J2mA3mFP8ibg1cCWbibXuarqYuBigGfmiDXbSZpWr7AnOQt4J/DyqvrhuCVJmkLfiR3/DjgU2JbkhiQfnrhOSQP1ndjxoxPUImlCXkEnNcKwS40w7FIjDLvUCMMuNcKwS40w7FIjDLvUCMMuNcKwS40w7FIjDLvUCMMuNcKwS40w7FIjDLvUCMMuNcKwS40w7FIjDLvUiF4TO65Y944kleTIacqTNJa+EzuS5DjgTOCekWuSNIFeEzt2/prZRBHO8iLtB3odsyfZCuyqqhtHrkfSRH7u6Z+SbAL+lNlL+H1p78SO0hLoM7L/MnACcGOSu5nNzb4jyXPmNa6qi6tqc1VtPpCD+lcqaZCfe2SvqpuAZ+++3wV+c1V9d8S6JI2s78SOkvYzfSd2XLn++NGqkTQZr6CTGmHYpUakanHXxCR5GPj2GquPBJbpJN+y1QPLV5P17Nl61PO8qvrFeSsWGvY9SXJ9VW1e7zp2W7Z6YPlqsp49W7Z6fBkvNcKwS41YprBfvN4FrLJs9cDy1WQ9e7ZU9SzNMbukaS3TyC5pQoZdasTCw57krCS3J9mZ5MI56w9K8qlu/fYkx09Yy3FJvpTkliQ3J3n7nDavSPJokhu6259NVc+KPu9OclPX3/Vz1ifJ33b76BtJTpuwlpNW/NtvSPJYkgtWtZl0H837arQkRyTZluTO7ufha2x7btfmziTnTljP+5Pc1j0fn0ly2Brb7vG5nVRVLewGbAC+BTwf2AjcCJy8qs3vAx/uls8BPjVhPUcDp3XLhwJ3zKnnFcBnF7yf7gaO3MP6s4HPAQHOALYv8Pl7gNmFGwvbR8DLgNOAb6547C+BC7vlC4H3zdnuCOCu7ufh3fLhE9VzJnBAt/y+efXsy3M75W3RI/vpwM6ququqHgc+CWxd1WYrcEm3fBmwJUmmKKaq7q+qHd3y94BbgWOm6GtkW4GP18x1wGFJjl5Av1uAb1XVWldBTqLmfzXayr+TS4DXztn0t4BtVfVIVf0PsI0536c4Rj1VdXVVPdndvY7Z9zwslUWH/RjgOyvu38v/D9dP23Q771HgF6YurDtcOBXYPmf1i5PcmORzSU6ZuhZm3+t3dZKvdd/0s9q+7McpnAN8Yo11i95HR1XV/d3yA8BRc9qs1346j9krr3n29txO5uf+8oqnoiTPAD4NXFBVj61avYPZy9bvJzkb+BfgxIlLemlV7UrybGBbktu60WTdJNkIvAa4aM7q9dhHP1VVlWQp3kNO8i7gSeDSNZqs23O76JF9F3DcivvHdo/NbZPkAOBZwH9PVVCSA5kF/dKqunz1+qp6rKq+3y1fBRw49ffkV9Wu7udDwGeYHf6stC/7cWyvAnZU1YOrV6zHPgIe3H3o0v18aE6bhe6nJG8CXg38TnUH6Kvtw3M7mUWH/avAiUlO6EaKc4ArV7W5Eth91vR1wBfX2nFDdecCPgrcWlUfWKPNc3afM0hyOrN9NuV/PockOXT3MrMTP6sn6LgSeGN3Vv4M4NEVL2mn8nrWeAm/6H3UWfl3ci5wxZw2XwDOTHJ4d7b+zO6x0SU5i9lXq7+mqn64Rpt9eW6ns+gzgszOJN/B7Kz8u7rH/pzZTgI4GPhnYCfwFeD5E9byUmbHUN8AbuhuZwNvBd7atXkbcDOzdw6uA14y8f55ftfXjV2/u/fRypoCfKjbhzcx+w7AKWs6hFl4n7XisYXtI2b/ydwPPMHsuPstzM7jXAPcCfwrcETXdjPwkRXbntf9Le0E3jxhPTuZnR/Y/Xe0+x2l5wJX7em5XdTNy2WlRngFndQIwy41wrBLjTDsUiMMu9QIwy41wrBLjfg/Blni2oZ6v/sAAAAASUVORK5CYII=\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
}
},
{
"output_type": "execute_result",
"data": {
"text/plain": [
"9"
]
},
"metadata": {},
"execution_count": 18
}
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "qF3yiSeGzKFg",
"outputId": "ef628d66-059b-4a46-b425-2900ecfabe9e"
},
"source": [
"# test on original un-jittered data\n",
"n_test = 10 # number of digits in dataset\n",
"img_cent = image_center_transformer.fit_transform(digpad[\"data\"][:n_test])\n",
"assert img_cent.shape == (10,225)\n",
"\n",
"for i in range(n_test):\n",
" is_eq = (digpad[\"images\"][i] == img_cent[i, :].reshape((15,15))).all()\n",
" print(f\"Label #{i}: {digits.target[i]}. unjittered ==? centered: {is_eq}\")\n",
" #plot_side(digpad[\"images\"][i], img_cent[i, :].reshape((14,14)), \"unjittered\", \"centered\")\n",
" #plt.show()"
],
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"text": [
"Label #0: 0. unjittered ==? centered: False\n",
"Label #1: 1. unjittered ==? centered: True\n",
"Label #2: 2. unjittered ==? centered: False\n",
"Label #3: 3. unjittered ==? centered: False\n",
"Label #4: 4. unjittered ==? centered: True\n",
"Label #5: 5. unjittered ==? centered: True\n",
"Label #6: 6. unjittered ==? centered: True\n",
"Label #7: 7. unjittered ==? centered: False\n",
"Label #8: 8. unjittered ==? centered: True\n",
"Label #9: 9. unjittered ==? centered: False\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "GoEc1Sur0U62",
"outputId": "aa899e37-fd0f-459f-c2ad-25ae2324cfe8"
},
"source": [
"digpad[\"images\"][i]"
],
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"array([[ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n",
" [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n",
" [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n",
" [ 0, 0, 0, 0, 0, 11, 12, 0, 0, 0, 0, 0, 0, 0, 0],\n",
" [ 0, 0, 0, 0, 2, 16, 16, 16, 13, 0, 0, 0, 0, 0, 0],\n",
" [ 0, 0, 0, 0, 3, 16, 12, 10, 14, 0, 0, 0, 0, 0, 0],\n",
" [ 0, 0, 0, 0, 1, 16, 1, 12, 15, 0, 0, 0, 0, 0, 0],\n",
" [ 0, 0, 0, 0, 0, 13, 16, 9, 15, 2, 0, 0, 0, 0, 0],\n",
" [ 0, 0, 0, 0, 0, 0, 3, 0, 9, 11, 0, 0, 0, 0, 0],\n",
" [ 0, 0, 0, 0, 0, 0, 0, 9, 15, 4, 0, 0, 0, 0, 0],\n",
" [ 0, 0, 0, 0, 0, 9, 12, 13, 3, 0, 0, 0, 0, 0, 0],\n",
" [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n",
" [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n",
" [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n",
" [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]])"
]
},
"metadata": {},
"execution_count": 20
}
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "bJSoC3NH0V-k",
"outputId": "39fded48-a235-4ebd-d67d-371da9ac1d6a"
},
"source": [
"#(img_cent[i, :].reshape((15,15)) * 1e19).astype(int)\n",
"img_cent[i, :].reshape((15,15))"
],
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"array([[ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n",
" [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n",
" [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n",
" [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n",
" [ 0, 0, 0, 0, 0, 11, 12, 0, 0, 0, 0, 0, 0, 0, 0],\n",
" [ 0, 0, 0, 0, 2, 16, 16, 16, 13, 0, 0, 0, 0, 0, 0],\n",
" [ 0, 0, 0, 0, 3, 16, 12, 10, 14, 0, 0, 0, 0, 0, 0],\n",
" [ 0, 0, 0, 0, 1, 16, 1, 12, 15, 0, 0, 0, 0, 0, 0],\n",
" [ 0, 0, 0, 0, 0, 13, 16, 9, 15, 2, 0, 0, 0, 0, 0],\n",
" [ 0, 0, 0, 0, 0, 0, 3, 0, 9, 11, 0, 0, 0, 0, 0],\n",
" [ 0, 0, 0, 0, 0, 0, 0, 9, 15, 4, 0, 0, 0, 0, 0],\n",
" [ 0, 0, 0, 0, 0, 9, 12, 13, 3, 0, 0, 0, 0, 0, 0],\n",
" [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n",
" [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n",
" [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]])"
]
},
"metadata": {},
"execution_count": 21
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "EjiCcTu0jp0_"
},
"source": [
"# run svm and knn"
]
},
{
"cell_type": "code",
"metadata": {
"id": "CIg193J-muvs"
},
"source": [
"#%pdb off"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "VIT0GtY0ww5a",
"outputId": "9a86ff31-1d1d-49a9-de1e-e177951e1e28"
},
"source": [
"from sklearn.neighbors import KNeighborsClassifier\n",
"from sklearn import metrics, svm, model_selection\n",
"from sklearn.pipeline import Pipeline\n",
"from sklearn.decomposition import PCA\n",
"#import datetime as dt\n",
"import time\n",
"from sklearn.base import clone as sk_clone\n",
"\n",
"# trying to add pca because automl-sklearn did it with k-nearest neighbors\n",
"# Updae: PCA doesn't improve accuracy, but speeds up execution\n",
"pca = PCA(n_components = .8, svd_solver = \"full\")\n",
"\n",
"clf_l = [\n",
" # n neighbors = 3 is better than any of 1,2,4,5\n",
" (\"KNN #1\", KNeighborsClassifier(n_neighbors=3)),\n",
" #(\"KNN #2\", KNeighborsClassifier(n_neighbors=4, weights=\"distance\")), # same arguments as automl-sklearn\n",
" # default is kernel=rbf\n",
" # Gamme 1e-3 is better than any of (1e-5,1e-4,1e-2,1e-1)\n",
" # Try linear, poly, RBF as xtomasch gist\n",
" # https://gist.github.com/xtomasch/84d1d8574ef51eb8d42e77560d647e06\n",
" #(\"SVM linear\", svm.SVC(kernel=\"linear\")), # slow and bad performance anyway, so skip\n",
" (\"SVM poly\", svm.SVC(kernel=\"poly\")),\n",
" (\"SVM RBF\", svm.SVC(kernel=\"rbf\", gamma=0.001)),\n",
"]\n",
"\n",
"X_l = [\n",
" (\"no jitter\", digpad[\"data\"]),\n",
" (\"with jitter\", digjit[\"data\"]),\n",
" ]\n",
"\n",
"print(\"(data, center, pca, classifier): accuracy (duration)\")\n",
"print(\"------------------------------------\")\n",
"for X_name, X_i in X_l:\n",
"#for X_name, X_i in X_l[-1:]: # only jitter\n",
" for with_center in [False, True]:\n",
" for with_pca in [False, True]:\n",
" print(\"\")\n",
" for clf_name, clf_i in clf_l:\n",
" #for clf_name, clf_i in clf_l[:1]: # only KNN\n",
" dt_start = time.time()\n",
"\n",
" steps = []\n",
" if with_center: steps.append(('centerer', image_center_transformer))\n",
" if with_pca: steps.append(('PCA', pca))\n",
" steps.append(('classifier', clf_i))\n",
"\n",
" pipe = Pipeline(steps=steps)\n",
"\n",
" # use clone so as to reset between runs\n",
" # Update: tests show no changes in results due to this, so maybe it's unnecessary, but keeping it anyway to be safe.\n",
" pipe = sk_clone(pipe)\n",
"\n",
" results = model_selection.cross_val_score(pipe, X_i, digits.target)\n",
"\n",
" dt_end = time.time() # dt_end - dt_start will be 0 if execution time < 0.1 seconds\n",
"\n",
" # FIXME\n",
" # Note: pipe.fit mixes between (n_features, n_samples) and (n_samples, n_features) ?\n",
" #results = pipe.fit(X_i.T, digits.target).score(digits.target)\n",
" print(f\"({X_name}, {'with center' if with_center else 'no center'}, {'with pca' if with_pca else 'without pca'}, {clf_name}): \\t {np.mean(results).round(2)} ({round(dt_end - dt_start,1)} s)\")\n",
" #break # FIXME\n",
"\n",
" if np.std(results) >= .05:\n",
" print(f\"\\t ******\")\n",
" print(f\"\\t Detailed scores per cross-validation run: {[round(x,2) for x in results]}\")\n",
" print(f\"\\t std scores: {np.std(results).round(3)}\")\n",
" print(\"\")\n",
"\n",
" "
],
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"text": [
"(data, center, pca, classifier): accuracy (duration)\n",
"------------------------------------\n",
"\n",
"(no jitter, no center, without pca, KNN #1): \t 0.97 (1.4 s)\n",
"(no jitter, no center, without pca, SVM poly): \t 0.96 (1.3 s)\n",
"(no jitter, no center, without pca, SVM RBF): \t 0.97 (2.3 s)\n",
"\n",
"(no jitter, no center, with pca, KNN #1): \t 0.95 (0.5 s)\n",
"(no jitter, no center, with pca, SVM poly): \t 0.94 (0.7 s)\n",
"(no jitter, no center, with pca, SVM RBF): \t 0.96 (0.8 s)\n",
"\n",
"(no jitter, with center, without pca, KNN #1): \t 0.95 (3.5 s)\n",
"(no jitter, with center, without pca, SVM poly): \t 0.95 (3.8 s)\n",
"(no jitter, with center, without pca, SVM RBF): \t 0.96 (5.2 s)\n",
"\n",
"(no jitter, with center, with pca, KNN #1): \t 0.94 (3.2 s)\n",
"(no jitter, with center, with pca, SVM poly): \t 0.94 (3.2 s)\n",
"(no jitter, with center, with pca, SVM RBF): \t 0.95 (3.3 s)\n",
"\n",
"(with jitter, no center, without pca, KNN #1): \t 0.35 (1.1 s)\n",
"(with jitter, no center, without pca, SVM poly): \t 0.27 (5.5 s)\n",
"(with jitter, no center, without pca, SVM RBF): \t 0.32 (8.2 s)\n",
"\n",
"(with jitter, no center, with pca, KNN #1): \t 0.32 (0.7 s)\n",
"\t ******\n",
"\t Detailed scores per cross-validation run: [0.25, 0.3, 0.37, 0.38, 0.28]\n",
"\t std scores: 0.053\n",
"(with jitter, no center, with pca, SVM poly): \t 0.27 (1.4 s)\n",
"(with jitter, no center, with pca, SVM RBF): \t 0.28 (2.2 s)\n",
"\t ******\n",
"\t Detailed scores per cross-validation run: [0.22, 0.26, 0.32, 0.36, 0.25]\n",
"\t std scores: 0.051\n",
"\n",
"(with jitter, with center, without pca, KNN #1): \t 0.95 (3.5 s)\n",
"(with jitter, with center, without pca, SVM poly): \t 0.95 (3.9 s)\n",
"(with jitter, with center, without pca, SVM RBF): \t 0.96 (5.2 s)\n",
"\n",
"(with jitter, with center, with pca, KNN #1): \t 0.94 (3.0 s)\n",
"(with jitter, with center, with pca, SVM poly): \t 0.93 (3.1 s)\n",
"(with jitter, with center, with pca, SVM RBF): \t 0.95 (3.2 s)\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "zVJwX_4ylZy_"
},
"source": [
"#len(X_i)\n",
"#X_i[0].shape\n",
"#image_center_transformer.fit_transform(X_i).shape\n",
"\n",
"# chain with fit_transform\n",
"# image_flatten_transformer.fit_transform(image_center_transformer.fit_transform(X_i)).shape\n",
"\n",
"# chain with separate calls to fit and transform\n",
"#X_cent = image_center_transformer.fit(X_i).transform(X_i)\n",
"#image_flatten_transformer.fit(X_cent).transform(X_cent).shape"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "0qH4IzNpmXyu"
},
"source": [
"# try auto-sklearn\n",
"\n",
"https://www.automl.org/automl/auto-sklearn/\n",
"\n",
"https://automl.github.io/auto-sklearn/master/\n",
"\n",
"Note: From the github.io page, the authors write \"This will run for one hour and should result in an accuracy above 0.98.\". So not running on the un-jittered digits dataset for now."
]
},
{
"cell_type": "code",
"metadata": {
"id": "kr-XX4ebnM4p",
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "7076d2b7-ac28-46af-80bf-ccab57469924"
},
"source": [
"!pip install --quiet auto-sklearn\n",
"# RESTART kernel and load data again"
],
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"text": [
"\u001b[K |████████████████████████████████| 6.3 MB 8.1 MB/s \n",
"\u001b[K |████████████████████████████████| 28.5 MB 52 kB/s \n",
"\u001b[K |████████████████████████████████| 22.3 MB 70.3 MB/s \n",
"\u001b[K |████████████████████████████████| 722 kB 60.1 MB/s \n",
"\u001b[K |████████████████████████████████| 4.2 MB 38.0 MB/s \n",
"\u001b[K |████████████████████████████████| 4.0 MB 52.4 MB/s \n",
"\u001b[K |████████████████████████████████| 208 kB 59.5 MB/s \n",
"\u001b[K |████████████████████████████████| 973 kB 44.5 MB/s \n",
"\u001b[K |████████████████████████████████| 118 kB 53.3 MB/s \n",
"\u001b[?25h Building wheel for auto-sklearn (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
" Building wheel for pynisher (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
" Building wheel for liac-arff (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
"\u001b[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.\n",
"albumentations 0.1.12 requires imgaug<0.2.7,>=0.2.5, but you have imgaug 0.2.9 which is incompatible.\u001b[0m\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "22lYRMGLmx-n"
},
"source": [
"from sklearn.model_selection import train_test_split\n",
"\n",
"# Split data into 50% train and 50% test subsets\n",
"X_pad_train, X_pad_test, X_jit_train, X_jit_test, y_train, y_test = train_test_split(\n",
" digpad[\"data\"], digjit[\"data\"], digits.target,\n",
" test_size=0.5,\n",
" #shuffle=False # cannot stratify without shuffle=True\n",
" shuffle=True, stratify=digits.target\n",
")"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 392
},
"id": "5BB7fuHSl9nQ",
"outputId": "23c94da6-be4b-462f-b0c9-193b48b294c4"
},
"source": [
"import sklearn.metrics\n",
"\n",
"# Got error about NVM uninitialized, so trying version 2\n",
"# Update: got the same. Maybe because I'm testing with time_left...=30 or n_jobs=2\n",
"from autosklearn.classification import AutoSklearnClassifier\n",
"cls = AutoSklearnClassifier(n_jobs=1, time_left_for_this_task=120)\n",
"\n",
"#from autosklearn.experimental.askl2 import AutoSklearn2Classifier\n",
"#cls = AutoSklearn2Classifier(n_jobs=1, time_left_for_this_task=30)\n",
"\n",
"cls.fit(X_jit_train, y_train)\n",
"predictions = cls.predict(X_jit_test)\n",
"acc = sklearn.metrics.accuracy_score(y_test, predictions).round(2)\n",
"#print(f\"X_train = no jitter, accuracy = {acc}\")\n",
"acc"
],
"execution_count": null,
"outputs": [
{
"output_type": "error",
"ename": "IncorrectPackageVersionError",
"evalue": "ignored",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mIncorrectPackageVersionError\u001b[0m Traceback (most recent call last)",
"\u001b[0;32m<ipython-input-27-072a5e48bba5>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m()\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[0;31m# Got error about NVM uninitialized, so trying version 2\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[0;31m# Update: got the same. Maybe because I'm testing with time_left...=30 or n_jobs=2\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 5\u001b[0;31m \u001b[0;32mfrom\u001b[0m \u001b[0mautosklearn\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mclassification\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mAutoSklearnClassifier\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 6\u001b[0m \u001b[0mcls\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mAutoSklearnClassifier\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mn_jobs\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtime_left_for_this_task\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m120\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 7\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/usr/local/lib/python3.7/dist-packages/autosklearn/__init__.py\u001b[0m in \u001b[0;36m<module>\u001b[0;34m()\u001b[0m\n\u001b[1;32m 11\u001b[0m \u001b[0mrequirements\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mrequirements\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdecode\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'utf-8'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 12\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 13\u001b[0;31m \u001b[0mdependencies\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mverify_packages\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mrequirements\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 14\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 15\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mos\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mname\u001b[0m \u001b[0;34m!=\u001b[0m \u001b[0;34m'posix'\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/usr/local/lib/python3.7/dist-packages/autosklearn/util/dependencies.py\u001b[0m in \u001b[0;36mverify_packages\u001b[0;34m(packages)\u001b[0m\n\u001b[1;32m 23\u001b[0m \u001b[0moperation\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmatch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mgroup\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'operation1'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 24\u001b[0m \u001b[0mversion\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmatch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mgroup\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'version1'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 25\u001b[0;31m \u001b[0m_verify_package\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mname\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0moperation\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mversion\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 26\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 27\u001b[0m \u001b[0;32mraise\u001b[0m \u001b[0mValueError\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'Unable to read requirement: %s'\u001b[0m \u001b[0;34m%\u001b[0m \u001b[0mpackage\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/usr/local/lib/python3.7/dist-packages/autosklearn/util/dependencies.py\u001b[0m in \u001b[0;36m_verify_package\u001b[0;34m(name, operation, version)\u001b[0m\n\u001b[1;32m 60\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0mcheck\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 61\u001b[0m raise IncorrectPackageVersionError(name, installed_version, operation,\n\u001b[0;32m---> 62\u001b[0;31m required_version)\n\u001b[0m\u001b[1;32m 63\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 64\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;31mIncorrectPackageVersionError\u001b[0m: found 'scipy' version 1.4.1 but requires scipy version >=1.7.0"
]
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "nh6HOT2F88lj"
},
"source": [
"cls"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "aeVAeOFB9lhU"
},
"source": [
"Inspect results as per https://automl.github.io/auto-sklearn/master/manual.html#manual"
]
},
{
"cell_type": "code",
"metadata": {
"id": "K8qajCOQ8_Kk"
},
"source": [
"cls.cv_results_"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "-LkmR2mz9OOg"
},
"source": [
"print(cls.sprint_statistics())"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "RfMKiVNz9Sg8"
},
"source": [
"print(cls.show_models())"
],
"execution_count": null,
"outputs": []
}
]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment