Skip to content

Instantly share code, notes, and snippets.

@angelormrl
Created May 8, 2019 18:00
Show Gist options
  • Save angelormrl/124454ef48f7b1e4758c33b632f4061a to your computer and use it in GitHub Desktop.
Save angelormrl/124454ef48f7b1e4758c33b632f4061a to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Model Tuning and Testing\n",
"This notebook contains the code for and documentation of the final stages of this machine learning task. These include:\n",
"1. Importing the input, target and info data from csv files.\n",
"2. Reformatting and normalising the input data.\n",
"3. Parameter tuning through multiple iterations of nested cross-validation.\n",
"4. Accuracy test on unseen examples from patches otherwise present in the dataset.\n",
"5. Accuracy test on unseen patches for models trained on all other patches."
]
},
{
"cell_type": "code",
"execution_count": 148,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"## ======== libraries ======== ##\n",
"\n",
"\n",
"# mathematical operations\n",
"import numpy as np\n",
"\n",
"# obtains principal componenets and projects new data\n",
"from sklearn.decomposition import PCA\n",
"\n",
"# Multilayer perceptron for classifcation\n",
"from sklearn.neural_network import MLPClassifier\n",
"\n",
"# to import data from csvs\n",
"import csv\n",
"\n",
"# plotting results\n",
"import matplotlib.pyplot as plt"
]
},
{
"cell_type": "code",
"execution_count": 234,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"## ======== function definitions ======== ##\n",
"\n",
"\n",
"# extracts a list of sublists from a csv file\n",
"def read_csv(path, convert_to_float = False):\n",
"\n",
" data = []\n",
" \n",
" my_file = open(path)\n",
"\n",
" readCSV = csv.reader(my_file, delimiter=',')\n",
" \n",
" for row in readCSV:\n",
" \n",
" # if True strings are converted to floating points\n",
" # necessary as floats written to csv files become strings\n",
" if(convert_to_float):\n",
" \n",
" row = [float(element) for element in row]\n",
" \n",
" data.append(row)\n",
" \n",
" return data\n",
" \n",
"\n",
"# divides each sublist into multiple lists of a given length\n",
"# necessary as 2d lists of mfccs were flattened to be written to csv\n",
"def unflatten_sublists(data, length):\n",
" \n",
" unflattened_data = []\n",
" \n",
" for example in data:\n",
" \n",
" # divides long list in to smaller ones and stores as a 2d list\n",
" unflattened_example = [example[i:i+length] for i in range(0, len(example), length)]\n",
" \n",
" unflattened_data.append(unflattened_example)\n",
" \n",
" return unflattened_data\n",
" \n",
" \n",
"# normalizes values per sample across each mfcc\n",
"def normalize_data(x):\n",
" \n",
" # returns matrix of zeros in the same shape as input data\n",
" data = zeros(shape(x))\n",
"\n",
" # loops through each mfcc\n",
" for i in range(len(x[0])):\n",
" \n",
" # gathers the same sample across each mfcc\n",
" feature = [example[i] for example in x]\n",
" \n",
" max_val = np.max(feature)\n",
" \n",
" min_val = np.min(feature)\n",
" \n",
" # max and min values allow us to calculate the range of values\n",
" val_range = max_val - min_val\n",
" \n",
" for j in range(len(x)):\n",
" \n",
" # by subtracting the min and dividing by the range we normalise each value\n",
" data[j][i] = (x[j][i] - min_val) / val_range\n",
" \n",
" return data\n",
"\n",
"\n",
"# reduces dimensionality of data using principal componenet analysis\n",
"def reduce_dimensions(data, n_components):\n",
" \n",
" # a dimension is reduced to size n_components\n",
" pca = PCA(n_components=n_components)\n",
" \n",
" # first the principal components are calculated\n",
" pca.fit(data)\n",
" \n",
" # then data is projected to calculate new values\n",
" reduced_data = pca.transform(data)\n",
" \n",
" return reduced_data\n",
"\n",
"\n",
"# prepare data for training: trim n_mfccs, perform pca, flatten\n",
"def prepare_input(n_mfccs, n_components, norm_mfccs):\n",
" \n",
" X = []\n",
"\n",
" for example in norm_mfccs:\n",
" \n",
" # reduce number of mfccs, starting with higher mfccs first\n",
" trimmed_mfccs = example[:n_mfccs]\n",
" \n",
" reduced_mfccs = reduce_dimensions(trimmed_mfccs, n_components)\n",
" \n",
" # arrays flattened to single dimension as our MLP takes input in this format\n",
" flattened = [item for sublist in reduced_mfccs for item in sublist]\n",
" \n",
" X.append(flattened)\n",
" \n",
" return X\n",
"\n",
"\n",
"# compares predicted and true classes and returns accuracy as a percentage\n",
"def accuracy_score(pred, true):\n",
" \n",
" correct = np.sum(pred==true)\n",
" \n",
" accuracy = correct/len(pred)\n",
" \n",
" return accuracy\n",
"\n",
"\n",
"# computes confusion matrix from predicted and true classes\n",
"def confusion_matrix(pred, true, n_classes):\n",
" \n",
" # matrix of zeros created with both dimensions equal to number of classes in dataset\n",
" matrix = np.zeros([n_classes, n_classes], dtype = int)\n",
" \n",
" # each pair of predicted and true values increase their corresponding matrix position by 1\n",
" for i in range(len(pred)): matrix[pred[i]][true[i]] += 1\n",
" \n",
" return matrix\n",
"\n",
"\n",
"# splits indices between training, testing and validation folds for nested cross-val.\n",
"# split is determined by the current outer fold i.\n",
"def assign_bins(fold_k, i, bins):\n",
" \n",
" fold_test = []\n",
" \n",
" fold_val = []\n",
" \n",
" fold_train = []\n",
" \n",
" # each bins is assigned to a fold corresponding to the rule:\n",
" # bin[i] -> test, bin[i+1] -> validation, rest -> train\n",
" for j in range(0,len(bins)):\n",
" \n",
" if(j == i):\n",
" \n",
" fold_test.extend(bins[j])\n",
" \n",
" elif(j == i + 1):\n",
" \n",
" fold_val.extend(bins[j])\n",
" \n",
" elif((j == 0) & (i == fold_k - 1)):\n",
" \n",
" fold_val.extend(bins[j])\n",
" \n",
" else:\n",
" \n",
" fold_train.extend(bins[j])\n",
" \n",
" return fold_train, fold_test, fold_val\n",
"\n",
"\n",
"# nested cross validation performs parameter tuning and assesses the performance of models on unseen data.\n",
"# this is an altered version of one submitted as part of coursework for my machine learning module\n",
"def nested_cross_val(norm_mfccs, y, fold_k, n_components, n_hiddens, activations, my_seed):\n",
"\n",
" # final accuracy on test data will be saved across each of the k folds\n",
" accuracy_fold=[]\n",
" \n",
" np.random.seed(my_seed)\n",
" \n",
" # indices referring to each data point are created and their position randomised\n",
" indices = np.random.permutation(len(norm_mfccs))\n",
" \n",
" # these indices are then split in to k bins\n",
" bins = np.array_split(indices, fold_k)\n",
" \n",
" # each configuration of folds is looped through\n",
" for i in range(0,fold_k):\n",
" \n",
" fold_train, fold_test, fold_val = assign_bins(fold_k, i, bins)\n",
" \n",
" best_accuracy = - 1000\n",
" \n",
" # each configuration of folds is tested on each combination of parameters\n",
" for j in range(0, len(n_components)):\n",
" \n",
" # data is trimemd to correct numer of mfccs and dimensions reduced to correct numer of components\n",
" X = np.asarray(prepare_input(20, n_components[j], norm_mfccs))\n",
" \n",
" for k in range(0, len(n_hiddens)):\n",
" \n",
" for l in range(0, len(activations)):\n",
" \n",
" # network initiated with correct parameters and then fitted before predictions are made\n",
" net = MLPClassifier(activation=activations[l], hidden_layer_sizes=([n_hiddens[k]]), learning_rate_init=0.01)\n",
"\n",
" net.fit(X[fold_train], y[fold_train])\n",
"\n",
" preds = net.predict(X[fold_val])\n",
" \n",
" temp_accuracy = (accuracy_score(preds, y[fold_val]))\n",
" \n",
" print'~~ fold: ' + str(i) + ' n_comps: ' + str(n_components[j]) + ' n_hids: ' + str(n_hiddens[k]) + ' activ: ' + str(activations[l]) + ' acc: ' + str(temp_accuracy)\n",
" \n",
" # if the accuracy for current paramters is best all paramters are saved as best for fold\n",
" if(temp_accuracy > best_accuracy):\n",
" \n",
" best_components = n_components[j] \n",
" \n",
" best_hiddens = n_hiddens[k]\n",
" \n",
" best_activation = activations[l]\n",
" \n",
" best_accuracy = temp_accuracy\n",
" \n",
" print '**** End of val for this fold, best n_comps: '+ str(best_components) + ' best n_hids: ' + str(best_hiddens) + ' best_activ: ' + str(best_activation) + ' best acc: ' + str(best_accuracy)\n",
"\n",
" # training fold is extended to include the validation fold to give model maximum training data\n",
" fold_train.extend(fold_val)\n",
" \n",
" # fitting and prediction are carried out as before but with best parameters for the fold\n",
" X = np.asarray(prepare_input(20, best_components, norm_mfccs))\n",
"\n",
" net = MLPClassifier(activation=best_activation, hidden_layer_sizes=([best_hiddens]), learning_rate_init=0.01)\n",
"\n",
" net.fit(X[fold_train], y[fold_train])\n",
"\n",
" preds = net.predict(X[fold_test])\n",
" \n",
" temp_accuracy = (accuracy_score(preds, y[fold_test]))\n",
" \n",
" print 'test set accuracy: ' + str(temp_accuracy)\n",
" \n",
" print 'test set confusion Matrix:'\n",
" \n",
" print(confusion_matrix(preds, y[fold_test], 3))\n",
" \n",
" accuracy_fold.extend([temp_accuracy])\n",
"\n",
" return accuracy_fold;\n",
"\n",
"\n",
"# stores all indices for misclassified examples to be used for analysys\n",
"def failed_indices(preds, y, fold_test):\n",
" \n",
" indices = []\n",
" \n",
" for i in range(len(preds)):\n",
" \n",
" if(preds[i] != y[fold_test[i]]):\n",
" \n",
" indices.append(fold_test[i])\n",
" \n",
" return indices\n",
"\n",
"\n",
"# plots the frequency with which examples played at each note were misclassified\n",
"def plot_note_counts(note_counts):\n",
" \n",
" # trim count to start with the first note at which a misclassification is counted\n",
" for i in range(len(note_count)):\n",
" \n",
" if(note_count[i] != 0):\n",
" \n",
" minimum = i\n",
" \n",
" break\n",
" \n",
" trimmed_note_counts = note_counts[minimum:]\n",
"\n",
" plt.bar(np.arange(minimum, 109), trimmed_note_counts)\n",
" \n",
" plt.title(\"Midi Notes of Incorrectly Classified Examples\")\n",
" \n",
" plt.xlabel(\"Midi Notes\")\n",
" \n",
" plt.ylabel(\"Frequency\")\n",
" \n",
"\n",
"# split the indices in to bins by unique patch rather than randomly\n",
"def bins_by_patch(y, patch_index):\n",
" \n",
" bins = []\n",
" \n",
" # examples were labeled with a class index and a patch index beteween 0 and 7. each class had \n",
" # 7 unique patches within it. to seperate each unique patch within the data set it was \n",
" # necessary to check for indices that were present when the list of target classes and patch \n",
" # indexes were searched for each combination of class index and patch index.\n",
" # sorry dont know how to explain that any better!\n",
" for i in range(len(np.unique(y))):\n",
" \n",
" for j in range(len(np.unique(patch_index))):\n",
" \n",
" matches = []\n",
" \n",
" group1 = np.argwhere(patch_index == j)\n",
" \n",
" group2 = np.argwhere(y == i)\n",
" \n",
" for item1 in group1:\n",
" \n",
" for item2 in group2:\n",
" \n",
" if item1 == item2:\n",
" \n",
" matches.extend(list(item1))\n",
" \n",
" bins.append(matches)\n",
" \n",
" return bins"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Importing and Normalising Data\n",
"As our MFCC and info data was saved to csv files in the 'feature_extraction' notebook it is now necessary to import and unpack it. \n",
"\n",
"To save the 20 X 173 array of MFCC's for each example to a csv we had to reshape them into single dimensional lists, this must be undone. We will then perform the first step of pre-processing the data by normalising the MFCC's.\n",
"\n",
"From our info array we will unpack each feature into its own list. In order these are: y (target class), patch_index (the index of the patch within each class) and note (the midi note played for that example)."
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {
"scrolled": true
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"*** mfccs imported and unflattened ***\n",
"*** mfccs normalised ***\n",
"*** info and targets imported ***\n"
]
}
],
"source": [
"## ======== importing and normalising data ======== ##\n",
"\n",
"\n",
"mfcc_path = \"/users/angelorussell/Desktop/monopoly_data/mfccs.csv\"\n",
"\n",
"# 173 samples were taken from each MFCC so to reshape our array we need this value\n",
"length = 173\n",
"\n",
"flattened_mfccs = read_csv(mfcc_path, convert_to_float = True)\n",
"\n",
"mfccs = unflatten_sublists(flattened_mfccs, length)\n",
"\n",
"print \"*** mfccs imported and unflattened ***\"\n",
"\n",
"norm_mfccs = []\n",
"\n",
"# normalize_data function takes two dimensional arrays so we pass it one example at a time\n",
"for example in mfccs:\n",
" \n",
" norm = normalize_data(example)\n",
" \n",
" norm_mfccs.append(norm)\n",
" \n",
"print \"*** mfccs normalised ***\"\n",
"\n",
"info_path = \"/users/angelorussell/Desktop/monopoly_data/info.csv\"\n",
"\n",
"info = read_csv(info_path)\n",
"\n",
"# first element of each sublist is the target class\n",
"y = np.asarray([int(element[0]) for element in info])\n",
"\n",
"# the second is the patch index\n",
"patch_index = np.asarray([int(element[1]) for element in info])\n",
"\n",
"# third is the MIDI note\n",
"note = np.asarray([int(element[2]) for element in info])\n",
"\n",
"print \"*** info and targets imported ***\""
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Nested Cross-Validation\n",
"\n",
"Once our data is imported and formatted we begin tuning the hyperparamters our model. This means that we will attempt to find their configuration at which our model performs best on unseen data. \n",
"\n",
"We will be working with a Multi-Layer Perceptron with a single layer of hidden neurons which has a large number hyperparameters including: learning rate, activation function and number of hidden neurons. Simaltaneously we will be attempting to find the number of principal componenets extracted from our MFCCs at which our model performs best.\n",
"\n",
"From examing the final test accuracies we will also be able to assess how well our model will generalise to unseen data. From this variance in this accuracy we will also be able to tell how conisitent the model is."
]
},
{
"cell_type": "code",
"execution_count": 233,
"metadata": {
"scrolled": true
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"~~ fold: 0 n_comps: 1 n_hids: 20 activ: logistic acc: 0.789915966387\n",
"~~ fold: 0 n_comps: 1 n_hids: 20 activ: tanh acc: 0.795518207283\n",
"~~ fold: 0 n_comps: 1 n_hids: 20 activ: relu acc: 0.686274509804\n",
"~~ fold: 0 n_comps: 1 n_hids: 40 activ: logistic acc: 0.809523809524\n",
"~~ fold: 0 n_comps: 1 n_hids: 40 activ: tanh acc: 0.826330532213\n",
"~~ fold: 0 n_comps: 1 n_hids: 40 activ: relu acc: 0.789915966387\n",
"~~ fold: 0 n_comps: 1 n_hids: 60 activ: logistic acc: 0.843137254902\n",
"~~ fold: 0 n_comps: 1 n_hids: 60 activ: tanh acc: 0.817927170868\n",
"~~ fold: 0 n_comps: 1 n_hids: 60 activ: relu acc: 0.789915966387\n",
"~~ fold: 0 n_comps: 1 n_hids: 80 activ: logistic acc: 0.84593837535\n",
"~~ fold: 0 n_comps: 1 n_hids: 80 activ: tanh acc: 0.840336134454\n",
"~~ fold: 0 n_comps: 1 n_hids: 80 activ: relu acc: 0.837535014006\n",
"~~ fold: 0 n_comps: 1 n_hids: 100 activ: logistic acc: 0.736694677871\n",
"~~ fold: 0 n_comps: 1 n_hids: 100 activ: tanh acc: 0.817927170868\n",
"~~ fold: 0 n_comps: 1 n_hids: 100 activ: relu acc: 0.809523809524\n",
"~~ fold: 0 n_comps: 2 n_hids: 20 activ: logistic acc: 0.885154061625\n",
"~~ fold: 0 n_comps: 2 n_hids: 20 activ: tanh acc: 0.901960784314\n",
"~~ fold: 0 n_comps: 2 n_hids: 20 activ: relu acc: 0.851540616246\n",
"~~ fold: 0 n_comps: 2 n_hids: 40 activ: logistic acc: 0.924369747899\n",
"~~ fold: 0 n_comps: 2 n_hids: 40 activ: tanh acc: 0.913165266106\n",
"~~ fold: 0 n_comps: 2 n_hids: 40 activ: relu acc: 0.932773109244\n",
"~~ fold: 0 n_comps: 2 n_hids: 60 activ: logistic acc: 0.927170868347\n",
"~~ fold: 0 n_comps: 2 n_hids: 60 activ: tanh acc: 0.935574229692\n",
"~~ fold: 0 n_comps: 2 n_hids: 60 activ: relu acc: 0.859943977591\n",
"~~ fold: 0 n_comps: 2 n_hids: 80 activ: logistic acc: 0.910364145658\n",
"~~ fold: 0 n_comps: 2 n_hids: 80 activ: tanh acc: 0.913165266106\n",
"~~ fold: 0 n_comps: 2 n_hids: 80 activ: relu acc: 0.929971988796\n",
"~~ fold: 0 n_comps: 2 n_hids: 100 activ: logistic acc: 0.887955182073\n",
"~~ fold: 0 n_comps: 2 n_hids: 100 activ: tanh acc: 0.941176470588\n",
"~~ fold: 0 n_comps: 2 n_hids: 100 activ: relu acc: 0.918767507003\n",
"~~ fold: 0 n_comps: 3 n_hids: 20 activ: logistic acc: 0.904761904762\n",
"~~ fold: 0 n_comps: 3 n_hids: 20 activ: tanh acc: 0.904761904762\n",
"~~ fold: 0 n_comps: 3 n_hids: 20 activ: relu acc: 0.862745098039\n",
"~~ fold: 0 n_comps: 3 n_hids: 40 activ: logistic acc: 0.918767507003\n",
"~~ fold: 0 n_comps: 3 n_hids: 40 activ: tanh acc: 0.915966386555\n",
"~~ fold: 0 n_comps: 3 n_hids: 40 activ: relu acc: 0.924369747899\n",
"~~ fold: 0 n_comps: 3 n_hids: 60 activ: logistic acc: 0.913165266106\n",
"~~ fold: 0 n_comps: 3 n_hids: 60 activ: tanh acc: 0.921568627451\n",
"~~ fold: 0 n_comps: 3 n_hids: 60 activ: relu acc: 0.913165266106\n",
"~~ fold: 0 n_comps: 3 n_hids: 80 activ: logistic acc: 0.904761904762\n",
"~~ fold: 0 n_comps: 3 n_hids: 80 activ: tanh acc: 0.929971988796\n",
"~~ fold: 0 n_comps: 3 n_hids: 80 activ: relu acc: 0.90756302521\n",
"~~ fold: 0 n_comps: 3 n_hids: 100 activ: logistic acc: 0.899159663866\n",
"~~ fold: 0 n_comps: 3 n_hids: 100 activ: tanh acc: 0.927170868347\n",
"~~ fold: 0 n_comps: 3 n_hids: 100 activ: relu acc: 0.918767507003\n",
"~~ fold: 0 n_comps: 4 n_hids: 20 activ: logistic acc: 0.890756302521\n",
"~~ fold: 0 n_comps: 4 n_hids: 20 activ: tanh acc: 0.87675070028\n",
"~~ fold: 0 n_comps: 4 n_hids: 20 activ: relu acc: 0.887955182073\n",
"~~ fold: 0 n_comps: 4 n_hids: 40 activ: logistic acc: 0.901960784314\n",
"~~ fold: 0 n_comps: 4 n_hids: 40 activ: tanh acc: 0.913165266106\n",
"~~ fold: 0 n_comps: 4 n_hids: 40 activ: relu acc: 0.887955182073\n",
"~~ fold: 0 n_comps: 4 n_hids: 60 activ: logistic acc: 0.901960784314\n",
"~~ fold: 0 n_comps: 4 n_hids: 60 activ: tanh acc: 0.90756302521\n",
"~~ fold: 0 n_comps: 4 n_hids: 60 activ: relu acc: 0.913165266106\n",
"~~ fold: 0 n_comps: 4 n_hids: 80 activ: logistic acc: 0.879551820728\n",
"~~ fold: 0 n_comps: 4 n_hids: 80 activ: tanh acc: 0.918767507003\n",
"~~ fold: 0 n_comps: 4 n_hids: 80 activ: relu acc: 0.904761904762\n",
"~~ fold: 0 n_comps: 4 n_hids: 100 activ: logistic acc: 0.893557422969\n",
"~~ fold: 0 n_comps: 4 n_hids: 100 activ: tanh acc: 0.901960784314\n",
"~~ fold: 0 n_comps: 4 n_hids: 100 activ: relu acc: 0.918767507003\n",
"**** End of val for this fold, best n_comps: 2 best n_hids: 100 best_activ: tanh best acc: 0.941176470588\n",
"test set accuracy: 0.932773109244\n",
"test set confusion Matrix:\n",
"[[120 4 4]\n",
" [ 5 115 6]\n",
" [ 5 0 98]]\n",
"~~ fold: 1 n_comps: 1 n_hids: 20 activ: logistic acc: 0.759103641457\n",
"~~ fold: 1 n_comps: 1 n_hids: 20 activ: tanh acc: 0.750700280112\n",
"~~ fold: 1 n_comps: 1 n_hids: 20 activ: relu acc: 0.725490196078\n",
"~~ fold: 1 n_comps: 1 n_hids: 40 activ: logistic acc: 0.778711484594\n",
"~~ fold: 1 n_comps: 1 n_hids: 40 activ: tanh acc: 0.801120448179\n",
"~~ fold: 1 n_comps: 1 n_hids: 40 activ: relu acc: 0.75350140056\n",
"~~ fold: 1 n_comps: 1 n_hids: 60 activ: logistic acc: 0.801120448179\n",
"~~ fold: 1 n_comps: 1 n_hids: 60 activ: tanh acc: 0.767507002801\n",
"~~ fold: 1 n_comps: 1 n_hids: 60 activ: relu acc: 0.747899159664\n",
"~~ fold: 1 n_comps: 1 n_hids: 80 activ: logistic acc: 0.806722689076\n",
"~~ fold: 1 n_comps: 1 n_hids: 80 activ: tanh acc: 0.829131652661\n",
"~~ fold: 1 n_comps: 1 n_hids: 80 activ: relu acc: 0.837535014006\n",
"~~ fold: 1 n_comps: 1 n_hids: 100 activ: logistic acc: 0.759103641457\n",
"~~ fold: 1 n_comps: 1 n_hids: 100 activ: tanh acc: 0.742296918768\n",
"~~ fold: 1 n_comps: 1 n_hids: 100 activ: relu acc: 0.798319327731\n",
"~~ fold: 1 n_comps: 2 n_hids: 20 activ: logistic acc: 0.865546218487\n",
"~~ fold: 1 n_comps: 2 n_hids: 20 activ: tanh acc: 0.857142857143\n",
"~~ fold: 1 n_comps: 2 n_hids: 20 activ: relu acc: 0.820728291317\n",
"~~ fold: 1 n_comps: 2 n_hids: 40 activ: logistic acc: 0.893557422969\n",
"~~ fold: 1 n_comps: 2 n_hids: 40 activ: tanh acc: 0.899159663866\n",
"~~ fold: 1 n_comps: 2 n_hids: 40 activ: relu acc: 0.87675070028\n",
"~~ fold: 1 n_comps: 2 n_hids: 60 activ: logistic acc: 0.868347338936\n",
"~~ fold: 1 n_comps: 2 n_hids: 60 activ: tanh acc: 0.915966386555\n",
"~~ fold: 1 n_comps: 2 n_hids: 60 activ: relu acc: 0.809523809524\n",
"~~ fold: 1 n_comps: 2 n_hids: 80 activ: logistic acc: 0.899159663866\n",
"~~ fold: 1 n_comps: 2 n_hids: 80 activ: tanh acc: 0.918767507003\n",
"~~ fold: 1 n_comps: 2 n_hids: 80 activ: relu acc: 0.873949579832\n",
"~~ fold: 1 n_comps: 2 n_hids: 100 activ: logistic acc: 0.924369747899\n",
"~~ fold: 1 n_comps: 2 n_hids: 100 activ: tanh acc: 0.90756302521\n",
"~~ fold: 1 n_comps: 2 n_hids: 100 activ: relu acc: 0.834733893557\n",
"~~ fold: 1 n_comps: 3 n_hids: 20 activ: logistic acc: 0.865546218487\n",
"~~ fold: 1 n_comps: 3 n_hids: 20 activ: tanh acc: 0.893557422969\n",
"~~ fold: 1 n_comps: 3 n_hids: 20 activ: relu acc: 0.826330532213\n",
"~~ fold: 1 n_comps: 3 n_hids: 40 activ: logistic acc: 0.854341736695\n",
"~~ fold: 1 n_comps: 3 n_hids: 40 activ: tanh acc: 0.882352941176\n",
"~~ fold: 1 n_comps: 3 n_hids: 40 activ: relu acc: 0.899159663866\n",
"~~ fold: 1 n_comps: 3 n_hids: 60 activ: logistic acc: 0.887955182073\n",
"~~ fold: 1 n_comps: 3 n_hids: 60 activ: tanh acc: 0.904761904762\n",
"~~ fold: 1 n_comps: 3 n_hids: 60 activ: relu acc: 0.901960784314\n",
"~~ fold: 1 n_comps: 3 n_hids: 80 activ: logistic acc: 0.893557422969\n",
"~~ fold: 1 n_comps: 3 n_hids: 80 activ: tanh acc: 0.890756302521\n",
"~~ fold: 1 n_comps: 3 n_hids: 80 activ: relu acc: 0.882352941176\n",
"~~ fold: 1 n_comps: 3 n_hids: 100 activ: logistic acc: 0.899159663866\n",
"~~ fold: 1 n_comps: 3 n_hids: 100 activ: tanh acc: 0.887955182073\n",
"~~ fold: 1 n_comps: 3 n_hids: 100 activ: relu acc: 0.893557422969\n",
"~~ fold: 1 n_comps: 4 n_hids: 20 activ: logistic acc: 0.879551820728\n",
"~~ fold: 1 n_comps: 4 n_hids: 20 activ: tanh acc: 0.857142857143\n",
"~~ fold: 1 n_comps: 4 n_hids: 20 activ: relu acc: 0.890756302521\n",
"~~ fold: 1 n_comps: 4 n_hids: 40 activ: logistic acc: 0.868347338936\n",
"~~ fold: 1 n_comps: 4 n_hids: 40 activ: tanh acc: 0.885154061625\n",
"~~ fold: 1 n_comps: 4 n_hids: 40 activ: relu acc: 0.893557422969\n",
"~~ fold: 1 n_comps: 4 n_hids: 60 activ: logistic acc: 0.879551820728\n",
"~~ fold: 1 n_comps: 4 n_hids: 60 activ: tanh acc: 0.890756302521\n",
"~~ fold: 1 n_comps: 4 n_hids: 60 activ: relu acc: 0.87675070028\n",
"~~ fold: 1 n_comps: 4 n_hids: 80 activ: logistic acc: 0.829131652661\n",
"~~ fold: 1 n_comps: 4 n_hids: 80 activ: tanh acc: 0.882352941176\n",
"~~ fold: 1 n_comps: 4 n_hids: 80 activ: relu acc: 0.882352941176\n",
"~~ fold: 1 n_comps: 4 n_hids: 100 activ: logistic acc: 0.859943977591\n",
"~~ fold: 1 n_comps: 4 n_hids: 100 activ: tanh acc: 0.890756302521\n",
"~~ fold: 1 n_comps: 4 n_hids: 100 activ: relu acc: 0.882352941176\n",
"**** End of val for this fold, best n_comps: 2 best n_hids: 100 best_activ: logistic best acc: 0.924369747899\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"test set accuracy: 0.943977591036\n",
"test set confusion Matrix:\n",
"[[107 4 3]\n",
" [ 4 110 1]\n",
" [ 8 0 120]]\n",
"~~ fold: 2 n_comps: 1 n_hids: 20 activ: logistic acc: 0.837535014006\n",
"~~ fold: 2 n_comps: 1 n_hids: 20 activ: tanh acc: 0.703081232493\n",
"~~ fold: 2 n_comps: 1 n_hids: 20 activ: relu acc: 0.711484593838\n",
"~~ fold: 2 n_comps: 1 n_hids: 40 activ: logistic acc: 0.823529411765\n",
"~~ fold: 2 n_comps: 1 n_hids: 40 activ: tanh acc: 0.812324929972\n",
"~~ fold: 2 n_comps: 1 n_hids: 40 activ: relu acc: 0.759103641457\n",
"~~ fold: 2 n_comps: 1 n_hids: 60 activ: logistic acc: 0.795518207283\n",
"~~ fold: 2 n_comps: 1 n_hids: 60 activ: tanh acc: 0.806722689076\n",
"~~ fold: 2 n_comps: 1 n_hids: 60 activ: relu acc: 0.756302521008\n",
"~~ fold: 2 n_comps: 1 n_hids: 80 activ: logistic acc: 0.803921568627\n",
"~~ fold: 2 n_comps: 1 n_hids: 80 activ: tanh acc: 0.834733893557\n",
"~~ fold: 2 n_comps: 1 n_hids: 80 activ: relu acc: 0.829131652661\n",
"~~ fold: 2 n_comps: 1 n_hids: 100 activ: logistic acc: 0.823529411765\n",
"~~ fold: 2 n_comps: 1 n_hids: 100 activ: tanh acc: 0.689075630252\n",
"~~ fold: 2 n_comps: 1 n_hids: 100 activ: relu acc: 0.750700280112\n",
"~~ fold: 2 n_comps: 2 n_hids: 20 activ: logistic acc: 0.87675070028\n",
"~~ fold: 2 n_comps: 2 n_hids: 20 activ: tanh acc: 0.901960784314\n",
"~~ fold: 2 n_comps: 2 n_hids: 20 activ: relu acc: 0.809523809524\n",
"~~ fold: 2 n_comps: 2 n_hids: 40 activ: logistic acc: 0.882352941176\n",
"~~ fold: 2 n_comps: 2 n_hids: 40 activ: tanh acc: 0.893557422969\n",
"~~ fold: 2 n_comps: 2 n_hids: 40 activ: relu acc: 0.831932773109\n",
"~~ fold: 2 n_comps: 2 n_hids: 60 activ: logistic acc: 0.885154061625\n",
"~~ fold: 2 n_comps: 2 n_hids: 60 activ: tanh acc: 0.893557422969\n",
"~~ fold: 2 n_comps: 2 n_hids: 60 activ: relu acc: 0.809523809524\n",
"~~ fold: 2 n_comps: 2 n_hids: 80 activ: logistic acc: 0.871148459384\n",
"~~ fold: 2 n_comps: 2 n_hids: 80 activ: tanh acc: 0.915966386555\n",
"~~ fold: 2 n_comps: 2 n_hids: 80 activ: relu acc: 0.885154061625\n",
"~~ fold: 2 n_comps: 2 n_hids: 100 activ: logistic acc: 0.885154061625\n",
"~~ fold: 2 n_comps: 2 n_hids: 100 activ: tanh acc: 0.882352941176\n",
"~~ fold: 2 n_comps: 2 n_hids: 100 activ: relu acc: 0.879551820728\n",
"~~ fold: 2 n_comps: 3 n_hids: 20 activ: logistic acc: 0.859943977591\n",
"~~ fold: 2 n_comps: 3 n_hids: 20 activ: tanh acc: 0.871148459384\n",
"~~ fold: 2 n_comps: 3 n_hids: 20 activ: relu acc: 0.854341736695\n",
"~~ fold: 2 n_comps: 3 n_hids: 40 activ: logistic acc: 0.879551820728\n",
"~~ fold: 2 n_comps: 3 n_hids: 40 activ: tanh acc: 0.896358543417\n",
"~~ fold: 2 n_comps: 3 n_hids: 40 activ: relu acc: 0.865546218487\n",
"~~ fold: 2 n_comps: 3 n_hids: 60 activ: logistic acc: 0.879551820728\n",
"~~ fold: 2 n_comps: 3 n_hids: 60 activ: tanh acc: 0.910364145658\n",
"~~ fold: 2 n_comps: 3 n_hids: 60 activ: relu acc: 0.893557422969\n",
"~~ fold: 2 n_comps: 3 n_hids: 80 activ: logistic acc: 0.873949579832\n",
"~~ fold: 2 n_comps: 3 n_hids: 80 activ: tanh acc: 0.899159663866\n",
"~~ fold: 2 n_comps: 3 n_hids: 80 activ: relu acc: 0.873949579832\n",
"~~ fold: 2 n_comps: 3 n_hids: 100 activ: logistic acc: 0.885154061625\n",
"~~ fold: 2 n_comps: 3 n_hids: 100 activ: tanh acc: 0.893557422969\n",
"~~ fold: 2 n_comps: 3 n_hids: 100 activ: relu acc: 0.868347338936\n",
"~~ fold: 2 n_comps: 4 n_hids: 20 activ: logistic acc: 0.843137254902\n",
"~~ fold: 2 n_comps: 4 n_hids: 20 activ: tanh acc: 0.854341736695\n",
"~~ fold: 2 n_comps: 4 n_hids: 20 activ: relu acc: 0.837535014006\n",
"~~ fold: 2 n_comps: 4 n_hids: 40 activ: logistic acc: 0.843137254902\n",
"~~ fold: 2 n_comps: 4 n_hids: 40 activ: tanh acc: 0.885154061625\n",
"~~ fold: 2 n_comps: 4 n_hids: 40 activ: relu acc: 0.848739495798\n",
"~~ fold: 2 n_comps: 4 n_hids: 60 activ: logistic acc: 0.851540616246\n",
"~~ fold: 2 n_comps: 4 n_hids: 60 activ: tanh acc: 0.859943977591\n",
"~~ fold: 2 n_comps: 4 n_hids: 60 activ: relu acc: 0.873949579832\n",
"~~ fold: 2 n_comps: 4 n_hids: 80 activ: logistic acc: 0.854341736695\n",
"~~ fold: 2 n_comps: 4 n_hids: 80 activ: tanh acc: 0.899159663866\n",
"~~ fold: 2 n_comps: 4 n_hids: 80 activ: relu acc: 0.857142857143\n",
"~~ fold: 2 n_comps: 4 n_hids: 100 activ: logistic acc: 0.84593837535\n",
"~~ fold: 2 n_comps: 4 n_hids: 100 activ: tanh acc: 0.885154061625\n",
"~~ fold: 2 n_comps: 4 n_hids: 100 activ: relu acc: 0.862745098039\n",
"**** End of val for this fold, best n_comps: 2 best n_hids: 80 best_activ: tanh best acc: 0.915966386555\n",
"test set accuracy: 0.924369747899\n",
"test set confusion Matrix:\n",
"[[103 6 1]\n",
" [ 9 108 1]\n",
" [ 8 2 119]]\n",
"~~ fold: 3 n_comps: 1 n_hids: 20 activ: logistic acc: 0.717086834734\n",
"~~ fold: 3 n_comps: 1 n_hids: 20 activ: tanh acc: 0.728291316527\n",
"~~ fold: 3 n_comps: 1 n_hids: 20 activ: relu acc: 0.725490196078\n",
"~~ fold: 3 n_comps: 1 n_hids: 40 activ: logistic acc: 0.795518207283\n",
"~~ fold: 3 n_comps: 1 n_hids: 40 activ: tanh acc: 0.789915966387\n",
"~~ fold: 3 n_comps: 1 n_hids: 40 activ: relu acc: 0.756302521008\n",
"~~ fold: 3 n_comps: 1 n_hids: 60 activ: logistic acc: 0.795518207283\n",
"~~ fold: 3 n_comps: 1 n_hids: 60 activ: tanh acc: 0.848739495798\n",
"~~ fold: 3 n_comps: 1 n_hids: 60 activ: relu acc: 0.759103641457\n",
"~~ fold: 3 n_comps: 1 n_hids: 80 activ: logistic acc: 0.689075630252\n",
"~~ fold: 3 n_comps: 1 n_hids: 80 activ: tanh acc: 0.756302521008\n",
"~~ fold: 3 n_comps: 1 n_hids: 80 activ: relu acc: 0.773109243697\n",
"~~ fold: 3 n_comps: 1 n_hids: 100 activ: logistic acc: 0.801120448179\n",
"~~ fold: 3 n_comps: 1 n_hids: 100 activ: tanh acc: 0.717086834734\n",
"~~ fold: 3 n_comps: 1 n_hids: 100 activ: relu acc: 0.812324929972\n",
"~~ fold: 3 n_comps: 2 n_hids: 20 activ: logistic acc: 0.882352941176\n",
"~~ fold: 3 n_comps: 2 n_hids: 20 activ: tanh acc: 0.862745098039\n",
"~~ fold: 3 n_comps: 2 n_hids: 20 activ: relu acc: 0.840336134454\n",
"~~ fold: 3 n_comps: 2 n_hids: 40 activ: logistic acc: 0.887955182073\n",
"~~ fold: 3 n_comps: 2 n_hids: 40 activ: tanh acc: 0.929971988796\n",
"~~ fold: 3 n_comps: 2 n_hids: 40 activ: relu acc: 0.859943977591\n",
"~~ fold: 3 n_comps: 2 n_hids: 60 activ: logistic acc: 0.901960784314\n",
"~~ fold: 3 n_comps: 2 n_hids: 60 activ: tanh acc: 0.910364145658\n",
"~~ fold: 3 n_comps: 2 n_hids: 60 activ: relu acc: 0.859943977591\n",
"~~ fold: 3 n_comps: 2 n_hids: 80 activ: logistic acc: 0.848739495798\n",
"~~ fold: 3 n_comps: 2 n_hids: 80 activ: tanh acc: 0.882352941176\n",
"~~ fold: 3 n_comps: 2 n_hids: 80 activ: relu acc: 0.857142857143\n",
"~~ fold: 3 n_comps: 2 n_hids: 100 activ: logistic acc: 0.879551820728\n",
"~~ fold: 3 n_comps: 2 n_hids: 100 activ: tanh acc: 0.857142857143\n",
"~~ fold: 3 n_comps: 2 n_hids: 100 activ: relu acc: 0.910364145658\n",
"~~ fold: 3 n_comps: 3 n_hids: 20 activ: logistic acc: 0.871148459384\n",
"~~ fold: 3 n_comps: 3 n_hids: 20 activ: tanh acc: 0.862745098039\n",
"~~ fold: 3 n_comps: 3 n_hids: 20 activ: relu acc: 0.857142857143\n",
"~~ fold: 3 n_comps: 3 n_hids: 40 activ: logistic acc: 0.873949579832\n",
"~~ fold: 3 n_comps: 3 n_hids: 40 activ: tanh acc: 0.879551820728\n",
"~~ fold: 3 n_comps: 3 n_hids: 40 activ: relu acc: 0.879551820728\n",
"~~ fold: 3 n_comps: 3 n_hids: 60 activ: logistic acc: 0.868347338936\n",
"~~ fold: 3 n_comps: 3 n_hids: 60 activ: tanh acc: 0.879551820728\n",
"~~ fold: 3 n_comps: 3 n_hids: 60 activ: relu acc: 0.896358543417\n",
"~~ fold: 3 n_comps: 3 n_hids: 80 activ: logistic acc: 0.87675070028\n",
"~~ fold: 3 n_comps: 3 n_hids: 80 activ: tanh acc: 0.904761904762\n",
"~~ fold: 3 n_comps: 3 n_hids: 80 activ: relu acc: 0.885154061625\n",
"~~ fold: 3 n_comps: 3 n_hids: 100 activ: logistic acc: 0.887955182073\n",
"~~ fold: 3 n_comps: 3 n_hids: 100 activ: tanh acc: 0.915966386555\n",
"~~ fold: 3 n_comps: 3 n_hids: 100 activ: relu acc: 0.890756302521\n",
"~~ fold: 3 n_comps: 4 n_hids: 20 activ: logistic acc: 0.868347338936\n",
"~~ fold: 3 n_comps: 4 n_hids: 20 activ: tanh acc: 0.87675070028\n",
"~~ fold: 3 n_comps: 4 n_hids: 20 activ: relu acc: 0.873949579832\n",
"~~ fold: 3 n_comps: 4 n_hids: 40 activ: logistic acc: 0.851540616246\n",
"~~ fold: 3 n_comps: 4 n_hids: 40 activ: tanh acc: 0.879551820728\n",
"~~ fold: 3 n_comps: 4 n_hids: 40 activ: relu acc: 0.848739495798\n",
"~~ fold: 3 n_comps: 4 n_hids: 60 activ: logistic acc: 0.862745098039\n",
"~~ fold: 3 n_comps: 4 n_hids: 60 activ: tanh acc: 0.868347338936\n",
"~~ fold: 3 n_comps: 4 n_hids: 60 activ: relu acc: 0.868347338936\n",
"~~ fold: 3 n_comps: 4 n_hids: 80 activ: logistic acc: 0.84593837535\n",
"~~ fold: 3 n_comps: 4 n_hids: 80 activ: tanh acc: 0.882352941176\n",
"~~ fold: 3 n_comps: 4 n_hids: 80 activ: relu acc: 0.87675070028\n",
"~~ fold: 3 n_comps: 4 n_hids: 100 activ: logistic acc: 0.854341736695\n",
"~~ fold: 3 n_comps: 4 n_hids: 100 activ: tanh acc: 0.873949579832\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"~~ fold: 3 n_comps: 4 n_hids: 100 activ: relu acc: 0.859943977591\n",
"**** End of val for this fold, best n_comps: 2 best n_hids: 40 best_activ: tanh best acc: 0.929971988796\n",
"test set accuracy: 0.915966386555\n",
"test set confusion Matrix:\n",
"[[ 94 4 7]\n",
" [ 1 120 2]\n",
" [ 11 5 113]]\n",
"~~ fold: 4 n_comps: 1 n_hids: 20 activ: logistic acc: 0.719887955182\n",
"~~ fold: 4 n_comps: 1 n_hids: 20 activ: tanh acc: 0.72268907563\n",
"~~ fold: 4 n_comps: 1 n_hids: 20 activ: relu acc: 0.725490196078\n",
"~~ fold: 4 n_comps: 1 n_hids: 40 activ: logistic acc: 0.798319327731\n",
"~~ fold: 4 n_comps: 1 n_hids: 40 activ: tanh acc: 0.812324929972\n",
"~~ fold: 4 n_comps: 1 n_hids: 40 activ: relu acc: 0.739495798319\n",
"~~ fold: 4 n_comps: 1 n_hids: 60 activ: logistic acc: 0.677871148459\n",
"~~ fold: 4 n_comps: 1 n_hids: 60 activ: tanh acc: 0.806722689076\n",
"~~ fold: 4 n_comps: 1 n_hids: 60 activ: relu acc: 0.72268907563\n",
"~~ fold: 4 n_comps: 1 n_hids: 80 activ: logistic acc: 0.834733893557\n",
"~~ fold: 4 n_comps: 1 n_hids: 80 activ: tanh acc: 0.75350140056\n",
"~~ fold: 4 n_comps: 1 n_hids: 80 activ: relu acc: 0.778711484594\n",
"~~ fold: 4 n_comps: 1 n_hids: 100 activ: logistic acc: 0.806722689076\n",
"~~ fold: 4 n_comps: 1 n_hids: 100 activ: tanh acc: 0.756302521008\n",
"~~ fold: 4 n_comps: 1 n_hids: 100 activ: relu acc: 0.781512605042\n",
"~~ fold: 4 n_comps: 2 n_hids: 20 activ: logistic acc: 0.865546218487\n",
"~~ fold: 4 n_comps: 2 n_hids: 20 activ: tanh acc: 0.890756302521\n",
"~~ fold: 4 n_comps: 2 n_hids: 20 activ: relu acc: 0.848739495798\n",
"~~ fold: 4 n_comps: 2 n_hids: 40 activ: logistic acc: 0.871148459384\n",
"~~ fold: 4 n_comps: 2 n_hids: 40 activ: tanh acc: 0.865546218487\n",
"~~ fold: 4 n_comps: 2 n_hids: 40 activ: relu acc: 0.887955182073\n",
"~~ fold: 4 n_comps: 2 n_hids: 60 activ: logistic acc: 0.885154061625\n",
"~~ fold: 4 n_comps: 2 n_hids: 60 activ: tanh acc: 0.899159663866\n",
"~~ fold: 4 n_comps: 2 n_hids: 60 activ: relu acc: 0.873949579832\n",
"~~ fold: 4 n_comps: 2 n_hids: 80 activ: logistic acc: 0.873949579832\n",
"~~ fold: 4 n_comps: 2 n_hids: 80 activ: tanh acc: 0.899159663866\n",
"~~ fold: 4 n_comps: 2 n_hids: 80 activ: relu acc: 0.848739495798\n",
"~~ fold: 4 n_comps: 2 n_hids: 100 activ: logistic acc: 0.885154061625\n",
"~~ fold: 4 n_comps: 2 n_hids: 100 activ: tanh acc: 0.904761904762\n",
"~~ fold: 4 n_comps: 2 n_hids: 100 activ: relu acc: 0.859943977591\n",
"~~ fold: 4 n_comps: 3 n_hids: 20 activ: logistic acc: 0.865546218487\n",
"~~ fold: 4 n_comps: 3 n_hids: 20 activ: tanh acc: 0.848739495798\n",
"~~ fold: 4 n_comps: 3 n_hids: 20 activ: relu acc: 0.820728291317\n",
"~~ fold: 4 n_comps: 3 n_hids: 40 activ: logistic acc: 0.859943977591\n",
"~~ fold: 4 n_comps: 3 n_hids: 40 activ: tanh acc: 0.890756302521\n",
"~~ fold: 4 n_comps: 3 n_hids: 40 activ: relu acc: 0.854341736695\n",
"~~ fold: 4 n_comps: 3 n_hids: 60 activ: logistic acc: 0.882352941176\n",
"~~ fold: 4 n_comps: 3 n_hids: 60 activ: tanh acc: 0.885154061625\n",
"~~ fold: 4 n_comps: 3 n_hids: 60 activ: relu acc: 0.868347338936\n",
"~~ fold: 4 n_comps: 3 n_hids: 80 activ: logistic acc: 0.862745098039\n",
"~~ fold: 4 n_comps: 3 n_hids: 80 activ: tanh acc: 0.885154061625\n",
"~~ fold: 4 n_comps: 3 n_hids: 80 activ: relu acc: 0.865546218487\n",
"~~ fold: 4 n_comps: 3 n_hids: 100 activ: logistic acc: 0.862745098039\n",
"~~ fold: 4 n_comps: 3 n_hids: 100 activ: tanh acc: 0.862745098039\n",
"~~ fold: 4 n_comps: 3 n_hids: 100 activ: relu acc: 0.885154061625\n",
"~~ fold: 4 n_comps: 4 n_hids: 20 activ: logistic acc: 0.843137254902\n",
"~~ fold: 4 n_comps: 4 n_hids: 20 activ: tanh acc: 0.857142857143\n",
"~~ fold: 4 n_comps: 4 n_hids: 20 activ: relu acc: 0.823529411765\n",
"~~ fold: 4 n_comps: 4 n_hids: 40 activ: logistic acc: 0.868347338936\n",
"~~ fold: 4 n_comps: 4 n_hids: 40 activ: tanh acc: 0.87675070028\n",
"~~ fold: 4 n_comps: 4 n_hids: 40 activ: relu acc: 0.854341736695\n",
"~~ fold: 4 n_comps: 4 n_hids: 60 activ: logistic acc: 0.868347338936\n",
"~~ fold: 4 n_comps: 4 n_hids: 60 activ: tanh acc: 0.873949579832\n",
"~~ fold: 4 n_comps: 4 n_hids: 60 activ: relu acc: 0.84593837535\n",
"~~ fold: 4 n_comps: 4 n_hids: 80 activ: logistic acc: 0.854341736695\n",
"~~ fold: 4 n_comps: 4 n_hids: 80 activ: tanh acc: 0.882352941176\n",
"~~ fold: 4 n_comps: 4 n_hids: 80 activ: relu acc: 0.851540616246\n",
"~~ fold: 4 n_comps: 4 n_hids: 100 activ: logistic acc: 0.87675070028\n",
"~~ fold: 4 n_comps: 4 n_hids: 100 activ: tanh acc: 0.873949579832\n",
"~~ fold: 4 n_comps: 4 n_hids: 100 activ: relu acc: 0.890756302521\n",
"**** End of val for this fold, best n_comps: 2 best n_hids: 100 best_activ: tanh best acc: 0.904761904762\n",
"test set accuracy: 0.93837535014\n",
"test set confusion Matrix:\n",
"[[111 7 6]\n",
" [ 4 110 0]\n",
" [ 5 0 114]]\n",
"test accuracies:[0.9327731092436975, 0.94397759103641454, 0.92436974789915971, 0.91596638655462181, 0.93837535014005602]\n"
]
}
],
"source": [
"## ======== nested cross validation ======== ##\n",
"\n",
"\n",
"my_seed = 1234\n",
"\n",
"# the next three lines contain all the parameters for our nested cross-validation to test on\n",
"n_components = [1, 2, 3, 4]\n",
"\n",
"n_hiddens = np.arange(20, 101, 20)\n",
"\n",
"activations = ['logistic', 'tanh', 'relu']\n",
"\n",
"fold_k = 5\n",
"\n",
"# this function will print a running feed of its results and will return a final array of test accuracies\n",
"accuracies = nested_cross_val(norm_mfccs, y, fold_k, n_components, n_hiddens, activations, my_seed)\n",
"\n",
"print 'test accuracies:' + str(accuracies)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Analysing Results of Nested Cross-Validation\n",
"\n",
"First we will discuss the parameters that achieved the highest accuracies. \n",
"\n",
"Across all five folds the optimum number of componenets for PCA was deemed to be 2. This was the only parameter to receuve the same best choice across all folds. The next most unanimous decision was for activation function, which for four out of five folds was chosen to be tanh. The logistic function also performed well on the second fold however the results clearly point to using tanh. In terms of number of hidden nodes there was still a clear best choice with 100 performing the best across 3 of five folds.\n",
"\n",
"In terms of our test fold accuracies, results were very promising. All folds achived over 91% accuracy and there was very little variance with the highest accuracy being 94.4%. This not only means has the capacity to be very accuracte but that it is also consistent."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Initial Accuracy Test\n",
"\n",
"We will now test our model using the best parameter from cross validation; 2 components, tanh activation function and 100 hidden neurons. It's accuracy when tested on unseen examples from patches otherwise present in the data will be averaged over 100 folds. We will also compute a confusion matrix and visualise how examples at different MIDI notes perform."
]
},
{
"cell_type": "code",
"execution_count": 153,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"average accuracy over 100 iterations: 0.9568\n",
"total confusion matrix: \n",
"[[3212 119 74]\n",
" [ 83 3281 37]\n",
" [ 100 19 3075]]\n"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAYIAAAEWCAYAAABrDZDcAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMS4wLCBo\ndHRwOi8vbWF0cGxvdGxpYi5vcmcvpW3flQAAHHtJREFUeJzt3Xm4XFWZ7/HvjwQJk0QgYAKEA0qj\niMoQEBrti6gtMsoVVFQEBFEvtNjXViLXFmylG+9VEC8qBEFGUQZRVBwQEFpbwATzMAVFIEJICAES\nE0FF8O0/1irYKU7VqTrUrjon6/d5nnrOHlbt/e51dtVba+1JEYGZmZVrtUEHYGZmg+VEYGZWOCcC\nM7PCORGYmRXOicDMrHBOBGZmhXMi6DFJZ0j61zbzQ9JLOyk7HklaU9L3JP1B0qWDjqdfqv/XHi93\nd0kLer3cyvJX2gclfUjSYkl/lLRB/rvlKJddS530S911P5Y4EXRI0nxJT0rasGn63LzDDwFExAcj\n4jOdLLNdWUmH5eV+rGn6Akm7dxDvoHbiA4GNgQ0i4qDmmZJOlHRh/8PqHUk/k3RkD5e3s6SrJC2T\n9JikmyUd3qvlt1PdByWtDpwC/GNErBMRj+a/9/Z6vbkO/5wTTeP1vV6vxzrjRNCd+4CDGyOSXgms\nWeP6HgOOk/TCGtfRa5sDv42IpwYdSIOkCZ1MGwRJuwLXAtcDLwU2AD4EvGUA4WwMTALu6NP6jsmJ\npvHat0/rtSZOBN25AHhvZfxQ4PxqAUnnSvpsZfxjkhZJWijpfe3KDmMe8Evgn4ebKWkNSV/My16Y\nh9eQtDbwQ2Ba5dfWNEmrSZop6R5Jj0q6RNL6eVmTJF2Ypy+T9CtJG7dY78vzL7plku6QtF+e/mng\nU8A78jqPaLNtjWWFpA9KulvSUklflqTK/PdLmidphaQ7Je3QLoZKvX41/8p+HHh9i2lrSPq8pPtz\nd8gZktasLGf/3OJbnutsT0knAa8DTs/beHrT9uyUlzWxMu1tkua2qIL/B5wXEZ+LiEcimRMRb29R\nX43/X6M+DqjMe6mk65W65R6R9K08XZJOlfRwnnerpG0rdfVZSX8H/CYvapmkayv/n0ZX5kj11XJf\n74ak4yTd2KhDpe6qOyRNyuOXSnoob8sNkl5Ree+5kr4i6Yf5//MLSS/On42lku6StH2l/HxJn8h1\nuVTS1xvrGSauaZIul7RE0n2SPlyZt7Ok2XlfWSzplNFu/0BEhF8dvID5wBtJH5aXAxOAB0i/gAMY\nyuXOBT6bh/cEFgPbAmsD38hlX9pcdpj1HQb8HNgOWAasn6cvAHbPw/8G3AhsBEwB/gv4TJ63O7Cg\naZkfyeU3BdYAzgQuzvM+AHwPWCtv247AC4eJa3Xgd8DxwAuAPYAVwNZ5/onAhW3qcaX5uT6+D0wG\npgNLgD3zvIOAB4GdAJF+MW/eQQznAn8AdiP92JnUYtoXgSuB9YF18/b/R17Gzrn8m3L5TYCX5Xk/\nA45s2q7q//VO4C2VeVcAHx2mLtYCngZe36a+Vvo/5jqZlmN6B/A4MDXPuxj4P5Xte22e/mZgTq5j\nkfbfqZW6auyvQ3k7JrbYrnb11XZfH2a7nlOHlXmrATfkfWUrYCmwfWX++/L618gxza3MOxd4hLT/\nTiK1tu4j/YCbAHwWuK7pc307sFnerl9U6uOZus8xzSH90HkBsCVwL/DmPP+XwCF5eB1gl0F/Z3X1\n/TboAMbLi2cTwSeB/8g7/tXARFongnOAkyvL+Du6TAR5+BLgc3m4mgjuAfaqvOfNwPw8/MxOXJk/\nD3hDZXwq8Ne8De8jJZJXjVAPrwMeAlarTLsYODEPn0j3ieC1lfFLgJl5+MfAsaOI4Vzg/Kb3rDSN\n9IX4OPCSyrRdgfvy8JnAqS224We0TwTHARfl4fWBJ8hfvE3v2SS/72Vt6us5/8em+XOB/fPw+cAs\nYNOmMnsAvwV2qdbZMPvrEC0SQQf11XZfb1GHT5B+5DRen6nMHyJ1jc4DPtFm+yfn9axX2Z6zKvP/\nCZhXGX8lsKzpc/3ByvhewD3NdQ+8Bri/ad2fAL6eh28APg1s2O7zM1Zf7hrq3gXAu0hf1Oe3L8o0\nUquh4fejXOengA9JevEwy68u8/d5WiubA1fk7pRlpA/Z06S+4QtIX7zfzE37/6t08LDZNOCBiPhb\n03o36WqLVvZQZfgJ0i8qSL/S7hllDA/wXNVpU0i/yOdU6uNHeXq7dXfiQmBfSesAbwf+MyIWDVNu\nKfA3UkLuiKT35u6qRszbAo0TGD5O+sK+OXelvA8gIq4FTge+DCyWNEvdH3caqb5Gs69/OCImV17P\nnL0UEfOB60gJ4cuV7Z8g6eTcPbac9EUOz9YBpJZJw5+GGV+HlTXHPdxnaHNSV+uyyvYfT/rsABxB\nSn53KXWr7tNmu8ccJ4IuRcTvSU3NvYBvj1B8EekLpWH6KNd5V17X8U2zFpJ20OryFzbeNsyiHiB1\nWVQ/fJMi4sGI+GtEfDoitgH+HtiHlY+HVNe5maTqvjOd1IXTaw8ALxllDMNtf3XaI6QvhVdU6mK9\niGh8SbRad6tlPzsz4kFSV8EBwCGkJDtcuSdyube1W16DpM2Bs4BjSGdlTSZ1aygv76GIeH9ETCN1\n9X2l0b8fEV+KiB2BV5C+sD423DraGKm+erKvN0jai9TiuIZ0HKXhXcD+pNb5eqREAbkORqk57oXD\nlHmA1PqpfnbWjYi9ACLi7og4mNRN+zngMqVjdeOCE8HoHAHsERGPj1DuEuAwSdtIWgs44Xms89PA\n4aSmcMPFwCclTVE6rfVTpF+jkH4FbSBpvUr5M4CT8hcK+X375+HXS3ql0tk0y0ldRk8PE8dNpC6C\nj0taXelU1n2Bbz6PbWvla8C/SNoxH/B8aY79eceQWxNnAadK2ghA0iaS3pyLnA0cLukNSgfZN5H0\nsjxvMamPuJ3zSb/QX0k6RtDKx0n7yMckbZDjeLWk4bZlbVISWpLLHU5qEZDHD5K0aR5dmss+rXQA\n+zW5hfc48GeG/9+21EF99Wxfz/vy2cCRpBMy9s2JAdKxgb8Aj5JaKP8+2vVUHC1pU6UTJ44HvjVM\nmZuB5flA9pq5ZbKtpJ1yzO+RNCXX07L8nq7qeJCcCEYhIu6JiNkdlPsh6WDWtaSDm9c+j3XeR/pl\nWf2V8VlgNnArcBtwS57WaEVcDNybm7LTgNNIB/t+ImkF6cDxa/KyXgxcRkoC80inMz7nfP+IeBLY\nj3R64yPAV4D35vX1VERcCpxEOvC4AvgO6aB5r2I4jvR/uTF3M/wU2Dqv+2ZS4j2VdND4ep5tfZ0G\nHJjPMvlSi2Vfkctf0e4HQ0T8F6kPfw/S/+oxUj//VcOUvRP4AqkVsZiUZH5RKbITcJOkP5L+z8fm\n/eaFpC/xpaSuj0eBz7eulpba1ddo9vXGmVeN15w8fRbw3Yi4KiIeJf3w+lpOlOfnbXiQdFD+xlFs\nR7NvAD8hHfy9l/wZqoqIp0k/NrYj9Qg8Qvqh0vihtSdwR67704B3RsSfexBbXygf6DCzHpN0D/CB\niPjpoGOx4UmaTzrwX/T/yC0CsxpIehupa2bUrUCzfpk4chEz64aknwHbkM4r/9sIxc0Gzl1DZmaF\nc9eQmVnhxkXX0IYbbhhDQ0ODDsPMbFyZM2fOIxExZaRy4yIRDA0NMXv2iGdrmplZhaSO7mbgriEz\ns8I5EZiZFc6JwMyscE4EZmaFcyIwMyucE4GZWeFqSwSSNpN0ndLzZu+QdGyefqKkB/PDNeZWbi9r\nZmYDUOd1BE+RntN6i6R1SU82ujrPOzUiRnMbXDMz67HaEkF+NN+iPLxC0jye3+MMzcysBn25sljS\nELA96clSuwHHSHov6aEqH42IpcO85yjgKIDp05/XU+/MzAZuaOYPnhmef/LeA4zkuWo/WJwf4H05\n8JGIWA58lfQs2O1ILYYvDPe+iJgVETMiYsaUKSPeKsPMzEap1kSQn5F6OXBRRHwbICIWR8TTlWeg\n7lxnDGZm1l6dZw2J9ADqeRFxSmX61EqxA4Db64rBzMxGVucxgt2AQ4DbJM3N044HDpa0HekxfvOB\nD9QYg5mZjaDOs4Z+DmiYWVfVtU4zM+ueryw2MyucE4GZWeHGxRPKzMxK0u9rDtwiMDMrnBOBmVnh\nnAjMzArnRGBmVjgnAjOzwjkRmJkVzonAzKxwTgRmZoVzIjAzK5wTgZlZ4ZwIzMwK50RgZlY4JwIz\ns8I5EZiZFc6JwMyscE4EZmaF84NpzMz6rPnBM/1+EE0ztwjMzArnRGBmVjgnAjOzwjkRmJkVzonA\nzKxwTgRmZoVzIjAzK5wTgZlZ4ZwIzMwK50RgZlY4JwIzs8I5EZiZFa62RCBpM0nXSZon6Q5Jx+bp\n60u6WtLd+e+L6orBzMxGVmeL4CngoxHxcmAX4GhJ2wAzgWsiYivgmjxuZmYDUlsiiIhFEXFLHl4B\nzAM2AfYHzsvFzgPeWlcMZmY2sr4cI5A0BGwP3ARsHBGLICULYKMW7zlK0mxJs5csWdKPMM3MilR7\nIpC0DnA58JGIWN7p+yJiVkTMiIgZU6ZMqS9AM7PC1ZoIJK1OSgIXRcS38+TFkqbm+VOBh+uMwczM\n2qvzrCEBZwPzIuKUyqwrgUPz8KHAd+uKwczMRlbnM4t3Aw4BbpM0N087HjgZuETSEcD9wEE1xmBm\nZiOoLRFExM8BtZj9hrrWa2Zm3fGVxWZmhXMiMDMrXJ3HCMzMijI08wfPDM8/ee8BRtIdtwjMzArn\nRGBmVjgnAjOzwjkRmJkVzonAzKxwTgRmZoVzIjAzK5wTgZlZ4ZwIzMwK50RgZlY4JwIzs8I5EZiZ\nFc6JwMyscE4EZmaFcyIwMyucE4GZWeGcCMzMCudEYGZWOCcCM7PCORGYmRXOicDMrHBOBGZmhXMi\nMDMrnBOBmVnhOkoEkratOxAzMxuMTlsEZ0i6WdL/kjS51ojMzKyvOkoEEfFa4N3AZsBsSd+Q9KZa\nIzMzs77o+BhBRNwNfBI4DvgfwJck3SXpf9YVnJmZ1a/TYwSvknQqMA/YA9g3Il6eh0+tMT4zM6tZ\npy2C04FbgFdHxNERcQtARCwktRKeQ9I5kh6WdHtl2omSHpQ0N7/2er4bYGZmz8/EDsvtBfwpIp4G\nkLQaMCkinoiIC1q851xSAjm/afqpEfH50QRrZma912mL4KfAmpXxtfK0liLiBuCxUcZlZmZ90mki\nmBQRf2yM5OG1RrnOYyTdmruOXjTKZZiZWY90mggel7RDY0TSjsCfRrG+rwIvAbYDFgFfaFVQ0lGS\nZkuavWTJklGsysyse0Mzf/DMqxSdHiP4CHCppIV5fCrwjm5XFhGLG8OSzgK+36bsLGAWwIwZM6Lb\ndZmZWWc6SgQR8StJLwO2BgTcFRF/7XZlkqZGxKI8egBwe7vyZmZWv05bBAA7AUP5PdtLIiKazwh6\nhqSLgd2BDSUtAE4Adpe0HRDAfOADowvbzMx6paNEIOkCUt/+XODpPDl47qmhz4iIg4eZfHa3AZqZ\nWb06bRHMALaJCPfVm5mtYjo9a+h24MV1BmJmZoPRaYtgQ+BOSTcDf2lMjIj9aonKzMz6ptNEcGKd\nQZiZjReN6wvmn7z3gCPpnU5PH71e0ubAVhHxU0lrARPqDc3MzPqh09tQvx+4DDgzT9oE+E5dQZmZ\nWf90erD4aGA3YDk885CajeoKyszM+qfTRPCXiHiyMSJpIuk6AjMzG+c6TQTXSzoeWDM/q/hS4Hv1\nhWVmZv3SaSKYCSwBbiPdFuIqWjyZzMzMxpdOzxr6G3BWfpmZ2Sqk03sN3ccwxwQiYsueR2RmVrPq\nswbquh6gH+volW7uNdQwCTgIWL/34ZiZWb91dIwgIh6tvB6MiC8Ce9Qcm5mZ9UGnXUM7VEZXI7UQ\n1q0lIjMz66tOu4aqzxZ+ivRQmbf3PBozM+u7Ts8aen3dgZiZ2WB02jX0v9vNj4hTehOOmZn1Wzdn\nDe0EXJnH9wVuAB6oIygzM+ufbh5Ms0NErACQdCJwaUQcWVdgZnUZT+d328rG2v9urMUzWp3eYmI6\n8GRl/ElgqOfRmJlZ33XaIrgAuFnSFaQrjA8Azq8tKjMz65tOzxo6SdIPgdflSYdHxK/rC8vMzPql\n064hgLWA5RFxGrBA0hY1xWRmZn3U6aMqTwCOAz6RJ60OXFhXUGZm1j+dtggOAPYDHgeIiIX4FhNm\nZquEThPBkxER5FtRS1q7vpDMzKyfOk0El0g6E5gs6f3AT/FDaszMVgmdnjX0+fys4uXA1sCnIuLq\nWiMzK8SqclGSjV8jJgJJE4AfR8QbAX/5m5mtYkbsGoqIp4EnJK3Xh3jMzKzPOr2y+M/AbZKuJp85\nBBARH64lKjMz65tOE8EP8svMzFYxbROBpOkRcX9EnNftgiWdA+wDPBwR2+Zp6wPfIt2wbj7w9ohY\n2u2yzcysd0Y6RvCdxoCky7tc9rnAnk3TZgLXRMRWwDV53MzMBmikRKDK8JbdLDgibgAea5q8P9Bo\nXZwHvLWbZZqZWe+NlAiixfBobRwRiwDy341aFZR0lKTZkmYvWbKkB6s2s14bmvmDla6DsPFppETw\naknLJa0AXpWHl0taIWl5nYFFxKyImBERM6ZMmVLnqszMitb2YHFETOjx+hZLmhoRiyRNBR7u8fLN\nzKxL3TyPoBeuBA7Nw4cC3+3z+s3MrEltiUDSxcAvga0lLZB0BHAy8CZJdwNvyuNmZjZAnV5Q1rWI\nOLjFrDfUtU4zM+tev7uGzMxsjHEiMDMrXG1dQ2ZmrYy1ZzC0uxZirMVaB7cIzMwK50RgZlY4JwIz\ns8I5EZiZFc6JwMyscE4EZmaFcyIwMyucE4GZWeF8QZlZxfO5eKjx3rFw0dGqdBGUH3xTP7cIzMwK\n50RgZlY4JwIzs8I5EZiZFc6JwMyscE4EZmaFcyIwMyucryOwVd5YOKd+LMRQtxK2cVXlFoGZWeGc\nCMzMCudEYGZWOCcCM7PCORGYmRXOicDMrHBOBGZmhfN1BDZudHOeeq/uYT+ezo3vZptH++yE8VQf\n1jm3CMzMCudEYGZWOCcCM7PCORGYmRVuIAeLJc0HVgBPA09FxIxBxGFmZoM9a+j1EfHIANdvZma4\na8jMrHiDahEE8BNJAZwZEbOaC0g6CjgKYPr06aNe0WjPPfc50mZWikG1CHaLiB2AtwBHS/qH5gIR\nMSsiZkTEjClTpvQ/QjOzQgwkEUTEwvz3YeAKYOdBxGFmZgNIBJLWlrRuYxj4R+D2fsdhZmbJII4R\nbAxcIamx/m9ExI8GEIeZmTGARBAR9wKv7vd6zcxseD591MyscE4EZmaFcyIwMytccQ+mGe1FY77Y\nzKw7rR6U08vPT7vPZfO80T6MpwRuEZiZFc6JwMyscE4EZmaFcyIwMyucE4GZWeGcCMzMCudEYGZW\nuOKuI+iGzzsePwZxnUc357B3Om8QWp1v3xgfdDxWP7cIzMwK50RgZlY4JwIzs8I5EZiZFc6JwMys\ncE4EZmaFcyIwMytc0dcR9ON85V6tw+dW12csXS9S1zn9Y3n/GcuxlcItAjOzwjkRmJkVzonAzKxw\nTgRmZoVzIjAzK5wTgZlZ4ZwIzMwK50RgZla4oi8oez6qF8FUNV8E1O59I10802453ej3hXNVw62v\n04e2dLqcUi9C8oVYK3N9jJ5bBGZmhXMiMDMrnBOBmVnhnAjMzAo3kEQgaU9Jv5H0O0kzBxGDmZkl\nfU8EkiYAXwbeAmwDHCxpm37HYWZmySBaBDsDv4uIeyPiSeCbwP4DiMPMzABFRH9XKB0I7BkRR+bx\nQ4DXRMQxTeWOAo7Ko1sDvwE2BB7pY7jjjeunPddPe66f1sZr3WweEVNGKjSIC8o0zLTnZKOImAXM\nWumN0uyImFFXYOOd66c91097rp/WVvW6GUTX0AJgs8r4psDCAcRhZmYMJhH8CthK0haSXgC8E7hy\nAHGYmRkD6BqKiKckHQP8GJgAnBMRd3T49lkjFyma66c91097rp/WVum66fvBYjMzG1t8ZbGZWeGc\nCMzMCjdmE4GkzSRdJ2mepDskHZunry/pakl3578vGnSsgyJpgqRfS/p+Ht9C0k25br6VD8YXSdJk\nSZdJuivvQ7t633mWpH/On6vbJV0saVLJ+4+kcyQ9LOn2yrRh9xclX8q3yLlV0g6Di7w3xmwiAJ4C\nPhoRLwd2AY7Ot6KYCVwTEVsB1+TxUh0LzKuMfw44NdfNUuCIgUQ1NpwG/CgiXga8mlRP3ncASZsA\nHwZmRMS2pJM23knZ+8+5wJ5N01rtL28Btsqvo4Cv9inG2ozZRBARiyLiljy8gvRB3oR0O4rzcrHz\ngLcOJsLBkrQpsDfwtTwuYA/gslyk5Lp5IfAPwNkAEfFkRCzD+07VRGBNSROBtYBFFLz/RMQNwGNN\nk1vtL/sD50dyIzBZ0tT+RFqPMZsIqiQNAdsDNwEbR8QiSMkC2GhwkQ3UF4GPA3/L4xsAyyLiqTy+\ngJQ4S7QlsAT4eu46+5qktfG+A0BEPAh8HriflAD+AMzB+0+zVvvLJsADlXLjvq7GfCKQtA5wOfCR\niFg+6HjGAkn7AA9HxJzq5GGKlnpu8ERgB+CrEbE98DiFdgMNJ/d17w9sAUwD1iZ1dzQrdf8ZySr3\nWRvTiUDS6qQkcFFEfDtPXtxohuW/Dw8qvgHaDdhP0nzS3Vv3ILUQJuemPpR9644FwIKIuCmPX0ZK\nDN53kjcC90XEkoj4K/Bt4O/x/tOs1f6yyt0mZ8wmgtznfTYwLyJOqcy6Ejg0Dx8KfLffsQ1aRHwi\nIjaNiCHSQb5rI+LdwHXAgblYkXUDEBEPAQ9I2jpPegNwJ953Gu4HdpG0Vv6cNerH+8/KWu0vVwLv\nzWcP7QL8odGFNF6N2SuLJb0W+E/gNp7tBz+edJzgEmA6aYc+KCKaD/IUQ9LuwL9ExD6StiS1ENYH\nfg28JyL+Msj4BkXSdqQD6S8A7gUOJ/3w8b4DSPo08A7S2Xm/Bo4k9XMXuf9IuhjYnXS76cXACcB3\nGGZ/ycnzdNJZRk8Ah0fE7EHE3StjNhGYmVl/jNmuITMz6w8nAjOzwjkRmJkVzonAzKxwTgRmZoVz\nIrBVlqSQdEFlfKKkJZW7te4nadgrjiX9Mf+dJumyYeYP5eX/U2Xa6ZIOGyGmwyRNG+UmmdXCicBW\nZY8D20paM4+/CXiwMTMiroyIk9stICIWRsSBLWY/DBzb5e2aDyPd1sFszHAisFXdD0l3aQU4GLi4\nMSP/Oj89D28h6ZeSfiXpM5UyQ9V71DdZQro98aHNMyRtJ+nGfL/6KyS9SNKBwAzgIklzJa0paUdJ\n10uaI+nHlVsafFjSnfn93+xFRZi14kRgq7pvAu+UNAl4FenK9OGcRrpJ3U7AQ10s/2Tgo5ImNE0/\nHzguIl5Fujr+hIi4DJgNvDsitiNd1fv/gQMjYkfgHOCk/P6ZwPb5/R/sIh6zrjkR2CotIm4Fhkit\ngavaFN2NZ1sLF7Qp17z8+4CbgXc1pklaD5gcEdfnSeeRno/QbGtgW+BqSXOBT5JuYAZwK6nl8B5S\nwjCrzcSRi5iNe1eS7r+/O+m5Da2M9n4r/066w+kNXb5PwB0Rsesw8/YmJY/9gH+V9IrKswLMesot\nAivBOcC/RcRtbcr8gnQnV4B3d7PwiLiLdPfOffL4H4Clkl6XixwCNFoHK4B18/BvgCmSdoV023VJ\nr5C0GrBZRFxHevjQZGCdbmIy64ZbBLbKi4gFpGMA7RwLfEPSsaRnYHTrJNIdOxsOBc6QtBbP3v0U\n0rNxz5D0J2BX0m2fv5S7kyaSnivxW+DCPE2k5wgvG0VMZh3x3UfNzArnriEzs8I5EZiZFc6JwMys\ncE4EZmaFcyIwMyucE4GZWeGcCMzMCvffTLzYTKWQ23sAAAAASUVORK5CYII=\n",
"text/plain": [
"<matplotlib.figure.Figure at 0x111e6d7d0>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"## ======== accuracy test ======== ##\n",
"\n",
"accuracies = []\n",
"\n",
"# the example indices for all failed classification will be stored here\n",
"failed = []\n",
"\n",
"n_components = 2\n",
"\n",
"n_hiddens = 100\n",
"\n",
"X = np.asarray(prepare_input(20, n_components, norm_mfccs))\n",
"\n",
"# an empty matrix is created to store the total confusion matrix after all iterations\n",
"total_matrix = np.zeros([3, 3], dtype = int)\n",
"\n",
"# this process will be iterated 100 times to gain a reliable estimate for the models accuracy\n",
"for i in range(100):\n",
"\n",
" indices = np.random.permutation(len(X))\n",
" \n",
" # this training and test split 100 random examples for testing and the rest for training\n",
" fold_train, fold_test = indices[100:], indices[:100]\n",
" \n",
" net = MLPClassifier(activation='tanh', hidden_layer_sizes=([n_hiddens]), learning_rate_init=0.01)\n",
"\n",
" net.fit(X[fold_train], y[fold_train])\n",
"\n",
" preds = net.predict(X[fold_test])\n",
" \n",
" accuracies.extend([accuracy_score(preds, y[fold_test])])\n",
" \n",
" failed.extend(np.asarray(failed_indices(preds, y, fold_test)))\n",
" \n",
" total_matrix += confusion_matrix(preds, y[fold_test])\n",
" \n",
"print 'average accuracy over 100 iterations: ' + str(np.sum(accuracies) / 100)\n",
"\n",
"print 'total confusion matrix: ' + '\\n' + str(total_matrix)\n",
"\n",
"# an array containing the count of examples misclassified at each MIDI note is returned\n",
"# the index of the element refers to the note and the count is the value of the element\n",
"note_counts = np.bincount(note[failed])\n",
"\n",
"# this count is then plotted\n",
"plot_note_counts(note_counts)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Secondary Accuracy Test\n",
"Upon reflection I realised that the initial accuracy test reflected the models ability to recognise the properties of individual patches rather than their patch type. To truly test whether the model will generalize to unseen patches we must instead split our examples into bins by their individual patches."
]
},
{
"cell_type": "code",
"execution_count": 229,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"average accuracy for unseen patches: 0.620168067227\n",
"total confusion matrix: \n",
"[[306 175 122]\n",
" [157 370 42]\n",
" [132 50 431]]\n"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAYIAAAEWCAYAAABrDZDcAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMS4wLCBo\ndHRwOi8vbWF0cGxvdGxpYi5vcmcvpW3flQAAHRdJREFUeJzt3XmcXFWd9/HPlwRIwhaBZklC0yBM\nECKyBAXRGQR5RNbxERSUHYz6qKCPyiYjMIKDzzBsDyqGRVaDgKDogLLDjMNiAgxbQAQiCWEJSwBZ\nROA3f5xTcFNUV1dXaunu+32/Xv3qu5y659xTt+p3z7m3zlVEYGZm5bVEtwtgZmbd5UBgZlZyDgRm\nZiXnQGBmVnIOBGZmJedAYGZWcg4ELSbpdEn/VGd9SFqnkbTDkaSxkn4t6QVJl3S7PJ1SfF9bvN2t\nJM1r9XYL21/kGJT0ZUlPSfqLpJXy/7Wb3HZb6qRT2l33Q4kDQYMkzZH0uqSVq5bflQ/4PoCI+FJE\nfK+RbdZLK2nfvN1vVy2fJ2mrBsrbrYN4V2BVYKWI2K16paSjJV3Q+WK1jqQbJR3Ywu19UNKVkhZK\nek7S7ZL2a9X26ykeg5KWBE4E/ldELBsRz+b/j7Q631yHr+VAU/n7davzscY4EAzOo8AelRlJ7wfG\ntjG/54BDJS3fxjxabU3gjxHxRrcLUiFpVCPLukHSFsD1wE3AOsBKwJeBT3ahOKsCY4D7OpTfV3Og\nqfzt1KF8rYoDweCcD+xdmN8HOK+YQNI5ko4tzH9b0hOS5kvav17aGmYDtwDfqLVS0tKSTs7bnp+n\nl5a0DHAVMKFwtjVB0hKSDpP0sKRnJV0sacW8rTGSLsjLF0r6g6RV+8n3ffmMbqGk+yTtnJcfA3wX\n+GzO84A6+1bZVkj6kqSHJD0v6YeSVFj/BUmzJb0k6X5Jm9QrQ6Fef5zPsl8GPtbPsqUlnSDpsdwd\ncrqksYXt7JJbfC/mOttO0nHAR4HT8j6eVrU/m+VtjS4s+7Sku/qpgn8Fzo2IH0TEM5HMiojP9FNf\nlfevUh+fKqxbR9JNSt1yz0j6eV4uSSdJejqvu1vSlEJdHSvp74AH86YWSrq+8P5UujIHqq9+j/XB\nkHSopFsrdajUXXWfpDF5/hJJT+Z9uVnSBoXXniPpR5Kuyu/P7yWtlj8bz0t6QNLGhfRzJB2e6/J5\nST+t5FOjXBMk/ULSAkmPSjqosO6DkmbmY+UpSSc2u/9dERH+a+APmAN8nPRheR8wCphLOgMOoC+n\nOwc4Nk9vBzwFTAGWAX6W065TnbZGfvsC/wlsBCwEVszL5wFb5el/Bm4FVgF6gP8CvpfXbQXMq9rm\n13P6ScDSwE+AGXndF4FfA+Pyvm0KLF+jXEsCfwKOAJYCtgZeAibn9UcDF9Spx0XW5/r4DTAe6AUW\nANvldbsBjwObASKdMa/ZQBnOAV4AtiSd7IzpZ9nJwBXAisByef//JW/jgzn9tjn9RGC9vO5G4MCq\n/Sq+r/cDnyysuxz4Zo26GAe8CXysTn0t8j7mOpmQy/RZ4GVg9bxuBvCdwv59JC//BDAr17FIx+/q\nhbqqHK99eT9G97Nf9eqr7rFeY7/eVYeFdUsAN+djZV3geWDjwvr9c/5L5zLdVVh3DvAM6fgdQ2pt\nPUo6gRsFHAvcUPW5vhdYI+/X7wv18Xbd5zLNIp3oLAWsDTwCfCKvvwXYK08vC2ze7e+sQX2/dbsA\nw+WPdwLBkcC/5AP/GmA0/QeCs4HjC9v4OwYZCPL0xcAP8nQxEDwMbF94zSeAOXn67YO4sH42sE1h\nfnXgb3kf9icFkg0HqIePAk8CSxSWzQCOztNHM/hA8JHC/MXAYXn6d8DBTZThHOC8qtcssoz0hfgy\n8N7Csi2AR/P0T4CT+tmHG6kfCA4FLszTKwKvkL94q14zMb9uvTr19a73sWr9XcAuefo8YDowqSrN\n1sAfgc2LdVbjeO2jn0DQQH3VPdb7qcNXSCc5lb/vFdb3kbpGZwOH19n/8TmfFQr7c0Zh/deA2YX5\n9wMLqz7XXyrMbw88XF33wIeAx6ryPhz4aZ6+GTgGWLne52eo/rlraPDOBz5H+qI+r35SJpBaDRV/\nbjLP7wJflrRaje0Xt/nnvKw/awKX5+6UhaQP2ZukvuHzSV+8F+Wm/f9TunhYbQIwNyLeqsp34qD2\naFFPFqZfIZ1RQTpLe7jJMszl3YrLekhn5LMK9fHbvLxe3o24ANhJ0rLAZ4D/iIgnaqR7HniLFJAb\nImnv3F1VKfMUoHIDwyGkL+zbc1fK/gARcT1wGvBD4ClJ0zX4604D1Vczx/pBETG+8Pf23UsRMQe4\ngRQQfljY/1GSjs/dYy+SvsjhnTqA1DKpeLXG/LIsqrrctT5Da5K6WhcW9v8I0mcH4ABS8HtAqVt1\nxzr7PeQ4EAxSRPyZ1NTcHrhsgORPkL5QKnqbzPOBnNcRVavmkw7Q4vbnV15WY1NzSV0WxQ/fmIh4\nPCL+FhHHRMT6wIeBHVn0ekgxzzUkFY+dXlIXTqvNBd7bZBlq7X9x2TOkL4UNCnWxQkRUviT6y7u/\nbb+zMuJxUlfBp4C9SEG2VrpXcrpP19tehaQ1gTOAr5LuyhpP6tZQ3t6TEfGFiJhA6ur7UaV/PyJO\njYhNgQ1IX1jfrpVHHQPVV0uO9QpJ25NaHNeRrqNUfA7YhdQ6X4EUKCDXQZOqyz2/Rpq5pNZP8bOz\nXERsDxARD0XEHqRu2h8AlypdqxsWHAiacwCwdUS8PEC6i4F9Ja0vaRxw1GLkeQywH6kpXDEDOFJS\nj9Jtrd8lnY1COgtaSdIKhfSnA8flLxTy63bJ0x+T9H6lu2leJHUZvVmjHLeRuggOkbSk0q2sOwEX\nLca+9edM4FuSNs0XPNfJZV/sMuTWxBnASZJWAZA0UdIncpKzgP0kbaN0kX2ipPXyuqdIfcT1nEc6\nQ38/6RpBfw4hHSPflrRSLscHJNXal2VIQWhBTrcfqUVAnt9N0qQ8+3xO+6bSBewP5Rbey8Br1H5v\n+9VAfbXsWM/H8lnAgaQbMnbKgQHStYG/As+SWijfbzafgq9ImqR048QRwM9rpLkdeDFfyB6bWyZT\nJG2Wy7ynpJ5cTwvzawZVx93kQNCEiHg4ImY2kO4q0sWs60kXN69fjDwfJZ1ZFs8yjgVmAncD9wB3\n5GWVVsQM4JHclJ0AnEK62He1pJdIF44/lLe1GnApKQjMJt3O+K77/SPidWBn0u2NzwA/AvbO+bVU\nRFwCHEe68PgS8EvSRfNWleFQ0vtya+5muBaYnPO+nRR4TyJdNL6Jd1pfpwC75rtMTu1n25fn9JfX\nO2GIiP8i9eFvTXqvniP1819ZI+39wL+RWhFPkYLM7wtJNgNuk/QX0vt8cD5ulid9iT9P6vp4Fjih\n/2rpV736auZYr9x5VfmblZdPB34VEVdGxLOkE68zc6A8L+/D46SL8rc2sR/VfgZcTbr4+wj5M1QU\nEW+STjY2IvUIPEM6UamcaG0H3Jfr/hRg94h4rQVl6wjlCx1m1mKSHga+GBHXdrssVpukOaQL/6V+\nj9wiMGsDSZ8mdc003Qo065TRAycxs8GQdCOwPum+8rcGSG7Wde4aMjMrOXcNmZmV3LDoGlp55ZWj\nr6+v28UwMxtWZs2a9UxE9AyUblgEgr6+PmbOHPBuTTMzK5DU0GgG7hoyMys5BwIzs5JzIDAzKzkH\nAjOzknMgMDMrOQcCM7OSa1sgkHS20jNS762x7ltKz0JdudZrzcysc9rZIjiHNDTrIiStQXoO7GNt\nzNvMzBrUtkAQETeTnjla7STSwzg8yJGZ2RDQ0V8WS9oZeDwi/luq/2Q5SdOAaQC9vYv11Dszs6b1\nHfbvb0/POX6HLpakfTp2sTg/vu47pMcpDigipkfE1IiY2tMz4FAZZmbWpE7eNfReYC3gv/NTgSYB\nd0harYNlMDOzKh3rGoqIe4BVKvM5GEyNiGc6VQYzM3u3dt4+OoP0kO3JkuZJOqBdeZmZWfPa1iKI\niD0GWN/XrrzNzKxx/mWxmVnJORCYmZWcA4GZWck5EJiZlZwDgZlZyTkQmJmVnAOBmVnJORCYmZWc\nA4GZWcl1dBhqMyvHsMYjWeX9m3P8DiPmvXSLwMys5BwIzMxKzoHAzKzkHAjMzErOgcDMrOQcCMzM\nSs6BwMys5BwIzMxKzoHAzKzkHAjMzEqubYFA0tmSnpZ0b2HZv0p6QNLdki6XNL5d+ZuZWWPa2SI4\nB9iuatk1wJSI2BD4I3B4G/M3M7MGtC0QRMTNwHNVy66OiDfy7K3ApHblb2Zmjenm6KP7Az/vb6Wk\nacA0gN7e3k6VyaxpxVEp+1tX73XVRtLolvWM1H0cTvvVlYvFkr4DvAFc2F+aiJgeEVMjYmpPT0/n\nCmdmVjIdbxFI2gfYEdgmIqLT+ZuZ2aI6GggkbQccCvxDRLzSybzNzKy2dt4+OgO4BZgsaZ6kA4DT\ngOWAayTdJen0duVvZmaNaVuLICL2qLH4rHblZ2ZmzfEvi83MSs6BwMys5BwIzMxKzoHAzKzkHAjM\nzErOgcDMrOQcCMzMSs6BwMys5BwIzMxKrpvDUJsNK4MZVng4DUFcrdtlr86/HeWpl0fRcHvvmuUW\ngZlZyTkQmJmVnAOBmVnJORCYmZWcA4GZWck5EJiZlZwDgZlZyTkQmJmVnAOBmVnJORCYmZWcA4GZ\nWcm1LRBIOlvS05LuLSxbUdI1kh7K/9/TrvzNzKwx7WwRnANsV7XsMOC6iFgXuC7Pm5lZF7UtEETE\nzcBzVYt3Ac7N0+cC/9iu/M3MrDGdHoZ61Yh4AiAinpC0Sn8JJU0DpgH09vZ2qHg2lDU7DHTRYIcV\n7m87gy1Pq1TyrDU8cyeGUu4vz1bW63DW7PHS7aG/h+zF4oiYHhFTI2JqT09Pt4tjZjZidToQPCVp\ndYD8/+kO529mZlU6HQiuAPbJ0/sAv+pw/mZmVqWdt4/OAG4BJkuaJ+kA4HhgW0kPAdvmeTMz66K2\nXSyOiD36WbVNu/I0M7PBG7IXi83MrDMcCMzMSs6BwMys5BwIzMxKzoHAzKzkHAjMzErOgcDMrOQc\nCMzMSs6BwMys5Br6ZbGkKRFx78ApzVpnoKF56w3J3Ew+3Rj+t1M6MczxUMujVUNmt0O3h52u1miL\n4HRJt0v6P5LGt7VEZmbWUQ0Fgoj4CPB5YA1gpqSfSdq2rSUzM7OOaPgaQUQ8BBwJHAr8A3CqpAck\n/e92Fc7MzNqvoUAgaUNJJwGzga2BnSLifXn6pDaWz8zM2qzRYahPA84AjoiIVysLI2K+pCPbUjIz\nM+uIRgPB9sCrEfEmgKQlgDER8UpEnN+20pmZWds1eo3gWmBsYX5cXmZmZsNco4FgTET8pTKTp8e1\np0hmZtZJjQaClyVtUpmRtCnwap30ZmY2TDR6jeDrwCWS5uf51YHPtqdIZmbWSQ0Fgoj4g6T1gMmA\ngAci4m/NZirpG8CBQAD3APtFxGvNbs/MzJo3mEHnNgM2BDYG9pC0dzMZSpoIHARMjYgpwChg92a2\nZWZmi6/RQefOB94L3AW8mRcHcN5i5DtW0t9IF53nD5DezMzapNFrBFOB9SMiFjfDiHhc0gnAY6QL\nzldHxNXV6SRNA6YB9Pb2Lm62Zg0ZaqNCdlur6mOojUw61HV6XxrtGroXWK0VGUp6D7ALsBYwAVhG\n0p7V6SJiekRMjYipPT09rcjazMxqaLRFsDJwv6Tbgb9WFkbEzk3k+XHg0YhYACDpMuDDwAVNbMvM\nzBZTo4Hg6Bbm+RiwuaRxpK6hbYCZLdy+mZkNQqO3j94kaU1g3Yi4Nn+Jj2omw4i4TdKlwB3AG8Cd\nwPRmtmVmZouv0buGvkC6cLsi6e6hicDppLP5QYuIo4CjmnmtmZm1VqMXi78CbAm8CG8/pGaVdhXK\nzMw6p9FA8NeIeL0yI2k06XcEZmY2zDUaCG6SdATpR2DbApcAv25fsczMrFMaDQSHAQtI4wJ9EbiS\n9PxiMzMb5hq9a+gt0qMqz2hvcczMrNMavWvoUWpcE4iItVteIjMz66jBjDVUMQbYjXQrqZmZDXMN\nXSOIiGcLf49HxMnA1m0um5mZdUCjXUObFGaXILUQlmtLiczMrKMa7Rr6t8L0G8Ac4DMtL42NCIMZ\nQrc6bXG+XUbScMXNqtRBO/e/0++lNa/Ru4Y+1u6CmJlZdzTaNfR/662PiBNbUxwzM+u0wdw1tBlw\nRZ7fCbgZmNuOQpmZWecM5sE0m0TESwCSjgYuiYgD21UwMzPrjEaHmOgFXi/Mvw70tbw0ZmbWcY22\nCM4Hbpd0OekXxp8CzmtbqczMrGMavWvoOElXAR/Ni/aLiDvbVywzM+uURruGAMYBL0bEKcA8SWu1\nqUxmZtZBDQUCSUcBhwKH50VLAhe0q1BmZtY5jbYIPgXsDLwMEBHz8RATZmYjQqOB4PWICPJQ1JKW\naV+RzMyskxoNBBdL+gkwXtIXgGtZjIfUSBov6VJJD0iaLWmLZrdlZmaLp9G7hk7Izyp+EZgMfDci\nrlmMfE8BfhsRu0painQh2szMumDAQCBpFPC7iPg4sDhf/pXtLQ/8PbAvQES8zqI/VjMzsw4aMBBE\nxJuSXpG0QkS80II81wYWAD+V9AFgFnBwRLxcTCRpGjANoLe3twXZWjvVGw64G0NNW+M8LLc1eo3g\nNeAeSWdJOrXy12Seo4FNgB9HxMakO5EOq04UEdMjYmpETO3p6WkyKzMzG0ijQ0z8e/5rhXnAvIi4\nLc9fSo1AYGZmnVE3EEjqjYjHIuLcVmUYEU9KmitpckQ8CGwD3N+q7ZuZ2eAM1DX0y8qEpF+0MN+v\nARdKuhvYCPh+C7dtZmaDMFDXkArTa7cq04i4i/SwGzMz67KBWgTRz7SZmY0QA7UIPiDpRVLLYGye\nJs9HRCzf1tKZmVnb1Q0EETGqUwUxM7PuGMzzCMzMbARyIDAzKzkHAjOzknMgMDMrOQcCM7OScyAw\nMyu5RgedGzHqDbnb33DJ1UMn1xtKebgP41vcZzMrB7cIzMxKzoHAzKzkHAjMzErOgcDMrOQcCMzM\nSs6BwMys5BwIzMxKzoHAzKzkHAjMzErOgcDMrOS6FggkjZJ0p6TfdKsMZmbW3RbBwcDsLuZvZmZ0\nKRBImgTsAJzZjfzNzOwd3WoRnAwcArzVpfzNzCzr+DDUknYEno6IWZK2qpNuGjANoLe3t+n8BjPs\ndKvUG6J6MMNZt6t8jSrL0NtmZdeNFsGWwM6S5gAXAVtLuqA6UURMj4ipETG1p6en02U0MyuNjgeC\niDg8IiZFRB+wO3B9ROzZ6XKYmVni3xGYmZVcVx9VGRE3Ajd2swxmZmXnFoGZWck5EJiZlZwDgZlZ\nyTkQmJmVnAOBmVnJORCYmZWcA4GZWck5EJiZlZwDgZlZyXX1l8XWuHojgRbnh9sIp2bWfW4RmJmV\nnAOBmVnJORCYmZWcA4GZWck5EJiZlZwDgZlZyTkQmJmVnAOBmVnJORCYmZWcA4GZWck5EJiZlVzH\nA4GkNSTdIGm2pPskHdzpMpiZ2Tu6MejcG8A3I+IOScsBsyRdExH3d6EsZmal1/EWQUQ8ERF35OmX\ngNnAxE6Xw8zMkq5eI5DUB2wM3FZj3TRJMyXNXLBgQaeLZmZWGl0LBJKWBX4BfD0iXqxeHxHTI2Jq\nREzt6enpfAHNzEqiK4FA0pKkIHBhRFzWjTKYmVnSjbuGBJwFzI6IEzudv5mZLaobLYItgb2ArSXd\nlf+270I5zMyMLtw+GhH/CajT+ZqZWW3+ZbGZWck5EJiZlZwDgZlZyTkQmJmVnAOBmVnJORCYmZWc\nA4GZWck5EJiZlZwDgZlZyTkQmJmVnAOBmVnJORCYmZWcA4GZWck5EJiZlZwDgZlZyTkQmJmVnAOB\nmVnJORCYmZWcA4GZWck5EJiZlZwDgZlZyXUlEEjaTtKDkv4k6bBulMHMzJKOBwJJo4AfAp8E1gf2\nkLR+p8thZmZJN1oEHwT+FBGPRMTrwEXALl0oh5mZAYqIzmYo7QpsFxEH5vm9gA9FxFer0k0DpuXZ\nycCDwMrAMx0s7nDj+qnP9VOf66d/w7Vu1oyInoESje5ESaqoxrJ3RaOImA5MX+SF0syImNqugg13\nrp/6XD/1uX76N9LrphtdQ/OANQrzk4D5XSiHmZnRnUDwB2BdSWtJWgrYHbiiC+UwMzO60DUUEW9I\n+irwO2AUcHZE3Nfgy6cPnKTUXD/1uX7qc/30b0TXTccvFpuZ2dDiXxabmZWcA4GZWckN2UAgaQ1J\nN0iaLek+SQfn5StKukbSQ/n/e7pd1m6RNErSnZJ+k+fXknRbrpuf54vxpSRpvKRLJT2Qj6EtfOy8\nQ9I38ufqXkkzJI0p8/Ej6WxJT0u6t7Cs5vGi5NQ8RM7dkjbpXslbY8gGAuAN4JsR8T5gc+AreSiK\nw4DrImJd4Lo8X1YHA7ML8z8ATsp18zxwQFdKNTScAvw2ItYDPkCqJx87gKSJwEHA1IiYQrppY3fK\nffycA2xXtay/4+WTwLr5bxrw4w6VsW2GbCCIiCci4o48/RLpgzyRNBzFuTnZucA/dqeE3SVpErAD\ncGaeF7A1cGlOUua6WR74e+AsgIh4PSIW4mOnaDQwVtJoYBzwBCU+fiLiZuC5qsX9HS+7AOdFcisw\nXtLqnSlpewzZQFAkqQ/YGLgNWDUinoAULIBVuleyrjoZOAR4K8+vBCyMiDfy/DxS4CyjtYEFwE9z\n19mZkpbBxw4AEfE4cALwGCkAvADMwsdPtf6Ol4nA3EK6YV9XQz4QSFoW+AXw9Yh4sdvlGQok7Qg8\nHRGziotrJC3rvcGjgU2AH0fExsDLlLQbqJbc170LsBYwAViG1N1RrazHz0BG3GdtSAcCSUuSgsCF\nEXFZXvxUpRmW/z/drfJ10ZbAzpLmkEZv3ZrUQhifm/pQ7qE75gHzIuK2PH8pKTD42Ek+DjwaEQsi\n4m/AZcCH8fFTrb/jZcQNkzNkA0Hu8z4LmB0RJxZWXQHsk6f3AX7V6bJ1W0QcHhGTIqKPdJHv+oj4\nPHADsGtOVsq6AYiIJ4G5kibnRdsA9+Njp+IxYHNJ4/LnrFI/Pn4W1d/xcgWwd757aHPghUoX0nA1\nZH9ZLOkjwH8A9/BOP/gRpOsEFwO9pAN6t4iovshTGpK2Ar4VETtKWpvUQlgRuBPYMyL+2s3ydYuk\njUgX0pcCHgH2I534+NgBJB0DfJZ0d96dwIGkfu5SHj+SZgBbkYabfgo4CvglNY6XHDxPI91l9Aqw\nX0TM7Ea5W2XIBgIzM+uMIds1ZGZmneFAYGZWcg4EZmYl50BgZlZyDgRmZiXnQGAjlqSQdH5hfrSk\nBYXRWneWVPMXx5L+kv9PkHRpjfV9eftfKyw7TdK+A5RpX0kTmtwls7ZwILCR7GVgiqSxeX5b4PHK\nyoi4IiKOr7eBiJgfEbv2s/pp4OBBDte8L2lYB7Mhw4HARrqrSKO0AuwBzKisyGfnp+XptSTdIukP\nkr5XSNNXHKO+ygLS8MT7VK+QtJGkW/N49ZdLeo+kXYGpwIWS7pI0VtKmkm6SNEvS7wpDGhwk6f78\n+otaURFm/XEgsJHuImB3SWOADUm/TK/lFNIgdZsBTw5i+8cD35Q0qmr5ecChEbEh6dfxR0XEpcBM\n4PMRsRHpV73/H9g1IjYFzgaOy68/DNg4v/5LgyiP2aA5ENiIFhF3A32k1sCVdZJuyTuthfPrpKve\n/qPA7cDnKsskrQCMj4ib8qJzSc9HqDYZmAJcI+ku4EjSAGYAd5NaDnuSAoZZ24weOInZsHcFafz9\nrUjPbehPs+OtfJ80wunNg3ydgPsiYosa63YgBY+dgX+StEHhWQFmLeUWgZXB2cA/R8Q9ddL8njSS\nK8DnB7PxiHiANHrnjnn+BeB5SR/NSfYCKq2Dl4Dl8vSDQI+kLSANuy5pA0lLAGtExA2khw+NB5Yd\nTJnMBsMtAhvxImIe6RpAPQcDP5N0MOkZGIN1HGnEzop9gNMljeOd0U8hPRv3dEmvAluQhn0+NXcn\njSY9V+KPwAV5mUjPEV7YRJnMGuLRR83MSs5dQ2ZmJedAYGZWcg4EZmYl50BgZlZyDgRmZiXnQGBm\nVnIOBGZmJfc/iLEijczJOE0AAAAASUVORK5CYII=\n",
"text/plain": [
"<matplotlib.figure.Figure at 0x11d06a410>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"## ======== accuracy test for unseen patch ======== ##\n",
"\n",
"\n",
"n_components = 2\n",
"\n",
"n_hiddens = 100\n",
"\n",
"accuracies = []\n",
"\n",
"failed = []\n",
"\n",
"n_classes = 3\n",
"\n",
"total_matrix = np.zeros([3, 3], dtype = int)\n",
"\n",
"X = np.asarray(prepare_input(20, n_components, norm_mfccs))\n",
"\n",
"# the only significant difference between these tests is that our bins are seperated by patch.\n",
"bins = bins_by_patch(y, patch_index)\n",
"\n",
"for i in range(len(bins)):\n",
" \n",
" fold_test = bins[i]\n",
" \n",
" fold_test = reshape(fold_test, -1)\n",
" \n",
" fold_train = bins[:i]\n",
" \n",
" fold_train.extend(bins[i+1:])\n",
" \n",
" fold_train = reshape(fold_train, -1)\n",
" \n",
" net = MLPClassifier(activation='tanh', hidden_layer_sizes=([n_hiddens]), learning_rate_init=0.01)\n",
"\n",
" net.fit(X[fold_train], y[fold_train])\n",
"\n",
" preds = net.predict(X[fold_test])\n",
" \n",
" accuracies.extend([accuracy_score(preds, y[fold_test])])\n",
" \n",
" failed.extend(np.asarray(failed_indices(preds, y, fold_test)))\n",
" \n",
" total_matrix += confusion_matrix(preds, y[fold_test], n_classes)\n",
" \n",
"print 'average accuracy for unseen patches: ' + str(np.sum(accuracies) / len(bins))\n",
"\n",
"print 'total confusion matrix: ' + '\\n' + str(total_matrix)\n",
"\n",
"note_counts = np.bincount(note[failed])\n",
"\n",
"plot_note_counts(note_counts)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Full analysis of the results from both tests are present in my report."
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 2",
"language": "python",
"name": "python2"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 2
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython2",
"version": "2.7.14"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment