Skip to content

Instantly share code, notes, and snippets.

@shadiakiki1986
Last active August 30, 2021 06:22
Show Gist options
  • Save shadiakiki1986/689980135fe9dde1d892127bde40a5a1 to your computer and use it in GitHub Desktop.
Save shadiakiki1986/689980135fe9dde1d892127bde40a5a1 to your computer and use it in GitHub Desktop.
SVM-RBF sensitivity to translation.ipynb
Display the source blob
Display the rendered blob
Raw
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"name": "SVM-RBF sensitivity to translation.ipynb",
"provenance": [],
"collapsed_sections": [],
"toc_visible": true,
"authorship_tag": "ABX9TyOF2Z+MPhxOHaV1DPw1vbgU",
"include_colab_link": true
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"language_info": {
"name": "python"
}
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/gist/shadiakiki1986/689980135fe9dde1d892127bde40a5a1/svm-rbf-sensitivity-to-translation.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "8GaAZFCMWaw4"
},
"source": [
"Testing if RBF is invariant to translation:\n",
"\n",
"- sklearn issue: https://github.com/scikit-learn/scikit-learn/issues/18432\n",
"- gist: https://gist.github.com/xtomasch/84d1d8574ef51eb8d42e77560d647e06\n",
"\n",
"\n",
"Uses my digits dataset with jitter: https://github.com/shadiakiki1986/mnist-digits-jitter\n",
"\n",
"Published as gist at https://gist.github.com/shadiakiki1986/689980135fe9dde1d892127bde40a5a1"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "HGkCzTmwXS4o"
},
"source": [
"# get data"
]
},
{
"cell_type": "code",
"metadata": {
"id": "wxWb-jzbXTdv"
},
"source": [
"# first the original data\n",
"from sklearn.datasets import load_digits\n",
"digits = load_digits()"
],
"execution_count": 1,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "RrAkpWbdfPMu",
"outputId": "0ac4e1d5-748f-407d-95ce-622f09c8e8db"
},
"source": [
"# then the padded and jittered data\n",
"!git clone https://github.com/shadiakiki1986/mnist-digits-jitter\n",
"\n",
"# Update: no need to gunzip since np.loadtxt can automatically do it\n",
"#!gunzip mnist-digits-jitter/digits_padded.csv.gz\n",
"#!gunzip mnist-digits-jitter/digits_jitter.csv.gz"
],
"execution_count": 2,
"outputs": [
{
"output_type": "stream",
"text": [
"Cloning into 'mnist-digits-jitter'...\n",
"remote: Enumerating objects: 47, done.\u001b[K\n",
"remote: Counting objects: 100% (47/47), done.\u001b[K\n",
"remote: Compressing objects: 100% (44/44), done.\u001b[K\n",
"remote: Total 47 (delta 20), reused 8 (delta 1), pack-reused 0\u001b[K\n",
"Unpacking objects: 100% (47/47), done.\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "66oFvhVEfZMk",
"outputId": "db7a154b-a160-47c1-ee81-502f210d5e0b"
},
"source": [
"import numpy as np\n",
"# np.loadtxt can decompress the files on read\n",
"digpad = {\"data\": np.loadtxt(\"mnist-digits-jitter/digits_padded.csv.gz\", delimiter=\",\", dtype=int)}\n",
"digjit = {\"data\": np.loadtxt(\"mnist-digits-jitter/digits_jitter.csv.gz\", delimiter=\",\", dtype=int)}\n",
"digpad[\"data\"].shape, digjit[\"data\"].shape"
],
"execution_count": 3,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"((1797, 225), (1797, 225))"
]
},
"metadata": {},
"execution_count": 3
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "nBjvLW_9f3YM"
},
"source": [
"# convert data to image (not flat)\n",
"#def im2data(digxxx_img):\n",
"# return np.vstack([img.reshape((-1,1)).squeeze() for img in digxxx_img])\n",
"\n",
"def data2im(digxxx_data):\n",
" s = int(digxxx_data.shape[1]**.5)\n",
" l = digxxx_data.reshape((-1,s,s))\n",
" return l\n",
"\n",
"digjit[\"images\"] = data2im(digjit[\"data\"])\n",
"digpad[\"images\"] = data2im(digpad[\"data\"])\n",
"\n",
"assert digjit[\"data\"].shape == (1797, 225)\n",
"assert digpad[\"data\"].shape == (1797, 225)\n",
"assert digjit[\"images\"].shape == (1797, 15, 15)\n",
"assert digpad[\"images\"].shape == (1797, 15, 15)"
],
"execution_count": 4,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "N8Ukyc9Twqve"
},
"source": [
"# run svm and knn"
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "VIT0GtY0ww5a",
"outputId": "2630bfe0-c5d1-4bfb-f92d-4d5094d7dfbb"
},
"source": [
"from sklearn.neighbors import KNeighborsClassifier\n",
"from sklearn import metrics, svm, model_selection\n",
"import statistics\n",
"\n",
"clf_l = [\n",
" # n neighbors = 3 is better than any of 1,2,4,5\n",
" (\"KNN\", KNeighborsClassifier(n_neighbors=3)),\n",
" # default is kernel=rbf\n",
" # Gamme 1e-3 is better than any of (1e-5,1e-4,1e-2,1e-1)\n",
" # Try linear, poly, RBF as xtomasch gist\n",
" # https://gist.github.com/xtomasch/84d1d8574ef51eb8d42e77560d647e06\n",
" (\"SVM linear\", svm.SVC(kernel=\"linear\")),\n",
" (\"SVM poly\", svm.SVC(kernel=\"poly\")),\n",
" (\"SVM RBF\", svm.SVC(kernel=\"rbf\", gamma=0.001)),\n",
"]\n",
"\n",
"X_l = [\n",
" (\"no jitter\", digpad[\"data\"]),\n",
" (\"with jitter\", digjit[\"data\"]),\n",
" ]\n",
"\n",
"\n",
"for X_name, X_i in X_l:\n",
" for clf_name, clf_i in clf_l:\n",
" results = model_selection.cross_val_score(clf_i, X_i, digits.target)\n",
" print(f\"{clf_name}, {X_name}: {statistics.mean(results).round(2)}\")"
],
"execution_count": 5,
"outputs": [
{
"output_type": "stream",
"text": [
"KNN, no jitter: 0.97\n",
"SVM linear, no jitter: 0.95\n",
"SVM poly, no jitter: 0.96\n",
"SVM RBF, no jitter: 0.97\n",
"KNN, with jitter: 0.35\n",
"SVM linear, with jitter: 0.1\n",
"SVM poly, with jitter: 0.27\n",
"SVM RBF, with jitter: 0.32\n"
],
"name": "stdout"
}
]
}
]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment