Skip to content

Instantly share code, notes, and snippets.

@samueleverett01
Last active June 23, 2020 15:02
Show Gist options
  • Save samueleverett01/15a1fe53dcde8813eed9367b103676b2 to your computer and use it in GitHub Desktop.
Save samueleverett01/15a1fe53dcde8813eed9367b103676b2 to your computer and use it in GitHub Desktop.
A K-NN model that outperforms standard Scikit-Learn K-NN classifiers.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Building a Faster KNN Classifier\n",
"\n",
"In this notebook, we will build a parsimonious K-NN model that uses cosine similarity as a distance metric to classify MNIST images, in an attempt to find a speed and or accuracy improvement over the Scikit-Learn K-NN model.\n",
"\n",
"Start by importing required libraries, and building the same data sets as in the Scikit-Learn K-NN notebook."
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"((70000, 784), (70000,))"
]
},
"execution_count": 1,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import numpy as np\n",
"import heapq\n",
"from collections import Counter\n",
"from sklearn.metrics.pairwise import cosine_similarity\n",
"from sklearn import datasets, model_selection\n",
"from sklearn.metrics import classification_report\n",
"\n",
"mnist = datasets.fetch_mldata('MNIST original')\n",
"data, target = mnist.data, mnist.target\n",
"\n",
"# make sure everything was correctly imported\n",
"data.shape, target.shape"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Set up the exact same data sets with the same method as in the Scikit-Learn K-NN notebook."
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"# make an array of indices the size of MNIST to use for making the data sets.\n",
"# This array is in random order, so we can use it to scramble up the MNIST data\n",
"indx = np.random.choice(len(target), 70000, replace=False)\n",
"\n",
"# method for building datasets to test with\n",
"def mk_dataset(size):\n",
" \"\"\"makes a dataset of size \"size\", and returns that datasets images and targets\n",
" This is used to make the dataset that will be stored by a model and used in \n",
" experimenting with different stored dataset sizes\n",
" \"\"\"\n",
" train_img = [data[i] for i in indx[:size]]\n",
" train_img = np.array(train_img)\n",
" train_target = [target[i] for i in indx[:size]]\n",
" train_target = np.array(train_target)\n",
" \n",
" return train_img, train_target"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"((50000, 784), (50000,))"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# lets make a dataset of size 50,000, meaning the model will have 50,000 data points to compare each \n",
"# new point it is to classify to\n",
"fifty_x, fifty_y = mk_dataset(50000)\n",
"fifty_x.shape, fifty_y.shape"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"((20000, 784), (20000,))"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# lets make one more of size 20,000 and see how classification accuracy decreases when we use that one\n",
"twenty_x, twenty_y = mk_dataset(20000)\n",
"twenty_x.shape, twenty_y.shape"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"((10000, 784), (10000,))"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# build model testing dataset\n",
"test_img = [data[i] for i in indx[60000:70000]]\n",
"test_img1 = np.array(test_img)\n",
"test_target = [target[i] for i in indx[60000:70000]]\n",
"test_target1 = np.array(test_target)\n",
"test_img1.shape, test_target1.shape"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Building the Model\n",
"\n",
"Below we will create the function `cos_knn()` that will act as our latest and greatest K-NN classifier for MNIST. Follow the comments in the function for details on how it works."
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"def cos_knn(k, test_data, test_target, stored_data, stored_target):\n",
" \"\"\"k: number of neighbors to use for voting\n",
" test_data: a set of unobserved images to classify\n",
" test_target: the labels for the test_data (for calculating accuracy)\n",
" stored_data: the images already observed and available to the model\n",
" stored_target: labels for stored_data\n",
" \"\"\"\n",
" \n",
" # find cosine similarity for every point in test_data between every other point in stored_data\n",
" cosim = cosine_similarity(test_data, stored_data)\n",
" \n",
" # get top k indices of images in stored_data that are most similar to any given test_data point\n",
" top = [(heapq.nlargest((k), range(len(i)), i.take)) for i in cosim]\n",
" # convert indices to numbers using stored target values\n",
" top = [[stored_target[j] for j in i[:k]] for i in top]\n",
" \n",
" # vote, and return prediction for every image in test_data\n",
" pred = [max(set(i), key=i.count) for i in top]\n",
" pred = np.array(pred)\n",
" \n",
" # print table giving classifier accuracy using test_target\n",
" print(classification_report(test_target, pred))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Testing the Model\n",
"Now, just as with the Scikit-Learn K-NN model, we will test the `cos_knn()` model on the two data sets and see how it stacks up against the Scikit-Learn K-NN model."
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
" precision recall f1-score support\n",
"\n",
" 0.0 0.97 0.99 0.98 992\n",
" 1.0 0.98 0.99 0.98 1123\n",
" 2.0 0.98 0.98 0.98 984\n",
" 3.0 0.98 0.97 0.97 1089\n",
" 4.0 0.99 0.97 0.98 1016\n",
" 5.0 0.99 0.96 0.97 857\n",
" 6.0 0.98 0.99 0.98 979\n",
" 7.0 0.97 0.96 0.97 1001\n",
" 8.0 0.96 0.96 0.96 993\n",
" 9.0 0.95 0.97 0.96 966\n",
"\n",
"avg / total 0.97 0.97 0.97 10000\n",
"\n",
"CPU times: user 5min 17s, sys: 1.21 s, total: 5min 18s\n",
"Wall time: 4min 59s\n"
]
}
],
"source": [
"%%time\n",
"# stored data set size of 50,000\n",
"cos_knn(5, test_img1, test_target1, fifty_x, fifty_y)"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
" precision recall f1-score support\n",
"\n",
" 0.0 0.96 0.99 0.98 992\n",
" 1.0 0.96 0.98 0.97 1123\n",
" 2.0 0.97 0.97 0.97 984\n",
" 3.0 0.97 0.95 0.96 1089\n",
" 4.0 0.98 0.95 0.97 1016\n",
" 5.0 0.97 0.94 0.96 857\n",
" 6.0 0.97 0.99 0.98 979\n",
" 7.0 0.96 0.96 0.96 1001\n",
" 8.0 0.96 0.95 0.95 993\n",
" 9.0 0.94 0.96 0.95 966\n",
"\n",
"avg / total 0.97 0.97 0.97 10000\n",
"\n",
"CPU times: user 2min 9s, sys: 528 ms, total: 2min 9s\n",
"Wall time: 2min 1s\n"
]
}
],
"source": [
"%%time\n",
"# stored data set size of 20,000\n",
"cos_knn(5, test_img1, test_target1, twenty_x, twenty_y)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Fantastic! The cosine similarity model we built ourselves outperformed the Scikit-Learn K-NN! Remarkably, the model outperformed the Scikit-Learn K-NN in terms of both classification speed (by a sizeable margin) and accuracy, and yet the model is so simple!\n",
"\n",
"For furthur analysis into how the model works and how it stacked up against the Scikit-Learn K-NN in many different situations, see [this GitHub repository](https://github.com/samgrassi01/Cosine-Similarity-Classifier)."
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python [conda root]",
"language": "python",
"name": "conda-root-py"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.4"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment