Skip to content

Instantly share code, notes, and snippets.

@pmarcelino
Last active January 26, 2019 04:21
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 1 You must be signed in to fork a gist
  • Save pmarcelino/4b54bc02ee3f9b01036339f508b22fbb to your computer and use it in GitHub Desktop.
Save pmarcelino/4b54bc02ee3f9b01036339f508b22fbb to your computer and use it in GitHub Desktop.
Missing Data
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Missing Data"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Real world datasets often have missing values. Datasets with missing values are incomplete, which is a problem because not all machine learning algorithms can handle missing data. Accordingly, we need to find ways to transform an incomplete dataset into a complete dataset.\n",
"\n",
"Two of the most **common solutions** to transform an incomplete dataset into a complete dataset are:\n",
"\n",
"1. **Use only valid data.** In these cases we remove all the observations with missing data.\n",
"1. **Impute data.** Here we replace missing values with estimated values based on other information available in the dataset.\n",
"\n",
"The question now is: which is the best solution? \n",
"\n",
"Assuming that our goal is to make predictions, the **best solution** is the one who leads us to the **most accurate model**. We can find out this through the following process:\n",
"\n",
"1. Choose an **error metric** (e.g. accuracy).\n",
"1. Select a **machine learning algorithm** (e.g. logistic regression).\n",
"1. Apply a **missing data solution** (e.g. impute data) to get a complete dataset.\n",
"1. Evaluate model performance through **cross-validation**.\n",
"\n",
"In the end, we should choose the solution that gives us the most accurate model. \n",
"\n",
"Let's see how to do this in practice."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"---"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Example"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We will start by creating a dataset with missing data. In this example, missing values will be artificially implanted into the [Iris dataset](https://en.wikipedia.org/wiki/Iris_flower_data_set). This dataset uses morphologic data (e.g. petal length) to characterize three different species of the Iris flower."
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{'DESCR': 'Iris Plants Database\\n====================\\n\\nNotes\\n-----\\nData Set Characteristics:\\n :Number of Instances: 150 (50 in each of three classes)\\n :Number of Attributes: 4 numeric, predictive attributes and the class\\n :Attribute Information:\\n - sepal length in cm\\n - sepal width in cm\\n - petal length in cm\\n - petal width in cm\\n - class:\\n - Iris-Setosa\\n - Iris-Versicolour\\n - Iris-Virginica\\n :Summary Statistics:\\n\\n ============== ==== ==== ======= ===== ====================\\n Min Max Mean SD Class Correlation\\n ============== ==== ==== ======= ===== ====================\\n sepal length: 4.3 7.9 5.84 0.83 0.7826\\n sepal width: 2.0 4.4 3.05 0.43 -0.4194\\n petal length: 1.0 6.9 3.76 1.76 0.9490 (high!)\\n petal width: 0.1 2.5 1.20 0.76 0.9565 (high!)\\n ============== ==== ==== ======= ===== ====================\\n\\n :Missing Attribute Values: None\\n :Class Distribution: 33.3% for each of 3 classes.\\n :Creator: R.A. Fisher\\n :Donor: Michael Marshall (MARSHALL%PLU@io.arc.nasa.gov)\\n :Date: July, 1988\\n\\nThis is a copy of UCI ML iris datasets.\\nhttp://archive.ics.uci.edu/ml/datasets/Iris\\n\\nThe famous Iris database, first used by Sir R.A Fisher\\n\\nThis is perhaps the best known database to be found in the\\npattern recognition literature. Fisher\\'s paper is a classic in the field and\\nis referenced frequently to this day. (See Duda & Hart, for example.) The\\ndata set contains 3 classes of 50 instances each, where each class refers to a\\ntype of iris plant. One class is linearly separable from the other 2; the\\nlatter are NOT linearly separable from each other.\\n\\nReferences\\n----------\\n - Fisher,R.A. \"The use of multiple measurements in taxonomic problems\"\\n Annual Eugenics, 7, Part II, 179-188 (1936); also in \"Contributions to\\n Mathematical Statistics\" (John Wiley, NY, 1950).\\n - Duda,R.O., & Hart,P.E. (1973) Pattern Classification and Scene Analysis.\\n (Q327.D83) John Wiley & Sons. ISBN 0-471-22361-1. See page 218.\\n - Dasarathy, B.V. (1980) \"Nosing Around the Neighborhood: A New System\\n Structure and Classification Rule for Recognition in Partially Exposed\\n Environments\". IEEE Transactions on Pattern Analysis and Machine\\n Intelligence, Vol. PAMI-2, No. 1, 67-71.\\n - Gates, G.W. (1972) \"The Reduced Nearest Neighbor Rule\". IEEE Transactions\\n on Information Theory, May 1972, 431-433.\\n - See also: 1988 MLC Proceedings, 54-64. Cheeseman et al\"s AUTOCLASS II\\n conceptual clustering system finds 3 classes in the data.\\n - Many, many more ...\\n',\n",
" 'data': array([[5.1, 3.5, 1.4, 0.2],\n",
" [4.9, 3. , 1.4, 0.2],\n",
" [4.7, 3.2, 1.3, 0.2],\n",
" [4.6, 3.1, 1.5, 0.2],\n",
" [5. , 3.6, 1.4, 0.2],\n",
" [5.4, 3.9, 1.7, 0.4],\n",
" [4.6, 3.4, 1.4, 0.3],\n",
" [5. , 3.4, 1.5, 0.2],\n",
" [4.4, 2.9, 1.4, 0.2],\n",
" [4.9, 3.1, 1.5, 0.1],\n",
" [5.4, 3.7, 1.5, 0.2],\n",
" [4.8, 3.4, 1.6, 0.2],\n",
" [4.8, 3. , 1.4, 0.1],\n",
" [4.3, 3. , 1.1, 0.1],\n",
" [5.8, 4. , 1.2, 0.2],\n",
" [5.7, 4.4, 1.5, 0.4],\n",
" [5.4, 3.9, 1.3, 0.4],\n",
" [5.1, 3.5, 1.4, 0.3],\n",
" [5.7, 3.8, 1.7, 0.3],\n",
" [5.1, 3.8, 1.5, 0.3],\n",
" [5.4, 3.4, 1.7, 0.2],\n",
" [5.1, 3.7, 1.5, 0.4],\n",
" [4.6, 3.6, 1. , 0.2],\n",
" [5.1, 3.3, 1.7, 0.5],\n",
" [4.8, 3.4, 1.9, 0.2],\n",
" [5. , 3. , 1.6, 0.2],\n",
" [5. , 3.4, 1.6, 0.4],\n",
" [5.2, 3.5, 1.5, 0.2],\n",
" [5.2, 3.4, 1.4, 0.2],\n",
" [4.7, 3.2, 1.6, 0.2],\n",
" [4.8, 3.1, 1.6, 0.2],\n",
" [5.4, 3.4, 1.5, 0.4],\n",
" [5.2, 4.1, 1.5, 0.1],\n",
" [5.5, 4.2, 1.4, 0.2],\n",
" [4.9, 3.1, 1.5, 0.1],\n",
" [5. , 3.2, 1.2, 0.2],\n",
" [5.5, 3.5, 1.3, 0.2],\n",
" [4.9, 3.1, 1.5, 0.1],\n",
" [4.4, 3. , 1.3, 0.2],\n",
" [5.1, 3.4, 1.5, 0.2],\n",
" [5. , 3.5, 1.3, 0.3],\n",
" [4.5, 2.3, 1.3, 0.3],\n",
" [4.4, 3.2, 1.3, 0.2],\n",
" [5. , 3.5, 1.6, 0.6],\n",
" [5.1, 3.8, 1.9, 0.4],\n",
" [4.8, 3. , 1.4, 0.3],\n",
" [5.1, 3.8, 1.6, 0.2],\n",
" [4.6, 3.2, 1.4, 0.2],\n",
" [5.3, 3.7, 1.5, 0.2],\n",
" [5. , 3.3, 1.4, 0.2],\n",
" [7. , 3.2, 4.7, 1.4],\n",
" [6.4, 3.2, 4.5, 1.5],\n",
" [6.9, 3.1, 4.9, 1.5],\n",
" [5.5, 2.3, 4. , 1.3],\n",
" [6.5, 2.8, 4.6, 1.5],\n",
" [5.7, 2.8, 4.5, 1.3],\n",
" [6.3, 3.3, 4.7, 1.6],\n",
" [4.9, 2.4, 3.3, 1. ],\n",
" [6.6, 2.9, 4.6, 1.3],\n",
" [5.2, 2.7, 3.9, 1.4],\n",
" [5. , 2. , 3.5, 1. ],\n",
" [5.9, 3. , 4.2, 1.5],\n",
" [6. , 2.2, 4. , 1. ],\n",
" [6.1, 2.9, 4.7, 1.4],\n",
" [5.6, 2.9, 3.6, 1.3],\n",
" [6.7, 3.1, 4.4, 1.4],\n",
" [5.6, 3. , 4.5, 1.5],\n",
" [5.8, 2.7, 4.1, 1. ],\n",
" [6.2, 2.2, 4.5, 1.5],\n",
" [5.6, 2.5, 3.9, 1.1],\n",
" [5.9, 3.2, 4.8, 1.8],\n",
" [6.1, 2.8, 4. , 1.3],\n",
" [6.3, 2.5, 4.9, 1.5],\n",
" [6.1, 2.8, 4.7, 1.2],\n",
" [6.4, 2.9, 4.3, 1.3],\n",
" [6.6, 3. , 4.4, 1.4],\n",
" [6.8, 2.8, 4.8, 1.4],\n",
" [6.7, 3. , 5. , 1.7],\n",
" [6. , 2.9, 4.5, 1.5],\n",
" [5.7, 2.6, 3.5, 1. ],\n",
" [5.5, 2.4, 3.8, 1.1],\n",
" [5.5, 2.4, 3.7, 1. ],\n",
" [5.8, 2.7, 3.9, 1.2],\n",
" [6. , 2.7, 5.1, 1.6],\n",
" [5.4, 3. , 4.5, 1.5],\n",
" [6. , 3.4, 4.5, 1.6],\n",
" [6.7, 3.1, 4.7, 1.5],\n",
" [6.3, 2.3, 4.4, 1.3],\n",
" [5.6, 3. , 4.1, 1.3],\n",
" [5.5, 2.5, 4. , 1.3],\n",
" [5.5, 2.6, 4.4, 1.2],\n",
" [6.1, 3. , 4.6, 1.4],\n",
" [5.8, 2.6, 4. , 1.2],\n",
" [5. , 2.3, 3.3, 1. ],\n",
" [5.6, 2.7, 4.2, 1.3],\n",
" [5.7, 3. , 4.2, 1.2],\n",
" [5.7, 2.9, 4.2, 1.3],\n",
" [6.2, 2.9, 4.3, 1.3],\n",
" [5.1, 2.5, 3. , 1.1],\n",
" [5.7, 2.8, 4.1, 1.3],\n",
" [6.3, 3.3, 6. , 2.5],\n",
" [5.8, 2.7, 5.1, 1.9],\n",
" [7.1, 3. , 5.9, 2.1],\n",
" [6.3, 2.9, 5.6, 1.8],\n",
" [6.5, 3. , 5.8, 2.2],\n",
" [7.6, 3. , 6.6, 2.1],\n",
" [4.9, 2.5, 4.5, 1.7],\n",
" [7.3, 2.9, 6.3, 1.8],\n",
" [6.7, 2.5, 5.8, 1.8],\n",
" [7.2, 3.6, 6.1, 2.5],\n",
" [6.5, 3.2, 5.1, 2. ],\n",
" [6.4, 2.7, 5.3, 1.9],\n",
" [6.8, 3. , 5.5, 2.1],\n",
" [5.7, 2.5, 5. , 2. ],\n",
" [5.8, 2.8, 5.1, 2.4],\n",
" [6.4, 3.2, 5.3, 2.3],\n",
" [6.5, 3. , 5.5, 1.8],\n",
" [7.7, 3.8, 6.7, 2.2],\n",
" [7.7, 2.6, 6.9, 2.3],\n",
" [6. , 2.2, 5. , 1.5],\n",
" [6.9, 3.2, 5.7, 2.3],\n",
" [5.6, 2.8, 4.9, 2. ],\n",
" [7.7, 2.8, 6.7, 2. ],\n",
" [6.3, 2.7, 4.9, 1.8],\n",
" [6.7, 3.3, 5.7, 2.1],\n",
" [7.2, 3.2, 6. , 1.8],\n",
" [6.2, 2.8, 4.8, 1.8],\n",
" [6.1, 3. , 4.9, 1.8],\n",
" [6.4, 2.8, 5.6, 2.1],\n",
" [7.2, 3. , 5.8, 1.6],\n",
" [7.4, 2.8, 6.1, 1.9],\n",
" [7.9, 3.8, 6.4, 2. ],\n",
" [6.4, 2.8, 5.6, 2.2],\n",
" [6.3, 2.8, 5.1, 1.5],\n",
" [6.1, 2.6, 5.6, 1.4],\n",
" [7.7, 3. , 6.1, 2.3],\n",
" [6.3, 3.4, 5.6, 2.4],\n",
" [6.4, 3.1, 5.5, 1.8],\n",
" [6. , 3. , 4.8, 1.8],\n",
" [6.9, 3.1, 5.4, 2.1],\n",
" [6.7, 3.1, 5.6, 2.4],\n",
" [6.9, 3.1, 5.1, 2.3],\n",
" [5.8, 2.7, 5.1, 1.9],\n",
" [6.8, 3.2, 5.9, 2.3],\n",
" [6.7, 3.3, 5.7, 2.5],\n",
" [6.7, 3. , 5.2, 2.3],\n",
" [6.3, 2.5, 5. , 1.9],\n",
" [6.5, 3. , 5.2, 2. ],\n",
" [6.2, 3.4, 5.4, 2.3],\n",
" [5.9, 3. , 5.1, 1.8]]),\n",
" 'feature_names': ['sepal length (cm)',\n",
" 'sepal width (cm)',\n",
" 'petal length (cm)',\n",
" 'petal width (cm)'],\n",
" 'target': array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
" 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
" 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,\n",
" 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,\n",
" 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2]),\n",
" 'target_names': array(['setosa', 'versicolor', 'virginica'], dtype='<U10')}"
]
},
"execution_count": 1,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Load Iris dataset\n",
"from sklearn.datasets import load_iris\n",
"\n",
"df = load_iris()\n",
"df"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"As we can see, there are four features (sepal length, sepal width, petal length, petal width) and three different species (setosa, versicolor, virginica)."
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"X\n",
" [[5.1 3.5 1.4 0.2]\n",
" [4.9 3. 1.4 0.2]\n",
" [4.7 3.2 1.3 0.2]\n",
" [4.6 3.1 1.5 0.2]\n",
" [5. 3.6 1.4 0.2]\n",
" [5.4 3.9 1.7 0.4]\n",
" [4.6 3.4 1.4 0.3]\n",
" [5. 3.4 1.5 0.2]\n",
" [4.4 2.9 1.4 0.2]\n",
" [4.9 3.1 1.5 0.1]\n",
" [5.4 3.7 1.5 0.2]\n",
" [4.8 3.4 1.6 0.2]\n",
" [4.8 3. 1.4 0.1]\n",
" [4.3 3. 1.1 0.1]\n",
" [5.8 4. 1.2 0.2]\n",
" [5.7 4.4 1.5 0.4]\n",
" [5.4 3.9 1.3 0.4]\n",
" [5.1 3.5 1.4 0.3]\n",
" [5.7 3.8 1.7 0.3]\n",
" [5.1 3.8 1.5 0.3]\n",
" [5.4 3.4 1.7 0.2]\n",
" [5.1 3.7 1.5 0.4]\n",
" [4.6 3.6 1. 0.2]\n",
" [5.1 3.3 1.7 0.5]\n",
" [4.8 3.4 1.9 0.2]\n",
" [5. 3. 1.6 0.2]\n",
" [5. 3.4 1.6 0.4]\n",
" [5.2 3.5 1.5 0.2]\n",
" [5.2 3.4 1.4 0.2]\n",
" [4.7 3.2 1.6 0.2]\n",
" [4.8 3.1 1.6 0.2]\n",
" [5.4 3.4 1.5 0.4]\n",
" [5.2 4.1 1.5 0.1]\n",
" [5.5 4.2 1.4 0.2]\n",
" [4.9 3.1 1.5 0.1]\n",
" [5. 3.2 1.2 0.2]\n",
" [5.5 3.5 1.3 0.2]\n",
" [4.9 3.1 1.5 0.1]\n",
" [4.4 3. 1.3 0.2]\n",
" [5.1 3.4 1.5 0.2]\n",
" [5. 3.5 1.3 0.3]\n",
" [4.5 2.3 1.3 0.3]\n",
" [4.4 3.2 1.3 0.2]\n",
" [5. 3.5 1.6 0.6]\n",
" [5.1 3.8 1.9 0.4]\n",
" [4.8 3. 1.4 0.3]\n",
" [5.1 3.8 1.6 0.2]\n",
" [4.6 3.2 1.4 0.2]\n",
" [5.3 3.7 1.5 0.2]\n",
" [5. 3.3 1.4 0.2]\n",
" [7. 3.2 4.7 1.4]\n",
" [6.4 3.2 4.5 1.5]\n",
" [6.9 3.1 4.9 1.5]\n",
" [5.5 2.3 4. 1.3]\n",
" [6.5 2.8 4.6 1.5]\n",
" [5.7 2.8 4.5 1.3]\n",
" [6.3 3.3 4.7 1.6]\n",
" [4.9 2.4 3.3 1. ]\n",
" [6.6 2.9 4.6 1.3]\n",
" [5.2 2.7 3.9 1.4]\n",
" [5. 2. 3.5 1. ]\n",
" [5.9 3. 4.2 1.5]\n",
" [6. 2.2 4. 1. ]\n",
" [6.1 2.9 4.7 1.4]\n",
" [5.6 2.9 3.6 1.3]\n",
" [6.7 3.1 4.4 1.4]\n",
" [5.6 3. 4.5 1.5]\n",
" [5.8 2.7 4.1 1. ]\n",
" [6.2 2.2 4.5 1.5]\n",
" [5.6 2.5 3.9 1.1]\n",
" [5.9 3.2 4.8 1.8]\n",
" [6.1 2.8 4. 1.3]\n",
" [6.3 2.5 4.9 1.5]\n",
" [6.1 2.8 4.7 1.2]\n",
" [6.4 2.9 4.3 1.3]\n",
" [6.6 3. 4.4 1.4]\n",
" [6.8 2.8 4.8 1.4]\n",
" [6.7 3. 5. 1.7]\n",
" [6. 2.9 4.5 1.5]\n",
" [5.7 2.6 3.5 1. ]\n",
" [5.5 2.4 3.8 1.1]\n",
" [5.5 2.4 3.7 1. ]\n",
" [5.8 2.7 3.9 1.2]\n",
" [6. 2.7 5.1 1.6]\n",
" [5.4 3. 4.5 1.5]\n",
" [6. 3.4 4.5 1.6]\n",
" [6.7 3.1 4.7 1.5]\n",
" [6.3 2.3 4.4 1.3]\n",
" [5.6 3. 4.1 1.3]\n",
" [5.5 2.5 4. 1.3]\n",
" [5.5 2.6 4.4 1.2]\n",
" [6.1 3. 4.6 1.4]\n",
" [5.8 2.6 4. 1.2]\n",
" [5. 2.3 3.3 1. ]\n",
" [5.6 2.7 4.2 1.3]\n",
" [5.7 3. 4.2 1.2]\n",
" [5.7 2.9 4.2 1.3]\n",
" [6.2 2.9 4.3 1.3]\n",
" [5.1 2.5 3. 1.1]\n",
" [5.7 2.8 4.1 1.3]\n",
" [6.3 3.3 6. 2.5]\n",
" [5.8 2.7 5.1 1.9]\n",
" [7.1 3. 5.9 2.1]\n",
" [6.3 2.9 5.6 1.8]\n",
" [6.5 3. 5.8 2.2]\n",
" [7.6 3. 6.6 2.1]\n",
" [4.9 2.5 4.5 1.7]\n",
" [7.3 2.9 6.3 1.8]\n",
" [6.7 2.5 5.8 1.8]\n",
" [7.2 3.6 6.1 2.5]\n",
" [6.5 3.2 5.1 2. ]\n",
" [6.4 2.7 5.3 1.9]\n",
" [6.8 3. 5.5 2.1]\n",
" [5.7 2.5 5. 2. ]\n",
" [5.8 2.8 5.1 2.4]\n",
" [6.4 3.2 5.3 2.3]\n",
" [6.5 3. 5.5 1.8]\n",
" [7.7 3.8 6.7 2.2]\n",
" [7.7 2.6 6.9 2.3]\n",
" [6. 2.2 5. 1.5]\n",
" [6.9 3.2 5.7 2.3]\n",
" [5.6 2.8 4.9 2. ]\n",
" [7.7 2.8 6.7 2. ]\n",
" [6.3 2.7 4.9 1.8]\n",
" [6.7 3.3 5.7 2.1]\n",
" [7.2 3.2 6. 1.8]\n",
" [6.2 2.8 4.8 1.8]\n",
" [6.1 3. 4.9 1.8]\n",
" [6.4 2.8 5.6 2.1]\n",
" [7.2 3. 5.8 1.6]\n",
" [7.4 2.8 6.1 1.9]\n",
" [7.9 3.8 6.4 2. ]\n",
" [6.4 2.8 5.6 2.2]\n",
" [6.3 2.8 5.1 1.5]\n",
" [6.1 2.6 5.6 1.4]\n",
" [7.7 3. 6.1 2.3]\n",
" [6.3 3.4 5.6 2.4]\n",
" [6.4 3.1 5.5 1.8]\n",
" [6. 3. 4.8 1.8]\n",
" [6.9 3.1 5.4 2.1]\n",
" [6.7 3.1 5.6 2.4]\n",
" [6.9 3.1 5.1 2.3]\n",
" [5.8 2.7 5.1 1.9]\n",
" [6.8 3.2 5.9 2.3]\n",
" [6.7 3.3 5.7 2.5]\n",
" [6.7 3. 5.2 2.3]\n",
" [6.3 2.5 5. 1.9]\n",
" [6.5 3. 5.2 2. ]\n",
" [6.2 3.4 5.4 2.3]\n",
" [5.9 3. 5.1 1.8]]\n",
"Y\n",
" [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n",
" 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1\n",
" 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 2 2 2 2 2 2 2 2 2 2 2\n",
" 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2\n",
" 2 2]\n"
]
}
],
"source": [
"# Split data into features and target variable\n",
"X, y = df.data, df.target\n",
"print('X\\n', X)\n",
"print('Y\\n', y)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The dataset isn't shuffled. We need to shuffle it to avoid data segretation problems during cross-validation. If you want to read more about shuffling data and other data cleaning tasks, you can read this [blog post](http://pmarcelino.com/data-cleaning-general-techniques/)."
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"X\n",
" [[5.7 4.4 1.5 0.4]\n",
" [4.9 3.1 1.5 0.1]\n",
" [5.7 3. 4.2 1.2]\n",
" [5.8 2.7 5.1 1.9]\n",
" [6.9 3.1 5.4 2.1]\n",
" [6.3 3.4 5.6 2.4]\n",
" [6.7 3.3 5.7 2.1]\n",
" [5.1 3.8 1.5 0.3]\n",
" [5. 3.5 1.3 0.3]\n",
" [6.6 2.9 4.6 1.3]\n",
" [5.3 3.7 1.5 0.2]\n",
" [5. 3.4 1.6 0.4]\n",
" [5. 3. 1.6 0.2]\n",
" [6.1 3. 4.9 1.8]\n",
" [4.6 3.2 1.4 0.2]\n",
" [4.9 3.1 1.5 0.1]\n",
" [6. 2.2 5. 1.5]\n",
" [4.9 3.1 1.5 0.1]\n",
" [5.6 3. 4.1 1.3]\n",
" [5.6 2.8 4.9 2. ]\n",
" [6.3 2.3 4.4 1.3]\n",
" [6.4 3.2 4.5 1.5]\n",
" [6.7 3.1 4.4 1.4]\n",
" [5.1 3.7 1.5 0.4]\n",
" [5.7 2.5 5. 2. ]\n",
" [6.4 2.8 5.6 2.1]\n",
" [5.6 3. 4.5 1.5]\n",
" [5.4 3.4 1.7 0.2]\n",
" [6.7 3.1 4.7 1.5]\n",
" [5. 3.2 1.2 0.2]\n",
" [6.3 2.5 5. 1.9]\n",
" [4.8 3. 1.4 0.3]\n",
" [4.6 3.6 1. 0.2]\n",
" [6.1 3. 4.6 1.4]\n",
" [6.8 2.8 4.8 1.4]\n",
" [5.8 2.7 4.1 1. ]\n",
" [7.6 3. 6.6 2.1]\n",
" [6.1 2.6 5.6 1.4]\n",
" [6.5 3. 5.2 2. ]\n",
" [7. 3.2 4.7 1.4]\n",
" [5.2 3.4 1.4 0.2]\n",
" [5.9 3.2 4.8 1.8]\n",
" [7.1 3. 5.9 2.1]\n",
" [6. 2.2 4. 1. ]\n",
" [6.3 2.9 5.6 1.8]\n",
" [7.7 2.8 6.7 2. ]\n",
" [6.4 3.2 5.3 2.3]\n",
" [5.4 3. 4.5 1.5]\n",
" [5.1 3.8 1.9 0.4]\n",
" [6.1 2.8 4. 1.3]\n",
" [7.4 2.8 6.1 1.9]\n",
" [5.5 2.4 3.7 1. ]\n",
" [7.7 3.8 6.7 2.2]\n",
" [5.5 3.5 1.3 0.2]\n",
" [6.5 2.8 4.6 1.5]\n",
" [6.1 2.8 4.7 1.2]\n",
" [4.4 2.9 1.4 0.2]\n",
" [5.8 4. 1.2 0.2]\n",
" [5. 3.6 1.4 0.2]\n",
" [5.2 2.7 3.9 1.4]\n",
" [6.4 2.8 5.6 2.2]\n",
" [6.8 3.2 5.9 2.3]\n",
" [5.7 2.8 4.1 1.3]\n",
" [4.4 3. 1.3 0.2]\n",
" [6.3 2.8 5.1 1.5]\n",
" [4.7 3.2 1.3 0.2]\n",
" [4.5 2.3 1.3 0.3]\n",
" [6. 2.9 4.5 1.5]\n",
" [6.7 3. 5. 1.7]\n",
" [5.6 2.9 3.6 1.3]\n",
" [7.7 2.6 6.9 2.3]\n",
" [4.9 3. 1.4 0.2]\n",
" [6.4 3.1 5.5 1.8]\n",
" [6.7 3.1 5.6 2.4]\n",
" [4.8 3. 1.4 0.1]\n",
" [4.7 3.2 1.6 0.2]\n",
" [6.7 3. 5.2 2.3]\n",
" [4.4 3.2 1.3 0.2]\n",
" [5.1 3.5 1.4 0.2]\n",
" [6.3 2.5 4.9 1.5]\n",
" [6.2 2.9 4.3 1.3]\n",
" [5.1 3.4 1.5 0.2]\n",
" [5.4 3.9 1.3 0.4]\n",
" [5.1 3.5 1.4 0.3]\n",
" [6. 3.4 4.5 1.6]\n",
" [6.7 2.5 5.8 1.8]\n",
" [5.9 3. 4.2 1.5]\n",
" [5.8 2.6 4. 1.2]\n",
" [6.1 2.9 4.7 1.4]\n",
" [6.9 3.1 5.1 2.3]\n",
" [5.2 4.1 1.5 0.1]\n",
" [4.6 3.4 1.4 0.3]\n",
" [4.8 3.4 1.9 0.2]\n",
" [6.5 3. 5.5 1.8]\n",
" [5.1 2.5 3. 1.1]\n",
" [6.7 3.3 5.7 2.5]\n",
" [5.9 3. 5.1 1.8]\n",
" [5.4 3.4 1.5 0.4]\n",
" [7.9 3.8 6.4 2. ]\n",
" [4.8 3.4 1.6 0.2]\n",
" [6.6 3. 4.4 1.4]\n",
" [5. 2. 3.5 1. ]\n",
" [4.9 2.4 3.3 1. ]\n",
" [6.5 3. 5.8 2.2]\n",
" [5. 3.3 1.4 0.2]\n",
" [6.3 3.3 4.7 1.6]\n",
" [7.2 3.2 6. 1.8]\n",
" [5.8 2.7 3.9 1.2]\n",
" [7.2 3.6 6.1 2.5]\n",
" [5.5 2.3 4. 1.3]\n",
" [5.5 4.2 1.4 0.2]\n",
" [7.2 3. 5.8 1.6]\n",
" [6.3 3.3 6. 2.5]\n",
" [5.1 3.8 1.6 0.2]\n",
" [6.2 2.2 4.5 1.5]\n",
" [6.4 2.9 4.3 1.3]\n",
" [5.5 2.5 4. 1.3]\n",
" [5.5 2.4 3.8 1.1]\n",
" [5.7 3.8 1.7 0.3]\n",
" [5. 3.4 1.5 0.2]\n",
" [6.9 3.1 4.9 1.5]\n",
" [5.6 2.5 3.9 1.1]\n",
" [4.3 3. 1.1 0.1]\n",
" [7.3 2.9 6.3 1.8]\n",
" [5.8 2.7 5.1 1.9]\n",
" [6.4 2.7 5.3 1.9]\n",
" [5.7 2.8 4.5 1.3]\n",
" [5.7 2.6 3.5 1. ]\n",
" [5.1 3.3 1.7 0.5]\n",
" [6.8 3. 5.5 2.1]\n",
" [6. 2.7 5.1 1.6]\n",
" [4.6 3.1 1.5 0.2]\n",
" [6. 3. 4.8 1.8]\n",
" [5.5 2.6 4.4 1.2]\n",
" [5.8 2.8 5.1 2.4]\n",
" [4.8 3.1 1.6 0.2]\n",
" [5.4 3.7 1.5 0.2]\n",
" [6.3 2.7 4.9 1.8]\n",
" [5.7 2.9 4.2 1.3]\n",
" [6.2 3.4 5.4 2.3]\n",
" [5. 3.5 1.6 0.6]\n",
" [7.7 3. 6.1 2.3]\n",
" [5.6 2.7 4.2 1.3]\n",
" [5.4 3.9 1.7 0.4]\n",
" [5. 2.3 3.3 1. ]\n",
" [6.9 3.2 5.7 2.3]\n",
" [6.2 2.8 4.8 1.8]\n",
" [4.9 2.5 4.5 1.7]\n",
" [5.2 3.5 1.5 0.2]\n",
" [6.5 3.2 5.1 2. ]]\n",
"Y\n",
" [0 0 1 2 2 2 2 0 0 1 0 0 0 2 0 0 2 0 1 2 1 1 1 0 2 2 1 0 1 0 2 0 0 1 1 1 2\n",
" 2 2 1 0 1 2 1 2 2 2 1 0 1 2 1 2 0 1 1 0 0 0 1 2 2 1 0 2 0 0 1 1 1 2 0 2 2\n",
" 0 0 2 0 0 1 1 0 0 0 1 2 1 1 1 2 0 0 0 2 1 2 2 0 2 0 1 1 1 2 0 1 2 1 2 1 0\n",
" 2 2 0 1 1 1 1 0 0 1 1 0 2 2 2 1 1 0 2 1 0 2 1 2 0 0 2 1 2 0 2 1 0 1 2 2 2\n",
" 0 2]\n"
]
}
],
"source": [
"# Shuffle data\n",
"from sklearn.utils import shuffle\n",
"\n",
"X, y = shuffle(X, y)\n",
"print('X\\n', X)\n",
"print('Y\\n', y)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The dataset is loaded and converted into a workable format.\n",
"\n",
"Now, we need to artificially implant missing data. There are countless ways to do it. Here, we will do it by:\n",
"1. Generating a random number.\n",
"1. Subtracting it to the values in the dataset.\n",
"1. Assigning a missing value to the cases where the substraction is less than a certain threshold.\n",
"\n",
"Parameters and threshold values are defined randomly, but taking into account the order of magnitude of the numbers involved. "
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([[5.7, 4.4, 1.5, 0.4],\n",
" [nan, 3.1, 1.5, nan],\n",
" [5.7, 3. , 4.2, nan],\n",
" [5.8, 2.7, 5.1, nan],\n",
" [6.9, 3.1, nan, nan],\n",
" [6.3, 3.4, 5.6, 2.4],\n",
" [nan, 3.3, 5.7, 2.1],\n",
" [5.1, 3.8, 1.5, 0.3],\n",
" [5. , 3.5, nan, 0.3],\n",
" [nan, 2.9, 4.6, 1.3],\n",
" [5.3, 3.7, 1.5, 0.2],\n",
" [5. , 3.4, 1.6, 0.4],\n",
" [nan, nan, 1.6, 0.2],\n",
" [6.1, 3. , 4.9, 1.8],\n",
" [nan, 3.2, 1.4, 0.2],\n",
" [nan, nan, 1.5, 0.1],\n",
" [6. , 2.2, 5. , 1.5],\n",
" [4.9, 3.1, 1.5, 0.1],\n",
" [nan, 3. , 4.1, 1.3],\n",
" [5.6, nan, 4.9, nan],\n",
" [6.3, 2.3, 4.4, 1.3],\n",
" [nan, nan, 4.5, 1.5],\n",
" [nan, nan, nan, 1.4],\n",
" [5.1, 3.7, 1.5, 0.4],\n",
" [5.7, 2.5, 5. , 2. ],\n",
" [6.4, nan, 5.6, 2.1],\n",
" [5.6, 3. , 4.5, 1.5],\n",
" [5.4, 3.4, 1.7, 0.2],\n",
" [6.7, 3.1, 4.7, 1.5],\n",
" [5. , 3.2, nan, nan],\n",
" [nan, 2.5, nan, nan],\n",
" [nan, 3. , 1.4, 0.3],\n",
" [4.6, 3.6, nan, 0.2],\n",
" [6.1, 3. , 4.6, 1.4],\n",
" [6.8, nan, 4.8, nan],\n",
" [5.8, 2.7, nan, 1. ],\n",
" [7.6, 3. , 6.6, 2.1],\n",
" [6.1, 2.6, nan, nan],\n",
" [6.5, 3. , 5.2, 2. ],\n",
" [7. , 3.2, 4.7, 1.4],\n",
" [5.2, 3.4, 1.4, 0.2],\n",
" [5.9, 3.2, 4.8, 1.8],\n",
" [7.1, 3. , 5.9, 2.1],\n",
" [6. , 2.2, 4. , 1. ],\n",
" [6.3, nan, 5.6, 1.8],\n",
" [7.7, 2.8, 6.7, 2. ],\n",
" [6.4, nan, 5.3, 2.3],\n",
" [nan, nan, 4.5, 1.5],\n",
" [5.1, 3.8, nan, 0.4],\n",
" [6.1, 2.8, nan, 1.3],\n",
" [7.4, 2.8, 6.1, nan],\n",
" [nan, 2.4, 3.7, 1. ],\n",
" [7.7, nan, 6.7, nan],\n",
" [5.5, nan, 1.3, 0.2],\n",
" [6.5, nan, 4.6, 1.5],\n",
" [nan, 2.8, 4.7, 1.2],\n",
" [4.4, 2.9, 1.4, 0.2],\n",
" [5.8, 4. , 1.2, 0.2],\n",
" [5. , 3.6, 1.4, 0.2],\n",
" [5.2, 2.7, nan, 1.4],\n",
" [6.4, 2.8, nan, nan],\n",
" [6.8, nan, 5.9, 2.3],\n",
" [5.7, nan, 4.1, 1.3],\n",
" [4.4, 3. , nan, 0.2],\n",
" [6.3, 2.8, 5.1, 1.5],\n",
" [4.7, 3.2, 1.3, 0.2],\n",
" [4.5, 2.3, 1.3, nan],\n",
" [nan, nan, 4.5, 1.5],\n",
" [6.7, 3. , 5. , nan],\n",
" [5.6, 2.9, 3.6, 1.3],\n",
" [nan, 2.6, nan, 2.3],\n",
" [4.9, 3. , 1.4, 0.2],\n",
" [6.4, 3.1, nan, 1.8],\n",
" [6.7, nan, 5.6, 2.4],\n",
" [nan, 3. , 1.4, 0.1],\n",
" [4.7, 3.2, nan, 0.2],\n",
" [6.7, 3. , nan, nan],\n",
" [4.4, 3.2, 1.3, 0.2],\n",
" [5.1, 3.5, 1.4, 0.2],\n",
" [6.3, 2.5, 4.9, 1.5],\n",
" [6.2, 2.9, 4.3, 1.3],\n",
" [nan, 3.4, nan, nan],\n",
" [nan, 3.9, 1.3, 0.4],\n",
" [5.1, nan, 1.4, 0.3],\n",
" [6. , 3.4, 4.5, 1.6],\n",
" [6.7, nan, 5.8, 1.8],\n",
" [5.9, 3. , 4.2, 1.5],\n",
" [5.8, 2.6, 4. , 1.2],\n",
" [6.1, nan, 4.7, 1.4],\n",
" [6.9, nan, 5.1, 2.3],\n",
" [5.2, 4.1, nan, 0.1],\n",
" [4.6, 3.4, 1.4, 0.3],\n",
" [4.8, 3.4, 1.9, 0.2],\n",
" [6.5, 3. , 5.5, 1.8],\n",
" [nan, 2.5, 3. , 1.1],\n",
" [nan, 3.3, 5.7, 2.5],\n",
" [5.9, 3. , nan, 1.8],\n",
" [5.4, 3.4, 1.5, 0.4],\n",
" [7.9, 3.8, 6.4, 2. ],\n",
" [nan, 3.4, nan, 0.2],\n",
" [6.6, nan, 4.4, 1.4],\n",
" [5. , 2. , 3.5, nan],\n",
" [4.9, nan, nan, 1. ],\n",
" [6.5, 3. , nan, 2.2],\n",
" [5. , nan, 1.4, 0.2],\n",
" [6.3, nan, nan, 1.6],\n",
" [7.2, 3.2, nan, 1.8],\n",
" [5.8, 2.7, 3.9, 1.2],\n",
" [7.2, 3.6, 6.1, 2.5],\n",
" [5.5, nan, 4. , 1.3],\n",
" [5.5, 4.2, 1.4, 0.2],\n",
" [nan, 3. , nan, 1.6],\n",
" [6.3, 3.3, 6. , 2.5],\n",
" [5.1, 3.8, 1.6, 0.2],\n",
" [6.2, 2.2, 4.5, 1.5],\n",
" [nan, 2.9, nan, 1.3],\n",
" [5.5, nan, nan, nan],\n",
" [5.5, 2.4, nan, nan],\n",
" [nan, 3.8, 1.7, 0.3],\n",
" [nan, 3.4, 1.5, nan],\n",
" [nan, nan, nan, nan],\n",
" [5.6, 2.5, nan, 1.1],\n",
" [nan, 3. , 1.1, 0.1],\n",
" [7.3, nan, 6.3, nan],\n",
" [5.8, 2.7, 5.1, nan],\n",
" [6.4, 2.7, 5.3, 1.9],\n",
" [5.7, nan, 4.5, 1.3],\n",
" [5.7, 2.6, 3.5, 1. ],\n",
" [nan, 3.3, 1.7, 0.5],\n",
" [nan, 3. , nan, 2.1],\n",
" [6. , 2.7, 5.1, 1.6],\n",
" [nan, 3.1, 1.5, 0.2],\n",
" [6. , 3. , nan, 1.8],\n",
" [5.5, 2.6, nan, nan],\n",
" [5.8, nan, 5.1, 2.4],\n",
" [nan, nan, nan, 0.2],\n",
" [nan, 3.7, nan, 0.2],\n",
" [6.3, 2.7, 4.9, nan],\n",
" [nan, nan, nan, 1.3],\n",
" [6.2, 3.4, 5.4, 2.3],\n",
" [5. , 3.5, 1.6, nan],\n",
" [nan, 3. , 6.1, nan],\n",
" [5.6, 2.7, 4.2, 1.3],\n",
" [nan, 3.9, 1.7, 0.4],\n",
" [5. , 2.3, 3.3, nan],\n",
" [nan, 3.2, 5.7, nan],\n",
" [6.2, 2.8, nan, 1.8],\n",
" [4.9, nan, 4.5, 1.7],\n",
" [5.2, 3.5, 1.5, 0.2],\n",
" [6.5, nan, 5.1, 2. ]])"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Implant artificial missing values\n",
"import numpy as np\n",
"\n",
"rand = np.random.RandomState(0) # Random number generator\n",
"X_missing = X.copy()\n",
"\n",
"mask = [] # We will need this later to filter observations in y\n",
"features_missing = np.shape(X)[1] # Missing values in all features\n",
"loc = 0 # Mean in numpy.random.normal\n",
"scale = 5 # Standard deviation in numpy.random.normal \n",
"threshold = 1.75\n",
"for i in range(0, features_missing):\n",
" mask_partial = np.abs(X[:,1] - rand.normal(loc=loc, scale=scale, size=X.shape[0])) < threshold\n",
" X_missing[mask_partial, i] = np.NaN\n",
" mask.append(mask_partial)\n",
"\n",
"X_missing"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Done! Our dataset is ready.\n",
"\n",
"Now, let's compare the two approaches we discussed in the beginning of the notebook:\n",
"1. Use only valid data.\n",
"1. Impute data."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# 1. Use only valid data"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"As we said in the beginning, the procedure that we need to follow is:\n",
"1. Choose an **error metric** (e.g. accuracy).\n",
"1. Select a **machine learning algorithm** (e.g. logistic regression).\n",
"1. Apply a **missing data solution** (e.g. impute data) to get a complete dataset.\n",
"1. Evaluate model performance through **cross-validation**."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 1.1. Error metric"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We will use **accuracy**. Accuracy is given by the ratio between the number of correct predicted labels and the total number of observations in the sample."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 1.2. Machine learning algorithm"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We will use **logistic regression**. It's a simple and well-known algorithm that fits the illustrative purposes of our example."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 1.3. Imputation method"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"In the *only valid data* approach, the idea is to remove the observations with missing values and keep only the observations with complete data. Let's do it."
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"(54, 4)\n",
"(54,)\n"
]
}
],
"source": [
"# Delete observations with missing values\n",
"import pandas as pd\n",
"\n",
"X_filtered = pd.DataFrame(X_missing)\n",
"X_filtered.dropna(inplace=True)\n",
"\n",
"## To remove observations in y it's not trivial\n",
"## We need mask to know, in each column, which observations have missing values\n",
"## Then we keep all the observations without missing values\n",
"mask_total = mask[0]\n",
"for i in range(0, np.shape(mask)[0]):\n",
" mask_total += mask[i]\n",
"y_filtered = y[~(mask_total)]\n",
"\n",
"print(np.shape(X_filtered))\n",
"print(np.shape(y_filtered))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 1.4. Evaluation through cross-validation"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now, to estimate the performance of the model, we will use cross-validation. In scikit-learn, we can use cross-validation through the *cross_val_score* function. Since it applies k-fold cross-validation, we need to define the number of folds. Usually, 5 or 10 are enough. Here, we will use 5 folds (cv=5) because I like the number 5."
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Score: 0.909 +/- 0.058\n"
]
}
],
"source": [
"# Estimate model's performance\n",
"from sklearn.linear_model import LogisticRegression\n",
"from sklearn.model_selection import cross_val_score\n",
"\n",
"lr = LogisticRegression()\n",
"score = cross_val_score(lr, X_filtered, y_filtered, cv=5)\n",
"print('Score: %.3f +/- %.3f' % (np.mean(score), np.std(score)))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# 2. Impute data"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Ok, here it's a different solution but the is similar. Do you still remember the four steps to glory?"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 2.1. Error metric"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We need to use **accuracy** because we want to compare the approaches."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 2.2. Machine learning algorithm"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"For the same reasons as above, we will go for **logistic regression**."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 2.3. Imputation method"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"This is where things start getting interesting. So, we want to impute values to replace missing values. These imputed values must be estimated. How do we estimate them? The easiest way to estimate them is to say that they result from the average of the known values (in each feature). This means that, for example, the missing values of petal length can be replaced by the mean value of all the known petal lengths values.\n",
"\n",
"The mean imputation is one of the simplest ways to estimate missing values. We can go for more complex solutions, but in scikit-learn we are somehow restricted. Let's see why."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Data leakage"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Imagine that you have a dataset with missing values and you will use all the data in the dataset to compute means:\n",
"\n",
"<img src=\"figures/missing_data_dataset_incomplete.jpeg\" style=\"max-width:50%; width: 25%\">\n",
"\n",
"Now, you will impute those means into your dataset to complete it:\n",
"\n",
"<img src=\"figures/missing_data_dataset_complete.jpeg\" style=\"max-width:50%; width: 25%\">\n",
"\n",
"Ok. You have a complete dataset. What's next? Next, you use cross-validation to evaluate the performance of the model:\n",
"\n",
"<img src=\"figures/missing_data_5foldcv.jpg\" style=\"max-width:50%; width: 25%\">\n",
"\n",
"And then you find out that what you're doing is **wrong**. \n",
"\n",
"Let me tell you why. If you use the entire dataset to compute the means, you'll be using information from the validation set to fill missing values in the training set. This corresponds to a **data leakage** situation. You fall into a data leakage situation everytime you train your model with data that, somehow, has information about the data used to evaluate model performance. When data leakage occurs, you'll get overly optimistic results during cross-validation because the model will be tested on seen data, instead of unseen data. That's why you should be careful when using imputation and cross-validation together."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Pipelines"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"So, what should we do? What we need to do is to make sure that we first split the data into train and validation sets, and only then we compute and impute means. In this way, we avoid mixing the datasets.\n",
"\n",
"The following diagram illustrates what I want to say:\n",
"\n",
"<img src=\"figures/missing_data_correct_pipeline_1.jpg\" style=\"width:30%\">\n",
"\n",
"<img src=\"figures/missing_data_correct_pipeline_2.jpeg\" style=\"width:30%\">\n",
"\n",
"<img src=\"figures/missing_data_correct_pipeline_3.jpeg\" style=\"width:30%\">\n",
"\n",
"We can easily implement these steps in scikit-learn using the **Pipeline** class. Pipelines allow us to integrate multiple steps into a single unit (the pipeline). In scikit-learn, you can use this unit in the same way you use an estimator. This means that the unit Pipeline works like LogisticRegression or any other model. It has fit, predict, and score methods, and you can use it as a classifier."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Building pipelines"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The easiest way to build a pipeline is through the function *make_pipeline*. The syntax for *make_pipeline* is as simples as: \n",
"\n",
"> make_pipeline(*steps you want to do*)\n",
"\n",
"The steps that you want to do must be 'transforms'. And that's a problem..."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Pipelines limitations"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"To say that the steps must be 'transforms', means that they must implement fit and transform methods. Accordingly, if we are trying an imputation method that does not have these methods, we can't use pipelines. And if we can't use pipelines, we can't evaluate our model through cross-validation because of the data leakage problem.\n",
"\n",
"This is the reason why I told you that scikit-learn restricts the use of more complex imputation methods. You can still apply them if you work around the code, but it will never be as straightforward as it is to apply the mean imputation method."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 2.4. Evaluation through cross-validation"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now that we already discussed the need for pipelines, let's solve the problem."
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Score: 0.813 +/- 0.078\n"
]
}
],
"source": [
"# Estimate score with pipeline\n",
"from sklearn.preprocessing import Imputer\n",
"from sklearn.pipeline import make_pipeline\n",
"\n",
"pipe = make_pipeline(Imputer(), LogisticRegression())\n",
"score = cross_val_score(pipe, X_missing, y, cv=5)\n",
"print('Score: %.3f +/- %.3f' % (np.mean(score), np.std(score)))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"This is the score that should be compared with the score resulting from the approach in which we used only valid data. It's the comparison between these two scores that should guide our decision on which solution we should use to solve the missing data problem."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"---"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Summary"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"In this example, we saw how to select a solution for the missing data problem. We discussed two solutions: \n",
"1. Use only valid data.\n",
"1. Impute data (mean imputation).\n",
"\n",
"While the first solution is easy to apply, the second one is tricky. In particular, the second solution can guide us to a data leakage situation. To avoid these situation, we must use pipelines. Since pipelines are restricted to 'transformers', not always we can apply complex imputation methods in scikit-learn.\n",
"\n",
"The general process to compare missing data solutions is:\n",
"1. Choose an error metric.\n",
"1. Select a machine learning algorithm.\n",
"1. Apply a missing data solution to get a complete dataset.\n",
"1. Evaluate model performance through cross-validation.\n",
"\n",
"Once you finished this process, you're able to select the missing data solution that leads you to the most accurate predictions."
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.2"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment