Skip to content

Instantly share code, notes, and snippets.

@sabineri
Created February 19, 2020 17:21
Show Gist options
  • Save sabineri/5732b402a2872731e54da7aaa3553f9c to your computer and use it in GitHub Desktop.
Save sabineri/5732b402a2872731e54da7aaa3553f9c to your computer and use it in GitHub Desktop.
Created on Cognitive Class Labs
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"button": false,
"new_sheet": false,
"run_control": {
"read_only": false
}
},
"source": [
"<a href=\"https://www.bigdatauniversity.com\"><img src = \"https://ibm.box.com/shared/static/cw2c7r3o20w9zn8gkecaeyjhgw3xdgbj.png\" width = 400, align = \"center\"></a>\n",
"\n",
"\n",
"<h1 align=center><font size = 5>K-Nearest Neighbors in R</font></h1>"
]
},
{
"cell_type": "markdown",
"metadata": {
"button": false,
"new_sheet": false,
"run_control": {
"read_only": false
}
},
"source": [
"In this lesson, we are going to introduce the $K$-nearest neighbors (KNN) algorithm and show some practical ways of using it in `R` with the `knn` function that exists in the `class` library. Later, we will show how to tune it with the `caret` library.\n",
"\n",
"For a simple binary classification task (two class classification, $A$, $B$), *given* training and testing datasets and a positive integer $K$, for **each record in the test dataset**, KNN tries to find $K$ neighbors in **training set** that are *closest* to that test record and *counts* how many of those $K$ examples in the training set belong to class $A,$ and how many belong to class $B.$ The test record is then classified as belonging to the majority class (based on counted votes) i.e. the test record is considered to be of class $i$ if the majority of the $K$-nearest neighbors in the training set belong to class $i.$\n",
"\n",
"As can be seen, there are no parameters that *need to be learned during training* to determine whether a new observation belongs to class $A$ or $B.$ The only parameter used in k-nearest neighbors is k, which is a predetermined value. The algorithm simply works by looking at the training samples, calculating *distances* and finding the $K$ examples in the training set that are closest to the new observation. Thus, KNN is a *non-parametric,* supervised (needs training labels) learning algorithm.\n",
"\n",
"The following diagram illustrates the main idea of how the k-nearest neighbors algorithm works. As $K$ varies from $3$ to $6$ the class of the new observation (red star) changes from $B$ to $A$ because the majority votes are changed. That is, for $K=3,$ we have two observations of class $B$ and one of class $A$, while for $K=6,$ we have two observations of class $B,$ and four of class $A.$\n",
"\n",
"<img src=\"http://bdewilde.github.io/assets/images/2012-10-26-knn-concept.png\" width=600, align = \"center\">\n",
"<div style=\"text-align:center\"> [[KNN classifications for k=3 and k=6]](http://bdewilde.github.io/blog/blogger/2012/10/26/classification-of-hand-written-digits-3/)\n",
"<br>\n",
"\n",
"For a regression task, the same method can be applied but instead of taking majority votes, we can, for example, find the mean of the $response$ $variable$ of the $K$-nearest neighbors from the new observation.\n",
"\n",
"KNN depends on 1) the choice of metric (for example, Euclidean in above example), and 2) the choice of $K.$ There are no universal choices, and depending on the data, one has to examine various options to find a suitable choice.\n",
"\n",
"*Caveats:*\n",
"\n",
"* When using KNN, we must ensure that there are no categorical variables (factors) involved in the **features**, simply because one cannot find the distance from them. For example, when a categorical variable takes values from the set {apple, orange, banana, grapes ...}, one cannot make use of numerical distance functions, unless of course there is a pre-determined way to evaluate these distance from a qualitative standpoint.\n",
"\n",
"* If the training set is high-dimensional, KNN will suffer from [the curse of dimensionality](https://en.wikipedia.org/wiki/Curse_of_dimensionality). Therefore, we could use a dimensionality reduction technique prior to using KNN.\n",
"\n",
"* *Standardize* the *training* set before using KNN. Precisely, one can preprocess data so that each training feature (column) has a mean of zero and a standard deviation of one. Note that the order is exact. \n",
"\n",
"In fact, we will see the **effect** of standardizing training and test sets on the predicted values later.\n",
"\n",
"-------"
]
},
{
"cell_type": "markdown",
"metadata": {
"button": false,
"new_sheet": false,
"run_control": {
"read_only": false
}
},
"source": [
"### Install and import libraries"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"__Notice:__ We have to install some packages for this notebook, and it may takes a few minutes."
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"button": false,
"collapsed": false,
"jupyter": {
"outputs_hidden": false
},
"new_sheet": false,
"run_control": {
"read_only": false
}
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Updating HTML index of packages in '.Library'\n",
"Making 'packages.html' ... done\n",
"also installing the dependency ‘pROC’\n",
"\n",
"Updating HTML index of packages in '.Library'\n",
"Making 'packages.html' ... done\n",
"Updating HTML index of packages in '.Library'\n",
"Making 'packages.html' ... done\n",
"Updating HTML index of packages in '.Library'\n",
"Making 'packages.html' ... done\n",
"Loading required package: lattice\n",
"Loading required package: ggplot2\n",
"Loading required package: mlbench\n"
]
}
],
"source": [
"# install the packages (note: this may take some time)\n",
"install.packages(\"class\")\n",
"install.packages(\"caret\")\n",
"install.packages(\"mlbench\")\n",
"install.packages(\"e1071\")\n",
"\n",
"library(class)\n",
"library(caret)\n",
"require(mlbench)\n",
"library(e1071)\n",
"library(base)\n",
"require(base)"
]
},
{
"cell_type": "markdown",
"metadata": {
"button": false,
"new_sheet": false,
"run_control": {
"read_only": false
}
},
"source": [
"### Step 1- Data collection\n",
"\n",
"For this lesson, we will be using `Sonar` data set (signals) from `mlbench` library. `Sonar` is a system for the detection of objects under water and for measuring the water's depth by emitting sound pulses and detecting. The complete description can be found in [mlbench](https://cran.r-project.org/web/packages/mlbench/mlbench.pdf). For our purposes, this is a two-class (class $R$ and class $M$) classification task with numeric data.\n",
"\n",
"Let's look at the first five rows of `Sonar`"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"button": false,
"collapsed": false,
"jupyter": {
"outputs_hidden": false
},
"new_sheet": false,
"run_control": {
"read_only": false
}
},
"outputs": [
{
"data": {
"text/html": [
"<table>\n",
"<caption>A data.frame: 6 × 61</caption>\n",
"<thead>\n",
"\t<tr><th scope=col>V1</th><th scope=col>V2</th><th scope=col>V3</th><th scope=col>V4</th><th scope=col>V5</th><th scope=col>V6</th><th scope=col>V7</th><th scope=col>V8</th><th scope=col>V9</th><th scope=col>V10</th><th scope=col>⋯</th><th scope=col>V52</th><th scope=col>V53</th><th scope=col>V54</th><th scope=col>V55</th><th scope=col>V56</th><th scope=col>V57</th><th scope=col>V58</th><th scope=col>V59</th><th scope=col>V60</th><th scope=col>Class</th></tr>\n",
"\t<tr><th scope=col>&lt;dbl&gt;</th><th scope=col>&lt;dbl&gt;</th><th scope=col>&lt;dbl&gt;</th><th scope=col>&lt;dbl&gt;</th><th scope=col>&lt;dbl&gt;</th><th scope=col>&lt;dbl&gt;</th><th scope=col>&lt;dbl&gt;</th><th scope=col>&lt;dbl&gt;</th><th scope=col>&lt;dbl&gt;</th><th scope=col>&lt;dbl&gt;</th><th scope=col>⋯</th><th scope=col>&lt;dbl&gt;</th><th scope=col>&lt;dbl&gt;</th><th scope=col>&lt;dbl&gt;</th><th scope=col>&lt;dbl&gt;</th><th scope=col>&lt;dbl&gt;</th><th scope=col>&lt;dbl&gt;</th><th scope=col>&lt;dbl&gt;</th><th scope=col>&lt;dbl&gt;</th><th scope=col>&lt;dbl&gt;</th><th scope=col>&lt;fct&gt;</th></tr>\n",
"</thead>\n",
"<tbody>\n",
"\t<tr><td>0.0200</td><td>0.0371</td><td>0.0428</td><td>0.0207</td><td>0.0954</td><td>0.0986</td><td>0.1539</td><td>0.1601</td><td>0.3109</td><td>0.2111</td><td>⋯</td><td>0.0027</td><td>0.0065</td><td>0.0159</td><td>0.0072</td><td>0.0167</td><td>0.0180</td><td>0.0084</td><td>0.0090</td><td>0.0032</td><td>R</td></tr>\n",
"\t<tr><td>0.0453</td><td>0.0523</td><td>0.0843</td><td>0.0689</td><td>0.1183</td><td>0.2583</td><td>0.2156</td><td>0.3481</td><td>0.3337</td><td>0.2872</td><td>⋯</td><td>0.0084</td><td>0.0089</td><td>0.0048</td><td>0.0094</td><td>0.0191</td><td>0.0140</td><td>0.0049</td><td>0.0052</td><td>0.0044</td><td>R</td></tr>\n",
"\t<tr><td>0.0262</td><td>0.0582</td><td>0.1099</td><td>0.1083</td><td>0.0974</td><td>0.2280</td><td>0.2431</td><td>0.3771</td><td>0.5598</td><td>0.6194</td><td>⋯</td><td>0.0232</td><td>0.0166</td><td>0.0095</td><td>0.0180</td><td>0.0244</td><td>0.0316</td><td>0.0164</td><td>0.0095</td><td>0.0078</td><td>R</td></tr>\n",
"\t<tr><td>0.0100</td><td>0.0171</td><td>0.0623</td><td>0.0205</td><td>0.0205</td><td>0.0368</td><td>0.1098</td><td>0.1276</td><td>0.0598</td><td>0.1264</td><td>⋯</td><td>0.0121</td><td>0.0036</td><td>0.0150</td><td>0.0085</td><td>0.0073</td><td>0.0050</td><td>0.0044</td><td>0.0040</td><td>0.0117</td><td>R</td></tr>\n",
"\t<tr><td>0.0762</td><td>0.0666</td><td>0.0481</td><td>0.0394</td><td>0.0590</td><td>0.0649</td><td>0.1209</td><td>0.2467</td><td>0.3564</td><td>0.4459</td><td>⋯</td><td>0.0031</td><td>0.0054</td><td>0.0105</td><td>0.0110</td><td>0.0015</td><td>0.0072</td><td>0.0048</td><td>0.0107</td><td>0.0094</td><td>R</td></tr>\n",
"\t<tr><td>0.0286</td><td>0.0453</td><td>0.0277</td><td>0.0174</td><td>0.0384</td><td>0.0990</td><td>0.1201</td><td>0.1833</td><td>0.2105</td><td>0.3039</td><td>⋯</td><td>0.0045</td><td>0.0014</td><td>0.0038</td><td>0.0013</td><td>0.0089</td><td>0.0057</td><td>0.0027</td><td>0.0051</td><td>0.0062</td><td>R</td></tr>\n",
"</tbody>\n",
"</table>\n"
],
"text/latex": [
"A data.frame: 6 × 61\n",
"\\begin{tabular}{r|lllllllllllllllllllllllllllllllllllllllllllllllllllllllllllll}\n",
" V1 & V2 & V3 & V4 & V5 & V6 & V7 & V8 & V9 & V10 & V11 & V12 & V13 & V14 & V15 & V16 & V17 & V18 & V19 & V20 & V21 & V22 & V23 & V24 & V25 & V26 & V27 & V28 & V29 & V30 & V31 & V32 & V33 & V34 & V35 & V36 & V37 & V38 & V39 & V40 & V41 & V42 & V43 & V44 & V45 & V46 & V47 & V48 & V49 & V50 & V51 & V52 & V53 & V54 & V55 & V56 & V57 & V58 & V59 & V60 & Class\\\\\n",
" <dbl> & <dbl> & <dbl> & <dbl> & <dbl> & <dbl> & <dbl> & <dbl> & <dbl> & <dbl> & <dbl> & <dbl> & <dbl> & <dbl> & <dbl> & <dbl> & <dbl> & <dbl> & <dbl> & <dbl> & <dbl> & <dbl> & <dbl> & <dbl> & <dbl> & <dbl> & <dbl> & <dbl> & <dbl> & <dbl> & <dbl> & <dbl> & <dbl> & <dbl> & <dbl> & <dbl> & <dbl> & <dbl> & <dbl> & <dbl> & <dbl> & <dbl> & <dbl> & <dbl> & <dbl> & <dbl> & <dbl> & <dbl> & <dbl> & <dbl> & <dbl> & <dbl> & <dbl> & <dbl> & <dbl> & <dbl> & <dbl> & <dbl> & <dbl> & <dbl> & <fct>\\\\\n",
"\\hline\n",
"\t 0.0200 & 0.0371 & 0.0428 & 0.0207 & 0.0954 & 0.0986 & 0.1539 & 0.1601 & 0.3109 & 0.2111 & 0.1609 & 0.1582 & 0.2238 & 0.0645 & 0.0660 & 0.2273 & 0.3100 & 0.2999 & 0.5078 & 0.4797 & 0.5783 & 0.5071 & 0.4328 & 0.5550 & 0.6711 & 0.6415 & 0.7104 & 0.8080 & 0.6791 & 0.3857 & 0.1307 & 0.2604 & 0.5121 & 0.7547 & 0.8537 & 0.8507 & 0.6692 & 0.6097 & 0.4943 & 0.2744 & 0.0510 & 0.2834 & 0.2825 & 0.4256 & 0.2641 & 0.1386 & 0.1051 & 0.1343 & 0.0383 & 0.0324 & 0.0232 & 0.0027 & 0.0065 & 0.0159 & 0.0072 & 0.0167 & 0.0180 & 0.0084 & 0.0090 & 0.0032 & R\\\\\n",
"\t 0.0453 & 0.0523 & 0.0843 & 0.0689 & 0.1183 & 0.2583 & 0.2156 & 0.3481 & 0.3337 & 0.2872 & 0.4918 & 0.6552 & 0.6919 & 0.7797 & 0.7464 & 0.9444 & 1.0000 & 0.8874 & 0.8024 & 0.7818 & 0.5212 & 0.4052 & 0.3957 & 0.3914 & 0.3250 & 0.3200 & 0.3271 & 0.2767 & 0.4423 & 0.2028 & 0.3788 & 0.2947 & 0.1984 & 0.2341 & 0.1306 & 0.4182 & 0.3835 & 0.1057 & 0.1840 & 0.1970 & 0.1674 & 0.0583 & 0.1401 & 0.1628 & 0.0621 & 0.0203 & 0.0530 & 0.0742 & 0.0409 & 0.0061 & 0.0125 & 0.0084 & 0.0089 & 0.0048 & 0.0094 & 0.0191 & 0.0140 & 0.0049 & 0.0052 & 0.0044 & R\\\\\n",
"\t 0.0262 & 0.0582 & 0.1099 & 0.1083 & 0.0974 & 0.2280 & 0.2431 & 0.3771 & 0.5598 & 0.6194 & 0.6333 & 0.7060 & 0.5544 & 0.5320 & 0.6479 & 0.6931 & 0.6759 & 0.7551 & 0.8929 & 0.8619 & 0.7974 & 0.6737 & 0.4293 & 0.3648 & 0.5331 & 0.2413 & 0.5070 & 0.8533 & 0.6036 & 0.8514 & 0.8512 & 0.5045 & 0.1862 & 0.2709 & 0.4232 & 0.3043 & 0.6116 & 0.6756 & 0.5375 & 0.4719 & 0.4647 & 0.2587 & 0.2129 & 0.2222 & 0.2111 & 0.0176 & 0.1348 & 0.0744 & 0.0130 & 0.0106 & 0.0033 & 0.0232 & 0.0166 & 0.0095 & 0.0180 & 0.0244 & 0.0316 & 0.0164 & 0.0095 & 0.0078 & R\\\\\n",
"\t 0.0100 & 0.0171 & 0.0623 & 0.0205 & 0.0205 & 0.0368 & 0.1098 & 0.1276 & 0.0598 & 0.1264 & 0.0881 & 0.1992 & 0.0184 & 0.2261 & 0.1729 & 0.2131 & 0.0693 & 0.2281 & 0.4060 & 0.3973 & 0.2741 & 0.3690 & 0.5556 & 0.4846 & 0.3140 & 0.5334 & 0.5256 & 0.2520 & 0.2090 & 0.3559 & 0.6260 & 0.7340 & 0.6120 & 0.3497 & 0.3953 & 0.3012 & 0.5408 & 0.8814 & 0.9857 & 0.9167 & 0.6121 & 0.5006 & 0.3210 & 0.3202 & 0.4295 & 0.3654 & 0.2655 & 0.1576 & 0.0681 & 0.0294 & 0.0241 & 0.0121 & 0.0036 & 0.0150 & 0.0085 & 0.0073 & 0.0050 & 0.0044 & 0.0040 & 0.0117 & R\\\\\n",
"\t 0.0762 & 0.0666 & 0.0481 & 0.0394 & 0.0590 & 0.0649 & 0.1209 & 0.2467 & 0.3564 & 0.4459 & 0.4152 & 0.3952 & 0.4256 & 0.4135 & 0.4528 & 0.5326 & 0.7306 & 0.6193 & 0.2032 & 0.4636 & 0.4148 & 0.4292 & 0.5730 & 0.5399 & 0.3161 & 0.2285 & 0.6995 & 1.0000 & 0.7262 & 0.4724 & 0.5103 & 0.5459 & 0.2881 & 0.0981 & 0.1951 & 0.4181 & 0.4604 & 0.3217 & 0.2828 & 0.2430 & 0.1979 & 0.2444 & 0.1847 & 0.0841 & 0.0692 & 0.0528 & 0.0357 & 0.0085 & 0.0230 & 0.0046 & 0.0156 & 0.0031 & 0.0054 & 0.0105 & 0.0110 & 0.0015 & 0.0072 & 0.0048 & 0.0107 & 0.0094 & R\\\\\n",
"\t 0.0286 & 0.0453 & 0.0277 & 0.0174 & 0.0384 & 0.0990 & 0.1201 & 0.1833 & 0.2105 & 0.3039 & 0.2988 & 0.4250 & 0.6343 & 0.8198 & 1.0000 & 0.9988 & 0.9508 & 0.9025 & 0.7234 & 0.5122 & 0.2074 & 0.3985 & 0.5890 & 0.2872 & 0.2043 & 0.5782 & 0.5389 & 0.3750 & 0.3411 & 0.5067 & 0.5580 & 0.4778 & 0.3299 & 0.2198 & 0.1407 & 0.2856 & 0.3807 & 0.4158 & 0.4054 & 0.3296 & 0.2707 & 0.2650 & 0.0723 & 0.1238 & 0.1192 & 0.1089 & 0.0623 & 0.0494 & 0.0264 & 0.0081 & 0.0104 & 0.0045 & 0.0014 & 0.0038 & 0.0013 & 0.0089 & 0.0057 & 0.0027 & 0.0051 & 0.0062 & R\\\\\n",
"\\end{tabular}\n"
],
"text/markdown": [
"\n",
"A data.frame: 6 × 61\n",
"\n",
"| V1 &lt;dbl&gt; | V2 &lt;dbl&gt; | V3 &lt;dbl&gt; | V4 &lt;dbl&gt; | V5 &lt;dbl&gt; | V6 &lt;dbl&gt; | V7 &lt;dbl&gt; | V8 &lt;dbl&gt; | V9 &lt;dbl&gt; | V10 &lt;dbl&gt; | ⋯ ⋯ | V52 &lt;dbl&gt; | V53 &lt;dbl&gt; | V54 &lt;dbl&gt; | V55 &lt;dbl&gt; | V56 &lt;dbl&gt; | V57 &lt;dbl&gt; | V58 &lt;dbl&gt; | V59 &lt;dbl&gt; | V60 &lt;dbl&gt; | Class &lt;fct&gt; |\n",
"|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|\n",
"| 0.0200 | 0.0371 | 0.0428 | 0.0207 | 0.0954 | 0.0986 | 0.1539 | 0.1601 | 0.3109 | 0.2111 | ⋯ | 0.0027 | 0.0065 | 0.0159 | 0.0072 | 0.0167 | 0.0180 | 0.0084 | 0.0090 | 0.0032 | R |\n",
"| 0.0453 | 0.0523 | 0.0843 | 0.0689 | 0.1183 | 0.2583 | 0.2156 | 0.3481 | 0.3337 | 0.2872 | ⋯ | 0.0084 | 0.0089 | 0.0048 | 0.0094 | 0.0191 | 0.0140 | 0.0049 | 0.0052 | 0.0044 | R |\n",
"| 0.0262 | 0.0582 | 0.1099 | 0.1083 | 0.0974 | 0.2280 | 0.2431 | 0.3771 | 0.5598 | 0.6194 | ⋯ | 0.0232 | 0.0166 | 0.0095 | 0.0180 | 0.0244 | 0.0316 | 0.0164 | 0.0095 | 0.0078 | R |\n",
"| 0.0100 | 0.0171 | 0.0623 | 0.0205 | 0.0205 | 0.0368 | 0.1098 | 0.1276 | 0.0598 | 0.1264 | ⋯ | 0.0121 | 0.0036 | 0.0150 | 0.0085 | 0.0073 | 0.0050 | 0.0044 | 0.0040 | 0.0117 | R |\n",
"| 0.0762 | 0.0666 | 0.0481 | 0.0394 | 0.0590 | 0.0649 | 0.1209 | 0.2467 | 0.3564 | 0.4459 | ⋯ | 0.0031 | 0.0054 | 0.0105 | 0.0110 | 0.0015 | 0.0072 | 0.0048 | 0.0107 | 0.0094 | R |\n",
"| 0.0286 | 0.0453 | 0.0277 | 0.0174 | 0.0384 | 0.0990 | 0.1201 | 0.1833 | 0.2105 | 0.3039 | ⋯ | 0.0045 | 0.0014 | 0.0038 | 0.0013 | 0.0089 | 0.0057 | 0.0027 | 0.0051 | 0.0062 | R |\n",
"\n"
],
"text/plain": [
" V1 V2 V3 V4 V5 V6 V7 V8 V9 V10 ⋯\n",
"1 0.0200 0.0371 0.0428 0.0207 0.0954 0.0986 0.1539 0.1601 0.3109 0.2111 ⋯\n",
"2 0.0453 0.0523 0.0843 0.0689 0.1183 0.2583 0.2156 0.3481 0.3337 0.2872 ⋯\n",
"3 0.0262 0.0582 0.1099 0.1083 0.0974 0.2280 0.2431 0.3771 0.5598 0.6194 ⋯\n",
"4 0.0100 0.0171 0.0623 0.0205 0.0205 0.0368 0.1098 0.1276 0.0598 0.1264 ⋯\n",
"5 0.0762 0.0666 0.0481 0.0394 0.0590 0.0649 0.1209 0.2467 0.3564 0.4459 ⋯\n",
"6 0.0286 0.0453 0.0277 0.0174 0.0384 0.0990 0.1201 0.1833 0.2105 0.3039 ⋯\n",
" V52 V53 V54 V55 V56 V57 V58 V59 V60 Class\n",
"1 0.0027 0.0065 0.0159 0.0072 0.0167 0.0180 0.0084 0.0090 0.0032 R \n",
"2 0.0084 0.0089 0.0048 0.0094 0.0191 0.0140 0.0049 0.0052 0.0044 R \n",
"3 0.0232 0.0166 0.0095 0.0180 0.0244 0.0316 0.0164 0.0095 0.0078 R \n",
"4 0.0121 0.0036 0.0150 0.0085 0.0073 0.0050 0.0044 0.0040 0.0117 R \n",
"5 0.0031 0.0054 0.0105 0.0110 0.0015 0.0072 0.0048 0.0107 0.0094 R \n",
"6 0.0045 0.0014 0.0038 0.0013 0.0089 0.0057 0.0027 0.0051 0.0062 R "
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"data(Sonar)\n",
"head(Sonar)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Step 2- Preparing and exploring the data"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"It is A data frame with 208 observations on 61 variables, all numerical and one (the Class) nominal."
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"button": false,
"collapsed": false,
"jupyter": {
"outputs_hidden": false
},
"new_sheet": false,
"run_control": {
"read_only": false
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"number of rows and columns are: 208 61"
]
}
],
"source": [
"cat(\"number of rows and columns are:\", nrow(Sonar), ncol(Sonar))"
]
},
{
"cell_type": "markdown",
"metadata": {
"button": false,
"new_sheet": false,
"run_control": {
"read_only": false
}
},
"source": [
"Lets check how many $M$ classes and $R$ classes `Sonar` data contain? and check whether `Sonar` data contains any NA in its columns."
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"button": false,
"collapsed": false,
"jupyter": {
"outputs_hidden": false
},
"new_sheet": false,
"run_control": {
"read_only": false
}
},
"outputs": [
{
"data": {
"text/plain": [
"\n",
" M R \n",
"111 97 "
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<dl class=dl-horizontal>\n",
"\t<dt>V1</dt>\n",
"\t\t<dd>0</dd>\n",
"\t<dt>V2</dt>\n",
"\t\t<dd>0</dd>\n",
"\t<dt>V3</dt>\n",
"\t\t<dd>0</dd>\n",
"\t<dt>V4</dt>\n",
"\t\t<dd>0</dd>\n",
"\t<dt>V5</dt>\n",
"\t\t<dd>0</dd>\n",
"\t<dt>V6</dt>\n",
"\t\t<dd>0</dd>\n",
"\t<dt>V7</dt>\n",
"\t\t<dd>0</dd>\n",
"\t<dt>V8</dt>\n",
"\t\t<dd>0</dd>\n",
"\t<dt>V9</dt>\n",
"\t\t<dd>0</dd>\n",
"\t<dt>V10</dt>\n",
"\t\t<dd>0</dd>\n",
"\t<dt>V11</dt>\n",
"\t\t<dd>0</dd>\n",
"\t<dt>V12</dt>\n",
"\t\t<dd>0</dd>\n",
"\t<dt>V13</dt>\n",
"\t\t<dd>0</dd>\n",
"\t<dt>V14</dt>\n",
"\t\t<dd>0</dd>\n",
"\t<dt>V15</dt>\n",
"\t\t<dd>0</dd>\n",
"\t<dt>V16</dt>\n",
"\t\t<dd>0</dd>\n",
"\t<dt>V17</dt>\n",
"\t\t<dd>0</dd>\n",
"\t<dt>V18</dt>\n",
"\t\t<dd>0</dd>\n",
"\t<dt>V19</dt>\n",
"\t\t<dd>0</dd>\n",
"\t<dt>V20</dt>\n",
"\t\t<dd>0</dd>\n",
"\t<dt>V21</dt>\n",
"\t\t<dd>0</dd>\n",
"\t<dt>V22</dt>\n",
"\t\t<dd>0</dd>\n",
"\t<dt>V23</dt>\n",
"\t\t<dd>0</dd>\n",
"\t<dt>V24</dt>\n",
"\t\t<dd>0</dd>\n",
"\t<dt>V25</dt>\n",
"\t\t<dd>0</dd>\n",
"\t<dt>V26</dt>\n",
"\t\t<dd>0</dd>\n",
"\t<dt>V27</dt>\n",
"\t\t<dd>0</dd>\n",
"\t<dt>V28</dt>\n",
"\t\t<dd>0</dd>\n",
"\t<dt>V29</dt>\n",
"\t\t<dd>0</dd>\n",
"\t<dt>V30</dt>\n",
"\t\t<dd>0</dd>\n",
"\t<dt>V31</dt>\n",
"\t\t<dd>0</dd>\n",
"\t<dt>V32</dt>\n",
"\t\t<dd>0</dd>\n",
"\t<dt>V33</dt>\n",
"\t\t<dd>0</dd>\n",
"\t<dt>V34</dt>\n",
"\t\t<dd>0</dd>\n",
"\t<dt>V35</dt>\n",
"\t\t<dd>0</dd>\n",
"\t<dt>V36</dt>\n",
"\t\t<dd>0</dd>\n",
"\t<dt>V37</dt>\n",
"\t\t<dd>0</dd>\n",
"\t<dt>V38</dt>\n",
"\t\t<dd>0</dd>\n",
"\t<dt>V39</dt>\n",
"\t\t<dd>0</dd>\n",
"\t<dt>V40</dt>\n",
"\t\t<dd>0</dd>\n",
"\t<dt>V41</dt>\n",
"\t\t<dd>0</dd>\n",
"\t<dt>V42</dt>\n",
"\t\t<dd>0</dd>\n",
"\t<dt>V43</dt>\n",
"\t\t<dd>0</dd>\n",
"\t<dt>V44</dt>\n",
"\t\t<dd>0</dd>\n",
"\t<dt>V45</dt>\n",
"\t\t<dd>0</dd>\n",
"\t<dt>V46</dt>\n",
"\t\t<dd>0</dd>\n",
"\t<dt>V47</dt>\n",
"\t\t<dd>0</dd>\n",
"\t<dt>V48</dt>\n",
"\t\t<dd>0</dd>\n",
"\t<dt>V49</dt>\n",
"\t\t<dd>0</dd>\n",
"\t<dt>V50</dt>\n",
"\t\t<dd>0</dd>\n",
"\t<dt>V51</dt>\n",
"\t\t<dd>0</dd>\n",
"\t<dt>V52</dt>\n",
"\t\t<dd>0</dd>\n",
"\t<dt>V53</dt>\n",
"\t\t<dd>0</dd>\n",
"\t<dt>V54</dt>\n",
"\t\t<dd>0</dd>\n",
"\t<dt>V55</dt>\n",
"\t\t<dd>0</dd>\n",
"\t<dt>V56</dt>\n",
"\t\t<dd>0</dd>\n",
"\t<dt>V57</dt>\n",
"\t\t<dd>0</dd>\n",
"\t<dt>V58</dt>\n",
"\t\t<dd>0</dd>\n",
"\t<dt>V59</dt>\n",
"\t\t<dd>0</dd>\n",
"\t<dt>V60</dt>\n",
"\t\t<dd>0</dd>\n",
"\t<dt>Class</dt>\n",
"\t\t<dd>0</dd>\n",
"</dl>\n"
],
"text/latex": [
"\\begin{description*}\n",
"\\item[V1] 0\n",
"\\item[V2] 0\n",
"\\item[V3] 0\n",
"\\item[V4] 0\n",
"\\item[V5] 0\n",
"\\item[V6] 0\n",
"\\item[V7] 0\n",
"\\item[V8] 0\n",
"\\item[V9] 0\n",
"\\item[V10] 0\n",
"\\item[V11] 0\n",
"\\item[V12] 0\n",
"\\item[V13] 0\n",
"\\item[V14] 0\n",
"\\item[V15] 0\n",
"\\item[V16] 0\n",
"\\item[V17] 0\n",
"\\item[V18] 0\n",
"\\item[V19] 0\n",
"\\item[V20] 0\n",
"\\item[V21] 0\n",
"\\item[V22] 0\n",
"\\item[V23] 0\n",
"\\item[V24] 0\n",
"\\item[V25] 0\n",
"\\item[V26] 0\n",
"\\item[V27] 0\n",
"\\item[V28] 0\n",
"\\item[V29] 0\n",
"\\item[V30] 0\n",
"\\item[V31] 0\n",
"\\item[V32] 0\n",
"\\item[V33] 0\n",
"\\item[V34] 0\n",
"\\item[V35] 0\n",
"\\item[V36] 0\n",
"\\item[V37] 0\n",
"\\item[V38] 0\n",
"\\item[V39] 0\n",
"\\item[V40] 0\n",
"\\item[V41] 0\n",
"\\item[V42] 0\n",
"\\item[V43] 0\n",
"\\item[V44] 0\n",
"\\item[V45] 0\n",
"\\item[V46] 0\n",
"\\item[V47] 0\n",
"\\item[V48] 0\n",
"\\item[V49] 0\n",
"\\item[V50] 0\n",
"\\item[V51] 0\n",
"\\item[V52] 0\n",
"\\item[V53] 0\n",
"\\item[V54] 0\n",
"\\item[V55] 0\n",
"\\item[V56] 0\n",
"\\item[V57] 0\n",
"\\item[V58] 0\n",
"\\item[V59] 0\n",
"\\item[V60] 0\n",
"\\item[Class] 0\n",
"\\end{description*}\n"
],
"text/markdown": [
"V1\n",
": 0V2\n",
": 0V3\n",
": 0V4\n",
": 0V5\n",
": 0V6\n",
": 0V7\n",
": 0V8\n",
": 0V9\n",
": 0V10\n",
": 0V11\n",
": 0V12\n",
": 0V13\n",
": 0V14\n",
": 0V15\n",
": 0V16\n",
": 0V17\n",
": 0V18\n",
": 0V19\n",
": 0V20\n",
": 0V21\n",
": 0V22\n",
": 0V23\n",
": 0V24\n",
": 0V25\n",
": 0V26\n",
": 0V27\n",
": 0V28\n",
": 0V29\n",
": 0V30\n",
": 0V31\n",
": 0V32\n",
": 0V33\n",
": 0V34\n",
": 0V35\n",
": 0V36\n",
": 0V37\n",
": 0V38\n",
": 0V39\n",
": 0V40\n",
": 0V41\n",
": 0V42\n",
": 0V43\n",
": 0V44\n",
": 0V45\n",
": 0V46\n",
": 0V47\n",
": 0V48\n",
": 0V49\n",
": 0V50\n",
": 0V51\n",
": 0V52\n",
": 0V53\n",
": 0V54\n",
": 0V55\n",
": 0V56\n",
": 0V57\n",
": 0V58\n",
": 0V59\n",
": 0V60\n",
": 0Class\n",
": 0\n",
"\n"
],
"text/plain": [
" V1 V2 V3 V4 V5 V6 V7 V8 V9 V10 V11 V12 V13 \n",
" 0 0 0 0 0 0 0 0 0 0 0 0 0 \n",
" V14 V15 V16 V17 V18 V19 V20 V21 V22 V23 V24 V25 V26 \n",
" 0 0 0 0 0 0 0 0 0 0 0 0 0 \n",
" V27 V28 V29 V30 V31 V32 V33 V34 V35 V36 V37 V38 V39 \n",
" 0 0 0 0 0 0 0 0 0 0 0 0 0 \n",
" V40 V41 V42 V43 V44 V45 V46 V47 V48 V49 V50 V51 V52 \n",
" 0 0 0 0 0 0 0 0 0 0 0 0 0 \n",
" V53 V54 V55 V56 V57 V58 V59 V60 Class \n",
" 0 0 0 0 0 0 0 0 0 "
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"base::table(Sonar$Class) \n",
"apply(Sonar, 2, function(x) sum(is.na(x))) "
]
},
{
"cell_type": "markdown",
"metadata": {
"button": false,
"new_sheet": false,
"run_control": {
"read_only": false
}
},
"source": [
"Here, we want to manually take samples from our data to split `Sonar` into training and test sets"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"button": false,
"collapsed": false,
"jupyter": {
"outputs_hidden": false
},
"new_sheet": false,
"run_control": {
"read_only": false
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"number of training and test samples are 145 63"
]
}
],
"source": [
"SEED <- 123\n",
"set.seed(SEED)\n",
"data <- Sonar[base::sample(nrow(Sonar)), ] # shuffle data first\n",
"bound <- floor(0.7 * nrow(data))\n",
"df_train <- data[1:bound, ] \n",
"df_test <- data[(bound + 1):nrow(data), ]\n",
"cat(\"number of training and test samples are \", nrow(df_train), nrow(df_test))"
]
},
{
"cell_type": "markdown",
"metadata": {
"button": false,
"new_sheet": false,
"run_control": {
"read_only": false
}
},
"source": [
"Let's examine if the train and test samples have properly splitted with the almost the same portion of `Class` labels"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {
"button": false,
"collapsed": false,
"jupyter": {
"outputs_hidden": false
},
"new_sheet": false,
"run_control": {
"read_only": false
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"number of training classes: \n",
" 0.5310345 0.4689655\n",
"number of test classes: \n",
" 0.5396825 0.4603175"
]
}
],
"source": [
"cat(\"number of training classes: \\n\", base::table(df_train$Class)/nrow(df_train))\n",
"cat(\"\\n\")\n",
"cat(\"number of test classes: \\n\", base::table(df_test$Class)/nrow(df_test))"
]
},
{
"cell_type": "markdown",
"metadata": {
"button": false,
"new_sheet": false,
"run_control": {
"read_only": false
}
},
"source": [
"To simplify our job, we can create the following data frames"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {
"button": false,
"new_sheet": false,
"run_control": {
"read_only": false
}
},
"outputs": [],
"source": [
"X_train <- subset(df_train, select=-Class)\n",
"y_train <- df_train$Class\n",
"X_test <- subset(df_test, select=-Class) # exclude Class for prediction\n",
"y_test <- df_test$Class"
]
},
{
"cell_type": "markdown",
"metadata": {
"button": false,
"new_sheet": false,
"run_control": {
"read_only": false
}
},
"source": [
"### Step 3 – Training a model on data\n",
"Now, we are going to use `knn` function from `class` library with $k=3$"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {
"button": false,
"collapsed": false,
"jupyter": {
"outputs_hidden": false
},
"new_sheet": false,
"run_control": {
"read_only": false
}
},
"outputs": [
{
"data": {
"text/html": [
"<ol class=list-inline>\n",
"\t<li>M</li>\n",
"\t<li>M</li>\n",
"\t<li>M</li>\n",
"\t<li>M</li>\n",
"\t<li>R</li>\n",
"\t<li>R</li>\n",
"\t<li>M</li>\n",
"\t<li>M</li>\n",
"\t<li>M</li>\n",
"\t<li>R</li>\n",
"\t<li>M</li>\n",
"\t<li>M</li>\n",
"\t<li>M</li>\n",
"\t<li>R</li>\n",
"\t<li>M</li>\n",
"\t<li>R</li>\n",
"\t<li>R</li>\n",
"\t<li>M</li>\n",
"\t<li>M</li>\n",
"\t<li>M</li>\n",
"\t<li>M</li>\n",
"\t<li>R</li>\n",
"\t<li>M</li>\n",
"\t<li>R</li>\n",
"\t<li>R</li>\n",
"\t<li>M</li>\n",
"\t<li>R</li>\n",
"\t<li>M</li>\n",
"\t<li>R</li>\n",
"\t<li>M</li>\n",
"\t<li>M</li>\n",
"\t<li>R</li>\n",
"\t<li>M</li>\n",
"\t<li>M</li>\n",
"\t<li>M</li>\n",
"\t<li>M</li>\n",
"\t<li>M</li>\n",
"\t<li>M</li>\n",
"\t<li>R</li>\n",
"\t<li>R</li>\n",
"\t<li>M</li>\n",
"\t<li>M</li>\n",
"\t<li>M</li>\n",
"\t<li>M</li>\n",
"\t<li>M</li>\n",
"\t<li>R</li>\n",
"\t<li>R</li>\n",
"\t<li>R</li>\n",
"\t<li>R</li>\n",
"\t<li>R</li>\n",
"\t<li>M</li>\n",
"\t<li>M</li>\n",
"\t<li>R</li>\n",
"\t<li>M</li>\n",
"\t<li>R</li>\n",
"\t<li>R</li>\n",
"\t<li>R</li>\n",
"\t<li>R</li>\n",
"\t<li>R</li>\n",
"\t<li>R</li>\n",
"\t<li>M</li>\n",
"\t<li>R</li>\n",
"\t<li>M</li>\n",
"</ol>\n",
"\n",
"<details>\n",
"\t<summary style=display:list-item;cursor:pointer>\n",
"\t\t<strong>Levels</strong>:\n",
"\t</summary>\n",
"\t<ol class=list-inline>\n",
"\t\t<li>'M'</li>\n",
"\t\t<li>'R'</li>\n",
"\t</ol>\n",
"</details>"
],
"text/latex": [
"\\begin{enumerate*}\n",
"\\item M\n",
"\\item M\n",
"\\item M\n",
"\\item M\n",
"\\item R\n",
"\\item R\n",
"\\item M\n",
"\\item M\n",
"\\item M\n",
"\\item R\n",
"\\item M\n",
"\\item M\n",
"\\item M\n",
"\\item R\n",
"\\item M\n",
"\\item R\n",
"\\item R\n",
"\\item M\n",
"\\item M\n",
"\\item M\n",
"\\item M\n",
"\\item R\n",
"\\item M\n",
"\\item R\n",
"\\item R\n",
"\\item M\n",
"\\item R\n",
"\\item M\n",
"\\item R\n",
"\\item M\n",
"\\item M\n",
"\\item R\n",
"\\item M\n",
"\\item M\n",
"\\item M\n",
"\\item M\n",
"\\item M\n",
"\\item M\n",
"\\item R\n",
"\\item R\n",
"\\item M\n",
"\\item M\n",
"\\item M\n",
"\\item M\n",
"\\item M\n",
"\\item R\n",
"\\item R\n",
"\\item R\n",
"\\item R\n",
"\\item R\n",
"\\item M\n",
"\\item M\n",
"\\item R\n",
"\\item M\n",
"\\item R\n",
"\\item R\n",
"\\item R\n",
"\\item R\n",
"\\item R\n",
"\\item R\n",
"\\item M\n",
"\\item R\n",
"\\item M\n",
"\\end{enumerate*}\n",
"\n",
"\\emph{Levels}: \\begin{enumerate*}\n",
"\\item 'M'\n",
"\\item 'R'\n",
"\\end{enumerate*}\n"
],
"text/markdown": [
"1. M\n",
"2. M\n",
"3. M\n",
"4. M\n",
"5. R\n",
"6. R\n",
"7. M\n",
"8. M\n",
"9. M\n",
"10. R\n",
"11. M\n",
"12. M\n",
"13. M\n",
"14. R\n",
"15. M\n",
"16. R\n",
"17. R\n",
"18. M\n",
"19. M\n",
"20. M\n",
"21. M\n",
"22. R\n",
"23. M\n",
"24. R\n",
"25. R\n",
"26. M\n",
"27. R\n",
"28. M\n",
"29. R\n",
"30. M\n",
"31. M\n",
"32. R\n",
"33. M\n",
"34. M\n",
"35. M\n",
"36. M\n",
"37. M\n",
"38. M\n",
"39. R\n",
"40. R\n",
"41. M\n",
"42. M\n",
"43. M\n",
"44. M\n",
"45. M\n",
"46. R\n",
"47. R\n",
"48. R\n",
"49. R\n",
"50. R\n",
"51. M\n",
"52. M\n",
"53. R\n",
"54. M\n",
"55. R\n",
"56. R\n",
"57. R\n",
"58. R\n",
"59. R\n",
"60. R\n",
"61. M\n",
"62. R\n",
"63. M\n",
"\n",
"\n",
"\n",
"**Levels**: 1. 'M'\n",
"2. 'R'\n",
"\n",
"\n"
],
"text/plain": [
" [1] M M M M R R M M M R M M M R M R R M M M M R M R R M R M R M M R M M M M M M\n",
"[39] R R M M M M M R R R R R M M R M R R R R R R M R M\n",
"Levels: M R"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"model_knn <- knn(train=X_train,\n",
" test=X_test,\n",
" cl=y_train, # class labels\n",
" k=3)\n",
"model_knn"
]
},
{
"cell_type": "markdown",
"metadata": {
"button": false,
"new_sheet": false,
"run_control": {
"read_only": false
}
},
"source": [
"### Step 4 – Evaluate the model performance\n",
"As you can see, `model_knn` with $k=3$ provides the above predictions for the test set `X_test`. Then, we can see how many classes have been correctly or incorrectly classified by comparing to the true labels as follows"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {
"button": false,
"collapsed": false,
"jupyter": {
"outputs_hidden": false
},
"new_sheet": false,
"run_control": {
"read_only": false
}
},
"outputs": [
{
"data": {
"text/plain": [
" model_knn\n",
"y_test M R\n",
" M 28 6\n",
" R 8 21"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"conf_mat <- base::table(y_test, model_knn)\n",
"conf_mat"
]
},
{
"cell_type": "markdown",
"metadata": {
"button": false,
"new_sheet": false,
"run_control": {
"read_only": false
}
},
"source": [
"To compute the accuracy, we sum up all the correctly classified observations (located in diagonal) and divide it by the total number of classes"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {
"button": false,
"collapsed": false,
"jupyter": {
"outputs_hidden": false
},
"new_sheet": false,
"run_control": {
"read_only": false
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Test accuracy: 0.7777778"
]
}
],
"source": [
"cat(\"Test accuracy: \", sum(diag(conf_mat))/sum(conf_mat))"
]
},
{
"cell_type": "markdown",
"metadata": {
"button": false,
"new_sheet": false,
"run_control": {
"read_only": false
}
},
"source": [
"To assess whether $k=3$ is a good choice and see whether $k=3$ leads to overfitting /underfitting the data, we could use `knn.cv` which does the leave-one-out cross-validations for training set (i.e., it singles out a training sample one at a time and tries to view it as a new example and see what class label it assigns)."
]
},
{
"cell_type": "markdown",
"metadata": {
"button": false,
"new_sheet": false,
"run_control": {
"read_only": false
}
},
"source": [
"Below are the predicted classes for the training set using the leave-one-out cross-validation. Now, let's examine its accuracy"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {
"button": false,
"collapsed": false,
"jupyter": {
"outputs_hidden": false
},
"new_sheet": false,
"run_control": {
"read_only": false
}
},
"outputs": [
{
"data": {
"text/html": [
"<ol class=list-inline>\n",
"\t<li>R</li>\n",
"\t<li>M</li>\n",
"\t<li>R</li>\n",
"\t<li>M</li>\n",
"\t<li>M</li>\n",
"\t<li>R</li>\n",
"\t<li>M</li>\n",
"\t<li>M</li>\n",
"\t<li>M</li>\n",
"\t<li>R</li>\n",
"\t<li>M</li>\n",
"\t<li>R</li>\n",
"\t<li>M</li>\n",
"\t<li>M</li>\n",
"\t<li>M</li>\n",
"\t<li>R</li>\n",
"\t<li>R</li>\n",
"\t<li>R</li>\n",
"\t<li>R</li>\n",
"\t<li>M</li>\n",
"\t<li>M</li>\n",
"\t<li>M</li>\n",
"\t<li>M</li>\n",
"\t<li>M</li>\n",
"\t<li>M</li>\n",
"\t<li>M</li>\n",
"\t<li>R</li>\n",
"\t<li>M</li>\n",
"\t<li>R</li>\n",
"\t<li>M</li>\n",
"\t<li>M</li>\n",
"\t<li>M</li>\n",
"\t<li>M</li>\n",
"\t<li>R</li>\n",
"\t<li>M</li>\n",
"\t<li>R</li>\n",
"\t<li>M</li>\n",
"\t<li>R</li>\n",
"\t<li>R</li>\n",
"\t<li>R</li>\n",
"\t<li>R</li>\n",
"\t<li>M</li>\n",
"\t<li>R</li>\n",
"\t<li>R</li>\n",
"\t<li>R</li>\n",
"\t<li>R</li>\n",
"\t<li>M</li>\n",
"\t<li>R</li>\n",
"\t<li>R</li>\n",
"\t<li>M</li>\n",
"\t<li>M</li>\n",
"\t<li>M</li>\n",
"\t<li>M</li>\n",
"\t<li>R</li>\n",
"\t<li>R</li>\n",
"\t<li>R</li>\n",
"\t<li>M</li>\n",
"\t<li>M</li>\n",
"\t<li>M</li>\n",
"\t<li>R</li>\n",
"\t<li>R</li>\n",
"\t<li>R</li>\n",
"\t<li>R</li>\n",
"\t<li>M</li>\n",
"\t<li>M</li>\n",
"\t<li>R</li>\n",
"\t<li>M</li>\n",
"\t<li>M</li>\n",
"\t<li>R</li>\n",
"\t<li>R</li>\n",
"\t<li>M</li>\n",
"\t<li>M</li>\n",
"\t<li>M</li>\n",
"\t<li>M</li>\n",
"\t<li>R</li>\n",
"\t<li>R</li>\n",
"\t<li>M</li>\n",
"\t<li>M</li>\n",
"\t<li>M</li>\n",
"\t<li>R</li>\n",
"\t<li>M</li>\n",
"\t<li>M</li>\n",
"\t<li>M</li>\n",
"\t<li>M</li>\n",
"\t<li>M</li>\n",
"\t<li>R</li>\n",
"\t<li>M</li>\n",
"\t<li>M</li>\n",
"\t<li>M</li>\n",
"\t<li>M</li>\n",
"\t<li>R</li>\n",
"\t<li>R</li>\n",
"\t<li>M</li>\n",
"\t<li>M</li>\n",
"\t<li>R</li>\n",
"\t<li>R</li>\n",
"\t<li>R</li>\n",
"\t<li>R</li>\n",
"\t<li>R</li>\n",
"\t<li>R</li>\n",
"\t<li>M</li>\n",
"\t<li>R</li>\n",
"\t<li>M</li>\n",
"\t<li>M</li>\n",
"\t<li>M</li>\n",
"\t<li>R</li>\n",
"\t<li>M</li>\n",
"\t<li>R</li>\n",
"\t<li>R</li>\n",
"\t<li>M</li>\n",
"\t<li>M</li>\n",
"\t<li>M</li>\n",
"\t<li>M</li>\n",
"\t<li>M</li>\n",
"\t<li>R</li>\n",
"\t<li>M</li>\n",
"\t<li>M</li>\n",
"\t<li>M</li>\n",
"\t<li>M</li>\n",
"\t<li>M</li>\n",
"\t<li>R</li>\n",
"\t<li>R</li>\n",
"\t<li>M</li>\n",
"\t<li>M</li>\n",
"\t<li>M</li>\n",
"\t<li>R</li>\n",
"\t<li>M</li>\n",
"\t<li>M</li>\n",
"\t<li>R</li>\n",
"\t<li>M</li>\n",
"\t<li>R</li>\n",
"\t<li>R</li>\n",
"\t<li>M</li>\n",
"\t<li>M</li>\n",
"\t<li>R</li>\n",
"\t<li>R</li>\n",
"\t<li>R</li>\n",
"\t<li>R</li>\n",
"\t<li>M</li>\n",
"\t<li>R</li>\n",
"\t<li>M</li>\n",
"\t<li>M</li>\n",
"\t<li>R</li>\n",
"\t<li>R</li>\n",
"\t<li>M</li>\n",
"</ol>\n",
"\n",
"<details>\n",
"\t<summary style=display:list-item;cursor:pointer>\n",
"\t\t<strong>Levels</strong>:\n",
"\t</summary>\n",
"\t<ol class=list-inline>\n",
"\t\t<li>'M'</li>\n",
"\t\t<li>'R'</li>\n",
"\t</ol>\n",
"</details>"
],
"text/latex": [
"\\begin{enumerate*}\n",
"\\item R\n",
"\\item M\n",
"\\item R\n",
"\\item M\n",
"\\item M\n",
"\\item R\n",
"\\item M\n",
"\\item M\n",
"\\item M\n",
"\\item R\n",
"\\item M\n",
"\\item R\n",
"\\item M\n",
"\\item M\n",
"\\item M\n",
"\\item R\n",
"\\item R\n",
"\\item R\n",
"\\item R\n",
"\\item M\n",
"\\item M\n",
"\\item M\n",
"\\item M\n",
"\\item M\n",
"\\item M\n",
"\\item M\n",
"\\item R\n",
"\\item M\n",
"\\item R\n",
"\\item M\n",
"\\item M\n",
"\\item M\n",
"\\item M\n",
"\\item R\n",
"\\item M\n",
"\\item R\n",
"\\item M\n",
"\\item R\n",
"\\item R\n",
"\\item R\n",
"\\item R\n",
"\\item M\n",
"\\item R\n",
"\\item R\n",
"\\item R\n",
"\\item R\n",
"\\item M\n",
"\\item R\n",
"\\item R\n",
"\\item M\n",
"\\item M\n",
"\\item M\n",
"\\item M\n",
"\\item R\n",
"\\item R\n",
"\\item R\n",
"\\item M\n",
"\\item M\n",
"\\item M\n",
"\\item R\n",
"\\item R\n",
"\\item R\n",
"\\item R\n",
"\\item M\n",
"\\item M\n",
"\\item R\n",
"\\item M\n",
"\\item M\n",
"\\item R\n",
"\\item R\n",
"\\item M\n",
"\\item M\n",
"\\item M\n",
"\\item M\n",
"\\item R\n",
"\\item R\n",
"\\item M\n",
"\\item M\n",
"\\item M\n",
"\\item R\n",
"\\item M\n",
"\\item M\n",
"\\item M\n",
"\\item M\n",
"\\item M\n",
"\\item R\n",
"\\item M\n",
"\\item M\n",
"\\item M\n",
"\\item M\n",
"\\item R\n",
"\\item R\n",
"\\item M\n",
"\\item M\n",
"\\item R\n",
"\\item R\n",
"\\item R\n",
"\\item R\n",
"\\item R\n",
"\\item R\n",
"\\item M\n",
"\\item R\n",
"\\item M\n",
"\\item M\n",
"\\item M\n",
"\\item R\n",
"\\item M\n",
"\\item R\n",
"\\item R\n",
"\\item M\n",
"\\item M\n",
"\\item M\n",
"\\item M\n",
"\\item M\n",
"\\item R\n",
"\\item M\n",
"\\item M\n",
"\\item M\n",
"\\item M\n",
"\\item M\n",
"\\item R\n",
"\\item R\n",
"\\item M\n",
"\\item M\n",
"\\item M\n",
"\\item R\n",
"\\item M\n",
"\\item M\n",
"\\item R\n",
"\\item M\n",
"\\item R\n",
"\\item R\n",
"\\item M\n",
"\\item M\n",
"\\item R\n",
"\\item R\n",
"\\item R\n",
"\\item R\n",
"\\item M\n",
"\\item R\n",
"\\item M\n",
"\\item M\n",
"\\item R\n",
"\\item R\n",
"\\item M\n",
"\\end{enumerate*}\n",
"\n",
"\\emph{Levels}: \\begin{enumerate*}\n",
"\\item 'M'\n",
"\\item 'R'\n",
"\\end{enumerate*}\n"
],
"text/markdown": [
"1. R\n",
"2. M\n",
"3. R\n",
"4. M\n",
"5. M\n",
"6. R\n",
"7. M\n",
"8. M\n",
"9. M\n",
"10. R\n",
"11. M\n",
"12. R\n",
"13. M\n",
"14. M\n",
"15. M\n",
"16. R\n",
"17. R\n",
"18. R\n",
"19. R\n",
"20. M\n",
"21. M\n",
"22. M\n",
"23. M\n",
"24. M\n",
"25. M\n",
"26. M\n",
"27. R\n",
"28. M\n",
"29. R\n",
"30. M\n",
"31. M\n",
"32. M\n",
"33. M\n",
"34. R\n",
"35. M\n",
"36. R\n",
"37. M\n",
"38. R\n",
"39. R\n",
"40. R\n",
"41. R\n",
"42. M\n",
"43. R\n",
"44. R\n",
"45. R\n",
"46. R\n",
"47. M\n",
"48. R\n",
"49. R\n",
"50. M\n",
"51. M\n",
"52. M\n",
"53. M\n",
"54. R\n",
"55. R\n",
"56. R\n",
"57. M\n",
"58. M\n",
"59. M\n",
"60. R\n",
"61. R\n",
"62. R\n",
"63. R\n",
"64. M\n",
"65. M\n",
"66. R\n",
"67. M\n",
"68. M\n",
"69. R\n",
"70. R\n",
"71. M\n",
"72. M\n",
"73. M\n",
"74. M\n",
"75. R\n",
"76. R\n",
"77. M\n",
"78. M\n",
"79. M\n",
"80. R\n",
"81. M\n",
"82. M\n",
"83. M\n",
"84. M\n",
"85. M\n",
"86. R\n",
"87. M\n",
"88. M\n",
"89. M\n",
"90. M\n",
"91. R\n",
"92. R\n",
"93. M\n",
"94. M\n",
"95. R\n",
"96. R\n",
"97. R\n",
"98. R\n",
"99. R\n",
"100. R\n",
"101. M\n",
"102. R\n",
"103. M\n",
"104. M\n",
"105. M\n",
"106. R\n",
"107. M\n",
"108. R\n",
"109. R\n",
"110. M\n",
"111. M\n",
"112. M\n",
"113. M\n",
"114. M\n",
"115. R\n",
"116. M\n",
"117. M\n",
"118. M\n",
"119. M\n",
"120. M\n",
"121. R\n",
"122. R\n",
"123. M\n",
"124. M\n",
"125. M\n",
"126. R\n",
"127. M\n",
"128. M\n",
"129. R\n",
"130. M\n",
"131. R\n",
"132. R\n",
"133. M\n",
"134. M\n",
"135. R\n",
"136. R\n",
"137. R\n",
"138. R\n",
"139. M\n",
"140. R\n",
"141. M\n",
"142. M\n",
"143. R\n",
"144. R\n",
"145. M\n",
"\n",
"\n",
"\n",
"**Levels**: 1. 'M'\n",
"2. 'R'\n",
"\n",
"\n"
],
"text/plain": [
" [1] R M R M M R M M M R M R M M M R R R R M M M M M M M R M R M M M M R M R M\n",
" [38] R R R R M R R R R M R R M M M M R R R M M M R R R R M M R M M R R M M M M\n",
" [75] R R M M M R M M M M M R M M M M R R M M R R R R R R M R M M M R M R R M M\n",
"[112] M M M R M M M M M R R M M M R M M R M R R M M R R R R M R M M R R M\n",
"Levels: M R"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"knn_loocv <- knn.cv(train=X_train, cl=y_train, k=3)\n",
"knn_loocv"
]
},
{
"cell_type": "markdown",
"metadata": {
"button": false,
"new_sheet": false,
"run_control": {
"read_only": false
}
},
"source": [
"Lets create a confusion matrix to compute the accuracy of the training labels `y_train` and the cross-validated predictions `knn_loocv`, same as the above. What can you find from comparing the LOOCV accuracy and the test accuracy above?"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {
"button": false,
"collapsed": false,
"jupyter": {
"outputs_hidden": false
},
"new_sheet": false,
"run_control": {
"read_only": false
}
},
"outputs": [
{
"data": {
"text/plain": [
" knn_loocv\n",
"y_train M R\n",
" M 67 10\n",
" R 15 53"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"LOOCV accuracy: 0.8275862"
]
}
],
"source": [
"conf_mat_cv <- base::table(y_train, knn_loocv)\n",
"conf_mat_cv\n",
"cat(\"LOOCV accuracy: \", sum(diag(conf_mat_cv)) / sum(conf_mat_cv))"
]
},
{
"cell_type": "markdown",
"metadata": {
"button": false,
"new_sheet": false,
"run_control": {
"read_only": false
}
},
"source": [
"The difference between the cross-validated accuracy and the test accuracy shows that, $k=3$ leads to overfitting. Perhaps we should change $k$ to lessen the overfitting."
]
},
{
"cell_type": "markdown",
"metadata": {
"button": false,
"new_sheet": false,
"run_control": {
"read_only": false
}
},
"source": [
"### Step 5 – Improve the performance of the model\n",
"As noted earlier, we have *not* standardized (as part of preprocessing) our training and test sets. In the rest of the tutorial, we will see the effect of choosing a suitable $k$ through repeated *cross-validations* using `caret` library.\n",
"\n",
"In a *cross-validation* procedure: \n",
"1. The data is divided into the finite number of mutually exclusive subsets \n",
"2. Through each iteration, a subset is set aside, and the remaining subsets are used as the training set\n",
"3. The subset that was set aside is used as the test set (prediction)\n",
"\n",
"This is a method of cross-referencing the model built using its own data."
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {
"button": false,
"new_sheet": false,
"run_control": {
"read_only": false
}
},
"outputs": [],
"source": [
"SEED <- 2016\n",
"set.seed(SEED)\n",
"# create the training data 70% of the overall Sonar data.\n",
"in_train <- createDataPartition(Sonar$Class, p=0.7, list=FALSE) # create training indices\n",
"ndf_train <- Sonar[in_train, ]\n",
"ndf_test <- Sonar[-in_train, ]"
]
},
{
"cell_type": "markdown",
"metadata": {
"button": false,
"new_sheet": false,
"run_control": {
"read_only": false
}
},
"source": [
"Here, we specify the cross-validation method we want to use to find the best $k$ in grid search. Later, we use the built-in `plot` function to assess the changes in accuracy for different choices of $k.$"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {
"button": false,
"collapsed": false,
"jupyter": {
"outputs_hidden": false
},
"new_sheet": false,
"run_control": {
"read_only": false
}
},
"outputs": [
{
"data": {
"text/html": [
"<table>\n",
"<caption>A data.frame: 4 × 1</caption>\n",
"<thead>\n",
"\t<tr><th scope=col>k</th></tr>\n",
"\t<tr><th scope=col>&lt;dbl&gt;</th></tr>\n",
"</thead>\n",
"<tbody>\n",
"\t<tr><td>1</td></tr>\n",
"\t<tr><td>3</td></tr>\n",
"\t<tr><td>5</td></tr>\n",
"\t<tr><td>7</td></tr>\n",
"</tbody>\n",
"</table>\n"
],
"text/latex": [
"A data.frame: 4 × 1\n",
"\\begin{tabular}{r|l}\n",
" k\\\\\n",
" <dbl>\\\\\n",
"\\hline\n",
"\t 1\\\\\n",
"\t 3\\\\\n",
"\t 5\\\\\n",
"\t 7\\\\\n",
"\\end{tabular}\n"
],
"text/markdown": [
"\n",
"A data.frame: 4 × 1\n",
"\n",
"| k &lt;dbl&gt; |\n",
"|---|\n",
"| 1 |\n",
"| 3 |\n",
"| 5 |\n",
"| 7 |\n",
"\n"
],
"text/plain": [
" k\n",
"1 1\n",
"2 3\n",
"3 5\n",
"4 7"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"# lets create a function setup to do 5-fold cross-validation with 2 repeat.\n",
"ctrl <- trainControl(method=\"repeatedcv\", number=5, repeats=2)\n",
"\n",
"nn_grid <- expand.grid(k=c(1,3,5,7))\n",
"nn_grid"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {
"button": false,
"collapsed": false,
"jupyter": {
"outputs_hidden": false
},
"new_sheet": false,
"run_control": {
"read_only": false
}
},
"outputs": [
{
"data": {
"text/plain": [
"k-Nearest Neighbors \n",
"\n",
"146 samples\n",
" 60 predictor\n",
" 2 classes: 'M', 'R' \n",
"\n",
"Pre-processing: centered (60), scaled (60) \n",
"Resampling: Cross-Validated (5 fold, repeated 2 times) \n",
"Summary of sample sizes: 116, 116, 117, 117, 118, 116, ... \n",
"Resampling results across tuning parameters:\n",
"\n",
" k Accuracy Kappa \n",
" 1 0.8593432 0.7152417\n",
" 3 0.8329310 0.6601456\n",
" 5 0.7846305 0.5602652\n",
" 7 0.7608210 0.5081680\n",
"\n",
"Accuracy was used to select the optimal model using the largest value.\n",
"The final value used for the model was k = 1."
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"set.seed(SEED)\n",
"\n",
"best_knn <- train(Class~., data=ndf_train,\n",
" method=\"knn\",\n",
" trControl=ctrl, \n",
" preProcess = c(\"center\", \"scale\"), # standardize\n",
" tuneGrid=nn_grid)\n",
"best_knn"
]
},
{
"cell_type": "markdown",
"metadata": {
"button": false,
"new_sheet": false,
"run_control": {
"read_only": false
}
},
"source": [
"So seemingly, $k=1$ has the highest accuracy from repeated cross-validation."
]
},
{
"cell_type": "markdown",
"metadata": {
"button": false,
"new_sheet": false,
"run_control": {
"read_only": false
}
},
"source": [
"### <span style=\"color:red\">(Optional) Exercise</span>\n",
"\n",
"Try to do dimensionality reduction as part of preprocess to achieve *higher* testing accuracy than above. This may not have a definite solution and it depends on how hard you try!\n",
"\n",
"If you are going to use `caret`, [here](https://github.com/topepo/caret/blob/master/pkg/caret/R/preProcess.R) are the available preprocess options with [explanations](http://topepo.github.io/caret/preprocess.html#pp).\n",
"\n",
"Use the above `best_knn` to make `predictions` on the test set (remeber to remove the `Class` for prediction). Then create the much better version of confusion matrix with `confusionMatrix` function from `caret` and examine the accuracy and its $\\%95$ confidence interval."
]
},
{
"cell_type": "markdown",
"metadata": {
"button": false,
"new_sheet": false,
"run_control": {
"read_only": false
}
},
"source": [
"In fact, the above result indicates $k=1$ (as could be guessed) is also overfitting, though it might be a better option than $k=3.$ Since the initial dimension of our data is high ($61$ is considered high!), then you might have suspected the better approach, as we said at the beginning of tutorial, is to preform *dimensionality reduction* as part of preprocessing."
]
},
{
"cell_type": "markdown",
"metadata": {
"button": false,
"new_sheet": false,
"run_control": {
"read_only": false
}
},
"source": [
"### <span style=\"color:red\">(Optional) Exercise</span>\n",
"\n",
"Try to do dimensionality reduction as part of preprocess to achieve *higher* testing accuracy than above. This may not have a definite solution and it depends on how hard you try!\n",
"\n",
"If you are going to use `caret`, [here](https://github.com/topepo/caret/blob/master/pkg/caret/R/preProcess.R) are the available preprocess options with [explanations](http://topepo.github.io/caret/preprocess.html#pp)."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"button": false,
"collapsed": false,
"jupyter": {
"outputs_hidden": false
},
"new_sheet": false,
"run_control": {
"read_only": false
}
},
"outputs": [],
"source": [
"SEED <- 123 \n",
"set.seed(SEED) \n",
"ctrl <- trainControl(method=\"repeatedcv\", number=5, repeats=5) \n",
"nn_grid <- expand.grid(k=c(1, 3, 5, 7)) \n",
"best_knn_reduced <- train( Class~., data=ndf_train, method=\"knn\", \n",
" trControl=ctrl, preProcess=c(\"center\", \"scale\",\"YeoJohnson\"))\n",
"X_test <- subset(ndf_test, select=-Class) \n",
"pred_reduced <- predict(best_knn_reduced, newdata=X_test, model=\"best\") \n",
"conf_mat_best_reduced <- confusionMatrix(ndf_test$Class, pred_reduced) \n",
"conf_mat_best_reduced \n"
]
},
{
"cell_type": "markdown",
"metadata": {
"button": false,
"new_sheet": false,
"run_control": {
"read_only": false
}
},
"source": [
"-----------------"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Want to learn more?\n",
"\n",
"IBM SPSS Modeler is a comprehensive analytics platform that has many machine learning algorithms. It has been designed to bring predictive intelligence to decisions made by individuals, by groups, by systems – by your enterprise as a whole. A free trial is available through this course, available here: [SPSS Modeler for Mac users](https://cocl.us/ML0151EN_SPSSMod_mac) and [SPSS Modeler for Windows users](https://cocl.us/ML0151EN_SPSSMod_win)\n",
"\n",
"Also, you can use Data Science Experience to run these notebooks faster with bigger datasets. Data Science Experience is IBM's leading cloud solution for data scientists, built by data scientists. With Jupyter notebooks, RStudio, Apache Spark and popular libraries pre-packaged in the cloud, DSX enables data scientists to collaborate on their projects without having to install anything. Join the fast-growing community of DSX users today with a free account at [Data Science Experience](https://cocl.us/ML0151EN_DSX)"
]
},
{
"cell_type": "markdown",
"metadata": {
"button": false,
"new_sheet": false,
"run_control": {
"read_only": false
}
},
"source": [
"### Thanks for completing this lesson!\n",
"\n",
"Notebook created by: [Ehsan M. Kermani](https://www.linkedin.com/in/ehsanmkermani)"
]
},
{
"cell_type": "markdown",
"metadata": {
"button": false,
"new_sheet": false,
"run_control": {
"read_only": false
}
},
"source": [
"### References:\n",
"\n",
"* [K-nearest neighbors algorithm](https://en.wikipedia.org/wiki/K-nearest_neighbors_algorithm)\n",
"* [Short introdunction to caret](https://cran.r-project.org/web/packages/caret/vignettes/caret.pdf)\n",
"* [Predictive modeling with caret](https://www.r-project.org/nosvn/conferences/useR-2013/Tutorials/kuhn/user_caret_2up.pdf)"
]
},
{
"cell_type": "markdown",
"metadata": {
"button": false,
"new_sheet": false,
"run_control": {
"read_only": false
}
},
"source": [
"<hr>\n",
"Copyright &copy; 2016 [Cognitive Class](https://cognitiveClass.ai/?utm_source=bducopyrightlink&utm_medium=dswb&utm_campaign=bdu). This notebook and its source code are released under the terms of the [MIT License](https://bigdatauniversity.com/mit-license/)."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true,
"jupyter": {
"outputs_hidden": true
}
},
"outputs": [],
"source": []
}
],
"metadata": {
"anaconda-cloud": {},
"kernelspec": {
"display_name": "R",
"language": "R",
"name": "conda-env-r-r"
},
"language_info": {
"codemirror_mode": "r",
"file_extension": ".r",
"mimetype": "text/x-r-source",
"name": "R",
"pygments_lexer": "r",
"version": "3.5.1"
},
"widgets": {
"state": {},
"version": "1.1.2"
}
},
"nbformat": 4,
"nbformat_minor": 4
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment