Skip to content

Instantly share code, notes, and snippets.

@DragaDoncila
Last active November 24, 2020 04:34
Show Gist options
  • Save DragaDoncila/918a0e0f951176bbcfe01b10d177edbd to your computer and use it in GitHub Desktop.
Save DragaDoncila/918a0e0f951176bbcfe01b10d177edbd to your computer and use it in GitHub Desktop.
Code plotting confusion matrix for classifier
from matplotlib.colors import ListedColormap
from sklearn import neighbors
from imblearn.under_sampling import RandomUnderSampler
from sklearn.model_selection import cross_val_score
from sklearn.metrics import confusion_matrix
import itertools
import pandas as pd
import os
import re
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
PATH = "/home/draga/FIT2082/Data_CSVs/"
OUT_PATH = "/home/draga/FIT2082/Data_Scatters/"
pd.set_option('display.max_columns', None)
def main():
h = 0.05 # step size in the mesh (.02 initially caused memory error)
for filename in os.listdir(PATH):
# pattern = r"(Data)_*(TSNE)(_group).csv"
pattern = r"(Data2*)_*(.*)(_group).csv"
match = re.search(pattern, filename)
if match:
# data we'll be plotting on top of decision boundaries
point_data_full = pd.read_csv(PATH + filename)
no_nones = point_data_full.loc[point_data_full['group'] != 'NONE']
nones = pd.concat([point_data_full, no_nones]).drop_duplicates(keep=False)
# generate dummy integer values for the group column
group_dummies = pd.get_dummies(no_nones['group']).values.argmax(1)
no_nones['group.dummy'] = group_dummies
# get equal(ish) samples of each category
sample_sizes = get_sample_sizes(no_nones, 30)
no_nones = get_sampled_data(no_nones, group_dummies, sample_sizes)
# define x and y values for plotting and classifying
x_points = no_nones['x'].values
y_points = no_nones['y'].values
X = np.stack((x_points, y_points), axis=1)
y = no_nones['group.dummy'].values
cmap_light = ListedColormap(sns.color_palette('pastel', 13))
colours_bold = list(sns.color_palette('bright', 13))
# cross reference colour to group for plotting
cats = np.unique(no_nones['group'])
color_dict = dict(zip(cats, colours_bold))
colour_col = no_nones['group'].apply(lambda x: color_dict[x])
# find best model for this data and fit it
k, accuracy = get_best_k(X, y)
clf = neighbors.KNeighborsClassifier(k, weights='distance')
clf.fit(X, y)
# Plot the decision boundary
x_min, x_max = X[:, 0].min() - 1, X[:, 0].max() + 1
y_min, y_max = X[:, 1].min() - 1, X[:, 1].max() + 1
xx, yy = np.meshgrid(np.arange(x_min, x_max, h),
np.arange(y_min, y_max, h))
Z = clf.predict(np.c_[xx.ravel(), yy.ravel()])
# plot_decision_boundaries(Z, xx, yy, x_points, y_points, nones, cats, cmap_light, colour_col, match, k, accuracy)
cnf_matrix = confusion_matrix(y, clf.predict(X)) #gives 100% correct
np.set_printoptions(precision=2)
# Plot non-normalized confusion matrix
plt.figure()
plot_confusion_matrix(cnf_matrix, classes=[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13],
title='Confusion matrix, without normalization')
# Plot normalized confusion matrix
plt.figure()
plot_confusion_matrix(cnf_matrix, classes=list(cats), normalize=True,
title='Normalized confusion matrix')
plt.show()
def get_counts(full_data):
"""
Get counts of each group category in the dataset
Parameters
----------
full_data: pd DataFrame
Data to search for categories
Returns
-------
count_df: pd DataFrame
Dataframe containing each unique category, its dummy value and the count of its occurrences in full_data
"""
categories = list(np.unique(full_data['group']))
dummies = []
counts = []
for info in categories:
dummies.append(full_data.loc[full_data['group'] == info, 'group.dummy'].iloc[0])
counts.append(full_data.loc[full_data['group'] == info, 'group'].agg(['count']).iloc[0])
categories = pd.DataFrame(categories)
counts = pd.DataFrame(counts)
dummies = pd.DataFrame(dummies)
count_df = pd.concat([categories, dummies, counts], axis=1)
count_df.columns = ['Category', 'Category.Dummy', 'Count']
return count_df
def get_sampled_data(data, y, sample_sizes):
"""
Use random under sampling to sample data. Each category in y will be sampled according to its corresponding value
in sample_sizes
Parameters
----------
data: pd DataFrame
Data to be sampled
y: np array
Target categories to use for under sampling
sample_sizes: dict
{category_dummy: desired_count} for each category in data
Returns
-------
x_resample: pd DataFrame
resampled data
"""
us = RandomUnderSampler(random_state=42, ratio=sample_sizes)
# convert dataframe to numpy array
col_names = data.columns
col_types = data.dtypes
data = data.values.astype("U")
x_resample, y_resample = us.fit_sample(data, y)
# convert back to dataframe
x_resample = pd.DataFrame(x_resample)
x_resample.columns = col_names
for col_name, col_type in zip(col_names, col_types):
x_resample[col_name] = x_resample[col_name].astype(col_type)
return x_resample
def get_sample_sizes(data, desired_n):
"""
Return a dictionary of sample sizes for each category in data to use for under sampling
Will return the category count for categories with fewer than desired_n samples
Parameters
----------
data: pd DataFrame
Data to get sample sizes for
desired_n: int
Desired number of points in each category
Returns
-------
count_dict: dictionary
{category dummy : category count} for each category in data
"""
counts = get_counts(data)
cat_list = list(counts['Category.Dummy'])
count_list = []
for cat in cat_list:
count_list.append(min(counts.loc[counts['Category.Dummy'] == cat, 'Count'].iloc[0], desired_n))
count_dict = dict(zip(cat_list, count_list))
return count_dict
def get_best_k(X, y):
"""
Perform cross validation on fitting X using KNN classifier with k = 3 to k = 21. Returns k which results in maximum
mean accuracy after cross validation
Parameters
----------
X: np array
Data to perform fit on
y: np array
Target values
Returns
-------
(k, accuracy): tuple
The value of k which gave the best mean accuracy score, and that accuracy
"""
mean_accuracies = []
for n_neighbors in range(3, 21):
# we create an instance of Neighbours Classifier and fit the data.
clf = neighbors.KNeighborsClassifier(n_neighbors, weights='distance')
scores = cross_val_score(clf, X, y, cv=5)
mean_accuracy = np.mean(scores)
mean_accuracies.append(mean_accuracy)
best_accuracy = max(mean_accuracies)
best_k = mean_accuracies.index(best_accuracy) + 3
return (best_k, best_accuracy)
def plot_decision_boundaries(Z, xx, yy, x_points, y_points, nones, cats, cmap_light, colour_col, match, k, accuracy):
# Put the result into a color plot
Z = Z.reshape(xx.shape)
plt.figure()
plt.pcolormesh(xx, yy, Z, cmap=cmap_light, zorder=1)
# Plot the training points
plt.scatter(x_points, y_points, c=colour_col.values,
edgecolor='k', s=20, zorder=3, label=list(cats))
plt.scatter(nones['x'].values, nones['y'].values, c=(135 / 255, 135 / 255, 135 / 255, 0.15), s=15, zorder=2,
label="UNKNOWN")
plt.xlim(xx.min(), xx.max())
plt.ylim(yy.min(), yy.max())
plt.xlabel("TSNE Component 1")
plt.ylabel("TSNE Component 2")
title = ("PCA" if len(match.group(2)) == 0 else match.group(2)) + \
(" Single Object Sample" if len(match.group(1)) != 5 else " Multi Object Sample")
plt.title(title + \
" , k={} accuracy={:.2f}".format(k, accuracy))
plt.savefig(OUT_PATH + title)
def plot_confusion_matrix(cm, classes,
normalize=False,
title='Confusion matrix',
cmap=plt.cm.Blues):
"""
This function prints and plots the confusion matrix.
Normalization can be applied by setting `normalize=True`.
"""
if normalize:
cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
print("Normalized confusion matrix")
else:
print('Confusion matrix, without normalization')
print(cm)
plt.imshow(cm, interpolation='nearest', cmap=cmap)
plt.title(title)
plt.colorbar()
tick_marks = np.arange(len(classes))
plt.xticks(tick_marks, classes, rotation=45)
plt.yticks(tick_marks, classes)
fmt = '.2f' if normalize else 'd'
thresh = cm.max() / 2.
for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
plt.text(j, i, format(cm[i, j], fmt),
horizontalalignment="center",
color="white" if cm[i, j] > thresh else "black")
plt.ylabel('True label')
plt.xlabel('Predicted label')
plt.tight_layout()
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment