Radial Basis Function initialized on K-means and trained on Pseudo Inverse for MNIST digit classification
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
modified from several gist of https://gist.github.com/tarlanahad | |
For Homework Assingment 12 (CpE 520) | |
""" | |
import numpy as np | |
import torchvision | |
from matplotlib import pyplot as plt | |
from sklearn.metrics import confusion_matrix | |
import seaborn as sns | |
data_dir = "./data" | |
saved_image_dir = "./images" | |
num_classes = 10 | |
num_rbf_unit = 40 | |
def get_distance(x1, x2): | |
return np.sqrt(np.sum(np.power(x1-x2, 2))) | |
def kmeans(X, num_centroids, max_iters=60): | |
indices = np.random.choice(range(60000), size=num_centroids, replace=False) | |
centroids = X[indices] | |
converged = False | |
current_iter = 0 | |
print("Kmeans algorithm running ....") | |
while (not converged) and (current_iter < max_iters): | |
cluster_list = [[] for i in range(len(centroids))] | |
print("Iter: ", current_iter) | |
for x in X: # Go through each data point | |
distances_list = [] | |
for c in centroids: | |
distances_list.append(get_distance(c, x)) | |
cluster_list[int(np.argmin(distances_list))].append(x) | |
cluster_list = list((filter(None, cluster_list))) | |
prev_centroids = centroids.copy() | |
centroids = [] | |
for j in range(len(cluster_list)): | |
centroids.append(np.mean(cluster_list[j], axis=0)) | |
pattern = np.abs(np.sum(prev_centroids) - np.sum(centroids)) | |
print('K-MEANS: ', int(pattern)) | |
converged = (pattern == 0) | |
current_iter += 1 | |
return np.array(centroids), [np.std(x) for x in cluster_list] | |
class RBF: | |
def __init__(self, train_data, train_label, val_data, | |
val_label, k, std_from_clusters=True): | |
self.X = train_data | |
self.y = train_label | |
self.tX = val_data | |
self.ty = val_label | |
self.num_classes = num_classes | |
self.k = k | |
self.std_from_clusters = std_from_clusters | |
def convert_to_one_hot(self, x): | |
arr = np.zeros((len(x), self.num_classes)) | |
for i in range(len(x)): | |
c = int(x[i]) | |
arr[i][c] = 1 | |
return arr | |
def rbf(self, x, c, s): | |
distance = get_distance(x, c) | |
return 1 / np.exp((-distance**2) / (2*s**2)) | |
def rbf_list(self, X, centroids, std_list): | |
RBF_list = [] | |
for x in X: | |
RBF_list.append([self.rbf(x, c, s) for (c, s) in zip(centroids, std_list)]) | |
return np.array(RBF_list) | |
def fit(self): | |
self.centroids, self.std_list = kmeans(self.X, self.k) | |
if not self.std_from_clusters: | |
dMax = np.max([get_distance(c1, c2) for c1 in self.centroids for c2 in self.centroids]) | |
self.std_list = np.repeat(dMax / np.sqrt(2 * self.k), self.k) | |
RBF_X = self.rbf_list(self.X, self.centroids, self.std_list) | |
self.w = np.linalg.pinv(RBF_X.T @ RBF_X) @ RBF_X.T @ self.convert_to_one_hot(self.y) | |
RBF_list_tst = self.rbf_list(self.tX, self.centroids, self.std_list) | |
# predicting on train dataet | |
self.pred_train = RBF_X @ self.w | |
self.pred_train = np.array([np.argmax(x) for x in self.pred_train]) | |
# predicting on test dataset | |
self.pred_test = RBF_list_tst @ self.w | |
self.pred_test = np.array([np.argmax(x) for x in self.pred_test]) | |
# find accuracy | |
diff = self.pred_test - self.ty | |
print('Accuracy: ', len(np.where(diff == 0)[0]) / len(diff)) | |
# confusion matrix | |
self.draw_confusion_matrix(self.y, self.pred_train) | |
self.draw_confusion_matrix(self.ty, self.pred_test) | |
np.savez("file.npz", self.centroids) | |
def draw_confusion_matrix(self, true_labels, pred_labels): | |
# make confusion matrix | |
c_matrix = confusion_matrix(y_true=true_labels, y_pred=pred_labels) | |
plt.figure(figsize = (10, 10)) | |
sns. set(font_scale=1.4) | |
sns.heatmap(c_matrix, annot=True, fmt = 'g', linewidths=.5) | |
# labels, title | |
plt.xlabel('Predicted Label', fontsize=10, labelpad=11) | |
plt.ylabel('True Label', fontsize=10) | |
plt.show() | |
class Train: | |
def __init__(self): | |
train_dataset = torchvision.datasets.MNIST( | |
root=data_dir, | |
train=True, | |
download=True) | |
val_dataset = torchvision.datasets.MNIST( | |
root=data_dir, | |
train=False, | |
download=True) | |
train_data = [] | |
train_label = [] | |
val_data = [] | |
val_label = [] | |
for img, label in train_dataset: | |
train_data.append((np.asarray(img)/255).reshape(784)) | |
train_label.append(label) | |
for img, label in val_dataset: | |
val_data.append((np.asarray(img)/255).reshape(784)) | |
val_label.append(label) | |
train_data = np.array(train_data) | |
val_data = np.array(val_data) | |
train_label = np.array(train_label) | |
val_label = np.array(val_label) | |
print("Running RBF network") | |
model = RBF(train_data, train_label, val_data, val_label, | |
k=num_rbf_unit, std_from_clusters=False) | |
model.fit() | |
if __name__ == "__main__": | |
t = Train() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
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